fix(tool-scheduler): Correctly pipe cancellation signal to tool calls (#852)

This commit is contained in:
N. Taylor Mullen
2025-06-08 15:42:49 -07:00
committed by GitHub
parent 7868ef8229
commit f2ea78d0e4
7 changed files with 235 additions and 209 deletions

View File

@@ -4,9 +4,110 @@
* SPDX-License-Identifier: Apache-2.0
*/
import { describe, it, expect } from 'vitest';
import { convertToFunctionResponse } from './coreToolScheduler.js';
/* eslint-disable @typescript-eslint/no-explicit-any */
import { describe, it, expect, vi } from 'vitest';
import {
CoreToolScheduler,
ToolCall,
ValidatingToolCall,
} from './coreToolScheduler.js';
import {
BaseTool,
ToolCallConfirmationDetails,
ToolConfirmationOutcome,
ToolResult,
} from '../index.js';
import { Part, PartListUnion } from '@google/genai';
import { convertToFunctionResponse } from './coreToolScheduler.js';
class MockTool extends BaseTool<Record<string, unknown>, ToolResult> {
shouldConfirm = false;
executeFn = vi.fn();
constructor(name = 'mockTool') {
super(name, name, 'A mock tool', {});
}
async shouldConfirmExecute(
_params: Record<string, unknown>,
_abortSignal: AbortSignal,
): Promise<ToolCallConfirmationDetails | false> {
if (this.shouldConfirm) {
return {
type: 'exec',
title: 'Confirm Mock Tool',
command: 'do_thing',
rootCommand: 'do_thing',
onConfirm: async () => {},
};
}
return false;
}
async execute(
params: Record<string, unknown>,
_abortSignal: AbortSignal,
): Promise<ToolResult> {
this.executeFn(params);
return { llmContent: 'Tool executed', returnDisplay: 'Tool executed' };
}
}
describe('CoreToolScheduler', () => {
it('should cancel a tool call if the signal is aborted before confirmation', async () => {
const mockTool = new MockTool();
mockTool.shouldConfirm = true;
const toolRegistry = {
getTool: () => mockTool,
getFunctionDeclarations: () => [],
tools: new Map(),
discovery: {} as any,
config: {} as any,
registerTool: () => {},
getToolByName: () => mockTool,
getToolByDisplayName: () => mockTool,
getTools: () => [],
discoverTools: async () => {},
getAllTools: () => [],
getToolsByServer: () => [],
};
const onAllToolCallsComplete = vi.fn();
const onToolCallsUpdate = vi.fn();
const scheduler = new CoreToolScheduler({
toolRegistry: Promise.resolve(toolRegistry as any),
onAllToolCallsComplete,
onToolCallsUpdate,
});
const abortController = new AbortController();
const request = { callId: '1', name: 'mockTool', args: {} };
abortController.abort();
await scheduler.schedule([request], abortController.signal);
const _waitingCall = onToolCallsUpdate.mock
.calls[1][0][0] as ValidatingToolCall;
const confirmationDetails = await mockTool.shouldConfirmExecute(
{},
abortController.signal,
);
if (confirmationDetails) {
await scheduler.handleConfirmationResponse(
'1',
confirmationDetails.onConfirm,
ToolConfirmationOutcome.ProceedOnce,
abortController.signal,
);
}
expect(onAllToolCallsComplete).toHaveBeenCalled();
const completedCalls = onAllToolCallsComplete.mock
.calls[0][0] as ToolCall[];
expect(completedCalls[0].status).toBe('cancelled');
});
});
describe('convertToFunctionResponse', () => {
const toolName = 'testTool';

View File

@@ -208,7 +208,6 @@ interface CoreToolSchedulerOptions {
export class CoreToolScheduler {
private toolRegistry: Promise<ToolRegistry>;
private toolCalls: ToolCall[] = [];
private abortController: AbortController;
private outputUpdateHandler?: OutputUpdateHandler;
private onAllToolCallsComplete?: AllToolCallsCompleteHandler;
private onToolCallsUpdate?: ToolCallsUpdateHandler;
@@ -220,7 +219,6 @@ export class CoreToolScheduler {
this.onAllToolCallsComplete = options.onAllToolCallsComplete;
this.onToolCallsUpdate = options.onToolCallsUpdate;
this.approvalMode = options.approvalMode ?? ApprovalMode.DEFAULT;
this.abortController = new AbortController();
}
private setStatusInternal(
@@ -379,6 +377,7 @@ export class CoreToolScheduler {
async schedule(
request: ToolCallRequestInfo | ToolCallRequestInfo[],
signal: AbortSignal,
): Promise<void> {
if (this.isRunning()) {
throw new Error(
@@ -426,7 +425,7 @@ export class CoreToolScheduler {
} else {
const confirmationDetails = await toolInstance.shouldConfirmExecute(
reqInfo.args,
this.abortController.signal,
signal,
);
if (confirmationDetails) {
@@ -438,6 +437,7 @@ export class CoreToolScheduler {
reqInfo.callId,
originalOnConfirm,
outcome,
signal,
),
};
this.setStatusInternal(
@@ -460,7 +460,7 @@ export class CoreToolScheduler {
);
}
}
this.attemptExecutionOfScheduledCalls();
this.attemptExecutionOfScheduledCalls(signal);
this.checkAndNotifyCompletion();
}
@@ -468,6 +468,7 @@ export class CoreToolScheduler {
callId: string,
originalOnConfirm: (outcome: ToolConfirmationOutcome) => Promise<void>,
outcome: ToolConfirmationOutcome,
signal: AbortSignal,
): Promise<void> {
const toolCall = this.toolCalls.find(
(c) => c.request.callId === callId && c.status === 'awaiting_approval',
@@ -477,7 +478,7 @@ export class CoreToolScheduler {
await originalOnConfirm(outcome);
}
if (outcome === ToolConfirmationOutcome.Cancel) {
if (outcome === ToolConfirmationOutcome.Cancel || signal.aborted) {
this.setStatusInternal(
callId,
'cancelled',
@@ -497,7 +498,7 @@ export class CoreToolScheduler {
const modifyResults = await editTool.onModify(
waitingToolCall.request.args as unknown as EditToolParams,
this.abortController.signal,
signal,
outcome,
);
@@ -513,10 +514,10 @@ export class CoreToolScheduler {
} else {
this.setStatusInternal(callId, 'scheduled');
}
this.attemptExecutionOfScheduledCalls();
this.attemptExecutionOfScheduledCalls(signal);
}
private attemptExecutionOfScheduledCalls(): void {
private attemptExecutionOfScheduledCalls(signal: AbortSignal): void {
const allCallsFinalOrScheduled = this.toolCalls.every(
(call) =>
call.status === 'scheduled' ||
@@ -553,17 +554,13 @@ export class CoreToolScheduler {
: undefined;
scheduledCall.tool
.execute(
scheduledCall.request.args,
this.abortController.signal,
liveOutputCallback,
)
.execute(scheduledCall.request.args, signal, liveOutputCallback)
.then((toolResult: ToolResult) => {
if (this.abortController.signal.aborted) {
if (signal.aborted) {
this.setStatusInternal(
callId,
'cancelled',
this.abortController.signal.reason || 'Execution aborted.',
'User cancelled tool execution.',
);
return;
}
@@ -613,29 +610,10 @@ export class CoreToolScheduler {
if (this.onAllToolCallsComplete) {
this.onAllToolCallsComplete(completedCalls);
}
this.abortController = new AbortController();
this.notifyToolCallsUpdate();
}
}
cancelAll(reason: string = 'User initiated cancellation.'): void {
if (!this.abortController.signal.aborted) {
this.abortController.abort(reason);
}
this.abortController = new AbortController();
const callsToCancel = [...this.toolCalls];
callsToCancel.forEach((call) => {
if (
call.status !== 'error' &&
call.status !== 'success' &&
call.status !== 'cancelled'
) {
this.setStatusInternal(call.request.callId, 'cancelled', reason);
}
});
}
private notifyToolCallsUpdate(): void {
if (this.onToolCallsUpdate) {
this.onToolCallsUpdate([...this.toolCalls]);

View File

@@ -162,6 +162,13 @@ export class ShellTool extends BaseTool<ShellToolParams, ToolResult> {
};
}
if (abortSignal.aborted) {
return {
llmContent: 'Command was cancelled by user before it could start.',
returnDisplay: 'Command cancelled by user.',
};
}
// wrap command to append subprocess pids (via pgrep) to temporary file
const tempFileName = `shell_pgrep_${crypto.randomBytes(6).toString('hex')}.tmp`;
const tempFilePath = path.join(os.tmpdir(), tempFileName);