# 🚀 Sync Gemini CLI v0.2.1 - Major Feature Update (#483)

This commit is contained in:
tanzhenxin
2025-09-01 14:48:55 +08:00
committed by GitHub
parent 1610c1586e
commit 2572faf726
292 changed files with 19401 additions and 5941 deletions

View File

@@ -0,0 +1,73 @@
// Vitest Snapshot v1, https://vitest.dev/guide/snapshot.html
exports[`ShellTool > getDescription > should return the non-windows description when not on windows 1`] = `
"
This tool executes a given shell command as \`bash -c <command>\`.
**Background vs Foreground Execution:**
You should decide whether commands should run in background or foreground based on their nature:
**Use background execution (is_background: true) for:**
- Long-running development servers: \`npm run start\`, \`npm run dev\`, \`yarn dev\`, \`bun run start\`
- Build watchers: \`npm run watch\`, \`webpack --watch\`
- Database servers: \`mongod\`, \`mysql\`, \`redis-server\`
- Web servers: \`python -m http.server\`, \`php -S localhost:8000\`
- Any command expected to run indefinitely until manually stopped
**Use foreground execution (is_background: false) for:**
- One-time commands: \`ls\`, \`cat\`, \`grep\`
- Build commands: \`npm run build\`, \`make\`
- Installation commands: \`npm install\`, \`pip install\`
- Git operations: \`git commit\`, \`git push\`
- Test runs: \`npm test\`, \`pytest\`
Command is executed as a subprocess that leads its own process group. Command process group can be terminated as \`kill -- -PGID\` or signaled as \`kill -s SIGNAL -- -PGID\`.
The following information is returned:
Command: Executed command.
Directory: Directory (relative to project root) where command was executed, or \`(root)\`.
Stdout: Output on stdout stream. Can be \`(empty)\` or partial on error and for any unwaited background processes.
Stderr: Output on stderr stream. Can be \`(empty)\` or partial on error and for any unwaited background processes.
Error: Error or \`(none)\` if no error was reported for the subprocess.
Exit Code: Exit code or \`(none)\` if terminated by signal.
Signal: Signal number or \`(none)\` if no signal was received.
Background PIDs: List of background processes started or \`(none)\`.
Process Group PGID: Process group started or \`(none)\`"
`;
exports[`ShellTool > getDescription > should return the windows description when on windows 1`] = `
"
This tool executes a given shell command as \`cmd.exe /c <command>\`.
**Background vs Foreground Execution:**
You should decide whether commands should run in background or foreground based on their nature:
**Use background execution (is_background: true) for:**
- Long-running development servers: \`npm run start\`, \`npm run dev\`, \`yarn dev\`, \`bun run start\`
- Build watchers: \`npm run watch\`, \`webpack --watch\`
- Database servers: \`mongod\`, \`mysql\`, \`redis-server\`
- Web servers: \`python -m http.server\`, \`php -S localhost:8000\`
- Any command expected to run indefinitely until manually stopped
**Use foreground execution (is_background: false) for:**
- One-time commands: \`ls\`, \`cat\`, \`grep\`
- Build commands: \`npm run build\`, \`make\`
- Installation commands: \`npm install\`, \`pip install\`
- Git operations: \`git commit\`, \`git push\`
- Test runs: \`npm test\`, \`pytest\`
The following information is returned:
Command: Executed command.
Directory: Directory (relative to project root) where command was executed, or \`(root)\`.
Stdout: Output on stdout stream. Can be \`(empty)\` or partial on error and for any unwaited background processes.
Stderr: Output on stderr stream. Can be \`(empty)\` or partial on error and for any unwaited background processes.
Error: Error or \`(none)\` if no error was reported for the subprocess.
Exit Code: Exit code or \`(none)\` if terminated by signal.
Signal: Signal number or \`(none)\` if no signal was received.
Background PIDs: List of background processes started or \`(none)\`.
Process Group PGID: Process group started or \`(none)\`"
`;

View File

