Merge tag 'v0.1.15' into feature/yiheng/sync-gemini-cli-0.1.15

This commit is contained in:
奕桁
2025-08-01 23:06:11 +08:00
340 changed files with 36528 additions and 22931 deletions

View File

@@ -608,6 +608,19 @@ describe('EditTool', () => {
/User modified the `new_string` content/,
);
});
it('should return error if old_string and new_string are identical', async () => {
const initialContent = 'This is some identical text.';
fs.writeFileSync(filePath, initialContent, 'utf8');
const params: EditToolParams = {
file_path: filePath,
old_string: 'identical',
new_string: 'identical',
};
const result = await tool.execute(params, new AbortController().signal);
expect(result.llmContent).toMatch(/No changes to apply/);
expect(result.returnDisplay).toMatch(/No changes to apply/);
});
});
describe('getDescription', () => {

View File

@@ -9,9 +9,11 @@ import * as path from 'path';
import * as Diff from 'diff';
import {
BaseTool,
Icon,
ToolCallConfirmationDetails,
ToolConfirmationOutcome,
ToolEditConfirmationDetails,
ToolLocation,
ToolResult,
ToolResultDisplay,
} from './tools.js';
@@ -89,6 +91,7 @@ Expectation for required parameters:
4. NEVER escape \`old_string\` or \`new_string\`, that would break the exact literal text requirement.
**Important:** If ANY of the above are not satisfied, the tool will fail. CRITICAL for \`old_string\`: Must uniquely identify the single instance to change. Include at least 3 lines of context BEFORE and AFTER the target text, matching whitespace and indentation precisely. If this string matches multiple locations, or does not match exactly, the tool will fail.
**Multiple replacements:** Set \`expected_replacements\` to the number of occurrences you want to replace. The tool will replace ALL occurrences that match \`old_string\` exactly. Ensure the number of replacements matches your expectation.`,
Icon.Pencil,
{
properties: {
file_path: {
@@ -141,6 +144,15 @@ Expectation for required parameters:
return null;
}
/**
* Determines any file locations affected by the tool execution
* @param params Parameters for the tool execution
* @returns A list of such paths
*/
toolLocations(params: EditToolParams): ToolLocation[] {
return [{ path: params.file_path }];
}
private _applyReplacement(
currentContent: string | null,
oldString: string,
@@ -197,7 +209,7 @@ Expectation for required parameters:
// Creating a new file
isNewFile = true;
} else if (!fileExists) {
// Trying to edit a non-existent file (and old_string is not empty)
// Trying to edit a nonexistent file (and old_string is not empty)
error = {
display: `File not found. Cannot apply edit. Use an empty old_string to create a new file.`,
raw: `File not found: ${params.file_path}`,
@@ -227,12 +239,17 @@ Expectation for required parameters:
raw: `Failed to edit, 0 occurrences found for old_string in ${params.file_path}. No edits made. The exact text in old_string was not found. Ensure you're not escaping content incorrectly and check whitespace, indentation, and context. Use ${ReadFileTool.Name} tool to verify.`,
};
} else if (occurrences !== expectedReplacements) {
const occurenceTerm =
const occurrenceTerm =
expectedReplacements === 1 ? 'occurrence' : 'occurrences';
error = {
display: `Failed to edit, expected ${expectedReplacements} ${occurenceTerm} but found ${occurrences}.`,
raw: `Failed to edit, Expected ${expectedReplacements} ${occurenceTerm} but found ${occurrences} for old_string in file: ${params.file_path}`,
display: `Failed to edit, expected ${expectedReplacements} ${occurrenceTerm} but found ${occurrences}.`,
raw: `Failed to edit, Expected ${expectedReplacements} ${occurrenceTerm} but found ${occurrences} for old_string in file: ${params.file_path}`,
};
} else if (finalOldString === finalNewString) {
error = {
display: `No changes to apply. The old_string and new_string are identical.`,
raw: `No changes to apply. The old_string and new_string are identical in file: ${params.file_path}`,
};
}
} else {
@@ -306,6 +323,8 @@ Expectation for required parameters:
title: `Confirm Edit: ${shortenPath(makeRelative(params.file_path, this.config.getTargetDir()))}`,
fileName,
fileDiff,
originalContent: editData.currentContent,
newContent: editData.newContent,
onConfirm: async (outcome: ToolConfirmationOutcome) => {
if (outcome === ToolConfirmationOutcome.ProceedAlways) {
this.config.setApprovalMode(ApprovalMode.AUTO_EDIT);
@@ -394,7 +413,12 @@ Expectation for required parameters:
'Proposed',
DEFAULT_DIFF_OPTIONS,
);
displayResult = { fileDiff, fileName };
displayResult = {
fileDiff,
fileName,
originalContent: editData.currentContent,
newContent: editData.newContent,
};
}
const llmSuccessMessageParts = [

View File

@@ -150,11 +150,19 @@ describe('GlobTool', () => {
expect(typeof llmContent).toBe('string');
const filesListed = llmContent
.substring(llmContent.indexOf(':') + 1)
.trim()
.split('\n');
expect(filesListed[0]).toContain(path.join(tempRootDir, 'newer.sortme'));
expect(filesListed[1]).toContain(path.join(tempRootDir, 'older.sortme'));
.split(/\r?\n/)
.slice(1)
.map((line) => line.trim())
.filter(Boolean);
expect(filesListed).toHaveLength(2);
expect(path.resolve(filesListed[0])).toBe(
path.resolve(tempRootDir, 'newer.sortme'),
);
expect(path.resolve(filesListed[1])).toBe(
path.resolve(tempRootDir, 'older.sortme'),
);
});
});

View File

@@ -8,7 +8,7 @@ import fs from 'fs';
import path from 'path';
import { glob } from 'glob';
import { SchemaValidator } from '../utils/schemaValidator.js';
import { BaseTool, ToolResult } from './tools.js';
import { BaseTool, Icon, ToolResult } from './tools.js';
import { Type } from '@google/genai';
import { shortenPath, makeRelative } from '../utils/paths.js';
import { isWithinRoot } from '../utils/fileUtils.js';
@@ -86,6 +86,7 @@ export class GlobTool extends BaseTool<GlobToolParams, ToolResult> {
GlobTool.Name,
'FindFiles',
'Efficiently finds files matching specific glob patterns (e.g., `src/**/*.ts`, `**/*.md`), returning absolute paths sorted by modification time (newest first). Ideal for quickly locating files based on their name or path structure, especially in large codebases.',
Icon.FileSearch,
{
properties: {
pattern: {
@@ -199,7 +200,7 @@ export class GlobTool extends BaseTool<GlobToolParams, ToolResult> {
this.config.getFileFilteringRespectGitIgnore();
const fileDiscovery = this.config.getFileService();
const entries = (await glob(params.pattern, {
const entries = await glob(params.pattern, {
cwd: searchDirAbsolute,
withFileTypes: true,
nodir: true,
@@ -209,7 +210,7 @@ export class GlobTool extends BaseTool<GlobToolParams, ToolResult> {
ignore: ['**/node_modules/**', '**/.git/**'],
follow: false,
signal,
})) as GlobPath[];
});
// Apply git-aware filtering if enabled and in git repository
let filteredEntries = entries;

View File

@@ -17,7 +17,7 @@ vi.mock('child_process', () => ({
on: (event: string, cb: (...args: unknown[]) => void) => {
if (event === 'error' || event === 'close') {
// Simulate command not found or error for git grep and system grep
// to force fallback to JS implementation.
// to force it to fall back to JS implementation.
setTimeout(() => cb(1), 0); // cb(1) for error/close
}
},
@@ -125,7 +125,9 @@ describe('GrepTool', () => {
expect(result.llmContent).toContain('File: fileA.txt');
expect(result.llmContent).toContain('L1: hello world');
expect(result.llmContent).toContain('L2: second line with world');
expect(result.llmContent).toContain('File: sub/fileC.txt');
expect(result.llmContent).toContain(
`File: ${path.join('sub', 'fileC.txt')}`,
);
expect(result.llmContent).toContain('L1: another world in sub dir');
expect(result.returnDisplay).toBe('Found 3 matches');
});
@@ -235,7 +237,7 @@ describe('GrepTool', () => {
it('should generate correct description with pattern and path', () => {
const params: GrepToolParams = {
pattern: 'testPattern',
path: 'src/app',
path: path.join('src', 'app'),
};
// The path will be relative to the tempRootDir, so we check for containment.
expect(grepTool.getDescription(params)).toContain("'testPattern' within");
@@ -248,12 +250,14 @@ describe('GrepTool', () => {
const params: GrepToolParams = {
pattern: 'testPattern',
include: '*.ts',
path: 'src/app',
path: path.join('src', 'app'),
};
expect(grepTool.getDescription(params)).toContain(
"'testPattern' in *.ts within",
);
expect(grepTool.getDescription(params)).toContain('src/app');
expect(grepTool.getDescription(params)).toContain(
path.join('src', 'app'),
);
});
it('should use ./ for root path in description', () => {

View File

@@ -9,8 +9,8 @@ import fsPromises from 'fs/promises';
import path from 'path';
import { EOL } from 'os';
import { spawn } from 'child_process';
import { globStream } from 'glob';
import { BaseTool, ToolResult } from './tools.js';
import { globIterate } from 'glob';
import { BaseTool, Icon, ToolResult } from './tools.js';
import { Type } from '@google/genai';
import { SchemaValidator } from '../utils/schemaValidator.js';
import { makeRelative, shortenPath } from '../utils/paths.js';
@@ -62,6 +62,7 @@ export class GrepTool extends BaseTool<GrepToolParams, ToolResult> {
GrepTool.Name,
'SearchText',
'Searches for a regular expression pattern within the content of files in a specified directory (or current working directory). Can filter files by a glob pattern. Returns the lines containing matches, along with their file paths and line numbers.',
Icon.Regex,
{
properties: {
pattern: {
@@ -498,7 +499,7 @@ export class GrepTool extends BaseTool<GrepToolParams, ToolResult> {
'.hg/**',
]; // Use glob patterns for ignores here
const filesStream = globStream(globPattern, {
const filesIterator = globIterate(globPattern, {
cwd: absolutePath,
dot: true,
ignore: ignorePatterns,
@@ -510,7 +511,7 @@ export class GrepTool extends BaseTool<GrepToolParams, ToolResult> {
const regex = new RegExp(pattern, 'i');
const allMatches: GrepMatch[] = [];
for await (const filePath of filesStream) {
for await (const filePath of filesIterator) {
const fileAbsolutePath = filePath as string;
try {
const content = await fsPromises.readFile(fileAbsolutePath, 'utf8');

View File

@@ -6,11 +6,11 @@
import fs from 'fs';
import path from 'path';
import { BaseTool, ToolResult } from './tools.js';
import { BaseTool, Icon, ToolResult } from './tools.js';
import { Type } from '@google/genai';
import { SchemaValidator } from '../utils/schemaValidator.js';
import { makeRelative, shortenPath } from '../utils/paths.js';
import { Config } from '../config/config.js';
import { Config, DEFAULT_FILE_FILTERING_OPTIONS } from '../config/config.js';
import { isWithinRoot } from '../utils/fileUtils.js';
/**
@@ -28,9 +28,12 @@ export interface LSToolParams {
ignore?: string[];
/**
* Whether to respect .gitignore patterns (optional, defaults to true)
* Whether to respect .gitignore and .geminiignore patterns (optional, defaults to true)
*/
respect_git_ignore?: boolean;
file_filtering_options?: {
respect_git_ignore?: boolean;
respect_gemini_ignore?: boolean;
};
}
/**
@@ -74,6 +77,7 @@ export class LSTool extends BaseTool<LSToolParams, ToolResult> {
LSTool.Name,
'ReadFolder',
'Lists the names of files and subdirectories directly within a specified directory path. Can optionally ignore entries matching provided glob patterns.',
Icon.Folder,
{
properties: {
path: {
@@ -88,10 +92,22 @@ export class LSTool extends BaseTool<LSToolParams, ToolResult> {
},
type: Type.ARRAY,
},
respect_git_ignore: {
file_filtering_options: {
description:
'Optional: Whether to respect .gitignore patterns when listing files. Only available in git repositories. Defaults to true.',
type: Type.BOOLEAN,
'Optional: Whether to respect ignore patterns from .gitignore or .geminiignore',
type: Type.OBJECT,
properties: {
respect_git_ignore: {
description:
'Optional: Whether to respect .gitignore patterns when listing files. Only available in git repositories. Defaults to true.',
type: Type.BOOLEAN,
},
respect_gemini_ignore: {
description:
'Optional: Whether to respect .geminiignore patterns when listing files. Defaults to true.',
type: Type.BOOLEAN,
},
},
},
},
required: ['path'],
@@ -198,14 +214,25 @@ export class LSTool extends BaseTool<LSToolParams, ToolResult> {
const files = fs.readdirSync(params.path);
const defaultFileIgnores =
this.config.getFileFilteringOptions() ?? DEFAULT_FILE_FILTERING_OPTIONS;
const fileFilteringOptions = {
respectGitIgnore:
params.file_filtering_options?.respect_git_ignore ??
defaultFileIgnores.respectGitIgnore,
respectGeminiIgnore:
params.file_filtering_options?.respect_gemini_ignore ??
defaultFileIgnores.respectGeminiIgnore,
};
// Get centralized file discovery service
const respectGitIgnore =
params.respect_git_ignore ??
this.config.getFileFilteringRespectGitIgnore();
const fileDiscovery = this.config.getFileService();
const entries: FileEntry[] = [];
let gitIgnoredCount = 0;
let geminiIgnoredCount = 0;
if (files.length === 0) {
// Changed error message to be more neutral for LLM
@@ -226,14 +253,21 @@ export class LSTool extends BaseTool<LSToolParams, ToolResult> {
fullPath,
);
// Check if this file should be git-ignored (only in git repositories)
// Check if this file should be ignored based on git or gemini ignore rules
if (
respectGitIgnore &&
fileFilteringOptions.respectGitIgnore &&
fileDiscovery.shouldGitIgnoreFile(relativePath)
) {
gitIgnoredCount++;
continue;
}
if (
fileFilteringOptions.respectGeminiIgnore &&
fileDiscovery.shouldGeminiIgnoreFile(relativePath)
) {
geminiIgnoredCount++;
continue;
}
try {
const stats = fs.statSync(fullPath);
@@ -264,13 +298,21 @@ export class LSTool extends BaseTool<LSToolParams, ToolResult> {
.join('\n');
let resultMessage = `Directory listing for ${params.path}:\n${directoryContent}`;
const ignoredMessages = [];
if (gitIgnoredCount > 0) {
resultMessage += `\n\n(${gitIgnoredCount} items were git-ignored)`;
ignoredMessages.push(`${gitIgnoredCount} git-ignored`);
}
if (geminiIgnoredCount > 0) {
ignoredMessages.push(`${geminiIgnoredCount} gemini-ignored`);
}
if (ignoredMessages.length > 0) {
resultMessage += `\n\n(${ignoredMessages.join(', ')})`;
}
let displayMessage = `Listed ${entries.length} item(s).`;
if (gitIgnoredCount > 0) {
displayMessage += ` (${gitIgnoredCount} git-ignored)`;
if (ignoredMessages.length > 0) {
displayMessage += ` (${ignoredMessages.join(', ')})`;
}
return {

View File

@@ -9,18 +9,23 @@ import { StreamableHTTPClientTransport } from '@modelcontextprotocol/sdk/client/
import {
populateMcpServerCommand,
createTransport,
generateValidName,
isEnabled,
discoverTools,
discoverPrompts,
} from './mcp-client.js';
import { SSEClientTransport } from '@modelcontextprotocol/sdk/client/sse.js';
import * as SdkClientStdioLib from '@modelcontextprotocol/sdk/client/stdio.js';
import * as ClientLib from '@modelcontextprotocol/sdk/client/index.js';
import * as GenAiLib from '@google/genai';
import { GoogleCredentialProvider } from '../mcp/google-auth-provider.js';
import { AuthProviderType } from '../config/config.js';
import { PromptRegistry } from '../prompts/prompt-registry.js';
vi.mock('@modelcontextprotocol/sdk/client/stdio.js');
vi.mock('@modelcontextprotocol/sdk/client/index.js');
vi.mock('@google/genai');
vi.mock('../mcp/oauth-provider.js');
vi.mock('../mcp/oauth-token-storage.js');
describe('mcp-client', () => {
afterEach(() => {
@@ -47,6 +52,77 @@ describe('mcp-client', () => {
});
});
describe('discoverPrompts', () => {
const mockedPromptRegistry = {
registerPrompt: vi.fn(),
} as unknown as PromptRegistry;
it('should discover and log prompts', async () => {
const mockRequest = vi.fn().mockResolvedValue({
prompts: [
{ name: 'prompt1', description: 'desc1' },
{ name: 'prompt2' },
],
});
const mockedClient = {
request: mockRequest,
} as unknown as ClientLib.Client;
await discoverPrompts('test-server', mockedClient, mockedPromptRegistry);
expect(mockRequest).toHaveBeenCalledWith(
{ method: 'prompts/list', params: {} },
expect.anything(),
);
});
it('should do nothing if no prompts are discovered', async () => {
const mockRequest = vi.fn().mockResolvedValue({
prompts: [],
});
const mockedClient = {
request: mockRequest,
} as unknown as ClientLib.Client;
const consoleLogSpy = vi
.spyOn(console, 'debug')
.mockImplementation(() => {
// no-op
});
await discoverPrompts('test-server', mockedClient, mockedPromptRegistry);
expect(mockRequest).toHaveBeenCalledOnce();
expect(consoleLogSpy).not.toHaveBeenCalled();
consoleLogSpy.mockRestore();
});
it('should log an error if discovery fails', async () => {
const testError = new Error('test error');
testError.message = 'test error';
const mockRequest = vi.fn().mockRejectedValue(testError);
const mockedClient = {
request: mockRequest,
} as unknown as ClientLib.Client;
const consoleErrorSpy = vi
.spyOn(console, 'error')
.mockImplementation(() => {
// no-op
});
await discoverPrompts('test-server', mockedClient, mockedPromptRegistry);
expect(mockRequest).toHaveBeenCalledOnce();
expect(consoleErrorSpy).toHaveBeenCalledWith(
`Error discovering prompts from test-server: ${testError.message}`,
);
consoleErrorSpy.mockRestore();
});
});
describe('appendMcpServerCommand', () => {
it('should do nothing if no MCP servers or command are configured', () => {
const out = populateMcpServerCommand({}, undefined);
@@ -83,7 +159,7 @@ describe('mcp-client', () => {
describe('should connect via httpUrl', () => {
it('without headers', async () => {
const transport = createTransport(
const transport = await createTransport(
'test-server',
{
httpUrl: 'http://test-server',
@@ -97,7 +173,7 @@ describe('mcp-client', () => {
});
it('with headers', async () => {
const transport = createTransport(
const transport = await createTransport(
'test-server',
{
httpUrl: 'http://test-server',
@@ -118,7 +194,7 @@ describe('mcp-client', () => {
describe('should connect via url', () => {
it('without headers', async () => {
const transport = createTransport(
const transport = await createTransport(
'test-server',
{
url: 'http://test-server',
@@ -131,7 +207,7 @@ describe('mcp-client', () => {
});
it('with headers', async () => {
const transport = createTransport(
const transport = await createTransport(
'test-server',
{
url: 'http://test-server',
@@ -150,10 +226,10 @@ describe('mcp-client', () => {
});
});
it('should connect via command', () => {
it('should connect via command', async () => {
const mockedTransport = vi.mocked(SdkClientStdioLib.StdioClientTransport);
createTransport(
await createTransport(
'test-server',
{
command: 'test-command',
@@ -172,91 +248,62 @@ describe('mcp-client', () => {
stderr: 'pipe',
});
});
});
describe('generateValidName', () => {
it('should return a valid name for a simple function', () => {
const funcDecl = { name: 'myFunction' };
const serverName = 'myServer';
const result = generateValidName(funcDecl, serverName);
expect(result).toBe('myServer__myFunction');
});
it('should prepend the server name', () => {
const funcDecl = { name: 'anotherFunction' };
const serverName = 'production-server';
const result = generateValidName(funcDecl, serverName);
expect(result).toBe('production-server__anotherFunction');
});
describe('useGoogleCredentialProvider', () => {
it('should use GoogleCredentialProvider when specified', async () => {
const transport = await createTransport(
'test-server',
{
httpUrl: 'http://test-server',
authProviderType: AuthProviderType.GOOGLE_CREDENTIALS,
oauth: {
scopes: ['scope1'],
},
},
false,
);
it('should replace invalid characters with underscores', () => {
const funcDecl = { name: 'invalid-name with spaces' };
const serverName = 'test_server';
const result = generateValidName(funcDecl, serverName);
expect(result).toBe('test_server__invalid-name_with_spaces');
});
expect(transport).toBeInstanceOf(StreamableHTTPClientTransport);
// eslint-disable-next-line @typescript-eslint/no-explicit-any
const authProvider = (transport as any)._authProvider;
expect(authProvider).toBeInstanceOf(GoogleCredentialProvider);
});
it('should truncate long names', () => {
const funcDecl = {
name: 'a_very_long_function_name_that_will_definitely_exceed_the_limit',
};
const serverName = 'a_long_server_name';
const result = generateValidName(funcDecl, serverName);
expect(result.length).toBe(63);
expect(result).toBe(
'a_long_server_name__a_very_l___will_definitely_exceed_the_limit',
);
});
it('should use GoogleCredentialProvider with SSE transport', async () => {
const transport = await createTransport(
'test-server',
{
url: 'http://test-server',
authProviderType: AuthProviderType.GOOGLE_CREDENTIALS,
oauth: {
scopes: ['scope1'],
},
},
false,
);
it('should handle names with only invalid characters', () => {
const funcDecl = { name: '!@#$%^&*()' };
const serverName = 'special-chars';
const result = generateValidName(funcDecl, serverName);
expect(result).toBe('special-chars____________');
});
expect(transport).toBeInstanceOf(SSEClientTransport);
// eslint-disable-next-line @typescript-eslint/no-explicit-any
const authProvider = (transport as any)._authProvider;
expect(authProvider).toBeInstanceOf(GoogleCredentialProvider);
});
it('should handle names that are already valid', () => {
const funcDecl = { name: 'already_valid' };
const serverName = 'validator';
const result = generateValidName(funcDecl, serverName);
expect(result).toBe('validator__already_valid');
});
it('should handle names with leading/trailing invalid characters', () => {
const funcDecl = { name: '-_invalid-_' };
const serverName = 'trim-test';
const result = generateValidName(funcDecl, serverName);
expect(result).toBe('trim-test__-_invalid-_');
});
it('should handle names that are exactly 63 characters long', () => {
const longName = 'a'.repeat(45);
const funcDecl = { name: longName };
const serverName = 'server';
const result = generateValidName(funcDecl, serverName);
expect(result).toBe(`server__${longName}`);
expect(result.length).toBe(53);
});
it('should handle names that are exactly 64 characters long', () => {
const longName = 'a'.repeat(55);
const funcDecl = { name: longName };
const serverName = 'server';
const result = generateValidName(funcDecl, serverName);
expect(result.length).toBe(63);
expect(result).toBe(
'server__aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa',
);
});
it('should handle names that are longer than 64 characters', () => {
const longName = 'a'.repeat(100);
const funcDecl = { name: longName };
const serverName = 'long-server';
const result = generateValidName(funcDecl, serverName);
expect(result.length).toBe(63);
expect(result).toBe(
'long-server__aaaaaaaaaaaaaaa___aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa',
);
it('should throw an error if no URL is provided with GoogleCredentialProvider', async () => {
await expect(
createTransport(
'test-server',
{
authProviderType: AuthProviderType.GOOGLE_CREDENTIALS,
oauth: {
scopes: ['scope1'],
},
},
false,
),
).rejects.toThrow(
'No URL configured for Google Credentials MCP server',
);
});
});
});
describe('isEnabled', () => {

View File

@@ -15,14 +15,32 @@ import {
StreamableHTTPClientTransport,
StreamableHTTPClientTransportOptions,
} from '@modelcontextprotocol/sdk/client/streamableHttp.js';
import {
Prompt,
ListPromptsResultSchema,
GetPromptResult,
GetPromptResultSchema,
} from '@modelcontextprotocol/sdk/types.js';
import { parse } from 'shell-quote';
import { MCPServerConfig } from '../config/config.js';
import { AuthProviderType, MCPServerConfig } from '../config/config.js';
import { GoogleCredentialProvider } from '../mcp/google-auth-provider.js';
import { DiscoveredMCPTool } from './mcp-tool.js';
import { FunctionDeclaration, Type, mcpToTool } from '@google/genai';
import { sanitizeParameters, ToolRegistry } from './tool-registry.js';
import { FunctionDeclaration, mcpToTool } from '@google/genai';
import { ToolRegistry } from './tool-registry.js';
import { PromptRegistry } from '../prompts/prompt-registry.js';
import { MCPOAuthProvider } from '../mcp/oauth-provider.js';
import { OAuthUtils } from '../mcp/oauth-utils.js';
import { MCPOAuthTokenStorage } from '../mcp/oauth-token-storage.js';
import { getErrorMessage } from '../utils/errors.js';
export const MCP_DEFAULT_TIMEOUT_MSEC = 10 * 60 * 1000; // default to 10 minutes
export type DiscoveredMCPPrompt = Prompt & {
serverName: string;
invoke: (params: Record<string, unknown>) => Promise<GetPromptResult>;
};
/**
* Enum representing the connection status of an MCP server
*/
@@ -50,13 +68,18 @@ export enum MCPDiscoveryState {
/**
* Map to track the status of each MCP server within the core package
*/
const mcpServerStatusesInternal: Map<string, MCPServerStatus> = new Map();
const serverStatuses: Map<string, MCPServerStatus> = new Map();
/**
* Track the overall MCP discovery state
*/
let mcpDiscoveryState: MCPDiscoveryState = MCPDiscoveryState.NOT_STARTED;
/**
* Map to track which MCP servers have been discovered to require OAuth
*/
export const mcpServerRequiresOAuth: Map<string, boolean> = new Map();
/**
* Event listeners for MCP server status changes
*/
@@ -94,7 +117,7 @@ function updateMCPServerStatus(
serverName: string,
status: MCPServerStatus,
): void {
mcpServerStatusesInternal.set(serverName, status);
serverStatuses.set(serverName, status);
// Notify all listeners
for (const listener of statusChangeListeners) {
listener(serverName, status);
@@ -105,16 +128,14 @@ function updateMCPServerStatus(
* Get the current status of an MCP server
*/
export function getMCPServerStatus(serverName: string): MCPServerStatus {
return (
mcpServerStatusesInternal.get(serverName) || MCPServerStatus.DISCONNECTED
);
return serverStatuses.get(serverName) || MCPServerStatus.DISCONNECTED;
}
/**
* Get all MCP server statuses
*/
export function getAllMCPServerStatuses(): Map<string, MCPServerStatus> {
return new Map(mcpServerStatusesInternal);
return new Map(serverStatuses);
}
/**
@@ -124,6 +145,165 @@ export function getMCPDiscoveryState(): MCPDiscoveryState {
return mcpDiscoveryState;
}
/**
* Parse www-authenticate header to extract OAuth metadata URI.
*
* @param wwwAuthenticate The www-authenticate header value
* @returns The resource metadata URI if found, null otherwise
*/
function _parseWWWAuthenticate(wwwAuthenticate: string): string | null {
// Parse header like: Bearer realm="MCP Server", resource_metadata_uri="https://..."
const resourceMetadataMatch = wwwAuthenticate.match(
/resource_metadata_uri="([^"]+)"/,
);
return resourceMetadataMatch ? resourceMetadataMatch[1] : null;
}
/**
* Extract WWW-Authenticate header from error message string.
* This is a more robust approach than regex matching.
*
* @param errorString The error message string
* @returns The www-authenticate header value if found, null otherwise
*/
function extractWWWAuthenticateHeader(errorString: string): string | null {
// Try multiple patterns to extract the header
const patterns = [
/www-authenticate:\s*([^\n\r]+)/i,
/WWW-Authenticate:\s*([^\n\r]+)/i,
/"www-authenticate":\s*"([^"]+)"/i,
/'www-authenticate':\s*'([^']+)'/i,
];
for (const pattern of patterns) {
const match = errorString.match(pattern);
if (match) {
return match[1].trim();
}
}
return null;
}
/**
* Handle automatic OAuth discovery and authentication for a server.
*
* @param mcpServerName The name of the MCP server
* @param mcpServerConfig The MCP server configuration
* @param wwwAuthenticate The www-authenticate header value
* @returns True if OAuth was successfully configured and authenticated, false otherwise
*/
async function handleAutomaticOAuth(
mcpServerName: string,
mcpServerConfig: MCPServerConfig,
wwwAuthenticate: string,
): Promise<boolean> {
try {
console.log(`🔐 '${mcpServerName}' requires OAuth authentication`);
// Always try to parse the resource metadata URI from the www-authenticate header
let oauthConfig;
const resourceMetadataUri =
OAuthUtils.parseWWWAuthenticateHeader(wwwAuthenticate);
if (resourceMetadataUri) {
oauthConfig = await OAuthUtils.discoverOAuthConfig(resourceMetadataUri);
} else if (mcpServerConfig.url) {
// Fallback: try to discover OAuth config from the base URL for SSE
const sseUrl = new URL(mcpServerConfig.url);
const baseUrl = `${sseUrl.protocol}//${sseUrl.host}`;
oauthConfig = await OAuthUtils.discoverOAuthConfig(baseUrl);
} else if (mcpServerConfig.httpUrl) {
// Fallback: try to discover OAuth config from the base URL for HTTP
const httpUrl = new URL(mcpServerConfig.httpUrl);
const baseUrl = `${httpUrl.protocol}//${httpUrl.host}`;
oauthConfig = await OAuthUtils.discoverOAuthConfig(baseUrl);
}
if (!oauthConfig) {
console.error(
`❌ Could not configure OAuth for '${mcpServerName}' - please authenticate manually with /mcp auth ${mcpServerName}`,
);
return false;
}
// OAuth configuration discovered - proceed with authentication
// Create OAuth configuration for authentication
const oauthAuthConfig = {
enabled: true,
authorizationUrl: oauthConfig.authorizationUrl,
tokenUrl: oauthConfig.tokenUrl,
scopes: oauthConfig.scopes || [],
};
// Perform OAuth authentication
console.log(
`Starting OAuth authentication for server '${mcpServerName}'...`,
);
await MCPOAuthProvider.authenticate(mcpServerName, oauthAuthConfig);
console.log(
`OAuth authentication successful for server '${mcpServerName}'`,
);
return true;
} catch (error) {
console.error(
`Failed to handle automatic OAuth for server '${mcpServerName}': ${getErrorMessage(error)}`,
);
return false;
}
}
/**
* Create a transport with OAuth token for the given server configuration.
*
* @param mcpServerName The name of the MCP server
* @param mcpServerConfig The MCP server configuration
* @param accessToken The OAuth access token
* @returns The transport with OAuth token, or null if creation fails
*/
async function createTransportWithOAuth(
mcpServerName: string,
mcpServerConfig: MCPServerConfig,
accessToken: string,
): Promise<StreamableHTTPClientTransport | SSEClientTransport | null> {
try {
if (mcpServerConfig.httpUrl) {
// Create HTTP transport with OAuth token
const oauthTransportOptions: StreamableHTTPClientTransportOptions = {
requestInit: {
headers: {
...mcpServerConfig.headers,
Authorization: `Bearer ${accessToken}`,
},
},
};
return new StreamableHTTPClientTransport(
new URL(mcpServerConfig.httpUrl),
oauthTransportOptions,
);
} else if (mcpServerConfig.url) {
// Create SSE transport with OAuth token in Authorization header
return new SSEClientTransport(new URL(mcpServerConfig.url), {
requestInit: {
headers: {
...mcpServerConfig.headers,
Authorization: `Bearer ${accessToken}`,
},
},
});
}
return null;
} catch (error) {
console.error(
`Failed to create OAuth transport for server '${mcpServerName}': ${getErrorMessage(error)}`,
);
return null;
}
}
/**
* Discovers tools from all configured MCP servers and registers them with the tool registry.
* It orchestrates the connection and discovery process for each server defined in the
@@ -138,6 +318,7 @@ export async function discoverMcpTools(
mcpServers: Record<string, MCPServerConfig>,
mcpServerCommand: string | undefined,
toolRegistry: ToolRegistry,
promptRegistry: PromptRegistry,
debugMode: boolean,
): Promise<void> {
mcpDiscoveryState = MCPDiscoveryState.IN_PROGRESS;
@@ -150,6 +331,7 @@ export async function discoverMcpTools(
mcpServerName,
mcpServerConfig,
toolRegistry,
promptRegistry,
debugMode,
),
);
@@ -193,6 +375,7 @@ export async function connectAndDiscover(
mcpServerName: string,
mcpServerConfig: MCPServerConfig,
toolRegistry: ToolRegistry,
promptRegistry: PromptRegistry,
debugMode: boolean,
): Promise<void> {
updateMCPServerStatus(mcpServerName, MCPServerStatus.CONNECTING);
@@ -205,11 +388,11 @@ export async function connectAndDiscover(
);
try {
updateMCPServerStatus(mcpServerName, MCPServerStatus.CONNECTED);
mcpClient.onerror = (error) => {
console.error(`MCP ERROR (${mcpServerName}):`, error.toString());
updateMCPServerStatus(mcpServerName, MCPServerStatus.DISCONNECTED);
};
await discoverPrompts(mcpServerName, mcpClient, promptRegistry);
const tools = await discoverTools(
mcpServerName,
@@ -224,7 +407,11 @@ export async function connectAndDiscover(
throw error;
}
} catch (error) {
console.error(`Error connecting to MCP server '${mcpServerName}':`, error);
console.error(
`Error connecting to MCP server '${mcpServerName}': ${getErrorMessage(
error,
)}`,
);
updateMCPServerStatus(mcpServerName, MCPServerStatus.DISCONNECTED);
}
}
@@ -259,32 +446,109 @@ export async function discoverTools(
continue;
}
const toolNameForModel = generateValidName(funcDecl, mcpServerName);
sanitizeParameters(funcDecl.parameters);
discoveredTools.push(
new DiscoveredMCPTool(
mcpCallableTool,
mcpServerName,
toolNameForModel,
funcDecl.description ?? '',
funcDecl.parameters ?? { type: Type.OBJECT, properties: {} },
funcDecl.name!,
funcDecl.description ?? '',
funcDecl.parametersJsonSchema ?? { type: 'object', properties: {} },
mcpServerConfig.timeout ?? MCP_DEFAULT_TIMEOUT_MSEC,
mcpServerConfig.trust,
),
);
}
if (discoveredTools.length === 0) {
throw Error('No enabled tools found');
}
return discoveredTools;
} catch (error) {
throw new Error(`Error discovering tools: ${error}`);
}
}
/**
* Discovers and logs prompts from a connected MCP client.
* It retrieves prompt declarations from the client and logs their names.
*
* @param mcpServerName The name of the MCP server.
* @param mcpClient The active MCP client instance.
*/
export async function discoverPrompts(
mcpServerName: string,
mcpClient: Client,
promptRegistry: PromptRegistry,
): Promise<void> {
try {
const response = await mcpClient.request(
{ method: 'prompts/list', params: {} },
ListPromptsResultSchema,
);
for (const prompt of response.prompts) {
promptRegistry.registerPrompt({
...prompt,
serverName: mcpServerName,
invoke: (params: Record<string, unknown>) =>
invokeMcpPrompt(mcpServerName, mcpClient, prompt.name, params),
});
}
} catch (error) {
// It's okay if this fails, not all servers will have prompts.
// Don't log an error if the method is not found, which is a common case.
if (
error instanceof Error &&
!error.message?.includes('Method not found')
) {
console.error(
`Error discovering prompts from ${mcpServerName}: ${getErrorMessage(
error,
)}`,
);
}
}
}
/**
* Invokes a prompt on a connected MCP client.
*
* @param mcpServerName The name of the MCP server.
* @param mcpClient The active MCP client instance.
* @param promptName The name of the prompt to invoke.
* @param promptParams The parameters to pass to the prompt.
* @returns A promise that resolves to the result of the prompt invocation.
*/
export async function invokeMcpPrompt(
mcpServerName: string,
mcpClient: Client,
promptName: string,
promptParams: Record<string, unknown>,
): Promise<GetPromptResult> {
try {
const response = await mcpClient.request(
{
method: 'prompts/get',
params: {
name: promptName,
arguments: promptParams,
},
},
GetPromptResultSchema,
);
return response;
} catch (error) {
if (
error instanceof Error &&
!error.message?.includes('Method not found')
) {
console.error(
`Error invoking prompt '${promptName}' from ${mcpServerName} ${promptParams}: ${getErrorMessage(
error,
)}`,
);
}
throw error;
}
}
/**
* Creates and connects an MCP client to a server based on the provided configuration.
* It determines the appropriate transport (Stdio, SSE, or Streamable HTTP) and
@@ -318,7 +582,7 @@ export async function connectToMcpServer(
}
try {
const transport = createTransport(
const transport = await createTransport(
mcpServerName,
mcpServerConfig,
debugMode,
@@ -333,40 +597,419 @@ export async function connectToMcpServer(
throw error;
}
} catch (error) {
// Create a safe config object that excludes sensitive information
const safeConfig = {
command: mcpServerConfig.command,
url: mcpServerConfig.url,
httpUrl: mcpServerConfig.httpUrl,
cwd: mcpServerConfig.cwd,
timeout: mcpServerConfig.timeout,
trust: mcpServerConfig.trust,
// Exclude args, env, and headers which may contain sensitive data
};
// Check if this is a 401 error that might indicate OAuth is required
const errorString = String(error);
if (
errorString.includes('401') &&
(mcpServerConfig.httpUrl || mcpServerConfig.url)
) {
mcpServerRequiresOAuth.set(mcpServerName, true);
// Only trigger automatic OAuth discovery for HTTP servers or when OAuth is explicitly configured
// For SSE servers, we should not trigger new OAuth flows automatically
const shouldTriggerOAuth =
mcpServerConfig.httpUrl || mcpServerConfig.oauth?.enabled;
let errorString =
`failed to start or connect to MCP server '${mcpServerName}' ` +
`${JSON.stringify(safeConfig)}; \n${error}`;
if (process.env.SANDBOX) {
errorString += `\nMake sure it is available in the sandbox`;
if (!shouldTriggerOAuth) {
// For SSE servers without explicit OAuth config, if a token was found but rejected, report it accurately.
const credentials = await MCPOAuthTokenStorage.getToken(mcpServerName);
if (credentials) {
const hasStoredTokens = await MCPOAuthProvider.getValidToken(
mcpServerName,
{
// Pass client ID if available
clientId: credentials.clientId,
},
);
if (hasStoredTokens) {
console.log(
`Stored OAuth token for SSE server '${mcpServerName}' was rejected. ` +
`Please re-authenticate using: /mcp auth ${mcpServerName}`,
);
} else {
console.log(
`401 error received for SSE server '${mcpServerName}' without OAuth configuration. ` +
`Please authenticate using: /mcp auth ${mcpServerName}`,
);
}
}
throw new Error(
`401 error received for SSE server '${mcpServerName}' without OAuth configuration. ` +
`Please authenticate using: /mcp auth ${mcpServerName}`,
);
}
// Try to extract www-authenticate header from the error
let wwwAuthenticate = extractWWWAuthenticateHeader(errorString);
// If we didn't get the header from the error string, try to get it from the server
if (!wwwAuthenticate && mcpServerConfig.url) {
console.log(
`No www-authenticate header in error, trying to fetch it from server...`,
);
try {
const response = await fetch(mcpServerConfig.url, {
method: 'HEAD',
headers: {
Accept: 'text/event-stream',
},
signal: AbortSignal.timeout(5000),
});
if (response.status === 401) {
wwwAuthenticate = response.headers.get('www-authenticate');
if (wwwAuthenticate) {
console.log(
`Found www-authenticate header from server: ${wwwAuthenticate}`,
);
}
}
} catch (fetchError) {
console.debug(
`Failed to fetch www-authenticate header: ${getErrorMessage(fetchError)}`,
);
}
}
if (wwwAuthenticate) {
console.log(
`Received 401 with www-authenticate header: ${wwwAuthenticate}`,
);
// Try automatic OAuth discovery and authentication
const oauthSuccess = await handleAutomaticOAuth(
mcpServerName,
mcpServerConfig,
wwwAuthenticate,
);
if (oauthSuccess) {
// Retry connection with OAuth token
console.log(
`Retrying connection to '${mcpServerName}' with OAuth token...`,
);
// Get the valid token - we need to create a proper OAuth config
// The token should already be available from the authentication process
const credentials =
await MCPOAuthTokenStorage.getToken(mcpServerName);
if (credentials) {
const accessToken = await MCPOAuthProvider.getValidToken(
mcpServerName,
{
// Pass client ID if available
clientId: credentials.clientId,
},
);
if (accessToken) {
// Create transport with OAuth token
const oauthTransport = await createTransportWithOAuth(
mcpServerName,
mcpServerConfig,
accessToken,
);
if (oauthTransport) {
try {
await mcpClient.connect(oauthTransport, {
timeout:
mcpServerConfig.timeout ?? MCP_DEFAULT_TIMEOUT_MSEC,
});
// Connection successful with OAuth
return mcpClient;
} catch (retryError) {
console.error(
`Failed to connect with OAuth token: ${getErrorMessage(
retryError,
)}`,
);
throw retryError;
}
} else {
console.error(
`Failed to create OAuth transport for server '${mcpServerName}'`,
);
throw new Error(
`Failed to create OAuth transport for server '${mcpServerName}'`,
);
}
} else {
console.error(
`Failed to get OAuth token for server '${mcpServerName}'`,
);
throw new Error(
`Failed to get OAuth token for server '${mcpServerName}'`,
);
}
} else {
console.error(
`Failed to get credentials for server '${mcpServerName}' after successful OAuth authentication`,
);
throw new Error(
`Failed to get credentials for server '${mcpServerName}' after successful OAuth authentication`,
);
}
} else {
console.error(
`Failed to handle automatic OAuth for server '${mcpServerName}'`,
);
throw new Error(
`Failed to handle automatic OAuth for server '${mcpServerName}'`,
);
}
} else {
// No www-authenticate header found, but we got a 401
// Only try OAuth discovery for HTTP servers or when OAuth is explicitly configured
// For SSE servers, we should not trigger new OAuth flows automatically
const shouldTryDiscovery =
mcpServerConfig.httpUrl || mcpServerConfig.oauth?.enabled;
if (!shouldTryDiscovery) {
const credentials =
await MCPOAuthTokenStorage.getToken(mcpServerName);
if (credentials) {
const hasStoredTokens = await MCPOAuthProvider.getValidToken(
mcpServerName,
{
// Pass client ID if available
clientId: credentials.clientId,
},
);
if (hasStoredTokens) {
console.log(
`Stored OAuth token for SSE server '${mcpServerName}' was rejected. ` +
`Please re-authenticate using: /mcp auth ${mcpServerName}`,
);
} else {
console.log(
`401 error received for SSE server '${mcpServerName}' without OAuth configuration. ` +
`Please authenticate using: /mcp auth ${mcpServerName}`,
);
}
}
throw new Error(
`401 error received for SSE server '${mcpServerName}' without OAuth configuration. ` +
`Please authenticate using: /mcp auth ${mcpServerName}`,
);
}
// For SSE servers, try to discover OAuth configuration from the base URL
console.log(`🔍 Attempting OAuth discovery for '${mcpServerName}'...`);
if (mcpServerConfig.url) {
const sseUrl = new URL(mcpServerConfig.url);
const baseUrl = `${sseUrl.protocol}//${sseUrl.host}`;
try {
// Try to discover OAuth configuration from the base URL
const oauthConfig = await OAuthUtils.discoverOAuthConfig(baseUrl);
if (oauthConfig) {
console.log(
`Discovered OAuth configuration from base URL for server '${mcpServerName}'`,
);
// Create OAuth configuration for authentication
const oauthAuthConfig = {
enabled: true,
authorizationUrl: oauthConfig.authorizationUrl,
tokenUrl: oauthConfig.tokenUrl,
scopes: oauthConfig.scopes || [],
};
// Perform OAuth authentication
console.log(
`Starting OAuth authentication for server '${mcpServerName}'...`,
);
await MCPOAuthProvider.authenticate(
mcpServerName,
oauthAuthConfig,
);
// Retry connection with OAuth token
const credentials =
await MCPOAuthTokenStorage.getToken(mcpServerName);
if (credentials) {
const accessToken = await MCPOAuthProvider.getValidToken(
mcpServerName,
{
// Pass client ID if available
clientId: credentials.clientId,
},
);
if (accessToken) {
// Create transport with OAuth token
const oauthTransport = await createTransportWithOAuth(
mcpServerName,
mcpServerConfig,
accessToken,
);
if (oauthTransport) {
try {
await mcpClient.connect(oauthTransport, {
timeout:
mcpServerConfig.timeout ?? MCP_DEFAULT_TIMEOUT_MSEC,
});
// Connection successful with OAuth
return mcpClient;
} catch (retryError) {
console.error(
`Failed to connect with OAuth token: ${getErrorMessage(
retryError,
)}`,
);
throw retryError;
}
} else {
console.error(
`Failed to create OAuth transport for server '${mcpServerName}'`,
);
throw new Error(
`Failed to create OAuth transport for server '${mcpServerName}'`,
);
}
} else {
console.error(
`Failed to get OAuth token for server '${mcpServerName}'`,
);
throw new Error(
`Failed to get OAuth token for server '${mcpServerName}'`,
);
}
} else {
console.error(
`Failed to get stored credentials for server '${mcpServerName}'`,
);
throw new Error(
`Failed to get stored credentials for server '${mcpServerName}'`,
);
}
} else {
console.error(
`❌ Could not configure OAuth for '${mcpServerName}' - please authenticate manually with /mcp auth ${mcpServerName}`,
);
throw new Error(
`OAuth configuration failed for '${mcpServerName}'. Please authenticate manually with /mcp auth ${mcpServerName}`,
);
}
} catch (discoveryError) {
console.error(
`❌ OAuth discovery failed for '${mcpServerName}' - please authenticate manually with /mcp auth ${mcpServerName}`,
);
throw discoveryError;
}
} else {
console.error(
`❌ '${mcpServerName}' requires authentication but no OAuth configuration found`,
);
throw new Error(
`MCP server '${mcpServerName}' requires authentication. Please configure OAuth or check server settings.`,
);
}
}
} else {
// Handle other connection errors
// Create a concise error message
const errorMessage = (error as Error).message || String(error);
const isNetworkError =
errorMessage.includes('ENOTFOUND') ||
errorMessage.includes('ECONNREFUSED');
let conciseError: string;
if (isNetworkError) {
conciseError = `Cannot connect to '${mcpServerName}' - server may be down or URL incorrect`;
} else {
conciseError = `Connection failed for '${mcpServerName}': ${errorMessage}`;
}
if (process.env.SANDBOX) {
conciseError += ` (check sandbox availability)`;
}
throw new Error(conciseError);
}
throw new Error(errorString);
}
}
/** Visible for Testing */
export function createTransport(
export async function createTransport(
mcpServerName: string,
mcpServerConfig: MCPServerConfig,
debugMode: boolean,
): Transport {
): Promise<Transport> {
if (
mcpServerConfig.authProviderType === AuthProviderType.GOOGLE_CREDENTIALS
) {
const provider = new GoogleCredentialProvider(mcpServerConfig);
const transportOptions:
| StreamableHTTPClientTransportOptions
| SSEClientTransportOptions = {
authProvider: provider,
};
if (mcpServerConfig.httpUrl) {
return new StreamableHTTPClientTransport(
new URL(mcpServerConfig.httpUrl),
transportOptions,
);
} else if (mcpServerConfig.url) {
return new SSEClientTransport(
new URL(mcpServerConfig.url),
transportOptions,
);
}
throw new Error('No URL configured for Google Credentials MCP server');
}
// Check if we have OAuth configuration or stored tokens
let accessToken: string | null = null;
let hasOAuthConfig = mcpServerConfig.oauth?.enabled;
if (hasOAuthConfig && mcpServerConfig.oauth) {
accessToken = await MCPOAuthProvider.getValidToken(
mcpServerName,
mcpServerConfig.oauth,
);
if (!accessToken) {
console.error(
`MCP server '${mcpServerName}' requires OAuth authentication. ` +
`Please authenticate using the /mcp auth command.`,
);
throw new Error(
`MCP server '${mcpServerName}' requires OAuth authentication. ` +
`Please authenticate using the /mcp auth command.`,
);
}
} else {
// Check if we have stored OAuth tokens for this server (from previous authentication)
const credentials = await MCPOAuthTokenStorage.getToken(mcpServerName);
if (credentials) {
accessToken = await MCPOAuthProvider.getValidToken(mcpServerName, {
// Pass client ID if available
clientId: credentials.clientId,
});
if (accessToken) {
hasOAuthConfig = true;
console.log(`Found stored OAuth token for server '${mcpServerName}'`);
}
}
}
if (mcpServerConfig.httpUrl) {
const transportOptions: StreamableHTTPClientTransportOptions = {};
if (mcpServerConfig.headers) {
// Set up headers with OAuth token if available
if (hasOAuthConfig && accessToken) {
transportOptions.requestInit = {
headers: {
...mcpServerConfig.headers,
Authorization: `Bearer ${accessToken}`,
},
};
} else if (mcpServerConfig.headers) {
transportOptions.requestInit = {
headers: mcpServerConfig.headers,
};
}
return new StreamableHTTPClientTransport(
new URL(mcpServerConfig.httpUrl),
transportOptions,
@@ -375,11 +1018,21 @@ export function createTransport(
if (mcpServerConfig.url) {
const transportOptions: SSEClientTransportOptions = {};
if (mcpServerConfig.headers) {
// Set up headers with OAuth token if available
if (hasOAuthConfig && accessToken) {
transportOptions.requestInit = {
headers: {
...mcpServerConfig.headers,
Authorization: `Bearer ${accessToken}`,
},
};
} else if (mcpServerConfig.headers) {
transportOptions.requestInit = {
headers: mcpServerConfig.headers,
};
}
return new SSEClientTransport(
new URL(mcpServerConfig.url),
transportOptions,
@@ -411,26 +1064,6 @@ export function createTransport(
);
}
/** Visible for testing */
export function generateValidName(
funcDecl: FunctionDeclaration,
mcpServerName: string,
) {
// Replace invalid characters (based on 400 error message from Gemini API) with underscores
let validToolname = funcDecl.name!.replace(/[^a-zA-Z0-9_.-]/g, '_');
// Prepend MCP server name to avoid conflicts with other tools
validToolname = mcpServerName + '__' + validToolname;
// If longer than 63 characters, replace middle with '___'
// (Gemini API says max length 64, but actual limit seems to be 63)
if (validToolname.length > 63) {
validToolname =
validToolname.slice(0, 28) + '___' + validToolname.slice(-32);
}
return validToolname;
}
/** Visible for testing */
export function isEnabled(
funcDecl: FunctionDeclaration,

View File

@@ -14,7 +14,7 @@ import {
afterEach,
Mocked,
} from 'vitest';
import { DiscoveredMCPTool } from './mcp-tool.js'; // Added getStringifiedResultForDisplay
import { DiscoveredMCPTool, generateValidName } from './mcp-tool.js'; // Added getStringifiedResultForDisplay
import { ToolResult, ToolConfirmationOutcome } from './tools.js'; // Added ToolConfirmationOutcome
import { CallableTool, Part } from '@google/genai';
@@ -29,9 +29,42 @@ const mockCallableToolInstance: Mocked<CallableTool> = {
// Add other methods if DiscoveredMCPTool starts using them
};
describe('generateValidName', () => {
it('should return a valid name for a simple function', () => {
expect(generateValidName('myFunction')).toBe('myFunction');
});
it('should replace invalid characters with underscores', () => {
expect(generateValidName('invalid-name with spaces')).toBe(
'invalid-name_with_spaces',
);
});
it('should truncate long names', () => {
expect(generateValidName('x'.repeat(80))).toBe(
'xxxxxxxxxxxxxxxxxxxxxxxxxxxx___xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx',
);
});
it('should handle names with only invalid characters', () => {
expect(generateValidName('!@#$%^&*()')).toBe('__________');
});
it('should handle names that are exactly 63 characters long', () => {
expect(generateValidName('a'.repeat(63)).length).toBe(63);
});
it('should handle names that are exactly 64 characters long', () => {
expect(generateValidName('a'.repeat(64)).length).toBe(63);
});
it('should handle names that are longer than 64 characters', () => {
expect(generateValidName('a'.repeat(80)).length).toBe(63);
});
});
describe('DiscoveredMCPTool', () => {
const serverName = 'mock-mcp-server';
const toolNameForModel = 'test-mcp-tool-for-model';
const serverToolName = 'actual-server-tool-name';
const baseDescription = 'A test MCP tool.';
const inputSchema: Record<string, unknown> = {
@@ -52,46 +85,32 @@ describe('DiscoveredMCPTool', () => {
});
describe('constructor', () => {
it('should set properties correctly (non-generic server)', () => {
it('should set properties correctly', () => {
const tool = new DiscoveredMCPTool(
mockCallableToolInstance,
serverName, // serverName is 'mock-mcp-server', not 'mcp'
toolNameForModel,
serverName,
serverToolName,
baseDescription,
inputSchema,
serverToolName,
);
expect(tool.name).toBe(toolNameForModel);
expect(tool.schema.name).toBe(toolNameForModel);
expect(tool.name).toBe(serverToolName);
expect(tool.schema.name).toBe(serverToolName);
expect(tool.schema.description).toBe(baseDescription);
expect(tool.schema.parameters).toEqual(inputSchema);
expect(tool.schema.parameters).toBeUndefined();
expect(tool.schema.parametersJsonSchema).toEqual(inputSchema);
expect(tool.serverToolName).toBe(serverToolName);
expect(tool.timeout).toBeUndefined();
});
it('should set properties correctly (generic "mcp" server)', () => {
const genericServerName = 'mcp';
const tool = new DiscoveredMCPTool(
mockCallableToolInstance,
genericServerName, // serverName is 'mcp'
toolNameForModel,
baseDescription,
inputSchema,
serverToolName,
);
expect(tool.schema.description).toBe(baseDescription);
});
it('should accept and store a custom timeout', () => {
const customTimeout = 5000;
const tool = new DiscoveredMCPTool(
mockCallableToolInstance,
serverName,
toolNameForModel,
serverToolName,
baseDescription,
inputSchema,
serverToolName,
customTimeout,
);
expect(tool.timeout).toBe(customTimeout);
@@ -103,10 +122,9 @@ describe('DiscoveredMCPTool', () => {
const tool = new DiscoveredMCPTool(
mockCallableToolInstance,
serverName,
toolNameForModel,
serverToolName,
baseDescription,
inputSchema,
serverToolName,
);
const params = { param: 'testValue' };
const mockToolSuccessResultObject = {
@@ -143,10 +161,9 @@ describe('DiscoveredMCPTool', () => {
const tool = new DiscoveredMCPTool(
mockCallableToolInstance,
serverName,
toolNameForModel,
serverToolName,
baseDescription,
inputSchema,
serverToolName,
);
const params = { param: 'testValue' };
const mockMcpToolResponsePartsEmpty: Part[] = [];
@@ -159,10 +176,9 @@ describe('DiscoveredMCPTool', () => {
const tool = new DiscoveredMCPTool(
mockCallableToolInstance,
serverName,
toolNameForModel,
serverToolName,
baseDescription,
inputSchema,
serverToolName,
);
const params = { param: 'failCase' };
const expectedError = new Error('MCP call failed');
@@ -179,10 +195,9 @@ describe('DiscoveredMCPTool', () => {
const tool = new DiscoveredMCPTool(
mockCallableToolInstance,
serverName,
toolNameForModel,
serverToolName,
baseDescription,
inputSchema,
serverToolName,
undefined,
true,
);
@@ -196,10 +211,9 @@ describe('DiscoveredMCPTool', () => {
const tool = new DiscoveredMCPTool(
mockCallableToolInstance,
serverName,
toolNameForModel,
serverToolName,
baseDescription,
inputSchema,
serverToolName,
);
expect(
await tool.shouldConfirmExecute({}, new AbortController().signal),
@@ -212,10 +226,9 @@ describe('DiscoveredMCPTool', () => {
const tool = new DiscoveredMCPTool(
mockCallableToolInstance,
serverName,
toolNameForModel,
serverToolName,
baseDescription,
inputSchema,
serverToolName,
);
expect(
await tool.shouldConfirmExecute({}, new AbortController().signal),
@@ -226,10 +239,9 @@ describe('DiscoveredMCPTool', () => {
const tool = new DiscoveredMCPTool(
mockCallableToolInstance,
serverName,
toolNameForModel,
serverToolName,
baseDescription,
inputSchema,
serverToolName,
);
const confirmation = await tool.shouldConfirmExecute(
{},
@@ -257,10 +269,9 @@ describe('DiscoveredMCPTool', () => {
const tool = new DiscoveredMCPTool(
mockCallableToolInstance,
serverName,
toolNameForModel,
serverToolName,
baseDescription,
inputSchema,
serverToolName,
);
const confirmation = await tool.shouldConfirmExecute(
{},
@@ -288,10 +299,9 @@ describe('DiscoveredMCPTool', () => {
const tool = new DiscoveredMCPTool(
mockCallableToolInstance,
serverName,
toolNameForModel,
serverToolName,
baseDescription,
inputSchema,
serverToolName,
);
const toolAllowlistKey = `${serverName}.${serverToolName}`;
const confirmation = await tool.shouldConfirmExecute(
@@ -315,5 +325,77 @@ describe('DiscoveredMCPTool', () => {
);
}
});
it('should handle Cancel confirmation outcome', async () => {
const tool = new DiscoveredMCPTool(
mockCallableToolInstance,
serverName,
serverToolName,
baseDescription,
inputSchema,
);
const confirmation = await tool.shouldConfirmExecute(
{},
new AbortController().signal,
);
expect(confirmation).not.toBe(false);
if (
confirmation &&
typeof confirmation === 'object' &&
'onConfirm' in confirmation &&
typeof confirmation.onConfirm === 'function'
) {
// Cancel should not add anything to allowlist
await confirmation.onConfirm(ToolConfirmationOutcome.Cancel);
expect((DiscoveredMCPTool as any).allowlist.has(serverName)).toBe(
false,
);
expect(
(DiscoveredMCPTool as any).allowlist.has(
`${serverName}.${serverToolName}`,
),
).toBe(false);
} else {
throw new Error(
'Confirmation details or onConfirm not in expected format',
);
}
});
it('should handle ProceedOnce confirmation outcome', async () => {
const tool = new DiscoveredMCPTool(
mockCallableToolInstance,
serverName,
serverToolName,
baseDescription,
inputSchema,
);
const confirmation = await tool.shouldConfirmExecute(
{},
new AbortController().signal,
);
expect(confirmation).not.toBe(false);
if (
confirmation &&
typeof confirmation === 'object' &&
'onConfirm' in confirmation &&
typeof confirmation.onConfirm === 'function'
) {
// ProceedOnce should not add anything to allowlist
await confirmation.onConfirm(ToolConfirmationOutcome.ProceedOnce);
expect((DiscoveredMCPTool as any).allowlist.has(serverName)).toBe(
false,
);
expect(
(DiscoveredMCPTool as any).allowlist.has(
`${serverName}.${serverToolName}`,
),
).toBe(false);
} else {
throw new Error(
'Confirmation details or onConfirm not in expected format',
);
}
});
});
});

View File

@@ -10,8 +10,15 @@ import {
ToolCallConfirmationDetails,
ToolConfirmationOutcome,
ToolMcpConfirmationDetails,
Icon,
} from './tools.js';
import { CallableTool, Part, FunctionCall, Schema } from '@google/genai';
import {
CallableTool,
Part,
FunctionCall,
FunctionDeclaration,
Type,
} from '@google/genai';
type ToolParams = Record<string, unknown>;
@@ -21,23 +28,49 @@ export class DiscoveredMCPTool extends BaseTool<ToolParams, ToolResult> {
constructor(
private readonly mcpTool: CallableTool,
readonly serverName: string,
readonly name: string,
readonly description: string,
readonly parameterSchema: Schema,
readonly serverToolName: string,
description: string,
readonly parameterSchemaJson: unknown,
readonly timeout?: number,
readonly trust?: boolean,
nameOverride?: string,
) {
super(
name,
nameOverride ?? generateValidName(serverToolName),
`${serverToolName} (${serverName} MCP Server)`,
description,
parameterSchema,
Icon.Hammer,
{ type: Type.OBJECT }, // this is a dummy Schema for MCP, will be not be used to construct the FunctionDeclaration
true, // isOutputMarkdown
false, // canUpdateOutput
);
}
asFullyQualifiedTool(): DiscoveredMCPTool {
return new DiscoveredMCPTool(
this.mcpTool,
this.serverName,
this.serverToolName,
this.description,
this.parameterSchemaJson,
this.timeout,
this.trust,
`${this.serverName}__${this.serverToolName}`,
);
}
/**
* Overrides the base schema to use parametersJsonSchema when building
* FunctionDeclaration
*/
override get schema(): FunctionDeclaration {
return {
name: this.name,
description: this.description,
parametersJsonSchema: this.parameterSchemaJson,
};
}
async shouldConfirmExecute(
_params: ToolParams,
_abortSignal: AbortSignal,
@@ -53,7 +86,7 @@ export class DiscoveredMCPTool extends BaseTool<ToolParams, ToolResult> {
DiscoveredMCPTool.allowlist.has(serverAllowListKey) ||
DiscoveredMCPTool.allowlist.has(toolAllowListKey)
) {
return false; // server and/or tool already allow listed
return false; // server and/or tool already allowlisted
}
const confirmationDetails: ToolMcpConfirmationDetails = {
@@ -146,3 +179,17 @@ function getStringifiedResultForDisplay(result: Part[]) {
return '```json\n' + JSON.stringify(processedResults, null, 2) + '\n```';
}
/** Visible for testing */
export function generateValidName(name: string) {
// Replace invalid characters (based on 400 error message from Gemini API) with underscores
let validToolname = name.replace(/[^a-zA-Z0-9_.-]/g, '_');
// If longer than 63 characters, replace middle with '___'
// (Gemini API says max length 64, but actual limit seems to be 63)
if (validToolname.length > 63) {
validToolname =
validToolname.slice(0, 28) + '___' + validToolname.slice(-32);
}
return validToolname;
}

View File

@@ -87,7 +87,7 @@ describe('MemoryTool', () => {
describe('performAddMemoryEntry (static method)', () => {
const testFilePath = path.join(
'/mock/home',
'.qwen',
'.gemini',
DEFAULT_CONTEXT_FILENAME, // Use the default for basic tests
);
@@ -207,7 +207,7 @@ describe('MemoryTool', () => {
// Use getCurrentGeminiMdFilename for the default expectation before any setGeminiMdFilename calls in a test
const expectedFilePath = path.join(
'/mock/home',
'.qwen',
'.gemini',
getCurrentGeminiMdFilename(), // This will be DEFAULT_CONTEXT_FILENAME unless changed by a test
);

View File

@@ -4,7 +4,7 @@
* SPDX-License-Identifier: Apache-2.0
*/
import { BaseTool, ToolResult } from './tools.js';
import { BaseTool, Icon, ToolResult } from './tools.js';
import { FunctionDeclaration, Type } from '@google/genai';
import * as fs from 'fs/promises';
import * as path from 'path';
@@ -46,8 +46,8 @@ Do NOT use this tool:
- \`fact\` (string, required): The specific fact or piece of information to remember. This should be a clear, self-contained statement. For example, if the user says "My favorite color is blue", the fact would be "My favorite color is blue".
`;
export const GEMINI_CONFIG_DIR = '.qwen';
export const DEFAULT_CONTEXT_FILENAME = 'QWEN.md';
export const GEMINI_CONFIG_DIR = '.gemini';
export const DEFAULT_CONTEXT_FILENAME = 'GEMINI.md';
export const MEMORY_SECTION_HEADER = '## Gemini Added Memories';
// This variable will hold the currently configured filename for GEMINI.md context files.
@@ -105,6 +105,7 @@ export class MemoryTool extends BaseTool<SaveMemoryParams, ToolResult> {
MemoryTool.Name,
'Save Memory',
memoryToolDescription,
Icon.LightBulb,
memoryToolSchemaData.parameters as Record<string, unknown>,
);
}

View File

@@ -4,15 +4,7 @@
* SPDX-License-Identifier: Apache-2.0
*/
import {
vi,
describe,
it,
expect,
beforeEach,
afterEach,
type Mock,
} from 'vitest';
import { vi, describe, it, expect, beforeEach, afterEach } from 'vitest';
import {
modifyWithEditor,
ModifyContext,
@@ -21,6 +13,7 @@ import {
} from './modifiable-tool.js';
import { EditorType } from '../utils/editor.js';
import fs from 'fs';
import fsp from 'fs/promises';
import os from 'os';
import * as path from 'path';
@@ -36,9 +29,6 @@ vi.mock('diff', () => ({
createPatch: mockCreatePatch,
}));
vi.mock('fs');
vi.mock('os');
interface TestParams {
filePath: string;
someOtherParam: string;
@@ -46,7 +36,7 @@ interface TestParams {
}
describe('modifyWithEditor', () => {
let tempDir: string;
let testProjectDir: string;
let mockModifyContext: ModifyContext<TestParams>;
let mockParams: TestParams;
let currentContent: string;
@@ -54,17 +44,19 @@ describe('modifyWithEditor', () => {
let modifiedContent: string;
let abortSignal: AbortSignal;
beforeEach(() => {
beforeEach(async () => {
vi.resetAllMocks();
tempDir = '/tmp/test-dir';
testProjectDir = await fsp.mkdtemp(
path.join(os.tmpdir(), 'modifiable-tool-test-'),
);
abortSignal = new AbortController().signal;
currentContent = 'original content\nline 2\nline 3';
proposedContent = 'modified content\nline 2\nline 3';
modifiedContent = 'user modified content\nline 2\nline 3\nnew line';
mockParams = {
filePath: path.join(tempDir, 'test.txt'),
filePath: path.join(testProjectDir, 'test.txt'),
someOtherParam: 'value',
};
@@ -81,26 +73,18 @@ describe('modifyWithEditor', () => {
})),
};
(os.tmpdir as Mock).mockReturnValue(tempDir);
(fs.existsSync as Mock).mockReturnValue(true);
(fs.mkdirSync as Mock).mockImplementation(() => undefined);
(fs.writeFileSync as Mock).mockImplementation(() => {});
(fs.unlinkSync as Mock).mockImplementation(() => {});
(fs.readFileSync as Mock).mockImplementation((filePath: string) => {
if (filePath.includes('-new-')) {
return modifiedContent;
}
return currentContent;
mockOpenDiff.mockImplementation(async (_oldPath, newPath) => {
await fsp.writeFile(newPath, modifiedContent, 'utf8');
});
mockCreatePatch.mockReturnValue('mock diff content');
mockOpenDiff.mockResolvedValue(undefined);
});
afterEach(() => {
afterEach(async () => {
vi.restoreAllMocks();
await fsp.rm(testProjectDir, { recursive: true, force: true });
const diffDir = path.join(os.tmpdir(), 'gemini-cli-tool-modify-diffs');
await fsp.rm(diffDir, { recursive: true, force: true });
});
describe('successful modification', () => {
@@ -120,38 +104,8 @@ describe('modifyWithEditor', () => {
);
expect(mockModifyContext.getFilePath).toHaveBeenCalledWith(mockParams);
expect(fs.writeFileSync).toHaveBeenCalledTimes(2);
expect(fs.writeFileSync).toHaveBeenNthCalledWith(
1,
expect.stringContaining(
path.join(tempDir, 'gemini-cli-tool-modify-diffs'),
),
currentContent,
'utf8',
);
expect(fs.writeFileSync).toHaveBeenNthCalledWith(
2,
expect.stringContaining(
path.join(tempDir, 'gemini-cli-tool-modify-diffs'),
),
proposedContent,
'utf8',
);
expect(mockOpenDiff).toHaveBeenCalledWith(
expect.stringContaining('-old-'),
expect.stringContaining('-new-'),
'vscode',
);
expect(fs.readFileSync).toHaveBeenCalledWith(
expect.stringContaining('-old-'),
'utf8',
);
expect(fs.readFileSync).toHaveBeenCalledWith(
expect.stringContaining('-new-'),
'utf8',
);
expect(mockOpenDiff).toHaveBeenCalledOnce();
const [oldFilePath, newFilePath] = mockOpenDiff.mock.calls[0];
expect(mockModifyContext.createUpdatedParams).toHaveBeenCalledWith(
currentContent,
@@ -171,15 +125,9 @@ describe('modifyWithEditor', () => {
}),
);
expect(fs.unlinkSync).toHaveBeenCalledTimes(2);
expect(fs.unlinkSync).toHaveBeenNthCalledWith(
1,
expect.stringContaining('-old-'),
);
expect(fs.unlinkSync).toHaveBeenNthCalledWith(
2,
expect.stringContaining('-new-'),
);
// Check that temp files are deleted.
await expect(fsp.access(oldFilePath)).rejects.toThrow();
await expect(fsp.access(newFilePath)).rejects.toThrow();
expect(result).toEqual({
updatedParams: {
@@ -192,7 +140,8 @@ describe('modifyWithEditor', () => {
});
it('should create temp directory if it does not exist', async () => {
(fs.existsSync as Mock).mockReturnValue(false);
const diffDir = path.join(os.tmpdir(), 'gemini-cli-tool-modify-diffs');
await fsp.rm(diffDir, { recursive: true, force: true }).catch(() => {});
await modifyWithEditor(
mockParams,
@@ -201,14 +150,15 @@ describe('modifyWithEditor', () => {
abortSignal,
);
expect(fs.mkdirSync).toHaveBeenCalledWith(
path.join(tempDir, 'gemini-cli-tool-modify-diffs'),
{ recursive: true },
);
const stats = await fsp.stat(diffDir);
expect(stats.isDirectory()).toBe(true);
});
it('should not create temp directory if it already exists', async () => {
(fs.existsSync as Mock).mockReturnValue(true);
const diffDir = path.join(os.tmpdir(), 'gemini-cli-tool-modify-diffs');
await fsp.mkdir(diffDir, { recursive: true });
const mkdirSpy = vi.spyOn(fs, 'mkdirSync');
await modifyWithEditor(
mockParams,
@@ -217,18 +167,15 @@ describe('modifyWithEditor', () => {
abortSignal,
);
expect(fs.mkdirSync).not.toHaveBeenCalled();
expect(mkdirSpy).not.toHaveBeenCalled();
mkdirSpy.mockRestore();
});
});
it('should handle missing old temp file gracefully', async () => {
(fs.readFileSync as Mock).mockImplementation((filePath: string) => {
if (filePath.includes('-old-')) {
const error = new Error('ENOENT: no such file or directory');
(error as NodeJS.ErrnoException).code = 'ENOENT';
throw error;
}
return modifiedContent;
mockOpenDiff.mockImplementation(async (oldPath, newPath) => {
await fsp.writeFile(newPath, modifiedContent, 'utf8');
await fsp.unlink(oldPath);
});
const result = await modifyWithEditor(
@@ -255,13 +202,8 @@ describe('modifyWithEditor', () => {
});
it('should handle missing new temp file gracefully', async () => {
(fs.readFileSync as Mock).mockImplementation((filePath: string) => {
if (filePath.includes('-new-')) {
const error = new Error('ENOENT: no such file or directory');
(error as NodeJS.ErrnoException).code = 'ENOENT';
throw error;
}
return currentContent;
mockOpenDiff.mockImplementation(async (_oldPath, newPath) => {
await fsp.unlink(newPath);
});
const result = await modifyWithEditor(
@@ -291,6 +233,8 @@ describe('modifyWithEditor', () => {
const editorError = new Error('Editor failed to open');
mockOpenDiff.mockRejectedValue(editorError);
const writeSpy = vi.spyOn(fs, 'writeFileSync');
await expect(
modifyWithEditor(
mockParams,
@@ -300,14 +244,21 @@ describe('modifyWithEditor', () => {
),
).rejects.toThrow('Editor failed to open');
expect(fs.unlinkSync).toHaveBeenCalledTimes(2);
expect(writeSpy).toHaveBeenCalledTimes(2);
const oldFilePath = writeSpy.mock.calls[0][0] as string;
const newFilePath = writeSpy.mock.calls[1][0] as string;
await expect(fsp.access(oldFilePath)).rejects.toThrow();
await expect(fsp.access(newFilePath)).rejects.toThrow();
writeSpy.mockRestore();
});
it('should handle temp file cleanup errors gracefully', async () => {
const consoleErrorSpy = vi
.spyOn(console, 'error')
.mockImplementation(() => {});
(fs.unlinkSync as Mock).mockImplementation((_filePath: string) => {
vi.spyOn(fs, 'unlinkSync').mockImplementation(() => {
throw new Error('Failed to delete file');
});
@@ -327,7 +278,11 @@ describe('modifyWithEditor', () => {
});
it('should create temp files with correct naming with extension', async () => {
const testFilePath = path.join(tempDir, 'subfolder', 'test-file.txt');
const testFilePath = path.join(
testProjectDir,
'subfolder',
'test-file.txt',
);
mockModifyContext.getFilePath = vi.fn().mockReturnValue(testFilePath);
await modifyWithEditor(
@@ -337,20 +292,18 @@ describe('modifyWithEditor', () => {
abortSignal,
);
const writeFileCalls = (fs.writeFileSync as Mock).mock.calls;
expect(writeFileCalls).toHaveLength(2);
const oldFilePath = writeFileCalls[0][0];
const newFilePath = writeFileCalls[1][0];
expect(mockOpenDiff).toHaveBeenCalledOnce();
const [oldFilePath, newFilePath] = mockOpenDiff.mock.calls[0];
expect(oldFilePath).toMatch(/gemini-cli-modify-test-file-old-\d+\.txt$/);
expect(newFilePath).toMatch(/gemini-cli-modify-test-file-new-\d+\.txt$/);
expect(oldFilePath).toContain(`${tempDir}/gemini-cli-tool-modify-diffs/`);
expect(newFilePath).toContain(`${tempDir}/gemini-cli-tool-modify-diffs/`);
const diffDir = path.join(os.tmpdir(), 'gemini-cli-tool-modify-diffs');
expect(path.dirname(oldFilePath)).toBe(diffDir);
expect(path.dirname(newFilePath)).toBe(diffDir);
});
it('should create temp files with correct naming without extension', async () => {
const testFilePath = path.join(tempDir, 'subfolder', 'test-file');
const testFilePath = path.join(testProjectDir, 'subfolder', 'test-file');
mockModifyContext.getFilePath = vi.fn().mockReturnValue(testFilePath);
await modifyWithEditor(
@@ -360,16 +313,14 @@ describe('modifyWithEditor', () => {
abortSignal,
);
const writeFileCalls = (fs.writeFileSync as Mock).mock.calls;
expect(writeFileCalls).toHaveLength(2);
const oldFilePath = writeFileCalls[0][0];
const newFilePath = writeFileCalls[1][0];
expect(mockOpenDiff).toHaveBeenCalledOnce();
const [oldFilePath, newFilePath] = mockOpenDiff.mock.calls[0];
expect(oldFilePath).toMatch(/gemini-cli-modify-test-file-old-\d+$/);
expect(newFilePath).toMatch(/gemini-cli-modify-test-file-new-\d+$/);
expect(oldFilePath).toContain(`${tempDir}/gemini-cli-tool-modify-diffs/`);
expect(newFilePath).toContain(`${tempDir}/gemini-cli-tool-modify-diffs/`);
const diffDir = path.join(os.tmpdir(), 'gemini-cli-tool-modify-diffs');
expect(path.dirname(oldFilePath)).toBe(diffDir);
expect(path.dirname(newFilePath)).toBe(diffDir);
});
});

View File

@@ -4,54 +4,37 @@
* SPDX-License-Identifier: Apache-2.0
*/
import { vi, describe, it, expect, beforeEach, afterEach, Mock } from 'vitest';
import { describe, it, expect, beforeEach, afterEach } from 'vitest';
import { ReadFileTool, ReadFileToolParams } from './read-file.js';
import * as fileUtils from '../utils/fileUtils.js';
import path from 'path';
import os from 'os';
import fs from 'fs'; // For actual fs operations in setup
import fs from 'fs';
import fsp from 'fs/promises';
import { Config } from '../config/config.js';
import { FileDiscoveryService } from '../services/fileDiscoveryService.js';
// Mock fileUtils.processSingleFileContent
vi.mock('../utils/fileUtils', async () => {
const actualFileUtils =
await vi.importActual<typeof fileUtils>('../utils/fileUtils');
return {
...actualFileUtils, // Spread actual implementations
processSingleFileContent: vi.fn(), // Mock specific function
};
});
const mockProcessSingleFileContent = fileUtils.processSingleFileContent as Mock;
describe('ReadFileTool', () => {
let tempRootDir: string;
let tool: ReadFileTool;
const abortSignal = new AbortController().signal;
beforeEach(() => {
beforeEach(async () => {
// Create a unique temporary root directory for each test run
tempRootDir = fs.mkdtempSync(
tempRootDir = await fsp.mkdtemp(
path.join(os.tmpdir(), 'read-file-tool-root-'),
);
fs.writeFileSync(
path.join(tempRootDir, '.geminiignore'),
['foo.*'].join('\n'),
);
const fileService = new FileDiscoveryService(tempRootDir);
const mockConfigInstance = {
getFileService: () => fileService,
getFileService: () => new FileDiscoveryService(tempRootDir),
getTargetDir: () => tempRootDir,
} as unknown as Config;
tool = new ReadFileTool(mockConfigInstance);
mockProcessSingleFileContent.mockReset();
});
afterEach(() => {
afterEach(async () => {
// Clean up the temporary root directory
if (fs.existsSync(tempRootDir)) {
fs.rmSync(tempRootDir, { recursive: true, force: true });
await fsp.rm(tempRootDir, { recursive: true, force: true });
}
});
@@ -129,9 +112,9 @@ describe('ReadFileTool', () => {
it('should return a shortened, relative path', () => {
const filePath = path.join(tempRootDir, 'sub', 'dir', 'file.txt');
const params: ReadFileToolParams = { absolute_path: filePath };
// Assuming tempRootDir is something like /tmp/read-file-tool-root-XXXXXX
// The relative path would be sub/dir/file.txt
expect(tool.getDescription(params)).toBe('sub/dir/file.txt');
expect(tool.getDescription(params)).toBe(
path.join('sub', 'dir', 'file.txt'),
);
});
it('should return . if path is the root directory', () => {
@@ -142,111 +125,140 @@ describe('ReadFileTool', () => {
describe('execute', () => {
it('should return validation error if params are invalid', async () => {
const params: ReadFileToolParams = { absolute_path: 'relative/path.txt' };
const result = await tool.execute(params, abortSignal);
expect(result.llmContent).toBe(
'Error: Invalid parameters provided. Reason: File path must be absolute, but was relative: relative/path.txt. You must provide an absolute path.',
);
expect(result.returnDisplay).toBe(
'File path must be absolute, but was relative: relative/path.txt. You must provide an absolute path.',
);
const params: ReadFileToolParams = {
absolute_path: 'relative/path.txt',
};
expect(await tool.execute(params, abortSignal)).toEqual({
llmContent:
'Error: Invalid parameters provided. Reason: File path must be absolute, but was relative: relative/path.txt. You must provide an absolute path.',
returnDisplay:
'File path must be absolute, but was relative: relative/path.txt. You must provide an absolute path.',
});
});
it('should return error from processSingleFileContent if it fails', async () => {
const filePath = path.join(tempRootDir, 'error.txt');
it('should return error if file does not exist', async () => {
const filePath = path.join(tempRootDir, 'nonexistent.txt');
const params: ReadFileToolParams = { absolute_path: filePath };
const errorMessage = 'Simulated read error';
mockProcessSingleFileContent.mockResolvedValue({
llmContent: `Error reading file ${filePath}: ${errorMessage}`,
returnDisplay: `Error reading file ${filePath}: ${errorMessage}`,
error: errorMessage,
});
const result = await tool.execute(params, abortSignal);
expect(mockProcessSingleFileContent).toHaveBeenCalledWith(
filePath,
tempRootDir,
undefined,
undefined,
);
expect(result.llmContent).toContain(errorMessage);
expect(result.returnDisplay).toContain(errorMessage);
expect(await tool.execute(params, abortSignal)).toEqual({
llmContent: `File not found: ${filePath}`,
returnDisplay: 'File not found.',
});
});
it('should return success result for a text file', async () => {
const filePath = path.join(tempRootDir, 'textfile.txt');
const fileContent = 'This is a test file.';
await fsp.writeFile(filePath, fileContent, 'utf-8');
const params: ReadFileToolParams = { absolute_path: filePath };
mockProcessSingleFileContent.mockResolvedValue({
llmContent: fileContent,
returnDisplay: `Read text file: ${path.basename(filePath)}`,
});
const result = await tool.execute(params, abortSignal);
expect(mockProcessSingleFileContent).toHaveBeenCalledWith(
filePath,
tempRootDir,
undefined,
undefined,
);
expect(result.llmContent).toBe(fileContent);
expect(result.returnDisplay).toBe(
`Read text file: ${path.basename(filePath)}`,
);
expect(await tool.execute(params, abortSignal)).toEqual({
llmContent: fileContent,
returnDisplay: '',
});
});
it('should return success result for an image file', async () => {
// A minimal 1x1 transparent PNG file.
const pngContent = Buffer.from([
137, 80, 78, 71, 13, 10, 26, 10, 0, 0, 0, 13, 73, 72, 68, 82, 0, 0, 0,
1, 0, 0, 0, 1, 8, 6, 0, 0, 0, 31, 21, 196, 137, 0, 0, 0, 10, 73, 68, 65,
84, 120, 156, 99, 0, 1, 0, 0, 5, 0, 1, 13, 10, 45, 180, 0, 0, 0, 0, 73,
69, 78, 68, 174, 66, 96, 130,
]);
const filePath = path.join(tempRootDir, 'image.png');
const imageData = {
inlineData: { mimeType: 'image/png', data: 'base64...' },
};
await fsp.writeFile(filePath, pngContent);
const params: ReadFileToolParams = { absolute_path: filePath };
mockProcessSingleFileContent.mockResolvedValue({
llmContent: imageData,
returnDisplay: `Read image file: ${path.basename(filePath)}`,
});
const result = await tool.execute(params, abortSignal);
expect(mockProcessSingleFileContent).toHaveBeenCalledWith(
filePath,
tempRootDir,
undefined,
undefined,
);
expect(result.llmContent).toEqual(imageData);
expect(result.returnDisplay).toBe(
`Read image file: ${path.basename(filePath)}`,
);
expect(await tool.execute(params, abortSignal)).toEqual({
llmContent: {
inlineData: {
mimeType: 'image/png',
data: pngContent.toString('base64'),
},
},
returnDisplay: `Read image file: image.png`,
});
});
it('should pass offset and limit to processSingleFileContent', async () => {
it('should treat a non-image file with image extension as an image', async () => {
const filePath = path.join(tempRootDir, 'fake-image.png');
const fileContent = 'This is not a real png.';
await fsp.writeFile(filePath, fileContent, 'utf-8');
const params: ReadFileToolParams = { absolute_path: filePath };
expect(await tool.execute(params, abortSignal)).toEqual({
llmContent: {
inlineData: {
mimeType: 'image/png',
data: Buffer.from(fileContent).toString('base64'),
},
},
returnDisplay: `Read image file: fake-image.png`,
});
});
it('should pass offset and limit to read a slice of a text file', async () => {
const filePath = path.join(tempRootDir, 'paginated.txt');
const fileContent = Array.from(
{ length: 20 },
(_, i) => `Line ${i + 1}`,
).join('\n');
await fsp.writeFile(filePath, fileContent, 'utf-8');
const params: ReadFileToolParams = {
absolute_path: filePath,
offset: 10,
limit: 5,
offset: 5, // Start from line 6
limit: 3,
};
mockProcessSingleFileContent.mockResolvedValue({
llmContent: 'some lines',
returnDisplay: 'Read text file (paginated)',
});
await tool.execute(params, abortSignal);
expect(mockProcessSingleFileContent).toHaveBeenCalledWith(
filePath,
tempRootDir,
10,
5,
);
expect(await tool.execute(params, abortSignal)).toEqual({
llmContent: [
'[File content truncated: showing lines 6-8 of 20 total lines. Use offset/limit parameters to view more.]',
'Line 6',
'Line 7',
'Line 8',
].join('\n'),
returnDisplay: '(truncated)',
});
});
it('should return error if path is ignored by a .geminiignore pattern', async () => {
const params: ReadFileToolParams = {
absolute_path: path.join(tempRootDir, 'foo.bar'),
};
const result = await tool.execute(params, abortSignal);
expect(result.returnDisplay).toContain('foo.bar');
expect(result.returnDisplay).not.toContain('foo.baz');
describe('with .geminiignore', () => {
beforeEach(async () => {
await fsp.writeFile(
path.join(tempRootDir, '.geminiignore'),
['foo.*', 'ignored/'].join('\n'),
);
});
it('should return error if path is ignored by a .geminiignore pattern', async () => {
const ignoredFilePath = path.join(tempRootDir, 'foo.bar');
await fsp.writeFile(ignoredFilePath, 'content', 'utf-8');
const params: ReadFileToolParams = {
absolute_path: ignoredFilePath,
};
const expectedError = `File path '${ignoredFilePath}' is ignored by .geminiignore pattern(s).`;
expect(await tool.execute(params, abortSignal)).toEqual({
llmContent: `Error: Invalid parameters provided. Reason: ${expectedError}`,
returnDisplay: expectedError,
});
});
it('should return error if path is in an ignored directory', async () => {
const ignoredDirPath = path.join(tempRootDir, 'ignored');
await fsp.mkdir(ignoredDirPath);
const filePath = path.join(ignoredDirPath, 'somefile.txt');
await fsp.writeFile(filePath, 'content', 'utf-8');
const params: ReadFileToolParams = {
absolute_path: filePath,
};
const expectedError = `File path '${filePath}' is ignored by .geminiignore pattern(s).`;
expect(await tool.execute(params, abortSignal)).toEqual({
llmContent: `Error: Invalid parameters provided. Reason: ${expectedError}`,
returnDisplay: expectedError,
});
});
});
});
});

View File

@@ -7,7 +7,7 @@
import path from 'path';
import { SchemaValidator } from '../utils/schemaValidator.js';
import { makeRelative, shortenPath } from '../utils/paths.js';
import { BaseTool, ToolResult } from './tools.js';
import { BaseTool, Icon, ToolLocation, ToolResult } from './tools.js';
import { Type } from '@google/genai';
import {
isWithinRoot,
@@ -51,6 +51,7 @@ export class ReadFileTool extends BaseTool<ReadFileToolParams, ToolResult> {
ReadFileTool.Name,
'ReadFile',
'Reads and returns the content of a specified file from the local filesystem. Handles text, images (PNG, JPG, GIF, WEBP, SVG, BMP), and PDF files. For text files, it can read specific line ranges.',
Icon.FileSearch,
{
properties: {
absolute_path: {
@@ -118,6 +119,10 @@ export class ReadFileTool extends BaseTool<ReadFileToolParams, ToolResult> {
return shortenPath(relativePath);
}
toolLocations(params: ReadFileToolParams): ToolLocation[] {
return [{ path: params.absolute_path, line: params.offset }];
}
async execute(
params: ReadFileToolParams,
_signal: AbortSignal,

View File

@@ -58,10 +58,13 @@ describe('ReadManyFilesTool', () => {
const fileService = new FileDiscoveryService(tempRootDir);
const mockConfig = {
getFileService: () => fileService,
getFileFilteringRespectGitIgnore: () => true,
getFileFilteringOptions: () => ({
respectGitIgnore: true,
respectGeminiIgnore: true,
}),
getTargetDir: () => tempRootDir,
} as Partial<Config> as Config;
tool = new ReadManyFilesTool(mockConfig);
mockReadFileFn = mockControl.mockReadFile;
@@ -269,7 +272,7 @@ describe('ReadManyFilesTool', () => {
);
});
it('should handle non-existent specific files gracefully', async () => {
it('should handle nonexistent specific files gracefully', async () => {
const params = { paths: ['nonexistent-file.txt'] };
const result = await tool.execute(params, new AbortController().signal);
expect(result.llmContent).toEqual([

View File

@@ -4,7 +4,7 @@
* SPDX-License-Identifier: Apache-2.0
*/
import { BaseTool, ToolResult } from './tools.js';
import { BaseTool, Icon, ToolResult } from './tools.js';
import { SchemaValidator } from '../utils/schemaValidator.js';
import { getErrorMessage } from '../utils/errors.js';
import * as path from 'path';
@@ -17,7 +17,7 @@ import {
getSpecificMimeType,
} from '../utils/fileUtils.js';
import { PartListUnion, Schema, Type } from '@google/genai';
import { Config } from '../config/config.js';
import { Config, DEFAULT_FILE_FILTERING_OPTIONS } from '../config/config.js';
import {
recordFileOperationMetric,
FileOperation,
@@ -62,9 +62,12 @@ export interface ReadManyFilesParams {
useDefaultExcludes?: boolean;
/**
* Optional. Whether to respect .gitignore patterns. Defaults to true.
* Whether to respect .gitignore and .geminiignore patterns (optional, defaults to true)
*/
respect_git_ignore?: boolean;
file_filtering_options?: {
respect_git_ignore?: boolean;
respect_gemini_ignore?: boolean;
};
}
/**
@@ -125,8 +128,6 @@ export class ReadManyFilesTool extends BaseTool<
> {
static readonly Name: string = 'read_many_files';
private readonly geminiIgnorePatterns: string[] = [];
constructor(private config: Config) {
const parameterSchema: Schema = {
type: Type.OBJECT,
@@ -173,11 +174,22 @@ export class ReadManyFilesTool extends BaseTool<
'Optional. Whether to apply a list of default exclusion patterns (e.g., node_modules, .git, binary files). Defaults to true.',
default: true,
},
respect_git_ignore: {
type: Type.BOOLEAN,
file_filtering_options: {
description:
'Optional. Whether to respect .gitignore patterns when discovering files. Only available in git repositories. Defaults to true.',
default: true,
'Whether to respect ignore patterns from .gitignore or .geminiignore',
type: Type.OBJECT,
properties: {
respect_git_ignore: {
description:
'Optional: Whether to respect .gitignore patterns when listing files. Only available in git repositories. Defaults to true.',
type: Type.BOOLEAN,
},
respect_gemini_ignore: {
description:
'Optional: Whether to respect .geminiignore patterns when listing files. Defaults to true.',
type: Type.BOOLEAN,
},
},
},
},
required: ['paths'],
@@ -196,11 +208,9 @@ This tool is useful when you need to understand or analyze a collection of files
- When the user asks to "read all files in X directory" or "show me the content of all Y files".
Use this tool when the user's query implies needing the content of several files simultaneously for context, analysis, or summarization. For text files, it uses default UTF-8 encoding and a '--- {filePath} ---' separator between file contents. Ensure paths are relative to the target directory. Glob patterns like 'src/**/*.js' are supported. Avoid using for single files if a more specific single-file reading tool is available, unless the user specifically requests to process a list containing just one file via this tool. Other binary files (not explicitly requested as image/PDF) are generally skipped. Default excludes apply to common non-text files (except for explicitly requested images/PDFs) and large dependency directories unless 'useDefaultExcludes' is false.`,
Icon.FileSearch,
parameterSchema,
);
this.geminiIgnorePatterns = config
.getFileService()
.getGeminiIgnorePatterns();
}
validateParams(params: ReadManyFilesParams): string | null {
@@ -218,17 +228,19 @@ Use this tool when the user's query implies needing the content of several files
// Determine the final list of exclusion patterns exactly as in execute method
const paramExcludes = params.exclude || [];
const paramUseDefaultExcludes = params.useDefaultExcludes !== false;
const geminiIgnorePatterns = this.config
.getFileService()
.getGeminiIgnorePatterns();
const finalExclusionPatternsForDescription: string[] =
paramUseDefaultExcludes
? [...DEFAULT_EXCLUDES, ...paramExcludes, ...this.geminiIgnorePatterns]
: [...paramExcludes, ...this.geminiIgnorePatterns];
? [...DEFAULT_EXCLUDES, ...paramExcludes, ...geminiIgnorePatterns]
: [...paramExcludes, ...geminiIgnorePatterns];
let excludeDesc = `Excluding: ${finalExclusionPatternsForDescription.length > 0 ? `patterns like \`${finalExclusionPatternsForDescription.slice(0, 2).join('`, `')}${finalExclusionPatternsForDescription.length > 2 ? '...`' : '`'}` : 'none specified'}`;
// Add a note if .geminiignore patterns contributed to the final list of exclusions
if (this.geminiIgnorePatterns.length > 0) {
const geminiPatternsInEffect = this.geminiIgnorePatterns.filter((p) =>
if (geminiIgnorePatterns.length > 0) {
const geminiPatternsInEffect = geminiIgnorePatterns.filter((p) =>
finalExclusionPatternsForDescription.includes(p),
).length;
if (geminiPatternsInEffect > 0) {
@@ -256,12 +268,19 @@ Use this tool when the user's query implies needing the content of several files
include = [],
exclude = [],
useDefaultExcludes = true,
respect_git_ignore = true,
} = params;
const respectGitIgnore =
respect_git_ignore ?? this.config.getFileFilteringRespectGitIgnore();
const defaultFileIgnores =
this.config.getFileFilteringOptions() ?? DEFAULT_FILE_FILTERING_OPTIONS;
const fileFilteringOptions = {
respectGitIgnore:
params.file_filtering_options?.respect_git_ignore ??
defaultFileIgnores.respectGitIgnore, // Use the property from the returned object
respectGeminiIgnore:
params.file_filtering_options?.respect_gemini_ignore ??
defaultFileIgnores.respectGeminiIgnore, // Use the property from the returned object
};
// Get centralized file discovery service
const fileDiscovery = this.config.getFileService();
@@ -271,8 +290,8 @@ Use this tool when the user's query implies needing the content of several files
const contentParts: PartListUnion = [];
const effectiveExcludes = useDefaultExcludes
? [...DEFAULT_EXCLUDES, ...exclude, ...this.geminiIgnorePatterns]
: [...exclude, ...this.geminiIgnorePatterns];
? [...DEFAULT_EXCLUDES, ...exclude]
: [...exclude];
const searchPatterns = [...inputPatterns, ...include];
if (searchPatterns.length === 0) {
@@ -283,7 +302,8 @@ Use this tool when the user's query implies needing the content of several files
}
try {
const entries = await glob(searchPatterns, {
const patterns = searchPatterns.map((p) => p.replace(/\\/g, '/'));
const entries: string[] = await glob(patterns, {
cwd: this.config.getTargetDir(),
ignore: effectiveExcludes,
nodir: true,
@@ -291,20 +311,39 @@ Use this tool when the user's query implies needing the content of several files
absolute: true,
nocase: true,
signal,
withFileTypes: false,
});
const filteredEntries = respectGitIgnore
const gitFilteredEntries = fileFilteringOptions.respectGitIgnore
? fileDiscovery
.filterFiles(
entries.map((p) => path.relative(this.config.getTargetDir(), p)),
{
respectGitIgnore,
respectGitIgnore: true,
respectGeminiIgnore: false,
},
)
.map((p) => path.resolve(this.config.getTargetDir(), p))
: entries;
// Apply gemini ignore filtering if enabled
const finalFilteredEntries = fileFilteringOptions.respectGeminiIgnore
? fileDiscovery
.filterFiles(
gitFilteredEntries.map((p) =>
path.relative(this.config.getTargetDir(), p),
),
{
respectGitIgnore: false,
respectGeminiIgnore: true,
},
)
.map((p) => path.resolve(this.config.getTargetDir(), p))
: gitFilteredEntries;
let gitIgnoredCount = 0;
let geminiIgnoredCount = 0;
for (const absoluteFilePath of entries) {
// Security check: ensure the glob library didn't return something outside targetDir.
if (!absoluteFilePath.startsWith(this.config.getTargetDir())) {
@@ -316,11 +355,23 @@ Use this tool when the user's query implies needing the content of several files
}
// Check if this file was filtered out by git ignore
if (respectGitIgnore && !filteredEntries.includes(absoluteFilePath)) {
if (
fileFilteringOptions.respectGitIgnore &&
!gitFilteredEntries.includes(absoluteFilePath)
) {
gitIgnoredCount++;
continue;
}
// Check if this file was filtered out by gemini ignore
if (
fileFilteringOptions.respectGeminiIgnore &&
!finalFilteredEntries.includes(absoluteFilePath)
) {
geminiIgnoredCount++;
continue;
}
filesToConsider.add(absoluteFilePath);
}
@@ -328,7 +379,15 @@ Use this tool when the user's query implies needing the content of several files
if (gitIgnoredCount > 0) {
skippedFiles.push({
path: `${gitIgnoredCount} file(s)`,
reason: 'ignored',
reason: 'git ignored',
});
}
// Add info about gemini-ignored files if any were filtered
if (geminiIgnoredCount > 0) {
skippedFiles.push({
path: `${geminiIgnoredCount} file(s)`,
reason: 'gemini ignored',
});
}
} catch (error) {
@@ -345,7 +404,7 @@ Use this tool when the user's query implies needing the content of several files
.relative(this.config.getTargetDir(), filePath)
.replace(/\\/g, '/');
const fileType = detectFileType(filePath);
const fileType = await detectFileType(filePath);
if (fileType === 'image' || fileType === 'pdf') {
const fileExtension = path.extname(filePath).toLowerCase();

View File

@@ -4,429 +4,384 @@
* SPDX-License-Identifier: Apache-2.0
*/
import { expect, describe, it, vi, beforeEach } from 'vitest';
import {
vi,
describe,
it,
expect,
beforeEach,
afterEach,
type Mock,
} from 'vitest';
const mockShellExecutionService = vi.hoisted(() => vi.fn());
vi.mock('../services/shellExecutionService.js', () => ({
ShellExecutionService: { execute: mockShellExecutionService },
}));
vi.mock('fs');
vi.mock('os');
vi.mock('crypto');
vi.mock('../utils/summarizer.js');
import { isCommandAllowed } from '../utils/shell-utils.js';
import { ShellTool } from './shell.js';
import { Config } from '../config/config.js';
import { type Config } from '../config/config.js';
import {
type ShellExecutionResult,
type ShellOutputEvent,
} from '../services/shellExecutionService.js';
import * as fs from 'fs';
import * as os from 'os';
import * as path from 'path';
import * as crypto from 'crypto';
import * as summarizer from '../utils/summarizer.js';
import { GeminiClient } from '../core/client.js';
import { ToolConfirmationOutcome } from './tools.js';
import { OUTPUT_UPDATE_INTERVAL_MS } from './shell.js';
describe('ShellTool', () => {
it('should allow a command if no restrictions are provided', async () => {
const config = {
getCoreTools: () => undefined,
getExcludeTools: () => undefined,
} as Config;
const shellTool = new ShellTool(config);
const result = shellTool.isCommandAllowed('ls -l');
expect(result.allowed).toBe(true);
});
it('should allow a command if it is in the allowed list', async () => {
const config = {
getCoreTools: () => ['ShellTool(ls -l)'],
getExcludeTools: () => undefined,
} as unknown as Config;
const shellTool = new ShellTool(config);
const result = shellTool.isCommandAllowed('ls -l');
expect(result.allowed).toBe(true);
});
it('should block a command if it is not in the allowed list', async () => {
const config = {
getCoreTools: () => ['ShellTool(ls -l)'],
getExcludeTools: () => undefined,
} as unknown as Config;
const shellTool = new ShellTool(config);
const result = shellTool.isCommandAllowed('rm -rf /');
expect(result.allowed).toBe(false);
expect(result.reason).toBe(
"Command 'rm -rf /' is not in the allowed commands list",
);
});
it('should block a command if it is in the blocked list', async () => {
const config = {
getCoreTools: () => undefined,
getExcludeTools: () => ['ShellTool(rm -rf /)'],
} as unknown as Config;
const shellTool = new ShellTool(config);
const result = shellTool.isCommandAllowed('rm -rf /');
expect(result.allowed).toBe(false);
expect(result.reason).toBe(
"Command 'rm -rf /' is blocked by configuration",
);
});
it('should allow a command if it is not in the blocked list', async () => {
const config = {
getCoreTools: () => undefined,
getExcludeTools: () => ['ShellTool(rm -rf /)'],
} as unknown as Config;
const shellTool = new ShellTool(config);
const result = shellTool.isCommandAllowed('ls -l');
expect(result.allowed).toBe(true);
});
it('should block a command if it is in both the allowed and blocked lists', async () => {
const config = {
getCoreTools: () => ['ShellTool(rm -rf /)'],
getExcludeTools: () => ['ShellTool(rm -rf /)'],
} as unknown as Config;
const shellTool = new ShellTool(config);
const result = shellTool.isCommandAllowed('rm -rf /');
expect(result.allowed).toBe(false);
expect(result.reason).toBe(
"Command 'rm -rf /' is blocked by configuration",
);
});
it('should allow any command when ShellTool is in coreTools without specific commands', async () => {
const config = {
getCoreTools: () => ['ShellTool'],
getExcludeTools: () => [],
} as unknown as Config;
const shellTool = new ShellTool(config);
const result = shellTool.isCommandAllowed('any command');
expect(result.allowed).toBe(true);
});
it('should block any command when ShellTool is in excludeTools without specific commands', async () => {
const config = {
getCoreTools: () => [],
getExcludeTools: () => ['ShellTool'],
} as unknown as Config;
const shellTool = new ShellTool(config);
const result = shellTool.isCommandAllowed('any command');
expect(result.allowed).toBe(false);
expect(result.reason).toBe(
'Shell tool is globally disabled in configuration',
);
});
it('should allow a command if it is in the allowed list using the public-facing name', async () => {
const config = {
getCoreTools: () => ['run_shell_command(ls -l)'],
getExcludeTools: () => undefined,
} as unknown as Config;
const shellTool = new ShellTool(config);
const result = shellTool.isCommandAllowed('ls -l');
expect(result.allowed).toBe(true);
});
it('should block a command if it is in the blocked list using the public-facing name', async () => {
const config = {
getCoreTools: () => undefined,
getExcludeTools: () => ['run_shell_command(rm -rf /)'],
} as unknown as Config;
const shellTool = new ShellTool(config);
const result = shellTool.isCommandAllowed('rm -rf /');
expect(result.allowed).toBe(false);
expect(result.reason).toBe(
"Command 'rm -rf /' is blocked by configuration",
);
});
it('should block any command when ShellTool is in excludeTools using the public-facing name', async () => {
const config = {
getCoreTools: () => [],
getExcludeTools: () => ['run_shell_command'],
} as unknown as Config;
const shellTool = new ShellTool(config);
const result = shellTool.isCommandAllowed('any command');
expect(result.allowed).toBe(false);
expect(result.reason).toBe(
'Shell tool is globally disabled in configuration',
);
});
it('should block any command if coreTools contains an empty ShellTool command list using the public-facing name', async () => {
const config = {
getCoreTools: () => ['run_shell_command()'],
getExcludeTools: () => [],
} as unknown as Config;
const shellTool = new ShellTool(config);
const result = shellTool.isCommandAllowed('any command');
expect(result.allowed).toBe(false);
expect(result.reason).toBe(
"Command 'any command' is not in the allowed commands list",
);
});
it('should block any command if coreTools contains an empty ShellTool command list', async () => {
const config = {
getCoreTools: () => ['ShellTool()'],
getExcludeTools: () => [],
} as unknown as Config;
const shellTool = new ShellTool(config);
const result = shellTool.isCommandAllowed('any command');
expect(result.allowed).toBe(false);
expect(result.reason).toBe(
"Command 'any command' is not in the allowed commands list",
);
});
it('should block a command with extra whitespace if it is in the blocked list', async () => {
const config = {
getCoreTools: () => undefined,
getExcludeTools: () => ['ShellTool(rm -rf /)'],
} as unknown as Config;
const shellTool = new ShellTool(config);
const result = shellTool.isCommandAllowed(' rm -rf / ');
expect(result.allowed).toBe(false);
expect(result.reason).toBe(
"Command 'rm -rf /' is blocked by configuration",
);
});
it('should allow any command when ShellTool is present with specific commands', async () => {
const config = {
getCoreTools: () => ['ShellTool', 'ShellTool(ls)'],
getExcludeTools: () => [],
} as unknown as Config;
const shellTool = new ShellTool(config);
const result = shellTool.isCommandAllowed('any command');
expect(result.allowed).toBe(true);
});
it('should block a command on the blocklist even with a wildcard allow', async () => {
const config = {
getCoreTools: () => ['ShellTool'],
getExcludeTools: () => ['ShellTool(rm -rf /)'],
} as unknown as Config;
const shellTool = new ShellTool(config);
const result = shellTool.isCommandAllowed('rm -rf /');
expect(result.allowed).toBe(false);
expect(result.reason).toBe(
"Command 'rm -rf /' is blocked by configuration",
);
});
it('should allow a command that starts with an allowed command prefix', async () => {
const config = {
getCoreTools: () => ['ShellTool(gh issue edit)'],
getExcludeTools: () => [],
} as unknown as Config;
const shellTool = new ShellTool(config);
const result = shellTool.isCommandAllowed(
'gh issue edit 1 --add-label "kind/feature"',
);
expect(result.allowed).toBe(true);
});
it('should allow a command that starts with an allowed command prefix using the public-facing name', async () => {
const config = {
getCoreTools: () => ['run_shell_command(gh issue edit)'],
getExcludeTools: () => [],
} as unknown as Config;
const shellTool = new ShellTool(config);
const result = shellTool.isCommandAllowed(
'gh issue edit 1 --add-label "kind/feature"',
);
expect(result.allowed).toBe(true);
});
it('should not allow a command that starts with an allowed command prefix but is chained with another command', async () => {
const config = {
getCoreTools: () => ['run_shell_command(gh issue edit)'],
getExcludeTools: () => [],
} as unknown as Config;
const shellTool = new ShellTool(config);
const result = shellTool.isCommandAllowed('gh issue edit&&rm -rf /');
expect(result.allowed).toBe(false);
expect(result.reason).toBe(
"Command 'rm -rf /' is not in the allowed commands list",
);
});
it('should not allow a command that is a prefix of an allowed command', async () => {
const config = {
getCoreTools: () => ['run_shell_command(gh issue edit)'],
getExcludeTools: () => [],
} as unknown as Config;
const shellTool = new ShellTool(config);
const result = shellTool.isCommandAllowed('gh issue');
expect(result.allowed).toBe(false);
expect(result.reason).toBe(
"Command 'gh issue' is not in the allowed commands list",
);
});
it('should not allow a command that is a prefix of a blocked command', async () => {
const config = {
getCoreTools: () => [],
getExcludeTools: () => ['run_shell_command(gh issue edit)'],
} as unknown as Config;
const shellTool = new ShellTool(config);
const result = shellTool.isCommandAllowed('gh issue');
expect(result.allowed).toBe(true);
});
it('should not allow a command that is chained with a pipe', async () => {
const config = {
getCoreTools: () => ['run_shell_command(gh issue list)'],
getExcludeTools: () => [],
} as unknown as Config;
const shellTool = new ShellTool(config);
const result = shellTool.isCommandAllowed('gh issue list | rm -rf /');
expect(result.allowed).toBe(false);
expect(result.reason).toBe(
"Command 'rm -rf /' is not in the allowed commands list",
);
});
it('should not allow a command that is chained with a semicolon', async () => {
const config = {
getCoreTools: () => ['run_shell_command(gh issue list)'],
getExcludeTools: () => [],
} as unknown as Config;
const shellTool = new ShellTool(config);
const result = shellTool.isCommandAllowed('gh issue list; rm -rf /');
expect(result.allowed).toBe(false);
expect(result.reason).toBe(
"Command 'rm -rf /' is not in the allowed commands list",
);
});
it('should block a chained command if any part is blocked', async () => {
const config = {
getCoreTools: () => ['run_shell_command(echo "hello")'],
getExcludeTools: () => ['run_shell_command(rm)'],
} as unknown as Config;
const shellTool = new ShellTool(config);
const result = shellTool.isCommandAllowed('echo "hello" && rm -rf /');
expect(result.allowed).toBe(false);
expect(result.reason).toBe(
"Command 'rm -rf /' is blocked by configuration",
);
});
it('should block a command if its prefix is on the blocklist, even if the command itself is on the allowlist', async () => {
const config = {
getCoreTools: () => ['run_shell_command(git push)'],
getExcludeTools: () => ['run_shell_command(git)'],
} as unknown as Config;
const shellTool = new ShellTool(config);
const result = shellTool.isCommandAllowed('git push');
expect(result.allowed).toBe(false);
expect(result.reason).toBe(
"Command 'git push' is blocked by configuration",
);
});
it('should be case-sensitive in its matching', async () => {
const config = {
getCoreTools: () => ['run_shell_command(echo)'],
getExcludeTools: () => [],
} as unknown as Config;
const shellTool = new ShellTool(config);
const result = shellTool.isCommandAllowed('ECHO "hello"');
expect(result.allowed).toBe(false);
expect(result.reason).toBe(
'Command \'ECHO "hello"\' is not in the allowed commands list',
);
});
it('should correctly handle commands with extra whitespace around chaining operators', async () => {
const config = {
getCoreTools: () => ['run_shell_command(ls -l)'],
getExcludeTools: () => ['run_shell_command(rm)'],
} as unknown as Config;
const shellTool = new ShellTool(config);
const result = shellTool.isCommandAllowed('ls -l ; rm -rf /');
expect(result.allowed).toBe(false);
expect(result.reason).toBe(
"Command 'rm -rf /' is blocked by configuration",
);
});
it('should allow a chained command if all parts are allowed', async () => {
const config = {
getCoreTools: () => [
'run_shell_command(echo)',
'run_shell_command(ls -l)',
],
getExcludeTools: () => [],
} as unknown as Config;
const shellTool = new ShellTool(config);
const result = shellTool.isCommandAllowed('echo "hello" && ls -l');
expect(result.allowed).toBe(true);
});
it('should allow a command with command substitution using backticks', async () => {
const config = {
getCoreTools: () => ['run_shell_command(echo)'],
getExcludeTools: () => [],
} as unknown as Config;
const shellTool = new ShellTool(config);
const result = shellTool.isCommandAllowed('echo `rm -rf /`');
expect(result.allowed).toBe(true);
});
it('should block a command with command substitution using $()', async () => {
const config = {
getCoreTools: () => ['run_shell_command(echo)'],
getExcludeTools: () => [],
} as unknown as Config;
const shellTool = new ShellTool(config);
const result = shellTool.isCommandAllowed('echo $(rm -rf /)');
expect(result.allowed).toBe(false);
expect(result.reason).toBe(
'Command substitution using $() is not allowed for security reasons',
);
});
it('should allow a command with I/O redirection', async () => {
const config = {
getCoreTools: () => ['run_shell_command(echo)'],
getExcludeTools: () => [],
} as unknown as Config;
const shellTool = new ShellTool(config);
const result = shellTool.isCommandAllowed('echo "hello" > file.txt');
expect(result.allowed).toBe(true);
});
it('should not allow a command that is chained with a double pipe', async () => {
const config = {
getCoreTools: () => ['run_shell_command(gh issue list)'],
getExcludeTools: () => [],
} as unknown as Config;
const shellTool = new ShellTool(config);
const result = shellTool.isCommandAllowed('gh issue list || rm -rf /');
expect(result.allowed).toBe(false);
expect(result.reason).toBe(
"Command 'rm -rf /' is not in the allowed commands list",
);
});
});
describe('ShellTool Bug Reproduction', () => {
let shellTool: ShellTool;
let config: Config;
let mockConfig: Config;
let mockShellOutputCallback: (event: ShellOutputEvent) => void;
let resolveExecutionPromise: (result: ShellExecutionResult) => void;
beforeEach(() => {
config = {
getCoreTools: () => undefined,
getExcludeTools: () => undefined,
getDebugMode: () => false,
getGeminiClient: () => ({}) as GeminiClient,
getTargetDir: () => '.',
vi.clearAllMocks();
mockConfig = {
getCoreTools: vi.fn().mockReturnValue([]),
getExcludeTools: vi.fn().mockReturnValue([]),
getDebugMode: vi.fn().mockReturnValue(false),
getTargetDir: vi.fn().mockReturnValue('/test/dir'),
getSummarizeToolOutputConfig: vi.fn().mockReturnValue(undefined),
getGeminiClient: vi.fn(),
} as unknown as Config;
shellTool = new ShellTool(config);
});
it('should not let the summarizer override the return display', async () => {
const summarizeSpy = vi
.spyOn(summarizer, 'summarizeToolOutput')
.mockResolvedValue('summarized output');
shellTool = new ShellTool(mockConfig);
const abortSignal = new AbortController().signal;
const result = await shellTool.execute(
{ command: 'echo "hello"' },
abortSignal,
vi.mocked(os.platform).mockReturnValue('linux');
vi.mocked(os.tmpdir).mockReturnValue('/tmp');
(vi.mocked(crypto.randomBytes) as Mock).mockReturnValue(
Buffer.from('abcdef', 'hex'),
);
expect(result.returnDisplay).toBe('hello\n');
expect(result.llmContent).toBe('summarized output');
expect(summarizeSpy).toHaveBeenCalled();
// Capture the output callback to simulate streaming events from the service
mockShellExecutionService.mockImplementation((_cmd, _cwd, callback) => {
mockShellOutputCallback = callback;
return {
pid: 12345,
result: new Promise((resolve) => {
resolveExecutionPromise = resolve;
}),
};
});
});
describe('isCommandAllowed', () => {
it('should allow a command if no restrictions are provided', () => {
(mockConfig.getCoreTools as Mock).mockReturnValue(undefined);
(mockConfig.getExcludeTools as Mock).mockReturnValue(undefined);
expect(isCommandAllowed('ls -l', mockConfig).allowed).toBe(true);
});
it('should block a command with command substitution using $()', () => {
expect(isCommandAllowed('echo $(rm -rf /)', mockConfig).allowed).toBe(
false,
);
});
});
describe('validateToolParams', () => {
it('should return null for a valid command', () => {
expect(shellTool.validateToolParams({ command: 'ls -l' })).toBeNull();
});
it('should return an error for an empty command', () => {
expect(shellTool.validateToolParams({ command: ' ' })).toBe(
'Command cannot be empty.',
);
});
it('should return an error for a non-existent directory', () => {
vi.mocked(fs.existsSync).mockReturnValue(false);
expect(
shellTool.validateToolParams({ command: 'ls', directory: 'rel/path' }),
).toBe('Directory must exist.');
});
});
describe('execute', () => {
const mockAbortSignal = new AbortController().signal;
const resolveShellExecution = (
result: Partial<ShellExecutionResult> = {},
) => {
const fullResult: ShellExecutionResult = {
rawOutput: Buffer.from(result.output || ''),
output: 'Success',
stdout: 'Success',
stderr: '',
exitCode: 0,
signal: null,
error: null,
aborted: false,
pid: 12345,
...result,
};
resolveExecutionPromise(fullResult);
};
it('should wrap command on linux and parse pgrep output', async () => {
const promise = shellTool.execute(
{ command: 'my-command &' },
mockAbortSignal,
);
resolveShellExecution({ pid: 54321 });
vi.mocked(fs.existsSync).mockReturnValue(true);
vi.mocked(fs.readFileSync).mockReturnValue('54321\n54322\n'); // Service PID and background PID
const result = await promise;
const tmpFile = path.join(os.tmpdir(), 'shell_pgrep_abcdef.tmp');
const wrappedCommand = `{ my-command & }; __code=$?; pgrep -g 0 >${tmpFile} 2>&1; exit $__code;`;
expect(mockShellExecutionService).toHaveBeenCalledWith(
wrappedCommand,
expect.any(String),
expect.any(Function),
mockAbortSignal,
);
expect(result.llmContent).toContain('Background PIDs: 54322');
expect(vi.mocked(fs.unlinkSync)).toHaveBeenCalledWith(tmpFile);
});
it('should not wrap command on windows', async () => {
vi.mocked(os.platform).mockReturnValue('win32');
const promise = shellTool.execute({ command: 'dir' }, mockAbortSignal);
resolveExecutionPromise({
rawOutput: Buffer.from(''),
output: '',
stdout: '',
stderr: '',
exitCode: 0,
signal: null,
error: null,
aborted: false,
pid: 12345,
});
await promise;
expect(mockShellExecutionService).toHaveBeenCalledWith(
'dir',
expect.any(String),
expect.any(Function),
mockAbortSignal,
);
});
it('should format error messages correctly', async () => {
const error = new Error('wrapped command failed');
const promise = shellTool.execute(
{ command: 'user-command' },
mockAbortSignal,
);
resolveShellExecution({
error,
exitCode: 1,
output: 'err',
stderr: 'err',
rawOutput: Buffer.from('err'),
stdout: '',
signal: null,
aborted: false,
pid: 12345,
});
const result = await promise;
// The final llmContent should contain the user's command, not the wrapper
expect(result.llmContent).toContain('Error: wrapped command failed');
expect(result.llmContent).not.toContain('pgrep');
});
it('should summarize output when configured', async () => {
(mockConfig.getSummarizeToolOutputConfig as Mock).mockReturnValue({
[shellTool.name]: { tokenBudget: 1000 },
});
vi.mocked(summarizer.summarizeToolOutput).mockResolvedValue(
'summarized output',
);
const promise = shellTool.execute({ command: 'ls' }, mockAbortSignal);
resolveExecutionPromise({
output: 'long output',
rawOutput: Buffer.from('long output'),
stdout: 'long output',
stderr: '',
exitCode: 0,
signal: null,
error: null,
aborted: false,
pid: 12345,
});
const result = await promise;
expect(summarizer.summarizeToolOutput).toHaveBeenCalledWith(
expect.any(String),
mockConfig.getGeminiClient(),
mockAbortSignal,
1000,
);
expect(result.llmContent).toBe('summarized output');
expect(result.returnDisplay).toBe('long output');
});
it('should clean up the temp file on synchronous execution error', async () => {
const error = new Error('sync spawn error');
mockShellExecutionService.mockImplementation(() => {
throw error;
});
vi.mocked(fs.existsSync).mockReturnValue(true); // Pretend the file exists
await expect(
shellTool.execute({ command: 'a-command' }, mockAbortSignal),
).rejects.toThrow(error);
const tmpFile = path.join(os.tmpdir(), 'shell_pgrep_abcdef.tmp');
expect(vi.mocked(fs.unlinkSync)).toHaveBeenCalledWith(tmpFile);
});
describe('Streaming to `updateOutput`', () => {
let updateOutputMock: Mock;
beforeEach(() => {
vi.useFakeTimers({ toFake: ['Date'] });
updateOutputMock = vi.fn();
});
afterEach(() => {
vi.useRealTimers();
});
it('should throttle text output updates', async () => {
const promise = shellTool.execute(
{ command: 'stream' },
mockAbortSignal,
updateOutputMock,
);
// First chunk, should be throttled.
mockShellOutputCallback({
type: 'data',
stream: 'stdout',
chunk: 'hello ',
});
expect(updateOutputMock).not.toHaveBeenCalled();
// Advance time past the throttle interval.
await vi.advanceTimersByTimeAsync(OUTPUT_UPDATE_INTERVAL_MS + 1);
// Send a second chunk. THIS event triggers the update with the CUMULATIVE content.
mockShellOutputCallback({
type: 'data',
stream: 'stderr',
chunk: 'world',
});
// It should have been called once now with the combined output.
expect(updateOutputMock).toHaveBeenCalledOnce();
expect(updateOutputMock).toHaveBeenCalledWith('hello \nworld');
resolveExecutionPromise({
rawOutput: Buffer.from(''),
output: '',
stdout: '',
stderr: '',
exitCode: 0,
signal: null,
error: null,
aborted: false,
pid: 12345,
});
await promise;
});
it('should immediately show binary detection message and throttle progress', async () => {
const promise = shellTool.execute(
{ command: 'cat img' },
mockAbortSignal,
updateOutputMock,
);
mockShellOutputCallback({ type: 'binary_detected' });
expect(updateOutputMock).toHaveBeenCalledOnce();
expect(updateOutputMock).toHaveBeenCalledWith(
'[Binary output detected. Halting stream...]',
);
mockShellOutputCallback({
type: 'binary_progress',
bytesReceived: 1024,
});
expect(updateOutputMock).toHaveBeenCalledOnce();
// Advance time past the throttle interval.
await vi.advanceTimersByTimeAsync(OUTPUT_UPDATE_INTERVAL_MS + 1);
// Send a SECOND progress event. This one will trigger the flush.
mockShellOutputCallback({
type: 'binary_progress',
bytesReceived: 2048,
});
// Now it should be called a second time with the latest progress.
expect(updateOutputMock).toHaveBeenCalledTimes(2);
expect(updateOutputMock).toHaveBeenLastCalledWith(
'[Receiving binary output... 2.0 KB received]',
);
resolveExecutionPromise({
rawOutput: Buffer.from(''),
output: '',
stdout: '',
stderr: '',
exitCode: 0,
signal: null,
error: null,
aborted: false,
pid: 12345,
});
await promise;
});
});
});
describe('shouldConfirmExecute', () => {
it('should request confirmation for a new command and whitelist it on "Always"', async () => {
const params = { command: 'npm install' };
const confirmation = await shellTool.shouldConfirmExecute(
params,
new AbortController().signal,
);
expect(confirmation).not.toBe(false);
expect(confirmation && confirmation.type).toBe('exec');
// eslint-disable-next-line @typescript-eslint/no-explicit-any
await (confirmation as any).onConfirm(
ToolConfirmationOutcome.ProceedAlways,
);
// Should now be whitelisted
const secondConfirmation = await shellTool.shouldConfirmExecute(
{ command: 'npm test' },
new AbortController().signal,
);
expect(secondConfirmation).toBe(false);
});
it('should skip confirmation if validation fails', async () => {
const confirmation = await shellTool.shouldConfirmExecute(
{ command: '' },
new AbortController().signal,
);
expect(confirmation).toBe(false);
});
});
});

View File

@@ -15,25 +15,34 @@ import {
ToolCallConfirmationDetails,
ToolExecuteConfirmationDetails,
ToolConfirmationOutcome,
Icon,
} from './tools.js';
import { Type } from '@google/genai';
import { SchemaValidator } from '../utils/schemaValidator.js';
import { getErrorMessage } from '../utils/errors.js';
import stripAnsi from 'strip-ansi';
import { summarizeToolOutput } from '../utils/summarizer.js';
import {
ShellExecutionService,
ShellOutputEvent,
} from '../services/shellExecutionService.js';
import { formatMemoryUsage } from '../utils/formatters.js';
import {
getCommandRoots,
isCommandAllowed,
stripShellWrapper,
} from '../utils/shell-utils.js';
export const OUTPUT_UPDATE_INTERVAL_MS = 1000;
export interface ShellToolParams {
command: string;
description?: string;
directory?: string;
}
import { spawn } from 'child_process';
import { summarizeToolOutput } from '../utils/summarizer.js';
const OUTPUT_UPDATE_INTERVAL_MS = 1000;
export class ShellTool extends BaseTool<ShellToolParams, ToolResult> {
static Name: string = 'run_shell_command';
private whitelist: Set<string> = new Set();
private allowlist: Set<string> = new Set();
constructor(private readonly config: Config) {
super(
@@ -41,17 +50,18 @@ export class ShellTool extends BaseTool<ShellToolParams, ToolResult> {
'Shell',
`This tool executes a given shell command as \`bash -c <command>\`. Command can start background processes using \`&\`. Command is executed as a subprocess that leads its own process group. Command process group can be terminated as \`kill -- -PGID\` or signaled as \`kill -s SIGNAL -- -PGID\`.
The following information is returned:
The following information is returned:
Command: Executed command.
Directory: Directory (relative to project root) where command was executed, or \`(root)\`.
Stdout: Output on stdout stream. Can be \`(empty)\` or partial on error and for any unwaited background processes.
Stderr: Output on stderr stream. Can be \`(empty)\` or partial on error and for any unwaited background processes.
Error: Error or \`(none)\` if no error was reported for the subprocess.
Exit Code: Exit code or \`(none)\` if terminated by signal.
Signal: Signal number or \`(none)\` if no signal was received.
Background PIDs: List of background processes started or \`(none)\`.
Process Group PGID: Process group started or \`(none)\``,
Command: Executed command.
Directory: Directory (relative to project root) where command was executed, or \`(root)\`.
Stdout: Output on stdout stream. Can be \`(empty)\` or partial on error and for any unwaited background processes.
Stderr: Output on stderr stream. Can be \`(empty)\` or partial on error and for any unwaited background processes.
Error: Error or \`(none)\` if no error was reported for the subprocess.
Exit Code: Exit code or \`(none)\` if terminated by signal.
Signal: Signal number or \`(none)\` if no signal was received.
Background PIDs: List of background processes started or \`(none)\`.
Process Group PGID: Process group started or \`(none)\``,
Icon.Terminal,
{
type: Type.OBJECT,
properties: {
@@ -91,131 +101,8 @@ Process Group PGID: Process group started or \`(none)\``,
return description;
}
/**
* Extracts the root command from a given shell command string.
* This is used to identify the base command for permission checks.
*
* @param command The shell command string to parse
* @returns The root command name, or undefined if it cannot be determined
* @example getCommandRoot("ls -la /tmp") returns "ls"
* @example getCommandRoot("git status && npm test") returns "git"
*/
getCommandRoot(command: string): string | undefined {
return command
.trim() // remove leading and trailing whitespace
.replace(/[{}()]/g, '') // remove all grouping operators
.split(/[\s;&|]+/)[0] // split on any whitespace or separator or chaining operators and take first part
?.split(/[/\\]/) // split on any path separators (or return undefined if previous line was undefined)
.pop(); // take last part and return command root (or undefined if previous line was empty)
}
/**
* Determines whether a given shell command is allowed to execute based on
* the tool's configuration including allowlists and blocklists.
*
* @param command The shell command string to validate
* @returns An object with 'allowed' boolean and optional 'reason' string if not allowed
*/
isCommandAllowed(command: string): { allowed: boolean; reason?: string } {
// 0. Disallow command substitution
if (command.includes('$(')) {
return {
allowed: false,
reason:
'Command substitution using $() is not allowed for security reasons',
};
}
const SHELL_TOOL_NAMES = [ShellTool.name, ShellTool.Name];
const normalize = (cmd: string): string => cmd.trim().replace(/\s+/g, ' ');
/**
* Checks if a command string starts with a given prefix, ensuring it's a
* whole word match (i.e., followed by a space or it's an exact match).
* e.g., `isPrefixedBy('npm install', 'npm')` -> true
* e.g., `isPrefixedBy('npm', 'npm')` -> true
* e.g., `isPrefixedBy('npminstall', 'npm')` -> false
*/
const isPrefixedBy = (cmd: string, prefix: string): boolean => {
if (!cmd.startsWith(prefix)) {
return false;
}
return cmd.length === prefix.length || cmd[prefix.length] === ' ';
};
/**
* Extracts and normalizes shell commands from a list of tool strings.
* e.g., 'ShellTool("ls -l")' becomes 'ls -l'
*/
const extractCommands = (tools: string[]): string[] =>
tools.flatMap((tool) => {
for (const toolName of SHELL_TOOL_NAMES) {
if (tool.startsWith(`${toolName}(`) && tool.endsWith(')')) {
return [normalize(tool.slice(toolName.length + 1, -1))];
}
}
return [];
});
const coreTools = this.config.getCoreTools() || [];
const excludeTools = this.config.getExcludeTools() || [];
// 1. Check if the shell tool is globally disabled.
if (SHELL_TOOL_NAMES.some((name) => excludeTools.includes(name))) {
return {
allowed: false,
reason: 'Shell tool is globally disabled in configuration',
};
}
const blockedCommands = new Set(extractCommands(excludeTools));
const allowedCommands = new Set(extractCommands(coreTools));
const hasSpecificAllowedCommands = allowedCommands.size > 0;
const isWildcardAllowed = SHELL_TOOL_NAMES.some((name) =>
coreTools.includes(name),
);
const commandsToValidate = command.split(/&&|\|\||\||;/).map(normalize);
const blockedCommandsArr = [...blockedCommands];
for (const cmd of commandsToValidate) {
// 2. Check if the command is on the blocklist.
const isBlocked = blockedCommandsArr.some((blocked) =>
isPrefixedBy(cmd, blocked),
);
if (isBlocked) {
return {
allowed: false,
reason: `Command '${cmd}' is blocked by configuration`,
};
}
// 3. If in strict allow-list mode, check if the command is permitted.
const isStrictAllowlist =
hasSpecificAllowedCommands && !isWildcardAllowed;
const allowedCommandsArr = [...allowedCommands];
if (isStrictAllowlist) {
const isAllowed = allowedCommandsArr.some((allowed) =>
isPrefixedBy(cmd, allowed),
);
if (!isAllowed) {
return {
allowed: false,
reason: `Command '${cmd}' is not in the allowed commands list`,
};
}
}
}
// 4. If all checks pass, the command is allowed.
return { allowed: true };
}
validateToolParams(params: ShellToolParams): string | null {
const commandCheck = this.isCommandAllowed(params.command);
const commandCheck = isCommandAllowed(params.command, this.config);
if (!commandCheck.allowed) {
if (!commandCheck.reason) {
console.error(
@@ -232,7 +119,7 @@ Process Group PGID: Process group started or \`(none)\``,
if (!params.command.trim()) {
return 'Command cannot be empty.';
}
if (!this.getCommandRoot(params.command)) {
if (getCommandRoots(params.command).length === 0) {
return 'Could not identify command root to obtain permission from user.';
}
if (params.directory) {
@@ -257,18 +144,25 @@ Process Group PGID: Process group started or \`(none)\``,
if (this.validateToolParams(params)) {
return false; // skip confirmation, execute call will fail immediately
}
const rootCommand = this.getCommandRoot(params.command)!; // must be non-empty string post-validation
if (this.whitelist.has(rootCommand)) {
const command = stripShellWrapper(params.command);
const rootCommands = [...new Set(getCommandRoots(command))];
const commandsToConfirm = rootCommands.filter(
(command) => !this.allowlist.has(command),
);
if (commandsToConfirm.length === 0) {
return false; // already approved and whitelisted
}
const confirmationDetails: ToolExecuteConfirmationDetails = {
type: 'exec',
title: 'Confirm Shell Command',
command: params.command,
rootCommand,
rootCommand: commandsToConfirm.join(', '),
onConfirm: async (outcome: ToolConfirmationOutcome) => {
if (outcome === ToolConfirmationOutcome.ProceedAlways) {
this.whitelist.add(rootCommand);
commandsToConfirm.forEach((command) => this.allowlist.add(command));
}
},
};
@@ -277,21 +171,22 @@ Process Group PGID: Process group started or \`(none)\``,
async execute(
params: ShellToolParams,
abortSignal: AbortSignal,
updateOutput?: (chunk: string) => void,
signal: AbortSignal,
updateOutput?: (output: string) => void,
): Promise<ToolResult> {
const validationError = this.validateToolParams(params);
const strippedCommand = stripShellWrapper(params.command);
const validationError = this.validateToolParams({
...params,
command: strippedCommand,
});
if (validationError) {
return {
llmContent: [
`Command rejected: ${params.command}`,
`Reason: ${validationError}`,
].join('\n'),
returnDisplay: `Error: ${validationError}`,
llmContent: validationError,
returnDisplay: validationError,
};
}
if (abortSignal.aborted) {
if (signal.aborted) {
return {
llmContent: 'Command was cancelled by user before it could start.',
returnDisplay: 'Command cancelled by user.',
@@ -304,200 +199,182 @@ Process Group PGID: Process group started or \`(none)\``,
.toString('hex')}.tmp`;
const tempFilePath = path.join(os.tmpdir(), tempFileName);
// pgrep is not available on Windows, so we can't get background PIDs
const command = isWindows
? params.command
: (() => {
// wrap command to append subprocess pids (via pgrep) to temporary file
let command = params.command.trim();
if (!command.endsWith('&')) command += ';';
return `{ ${command} }; __code=$?; pgrep -g 0 >${tempFilePath} 2>&1; exit $__code;`;
})();
// spawn command in specified directory (or project root if not specified)
const shell = isWindows
? spawn('cmd.exe', ['/c', command], {
stdio: ['ignore', 'pipe', 'pipe'],
// detached: true, // ensure subprocess starts its own process group (esp. in Linux)
cwd: path.resolve(this.config.getTargetDir(), params.directory || ''),
})
: spawn('bash', ['-c', command], {
stdio: ['ignore', 'pipe', 'pipe'],
detached: true, // ensure subprocess starts its own process group (esp. in Linux)
cwd: path.resolve(this.config.getTargetDir(), params.directory || ''),
});
let exited = false;
let stdout = '';
let output = '';
let lastUpdateTime = Date.now();
const appendOutput = (str: string) => {
output += str;
if (
updateOutput &&
Date.now() - lastUpdateTime > OUTPUT_UPDATE_INTERVAL_MS
) {
updateOutput(output);
lastUpdateTime = Date.now();
}
};
shell.stdout.on('data', (data: Buffer) => {
// continue to consume post-exit for background processes
// removing listeners can overflow OS buffer and block subprocesses
// destroying (e.g. shell.stdout.destroy()) can terminate subprocesses via SIGPIPE
if (!exited) {
const str = stripAnsi(data.toString());
stdout += str;
appendOutput(str);
}
});
let stderr = '';
shell.stderr.on('data', (data: Buffer) => {
if (!exited) {
const str = stripAnsi(data.toString());
stderr += str;
appendOutput(str);
}
});
let error: Error | null = null;
shell.on('error', (err: Error) => {
error = err;
// remove wrapper from user's command in error message
error.message = error.message.replace(command, params.command);
});
let code: number | null = null;
let processSignal: NodeJS.Signals | null = null;
const exitHandler = (
_code: number | null,
_signal: NodeJS.Signals | null,
) => {
exited = true;
code = _code;
processSignal = _signal;
};
shell.on('exit', exitHandler);
const abortHandler = async () => {
if (shell.pid && !exited) {
if (os.platform() === 'win32') {
// For Windows, use taskkill to kill the process tree
spawn('taskkill', ['/pid', shell.pid.toString(), '/f', '/t']);
} else {
try {
// attempt to SIGTERM process group (negative PID)
// fall back to SIGKILL (to group) after 200ms
process.kill(-shell.pid, 'SIGTERM');
await new Promise((resolve) => setTimeout(resolve, 200));
if (shell.pid && !exited) {
process.kill(-shell.pid, 'SIGKILL');
}
} catch (_e) {
// if group kill fails, fall back to killing just the main process
try {
if (shell.pid) {
shell.kill('SIGKILL');
}
} catch (_e) {
console.error(`failed to kill shell process ${shell.pid}: ${_e}`);
}
}
}
}
};
abortSignal.addEventListener('abort', abortHandler);
// wait for the shell to exit
try {
await new Promise((resolve) => shell.on('exit', resolve));
// pgrep is not available on Windows, so we can't get background PIDs
const commandToExecute = isWindows
? strippedCommand
: (() => {
// wrap command to append subprocess pids (via pgrep) to temporary file
let command = strippedCommand.trim();
if (!command.endsWith('&')) command += ';';
return `{ ${command} }; __code=$?; pgrep -g 0 >${tempFilePath} 2>&1; exit $__code;`;
})();
const cwd = path.resolve(
this.config.getTargetDir(),
params.directory || '',
);
let cumulativeStdout = '';
let cumulativeStderr = '';
let lastUpdateTime = Date.now();
let isBinaryStream = false;
const { result: resultPromise } = ShellExecutionService.execute(
commandToExecute,
cwd,
(event: ShellOutputEvent) => {
if (!updateOutput) {
return;
}
let currentDisplayOutput = '';
let shouldUpdate = false;
switch (event.type) {
case 'data':
if (isBinaryStream) break; // Don't process text if we are in binary mode
if (event.stream === 'stdout') {
cumulativeStdout += event.chunk;
} else {
cumulativeStderr += event.chunk;
}
currentDisplayOutput =
cumulativeStdout +
(cumulativeStderr ? `\n${cumulativeStderr}` : '');
if (Date.now() - lastUpdateTime > OUTPUT_UPDATE_INTERVAL_MS) {
shouldUpdate = true;
}
break;
case 'binary_detected':
isBinaryStream = true;
currentDisplayOutput =
'[Binary output detected. Halting stream...]';
shouldUpdate = true;
break;
case 'binary_progress':
isBinaryStream = true;
currentDisplayOutput = `[Receiving binary output... ${formatMemoryUsage(
event.bytesReceived,
)} received]`;
if (Date.now() - lastUpdateTime > OUTPUT_UPDATE_INTERVAL_MS) {
shouldUpdate = true;
}
break;
default: {
throw new Error('An unhandled ShellOutputEvent was found.');
}
}
if (shouldUpdate) {
updateOutput(currentDisplayOutput);
lastUpdateTime = Date.now();
}
},
signal,
);
const result = await resultPromise;
const backgroundPIDs: number[] = [];
if (os.platform() !== 'win32') {
if (fs.existsSync(tempFilePath)) {
const pgrepLines = fs
.readFileSync(tempFilePath, 'utf8')
.split('\n')
.filter(Boolean);
for (const line of pgrepLines) {
if (!/^\d+$/.test(line)) {
console.error(`pgrep: ${line}`);
}
const pid = Number(line);
if (pid !== result.pid) {
backgroundPIDs.push(pid);
}
}
} else {
if (!signal.aborted) {
console.error('missing pgrep output');
}
}
}
let llmContent = '';
if (result.aborted) {
llmContent = 'Command was cancelled by user before it could complete.';
if (result.output.trim()) {
llmContent += ` Below is the output (on stdout and stderr) before it was cancelled:\n${result.output}`;
} else {
llmContent += ' There was no output before it was cancelled.';
}
} else {
// Create a formatted error string for display, replacing the wrapper command
// with the user-facing command.
const finalError = result.error
? result.error.message.replace(commandToExecute, params.command)
: '(none)';
llmContent = [
`Command: ${params.command}`,
`Directory: ${params.directory || '(root)'}`,
`Stdout: ${result.stdout || '(empty)'}`,
`Stderr: ${result.stderr || '(empty)'}`,
`Error: ${finalError}`, // Use the cleaned error string.
`Exit Code: ${result.exitCode ?? '(none)'}`,
`Signal: ${result.signal ?? '(none)'}`,
`Background PIDs: ${
backgroundPIDs.length ? backgroundPIDs.join(', ') : '(none)'
}`,
`Process Group PGID: ${result.pid ?? '(none)'}`,
].join('\n');
}
let returnDisplayMessage = '';
if (this.config.getDebugMode()) {
returnDisplayMessage = llmContent;
} else {
if (result.output.trim()) {
returnDisplayMessage = result.output;
} else {
if (result.aborted) {
returnDisplayMessage = 'Command cancelled by user.';
} else if (result.signal) {
returnDisplayMessage = `Command terminated by signal: ${result.signal}`;
} else if (result.error) {
returnDisplayMessage = `Command failed: ${getErrorMessage(
result.error,
)}`;
} else if (result.exitCode !== null && result.exitCode !== 0) {
returnDisplayMessage = `Command exited with code: ${result.exitCode}`;
}
// If output is empty and command succeeded (code 0, no error/signal/abort),
// returnDisplayMessage will remain empty, which is fine.
}
}
const summarizeConfig = this.config.getSummarizeToolOutputConfig();
if (summarizeConfig && summarizeConfig[this.name]) {
const summary = await summarizeToolOutput(
llmContent,
this.config.getGeminiClient(),
signal,
summarizeConfig[this.name].tokenBudget,
);
return {
llmContent: summary,
returnDisplay: returnDisplayMessage,
};
}
return {
llmContent,
returnDisplay: returnDisplayMessage,
};
} finally {
abortSignal.removeEventListener('abort', abortHandler);
}
// parse pids (pgrep output) from temporary file and remove it
const backgroundPIDs: number[] = [];
if (os.platform() !== 'win32') {
if (fs.existsSync(tempFilePath)) {
const pgrepLines = fs
.readFileSync(tempFilePath, 'utf8')
.split('\n')
.filter(Boolean);
for (const line of pgrepLines) {
if (!/^\d+$/.test(line)) {
console.error(`pgrep: ${line}`);
}
const pid = Number(line);
// exclude the shell subprocess pid
if (pid !== shell.pid) {
backgroundPIDs.push(pid);
}
}
fs.unlinkSync(tempFilePath);
} else {
if (!abortSignal.aborted) {
console.error('missing pgrep output');
}
}
}
let llmContent = '';
if (abortSignal.aborted) {
llmContent = 'Command was cancelled by user before it could complete.';
if (output.trim()) {
llmContent += ` Below is the output (on stdout and stderr) before it was cancelled:\n${output}`;
} else {
llmContent += ' There was no output before it was cancelled.';
}
} else {
llmContent = [
`Command: ${params.command}`,
`Directory: ${params.directory || '(root)'}`,
`Stdout: ${stdout || '(empty)'}`,
`Stderr: ${stderr || '(empty)'}`,
`Error: ${error ?? '(none)'}`,
`Exit Code: ${code ?? '(none)'}`,
`Signal: ${processSignal ?? '(none)'}`,
`Background PIDs: ${backgroundPIDs.length ? backgroundPIDs.join(', ') : '(none)'}`,
`Process Group PGID: ${shell.pid ?? '(none)'}`,
].join('\n');
}
let returnDisplayMessage = '';
if (this.config.getDebugMode()) {
returnDisplayMessage = llmContent;
} else {
if (output.trim()) {
returnDisplayMessage = output;
} else {
// Output is empty, let's provide a reason if the command failed or was cancelled
if (abortSignal.aborted) {
returnDisplayMessage = 'Command cancelled by user.';
} else if (processSignal) {
returnDisplayMessage = `Command terminated by signal: ${processSignal}`;
} else if (error) {
// If error is not null, it's an Error object (or other truthy value)
returnDisplayMessage = `Command failed: ${getErrorMessage(error)}`;
} else if (code !== null && code !== 0) {
returnDisplayMessage = `Command exited with code: ${code}`;
}
// If output is empty and command succeeded (code 0, no error/signal/abort),
// returnDisplayMessage will remain empty, which is fine.
}
}
const summary = await summarizeToolOutput(
llmContent,
this.config.getGeminiClient(),
abortSignal,
);
return {
llmContent: summary,
returnDisplay: returnDisplayMessage,
};
}
}

View File

@@ -14,14 +14,14 @@ import {
afterEach,
Mocked,
} from 'vitest';
import { Config, ConfigParameters, ApprovalMode } from '../config/config.js';
import {
ToolRegistry,
DiscoveredTool,
sanitizeParameters,
} from './tool-registry.js';
import { DiscoveredMCPTool } from './mcp-tool.js';
import { Config, ConfigParameters, ApprovalMode } from '../config/config.js';
import { BaseTool, ToolResult } from './tools.js';
import { BaseTool, Icon, ToolResult } from './tools.js';
import {
FunctionDeclaration,
CallableTool,
@@ -104,8 +104,12 @@ const createMockCallableTool = (
});
class MockTool extends BaseTool<{ param: string }, ToolResult> {
constructor(name = 'mock-tool', description = 'A mock tool') {
super(name, name, description, {
constructor(
name = 'mock-tool',
displayName = 'A mock tool',
description = 'A mock tool description',
) {
super(name, displayName, description, Icon.Hammer, {
type: Type.OBJECT,
properties: {
param: { type: Type.STRING },
@@ -174,42 +178,85 @@ describe('ToolRegistry', () => {
});
});
describe('getAllTools', () => {
it('should return all registered tools sorted alphabetically by displayName', () => {
// Register tools with displayNames in non-alphabetical order
const toolC = new MockTool('c-tool', 'Tool C');
const toolA = new MockTool('a-tool', 'Tool A');
const toolB = new MockTool('b-tool', 'Tool B');
toolRegistry.registerTool(toolC);
toolRegistry.registerTool(toolA);
toolRegistry.registerTool(toolB);
const allTools = toolRegistry.getAllTools();
const displayNames = allTools.map((t) => t.displayName);
// Assert that the returned array is sorted by displayName
expect(displayNames).toEqual(['Tool A', 'Tool B', 'Tool C']);
});
});
describe('getToolsByServer', () => {
it('should return an empty array if no tools match the server name', () => {
toolRegistry.registerTool(new MockTool());
expect(toolRegistry.getToolsByServer('any-mcp-server')).toEqual([]);
});
it('should return only tools matching the server name', async () => {
it('should return only tools matching the server name, sorted by name', async () => {
const server1Name = 'mcp-server-uno';
const server2Name = 'mcp-server-dos';
const mockCallable = {} as CallableTool;
const mcpTool1 = new DiscoveredMCPTool(
const mcpTool1_c = new DiscoveredMCPTool(
mockCallable,
server1Name,
'server1Name__tool-on-server1',
'zebra-tool',
'd1',
{},
'tool-on-server1',
);
const mcpTool1_a = new DiscoveredMCPTool(
mockCallable,
server1Name,
'apple-tool',
'd2',
{},
);
const mcpTool1_b = new DiscoveredMCPTool(
mockCallable,
server1Name,
'banana-tool',
'd3',
{},
);
const mcpTool2 = new DiscoveredMCPTool(
mockCallable,
server2Name,
'server2Name__tool-on-server2',
'd2',
{},
'tool-on-server2',
'd4',
{},
);
const nonMcpTool = new MockTool('regular-tool');
toolRegistry.registerTool(mcpTool1);
toolRegistry.registerTool(mcpTool1_c);
toolRegistry.registerTool(mcpTool1_a);
toolRegistry.registerTool(mcpTool1_b);
toolRegistry.registerTool(mcpTool2);
toolRegistry.registerTool(nonMcpTool);
const toolsFromServer1 = toolRegistry.getToolsByServer(server1Name);
expect(toolsFromServer1).toHaveLength(1);
expect(toolsFromServer1[0].name).toBe(mcpTool1.name);
const toolNames = toolsFromServer1.map((t) => t.name);
// Assert that the array has the correct tools and is sorted by name
expect(toolsFromServer1).toHaveLength(3);
expect(toolNames).toEqual(['apple-tool', 'banana-tool', 'zebra-tool']);
// Assert that all returned tools are indeed from the correct server
for (const tool of toolsFromServer1) {
expect((tool as DiscoveredMCPTool).serverName).toBe(server1Name);
}
// Assert that the other server's tools are returned correctly
const toolsFromServer2 = toolRegistry.getToolsByServer(server2Name);
expect(toolsFromServer2).toHaveLength(1);
expect(toolsFromServer2[0].name).toBe(mcpTool2.name);
@@ -265,7 +312,7 @@ describe('ToolRegistry', () => {
return mockChildProcess as any;
});
await toolRegistry.discoverTools();
await toolRegistry.discoverAllTools();
const discoveredTool = toolRegistry.getTool('tool-with-bad-format');
expect(discoveredTool).toBeDefined();
@@ -291,12 +338,13 @@ describe('ToolRegistry', () => {
};
vi.spyOn(config, 'getMcpServers').mockReturnValue(mcpServerConfigVal);
await toolRegistry.discoverTools();
await toolRegistry.discoverAllTools();
expect(mockDiscoverMcpTools).toHaveBeenCalledWith(
mcpServerConfigVal,
undefined,
toolRegistry,
undefined,
false,
);
});
@@ -313,12 +361,13 @@ describe('ToolRegistry', () => {
};
vi.spyOn(config, 'getMcpServers').mockReturnValue(mcpServerConfigVal);
await toolRegistry.discoverTools();
await toolRegistry.discoverAllTools();
expect(mockDiscoverMcpTools).toHaveBeenCalledWith(
mcpServerConfigVal,
undefined,
toolRegistry,
undefined,
false,
);
});

View File

@@ -5,7 +5,7 @@
*/
import { FunctionDeclaration, Schema, Type } from '@google/genai';
import { Tool, ToolResult, BaseTool } from './tools.js';
import { Tool, ToolResult, BaseTool, Icon } from './tools.js';
import { Config } from '../config/config.js';
import { spawn } from 'node:child_process';
import { StringDecoder } from 'node:string_decoder';
@@ -18,7 +18,7 @@ type ToolParams = Record<string, unknown>;
export class DiscoveredTool extends BaseTool<ToolParams, ToolResult> {
constructor(
private readonly config: Config,
readonly name: string,
name: string,
readonly description: string,
readonly parameterSchema: Record<string, unknown>,
) {
@@ -44,6 +44,7 @@ Signal: Signal number or \`(none)\` if no signal was received.
name,
name,
description,
Icon.Hammer,
parameterSchema,
false, // isOutputMarkdown
false, // canUpdateOutput
@@ -137,10 +138,14 @@ export class ToolRegistry {
*/
registerTool(tool: Tool): void {
if (this.tools.has(tool.name)) {
// Decide on behavior: throw error, log warning, or allow overwrite
console.warn(
`Tool with name "${tool.name}" is already registered. Overwriting.`,
);
if (tool instanceof DiscoveredMCPTool) {
tool = tool.asFullyQualifiedTool();
} else {
// Decide on behavior: throw error, log warning, or allow overwrite
console.warn(
`Tool with name "${tool.name}" is already registered. Overwriting.`,
);
}
}
this.tools.set(tool.name, tool);
}
@@ -148,8 +153,9 @@ export class ToolRegistry {
/**
* Discovers tools from project (if available and configured).
* Can be called multiple times to update discovered tools.
* This will discover tools from the command line and from MCP servers.
*/
async discoverTools(): Promise<void> {
async discoverAllTools(): Promise<void> {
// remove any previously discovered tools
for (const tool of this.tools.values()) {
if (tool instanceof DiscoveredTool || tool instanceof DiscoveredMCPTool) {
@@ -164,10 +170,59 @@ export class ToolRegistry {
this.config.getMcpServers() ?? {},
this.config.getMcpServerCommand(),
this,
this.config.getPromptRegistry(),
this.config.getDebugMode(),
);
}
/**
* Discovers tools from project (if available and configured).
* Can be called multiple times to update discovered tools.
* This will NOT discover tools from the command line, only from MCP servers.
*/
async discoverMcpTools(): Promise<void> {
// remove any previously discovered tools
for (const tool of this.tools.values()) {
if (tool instanceof DiscoveredMCPTool) {
this.tools.delete(tool.name);
}
}
// discover tools using MCP servers, if configured
await discoverMcpTools(
this.config.getMcpServers() ?? {},
this.config.getMcpServerCommand(),
this,
this.config.getPromptRegistry(),
this.config.getDebugMode(),
);
}
/**
* Discover or re-discover tools for a single MCP server.
* @param serverName - The name of the server to discover tools from.
*/
async discoverToolsForServer(serverName: string): Promise<void> {
// Remove any previously discovered tools from this server
for (const [name, tool] of this.tools.entries()) {
if (tool instanceof DiscoveredMCPTool && tool.serverName === serverName) {
this.tools.delete(name);
}
}
const mcpServers = this.config.getMcpServers() ?? {};
const serverConfig = mcpServers[serverName];
if (serverConfig) {
await discoverMcpTools(
{ [serverName]: serverConfig },
undefined,
this,
this.config.getPromptRegistry(),
this.config.getDebugMode(),
);
}
}
private async discoverAndRegisterToolsFromCommand(): Promise<void> {
const discoveryCmd = this.config.getToolDiscoveryCommand();
if (!discoveryCmd) {
@@ -308,7 +363,9 @@ export class ToolRegistry {
* Returns an array of all registered and discovered tool instances.
*/
getAllTools(): Tool[] {
return Array.from(this.tools.values());
return Array.from(this.tools.values()).sort((a, b) =>
a.displayName.localeCompare(b.displayName),
);
}
/**
@@ -321,7 +378,7 @@ export class ToolRegistry {
serverTools.push(tool);
}
}
return serverTools;
return serverTools.sort((a, b) => a.name.localeCompare(b.name));
}
/**
@@ -379,6 +436,19 @@ function _sanitizeParameters(schema: Schema | undefined, visited: Set<Schema>) {
}
}
}
// Handle enum values - Gemini API only allows enum for STRING type
if (schema.enum && Array.isArray(schema.enum)) {
if (schema.type !== Type.STRING) {
// If enum is present but type is not STRING, convert type to STRING
schema.type = Type.STRING;
}
// Filter out null and undefined values, then convert remaining values to strings for Gemini API compatibility
schema.enum = schema.enum
.filter((value: unknown) => value !== null && value !== undefined)
.map((value: unknown) => String(value));
}
// Vertex AI only supports 'enum' and 'date-time' for STRING format.
if (schema.type === Type.STRING) {
if (

View File

@@ -28,6 +28,11 @@ export interface Tool<
*/
description: string;
/**
* The icon to display when interacting via ACP
*/
icon: Icon;
/**
* Function declaration schema from @google/genai
*/
@@ -60,6 +65,13 @@ export interface Tool<
*/
getDescription(params: TParams): string;
/**
* Determines what file system paths the tool will affect
* @param params Parameters for the tool execution
* @returns A list of such paths
*/
toolLocations(params: TParams): ToolLocation[];
/**
* Determines if the tool should prompt for confirmation before execution
* @param params Parameters for the tool execution
@@ -97,12 +109,13 @@ export abstract class BaseTool<
* @param description Description of what the tool does
* @param isOutputMarkdown Whether the tool's output should be rendered as markdown
* @param canUpdateOutput Whether the tool supports live (streaming) output
* @param parameterSchema JSON Schema defining the parameters
* @param parameterSchema Open API 3.0 Schema defining the parameters
*/
constructor(
readonly name: string,
readonly displayName: string,
readonly description: string,
readonly icon: Icon,
readonly parameterSchema: Schema,
readonly isOutputMarkdown: boolean = true,
readonly canUpdateOutput: boolean = false,
@@ -158,6 +171,18 @@ export abstract class BaseTool<
return Promise.resolve(false);
}
/**
* Determines what file system paths the tool will affect
* @param params Parameters for the tool execution
* @returns A list of such paths
*/
toolLocations(
// eslint-disable-next-line @typescript-eslint/no-unused-vars
params: TParams,
): ToolLocation[] {
return [];
}
/**
* Abstract method to execute the tool with the given parameters
* Must be implemented by derived classes
@@ -199,6 +224,8 @@ export type ToolResultDisplay = string | FileDiff;
export interface FileDiff {
fileDiff: string;
fileName: string;
originalContent: string | null;
newContent: string;
}
export interface ToolEditConfirmationDetails {
@@ -210,6 +237,8 @@ export interface ToolEditConfirmationDetails {
) => Promise<void>;
fileName: string;
fileDiff: string;
originalContent: string | null;
newContent: string;
isModifying?: boolean;
}
@@ -258,3 +287,21 @@ export enum ToolConfirmationOutcome {
ModifyWithEditor = 'modify_with_editor',
Cancel = 'cancel',
}
export enum Icon {
FileSearch = 'fileSearch',
Folder = 'folder',
Globe = 'globe',
Hammer = 'hammer',
LightBulb = 'lightBulb',
Pencil = 'pencil',
Regex = 'regex',
Terminal = 'terminal',
}
export interface ToolLocation {
// Absolute path to the file
path: string;
// Which line (if known)
line?: number;
}

View File

@@ -13,6 +13,7 @@ describe('WebFetchTool', () => {
const mockConfig = {
getApprovalMode: vi.fn(),
setApprovalMode: vi.fn(),
getProxy: vi.fn(),
} as unknown as Config;
describe('shouldConfirmExecute', () => {

View File

@@ -10,6 +10,7 @@ import {
ToolResult,
ToolCallConfirmationDetails,
ToolConfirmationOutcome,
Icon,
} from './tools.js';
import { Type } from '@google/genai';
import { getErrorMessage } from '../utils/errors.js';
@@ -17,9 +18,10 @@ import { Config, ApprovalMode } from '../config/config.js';
import { getResponseText } from '../utils/generateContentResponseUtilities.js';
import { fetchWithTimeout, isPrivateIp } from '../utils/fetch.js';
import { convert } from 'html-to-text';
import { ProxyAgent, setGlobalDispatcher } from 'undici';
const URL_FETCH_TIMEOUT_MS = 10000;
const MAX_CONTENT_LENGTH = 50000;
const MAX_CONTENT_LENGTH = 100000;
// Helper function to extract URLs from a string
function extractUrls(text: string): string[] {
@@ -69,6 +71,7 @@ export class WebFetchTool extends BaseTool<WebFetchToolParams, ToolResult> {
WebFetchTool.Name,
'WebFetch',
"Processes content from URL(s), including local and private network addresses (e.g., localhost), embedded in a prompt. Include up to 20 URLs and instructions (e.g., summarize, extract specific data) directly in the 'prompt' parameter.",
Icon.Globe,
{
properties: {
prompt: {
@@ -81,6 +84,10 @@ export class WebFetchTool extends BaseTool<WebFetchToolParams, ToolResult> {
type: Type.OBJECT,
},
);
const proxy = config.getProxy();
if (proxy) {
setGlobalDispatcher(new ProxyAgent(proxy as string));
}
}
private async executeFallback(
@@ -94,70 +101,40 @@ export class WebFetchTool extends BaseTool<WebFetchToolParams, ToolResult> {
returnDisplay: 'Error: No URL found in the prompt for fallback.',
};
}
// For now, we only support one URL for fallback
let url = urls[0];
const results: string[] = [];
const processedUrls: string[] = [];
// Process multiple URLs (up to 20 as mentioned in description)
const urlsToProcess = urls.slice(0, 20);
for (const originalUrl of urlsToProcess) {
let url = originalUrl;
// Convert GitHub blob URL to raw URL
if (url.includes('github.com') && url.includes('/blob/')) {
url = url
.replace('github.com', 'raw.githubusercontent.com')
.replace('/blob/', '/');
}
try {
const response = await fetchWithTimeout(url, URL_FETCH_TIMEOUT_MS);
if (!response.ok) {
throw new Error(
`Request failed with status code ${response.status} ${response.statusText}`,
);
}
const html = await response.text();
const textContent = convert(html, {
wordwrap: false,
selectors: [
{ selector: 'a', options: { ignoreHref: true } },
{ selector: 'img', format: 'skip' },
],
}).substring(0, MAX_CONTENT_LENGTH);
results.push(`Content from ${url}:\n${textContent}`);
processedUrls.push(url);
} catch (e) {
const error = e as Error;
results.push(`Error fetching ${url}: ${error.message}`);
processedUrls.push(url);
}
// Convert GitHub blob URL to raw URL
if (url.includes('github.com') && url.includes('/blob/')) {
url = url
.replace('github.com', 'raw.githubusercontent.com')
.replace('/blob/', '/');
}
try {
const geminiClient = this.config.getGeminiClient();
const combinedContent = results.join('\n\n---\n\n');
// Ensure the total prompt length doesn't exceed limits
const maxPromptLength = 200000; // Leave room for system instructions
const promptPrefix = `The user requested the following: "${params.prompt}".
I have fetched the content from the following URL(s). Please use this content to answer the user's request. Do not attempt to access the URL(s) again.
`;
let finalContent = combinedContent;
if (promptPrefix.length + combinedContent.length > maxPromptLength) {
const availableLength = maxPromptLength - promptPrefix.length - 100; // Leave some buffer
finalContent =
combinedContent.substring(0, availableLength) +
'\n\n[Content truncated due to length limits]';
const response = await fetchWithTimeout(url, URL_FETCH_TIMEOUT_MS);
if (!response.ok) {
throw new Error(
`Request failed with status code ${response.status} ${response.statusText}`,
);
}
const html = await response.text();
const textContent = convert(html, {
wordwrap: false,
selectors: [
{ selector: 'a', options: { ignoreHref: true } },
{ selector: 'img', format: 'skip' },
],
}).substring(0, MAX_CONTENT_LENGTH);
const fallbackPrompt = promptPrefix + finalContent;
const geminiClient = this.config.getGeminiClient();
const fallbackPrompt = `The user requested the following: "${params.prompt}".
I was unable to access the URL directly. Instead, I have fetched the raw content of the page. Please use the following content to answer the user's request. Do not attempt to access the URL again.
---
${textContent}
---`;
const result = await geminiClient.generateContent(
[{ role: 'user', parts: [{ text: fallbackPrompt }] }],
{},
@@ -166,11 +143,11 @@ I have fetched the content from the following URL(s). Please use this content to
const resultText = getResponseText(result) || '';
return {
llmContent: resultText,
returnDisplay: `Content from ${processedUrls.length} URL(s) processed using fallback fetch.`,
returnDisplay: `Content for ${url} processed using fallback fetch.`,
};
} catch (e) {
const error = e as Error;
const errorMessage = `Error during fallback processing: ${error.message}`;
const errorMessage = `Error during fallback fetch for ${url}: ${error.message}`;
return {
llmContent: `Error: ${errorMessage}`,
returnDisplay: `Error: ${errorMessage}`,
@@ -262,12 +239,6 @@ I have fetched the content from the following URL(s). Please use this content to
}
const geminiClient = this.config.getGeminiClient();
const contentGenerator = geminiClient.getContentGenerator();
// Check if using OpenAI content generator - if so, use fallback
if (contentGenerator.constructor.name === 'OpenAIContentGenerator') {
return this.executeFallback(params, signal);
}
try {
const response = await geminiClient.generateContent(

View File

@@ -5,7 +5,7 @@
*/
import { GroundingMetadata } from '@google/genai';
import { BaseTool, ToolResult } from './tools.js';
import { BaseTool, Icon, ToolResult } from './tools.js';
import { Type } from '@google/genai';
import { SchemaValidator } from '../utils/schemaValidator.js';
@@ -69,6 +69,7 @@ export class WebSearchTool extends BaseTool<
WebSearchTool.Name,
'GoogleSearch',
'Performs a web search using Google Search (via the Gemini API) and returns the results. This tool is useful for finding information on the internet based on a query.',
Icon.Globe,
{
type: Type.OBJECT,
properties: {

View File

@@ -15,6 +15,7 @@ import {
ToolEditConfirmationDetails,
ToolConfirmationOutcome,
ToolCallConfirmationDetails,
Icon,
} from './tools.js';
import { Type } from '@google/genai';
import { SchemaValidator } from '../utils/schemaValidator.js';
@@ -72,9 +73,10 @@ export class WriteFileTool
super(
WriteFileTool.Name,
'WriteFile',
`Writes content to a specified file in the local filesystem.
`Writes content to a specified file in the local filesystem.
The user has the ability to modify \`content\`. If modified, this will be stated in the response.`,
Icon.Pencil,
{
properties: {
file_path: {
@@ -184,6 +186,8 @@ export class WriteFileTool
title: `Confirm Write: ${shortenPath(relativePath)}`,
fileName,
fileDiff,
originalContent,
newContent: correctedContent,
onConfirm: async (outcome: ToolConfirmationOutcome) => {
if (outcome === ToolConfirmationOutcome.ProceedAlways) {
this.config.setApprovalMode(ApprovalMode.AUTO_EDIT);
@@ -269,7 +273,12 @@ export class WriteFileTool
);
}
const displayResult: FileDiff = { fileDiff, fileName };
const displayResult: FileDiff = {
fileDiff,
fileName,
originalContent: correctedContentResult.originalContent,
newContent: correctedContentResult.correctedContent,
};
const lines = fileContent.split('\n').length;
const mimetype = getSpecificMimeType(params.file_path);