feat: Add client-initiated tool call handling (#1292)

This commit is contained in:
Abhi
2025-06-22 01:35:36 -04:00
committed by GitHub
parent 5cf8dc4f07
commit c9950b3cb2
8 changed files with 363 additions and 136 deletions

View File

@@ -88,7 +88,12 @@ describe('CoreToolScheduler', () => {
});
const abortController = new AbortController();
const request = { callId: '1', name: 'mockTool', args: {} };
const request = {
callId: '1',
name: 'mockTool',
args: {},
isClientInitiated: false,
};
abortController.abort();
await scheduler.schedule([request], abortController.signal);

View File

@@ -62,6 +62,7 @@ describe('executeToolCall', () => {
callId: 'call1',
name: 'testTool',
args: { param1: 'value1' },
isClientInitiated: false,
};
const toolResult: ToolResult = {
llmContent: 'Tool executed successfully',
@@ -99,6 +100,7 @@ describe('executeToolCall', () => {
callId: 'call2',
name: 'nonExistentTool',
args: {},
isClientInitiated: false,
};
vi.mocked(mockToolRegistry.getTool).mockReturnValue(undefined);
@@ -133,6 +135,7 @@ describe('executeToolCall', () => {
callId: 'call3',
name: 'testTool',
args: { param1: 'value1' },
isClientInitiated: false,
};
const executionError = new Error('Tool execution failed');
vi.mocked(mockToolRegistry.getTool).mockReturnValue(mockTool);
@@ -164,6 +167,7 @@ describe('executeToolCall', () => {
callId: 'call4',
name: 'testTool',
args: { param1: 'value1' },
isClientInitiated: false,
};
const cancellationError = new Error('Operation cancelled');
vi.mocked(mockToolRegistry.getTool).mockReturnValue(mockTool);
@@ -206,6 +210,7 @@ describe('executeToolCall', () => {
callId: 'call5',
name: 'testTool',
args: {},
isClientInitiated: false,
};
const imageDataPart: Part = {
inlineData: { mimeType: 'image/png', data: 'base64data' },

View File

@@ -132,8 +132,13 @@ describe('Turn', () => {
const mockResponseStream = (async function* () {
yield {
functionCalls: [
{ id: 'fc1', name: 'tool1', args: { arg1: 'val1' } },
{ name: 'tool2', args: { arg2: 'val2' } }, // No ID
{
id: 'fc1',
name: 'tool1',
args: { arg1: 'val1' },
isClientInitiated: false,
},
{ name: 'tool2', args: { arg2: 'val2' }, isClientInitiated: false }, // No ID
],
} as unknown as GenerateContentResponse;
})();
@@ -156,6 +161,7 @@ describe('Turn', () => {
callId: 'fc1',
name: 'tool1',
args: { arg1: 'val1' },
isClientInitiated: false,
}),
);
expect(turn.pendingToolCalls[0]).toEqual(event1.value);
@@ -163,7 +169,11 @@ describe('Turn', () => {
const event2 = events[1] as ServerGeminiToolCallRequestEvent;
expect(event2.type).toBe(GeminiEventType.ToolCallRequest);
expect(event2.value).toEqual(
expect.objectContaining({ name: 'tool2', args: { arg2: 'val2' } }),
expect.objectContaining({
name: 'tool2',
args: { arg2: 'val2' },
isClientInitiated: false,
}),
);
expect(event2.value.callId).toEqual(
expect.stringMatching(/^tool2-\d{13}-\w{10,}$/),
@@ -301,6 +311,7 @@ describe('Turn', () => {
callId: 'fc1',
name: 'undefined_tool_name',
args: { arg1: 'val1' },
isClientInitiated: false,
}),
);
expect(turn.pendingToolCalls[0]).toEqual(event1.value);
@@ -308,7 +319,12 @@ describe('Turn', () => {
const event2 = events[1] as ServerGeminiToolCallRequestEvent;
expect(event2.type).toBe(GeminiEventType.ToolCallRequest);
expect(event2.value).toEqual(
expect.objectContaining({ callId: 'fc2', name: 'tool2', args: {} }),
expect.objectContaining({
callId: 'fc2',
name: 'tool2',
args: {},
isClientInitiated: false,
}),
);
expect(turn.pendingToolCalls[1]).toEqual(event2.value);
@@ -319,6 +335,7 @@ describe('Turn', () => {
callId: 'fc3',
name: 'undefined_tool_name',
args: {},
isClientInitiated: false,
}),
);
expect(turn.pendingToolCalls[2]).toEqual(event3.value);

View File

@@ -57,6 +57,7 @@ export interface ToolCallRequestInfo {
callId: string;
name: string;
args: Record<string, unknown>;
isClientInitiated: boolean;
}
export interface ToolCallResponseInfo {
@@ -139,11 +140,7 @@ export type ServerGeminiStreamEvent =
// A turn manages the agentic loop turn within the server context.
export class Turn {
readonly pendingToolCalls: Array<{
callId: string;
name: string;
args: Record<string, unknown>;
}>;
readonly pendingToolCalls: ToolCallRequestInfo[];
private debugResponses: GenerateContentResponse[];
private lastUsageMetadata: GenerateContentResponseUsageMetadata | null = null;
@@ -254,11 +251,17 @@ export class Turn {
const name = fnCall.name || 'undefined_tool_name';
const args = (fnCall.args || {}) as Record<string, unknown>;
this.pendingToolCalls.push({ callId, name, args });
const toolCallRequest: ToolCallRequestInfo = {
callId,
name,
args,
isClientInitiated: false,
};
this.pendingToolCalls.push(toolCallRequest);
// Yield a request for the tool call, not the pending/confirming status
const value: ToolCallRequestInfo = { callId, name, args };
return { type: GeminiEventType.ToolCallRequest, value };
return { type: GeminiEventType.ToolCallRequest, value: toolCallRequest };
}
getDebugResponses(): GenerateContentResponse[] {