feat: useToolScheduler hook to manage parallel tool calls (#448)

This commit is contained in:
Brandon Keiji
2025-05-22 05:57:53 +00:00
committed by GitHub
parent efee7c6cce
commit 02eec5c8ca
6 changed files with 109 additions and 369 deletions

View File

@@ -155,10 +155,9 @@ export class GeminiClient {
signal?: AbortSignal,
): AsyncGenerator<ServerGeminiStreamEvent> {
let turns = 0;
const availableTools = this.config.getToolRegistry().getAllTools();
while (turns < this.MAX_TURNS) {
turns++;
const turn = new Turn(chat, availableTools);
const turn = new Turn(chat);
const resultStream = turn.run(request, signal);
let seenError = false;
for await (const event of resultStream) {

View File

@@ -21,18 +21,6 @@ import { getResponseText } from '../utils/generateContentResponseUtilities.js';
import { reportError } from '../utils/errorReporting.js';
import { getErrorMessage } from '../utils/errors.js';
// --- Types for Server Logic ---
// Define a simpler structure for Tool execution results within the server
interface ServerToolExecutionOutcome {
callId: string;
name: string;
args: Record<string, unknown>;
result?: ToolResult;
error?: Error;
confirmationDetails: ToolCallConfirmationDetails | undefined;
}
// Define a structure for tools passed to the server
export interface ServerTool {
name: string;
@@ -118,7 +106,6 @@ export type ServerGeminiStreamEvent =
// A turn manages the agentic loop turn within the server context.
export class Turn {
private readonly availableTools: Map<string, ServerTool>;
private pendingToolCalls: Array<{
callId: string;
name: string;
@@ -128,11 +115,7 @@ export class Turn {
private confirmationDetails: ToolCallConfirmationDetails[];
private debugResponses: GenerateContentResponse[];
constructor(
private readonly chat: Chat,
availableTools: ServerTool[],
) {
this.availableTools = new Map(availableTools.map((t) => [t.name, t]));
constructor(private readonly chat: Chat) {
this.pendingToolCalls = [];
this.fnResponses = [];
this.confirmationDetails = [];
@@ -160,12 +143,9 @@ export class Turn {
yield { type: GeminiEventType.Content, value: text };
}
if (!resp.functionCalls) {
continue;
}
// Handle function calls (requesting tool execution)
for (const fnCall of resp.functionCalls) {
const functionCalls = resp.functionCalls ?? [];
for (const fnCall of functionCalls) {
const event = this.handlePendingFunctionCall(fnCall);
if (event) {
yield event;
@@ -184,80 +164,6 @@ export class Turn {
yield { type: GeminiEventType.Error, value: { message: errorMessage } };
return;
}
// Execute pending tool calls
const toolPromises = this.pendingToolCalls.map(
async (pendingToolCall): Promise<ServerToolExecutionOutcome> => {
const tool = this.availableTools.get(pendingToolCall.name);
if (!tool) {
return {
...pendingToolCall,
error: new Error(
`Tool "${pendingToolCall.name}" not found or not provided to Turn.`,
),
confirmationDetails: undefined,
};
}
try {
const confirmationDetails = await tool.shouldConfirmExecute(
pendingToolCall.args,
);
if (confirmationDetails) {
return { ...pendingToolCall, confirmationDetails };
}
const result = await tool.execute(pendingToolCall.args, signal);
return {
...pendingToolCall,
result,
confirmationDetails: undefined,
};
} catch (execError: unknown) {
return {
...pendingToolCall,
error: new Error(
`Tool execution failed: ${execError instanceof Error ? execError.message : String(execError)}`,
),
confirmationDetails: undefined,
};
}
},
);
const outcomes = await Promise.all(toolPromises);
// Process outcomes and prepare function responses
this.pendingToolCalls = []; // Clear pending calls for this turn
for (const outcome of outcomes) {
if (outcome.confirmationDetails) {
this.confirmationDetails.push(outcome.confirmationDetails);
const serverConfirmationetails: ServerToolCallConfirmationDetails = {
request: {
callId: outcome.callId,
name: outcome.name,
args: outcome.args,
},
details: outcome.confirmationDetails,
};
yield {
type: GeminiEventType.ToolCallConfirmation,
value: serverConfirmationetails,
};
}
const responsePart = this.buildFunctionResponse(outcome);
this.fnResponses.push(responsePart);
const responseInfo: ToolCallResponseInfo = {
callId: outcome.callId,
responsePart,
resultDisplay: outcome.result?.returnDisplay,
error: outcome.error,
};
// If aborted we're already yielding the user cancellations elsewhere.
if (!signal?.aborted) {
yield { type: GeminiEventType.ToolCallResponse, value: responseInfo };
}
}
}
private handlePendingFunctionCall(
@@ -276,30 +182,6 @@ export class Turn {
return { type: GeminiEventType.ToolCallRequest, value };
}
// Builds the Part array expected by the Google GenAI API
private buildFunctionResponse(outcome: ServerToolExecutionOutcome): Part {
const { name, result, error } = outcome;
if (error) {
// Format error for the LLM
const errorMessage = error?.message || String(error);
console.error(`[Server Turn] Error executing tool ${name}:`, error);
return {
functionResponse: {
name,
id: outcome.callId,
response: { error: `Tool execution failed: ${errorMessage}` },
},
};
}
return {
functionResponse: {
name,
id: outcome.callId,
response: { output: result?.llmContent ?? '' },
},
};
}
getConfirmationDetails(): ToolCallConfirmationDetails[] {
return this.confirmationDetails;
}

View File

@@ -171,23 +171,28 @@ export interface FileDiff {
fileName: string;
}
export interface ToolCallConfirmationDetails {
export interface ToolCallConfirmationDetailsDefault {
title: string;
onConfirm: (outcome: ToolConfirmationOutcome) => Promise<void>;
}
export interface ToolEditConfirmationDetails
extends ToolCallConfirmationDetails {
extends ToolCallConfirmationDetailsDefault {
fileName: string;
fileDiff: string;
}
export interface ToolExecuteConfirmationDetails
extends ToolCallConfirmationDetails {
extends ToolCallConfirmationDetailsDefault {
command: string;
rootCommand: string;
}
export type ToolCallConfirmationDetails =
| ToolCallConfirmationDetailsDefault
| ToolEditConfirmationDetails
| ToolExecuteConfirmationDetails;
export enum ToolConfirmationOutcome {
ProceedOnce,
ProceedAlways,