mirror of
https://github.com/QwenLM/qwen-code.git
synced 2025-12-19 09:33:53 +00:00
fix(tool-scheduler): Correctly pipe cancellation signal to tool calls (#852)
This commit is contained in:
@@ -4,9 +4,110 @@
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { describe, it, expect } from 'vitest';
|
||||
import { convertToFunctionResponse } from './coreToolScheduler.js';
|
||||
/* eslint-disable @typescript-eslint/no-explicit-any */
|
||||
import { describe, it, expect, vi } from 'vitest';
|
||||
import {
|
||||
CoreToolScheduler,
|
||||
ToolCall,
|
||||
ValidatingToolCall,
|
||||
} from './coreToolScheduler.js';
|
||||
import {
|
||||
BaseTool,
|
||||
ToolCallConfirmationDetails,
|
||||
ToolConfirmationOutcome,
|
||||
ToolResult,
|
||||
} from '../index.js';
|
||||
import { Part, PartListUnion } from '@google/genai';
|
||||
import { convertToFunctionResponse } from './coreToolScheduler.js';
|
||||
|
||||
class MockTool extends BaseTool<Record<string, unknown>, ToolResult> {
|
||||
shouldConfirm = false;
|
||||
executeFn = vi.fn();
|
||||
|
||||
constructor(name = 'mockTool') {
|
||||
super(name, name, 'A mock tool', {});
|
||||
}
|
||||
|
||||
async shouldConfirmExecute(
|
||||
_params: Record<string, unknown>,
|
||||
_abortSignal: AbortSignal,
|
||||
): Promise<ToolCallConfirmationDetails | false> {
|
||||
if (this.shouldConfirm) {
|
||||
return {
|
||||
type: 'exec',
|
||||
title: 'Confirm Mock Tool',
|
||||
command: 'do_thing',
|
||||
rootCommand: 'do_thing',
|
||||
onConfirm: async () => {},
|
||||
};
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
async execute(
|
||||
params: Record<string, unknown>,
|
||||
_abortSignal: AbortSignal,
|
||||
): Promise<ToolResult> {
|
||||
this.executeFn(params);
|
||||
return { llmContent: 'Tool executed', returnDisplay: 'Tool executed' };
|
||||
}
|
||||
}
|
||||
|
||||
describe('CoreToolScheduler', () => {
|
||||
it('should cancel a tool call if the signal is aborted before confirmation', async () => {
|
||||
const mockTool = new MockTool();
|
||||
mockTool.shouldConfirm = true;
|
||||
const toolRegistry = {
|
||||
getTool: () => mockTool,
|
||||
getFunctionDeclarations: () => [],
|
||||
tools: new Map(),
|
||||
discovery: {} as any,
|
||||
config: {} as any,
|
||||
registerTool: () => {},
|
||||
getToolByName: () => mockTool,
|
||||
getToolByDisplayName: () => mockTool,
|
||||
getTools: () => [],
|
||||
discoverTools: async () => {},
|
||||
getAllTools: () => [],
|
||||
getToolsByServer: () => [],
|
||||
};
|
||||
|
||||
const onAllToolCallsComplete = vi.fn();
|
||||
const onToolCallsUpdate = vi.fn();
|
||||
|
||||
const scheduler = new CoreToolScheduler({
|
||||
toolRegistry: Promise.resolve(toolRegistry as any),
|
||||
onAllToolCallsComplete,
|
||||
onToolCallsUpdate,
|
||||
});
|
||||
|
||||
const abortController = new AbortController();
|
||||
const request = { callId: '1', name: 'mockTool', args: {} };
|
||||
|
||||
abortController.abort();
|
||||
await scheduler.schedule([request], abortController.signal);
|
||||
|
||||
const _waitingCall = onToolCallsUpdate.mock
|
||||
.calls[1][0][0] as ValidatingToolCall;
|
||||
const confirmationDetails = await mockTool.shouldConfirmExecute(
|
||||
{},
|
||||
abortController.signal,
|
||||
);
|
||||
if (confirmationDetails) {
|
||||
await scheduler.handleConfirmationResponse(
|
||||
'1',
|
||||
confirmationDetails.onConfirm,
|
||||
ToolConfirmationOutcome.ProceedOnce,
|
||||
abortController.signal,
|
||||
);
|
||||
}
|
||||
|
||||
expect(onAllToolCallsComplete).toHaveBeenCalled();
|
||||
const completedCalls = onAllToolCallsComplete.mock
|
||||
.calls[0][0] as ToolCall[];
|
||||
expect(completedCalls[0].status).toBe('cancelled');
|
||||
});
|
||||
});
|
||||
|
||||
describe('convertToFunctionResponse', () => {
|
||||
const toolName = 'testTool';
|
||||
|
||||
@@ -208,7 +208,6 @@ interface CoreToolSchedulerOptions {
|
||||
export class CoreToolScheduler {
|
||||
private toolRegistry: Promise<ToolRegistry>;
|
||||
private toolCalls: ToolCall[] = [];
|
||||
private abortController: AbortController;
|
||||
private outputUpdateHandler?: OutputUpdateHandler;
|
||||
private onAllToolCallsComplete?: AllToolCallsCompleteHandler;
|
||||
private onToolCallsUpdate?: ToolCallsUpdateHandler;
|
||||
@@ -220,7 +219,6 @@ export class CoreToolScheduler {
|
||||
this.onAllToolCallsComplete = options.onAllToolCallsComplete;
|
||||
this.onToolCallsUpdate = options.onToolCallsUpdate;
|
||||
this.approvalMode = options.approvalMode ?? ApprovalMode.DEFAULT;
|
||||
this.abortController = new AbortController();
|
||||
}
|
||||
|
||||
private setStatusInternal(
|
||||
@@ -379,6 +377,7 @@ export class CoreToolScheduler {
|
||||
|
||||
async schedule(
|
||||
request: ToolCallRequestInfo | ToolCallRequestInfo[],
|
||||
signal: AbortSignal,
|
||||
): Promise<void> {
|
||||
if (this.isRunning()) {
|
||||
throw new Error(
|
||||
@@ -426,7 +425,7 @@ export class CoreToolScheduler {
|
||||
} else {
|
||||
const confirmationDetails = await toolInstance.shouldConfirmExecute(
|
||||
reqInfo.args,
|
||||
this.abortController.signal,
|
||||
signal,
|
||||
);
|
||||
|
||||
if (confirmationDetails) {
|
||||
@@ -438,6 +437,7 @@ export class CoreToolScheduler {
|
||||
reqInfo.callId,
|
||||
originalOnConfirm,
|
||||
outcome,
|
||||
signal,
|
||||
),
|
||||
};
|
||||
this.setStatusInternal(
|
||||
@@ -460,7 +460,7 @@ export class CoreToolScheduler {
|
||||
);
|
||||
}
|
||||
}
|
||||
this.attemptExecutionOfScheduledCalls();
|
||||
this.attemptExecutionOfScheduledCalls(signal);
|
||||
this.checkAndNotifyCompletion();
|
||||
}
|
||||
|
||||
@@ -468,6 +468,7 @@ export class CoreToolScheduler {
|
||||
callId: string,
|
||||
originalOnConfirm: (outcome: ToolConfirmationOutcome) => Promise<void>,
|
||||
outcome: ToolConfirmationOutcome,
|
||||
signal: AbortSignal,
|
||||
): Promise<void> {
|
||||
const toolCall = this.toolCalls.find(
|
||||
(c) => c.request.callId === callId && c.status === 'awaiting_approval',
|
||||
@@ -477,7 +478,7 @@ export class CoreToolScheduler {
|
||||
await originalOnConfirm(outcome);
|
||||
}
|
||||
|
||||
if (outcome === ToolConfirmationOutcome.Cancel) {
|
||||
if (outcome === ToolConfirmationOutcome.Cancel || signal.aborted) {
|
||||
this.setStatusInternal(
|
||||
callId,
|
||||
'cancelled',
|
||||
@@ -497,7 +498,7 @@ export class CoreToolScheduler {
|
||||
|
||||
const modifyResults = await editTool.onModify(
|
||||
waitingToolCall.request.args as unknown as EditToolParams,
|
||||
this.abortController.signal,
|
||||
signal,
|
||||
outcome,
|
||||
);
|
||||
|
||||
@@ -513,10 +514,10 @@ export class CoreToolScheduler {
|
||||
} else {
|
||||
this.setStatusInternal(callId, 'scheduled');
|
||||
}
|
||||
this.attemptExecutionOfScheduledCalls();
|
||||
this.attemptExecutionOfScheduledCalls(signal);
|
||||
}
|
||||
|
||||
private attemptExecutionOfScheduledCalls(): void {
|
||||
private attemptExecutionOfScheduledCalls(signal: AbortSignal): void {
|
||||
const allCallsFinalOrScheduled = this.toolCalls.every(
|
||||
(call) =>
|
||||
call.status === 'scheduled' ||
|
||||
@@ -553,17 +554,13 @@ export class CoreToolScheduler {
|
||||
: undefined;
|
||||
|
||||
scheduledCall.tool
|
||||
.execute(
|
||||
scheduledCall.request.args,
|
||||
this.abortController.signal,
|
||||
liveOutputCallback,
|
||||
)
|
||||
.execute(scheduledCall.request.args, signal, liveOutputCallback)
|
||||
.then((toolResult: ToolResult) => {
|
||||
if (this.abortController.signal.aborted) {
|
||||
if (signal.aborted) {
|
||||
this.setStatusInternal(
|
||||
callId,
|
||||
'cancelled',
|
||||
this.abortController.signal.reason || 'Execution aborted.',
|
||||
'User cancelled tool execution.',
|
||||
);
|
||||
return;
|
||||
}
|
||||
@@ -613,29 +610,10 @@ export class CoreToolScheduler {
|
||||
if (this.onAllToolCallsComplete) {
|
||||
this.onAllToolCallsComplete(completedCalls);
|
||||
}
|
||||
this.abortController = new AbortController();
|
||||
this.notifyToolCallsUpdate();
|
||||
}
|
||||
}
|
||||
|
||||
cancelAll(reason: string = 'User initiated cancellation.'): void {
|
||||
if (!this.abortController.signal.aborted) {
|
||||
this.abortController.abort(reason);
|
||||
}
|
||||
this.abortController = new AbortController();
|
||||
|
||||
const callsToCancel = [...this.toolCalls];
|
||||
callsToCancel.forEach((call) => {
|
||||
if (
|
||||
call.status !== 'error' &&
|
||||
call.status !== 'success' &&
|
||||
call.status !== 'cancelled'
|
||||
) {
|
||||
this.setStatusInternal(call.request.callId, 'cancelled', reason);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
private notifyToolCallsUpdate(): void {
|
||||
if (this.onToolCallsUpdate) {
|
||||
this.onToolCallsUpdate([...this.toolCalls]);
|
||||
|
||||
@@ -162,6 +162,13 @@ export class ShellTool extends BaseTool<ShellToolParams, ToolResult> {
|
||||
};
|
||||
}
|
||||
|
||||
if (abortSignal.aborted) {
|
||||
return {
|
||||
llmContent: 'Command was cancelled by user before it could start.',
|
||||
returnDisplay: 'Command cancelled by user.',
|
||||
};
|
||||
}
|
||||
|
||||
// wrap command to append subprocess pids (via pgrep) to temporary file
|
||||
const tempFileName = `shell_pgrep_${crypto.randomBytes(6).toString('hex')}.tmp`;
|
||||
const tempFilePath = path.join(os.tmpdir(), tempFileName);
|
||||
|
||||
Reference in New Issue
Block a user