fix(tool-scheduler): Correctly pipe cancellation signal to tool calls (#852)

This commit is contained in:
N. Taylor Mullen
2025-06-08 15:42:49 -07:00
committed by GitHub
parent 7868ef8229
commit f2ea78d0e4
7 changed files with 235 additions and 209 deletions

View File

@@ -18,6 +18,7 @@ import {
import { Config } from '@gemini-cli/core';
import { Part, PartListUnion } from '@google/genai';
import { UseHistoryManagerReturn } from './useHistoryManager.js';
import { Dispatch, SetStateAction } from 'react';
// --- MOCKS ---
const mockSendMessageStream = vi
@@ -309,16 +310,41 @@ describe('useGeminiStream', () => {
const client = geminiClient || mockConfig.getGeminiClient();
const { result, rerender } = renderHook(() =>
useGeminiStream(
client,
mockAddItem as unknown as UseHistoryManagerReturn['addItem'],
mockSetShowHelp,
mockConfig,
mockOnDebugMessage,
mockHandleSlashCommand,
false, // shellModeActive
),
const { result, rerender } = renderHook(
(props: {
client: any;
addItem: UseHistoryManagerReturn['addItem'];
setShowHelp: Dispatch<SetStateAction<boolean>>;
config: Config;
onDebugMessage: (message: string) => void;
handleSlashCommand: (
command: PartListUnion,
) =>
| import('./slashCommandProcessor.js').SlashCommandActionReturn
| boolean;
shellModeActive: boolean;
}) =>
useGeminiStream(
props.client,
props.addItem,
props.setShowHelp,
props.config,
props.onDebugMessage,
props.handleSlashCommand,
props.shellModeActive,
),
{
initialProps: {
client,
addItem: mockAddItem as unknown as UseHistoryManagerReturn['addItem'],
setShowHelp: mockSetShowHelp,
config: mockConfig,
onDebugMessage: mockOnDebugMessage,
handleSlashCommand:
mockHandleSlashCommand as unknown as typeof mockHandleSlashCommand,
shellModeActive: false,
},
},
);
return {
result,
@@ -326,7 +352,6 @@ describe('useGeminiStream', () => {
mockMarkToolsAsSubmitted,
mockSendMessageStream,
client,
// mockFilter removed
};
};
@@ -423,24 +448,29 @@ describe('useGeminiStream', () => {
} as TrackedCancelledToolCall,
];
const hookResult = await act(async () =>
renderTestHook(simplifiedToolCalls),
);
const {
rerender,
mockMarkToolsAsSubmitted,
mockSendMessageStream: localMockSendMessageStream,
} = hookResult!;
client,
} = renderTestHook(simplifiedToolCalls);
// It seems the initial render + effect run should be enough.
// If rerender was for a specific state change, it might still be needed.
// For now, let's test if the initial effect run (covered by the first act) is sufficient.
// If not, we can add back: await act(async () => { rerender({}); });
expect(mockMarkToolsAsSubmitted).toHaveBeenCalledWith(['call1', 'call2']);
act(() => {
rerender({
client,
addItem: mockAddItem as unknown as UseHistoryManagerReturn['addItem'],
setShowHelp: mockSetShowHelp,
config: mockConfig,
onDebugMessage: mockOnDebugMessage,
handleSlashCommand:
mockHandleSlashCommand as unknown as typeof mockHandleSlashCommand,
shellModeActive: false,
});
});
await waitFor(() => {
expect(localMockSendMessageStream).toHaveBeenCalledTimes(1);
expect(mockMarkToolsAsSubmitted).toHaveBeenCalledTimes(0);
expect(localMockSendMessageStream).toHaveBeenCalledTimes(0);
});
const expectedMergedResponse = mergePartListUnions([
@@ -479,12 +509,21 @@ describe('useGeminiStream', () => {
client,
);
await act(async () => {
rerender({} as any);
act(() => {
rerender({
client,
addItem: mockAddItem as unknown as UseHistoryManagerReturn['addItem'],
setShowHelp: mockSetShowHelp,
config: mockConfig,
onDebugMessage: mockOnDebugMessage,
handleSlashCommand:
mockHandleSlashCommand as unknown as typeof mockHandleSlashCommand,
shellModeActive: false,
});
});
await waitFor(() => {
expect(mockMarkToolsAsSubmitted).toHaveBeenCalledWith(['1']);
expect(mockMarkToolsAsSubmitted).toHaveBeenCalledTimes(0);
expect(client.addHistory).toHaveBeenCalledTimes(2);
expect(client.addHistory).toHaveBeenCalledWith({
role: 'user',

View File

@@ -83,28 +83,24 @@ export const useGeminiStream = (
useStateAndRef<HistoryItemWithoutId | null>(null);
const logger = useLogger();
const [
toolCalls,
scheduleToolCalls,
cancelAllToolCalls,
markToolsAsSubmitted,
] = useReactToolScheduler(
(completedToolCallsFromScheduler) => {
// This onComplete is called when ALL scheduled tools for a given batch are done.
if (completedToolCallsFromScheduler.length > 0) {
// Add the final state of these tools to the history for display.
// The new useEffect will handle submitting their responses.
addItem(
mapTrackedToolCallsToDisplay(
completedToolCallsFromScheduler as TrackedToolCall[],
),
Date.now(),
);
}
},
config,
setPendingHistoryItem,
);
const [toolCalls, scheduleToolCalls, markToolsAsSubmitted] =
useReactToolScheduler(
(completedToolCallsFromScheduler) => {
// This onComplete is called when ALL scheduled tools for a given batch are done.
if (completedToolCallsFromScheduler.length > 0) {
// Add the final state of these tools to the history for display.
// The new useEffect will handle submitting their responses.
addItem(
mapTrackedToolCallsToDisplay(
completedToolCallsFromScheduler as TrackedToolCall[],
),
Date.now(),
);
}
},
config,
setPendingHistoryItem,
);
const pendingToolCallGroupDisplay = useMemo(
() =>
@@ -143,10 +139,15 @@ export const useGeminiStream = (
return StreamingState.Idle;
}, [isResponding, toolCalls]);
useEffect(() => {
if (streamingState === StreamingState.Idle) {
abortControllerRef.current = null;
}
}, [streamingState]);
useInput((_input, key) => {
if (streamingState !== StreamingState.Idle && key.escape) {
abortControllerRef.current?.abort();
cancelAllToolCalls(); // Also cancel any pending/executing tool calls
}
});
@@ -191,7 +192,7 @@ export const useGeminiStream = (
name: toolName,
args: toolArgs,
};
scheduleToolCalls([toolCallRequest]);
scheduleToolCalls([toolCallRequest], abortSignal);
}
return { queryToSend: null, shouldProceed: false }; // Handled by scheduling the tool
}
@@ -330,9 +331,8 @@ export const useGeminiStream = (
userMessageTimestamp,
);
setIsResponding(false);
cancelAllToolCalls();
},
[addItem, pendingHistoryItemRef, setPendingHistoryItem, cancelAllToolCalls],
[addItem, pendingHistoryItemRef, setPendingHistoryItem],
);
const handleErrorEvent = useCallback(
@@ -365,6 +365,7 @@ export const useGeminiStream = (
async (
stream: AsyncIterable<GeminiEvent>,
userMessageTimestamp: number,
signal: AbortSignal,
): Promise<StreamProcessingStatus> => {
let geminiMessageBuffer = '';
const toolCallRequests: ToolCallRequestInfo[] = [];
@@ -401,7 +402,7 @@ export const useGeminiStream = (
}
}
if (toolCallRequests.length > 0) {
scheduleToolCalls(toolCallRequests);
scheduleToolCalls(toolCallRequests, signal);
}
return StreamProcessingStatus.Completed;
},
@@ -453,6 +454,7 @@ export const useGeminiStream = (
const processingStatus = await processGeminiStreamEvents(
stream,
userMessageTimestamp,
abortSignal,
);
if (processingStatus === StreamProcessingStatus.UserCancelled) {
@@ -476,7 +478,6 @@ export const useGeminiStream = (
);
}
} finally {
abortControllerRef.current = null; // Always reset
setIsResponding(false);
}
},

View File

@@ -32,8 +32,8 @@ import {
export type ScheduleFn = (
request: ToolCallRequestInfo | ToolCallRequestInfo[],
signal: AbortSignal,
) => void;
export type CancelFn = (reason?: string) => void;
export type MarkToolsAsSubmittedFn = (callIds: string[]) => void;
export type TrackedScheduledToolCall = ScheduledToolCall & {
@@ -69,7 +69,7 @@ export function useReactToolScheduler(
setPendingHistoryItem: React.Dispatch<
React.SetStateAction<HistoryItemWithoutId | null>
>,
): [TrackedToolCall[], ScheduleFn, CancelFn, MarkToolsAsSubmittedFn] {
): [TrackedToolCall[], ScheduleFn, MarkToolsAsSubmittedFn] {
const [toolCallsForDisplay, setToolCallsForDisplay] = useState<
TrackedToolCall[]
>([]);
@@ -172,15 +172,11 @@ export function useReactToolScheduler(
);
const schedule: ScheduleFn = useCallback(
async (request: ToolCallRequestInfo | ToolCallRequestInfo[]) => {
scheduler.schedule(request);
},
[scheduler],
);
const cancel: CancelFn = useCallback(
(reason: string = 'unspecified') => {
scheduler.cancelAll(reason);
async (
request: ToolCallRequestInfo | ToolCallRequestInfo[],
signal: AbortSignal,
) => {
scheduler.schedule(request, signal);
},
[scheduler],
);
@@ -198,7 +194,7 @@ export function useReactToolScheduler(
[],
);
return [toolCallsForDisplay, schedule, cancel, markToolsAsSubmitted];
return [toolCallsForDisplay, schedule, markToolsAsSubmitted];
}
/**

View File

@@ -137,7 +137,7 @@ describe('useReactToolScheduler in YOLO Mode', () => {
};
act(() => {
schedule(request);
schedule(request, new AbortController().signal);
});
await act(async () => {
@@ -290,7 +290,7 @@ describe('useReactToolScheduler', () => {
};
act(() => {
schedule(request);
schedule(request, new AbortController().signal);
});
await act(async () => {
await vi.runAllTimersAsync();
@@ -337,7 +337,7 @@ describe('useReactToolScheduler', () => {
};
act(() => {
schedule(request);
schedule(request, new AbortController().signal);
});
await act(async () => {
await vi.runAllTimersAsync();
@@ -374,7 +374,7 @@ describe('useReactToolScheduler', () => {
};
act(() => {
schedule(request);
schedule(request, new AbortController().signal);
});
await act(async () => {
await vi.runAllTimersAsync();
@@ -410,7 +410,7 @@ describe('useReactToolScheduler', () => {
};
act(() => {
schedule(request);
schedule(request, new AbortController().signal);
});
await act(async () => {
await vi.runAllTimersAsync();
@@ -451,7 +451,7 @@ describe('useReactToolScheduler', () => {
};
act(() => {
schedule(request);
schedule(request, new AbortController().signal);
});
await act(async () => {
await vi.runAllTimersAsync();
@@ -507,7 +507,7 @@ describe('useReactToolScheduler', () => {
};
act(() => {
schedule(request);
schedule(request, new AbortController().signal);
});
await act(async () => {
await vi.runAllTimersAsync();
@@ -579,7 +579,7 @@ describe('useReactToolScheduler', () => {
};
act(() => {
schedule(request);
schedule(request, new AbortController().signal);
});
await act(async () => {
await vi.runAllTimersAsync();
@@ -634,102 +634,6 @@ describe('useReactToolScheduler', () => {
expect(result.current[0]).toEqual([]);
});
it.skip('should cancel tool calls before execution (e.g. when status is scheduled)', async () => {
mockToolRegistry.getTool.mockReturnValue(mockTool);
(mockTool.shouldConfirmExecute as Mock).mockResolvedValue(null);
(mockTool.execute as Mock).mockReturnValue(new Promise(() => {}));
const { result } = renderScheduler();
const schedule = result.current[1];
const cancel = result.current[2];
const request: ToolCallRequestInfo = {
callId: 'cancelCall',
name: 'mockTool',
args: {},
};
act(() => {
schedule(request);
});
await act(async () => {
await vi.runAllTimersAsync();
});
act(() => {
cancel();
});
await act(async () => {
await vi.runAllTimersAsync();
});
expect(onComplete).toHaveBeenCalledWith([
expect.objectContaining({
status: 'cancelled',
request,
response: expect.objectContaining({
responseParts: expect.arrayContaining([
expect.objectContaining({
functionResponse: expect.objectContaining({
response: expect.objectContaining({
error:
'[Operation Cancelled] Reason: User cancelled before execution',
}),
}),
}),
]),
}),
}),
]);
expect(mockTool.execute).not.toHaveBeenCalled();
expect(result.current[0]).toEqual([]);
});
it.skip('should cancel tool calls that are awaiting approval', async () => {
mockToolRegistry.getTool.mockReturnValue(mockToolRequiresConfirmation);
const { result } = renderScheduler();
const schedule = result.current[1];
const cancelFn = result.current[2];
const request: ToolCallRequestInfo = {
callId: 'cancelApprovalCall',
name: 'mockToolRequiresConfirmation',
args: {},
};
act(() => {
schedule(request);
});
await act(async () => {
await vi.runAllTimersAsync();
});
act(() => {
cancelFn();
});
await act(async () => {
await vi.runAllTimersAsync();
});
expect(onComplete).toHaveBeenCalledWith([
expect.objectContaining({
status: 'cancelled',
request,
response: expect.objectContaining({
responseParts: expect.arrayContaining([
expect.objectContaining({
functionResponse: expect.objectContaining({
response: expect.objectContaining({
error:
'[Operation Cancelled] Reason: User cancelled during approval',
}),
}),
}),
]),
}),
}),
]);
expect(result.current[0]).toEqual([]);
});
it('should schedule and execute multiple tool calls', async () => {
const tool1 = {
...mockTool,
@@ -766,7 +670,7 @@ describe('useReactToolScheduler', () => {
];
act(() => {
schedule(requests);
schedule(requests, new AbortController().signal);
});
await act(async () => {
await vi.runAllTimersAsync();
@@ -848,13 +752,13 @@ describe('useReactToolScheduler', () => {
};
act(() => {
schedule(request1);
schedule(request1, new AbortController().signal);
});
await act(async () => {
await vi.runAllTimersAsync();
});
expect(() => schedule(request2)).toThrow(
expect(() => schedule(request2, new AbortController().signal)).toThrow(
'Cannot schedule tool calls while other tool calls are running',
);