@@ -36,6 +36,7 @@ import os from 'os';
import { ApprovalMode, Config } from '../config/config.js';
import { Content, Part, SchemaUnion } from '@google/genai';
import { createMockWorkspaceContext } from '../test-utils/mockWorkspaceContext.js';
import { StandardFileSystemService } from '../services/fileSystemService.js';
describe('EditTool', () => {
let tool: EditTool;
@@ -60,6 +61,7 @@ describe('EditTool', () => {
getApprovalMode: vi.fn(),
setApprovalMode: vi.fn(),
getWorkspaceContext: () => createMockWorkspaceContext(rootDir),
getFileSystemService: () => new StandardFileSystemService(),
getIdeClient: () => undefined,
getIdeMode: () => false,
// getGeminiConfig: () => ({ apiKey: 'test-api-key' }), // This was not a real Config method
@@ -393,7 +395,7 @@ describe('EditTool', () => {
});
});
it('should throw error if params are invalid', async () => {
it('should throw error if file path is not absolute', async () => {
const params: EditToolParams = {
file_path: 'relative.txt',
old_string: 'old',
@@ -402,6 +404,17 @@ describe('EditTool', () => {
expect(() => tool.build(params)).toThrow(/File path must be absolute/);
});
it('should throw error if file path is empty', async () => {
const params: EditToolParams = {
file_path: '',
old_string: 'old',
new_string: 'new',
};
expect(() => tool.build(params)).toThrow(
/The 'file_path' parameter must be non-empty./,
);
});
it('should edit an existing file and return diff with fileName', async () => {
const initialContent = 'This is some old text.';
const newContent = 'This is some new text.'; // old -> new

View File

@@ -19,7 +19,6 @@ import {
ToolResultDisplay,
} from './tools.js';
import { ToolErrorType } from './tool-error.js';
import { SchemaValidator } from '../utils/schemaValidator.js';
import { makeRelative, shortenPath } from '../utils/paths.js';
import { isNodeError } from '../utils/errors.js';
import { Config, ApprovalMode } from '../config/config.js';
@@ -125,7 +124,9 @@ class EditToolInvocation implements ToolInvocation<EditToolParams, ToolResult> {
| undefined = undefined;
try {
currentContent = fs.readFileSync(params.file_path, 'utf8');
currentContent = await this.config
.getFileSystemService()
.readTextFile(params.file_path);
// Normalize line endings to LF for consistent processing.
currentContent = currentContent.replace(/\r\n/g, '\n');
fileExists = true;
@@ -339,7 +340,9 @@ class EditToolInvocation implements ToolInvocation<EditToolParams, ToolResult> {
try {
this.ensureParentDirectoriesExist(this.params.file_path);
fs.writeFileSync(this.params.file_path, editData.newContent, 'utf8');
await this.config
.getFileSystemService()
.writeTextFile(this.params.file_path, editData.newContent);
let displayResult: ToolResultDisplay;
if (editData.isNewFile) {
@@ -471,13 +474,11 @@ Expectation for required parameters:
* @param params Parameters to validate
* @returns Error message string or null if valid
*/
override validateToolParams(params: EditToolParams): string | null {
const errors = SchemaValidator.validate(
this.schema.parametersJsonSchema,
params,
);
if (errors) {
return errors;
protected override validateToolParamValues(
params: EditToolParams,
): string | null {
if (!params.file_path) {
return "The 'file_path' parameter must be non-empty.";
}
if (!path.isAbsolute(params.file_path)) {
@@ -504,7 +505,9 @@ Expectation for required parameters:
getFilePath: (params: EditToolParams) => params.file_path,
getCurrentContent: async (params: EditToolParams): Promise<string> => {
try {
return fs.readFileSync(params.file_path, 'utf8');
return this.config
.getFileSystemService()
.readTextFile(params.file_path);
} catch (err) {
if (!isNodeError(err) || err.code !== 'ENOENT') throw err;
return '';
@@ -512,7 +515,9 @@ Expectation for required parameters:
},
getProposedContent: async (params: EditToolParams): Promise<string> => {
try {
const currentContent = fs.readFileSync(params.file_path, 'utf8');
const currentContent = await this.config
.getFileSystemService()
.readTextFile(params.file_path);
return applyReplacement(
currentContent,
params.old_string,

View File

@@ -150,6 +150,34 @@ describe('GlobTool', () => {
expect(result.returnDisplay).toBe('No files found');
});
it('should find files with special characters in the name', async () => {
await fs.writeFile(path.join(tempRootDir, 'file[1].txt'), 'content');
const params: GlobToolParams = { pattern: 'file[1].txt' };
const invocation = globTool.build(params);
const result = await invocation.execute(abortSignal);
expect(result.llmContent).toContain('Found 1 file(s)');
expect(result.llmContent).toContain(
path.join(tempRootDir, 'file[1].txt'),
);
});
it('should find files with special characters like [] and () in the path', async () => {
const filePath = path.join(
tempRootDir,
'src/app/[test]/(dashboard)/testing/components/code.tsx',
);
await fs.mkdir(path.dirname(filePath), { recursive: true });
await fs.writeFile(filePath, 'content');
const params: GlobToolParams = {
pattern: 'src/app/[test]/(dashboard)/testing/components/code.tsx',
};
const invocation = globTool.build(params);
const result = await invocation.execute(abortSignal);
expect(result.llmContent).toContain('Found 1 file(s)');
expect(result.llmContent).toContain(filePath);
});
it('should correctly sort files by modification time (newest first)', async () => {
const params: GlobToolParams = { pattern: '*.sortme' };
const invocation = globTool.build(params);

View File

@@ -6,8 +6,7 @@
import fs from 'fs';
import path from 'path';
import { glob } from 'glob';
import { SchemaValidator } from '../utils/schemaValidator.js';
import { glob, escape } from 'glob';
import {
BaseDeclarativeTool,
BaseToolInvocation,
@@ -137,7 +136,13 @@ class GlobToolInvocation extends BaseToolInvocation<
let allEntries: GlobPath[] = [];
for (const searchDir of searchDirectories) {
const entries = (await glob(this.params.pattern, {
let pattern = this.params.pattern;
const fullPath = path.join(searchDir, pattern);
if (fs.existsSync(fullPath)) {
pattern = escape(pattern);
}
const entries = (await glob(pattern, {
cwd: searchDir,
withFileTypes: true,
nodir: true,
@@ -281,15 +286,9 @@ export class GlobTool extends BaseDeclarativeTool<GlobToolParams, ToolResult> {
/**
* Validates the parameters for the tool.
*/
override validateToolParams(params: GlobToolParams): string | null {
const errors = SchemaValidator.validate(
this.schema.parametersJsonSchema,
params,
);
if (errors) {
return errors;
}
protected override validateToolParamValues(
params: GlobToolParams,
): string | null {
const searchDirAbsolute = path.resolve(
this.config.getTargetDir(),
params.path || '.',

View File

@@ -17,7 +17,6 @@ import {
ToolInvocation,
ToolResult,
} from './tools.js';
import { SchemaValidator } from '../utils/schemaValidator.js';
import { makeRelative, shortenPath } from '../utils/paths.js';
import { getErrorMessage, isNodeError } from '../utils/errors.js';
import { isGitRepository } from '../utils/gitUtils.js';
@@ -672,15 +671,9 @@ export class GrepTool extends BaseDeclarativeTool<GrepToolParams, ToolResult> {
* @param params Parameters to validate
* @returns An error message string if invalid, null otherwise
*/
override validateToolParams(params: GrepToolParams): string | null {
const errors = SchemaValidator.validate(
this.schema.parametersJsonSchema,
params,
);
if (errors) {
return errors;
}
protected override validateToolParamValues(
params: GrepToolParams,
): string | null {
try {
new RegExp(params.pattern);
} catch (error) {

View File

@@ -13,7 +13,6 @@ import {
ToolInvocation,
ToolResult,
} from './tools.js';
import { SchemaValidator } from '../utils/schemaValidator.js';
import { makeRelative, shortenPath } from '../utils/paths.js';
import { Config, DEFAULT_FILE_FILTERING_OPTIONS } from '../config/config.js';
@@ -314,14 +313,9 @@ export class LSTool extends BaseDeclarativeTool<LSToolParams, ToolResult> {
* @param params Parameters to validate
* @returns An error message string if invalid, null otherwise
*/
override validateToolParams(params: LSToolParams): string | null {
const errors = SchemaValidator.validate(
this.schema.parametersJsonSchema,
params,
);
if (errors) {
return errors;
}
protected override validateToolParamValues(
params: LSToolParams,
): string | null {
if (!path.isAbsolute(params.path)) {
return `Path must be absolute: ${params.path}`;
}

View File

@@ -0,0 +1,54 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { afterEach, describe, expect, it, vi } from 'vitest';
import { McpClientManager } from './mcp-client-manager.js';
import { McpClient } from './mcp-client.js';
import { ToolRegistry } from './tool-registry.js';
import { PromptRegistry } from '../prompts/prompt-registry.js';
import { WorkspaceContext } from '../utils/workspaceContext.js';
vi.mock('./mcp-client.js', async () => {
const originalModule = await vi.importActual('./mcp-client.js');
return {
...originalModule,
McpClient: vi.fn(),
populateMcpServerCommand: vi.fn(() => ({
'test-server': {},
})),
};
});
describe('McpClientManager', () => {
afterEach(() => {
vi.restoreAllMocks();
});
it('should discover tools from all servers', async () => {
const mockedMcpClient = {
connect: vi.fn(),
discover: vi.fn(),
disconnect: vi.fn(),
getStatus: vi.fn(),
};
vi.mocked(McpClient).mockReturnValue(
mockedMcpClient as unknown as McpClient,
);
const manager = new McpClientManager(
{
'test-server': {},
},
'',
{} as ToolRegistry,
{} as PromptRegistry,
false,
{} as WorkspaceContext,
);
await manager.discoverAllMcpTools();
expect(mockedMcpClient.connect).toHaveBeenCalledOnce();
expect(mockedMcpClient.discover).toHaveBeenCalledOnce();
});
});

View File

@@ -0,0 +1,115 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { MCPServerConfig } from '../config/config.js';
import { ToolRegistry } from './tool-registry.js';
import { PromptRegistry } from '../prompts/prompt-registry.js';
import {
McpClient,
MCPDiscoveryState,
populateMcpServerCommand,
} from './mcp-client.js';
import { getErrorMessage } from '../utils/errors.js';
import { WorkspaceContext } from '../utils/workspaceContext.js';
/**
* Manages the lifecycle of multiple MCP clients, including local child processes.
* This class is responsible for starting, stopping, and discovering tools from
* a collection of MCP servers defined in the configuration.
*/
export class McpClientManager {
private clients: Map<string, McpClient> = new Map();
private readonly mcpServers: Record<string, MCPServerConfig>;
private readonly mcpServerCommand: string | undefined;
private readonly toolRegistry: ToolRegistry;
private readonly promptRegistry: PromptRegistry;
private readonly debugMode: boolean;
private readonly workspaceContext: WorkspaceContext;
private discoveryState: MCPDiscoveryState = MCPDiscoveryState.NOT_STARTED;
constructor(
mcpServers: Record<string, MCPServerConfig>,
mcpServerCommand: string | undefined,
toolRegistry: ToolRegistry,
promptRegistry: PromptRegistry,
debugMode: boolean,
workspaceContext: WorkspaceContext,
) {
this.mcpServers = mcpServers;
this.mcpServerCommand = mcpServerCommand;
this.toolRegistry = toolRegistry;
this.promptRegistry = promptRegistry;
this.debugMode = debugMode;
this.workspaceContext = workspaceContext;
}
/**
* Initiates the tool discovery process for all configured MCP servers.
* It connects to each server, discovers its available tools, and registers
* them with the `ToolRegistry`.
*/
async discoverAllMcpTools(): Promise<void> {
await this.stop();
this.discoveryState = MCPDiscoveryState.IN_PROGRESS;
const servers = populateMcpServerCommand(
this.mcpServers,
this.mcpServerCommand,
);
const discoveryPromises = Object.entries(servers).map(
async ([name, config]) => {
const client = new McpClient(
name,
config,
this.toolRegistry,
this.promptRegistry,
this.workspaceContext,
this.debugMode,
);
this.clients.set(name, client);
try {
await client.connect();
await client.discover();
} catch (error) {
// Log the error but don't let a single failed server stop the others
console.error(
`Error during discovery for server '${name}': ${getErrorMessage(
error,
)}`,
);
}
},
);
await Promise.all(discoveryPromises);
this.discoveryState = MCPDiscoveryState.COMPLETED;
}
/**
* Stops all running local MCP servers and closes all client connections.
* This is the cleanup method to be called on application exit.
*/
async stop(): Promise<void> {
const disconnectionPromises = Array.from(this.clients.entries()).map(
async ([name, client]) => {
try {
await client.disconnect();
} catch (error) {
console.error(
`Error stopping client '${name}': ${getErrorMessage(error)}`,
);
}
},
);
await Promise.all(disconnectionPromises);
this.clients.clear();
}
getDiscoveryState(): MCPDiscoveryState {
return this.discoveryState;
}
}

View File

@@ -4,16 +4,14 @@
* SPDX-License-Identifier: Apache-2.0
*/
import { afterEach, describe, expect, it, vi, beforeEach } from 'vitest';
import { afterEach, describe, expect, it, vi } from 'vitest';
import { StreamableHTTPClientTransport } from '@modelcontextprotocol/sdk/client/streamableHttp.js';
import {
populateMcpServerCommand,
createTransport,
isEnabled,
discoverTools,
discoverPrompts,
hasValidTypes,
connectToMcpServer,
McpClient,
} from './mcp-client.js';
import { SSEClientTransport } from '@modelcontextprotocol/sdk/client/sse.js';
import * as SdkClientStdioLib from '@modelcontextprotocol/sdk/client/stdio.js';
@@ -22,26 +20,36 @@ import * as GenAiLib from '@google/genai';
import { GoogleCredentialProvider } from '../mcp/google-auth-provider.js';
import { AuthProviderType } from '../config/config.js';
import { PromptRegistry } from '../prompts/prompt-registry.js';
import { DiscoveredMCPTool } from './mcp-tool.js';
import { ToolRegistry } from './tool-registry.js';
import { WorkspaceContext } from '../utils/workspaceContext.js';
import { pathToFileURL } from 'node:url';
vi.mock('@modelcontextprotocol/sdk/client/stdio.js');
vi.mock('@modelcontextprotocol/sdk/client/index.js');
vi.mock('@google/genai');
vi.mock('../mcp/oauth-provider.js');
vi.mock('../mcp/oauth-token-storage.js');
vi.mock('./mcp-tool.js');
describe('mcp-client', () => {
afterEach(() => {
vi.restoreAllMocks();
});
describe('discoverTools', () => {
describe('McpClient', () => {
it('should discover tools', async () => {
const mockedClient = {} as unknown as ClientLib.Client;
const mockedClient = {
connect: vi.fn(),
discover: vi.fn(),
disconnect: vi.fn(),
getStatus: vi.fn(),
registerCapabilities: vi.fn(),
setRequestHandler: vi.fn(),
};
vi.mocked(ClientLib.Client).mockReturnValue(
mockedClient as unknown as ClientLib.Client,
);
vi.spyOn(SdkClientStdioLib, 'StdioClientTransport').mockReturnValue(
{} as SdkClientStdioLib.StdioClientTransport,
);
const mockedMcpToTool = vi.mocked(GenAiLib.mcpToTool).mockReturnValue({
tool: () => ({
functionDeclarations: [
@@ -51,62 +59,43 @@ describe('mcp-client', () => {
],
}),
} as unknown as GenAiLib.CallableTool);
const tools = await discoverTools('test-server', {}, mockedClient);
expect(tools.length).toBe(1);
const mockedToolRegistry = {
registerTool: vi.fn(),
} as unknown as ToolRegistry;
const client = new McpClient(
'test-server',
{
command: 'test-command',
},
mockedToolRegistry,
{} as PromptRegistry,
{} as WorkspaceContext,
false,
);
await client.connect();
await client.discover();
expect(mockedMcpToTool).toHaveBeenCalledOnce();
});
it('should log an error if there is an error discovering a tool', async () => {
const mockedClient = {} as unknown as ClientLib.Client;
const consoleErrorSpy = vi
.spyOn(console, 'error')
.mockImplementation(() => {});
const testError = new Error('Invalid tool name');
vi.mocked(DiscoveredMCPTool).mockImplementation(
(
_mcpCallableTool: GenAiLib.CallableTool,
_serverName: string,
name: string,
) => {
if (name === 'invalid tool name') {
throw testError;
}
return { name: 'validTool' } as DiscoveredMCPTool;
},
);
vi.mocked(GenAiLib.mcpToTool).mockReturnValue({
tool: () =>
Promise.resolve({
functionDeclarations: [
{
name: 'validTool',
},
{
name: 'invalid tool name', // this will fail validation
},
],
}),
} as unknown as GenAiLib.CallableTool);
const tools = await discoverTools('test-server', {}, mockedClient);
expect(tools.length).toBe(1);
expect(tools[0].name).toBe('validTool');
expect(consoleErrorSpy).toHaveBeenCalledOnce();
expect(consoleErrorSpy).toHaveBeenCalledWith(
`Error discovering tool: 'invalid tool name' from MCP server 'test-server': ${testError.message}`,
);
});
it('should skip tools if a parameter is missing a type', async () => {
const mockedClient = {} as unknown as ClientLib.Client;
const consoleWarnSpy = vi
.spyOn(console, 'warn')
.mockImplementation(() => {});
const mockedClient = {
connect: vi.fn(),
discover: vi.fn(),
disconnect: vi.fn(),
getStatus: vi.fn(),
registerCapabilities: vi.fn(),
setRequestHandler: vi.fn(),
tool: vi.fn(),
};
vi.mocked(ClientLib.Client).mockReturnValue(
mockedClient as unknown as ClientLib.Client,
);
vi.spyOn(SdkClientStdioLib, 'StdioClientTransport').mockReturnValue(
{} as SdkClientStdioLib.StdioClientTransport,
);
vi.mocked(GenAiLib.mcpToTool).mockReturnValue({
tool: () =>
Promise.resolve({
@@ -132,11 +121,22 @@ describe('mcp-client', () => {
],
}),
} as unknown as GenAiLib.CallableTool);
const tools = await discoverTools('test-server', {}, mockedClient);
expect(tools.length).toBe(1);
expect(vi.mocked(DiscoveredMCPTool).mock.calls[0][2]).toBe('validTool');
const mockedToolRegistry = {
registerTool: vi.fn(),
} as unknown as ToolRegistry;
const client = new McpClient(
'test-server',
{
command: 'test-command',
},
mockedToolRegistry,
{} as PromptRegistry,
{} as WorkspaceContext,
false,
);
await client.connect();
await client.discover();
expect(mockedToolRegistry.registerTool).toHaveBeenCalledOnce();
expect(consoleWarnSpy).toHaveBeenCalledOnce();
expect(consoleWarnSpy).toHaveBeenCalledWith(
`Skipping tool 'invalidTool' from MCP server 'test-server' because it has ` +
@@ -145,147 +145,19 @@ describe('mcp-client', () => {
consoleWarnSpy.mockRestore();
});
it('should skip tools if a nested parameter is missing a type', async () => {
const mockedClient = {} as unknown as ClientLib.Client;
const consoleWarnSpy = vi
.spyOn(console, 'warn')
it('should handle errors when discovering prompts', async () => {
const consoleErrorSpy = vi
.spyOn(console, 'error')
.mockImplementation(() => {});
vi.mocked(GenAiLib.mcpToTool).mockReturnValue({
tool: () =>
Promise.resolve({
functionDeclarations: [
{
name: 'invalidTool',
parametersJsonSchema: {
type: 'object',
properties: {
param1: {
type: 'object',
properties: {
nestedParam: {
description: 'a nested param with no type',
},
},
},
},
},
},
],
}),
} as unknown as GenAiLib.CallableTool);
const tools = await discoverTools('test-server', {}, mockedClient);
expect(tools.length).toBe(0);
expect(consoleWarnSpy).toHaveBeenCalledOnce();
expect(consoleWarnSpy).toHaveBeenCalledWith(
`Skipping tool 'invalidTool' from MCP server 'test-server' because it has ` +
`missing types in its parameter schema. Please file an issue with the owner of the MCP server.`,
);
consoleWarnSpy.mockRestore();
});
it('should skip tool if an array item is missing a type', async () => {
const mockedClient = {} as unknown as ClientLib.Client;
const consoleWarnSpy = vi
.spyOn(console, 'warn')
.mockImplementation(() => {});
vi.mocked(GenAiLib.mcpToTool).mockReturnValue({
tool: () =>
Promise.resolve({
functionDeclarations: [
{
name: 'invalidTool',
parametersJsonSchema: {
type: 'object',
properties: {
param1: {
type: 'array',
items: {
description: 'an array item with no type',
},
},
},
},
},
],
}),
} as unknown as GenAiLib.CallableTool);
const tools = await discoverTools('test-server', {}, mockedClient);
expect(tools.length).toBe(0);
expect(consoleWarnSpy).toHaveBeenCalledOnce();
expect(consoleWarnSpy).toHaveBeenCalledWith(
`Skipping tool 'invalidTool' from MCP server 'test-server' because it has ` +
`missing types in its parameter schema. Please file an issue with the owner of the MCP server.`,
);
consoleWarnSpy.mockRestore();
});
it('should discover tool with no properties in schema', async () => {
const mockedClient = {} as unknown as ClientLib.Client;
const consoleWarnSpy = vi
.spyOn(console, 'warn')
.mockImplementation(() => {});
vi.mocked(GenAiLib.mcpToTool).mockReturnValue({
tool: () =>
Promise.resolve({
functionDeclarations: [
{
name: 'validTool',
parametersJsonSchema: {
type: 'object',
},
},
],
}),
} as unknown as GenAiLib.CallableTool);
const tools = await discoverTools('test-server', {}, mockedClient);
expect(tools.length).toBe(1);
expect(vi.mocked(DiscoveredMCPTool).mock.calls[0][2]).toBe('validTool');
expect(consoleWarnSpy).not.toHaveBeenCalled();
consoleWarnSpy.mockRestore();
});
it('should discover tool with empty properties object in schema', async () => {
const mockedClient = {} as unknown as ClientLib.Client;
const consoleWarnSpy = vi
.spyOn(console, 'warn')
.mockImplementation(() => {});
vi.mocked(GenAiLib.mcpToTool).mockReturnValue({
tool: () =>
Promise.resolve({
functionDeclarations: [
{
name: 'validTool',
parametersJsonSchema: {
type: 'object',
properties: {},
},
},
],
}),
} as unknown as GenAiLib.CallableTool);
const tools = await discoverTools('test-server', {}, mockedClient);
expect(tools.length).toBe(1);
expect(vi.mocked(DiscoveredMCPTool).mock.calls[0][2]).toBe('validTool');
expect(consoleWarnSpy).not.toHaveBeenCalled();
consoleWarnSpy.mockRestore();
});
});
describe('connectToMcpServer', () => {
it('should register a roots/list handler', async () => {
const mockedClient = {
connect: vi.fn(),
discover: vi.fn(),
disconnect: vi.fn(),
getStatus: vi.fn(),
registerCapabilities: vi.fn(),
setRequestHandler: vi.fn(),
callTool: vi.fn(),
connect: vi.fn(),
getServerCapabilities: vi.fn().mockReturnValue({ prompts: {} }),
request: vi.fn().mockRejectedValue(new Error('Test error')),
};
vi.mocked(ClientLib.Client).mockReturnValue(
mockedClient as unknown as ClientLib.Client,
@@ -293,148 +165,29 @@ describe('mcp-client', () => {
vi.spyOn(SdkClientStdioLib, 'StdioClientTransport').mockReturnValue(
{} as SdkClientStdioLib.StdioClientTransport,
);
const mockWorkspaceContext = {
getDirectories: vi
.fn()
.mockReturnValue(['/test/dir', '/another/project']),
} as unknown as WorkspaceContext;
await connectToMcpServer(
vi.mocked(GenAiLib.mcpToTool).mockReturnValue({
tool: () => Promise.resolve({ functionDeclarations: [] }),
} as unknown as GenAiLib.CallableTool);
const client = new McpClient(
'test-server',
{
command: 'test-command',
},
{} as ToolRegistry,
{} as PromptRegistry,
{} as WorkspaceContext,
false,
mockWorkspaceContext,
);
expect(mockedClient.registerCapabilities).toHaveBeenCalledWith({
roots: {},
});
expect(mockedClient.setRequestHandler).toHaveBeenCalledOnce();
const handler = mockedClient.setRequestHandler.mock.calls[0][1];
const roots = await handler();
expect(roots).toEqual({
roots: [
{
uri: pathToFileURL('/test/dir').toString(),
name: 'dir',
},
{
uri: pathToFileURL('/another/project').toString(),
name: 'project',
},
],
});
});
});
describe('discoverPrompts', () => {
const mockedPromptRegistry = {
registerPrompt: vi.fn(),
} as unknown as PromptRegistry;
it('should discover and log prompts', async () => {
const mockRequest = vi.fn().mockResolvedValue({
prompts: [
{ name: 'prompt1', description: 'desc1' },
{ name: 'prompt2' },
],
});
const mockGetServerCapabilities = vi.fn().mockReturnValue({
prompts: {},
});
const mockedClient = {
getServerCapabilities: mockGetServerCapabilities,
request: mockRequest,
} as unknown as ClientLib.Client;
await discoverPrompts('test-server', mockedClient, mockedPromptRegistry);
expect(mockGetServerCapabilities).toHaveBeenCalledOnce();
expect(mockRequest).toHaveBeenCalledWith(
{ method: 'prompts/list', params: {} },
expect.anything(),
await client.connect();
await expect(client.discover()).rejects.toThrow(
'No prompts or tools found on the server.',
);
});
it('should do nothing if no prompts are discovered', async () => {
const mockRequest = vi.fn().mockResolvedValue({
prompts: [],
});
const mockGetServerCapabilities = vi.fn().mockReturnValue({
prompts: {},
});
const mockedClient = {
getServerCapabilities: mockGetServerCapabilities,
request: mockRequest,
} as unknown as ClientLib.Client;
const consoleLogSpy = vi
.spyOn(console, 'debug')
.mockImplementation(() => {});
await discoverPrompts('test-server', mockedClient, mockedPromptRegistry);
expect(mockGetServerCapabilities).toHaveBeenCalledOnce();
expect(mockRequest).toHaveBeenCalledOnce();
expect(consoleLogSpy).not.toHaveBeenCalled();
consoleLogSpy.mockRestore();
});
it('should do nothing if the server has no prompt support', async () => {
const mockRequest = vi.fn().mockResolvedValue({
prompts: [],
});
const mockGetServerCapabilities = vi.fn().mockReturnValue({});
const mockedClient = {
getServerCapabilities: mockGetServerCapabilities,
request: mockRequest,
} as unknown as ClientLib.Client;
const consoleLogSpy = vi
.spyOn(console, 'debug')
.mockImplementation(() => {});
await discoverPrompts('test-server', mockedClient, mockedPromptRegistry);
expect(mockGetServerCapabilities).toHaveBeenCalledOnce();
expect(mockRequest).not.toHaveBeenCalled();
expect(consoleLogSpy).not.toHaveBeenCalled();
consoleLogSpy.mockRestore();
});
it('should log an error if discovery fails', async () => {
const testError = new Error('test error');
testError.message = 'test error';
const mockRequest = vi.fn().mockRejectedValue(testError);
const mockGetServerCapabilities = vi.fn().mockReturnValue({
prompts: {},
});
const mockedClient = {
getServerCapabilities: mockGetServerCapabilities,
request: mockRequest,
} as unknown as ClientLib.Client;
const consoleErrorSpy = vi
.spyOn(console, 'error')
.mockImplementation(() => {});
await discoverPrompts('test-server', mockedClient, mockedPromptRegistry);
expect(mockRequest).toHaveBeenCalledOnce();
expect(consoleErrorSpy).toHaveBeenCalledWith(
`Error discovering prompts from test-server: ${testError.message}`,
`Error discovering prompts from test-server: Test error`,
);
consoleErrorSpy.mockRestore();
});
});
describe('appendMcpServerCommand', () => {
it('should do nothing if no MCP servers or command are configured', () => {
const out = populateMcpServerCommand({}, undefined);
@@ -458,17 +211,6 @@ describe('mcp-client', () => {
});
describe('createTransport', () => {
const originalEnv = process.env;
beforeEach(() => {
vi.resetModules();
process.env = {};
});
afterEach(() => {
process.env = originalEnv;
});
describe('should connect via httpUrl', () => {
it('without headers', async () => {
const transport = await createTransport(
@@ -558,7 +300,7 @@ describe('mcp-client', () => {
command: 'test-command',
args: ['--foo', 'bar'],
cwd: 'test/cwd',
env: { FOO: 'bar' },
env: { ...process.env, FOO: 'bar' },
stderr: 'pipe',
});
});

View File

@@ -36,7 +36,7 @@ import { MCPOAuthTokenStorage } from '../mcp/oauth-token-storage.js';
import { getErrorMessage } from '../utils/errors.js';
import { basename } from 'node:path';
import { pathToFileURL } from 'node:url';
import { WorkspaceContext } from '../utils/workspaceContext.js';
import { Unsubscribe, WorkspaceContext } from '../utils/workspaceContext.js';
export const MCP_DEFAULT_TIMEOUT_MSEC = 10 * 60 * 1000; // default to 10 minutes
@@ -69,6 +69,134 @@ export enum MCPDiscoveryState {
COMPLETED = 'completed',
}
/**
* A client for a single MCP server.
*
* This class is responsible for connecting to, discovering tools from, and
* managing the state of a single MCP server.
*/
export class McpClient {
private client: Client;
private transport: Transport | undefined;
private status: MCPServerStatus = MCPServerStatus.DISCONNECTED;
private isDisconnecting = false;
constructor(
private readonly serverName: string,
private readonly serverConfig: MCPServerConfig,
private readonly toolRegistry: ToolRegistry,
private readonly promptRegistry: PromptRegistry,
private readonly workspaceContext: WorkspaceContext,
private readonly debugMode: boolean,
) {
this.client = new Client({
name: `gemini-cli-mcp-client-${this.serverName}`,
version: '0.0.1',
});
}
/**
* Connects to the MCP server.
*/
async connect(): Promise<void> {
this.isDisconnecting = false;
this.updateStatus(MCPServerStatus.CONNECTING);
try {
this.transport = await this.createTransport();
this.client.onerror = (error) => {
if (this.isDisconnecting) {
return;
}
console.error(`MCP ERROR (${this.serverName}):`, error.toString());
this.updateStatus(MCPServerStatus.DISCONNECTED);
};
this.client.registerCapabilities({
roots: {},
});
this.client.setRequestHandler(ListRootsRequestSchema, async () => {
const roots = [];
for (const dir of this.workspaceContext.getDirectories()) {
roots.push({
uri: pathToFileURL(dir).toString(),
name: basename(dir),
});
}
return {
roots,
};
});
await this.client.connect(this.transport, {
timeout: this.serverConfig.timeout,
});
this.updateStatus(MCPServerStatus.CONNECTED);
} catch (error) {
this.updateStatus(MCPServerStatus.DISCONNECTED);
throw error;
}
}
/**
* Discovers tools and prompts from the MCP server.
*/
async discover(): Promise<void> {
if (this.status !== MCPServerStatus.CONNECTED) {
throw new Error('Client is not connected.');
}
const prompts = await this.discoverPrompts();
const tools = await this.discoverTools();
if (prompts.length === 0 && tools.length === 0) {
throw new Error('No prompts or tools found on the server.');
}
for (const tool of tools) {
this.toolRegistry.registerTool(tool);
}
}
/**
* Disconnects from the MCP server.
*/
async disconnect(): Promise<void> {
this.isDisconnecting = true;
if (this.transport) {
await this.transport.close();
}
this.client.close();
this.updateStatus(MCPServerStatus.DISCONNECTED);
}
/**
* Returns the current status of the client.
*/
getStatus(): MCPServerStatus {
return this.status;
}
private updateStatus(status: MCPServerStatus): void {
this.status = status;
updateMCPServerStatus(this.serverName, status);
}
private async createTransport(): Promise<Transport> {
return createTransport(this.serverName, this.serverConfig, this.debugMode);
}
private async discoverTools(): Promise<DiscoveredMCPTool[]> {
return discoverTools(this.serverName, this.serverConfig, this.client);
}
private async discoverPrompts(): Promise<Prompt[]> {
return discoverPrompts(this.serverName, this.client, this.promptRegistry);
}
}
/**
* Map to track the status of each MCP server within the core package
*/
@@ -117,7 +245,7 @@ export function removeMCPStatusChangeListener(
/**
* Update the status of an MCP server
*/
function updateMCPServerStatus(
export function updateMCPServerStatus(
serverName: string,
status: MCPServerStatus,
): void {
@@ -227,10 +355,16 @@ async function handleAutomaticOAuth(
};
// Perform OAuth authentication
// Pass the server URL for proper discovery
const serverUrl = mcpServerConfig.httpUrl || mcpServerConfig.url;
console.log(
`Starting OAuth authentication for server '${mcpServerName}'...`,
);
await MCPOAuthProvider.authenticate(mcpServerName, oauthAuthConfig);
await MCPOAuthProvider.authenticate(
mcpServerName,
oauthAuthConfig,
serverUrl,
);
console.log(
`OAuth authentication successful for server '${mcpServerName}'`,
@@ -442,7 +576,7 @@ export function hasValidTypes(schema: unknown): boolean {
const s = schema as Record<string, unknown>;
if (!s.type) {
if (!s['type']) {
// These keywords contain an array of schemas that should be validated.
//
// If no top level type was given, then they must each have a type.
@@ -464,9 +598,9 @@ export function hasValidTypes(schema: unknown): boolean {
if (!hasSubSchema) return false;
}
if (s.type === 'object' && s.properties) {
if (typeof s.properties === 'object' && s.properties !== null) {
for (const prop of Object.values(s.properties)) {
if (s['type'] === 'object' && s['properties']) {
if (typeof s['properties'] === 'object' && s['properties'] !== null) {
for (const prop of Object.values(s['properties'])) {
if (!hasValidTypes(prop)) {
return false;
}
@@ -474,8 +608,8 @@ export function hasValidTypes(schema: unknown): boolean {
}
}
if (s.type === 'array' && s.items) {
if (!hasValidTypes(s.items)) {
if (s['type'] === 'array' && s['items']) {
if (!hasValidTypes(s['items'])) {
return false;
}
}
@@ -671,7 +805,9 @@ export async function connectToMcpServer(
});
mcpClient.registerCapabilities({
roots: {},
roots: {
listChanged: true,
},
});
mcpClient.setRequestHandler(ListRootsRequestSchema, async () => {
@@ -687,6 +823,32 @@ export async function connectToMcpServer(
};
});
let unlistenDirectories: Unsubscribe | undefined =
workspaceContext.onDirectoriesChanged(async () => {
try {
await mcpClient.notification({
method: 'notifications/roots/list_changed',
});
} catch (_) {
// If this fails, its almost certainly because the connection was closed
// and we should just stop listening for future directory changes.
unlistenDirectories?.();
unlistenDirectories = undefined;
}
});
// Attempt to pro-actively unsubscribe if the mcp client closes. This API is
// very brittle though so we don't have any guarantees, hence the try/catch
// above as well.
//
// Be a good steward and don't just bash over onclose.
const oldOnClose = mcpClient.onclose;
mcpClient.onclose = () => {
oldOnClose?.();
unlistenDirectories?.();
unlistenDirectories = undefined;
};
// patch Client.callTool to use request timeout as genai McpCallTool.callTool does not do it
// TODO: remove this hack once GenAI SDK does callTool with request options
if ('callTool' in mcpClient) {
@@ -933,12 +1095,15 @@ export async function connectToMcpServer(
};
// Perform OAuth authentication
// Pass the server URL for proper discovery
const serverUrl = mcpServerConfig.httpUrl || mcpServerConfig.url;
console.log(
`Starting OAuth authentication for server '${mcpServerName}'...`,
);
await MCPOAuthProvider.authenticate(
mcpServerName,
oauthAuthConfig,
serverUrl,
);
// Retry connection with OAuth token
@@ -1037,7 +1202,7 @@ export async function connectToMcpServer(
conciseError = `Connection failed for '${mcpServerName}': ${errorMessage}`;
}
if (process.env.SANDBOX) {
if (process.env['SANDBOX']) {
conciseError += ` (check sandbox availability)`;
}

View File

@@ -86,7 +86,7 @@ describe('DiscoveredMCPTool', () => {
inputSchema,
);
// Clear allowlist before each relevant test, especially for shouldConfirmExecute
const invocation = tool.build({}) as any;
const invocation = tool.build({ param: 'mock' }) as any;
invocation.constructor.allowlist.clear();
});
@@ -185,8 +185,100 @@ describe('DiscoveredMCPTool', () => {
).rejects.toThrow(expectedError);
});
it.each([
{ isErrorValue: true, description: 'true (bool)' },
{ isErrorValue: 'true', description: '"true" (str)' },
])(
'should consider a ToolResult with isError $description to be a failure',
async ({ isErrorValue }) => {
const tool = new DiscoveredMCPTool(
mockCallableToolInstance,
serverName,
serverToolName,
baseDescription,
inputSchema,
);
const params = { param: 'isErrorTrueCase' };
const errorResponse = { isError: isErrorValue };
const mockMcpToolResponseParts: Part[] = [
{
functionResponse: {
name: serverToolName,
response: { error: errorResponse },
},
},
];
mockCallTool.mockResolvedValue(mockMcpToolResponseParts);
const expectedError = new Error(
`MCP tool '${serverToolName}' reported tool error with response: ${JSON.stringify(
mockMcpToolResponseParts,
)}`,
);
const invocation = tool.build(params);
await expect(
invocation.execute(new AbortController().signal),
).rejects.toThrow(expectedError);
},
);
it.each([
{ isErrorValue: false, description: 'false (bool)' },
{ isErrorValue: 'false', description: '"false" (str)' },
])(
'should consider a ToolResult with isError ${description} to be a success',
async ({ isErrorValue }) => {
const tool = new DiscoveredMCPTool(
mockCallableToolInstance,
serverName,
serverToolName,
baseDescription,
inputSchema,
);
const params = { param: 'isErrorFalseCase' };
const mockToolSuccessResultObject = {
success: true,
details: 'executed',
};
const mockFunctionResponseContent = [
{
type: 'text',
text: JSON.stringify(mockToolSuccessResultObject),
},
];
const errorResponse = { isError: isErrorValue };
const mockMcpToolResponseParts: Part[] = [
{
functionResponse: {
name: serverToolName,
response: {
error: errorResponse,
content: mockFunctionResponseContent,
},
},
},
];
mockCallTool.mockResolvedValue(mockMcpToolResponseParts);
const invocation = tool.build(params);
const toolResult = await invocation.execute(
new AbortController().signal,
);
const stringifiedResponseContent = JSON.stringify(
mockToolSuccessResultObject,
);
expect(toolResult.llmContent).toEqual([
{ text: stringifiedResponseContent },
]);
expect(toolResult.returnDisplay).toBe(stringifiedResponseContent);
},
);
it('should handle a simple text response correctly', async () => {
const params = { query: 'test' };
const params = { param: 'test' };
const successMessage = 'This is a success message.';
// Simulate the response from the GenAI SDK, which wraps the MCP
@@ -220,7 +312,7 @@ describe('DiscoveredMCPTool', () => {
});
it('should handle an AudioBlock response', async () => {
const params = { action: 'play' };
const params = { param: 'play' };
const sdkResponse: Part[] = [
{
functionResponse: {
@@ -257,7 +349,7 @@ describe('DiscoveredMCPTool', () => {
});
it('should handle a ResourceLinkBlock response', async () => {
const params = { resource: 'get' };
const params = { param: 'get' };
const sdkResponse: Part[] = [
{
functionResponse: {
@@ -291,7 +383,7 @@ describe('DiscoveredMCPTool', () => {
});
it('should handle an embedded text ResourceBlock response', async () => {
const params = { resource: 'get' };
const params = { param: 'get' };
const sdkResponse: Part[] = [
{
functionResponse: {
@@ -323,7 +415,7 @@ describe('DiscoveredMCPTool', () => {
});
it('should handle an embedded binary ResourceBlock response', async () => {
const params = { resource: 'get' };
const params = { param: 'get' };
const sdkResponse: Part[] = [
{
functionResponse: {
@@ -365,7 +457,7 @@ describe('DiscoveredMCPTool', () => {
});
it('should handle a mix of content block types', async () => {
const params = { action: 'complex' };
const params = { param: 'complex' };
const sdkResponse: Part[] = [
{
functionResponse: {
@@ -408,7 +500,7 @@ describe('DiscoveredMCPTool', () => {
});
it('should ignore unknown content block types', async () => {
const params = { action: 'test' };
const params = { param: 'test' };
const sdkResponse: Part[] = [
{
functionResponse: {
@@ -434,7 +526,7 @@ describe('DiscoveredMCPTool', () => {
});
it('should handle a complex mix of content block types', async () => {
const params = { action: 'super-complex' };
const params = { param: 'super-complex' };
const sdkResponse: Part[] = [
{
functionResponse: {
@@ -504,14 +596,14 @@ describe('DiscoveredMCPTool', () => {
undefined,
true,
);
const invocation = trustedTool.build({});
const invocation = trustedTool.build({ param: 'mock' });
expect(
await invocation.shouldConfirmExecute(new AbortController().signal),
).toBe(false);
});
it('should return false if server is allowlisted', async () => {
const invocation = tool.build({}) as any;
const invocation = tool.build({ param: 'mock' }) as any;
invocation.constructor.allowlist.add(serverName);
expect(
await invocation.shouldConfirmExecute(new AbortController().signal),
@@ -520,7 +612,7 @@ describe('DiscoveredMCPTool', () => {
it('should return false if tool is allowlisted', async () => {
const toolAllowlistKey = `${serverName}.${serverToolName}`;
const invocation = tool.build({}) as any;
const invocation = tool.build({ param: 'mock' }) as any;
invocation.constructor.allowlist.add(toolAllowlistKey);
expect(
await invocation.shouldConfirmExecute(new AbortController().signal),
@@ -528,7 +620,7 @@ describe('DiscoveredMCPTool', () => {
});
it('should return confirmation details if not trusted and not allowlisted', async () => {
const invocation = tool.build({});
const invocation = tool.build({ param: 'mock' });
const confirmation = await invocation.shouldConfirmExecute(
new AbortController().signal,
);
@@ -551,7 +643,7 @@ describe('DiscoveredMCPTool', () => {
});
it('should add server to allowlist on ProceedAlwaysServer', async () => {
const invocation = tool.build({}) as any;
const invocation = tool.build({ param: 'mock' }) as any;
const confirmation = await invocation.shouldConfirmExecute(
new AbortController().signal,
);
@@ -575,7 +667,7 @@ describe('DiscoveredMCPTool', () => {
it('should add tool to allowlist on ProceedAlwaysTool', async () => {
const toolAllowlistKey = `${serverName}.${serverToolName}`;
const invocation = tool.build({}) as any;
const invocation = tool.build({ param: 'mock' }) as any;
const confirmation = await invocation.shouldConfirmExecute(
new AbortController().signal,
);
@@ -598,7 +690,7 @@ describe('DiscoveredMCPTool', () => {
});
it('should handle Cancel confirmation outcome', async () => {
const invocation = tool.build({}) as any;
const invocation = tool.build({ param: 'mock' }) as any;
const confirmation = await invocation.shouldConfirmExecute(
new AbortController().signal,
);
@@ -625,7 +717,7 @@ describe('DiscoveredMCPTool', () => {
});
it('should handle ProceedOnce confirmation outcome', async () => {
const invocation = tool.build({}) as any;
const invocation = tool.build({ param: 'mock' }) as any;
const confirmation = await invocation.shouldConfirmExecute(
new AbortController().signal,
);

View File

@@ -104,6 +104,28 @@ class DiscoveredMCPToolInvocation extends BaseToolInvocation<
return confirmationDetails;
}
// Determine if the response contains tool errors
// This is needed because CallToolResults should return errors inside the response.
// ref: https://modelcontextprotocol.io/specification/2025-06-18/schema#calltoolresult
isMCPToolError(rawResponseParts: Part[]): boolean {
const functionResponse = rawResponseParts?.[0]?.functionResponse;
const response = functionResponse?.response;
interface McpError {
isError?: boolean | string;
}
if (response) {
const error = (response as { error?: McpError })?.error;
const isError = error?.isError;
if (error && (isError === true || isError === 'true')) {
return true;
}
}
return false;
}
async execute(): Promise<ToolResult> {
const functionCalls: FunctionCall[] = [
{
@@ -113,6 +135,14 @@ class DiscoveredMCPToolInvocation extends BaseToolInvocation<
];
const rawResponseParts = await this.mcpTool.callTool(functionCalls);
// Ensure the response is not an error
if (this.isMCPToolError(rawResponseParts)) {
throw new Error(
`MCP tool '${this.serverToolName}' reported tool error with response: ${JSON.stringify(rawResponseParts)}`,
);
}
const transformedParts = transformMcpContentToParts(rawResponseParts);
return {
@@ -241,7 +271,7 @@ function transformResourceLinkBlock(block: McpResourceLinkBlock): Part {
*/
function transformMcpContentToParts(sdkResponse: Part[]): Part[] {
const funcResponse = sdkResponse?.[0]?.functionResponse;
const mcpContent = funcResponse?.response?.content as McpContentBlock[];
const mcpContent = funcResponse?.response?.['content'] as McpContentBlock[];
const toolName = funcResponse?.name || 'unknown tool';
if (!Array.isArray(mcpContent)) {
@@ -278,8 +308,9 @@ function transformMcpContentToParts(sdkResponse: Part[]): Part[] {
* @returns A formatted string representing the tool's output.
*/
function getStringifiedResultForDisplay(rawResponse: Part[]): string {
const mcpContent = rawResponse?.[0]?.functionResponse?.response
?.content as McpContentBlock[];
const mcpContent = rawResponse?.[0]?.functionResponse?.response?.[
'content'
] as McpContentBlock[];
if (!Array.isArray(mcpContent)) {
return '```json\n' + JSON.stringify(rawResponse, null, 2) + '\n```';

View File

@@ -20,7 +20,6 @@ import * as Diff from 'diff';
import { DEFAULT_DIFF_OPTIONS } from './diffOptions.js';
import { tildeifyPath } from '../utils/paths.js';
import { ModifiableDeclarativeTool, ModifyContext } from './modifiable-tool.js';
import { SchemaValidator } from '../utils/schemaValidator.js';
const memoryToolSchemaData: FunctionDeclaration = {
name: 'save_memory',
@@ -396,15 +395,9 @@ export class MemoryTool
);
}
override validateToolParams(params: SaveMemoryParams): string | null {
const errors = SchemaValidator.validate(
this.schema.parametersJsonSchema,
params,
);
if (errors) {
return errors;
}
protected override validateToolParamValues(
params: SaveMemoryParams,
): string | null {
if (params.fact.trim() === '') {
return 'Parameter "fact" must be a non-empty string.';
}

View File

@@ -13,6 +13,7 @@ import fs from 'fs';
import fsp from 'fs/promises';
import { Config } from '../config/config.js';
import { FileDiscoveryService } from '../services/fileDiscoveryService.js';
import { StandardFileSystemService } from '../services/fileSystemService.js';
import { createMockWorkspaceContext } from '../test-utils/mockWorkspaceContext.js';
import { ToolInvocation, ToolResult } from './tools.js';
@@ -29,6 +30,7 @@ describe('ReadFileTool', () => {
const mockConfigInstance = {
getFileService: () => new FileDiscoveryService(tempRootDir),
getFileSystemService: () => new StandardFileSystemService(),
getTargetDir: () => tempRootDir,
getWorkspaceContext: () => createMockWorkspaceContext(tempRootDir),
} as unknown as Config;
@@ -69,6 +71,15 @@ describe('ReadFileTool', () => {
);
});
it('should throw error if path is empty', () => {
const params: ReadFileToolParams = {
absolute_path: '',
};
expect(() => tool.build(params)).toThrow(
/The 'absolute_path' parameter must be non-empty./,
);
});
it('should throw error if offset is negative', () => {
const params: ReadFileToolParams = {
absolute_path: path.join(tempRootDir, 'test.txt'),

View File

@@ -5,7 +5,6 @@
*/
import path from 'path';
import { SchemaValidator } from '../utils/schemaValidator.js';
import { makeRelative, shortenPath } from '../utils/paths.js';
import {
BaseDeclarativeTool,
@@ -74,6 +73,7 @@ class ReadFileToolInvocation extends BaseToolInvocation<
const result = await processSingleFileContent(
this.params.absolute_path,
this.config.getTargetDir(),
this.config.getFileSystemService(),
this.params.offset,
this.params.limit,
);
@@ -198,18 +198,14 @@ export class ReadFileTool extends BaseDeclarativeTool<
);
}
protected override validateToolParams(
protected override validateToolParamValues(
params: ReadFileToolParams,
): string | null {
const errors = SchemaValidator.validate(
this.schema.parametersJsonSchema,
params,
);
if (errors) {
return errors;
const filePath = params.absolute_path;
if (params.absolute_path.trim() === '') {
return "The 'absolute_path' parameter must be non-empty.";
}
const filePath = params.absolute_path;
if (!path.isAbsolute(filePath)) {
return `File path must be absolute, but was relative: ${filePath}. You must provide an absolute path.`;
}

View File

@@ -14,6 +14,7 @@ import fs from 'fs'; // Actual fs for setup
import os from 'os';
import { Config } from '../config/config.js';
import { WorkspaceContext } from '../utils/workspaceContext.js';
import { StandardFileSystemService } from '../services/fileSystemService.js';
vi.mock('mime-types', () => {
const lookup = (filename: string) => {
@@ -59,6 +60,7 @@ describe('ReadManyFilesTool', () => {
const fileService = new FileDiscoveryService(tempRootDir);
const mockConfig = {
getFileService: () => fileService,
getFileSystemService: () => new StandardFileSystemService(),
getFileFilteringOptions: () => ({
respectGitIgnore: true,
@@ -456,6 +458,7 @@ describe('ReadManyFilesTool', () => {
const fileService = new FileDiscoveryService(tempDir1);
const mockConfig = {
getFileService: () => fileService,
getFileSystemService: () => new StandardFileSystemService(),
getFileFilteringOptions: () => ({
respectGitIgnore: true,
respectGeminiIgnore: true,
@@ -524,6 +527,43 @@ describe('ReadManyFilesTool', () => {
expect(truncatedFileContent).toContain('L200');
expect(truncatedFileContent).not.toContain('L2400');
});
it('should read files with special characters like [] and () in the path', async () => {
const filePath = 'src/app/[test]/(dashboard)/testing/components/code.tsx';
createFile(filePath, 'Content of receive-detail');
const params = { paths: [filePath] };
const invocation = tool.build(params);
const result = await invocation.execute(new AbortController().signal);
const expectedPath = path.join(tempRootDir, filePath);
expect(result.llmContent).toEqual([
`--- ${expectedPath} ---
Content of receive-detail
`,
]);
expect(result.returnDisplay).toContain(
'Successfully read and concatenated content from **1 file(s)**',
);
});
it('should read files with special characters in the name', async () => {
createFile('file[1].txt', 'Content of file[1]');
const params = { paths: ['file[1].txt'] };
const invocation = tool.build(params);
const result = await invocation.execute(new AbortController().signal);
const expectedPath = path.join(tempRootDir, 'file[1].txt');
expect(result.llmContent).toEqual([
`--- ${expectedPath} ---
Content of file[1]
`,
]);
expect(result.returnDisplay).toContain(
'Successfully read and concatenated content from **1 file(s)**',
);
});
});
describe('Batch Processing', () => {

View File

@@ -11,10 +11,10 @@ import {
ToolInvocation,
ToolResult,
} from './tools.js';
import { SchemaValidator } from '../utils/schemaValidator.js';
import { getErrorMessage } from '../utils/errors.js';
import * as fs from 'fs';
import * as path from 'path';
import { glob } from 'glob';
import { glob, escape } from 'glob';
import { getCurrentGeminiMdFilename } from './memoryTool.js';
import {
detectFileType,
@@ -245,18 +245,27 @@ ${finalExclusionPatternsForDescription
const workspaceDirs = this.config.getWorkspaceContext().getDirectories();
for (const dir of workspaceDirs) {
const entriesInDir = await glob(
searchPatterns.map((p) => p.replace(/\\/g, '/')),
{
cwd: dir,
ignore: effectiveExcludes,
nodir: true,
dot: true,
absolute: true,
nocase: true,
signal,
},
);
const processedPatterns = [];
for (const p of searchPatterns) {
const normalizedP = p.replace(/\\/g, '/');
const fullPath = path.join(dir, normalizedP);
if (fs.existsSync(fullPath)) {
processedPatterns.push(escape(normalizedP));
} else {
// The path does not exist or is not a file, so we treat it as a glob pattern.
processedPatterns.push(normalizedP);
}
}
const entriesInDir = await glob(processedPatterns, {
cwd: dir,
ignore: effectiveExcludes,
nodir: true,
dot: true,
absolute: true,
nocase: true,
signal,
});
for (const entry of entriesInDir) {
allEntries.add(entry);
}
@@ -388,6 +397,7 @@ ${finalExclusionPatternsForDescription
const fileReadResult = await processSingleFileContent(
filePath,
this.config.getTargetDir(),
this.config.getFileSystemService(),
);
if (fileReadResult.error) {
@@ -626,19 +636,6 @@ Use this tool when the user's query implies needing the content of several files
);
}
protected override validateToolParams(
params: ReadManyFilesParams,
): string | null {
const errors = SchemaValidator.validate(
this.schema.parametersJsonSchema,
params,
);
if (errors) {
return errors;
}
return null;
}
protected createInvocation(
params: ReadManyFilesParams,
): ToolInvocation<ReadManyFilesParams, ToolResult> {

View File

@@ -61,6 +61,7 @@ describe('ShellTool', () => {
name: 'Qwen-Coder',
email: 'qwen-coder@alibabacloud.com',
}),
getShouldUseNodePtyShell: vi.fn().mockReturnValue(false),
} as unknown as Config;
shellTool = new ShellTool(mockConfig);
@@ -151,13 +152,12 @@ describe('ShellTool', () => {
const fullResult: ShellExecutionResult = {
rawOutput: Buffer.from(result.output || ''),
output: 'Success',
stdout: 'Success',
stderr: '',
exitCode: 0,
signal: null,
error: null,
aborted: false,
pid: 12345,
executionMethod: 'child_process',
...result,
};
resolveExecutionPromise(fullResult);
@@ -183,6 +183,9 @@ describe('ShellTool', () => {
expect.any(String),
expect.any(Function),
mockAbortSignal,
false,
undefined,
undefined,
);
expect(result.llmContent).toContain('Background PIDs: 54322');
expect(vi.mocked(fs.unlinkSync)).toHaveBeenCalledWith(tmpFile);
@@ -208,6 +211,9 @@ describe('ShellTool', () => {
expect.any(String),
expect.any(Function),
mockAbortSignal,
false,
undefined,
undefined,
);
});
@@ -231,6 +237,9 @@ describe('ShellTool', () => {
expect.any(String),
expect.any(Function),
mockAbortSignal,
false,
undefined,
undefined,
);
});
@@ -254,6 +263,9 @@ describe('ShellTool', () => {
expect.any(String),
expect.any(Function),
mockAbortSignal,
false,
undefined,
undefined,
);
});
@@ -267,13 +279,12 @@ describe('ShellTool', () => {
resolveShellExecution({
rawOutput: Buffer.from(''),
output: '',
stdout: '',
stderr: '',
exitCode: 0,
signal: null,
error: null,
aborted: false,
pid: 12345,
executionMethod: 'child_process',
});
await promise;
expect(mockShellExecutionService).toHaveBeenCalledWith(
@@ -281,6 +292,9 @@ describe('ShellTool', () => {
expect.any(String),
expect.any(Function),
mockAbortSignal,
false,
undefined,
undefined,
);
});
@@ -295,16 +309,14 @@ describe('ShellTool', () => {
error,
exitCode: 1,
output: 'err',
stderr: 'err',
rawOutput: Buffer.from('err'),
stdout: '',
signal: null,
aborted: false,
pid: 12345,
executionMethod: 'child_process',
});
const result = await promise;
// The final llmContent should contain the user's command, not the wrapper
expect(result.llmContent).toContain('Error: wrapped command failed');
expect(result.llmContent).not.toContain('pgrep');
});
@@ -344,13 +356,12 @@ describe('ShellTool', () => {
resolveExecutionPromise({
output: 'long output',
rawOutput: Buffer.from('long output'),
stdout: 'long output',
stderr: '',
exitCode: 0,
signal: null,
error: null,
aborted: false,
pid: 12345,
executionMethod: 'child_process',
});
const result = await promise;
@@ -402,7 +413,6 @@ describe('ShellTool', () => {
// First chunk, should be throttled.
mockShellOutputCallback({
type: 'data',
stream: 'stdout',
chunk: 'hello ',
});
expect(updateOutputMock).not.toHaveBeenCalled();
@@ -413,24 +423,22 @@ describe('ShellTool', () => {
// Send a second chunk. THIS event triggers the update with the CUMULATIVE content.
mockShellOutputCallback({
type: 'data',
stream: 'stderr',
chunk: 'world',
chunk: 'hello world',
});
// It should have been called once now with the combined output.
expect(updateOutputMock).toHaveBeenCalledOnce();
expect(updateOutputMock).toHaveBeenCalledWith('hello \nworld');
expect(updateOutputMock).toHaveBeenCalledWith('hello world');
resolveExecutionPromise({
rawOutput: Buffer.from(''),
output: '',
stdout: '',
stderr: '',
exitCode: 0,
signal: null,
error: null,
aborted: false,
pid: 12345,
executionMethod: 'child_process',
});
await promise;
});
@@ -472,13 +480,12 @@ describe('ShellTool', () => {
resolveExecutionPromise({
rawOutput: Buffer.from(''),
output: '',
stdout: '',
stderr: '',
exitCode: 0,
signal: null,
error: null,
aborted: false,
pid: 12345,
executionMethod: 'child_process',
});
await promise;
});
@@ -494,13 +501,12 @@ describe('ShellTool', () => {
resolveExecutionPromise({
rawOutput: Buffer.from(''),
output: '',
stdout: '',
stderr: '',
exitCode: 0,
signal: null,
error: null,
aborted: false,
pid: 12345,
executionMethod: 'child_process',
});
await promise;
@@ -513,6 +519,9 @@ describe('ShellTool', () => {
expect.any(String),
expect.any(Function),
mockAbortSignal,
false,
undefined,
undefined,
);
});
@@ -524,13 +533,12 @@ describe('ShellTool', () => {
resolveExecutionPromise({
rawOutput: Buffer.from(''),
output: '',
stdout: '',
stderr: '',
exitCode: 0,
signal: null,
error: null,
aborted: false,
pid: 12345,
executionMethod: 'child_process',
});
await promise;
@@ -542,6 +550,9 @@ describe('ShellTool', () => {
expect.any(String),
expect.any(Function),
mockAbortSignal,
false,
undefined,
undefined,
);
});
@@ -553,13 +564,12 @@ describe('ShellTool', () => {
resolveExecutionPromise({
rawOutput: Buffer.from(''),
output: '',
stdout: '',
stderr: '',
exitCode: 0,
signal: null,
error: null,
aborted: false,
pid: 12345,
executionMethod: 'child_process',
});
await promise;
@@ -571,6 +581,9 @@ describe('ShellTool', () => {
expect.any(String),
expect.any(Function),
mockAbortSignal,
false,
undefined,
undefined,
);
});
@@ -582,13 +595,12 @@ describe('ShellTool', () => {
resolveExecutionPromise({
rawOutput: Buffer.from(''),
output: '',
stdout: '',
stderr: '',
exitCode: 0,
signal: null,
error: null,
aborted: false,
pid: 12345,
executionMethod: 'child_process',
});
await promise;
@@ -599,6 +611,9 @@ describe('ShellTool', () => {
expect.any(String),
expect.any(Function),
mockAbortSignal,
false,
undefined,
undefined,
);
});
@@ -610,13 +625,12 @@ describe('ShellTool', () => {
resolveExecutionPromise({
rawOutput: Buffer.from(''),
output: '',
stdout: '',
stderr: '',
exitCode: 0,
signal: null,
error: null,
aborted: false,
pid: 12345,
executionMethod: 'child_process',
});
await promise;
@@ -627,6 +641,9 @@ describe('ShellTool', () => {
expect.any(String),
expect.any(Function),
mockAbortSignal,
false,
undefined,
undefined,
);
});
@@ -638,13 +655,12 @@ describe('ShellTool', () => {
resolveExecutionPromise({
rawOutput: Buffer.from(''),
output: '',
stdout: '',
stderr: '',
exitCode: 0,
signal: null,
error: null,
aborted: false,
pid: 12345,
executionMethod: 'child_process',
});
await promise;
@@ -656,6 +672,9 @@ describe('ShellTool', () => {
expect.any(String),
expect.any(Function),
mockAbortSignal,
false,
undefined,
undefined,
);
});
@@ -674,13 +693,12 @@ describe('ShellTool', () => {
resolveExecutionPromise({
rawOutput: Buffer.from(''),
output: '',
stdout: '',
stderr: '',
exitCode: 0,
signal: null,
error: null,
aborted: false,
pid: 12345,
executionMethod: 'child_process',
});
await promise;
@@ -691,6 +709,9 @@ describe('ShellTool', () => {
expect.any(String),
expect.any(Function),
mockAbortSignal,
false,
undefined,
undefined,
);
});
@@ -709,13 +730,12 @@ describe('ShellTool', () => {
resolveExecutionPromise({
rawOutput: Buffer.from(''),
output: '',
stdout: '',
stderr: '',
exitCode: 0,
signal: null,
error: null,
aborted: false,
pid: 12345,
executionMethod: 'child_process',
});
await promise;
@@ -727,6 +747,9 @@ describe('ShellTool', () => {
expect.any(String),
expect.any(Function),
mockAbortSignal,
false,
undefined,
undefined,
);
});
});
@@ -765,6 +788,20 @@ describe('ShellTool', () => {
).toThrow();
});
});
describe('getDescription', () => {
it('should return the windows description when on windows', () => {
vi.mocked(os.platform).mockReturnValue('win32');
const shellTool = new ShellTool(mockConfig);
expect(shellTool.description).toMatchSnapshot();
});
it('should return the non-windows description when not on windows', () => {
vi.mocked(os.platform).mockReturnValue('linux');
const shellTool = new ShellTool(mockConfig);
expect(shellTool.description).toMatchSnapshot();
});
});
});
describe('validateToolParams', () => {

View File

@@ -19,7 +19,6 @@ import {
ToolConfirmationOutcome,
Kind,
} from './tools.js';
import { SchemaValidator } from '../utils/schemaValidator.js';
import { getErrorMessage } from '../utils/errors.js';
import { summarizeToolOutput } from '../utils/summarizer.js';
import {
@@ -102,6 +101,8 @@ class ShellToolInvocation extends BaseToolInvocation<
async execute(
signal: AbortSignal,
updateOutput?: (output: string) => void,
terminalColumns?: number,
terminalRows?: number,
): Promise<ToolResult> {
const strippedCommand = stripShellWrapper(this.params.command);
@@ -145,13 +146,11 @@ class ShellToolInvocation extends BaseToolInvocation<
this.params.directory || '',
);
let cumulativeStdout = '';
let cumulativeStderr = '';
let cumulativeOutput = '';
let lastUpdateTime = Date.now();
let isBinaryStream = false;
const { result: resultPromise } = ShellExecutionService.execute(
const { result: resultPromise } = await ShellExecutionService.execute(
commandToExecute,
cwd,
(event: ShellOutputEvent) => {
@@ -164,15 +163,9 @@ class ShellToolInvocation extends BaseToolInvocation<
switch (event.type) {
case 'data':
if (isBinaryStream) break; // Don't process text if we are in binary mode
if (event.stream === 'stdout') {
cumulativeStdout += event.chunk;
} else {
cumulativeStderr += event.chunk;
}
currentDisplayOutput =
cumulativeStdout +
(cumulativeStderr ? `\n${cumulativeStderr}` : '');
if (isBinaryStream) break;
cumulativeOutput = event.chunk;
currentDisplayOutput = cumulativeOutput;
if (Date.now() - lastUpdateTime > OUTPUT_UPDATE_INTERVAL_MS) {
shouldUpdate = true;
}
@@ -203,6 +196,9 @@ class ShellToolInvocation extends BaseToolInvocation<
}
},
signal,
this.config.getShouldUseNodePtyShell(),
terminalColumns,
terminalRows,
);
const result = await resultPromise;
@@ -234,7 +230,7 @@ class ShellToolInvocation extends BaseToolInvocation<
if (result.aborted) {
llmContent = 'Command was cancelled by user before it could complete.';
if (result.output.trim()) {
llmContent += ` Below is the output (on stdout and stderr) before it was cancelled:\n${result.output}`;
llmContent += ` Below is the output before it was cancelled:\n${result.output}`;
} else {
llmContent += ' There was no output before it was cancelled.';
}
@@ -248,8 +244,7 @@ class ShellToolInvocation extends BaseToolInvocation<
llmContent = [
`Command: ${this.params.command}`,
`Directory: ${this.params.directory || '(root)'}`,
`Stdout: ${result.stdout || '(empty)'}`,
`Stderr: ${result.stderr || '(empty)'}`,
`Output: ${result.output || '(empty)'}`,
`Error: ${finalError}`, // Use the cleaned error string.
`Exit Code: ${result.exitCode ?? '(none)'}`,
`Signal: ${result.signal ?? '(none)'}`,
@@ -345,18 +340,10 @@ Co-authored-by: ${gitCoAuthorSettings.name} <${gitCoAuthorSettings.email}>`;
}
}
export class ShellTool extends BaseDeclarativeTool<
ShellToolParams,
ToolResult
> {
static Name: string = 'run_shell_command';
private allowlist: Set<string> = new Set();
constructor(private readonly config: Config) {
super(
ShellTool.Name,
'Shell',
`This tool executes a given shell command as \`bash -c <command>\`.
function getShellToolDescription(): string {
const platform = os.platform();
const toolDescription = `
${platform === 'win32' ? 'This tool executes a given shell command as `cmd.exe /c <command>`.' : 'This tool executes a given shell command as `bash -c <command>`. '}
**Background vs Foreground Execution:**
You should decide whether commands should run in background or foreground based on their nature:
@@ -375,7 +362,7 @@ export class ShellTool extends BaseDeclarativeTool<
- Git operations: \`git commit\`, \`git push\`
- Test runs: \`npm test\`, \`pytest\`
Command is executed as a subprocess that leads its own process group. Command process group can be terminated as \`kill -- -PGID\` or signaled as \`kill -s SIGNAL -- -PGID\`.
${platform === 'win32' ? '' : 'Command is executed as a subprocess that leads its own process group. Command process group can be terminated as `kill -- -PGID` or signaled as `kill -s SIGNAL -- -PGID`.'}
The following information is returned:
@@ -387,14 +374,38 @@ export class ShellTool extends BaseDeclarativeTool<
Exit Code: Exit code or \`(none)\` if terminated by signal.
Signal: Signal number or \`(none)\` if no signal was received.
Background PIDs: List of background processes started or \`(none)\`.
Process Group PGID: Process group started or \`(none)\``,
Process Group PGID: Process group started or \`(none)\``;
return toolDescription;
}
function getCommandDescription(): string {
if (os.platform() === 'win32') {
return 'Exact command to execute as `cmd.exe /c <command>`';
} else {
return 'Exact bash command to execute as `bash -c <command>`';
}
}
export class ShellTool extends BaseDeclarativeTool<
ShellToolParams,
ToolResult
> {
static Name: string = 'run_shell_command';
private allowlist: Set<string> = new Set();
constructor(private readonly config: Config) {
super(
ShellTool.Name,
'Shell',
getShellToolDescription(),
Kind.Execute,
{
type: 'object',
properties: {
command: {
type: 'string',
description: 'Exact bash command to execute as `bash -c <command>`',
description: getCommandDescription(),
},
is_background: {
type: 'boolean',
@@ -419,7 +430,9 @@ export class ShellTool extends BaseDeclarativeTool<
);
}
override validateToolParams(params: ShellToolParams): string | null {
protected override validateToolParamValues(
params: ShellToolParams,
): string | null {
const commandCheck = isCommandAllowed(params.command, this.config);
if (!commandCheck.allowed) {
if (!commandCheck.reason) {
@@ -430,13 +443,6 @@ export class ShellTool extends BaseDeclarativeTool<
}
return commandCheck.reason;
}
const errors = SchemaValidator.validate(
this.schema.parametersJsonSchema,
params,
);
if (errors) {
return errors;
}
if (!params.command.trim()) {
return 'Command cannot be empty.';
}

View File

@@ -75,7 +75,9 @@ describe('TodoWriteTool', () => {
};
const result = tool.validateToolParams(params);
expect(result).toContain('must NOT have fewer than 1 characters');
expect(result).toContain(
'Each todo must have a non-empty "content" string',
);
});
it('should reject todos with empty id', () => {
@@ -87,7 +89,7 @@ describe('TodoWriteTool', () => {
};
const result = tool.validateToolParams(params);
expect(result).toContain('non-empty "id"');
expect(result).toContain('non-empty "id" string');
});
it('should reject todos with invalid status', () => {
@@ -103,7 +105,9 @@ describe('TodoWriteTool', () => {
};
const result = tool.validateToolParams(params);
expect(result).toContain('must be equal to one of the allowed values');
expect(result).toContain(
'Each todo must have a valid "status" (pending, in_progress, completed)',
);
});
it('should reject todos with duplicate IDs', () => {

View File

@@ -17,7 +17,6 @@ import * as path from 'path';
import * as process from 'process';
import { QWEN_DIR } from '../utils/paths.js';
import { SchemaValidator } from '../utils/schemaValidator.js';
import { Config } from '../config/config.js';
export interface TodoItem {
@@ -248,7 +247,8 @@ When in doubt, use this tool. Being proactive with task management demonstrates
const TODO_SUBDIR = 'todos';
function getTodoFilePath(sessionId?: string): string {
const homeDir = process.env.HOME || process.env.USERPROFILE || process.cwd();
const homeDir =
process.env['HOME'] || process.env['USERPROFILE'] || process.cwd();
const todoDir = path.join(homeDir, QWEN_DIR, TODO_SUBDIR);
// Use sessionId if provided, otherwise fall back to 'default'
@@ -383,7 +383,7 @@ export async function readTodosForSession(
export async function listTodoSessions(): Promise<string[]> {
try {
const homeDir =
process.env.HOME || process.env.USERPROFILE || process.cwd();
process.env['HOME'] || process.env['USERPROFILE'] || process.cwd();
const todoDir = path.join(homeDir, QWEN_DIR, TODO_SUBDIR);
const files = await fs.readdir(todoDir);
return files
@@ -415,14 +415,6 @@ export class TodoWriteTool extends BaseDeclarativeTool<
}
override validateToolParams(params: TodoWriteParams): string | null {
const errors = SchemaValidator.validate(
this.schema.parametersJsonSchema,
params,
);
if (errors) {
return errors;
}
// Validate todos array
if (!Array.isArray(params.todos)) {
return 'Parameter "todos" must be an array.';

View File

@@ -13,6 +13,7 @@ export enum ToolErrorType {
UNKNOWN = 'unknown',
UNHANDLED_EXCEPTION = 'unhandled_exception',
TOOL_NOT_REGISTERED = 'tool_not_registered',
EXECUTION_FAILED = 'execution_failed',
// File System Errors
FILE_NOT_FOUND = 'file_not_found',

View File

@@ -22,15 +22,17 @@ import { spawn } from 'node:child_process';
import fs from 'node:fs';
import { MockTool } from '../test-utils/tools.js';
import { McpClientManager } from './mcp-client-manager.js';
vi.mock('node:fs');
// Use vi.hoisted to define the mock function so it can be used in the vi.mock factory
const mockDiscoverMcpTools = vi.hoisted(() => vi.fn());
// Mock ./mcp-client.js to control its behavior within tool-registry tests
vi.mock('./mcp-client.js', () => ({
discoverMcpTools: mockDiscoverMcpTools,
}));
vi.mock('./mcp-client.js', async () => {
const originalModule = await vi.importActual('./mcp-client.js');
return {
...originalModule,
};
});
// Mock node:child_process
vi.mock('node:child_process', async () => {
@@ -142,7 +144,6 @@ describe('ToolRegistry', () => {
clear: vi.fn(),
removePromptsByServer: vi.fn(),
} as any);
mockDiscoverMcpTools.mockReset().mockResolvedValue(undefined);
});
afterEach(() => {
@@ -310,6 +311,10 @@ describe('ToolRegistry', () => {
});
it('should discover tools using MCP servers defined in getMcpServers', async () => {
const discoverSpy = vi.spyOn(
McpClientManager.prototype,
'discoverAllMcpTools',
);
mockConfigGetToolDiscoveryCommand.mockReturnValue(undefined);
vi.spyOn(config, 'getMcpServerCommand').mockReturnValue(undefined);
const mcpServerConfigVal = {
@@ -323,38 +328,7 @@ describe('ToolRegistry', () => {
await toolRegistry.discoverAllTools();
expect(mockDiscoverMcpTools).toHaveBeenCalledWith(
mcpServerConfigVal,
undefined,
toolRegistry,
config.getPromptRegistry(),
false,
expect.any(Object),
);
});
it('should discover tools using MCP servers defined in getMcpServers', async () => {
mockConfigGetToolDiscoveryCommand.mockReturnValue(undefined);
vi.spyOn(config, 'getMcpServerCommand').mockReturnValue(undefined);
const mcpServerConfigVal = {
'my-mcp-server': {
command: 'mcp-server-cmd',
args: ['--port', '1234'],
trust: true,
},
};
vi.spyOn(config, 'getMcpServers').mockReturnValue(mcpServerConfigVal);
await toolRegistry.discoverAllTools();
expect(mockDiscoverMcpTools).toHaveBeenCalledWith(
mcpServerConfigVal,
undefined,
toolRegistry,
config.getPromptRegistry(),
false,
expect.any(Object),
);
expect(discoverSpy).toHaveBeenCalled();
});
});
});

View File

@@ -5,56 +5,47 @@
*/
import { FunctionDeclaration } from '@google/genai';
import { AnyDeclarativeTool, Kind, ToolResult, BaseTool } from './tools.js';
import {
AnyDeclarativeTool,
Kind,
ToolResult,
BaseDeclarativeTool,
BaseToolInvocation,
ToolInvocation,
} from './tools.js';
import { Config } from '../config/config.js';
import { spawn } from 'node:child_process';
import { StringDecoder } from 'node:string_decoder';
import { discoverMcpTools } from './mcp-client.js';
import { connectAndDiscover } from './mcp-client.js';
import { McpClientManager } from './mcp-client-manager.js';
import { DiscoveredMCPTool } from './mcp-tool.js';
import { parse } from 'shell-quote';
type ToolParams = Record<string, unknown>;
export class DiscoveredTool extends BaseTool<ToolParams, ToolResult> {
class DiscoveredToolInvocation extends BaseToolInvocation<
ToolParams,
ToolResult
> {
constructor(
private readonly config: Config,
name: string,
override readonly description: string,
override readonly parameterSchema: Record<string, unknown>,
private readonly toolName: string,
params: ToolParams,
) {
const discoveryCmd = config.getToolDiscoveryCommand()!;
const callCommand = config.getToolCallCommand()!;
description += `
This tool was discovered from the project by executing the command \`${discoveryCmd}\` on project root.
When called, this tool will execute the command \`${callCommand} ${name}\` on project root.
Tool discovery and call commands can be configured in project or user settings.
When called, the tool call command is executed as a subprocess.
On success, tool output is returned as a json string.
Otherwise, the following information is returned:
Stdout: Output on stdout stream. Can be \`(empty)\` or partial.
Stderr: Output on stderr stream. Can be \`(empty)\` or partial.
Error: Error or \`(none)\` if no error was reported for the subprocess.
Exit Code: Exit code or \`(none)\` if terminated by signal.
Signal: Signal number or \`(none)\` if no signal was received.
`;
super(
name,
name,
description,
Kind.Other,
parameterSchema,
false, // isOutputMarkdown
false, // canUpdateOutput
);
super(params);
}
async execute(params: ToolParams): Promise<ToolResult> {
getDescription(): string {
return `Calling discovered tool: ${this.toolName}`;
}
async execute(
_signal: AbortSignal,
_updateOutput?: (output: string) => void,
): Promise<ToolResult> {
const callCommand = this.config.getToolCallCommand()!;
const child = spawn(callCommand, [this.name]);
child.stdin.write(JSON.stringify(params));
const child = spawn(callCommand, [this.toolName]);
child.stdin.write(JSON.stringify(this.params));
child.stdin.end();
let stdout = '';
@@ -124,12 +115,67 @@ Signal: Signal number or \`(none)\` if no signal was received.
}
}
export class DiscoveredTool extends BaseDeclarativeTool<
ToolParams,
ToolResult
> {
constructor(
private readonly config: Config,
name: string,
override readonly description: string,
override readonly parameterSchema: Record<string, unknown>,
) {
const discoveryCmd = config.getToolDiscoveryCommand()!;
const callCommand = config.getToolCallCommand()!;
description += `
This tool was discovered from the project by executing the command \`${discoveryCmd}\` on project root.
When called, this tool will execute the command \`${callCommand} ${name}\` on project root.
Tool discovery and call commands can be configured in project or user settings.
When called, the tool call command is executed as a subprocess.
On success, tool output is returned as a json string.
Otherwise, the following information is returned:
Stdout: Output on stdout stream. Can be \`(empty)\` or partial.
Stderr: Output on stderr stream. Can be \`(empty)\` or partial.
Error: Error or \`(none)\` if no error was reported for the subprocess.
Exit Code: Exit code or \`(none)\` if terminated by signal.
Signal: Signal number or \`(none)\` if no signal was received.
`;
super(
name,
name,
description,
Kind.Other,
parameterSchema,
false, // isOutputMarkdown
false, // canUpdateOutput
);
}
protected createInvocation(
params: ToolParams,
): ToolInvocation<ToolParams, ToolResult> {
return new DiscoveredToolInvocation(this.config, this.name, params);
}
}
export class ToolRegistry {
private tools: Map<string, AnyDeclarativeTool> = new Map();
private config: Config;
private mcpClientManager: McpClientManager;
constructor(config: Config) {
this.config = config;
this.mcpClientManager = new McpClientManager(
this.config.getMcpServers() ?? {},
this.config.getMcpServerCommand(),
this,
this.config.getPromptRegistry(),
this.config.getDebugMode(),
this.config.getWorkspaceContext(),
);
}
/**
@@ -184,14 +230,7 @@ export class ToolRegistry {
await this.discoverAndRegisterToolsFromCommand();
// discover tools using MCP servers, if configured
await discoverMcpTools(
this.config.getMcpServers() ?? {},
this.config.getMcpServerCommand(),
this,
this.config.getPromptRegistry(),
this.config.getDebugMode(),
this.config.getWorkspaceContext(),
);
await this.mcpClientManager.discoverAllMcpTools();
}
/**
@@ -206,14 +245,14 @@ export class ToolRegistry {
this.config.getPromptRegistry().clear();
// discover tools using MCP servers, if configured
await discoverMcpTools(
this.config.getMcpServers() ?? {},
this.config.getMcpServerCommand(),
this,
this.config.getPromptRegistry(),
this.config.getDebugMode(),
this.config.getWorkspaceContext(),
);
await this.mcpClientManager.discoverAllMcpTools();
}
/**
* Restarts all MCP servers and re-discovers tools.
*/
async restartMcpServers(): Promise<void> {
await this.discoverMcpTools();
}
/**
@@ -233,9 +272,9 @@ export class ToolRegistry {
const mcpServers = this.config.getMcpServers() ?? {};
const serverConfig = mcpServers[serverName];
if (serverConfig) {
await discoverMcpTools(
{ [serverName]: serverConfig },
undefined,
await connectAndDiscover(
serverName,
serverConfig,
this,
this.config.getPromptRegistry(),
this.config.getDebugMode(),

View File

@@ -4,8 +4,119 @@
* SPDX-License-Identifier: Apache-2.0
*/
import { describe, it, expect } from 'vitest';
import { hasCycleInSchema } from './tools.js'; // Added getStringifiedResultForDisplay
import { describe, it, expect, vi } from 'vitest';
import {
DeclarativeTool,
hasCycleInSchema,
Kind,
ToolInvocation,
ToolResult,
} from './tools.js';
import { ToolErrorType } from './tool-error.js';
class TestToolInvocation implements ToolInvocation<object, ToolResult> {
constructor(
readonly params: object,
private readonly executeFn: () => Promise<ToolResult>,
) {}
getDescription(): string {
return 'A test invocation';
}
toolLocations() {
return [];
}
shouldConfirmExecute(): Promise<false> {
return Promise.resolve(false);
}
execute(): Promise<ToolResult> {
return this.executeFn();
}
}
class TestTool extends DeclarativeTool<object, ToolResult> {
private readonly buildFn: (params: object) => TestToolInvocation;
constructor(buildFn: (params: object) => TestToolInvocation) {
super('test-tool', 'Test Tool', 'A tool for testing', Kind.Other, {});
this.buildFn = buildFn;
}
build(params: object): ToolInvocation<object, ToolResult> {
return this.buildFn(params);
}
}
describe('DeclarativeTool', () => {
describe('validateBuildAndExecute', () => {
const abortSignal = new AbortController().signal;
it('should return INVALID_TOOL_PARAMS error if build fails', async () => {
const buildError = new Error('Invalid build parameters');
const buildFn = vi.fn().mockImplementation(() => {
throw buildError;
});
const tool = new TestTool(buildFn);
const params = { foo: 'bar' };
const result = await tool.validateBuildAndExecute(params, abortSignal);
expect(buildFn).toHaveBeenCalledWith(params);
expect(result).toEqual({
llmContent: `Error: Invalid parameters provided. Reason: ${buildError.message}`,
returnDisplay: buildError.message,
error: {
message: buildError.message,
type: ToolErrorType.INVALID_TOOL_PARAMS,
},
});
});
it('should return EXECUTION_FAILED error if execute fails', async () => {
const executeError = new Error('Execution failed');
const executeFn = vi.fn().mockRejectedValue(executeError);
const invocation = new TestToolInvocation({}, executeFn);
const buildFn = vi.fn().mockReturnValue(invocation);
const tool = new TestTool(buildFn);
const params = { foo: 'bar' };
const result = await tool.validateBuildAndExecute(params, abortSignal);
expect(buildFn).toHaveBeenCalledWith(params);
expect(executeFn).toHaveBeenCalled();
expect(result).toEqual({
llmContent: `Error: Tool call execution failed. Reason: ${executeError.message}`,
returnDisplay: executeError.message,
error: {
message: executeError.message,
type: ToolErrorType.EXECUTION_FAILED,
},
});
});
it('should return the result of execute on success', async () => {
const successResult: ToolResult = {
llmContent: 'Success!',
returnDisplay: 'Success!',
summary: 'Tool executed successfully',
};
const executeFn = vi.fn().mockResolvedValue(successResult);
const invocation = new TestToolInvocation({}, executeFn);
const buildFn = vi.fn().mockReturnValue(invocation);
const tool = new TestTool(buildFn);
const params = { foo: 'bar' };
const result = await tool.validateBuildAndExecute(params, abortSignal);
expect(buildFn).toHaveBeenCalledWith(params);
expect(executeFn).toHaveBeenCalled();
expect(result).toEqual(successResult);
});
});
});
describe('hasCycleInSchema', () => {
it('should detect a simple direct cycle', () => {

View File

@@ -7,6 +7,7 @@
import { FunctionDeclaration, PartListUnion } from '@google/genai';
import { ToolErrorType } from './tool-error.js';
import { DiffUpdateResult } from '../ide/ideContext.js';
import { SchemaValidator } from '../utils/schemaValidator.js';
/**
* Represents a validated and ready-to-execute tool call.
@@ -86,42 +87,6 @@ export abstract class BaseToolInvocation<
*/
export type AnyToolInvocation = ToolInvocation<object, ToolResult>;
/**
* An adapter that wraps the legacy `Tool` interface to make it compatible
* with the new `ToolInvocation` pattern.
*/
export class LegacyToolInvocation<
TParams extends object,
TResult extends ToolResult,
> implements ToolInvocation<TParams, TResult>
{
constructor(
private readonly legacyTool: BaseTool<TParams, TResult>,
readonly params: TParams,
) {}
getDescription(): string {
return this.legacyTool.getDescription(this.params);
}
toolLocations(): ToolLocation[] {
return this.legacyTool.toolLocations(this.params);
}
shouldConfirmExecute(
abortSignal: AbortSignal,
): Promise<ToolCallConfirmationDetails | false> {
return this.legacyTool.shouldConfirmExecute(this.params, abortSignal);
}
execute(
signal: AbortSignal,
updateOutput?: (output: string) => void,
): Promise<TResult> {
return this.legacyTool.execute(this.params, signal, updateOutput);
}
}
/**
* Interface for a tool builder that validates parameters and creates invocations.
*/
@@ -206,7 +171,7 @@ export abstract class DeclarativeTool<
* @param params The raw parameters from the model.
* @returns An error message string if invalid, null otherwise.
*/
protected validateToolParams(_params: TParams): string | null {
validateToolParams(_params: TParams): string | null {
// Base implementation can be extended by subclasses.
return null;
}
@@ -236,6 +201,64 @@ export abstract class DeclarativeTool<
const invocation = this.build(params);
return invocation.execute(signal, updateOutput);
}
/**
* Similar to `build` but never throws.
* @param params The raw, untrusted parameters from the model.
* @returns A `ToolInvocation` instance.
*/
private silentBuild(
params: TParams,
): ToolInvocation<TParams, TResult> | Error {
try {
return this.build(params);
} catch (e) {
if (e instanceof Error) {
return e;
}
return new Error(String(e));
}
}
/**
* A convenience method that builds and executes the tool in one step.
* Never throws.
* @param params The raw, untrusted parameters from the model.
* @params abortSignal a signal to abort.
* @returns The result of the tool execution.
*/
async validateBuildAndExecute(
params: TParams,
abortSignal: AbortSignal,
): Promise<ToolResult> {
const invocationOrError = this.silentBuild(params);
if (invocationOrError instanceof Error) {
const errorMessage = invocationOrError.message;
return {
llmContent: `Error: Invalid parameters provided. Reason: ${errorMessage}`,
returnDisplay: errorMessage,
error: {
message: errorMessage,
type: ToolErrorType.INVALID_TOOL_PARAMS,
},
};
}
try {
return await invocationOrError.execute(abortSignal);
} catch (error) {
const errorMessage =
error instanceof Error ? error.message : String(error);
return {
llmContent: `Error: Tool call execution failed. Reason: ${errorMessage}`,
returnDisplay: errorMessage,
error: {
message: errorMessage,
type: ToolErrorType.EXECUTION_FAILED,
},
};
}
}
}
/**
@@ -256,6 +279,23 @@ export abstract class BaseDeclarativeTool<
return this.createInvocation(params);
}
override validateToolParams(params: TParams): string | null {
const errors = SchemaValidator.validate(
this.schema.parametersJsonSchema,
params,
);
if (errors) {
return errors;
}
return this.validateToolParamValues(params);
}
protected validateToolParamValues(_params: TParams): string | null {
// Base implementation can be extended by subclasses.
return null;
}
protected abstract createInvocation(
params: TParams,
): ToolInvocation<TParams, TResult>;
@@ -266,116 +306,6 @@ export abstract class BaseDeclarativeTool<
*/
export type AnyDeclarativeTool = DeclarativeTool<object, ToolResult>;
/**
* Base implementation for tools with common functionality
* @deprecated Use `DeclarativeTool` for new tools.
*/
export abstract class BaseTool<
TParams extends object,
TResult extends ToolResult = ToolResult,
> extends DeclarativeTool<TParams, TResult> {
/**
* Creates a new instance of BaseTool
* @param name Internal name of the tool (used for API calls)
* @param displayName User-friendly display name of the tool
* @param description Description of what the tool does
* @param isOutputMarkdown Whether the tool's output should be rendered as markdown
* @param canUpdateOutput Whether the tool supports live (streaming) output
* @param parameterSchema JSON Schema defining the parameters
*/
constructor(
override readonly name: string,
override readonly displayName: string,
override readonly description: string,
override readonly kind: Kind,
override readonly parameterSchema: unknown,
override readonly isOutputMarkdown: boolean = true,
override readonly canUpdateOutput: boolean = false,
) {
super(
name,
displayName,
description,
kind,
parameterSchema,
isOutputMarkdown,
canUpdateOutput,
);
}
build(params: TParams): ToolInvocation<TParams, TResult> {
const validationError = this.validateToolParams(params);
if (validationError) {
throw new Error(validationError);
}
return new LegacyToolInvocation(this, params);
}
/**
* Validates the parameters for the tool
* This is a placeholder implementation and should be overridden
* Should be called from both `shouldConfirmExecute` and `execute`
* `shouldConfirmExecute` should return false immediately if invalid
* @param params Parameters to validate
* @returns An error message string if invalid, null otherwise
*/
// eslint-disable-next-line @typescript-eslint/no-unused-vars
override validateToolParams(params: TParams): string | null {
// Implementation would typically use a JSON Schema validator
// This is a placeholder that should be implemented by derived classes
return null;
}
/**
* Gets a pre-execution description of the tool operation
* Default implementation that should be overridden by derived classes
* @param params Parameters for the tool execution
* @returns A markdown string describing what the tool will do
*/
getDescription(params: TParams): string {
return JSON.stringify(params);
}
/**
* Determines if the tool should prompt for confirmation before execution
* @param params Parameters for the tool execution
* @returns Whether or not execute should be confirmed by the user.
*/
shouldConfirmExecute(
// eslint-disable-next-line @typescript-eslint/no-unused-vars
params: TParams,
// eslint-disable-next-line @typescript-eslint/no-unused-vars
abortSignal: AbortSignal,
): Promise<ToolCallConfirmationDetails | false> {
return Promise.resolve(false);
}
/**
* Determines what file system paths the tool will affect
* @param params Parameters for the tool execution
* @returns A list of such paths
*/
toolLocations(
// eslint-disable-next-line @typescript-eslint/no-unused-vars
params: TParams,
): ToolLocation[] {
return [];
}
/**
* Abstract method to execute the tool with the given parameters
* Must be implemented by derived classes
* @param params Parameters for the tool execution
* @param signal AbortSignal for tool cancellation
* @returns Result of the tool execution
*/
abstract execute(
params: TParams,
signal: AbortSignal,
updateOutput?: (output: string) => void,
): Promise<TResult>;
}
export interface ToolResult {
/**
* A short, one-line summary of the tool's action and result.

View File

@@ -4,7 +4,6 @@
* SPDX-License-Identifier: Apache-2.0
*/
import { SchemaValidator } from '../utils/schemaValidator.js';
import {
BaseDeclarativeTool,
BaseToolInvocation,
@@ -211,16 +210,9 @@ export class WebFetchTool extends BaseDeclarativeTool<
}
}
protected override validateToolParams(
protected override validateToolParamValues(
params: WebFetchToolParams,
): string | null {
const errors = SchemaValidator.validate(
this.schema.parametersJsonSchema,
params,
);
if (errors) {
return errors;
}
if (!params.url || params.url.trim() === '') {
return "The 'url' parameter cannot be empty.";
}

View File

@@ -0,0 +1,166 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
import { WebSearchTool, WebSearchToolParams } from './web-search.js';
import { Config } from '../config/config.js';
import { GeminiClient } from '../core/client.js';
// Mock GeminiClient and Config constructor
vi.mock('../core/client.js');
vi.mock('../config/config.js');
// Mock global fetch
const mockFetch = vi.fn();
global.fetch = mockFetch;
describe('WebSearchTool', () => {
const abortSignal = new AbortController().signal;
let mockGeminiClient: GeminiClient;
let tool: WebSearchTool;
beforeEach(() => {
vi.clearAllMocks();
const mockConfigInstance = {
getGeminiClient: () => mockGeminiClient,
getProxy: () => undefined,
getTavilyApiKey: () => 'test-api-key', // Add the missing method
} as unknown as Config;
mockGeminiClient = new GeminiClient(mockConfigInstance);
tool = new WebSearchTool(mockConfigInstance);
});
afterEach(() => {
vi.restoreAllMocks();
});
describe('build', () => {
it('should return an invocation for a valid query', () => {
const params: WebSearchToolParams = { query: 'test query' };
const invocation = tool.build(params);
expect(invocation).toBeDefined();
expect(invocation.params).toEqual(params);
});
it('should throw an error for an empty query', () => {
const params: WebSearchToolParams = { query: '' };
expect(() => tool.build(params)).toThrow(
"The 'query' parameter cannot be empty.",
);
});
it('should throw an error for a query with only whitespace', () => {
const params: WebSearchToolParams = { query: ' ' };
expect(() => tool.build(params)).toThrow(
"The 'query' parameter cannot be empty.",
);
});
});
describe('getDescription', () => {
it('should return a description of the search', () => {
const params: WebSearchToolParams = { query: 'test query' };
const invocation = tool.build(params);
expect(invocation.getDescription()).toBe(
'Searching the web for: "test query"',
);
});
});
describe('execute', () => {
it('should return search results for a successful query', async () => {
const params: WebSearchToolParams = { query: 'successful query' };
// Mock the fetch response
mockFetch.mockResolvedValueOnce({
ok: true,
json: async () => ({
answer: 'Here are your results.',
results: [],
}),
} as Response);
const invocation = tool.build(params);
const result = await invocation.execute(abortSignal);
expect(result.llmContent).toBe(
'Web search results for "successful query":\n\nHere are your results.',
);
expect(result.returnDisplay).toBe(
'Search results for "successful query" returned.',
);
expect(result.sources).toEqual([]);
});
it('should handle no search results found', async () => {
const params: WebSearchToolParams = { query: 'no results query' };
// Mock the fetch response
mockFetch.mockResolvedValueOnce({
ok: true,
json: async () => ({
answer: '',
results: [],
}),
} as Response);
const invocation = tool.build(params);
const result = await invocation.execute(abortSignal);
expect(result.llmContent).toBe(
'No search results or information found for query: "no results query"',
);
expect(result.returnDisplay).toBe('No information found.');
});
it('should handle API errors gracefully', async () => {
const params: WebSearchToolParams = { query: 'error query' };
// Mock the fetch to reject
mockFetch.mockRejectedValueOnce(new Error('API Failure'));
const invocation = tool.build(params);
const result = await invocation.execute(abortSignal);
expect(result.llmContent).toContain('Error:');
expect(result.llmContent).toContain('API Failure');
expect(result.returnDisplay).toBe('Error performing web search.');
});
it('should correctly format results with sources', async () => {
const params: WebSearchToolParams = { query: 'grounding query' };
// Mock the fetch response
mockFetch.mockResolvedValueOnce({
ok: true,
json: async () => ({
answer: 'This is a test response.',
results: [
{ title: 'Example Site', url: 'https://example.com' },
{ title: 'Google', url: 'https://google.com' },
],
}),
} as Response);
const invocation = tool.build(params);
const result = await invocation.execute(abortSignal);
const expectedLlmContent = `Web search results for "grounding query":
This is a test response.
Sources:
[1] Example Site (https://example.com)
[2] Google (https://google.com)`;
expect(result.llmContent).toBe(expectedLlmContent);
expect(result.returnDisplay).toBe(
'Search results for "grounding query" returned.',
);
expect(result.sources).toHaveLength(2);
});
});
});

View File

@@ -4,8 +4,14 @@
* SPDX-License-Identifier: Apache-2.0
*/
import { BaseTool, Kind, ToolResult } from './tools.js';
import { SchemaValidator } from '../utils/schemaValidator.js';
import {
BaseDeclarativeTool,
BaseToolInvocation,
Kind,
ToolInvocation,
ToolResult,
} from './tools.js';
import { getErrorMessage } from '../utils/errors.js';
import { Config } from '../config/config.js';
@@ -40,71 +46,24 @@ export interface WebSearchToolResult extends ToolResult {
sources?: Array<{ title: string; url: string }>;
}
/**
* A tool to perform web searches using Tavily API.
*/
export class WebSearchTool extends BaseTool<
class WebSearchToolInvocation extends BaseToolInvocation<
WebSearchToolParams,
WebSearchToolResult
> {
static readonly Name: string = 'web_search';
constructor(private readonly config: Config) {
super(
WebSearchTool.Name,
'TavilySearch',
'Performs a web search using the Tavily API and returns a concise answer with sources. Requires the TAVILY_API_KEY environment variable.',
Kind.Search,
{
type: 'object',
properties: {
query: {
type: 'string',
description: 'The search query to find information on the web.',
},
},
required: ['query'],
},
);
}
/**
* Validates the parameters for the WebSearchTool.
* @param params The parameters to validate
* @returns An error message string if validation fails, null if valid
*/
validateParams(params: WebSearchToolParams): string | null {
const errors = SchemaValidator.validate(
this.schema.parametersJsonSchema,
params,
);
if (errors) {
return errors;
}
if (!params.query || params.query.trim() === '') {
return "The 'query' parameter cannot be empty.";
}
return null;
}
override getDescription(params: WebSearchToolParams): string {
return `Searching the web for: "${params.query}"`;
}
async execute(
constructor(
private readonly config: Config,
params: WebSearchToolParams,
_signal: AbortSignal,
): Promise<WebSearchToolResult> {
const validationError = this.validateToolParams(params);
if (validationError) {
return {
llmContent: `Error: Invalid parameters provided. Reason: ${validationError}`,
returnDisplay: validationError,
};
}
) {
super(params);
}
const apiKey = this.config.getTavilyApiKey() || process.env.TAVILY_API_KEY;
override getDescription(): string {
return `Searching the web for: "${this.params.query}"`;
}
async execute(signal: AbortSignal): Promise<WebSearchToolResult> {
const apiKey =
this.config.getTavilyApiKey() || process.env['TAVILY_API_KEY'];
if (!apiKey) {
return {
llmContent:
@@ -115,8 +74,6 @@ export class WebSearchTool extends BaseTool<
}
try {
const controller = new AbortController();
const timeoutId = setTimeout(() => controller.abort(), 15000);
const response = await fetch('https://api.tavily.com/search', {
method: 'POST',
headers: {
@@ -124,14 +81,13 @@ export class WebSearchTool extends BaseTool<
},
body: JSON.stringify({
api_key: apiKey,
query: params.query,
query: this.params.query,
search_depth: 'advanced',
max_results: 5,
include_answer: true,
}),
signal: controller.signal,
signal,
});
clearTimeout(timeoutId);
if (!response.ok) {
const text = await response.text().catch(() => '');
@@ -166,18 +122,18 @@ export class WebSearchTool extends BaseTool<
if (!content.trim()) {
return {
llmContent: `No search results or information found for query: "${params.query}"`,
llmContent: `No search results or information found for query: "${this.params.query}"`,
returnDisplay: 'No information found.',
};
}
return {
llmContent: `Web search results for "${params.query}":\n\n${content}`,
returnDisplay: `Search results for "${params.query}" returned.`,
llmContent: `Web search results for "${this.params.query}":\n\n${content}`,
returnDisplay: `Search results for "${this.params.query}" returned.`,
sources,
};
} catch (error: unknown) {
const errorMessage = `Error during web search for query "${params.query}": ${getErrorMessage(
const errorMessage = `Error during web search for query "${this.params.query}": ${getErrorMessage(
error,
)}`;
console.error(errorMessage, error);
@@ -188,3 +144,52 @@ export class WebSearchTool extends BaseTool<
}
}
}
/**
* A tool to perform web searches using Google Search via the Gemini API.
*/
export class WebSearchTool extends BaseDeclarativeTool<
WebSearchToolParams,
WebSearchToolResult
> {
static readonly Name: string = 'web_search';
constructor(private readonly config: Config) {
super(
WebSearchTool.Name,
'TavilySearch',
'Performs a web search using the Tavily API and returns a concise answer with sources. Requires the TAVILY_API_KEY environment variable.',
Kind.Search,
{
type: 'object',
properties: {
query: {
type: 'string',
description: 'The search query to find information on the web.',
},
},
required: ['query'],
},
);
}
/**
* Validates the parameters for the WebSearchTool.
* @param params The parameters to validate
* @returns An error message string if validation fails, null if valid
*/
protected override validateToolParamValues(
params: WebSearchToolParams,
): string | null {
if (!params.query || params.query.trim() === '') {
return "The 'query' parameter cannot be empty.";
}
return null;
}
protected createInvocation(
params: WebSearchToolParams,
): ToolInvocation<WebSearchToolParams, WebSearchToolResult> {
return new WebSearchToolInvocation(this.config, params);
}
}

View File

@@ -13,7 +13,11 @@ import {
vi,
type Mocked,
} from 'vitest';
import { WriteFileTool, WriteFileToolParams } from './write-file.js';
import {
getCorrectedFileContent,
WriteFileTool,
WriteFileToolParams,
} from './write-file.js';
import { ToolErrorType } from './tool-error.js';
import {
FileDiff,
@@ -33,6 +37,7 @@ import {
CorrectedEditResult,
} from '../utils/editCorrector.js';
import { createMockWorkspaceContext } from '../test-utils/mockWorkspaceContext.js';
import { StandardFileSystemService } from '../services/fileSystemService.js';
const rootDir = path.resolve(os.tmpdir(), 'qwen-code-test-root');
@@ -51,11 +56,13 @@ vi.mocked(ensureCorrectFileContent).mockImplementation(
);
// Mock Config
const fsService = new StandardFileSystemService();
const mockConfigInternal = {
getTargetDir: () => rootDir,
getApprovalMode: vi.fn(() => ApprovalMode.DEFAULT),
setApprovalMode: vi.fn(),
getGeminiClient: vi.fn(), // Initialize as a plain mock function
getFileSystemService: () => fsService,
getIdeClient: vi.fn(),
getIdeMode: vi.fn(() => false),
getWorkspaceContext: () => createMockWorkspaceContext(rootDir),
@@ -174,74 +181,67 @@ describe('WriteFileTool', () => {
vi.clearAllMocks();
});
describe('validateToolParams', () => {
it('should return null for valid absolute path within root', () => {
describe('build', () => {
it('should return an invocation for a valid absolute path within root', () => {
const params = {
file_path: path.join(rootDir, 'test.txt'),
content: 'hello',
};
expect(tool.validateToolParams(params)).toBeNull();
const invocation = tool.build(params);
expect(invocation).toBeDefined();
expect(invocation.params).toEqual(params);
});
it('should return error for relative path', () => {
it('should throw an error for a relative path', () => {
const params = { file_path: 'test.txt', content: 'hello' };
expect(tool.validateToolParams(params)).toMatch(
/File path must be absolute/,
);
expect(() => tool.build(params)).toThrow(/File path must be absolute/);
});
it('should return error for path outside root', () => {
it('should throw an error for a path outside root', () => {
const outsidePath = path.resolve(tempDir, 'outside-root.txt');
const params = {
file_path: outsidePath,
content: 'hello',
};
const error = tool.validateToolParams(params);
expect(error).toContain(
'File path must be within one of the workspace directories',
expect(() => tool.build(params)).toThrow(
/File path must be within one of the workspace directories/,
);
});
it('should return error if path is a directory', () => {
it('should throw an error if path is a directory', () => {
const dirAsFilePath = path.join(rootDir, 'a_directory');
fs.mkdirSync(dirAsFilePath);
const params = {
file_path: dirAsFilePath,
content: 'hello',
};
expect(tool.validateToolParams(params)).toMatch(
expect(() => tool.build(params)).toThrow(
`Path is a directory, not a file: ${dirAsFilePath}`,
);
});
it('should return error if the content is null', () => {
it('should throw an error if the content is null', () => {
const dirAsFilePath = path.join(rootDir, 'a_directory');
fs.mkdirSync(dirAsFilePath);
const params = {
file_path: dirAsFilePath,
content: null,
} as unknown as WriteFileToolParams; // Intentionally non-conforming
expect(tool.validateToolParams(params)).toMatch(
`params/content must be string`,
);
expect(() => tool.build(params)).toThrow('params/content must be string');
});
});
describe('getDescription', () => {
it('should return error if the file_path is empty', () => {
it('should throw error if the file_path is empty', () => {
const dirAsFilePath = path.join(rootDir, 'a_directory');
fs.mkdirSync(dirAsFilePath);
const params = {
file_path: '',
content: '',
};
expect(tool.getDescription(params)).toMatch(
`Model did not provide valid parameters for write file tool, missing or empty "file_path"`,
);
expect(() => tool.build(params)).toThrow(`Missing or empty "file_path"`);
});
});
describe('_getCorrectedFileContent', () => {
describe('getCorrectedFileContent', () => {
it('should call ensureCorrectFileContent for a new file', async () => {
const filePath = path.join(rootDir, 'new_corrected_file.txt');
const proposedContent = 'Proposed new content.';
@@ -250,8 +250,8 @@ describe('WriteFileTool', () => {
// Ensure the mock is set for this specific test case if needed, or rely on beforeEach
mockEnsureCorrectFileContent.mockResolvedValue(correctedContent);
// @ts-expect-error _getCorrectedFileContent is private
const result = await tool._getCorrectedFileContent(
const result = await getCorrectedFileContent(
mockConfig,
filePath,
proposedContent,
abortSignal,
@@ -287,8 +287,8 @@ describe('WriteFileTool', () => {
occurrences: 1,
} as CorrectedEditResult);
// @ts-expect-error _getCorrectedFileContent is private
const result = await tool._getCorrectedFileContent(
const result = await getCorrectedFileContent(
mockConfig,
filePath,
proposedContent,
abortSignal,
@@ -319,19 +319,18 @@ describe('WriteFileTool', () => {
fs.writeFileSync(filePath, 'content', { mode: 0o000 });
const readError = new Error('Permission denied');
const originalReadFileSync = fs.readFileSync;
vi.spyOn(fs, 'readFileSync').mockImplementationOnce(() => {
throw readError;
});
vi.spyOn(fsService, 'readTextFile').mockImplementationOnce(() =>
Promise.reject(readError),
);
// @ts-expect-error _getCorrectedFileContent is private
const result = await tool._getCorrectedFileContent(
const result = await getCorrectedFileContent(
mockConfig,
filePath,
proposedContent,
abortSignal,
);
expect(fs.readFileSync).toHaveBeenCalledWith(filePath, 'utf8');
expect(fsService.readTextFile).toHaveBeenCalledWith(filePath);
expect(mockEnsureCorrectEdit).not.toHaveBeenCalled();
expect(mockEnsureCorrectFileContent).not.toHaveBeenCalled();
expect(result.correctedContent).toBe(proposedContent);
@@ -342,25 +341,12 @@ describe('WriteFileTool', () => {
code: undefined,
});
vi.spyOn(fs, 'readFileSync').mockImplementation(originalReadFileSync);
fs.chmodSync(filePath, 0o600);
});
});
describe('shouldConfirmExecute', () => {
const abortSignal = new AbortController().signal;
it('should return false if params are invalid (relative path)', async () => {
const params = { file_path: 'relative.txt', content: 'test' };
const confirmation = await tool.shouldConfirmExecute(params, abortSignal);
expect(confirmation).toBe(false);
});
it('should return false if params are invalid (outside root)', async () => {
const outsidePath = path.resolve(tempDir, 'outside-root.txt');
const params = { file_path: outsidePath, content: 'test' };
const confirmation = await tool.shouldConfirmExecute(params, abortSignal);
expect(confirmation).toBe(false);
});
it('should return false if _getCorrectedFileContent returns an error', async () => {
const filePath = path.join(rootDir, 'confirm_error_file.txt');
@@ -368,15 +354,14 @@ describe('WriteFileTool', () => {
fs.writeFileSync(filePath, 'original', { mode: 0o000 });
const readError = new Error('Simulated read error for confirmation');
const originalReadFileSync = fs.readFileSync;
vi.spyOn(fs, 'readFileSync').mockImplementationOnce(() => {
throw readError;
});
vi.spyOn(fsService, 'readTextFile').mockImplementationOnce(() =>
Promise.reject(readError),
);
const confirmation = await tool.shouldConfirmExecute(params, abortSignal);
const invocation = tool.build(params);
const confirmation = await invocation.shouldConfirmExecute(abortSignal);
expect(confirmation).toBe(false);
vi.spyOn(fs, 'readFileSync').mockImplementation(originalReadFileSync);
fs.chmodSync(filePath, 0o600);
});
@@ -387,8 +372,8 @@ describe('WriteFileTool', () => {
mockEnsureCorrectFileContent.mockResolvedValue(correctedContent); // Ensure this mock is active
const params = { file_path: filePath, content: proposedContent };
const confirmation = (await tool.shouldConfirmExecute(
params,
const invocation = tool.build(params);
const confirmation = (await invocation.shouldConfirmExecute(
abortSignal,
)) as ToolEditConfirmationDetails;
@@ -430,8 +415,8 @@ describe('WriteFileTool', () => {
});
const params = { file_path: filePath, content: proposedContent };
const confirmation = (await tool.shouldConfirmExecute(
params,
const invocation = tool.build(params);
const confirmation = (await invocation.shouldConfirmExecute(
abortSignal,
)) as ToolEditConfirmationDetails;
@@ -461,45 +446,20 @@ describe('WriteFileTool', () => {
describe('execute', () => {
const abortSignal = new AbortController().signal;
it('should return error if params are invalid (relative path)', async () => {
const params = { file_path: 'relative.txt', content: 'test' };
const result = await tool.execute(params, abortSignal);
expect(result.llmContent).toContain(
'Could not write file due to invalid parameters:',
);
expect(result.returnDisplay).toMatch(/File path must be absolute/);
expect(result.error).toEqual({
message: 'File path must be absolute: relative.txt',
type: ToolErrorType.INVALID_TOOL_PARAMS,
});
});
it('should return error if params are invalid (path outside root)', async () => {
const outsidePath = path.resolve(tempDir, 'outside-root.txt');
const params = { file_path: outsidePath, content: 'test' };
const result = await tool.execute(params, abortSignal);
expect(result.llmContent).toContain(
'Could not write file due to invalid parameters:',
);
expect(result.returnDisplay).toContain(
'File path must be within one of the workspace directories',
);
expect(result.error?.type).toBe(ToolErrorType.INVALID_TOOL_PARAMS);
});
it('should return error if _getCorrectedFileContent returns an error during execute', async () => {
const filePath = path.join(rootDir, 'execute_error_file.txt');
const params = { file_path: filePath, content: 'test content' };
fs.writeFileSync(filePath, 'original', { mode: 0o000 });
const readError = new Error('Simulated read error for execute');
const originalReadFileSync = fs.readFileSync;
vi.spyOn(fs, 'readFileSync').mockImplementationOnce(() => {
throw readError;
vi.spyOn(fsService, 'readTextFile').mockImplementationOnce(() => {
const readError = new Error('Simulated read error for execute');
return Promise.reject(readError);
});
const result = await tool.execute(params, abortSignal);
expect(result.llmContent).toContain('Error checking existing file:');
const invocation = tool.build(params);
const result = await invocation.execute(abortSignal);
expect(result.llmContent).toContain('Error checking existing file');
expect(result.returnDisplay).toMatch(
/Error checking existing file: Simulated read error for execute/,
);
@@ -509,7 +469,6 @@ describe('WriteFileTool', () => {
type: ToolErrorType.FILE_WRITE_FAILURE,
});
vi.spyOn(fs, 'readFileSync').mockImplementation(originalReadFileSync);
fs.chmodSync(filePath, 0o600);
});
@@ -520,11 +479,9 @@ describe('WriteFileTool', () => {
mockEnsureCorrectFileContent.mockResolvedValue(correctedContent);
const params = { file_path: filePath, content: proposedContent };
const invocation = tool.build(params);
const confirmDetails = await tool.shouldConfirmExecute(
params,
abortSignal,
);
const confirmDetails = await invocation.shouldConfirmExecute(abortSignal);
if (
typeof confirmDetails === 'object' &&
'onConfirm' in confirmDetails &&
@@ -533,7 +490,7 @@ describe('WriteFileTool', () => {
await confirmDetails.onConfirm(ToolConfirmationOutcome.ProceedOnce);
}
const result = await tool.execute(params, abortSignal);
const result = await invocation.execute(abortSignal);
expect(mockEnsureCorrectFileContent).toHaveBeenCalledWith(
proposedContent,
@@ -544,7 +501,8 @@ describe('WriteFileTool', () => {
/Successfully created and wrote to new file/,
);
expect(fs.existsSync(filePath)).toBe(true);
expect(fs.readFileSync(filePath, 'utf8')).toBe(correctedContent);
const writtenContent = await fsService.readTextFile(filePath);
expect(writtenContent).toBe(correctedContent);
const display = result.returnDisplay as FileDiff;
expect(display.fileName).toBe('execute_new_corrected_file.txt');
expect(display.fileDiff).toMatch(
@@ -578,11 +536,9 @@ describe('WriteFileTool', () => {
});
const params = { file_path: filePath, content: proposedContent };
const invocation = tool.build(params);
const confirmDetails = await tool.shouldConfirmExecute(
params,
abortSignal,
);
const confirmDetails = await invocation.shouldConfirmExecute(abortSignal);
if (
typeof confirmDetails === 'object' &&
'onConfirm' in confirmDetails &&
@@ -591,7 +547,7 @@ describe('WriteFileTool', () => {
await confirmDetails.onConfirm(ToolConfirmationOutcome.ProceedOnce);
}
const result = await tool.execute(params, abortSignal);
const result = await invocation.execute(abortSignal);
expect(mockEnsureCorrectEdit).toHaveBeenCalledWith(
filePath,
@@ -605,7 +561,8 @@ describe('WriteFileTool', () => {
abortSignal,
);
expect(result.llmContent).toMatch(/Successfully overwrote file/);
expect(fs.readFileSync(filePath, 'utf8')).toBe(correctedProposedContent);
const writtenContent = await fsService.readTextFile(filePath);
expect(writtenContent).toBe(correctedProposedContent);
const display = result.returnDisplay as FileDiff;
expect(display.fileName).toBe('execute_existing_corrected_file.txt');
expect(display.fileDiff).toMatch(
@@ -623,11 +580,9 @@ describe('WriteFileTool', () => {
mockEnsureCorrectFileContent.mockResolvedValue(content); // Ensure this mock is active
const params = { file_path: filePath, content };
const invocation = tool.build(params);
// Simulate confirmation if your logic requires it before execute, or remove if not needed for this path
const confirmDetails = await tool.shouldConfirmExecute(
params,
abortSignal,
);
const confirmDetails = await invocation.shouldConfirmExecute(abortSignal);
if (
typeof confirmDetails === 'object' &&
'onConfirm' in confirmDetails &&
@@ -636,7 +591,7 @@ describe('WriteFileTool', () => {
await confirmDetails.onConfirm(ToolConfirmationOutcome.ProceedOnce);
}
await tool.execute(params, abortSignal);
await invocation.execute(abortSignal);
expect(fs.existsSync(dirPath)).toBe(true);
expect(fs.statSync(dirPath).isDirectory()).toBe(true);
@@ -654,7 +609,8 @@ describe('WriteFileTool', () => {
content,
modified_by_user: true,
};
const result = await tool.execute(params, abortSignal);
const invocation = tool.build(params);
const result = await invocation.execute(abortSignal);
expect(result.llmContent).toMatch(/User modified the `content`/);
});
@@ -669,7 +625,8 @@ describe('WriteFileTool', () => {
content,
modified_by_user: false,
};
const result = await tool.execute(params, abortSignal);
const invocation = tool.build(params);
const result = await invocation.execute(abortSignal);
expect(result.llmContent).not.toMatch(/User modified the `content`/);
});
@@ -683,7 +640,8 @@ describe('WriteFileTool', () => {
file_path: filePath,
content,
};
const result = await tool.execute(params, abortSignal);
const invocation = tool.build(params);
const result = await invocation.execute(abortSignal);
expect(result.llmContent).not.toMatch(/User modified the `content`/);
});
@@ -695,7 +653,7 @@ describe('WriteFileTool', () => {
file_path: path.join(rootDir, 'file.txt'),
content: 'test content',
};
expect(tool.validateToolParams(params)).toBeNull();
expect(() => tool.build(params)).not.toThrow();
});
it('should reject paths outside workspace root', () => {
@@ -703,24 +661,9 @@ describe('WriteFileTool', () => {
file_path: '/etc/passwd',
content: 'malicious',
};
const error = tool.validateToolParams(params);
expect(error).toContain(
'File path must be within one of the workspace directories',
expect(() => tool.build(params)).toThrow(
/File path must be within one of the workspace directories/,
);
expect(error).toContain(rootDir);
});
it('should provide clear error message with workspace directories', () => {
const outsidePath = path.join(tempDir, 'outside-root.txt');
const params = {
file_path: outsidePath,
content: 'test',
};
const error = tool.validateToolParams(params);
expect(error).toContain(
'File path must be within one of the workspace directories',
);
expect(error).toContain(rootDir);
});
});
@@ -731,50 +674,50 @@ describe('WriteFileTool', () => {
const filePath = path.join(rootDir, 'permission_denied_file.txt');
const content = 'test content';
// Mock writeFileSync to throw EACCES error
const originalWriteFileSync = fs.writeFileSync;
vi.spyOn(fs, 'writeFileSync').mockImplementationOnce(() => {
// Mock FileSystemService writeTextFile to throw EACCES error
vi.spyOn(fsService, 'writeTextFile').mockImplementationOnce(() => {
const error = new Error('Permission denied') as NodeJS.ErrnoException;
error.code = 'EACCES';
throw error;
return Promise.reject(error);
});
const params = { file_path: filePath, content };
const result = await tool.execute(params, abortSignal);
const invocation = tool.build(params);
const result = await invocation.execute(abortSignal);
expect(result.error?.type).toBe(ToolErrorType.PERMISSION_DENIED);
expect(result.llmContent).toContain(
`Permission denied writing to file: ${filePath} (EACCES)`,
);
expect(result.returnDisplay).toContain('Permission denied');
vi.spyOn(fs, 'writeFileSync').mockImplementation(originalWriteFileSync);
expect(result.returnDisplay).toContain(
`Permission denied writing to file: ${filePath} (EACCES)`,
);
});
it('should return NO_SPACE_LEFT error when write fails with ENOSPC', async () => {
const filePath = path.join(rootDir, 'no_space_file.txt');
const content = 'test content';
// Mock writeFileSync to throw ENOSPC error
const originalWriteFileSync = fs.writeFileSync;
vi.spyOn(fs, 'writeFileSync').mockImplementationOnce(() => {
// Mock FileSystemService writeTextFile to throw ENOSPC error
vi.spyOn(fsService, 'writeTextFile').mockImplementationOnce(() => {
const error = new Error(
'No space left on device',
) as NodeJS.ErrnoException;
error.code = 'ENOSPC';
throw error;
return Promise.reject(error);
});
const params = { file_path: filePath, content };
const result = await tool.execute(params, abortSignal);
const invocation = tool.build(params);
const result = await invocation.execute(abortSignal);
expect(result.error?.type).toBe(ToolErrorType.NO_SPACE_LEFT);
expect(result.llmContent).toContain(
`No space left on device: ${filePath} (ENOSPC)`,
);
expect(result.returnDisplay).toContain('No space left');
vi.spyOn(fs, 'writeFileSync').mockImplementation(originalWriteFileSync);
expect(result.returnDisplay).toContain(
`No space left on device: ${filePath} (ENOSPC)`,
);
});
it('should return TARGET_IS_DIRECTORY error when write fails with EISDIR', async () => {
@@ -790,25 +733,26 @@ describe('WriteFileTool', () => {
return originalExistsSync(path as string);
});
// Mock writeFileSync to throw EISDIR error
const originalWriteFileSync = fs.writeFileSync;
vi.spyOn(fs, 'writeFileSync').mockImplementationOnce(() => {
// Mock FileSystemService writeTextFile to throw EISDIR error
vi.spyOn(fsService, 'writeTextFile').mockImplementationOnce(() => {
const error = new Error('Is a directory') as NodeJS.ErrnoException;
error.code = 'EISDIR';
throw error;
return Promise.reject(error);
});
const params = { file_path: dirPath, content };
const result = await tool.execute(params, abortSignal);
const invocation = tool.build(params);
const result = await invocation.execute(abortSignal);
expect(result.error?.type).toBe(ToolErrorType.TARGET_IS_DIRECTORY);
expect(result.llmContent).toContain(
`Target is a directory, not a file: ${dirPath} (EISDIR)`,
);
expect(result.returnDisplay).toContain('Target is a directory');
expect(result.returnDisplay).toContain(
`Target is a directory, not a file: ${dirPath} (EISDIR)`,
);
vi.spyOn(fs, 'existsSync').mockImplementation(originalExistsSync);
vi.spyOn(fs, 'writeFileSync').mockImplementation(originalWriteFileSync);
});
it('should return FILE_WRITE_FAILURE for generic write errors', async () => {
@@ -818,19 +762,22 @@ describe('WriteFileTool', () => {
// Ensure fs.existsSync is not mocked for this test
vi.restoreAllMocks();
// Mock writeFileSync to throw generic error
vi.spyOn(fs, 'writeFileSync').mockImplementationOnce(() => {
throw new Error('Generic write error');
});
// Mock FileSystemService writeTextFile to throw generic error
vi.spyOn(fsService, 'writeTextFile').mockImplementationOnce(() =>
Promise.reject(new Error('Generic write error')),
);
const params = { file_path: filePath, content };
const result = await tool.execute(params, abortSignal);
const invocation = tool.build(params);
const result = await invocation.execute(abortSignal);
expect(result.error?.type).toBe(ToolErrorType.FILE_WRITE_FAILURE);
expect(result.llmContent).toContain(
'Error writing to file: Generic write error',
);
expect(result.returnDisplay).toContain('Generic write error');
expect(result.returnDisplay).toContain(
'Error writing to file: Generic write error',
);
});
});
});

View File

@@ -9,17 +9,18 @@ import path from 'path';
import * as Diff from 'diff';
import { Config, ApprovalMode } from '../config/config.js';
import {
BaseTool,
ToolResult,
BaseDeclarativeTool,
BaseToolInvocation,
FileDiff,
ToolEditConfirmationDetails,
ToolConfirmationOutcome,
ToolCallConfirmationDetails,
Kind,
ToolCallConfirmationDetails,
ToolConfirmationOutcome,
ToolEditConfirmationDetails,
ToolInvocation,
ToolLocation,
ToolResult,
} from './tools.js';
import { ToolErrorType } from './tool-error.js';
import { SchemaValidator } from '../utils/schemaValidator.js';
import { makeRelative, shortenPath } from '../utils/paths.js';
import { getErrorMessage, isNodeError } from '../utils/errors.js';
import {
@@ -67,113 +68,101 @@ interface GetCorrectedFileContentResult {
error?: { message: string; code?: string };
}
/**
* Implementation of the WriteFile tool logic
*/
export class WriteFileTool
extends BaseTool<WriteFileToolParams, ToolResult>
implements ModifiableDeclarativeTool<WriteFileToolParams>
{
static readonly Name: string = 'write_file';
export async function getCorrectedFileContent(
config: Config,
filePath: string,
proposedContent: string,
abortSignal: AbortSignal,
): Promise<GetCorrectedFileContentResult> {
let originalContent = '';
let fileExists = false;
let correctedContent = proposedContent;
constructor(private readonly config: Config) {
super(
WriteFileTool.Name,
'WriteFile',
`Writes content to a specified file in the local filesystem.
try {
originalContent = await config
.getFileSystemService()
.readTextFile(filePath);
fileExists = true; // File exists and was read
} catch (err) {
if (isNodeError(err) && err.code === 'ENOENT') {
fileExists = false;
originalContent = '';
} else {
// File exists but could not be read (permissions, etc.)
fileExists = true; // Mark as existing but problematic
originalContent = ''; // Can't use its content
const error = {
message: getErrorMessage(err),
code: isNodeError(err) ? err.code : undefined,
};
// Return early as we can't proceed with content correction meaningfully
return { originalContent, correctedContent, fileExists, error };
}
}
The user has the ability to modify \`content\`. If modified, this will be stated in the response.`,
Kind.Edit,
// If readError is set, we have returned.
// So, file was either read successfully (fileExists=true, originalContent set)
// or it was ENOENT (fileExists=false, originalContent='').
if (fileExists) {
// This implies originalContent is available
const { params: correctedParams } = await ensureCorrectEdit(
filePath,
originalContent,
{
properties: {
file_path: {
description:
"The absolute path to the file to write to (e.g., '/home/user/project/file.txt'). Relative paths are not supported.",
type: 'string',
},
content: {
description: 'The content to write to the file.',
type: 'string',
},
},
required: ['file_path', 'content'],
type: 'object',
old_string: originalContent, // Treat entire current content as old_string
new_string: proposedContent,
file_path: filePath,
},
config.getGeminiClient(),
abortSignal,
);
correctedContent = correctedParams.new_string;
} else {
// This implies new file (ENOENT)
correctedContent = await ensureCorrectFileContent(
proposedContent,
config.getGeminiClient(),
abortSignal,
);
}
return { originalContent, correctedContent, fileExists };
}
override toolLocations(params: WriteFileToolParams): ToolLocation[] {
return [{ path: params.file_path }];
class WriteFileToolInvocation extends BaseToolInvocation<
WriteFileToolParams,
ToolResult
> {
constructor(
private readonly config: Config,
params: WriteFileToolParams,
) {
super(params);
}
override validateToolParams(params: WriteFileToolParams): string | null {
const errors = SchemaValidator.validate(
this.schema.parametersJsonSchema,
params,
);
if (errors) {
return errors;
}
const filePath = params.file_path;
if (!path.isAbsolute(filePath)) {
return `File path must be absolute: ${filePath}`;
}
const workspaceContext = this.config.getWorkspaceContext();
if (!workspaceContext.isPathWithinWorkspace(filePath)) {
const directories = workspaceContext.getDirectories();
return `File path must be within one of the workspace directories: ${directories.join(', ')}`;
}
try {
// This check should be performed only if the path exists.
// If it doesn't exist, it's a new file, which is valid for writing.
if (fs.existsSync(filePath)) {
const stats = fs.lstatSync(filePath);
if (stats.isDirectory()) {
return `Path is a directory, not a file: ${filePath}`;
}
}
} catch (statError: unknown) {
// If fs.existsSync is true but lstatSync fails (e.g., permissions, race condition where file is deleted)
// this indicates an issue with accessing the path that should be reported.
return `Error accessing path properties for validation: ${filePath}. Reason: ${statError instanceof Error ? statError.message : String(statError)}`;
}
return null;
override toolLocations(): ToolLocation[] {
return [{ path: this.params.file_path }];
}
override getDescription(params: WriteFileToolParams): string {
if (!params.file_path) {
return `Model did not provide valid parameters for write file tool, missing or empty "file_path"`;
}
override getDescription(): string {
const relativePath = makeRelative(
params.file_path,
this.params.file_path,
this.config.getTargetDir(),
);
return `Writing to ${shortenPath(relativePath)}`;
}
/**
* Handles the confirmation prompt for the WriteFile tool.
*/
override async shouldConfirmExecute(
params: WriteFileToolParams,
abortSignal: AbortSignal,
): Promise<ToolCallConfirmationDetails | false> {
if (this.config.getApprovalMode() === ApprovalMode.AUTO_EDIT) {
return false;
}
const validationError = this.validateToolParams(params);
if (validationError) {
return false;
}
const correctedContentResult = await this._getCorrectedFileContent(
params.file_path,
params.content,
const correctedContentResult = await getCorrectedFileContent(
this.config,
this.params.file_path,
this.params.content,
abortSignal,
);
@@ -184,10 +173,10 @@ export class WriteFileTool
const { originalContent, correctedContent } = correctedContentResult;
const relativePath = makeRelative(
params.file_path,
this.params.file_path,
this.config.getTargetDir(),
);
const fileName = path.basename(params.file_path);
const fileName = path.basename(this.params.file_path);
const fileDiff = Diff.createPatch(
fileName,
@@ -202,14 +191,14 @@ export class WriteFileTool
const ideConfirmation =
this.config.getIdeMode() &&
ideClient.getConnectionStatus().status === IDEConnectionStatus.Connected
? ideClient.openDiff(params.file_path, correctedContent)
? ideClient.openDiff(this.params.file_path, correctedContent)
: undefined;
const confirmationDetails: ToolEditConfirmationDetails = {
type: 'edit',
title: `Confirm Write: ${shortenPath(relativePath)}`,
fileName,
filePath: params.file_path,
filePath: this.params.file_path,
fileDiff,
originalContent,
newContent: correctedContent,
@@ -221,7 +210,7 @@ export class WriteFileTool
if (ideConfirmation) {
const result = await ideConfirmation;
if (result.status === 'accepted' && result.content) {
params.content = result.content;
this.params.content = result.content;
}
}
},
@@ -230,32 +219,20 @@ export class WriteFileTool
return confirmationDetails;
}
async execute(
params: WriteFileToolParams,
abortSignal: AbortSignal,
): Promise<ToolResult> {
const validationError = this.validateToolParams(params);
if (validationError) {
return {
llmContent: `Could not write file due to invalid parameters: ${validationError}`,
returnDisplay: validationError,
error: {
message: validationError,
type: ToolErrorType.INVALID_TOOL_PARAMS,
},
};
}
const correctedContentResult = await this._getCorrectedFileContent(
params.file_path,
params.content,
async execute(abortSignal: AbortSignal): Promise<ToolResult> {
const { file_path, content, ai_proposed_content, modified_by_user } =
this.params;
const correctedContentResult = await getCorrectedFileContent(
this.config,
file_path,
content,
abortSignal,
);
if (correctedContentResult.error) {
const errDetails = correctedContentResult.error;
const errorMsg = errDetails.code
? `Error checking existing file '${params.file_path}': ${errDetails.message} (${errDetails.code})`
? `Error checking existing file '${file_path}': ${errDetails.message} (${errDetails.code})`
: `Error checking existing file: ${errDetails.message}`;
return {
llmContent: errorMsg,
@@ -280,15 +257,17 @@ export class WriteFileTool
!correctedContentResult.fileExists);
try {
const dirName = path.dirname(params.file_path);
const dirName = path.dirname(file_path);
if (!fs.existsSync(dirName)) {
fs.mkdirSync(dirName, { recursive: true });
}
fs.writeFileSync(params.file_path, fileContent, 'utf8');
await this.config
.getFileSystemService()
.writeTextFile(file_path, fileContent);
// Generate diff for display result
const fileName = path.basename(params.file_path);
const fileName = path.basename(file_path);
// If there was a readError, originalContent in correctedContentResult is '',
// but for the diff, we want to show the original content as it was before the write if possible.
// However, if it was unreadable, currentContentForDiff will be empty.
@@ -305,23 +284,22 @@ export class WriteFileTool
DEFAULT_DIFF_OPTIONS,
);
const originallyProposedContent =
params.ai_proposed_content || params.content;
const originallyProposedContent = ai_proposed_content || content;
const diffStat = getDiffStat(
fileName,
currentContentForDiff,
originallyProposedContent,
params.content,
content,
);
const llmSuccessMessageParts = [
isNewFile
? `Successfully created and wrote to new file: ${params.file_path}.`
: `Successfully overwrote file: ${params.file_path}.`,
? `Successfully created and wrote to new file: ${file_path}.`
: `Successfully overwrote file: ${file_path}.`,
];
if (params.modified_by_user) {
if (modified_by_user) {
llmSuccessMessageParts.push(
`User modified the \`content\` to be: ${params.content}`,
`User modified the \`content\` to be: ${content}`,
);
}
@@ -334,8 +312,8 @@ export class WriteFileTool
};
const lines = fileContent.split('\n').length;
const mimetype = getSpecificMimeType(params.file_path);
const extension = path.extname(params.file_path); // Get extension
const mimetype = getSpecificMimeType(file_path);
const extension = path.extname(file_path); // Get extension
if (isNewFile) {
recordFileOperationMetric(
this.config,
@@ -367,17 +345,17 @@ export class WriteFileTool
if (isNodeError(error)) {
// Handle specific Node.js errors with their error codes
errorMsg = `Error writing to file '${params.file_path}': ${error.message} (${error.code})`;
errorMsg = `Error writing to file '${file_path}': ${error.message} (${error.code})`;
// Log specific error types for better debugging
if (error.code === 'EACCES') {
errorMsg = `Permission denied writing to file: ${params.file_path} (${error.code})`;
errorMsg = `Permission denied writing to file: ${file_path} (${error.code})`;
errorType = ToolErrorType.PERMISSION_DENIED;
} else if (error.code === 'ENOSPC') {
errorMsg = `No space left on device: ${params.file_path} (${error.code})`;
errorMsg = `No space left on device: ${file_path} (${error.code})`;
errorType = ToolErrorType.NO_SPACE_LEFT;
} else if (error.code === 'EISDIR') {
errorMsg = `Target is a directory, not a file: ${params.file_path} (${error.code})`;
errorMsg = `Target is a directory, not a file: ${file_path} (${error.code})`;
errorType = ToolErrorType.TARGET_IS_DIRECTORY;
}
@@ -401,63 +379,84 @@ export class WriteFileTool
};
}
}
}
private async _getCorrectedFileContent(
filePath: string,
proposedContent: string,
abortSignal: AbortSignal,
): Promise<GetCorrectedFileContentResult> {
let originalContent = '';
let fileExists = false;
let correctedContent = proposedContent;
/**
* Implementation of the WriteFile tool logic
*/
export class WriteFileTool
extends BaseDeclarativeTool<WriteFileToolParams, ToolResult>
implements ModifiableDeclarativeTool<WriteFileToolParams>
{
static readonly Name: string = 'write_file';
constructor(private readonly config: Config) {
super(
WriteFileTool.Name,
'WriteFile',
`Writes content to a specified file in the local filesystem.
The user has the ability to modify \`content\`. If modified, this will be stated in the response.`,
Kind.Edit,
{
properties: {
file_path: {
description:
"The absolute path to the file to write to (e.g., '/home/user/project/file.txt'). Relative paths are not supported.",
type: 'string',
},
content: {
description: 'The content to write to the file.',
type: 'string',
},
},
required: ['file_path', 'content'],
type: 'object',
},
);
}
protected override validateToolParamValues(
params: WriteFileToolParams,
): string | null {
const filePath = params.file_path;
if (!filePath) {
return `Missing or empty "file_path"`;
}
if (!path.isAbsolute(filePath)) {
return `File path must be absolute: ${filePath}`;
}
const workspaceContext = this.config.getWorkspaceContext();
if (!workspaceContext.isPathWithinWorkspace(filePath)) {
const directories = workspaceContext.getDirectories();
return `File path must be within one of the workspace directories: ${directories.join(
', ',
)}`;
}
try {
originalContent = fs.readFileSync(filePath, 'utf8');
fileExists = true; // File exists and was read
} catch (err) {
if (isNodeError(err) && err.code === 'ENOENT') {
fileExists = false;
originalContent = '';
} else {
// File exists but could not be read (permissions, etc.)
fileExists = true; // Mark as existing but problematic
originalContent = ''; // Can't use its content
const error = {
message: getErrorMessage(err),
code: isNodeError(err) ? err.code : undefined,
};
// Return early as we can't proceed with content correction meaningfully
return { originalContent, correctedContent, fileExists, error };
if (fs.existsSync(filePath)) {
const stats = fs.lstatSync(filePath);
if (stats.isDirectory()) {
return `Path is a directory, not a file: ${filePath}`;
}
}
} catch (statError: unknown) {
return `Error accessing path properties for validation: ${filePath}. Reason: ${
statError instanceof Error ? statError.message : String(statError)
}`;
}
// If readError is set, we have returned.
// So, file was either read successfully (fileExists=true, originalContent set)
// or it was ENOENT (fileExists=false, originalContent='').
return null;
}
if (fileExists) {
// This implies originalContent is available
const { params: correctedParams } = await ensureCorrectEdit(
filePath,
originalContent,
{
old_string: originalContent, // Treat entire current content as old_string
new_string: proposedContent,
file_path: filePath,
},
this.config.getGeminiClient(),
abortSignal,
);
correctedContent = correctedParams.new_string;
} else {
// This implies new file (ENOENT)
correctedContent = await ensureCorrectFileContent(
proposedContent,
this.config.getGeminiClient(),
abortSignal,
);
}
return { originalContent, correctedContent, fileExists };
protected createInvocation(
params: WriteFileToolParams,
): ToolInvocation<WriteFileToolParams, ToolResult> {
return new WriteFileToolInvocation(this.config, params);
}
getModifyContext(
@@ -466,7 +465,8 @@ export class WriteFileTool
return {
getFilePath: (params: WriteFileToolParams) => params.file_path,
getCurrentContent: async (params: WriteFileToolParams) => {
const correctedContentResult = await this._getCorrectedFileContent(
const correctedContentResult = await getCorrectedFileContent(
this.config,
params.file_path,
params.content,
abortSignal,
@@ -474,7 +474,8 @@ export class WriteFileTool
return correctedContentResult.originalContent;
},
getProposedContent: async (params: WriteFileToolParams) => {
const correctedContentResult = await this._getCorrectedFileContent(
const correctedContentResult = await getCorrectedFileContent(
this.config,
params.file_path,
params.content,
abortSignal,