mirror of
https://github.com/QwenLM/qwen-code.git
synced 2025-12-20 16:57:46 +00:00
# 🚀 Sync Gemini CLI v0.2.1 - Major Feature Update (#483)
This commit is contained in:
73
packages/core/src/tools/__snapshots__/shell.test.ts.snap
Normal file
73
packages/core/src/tools/__snapshots__/shell.test.ts.snap
Normal file
@@ -0,0 +1,73 @@
|
||||
// Vitest Snapshot v1, https://vitest.dev/guide/snapshot.html
|
||||
|
||||
exports[`ShellTool > getDescription > should return the non-windows description when not on windows 1`] = `
|
||||
"
|
||||
This tool executes a given shell command as \`bash -c <command>\`.
|
||||
|
||||
**Background vs Foreground Execution:**
|
||||
You should decide whether commands should run in background or foreground based on their nature:
|
||||
|
||||
**Use background execution (is_background: true) for:**
|
||||
- Long-running development servers: \`npm run start\`, \`npm run dev\`, \`yarn dev\`, \`bun run start\`
|
||||
- Build watchers: \`npm run watch\`, \`webpack --watch\`
|
||||
- Database servers: \`mongod\`, \`mysql\`, \`redis-server\`
|
||||
- Web servers: \`python -m http.server\`, \`php -S localhost:8000\`
|
||||
- Any command expected to run indefinitely until manually stopped
|
||||
|
||||
**Use foreground execution (is_background: false) for:**
|
||||
- One-time commands: \`ls\`, \`cat\`, \`grep\`
|
||||
- Build commands: \`npm run build\`, \`make\`
|
||||
- Installation commands: \`npm install\`, \`pip install\`
|
||||
- Git operations: \`git commit\`, \`git push\`
|
||||
- Test runs: \`npm test\`, \`pytest\`
|
||||
|
||||
Command is executed as a subprocess that leads its own process group. Command process group can be terminated as \`kill -- -PGID\` or signaled as \`kill -s SIGNAL -- -PGID\`.
|
||||
|
||||
The following information is returned:
|
||||
|
||||
Command: Executed command.
|
||||
Directory: Directory (relative to project root) where command was executed, or \`(root)\`.
|
||||
Stdout: Output on stdout stream. Can be \`(empty)\` or partial on error and for any unwaited background processes.
|
||||
Stderr: Output on stderr stream. Can be \`(empty)\` or partial on error and for any unwaited background processes.
|
||||
Error: Error or \`(none)\` if no error was reported for the subprocess.
|
||||
Exit Code: Exit code or \`(none)\` if terminated by signal.
|
||||
Signal: Signal number or \`(none)\` if no signal was received.
|
||||
Background PIDs: List of background processes started or \`(none)\`.
|
||||
Process Group PGID: Process group started or \`(none)\`"
|
||||
`;
|
||||
|
||||
exports[`ShellTool > getDescription > should return the windows description when on windows 1`] = `
|
||||
"
|
||||
This tool executes a given shell command as \`cmd.exe /c <command>\`.
|
||||
|
||||
**Background vs Foreground Execution:**
|
||||
You should decide whether commands should run in background or foreground based on their nature:
|
||||
|
||||
**Use background execution (is_background: true) for:**
|
||||
- Long-running development servers: \`npm run start\`, \`npm run dev\`, \`yarn dev\`, \`bun run start\`
|
||||
- Build watchers: \`npm run watch\`, \`webpack --watch\`
|
||||
- Database servers: \`mongod\`, \`mysql\`, \`redis-server\`
|
||||
- Web servers: \`python -m http.server\`, \`php -S localhost:8000\`
|
||||
- Any command expected to run indefinitely until manually stopped
|
||||
|
||||
**Use foreground execution (is_background: false) for:**
|
||||
- One-time commands: \`ls\`, \`cat\`, \`grep\`
|
||||
- Build commands: \`npm run build\`, \`make\`
|
||||
- Installation commands: \`npm install\`, \`pip install\`
|
||||
- Git operations: \`git commit\`, \`git push\`
|
||||
- Test runs: \`npm test\`, \`pytest\`
|
||||
|
||||
|
||||
|
||||
The following information is returned:
|
||||
|
||||
Command: Executed command.
|
||||
Directory: Directory (relative to project root) where command was executed, or \`(root)\`.
|
||||
Stdout: Output on stdout stream. Can be \`(empty)\` or partial on error and for any unwaited background processes.
|
||||
Stderr: Output on stderr stream. Can be \`(empty)\` or partial on error and for any unwaited background processes.
|
||||
Error: Error or \`(none)\` if no error was reported for the subprocess.
|
||||
Exit Code: Exit code or \`(none)\` if terminated by signal.
|
||||
Signal: Signal number or \`(none)\` if no signal was received.
|
||||
Background PIDs: List of background processes started or \`(none)\`.
|
||||
Process Group PGID: Process group started or \`(none)\`"
|
||||
`;
|
||||
@@ -36,6 +36,7 @@ import os from 'os';
|
||||
import { ApprovalMode, Config } from '../config/config.js';
|
||||
import { Content, Part, SchemaUnion } from '@google/genai';
|
||||
import { createMockWorkspaceContext } from '../test-utils/mockWorkspaceContext.js';
|
||||
import { StandardFileSystemService } from '../services/fileSystemService.js';
|
||||
|
||||
describe('EditTool', () => {
|
||||
let tool: EditTool;
|
||||
@@ -60,6 +61,7 @@ describe('EditTool', () => {
|
||||
getApprovalMode: vi.fn(),
|
||||
setApprovalMode: vi.fn(),
|
||||
getWorkspaceContext: () => createMockWorkspaceContext(rootDir),
|
||||
getFileSystemService: () => new StandardFileSystemService(),
|
||||
getIdeClient: () => undefined,
|
||||
getIdeMode: () => false,
|
||||
// getGeminiConfig: () => ({ apiKey: 'test-api-key' }), // This was not a real Config method
|
||||
@@ -393,7 +395,7 @@ describe('EditTool', () => {
|
||||
});
|
||||
});
|
||||
|
||||
it('should throw error if params are invalid', async () => {
|
||||
it('should throw error if file path is not absolute', async () => {
|
||||
const params: EditToolParams = {
|
||||
file_path: 'relative.txt',
|
||||
old_string: 'old',
|
||||
@@ -402,6 +404,17 @@ describe('EditTool', () => {
|
||||
expect(() => tool.build(params)).toThrow(/File path must be absolute/);
|
||||
});
|
||||
|
||||
it('should throw error if file path is empty', async () => {
|
||||
const params: EditToolParams = {
|
||||
file_path: '',
|
||||
old_string: 'old',
|
||||
new_string: 'new',
|
||||
};
|
||||
expect(() => tool.build(params)).toThrow(
|
||||
/The 'file_path' parameter must be non-empty./,
|
||||
);
|
||||
});
|
||||
|
||||
it('should edit an existing file and return diff with fileName', async () => {
|
||||
const initialContent = 'This is some old text.';
|
||||
const newContent = 'This is some new text.'; // old -> new
|
||||
|
||||
@@ -19,7 +19,6 @@ import {
|
||||
ToolResultDisplay,
|
||||
} from './tools.js';
|
||||
import { ToolErrorType } from './tool-error.js';
|
||||
import { SchemaValidator } from '../utils/schemaValidator.js';
|
||||
import { makeRelative, shortenPath } from '../utils/paths.js';
|
||||
import { isNodeError } from '../utils/errors.js';
|
||||
import { Config, ApprovalMode } from '../config/config.js';
|
||||
@@ -125,7 +124,9 @@ class EditToolInvocation implements ToolInvocation<EditToolParams, ToolResult> {
|
||||
| undefined = undefined;
|
||||
|
||||
try {
|
||||
currentContent = fs.readFileSync(params.file_path, 'utf8');
|
||||
currentContent = await this.config
|
||||
.getFileSystemService()
|
||||
.readTextFile(params.file_path);
|
||||
// Normalize line endings to LF for consistent processing.
|
||||
currentContent = currentContent.replace(/\r\n/g, '\n');
|
||||
fileExists = true;
|
||||
@@ -339,7 +340,9 @@ class EditToolInvocation implements ToolInvocation<EditToolParams, ToolResult> {
|
||||
|
||||
try {
|
||||
this.ensureParentDirectoriesExist(this.params.file_path);
|
||||
fs.writeFileSync(this.params.file_path, editData.newContent, 'utf8');
|
||||
await this.config
|
||||
.getFileSystemService()
|
||||
.writeTextFile(this.params.file_path, editData.newContent);
|
||||
|
||||
let displayResult: ToolResultDisplay;
|
||||
if (editData.isNewFile) {
|
||||
@@ -471,13 +474,11 @@ Expectation for required parameters:
|
||||
* @param params Parameters to validate
|
||||
* @returns Error message string or null if valid
|
||||
*/
|
||||
override validateToolParams(params: EditToolParams): string | null {
|
||||
const errors = SchemaValidator.validate(
|
||||
this.schema.parametersJsonSchema,
|
||||
params,
|
||||
);
|
||||
if (errors) {
|
||||
return errors;
|
||||
protected override validateToolParamValues(
|
||||
params: EditToolParams,
|
||||
): string | null {
|
||||
if (!params.file_path) {
|
||||
return "The 'file_path' parameter must be non-empty.";
|
||||
}
|
||||
|
||||
if (!path.isAbsolute(params.file_path)) {
|
||||
@@ -504,7 +505,9 @@ Expectation for required parameters:
|
||||
getFilePath: (params: EditToolParams) => params.file_path,
|
||||
getCurrentContent: async (params: EditToolParams): Promise<string> => {
|
||||
try {
|
||||
return fs.readFileSync(params.file_path, 'utf8');
|
||||
return this.config
|
||||
.getFileSystemService()
|
||||
.readTextFile(params.file_path);
|
||||
} catch (err) {
|
||||
if (!isNodeError(err) || err.code !== 'ENOENT') throw err;
|
||||
return '';
|
||||
@@ -512,7 +515,9 @@ Expectation for required parameters:
|
||||
},
|
||||
getProposedContent: async (params: EditToolParams): Promise<string> => {
|
||||
try {
|
||||
const currentContent = fs.readFileSync(params.file_path, 'utf8');
|
||||
const currentContent = await this.config
|
||||
.getFileSystemService()
|
||||
.readTextFile(params.file_path);
|
||||
return applyReplacement(
|
||||
currentContent,
|
||||
params.old_string,
|
||||
|
||||
@@ -150,6 +150,34 @@ describe('GlobTool', () => {
|
||||
expect(result.returnDisplay).toBe('No files found');
|
||||
});
|
||||
|
||||
it('should find files with special characters in the name', async () => {
|
||||
await fs.writeFile(path.join(tempRootDir, 'file[1].txt'), 'content');
|
||||
const params: GlobToolParams = { pattern: 'file[1].txt' };
|
||||
const invocation = globTool.build(params);
|
||||
const result = await invocation.execute(abortSignal);
|
||||
expect(result.llmContent).toContain('Found 1 file(s)');
|
||||
expect(result.llmContent).toContain(
|
||||
path.join(tempRootDir, 'file[1].txt'),
|
||||
);
|
||||
});
|
||||
|
||||
it('should find files with special characters like [] and () in the path', async () => {
|
||||
const filePath = path.join(
|
||||
tempRootDir,
|
||||
'src/app/[test]/(dashboard)/testing/components/code.tsx',
|
||||
);
|
||||
await fs.mkdir(path.dirname(filePath), { recursive: true });
|
||||
await fs.writeFile(filePath, 'content');
|
||||
|
||||
const params: GlobToolParams = {
|
||||
pattern: 'src/app/[test]/(dashboard)/testing/components/code.tsx',
|
||||
};
|
||||
const invocation = globTool.build(params);
|
||||
const result = await invocation.execute(abortSignal);
|
||||
expect(result.llmContent).toContain('Found 1 file(s)');
|
||||
expect(result.llmContent).toContain(filePath);
|
||||
});
|
||||
|
||||
it('should correctly sort files by modification time (newest first)', async () => {
|
||||
const params: GlobToolParams = { pattern: '*.sortme' };
|
||||
const invocation = globTool.build(params);
|
||||
|
||||
@@ -6,8 +6,7 @@
|
||||
|
||||
import fs from 'fs';
|
||||
import path from 'path';
|
||||
import { glob } from 'glob';
|
||||
import { SchemaValidator } from '../utils/schemaValidator.js';
|
||||
import { glob, escape } from 'glob';
|
||||
import {
|
||||
BaseDeclarativeTool,
|
||||
BaseToolInvocation,
|
||||
@@ -137,7 +136,13 @@ class GlobToolInvocation extends BaseToolInvocation<
|
||||
let allEntries: GlobPath[] = [];
|
||||
|
||||
for (const searchDir of searchDirectories) {
|
||||
const entries = (await glob(this.params.pattern, {
|
||||
let pattern = this.params.pattern;
|
||||
const fullPath = path.join(searchDir, pattern);
|
||||
if (fs.existsSync(fullPath)) {
|
||||
pattern = escape(pattern);
|
||||
}
|
||||
|
||||
const entries = (await glob(pattern, {
|
||||
cwd: searchDir,
|
||||
withFileTypes: true,
|
||||
nodir: true,
|
||||
@@ -281,15 +286,9 @@ export class GlobTool extends BaseDeclarativeTool<GlobToolParams, ToolResult> {
|
||||
/**
|
||||
* Validates the parameters for the tool.
|
||||
*/
|
||||
override validateToolParams(params: GlobToolParams): string | null {
|
||||
const errors = SchemaValidator.validate(
|
||||
this.schema.parametersJsonSchema,
|
||||
params,
|
||||
);
|
||||
if (errors) {
|
||||
return errors;
|
||||
}
|
||||
|
||||
protected override validateToolParamValues(
|
||||
params: GlobToolParams,
|
||||
): string | null {
|
||||
const searchDirAbsolute = path.resolve(
|
||||
this.config.getTargetDir(),
|
||||
params.path || '.',
|
||||
|
||||
@@ -17,7 +17,6 @@ import {
|
||||
ToolInvocation,
|
||||
ToolResult,
|
||||
} from './tools.js';
|
||||
import { SchemaValidator } from '../utils/schemaValidator.js';
|
||||
import { makeRelative, shortenPath } from '../utils/paths.js';
|
||||
import { getErrorMessage, isNodeError } from '../utils/errors.js';
|
||||
import { isGitRepository } from '../utils/gitUtils.js';
|
||||
@@ -672,15 +671,9 @@ export class GrepTool extends BaseDeclarativeTool<GrepToolParams, ToolResult> {
|
||||
* @param params Parameters to validate
|
||||
* @returns An error message string if invalid, null otherwise
|
||||
*/
|
||||
override validateToolParams(params: GrepToolParams): string | null {
|
||||
const errors = SchemaValidator.validate(
|
||||
this.schema.parametersJsonSchema,
|
||||
params,
|
||||
);
|
||||
if (errors) {
|
||||
return errors;
|
||||
}
|
||||
|
||||
protected override validateToolParamValues(
|
||||
params: GrepToolParams,
|
||||
): string | null {
|
||||
try {
|
||||
new RegExp(params.pattern);
|
||||
} catch (error) {
|
||||
|
||||
@@ -13,7 +13,6 @@ import {
|
||||
ToolInvocation,
|
||||
ToolResult,
|
||||
} from './tools.js';
|
||||
import { SchemaValidator } from '../utils/schemaValidator.js';
|
||||
import { makeRelative, shortenPath } from '../utils/paths.js';
|
||||
import { Config, DEFAULT_FILE_FILTERING_OPTIONS } from '../config/config.js';
|
||||
|
||||
@@ -314,14 +313,9 @@ export class LSTool extends BaseDeclarativeTool<LSToolParams, ToolResult> {
|
||||
* @param params Parameters to validate
|
||||
* @returns An error message string if invalid, null otherwise
|
||||
*/
|
||||
override validateToolParams(params: LSToolParams): string | null {
|
||||
const errors = SchemaValidator.validate(
|
||||
this.schema.parametersJsonSchema,
|
||||
params,
|
||||
);
|
||||
if (errors) {
|
||||
return errors;
|
||||
}
|
||||
protected override validateToolParamValues(
|
||||
params: LSToolParams,
|
||||
): string | null {
|
||||
if (!path.isAbsolute(params.path)) {
|
||||
return `Path must be absolute: ${params.path}`;
|
||||
}
|
||||
|
||||
54
packages/core/src/tools/mcp-client-manager.test.ts
Normal file
54
packages/core/src/tools/mcp-client-manager.test.ts
Normal file
@@ -0,0 +1,54 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { afterEach, describe, expect, it, vi } from 'vitest';
|
||||
import { McpClientManager } from './mcp-client-manager.js';
|
||||
import { McpClient } from './mcp-client.js';
|
||||
import { ToolRegistry } from './tool-registry.js';
|
||||
import { PromptRegistry } from '../prompts/prompt-registry.js';
|
||||
import { WorkspaceContext } from '../utils/workspaceContext.js';
|
||||
|
||||
vi.mock('./mcp-client.js', async () => {
|
||||
const originalModule = await vi.importActual('./mcp-client.js');
|
||||
return {
|
||||
...originalModule,
|
||||
McpClient: vi.fn(),
|
||||
populateMcpServerCommand: vi.fn(() => ({
|
||||
'test-server': {},
|
||||
})),
|
||||
};
|
||||
});
|
||||
|
||||
describe('McpClientManager', () => {
|
||||
afterEach(() => {
|
||||
vi.restoreAllMocks();
|
||||
});
|
||||
|
||||
it('should discover tools from all servers', async () => {
|
||||
const mockedMcpClient = {
|
||||
connect: vi.fn(),
|
||||
discover: vi.fn(),
|
||||
disconnect: vi.fn(),
|
||||
getStatus: vi.fn(),
|
||||
};
|
||||
vi.mocked(McpClient).mockReturnValue(
|
||||
mockedMcpClient as unknown as McpClient,
|
||||
);
|
||||
const manager = new McpClientManager(
|
||||
{
|
||||
'test-server': {},
|
||||
},
|
||||
'',
|
||||
{} as ToolRegistry,
|
||||
{} as PromptRegistry,
|
||||
false,
|
||||
{} as WorkspaceContext,
|
||||
);
|
||||
await manager.discoverAllMcpTools();
|
||||
expect(mockedMcpClient.connect).toHaveBeenCalledOnce();
|
||||
expect(mockedMcpClient.discover).toHaveBeenCalledOnce();
|
||||
});
|
||||
});
|
||||
115
packages/core/src/tools/mcp-client-manager.ts
Normal file
115
packages/core/src/tools/mcp-client-manager.ts
Normal file
@@ -0,0 +1,115 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { MCPServerConfig } from '../config/config.js';
|
||||
import { ToolRegistry } from './tool-registry.js';
|
||||
import { PromptRegistry } from '../prompts/prompt-registry.js';
|
||||
import {
|
||||
McpClient,
|
||||
MCPDiscoveryState,
|
||||
populateMcpServerCommand,
|
||||
} from './mcp-client.js';
|
||||
import { getErrorMessage } from '../utils/errors.js';
|
||||
import { WorkspaceContext } from '../utils/workspaceContext.js';
|
||||
|
||||
/**
|
||||
* Manages the lifecycle of multiple MCP clients, including local child processes.
|
||||
* This class is responsible for starting, stopping, and discovering tools from
|
||||
* a collection of MCP servers defined in the configuration.
|
||||
*/
|
||||
export class McpClientManager {
|
||||
private clients: Map<string, McpClient> = new Map();
|
||||
private readonly mcpServers: Record<string, MCPServerConfig>;
|
||||
private readonly mcpServerCommand: string | undefined;
|
||||
private readonly toolRegistry: ToolRegistry;
|
||||
private readonly promptRegistry: PromptRegistry;
|
||||
private readonly debugMode: boolean;
|
||||
private readonly workspaceContext: WorkspaceContext;
|
||||
private discoveryState: MCPDiscoveryState = MCPDiscoveryState.NOT_STARTED;
|
||||
|
||||
constructor(
|
||||
mcpServers: Record<string, MCPServerConfig>,
|
||||
mcpServerCommand: string | undefined,
|
||||
toolRegistry: ToolRegistry,
|
||||
promptRegistry: PromptRegistry,
|
||||
debugMode: boolean,
|
||||
workspaceContext: WorkspaceContext,
|
||||
) {
|
||||
this.mcpServers = mcpServers;
|
||||
this.mcpServerCommand = mcpServerCommand;
|
||||
this.toolRegistry = toolRegistry;
|
||||
this.promptRegistry = promptRegistry;
|
||||
this.debugMode = debugMode;
|
||||
this.workspaceContext = workspaceContext;
|
||||
}
|
||||
|
||||
/**
|
||||
* Initiates the tool discovery process for all configured MCP servers.
|
||||
* It connects to each server, discovers its available tools, and registers
|
||||
* them with the `ToolRegistry`.
|
||||
*/
|
||||
async discoverAllMcpTools(): Promise<void> {
|
||||
await this.stop();
|
||||
this.discoveryState = MCPDiscoveryState.IN_PROGRESS;
|
||||
const servers = populateMcpServerCommand(
|
||||
this.mcpServers,
|
||||
this.mcpServerCommand,
|
||||
);
|
||||
|
||||
const discoveryPromises = Object.entries(servers).map(
|
||||
async ([name, config]) => {
|
||||
const client = new McpClient(
|
||||
name,
|
||||
config,
|
||||
this.toolRegistry,
|
||||
this.promptRegistry,
|
||||
this.workspaceContext,
|
||||
this.debugMode,
|
||||
);
|
||||
this.clients.set(name, client);
|
||||
try {
|
||||
await client.connect();
|
||||
await client.discover();
|
||||
} catch (error) {
|
||||
// Log the error but don't let a single failed server stop the others
|
||||
console.error(
|
||||
`Error during discovery for server '${name}': ${getErrorMessage(
|
||||
error,
|
||||
)}`,
|
||||
);
|
||||
}
|
||||
},
|
||||
);
|
||||
|
||||
await Promise.all(discoveryPromises);
|
||||
this.discoveryState = MCPDiscoveryState.COMPLETED;
|
||||
}
|
||||
|
||||
/**
|
||||
* Stops all running local MCP servers and closes all client connections.
|
||||
* This is the cleanup method to be called on application exit.
|
||||
*/
|
||||
async stop(): Promise<void> {
|
||||
const disconnectionPromises = Array.from(this.clients.entries()).map(
|
||||
async ([name, client]) => {
|
||||
try {
|
||||
await client.disconnect();
|
||||
} catch (error) {
|
||||
console.error(
|
||||
`Error stopping client '${name}': ${getErrorMessage(error)}`,
|
||||
);
|
||||
}
|
||||
},
|
||||
);
|
||||
|
||||
await Promise.all(disconnectionPromises);
|
||||
this.clients.clear();
|
||||
}
|
||||
|
||||
getDiscoveryState(): MCPDiscoveryState {
|
||||
return this.discoveryState;
|
||||
}
|
||||
}
|
||||
@@ -4,16 +4,14 @@
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { afterEach, describe, expect, it, vi, beforeEach } from 'vitest';
|
||||
import { afterEach, describe, expect, it, vi } from 'vitest';
|
||||
import { StreamableHTTPClientTransport } from '@modelcontextprotocol/sdk/client/streamableHttp.js';
|
||||
import {
|
||||
populateMcpServerCommand,
|
||||
createTransport,
|
||||
isEnabled,
|
||||
discoverTools,
|
||||
discoverPrompts,
|
||||
hasValidTypes,
|
||||
connectToMcpServer,
|
||||
McpClient,
|
||||
} from './mcp-client.js';
|
||||
import { SSEClientTransport } from '@modelcontextprotocol/sdk/client/sse.js';
|
||||
import * as SdkClientStdioLib from '@modelcontextprotocol/sdk/client/stdio.js';
|
||||
@@ -22,26 +20,36 @@ import * as GenAiLib from '@google/genai';
|
||||
import { GoogleCredentialProvider } from '../mcp/google-auth-provider.js';
|
||||
import { AuthProviderType } from '../config/config.js';
|
||||
import { PromptRegistry } from '../prompts/prompt-registry.js';
|
||||
|
||||
import { DiscoveredMCPTool } from './mcp-tool.js';
|
||||
import { ToolRegistry } from './tool-registry.js';
|
||||
import { WorkspaceContext } from '../utils/workspaceContext.js';
|
||||
import { pathToFileURL } from 'node:url';
|
||||
|
||||
vi.mock('@modelcontextprotocol/sdk/client/stdio.js');
|
||||
vi.mock('@modelcontextprotocol/sdk/client/index.js');
|
||||
vi.mock('@google/genai');
|
||||
vi.mock('../mcp/oauth-provider.js');
|
||||
vi.mock('../mcp/oauth-token-storage.js');
|
||||
vi.mock('./mcp-tool.js');
|
||||
|
||||
describe('mcp-client', () => {
|
||||
afterEach(() => {
|
||||
vi.restoreAllMocks();
|
||||
});
|
||||
|
||||
describe('discoverTools', () => {
|
||||
describe('McpClient', () => {
|
||||
it('should discover tools', async () => {
|
||||
const mockedClient = {} as unknown as ClientLib.Client;
|
||||
const mockedClient = {
|
||||
connect: vi.fn(),
|
||||
discover: vi.fn(),
|
||||
disconnect: vi.fn(),
|
||||
getStatus: vi.fn(),
|
||||
registerCapabilities: vi.fn(),
|
||||
setRequestHandler: vi.fn(),
|
||||
};
|
||||
vi.mocked(ClientLib.Client).mockReturnValue(
|
||||
mockedClient as unknown as ClientLib.Client,
|
||||
);
|
||||
vi.spyOn(SdkClientStdioLib, 'StdioClientTransport').mockReturnValue(
|
||||
{} as SdkClientStdioLib.StdioClientTransport,
|
||||
);
|
||||
const mockedMcpToTool = vi.mocked(GenAiLib.mcpToTool).mockReturnValue({
|
||||
tool: () => ({
|
||||
functionDeclarations: [
|
||||
@@ -51,62 +59,43 @@ describe('mcp-client', () => {
|
||||
],
|
||||
}),
|
||||
} as unknown as GenAiLib.CallableTool);
|
||||
|
||||
const tools = await discoverTools('test-server', {}, mockedClient);
|
||||
|
||||
expect(tools.length).toBe(1);
|
||||
const mockedToolRegistry = {
|
||||
registerTool: vi.fn(),
|
||||
} as unknown as ToolRegistry;
|
||||
const client = new McpClient(
|
||||
'test-server',
|
||||
{
|
||||
command: 'test-command',
|
||||
},
|
||||
mockedToolRegistry,
|
||||
{} as PromptRegistry,
|
||||
{} as WorkspaceContext,
|
||||
false,
|
||||
);
|
||||
await client.connect();
|
||||
await client.discover();
|
||||
expect(mockedMcpToTool).toHaveBeenCalledOnce();
|
||||
});
|
||||
|
||||
it('should log an error if there is an error discovering a tool', async () => {
|
||||
const mockedClient = {} as unknown as ClientLib.Client;
|
||||
const consoleErrorSpy = vi
|
||||
.spyOn(console, 'error')
|
||||
.mockImplementation(() => {});
|
||||
|
||||
const testError = new Error('Invalid tool name');
|
||||
vi.mocked(DiscoveredMCPTool).mockImplementation(
|
||||
(
|
||||
_mcpCallableTool: GenAiLib.CallableTool,
|
||||
_serverName: string,
|
||||
name: string,
|
||||
) => {
|
||||
if (name === 'invalid tool name') {
|
||||
throw testError;
|
||||
}
|
||||
return { name: 'validTool' } as DiscoveredMCPTool;
|
||||
},
|
||||
);
|
||||
|
||||
vi.mocked(GenAiLib.mcpToTool).mockReturnValue({
|
||||
tool: () =>
|
||||
Promise.resolve({
|
||||
functionDeclarations: [
|
||||
{
|
||||
name: 'validTool',
|
||||
},
|
||||
{
|
||||
name: 'invalid tool name', // this will fail validation
|
||||
},
|
||||
],
|
||||
}),
|
||||
} as unknown as GenAiLib.CallableTool);
|
||||
|
||||
const tools = await discoverTools('test-server', {}, mockedClient);
|
||||
|
||||
expect(tools.length).toBe(1);
|
||||
expect(tools[0].name).toBe('validTool');
|
||||
expect(consoleErrorSpy).toHaveBeenCalledOnce();
|
||||
expect(consoleErrorSpy).toHaveBeenCalledWith(
|
||||
`Error discovering tool: 'invalid tool name' from MCP server 'test-server': ${testError.message}`,
|
||||
);
|
||||
});
|
||||
|
||||
it('should skip tools if a parameter is missing a type', async () => {
|
||||
const mockedClient = {} as unknown as ClientLib.Client;
|
||||
const consoleWarnSpy = vi
|
||||
.spyOn(console, 'warn')
|
||||
.mockImplementation(() => {});
|
||||
const mockedClient = {
|
||||
connect: vi.fn(),
|
||||
discover: vi.fn(),
|
||||
disconnect: vi.fn(),
|
||||
getStatus: vi.fn(),
|
||||
registerCapabilities: vi.fn(),
|
||||
setRequestHandler: vi.fn(),
|
||||
tool: vi.fn(),
|
||||
};
|
||||
vi.mocked(ClientLib.Client).mockReturnValue(
|
||||
mockedClient as unknown as ClientLib.Client,
|
||||
);
|
||||
vi.spyOn(SdkClientStdioLib, 'StdioClientTransport').mockReturnValue(
|
||||
{} as SdkClientStdioLib.StdioClientTransport,
|
||||
);
|
||||
vi.mocked(GenAiLib.mcpToTool).mockReturnValue({
|
||||
tool: () =>
|
||||
Promise.resolve({
|
||||
@@ -132,11 +121,22 @@ describe('mcp-client', () => {
|
||||
],
|
||||
}),
|
||||
} as unknown as GenAiLib.CallableTool);
|
||||
|
||||
const tools = await discoverTools('test-server', {}, mockedClient);
|
||||
|
||||
expect(tools.length).toBe(1);
|
||||
expect(vi.mocked(DiscoveredMCPTool).mock.calls[0][2]).toBe('validTool');
|
||||
const mockedToolRegistry = {
|
||||
registerTool: vi.fn(),
|
||||
} as unknown as ToolRegistry;
|
||||
const client = new McpClient(
|
||||
'test-server',
|
||||
{
|
||||
command: 'test-command',
|
||||
},
|
||||
mockedToolRegistry,
|
||||
{} as PromptRegistry,
|
||||
{} as WorkspaceContext,
|
||||
false,
|
||||
);
|
||||
await client.connect();
|
||||
await client.discover();
|
||||
expect(mockedToolRegistry.registerTool).toHaveBeenCalledOnce();
|
||||
expect(consoleWarnSpy).toHaveBeenCalledOnce();
|
||||
expect(consoleWarnSpy).toHaveBeenCalledWith(
|
||||
`Skipping tool 'invalidTool' from MCP server 'test-server' because it has ` +
|
||||
@@ -145,147 +145,19 @@ describe('mcp-client', () => {
|
||||
consoleWarnSpy.mockRestore();
|
||||
});
|
||||
|
||||
it('should skip tools if a nested parameter is missing a type', async () => {
|
||||
const mockedClient = {} as unknown as ClientLib.Client;
|
||||
const consoleWarnSpy = vi
|
||||
.spyOn(console, 'warn')
|
||||
it('should handle errors when discovering prompts', async () => {
|
||||
const consoleErrorSpy = vi
|
||||
.spyOn(console, 'error')
|
||||
.mockImplementation(() => {});
|
||||
vi.mocked(GenAiLib.mcpToTool).mockReturnValue({
|
||||
tool: () =>
|
||||
Promise.resolve({
|
||||
functionDeclarations: [
|
||||
{
|
||||
name: 'invalidTool',
|
||||
parametersJsonSchema: {
|
||||
type: 'object',
|
||||
properties: {
|
||||
param1: {
|
||||
type: 'object',
|
||||
properties: {
|
||||
nestedParam: {
|
||||
description: 'a nested param with no type',
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
],
|
||||
}),
|
||||
} as unknown as GenAiLib.CallableTool);
|
||||
|
||||
const tools = await discoverTools('test-server', {}, mockedClient);
|
||||
|
||||
expect(tools.length).toBe(0);
|
||||
expect(consoleWarnSpy).toHaveBeenCalledOnce();
|
||||
expect(consoleWarnSpy).toHaveBeenCalledWith(
|
||||
`Skipping tool 'invalidTool' from MCP server 'test-server' because it has ` +
|
||||
`missing types in its parameter schema. Please file an issue with the owner of the MCP server.`,
|
||||
);
|
||||
consoleWarnSpy.mockRestore();
|
||||
});
|
||||
|
||||
it('should skip tool if an array item is missing a type', async () => {
|
||||
const mockedClient = {} as unknown as ClientLib.Client;
|
||||
const consoleWarnSpy = vi
|
||||
.spyOn(console, 'warn')
|
||||
.mockImplementation(() => {});
|
||||
vi.mocked(GenAiLib.mcpToTool).mockReturnValue({
|
||||
tool: () =>
|
||||
Promise.resolve({
|
||||
functionDeclarations: [
|
||||
{
|
||||
name: 'invalidTool',
|
||||
parametersJsonSchema: {
|
||||
type: 'object',
|
||||
properties: {
|
||||
param1: {
|
||||
type: 'array',
|
||||
items: {
|
||||
description: 'an array item with no type',
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
],
|
||||
}),
|
||||
} as unknown as GenAiLib.CallableTool);
|
||||
|
||||
const tools = await discoverTools('test-server', {}, mockedClient);
|
||||
|
||||
expect(tools.length).toBe(0);
|
||||
expect(consoleWarnSpy).toHaveBeenCalledOnce();
|
||||
expect(consoleWarnSpy).toHaveBeenCalledWith(
|
||||
`Skipping tool 'invalidTool' from MCP server 'test-server' because it has ` +
|
||||
`missing types in its parameter schema. Please file an issue with the owner of the MCP server.`,
|
||||
);
|
||||
consoleWarnSpy.mockRestore();
|
||||
});
|
||||
|
||||
it('should discover tool with no properties in schema', async () => {
|
||||
const mockedClient = {} as unknown as ClientLib.Client;
|
||||
const consoleWarnSpy = vi
|
||||
.spyOn(console, 'warn')
|
||||
.mockImplementation(() => {});
|
||||
vi.mocked(GenAiLib.mcpToTool).mockReturnValue({
|
||||
tool: () =>
|
||||
Promise.resolve({
|
||||
functionDeclarations: [
|
||||
{
|
||||
name: 'validTool',
|
||||
parametersJsonSchema: {
|
||||
type: 'object',
|
||||
},
|
||||
},
|
||||
],
|
||||
}),
|
||||
} as unknown as GenAiLib.CallableTool);
|
||||
|
||||
const tools = await discoverTools('test-server', {}, mockedClient);
|
||||
|
||||
expect(tools.length).toBe(1);
|
||||
expect(vi.mocked(DiscoveredMCPTool).mock.calls[0][2]).toBe('validTool');
|
||||
expect(consoleWarnSpy).not.toHaveBeenCalled();
|
||||
consoleWarnSpy.mockRestore();
|
||||
});
|
||||
|
||||
it('should discover tool with empty properties object in schema', async () => {
|
||||
const mockedClient = {} as unknown as ClientLib.Client;
|
||||
const consoleWarnSpy = vi
|
||||
.spyOn(console, 'warn')
|
||||
.mockImplementation(() => {});
|
||||
vi.mocked(GenAiLib.mcpToTool).mockReturnValue({
|
||||
tool: () =>
|
||||
Promise.resolve({
|
||||
functionDeclarations: [
|
||||
{
|
||||
name: 'validTool',
|
||||
parametersJsonSchema: {
|
||||
type: 'object',
|
||||
properties: {},
|
||||
},
|
||||
},
|
||||
],
|
||||
}),
|
||||
} as unknown as GenAiLib.CallableTool);
|
||||
|
||||
const tools = await discoverTools('test-server', {}, mockedClient);
|
||||
|
||||
expect(tools.length).toBe(1);
|
||||
expect(vi.mocked(DiscoveredMCPTool).mock.calls[0][2]).toBe('validTool');
|
||||
expect(consoleWarnSpy).not.toHaveBeenCalled();
|
||||
consoleWarnSpy.mockRestore();
|
||||
});
|
||||
});
|
||||
|
||||
describe('connectToMcpServer', () => {
|
||||
it('should register a roots/list handler', async () => {
|
||||
const mockedClient = {
|
||||
connect: vi.fn(),
|
||||
discover: vi.fn(),
|
||||
disconnect: vi.fn(),
|
||||
getStatus: vi.fn(),
|
||||
registerCapabilities: vi.fn(),
|
||||
setRequestHandler: vi.fn(),
|
||||
callTool: vi.fn(),
|
||||
connect: vi.fn(),
|
||||
getServerCapabilities: vi.fn().mockReturnValue({ prompts: {} }),
|
||||
request: vi.fn().mockRejectedValue(new Error('Test error')),
|
||||
};
|
||||
vi.mocked(ClientLib.Client).mockReturnValue(
|
||||
mockedClient as unknown as ClientLib.Client,
|
||||
@@ -293,148 +165,29 @@ describe('mcp-client', () => {
|
||||
vi.spyOn(SdkClientStdioLib, 'StdioClientTransport').mockReturnValue(
|
||||
{} as SdkClientStdioLib.StdioClientTransport,
|
||||
);
|
||||
const mockWorkspaceContext = {
|
||||
getDirectories: vi
|
||||
.fn()
|
||||
.mockReturnValue(['/test/dir', '/another/project']),
|
||||
} as unknown as WorkspaceContext;
|
||||
|
||||
await connectToMcpServer(
|
||||
vi.mocked(GenAiLib.mcpToTool).mockReturnValue({
|
||||
tool: () => Promise.resolve({ functionDeclarations: [] }),
|
||||
} as unknown as GenAiLib.CallableTool);
|
||||
const client = new McpClient(
|
||||
'test-server',
|
||||
{
|
||||
command: 'test-command',
|
||||
},
|
||||
{} as ToolRegistry,
|
||||
{} as PromptRegistry,
|
||||
{} as WorkspaceContext,
|
||||
false,
|
||||
mockWorkspaceContext,
|
||||
);
|
||||
|
||||
expect(mockedClient.registerCapabilities).toHaveBeenCalledWith({
|
||||
roots: {},
|
||||
});
|
||||
expect(mockedClient.setRequestHandler).toHaveBeenCalledOnce();
|
||||
const handler = mockedClient.setRequestHandler.mock.calls[0][1];
|
||||
const roots = await handler();
|
||||
expect(roots).toEqual({
|
||||
roots: [
|
||||
{
|
||||
uri: pathToFileURL('/test/dir').toString(),
|
||||
name: 'dir',
|
||||
},
|
||||
{
|
||||
uri: pathToFileURL('/another/project').toString(),
|
||||
name: 'project',
|
||||
},
|
||||
],
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('discoverPrompts', () => {
|
||||
const mockedPromptRegistry = {
|
||||
registerPrompt: vi.fn(),
|
||||
} as unknown as PromptRegistry;
|
||||
|
||||
it('should discover and log prompts', async () => {
|
||||
const mockRequest = vi.fn().mockResolvedValue({
|
||||
prompts: [
|
||||
{ name: 'prompt1', description: 'desc1' },
|
||||
{ name: 'prompt2' },
|
||||
],
|
||||
});
|
||||
const mockGetServerCapabilities = vi.fn().mockReturnValue({
|
||||
prompts: {},
|
||||
});
|
||||
const mockedClient = {
|
||||
getServerCapabilities: mockGetServerCapabilities,
|
||||
request: mockRequest,
|
||||
} as unknown as ClientLib.Client;
|
||||
|
||||
await discoverPrompts('test-server', mockedClient, mockedPromptRegistry);
|
||||
|
||||
expect(mockGetServerCapabilities).toHaveBeenCalledOnce();
|
||||
expect(mockRequest).toHaveBeenCalledWith(
|
||||
{ method: 'prompts/list', params: {} },
|
||||
expect.anything(),
|
||||
await client.connect();
|
||||
await expect(client.discover()).rejects.toThrow(
|
||||
'No prompts or tools found on the server.',
|
||||
);
|
||||
});
|
||||
|
||||
it('should do nothing if no prompts are discovered', async () => {
|
||||
const mockRequest = vi.fn().mockResolvedValue({
|
||||
prompts: [],
|
||||
});
|
||||
const mockGetServerCapabilities = vi.fn().mockReturnValue({
|
||||
prompts: {},
|
||||
});
|
||||
|
||||
const mockedClient = {
|
||||
getServerCapabilities: mockGetServerCapabilities,
|
||||
request: mockRequest,
|
||||
} as unknown as ClientLib.Client;
|
||||
|
||||
const consoleLogSpy = vi
|
||||
.spyOn(console, 'debug')
|
||||
.mockImplementation(() => {});
|
||||
|
||||
await discoverPrompts('test-server', mockedClient, mockedPromptRegistry);
|
||||
|
||||
expect(mockGetServerCapabilities).toHaveBeenCalledOnce();
|
||||
expect(mockRequest).toHaveBeenCalledOnce();
|
||||
expect(consoleLogSpy).not.toHaveBeenCalled();
|
||||
|
||||
consoleLogSpy.mockRestore();
|
||||
});
|
||||
|
||||
it('should do nothing if the server has no prompt support', async () => {
|
||||
const mockRequest = vi.fn().mockResolvedValue({
|
||||
prompts: [],
|
||||
});
|
||||
const mockGetServerCapabilities = vi.fn().mockReturnValue({});
|
||||
|
||||
const mockedClient = {
|
||||
getServerCapabilities: mockGetServerCapabilities,
|
||||
request: mockRequest,
|
||||
} as unknown as ClientLib.Client;
|
||||
|
||||
const consoleLogSpy = vi
|
||||
.spyOn(console, 'debug')
|
||||
.mockImplementation(() => {});
|
||||
|
||||
await discoverPrompts('test-server', mockedClient, mockedPromptRegistry);
|
||||
|
||||
expect(mockGetServerCapabilities).toHaveBeenCalledOnce();
|
||||
expect(mockRequest).not.toHaveBeenCalled();
|
||||
expect(consoleLogSpy).not.toHaveBeenCalled();
|
||||
|
||||
consoleLogSpy.mockRestore();
|
||||
});
|
||||
|
||||
it('should log an error if discovery fails', async () => {
|
||||
const testError = new Error('test error');
|
||||
testError.message = 'test error';
|
||||
const mockRequest = vi.fn().mockRejectedValue(testError);
|
||||
const mockGetServerCapabilities = vi.fn().mockReturnValue({
|
||||
prompts: {},
|
||||
});
|
||||
const mockedClient = {
|
||||
getServerCapabilities: mockGetServerCapabilities,
|
||||
request: mockRequest,
|
||||
} as unknown as ClientLib.Client;
|
||||
|
||||
const consoleErrorSpy = vi
|
||||
.spyOn(console, 'error')
|
||||
.mockImplementation(() => {});
|
||||
|
||||
await discoverPrompts('test-server', mockedClient, mockedPromptRegistry);
|
||||
|
||||
expect(mockRequest).toHaveBeenCalledOnce();
|
||||
expect(consoleErrorSpy).toHaveBeenCalledWith(
|
||||
`Error discovering prompts from test-server: ${testError.message}`,
|
||||
`Error discovering prompts from test-server: Test error`,
|
||||
);
|
||||
|
||||
consoleErrorSpy.mockRestore();
|
||||
});
|
||||
});
|
||||
|
||||
describe('appendMcpServerCommand', () => {
|
||||
it('should do nothing if no MCP servers or command are configured', () => {
|
||||
const out = populateMcpServerCommand({}, undefined);
|
||||
@@ -458,17 +211,6 @@ describe('mcp-client', () => {
|
||||
});
|
||||
|
||||
describe('createTransport', () => {
|
||||
const originalEnv = process.env;
|
||||
|
||||
beforeEach(() => {
|
||||
vi.resetModules();
|
||||
process.env = {};
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
process.env = originalEnv;
|
||||
});
|
||||
|
||||
describe('should connect via httpUrl', () => {
|
||||
it('without headers', async () => {
|
||||
const transport = await createTransport(
|
||||
@@ -558,7 +300,7 @@ describe('mcp-client', () => {
|
||||
command: 'test-command',
|
||||
args: ['--foo', 'bar'],
|
||||
cwd: 'test/cwd',
|
||||
env: { FOO: 'bar' },
|
||||
env: { ...process.env, FOO: 'bar' },
|
||||
stderr: 'pipe',
|
||||
});
|
||||
});
|
||||
|
||||
@@ -36,7 +36,7 @@ import { MCPOAuthTokenStorage } from '../mcp/oauth-token-storage.js';
|
||||
import { getErrorMessage } from '../utils/errors.js';
|
||||
import { basename } from 'node:path';
|
||||
import { pathToFileURL } from 'node:url';
|
||||
import { WorkspaceContext } from '../utils/workspaceContext.js';
|
||||
import { Unsubscribe, WorkspaceContext } from '../utils/workspaceContext.js';
|
||||
|
||||
export const MCP_DEFAULT_TIMEOUT_MSEC = 10 * 60 * 1000; // default to 10 minutes
|
||||
|
||||
@@ -69,6 +69,134 @@ export enum MCPDiscoveryState {
|
||||
COMPLETED = 'completed',
|
||||
}
|
||||
|
||||
/**
|
||||
* A client for a single MCP server.
|
||||
*
|
||||
* This class is responsible for connecting to, discovering tools from, and
|
||||
* managing the state of a single MCP server.
|
||||
*/
|
||||
export class McpClient {
|
||||
private client: Client;
|
||||
private transport: Transport | undefined;
|
||||
private status: MCPServerStatus = MCPServerStatus.DISCONNECTED;
|
||||
private isDisconnecting = false;
|
||||
|
||||
constructor(
|
||||
private readonly serverName: string,
|
||||
private readonly serverConfig: MCPServerConfig,
|
||||
private readonly toolRegistry: ToolRegistry,
|
||||
private readonly promptRegistry: PromptRegistry,
|
||||
private readonly workspaceContext: WorkspaceContext,
|
||||
private readonly debugMode: boolean,
|
||||
) {
|
||||
this.client = new Client({
|
||||
name: `gemini-cli-mcp-client-${this.serverName}`,
|
||||
version: '0.0.1',
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Connects to the MCP server.
|
||||
*/
|
||||
async connect(): Promise<void> {
|
||||
this.isDisconnecting = false;
|
||||
this.updateStatus(MCPServerStatus.CONNECTING);
|
||||
try {
|
||||
this.transport = await this.createTransport();
|
||||
|
||||
this.client.onerror = (error) => {
|
||||
if (this.isDisconnecting) {
|
||||
return;
|
||||
}
|
||||
console.error(`MCP ERROR (${this.serverName}):`, error.toString());
|
||||
this.updateStatus(MCPServerStatus.DISCONNECTED);
|
||||
};
|
||||
|
||||
this.client.registerCapabilities({
|
||||
roots: {},
|
||||
});
|
||||
|
||||
this.client.setRequestHandler(ListRootsRequestSchema, async () => {
|
||||
const roots = [];
|
||||
for (const dir of this.workspaceContext.getDirectories()) {
|
||||
roots.push({
|
||||
uri: pathToFileURL(dir).toString(),
|
||||
name: basename(dir),
|
||||
});
|
||||
}
|
||||
return {
|
||||
roots,
|
||||
};
|
||||
});
|
||||
|
||||
await this.client.connect(this.transport, {
|
||||
timeout: this.serverConfig.timeout,
|
||||
});
|
||||
|
||||
this.updateStatus(MCPServerStatus.CONNECTED);
|
||||
} catch (error) {
|
||||
this.updateStatus(MCPServerStatus.DISCONNECTED);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Discovers tools and prompts from the MCP server.
|
||||
*/
|
||||
async discover(): Promise<void> {
|
||||
if (this.status !== MCPServerStatus.CONNECTED) {
|
||||
throw new Error('Client is not connected.');
|
||||
}
|
||||
|
||||
const prompts = await this.discoverPrompts();
|
||||
const tools = await this.discoverTools();
|
||||
|
||||
if (prompts.length === 0 && tools.length === 0) {
|
||||
throw new Error('No prompts or tools found on the server.');
|
||||
}
|
||||
|
||||
for (const tool of tools) {
|
||||
this.toolRegistry.registerTool(tool);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Disconnects from the MCP server.
|
||||
*/
|
||||
async disconnect(): Promise<void> {
|
||||
this.isDisconnecting = true;
|
||||
if (this.transport) {
|
||||
await this.transport.close();
|
||||
}
|
||||
this.client.close();
|
||||
this.updateStatus(MCPServerStatus.DISCONNECTED);
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the current status of the client.
|
||||
*/
|
||||
getStatus(): MCPServerStatus {
|
||||
return this.status;
|
||||
}
|
||||
|
||||
private updateStatus(status: MCPServerStatus): void {
|
||||
this.status = status;
|
||||
updateMCPServerStatus(this.serverName, status);
|
||||
}
|
||||
|
||||
private async createTransport(): Promise<Transport> {
|
||||
return createTransport(this.serverName, this.serverConfig, this.debugMode);
|
||||
}
|
||||
|
||||
private async discoverTools(): Promise<DiscoveredMCPTool[]> {
|
||||
return discoverTools(this.serverName, this.serverConfig, this.client);
|
||||
}
|
||||
|
||||
private async discoverPrompts(): Promise<Prompt[]> {
|
||||
return discoverPrompts(this.serverName, this.client, this.promptRegistry);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Map to track the status of each MCP server within the core package
|
||||
*/
|
||||
@@ -117,7 +245,7 @@ export function removeMCPStatusChangeListener(
|
||||
/**
|
||||
* Update the status of an MCP server
|
||||
*/
|
||||
function updateMCPServerStatus(
|
||||
export function updateMCPServerStatus(
|
||||
serverName: string,
|
||||
status: MCPServerStatus,
|
||||
): void {
|
||||
@@ -227,10 +355,16 @@ async function handleAutomaticOAuth(
|
||||
};
|
||||
|
||||
// Perform OAuth authentication
|
||||
// Pass the server URL for proper discovery
|
||||
const serverUrl = mcpServerConfig.httpUrl || mcpServerConfig.url;
|
||||
console.log(
|
||||
`Starting OAuth authentication for server '${mcpServerName}'...`,
|
||||
);
|
||||
await MCPOAuthProvider.authenticate(mcpServerName, oauthAuthConfig);
|
||||
await MCPOAuthProvider.authenticate(
|
||||
mcpServerName,
|
||||
oauthAuthConfig,
|
||||
serverUrl,
|
||||
);
|
||||
|
||||
console.log(
|
||||
`OAuth authentication successful for server '${mcpServerName}'`,
|
||||
@@ -442,7 +576,7 @@ export function hasValidTypes(schema: unknown): boolean {
|
||||
|
||||
const s = schema as Record<string, unknown>;
|
||||
|
||||
if (!s.type) {
|
||||
if (!s['type']) {
|
||||
// These keywords contain an array of schemas that should be validated.
|
||||
//
|
||||
// If no top level type was given, then they must each have a type.
|
||||
@@ -464,9 +598,9 @@ export function hasValidTypes(schema: unknown): boolean {
|
||||
if (!hasSubSchema) return false;
|
||||
}
|
||||
|
||||
if (s.type === 'object' && s.properties) {
|
||||
if (typeof s.properties === 'object' && s.properties !== null) {
|
||||
for (const prop of Object.values(s.properties)) {
|
||||
if (s['type'] === 'object' && s['properties']) {
|
||||
if (typeof s['properties'] === 'object' && s['properties'] !== null) {
|
||||
for (const prop of Object.values(s['properties'])) {
|
||||
if (!hasValidTypes(prop)) {
|
||||
return false;
|
||||
}
|
||||
@@ -474,8 +608,8 @@ export function hasValidTypes(schema: unknown): boolean {
|
||||
}
|
||||
}
|
||||
|
||||
if (s.type === 'array' && s.items) {
|
||||
if (!hasValidTypes(s.items)) {
|
||||
if (s['type'] === 'array' && s['items']) {
|
||||
if (!hasValidTypes(s['items'])) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
@@ -671,7 +805,9 @@ export async function connectToMcpServer(
|
||||
});
|
||||
|
||||
mcpClient.registerCapabilities({
|
||||
roots: {},
|
||||
roots: {
|
||||
listChanged: true,
|
||||
},
|
||||
});
|
||||
|
||||
mcpClient.setRequestHandler(ListRootsRequestSchema, async () => {
|
||||
@@ -687,6 +823,32 @@ export async function connectToMcpServer(
|
||||
};
|
||||
});
|
||||
|
||||
let unlistenDirectories: Unsubscribe | undefined =
|
||||
workspaceContext.onDirectoriesChanged(async () => {
|
||||
try {
|
||||
await mcpClient.notification({
|
||||
method: 'notifications/roots/list_changed',
|
||||
});
|
||||
} catch (_) {
|
||||
// If this fails, its almost certainly because the connection was closed
|
||||
// and we should just stop listening for future directory changes.
|
||||
unlistenDirectories?.();
|
||||
unlistenDirectories = undefined;
|
||||
}
|
||||
});
|
||||
|
||||
// Attempt to pro-actively unsubscribe if the mcp client closes. This API is
|
||||
// very brittle though so we don't have any guarantees, hence the try/catch
|
||||
// above as well.
|
||||
//
|
||||
// Be a good steward and don't just bash over onclose.
|
||||
const oldOnClose = mcpClient.onclose;
|
||||
mcpClient.onclose = () => {
|
||||
oldOnClose?.();
|
||||
unlistenDirectories?.();
|
||||
unlistenDirectories = undefined;
|
||||
};
|
||||
|
||||
// patch Client.callTool to use request timeout as genai McpCallTool.callTool does not do it
|
||||
// TODO: remove this hack once GenAI SDK does callTool with request options
|
||||
if ('callTool' in mcpClient) {
|
||||
@@ -933,12 +1095,15 @@ export async function connectToMcpServer(
|
||||
};
|
||||
|
||||
// Perform OAuth authentication
|
||||
// Pass the server URL for proper discovery
|
||||
const serverUrl = mcpServerConfig.httpUrl || mcpServerConfig.url;
|
||||
console.log(
|
||||
`Starting OAuth authentication for server '${mcpServerName}'...`,
|
||||
);
|
||||
await MCPOAuthProvider.authenticate(
|
||||
mcpServerName,
|
||||
oauthAuthConfig,
|
||||
serverUrl,
|
||||
);
|
||||
|
||||
// Retry connection with OAuth token
|
||||
@@ -1037,7 +1202,7 @@ export async function connectToMcpServer(
|
||||
conciseError = `Connection failed for '${mcpServerName}': ${errorMessage}`;
|
||||
}
|
||||
|
||||
if (process.env.SANDBOX) {
|
||||
if (process.env['SANDBOX']) {
|
||||
conciseError += ` (check sandbox availability)`;
|
||||
}
|
||||
|
||||
|
||||
@@ -86,7 +86,7 @@ describe('DiscoveredMCPTool', () => {
|
||||
inputSchema,
|
||||
);
|
||||
// Clear allowlist before each relevant test, especially for shouldConfirmExecute
|
||||
const invocation = tool.build({}) as any;
|
||||
const invocation = tool.build({ param: 'mock' }) as any;
|
||||
invocation.constructor.allowlist.clear();
|
||||
});
|
||||
|
||||
@@ -185,8 +185,100 @@ describe('DiscoveredMCPTool', () => {
|
||||
).rejects.toThrow(expectedError);
|
||||
});
|
||||
|
||||
it.each([
|
||||
{ isErrorValue: true, description: 'true (bool)' },
|
||||
{ isErrorValue: 'true', description: '"true" (str)' },
|
||||
])(
|
||||
'should consider a ToolResult with isError $description to be a failure',
|
||||
async ({ isErrorValue }) => {
|
||||
const tool = new DiscoveredMCPTool(
|
||||
mockCallableToolInstance,
|
||||
serverName,
|
||||
serverToolName,
|
||||
baseDescription,
|
||||
inputSchema,
|
||||
);
|
||||
const params = { param: 'isErrorTrueCase' };
|
||||
|
||||
const errorResponse = { isError: isErrorValue };
|
||||
const mockMcpToolResponseParts: Part[] = [
|
||||
{
|
||||
functionResponse: {
|
||||
name: serverToolName,
|
||||
response: { error: errorResponse },
|
||||
},
|
||||
},
|
||||
];
|
||||
mockCallTool.mockResolvedValue(mockMcpToolResponseParts);
|
||||
const expectedError = new Error(
|
||||
`MCP tool '${serverToolName}' reported tool error with response: ${JSON.stringify(
|
||||
mockMcpToolResponseParts,
|
||||
)}`,
|
||||
);
|
||||
|
||||
const invocation = tool.build(params);
|
||||
await expect(
|
||||
invocation.execute(new AbortController().signal),
|
||||
).rejects.toThrow(expectedError);
|
||||
},
|
||||
);
|
||||
|
||||
it.each([
|
||||
{ isErrorValue: false, description: 'false (bool)' },
|
||||
{ isErrorValue: 'false', description: '"false" (str)' },
|
||||
])(
|
||||
'should consider a ToolResult with isError ${description} to be a success',
|
||||
async ({ isErrorValue }) => {
|
||||
const tool = new DiscoveredMCPTool(
|
||||
mockCallableToolInstance,
|
||||
serverName,
|
||||
serverToolName,
|
||||
baseDescription,
|
||||
inputSchema,
|
||||
);
|
||||
const params = { param: 'isErrorFalseCase' };
|
||||
const mockToolSuccessResultObject = {
|
||||
success: true,
|
||||
details: 'executed',
|
||||
};
|
||||
const mockFunctionResponseContent = [
|
||||
{
|
||||
type: 'text',
|
||||
text: JSON.stringify(mockToolSuccessResultObject),
|
||||
},
|
||||
];
|
||||
|
||||
const errorResponse = { isError: isErrorValue };
|
||||
const mockMcpToolResponseParts: Part[] = [
|
||||
{
|
||||
functionResponse: {
|
||||
name: serverToolName,
|
||||
response: {
|
||||
error: errorResponse,
|
||||
content: mockFunctionResponseContent,
|
||||
},
|
||||
},
|
||||
},
|
||||
];
|
||||
mockCallTool.mockResolvedValue(mockMcpToolResponseParts);
|
||||
|
||||
const invocation = tool.build(params);
|
||||
const toolResult = await invocation.execute(
|
||||
new AbortController().signal,
|
||||
);
|
||||
|
||||
const stringifiedResponseContent = JSON.stringify(
|
||||
mockToolSuccessResultObject,
|
||||
);
|
||||
expect(toolResult.llmContent).toEqual([
|
||||
{ text: stringifiedResponseContent },
|
||||
]);
|
||||
expect(toolResult.returnDisplay).toBe(stringifiedResponseContent);
|
||||
},
|
||||
);
|
||||
|
||||
it('should handle a simple text response correctly', async () => {
|
||||
const params = { query: 'test' };
|
||||
const params = { param: 'test' };
|
||||
const successMessage = 'This is a success message.';
|
||||
|
||||
// Simulate the response from the GenAI SDK, which wraps the MCP
|
||||
@@ -220,7 +312,7 @@ describe('DiscoveredMCPTool', () => {
|
||||
});
|
||||
|
||||
it('should handle an AudioBlock response', async () => {
|
||||
const params = { action: 'play' };
|
||||
const params = { param: 'play' };
|
||||
const sdkResponse: Part[] = [
|
||||
{
|
||||
functionResponse: {
|
||||
@@ -257,7 +349,7 @@ describe('DiscoveredMCPTool', () => {
|
||||
});
|
||||
|
||||
it('should handle a ResourceLinkBlock response', async () => {
|
||||
const params = { resource: 'get' };
|
||||
const params = { param: 'get' };
|
||||
const sdkResponse: Part[] = [
|
||||
{
|
||||
functionResponse: {
|
||||
@@ -291,7 +383,7 @@ describe('DiscoveredMCPTool', () => {
|
||||
});
|
||||
|
||||
it('should handle an embedded text ResourceBlock response', async () => {
|
||||
const params = { resource: 'get' };
|
||||
const params = { param: 'get' };
|
||||
const sdkResponse: Part[] = [
|
||||
{
|
||||
functionResponse: {
|
||||
@@ -323,7 +415,7 @@ describe('DiscoveredMCPTool', () => {
|
||||
});
|
||||
|
||||
it('should handle an embedded binary ResourceBlock response', async () => {
|
||||
const params = { resource: 'get' };
|
||||
const params = { param: 'get' };
|
||||
const sdkResponse: Part[] = [
|
||||
{
|
||||
functionResponse: {
|
||||
@@ -365,7 +457,7 @@ describe('DiscoveredMCPTool', () => {
|
||||
});
|
||||
|
||||
it('should handle a mix of content block types', async () => {
|
||||
const params = { action: 'complex' };
|
||||
const params = { param: 'complex' };
|
||||
const sdkResponse: Part[] = [
|
||||
{
|
||||
functionResponse: {
|
||||
@@ -408,7 +500,7 @@ describe('DiscoveredMCPTool', () => {
|
||||
});
|
||||
|
||||
it('should ignore unknown content block types', async () => {
|
||||
const params = { action: 'test' };
|
||||
const params = { param: 'test' };
|
||||
const sdkResponse: Part[] = [
|
||||
{
|
||||
functionResponse: {
|
||||
@@ -434,7 +526,7 @@ describe('DiscoveredMCPTool', () => {
|
||||
});
|
||||
|
||||
it('should handle a complex mix of content block types', async () => {
|
||||
const params = { action: 'super-complex' };
|
||||
const params = { param: 'super-complex' };
|
||||
const sdkResponse: Part[] = [
|
||||
{
|
||||
functionResponse: {
|
||||
@@ -504,14 +596,14 @@ describe('DiscoveredMCPTool', () => {
|
||||
undefined,
|
||||
true,
|
||||
);
|
||||
const invocation = trustedTool.build({});
|
||||
const invocation = trustedTool.build({ param: 'mock' });
|
||||
expect(
|
||||
await invocation.shouldConfirmExecute(new AbortController().signal),
|
||||
).toBe(false);
|
||||
});
|
||||
|
||||
it('should return false if server is allowlisted', async () => {
|
||||
const invocation = tool.build({}) as any;
|
||||
const invocation = tool.build({ param: 'mock' }) as any;
|
||||
invocation.constructor.allowlist.add(serverName);
|
||||
expect(
|
||||
await invocation.shouldConfirmExecute(new AbortController().signal),
|
||||
@@ -520,7 +612,7 @@ describe('DiscoveredMCPTool', () => {
|
||||
|
||||
it('should return false if tool is allowlisted', async () => {
|
||||
const toolAllowlistKey = `${serverName}.${serverToolName}`;
|
||||
const invocation = tool.build({}) as any;
|
||||
const invocation = tool.build({ param: 'mock' }) as any;
|
||||
invocation.constructor.allowlist.add(toolAllowlistKey);
|
||||
expect(
|
||||
await invocation.shouldConfirmExecute(new AbortController().signal),
|
||||
@@ -528,7 +620,7 @@ describe('DiscoveredMCPTool', () => {
|
||||
});
|
||||
|
||||
it('should return confirmation details if not trusted and not allowlisted', async () => {
|
||||
const invocation = tool.build({});
|
||||
const invocation = tool.build({ param: 'mock' });
|
||||
const confirmation = await invocation.shouldConfirmExecute(
|
||||
new AbortController().signal,
|
||||
);
|
||||
@@ -551,7 +643,7 @@ describe('DiscoveredMCPTool', () => {
|
||||
});
|
||||
|
||||
it('should add server to allowlist on ProceedAlwaysServer', async () => {
|
||||
const invocation = tool.build({}) as any;
|
||||
const invocation = tool.build({ param: 'mock' }) as any;
|
||||
const confirmation = await invocation.shouldConfirmExecute(
|
||||
new AbortController().signal,
|
||||
);
|
||||
@@ -575,7 +667,7 @@ describe('DiscoveredMCPTool', () => {
|
||||
|
||||
it('should add tool to allowlist on ProceedAlwaysTool', async () => {
|
||||
const toolAllowlistKey = `${serverName}.${serverToolName}`;
|
||||
const invocation = tool.build({}) as any;
|
||||
const invocation = tool.build({ param: 'mock' }) as any;
|
||||
const confirmation = await invocation.shouldConfirmExecute(
|
||||
new AbortController().signal,
|
||||
);
|
||||
@@ -598,7 +690,7 @@ describe('DiscoveredMCPTool', () => {
|
||||
});
|
||||
|
||||
it('should handle Cancel confirmation outcome', async () => {
|
||||
const invocation = tool.build({}) as any;
|
||||
const invocation = tool.build({ param: 'mock' }) as any;
|
||||
const confirmation = await invocation.shouldConfirmExecute(
|
||||
new AbortController().signal,
|
||||
);
|
||||
@@ -625,7 +717,7 @@ describe('DiscoveredMCPTool', () => {
|
||||
});
|
||||
|
||||
it('should handle ProceedOnce confirmation outcome', async () => {
|
||||
const invocation = tool.build({}) as any;
|
||||
const invocation = tool.build({ param: 'mock' }) as any;
|
||||
const confirmation = await invocation.shouldConfirmExecute(
|
||||
new AbortController().signal,
|
||||
);
|
||||
|
||||
@@ -104,6 +104,28 @@ class DiscoveredMCPToolInvocation extends BaseToolInvocation<
|
||||
return confirmationDetails;
|
||||
}
|
||||
|
||||
// Determine if the response contains tool errors
|
||||
// This is needed because CallToolResults should return errors inside the response.
|
||||
// ref: https://modelcontextprotocol.io/specification/2025-06-18/schema#calltoolresult
|
||||
isMCPToolError(rawResponseParts: Part[]): boolean {
|
||||
const functionResponse = rawResponseParts?.[0]?.functionResponse;
|
||||
const response = functionResponse?.response;
|
||||
|
||||
interface McpError {
|
||||
isError?: boolean | string;
|
||||
}
|
||||
|
||||
if (response) {
|
||||
const error = (response as { error?: McpError })?.error;
|
||||
const isError = error?.isError;
|
||||
|
||||
if (error && (isError === true || isError === 'true')) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
async execute(): Promise<ToolResult> {
|
||||
const functionCalls: FunctionCall[] = [
|
||||
{
|
||||
@@ -113,6 +135,14 @@ class DiscoveredMCPToolInvocation extends BaseToolInvocation<
|
||||
];
|
||||
|
||||
const rawResponseParts = await this.mcpTool.callTool(functionCalls);
|
||||
|
||||
// Ensure the response is not an error
|
||||
if (this.isMCPToolError(rawResponseParts)) {
|
||||
throw new Error(
|
||||
`MCP tool '${this.serverToolName}' reported tool error with response: ${JSON.stringify(rawResponseParts)}`,
|
||||
);
|
||||
}
|
||||
|
||||
const transformedParts = transformMcpContentToParts(rawResponseParts);
|
||||
|
||||
return {
|
||||
@@ -241,7 +271,7 @@ function transformResourceLinkBlock(block: McpResourceLinkBlock): Part {
|
||||
*/
|
||||
function transformMcpContentToParts(sdkResponse: Part[]): Part[] {
|
||||
const funcResponse = sdkResponse?.[0]?.functionResponse;
|
||||
const mcpContent = funcResponse?.response?.content as McpContentBlock[];
|
||||
const mcpContent = funcResponse?.response?.['content'] as McpContentBlock[];
|
||||
const toolName = funcResponse?.name || 'unknown tool';
|
||||
|
||||
if (!Array.isArray(mcpContent)) {
|
||||
@@ -278,8 +308,9 @@ function transformMcpContentToParts(sdkResponse: Part[]): Part[] {
|
||||
* @returns A formatted string representing the tool's output.
|
||||
*/
|
||||
function getStringifiedResultForDisplay(rawResponse: Part[]): string {
|
||||
const mcpContent = rawResponse?.[0]?.functionResponse?.response
|
||||
?.content as McpContentBlock[];
|
||||
const mcpContent = rawResponse?.[0]?.functionResponse?.response?.[
|
||||
'content'
|
||||
] as McpContentBlock[];
|
||||
|
||||
if (!Array.isArray(mcpContent)) {
|
||||
return '```json\n' + JSON.stringify(rawResponse, null, 2) + '\n```';
|
||||
|
||||
@@ -20,7 +20,6 @@ import * as Diff from 'diff';
|
||||
import { DEFAULT_DIFF_OPTIONS } from './diffOptions.js';
|
||||
import { tildeifyPath } from '../utils/paths.js';
|
||||
import { ModifiableDeclarativeTool, ModifyContext } from './modifiable-tool.js';
|
||||
import { SchemaValidator } from '../utils/schemaValidator.js';
|
||||
|
||||
const memoryToolSchemaData: FunctionDeclaration = {
|
||||
name: 'save_memory',
|
||||
@@ -396,15 +395,9 @@ export class MemoryTool
|
||||
);
|
||||
}
|
||||
|
||||
override validateToolParams(params: SaveMemoryParams): string | null {
|
||||
const errors = SchemaValidator.validate(
|
||||
this.schema.parametersJsonSchema,
|
||||
params,
|
||||
);
|
||||
if (errors) {
|
||||
return errors;
|
||||
}
|
||||
|
||||
protected override validateToolParamValues(
|
||||
params: SaveMemoryParams,
|
||||
): string | null {
|
||||
if (params.fact.trim() === '') {
|
||||
return 'Parameter "fact" must be a non-empty string.';
|
||||
}
|
||||
|
||||
@@ -13,6 +13,7 @@ import fs from 'fs';
|
||||
import fsp from 'fs/promises';
|
||||
import { Config } from '../config/config.js';
|
||||
import { FileDiscoveryService } from '../services/fileDiscoveryService.js';
|
||||
import { StandardFileSystemService } from '../services/fileSystemService.js';
|
||||
import { createMockWorkspaceContext } from '../test-utils/mockWorkspaceContext.js';
|
||||
import { ToolInvocation, ToolResult } from './tools.js';
|
||||
|
||||
@@ -29,6 +30,7 @@ describe('ReadFileTool', () => {
|
||||
|
||||
const mockConfigInstance = {
|
||||
getFileService: () => new FileDiscoveryService(tempRootDir),
|
||||
getFileSystemService: () => new StandardFileSystemService(),
|
||||
getTargetDir: () => tempRootDir,
|
||||
getWorkspaceContext: () => createMockWorkspaceContext(tempRootDir),
|
||||
} as unknown as Config;
|
||||
@@ -69,6 +71,15 @@ describe('ReadFileTool', () => {
|
||||
);
|
||||
});
|
||||
|
||||
it('should throw error if path is empty', () => {
|
||||
const params: ReadFileToolParams = {
|
||||
absolute_path: '',
|
||||
};
|
||||
expect(() => tool.build(params)).toThrow(
|
||||
/The 'absolute_path' parameter must be non-empty./,
|
||||
);
|
||||
});
|
||||
|
||||
it('should throw error if offset is negative', () => {
|
||||
const params: ReadFileToolParams = {
|
||||
absolute_path: path.join(tempRootDir, 'test.txt'),
|
||||
|
||||
@@ -5,7 +5,6 @@
|
||||
*/
|
||||
|
||||
import path from 'path';
|
||||
import { SchemaValidator } from '../utils/schemaValidator.js';
|
||||
import { makeRelative, shortenPath } from '../utils/paths.js';
|
||||
import {
|
||||
BaseDeclarativeTool,
|
||||
@@ -74,6 +73,7 @@ class ReadFileToolInvocation extends BaseToolInvocation<
|
||||
const result = await processSingleFileContent(
|
||||
this.params.absolute_path,
|
||||
this.config.getTargetDir(),
|
||||
this.config.getFileSystemService(),
|
||||
this.params.offset,
|
||||
this.params.limit,
|
||||
);
|
||||
@@ -198,18 +198,14 @@ export class ReadFileTool extends BaseDeclarativeTool<
|
||||
);
|
||||
}
|
||||
|
||||
protected override validateToolParams(
|
||||
protected override validateToolParamValues(
|
||||
params: ReadFileToolParams,
|
||||
): string | null {
|
||||
const errors = SchemaValidator.validate(
|
||||
this.schema.parametersJsonSchema,
|
||||
params,
|
||||
);
|
||||
if (errors) {
|
||||
return errors;
|
||||
const filePath = params.absolute_path;
|
||||
if (params.absolute_path.trim() === '') {
|
||||
return "The 'absolute_path' parameter must be non-empty.";
|
||||
}
|
||||
|
||||
const filePath = params.absolute_path;
|
||||
if (!path.isAbsolute(filePath)) {
|
||||
return `File path must be absolute, but was relative: ${filePath}. You must provide an absolute path.`;
|
||||
}
|
||||
|
||||
@@ -14,6 +14,7 @@ import fs from 'fs'; // Actual fs for setup
|
||||
import os from 'os';
|
||||
import { Config } from '../config/config.js';
|
||||
import { WorkspaceContext } from '../utils/workspaceContext.js';
|
||||
import { StandardFileSystemService } from '../services/fileSystemService.js';
|
||||
|
||||
vi.mock('mime-types', () => {
|
||||
const lookup = (filename: string) => {
|
||||
@@ -59,6 +60,7 @@ describe('ReadManyFilesTool', () => {
|
||||
const fileService = new FileDiscoveryService(tempRootDir);
|
||||
const mockConfig = {
|
||||
getFileService: () => fileService,
|
||||
getFileSystemService: () => new StandardFileSystemService(),
|
||||
|
||||
getFileFilteringOptions: () => ({
|
||||
respectGitIgnore: true,
|
||||
@@ -456,6 +458,7 @@ describe('ReadManyFilesTool', () => {
|
||||
const fileService = new FileDiscoveryService(tempDir1);
|
||||
const mockConfig = {
|
||||
getFileService: () => fileService,
|
||||
getFileSystemService: () => new StandardFileSystemService(),
|
||||
getFileFilteringOptions: () => ({
|
||||
respectGitIgnore: true,
|
||||
respectGeminiIgnore: true,
|
||||
@@ -524,6 +527,43 @@ describe('ReadManyFilesTool', () => {
|
||||
expect(truncatedFileContent).toContain('L200');
|
||||
expect(truncatedFileContent).not.toContain('L2400');
|
||||
});
|
||||
|
||||
it('should read files with special characters like [] and () in the path', async () => {
|
||||
const filePath = 'src/app/[test]/(dashboard)/testing/components/code.tsx';
|
||||
createFile(filePath, 'Content of receive-detail');
|
||||
const params = { paths: [filePath] };
|
||||
const invocation = tool.build(params);
|
||||
const result = await invocation.execute(new AbortController().signal);
|
||||
const expectedPath = path.join(tempRootDir, filePath);
|
||||
expect(result.llmContent).toEqual([
|
||||
`--- ${expectedPath} ---
|
||||
|
||||
Content of receive-detail
|
||||
|
||||
`,
|
||||
]);
|
||||
expect(result.returnDisplay).toContain(
|
||||
'Successfully read and concatenated content from **1 file(s)**',
|
||||
);
|
||||
});
|
||||
|
||||
it('should read files with special characters in the name', async () => {
|
||||
createFile('file[1].txt', 'Content of file[1]');
|
||||
const params = { paths: ['file[1].txt'] };
|
||||
const invocation = tool.build(params);
|
||||
const result = await invocation.execute(new AbortController().signal);
|
||||
const expectedPath = path.join(tempRootDir, 'file[1].txt');
|
||||
expect(result.llmContent).toEqual([
|
||||
`--- ${expectedPath} ---
|
||||
|
||||
Content of file[1]
|
||||
|
||||
`,
|
||||
]);
|
||||
expect(result.returnDisplay).toContain(
|
||||
'Successfully read and concatenated content from **1 file(s)**',
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Batch Processing', () => {
|
||||
|
||||
@@ -11,10 +11,10 @@ import {
|
||||
ToolInvocation,
|
||||
ToolResult,
|
||||
} from './tools.js';
|
||||
import { SchemaValidator } from '../utils/schemaValidator.js';
|
||||
import { getErrorMessage } from '../utils/errors.js';
|
||||
import * as fs from 'fs';
|
||||
import * as path from 'path';
|
||||
import { glob } from 'glob';
|
||||
import { glob, escape } from 'glob';
|
||||
import { getCurrentGeminiMdFilename } from './memoryTool.js';
|
||||
import {
|
||||
detectFileType,
|
||||
@@ -245,18 +245,27 @@ ${finalExclusionPatternsForDescription
|
||||
const workspaceDirs = this.config.getWorkspaceContext().getDirectories();
|
||||
|
||||
for (const dir of workspaceDirs) {
|
||||
const entriesInDir = await glob(
|
||||
searchPatterns.map((p) => p.replace(/\\/g, '/')),
|
||||
{
|
||||
cwd: dir,
|
||||
ignore: effectiveExcludes,
|
||||
nodir: true,
|
||||
dot: true,
|
||||
absolute: true,
|
||||
nocase: true,
|
||||
signal,
|
||||
},
|
||||
);
|
||||
const processedPatterns = [];
|
||||
for (const p of searchPatterns) {
|
||||
const normalizedP = p.replace(/\\/g, '/');
|
||||
const fullPath = path.join(dir, normalizedP);
|
||||
if (fs.existsSync(fullPath)) {
|
||||
processedPatterns.push(escape(normalizedP));
|
||||
} else {
|
||||
// The path does not exist or is not a file, so we treat it as a glob pattern.
|
||||
processedPatterns.push(normalizedP);
|
||||
}
|
||||
}
|
||||
|
||||
const entriesInDir = await glob(processedPatterns, {
|
||||
cwd: dir,
|
||||
ignore: effectiveExcludes,
|
||||
nodir: true,
|
||||
dot: true,
|
||||
absolute: true,
|
||||
nocase: true,
|
||||
signal,
|
||||
});
|
||||
for (const entry of entriesInDir) {
|
||||
allEntries.add(entry);
|
||||
}
|
||||
@@ -388,6 +397,7 @@ ${finalExclusionPatternsForDescription
|
||||
const fileReadResult = await processSingleFileContent(
|
||||
filePath,
|
||||
this.config.getTargetDir(),
|
||||
this.config.getFileSystemService(),
|
||||
);
|
||||
|
||||
if (fileReadResult.error) {
|
||||
@@ -626,19 +636,6 @@ Use this tool when the user's query implies needing the content of several files
|
||||
);
|
||||
}
|
||||
|
||||
protected override validateToolParams(
|
||||
params: ReadManyFilesParams,
|
||||
): string | null {
|
||||
const errors = SchemaValidator.validate(
|
||||
this.schema.parametersJsonSchema,
|
||||
params,
|
||||
);
|
||||
if (errors) {
|
||||
return errors;
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
protected createInvocation(
|
||||
params: ReadManyFilesParams,
|
||||
): ToolInvocation<ReadManyFilesParams, ToolResult> {
|
||||
|
||||
@@ -61,6 +61,7 @@ describe('ShellTool', () => {
|
||||
name: 'Qwen-Coder',
|
||||
email: 'qwen-coder@alibabacloud.com',
|
||||
}),
|
||||
getShouldUseNodePtyShell: vi.fn().mockReturnValue(false),
|
||||
} as unknown as Config;
|
||||
|
||||
shellTool = new ShellTool(mockConfig);
|
||||
@@ -151,13 +152,12 @@ describe('ShellTool', () => {
|
||||
const fullResult: ShellExecutionResult = {
|
||||
rawOutput: Buffer.from(result.output || ''),
|
||||
output: 'Success',
|
||||
stdout: 'Success',
|
||||
stderr: '',
|
||||
exitCode: 0,
|
||||
signal: null,
|
||||
error: null,
|
||||
aborted: false,
|
||||
pid: 12345,
|
||||
executionMethod: 'child_process',
|
||||
...result,
|
||||
};
|
||||
resolveExecutionPromise(fullResult);
|
||||
@@ -183,6 +183,9 @@ describe('ShellTool', () => {
|
||||
expect.any(String),
|
||||
expect.any(Function),
|
||||
mockAbortSignal,
|
||||
false,
|
||||
undefined,
|
||||
undefined,
|
||||
);
|
||||
expect(result.llmContent).toContain('Background PIDs: 54322');
|
||||
expect(vi.mocked(fs.unlinkSync)).toHaveBeenCalledWith(tmpFile);
|
||||
@@ -208,6 +211,9 @@ describe('ShellTool', () => {
|
||||
expect.any(String),
|
||||
expect.any(Function),
|
||||
mockAbortSignal,
|
||||
false,
|
||||
undefined,
|
||||
undefined,
|
||||
);
|
||||
});
|
||||
|
||||
@@ -231,6 +237,9 @@ describe('ShellTool', () => {
|
||||
expect.any(String),
|
||||
expect.any(Function),
|
||||
mockAbortSignal,
|
||||
false,
|
||||
undefined,
|
||||
undefined,
|
||||
);
|
||||
});
|
||||
|
||||
@@ -254,6 +263,9 @@ describe('ShellTool', () => {
|
||||
expect.any(String),
|
||||
expect.any(Function),
|
||||
mockAbortSignal,
|
||||
false,
|
||||
undefined,
|
||||
undefined,
|
||||
);
|
||||
});
|
||||
|
||||
@@ -267,13 +279,12 @@ describe('ShellTool', () => {
|
||||
resolveShellExecution({
|
||||
rawOutput: Buffer.from(''),
|
||||
output: '',
|
||||
stdout: '',
|
||||
stderr: '',
|
||||
exitCode: 0,
|
||||
signal: null,
|
||||
error: null,
|
||||
aborted: false,
|
||||
pid: 12345,
|
||||
executionMethod: 'child_process',
|
||||
});
|
||||
await promise;
|
||||
expect(mockShellExecutionService).toHaveBeenCalledWith(
|
||||
@@ -281,6 +292,9 @@ describe('ShellTool', () => {
|
||||
expect.any(String),
|
||||
expect.any(Function),
|
||||
mockAbortSignal,
|
||||
false,
|
||||
undefined,
|
||||
undefined,
|
||||
);
|
||||
});
|
||||
|
||||
@@ -295,16 +309,14 @@ describe('ShellTool', () => {
|
||||
error,
|
||||
exitCode: 1,
|
||||
output: 'err',
|
||||
stderr: 'err',
|
||||
rawOutput: Buffer.from('err'),
|
||||
stdout: '',
|
||||
signal: null,
|
||||
aborted: false,
|
||||
pid: 12345,
|
||||
executionMethod: 'child_process',
|
||||
});
|
||||
|
||||
const result = await promise;
|
||||
// The final llmContent should contain the user's command, not the wrapper
|
||||
expect(result.llmContent).toContain('Error: wrapped command failed');
|
||||
expect(result.llmContent).not.toContain('pgrep');
|
||||
});
|
||||
@@ -344,13 +356,12 @@ describe('ShellTool', () => {
|
||||
resolveExecutionPromise({
|
||||
output: 'long output',
|
||||
rawOutput: Buffer.from('long output'),
|
||||
stdout: 'long output',
|
||||
stderr: '',
|
||||
exitCode: 0,
|
||||
signal: null,
|
||||
error: null,
|
||||
aborted: false,
|
||||
pid: 12345,
|
||||
executionMethod: 'child_process',
|
||||
});
|
||||
|
||||
const result = await promise;
|
||||
@@ -402,7 +413,6 @@ describe('ShellTool', () => {
|
||||
// First chunk, should be throttled.
|
||||
mockShellOutputCallback({
|
||||
type: 'data',
|
||||
stream: 'stdout',
|
||||
chunk: 'hello ',
|
||||
});
|
||||
expect(updateOutputMock).not.toHaveBeenCalled();
|
||||
@@ -413,24 +423,22 @@ describe('ShellTool', () => {
|
||||
// Send a second chunk. THIS event triggers the update with the CUMULATIVE content.
|
||||
mockShellOutputCallback({
|
||||
type: 'data',
|
||||
stream: 'stderr',
|
||||
chunk: 'world',
|
||||
chunk: 'hello world',
|
||||
});
|
||||
|
||||
// It should have been called once now with the combined output.
|
||||
expect(updateOutputMock).toHaveBeenCalledOnce();
|
||||
expect(updateOutputMock).toHaveBeenCalledWith('hello \nworld');
|
||||
expect(updateOutputMock).toHaveBeenCalledWith('hello world');
|
||||
|
||||
resolveExecutionPromise({
|
||||
rawOutput: Buffer.from(''),
|
||||
output: '',
|
||||
stdout: '',
|
||||
stderr: '',
|
||||
exitCode: 0,
|
||||
signal: null,
|
||||
error: null,
|
||||
aborted: false,
|
||||
pid: 12345,
|
||||
executionMethod: 'child_process',
|
||||
});
|
||||
await promise;
|
||||
});
|
||||
@@ -472,13 +480,12 @@ describe('ShellTool', () => {
|
||||
resolveExecutionPromise({
|
||||
rawOutput: Buffer.from(''),
|
||||
output: '',
|
||||
stdout: '',
|
||||
stderr: '',
|
||||
exitCode: 0,
|
||||
signal: null,
|
||||
error: null,
|
||||
aborted: false,
|
||||
pid: 12345,
|
||||
executionMethod: 'child_process',
|
||||
});
|
||||
await promise;
|
||||
});
|
||||
@@ -494,13 +501,12 @@ describe('ShellTool', () => {
|
||||
resolveExecutionPromise({
|
||||
rawOutput: Buffer.from(''),
|
||||
output: '',
|
||||
stdout: '',
|
||||
stderr: '',
|
||||
exitCode: 0,
|
||||
signal: null,
|
||||
error: null,
|
||||
aborted: false,
|
||||
pid: 12345,
|
||||
executionMethod: 'child_process',
|
||||
});
|
||||
|
||||
await promise;
|
||||
@@ -513,6 +519,9 @@ describe('ShellTool', () => {
|
||||
expect.any(String),
|
||||
expect.any(Function),
|
||||
mockAbortSignal,
|
||||
false,
|
||||
undefined,
|
||||
undefined,
|
||||
);
|
||||
});
|
||||
|
||||
@@ -524,13 +533,12 @@ describe('ShellTool', () => {
|
||||
resolveExecutionPromise({
|
||||
rawOutput: Buffer.from(''),
|
||||
output: '',
|
||||
stdout: '',
|
||||
stderr: '',
|
||||
exitCode: 0,
|
||||
signal: null,
|
||||
error: null,
|
||||
aborted: false,
|
||||
pid: 12345,
|
||||
executionMethod: 'child_process',
|
||||
});
|
||||
|
||||
await promise;
|
||||
@@ -542,6 +550,9 @@ describe('ShellTool', () => {
|
||||
expect.any(String),
|
||||
expect.any(Function),
|
||||
mockAbortSignal,
|
||||
false,
|
||||
undefined,
|
||||
undefined,
|
||||
);
|
||||
});
|
||||
|
||||
@@ -553,13 +564,12 @@ describe('ShellTool', () => {
|
||||
resolveExecutionPromise({
|
||||
rawOutput: Buffer.from(''),
|
||||
output: '',
|
||||
stdout: '',
|
||||
stderr: '',
|
||||
exitCode: 0,
|
||||
signal: null,
|
||||
error: null,
|
||||
aborted: false,
|
||||
pid: 12345,
|
||||
executionMethod: 'child_process',
|
||||
});
|
||||
|
||||
await promise;
|
||||
@@ -571,6 +581,9 @@ describe('ShellTool', () => {
|
||||
expect.any(String),
|
||||
expect.any(Function),
|
||||
mockAbortSignal,
|
||||
false,
|
||||
undefined,
|
||||
undefined,
|
||||
);
|
||||
});
|
||||
|
||||
@@ -582,13 +595,12 @@ describe('ShellTool', () => {
|
||||
resolveExecutionPromise({
|
||||
rawOutput: Buffer.from(''),
|
||||
output: '',
|
||||
stdout: '',
|
||||
stderr: '',
|
||||
exitCode: 0,
|
||||
signal: null,
|
||||
error: null,
|
||||
aborted: false,
|
||||
pid: 12345,
|
||||
executionMethod: 'child_process',
|
||||
});
|
||||
|
||||
await promise;
|
||||
@@ -599,6 +611,9 @@ describe('ShellTool', () => {
|
||||
expect.any(String),
|
||||
expect.any(Function),
|
||||
mockAbortSignal,
|
||||
false,
|
||||
undefined,
|
||||
undefined,
|
||||
);
|
||||
});
|
||||
|
||||
@@ -610,13 +625,12 @@ describe('ShellTool', () => {
|
||||
resolveExecutionPromise({
|
||||
rawOutput: Buffer.from(''),
|
||||
output: '',
|
||||
stdout: '',
|
||||
stderr: '',
|
||||
exitCode: 0,
|
||||
signal: null,
|
||||
error: null,
|
||||
aborted: false,
|
||||
pid: 12345,
|
||||
executionMethod: 'child_process',
|
||||
});
|
||||
|
||||
await promise;
|
||||
@@ -627,6 +641,9 @@ describe('ShellTool', () => {
|
||||
expect.any(String),
|
||||
expect.any(Function),
|
||||
mockAbortSignal,
|
||||
false,
|
||||
undefined,
|
||||
undefined,
|
||||
);
|
||||
});
|
||||
|
||||
@@ -638,13 +655,12 @@ describe('ShellTool', () => {
|
||||
resolveExecutionPromise({
|
||||
rawOutput: Buffer.from(''),
|
||||
output: '',
|
||||
stdout: '',
|
||||
stderr: '',
|
||||
exitCode: 0,
|
||||
signal: null,
|
||||
error: null,
|
||||
aborted: false,
|
||||
pid: 12345,
|
||||
executionMethod: 'child_process',
|
||||
});
|
||||
|
||||
await promise;
|
||||
@@ -656,6 +672,9 @@ describe('ShellTool', () => {
|
||||
expect.any(String),
|
||||
expect.any(Function),
|
||||
mockAbortSignal,
|
||||
false,
|
||||
undefined,
|
||||
undefined,
|
||||
);
|
||||
});
|
||||
|
||||
@@ -674,13 +693,12 @@ describe('ShellTool', () => {
|
||||
resolveExecutionPromise({
|
||||
rawOutput: Buffer.from(''),
|
||||
output: '',
|
||||
stdout: '',
|
||||
stderr: '',
|
||||
exitCode: 0,
|
||||
signal: null,
|
||||
error: null,
|
||||
aborted: false,
|
||||
pid: 12345,
|
||||
executionMethod: 'child_process',
|
||||
});
|
||||
|
||||
await promise;
|
||||
@@ -691,6 +709,9 @@ describe('ShellTool', () => {
|
||||
expect.any(String),
|
||||
expect.any(Function),
|
||||
mockAbortSignal,
|
||||
false,
|
||||
undefined,
|
||||
undefined,
|
||||
);
|
||||
});
|
||||
|
||||
@@ -709,13 +730,12 @@ describe('ShellTool', () => {
|
||||
resolveExecutionPromise({
|
||||
rawOutput: Buffer.from(''),
|
||||
output: '',
|
||||
stdout: '',
|
||||
stderr: '',
|
||||
exitCode: 0,
|
||||
signal: null,
|
||||
error: null,
|
||||
aborted: false,
|
||||
pid: 12345,
|
||||
executionMethod: 'child_process',
|
||||
});
|
||||
|
||||
await promise;
|
||||
@@ -727,6 +747,9 @@ describe('ShellTool', () => {
|
||||
expect.any(String),
|
||||
expect.any(Function),
|
||||
mockAbortSignal,
|
||||
false,
|
||||
undefined,
|
||||
undefined,
|
||||
);
|
||||
});
|
||||
});
|
||||
@@ -765,6 +788,20 @@ describe('ShellTool', () => {
|
||||
).toThrow();
|
||||
});
|
||||
});
|
||||
|
||||
describe('getDescription', () => {
|
||||
it('should return the windows description when on windows', () => {
|
||||
vi.mocked(os.platform).mockReturnValue('win32');
|
||||
const shellTool = new ShellTool(mockConfig);
|
||||
expect(shellTool.description).toMatchSnapshot();
|
||||
});
|
||||
|
||||
it('should return the non-windows description when not on windows', () => {
|
||||
vi.mocked(os.platform).mockReturnValue('linux');
|
||||
const shellTool = new ShellTool(mockConfig);
|
||||
expect(shellTool.description).toMatchSnapshot();
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('validateToolParams', () => {
|
||||
|
||||
@@ -19,7 +19,6 @@ import {
|
||||
ToolConfirmationOutcome,
|
||||
Kind,
|
||||
} from './tools.js';
|
||||
import { SchemaValidator } from '../utils/schemaValidator.js';
|
||||
import { getErrorMessage } from '../utils/errors.js';
|
||||
import { summarizeToolOutput } from '../utils/summarizer.js';
|
||||
import {
|
||||
@@ -102,6 +101,8 @@ class ShellToolInvocation extends BaseToolInvocation<
|
||||
async execute(
|
||||
signal: AbortSignal,
|
||||
updateOutput?: (output: string) => void,
|
||||
terminalColumns?: number,
|
||||
terminalRows?: number,
|
||||
): Promise<ToolResult> {
|
||||
const strippedCommand = stripShellWrapper(this.params.command);
|
||||
|
||||
@@ -145,13 +146,11 @@ class ShellToolInvocation extends BaseToolInvocation<
|
||||
this.params.directory || '',
|
||||
);
|
||||
|
||||
let cumulativeStdout = '';
|
||||
let cumulativeStderr = '';
|
||||
|
||||
let cumulativeOutput = '';
|
||||
let lastUpdateTime = Date.now();
|
||||
let isBinaryStream = false;
|
||||
|
||||
const { result: resultPromise } = ShellExecutionService.execute(
|
||||
const { result: resultPromise } = await ShellExecutionService.execute(
|
||||
commandToExecute,
|
||||
cwd,
|
||||
(event: ShellOutputEvent) => {
|
||||
@@ -164,15 +163,9 @@ class ShellToolInvocation extends BaseToolInvocation<
|
||||
|
||||
switch (event.type) {
|
||||
case 'data':
|
||||
if (isBinaryStream) break; // Don't process text if we are in binary mode
|
||||
if (event.stream === 'stdout') {
|
||||
cumulativeStdout += event.chunk;
|
||||
} else {
|
||||
cumulativeStderr += event.chunk;
|
||||
}
|
||||
currentDisplayOutput =
|
||||
cumulativeStdout +
|
||||
(cumulativeStderr ? `\n${cumulativeStderr}` : '');
|
||||
if (isBinaryStream) break;
|
||||
cumulativeOutput = event.chunk;
|
||||
currentDisplayOutput = cumulativeOutput;
|
||||
if (Date.now() - lastUpdateTime > OUTPUT_UPDATE_INTERVAL_MS) {
|
||||
shouldUpdate = true;
|
||||
}
|
||||
@@ -203,6 +196,9 @@ class ShellToolInvocation extends BaseToolInvocation<
|
||||
}
|
||||
},
|
||||
signal,
|
||||
this.config.getShouldUseNodePtyShell(),
|
||||
terminalColumns,
|
||||
terminalRows,
|
||||
);
|
||||
|
||||
const result = await resultPromise;
|
||||
@@ -234,7 +230,7 @@ class ShellToolInvocation extends BaseToolInvocation<
|
||||
if (result.aborted) {
|
||||
llmContent = 'Command was cancelled by user before it could complete.';
|
||||
if (result.output.trim()) {
|
||||
llmContent += ` Below is the output (on stdout and stderr) before it was cancelled:\n${result.output}`;
|
||||
llmContent += ` Below is the output before it was cancelled:\n${result.output}`;
|
||||
} else {
|
||||
llmContent += ' There was no output before it was cancelled.';
|
||||
}
|
||||
@@ -248,8 +244,7 @@ class ShellToolInvocation extends BaseToolInvocation<
|
||||
llmContent = [
|
||||
`Command: ${this.params.command}`,
|
||||
`Directory: ${this.params.directory || '(root)'}`,
|
||||
`Stdout: ${result.stdout || '(empty)'}`,
|
||||
`Stderr: ${result.stderr || '(empty)'}`,
|
||||
`Output: ${result.output || '(empty)'}`,
|
||||
`Error: ${finalError}`, // Use the cleaned error string.
|
||||
`Exit Code: ${result.exitCode ?? '(none)'}`,
|
||||
`Signal: ${result.signal ?? '(none)'}`,
|
||||
@@ -345,18 +340,10 @@ Co-authored-by: ${gitCoAuthorSettings.name} <${gitCoAuthorSettings.email}>`;
|
||||
}
|
||||
}
|
||||
|
||||
export class ShellTool extends BaseDeclarativeTool<
|
||||
ShellToolParams,
|
||||
ToolResult
|
||||
> {
|
||||
static Name: string = 'run_shell_command';
|
||||
private allowlist: Set<string> = new Set();
|
||||
|
||||
constructor(private readonly config: Config) {
|
||||
super(
|
||||
ShellTool.Name,
|
||||
'Shell',
|
||||
`This tool executes a given shell command as \`bash -c <command>\`.
|
||||
function getShellToolDescription(): string {
|
||||
const platform = os.platform();
|
||||
const toolDescription = `
|
||||
${platform === 'win32' ? 'This tool executes a given shell command as `cmd.exe /c <command>`.' : 'This tool executes a given shell command as `bash -c <command>`. '}
|
||||
|
||||
**Background vs Foreground Execution:**
|
||||
You should decide whether commands should run in background or foreground based on their nature:
|
||||
@@ -375,7 +362,7 @@ export class ShellTool extends BaseDeclarativeTool<
|
||||
- Git operations: \`git commit\`, \`git push\`
|
||||
- Test runs: \`npm test\`, \`pytest\`
|
||||
|
||||
Command is executed as a subprocess that leads its own process group. Command process group can be terminated as \`kill -- -PGID\` or signaled as \`kill -s SIGNAL -- -PGID\`.
|
||||
${platform === 'win32' ? '' : 'Command is executed as a subprocess that leads its own process group. Command process group can be terminated as `kill -- -PGID` or signaled as `kill -s SIGNAL -- -PGID`.'}
|
||||
|
||||
The following information is returned:
|
||||
|
||||
@@ -387,14 +374,38 @@ export class ShellTool extends BaseDeclarativeTool<
|
||||
Exit Code: Exit code or \`(none)\` if terminated by signal.
|
||||
Signal: Signal number or \`(none)\` if no signal was received.
|
||||
Background PIDs: List of background processes started or \`(none)\`.
|
||||
Process Group PGID: Process group started or \`(none)\``,
|
||||
Process Group PGID: Process group started or \`(none)\``;
|
||||
|
||||
return toolDescription;
|
||||
}
|
||||
|
||||
function getCommandDescription(): string {
|
||||
if (os.platform() === 'win32') {
|
||||
return 'Exact command to execute as `cmd.exe /c <command>`';
|
||||
} else {
|
||||
return 'Exact bash command to execute as `bash -c <command>`';
|
||||
}
|
||||
}
|
||||
|
||||
export class ShellTool extends BaseDeclarativeTool<
|
||||
ShellToolParams,
|
||||
ToolResult
|
||||
> {
|
||||
static Name: string = 'run_shell_command';
|
||||
private allowlist: Set<string> = new Set();
|
||||
|
||||
constructor(private readonly config: Config) {
|
||||
super(
|
||||
ShellTool.Name,
|
||||
'Shell',
|
||||
getShellToolDescription(),
|
||||
Kind.Execute,
|
||||
{
|
||||
type: 'object',
|
||||
properties: {
|
||||
command: {
|
||||
type: 'string',
|
||||
description: 'Exact bash command to execute as `bash -c <command>`',
|
||||
description: getCommandDescription(),
|
||||
},
|
||||
is_background: {
|
||||
type: 'boolean',
|
||||
@@ -419,7 +430,9 @@ export class ShellTool extends BaseDeclarativeTool<
|
||||
);
|
||||
}
|
||||
|
||||
override validateToolParams(params: ShellToolParams): string | null {
|
||||
protected override validateToolParamValues(
|
||||
params: ShellToolParams,
|
||||
): string | null {
|
||||
const commandCheck = isCommandAllowed(params.command, this.config);
|
||||
if (!commandCheck.allowed) {
|
||||
if (!commandCheck.reason) {
|
||||
@@ -430,13 +443,6 @@ export class ShellTool extends BaseDeclarativeTool<
|
||||
}
|
||||
return commandCheck.reason;
|
||||
}
|
||||
const errors = SchemaValidator.validate(
|
||||
this.schema.parametersJsonSchema,
|
||||
params,
|
||||
);
|
||||
if (errors) {
|
||||
return errors;
|
||||
}
|
||||
if (!params.command.trim()) {
|
||||
return 'Command cannot be empty.';
|
||||
}
|
||||
|
||||
@@ -75,7 +75,9 @@ describe('TodoWriteTool', () => {
|
||||
};
|
||||
|
||||
const result = tool.validateToolParams(params);
|
||||
expect(result).toContain('must NOT have fewer than 1 characters');
|
||||
expect(result).toContain(
|
||||
'Each todo must have a non-empty "content" string',
|
||||
);
|
||||
});
|
||||
|
||||
it('should reject todos with empty id', () => {
|
||||
@@ -87,7 +89,7 @@ describe('TodoWriteTool', () => {
|
||||
};
|
||||
|
||||
const result = tool.validateToolParams(params);
|
||||
expect(result).toContain('non-empty "id"');
|
||||
expect(result).toContain('non-empty "id" string');
|
||||
});
|
||||
|
||||
it('should reject todos with invalid status', () => {
|
||||
@@ -103,7 +105,9 @@ describe('TodoWriteTool', () => {
|
||||
};
|
||||
|
||||
const result = tool.validateToolParams(params);
|
||||
expect(result).toContain('must be equal to one of the allowed values');
|
||||
expect(result).toContain(
|
||||
'Each todo must have a valid "status" (pending, in_progress, completed)',
|
||||
);
|
||||
});
|
||||
|
||||
it('should reject todos with duplicate IDs', () => {
|
||||
|
||||
@@ -17,7 +17,6 @@ import * as path from 'path';
|
||||
import * as process from 'process';
|
||||
|
||||
import { QWEN_DIR } from '../utils/paths.js';
|
||||
import { SchemaValidator } from '../utils/schemaValidator.js';
|
||||
import { Config } from '../config/config.js';
|
||||
|
||||
export interface TodoItem {
|
||||
@@ -248,7 +247,8 @@ When in doubt, use this tool. Being proactive with task management demonstrates
|
||||
const TODO_SUBDIR = 'todos';
|
||||
|
||||
function getTodoFilePath(sessionId?: string): string {
|
||||
const homeDir = process.env.HOME || process.env.USERPROFILE || process.cwd();
|
||||
const homeDir =
|
||||
process.env['HOME'] || process.env['USERPROFILE'] || process.cwd();
|
||||
const todoDir = path.join(homeDir, QWEN_DIR, TODO_SUBDIR);
|
||||
|
||||
// Use sessionId if provided, otherwise fall back to 'default'
|
||||
@@ -383,7 +383,7 @@ export async function readTodosForSession(
|
||||
export async function listTodoSessions(): Promise<string[]> {
|
||||
try {
|
||||
const homeDir =
|
||||
process.env.HOME || process.env.USERPROFILE || process.cwd();
|
||||
process.env['HOME'] || process.env['USERPROFILE'] || process.cwd();
|
||||
const todoDir = path.join(homeDir, QWEN_DIR, TODO_SUBDIR);
|
||||
const files = await fs.readdir(todoDir);
|
||||
return files
|
||||
@@ -415,14 +415,6 @@ export class TodoWriteTool extends BaseDeclarativeTool<
|
||||
}
|
||||
|
||||
override validateToolParams(params: TodoWriteParams): string | null {
|
||||
const errors = SchemaValidator.validate(
|
||||
this.schema.parametersJsonSchema,
|
||||
params,
|
||||
);
|
||||
if (errors) {
|
||||
return errors;
|
||||
}
|
||||
|
||||
// Validate todos array
|
||||
if (!Array.isArray(params.todos)) {
|
||||
return 'Parameter "todos" must be an array.';
|
||||
|
||||
@@ -13,6 +13,7 @@ export enum ToolErrorType {
|
||||
UNKNOWN = 'unknown',
|
||||
UNHANDLED_EXCEPTION = 'unhandled_exception',
|
||||
TOOL_NOT_REGISTERED = 'tool_not_registered',
|
||||
EXECUTION_FAILED = 'execution_failed',
|
||||
|
||||
// File System Errors
|
||||
FILE_NOT_FOUND = 'file_not_found',
|
||||
|
||||
@@ -22,15 +22,17 @@ import { spawn } from 'node:child_process';
|
||||
import fs from 'node:fs';
|
||||
import { MockTool } from '../test-utils/tools.js';
|
||||
|
||||
import { McpClientManager } from './mcp-client-manager.js';
|
||||
|
||||
vi.mock('node:fs');
|
||||
|
||||
// Use vi.hoisted to define the mock function so it can be used in the vi.mock factory
|
||||
const mockDiscoverMcpTools = vi.hoisted(() => vi.fn());
|
||||
|
||||
// Mock ./mcp-client.js to control its behavior within tool-registry tests
|
||||
vi.mock('./mcp-client.js', () => ({
|
||||
discoverMcpTools: mockDiscoverMcpTools,
|
||||
}));
|
||||
vi.mock('./mcp-client.js', async () => {
|
||||
const originalModule = await vi.importActual('./mcp-client.js');
|
||||
return {
|
||||
...originalModule,
|
||||
};
|
||||
});
|
||||
|
||||
// Mock node:child_process
|
||||
vi.mock('node:child_process', async () => {
|
||||
@@ -142,7 +144,6 @@ describe('ToolRegistry', () => {
|
||||
clear: vi.fn(),
|
||||
removePromptsByServer: vi.fn(),
|
||||
} as any);
|
||||
mockDiscoverMcpTools.mockReset().mockResolvedValue(undefined);
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
@@ -310,6 +311,10 @@ describe('ToolRegistry', () => {
|
||||
});
|
||||
|
||||
it('should discover tools using MCP servers defined in getMcpServers', async () => {
|
||||
const discoverSpy = vi.spyOn(
|
||||
McpClientManager.prototype,
|
||||
'discoverAllMcpTools',
|
||||
);
|
||||
mockConfigGetToolDiscoveryCommand.mockReturnValue(undefined);
|
||||
vi.spyOn(config, 'getMcpServerCommand').mockReturnValue(undefined);
|
||||
const mcpServerConfigVal = {
|
||||
@@ -323,38 +328,7 @@ describe('ToolRegistry', () => {
|
||||
|
||||
await toolRegistry.discoverAllTools();
|
||||
|
||||
expect(mockDiscoverMcpTools).toHaveBeenCalledWith(
|
||||
mcpServerConfigVal,
|
||||
undefined,
|
||||
toolRegistry,
|
||||
config.getPromptRegistry(),
|
||||
false,
|
||||
expect.any(Object),
|
||||
);
|
||||
});
|
||||
|
||||
it('should discover tools using MCP servers defined in getMcpServers', async () => {
|
||||
mockConfigGetToolDiscoveryCommand.mockReturnValue(undefined);
|
||||
vi.spyOn(config, 'getMcpServerCommand').mockReturnValue(undefined);
|
||||
const mcpServerConfigVal = {
|
||||
'my-mcp-server': {
|
||||
command: 'mcp-server-cmd',
|
||||
args: ['--port', '1234'],
|
||||
trust: true,
|
||||
},
|
||||
};
|
||||
vi.spyOn(config, 'getMcpServers').mockReturnValue(mcpServerConfigVal);
|
||||
|
||||
await toolRegistry.discoverAllTools();
|
||||
|
||||
expect(mockDiscoverMcpTools).toHaveBeenCalledWith(
|
||||
mcpServerConfigVal,
|
||||
undefined,
|
||||
toolRegistry,
|
||||
config.getPromptRegistry(),
|
||||
false,
|
||||
expect.any(Object),
|
||||
);
|
||||
expect(discoverSpy).toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -5,56 +5,47 @@
|
||||
*/
|
||||
|
||||
import { FunctionDeclaration } from '@google/genai';
|
||||
import { AnyDeclarativeTool, Kind, ToolResult, BaseTool } from './tools.js';
|
||||
import {
|
||||
AnyDeclarativeTool,
|
||||
Kind,
|
||||
ToolResult,
|
||||
BaseDeclarativeTool,
|
||||
BaseToolInvocation,
|
||||
ToolInvocation,
|
||||
} from './tools.js';
|
||||
import { Config } from '../config/config.js';
|
||||
import { spawn } from 'node:child_process';
|
||||
import { StringDecoder } from 'node:string_decoder';
|
||||
import { discoverMcpTools } from './mcp-client.js';
|
||||
import { connectAndDiscover } from './mcp-client.js';
|
||||
import { McpClientManager } from './mcp-client-manager.js';
|
||||
import { DiscoveredMCPTool } from './mcp-tool.js';
|
||||
import { parse } from 'shell-quote';
|
||||
|
||||
type ToolParams = Record<string, unknown>;
|
||||
|
||||
export class DiscoveredTool extends BaseTool<ToolParams, ToolResult> {
|
||||
class DiscoveredToolInvocation extends BaseToolInvocation<
|
||||
ToolParams,
|
||||
ToolResult
|
||||
> {
|
||||
constructor(
|
||||
private readonly config: Config,
|
||||
name: string,
|
||||
override readonly description: string,
|
||||
override readonly parameterSchema: Record<string, unknown>,
|
||||
private readonly toolName: string,
|
||||
params: ToolParams,
|
||||
) {
|
||||
const discoveryCmd = config.getToolDiscoveryCommand()!;
|
||||
const callCommand = config.getToolCallCommand()!;
|
||||
description += `
|
||||
|
||||
This tool was discovered from the project by executing the command \`${discoveryCmd}\` on project root.
|
||||
When called, this tool will execute the command \`${callCommand} ${name}\` on project root.
|
||||
Tool discovery and call commands can be configured in project or user settings.
|
||||
|
||||
When called, the tool call command is executed as a subprocess.
|
||||
On success, tool output is returned as a json string.
|
||||
Otherwise, the following information is returned:
|
||||
|
||||
Stdout: Output on stdout stream. Can be \`(empty)\` or partial.
|
||||
Stderr: Output on stderr stream. Can be \`(empty)\` or partial.
|
||||
Error: Error or \`(none)\` if no error was reported for the subprocess.
|
||||
Exit Code: Exit code or \`(none)\` if terminated by signal.
|
||||
Signal: Signal number or \`(none)\` if no signal was received.
|
||||
`;
|
||||
super(
|
||||
name,
|
||||
name,
|
||||
description,
|
||||
Kind.Other,
|
||||
parameterSchema,
|
||||
false, // isOutputMarkdown
|
||||
false, // canUpdateOutput
|
||||
);
|
||||
super(params);
|
||||
}
|
||||
|
||||
async execute(params: ToolParams): Promise<ToolResult> {
|
||||
getDescription(): string {
|
||||
return `Calling discovered tool: ${this.toolName}`;
|
||||
}
|
||||
|
||||
async execute(
|
||||
_signal: AbortSignal,
|
||||
_updateOutput?: (output: string) => void,
|
||||
): Promise<ToolResult> {
|
||||
const callCommand = this.config.getToolCallCommand()!;
|
||||
const child = spawn(callCommand, [this.name]);
|
||||
child.stdin.write(JSON.stringify(params));
|
||||
const child = spawn(callCommand, [this.toolName]);
|
||||
child.stdin.write(JSON.stringify(this.params));
|
||||
child.stdin.end();
|
||||
|
||||
let stdout = '';
|
||||
@@ -124,12 +115,67 @@ Signal: Signal number or \`(none)\` if no signal was received.
|
||||
}
|
||||
}
|
||||
|
||||
export class DiscoveredTool extends BaseDeclarativeTool<
|
||||
ToolParams,
|
||||
ToolResult
|
||||
> {
|
||||
constructor(
|
||||
private readonly config: Config,
|
||||
name: string,
|
||||
override readonly description: string,
|
||||
override readonly parameterSchema: Record<string, unknown>,
|
||||
) {
|
||||
const discoveryCmd = config.getToolDiscoveryCommand()!;
|
||||
const callCommand = config.getToolCallCommand()!;
|
||||
description += `
|
||||
|
||||
This tool was discovered from the project by executing the command \`${discoveryCmd}\` on project root.
|
||||
When called, this tool will execute the command \`${callCommand} ${name}\` on project root.
|
||||
Tool discovery and call commands can be configured in project or user settings.
|
||||
|
||||
When called, the tool call command is executed as a subprocess.
|
||||
On success, tool output is returned as a json string.
|
||||
Otherwise, the following information is returned:
|
||||
|
||||
Stdout: Output on stdout stream. Can be \`(empty)\` or partial.
|
||||
Stderr: Output on stderr stream. Can be \`(empty)\` or partial.
|
||||
Error: Error or \`(none)\` if no error was reported for the subprocess.
|
||||
Exit Code: Exit code or \`(none)\` if terminated by signal.
|
||||
Signal: Signal number or \`(none)\` if no signal was received.
|
||||
`;
|
||||
super(
|
||||
name,
|
||||
name,
|
||||
description,
|
||||
Kind.Other,
|
||||
parameterSchema,
|
||||
false, // isOutputMarkdown
|
||||
false, // canUpdateOutput
|
||||
);
|
||||
}
|
||||
|
||||
protected createInvocation(
|
||||
params: ToolParams,
|
||||
): ToolInvocation<ToolParams, ToolResult> {
|
||||
return new DiscoveredToolInvocation(this.config, this.name, params);
|
||||
}
|
||||
}
|
||||
|
||||
export class ToolRegistry {
|
||||
private tools: Map<string, AnyDeclarativeTool> = new Map();
|
||||
private config: Config;
|
||||
private mcpClientManager: McpClientManager;
|
||||
|
||||
constructor(config: Config) {
|
||||
this.config = config;
|
||||
this.mcpClientManager = new McpClientManager(
|
||||
this.config.getMcpServers() ?? {},
|
||||
this.config.getMcpServerCommand(),
|
||||
this,
|
||||
this.config.getPromptRegistry(),
|
||||
this.config.getDebugMode(),
|
||||
this.config.getWorkspaceContext(),
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -184,14 +230,7 @@ export class ToolRegistry {
|
||||
await this.discoverAndRegisterToolsFromCommand();
|
||||
|
||||
// discover tools using MCP servers, if configured
|
||||
await discoverMcpTools(
|
||||
this.config.getMcpServers() ?? {},
|
||||
this.config.getMcpServerCommand(),
|
||||
this,
|
||||
this.config.getPromptRegistry(),
|
||||
this.config.getDebugMode(),
|
||||
this.config.getWorkspaceContext(),
|
||||
);
|
||||
await this.mcpClientManager.discoverAllMcpTools();
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -206,14 +245,14 @@ export class ToolRegistry {
|
||||
this.config.getPromptRegistry().clear();
|
||||
|
||||
// discover tools using MCP servers, if configured
|
||||
await discoverMcpTools(
|
||||
this.config.getMcpServers() ?? {},
|
||||
this.config.getMcpServerCommand(),
|
||||
this,
|
||||
this.config.getPromptRegistry(),
|
||||
this.config.getDebugMode(),
|
||||
this.config.getWorkspaceContext(),
|
||||
);
|
||||
await this.mcpClientManager.discoverAllMcpTools();
|
||||
}
|
||||
|
||||
/**
|
||||
* Restarts all MCP servers and re-discovers tools.
|
||||
*/
|
||||
async restartMcpServers(): Promise<void> {
|
||||
await this.discoverMcpTools();
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -233,9 +272,9 @@ export class ToolRegistry {
|
||||
const mcpServers = this.config.getMcpServers() ?? {};
|
||||
const serverConfig = mcpServers[serverName];
|
||||
if (serverConfig) {
|
||||
await discoverMcpTools(
|
||||
{ [serverName]: serverConfig },
|
||||
undefined,
|
||||
await connectAndDiscover(
|
||||
serverName,
|
||||
serverConfig,
|
||||
this,
|
||||
this.config.getPromptRegistry(),
|
||||
this.config.getDebugMode(),
|
||||
|
||||
@@ -4,8 +4,119 @@
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { describe, it, expect } from 'vitest';
|
||||
import { hasCycleInSchema } from './tools.js'; // Added getStringifiedResultForDisplay
|
||||
import { describe, it, expect, vi } from 'vitest';
|
||||
import {
|
||||
DeclarativeTool,
|
||||
hasCycleInSchema,
|
||||
Kind,
|
||||
ToolInvocation,
|
||||
ToolResult,
|
||||
} from './tools.js';
|
||||
import { ToolErrorType } from './tool-error.js';
|
||||
|
||||
class TestToolInvocation implements ToolInvocation<object, ToolResult> {
|
||||
constructor(
|
||||
readonly params: object,
|
||||
private readonly executeFn: () => Promise<ToolResult>,
|
||||
) {}
|
||||
|
||||
getDescription(): string {
|
||||
return 'A test invocation';
|
||||
}
|
||||
|
||||
toolLocations() {
|
||||
return [];
|
||||
}
|
||||
|
||||
shouldConfirmExecute(): Promise<false> {
|
||||
return Promise.resolve(false);
|
||||
}
|
||||
|
||||
execute(): Promise<ToolResult> {
|
||||
return this.executeFn();
|
||||
}
|
||||
}
|
||||
|
||||
class TestTool extends DeclarativeTool<object, ToolResult> {
|
||||
private readonly buildFn: (params: object) => TestToolInvocation;
|
||||
|
||||
constructor(buildFn: (params: object) => TestToolInvocation) {
|
||||
super('test-tool', 'Test Tool', 'A tool for testing', Kind.Other, {});
|
||||
this.buildFn = buildFn;
|
||||
}
|
||||
|
||||
build(params: object): ToolInvocation<object, ToolResult> {
|
||||
return this.buildFn(params);
|
||||
}
|
||||
}
|
||||
|
||||
describe('DeclarativeTool', () => {
|
||||
describe('validateBuildAndExecute', () => {
|
||||
const abortSignal = new AbortController().signal;
|
||||
|
||||
it('should return INVALID_TOOL_PARAMS error if build fails', async () => {
|
||||
const buildError = new Error('Invalid build parameters');
|
||||
const buildFn = vi.fn().mockImplementation(() => {
|
||||
throw buildError;
|
||||
});
|
||||
const tool = new TestTool(buildFn);
|
||||
const params = { foo: 'bar' };
|
||||
|
||||
const result = await tool.validateBuildAndExecute(params, abortSignal);
|
||||
|
||||
expect(buildFn).toHaveBeenCalledWith(params);
|
||||
expect(result).toEqual({
|
||||
llmContent: `Error: Invalid parameters provided. Reason: ${buildError.message}`,
|
||||
returnDisplay: buildError.message,
|
||||
error: {
|
||||
message: buildError.message,
|
||||
type: ToolErrorType.INVALID_TOOL_PARAMS,
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
it('should return EXECUTION_FAILED error if execute fails', async () => {
|
||||
const executeError = new Error('Execution failed');
|
||||
const executeFn = vi.fn().mockRejectedValue(executeError);
|
||||
const invocation = new TestToolInvocation({}, executeFn);
|
||||
const buildFn = vi.fn().mockReturnValue(invocation);
|
||||
const tool = new TestTool(buildFn);
|
||||
const params = { foo: 'bar' };
|
||||
|
||||
const result = await tool.validateBuildAndExecute(params, abortSignal);
|
||||
|
||||
expect(buildFn).toHaveBeenCalledWith(params);
|
||||
expect(executeFn).toHaveBeenCalled();
|
||||
expect(result).toEqual({
|
||||
llmContent: `Error: Tool call execution failed. Reason: ${executeError.message}`,
|
||||
returnDisplay: executeError.message,
|
||||
error: {
|
||||
message: executeError.message,
|
||||
type: ToolErrorType.EXECUTION_FAILED,
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
it('should return the result of execute on success', async () => {
|
||||
const successResult: ToolResult = {
|
||||
llmContent: 'Success!',
|
||||
returnDisplay: 'Success!',
|
||||
summary: 'Tool executed successfully',
|
||||
};
|
||||
const executeFn = vi.fn().mockResolvedValue(successResult);
|
||||
const invocation = new TestToolInvocation({}, executeFn);
|
||||
const buildFn = vi.fn().mockReturnValue(invocation);
|
||||
const tool = new TestTool(buildFn);
|
||||
const params = { foo: 'bar' };
|
||||
|
||||
const result = await tool.validateBuildAndExecute(params, abortSignal);
|
||||
|
||||
expect(buildFn).toHaveBeenCalledWith(params);
|
||||
expect(executeFn).toHaveBeenCalled();
|
||||
expect(result).toEqual(successResult);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('hasCycleInSchema', () => {
|
||||
it('should detect a simple direct cycle', () => {
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
import { FunctionDeclaration, PartListUnion } from '@google/genai';
|
||||
import { ToolErrorType } from './tool-error.js';
|
||||
import { DiffUpdateResult } from '../ide/ideContext.js';
|
||||
import { SchemaValidator } from '../utils/schemaValidator.js';
|
||||
|
||||
/**
|
||||
* Represents a validated and ready-to-execute tool call.
|
||||
@@ -86,42 +87,6 @@ export abstract class BaseToolInvocation<
|
||||
*/
|
||||
export type AnyToolInvocation = ToolInvocation<object, ToolResult>;
|
||||
|
||||
/**
|
||||
* An adapter that wraps the legacy `Tool` interface to make it compatible
|
||||
* with the new `ToolInvocation` pattern.
|
||||
*/
|
||||
export class LegacyToolInvocation<
|
||||
TParams extends object,
|
||||
TResult extends ToolResult,
|
||||
> implements ToolInvocation<TParams, TResult>
|
||||
{
|
||||
constructor(
|
||||
private readonly legacyTool: BaseTool<TParams, TResult>,
|
||||
readonly params: TParams,
|
||||
) {}
|
||||
|
||||
getDescription(): string {
|
||||
return this.legacyTool.getDescription(this.params);
|
||||
}
|
||||
|
||||
toolLocations(): ToolLocation[] {
|
||||
return this.legacyTool.toolLocations(this.params);
|
||||
}
|
||||
|
||||
shouldConfirmExecute(
|
||||
abortSignal: AbortSignal,
|
||||
): Promise<ToolCallConfirmationDetails | false> {
|
||||
return this.legacyTool.shouldConfirmExecute(this.params, abortSignal);
|
||||
}
|
||||
|
||||
execute(
|
||||
signal: AbortSignal,
|
||||
updateOutput?: (output: string) => void,
|
||||
): Promise<TResult> {
|
||||
return this.legacyTool.execute(this.params, signal, updateOutput);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Interface for a tool builder that validates parameters and creates invocations.
|
||||
*/
|
||||
@@ -206,7 +171,7 @@ export abstract class DeclarativeTool<
|
||||
* @param params The raw parameters from the model.
|
||||
* @returns An error message string if invalid, null otherwise.
|
||||
*/
|
||||
protected validateToolParams(_params: TParams): string | null {
|
||||
validateToolParams(_params: TParams): string | null {
|
||||
// Base implementation can be extended by subclasses.
|
||||
return null;
|
||||
}
|
||||
@@ -236,6 +201,64 @@ export abstract class DeclarativeTool<
|
||||
const invocation = this.build(params);
|
||||
return invocation.execute(signal, updateOutput);
|
||||
}
|
||||
|
||||
/**
|
||||
* Similar to `build` but never throws.
|
||||
* @param params The raw, untrusted parameters from the model.
|
||||
* @returns A `ToolInvocation` instance.
|
||||
*/
|
||||
private silentBuild(
|
||||
params: TParams,
|
||||
): ToolInvocation<TParams, TResult> | Error {
|
||||
try {
|
||||
return this.build(params);
|
||||
} catch (e) {
|
||||
if (e instanceof Error) {
|
||||
return e;
|
||||
}
|
||||
return new Error(String(e));
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* A convenience method that builds and executes the tool in one step.
|
||||
* Never throws.
|
||||
* @param params The raw, untrusted parameters from the model.
|
||||
* @params abortSignal a signal to abort.
|
||||
* @returns The result of the tool execution.
|
||||
*/
|
||||
async validateBuildAndExecute(
|
||||
params: TParams,
|
||||
abortSignal: AbortSignal,
|
||||
): Promise<ToolResult> {
|
||||
const invocationOrError = this.silentBuild(params);
|
||||
if (invocationOrError instanceof Error) {
|
||||
const errorMessage = invocationOrError.message;
|
||||
return {
|
||||
llmContent: `Error: Invalid parameters provided. Reason: ${errorMessage}`,
|
||||
returnDisplay: errorMessage,
|
||||
error: {
|
||||
message: errorMessage,
|
||||
type: ToolErrorType.INVALID_TOOL_PARAMS,
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
try {
|
||||
return await invocationOrError.execute(abortSignal);
|
||||
} catch (error) {
|
||||
const errorMessage =
|
||||
error instanceof Error ? error.message : String(error);
|
||||
return {
|
||||
llmContent: `Error: Tool call execution failed. Reason: ${errorMessage}`,
|
||||
returnDisplay: errorMessage,
|
||||
error: {
|
||||
message: errorMessage,
|
||||
type: ToolErrorType.EXECUTION_FAILED,
|
||||
},
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -256,6 +279,23 @@ export abstract class BaseDeclarativeTool<
|
||||
return this.createInvocation(params);
|
||||
}
|
||||
|
||||
override validateToolParams(params: TParams): string | null {
|
||||
const errors = SchemaValidator.validate(
|
||||
this.schema.parametersJsonSchema,
|
||||
params,
|
||||
);
|
||||
|
||||
if (errors) {
|
||||
return errors;
|
||||
}
|
||||
return this.validateToolParamValues(params);
|
||||
}
|
||||
|
||||
protected validateToolParamValues(_params: TParams): string | null {
|
||||
// Base implementation can be extended by subclasses.
|
||||
return null;
|
||||
}
|
||||
|
||||
protected abstract createInvocation(
|
||||
params: TParams,
|
||||
): ToolInvocation<TParams, TResult>;
|
||||
@@ -266,116 +306,6 @@ export abstract class BaseDeclarativeTool<
|
||||
*/
|
||||
export type AnyDeclarativeTool = DeclarativeTool<object, ToolResult>;
|
||||
|
||||
/**
|
||||
* Base implementation for tools with common functionality
|
||||
* @deprecated Use `DeclarativeTool` for new tools.
|
||||
*/
|
||||
export abstract class BaseTool<
|
||||
TParams extends object,
|
||||
TResult extends ToolResult = ToolResult,
|
||||
> extends DeclarativeTool<TParams, TResult> {
|
||||
/**
|
||||
* Creates a new instance of BaseTool
|
||||
* @param name Internal name of the tool (used for API calls)
|
||||
* @param displayName User-friendly display name of the tool
|
||||
* @param description Description of what the tool does
|
||||
* @param isOutputMarkdown Whether the tool's output should be rendered as markdown
|
||||
* @param canUpdateOutput Whether the tool supports live (streaming) output
|
||||
* @param parameterSchema JSON Schema defining the parameters
|
||||
*/
|
||||
constructor(
|
||||
override readonly name: string,
|
||||
override readonly displayName: string,
|
||||
override readonly description: string,
|
||||
override readonly kind: Kind,
|
||||
override readonly parameterSchema: unknown,
|
||||
override readonly isOutputMarkdown: boolean = true,
|
||||
override readonly canUpdateOutput: boolean = false,
|
||||
) {
|
||||
super(
|
||||
name,
|
||||
displayName,
|
||||
description,
|
||||
kind,
|
||||
parameterSchema,
|
||||
isOutputMarkdown,
|
||||
canUpdateOutput,
|
||||
);
|
||||
}
|
||||
|
||||
build(params: TParams): ToolInvocation<TParams, TResult> {
|
||||
const validationError = this.validateToolParams(params);
|
||||
if (validationError) {
|
||||
throw new Error(validationError);
|
||||
}
|
||||
return new LegacyToolInvocation(this, params);
|
||||
}
|
||||
|
||||
/**
|
||||
* Validates the parameters for the tool
|
||||
* This is a placeholder implementation and should be overridden
|
||||
* Should be called from both `shouldConfirmExecute` and `execute`
|
||||
* `shouldConfirmExecute` should return false immediately if invalid
|
||||
* @param params Parameters to validate
|
||||
* @returns An error message string if invalid, null otherwise
|
||||
*/
|
||||
// eslint-disable-next-line @typescript-eslint/no-unused-vars
|
||||
override validateToolParams(params: TParams): string | null {
|
||||
// Implementation would typically use a JSON Schema validator
|
||||
// This is a placeholder that should be implemented by derived classes
|
||||
return null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets a pre-execution description of the tool operation
|
||||
* Default implementation that should be overridden by derived classes
|
||||
* @param params Parameters for the tool execution
|
||||
* @returns A markdown string describing what the tool will do
|
||||
*/
|
||||
getDescription(params: TParams): string {
|
||||
return JSON.stringify(params);
|
||||
}
|
||||
|
||||
/**
|
||||
* Determines if the tool should prompt for confirmation before execution
|
||||
* @param params Parameters for the tool execution
|
||||
* @returns Whether or not execute should be confirmed by the user.
|
||||
*/
|
||||
shouldConfirmExecute(
|
||||
// eslint-disable-next-line @typescript-eslint/no-unused-vars
|
||||
params: TParams,
|
||||
// eslint-disable-next-line @typescript-eslint/no-unused-vars
|
||||
abortSignal: AbortSignal,
|
||||
): Promise<ToolCallConfirmationDetails | false> {
|
||||
return Promise.resolve(false);
|
||||
}
|
||||
|
||||
/**
|
||||
* Determines what file system paths the tool will affect
|
||||
* @param params Parameters for the tool execution
|
||||
* @returns A list of such paths
|
||||
*/
|
||||
toolLocations(
|
||||
// eslint-disable-next-line @typescript-eslint/no-unused-vars
|
||||
params: TParams,
|
||||
): ToolLocation[] {
|
||||
return [];
|
||||
}
|
||||
|
||||
/**
|
||||
* Abstract method to execute the tool with the given parameters
|
||||
* Must be implemented by derived classes
|
||||
* @param params Parameters for the tool execution
|
||||
* @param signal AbortSignal for tool cancellation
|
||||
* @returns Result of the tool execution
|
||||
*/
|
||||
abstract execute(
|
||||
params: TParams,
|
||||
signal: AbortSignal,
|
||||
updateOutput?: (output: string) => void,
|
||||
): Promise<TResult>;
|
||||
}
|
||||
|
||||
export interface ToolResult {
|
||||
/**
|
||||
* A short, one-line summary of the tool's action and result.
|
||||
|
||||
@@ -4,7 +4,6 @@
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { SchemaValidator } from '../utils/schemaValidator.js';
|
||||
import {
|
||||
BaseDeclarativeTool,
|
||||
BaseToolInvocation,
|
||||
@@ -211,16 +210,9 @@ export class WebFetchTool extends BaseDeclarativeTool<
|
||||
}
|
||||
}
|
||||
|
||||
protected override validateToolParams(
|
||||
protected override validateToolParamValues(
|
||||
params: WebFetchToolParams,
|
||||
): string | null {
|
||||
const errors = SchemaValidator.validate(
|
||||
this.schema.parametersJsonSchema,
|
||||
params,
|
||||
);
|
||||
if (errors) {
|
||||
return errors;
|
||||
}
|
||||
if (!params.url || params.url.trim() === '') {
|
||||
return "The 'url' parameter cannot be empty.";
|
||||
}
|
||||
|
||||
166
packages/core/src/tools/web-search.test.ts
Normal file
166
packages/core/src/tools/web-search.test.ts
Normal file
@@ -0,0 +1,166 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
|
||||
import { WebSearchTool, WebSearchToolParams } from './web-search.js';
|
||||
import { Config } from '../config/config.js';
|
||||
import { GeminiClient } from '../core/client.js';
|
||||
|
||||
// Mock GeminiClient and Config constructor
|
||||
vi.mock('../core/client.js');
|
||||
vi.mock('../config/config.js');
|
||||
|
||||
// Mock global fetch
|
||||
const mockFetch = vi.fn();
|
||||
global.fetch = mockFetch;
|
||||
|
||||
describe('WebSearchTool', () => {
|
||||
const abortSignal = new AbortController().signal;
|
||||
let mockGeminiClient: GeminiClient;
|
||||
let tool: WebSearchTool;
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
const mockConfigInstance = {
|
||||
getGeminiClient: () => mockGeminiClient,
|
||||
getProxy: () => undefined,
|
||||
getTavilyApiKey: () => 'test-api-key', // Add the missing method
|
||||
} as unknown as Config;
|
||||
mockGeminiClient = new GeminiClient(mockConfigInstance);
|
||||
tool = new WebSearchTool(mockConfigInstance);
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.restoreAllMocks();
|
||||
});
|
||||
|
||||
describe('build', () => {
|
||||
it('should return an invocation for a valid query', () => {
|
||||
const params: WebSearchToolParams = { query: 'test query' };
|
||||
const invocation = tool.build(params);
|
||||
expect(invocation).toBeDefined();
|
||||
expect(invocation.params).toEqual(params);
|
||||
});
|
||||
|
||||
it('should throw an error for an empty query', () => {
|
||||
const params: WebSearchToolParams = { query: '' };
|
||||
expect(() => tool.build(params)).toThrow(
|
||||
"The 'query' parameter cannot be empty.",
|
||||
);
|
||||
});
|
||||
|
||||
it('should throw an error for a query with only whitespace', () => {
|
||||
const params: WebSearchToolParams = { query: ' ' };
|
||||
expect(() => tool.build(params)).toThrow(
|
||||
"The 'query' parameter cannot be empty.",
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('getDescription', () => {
|
||||
it('should return a description of the search', () => {
|
||||
const params: WebSearchToolParams = { query: 'test query' };
|
||||
const invocation = tool.build(params);
|
||||
expect(invocation.getDescription()).toBe(
|
||||
'Searching the web for: "test query"',
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('execute', () => {
|
||||
it('should return search results for a successful query', async () => {
|
||||
const params: WebSearchToolParams = { query: 'successful query' };
|
||||
|
||||
// Mock the fetch response
|
||||
mockFetch.mockResolvedValueOnce({
|
||||
ok: true,
|
||||
json: async () => ({
|
||||
answer: 'Here are your results.',
|
||||
results: [],
|
||||
}),
|
||||
} as Response);
|
||||
|
||||
const invocation = tool.build(params);
|
||||
const result = await invocation.execute(abortSignal);
|
||||
|
||||
expect(result.llmContent).toBe(
|
||||
'Web search results for "successful query":\n\nHere are your results.',
|
||||
);
|
||||
expect(result.returnDisplay).toBe(
|
||||
'Search results for "successful query" returned.',
|
||||
);
|
||||
expect(result.sources).toEqual([]);
|
||||
});
|
||||
|
||||
it('should handle no search results found', async () => {
|
||||
const params: WebSearchToolParams = { query: 'no results query' };
|
||||
|
||||
// Mock the fetch response
|
||||
mockFetch.mockResolvedValueOnce({
|
||||
ok: true,
|
||||
json: async () => ({
|
||||
answer: '',
|
||||
results: [],
|
||||
}),
|
||||
} as Response);
|
||||
|
||||
const invocation = tool.build(params);
|
||||
const result = await invocation.execute(abortSignal);
|
||||
|
||||
expect(result.llmContent).toBe(
|
||||
'No search results or information found for query: "no results query"',
|
||||
);
|
||||
expect(result.returnDisplay).toBe('No information found.');
|
||||
});
|
||||
|
||||
it('should handle API errors gracefully', async () => {
|
||||
const params: WebSearchToolParams = { query: 'error query' };
|
||||
|
||||
// Mock the fetch to reject
|
||||
mockFetch.mockRejectedValueOnce(new Error('API Failure'));
|
||||
|
||||
const invocation = tool.build(params);
|
||||
const result = await invocation.execute(abortSignal);
|
||||
|
||||
expect(result.llmContent).toContain('Error:');
|
||||
expect(result.llmContent).toContain('API Failure');
|
||||
expect(result.returnDisplay).toBe('Error performing web search.');
|
||||
});
|
||||
|
||||
it('should correctly format results with sources', async () => {
|
||||
const params: WebSearchToolParams = { query: 'grounding query' };
|
||||
|
||||
// Mock the fetch response
|
||||
mockFetch.mockResolvedValueOnce({
|
||||
ok: true,
|
||||
json: async () => ({
|
||||
answer: 'This is a test response.',
|
||||
results: [
|
||||
{ title: 'Example Site', url: 'https://example.com' },
|
||||
{ title: 'Google', url: 'https://google.com' },
|
||||
],
|
||||
}),
|
||||
} as Response);
|
||||
|
||||
const invocation = tool.build(params);
|
||||
const result = await invocation.execute(abortSignal);
|
||||
|
||||
const expectedLlmContent = `Web search results for "grounding query":
|
||||
|
||||
This is a test response.
|
||||
|
||||
Sources:
|
||||
[1] Example Site (https://example.com)
|
||||
[2] Google (https://google.com)`;
|
||||
|
||||
expect(result.llmContent).toBe(expectedLlmContent);
|
||||
expect(result.returnDisplay).toBe(
|
||||
'Search results for "grounding query" returned.',
|
||||
);
|
||||
expect(result.sources).toHaveLength(2);
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -4,8 +4,14 @@
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { BaseTool, Kind, ToolResult } from './tools.js';
|
||||
import { SchemaValidator } from '../utils/schemaValidator.js';
|
||||
import {
|
||||
BaseDeclarativeTool,
|
||||
BaseToolInvocation,
|
||||
Kind,
|
||||
ToolInvocation,
|
||||
ToolResult,
|
||||
} from './tools.js';
|
||||
|
||||
import { getErrorMessage } from '../utils/errors.js';
|
||||
import { Config } from '../config/config.js';
|
||||
|
||||
@@ -40,71 +46,24 @@ export interface WebSearchToolResult extends ToolResult {
|
||||
sources?: Array<{ title: string; url: string }>;
|
||||
}
|
||||
|
||||
/**
|
||||
* A tool to perform web searches using Tavily API.
|
||||
*/
|
||||
export class WebSearchTool extends BaseTool<
|
||||
class WebSearchToolInvocation extends BaseToolInvocation<
|
||||
WebSearchToolParams,
|
||||
WebSearchToolResult
|
||||
> {
|
||||
static readonly Name: string = 'web_search';
|
||||
|
||||
constructor(private readonly config: Config) {
|
||||
super(
|
||||
WebSearchTool.Name,
|
||||
'TavilySearch',
|
||||
'Performs a web search using the Tavily API and returns a concise answer with sources. Requires the TAVILY_API_KEY environment variable.',
|
||||
Kind.Search,
|
||||
{
|
||||
type: 'object',
|
||||
properties: {
|
||||
query: {
|
||||
type: 'string',
|
||||
description: 'The search query to find information on the web.',
|
||||
},
|
||||
},
|
||||
required: ['query'],
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Validates the parameters for the WebSearchTool.
|
||||
* @param params The parameters to validate
|
||||
* @returns An error message string if validation fails, null if valid
|
||||
*/
|
||||
validateParams(params: WebSearchToolParams): string | null {
|
||||
const errors = SchemaValidator.validate(
|
||||
this.schema.parametersJsonSchema,
|
||||
params,
|
||||
);
|
||||
if (errors) {
|
||||
return errors;
|
||||
}
|
||||
|
||||
if (!params.query || params.query.trim() === '') {
|
||||
return "The 'query' parameter cannot be empty.";
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
override getDescription(params: WebSearchToolParams): string {
|
||||
return `Searching the web for: "${params.query}"`;
|
||||
}
|
||||
|
||||
async execute(
|
||||
constructor(
|
||||
private readonly config: Config,
|
||||
params: WebSearchToolParams,
|
||||
_signal: AbortSignal,
|
||||
): Promise<WebSearchToolResult> {
|
||||
const validationError = this.validateToolParams(params);
|
||||
if (validationError) {
|
||||
return {
|
||||
llmContent: `Error: Invalid parameters provided. Reason: ${validationError}`,
|
||||
returnDisplay: validationError,
|
||||
};
|
||||
}
|
||||
) {
|
||||
super(params);
|
||||
}
|
||||
|
||||
const apiKey = this.config.getTavilyApiKey() || process.env.TAVILY_API_KEY;
|
||||
override getDescription(): string {
|
||||
return `Searching the web for: "${this.params.query}"`;
|
||||
}
|
||||
|
||||
async execute(signal: AbortSignal): Promise<WebSearchToolResult> {
|
||||
const apiKey =
|
||||
this.config.getTavilyApiKey() || process.env['TAVILY_API_KEY'];
|
||||
if (!apiKey) {
|
||||
return {
|
||||
llmContent:
|
||||
@@ -115,8 +74,6 @@ export class WebSearchTool extends BaseTool<
|
||||
}
|
||||
|
||||
try {
|
||||
const controller = new AbortController();
|
||||
const timeoutId = setTimeout(() => controller.abort(), 15000);
|
||||
const response = await fetch('https://api.tavily.com/search', {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
@@ -124,14 +81,13 @@ export class WebSearchTool extends BaseTool<
|
||||
},
|
||||
body: JSON.stringify({
|
||||
api_key: apiKey,
|
||||
query: params.query,
|
||||
query: this.params.query,
|
||||
search_depth: 'advanced',
|
||||
max_results: 5,
|
||||
include_answer: true,
|
||||
}),
|
||||
signal: controller.signal,
|
||||
signal,
|
||||
});
|
||||
clearTimeout(timeoutId);
|
||||
|
||||
if (!response.ok) {
|
||||
const text = await response.text().catch(() => '');
|
||||
@@ -166,18 +122,18 @@ export class WebSearchTool extends BaseTool<
|
||||
|
||||
if (!content.trim()) {
|
||||
return {
|
||||
llmContent: `No search results or information found for query: "${params.query}"`,
|
||||
llmContent: `No search results or information found for query: "${this.params.query}"`,
|
||||
returnDisplay: 'No information found.',
|
||||
};
|
||||
}
|
||||
|
||||
return {
|
||||
llmContent: `Web search results for "${params.query}":\n\n${content}`,
|
||||
returnDisplay: `Search results for "${params.query}" returned.`,
|
||||
llmContent: `Web search results for "${this.params.query}":\n\n${content}`,
|
||||
returnDisplay: `Search results for "${this.params.query}" returned.`,
|
||||
sources,
|
||||
};
|
||||
} catch (error: unknown) {
|
||||
const errorMessage = `Error during web search for query "${params.query}": ${getErrorMessage(
|
||||
const errorMessage = `Error during web search for query "${this.params.query}": ${getErrorMessage(
|
||||
error,
|
||||
)}`;
|
||||
console.error(errorMessage, error);
|
||||
@@ -188,3 +144,52 @@ export class WebSearchTool extends BaseTool<
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* A tool to perform web searches using Google Search via the Gemini API.
|
||||
*/
|
||||
export class WebSearchTool extends BaseDeclarativeTool<
|
||||
WebSearchToolParams,
|
||||
WebSearchToolResult
|
||||
> {
|
||||
static readonly Name: string = 'web_search';
|
||||
|
||||
constructor(private readonly config: Config) {
|
||||
super(
|
||||
WebSearchTool.Name,
|
||||
'TavilySearch',
|
||||
'Performs a web search using the Tavily API and returns a concise answer with sources. Requires the TAVILY_API_KEY environment variable.',
|
||||
Kind.Search,
|
||||
{
|
||||
type: 'object',
|
||||
properties: {
|
||||
query: {
|
||||
type: 'string',
|
||||
description: 'The search query to find information on the web.',
|
||||
},
|
||||
},
|
||||
required: ['query'],
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Validates the parameters for the WebSearchTool.
|
||||
* @param params The parameters to validate
|
||||
* @returns An error message string if validation fails, null if valid
|
||||
*/
|
||||
protected override validateToolParamValues(
|
||||
params: WebSearchToolParams,
|
||||
): string | null {
|
||||
if (!params.query || params.query.trim() === '') {
|
||||
return "The 'query' parameter cannot be empty.";
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
protected createInvocation(
|
||||
params: WebSearchToolParams,
|
||||
): ToolInvocation<WebSearchToolParams, WebSearchToolResult> {
|
||||
return new WebSearchToolInvocation(this.config, params);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -13,7 +13,11 @@ import {
|
||||
vi,
|
||||
type Mocked,
|
||||
} from 'vitest';
|
||||
import { WriteFileTool, WriteFileToolParams } from './write-file.js';
|
||||
import {
|
||||
getCorrectedFileContent,
|
||||
WriteFileTool,
|
||||
WriteFileToolParams,
|
||||
} from './write-file.js';
|
||||
import { ToolErrorType } from './tool-error.js';
|
||||
import {
|
||||
FileDiff,
|
||||
@@ -33,6 +37,7 @@ import {
|
||||
CorrectedEditResult,
|
||||
} from '../utils/editCorrector.js';
|
||||
import { createMockWorkspaceContext } from '../test-utils/mockWorkspaceContext.js';
|
||||
import { StandardFileSystemService } from '../services/fileSystemService.js';
|
||||
|
||||
const rootDir = path.resolve(os.tmpdir(), 'qwen-code-test-root');
|
||||
|
||||
@@ -51,11 +56,13 @@ vi.mocked(ensureCorrectFileContent).mockImplementation(
|
||||
);
|
||||
|
||||
// Mock Config
|
||||
const fsService = new StandardFileSystemService();
|
||||
const mockConfigInternal = {
|
||||
getTargetDir: () => rootDir,
|
||||
getApprovalMode: vi.fn(() => ApprovalMode.DEFAULT),
|
||||
setApprovalMode: vi.fn(),
|
||||
getGeminiClient: vi.fn(), // Initialize as a plain mock function
|
||||
getFileSystemService: () => fsService,
|
||||
getIdeClient: vi.fn(),
|
||||
getIdeMode: vi.fn(() => false),
|
||||
getWorkspaceContext: () => createMockWorkspaceContext(rootDir),
|
||||
@@ -174,74 +181,67 @@ describe('WriteFileTool', () => {
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
describe('validateToolParams', () => {
|
||||
it('should return null for valid absolute path within root', () => {
|
||||
describe('build', () => {
|
||||
it('should return an invocation for a valid absolute path within root', () => {
|
||||
const params = {
|
||||
file_path: path.join(rootDir, 'test.txt'),
|
||||
content: 'hello',
|
||||
};
|
||||
expect(tool.validateToolParams(params)).toBeNull();
|
||||
const invocation = tool.build(params);
|
||||
expect(invocation).toBeDefined();
|
||||
expect(invocation.params).toEqual(params);
|
||||
});
|
||||
|
||||
it('should return error for relative path', () => {
|
||||
it('should throw an error for a relative path', () => {
|
||||
const params = { file_path: 'test.txt', content: 'hello' };
|
||||
expect(tool.validateToolParams(params)).toMatch(
|
||||
/File path must be absolute/,
|
||||
);
|
||||
expect(() => tool.build(params)).toThrow(/File path must be absolute/);
|
||||
});
|
||||
|
||||
it('should return error for path outside root', () => {
|
||||
it('should throw an error for a path outside root', () => {
|
||||
const outsidePath = path.resolve(tempDir, 'outside-root.txt');
|
||||
const params = {
|
||||
file_path: outsidePath,
|
||||
content: 'hello',
|
||||
};
|
||||
const error = tool.validateToolParams(params);
|
||||
expect(error).toContain(
|
||||
'File path must be within one of the workspace directories',
|
||||
expect(() => tool.build(params)).toThrow(
|
||||
/File path must be within one of the workspace directories/,
|
||||
);
|
||||
});
|
||||
|
||||
it('should return error if path is a directory', () => {
|
||||
it('should throw an error if path is a directory', () => {
|
||||
const dirAsFilePath = path.join(rootDir, 'a_directory');
|
||||
fs.mkdirSync(dirAsFilePath);
|
||||
const params = {
|
||||
file_path: dirAsFilePath,
|
||||
content: 'hello',
|
||||
};
|
||||
expect(tool.validateToolParams(params)).toMatch(
|
||||
expect(() => tool.build(params)).toThrow(
|
||||
`Path is a directory, not a file: ${dirAsFilePath}`,
|
||||
);
|
||||
});
|
||||
|
||||
it('should return error if the content is null', () => {
|
||||
it('should throw an error if the content is null', () => {
|
||||
const dirAsFilePath = path.join(rootDir, 'a_directory');
|
||||
fs.mkdirSync(dirAsFilePath);
|
||||
const params = {
|
||||
file_path: dirAsFilePath,
|
||||
content: null,
|
||||
} as unknown as WriteFileToolParams; // Intentionally non-conforming
|
||||
expect(tool.validateToolParams(params)).toMatch(
|
||||
`params/content must be string`,
|
||||
);
|
||||
expect(() => tool.build(params)).toThrow('params/content must be string');
|
||||
});
|
||||
});
|
||||
|
||||
describe('getDescription', () => {
|
||||
it('should return error if the file_path is empty', () => {
|
||||
it('should throw error if the file_path is empty', () => {
|
||||
const dirAsFilePath = path.join(rootDir, 'a_directory');
|
||||
fs.mkdirSync(dirAsFilePath);
|
||||
const params = {
|
||||
file_path: '',
|
||||
content: '',
|
||||
};
|
||||
expect(tool.getDescription(params)).toMatch(
|
||||
`Model did not provide valid parameters for write file tool, missing or empty "file_path"`,
|
||||
);
|
||||
expect(() => tool.build(params)).toThrow(`Missing or empty "file_path"`);
|
||||
});
|
||||
});
|
||||
|
||||
describe('_getCorrectedFileContent', () => {
|
||||
describe('getCorrectedFileContent', () => {
|
||||
it('should call ensureCorrectFileContent for a new file', async () => {
|
||||
const filePath = path.join(rootDir, 'new_corrected_file.txt');
|
||||
const proposedContent = 'Proposed new content.';
|
||||
@@ -250,8 +250,8 @@ describe('WriteFileTool', () => {
|
||||
// Ensure the mock is set for this specific test case if needed, or rely on beforeEach
|
||||
mockEnsureCorrectFileContent.mockResolvedValue(correctedContent);
|
||||
|
||||
// @ts-expect-error _getCorrectedFileContent is private
|
||||
const result = await tool._getCorrectedFileContent(
|
||||
const result = await getCorrectedFileContent(
|
||||
mockConfig,
|
||||
filePath,
|
||||
proposedContent,
|
||||
abortSignal,
|
||||
@@ -287,8 +287,8 @@ describe('WriteFileTool', () => {
|
||||
occurrences: 1,
|
||||
} as CorrectedEditResult);
|
||||
|
||||
// @ts-expect-error _getCorrectedFileContent is private
|
||||
const result = await tool._getCorrectedFileContent(
|
||||
const result = await getCorrectedFileContent(
|
||||
mockConfig,
|
||||
filePath,
|
||||
proposedContent,
|
||||
abortSignal,
|
||||
@@ -319,19 +319,18 @@ describe('WriteFileTool', () => {
|
||||
fs.writeFileSync(filePath, 'content', { mode: 0o000 });
|
||||
|
||||
const readError = new Error('Permission denied');
|
||||
const originalReadFileSync = fs.readFileSync;
|
||||
vi.spyOn(fs, 'readFileSync').mockImplementationOnce(() => {
|
||||
throw readError;
|
||||
});
|
||||
vi.spyOn(fsService, 'readTextFile').mockImplementationOnce(() =>
|
||||
Promise.reject(readError),
|
||||
);
|
||||
|
||||
// @ts-expect-error _getCorrectedFileContent is private
|
||||
const result = await tool._getCorrectedFileContent(
|
||||
const result = await getCorrectedFileContent(
|
||||
mockConfig,
|
||||
filePath,
|
||||
proposedContent,
|
||||
abortSignal,
|
||||
);
|
||||
|
||||
expect(fs.readFileSync).toHaveBeenCalledWith(filePath, 'utf8');
|
||||
expect(fsService.readTextFile).toHaveBeenCalledWith(filePath);
|
||||
expect(mockEnsureCorrectEdit).not.toHaveBeenCalled();
|
||||
expect(mockEnsureCorrectFileContent).not.toHaveBeenCalled();
|
||||
expect(result.correctedContent).toBe(proposedContent);
|
||||
@@ -342,25 +341,12 @@ describe('WriteFileTool', () => {
|
||||
code: undefined,
|
||||
});
|
||||
|
||||
vi.spyOn(fs, 'readFileSync').mockImplementation(originalReadFileSync);
|
||||
fs.chmodSync(filePath, 0o600);
|
||||
});
|
||||
});
|
||||
|
||||
describe('shouldConfirmExecute', () => {
|
||||
const abortSignal = new AbortController().signal;
|
||||
it('should return false if params are invalid (relative path)', async () => {
|
||||
const params = { file_path: 'relative.txt', content: 'test' };
|
||||
const confirmation = await tool.shouldConfirmExecute(params, abortSignal);
|
||||
expect(confirmation).toBe(false);
|
||||
});
|
||||
|
||||
it('should return false if params are invalid (outside root)', async () => {
|
||||
const outsidePath = path.resolve(tempDir, 'outside-root.txt');
|
||||
const params = { file_path: outsidePath, content: 'test' };
|
||||
const confirmation = await tool.shouldConfirmExecute(params, abortSignal);
|
||||
expect(confirmation).toBe(false);
|
||||
});
|
||||
|
||||
it('should return false if _getCorrectedFileContent returns an error', async () => {
|
||||
const filePath = path.join(rootDir, 'confirm_error_file.txt');
|
||||
@@ -368,15 +354,14 @@ describe('WriteFileTool', () => {
|
||||
fs.writeFileSync(filePath, 'original', { mode: 0o000 });
|
||||
|
||||
const readError = new Error('Simulated read error for confirmation');
|
||||
const originalReadFileSync = fs.readFileSync;
|
||||
vi.spyOn(fs, 'readFileSync').mockImplementationOnce(() => {
|
||||
throw readError;
|
||||
});
|
||||
vi.spyOn(fsService, 'readTextFile').mockImplementationOnce(() =>
|
||||
Promise.reject(readError),
|
||||
);
|
||||
|
||||
const confirmation = await tool.shouldConfirmExecute(params, abortSignal);
|
||||
const invocation = tool.build(params);
|
||||
const confirmation = await invocation.shouldConfirmExecute(abortSignal);
|
||||
expect(confirmation).toBe(false);
|
||||
|
||||
vi.spyOn(fs, 'readFileSync').mockImplementation(originalReadFileSync);
|
||||
fs.chmodSync(filePath, 0o600);
|
||||
});
|
||||
|
||||
@@ -387,8 +372,8 @@ describe('WriteFileTool', () => {
|
||||
mockEnsureCorrectFileContent.mockResolvedValue(correctedContent); // Ensure this mock is active
|
||||
|
||||
const params = { file_path: filePath, content: proposedContent };
|
||||
const confirmation = (await tool.shouldConfirmExecute(
|
||||
params,
|
||||
const invocation = tool.build(params);
|
||||
const confirmation = (await invocation.shouldConfirmExecute(
|
||||
abortSignal,
|
||||
)) as ToolEditConfirmationDetails;
|
||||
|
||||
@@ -430,8 +415,8 @@ describe('WriteFileTool', () => {
|
||||
});
|
||||
|
||||
const params = { file_path: filePath, content: proposedContent };
|
||||
const confirmation = (await tool.shouldConfirmExecute(
|
||||
params,
|
||||
const invocation = tool.build(params);
|
||||
const confirmation = (await invocation.shouldConfirmExecute(
|
||||
abortSignal,
|
||||
)) as ToolEditConfirmationDetails;
|
||||
|
||||
@@ -461,45 +446,20 @@ describe('WriteFileTool', () => {
|
||||
|
||||
describe('execute', () => {
|
||||
const abortSignal = new AbortController().signal;
|
||||
it('should return error if params are invalid (relative path)', async () => {
|
||||
const params = { file_path: 'relative.txt', content: 'test' };
|
||||
const result = await tool.execute(params, abortSignal);
|
||||
expect(result.llmContent).toContain(
|
||||
'Could not write file due to invalid parameters:',
|
||||
);
|
||||
expect(result.returnDisplay).toMatch(/File path must be absolute/);
|
||||
expect(result.error).toEqual({
|
||||
message: 'File path must be absolute: relative.txt',
|
||||
type: ToolErrorType.INVALID_TOOL_PARAMS,
|
||||
});
|
||||
});
|
||||
|
||||
it('should return error if params are invalid (path outside root)', async () => {
|
||||
const outsidePath = path.resolve(tempDir, 'outside-root.txt');
|
||||
const params = { file_path: outsidePath, content: 'test' };
|
||||
const result = await tool.execute(params, abortSignal);
|
||||
expect(result.llmContent).toContain(
|
||||
'Could not write file due to invalid parameters:',
|
||||
);
|
||||
expect(result.returnDisplay).toContain(
|
||||
'File path must be within one of the workspace directories',
|
||||
);
|
||||
expect(result.error?.type).toBe(ToolErrorType.INVALID_TOOL_PARAMS);
|
||||
});
|
||||
|
||||
it('should return error if _getCorrectedFileContent returns an error during execute', async () => {
|
||||
const filePath = path.join(rootDir, 'execute_error_file.txt');
|
||||
const params = { file_path: filePath, content: 'test content' };
|
||||
fs.writeFileSync(filePath, 'original', { mode: 0o000 });
|
||||
|
||||
const readError = new Error('Simulated read error for execute');
|
||||
const originalReadFileSync = fs.readFileSync;
|
||||
vi.spyOn(fs, 'readFileSync').mockImplementationOnce(() => {
|
||||
throw readError;
|
||||
vi.spyOn(fsService, 'readTextFile').mockImplementationOnce(() => {
|
||||
const readError = new Error('Simulated read error for execute');
|
||||
return Promise.reject(readError);
|
||||
});
|
||||
|
||||
const result = await tool.execute(params, abortSignal);
|
||||
expect(result.llmContent).toContain('Error checking existing file:');
|
||||
const invocation = tool.build(params);
|
||||
const result = await invocation.execute(abortSignal);
|
||||
expect(result.llmContent).toContain('Error checking existing file');
|
||||
expect(result.returnDisplay).toMatch(
|
||||
/Error checking existing file: Simulated read error for execute/,
|
||||
);
|
||||
@@ -509,7 +469,6 @@ describe('WriteFileTool', () => {
|
||||
type: ToolErrorType.FILE_WRITE_FAILURE,
|
||||
});
|
||||
|
||||
vi.spyOn(fs, 'readFileSync').mockImplementation(originalReadFileSync);
|
||||
fs.chmodSync(filePath, 0o600);
|
||||
});
|
||||
|
||||
@@ -520,11 +479,9 @@ describe('WriteFileTool', () => {
|
||||
mockEnsureCorrectFileContent.mockResolvedValue(correctedContent);
|
||||
|
||||
const params = { file_path: filePath, content: proposedContent };
|
||||
const invocation = tool.build(params);
|
||||
|
||||
const confirmDetails = await tool.shouldConfirmExecute(
|
||||
params,
|
||||
abortSignal,
|
||||
);
|
||||
const confirmDetails = await invocation.shouldConfirmExecute(abortSignal);
|
||||
if (
|
||||
typeof confirmDetails === 'object' &&
|
||||
'onConfirm' in confirmDetails &&
|
||||
@@ -533,7 +490,7 @@ describe('WriteFileTool', () => {
|
||||
await confirmDetails.onConfirm(ToolConfirmationOutcome.ProceedOnce);
|
||||
}
|
||||
|
||||
const result = await tool.execute(params, abortSignal);
|
||||
const result = await invocation.execute(abortSignal);
|
||||
|
||||
expect(mockEnsureCorrectFileContent).toHaveBeenCalledWith(
|
||||
proposedContent,
|
||||
@@ -544,7 +501,8 @@ describe('WriteFileTool', () => {
|
||||
/Successfully created and wrote to new file/,
|
||||
);
|
||||
expect(fs.existsSync(filePath)).toBe(true);
|
||||
expect(fs.readFileSync(filePath, 'utf8')).toBe(correctedContent);
|
||||
const writtenContent = await fsService.readTextFile(filePath);
|
||||
expect(writtenContent).toBe(correctedContent);
|
||||
const display = result.returnDisplay as FileDiff;
|
||||
expect(display.fileName).toBe('execute_new_corrected_file.txt');
|
||||
expect(display.fileDiff).toMatch(
|
||||
@@ -578,11 +536,9 @@ describe('WriteFileTool', () => {
|
||||
});
|
||||
|
||||
const params = { file_path: filePath, content: proposedContent };
|
||||
const invocation = tool.build(params);
|
||||
|
||||
const confirmDetails = await tool.shouldConfirmExecute(
|
||||
params,
|
||||
abortSignal,
|
||||
);
|
||||
const confirmDetails = await invocation.shouldConfirmExecute(abortSignal);
|
||||
if (
|
||||
typeof confirmDetails === 'object' &&
|
||||
'onConfirm' in confirmDetails &&
|
||||
@@ -591,7 +547,7 @@ describe('WriteFileTool', () => {
|
||||
await confirmDetails.onConfirm(ToolConfirmationOutcome.ProceedOnce);
|
||||
}
|
||||
|
||||
const result = await tool.execute(params, abortSignal);
|
||||
const result = await invocation.execute(abortSignal);
|
||||
|
||||
expect(mockEnsureCorrectEdit).toHaveBeenCalledWith(
|
||||
filePath,
|
||||
@@ -605,7 +561,8 @@ describe('WriteFileTool', () => {
|
||||
abortSignal,
|
||||
);
|
||||
expect(result.llmContent).toMatch(/Successfully overwrote file/);
|
||||
expect(fs.readFileSync(filePath, 'utf8')).toBe(correctedProposedContent);
|
||||
const writtenContent = await fsService.readTextFile(filePath);
|
||||
expect(writtenContent).toBe(correctedProposedContent);
|
||||
const display = result.returnDisplay as FileDiff;
|
||||
expect(display.fileName).toBe('execute_existing_corrected_file.txt');
|
||||
expect(display.fileDiff).toMatch(
|
||||
@@ -623,11 +580,9 @@ describe('WriteFileTool', () => {
|
||||
mockEnsureCorrectFileContent.mockResolvedValue(content); // Ensure this mock is active
|
||||
|
||||
const params = { file_path: filePath, content };
|
||||
const invocation = tool.build(params);
|
||||
// Simulate confirmation if your logic requires it before execute, or remove if not needed for this path
|
||||
const confirmDetails = await tool.shouldConfirmExecute(
|
||||
params,
|
||||
abortSignal,
|
||||
);
|
||||
const confirmDetails = await invocation.shouldConfirmExecute(abortSignal);
|
||||
if (
|
||||
typeof confirmDetails === 'object' &&
|
||||
'onConfirm' in confirmDetails &&
|
||||
@@ -636,7 +591,7 @@ describe('WriteFileTool', () => {
|
||||
await confirmDetails.onConfirm(ToolConfirmationOutcome.ProceedOnce);
|
||||
}
|
||||
|
||||
await tool.execute(params, abortSignal);
|
||||
await invocation.execute(abortSignal);
|
||||
|
||||
expect(fs.existsSync(dirPath)).toBe(true);
|
||||
expect(fs.statSync(dirPath).isDirectory()).toBe(true);
|
||||
@@ -654,7 +609,8 @@ describe('WriteFileTool', () => {
|
||||
content,
|
||||
modified_by_user: true,
|
||||
};
|
||||
const result = await tool.execute(params, abortSignal);
|
||||
const invocation = tool.build(params);
|
||||
const result = await invocation.execute(abortSignal);
|
||||
|
||||
expect(result.llmContent).toMatch(/User modified the `content`/);
|
||||
});
|
||||
@@ -669,7 +625,8 @@ describe('WriteFileTool', () => {
|
||||
content,
|
||||
modified_by_user: false,
|
||||
};
|
||||
const result = await tool.execute(params, abortSignal);
|
||||
const invocation = tool.build(params);
|
||||
const result = await invocation.execute(abortSignal);
|
||||
|
||||
expect(result.llmContent).not.toMatch(/User modified the `content`/);
|
||||
});
|
||||
@@ -683,7 +640,8 @@ describe('WriteFileTool', () => {
|
||||
file_path: filePath,
|
||||
content,
|
||||
};
|
||||
const result = await tool.execute(params, abortSignal);
|
||||
const invocation = tool.build(params);
|
||||
const result = await invocation.execute(abortSignal);
|
||||
|
||||
expect(result.llmContent).not.toMatch(/User modified the `content`/);
|
||||
});
|
||||
@@ -695,7 +653,7 @@ describe('WriteFileTool', () => {
|
||||
file_path: path.join(rootDir, 'file.txt'),
|
||||
content: 'test content',
|
||||
};
|
||||
expect(tool.validateToolParams(params)).toBeNull();
|
||||
expect(() => tool.build(params)).not.toThrow();
|
||||
});
|
||||
|
||||
it('should reject paths outside workspace root', () => {
|
||||
@@ -703,24 +661,9 @@ describe('WriteFileTool', () => {
|
||||
file_path: '/etc/passwd',
|
||||
content: 'malicious',
|
||||
};
|
||||
const error = tool.validateToolParams(params);
|
||||
expect(error).toContain(
|
||||
'File path must be within one of the workspace directories',
|
||||
expect(() => tool.build(params)).toThrow(
|
||||
/File path must be within one of the workspace directories/,
|
||||
);
|
||||
expect(error).toContain(rootDir);
|
||||
});
|
||||
|
||||
it('should provide clear error message with workspace directories', () => {
|
||||
const outsidePath = path.join(tempDir, 'outside-root.txt');
|
||||
const params = {
|
||||
file_path: outsidePath,
|
||||
content: 'test',
|
||||
};
|
||||
const error = tool.validateToolParams(params);
|
||||
expect(error).toContain(
|
||||
'File path must be within one of the workspace directories',
|
||||
);
|
||||
expect(error).toContain(rootDir);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -731,50 +674,50 @@ describe('WriteFileTool', () => {
|
||||
const filePath = path.join(rootDir, 'permission_denied_file.txt');
|
||||
const content = 'test content';
|
||||
|
||||
// Mock writeFileSync to throw EACCES error
|
||||
const originalWriteFileSync = fs.writeFileSync;
|
||||
vi.spyOn(fs, 'writeFileSync').mockImplementationOnce(() => {
|
||||
// Mock FileSystemService writeTextFile to throw EACCES error
|
||||
vi.spyOn(fsService, 'writeTextFile').mockImplementationOnce(() => {
|
||||
const error = new Error('Permission denied') as NodeJS.ErrnoException;
|
||||
error.code = 'EACCES';
|
||||
throw error;
|
||||
return Promise.reject(error);
|
||||
});
|
||||
|
||||
const params = { file_path: filePath, content };
|
||||
const result = await tool.execute(params, abortSignal);
|
||||
const invocation = tool.build(params);
|
||||
const result = await invocation.execute(abortSignal);
|
||||
|
||||
expect(result.error?.type).toBe(ToolErrorType.PERMISSION_DENIED);
|
||||
expect(result.llmContent).toContain(
|
||||
`Permission denied writing to file: ${filePath} (EACCES)`,
|
||||
);
|
||||
expect(result.returnDisplay).toContain('Permission denied');
|
||||
|
||||
vi.spyOn(fs, 'writeFileSync').mockImplementation(originalWriteFileSync);
|
||||
expect(result.returnDisplay).toContain(
|
||||
`Permission denied writing to file: ${filePath} (EACCES)`,
|
||||
);
|
||||
});
|
||||
|
||||
it('should return NO_SPACE_LEFT error when write fails with ENOSPC', async () => {
|
||||
const filePath = path.join(rootDir, 'no_space_file.txt');
|
||||
const content = 'test content';
|
||||
|
||||
// Mock writeFileSync to throw ENOSPC error
|
||||
const originalWriteFileSync = fs.writeFileSync;
|
||||
vi.spyOn(fs, 'writeFileSync').mockImplementationOnce(() => {
|
||||
// Mock FileSystemService writeTextFile to throw ENOSPC error
|
||||
vi.spyOn(fsService, 'writeTextFile').mockImplementationOnce(() => {
|
||||
const error = new Error(
|
||||
'No space left on device',
|
||||
) as NodeJS.ErrnoException;
|
||||
error.code = 'ENOSPC';
|
||||
throw error;
|
||||
return Promise.reject(error);
|
||||
});
|
||||
|
||||
const params = { file_path: filePath, content };
|
||||
const result = await tool.execute(params, abortSignal);
|
||||
const invocation = tool.build(params);
|
||||
const result = await invocation.execute(abortSignal);
|
||||
|
||||
expect(result.error?.type).toBe(ToolErrorType.NO_SPACE_LEFT);
|
||||
expect(result.llmContent).toContain(
|
||||
`No space left on device: ${filePath} (ENOSPC)`,
|
||||
);
|
||||
expect(result.returnDisplay).toContain('No space left');
|
||||
|
||||
vi.spyOn(fs, 'writeFileSync').mockImplementation(originalWriteFileSync);
|
||||
expect(result.returnDisplay).toContain(
|
||||
`No space left on device: ${filePath} (ENOSPC)`,
|
||||
);
|
||||
});
|
||||
|
||||
it('should return TARGET_IS_DIRECTORY error when write fails with EISDIR', async () => {
|
||||
@@ -790,25 +733,26 @@ describe('WriteFileTool', () => {
|
||||
return originalExistsSync(path as string);
|
||||
});
|
||||
|
||||
// Mock writeFileSync to throw EISDIR error
|
||||
const originalWriteFileSync = fs.writeFileSync;
|
||||
vi.spyOn(fs, 'writeFileSync').mockImplementationOnce(() => {
|
||||
// Mock FileSystemService writeTextFile to throw EISDIR error
|
||||
vi.spyOn(fsService, 'writeTextFile').mockImplementationOnce(() => {
|
||||
const error = new Error('Is a directory') as NodeJS.ErrnoException;
|
||||
error.code = 'EISDIR';
|
||||
throw error;
|
||||
return Promise.reject(error);
|
||||
});
|
||||
|
||||
const params = { file_path: dirPath, content };
|
||||
const result = await tool.execute(params, abortSignal);
|
||||
const invocation = tool.build(params);
|
||||
const result = await invocation.execute(abortSignal);
|
||||
|
||||
expect(result.error?.type).toBe(ToolErrorType.TARGET_IS_DIRECTORY);
|
||||
expect(result.llmContent).toContain(
|
||||
`Target is a directory, not a file: ${dirPath} (EISDIR)`,
|
||||
);
|
||||
expect(result.returnDisplay).toContain('Target is a directory');
|
||||
expect(result.returnDisplay).toContain(
|
||||
`Target is a directory, not a file: ${dirPath} (EISDIR)`,
|
||||
);
|
||||
|
||||
vi.spyOn(fs, 'existsSync').mockImplementation(originalExistsSync);
|
||||
vi.spyOn(fs, 'writeFileSync').mockImplementation(originalWriteFileSync);
|
||||
});
|
||||
|
||||
it('should return FILE_WRITE_FAILURE for generic write errors', async () => {
|
||||
@@ -818,19 +762,22 @@ describe('WriteFileTool', () => {
|
||||
// Ensure fs.existsSync is not mocked for this test
|
||||
vi.restoreAllMocks();
|
||||
|
||||
// Mock writeFileSync to throw generic error
|
||||
vi.spyOn(fs, 'writeFileSync').mockImplementationOnce(() => {
|
||||
throw new Error('Generic write error');
|
||||
});
|
||||
// Mock FileSystemService writeTextFile to throw generic error
|
||||
vi.spyOn(fsService, 'writeTextFile').mockImplementationOnce(() =>
|
||||
Promise.reject(new Error('Generic write error')),
|
||||
);
|
||||
|
||||
const params = { file_path: filePath, content };
|
||||
const result = await tool.execute(params, abortSignal);
|
||||
const invocation = tool.build(params);
|
||||
const result = await invocation.execute(abortSignal);
|
||||
|
||||
expect(result.error?.type).toBe(ToolErrorType.FILE_WRITE_FAILURE);
|
||||
expect(result.llmContent).toContain(
|
||||
'Error writing to file: Generic write error',
|
||||
);
|
||||
expect(result.returnDisplay).toContain('Generic write error');
|
||||
expect(result.returnDisplay).toContain(
|
||||
'Error writing to file: Generic write error',
|
||||
);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -9,17 +9,18 @@ import path from 'path';
|
||||
import * as Diff from 'diff';
|
||||
import { Config, ApprovalMode } from '../config/config.js';
|
||||
import {
|
||||
BaseTool,
|
||||
ToolResult,
|
||||
BaseDeclarativeTool,
|
||||
BaseToolInvocation,
|
||||
FileDiff,
|
||||
ToolEditConfirmationDetails,
|
||||
ToolConfirmationOutcome,
|
||||
ToolCallConfirmationDetails,
|
||||
Kind,
|
||||
ToolCallConfirmationDetails,
|
||||
ToolConfirmationOutcome,
|
||||
ToolEditConfirmationDetails,
|
||||
ToolInvocation,
|
||||
ToolLocation,
|
||||
ToolResult,
|
||||
} from './tools.js';
|
||||
import { ToolErrorType } from './tool-error.js';
|
||||
import { SchemaValidator } from '../utils/schemaValidator.js';
|
||||
import { makeRelative, shortenPath } from '../utils/paths.js';
|
||||
import { getErrorMessage, isNodeError } from '../utils/errors.js';
|
||||
import {
|
||||
@@ -67,113 +68,101 @@ interface GetCorrectedFileContentResult {
|
||||
error?: { message: string; code?: string };
|
||||
}
|
||||
|
||||
/**
|
||||
* Implementation of the WriteFile tool logic
|
||||
*/
|
||||
export class WriteFileTool
|
||||
extends BaseTool<WriteFileToolParams, ToolResult>
|
||||
implements ModifiableDeclarativeTool<WriteFileToolParams>
|
||||
{
|
||||
static readonly Name: string = 'write_file';
|
||||
export async function getCorrectedFileContent(
|
||||
config: Config,
|
||||
filePath: string,
|
||||
proposedContent: string,
|
||||
abortSignal: AbortSignal,
|
||||
): Promise<GetCorrectedFileContentResult> {
|
||||
let originalContent = '';
|
||||
let fileExists = false;
|
||||
let correctedContent = proposedContent;
|
||||
|
||||
constructor(private readonly config: Config) {
|
||||
super(
|
||||
WriteFileTool.Name,
|
||||
'WriteFile',
|
||||
`Writes content to a specified file in the local filesystem.
|
||||
try {
|
||||
originalContent = await config
|
||||
.getFileSystemService()
|
||||
.readTextFile(filePath);
|
||||
fileExists = true; // File exists and was read
|
||||
} catch (err) {
|
||||
if (isNodeError(err) && err.code === 'ENOENT') {
|
||||
fileExists = false;
|
||||
originalContent = '';
|
||||
} else {
|
||||
// File exists but could not be read (permissions, etc.)
|
||||
fileExists = true; // Mark as existing but problematic
|
||||
originalContent = ''; // Can't use its content
|
||||
const error = {
|
||||
message: getErrorMessage(err),
|
||||
code: isNodeError(err) ? err.code : undefined,
|
||||
};
|
||||
// Return early as we can't proceed with content correction meaningfully
|
||||
return { originalContent, correctedContent, fileExists, error };
|
||||
}
|
||||
}
|
||||
|
||||
The user has the ability to modify \`content\`. If modified, this will be stated in the response.`,
|
||||
Kind.Edit,
|
||||
// If readError is set, we have returned.
|
||||
// So, file was either read successfully (fileExists=true, originalContent set)
|
||||
// or it was ENOENT (fileExists=false, originalContent='').
|
||||
|
||||
if (fileExists) {
|
||||
// This implies originalContent is available
|
||||
const { params: correctedParams } = await ensureCorrectEdit(
|
||||
filePath,
|
||||
originalContent,
|
||||
{
|
||||
properties: {
|
||||
file_path: {
|
||||
description:
|
||||
"The absolute path to the file to write to (e.g., '/home/user/project/file.txt'). Relative paths are not supported.",
|
||||
type: 'string',
|
||||
},
|
||||
content: {
|
||||
description: 'The content to write to the file.',
|
||||
type: 'string',
|
||||
},
|
||||
},
|
||||
required: ['file_path', 'content'],
|
||||
type: 'object',
|
||||
old_string: originalContent, // Treat entire current content as old_string
|
||||
new_string: proposedContent,
|
||||
file_path: filePath,
|
||||
},
|
||||
config.getGeminiClient(),
|
||||
abortSignal,
|
||||
);
|
||||
correctedContent = correctedParams.new_string;
|
||||
} else {
|
||||
// This implies new file (ENOENT)
|
||||
correctedContent = await ensureCorrectFileContent(
|
||||
proposedContent,
|
||||
config.getGeminiClient(),
|
||||
abortSignal,
|
||||
);
|
||||
}
|
||||
return { originalContent, correctedContent, fileExists };
|
||||
}
|
||||
|
||||
override toolLocations(params: WriteFileToolParams): ToolLocation[] {
|
||||
return [{ path: params.file_path }];
|
||||
class WriteFileToolInvocation extends BaseToolInvocation<
|
||||
WriteFileToolParams,
|
||||
ToolResult
|
||||
> {
|
||||
constructor(
|
||||
private readonly config: Config,
|
||||
params: WriteFileToolParams,
|
||||
) {
|
||||
super(params);
|
||||
}
|
||||
|
||||
override validateToolParams(params: WriteFileToolParams): string | null {
|
||||
const errors = SchemaValidator.validate(
|
||||
this.schema.parametersJsonSchema,
|
||||
params,
|
||||
);
|
||||
if (errors) {
|
||||
return errors;
|
||||
}
|
||||
|
||||
const filePath = params.file_path;
|
||||
if (!path.isAbsolute(filePath)) {
|
||||
return `File path must be absolute: ${filePath}`;
|
||||
}
|
||||
|
||||
const workspaceContext = this.config.getWorkspaceContext();
|
||||
if (!workspaceContext.isPathWithinWorkspace(filePath)) {
|
||||
const directories = workspaceContext.getDirectories();
|
||||
return `File path must be within one of the workspace directories: ${directories.join(', ')}`;
|
||||
}
|
||||
|
||||
try {
|
||||
// This check should be performed only if the path exists.
|
||||
// If it doesn't exist, it's a new file, which is valid for writing.
|
||||
if (fs.existsSync(filePath)) {
|
||||
const stats = fs.lstatSync(filePath);
|
||||
if (stats.isDirectory()) {
|
||||
return `Path is a directory, not a file: ${filePath}`;
|
||||
}
|
||||
}
|
||||
} catch (statError: unknown) {
|
||||
// If fs.existsSync is true but lstatSync fails (e.g., permissions, race condition where file is deleted)
|
||||
// this indicates an issue with accessing the path that should be reported.
|
||||
return `Error accessing path properties for validation: ${filePath}. Reason: ${statError instanceof Error ? statError.message : String(statError)}`;
|
||||
}
|
||||
|
||||
return null;
|
||||
override toolLocations(): ToolLocation[] {
|
||||
return [{ path: this.params.file_path }];
|
||||
}
|
||||
|
||||
override getDescription(params: WriteFileToolParams): string {
|
||||
if (!params.file_path) {
|
||||
return `Model did not provide valid parameters for write file tool, missing or empty "file_path"`;
|
||||
}
|
||||
override getDescription(): string {
|
||||
const relativePath = makeRelative(
|
||||
params.file_path,
|
||||
this.params.file_path,
|
||||
this.config.getTargetDir(),
|
||||
);
|
||||
return `Writing to ${shortenPath(relativePath)}`;
|
||||
}
|
||||
|
||||
/**
|
||||
* Handles the confirmation prompt for the WriteFile tool.
|
||||
*/
|
||||
override async shouldConfirmExecute(
|
||||
params: WriteFileToolParams,
|
||||
abortSignal: AbortSignal,
|
||||
): Promise<ToolCallConfirmationDetails | false> {
|
||||
if (this.config.getApprovalMode() === ApprovalMode.AUTO_EDIT) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const validationError = this.validateToolParams(params);
|
||||
if (validationError) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const correctedContentResult = await this._getCorrectedFileContent(
|
||||
params.file_path,
|
||||
params.content,
|
||||
const correctedContentResult = await getCorrectedFileContent(
|
||||
this.config,
|
||||
this.params.file_path,
|
||||
this.params.content,
|
||||
abortSignal,
|
||||
);
|
||||
|
||||
@@ -184,10 +173,10 @@ export class WriteFileTool
|
||||
|
||||
const { originalContent, correctedContent } = correctedContentResult;
|
||||
const relativePath = makeRelative(
|
||||
params.file_path,
|
||||
this.params.file_path,
|
||||
this.config.getTargetDir(),
|
||||
);
|
||||
const fileName = path.basename(params.file_path);
|
||||
const fileName = path.basename(this.params.file_path);
|
||||
|
||||
const fileDiff = Diff.createPatch(
|
||||
fileName,
|
||||
@@ -202,14 +191,14 @@ export class WriteFileTool
|
||||
const ideConfirmation =
|
||||
this.config.getIdeMode() &&
|
||||
ideClient.getConnectionStatus().status === IDEConnectionStatus.Connected
|
||||
? ideClient.openDiff(params.file_path, correctedContent)
|
||||
? ideClient.openDiff(this.params.file_path, correctedContent)
|
||||
: undefined;
|
||||
|
||||
const confirmationDetails: ToolEditConfirmationDetails = {
|
||||
type: 'edit',
|
||||
title: `Confirm Write: ${shortenPath(relativePath)}`,
|
||||
fileName,
|
||||
filePath: params.file_path,
|
||||
filePath: this.params.file_path,
|
||||
fileDiff,
|
||||
originalContent,
|
||||
newContent: correctedContent,
|
||||
@@ -221,7 +210,7 @@ export class WriteFileTool
|
||||
if (ideConfirmation) {
|
||||
const result = await ideConfirmation;
|
||||
if (result.status === 'accepted' && result.content) {
|
||||
params.content = result.content;
|
||||
this.params.content = result.content;
|
||||
}
|
||||
}
|
||||
},
|
||||
@@ -230,32 +219,20 @@ export class WriteFileTool
|
||||
return confirmationDetails;
|
||||
}
|
||||
|
||||
async execute(
|
||||
params: WriteFileToolParams,
|
||||
abortSignal: AbortSignal,
|
||||
): Promise<ToolResult> {
|
||||
const validationError = this.validateToolParams(params);
|
||||
if (validationError) {
|
||||
return {
|
||||
llmContent: `Could not write file due to invalid parameters: ${validationError}`,
|
||||
returnDisplay: validationError,
|
||||
error: {
|
||||
message: validationError,
|
||||
type: ToolErrorType.INVALID_TOOL_PARAMS,
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
const correctedContentResult = await this._getCorrectedFileContent(
|
||||
params.file_path,
|
||||
params.content,
|
||||
async execute(abortSignal: AbortSignal): Promise<ToolResult> {
|
||||
const { file_path, content, ai_proposed_content, modified_by_user } =
|
||||
this.params;
|
||||
const correctedContentResult = await getCorrectedFileContent(
|
||||
this.config,
|
||||
file_path,
|
||||
content,
|
||||
abortSignal,
|
||||
);
|
||||
|
||||
if (correctedContentResult.error) {
|
||||
const errDetails = correctedContentResult.error;
|
||||
const errorMsg = errDetails.code
|
||||
? `Error checking existing file '${params.file_path}': ${errDetails.message} (${errDetails.code})`
|
||||
? `Error checking existing file '${file_path}': ${errDetails.message} (${errDetails.code})`
|
||||
: `Error checking existing file: ${errDetails.message}`;
|
||||
return {
|
||||
llmContent: errorMsg,
|
||||
@@ -280,15 +257,17 @@ export class WriteFileTool
|
||||
!correctedContentResult.fileExists);
|
||||
|
||||
try {
|
||||
const dirName = path.dirname(params.file_path);
|
||||
const dirName = path.dirname(file_path);
|
||||
if (!fs.existsSync(dirName)) {
|
||||
fs.mkdirSync(dirName, { recursive: true });
|
||||
}
|
||||
|
||||
fs.writeFileSync(params.file_path, fileContent, 'utf8');
|
||||
await this.config
|
||||
.getFileSystemService()
|
||||
.writeTextFile(file_path, fileContent);
|
||||
|
||||
// Generate diff for display result
|
||||
const fileName = path.basename(params.file_path);
|
||||
const fileName = path.basename(file_path);
|
||||
// If there was a readError, originalContent in correctedContentResult is '',
|
||||
// but for the diff, we want to show the original content as it was before the write if possible.
|
||||
// However, if it was unreadable, currentContentForDiff will be empty.
|
||||
@@ -305,23 +284,22 @@ export class WriteFileTool
|
||||
DEFAULT_DIFF_OPTIONS,
|
||||
);
|
||||
|
||||
const originallyProposedContent =
|
||||
params.ai_proposed_content || params.content;
|
||||
const originallyProposedContent = ai_proposed_content || content;
|
||||
const diffStat = getDiffStat(
|
||||
fileName,
|
||||
currentContentForDiff,
|
||||
originallyProposedContent,
|
||||
params.content,
|
||||
content,
|
||||
);
|
||||
|
||||
const llmSuccessMessageParts = [
|
||||
isNewFile
|
||||
? `Successfully created and wrote to new file: ${params.file_path}.`
|
||||
: `Successfully overwrote file: ${params.file_path}.`,
|
||||
? `Successfully created and wrote to new file: ${file_path}.`
|
||||
: `Successfully overwrote file: ${file_path}.`,
|
||||
];
|
||||
if (params.modified_by_user) {
|
||||
if (modified_by_user) {
|
||||
llmSuccessMessageParts.push(
|
||||
`User modified the \`content\` to be: ${params.content}`,
|
||||
`User modified the \`content\` to be: ${content}`,
|
||||
);
|
||||
}
|
||||
|
||||
@@ -334,8 +312,8 @@ export class WriteFileTool
|
||||
};
|
||||
|
||||
const lines = fileContent.split('\n').length;
|
||||
const mimetype = getSpecificMimeType(params.file_path);
|
||||
const extension = path.extname(params.file_path); // Get extension
|
||||
const mimetype = getSpecificMimeType(file_path);
|
||||
const extension = path.extname(file_path); // Get extension
|
||||
if (isNewFile) {
|
||||
recordFileOperationMetric(
|
||||
this.config,
|
||||
@@ -367,17 +345,17 @@ export class WriteFileTool
|
||||
|
||||
if (isNodeError(error)) {
|
||||
// Handle specific Node.js errors with their error codes
|
||||
errorMsg = `Error writing to file '${params.file_path}': ${error.message} (${error.code})`;
|
||||
errorMsg = `Error writing to file '${file_path}': ${error.message} (${error.code})`;
|
||||
|
||||
// Log specific error types for better debugging
|
||||
if (error.code === 'EACCES') {
|
||||
errorMsg = `Permission denied writing to file: ${params.file_path} (${error.code})`;
|
||||
errorMsg = `Permission denied writing to file: ${file_path} (${error.code})`;
|
||||
errorType = ToolErrorType.PERMISSION_DENIED;
|
||||
} else if (error.code === 'ENOSPC') {
|
||||
errorMsg = `No space left on device: ${params.file_path} (${error.code})`;
|
||||
errorMsg = `No space left on device: ${file_path} (${error.code})`;
|
||||
errorType = ToolErrorType.NO_SPACE_LEFT;
|
||||
} else if (error.code === 'EISDIR') {
|
||||
errorMsg = `Target is a directory, not a file: ${params.file_path} (${error.code})`;
|
||||
errorMsg = `Target is a directory, not a file: ${file_path} (${error.code})`;
|
||||
errorType = ToolErrorType.TARGET_IS_DIRECTORY;
|
||||
}
|
||||
|
||||
@@ -401,63 +379,84 @@ export class WriteFileTool
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private async _getCorrectedFileContent(
|
||||
filePath: string,
|
||||
proposedContent: string,
|
||||
abortSignal: AbortSignal,
|
||||
): Promise<GetCorrectedFileContentResult> {
|
||||
let originalContent = '';
|
||||
let fileExists = false;
|
||||
let correctedContent = proposedContent;
|
||||
/**
|
||||
* Implementation of the WriteFile tool logic
|
||||
*/
|
||||
export class WriteFileTool
|
||||
extends BaseDeclarativeTool<WriteFileToolParams, ToolResult>
|
||||
implements ModifiableDeclarativeTool<WriteFileToolParams>
|
||||
{
|
||||
static readonly Name: string = 'write_file';
|
||||
|
||||
constructor(private readonly config: Config) {
|
||||
super(
|
||||
WriteFileTool.Name,
|
||||
'WriteFile',
|
||||
`Writes content to a specified file in the local filesystem.
|
||||
|
||||
The user has the ability to modify \`content\`. If modified, this will be stated in the response.`,
|
||||
Kind.Edit,
|
||||
{
|
||||
properties: {
|
||||
file_path: {
|
||||
description:
|
||||
"The absolute path to the file to write to (e.g., '/home/user/project/file.txt'). Relative paths are not supported.",
|
||||
type: 'string',
|
||||
},
|
||||
content: {
|
||||
description: 'The content to write to the file.',
|
||||
type: 'string',
|
||||
},
|
||||
},
|
||||
required: ['file_path', 'content'],
|
||||
type: 'object',
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
protected override validateToolParamValues(
|
||||
params: WriteFileToolParams,
|
||||
): string | null {
|
||||
const filePath = params.file_path;
|
||||
|
||||
if (!filePath) {
|
||||
return `Missing or empty "file_path"`;
|
||||
}
|
||||
|
||||
if (!path.isAbsolute(filePath)) {
|
||||
return `File path must be absolute: ${filePath}`;
|
||||
}
|
||||
|
||||
const workspaceContext = this.config.getWorkspaceContext();
|
||||
if (!workspaceContext.isPathWithinWorkspace(filePath)) {
|
||||
const directories = workspaceContext.getDirectories();
|
||||
return `File path must be within one of the workspace directories: ${directories.join(
|
||||
', ',
|
||||
)}`;
|
||||
}
|
||||
|
||||
try {
|
||||
originalContent = fs.readFileSync(filePath, 'utf8');
|
||||
fileExists = true; // File exists and was read
|
||||
} catch (err) {
|
||||
if (isNodeError(err) && err.code === 'ENOENT') {
|
||||
fileExists = false;
|
||||
originalContent = '';
|
||||
} else {
|
||||
// File exists but could not be read (permissions, etc.)
|
||||
fileExists = true; // Mark as existing but problematic
|
||||
originalContent = ''; // Can't use its content
|
||||
const error = {
|
||||
message: getErrorMessage(err),
|
||||
code: isNodeError(err) ? err.code : undefined,
|
||||
};
|
||||
// Return early as we can't proceed with content correction meaningfully
|
||||
return { originalContent, correctedContent, fileExists, error };
|
||||
if (fs.existsSync(filePath)) {
|
||||
const stats = fs.lstatSync(filePath);
|
||||
if (stats.isDirectory()) {
|
||||
return `Path is a directory, not a file: ${filePath}`;
|
||||
}
|
||||
}
|
||||
} catch (statError: unknown) {
|
||||
return `Error accessing path properties for validation: ${filePath}. Reason: ${
|
||||
statError instanceof Error ? statError.message : String(statError)
|
||||
}`;
|
||||
}
|
||||
|
||||
// If readError is set, we have returned.
|
||||
// So, file was either read successfully (fileExists=true, originalContent set)
|
||||
// or it was ENOENT (fileExists=false, originalContent='').
|
||||
return null;
|
||||
}
|
||||
|
||||
if (fileExists) {
|
||||
// This implies originalContent is available
|
||||
const { params: correctedParams } = await ensureCorrectEdit(
|
||||
filePath,
|
||||
originalContent,
|
||||
{
|
||||
old_string: originalContent, // Treat entire current content as old_string
|
||||
new_string: proposedContent,
|
||||
file_path: filePath,
|
||||
},
|
||||
this.config.getGeminiClient(),
|
||||
abortSignal,
|
||||
);
|
||||
correctedContent = correctedParams.new_string;
|
||||
} else {
|
||||
// This implies new file (ENOENT)
|
||||
correctedContent = await ensureCorrectFileContent(
|
||||
proposedContent,
|
||||
this.config.getGeminiClient(),
|
||||
abortSignal,
|
||||
);
|
||||
}
|
||||
return { originalContent, correctedContent, fileExists };
|
||||
protected createInvocation(
|
||||
params: WriteFileToolParams,
|
||||
): ToolInvocation<WriteFileToolParams, ToolResult> {
|
||||
return new WriteFileToolInvocation(this.config, params);
|
||||
}
|
||||
|
||||
getModifyContext(
|
||||
@@ -466,7 +465,8 @@ export class WriteFileTool
|
||||
return {
|
||||
getFilePath: (params: WriteFileToolParams) => params.file_path,
|
||||
getCurrentContent: async (params: WriteFileToolParams) => {
|
||||
const correctedContentResult = await this._getCorrectedFileContent(
|
||||
const correctedContentResult = await getCorrectedFileContent(
|
||||
this.config,
|
||||
params.file_path,
|
||||
params.content,
|
||||
abortSignal,
|
||||
@@ -474,7 +474,8 @@ export class WriteFileTool
|
||||
return correctedContentResult.originalContent;
|
||||
},
|
||||
getProposedContent: async (params: WriteFileToolParams) => {
|
||||
const correctedContentResult = await this._getCorrectedFileContent(
|
||||
const correctedContentResult = await getCorrectedFileContent(
|
||||
this.config,
|
||||
params.file_path,
|
||||
params.content,
|
||||
abortSignal,
|
||||
|
||||
Reference in New Issue
Block a user