mirror of
https://github.com/QwenLM/qwen-code.git
synced 2025-12-21 01:07:46 +00:00
feat: useToolScheduler hook to manage parallel tool calls (#448)
This commit is contained in:
@@ -155,10 +155,9 @@ export class GeminiClient {
|
||||
signal?: AbortSignal,
|
||||
): AsyncGenerator<ServerGeminiStreamEvent> {
|
||||
let turns = 0;
|
||||
const availableTools = this.config.getToolRegistry().getAllTools();
|
||||
while (turns < this.MAX_TURNS) {
|
||||
turns++;
|
||||
const turn = new Turn(chat, availableTools);
|
||||
const turn = new Turn(chat);
|
||||
const resultStream = turn.run(request, signal);
|
||||
let seenError = false;
|
||||
for await (const event of resultStream) {
|
||||
|
||||
@@ -21,18 +21,6 @@ import { getResponseText } from '../utils/generateContentResponseUtilities.js';
|
||||
import { reportError } from '../utils/errorReporting.js';
|
||||
import { getErrorMessage } from '../utils/errors.js';
|
||||
|
||||
// --- Types for Server Logic ---
|
||||
|
||||
// Define a simpler structure for Tool execution results within the server
|
||||
interface ServerToolExecutionOutcome {
|
||||
callId: string;
|
||||
name: string;
|
||||
args: Record<string, unknown>;
|
||||
result?: ToolResult;
|
||||
error?: Error;
|
||||
confirmationDetails: ToolCallConfirmationDetails | undefined;
|
||||
}
|
||||
|
||||
// Define a structure for tools passed to the server
|
||||
export interface ServerTool {
|
||||
name: string;
|
||||
@@ -118,7 +106,6 @@ export type ServerGeminiStreamEvent =
|
||||
|
||||
// A turn manages the agentic loop turn within the server context.
|
||||
export class Turn {
|
||||
private readonly availableTools: Map<string, ServerTool>;
|
||||
private pendingToolCalls: Array<{
|
||||
callId: string;
|
||||
name: string;
|
||||
@@ -128,11 +115,7 @@ export class Turn {
|
||||
private confirmationDetails: ToolCallConfirmationDetails[];
|
||||
private debugResponses: GenerateContentResponse[];
|
||||
|
||||
constructor(
|
||||
private readonly chat: Chat,
|
||||
availableTools: ServerTool[],
|
||||
) {
|
||||
this.availableTools = new Map(availableTools.map((t) => [t.name, t]));
|
||||
constructor(private readonly chat: Chat) {
|
||||
this.pendingToolCalls = [];
|
||||
this.fnResponses = [];
|
||||
this.confirmationDetails = [];
|
||||
@@ -160,12 +143,9 @@ export class Turn {
|
||||
yield { type: GeminiEventType.Content, value: text };
|
||||
}
|
||||
|
||||
if (!resp.functionCalls) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Handle function calls (requesting tool execution)
|
||||
for (const fnCall of resp.functionCalls) {
|
||||
const functionCalls = resp.functionCalls ?? [];
|
||||
for (const fnCall of functionCalls) {
|
||||
const event = this.handlePendingFunctionCall(fnCall);
|
||||
if (event) {
|
||||
yield event;
|
||||
@@ -184,80 +164,6 @@ export class Turn {
|
||||
yield { type: GeminiEventType.Error, value: { message: errorMessage } };
|
||||
return;
|
||||
}
|
||||
|
||||
// Execute pending tool calls
|
||||
const toolPromises = this.pendingToolCalls.map(
|
||||
async (pendingToolCall): Promise<ServerToolExecutionOutcome> => {
|
||||
const tool = this.availableTools.get(pendingToolCall.name);
|
||||
if (!tool) {
|
||||
return {
|
||||
...pendingToolCall,
|
||||
error: new Error(
|
||||
`Tool "${pendingToolCall.name}" not found or not provided to Turn.`,
|
||||
),
|
||||
confirmationDetails: undefined,
|
||||
};
|
||||
}
|
||||
|
||||
try {
|
||||
const confirmationDetails = await tool.shouldConfirmExecute(
|
||||
pendingToolCall.args,
|
||||
);
|
||||
if (confirmationDetails) {
|
||||
return { ...pendingToolCall, confirmationDetails };
|
||||
}
|
||||
const result = await tool.execute(pendingToolCall.args, signal);
|
||||
return {
|
||||
...pendingToolCall,
|
||||
result,
|
||||
confirmationDetails: undefined,
|
||||
};
|
||||
} catch (execError: unknown) {
|
||||
return {
|
||||
...pendingToolCall,
|
||||
error: new Error(
|
||||
`Tool execution failed: ${execError instanceof Error ? execError.message : String(execError)}`,
|
||||
),
|
||||
confirmationDetails: undefined,
|
||||
};
|
||||
}
|
||||
},
|
||||
);
|
||||
const outcomes = await Promise.all(toolPromises);
|
||||
|
||||
// Process outcomes and prepare function responses
|
||||
this.pendingToolCalls = []; // Clear pending calls for this turn
|
||||
|
||||
for (const outcome of outcomes) {
|
||||
if (outcome.confirmationDetails) {
|
||||
this.confirmationDetails.push(outcome.confirmationDetails);
|
||||
const serverConfirmationetails: ServerToolCallConfirmationDetails = {
|
||||
request: {
|
||||
callId: outcome.callId,
|
||||
name: outcome.name,
|
||||
args: outcome.args,
|
||||
},
|
||||
details: outcome.confirmationDetails,
|
||||
};
|
||||
yield {
|
||||
type: GeminiEventType.ToolCallConfirmation,
|
||||
value: serverConfirmationetails,
|
||||
};
|
||||
}
|
||||
const responsePart = this.buildFunctionResponse(outcome);
|
||||
this.fnResponses.push(responsePart);
|
||||
const responseInfo: ToolCallResponseInfo = {
|
||||
callId: outcome.callId,
|
||||
responsePart,
|
||||
resultDisplay: outcome.result?.returnDisplay,
|
||||
error: outcome.error,
|
||||
};
|
||||
|
||||
// If aborted we're already yielding the user cancellations elsewhere.
|
||||
if (!signal?.aborted) {
|
||||
yield { type: GeminiEventType.ToolCallResponse, value: responseInfo };
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private handlePendingFunctionCall(
|
||||
@@ -276,30 +182,6 @@ export class Turn {
|
||||
return { type: GeminiEventType.ToolCallRequest, value };
|
||||
}
|
||||
|
||||
// Builds the Part array expected by the Google GenAI API
|
||||
private buildFunctionResponse(outcome: ServerToolExecutionOutcome): Part {
|
||||
const { name, result, error } = outcome;
|
||||
if (error) {
|
||||
// Format error for the LLM
|
||||
const errorMessage = error?.message || String(error);
|
||||
console.error(`[Server Turn] Error executing tool ${name}:`, error);
|
||||
return {
|
||||
functionResponse: {
|
||||
name,
|
||||
id: outcome.callId,
|
||||
response: { error: `Tool execution failed: ${errorMessage}` },
|
||||
},
|
||||
};
|
||||
}
|
||||
return {
|
||||
functionResponse: {
|
||||
name,
|
||||
id: outcome.callId,
|
||||
response: { output: result?.llmContent ?? '' },
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
getConfirmationDetails(): ToolCallConfirmationDetails[] {
|
||||
return this.confirmationDetails;
|
||||
}
|
||||
|
||||
@@ -171,23 +171,28 @@ export interface FileDiff {
|
||||
fileName: string;
|
||||
}
|
||||
|
||||
export interface ToolCallConfirmationDetails {
|
||||
export interface ToolCallConfirmationDetailsDefault {
|
||||
title: string;
|
||||
onConfirm: (outcome: ToolConfirmationOutcome) => Promise<void>;
|
||||
}
|
||||
|
||||
export interface ToolEditConfirmationDetails
|
||||
extends ToolCallConfirmationDetails {
|
||||
extends ToolCallConfirmationDetailsDefault {
|
||||
fileName: string;
|
||||
fileDiff: string;
|
||||
}
|
||||
|
||||
export interface ToolExecuteConfirmationDetails
|
||||
extends ToolCallConfirmationDetails {
|
||||
extends ToolCallConfirmationDetailsDefault {
|
||||
command: string;
|
||||
rootCommand: string;
|
||||
}
|
||||
|
||||
export type ToolCallConfirmationDetails =
|
||||
| ToolCallConfirmationDetailsDefault
|
||||
| ToolEditConfirmationDetails
|
||||
| ToolExecuteConfirmationDetails;
|
||||
|
||||
export enum ToolConfirmationOutcome {
|
||||
ProceedOnce,
|
||||
ProceedAlways,
|
||||
|
||||
Reference in New Issue
Block a user