mirror of
https://github.com/QwenLM/qwen-code.git
synced 2025-12-20 16:57:46 +00:00
fix(tool-scheduler): Correctly pipe cancellation signal to tool calls (#852)
This commit is contained in:
@@ -18,6 +18,7 @@ import {
|
||||
import { Config } from '@gemini-cli/core';
|
||||
import { Part, PartListUnion } from '@google/genai';
|
||||
import { UseHistoryManagerReturn } from './useHistoryManager.js';
|
||||
import { Dispatch, SetStateAction } from 'react';
|
||||
|
||||
// --- MOCKS ---
|
||||
const mockSendMessageStream = vi
|
||||
@@ -309,16 +310,41 @@ describe('useGeminiStream', () => {
|
||||
|
||||
const client = geminiClient || mockConfig.getGeminiClient();
|
||||
|
||||
const { result, rerender } = renderHook(() =>
|
||||
useGeminiStream(
|
||||
client,
|
||||
mockAddItem as unknown as UseHistoryManagerReturn['addItem'],
|
||||
mockSetShowHelp,
|
||||
mockConfig,
|
||||
mockOnDebugMessage,
|
||||
mockHandleSlashCommand,
|
||||
false, // shellModeActive
|
||||
),
|
||||
const { result, rerender } = renderHook(
|
||||
(props: {
|
||||
client: any;
|
||||
addItem: UseHistoryManagerReturn['addItem'];
|
||||
setShowHelp: Dispatch<SetStateAction<boolean>>;
|
||||
config: Config;
|
||||
onDebugMessage: (message: string) => void;
|
||||
handleSlashCommand: (
|
||||
command: PartListUnion,
|
||||
) =>
|
||||
| import('./slashCommandProcessor.js').SlashCommandActionReturn
|
||||
| boolean;
|
||||
shellModeActive: boolean;
|
||||
}) =>
|
||||
useGeminiStream(
|
||||
props.client,
|
||||
props.addItem,
|
||||
props.setShowHelp,
|
||||
props.config,
|
||||
props.onDebugMessage,
|
||||
props.handleSlashCommand,
|
||||
props.shellModeActive,
|
||||
),
|
||||
{
|
||||
initialProps: {
|
||||
client,
|
||||
addItem: mockAddItem as unknown as UseHistoryManagerReturn['addItem'],
|
||||
setShowHelp: mockSetShowHelp,
|
||||
config: mockConfig,
|
||||
onDebugMessage: mockOnDebugMessage,
|
||||
handleSlashCommand:
|
||||
mockHandleSlashCommand as unknown as typeof mockHandleSlashCommand,
|
||||
shellModeActive: false,
|
||||
},
|
||||
},
|
||||
);
|
||||
return {
|
||||
result,
|
||||
@@ -326,7 +352,6 @@ describe('useGeminiStream', () => {
|
||||
mockMarkToolsAsSubmitted,
|
||||
mockSendMessageStream,
|
||||
client,
|
||||
// mockFilter removed
|
||||
};
|
||||
};
|
||||
|
||||
@@ -423,24 +448,29 @@ describe('useGeminiStream', () => {
|
||||
} as TrackedCancelledToolCall,
|
||||
];
|
||||
|
||||
const hookResult = await act(async () =>
|
||||
renderTestHook(simplifiedToolCalls),
|
||||
);
|
||||
|
||||
const {
|
||||
rerender,
|
||||
mockMarkToolsAsSubmitted,
|
||||
mockSendMessageStream: localMockSendMessageStream,
|
||||
} = hookResult!;
|
||||
client,
|
||||
} = renderTestHook(simplifiedToolCalls);
|
||||
|
||||
// It seems the initial render + effect run should be enough.
|
||||
// If rerender was for a specific state change, it might still be needed.
|
||||
// For now, let's test if the initial effect run (covered by the first act) is sufficient.
|
||||
// If not, we can add back: await act(async () => { rerender({}); });
|
||||
|
||||
expect(mockMarkToolsAsSubmitted).toHaveBeenCalledWith(['call1', 'call2']);
|
||||
act(() => {
|
||||
rerender({
|
||||
client,
|
||||
addItem: mockAddItem as unknown as UseHistoryManagerReturn['addItem'],
|
||||
setShowHelp: mockSetShowHelp,
|
||||
config: mockConfig,
|
||||
onDebugMessage: mockOnDebugMessage,
|
||||
handleSlashCommand:
|
||||
mockHandleSlashCommand as unknown as typeof mockHandleSlashCommand,
|
||||
shellModeActive: false,
|
||||
});
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
expect(localMockSendMessageStream).toHaveBeenCalledTimes(1);
|
||||
expect(mockMarkToolsAsSubmitted).toHaveBeenCalledTimes(0);
|
||||
expect(localMockSendMessageStream).toHaveBeenCalledTimes(0);
|
||||
});
|
||||
|
||||
const expectedMergedResponse = mergePartListUnions([
|
||||
@@ -479,12 +509,21 @@ describe('useGeminiStream', () => {
|
||||
client,
|
||||
);
|
||||
|
||||
await act(async () => {
|
||||
rerender({} as any);
|
||||
act(() => {
|
||||
rerender({
|
||||
client,
|
||||
addItem: mockAddItem as unknown as UseHistoryManagerReturn['addItem'],
|
||||
setShowHelp: mockSetShowHelp,
|
||||
config: mockConfig,
|
||||
onDebugMessage: mockOnDebugMessage,
|
||||
handleSlashCommand:
|
||||
mockHandleSlashCommand as unknown as typeof mockHandleSlashCommand,
|
||||
shellModeActive: false,
|
||||
});
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockMarkToolsAsSubmitted).toHaveBeenCalledWith(['1']);
|
||||
expect(mockMarkToolsAsSubmitted).toHaveBeenCalledTimes(0);
|
||||
expect(client.addHistory).toHaveBeenCalledTimes(2);
|
||||
expect(client.addHistory).toHaveBeenCalledWith({
|
||||
role: 'user',
|
||||
|
||||
@@ -83,28 +83,24 @@ export const useGeminiStream = (
|
||||
useStateAndRef<HistoryItemWithoutId | null>(null);
|
||||
const logger = useLogger();
|
||||
|
||||
const [
|
||||
toolCalls,
|
||||
scheduleToolCalls,
|
||||
cancelAllToolCalls,
|
||||
markToolsAsSubmitted,
|
||||
] = useReactToolScheduler(
|
||||
(completedToolCallsFromScheduler) => {
|
||||
// This onComplete is called when ALL scheduled tools for a given batch are done.
|
||||
if (completedToolCallsFromScheduler.length > 0) {
|
||||
// Add the final state of these tools to the history for display.
|
||||
// The new useEffect will handle submitting their responses.
|
||||
addItem(
|
||||
mapTrackedToolCallsToDisplay(
|
||||
completedToolCallsFromScheduler as TrackedToolCall[],
|
||||
),
|
||||
Date.now(),
|
||||
);
|
||||
}
|
||||
},
|
||||
config,
|
||||
setPendingHistoryItem,
|
||||
);
|
||||
const [toolCalls, scheduleToolCalls, markToolsAsSubmitted] =
|
||||
useReactToolScheduler(
|
||||
(completedToolCallsFromScheduler) => {
|
||||
// This onComplete is called when ALL scheduled tools for a given batch are done.
|
||||
if (completedToolCallsFromScheduler.length > 0) {
|
||||
// Add the final state of these tools to the history for display.
|
||||
// The new useEffect will handle submitting their responses.
|
||||
addItem(
|
||||
mapTrackedToolCallsToDisplay(
|
||||
completedToolCallsFromScheduler as TrackedToolCall[],
|
||||
),
|
||||
Date.now(),
|
||||
);
|
||||
}
|
||||
},
|
||||
config,
|
||||
setPendingHistoryItem,
|
||||
);
|
||||
|
||||
const pendingToolCallGroupDisplay = useMemo(
|
||||
() =>
|
||||
@@ -143,10 +139,15 @@ export const useGeminiStream = (
|
||||
return StreamingState.Idle;
|
||||
}, [isResponding, toolCalls]);
|
||||
|
||||
useEffect(() => {
|
||||
if (streamingState === StreamingState.Idle) {
|
||||
abortControllerRef.current = null;
|
||||
}
|
||||
}, [streamingState]);
|
||||
|
||||
useInput((_input, key) => {
|
||||
if (streamingState !== StreamingState.Idle && key.escape) {
|
||||
abortControllerRef.current?.abort();
|
||||
cancelAllToolCalls(); // Also cancel any pending/executing tool calls
|
||||
}
|
||||
});
|
||||
|
||||
@@ -191,7 +192,7 @@ export const useGeminiStream = (
|
||||
name: toolName,
|
||||
args: toolArgs,
|
||||
};
|
||||
scheduleToolCalls([toolCallRequest]);
|
||||
scheduleToolCalls([toolCallRequest], abortSignal);
|
||||
}
|
||||
return { queryToSend: null, shouldProceed: false }; // Handled by scheduling the tool
|
||||
}
|
||||
@@ -330,9 +331,8 @@ export const useGeminiStream = (
|
||||
userMessageTimestamp,
|
||||
);
|
||||
setIsResponding(false);
|
||||
cancelAllToolCalls();
|
||||
},
|
||||
[addItem, pendingHistoryItemRef, setPendingHistoryItem, cancelAllToolCalls],
|
||||
[addItem, pendingHistoryItemRef, setPendingHistoryItem],
|
||||
);
|
||||
|
||||
const handleErrorEvent = useCallback(
|
||||
@@ -365,6 +365,7 @@ export const useGeminiStream = (
|
||||
async (
|
||||
stream: AsyncIterable<GeminiEvent>,
|
||||
userMessageTimestamp: number,
|
||||
signal: AbortSignal,
|
||||
): Promise<StreamProcessingStatus> => {
|
||||
let geminiMessageBuffer = '';
|
||||
const toolCallRequests: ToolCallRequestInfo[] = [];
|
||||
@@ -401,7 +402,7 @@ export const useGeminiStream = (
|
||||
}
|
||||
}
|
||||
if (toolCallRequests.length > 0) {
|
||||
scheduleToolCalls(toolCallRequests);
|
||||
scheduleToolCalls(toolCallRequests, signal);
|
||||
}
|
||||
return StreamProcessingStatus.Completed;
|
||||
},
|
||||
@@ -453,6 +454,7 @@ export const useGeminiStream = (
|
||||
const processingStatus = await processGeminiStreamEvents(
|
||||
stream,
|
||||
userMessageTimestamp,
|
||||
abortSignal,
|
||||
);
|
||||
|
||||
if (processingStatus === StreamProcessingStatus.UserCancelled) {
|
||||
@@ -476,7 +478,6 @@ export const useGeminiStream = (
|
||||
);
|
||||
}
|
||||
} finally {
|
||||
abortControllerRef.current = null; // Always reset
|
||||
setIsResponding(false);
|
||||
}
|
||||
},
|
||||
|
||||
@@ -32,8 +32,8 @@ import {
|
||||
|
||||
export type ScheduleFn = (
|
||||
request: ToolCallRequestInfo | ToolCallRequestInfo[],
|
||||
signal: AbortSignal,
|
||||
) => void;
|
||||
export type CancelFn = (reason?: string) => void;
|
||||
export type MarkToolsAsSubmittedFn = (callIds: string[]) => void;
|
||||
|
||||
export type TrackedScheduledToolCall = ScheduledToolCall & {
|
||||
@@ -69,7 +69,7 @@ export function useReactToolScheduler(
|
||||
setPendingHistoryItem: React.Dispatch<
|
||||
React.SetStateAction<HistoryItemWithoutId | null>
|
||||
>,
|
||||
): [TrackedToolCall[], ScheduleFn, CancelFn, MarkToolsAsSubmittedFn] {
|
||||
): [TrackedToolCall[], ScheduleFn, MarkToolsAsSubmittedFn] {
|
||||
const [toolCallsForDisplay, setToolCallsForDisplay] = useState<
|
||||
TrackedToolCall[]
|
||||
>([]);
|
||||
@@ -172,15 +172,11 @@ export function useReactToolScheduler(
|
||||
);
|
||||
|
||||
const schedule: ScheduleFn = useCallback(
|
||||
async (request: ToolCallRequestInfo | ToolCallRequestInfo[]) => {
|
||||
scheduler.schedule(request);
|
||||
},
|
||||
[scheduler],
|
||||
);
|
||||
|
||||
const cancel: CancelFn = useCallback(
|
||||
(reason: string = 'unspecified') => {
|
||||
scheduler.cancelAll(reason);
|
||||
async (
|
||||
request: ToolCallRequestInfo | ToolCallRequestInfo[],
|
||||
signal: AbortSignal,
|
||||
) => {
|
||||
scheduler.schedule(request, signal);
|
||||
},
|
||||
[scheduler],
|
||||
);
|
||||
@@ -198,7 +194,7 @@ export function useReactToolScheduler(
|
||||
[],
|
||||
);
|
||||
|
||||
return [toolCallsForDisplay, schedule, cancel, markToolsAsSubmitted];
|
||||
return [toolCallsForDisplay, schedule, markToolsAsSubmitted];
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@@ -137,7 +137,7 @@ describe('useReactToolScheduler in YOLO Mode', () => {
|
||||
};
|
||||
|
||||
act(() => {
|
||||
schedule(request);
|
||||
schedule(request, new AbortController().signal);
|
||||
});
|
||||
|
||||
await act(async () => {
|
||||
@@ -290,7 +290,7 @@ describe('useReactToolScheduler', () => {
|
||||
};
|
||||
|
||||
act(() => {
|
||||
schedule(request);
|
||||
schedule(request, new AbortController().signal);
|
||||
});
|
||||
await act(async () => {
|
||||
await vi.runAllTimersAsync();
|
||||
@@ -337,7 +337,7 @@ describe('useReactToolScheduler', () => {
|
||||
};
|
||||
|
||||
act(() => {
|
||||
schedule(request);
|
||||
schedule(request, new AbortController().signal);
|
||||
});
|
||||
await act(async () => {
|
||||
await vi.runAllTimersAsync();
|
||||
@@ -374,7 +374,7 @@ describe('useReactToolScheduler', () => {
|
||||
};
|
||||
|
||||
act(() => {
|
||||
schedule(request);
|
||||
schedule(request, new AbortController().signal);
|
||||
});
|
||||
await act(async () => {
|
||||
await vi.runAllTimersAsync();
|
||||
@@ -410,7 +410,7 @@ describe('useReactToolScheduler', () => {
|
||||
};
|
||||
|
||||
act(() => {
|
||||
schedule(request);
|
||||
schedule(request, new AbortController().signal);
|
||||
});
|
||||
await act(async () => {
|
||||
await vi.runAllTimersAsync();
|
||||
@@ -451,7 +451,7 @@ describe('useReactToolScheduler', () => {
|
||||
};
|
||||
|
||||
act(() => {
|
||||
schedule(request);
|
||||
schedule(request, new AbortController().signal);
|
||||
});
|
||||
await act(async () => {
|
||||
await vi.runAllTimersAsync();
|
||||
@@ -507,7 +507,7 @@ describe('useReactToolScheduler', () => {
|
||||
};
|
||||
|
||||
act(() => {
|
||||
schedule(request);
|
||||
schedule(request, new AbortController().signal);
|
||||
});
|
||||
await act(async () => {
|
||||
await vi.runAllTimersAsync();
|
||||
@@ -579,7 +579,7 @@ describe('useReactToolScheduler', () => {
|
||||
};
|
||||
|
||||
act(() => {
|
||||
schedule(request);
|
||||
schedule(request, new AbortController().signal);
|
||||
});
|
||||
await act(async () => {
|
||||
await vi.runAllTimersAsync();
|
||||
@@ -634,102 +634,6 @@ describe('useReactToolScheduler', () => {
|
||||
expect(result.current[0]).toEqual([]);
|
||||
});
|
||||
|
||||
it.skip('should cancel tool calls before execution (e.g. when status is scheduled)', async () => {
|
||||
mockToolRegistry.getTool.mockReturnValue(mockTool);
|
||||
(mockTool.shouldConfirmExecute as Mock).mockResolvedValue(null);
|
||||
(mockTool.execute as Mock).mockReturnValue(new Promise(() => {}));
|
||||
|
||||
const { result } = renderScheduler();
|
||||
const schedule = result.current[1];
|
||||
const cancel = result.current[2];
|
||||
const request: ToolCallRequestInfo = {
|
||||
callId: 'cancelCall',
|
||||
name: 'mockTool',
|
||||
args: {},
|
||||
};
|
||||
|
||||
act(() => {
|
||||
schedule(request);
|
||||
});
|
||||
await act(async () => {
|
||||
await vi.runAllTimersAsync();
|
||||
});
|
||||
|
||||
act(() => {
|
||||
cancel();
|
||||
});
|
||||
await act(async () => {
|
||||
await vi.runAllTimersAsync();
|
||||
});
|
||||
|
||||
expect(onComplete).toHaveBeenCalledWith([
|
||||
expect.objectContaining({
|
||||
status: 'cancelled',
|
||||
request,
|
||||
response: expect.objectContaining({
|
||||
responseParts: expect.arrayContaining([
|
||||
expect.objectContaining({
|
||||
functionResponse: expect.objectContaining({
|
||||
response: expect.objectContaining({
|
||||
error:
|
||||
'[Operation Cancelled] Reason: User cancelled before execution',
|
||||
}),
|
||||
}),
|
||||
}),
|
||||
]),
|
||||
}),
|
||||
}),
|
||||
]);
|
||||
expect(mockTool.execute).not.toHaveBeenCalled();
|
||||
expect(result.current[0]).toEqual([]);
|
||||
});
|
||||
|
||||
it.skip('should cancel tool calls that are awaiting approval', async () => {
|
||||
mockToolRegistry.getTool.mockReturnValue(mockToolRequiresConfirmation);
|
||||
const { result } = renderScheduler();
|
||||
const schedule = result.current[1];
|
||||
const cancelFn = result.current[2];
|
||||
const request: ToolCallRequestInfo = {
|
||||
callId: 'cancelApprovalCall',
|
||||
name: 'mockToolRequiresConfirmation',
|
||||
args: {},
|
||||
};
|
||||
|
||||
act(() => {
|
||||
schedule(request);
|
||||
});
|
||||
await act(async () => {
|
||||
await vi.runAllTimersAsync();
|
||||
});
|
||||
|
||||
act(() => {
|
||||
cancelFn();
|
||||
});
|
||||
await act(async () => {
|
||||
await vi.runAllTimersAsync();
|
||||
});
|
||||
|
||||
expect(onComplete).toHaveBeenCalledWith([
|
||||
expect.objectContaining({
|
||||
status: 'cancelled',
|
||||
request,
|
||||
response: expect.objectContaining({
|
||||
responseParts: expect.arrayContaining([
|
||||
expect.objectContaining({
|
||||
functionResponse: expect.objectContaining({
|
||||
response: expect.objectContaining({
|
||||
error:
|
||||
'[Operation Cancelled] Reason: User cancelled during approval',
|
||||
}),
|
||||
}),
|
||||
}),
|
||||
]),
|
||||
}),
|
||||
}),
|
||||
]);
|
||||
expect(result.current[0]).toEqual([]);
|
||||
});
|
||||
|
||||
it('should schedule and execute multiple tool calls', async () => {
|
||||
const tool1 = {
|
||||
...mockTool,
|
||||
@@ -766,7 +670,7 @@ describe('useReactToolScheduler', () => {
|
||||
];
|
||||
|
||||
act(() => {
|
||||
schedule(requests);
|
||||
schedule(requests, new AbortController().signal);
|
||||
});
|
||||
await act(async () => {
|
||||
await vi.runAllTimersAsync();
|
||||
@@ -848,13 +752,13 @@ describe('useReactToolScheduler', () => {
|
||||
};
|
||||
|
||||
act(() => {
|
||||
schedule(request1);
|
||||
schedule(request1, new AbortController().signal);
|
||||
});
|
||||
await act(async () => {
|
||||
await vi.runAllTimersAsync();
|
||||
});
|
||||
|
||||
expect(() => schedule(request2)).toThrow(
|
||||
expect(() => schedule(request2, new AbortController().signal)).toThrow(
|
||||
'Cannot schedule tool calls while other tool calls are running',
|
||||
);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user