feat(core): Introduce DeclarativeTool and ToolInvocation. (#5613)

This commit is contained in:
joshualitt
2025-08-06 10:50:02 -07:00
committed by GitHub
parent 882a97aff9
commit 6133bea388
24 changed files with 991 additions and 681 deletions

View File

@@ -239,65 +239,62 @@ class GeminiAgent implements Agent {
);
}
let toolCallId;
const confirmationDetails = await tool.shouldConfirmExecute(
args,
abortSignal,
);
if (confirmationDetails) {
let content: acp.ToolCallContent | null = null;
if (confirmationDetails.type === 'edit') {
content = {
type: 'diff',
path: confirmationDetails.fileName,
oldText: confirmationDetails.originalContent,
newText: confirmationDetails.newContent,
};
}
const result = await this.client.requestToolCallConfirmation({
label: tool.getDescription(args),
icon: tool.icon,
content,
confirmation: toAcpToolCallConfirmation(confirmationDetails),
locations: tool.toolLocations(args),
});
await confirmationDetails.onConfirm(toToolCallOutcome(result.outcome));
switch (result.outcome) {
case 'reject':
return errorResponse(
new Error(`Tool "${fc.name}" not allowed to run by the user.`),
);
case 'cancel':
return errorResponse(
new Error(`Tool "${fc.name}" was canceled by the user.`),
);
case 'allow':
case 'alwaysAllow':
case 'alwaysAllowMcpServer':
case 'alwaysAllowTool':
break;
default: {
const resultOutcome: never = result.outcome;
throw new Error(`Unexpected: ${resultOutcome}`);
}
}
toolCallId = result.id;
} else {
const result = await this.client.pushToolCall({
icon: tool.icon,
label: tool.getDescription(args),
locations: tool.toolLocations(args),
});
toolCallId = result.id;
}
let toolCallId: number | undefined = undefined;
try {
const toolResult: ToolResult = await tool.execute(args, abortSignal);
const invocation = tool.build(args);
const confirmationDetails =
await invocation.shouldConfirmExecute(abortSignal);
if (confirmationDetails) {
let content: acp.ToolCallContent | null = null;
if (confirmationDetails.type === 'edit') {
content = {
type: 'diff',
path: confirmationDetails.fileName,
oldText: confirmationDetails.originalContent,
newText: confirmationDetails.newContent,
};
}
const result = await this.client.requestToolCallConfirmation({
label: invocation.getDescription(),
icon: tool.icon,
content,
confirmation: toAcpToolCallConfirmation(confirmationDetails),
locations: invocation.toolLocations(),
});
await confirmationDetails.onConfirm(toToolCallOutcome(result.outcome));
switch (result.outcome) {
case 'reject':
return errorResponse(
new Error(`Tool "${fc.name}" not allowed to run by the user.`),
);
case 'cancel':
return errorResponse(
new Error(`Tool "${fc.name}" was canceled by the user.`),
);
case 'allow':
case 'alwaysAllow':
case 'alwaysAllowMcpServer':
case 'alwaysAllowTool':
break;
default: {
const resultOutcome: never = result.outcome;
throw new Error(`Unexpected: ${resultOutcome}`);
}
}
toolCallId = result.id;
} else {
const result = await this.client.pushToolCall({
icon: tool.icon,
label: invocation.getDescription(),
locations: invocation.toolLocations(),
});
toolCallId = result.id;
}
const toolResult: ToolResult = await invocation.execute(abortSignal);
const toolCallContent = toToolCallContent(toolResult);
await this.client.updateToolCall({
@@ -320,12 +317,13 @@ class GeminiAgent implements Agent {
return convertToFunctionResponse(fc.name, callId, toolResult.llmContent);
} catch (e) {
const error = e instanceof Error ? e : new Error(String(e));
await this.client.updateToolCall({
toolCallId,
status: 'error',
content: { type: 'markdown', markdown: error.message },
});
if (toolCallId) {
await this.client.updateToolCall({
toolCallId,
status: 'error',
content: { type: 'markdown', markdown: error.message },
});
}
return errorResponse(error);
}
}
@@ -408,7 +406,7 @@ class GeminiAgent implements Agent {
`Path ${pathName} not found directly, attempting glob search.`,
);
try {
const globResult = await globTool.execute(
const globResult = await globTool.buildAndExecute(
{
pattern: `**/*${pathName}*`,
path: this.config.getTargetDir(),
@@ -530,12 +528,15 @@ class GeminiAgent implements Agent {
respectGitIgnore, // Use configuration setting
};
const toolCall = await this.client.pushToolCall({
icon: readManyFilesTool.icon,
label: readManyFilesTool.getDescription(toolArgs),
});
let toolCallId: number | undefined = undefined;
try {
const result = await readManyFilesTool.execute(toolArgs, abortSignal);
const invocation = readManyFilesTool.build(toolArgs);
const toolCall = await this.client.pushToolCall({
icon: readManyFilesTool.icon,
label: invocation.getDescription(),
});
toolCallId = toolCall.id;
const result = await invocation.execute(abortSignal);
const content = toToolCallContent(result) || {
type: 'markdown',
markdown: `Successfully read: ${contentLabelsForDisplay.join(', ')}`,
@@ -578,14 +579,16 @@ class GeminiAgent implements Agent {
return processedQueryParts;
} catch (error: unknown) {
await this.client.updateToolCall({
toolCallId: toolCall.id,
status: 'error',
content: {
type: 'markdown',
markdown: `Error reading files (${contentLabelsForDisplay.join(', ')}): ${getErrorMessage(error)}`,
},
});
if (toolCallId) {
await this.client.updateToolCall({
toolCallId,
status: 'error',
content: {
type: 'markdown',
markdown: `Error reading files (${contentLabelsForDisplay.join(', ')}): ${getErrorMessage(error)}`,
},
});
}
throw error;
}
}

View File

@@ -8,6 +8,7 @@ import * as fs from 'fs/promises';
import * as path from 'path';
import { PartListUnion, PartUnion } from '@google/genai';
import {
AnyToolInvocation,
Config,
getErrorMessage,
isNodeError,
@@ -254,7 +255,7 @@ export async function handleAtCommand({
`Path ${pathName} not found directly, attempting glob search.`,
);
try {
const globResult = await globTool.execute(
const globResult = await globTool.buildAndExecute(
{
pattern: `**/*${pathName}*`,
path: dir,
@@ -411,12 +412,14 @@ export async function handleAtCommand({
};
let toolCallDisplay: IndividualToolCallDisplay;
let invocation: AnyToolInvocation | undefined = undefined;
try {
const result = await readManyFilesTool.execute(toolArgs, signal);
invocation = readManyFilesTool.build(toolArgs);
const result = await invocation.execute(signal);
toolCallDisplay = {
callId: `client-read-${userMessageTimestamp}`,
name: readManyFilesTool.displayName,
description: readManyFilesTool.getDescription(toolArgs),
description: invocation.getDescription(),
status: ToolCallStatus.Success,
resultDisplay:
result.returnDisplay ||
@@ -466,7 +469,9 @@ export async function handleAtCommand({
toolCallDisplay = {
callId: `client-read-${userMessageTimestamp}`,
name: readManyFilesTool.displayName,
description: readManyFilesTool.getDescription(toolArgs),
description:
invocation?.getDescription() ??
'Error attempting to execute tool to read files',
status: ToolCallStatus.Error,
resultDisplay: `Error reading files (${contentLabelsForDisplay.join(', ')}): ${getErrorMessage(error)}`,
confirmationDetails: undefined,

View File

@@ -21,6 +21,7 @@ import {
EditorType,
AuthType,
GeminiEventType as ServerGeminiEventType,
AnyToolInvocation,
} from '@google/gemini-cli-core';
import { Part, PartListUnion } from '@google/genai';
import { UseHistoryManagerReturn } from './useHistoryManager.js';
@@ -452,9 +453,13 @@ describe('useGeminiStream', () => {
},
tool: {
name: 'tool1',
displayName: 'tool1',
description: 'desc1',
getDescription: vi.fn(),
build: vi.fn(),
} as any,
invocation: {
getDescription: () => `Mock description`,
} as unknown as AnyToolInvocation,
startTime: Date.now(),
endTime: Date.now(),
} as TrackedCompletedToolCall,
@@ -469,9 +474,13 @@ describe('useGeminiStream', () => {
responseSubmittedToGemini: false,
tool: {
name: 'tool2',
displayName: 'tool2',
description: 'desc2',
getDescription: vi.fn(),
build: vi.fn(),
} as any,
invocation: {
getDescription: () => `Mock description`,
} as unknown as AnyToolInvocation,
startTime: Date.now(),
liveOutput: '...',
} as TrackedExecutingToolCall,
@@ -506,6 +515,12 @@ describe('useGeminiStream', () => {
status: 'success',
responseSubmittedToGemini: false,
response: { callId: 'call1', responseParts: toolCall1ResponseParts },
tool: {
displayName: 'MockTool',
},
invocation: {
getDescription: () => `Mock description`,
} as unknown as AnyToolInvocation,
} as TrackedCompletedToolCall,
{
request: {
@@ -584,6 +599,12 @@ describe('useGeminiStream', () => {
status: 'cancelled',
response: { callId: '1', responseParts: [{ text: 'cancelled' }] },
responseSubmittedToGemini: false,
tool: {
displayName: 'mock tool',
},
invocation: {
getDescription: () => `Mock description`,
} as unknown as AnyToolInvocation,
} as TrackedCancelledToolCall,
];
const client = new MockedGeminiClientClass(mockConfig);
@@ -644,9 +665,13 @@ describe('useGeminiStream', () => {
},
tool: {
name: 'toolA',
displayName: 'toolA',
description: 'descA',
getDescription: vi.fn(),
build: vi.fn(),
} as any,
invocation: {
getDescription: () => `Mock description`,
} as unknown as AnyToolInvocation,
status: 'cancelled',
response: {
callId: 'cancel-1',
@@ -668,9 +693,13 @@ describe('useGeminiStream', () => {
},
tool: {
name: 'toolB',
displayName: 'toolB',
description: 'descB',
getDescription: vi.fn(),
build: vi.fn(),
} as any,
invocation: {
getDescription: () => `Mock description`,
} as unknown as AnyToolInvocation,
status: 'cancelled',
response: {
callId: 'cancel-2',
@@ -760,9 +789,13 @@ describe('useGeminiStream', () => {
responseSubmittedToGemini: false,
tool: {
name: 'tool1',
displayName: 'tool1',
description: 'desc',
getDescription: vi.fn(),
build: vi.fn(),
} as any,
invocation: {
getDescription: () => `Mock description`,
} as unknown as AnyToolInvocation,
startTime: Date.now(),
} as TrackedExecutingToolCall,
];
@@ -980,8 +1013,13 @@ describe('useGeminiStream', () => {
tool: {
name: 'tool1',
description: 'desc1',
getDescription: vi.fn(),
build: vi.fn().mockImplementation((_) => ({
getDescription: () => `Mock description`,
})),
} as any,
invocation: {
getDescription: () => `Mock description`,
},
startTime: Date.now(),
liveOutput: '...',
} as TrackedExecutingToolCall,
@@ -1131,9 +1169,13 @@ describe('useGeminiStream', () => {
},
tool: {
name: 'save_memory',
displayName: 'save_memory',
description: 'Saves memory',
getDescription: vi.fn(),
build: vi.fn(),
} as any,
invocation: {
getDescription: () => `Mock description`,
} as unknown as AnyToolInvocation,
};
// Capture the onComplete callback

View File

@@ -17,7 +17,6 @@ import {
OutputUpdateHandler,
AllToolCallsCompleteHandler,
ToolCallsUpdateHandler,
Tool,
ToolCall,
Status as CoreStatus,
EditorType,
@@ -216,23 +215,20 @@ export function mapToDisplay(
const toolDisplays = toolCalls.map(
(trackedCall): IndividualToolCallDisplay => {
let displayName = trackedCall.request.name;
let description = '';
let displayName: string;
let description: string;
let renderOutputAsMarkdown = false;
const currentToolInstance =
'tool' in trackedCall && trackedCall.tool
? (trackedCall as { tool: Tool }).tool
: undefined;
if (currentToolInstance) {
displayName = currentToolInstance.displayName;
description = currentToolInstance.getDescription(
trackedCall.request.args,
);
renderOutputAsMarkdown = currentToolInstance.isOutputMarkdown;
} else if ('request' in trackedCall && 'args' in trackedCall.request) {
if (trackedCall.status === 'error') {
displayName =
trackedCall.tool === undefined
? trackedCall.request.name
: trackedCall.tool.displayName;
description = JSON.stringify(trackedCall.request.args);
} else {
displayName = trackedCall.tool.displayName;
description = trackedCall.invocation.getDescription();
renderOutputAsMarkdown = trackedCall.tool.isOutputMarkdown;
}
const baseDisplayProperties: Omit<
@@ -256,7 +252,6 @@ export function mapToDisplay(
case 'error':
return {
...baseDisplayProperties,
name: currentToolInstance?.displayName ?? trackedCall.request.name,
status: mapCoreStatusToDisplayStatus(trackedCall.status),
resultDisplay: trackedCall.response.resultDisplay,
confirmationDetails: undefined,

View File

@@ -15,7 +15,6 @@ import { PartUnion, FunctionResponse } from '@google/genai';
import {
Config,
ToolCallRequestInfo,
Tool,
ToolRegistry,
ToolResult,
ToolCallConfirmationDetails,
@@ -25,6 +24,9 @@ import {
Status as ToolCallStatusType,
ApprovalMode,
Icon,
BaseTool,
AnyDeclarativeTool,
AnyToolInvocation,
} from '@google/gemini-cli-core';
import {
HistoryItemWithoutId,
@@ -53,46 +55,55 @@ const mockConfig = {
getDebugMode: () => false,
};
const mockTool: Tool = {
name: 'mockTool',
displayName: 'Mock Tool',
description: 'A mock tool for testing',
icon: Icon.Hammer,
toolLocations: vi.fn(),
isOutputMarkdown: false,
canUpdateOutput: false,
schema: {},
validateToolParams: vi.fn(),
execute: vi.fn(),
shouldConfirmExecute: vi.fn(),
getDescription: vi.fn((args) => `Description for ${JSON.stringify(args)}`),
};
class MockTool extends BaseTool<object, ToolResult> {
constructor(
name: string,
displayName: string,
canUpdateOutput = false,
shouldConfirm = false,
isOutputMarkdown = false,
) {
super(
name,
displayName,
'A mock tool for testing',
Icon.Hammer,
{},
isOutputMarkdown,
canUpdateOutput,
);
if (shouldConfirm) {
this.shouldConfirmExecute = vi.fn(
async (): Promise<ToolCallConfirmationDetails | false> => ({
type: 'edit',
title: 'Mock Tool Requires Confirmation',
onConfirm: mockOnUserConfirmForToolConfirmation,
fileName: 'mockToolRequiresConfirmation.ts',
fileDiff: 'Mock tool requires confirmation',
originalContent: 'Original content',
newContent: 'New content',
}),
);
}
}
const mockToolWithLiveOutput: Tool = {
...mockTool,
name: 'mockToolWithLiveOutput',
displayName: 'Mock Tool With Live Output',
canUpdateOutput: true,
};
execute = vi.fn();
shouldConfirmExecute = vi.fn();
}
const mockTool = new MockTool('mockTool', 'Mock Tool');
const mockToolWithLiveOutput = new MockTool(
'mockToolWithLiveOutput',
'Mock Tool With Live Output',
true,
);
let mockOnUserConfirmForToolConfirmation: Mock;
const mockToolRequiresConfirmation: Tool = {
...mockTool,
name: 'mockToolRequiresConfirmation',
displayName: 'Mock Tool Requires Confirmation',
shouldConfirmExecute: vi.fn(
async (): Promise<ToolCallConfirmationDetails | false> => ({
type: 'edit',
title: 'Mock Tool Requires Confirmation',
onConfirm: mockOnUserConfirmForToolConfirmation,
fileName: 'mockToolRequiresConfirmation.ts',
fileDiff: 'Mock tool requires confirmation',
originalContent: 'Original content',
newContent: 'New content',
}),
),
};
const mockToolRequiresConfirmation = new MockTool(
'mockToolRequiresConfirmation',
'Mock Tool Requires Confirmation',
false,
true,
);
describe('useReactToolScheduler in YOLO Mode', () => {
let onComplete: Mock;
@@ -646,28 +657,21 @@ describe('useReactToolScheduler', () => {
});
it('should schedule and execute multiple tool calls', async () => {
const tool1 = {
...mockTool,
name: 'tool1',
displayName: 'Tool 1',
execute: vi.fn().mockResolvedValue({
llmContent: 'Output 1',
returnDisplay: 'Display 1',
summary: 'Summary 1',
} as ToolResult),
shouldConfirmExecute: vi.fn().mockResolvedValue(null),
};
const tool2 = {
...mockTool,
name: 'tool2',
displayName: 'Tool 2',
execute: vi.fn().mockResolvedValue({
llmContent: 'Output 2',
returnDisplay: 'Display 2',
summary: 'Summary 2',
} as ToolResult),
shouldConfirmExecute: vi.fn().mockResolvedValue(null),
};
const tool1 = new MockTool('tool1', 'Tool 1');
tool1.execute.mockResolvedValue({
llmContent: 'Output 1',
returnDisplay: 'Display 1',
summary: 'Summary 1',
} as ToolResult);
tool1.shouldConfirmExecute.mockResolvedValue(null);
const tool2 = new MockTool('tool2', 'Tool 2');
tool2.execute.mockResolvedValue({
llmContent: 'Output 2',
returnDisplay: 'Display 2',
summary: 'Summary 2',
} as ToolResult);
tool2.shouldConfirmExecute.mockResolvedValue(null);
mockToolRegistry.getTool.mockImplementation((name) => {
if (name === 'tool1') return tool1;
@@ -805,20 +809,7 @@ describe('mapToDisplay', () => {
args: { foo: 'bar' },
};
const baseTool: Tool = {
name: 'testTool',
displayName: 'Test Tool Display',
description: 'Test Description',
isOutputMarkdown: false,
canUpdateOutput: false,
schema: {},
icon: Icon.Hammer,
toolLocations: vi.fn(),
validateToolParams: vi.fn(),
execute: vi.fn(),
shouldConfirmExecute: vi.fn(),
getDescription: vi.fn((args) => `Desc: ${JSON.stringify(args)}`),
};
const baseTool = new MockTool('testTool', 'Test Tool Display');
const baseResponse: ToolCallResponseInfo = {
callId: 'testCallId',
@@ -840,13 +831,15 @@ describe('mapToDisplay', () => {
// This helps ensure that tool and confirmationDetails are only accessed when they are expected to exist.
type MapToDisplayExtraProps =
| {
tool?: Tool;
tool?: AnyDeclarativeTool;
invocation?: AnyToolInvocation;
liveOutput?: string;
response?: ToolCallResponseInfo;
confirmationDetails?: ToolCallConfirmationDetails;
}
| {
tool: Tool;
tool: AnyDeclarativeTool;
invocation?: AnyToolInvocation;
response?: ToolCallResponseInfo;
confirmationDetails?: ToolCallConfirmationDetails;
}
@@ -857,10 +850,12 @@ describe('mapToDisplay', () => {
}
| {
confirmationDetails: ToolCallConfirmationDetails;
tool?: Tool;
tool?: AnyDeclarativeTool;
invocation?: AnyToolInvocation;
response?: ToolCallResponseInfo;
};
const baseInvocation = baseTool.build(baseRequest.args);
const testCases: Array<{
name: string;
status: ToolCallStatusType;
@@ -873,7 +868,7 @@ describe('mapToDisplay', () => {
{
name: 'validating',
status: 'validating',
extraProps: { tool: baseTool },
extraProps: { tool: baseTool, invocation: baseInvocation },
expectedStatus: ToolCallStatus.Executing,
expectedName: baseTool.displayName,
expectedDescription: baseTool.getDescription(baseRequest.args),
@@ -883,6 +878,7 @@ describe('mapToDisplay', () => {
status: 'awaiting_approval',
extraProps: {
tool: baseTool,
invocation: baseInvocation,
confirmationDetails: {
onConfirm: vi.fn(),
type: 'edit',
@@ -903,7 +899,7 @@ describe('mapToDisplay', () => {
{
name: 'scheduled',
status: 'scheduled',
extraProps: { tool: baseTool },
extraProps: { tool: baseTool, invocation: baseInvocation },
expectedStatus: ToolCallStatus.Pending,
expectedName: baseTool.displayName,
expectedDescription: baseTool.getDescription(baseRequest.args),
@@ -911,7 +907,7 @@ describe('mapToDisplay', () => {
{
name: 'executing no live output',
status: 'executing',
extraProps: { tool: baseTool },
extraProps: { tool: baseTool, invocation: baseInvocation },
expectedStatus: ToolCallStatus.Executing,
expectedName: baseTool.displayName,
expectedDescription: baseTool.getDescription(baseRequest.args),
@@ -919,7 +915,11 @@ describe('mapToDisplay', () => {
{
name: 'executing with live output',
status: 'executing',
extraProps: { tool: baseTool, liveOutput: 'Live test output' },
extraProps: {
tool: baseTool,
invocation: baseInvocation,
liveOutput: 'Live test output',
},
expectedStatus: ToolCallStatus.Executing,
expectedResultDisplay: 'Live test output',
expectedName: baseTool.displayName,
@@ -928,7 +928,11 @@ describe('mapToDisplay', () => {
{
name: 'success',
status: 'success',
extraProps: { tool: baseTool, response: baseResponse },
extraProps: {
tool: baseTool,
invocation: baseInvocation,
response: baseResponse,
},
expectedStatus: ToolCallStatus.Success,
expectedResultDisplay: baseResponse.resultDisplay as any,
expectedName: baseTool.displayName,
@@ -970,6 +974,7 @@ describe('mapToDisplay', () => {
status: 'cancelled',
extraProps: {
tool: baseTool,
invocation: baseInvocation,
response: {
...baseResponse,
resultDisplay: 'Cancelled display',
@@ -1030,12 +1035,21 @@ describe('mapToDisplay', () => {
request: { ...baseRequest, callId: 'call1' },
status: 'success',
tool: baseTool,
invocation: baseTool.build(baseRequest.args),
response: { ...baseResponse, callId: 'call1' },
} as ToolCall;
const toolForCall2 = new MockTool(
baseTool.name,
baseTool.displayName,
false,
false,
true,
);
const toolCall2: ToolCall = {
request: { ...baseRequest, callId: 'call2' },
status: 'executing',
tool: { ...baseTool, isOutputMarkdown: true },
tool: toolForCall2,
invocation: toolForCall2.build(baseRequest.args),
liveOutput: 'markdown output',
} as ToolCall;