mirror of
https://github.com/QwenLM/qwen-code.git
synced 2025-12-20 16:57:46 +00:00
Merge tag 'v0.1.15' into feature/yiheng/sync-gemini-cli-0.1.15
This commit is contained in:
@@ -608,6 +608,19 @@ describe('EditTool', () => {
|
||||
/User modified the `new_string` content/,
|
||||
);
|
||||
});
|
||||
|
||||
it('should return error if old_string and new_string are identical', async () => {
|
||||
const initialContent = 'This is some identical text.';
|
||||
fs.writeFileSync(filePath, initialContent, 'utf8');
|
||||
const params: EditToolParams = {
|
||||
file_path: filePath,
|
||||
old_string: 'identical',
|
||||
new_string: 'identical',
|
||||
};
|
||||
const result = await tool.execute(params, new AbortController().signal);
|
||||
expect(result.llmContent).toMatch(/No changes to apply/);
|
||||
expect(result.returnDisplay).toMatch(/No changes to apply/);
|
||||
});
|
||||
});
|
||||
|
||||
describe('getDescription', () => {
|
||||
|
||||
@@ -9,9 +9,11 @@ import * as path from 'path';
|
||||
import * as Diff from 'diff';
|
||||
import {
|
||||
BaseTool,
|
||||
Icon,
|
||||
ToolCallConfirmationDetails,
|
||||
ToolConfirmationOutcome,
|
||||
ToolEditConfirmationDetails,
|
||||
ToolLocation,
|
||||
ToolResult,
|
||||
ToolResultDisplay,
|
||||
} from './tools.js';
|
||||
@@ -89,6 +91,7 @@ Expectation for required parameters:
|
||||
4. NEVER escape \`old_string\` or \`new_string\`, that would break the exact literal text requirement.
|
||||
**Important:** If ANY of the above are not satisfied, the tool will fail. CRITICAL for \`old_string\`: Must uniquely identify the single instance to change. Include at least 3 lines of context BEFORE and AFTER the target text, matching whitespace and indentation precisely. If this string matches multiple locations, or does not match exactly, the tool will fail.
|
||||
**Multiple replacements:** Set \`expected_replacements\` to the number of occurrences you want to replace. The tool will replace ALL occurrences that match \`old_string\` exactly. Ensure the number of replacements matches your expectation.`,
|
||||
Icon.Pencil,
|
||||
{
|
||||
properties: {
|
||||
file_path: {
|
||||
@@ -141,6 +144,15 @@ Expectation for required parameters:
|
||||
return null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Determines any file locations affected by the tool execution
|
||||
* @param params Parameters for the tool execution
|
||||
* @returns A list of such paths
|
||||
*/
|
||||
toolLocations(params: EditToolParams): ToolLocation[] {
|
||||
return [{ path: params.file_path }];
|
||||
}
|
||||
|
||||
private _applyReplacement(
|
||||
currentContent: string | null,
|
||||
oldString: string,
|
||||
@@ -197,7 +209,7 @@ Expectation for required parameters:
|
||||
// Creating a new file
|
||||
isNewFile = true;
|
||||
} else if (!fileExists) {
|
||||
// Trying to edit a non-existent file (and old_string is not empty)
|
||||
// Trying to edit a nonexistent file (and old_string is not empty)
|
||||
error = {
|
||||
display: `File not found. Cannot apply edit. Use an empty old_string to create a new file.`,
|
||||
raw: `File not found: ${params.file_path}`,
|
||||
@@ -227,12 +239,17 @@ Expectation for required parameters:
|
||||
raw: `Failed to edit, 0 occurrences found for old_string in ${params.file_path}. No edits made. The exact text in old_string was not found. Ensure you're not escaping content incorrectly and check whitespace, indentation, and context. Use ${ReadFileTool.Name} tool to verify.`,
|
||||
};
|
||||
} else if (occurrences !== expectedReplacements) {
|
||||
const occurenceTerm =
|
||||
const occurrenceTerm =
|
||||
expectedReplacements === 1 ? 'occurrence' : 'occurrences';
|
||||
|
||||
error = {
|
||||
display: `Failed to edit, expected ${expectedReplacements} ${occurenceTerm} but found ${occurrences}.`,
|
||||
raw: `Failed to edit, Expected ${expectedReplacements} ${occurenceTerm} but found ${occurrences} for old_string in file: ${params.file_path}`,
|
||||
display: `Failed to edit, expected ${expectedReplacements} ${occurrenceTerm} but found ${occurrences}.`,
|
||||
raw: `Failed to edit, Expected ${expectedReplacements} ${occurrenceTerm} but found ${occurrences} for old_string in file: ${params.file_path}`,
|
||||
};
|
||||
} else if (finalOldString === finalNewString) {
|
||||
error = {
|
||||
display: `No changes to apply. The old_string and new_string are identical.`,
|
||||
raw: `No changes to apply. The old_string and new_string are identical in file: ${params.file_path}`,
|
||||
};
|
||||
}
|
||||
} else {
|
||||
@@ -306,6 +323,8 @@ Expectation for required parameters:
|
||||
title: `Confirm Edit: ${shortenPath(makeRelative(params.file_path, this.config.getTargetDir()))}`,
|
||||
fileName,
|
||||
fileDiff,
|
||||
originalContent: editData.currentContent,
|
||||
newContent: editData.newContent,
|
||||
onConfirm: async (outcome: ToolConfirmationOutcome) => {
|
||||
if (outcome === ToolConfirmationOutcome.ProceedAlways) {
|
||||
this.config.setApprovalMode(ApprovalMode.AUTO_EDIT);
|
||||
@@ -394,7 +413,12 @@ Expectation for required parameters:
|
||||
'Proposed',
|
||||
DEFAULT_DIFF_OPTIONS,
|
||||
);
|
||||
displayResult = { fileDiff, fileName };
|
||||
displayResult = {
|
||||
fileDiff,
|
||||
fileName,
|
||||
originalContent: editData.currentContent,
|
||||
newContent: editData.newContent,
|
||||
};
|
||||
}
|
||||
|
||||
const llmSuccessMessageParts = [
|
||||
|
||||
@@ -150,11 +150,19 @@ describe('GlobTool', () => {
|
||||
expect(typeof llmContent).toBe('string');
|
||||
|
||||
const filesListed = llmContent
|
||||
.substring(llmContent.indexOf(':') + 1)
|
||||
.trim()
|
||||
.split('\n');
|
||||
expect(filesListed[0]).toContain(path.join(tempRootDir, 'newer.sortme'));
|
||||
expect(filesListed[1]).toContain(path.join(tempRootDir, 'older.sortme'));
|
||||
.split(/\r?\n/)
|
||||
.slice(1)
|
||||
.map((line) => line.trim())
|
||||
.filter(Boolean);
|
||||
|
||||
expect(filesListed).toHaveLength(2);
|
||||
expect(path.resolve(filesListed[0])).toBe(
|
||||
path.resolve(tempRootDir, 'newer.sortme'),
|
||||
);
|
||||
expect(path.resolve(filesListed[1])).toBe(
|
||||
path.resolve(tempRootDir, 'older.sortme'),
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@ import fs from 'fs';
|
||||
import path from 'path';
|
||||
import { glob } from 'glob';
|
||||
import { SchemaValidator } from '../utils/schemaValidator.js';
|
||||
import { BaseTool, ToolResult } from './tools.js';
|
||||
import { BaseTool, Icon, ToolResult } from './tools.js';
|
||||
import { Type } from '@google/genai';
|
||||
import { shortenPath, makeRelative } from '../utils/paths.js';
|
||||
import { isWithinRoot } from '../utils/fileUtils.js';
|
||||
@@ -86,6 +86,7 @@ export class GlobTool extends BaseTool<GlobToolParams, ToolResult> {
|
||||
GlobTool.Name,
|
||||
'FindFiles',
|
||||
'Efficiently finds files matching specific glob patterns (e.g., `src/**/*.ts`, `**/*.md`), returning absolute paths sorted by modification time (newest first). Ideal for quickly locating files based on their name or path structure, especially in large codebases.',
|
||||
Icon.FileSearch,
|
||||
{
|
||||
properties: {
|
||||
pattern: {
|
||||
@@ -199,7 +200,7 @@ export class GlobTool extends BaseTool<GlobToolParams, ToolResult> {
|
||||
this.config.getFileFilteringRespectGitIgnore();
|
||||
const fileDiscovery = this.config.getFileService();
|
||||
|
||||
const entries = (await glob(params.pattern, {
|
||||
const entries = await glob(params.pattern, {
|
||||
cwd: searchDirAbsolute,
|
||||
withFileTypes: true,
|
||||
nodir: true,
|
||||
@@ -209,7 +210,7 @@ export class GlobTool extends BaseTool<GlobToolParams, ToolResult> {
|
||||
ignore: ['**/node_modules/**', '**/.git/**'],
|
||||
follow: false,
|
||||
signal,
|
||||
})) as GlobPath[];
|
||||
});
|
||||
|
||||
// Apply git-aware filtering if enabled and in git repository
|
||||
let filteredEntries = entries;
|
||||
|
||||
@@ -17,7 +17,7 @@ vi.mock('child_process', () => ({
|
||||
on: (event: string, cb: (...args: unknown[]) => void) => {
|
||||
if (event === 'error' || event === 'close') {
|
||||
// Simulate command not found or error for git grep and system grep
|
||||
// to force fallback to JS implementation.
|
||||
// to force it to fall back to JS implementation.
|
||||
setTimeout(() => cb(1), 0); // cb(1) for error/close
|
||||
}
|
||||
},
|
||||
@@ -125,7 +125,9 @@ describe('GrepTool', () => {
|
||||
expect(result.llmContent).toContain('File: fileA.txt');
|
||||
expect(result.llmContent).toContain('L1: hello world');
|
||||
expect(result.llmContent).toContain('L2: second line with world');
|
||||
expect(result.llmContent).toContain('File: sub/fileC.txt');
|
||||
expect(result.llmContent).toContain(
|
||||
`File: ${path.join('sub', 'fileC.txt')}`,
|
||||
);
|
||||
expect(result.llmContent).toContain('L1: another world in sub dir');
|
||||
expect(result.returnDisplay).toBe('Found 3 matches');
|
||||
});
|
||||
@@ -235,7 +237,7 @@ describe('GrepTool', () => {
|
||||
it('should generate correct description with pattern and path', () => {
|
||||
const params: GrepToolParams = {
|
||||
pattern: 'testPattern',
|
||||
path: 'src/app',
|
||||
path: path.join('src', 'app'),
|
||||
};
|
||||
// The path will be relative to the tempRootDir, so we check for containment.
|
||||
expect(grepTool.getDescription(params)).toContain("'testPattern' within");
|
||||
@@ -248,12 +250,14 @@ describe('GrepTool', () => {
|
||||
const params: GrepToolParams = {
|
||||
pattern: 'testPattern',
|
||||
include: '*.ts',
|
||||
path: 'src/app',
|
||||
path: path.join('src', 'app'),
|
||||
};
|
||||
expect(grepTool.getDescription(params)).toContain(
|
||||
"'testPattern' in *.ts within",
|
||||
);
|
||||
expect(grepTool.getDescription(params)).toContain('src/app');
|
||||
expect(grepTool.getDescription(params)).toContain(
|
||||
path.join('src', 'app'),
|
||||
);
|
||||
});
|
||||
|
||||
it('should use ./ for root path in description', () => {
|
||||
|
||||
@@ -9,8 +9,8 @@ import fsPromises from 'fs/promises';
|
||||
import path from 'path';
|
||||
import { EOL } from 'os';
|
||||
import { spawn } from 'child_process';
|
||||
import { globStream } from 'glob';
|
||||
import { BaseTool, ToolResult } from './tools.js';
|
||||
import { globIterate } from 'glob';
|
||||
import { BaseTool, Icon, ToolResult } from './tools.js';
|
||||
import { Type } from '@google/genai';
|
||||
import { SchemaValidator } from '../utils/schemaValidator.js';
|
||||
import { makeRelative, shortenPath } from '../utils/paths.js';
|
||||
@@ -62,6 +62,7 @@ export class GrepTool extends BaseTool<GrepToolParams, ToolResult> {
|
||||
GrepTool.Name,
|
||||
'SearchText',
|
||||
'Searches for a regular expression pattern within the content of files in a specified directory (or current working directory). Can filter files by a glob pattern. Returns the lines containing matches, along with their file paths and line numbers.',
|
||||
Icon.Regex,
|
||||
{
|
||||
properties: {
|
||||
pattern: {
|
||||
@@ -498,7 +499,7 @@ export class GrepTool extends BaseTool<GrepToolParams, ToolResult> {
|
||||
'.hg/**',
|
||||
]; // Use glob patterns for ignores here
|
||||
|
||||
const filesStream = globStream(globPattern, {
|
||||
const filesIterator = globIterate(globPattern, {
|
||||
cwd: absolutePath,
|
||||
dot: true,
|
||||
ignore: ignorePatterns,
|
||||
@@ -510,7 +511,7 @@ export class GrepTool extends BaseTool<GrepToolParams, ToolResult> {
|
||||
const regex = new RegExp(pattern, 'i');
|
||||
const allMatches: GrepMatch[] = [];
|
||||
|
||||
for await (const filePath of filesStream) {
|
||||
for await (const filePath of filesIterator) {
|
||||
const fileAbsolutePath = filePath as string;
|
||||
try {
|
||||
const content = await fsPromises.readFile(fileAbsolutePath, 'utf8');
|
||||
|
||||
@@ -6,11 +6,11 @@
|
||||
|
||||
import fs from 'fs';
|
||||
import path from 'path';
|
||||
import { BaseTool, ToolResult } from './tools.js';
|
||||
import { BaseTool, Icon, ToolResult } from './tools.js';
|
||||
import { Type } from '@google/genai';
|
||||
import { SchemaValidator } from '../utils/schemaValidator.js';
|
||||
import { makeRelative, shortenPath } from '../utils/paths.js';
|
||||
import { Config } from '../config/config.js';
|
||||
import { Config, DEFAULT_FILE_FILTERING_OPTIONS } from '../config/config.js';
|
||||
import { isWithinRoot } from '../utils/fileUtils.js';
|
||||
|
||||
/**
|
||||
@@ -28,9 +28,12 @@ export interface LSToolParams {
|
||||
ignore?: string[];
|
||||
|
||||
/**
|
||||
* Whether to respect .gitignore patterns (optional, defaults to true)
|
||||
* Whether to respect .gitignore and .geminiignore patterns (optional, defaults to true)
|
||||
*/
|
||||
respect_git_ignore?: boolean;
|
||||
file_filtering_options?: {
|
||||
respect_git_ignore?: boolean;
|
||||
respect_gemini_ignore?: boolean;
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -74,6 +77,7 @@ export class LSTool extends BaseTool<LSToolParams, ToolResult> {
|
||||
LSTool.Name,
|
||||
'ReadFolder',
|
||||
'Lists the names of files and subdirectories directly within a specified directory path. Can optionally ignore entries matching provided glob patterns.',
|
||||
Icon.Folder,
|
||||
{
|
||||
properties: {
|
||||
path: {
|
||||
@@ -88,10 +92,22 @@ export class LSTool extends BaseTool<LSToolParams, ToolResult> {
|
||||
},
|
||||
type: Type.ARRAY,
|
||||
},
|
||||
respect_git_ignore: {
|
||||
file_filtering_options: {
|
||||
description:
|
||||
'Optional: Whether to respect .gitignore patterns when listing files. Only available in git repositories. Defaults to true.',
|
||||
type: Type.BOOLEAN,
|
||||
'Optional: Whether to respect ignore patterns from .gitignore or .geminiignore',
|
||||
type: Type.OBJECT,
|
||||
properties: {
|
||||
respect_git_ignore: {
|
||||
description:
|
||||
'Optional: Whether to respect .gitignore patterns when listing files. Only available in git repositories. Defaults to true.',
|
||||
type: Type.BOOLEAN,
|
||||
},
|
||||
respect_gemini_ignore: {
|
||||
description:
|
||||
'Optional: Whether to respect .geminiignore patterns when listing files. Defaults to true.',
|
||||
type: Type.BOOLEAN,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
required: ['path'],
|
||||
@@ -198,14 +214,25 @@ export class LSTool extends BaseTool<LSToolParams, ToolResult> {
|
||||
|
||||
const files = fs.readdirSync(params.path);
|
||||
|
||||
const defaultFileIgnores =
|
||||
this.config.getFileFilteringOptions() ?? DEFAULT_FILE_FILTERING_OPTIONS;
|
||||
|
||||
const fileFilteringOptions = {
|
||||
respectGitIgnore:
|
||||
params.file_filtering_options?.respect_git_ignore ??
|
||||
defaultFileIgnores.respectGitIgnore,
|
||||
respectGeminiIgnore:
|
||||
params.file_filtering_options?.respect_gemini_ignore ??
|
||||
defaultFileIgnores.respectGeminiIgnore,
|
||||
};
|
||||
|
||||
// Get centralized file discovery service
|
||||
const respectGitIgnore =
|
||||
params.respect_git_ignore ??
|
||||
this.config.getFileFilteringRespectGitIgnore();
|
||||
|
||||
const fileDiscovery = this.config.getFileService();
|
||||
|
||||
const entries: FileEntry[] = [];
|
||||
let gitIgnoredCount = 0;
|
||||
let geminiIgnoredCount = 0;
|
||||
|
||||
if (files.length === 0) {
|
||||
// Changed error message to be more neutral for LLM
|
||||
@@ -226,14 +253,21 @@ export class LSTool extends BaseTool<LSToolParams, ToolResult> {
|
||||
fullPath,
|
||||
);
|
||||
|
||||
// Check if this file should be git-ignored (only in git repositories)
|
||||
// Check if this file should be ignored based on git or gemini ignore rules
|
||||
if (
|
||||
respectGitIgnore &&
|
||||
fileFilteringOptions.respectGitIgnore &&
|
||||
fileDiscovery.shouldGitIgnoreFile(relativePath)
|
||||
) {
|
||||
gitIgnoredCount++;
|
||||
continue;
|
||||
}
|
||||
if (
|
||||
fileFilteringOptions.respectGeminiIgnore &&
|
||||
fileDiscovery.shouldGeminiIgnoreFile(relativePath)
|
||||
) {
|
||||
geminiIgnoredCount++;
|
||||
continue;
|
||||
}
|
||||
|
||||
try {
|
||||
const stats = fs.statSync(fullPath);
|
||||
@@ -264,13 +298,21 @@ export class LSTool extends BaseTool<LSToolParams, ToolResult> {
|
||||
.join('\n');
|
||||
|
||||
let resultMessage = `Directory listing for ${params.path}:\n${directoryContent}`;
|
||||
const ignoredMessages = [];
|
||||
if (gitIgnoredCount > 0) {
|
||||
resultMessage += `\n\n(${gitIgnoredCount} items were git-ignored)`;
|
||||
ignoredMessages.push(`${gitIgnoredCount} git-ignored`);
|
||||
}
|
||||
if (geminiIgnoredCount > 0) {
|
||||
ignoredMessages.push(`${geminiIgnoredCount} gemini-ignored`);
|
||||
}
|
||||
|
||||
if (ignoredMessages.length > 0) {
|
||||
resultMessage += `\n\n(${ignoredMessages.join(', ')})`;
|
||||
}
|
||||
|
||||
let displayMessage = `Listed ${entries.length} item(s).`;
|
||||
if (gitIgnoredCount > 0) {
|
||||
displayMessage += ` (${gitIgnoredCount} git-ignored)`;
|
||||
if (ignoredMessages.length > 0) {
|
||||
displayMessage += ` (${ignoredMessages.join(', ')})`;
|
||||
}
|
||||
|
||||
return {
|
||||
|
||||
@@ -9,18 +9,23 @@ import { StreamableHTTPClientTransport } from '@modelcontextprotocol/sdk/client/
|
||||
import {
|
||||
populateMcpServerCommand,
|
||||
createTransport,
|
||||
generateValidName,
|
||||
isEnabled,
|
||||
discoverTools,
|
||||
discoverPrompts,
|
||||
} from './mcp-client.js';
|
||||
import { SSEClientTransport } from '@modelcontextprotocol/sdk/client/sse.js';
|
||||
import * as SdkClientStdioLib from '@modelcontextprotocol/sdk/client/stdio.js';
|
||||
import * as ClientLib from '@modelcontextprotocol/sdk/client/index.js';
|
||||
import * as GenAiLib from '@google/genai';
|
||||
import { GoogleCredentialProvider } from '../mcp/google-auth-provider.js';
|
||||
import { AuthProviderType } from '../config/config.js';
|
||||
import { PromptRegistry } from '../prompts/prompt-registry.js';
|
||||
|
||||
vi.mock('@modelcontextprotocol/sdk/client/stdio.js');
|
||||
vi.mock('@modelcontextprotocol/sdk/client/index.js');
|
||||
vi.mock('@google/genai');
|
||||
vi.mock('../mcp/oauth-provider.js');
|
||||
vi.mock('../mcp/oauth-token-storage.js');
|
||||
|
||||
describe('mcp-client', () => {
|
||||
afterEach(() => {
|
||||
@@ -47,6 +52,77 @@ describe('mcp-client', () => {
|
||||
});
|
||||
});
|
||||
|
||||
describe('discoverPrompts', () => {
|
||||
const mockedPromptRegistry = {
|
||||
registerPrompt: vi.fn(),
|
||||
} as unknown as PromptRegistry;
|
||||
|
||||
it('should discover and log prompts', async () => {
|
||||
const mockRequest = vi.fn().mockResolvedValue({
|
||||
prompts: [
|
||||
{ name: 'prompt1', description: 'desc1' },
|
||||
{ name: 'prompt2' },
|
||||
],
|
||||
});
|
||||
const mockedClient = {
|
||||
request: mockRequest,
|
||||
} as unknown as ClientLib.Client;
|
||||
|
||||
await discoverPrompts('test-server', mockedClient, mockedPromptRegistry);
|
||||
|
||||
expect(mockRequest).toHaveBeenCalledWith(
|
||||
{ method: 'prompts/list', params: {} },
|
||||
expect.anything(),
|
||||
);
|
||||
});
|
||||
|
||||
it('should do nothing if no prompts are discovered', async () => {
|
||||
const mockRequest = vi.fn().mockResolvedValue({
|
||||
prompts: [],
|
||||
});
|
||||
const mockedClient = {
|
||||
request: mockRequest,
|
||||
} as unknown as ClientLib.Client;
|
||||
|
||||
const consoleLogSpy = vi
|
||||
.spyOn(console, 'debug')
|
||||
.mockImplementation(() => {
|
||||
// no-op
|
||||
});
|
||||
|
||||
await discoverPrompts('test-server', mockedClient, mockedPromptRegistry);
|
||||
|
||||
expect(mockRequest).toHaveBeenCalledOnce();
|
||||
expect(consoleLogSpy).not.toHaveBeenCalled();
|
||||
|
||||
consoleLogSpy.mockRestore();
|
||||
});
|
||||
|
||||
it('should log an error if discovery fails', async () => {
|
||||
const testError = new Error('test error');
|
||||
testError.message = 'test error';
|
||||
const mockRequest = vi.fn().mockRejectedValue(testError);
|
||||
const mockedClient = {
|
||||
request: mockRequest,
|
||||
} as unknown as ClientLib.Client;
|
||||
|
||||
const consoleErrorSpy = vi
|
||||
.spyOn(console, 'error')
|
||||
.mockImplementation(() => {
|
||||
// no-op
|
||||
});
|
||||
|
||||
await discoverPrompts('test-server', mockedClient, mockedPromptRegistry);
|
||||
|
||||
expect(mockRequest).toHaveBeenCalledOnce();
|
||||
expect(consoleErrorSpy).toHaveBeenCalledWith(
|
||||
`Error discovering prompts from test-server: ${testError.message}`,
|
||||
);
|
||||
|
||||
consoleErrorSpy.mockRestore();
|
||||
});
|
||||
});
|
||||
|
||||
describe('appendMcpServerCommand', () => {
|
||||
it('should do nothing if no MCP servers or command are configured', () => {
|
||||
const out = populateMcpServerCommand({}, undefined);
|
||||
@@ -83,7 +159,7 @@ describe('mcp-client', () => {
|
||||
|
||||
describe('should connect via httpUrl', () => {
|
||||
it('without headers', async () => {
|
||||
const transport = createTransport(
|
||||
const transport = await createTransport(
|
||||
'test-server',
|
||||
{
|
||||
httpUrl: 'http://test-server',
|
||||
@@ -97,7 +173,7 @@ describe('mcp-client', () => {
|
||||
});
|
||||
|
||||
it('with headers', async () => {
|
||||
const transport = createTransport(
|
||||
const transport = await createTransport(
|
||||
'test-server',
|
||||
{
|
||||
httpUrl: 'http://test-server',
|
||||
@@ -118,7 +194,7 @@ describe('mcp-client', () => {
|
||||
|
||||
describe('should connect via url', () => {
|
||||
it('without headers', async () => {
|
||||
const transport = createTransport(
|
||||
const transport = await createTransport(
|
||||
'test-server',
|
||||
{
|
||||
url: 'http://test-server',
|
||||
@@ -131,7 +207,7 @@ describe('mcp-client', () => {
|
||||
});
|
||||
|
||||
it('with headers', async () => {
|
||||
const transport = createTransport(
|
||||
const transport = await createTransport(
|
||||
'test-server',
|
||||
{
|
||||
url: 'http://test-server',
|
||||
@@ -150,10 +226,10 @@ describe('mcp-client', () => {
|
||||
});
|
||||
});
|
||||
|
||||
it('should connect via command', () => {
|
||||
it('should connect via command', async () => {
|
||||
const mockedTransport = vi.mocked(SdkClientStdioLib.StdioClientTransport);
|
||||
|
||||
createTransport(
|
||||
await createTransport(
|
||||
'test-server',
|
||||
{
|
||||
command: 'test-command',
|
||||
@@ -172,91 +248,62 @@ describe('mcp-client', () => {
|
||||
stderr: 'pipe',
|
||||
});
|
||||
});
|
||||
});
|
||||
describe('generateValidName', () => {
|
||||
it('should return a valid name for a simple function', () => {
|
||||
const funcDecl = { name: 'myFunction' };
|
||||
const serverName = 'myServer';
|
||||
const result = generateValidName(funcDecl, serverName);
|
||||
expect(result).toBe('myServer__myFunction');
|
||||
});
|
||||
|
||||
it('should prepend the server name', () => {
|
||||
const funcDecl = { name: 'anotherFunction' };
|
||||
const serverName = 'production-server';
|
||||
const result = generateValidName(funcDecl, serverName);
|
||||
expect(result).toBe('production-server__anotherFunction');
|
||||
});
|
||||
describe('useGoogleCredentialProvider', () => {
|
||||
it('should use GoogleCredentialProvider when specified', async () => {
|
||||
const transport = await createTransport(
|
||||
'test-server',
|
||||
{
|
||||
httpUrl: 'http://test-server',
|
||||
authProviderType: AuthProviderType.GOOGLE_CREDENTIALS,
|
||||
oauth: {
|
||||
scopes: ['scope1'],
|
||||
},
|
||||
},
|
||||
false,
|
||||
);
|
||||
|
||||
it('should replace invalid characters with underscores', () => {
|
||||
const funcDecl = { name: 'invalid-name with spaces' };
|
||||
const serverName = 'test_server';
|
||||
const result = generateValidName(funcDecl, serverName);
|
||||
expect(result).toBe('test_server__invalid-name_with_spaces');
|
||||
});
|
||||
expect(transport).toBeInstanceOf(StreamableHTTPClientTransport);
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
const authProvider = (transport as any)._authProvider;
|
||||
expect(authProvider).toBeInstanceOf(GoogleCredentialProvider);
|
||||
});
|
||||
|
||||
it('should truncate long names', () => {
|
||||
const funcDecl = {
|
||||
name: 'a_very_long_function_name_that_will_definitely_exceed_the_limit',
|
||||
};
|
||||
const serverName = 'a_long_server_name';
|
||||
const result = generateValidName(funcDecl, serverName);
|
||||
expect(result.length).toBe(63);
|
||||
expect(result).toBe(
|
||||
'a_long_server_name__a_very_l___will_definitely_exceed_the_limit',
|
||||
);
|
||||
});
|
||||
it('should use GoogleCredentialProvider with SSE transport', async () => {
|
||||
const transport = await createTransport(
|
||||
'test-server',
|
||||
{
|
||||
url: 'http://test-server',
|
||||
authProviderType: AuthProviderType.GOOGLE_CREDENTIALS,
|
||||
oauth: {
|
||||
scopes: ['scope1'],
|
||||
},
|
||||
},
|
||||
false,
|
||||
);
|
||||
|
||||
it('should handle names with only invalid characters', () => {
|
||||
const funcDecl = { name: '!@#$%^&*()' };
|
||||
const serverName = 'special-chars';
|
||||
const result = generateValidName(funcDecl, serverName);
|
||||
expect(result).toBe('special-chars____________');
|
||||
});
|
||||
expect(transport).toBeInstanceOf(SSEClientTransport);
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
const authProvider = (transport as any)._authProvider;
|
||||
expect(authProvider).toBeInstanceOf(GoogleCredentialProvider);
|
||||
});
|
||||
|
||||
it('should handle names that are already valid', () => {
|
||||
const funcDecl = { name: 'already_valid' };
|
||||
const serverName = 'validator';
|
||||
const result = generateValidName(funcDecl, serverName);
|
||||
expect(result).toBe('validator__already_valid');
|
||||
});
|
||||
|
||||
it('should handle names with leading/trailing invalid characters', () => {
|
||||
const funcDecl = { name: '-_invalid-_' };
|
||||
const serverName = 'trim-test';
|
||||
const result = generateValidName(funcDecl, serverName);
|
||||
expect(result).toBe('trim-test__-_invalid-_');
|
||||
});
|
||||
|
||||
it('should handle names that are exactly 63 characters long', () => {
|
||||
const longName = 'a'.repeat(45);
|
||||
const funcDecl = { name: longName };
|
||||
const serverName = 'server';
|
||||
const result = generateValidName(funcDecl, serverName);
|
||||
expect(result).toBe(`server__${longName}`);
|
||||
expect(result.length).toBe(53);
|
||||
});
|
||||
|
||||
it('should handle names that are exactly 64 characters long', () => {
|
||||
const longName = 'a'.repeat(55);
|
||||
const funcDecl = { name: longName };
|
||||
const serverName = 'server';
|
||||
const result = generateValidName(funcDecl, serverName);
|
||||
expect(result.length).toBe(63);
|
||||
expect(result).toBe(
|
||||
'server__aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa',
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle names that are longer than 64 characters', () => {
|
||||
const longName = 'a'.repeat(100);
|
||||
const funcDecl = { name: longName };
|
||||
const serverName = 'long-server';
|
||||
const result = generateValidName(funcDecl, serverName);
|
||||
expect(result.length).toBe(63);
|
||||
expect(result).toBe(
|
||||
'long-server__aaaaaaaaaaaaaaa___aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa',
|
||||
);
|
||||
it('should throw an error if no URL is provided with GoogleCredentialProvider', async () => {
|
||||
await expect(
|
||||
createTransport(
|
||||
'test-server',
|
||||
{
|
||||
authProviderType: AuthProviderType.GOOGLE_CREDENTIALS,
|
||||
oauth: {
|
||||
scopes: ['scope1'],
|
||||
},
|
||||
},
|
||||
false,
|
||||
),
|
||||
).rejects.toThrow(
|
||||
'No URL configured for Google Credentials MCP server',
|
||||
);
|
||||
});
|
||||
});
|
||||
});
|
||||
describe('isEnabled', () => {
|
||||
|
||||
@@ -15,14 +15,32 @@ import {
|
||||
StreamableHTTPClientTransport,
|
||||
StreamableHTTPClientTransportOptions,
|
||||
} from '@modelcontextprotocol/sdk/client/streamableHttp.js';
|
||||
import {
|
||||
Prompt,
|
||||
ListPromptsResultSchema,
|
||||
GetPromptResult,
|
||||
GetPromptResultSchema,
|
||||
} from '@modelcontextprotocol/sdk/types.js';
|
||||
import { parse } from 'shell-quote';
|
||||
import { MCPServerConfig } from '../config/config.js';
|
||||
import { AuthProviderType, MCPServerConfig } from '../config/config.js';
|
||||
import { GoogleCredentialProvider } from '../mcp/google-auth-provider.js';
|
||||
import { DiscoveredMCPTool } from './mcp-tool.js';
|
||||
import { FunctionDeclaration, Type, mcpToTool } from '@google/genai';
|
||||
import { sanitizeParameters, ToolRegistry } from './tool-registry.js';
|
||||
|
||||
import { FunctionDeclaration, mcpToTool } from '@google/genai';
|
||||
import { ToolRegistry } from './tool-registry.js';
|
||||
import { PromptRegistry } from '../prompts/prompt-registry.js';
|
||||
import { MCPOAuthProvider } from '../mcp/oauth-provider.js';
|
||||
import { OAuthUtils } from '../mcp/oauth-utils.js';
|
||||
import { MCPOAuthTokenStorage } from '../mcp/oauth-token-storage.js';
|
||||
import { getErrorMessage } from '../utils/errors.js';
|
||||
|
||||
export const MCP_DEFAULT_TIMEOUT_MSEC = 10 * 60 * 1000; // default to 10 minutes
|
||||
|
||||
export type DiscoveredMCPPrompt = Prompt & {
|
||||
serverName: string;
|
||||
invoke: (params: Record<string, unknown>) => Promise<GetPromptResult>;
|
||||
};
|
||||
|
||||
/**
|
||||
* Enum representing the connection status of an MCP server
|
||||
*/
|
||||
@@ -50,13 +68,18 @@ export enum MCPDiscoveryState {
|
||||
/**
|
||||
* Map to track the status of each MCP server within the core package
|
||||
*/
|
||||
const mcpServerStatusesInternal: Map<string, MCPServerStatus> = new Map();
|
||||
const serverStatuses: Map<string, MCPServerStatus> = new Map();
|
||||
|
||||
/**
|
||||
* Track the overall MCP discovery state
|
||||
*/
|
||||
let mcpDiscoveryState: MCPDiscoveryState = MCPDiscoveryState.NOT_STARTED;
|
||||
|
||||
/**
|
||||
* Map to track which MCP servers have been discovered to require OAuth
|
||||
*/
|
||||
export const mcpServerRequiresOAuth: Map<string, boolean> = new Map();
|
||||
|
||||
/**
|
||||
* Event listeners for MCP server status changes
|
||||
*/
|
||||
@@ -94,7 +117,7 @@ function updateMCPServerStatus(
|
||||
serverName: string,
|
||||
status: MCPServerStatus,
|
||||
): void {
|
||||
mcpServerStatusesInternal.set(serverName, status);
|
||||
serverStatuses.set(serverName, status);
|
||||
// Notify all listeners
|
||||
for (const listener of statusChangeListeners) {
|
||||
listener(serverName, status);
|
||||
@@ -105,16 +128,14 @@ function updateMCPServerStatus(
|
||||
* Get the current status of an MCP server
|
||||
*/
|
||||
export function getMCPServerStatus(serverName: string): MCPServerStatus {
|
||||
return (
|
||||
mcpServerStatusesInternal.get(serverName) || MCPServerStatus.DISCONNECTED
|
||||
);
|
||||
return serverStatuses.get(serverName) || MCPServerStatus.DISCONNECTED;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get all MCP server statuses
|
||||
*/
|
||||
export function getAllMCPServerStatuses(): Map<string, MCPServerStatus> {
|
||||
return new Map(mcpServerStatusesInternal);
|
||||
return new Map(serverStatuses);
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -124,6 +145,165 @@ export function getMCPDiscoveryState(): MCPDiscoveryState {
|
||||
return mcpDiscoveryState;
|
||||
}
|
||||
|
||||
/**
|
||||
* Parse www-authenticate header to extract OAuth metadata URI.
|
||||
*
|
||||
* @param wwwAuthenticate The www-authenticate header value
|
||||
* @returns The resource metadata URI if found, null otherwise
|
||||
*/
|
||||
function _parseWWWAuthenticate(wwwAuthenticate: string): string | null {
|
||||
// Parse header like: Bearer realm="MCP Server", resource_metadata_uri="https://..."
|
||||
const resourceMetadataMatch = wwwAuthenticate.match(
|
||||
/resource_metadata_uri="([^"]+)"/,
|
||||
);
|
||||
return resourceMetadataMatch ? resourceMetadataMatch[1] : null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Extract WWW-Authenticate header from error message string.
|
||||
* This is a more robust approach than regex matching.
|
||||
*
|
||||
* @param errorString The error message string
|
||||
* @returns The www-authenticate header value if found, null otherwise
|
||||
*/
|
||||
function extractWWWAuthenticateHeader(errorString: string): string | null {
|
||||
// Try multiple patterns to extract the header
|
||||
const patterns = [
|
||||
/www-authenticate:\s*([^\n\r]+)/i,
|
||||
/WWW-Authenticate:\s*([^\n\r]+)/i,
|
||||
/"www-authenticate":\s*"([^"]+)"/i,
|
||||
/'www-authenticate':\s*'([^']+)'/i,
|
||||
];
|
||||
|
||||
for (const pattern of patterns) {
|
||||
const match = errorString.match(pattern);
|
||||
if (match) {
|
||||
return match[1].trim();
|
||||
}
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Handle automatic OAuth discovery and authentication for a server.
|
||||
*
|
||||
* @param mcpServerName The name of the MCP server
|
||||
* @param mcpServerConfig The MCP server configuration
|
||||
* @param wwwAuthenticate The www-authenticate header value
|
||||
* @returns True if OAuth was successfully configured and authenticated, false otherwise
|
||||
*/
|
||||
async function handleAutomaticOAuth(
|
||||
mcpServerName: string,
|
||||
mcpServerConfig: MCPServerConfig,
|
||||
wwwAuthenticate: string,
|
||||
): Promise<boolean> {
|
||||
try {
|
||||
console.log(`🔐 '${mcpServerName}' requires OAuth authentication`);
|
||||
|
||||
// Always try to parse the resource metadata URI from the www-authenticate header
|
||||
let oauthConfig;
|
||||
const resourceMetadataUri =
|
||||
OAuthUtils.parseWWWAuthenticateHeader(wwwAuthenticate);
|
||||
if (resourceMetadataUri) {
|
||||
oauthConfig = await OAuthUtils.discoverOAuthConfig(resourceMetadataUri);
|
||||
} else if (mcpServerConfig.url) {
|
||||
// Fallback: try to discover OAuth config from the base URL for SSE
|
||||
const sseUrl = new URL(mcpServerConfig.url);
|
||||
const baseUrl = `${sseUrl.protocol}//${sseUrl.host}`;
|
||||
oauthConfig = await OAuthUtils.discoverOAuthConfig(baseUrl);
|
||||
} else if (mcpServerConfig.httpUrl) {
|
||||
// Fallback: try to discover OAuth config from the base URL for HTTP
|
||||
const httpUrl = new URL(mcpServerConfig.httpUrl);
|
||||
const baseUrl = `${httpUrl.protocol}//${httpUrl.host}`;
|
||||
oauthConfig = await OAuthUtils.discoverOAuthConfig(baseUrl);
|
||||
}
|
||||
|
||||
if (!oauthConfig) {
|
||||
console.error(
|
||||
`❌ Could not configure OAuth for '${mcpServerName}' - please authenticate manually with /mcp auth ${mcpServerName}`,
|
||||
);
|
||||
return false;
|
||||
}
|
||||
|
||||
// OAuth configuration discovered - proceed with authentication
|
||||
|
||||
// Create OAuth configuration for authentication
|
||||
const oauthAuthConfig = {
|
||||
enabled: true,
|
||||
authorizationUrl: oauthConfig.authorizationUrl,
|
||||
tokenUrl: oauthConfig.tokenUrl,
|
||||
scopes: oauthConfig.scopes || [],
|
||||
};
|
||||
|
||||
// Perform OAuth authentication
|
||||
console.log(
|
||||
`Starting OAuth authentication for server '${mcpServerName}'...`,
|
||||
);
|
||||
await MCPOAuthProvider.authenticate(mcpServerName, oauthAuthConfig);
|
||||
|
||||
console.log(
|
||||
`OAuth authentication successful for server '${mcpServerName}'`,
|
||||
);
|
||||
return true;
|
||||
} catch (error) {
|
||||
console.error(
|
||||
`Failed to handle automatic OAuth for server '${mcpServerName}': ${getErrorMessage(error)}`,
|
||||
);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a transport with OAuth token for the given server configuration.
|
||||
*
|
||||
* @param mcpServerName The name of the MCP server
|
||||
* @param mcpServerConfig The MCP server configuration
|
||||
* @param accessToken The OAuth access token
|
||||
* @returns The transport with OAuth token, or null if creation fails
|
||||
*/
|
||||
async function createTransportWithOAuth(
|
||||
mcpServerName: string,
|
||||
mcpServerConfig: MCPServerConfig,
|
||||
accessToken: string,
|
||||
): Promise<StreamableHTTPClientTransport | SSEClientTransport | null> {
|
||||
try {
|
||||
if (mcpServerConfig.httpUrl) {
|
||||
// Create HTTP transport with OAuth token
|
||||
const oauthTransportOptions: StreamableHTTPClientTransportOptions = {
|
||||
requestInit: {
|
||||
headers: {
|
||||
...mcpServerConfig.headers,
|
||||
Authorization: `Bearer ${accessToken}`,
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
return new StreamableHTTPClientTransport(
|
||||
new URL(mcpServerConfig.httpUrl),
|
||||
oauthTransportOptions,
|
||||
);
|
||||
} else if (mcpServerConfig.url) {
|
||||
// Create SSE transport with OAuth token in Authorization header
|
||||
return new SSEClientTransport(new URL(mcpServerConfig.url), {
|
||||
requestInit: {
|
||||
headers: {
|
||||
...mcpServerConfig.headers,
|
||||
Authorization: `Bearer ${accessToken}`,
|
||||
},
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
return null;
|
||||
} catch (error) {
|
||||
console.error(
|
||||
`Failed to create OAuth transport for server '${mcpServerName}': ${getErrorMessage(error)}`,
|
||||
);
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Discovers tools from all configured MCP servers and registers them with the tool registry.
|
||||
* It orchestrates the connection and discovery process for each server defined in the
|
||||
@@ -138,6 +318,7 @@ export async function discoverMcpTools(
|
||||
mcpServers: Record<string, MCPServerConfig>,
|
||||
mcpServerCommand: string | undefined,
|
||||
toolRegistry: ToolRegistry,
|
||||
promptRegistry: PromptRegistry,
|
||||
debugMode: boolean,
|
||||
): Promise<void> {
|
||||
mcpDiscoveryState = MCPDiscoveryState.IN_PROGRESS;
|
||||
@@ -150,6 +331,7 @@ export async function discoverMcpTools(
|
||||
mcpServerName,
|
||||
mcpServerConfig,
|
||||
toolRegistry,
|
||||
promptRegistry,
|
||||
debugMode,
|
||||
),
|
||||
);
|
||||
@@ -193,6 +375,7 @@ export async function connectAndDiscover(
|
||||
mcpServerName: string,
|
||||
mcpServerConfig: MCPServerConfig,
|
||||
toolRegistry: ToolRegistry,
|
||||
promptRegistry: PromptRegistry,
|
||||
debugMode: boolean,
|
||||
): Promise<void> {
|
||||
updateMCPServerStatus(mcpServerName, MCPServerStatus.CONNECTING);
|
||||
@@ -205,11 +388,11 @@ export async function connectAndDiscover(
|
||||
);
|
||||
try {
|
||||
updateMCPServerStatus(mcpServerName, MCPServerStatus.CONNECTED);
|
||||
|
||||
mcpClient.onerror = (error) => {
|
||||
console.error(`MCP ERROR (${mcpServerName}):`, error.toString());
|
||||
updateMCPServerStatus(mcpServerName, MCPServerStatus.DISCONNECTED);
|
||||
};
|
||||
await discoverPrompts(mcpServerName, mcpClient, promptRegistry);
|
||||
|
||||
const tools = await discoverTools(
|
||||
mcpServerName,
|
||||
@@ -224,7 +407,11 @@ export async function connectAndDiscover(
|
||||
throw error;
|
||||
}
|
||||
} catch (error) {
|
||||
console.error(`Error connecting to MCP server '${mcpServerName}':`, error);
|
||||
console.error(
|
||||
`Error connecting to MCP server '${mcpServerName}': ${getErrorMessage(
|
||||
error,
|
||||
)}`,
|
||||
);
|
||||
updateMCPServerStatus(mcpServerName, MCPServerStatus.DISCONNECTED);
|
||||
}
|
||||
}
|
||||
@@ -259,32 +446,109 @@ export async function discoverTools(
|
||||
continue;
|
||||
}
|
||||
|
||||
const toolNameForModel = generateValidName(funcDecl, mcpServerName);
|
||||
|
||||
sanitizeParameters(funcDecl.parameters);
|
||||
|
||||
discoveredTools.push(
|
||||
new DiscoveredMCPTool(
|
||||
mcpCallableTool,
|
||||
mcpServerName,
|
||||
toolNameForModel,
|
||||
funcDecl.description ?? '',
|
||||
funcDecl.parameters ?? { type: Type.OBJECT, properties: {} },
|
||||
funcDecl.name!,
|
||||
funcDecl.description ?? '',
|
||||
funcDecl.parametersJsonSchema ?? { type: 'object', properties: {} },
|
||||
mcpServerConfig.timeout ?? MCP_DEFAULT_TIMEOUT_MSEC,
|
||||
mcpServerConfig.trust,
|
||||
),
|
||||
);
|
||||
}
|
||||
if (discoveredTools.length === 0) {
|
||||
throw Error('No enabled tools found');
|
||||
}
|
||||
return discoveredTools;
|
||||
} catch (error) {
|
||||
throw new Error(`Error discovering tools: ${error}`);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Discovers and logs prompts from a connected MCP client.
|
||||
* It retrieves prompt declarations from the client and logs their names.
|
||||
*
|
||||
* @param mcpServerName The name of the MCP server.
|
||||
* @param mcpClient The active MCP client instance.
|
||||
*/
|
||||
export async function discoverPrompts(
|
||||
mcpServerName: string,
|
||||
mcpClient: Client,
|
||||
promptRegistry: PromptRegistry,
|
||||
): Promise<void> {
|
||||
try {
|
||||
const response = await mcpClient.request(
|
||||
{ method: 'prompts/list', params: {} },
|
||||
ListPromptsResultSchema,
|
||||
);
|
||||
|
||||
for (const prompt of response.prompts) {
|
||||
promptRegistry.registerPrompt({
|
||||
...prompt,
|
||||
serverName: mcpServerName,
|
||||
invoke: (params: Record<string, unknown>) =>
|
||||
invokeMcpPrompt(mcpServerName, mcpClient, prompt.name, params),
|
||||
});
|
||||
}
|
||||
} catch (error) {
|
||||
// It's okay if this fails, not all servers will have prompts.
|
||||
// Don't log an error if the method is not found, which is a common case.
|
||||
if (
|
||||
error instanceof Error &&
|
||||
!error.message?.includes('Method not found')
|
||||
) {
|
||||
console.error(
|
||||
`Error discovering prompts from ${mcpServerName}: ${getErrorMessage(
|
||||
error,
|
||||
)}`,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Invokes a prompt on a connected MCP client.
|
||||
*
|
||||
* @param mcpServerName The name of the MCP server.
|
||||
* @param mcpClient The active MCP client instance.
|
||||
* @param promptName The name of the prompt to invoke.
|
||||
* @param promptParams The parameters to pass to the prompt.
|
||||
* @returns A promise that resolves to the result of the prompt invocation.
|
||||
*/
|
||||
export async function invokeMcpPrompt(
|
||||
mcpServerName: string,
|
||||
mcpClient: Client,
|
||||
promptName: string,
|
||||
promptParams: Record<string, unknown>,
|
||||
): Promise<GetPromptResult> {
|
||||
try {
|
||||
const response = await mcpClient.request(
|
||||
{
|
||||
method: 'prompts/get',
|
||||
params: {
|
||||
name: promptName,
|
||||
arguments: promptParams,
|
||||
},
|
||||
},
|
||||
GetPromptResultSchema,
|
||||
);
|
||||
|
||||
return response;
|
||||
} catch (error) {
|
||||
if (
|
||||
error instanceof Error &&
|
||||
!error.message?.includes('Method not found')
|
||||
) {
|
||||
console.error(
|
||||
`Error invoking prompt '${promptName}' from ${mcpServerName} ${promptParams}: ${getErrorMessage(
|
||||
error,
|
||||
)}`,
|
||||
);
|
||||
}
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates and connects an MCP client to a server based on the provided configuration.
|
||||
* It determines the appropriate transport (Stdio, SSE, or Streamable HTTP) and
|
||||
@@ -318,7 +582,7 @@ export async function connectToMcpServer(
|
||||
}
|
||||
|
||||
try {
|
||||
const transport = createTransport(
|
||||
const transport = await createTransport(
|
||||
mcpServerName,
|
||||
mcpServerConfig,
|
||||
debugMode,
|
||||
@@ -333,40 +597,419 @@ export async function connectToMcpServer(
|
||||
throw error;
|
||||
}
|
||||
} catch (error) {
|
||||
// Create a safe config object that excludes sensitive information
|
||||
const safeConfig = {
|
||||
command: mcpServerConfig.command,
|
||||
url: mcpServerConfig.url,
|
||||
httpUrl: mcpServerConfig.httpUrl,
|
||||
cwd: mcpServerConfig.cwd,
|
||||
timeout: mcpServerConfig.timeout,
|
||||
trust: mcpServerConfig.trust,
|
||||
// Exclude args, env, and headers which may contain sensitive data
|
||||
};
|
||||
// Check if this is a 401 error that might indicate OAuth is required
|
||||
const errorString = String(error);
|
||||
if (
|
||||
errorString.includes('401') &&
|
||||
(mcpServerConfig.httpUrl || mcpServerConfig.url)
|
||||
) {
|
||||
mcpServerRequiresOAuth.set(mcpServerName, true);
|
||||
// Only trigger automatic OAuth discovery for HTTP servers or when OAuth is explicitly configured
|
||||
// For SSE servers, we should not trigger new OAuth flows automatically
|
||||
const shouldTriggerOAuth =
|
||||
mcpServerConfig.httpUrl || mcpServerConfig.oauth?.enabled;
|
||||
|
||||
let errorString =
|
||||
`failed to start or connect to MCP server '${mcpServerName}' ` +
|
||||
`${JSON.stringify(safeConfig)}; \n${error}`;
|
||||
if (process.env.SANDBOX) {
|
||||
errorString += `\nMake sure it is available in the sandbox`;
|
||||
if (!shouldTriggerOAuth) {
|
||||
// For SSE servers without explicit OAuth config, if a token was found but rejected, report it accurately.
|
||||
const credentials = await MCPOAuthTokenStorage.getToken(mcpServerName);
|
||||
if (credentials) {
|
||||
const hasStoredTokens = await MCPOAuthProvider.getValidToken(
|
||||
mcpServerName,
|
||||
{
|
||||
// Pass client ID if available
|
||||
clientId: credentials.clientId,
|
||||
},
|
||||
);
|
||||
if (hasStoredTokens) {
|
||||
console.log(
|
||||
`Stored OAuth token for SSE server '${mcpServerName}' was rejected. ` +
|
||||
`Please re-authenticate using: /mcp auth ${mcpServerName}`,
|
||||
);
|
||||
} else {
|
||||
console.log(
|
||||
`401 error received for SSE server '${mcpServerName}' without OAuth configuration. ` +
|
||||
`Please authenticate using: /mcp auth ${mcpServerName}`,
|
||||
);
|
||||
}
|
||||
}
|
||||
throw new Error(
|
||||
`401 error received for SSE server '${mcpServerName}' without OAuth configuration. ` +
|
||||
`Please authenticate using: /mcp auth ${mcpServerName}`,
|
||||
);
|
||||
}
|
||||
|
||||
// Try to extract www-authenticate header from the error
|
||||
let wwwAuthenticate = extractWWWAuthenticateHeader(errorString);
|
||||
|
||||
// If we didn't get the header from the error string, try to get it from the server
|
||||
if (!wwwAuthenticate && mcpServerConfig.url) {
|
||||
console.log(
|
||||
`No www-authenticate header in error, trying to fetch it from server...`,
|
||||
);
|
||||
try {
|
||||
const response = await fetch(mcpServerConfig.url, {
|
||||
method: 'HEAD',
|
||||
headers: {
|
||||
Accept: 'text/event-stream',
|
||||
},
|
||||
signal: AbortSignal.timeout(5000),
|
||||
});
|
||||
|
||||
if (response.status === 401) {
|
||||
wwwAuthenticate = response.headers.get('www-authenticate');
|
||||
if (wwwAuthenticate) {
|
||||
console.log(
|
||||
`Found www-authenticate header from server: ${wwwAuthenticate}`,
|
||||
);
|
||||
}
|
||||
}
|
||||
} catch (fetchError) {
|
||||
console.debug(
|
||||
`Failed to fetch www-authenticate header: ${getErrorMessage(fetchError)}`,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
if (wwwAuthenticate) {
|
||||
console.log(
|
||||
`Received 401 with www-authenticate header: ${wwwAuthenticate}`,
|
||||
);
|
||||
|
||||
// Try automatic OAuth discovery and authentication
|
||||
const oauthSuccess = await handleAutomaticOAuth(
|
||||
mcpServerName,
|
||||
mcpServerConfig,
|
||||
wwwAuthenticate,
|
||||
);
|
||||
if (oauthSuccess) {
|
||||
// Retry connection with OAuth token
|
||||
console.log(
|
||||
`Retrying connection to '${mcpServerName}' with OAuth token...`,
|
||||
);
|
||||
|
||||
// Get the valid token - we need to create a proper OAuth config
|
||||
// The token should already be available from the authentication process
|
||||
const credentials =
|
||||
await MCPOAuthTokenStorage.getToken(mcpServerName);
|
||||
if (credentials) {
|
||||
const accessToken = await MCPOAuthProvider.getValidToken(
|
||||
mcpServerName,
|
||||
{
|
||||
// Pass client ID if available
|
||||
clientId: credentials.clientId,
|
||||
},
|
||||
);
|
||||
|
||||
if (accessToken) {
|
||||
// Create transport with OAuth token
|
||||
const oauthTransport = await createTransportWithOAuth(
|
||||
mcpServerName,
|
||||
mcpServerConfig,
|
||||
accessToken,
|
||||
);
|
||||
if (oauthTransport) {
|
||||
try {
|
||||
await mcpClient.connect(oauthTransport, {
|
||||
timeout:
|
||||
mcpServerConfig.timeout ?? MCP_DEFAULT_TIMEOUT_MSEC,
|
||||
});
|
||||
// Connection successful with OAuth
|
||||
return mcpClient;
|
||||
} catch (retryError) {
|
||||
console.error(
|
||||
`Failed to connect with OAuth token: ${getErrorMessage(
|
||||
retryError,
|
||||
)}`,
|
||||
);
|
||||
throw retryError;
|
||||
}
|
||||
} else {
|
||||
console.error(
|
||||
`Failed to create OAuth transport for server '${mcpServerName}'`,
|
||||
);
|
||||
throw new Error(
|
||||
`Failed to create OAuth transport for server '${mcpServerName}'`,
|
||||
);
|
||||
}
|
||||
} else {
|
||||
console.error(
|
||||
`Failed to get OAuth token for server '${mcpServerName}'`,
|
||||
);
|
||||
throw new Error(
|
||||
`Failed to get OAuth token for server '${mcpServerName}'`,
|
||||
);
|
||||
}
|
||||
} else {
|
||||
console.error(
|
||||
`Failed to get credentials for server '${mcpServerName}' after successful OAuth authentication`,
|
||||
);
|
||||
throw new Error(
|
||||
`Failed to get credentials for server '${mcpServerName}' after successful OAuth authentication`,
|
||||
);
|
||||
}
|
||||
} else {
|
||||
console.error(
|
||||
`Failed to handle automatic OAuth for server '${mcpServerName}'`,
|
||||
);
|
||||
throw new Error(
|
||||
`Failed to handle automatic OAuth for server '${mcpServerName}'`,
|
||||
);
|
||||
}
|
||||
} else {
|
||||
// No www-authenticate header found, but we got a 401
|
||||
// Only try OAuth discovery for HTTP servers or when OAuth is explicitly configured
|
||||
// For SSE servers, we should not trigger new OAuth flows automatically
|
||||
const shouldTryDiscovery =
|
||||
mcpServerConfig.httpUrl || mcpServerConfig.oauth?.enabled;
|
||||
|
||||
if (!shouldTryDiscovery) {
|
||||
const credentials =
|
||||
await MCPOAuthTokenStorage.getToken(mcpServerName);
|
||||
if (credentials) {
|
||||
const hasStoredTokens = await MCPOAuthProvider.getValidToken(
|
||||
mcpServerName,
|
||||
{
|
||||
// Pass client ID if available
|
||||
clientId: credentials.clientId,
|
||||
},
|
||||
);
|
||||
if (hasStoredTokens) {
|
||||
console.log(
|
||||
`Stored OAuth token for SSE server '${mcpServerName}' was rejected. ` +
|
||||
`Please re-authenticate using: /mcp auth ${mcpServerName}`,
|
||||
);
|
||||
} else {
|
||||
console.log(
|
||||
`401 error received for SSE server '${mcpServerName}' without OAuth configuration. ` +
|
||||
`Please authenticate using: /mcp auth ${mcpServerName}`,
|
||||
);
|
||||
}
|
||||
}
|
||||
throw new Error(
|
||||
`401 error received for SSE server '${mcpServerName}' without OAuth configuration. ` +
|
||||
`Please authenticate using: /mcp auth ${mcpServerName}`,
|
||||
);
|
||||
}
|
||||
|
||||
// For SSE servers, try to discover OAuth configuration from the base URL
|
||||
console.log(`🔍 Attempting OAuth discovery for '${mcpServerName}'...`);
|
||||
|
||||
if (mcpServerConfig.url) {
|
||||
const sseUrl = new URL(mcpServerConfig.url);
|
||||
const baseUrl = `${sseUrl.protocol}//${sseUrl.host}`;
|
||||
|
||||
try {
|
||||
// Try to discover OAuth configuration from the base URL
|
||||
const oauthConfig = await OAuthUtils.discoverOAuthConfig(baseUrl);
|
||||
if (oauthConfig) {
|
||||
console.log(
|
||||
`Discovered OAuth configuration from base URL for server '${mcpServerName}'`,
|
||||
);
|
||||
|
||||
// Create OAuth configuration for authentication
|
||||
const oauthAuthConfig = {
|
||||
enabled: true,
|
||||
authorizationUrl: oauthConfig.authorizationUrl,
|
||||
tokenUrl: oauthConfig.tokenUrl,
|
||||
scopes: oauthConfig.scopes || [],
|
||||
};
|
||||
|
||||
// Perform OAuth authentication
|
||||
console.log(
|
||||
`Starting OAuth authentication for server '${mcpServerName}'...`,
|
||||
);
|
||||
await MCPOAuthProvider.authenticate(
|
||||
mcpServerName,
|
||||
oauthAuthConfig,
|
||||
);
|
||||
|
||||
// Retry connection with OAuth token
|
||||
const credentials =
|
||||
await MCPOAuthTokenStorage.getToken(mcpServerName);
|
||||
if (credentials) {
|
||||
const accessToken = await MCPOAuthProvider.getValidToken(
|
||||
mcpServerName,
|
||||
{
|
||||
// Pass client ID if available
|
||||
clientId: credentials.clientId,
|
||||
},
|
||||
);
|
||||
if (accessToken) {
|
||||
// Create transport with OAuth token
|
||||
const oauthTransport = await createTransportWithOAuth(
|
||||
mcpServerName,
|
||||
mcpServerConfig,
|
||||
accessToken,
|
||||
);
|
||||
if (oauthTransport) {
|
||||
try {
|
||||
await mcpClient.connect(oauthTransport, {
|
||||
timeout:
|
||||
mcpServerConfig.timeout ?? MCP_DEFAULT_TIMEOUT_MSEC,
|
||||
});
|
||||
// Connection successful with OAuth
|
||||
return mcpClient;
|
||||
} catch (retryError) {
|
||||
console.error(
|
||||
`Failed to connect with OAuth token: ${getErrorMessage(
|
||||
retryError,
|
||||
)}`,
|
||||
);
|
||||
throw retryError;
|
||||
}
|
||||
} else {
|
||||
console.error(
|
||||
`Failed to create OAuth transport for server '${mcpServerName}'`,
|
||||
);
|
||||
throw new Error(
|
||||
`Failed to create OAuth transport for server '${mcpServerName}'`,
|
||||
);
|
||||
}
|
||||
} else {
|
||||
console.error(
|
||||
`Failed to get OAuth token for server '${mcpServerName}'`,
|
||||
);
|
||||
throw new Error(
|
||||
`Failed to get OAuth token for server '${mcpServerName}'`,
|
||||
);
|
||||
}
|
||||
} else {
|
||||
console.error(
|
||||
`Failed to get stored credentials for server '${mcpServerName}'`,
|
||||
);
|
||||
throw new Error(
|
||||
`Failed to get stored credentials for server '${mcpServerName}'`,
|
||||
);
|
||||
}
|
||||
} else {
|
||||
console.error(
|
||||
`❌ Could not configure OAuth for '${mcpServerName}' - please authenticate manually with /mcp auth ${mcpServerName}`,
|
||||
);
|
||||
throw new Error(
|
||||
`OAuth configuration failed for '${mcpServerName}'. Please authenticate manually with /mcp auth ${mcpServerName}`,
|
||||
);
|
||||
}
|
||||
} catch (discoveryError) {
|
||||
console.error(
|
||||
`❌ OAuth discovery failed for '${mcpServerName}' - please authenticate manually with /mcp auth ${mcpServerName}`,
|
||||
);
|
||||
throw discoveryError;
|
||||
}
|
||||
} else {
|
||||
console.error(
|
||||
`❌ '${mcpServerName}' requires authentication but no OAuth configuration found`,
|
||||
);
|
||||
throw new Error(
|
||||
`MCP server '${mcpServerName}' requires authentication. Please configure OAuth or check server settings.`,
|
||||
);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Handle other connection errors
|
||||
// Create a concise error message
|
||||
const errorMessage = (error as Error).message || String(error);
|
||||
const isNetworkError =
|
||||
errorMessage.includes('ENOTFOUND') ||
|
||||
errorMessage.includes('ECONNREFUSED');
|
||||
|
||||
let conciseError: string;
|
||||
if (isNetworkError) {
|
||||
conciseError = `Cannot connect to '${mcpServerName}' - server may be down or URL incorrect`;
|
||||
} else {
|
||||
conciseError = `Connection failed for '${mcpServerName}': ${errorMessage}`;
|
||||
}
|
||||
|
||||
if (process.env.SANDBOX) {
|
||||
conciseError += ` (check sandbox availability)`;
|
||||
}
|
||||
|
||||
throw new Error(conciseError);
|
||||
}
|
||||
throw new Error(errorString);
|
||||
}
|
||||
}
|
||||
|
||||
/** Visible for Testing */
|
||||
export function createTransport(
|
||||
export async function createTransport(
|
||||
mcpServerName: string,
|
||||
mcpServerConfig: MCPServerConfig,
|
||||
debugMode: boolean,
|
||||
): Transport {
|
||||
): Promise<Transport> {
|
||||
if (
|
||||
mcpServerConfig.authProviderType === AuthProviderType.GOOGLE_CREDENTIALS
|
||||
) {
|
||||
const provider = new GoogleCredentialProvider(mcpServerConfig);
|
||||
const transportOptions:
|
||||
| StreamableHTTPClientTransportOptions
|
||||
| SSEClientTransportOptions = {
|
||||
authProvider: provider,
|
||||
};
|
||||
if (mcpServerConfig.httpUrl) {
|
||||
return new StreamableHTTPClientTransport(
|
||||
new URL(mcpServerConfig.httpUrl),
|
||||
transportOptions,
|
||||
);
|
||||
} else if (mcpServerConfig.url) {
|
||||
return new SSEClientTransport(
|
||||
new URL(mcpServerConfig.url),
|
||||
transportOptions,
|
||||
);
|
||||
}
|
||||
throw new Error('No URL configured for Google Credentials MCP server');
|
||||
}
|
||||
|
||||
// Check if we have OAuth configuration or stored tokens
|
||||
let accessToken: string | null = null;
|
||||
let hasOAuthConfig = mcpServerConfig.oauth?.enabled;
|
||||
|
||||
if (hasOAuthConfig && mcpServerConfig.oauth) {
|
||||
accessToken = await MCPOAuthProvider.getValidToken(
|
||||
mcpServerName,
|
||||
mcpServerConfig.oauth,
|
||||
);
|
||||
|
||||
if (!accessToken) {
|
||||
console.error(
|
||||
`MCP server '${mcpServerName}' requires OAuth authentication. ` +
|
||||
`Please authenticate using the /mcp auth command.`,
|
||||
);
|
||||
throw new Error(
|
||||
`MCP server '${mcpServerName}' requires OAuth authentication. ` +
|
||||
`Please authenticate using the /mcp auth command.`,
|
||||
);
|
||||
}
|
||||
} else {
|
||||
// Check if we have stored OAuth tokens for this server (from previous authentication)
|
||||
const credentials = await MCPOAuthTokenStorage.getToken(mcpServerName);
|
||||
if (credentials) {
|
||||
accessToken = await MCPOAuthProvider.getValidToken(mcpServerName, {
|
||||
// Pass client ID if available
|
||||
clientId: credentials.clientId,
|
||||
});
|
||||
|
||||
if (accessToken) {
|
||||
hasOAuthConfig = true;
|
||||
console.log(`Found stored OAuth token for server '${mcpServerName}'`);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (mcpServerConfig.httpUrl) {
|
||||
const transportOptions: StreamableHTTPClientTransportOptions = {};
|
||||
if (mcpServerConfig.headers) {
|
||||
|
||||
// Set up headers with OAuth token if available
|
||||
if (hasOAuthConfig && accessToken) {
|
||||
transportOptions.requestInit = {
|
||||
headers: {
|
||||
...mcpServerConfig.headers,
|
||||
Authorization: `Bearer ${accessToken}`,
|
||||
},
|
||||
};
|
||||
} else if (mcpServerConfig.headers) {
|
||||
transportOptions.requestInit = {
|
||||
headers: mcpServerConfig.headers,
|
||||
};
|
||||
}
|
||||
|
||||
return new StreamableHTTPClientTransport(
|
||||
new URL(mcpServerConfig.httpUrl),
|
||||
transportOptions,
|
||||
@@ -375,11 +1018,21 @@ export function createTransport(
|
||||
|
||||
if (mcpServerConfig.url) {
|
||||
const transportOptions: SSEClientTransportOptions = {};
|
||||
if (mcpServerConfig.headers) {
|
||||
|
||||
// Set up headers with OAuth token if available
|
||||
if (hasOAuthConfig && accessToken) {
|
||||
transportOptions.requestInit = {
|
||||
headers: {
|
||||
...mcpServerConfig.headers,
|
||||
Authorization: `Bearer ${accessToken}`,
|
||||
},
|
||||
};
|
||||
} else if (mcpServerConfig.headers) {
|
||||
transportOptions.requestInit = {
|
||||
headers: mcpServerConfig.headers,
|
||||
};
|
||||
}
|
||||
|
||||
return new SSEClientTransport(
|
||||
new URL(mcpServerConfig.url),
|
||||
transportOptions,
|
||||
@@ -411,26 +1064,6 @@ export function createTransport(
|
||||
);
|
||||
}
|
||||
|
||||
/** Visible for testing */
|
||||
export function generateValidName(
|
||||
funcDecl: FunctionDeclaration,
|
||||
mcpServerName: string,
|
||||
) {
|
||||
// Replace invalid characters (based on 400 error message from Gemini API) with underscores
|
||||
let validToolname = funcDecl.name!.replace(/[^a-zA-Z0-9_.-]/g, '_');
|
||||
|
||||
// Prepend MCP server name to avoid conflicts with other tools
|
||||
validToolname = mcpServerName + '__' + validToolname;
|
||||
|
||||
// If longer than 63 characters, replace middle with '___'
|
||||
// (Gemini API says max length 64, but actual limit seems to be 63)
|
||||
if (validToolname.length > 63) {
|
||||
validToolname =
|
||||
validToolname.slice(0, 28) + '___' + validToolname.slice(-32);
|
||||
}
|
||||
return validToolname;
|
||||
}
|
||||
|
||||
/** Visible for testing */
|
||||
export function isEnabled(
|
||||
funcDecl: FunctionDeclaration,
|
||||
|
||||
@@ -14,7 +14,7 @@ import {
|
||||
afterEach,
|
||||
Mocked,
|
||||
} from 'vitest';
|
||||
import { DiscoveredMCPTool } from './mcp-tool.js'; // Added getStringifiedResultForDisplay
|
||||
import { DiscoveredMCPTool, generateValidName } from './mcp-tool.js'; // Added getStringifiedResultForDisplay
|
||||
import { ToolResult, ToolConfirmationOutcome } from './tools.js'; // Added ToolConfirmationOutcome
|
||||
import { CallableTool, Part } from '@google/genai';
|
||||
|
||||
@@ -29,9 +29,42 @@ const mockCallableToolInstance: Mocked<CallableTool> = {
|
||||
// Add other methods if DiscoveredMCPTool starts using them
|
||||
};
|
||||
|
||||
describe('generateValidName', () => {
|
||||
it('should return a valid name for a simple function', () => {
|
||||
expect(generateValidName('myFunction')).toBe('myFunction');
|
||||
});
|
||||
|
||||
it('should replace invalid characters with underscores', () => {
|
||||
expect(generateValidName('invalid-name with spaces')).toBe(
|
||||
'invalid-name_with_spaces',
|
||||
);
|
||||
});
|
||||
|
||||
it('should truncate long names', () => {
|
||||
expect(generateValidName('x'.repeat(80))).toBe(
|
||||
'xxxxxxxxxxxxxxxxxxxxxxxxxxxx___xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx',
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle names with only invalid characters', () => {
|
||||
expect(generateValidName('!@#$%^&*()')).toBe('__________');
|
||||
});
|
||||
|
||||
it('should handle names that are exactly 63 characters long', () => {
|
||||
expect(generateValidName('a'.repeat(63)).length).toBe(63);
|
||||
});
|
||||
|
||||
it('should handle names that are exactly 64 characters long', () => {
|
||||
expect(generateValidName('a'.repeat(64)).length).toBe(63);
|
||||
});
|
||||
|
||||
it('should handle names that are longer than 64 characters', () => {
|
||||
expect(generateValidName('a'.repeat(80)).length).toBe(63);
|
||||
});
|
||||
});
|
||||
|
||||
describe('DiscoveredMCPTool', () => {
|
||||
const serverName = 'mock-mcp-server';
|
||||
const toolNameForModel = 'test-mcp-tool-for-model';
|
||||
const serverToolName = 'actual-server-tool-name';
|
||||
const baseDescription = 'A test MCP tool.';
|
||||
const inputSchema: Record<string, unknown> = {
|
||||
@@ -52,46 +85,32 @@ describe('DiscoveredMCPTool', () => {
|
||||
});
|
||||
|
||||
describe('constructor', () => {
|
||||
it('should set properties correctly (non-generic server)', () => {
|
||||
it('should set properties correctly', () => {
|
||||
const tool = new DiscoveredMCPTool(
|
||||
mockCallableToolInstance,
|
||||
serverName, // serverName is 'mock-mcp-server', not 'mcp'
|
||||
toolNameForModel,
|
||||
serverName,
|
||||
serverToolName,
|
||||
baseDescription,
|
||||
inputSchema,
|
||||
serverToolName,
|
||||
);
|
||||
|
||||
expect(tool.name).toBe(toolNameForModel);
|
||||
expect(tool.schema.name).toBe(toolNameForModel);
|
||||
expect(tool.name).toBe(serverToolName);
|
||||
expect(tool.schema.name).toBe(serverToolName);
|
||||
expect(tool.schema.description).toBe(baseDescription);
|
||||
expect(tool.schema.parameters).toEqual(inputSchema);
|
||||
expect(tool.schema.parameters).toBeUndefined();
|
||||
expect(tool.schema.parametersJsonSchema).toEqual(inputSchema);
|
||||
expect(tool.serverToolName).toBe(serverToolName);
|
||||
expect(tool.timeout).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should set properties correctly (generic "mcp" server)', () => {
|
||||
const genericServerName = 'mcp';
|
||||
const tool = new DiscoveredMCPTool(
|
||||
mockCallableToolInstance,
|
||||
genericServerName, // serverName is 'mcp'
|
||||
toolNameForModel,
|
||||
baseDescription,
|
||||
inputSchema,
|
||||
serverToolName,
|
||||
);
|
||||
expect(tool.schema.description).toBe(baseDescription);
|
||||
});
|
||||
|
||||
it('should accept and store a custom timeout', () => {
|
||||
const customTimeout = 5000;
|
||||
const tool = new DiscoveredMCPTool(
|
||||
mockCallableToolInstance,
|
||||
serverName,
|
||||
toolNameForModel,
|
||||
serverToolName,
|
||||
baseDescription,
|
||||
inputSchema,
|
||||
serverToolName,
|
||||
customTimeout,
|
||||
);
|
||||
expect(tool.timeout).toBe(customTimeout);
|
||||
@@ -103,10 +122,9 @@ describe('DiscoveredMCPTool', () => {
|
||||
const tool = new DiscoveredMCPTool(
|
||||
mockCallableToolInstance,
|
||||
serverName,
|
||||
toolNameForModel,
|
||||
serverToolName,
|
||||
baseDescription,
|
||||
inputSchema,
|
||||
serverToolName,
|
||||
);
|
||||
const params = { param: 'testValue' };
|
||||
const mockToolSuccessResultObject = {
|
||||
@@ -143,10 +161,9 @@ describe('DiscoveredMCPTool', () => {
|
||||
const tool = new DiscoveredMCPTool(
|
||||
mockCallableToolInstance,
|
||||
serverName,
|
||||
toolNameForModel,
|
||||
serverToolName,
|
||||
baseDescription,
|
||||
inputSchema,
|
||||
serverToolName,
|
||||
);
|
||||
const params = { param: 'testValue' };
|
||||
const mockMcpToolResponsePartsEmpty: Part[] = [];
|
||||
@@ -159,10 +176,9 @@ describe('DiscoveredMCPTool', () => {
|
||||
const tool = new DiscoveredMCPTool(
|
||||
mockCallableToolInstance,
|
||||
serverName,
|
||||
toolNameForModel,
|
||||
serverToolName,
|
||||
baseDescription,
|
||||
inputSchema,
|
||||
serverToolName,
|
||||
);
|
||||
const params = { param: 'failCase' };
|
||||
const expectedError = new Error('MCP call failed');
|
||||
@@ -179,10 +195,9 @@ describe('DiscoveredMCPTool', () => {
|
||||
const tool = new DiscoveredMCPTool(
|
||||
mockCallableToolInstance,
|
||||
serverName,
|
||||
toolNameForModel,
|
||||
serverToolName,
|
||||
baseDescription,
|
||||
inputSchema,
|
||||
serverToolName,
|
||||
undefined,
|
||||
true,
|
||||
);
|
||||
@@ -196,10 +211,9 @@ describe('DiscoveredMCPTool', () => {
|
||||
const tool = new DiscoveredMCPTool(
|
||||
mockCallableToolInstance,
|
||||
serverName,
|
||||
toolNameForModel,
|
||||
serverToolName,
|
||||
baseDescription,
|
||||
inputSchema,
|
||||
serverToolName,
|
||||
);
|
||||
expect(
|
||||
await tool.shouldConfirmExecute({}, new AbortController().signal),
|
||||
@@ -212,10 +226,9 @@ describe('DiscoveredMCPTool', () => {
|
||||
const tool = new DiscoveredMCPTool(
|
||||
mockCallableToolInstance,
|
||||
serverName,
|
||||
toolNameForModel,
|
||||
serverToolName,
|
||||
baseDescription,
|
||||
inputSchema,
|
||||
serverToolName,
|
||||
);
|
||||
expect(
|
||||
await tool.shouldConfirmExecute({}, new AbortController().signal),
|
||||
@@ -226,10 +239,9 @@ describe('DiscoveredMCPTool', () => {
|
||||
const tool = new DiscoveredMCPTool(
|
||||
mockCallableToolInstance,
|
||||
serverName,
|
||||
toolNameForModel,
|
||||
serverToolName,
|
||||
baseDescription,
|
||||
inputSchema,
|
||||
serverToolName,
|
||||
);
|
||||
const confirmation = await tool.shouldConfirmExecute(
|
||||
{},
|
||||
@@ -257,10 +269,9 @@ describe('DiscoveredMCPTool', () => {
|
||||
const tool = new DiscoveredMCPTool(
|
||||
mockCallableToolInstance,
|
||||
serverName,
|
||||
toolNameForModel,
|
||||
serverToolName,
|
||||
baseDescription,
|
||||
inputSchema,
|
||||
serverToolName,
|
||||
);
|
||||
const confirmation = await tool.shouldConfirmExecute(
|
||||
{},
|
||||
@@ -288,10 +299,9 @@ describe('DiscoveredMCPTool', () => {
|
||||
const tool = new DiscoveredMCPTool(
|
||||
mockCallableToolInstance,
|
||||
serverName,
|
||||
toolNameForModel,
|
||||
serverToolName,
|
||||
baseDescription,
|
||||
inputSchema,
|
||||
serverToolName,
|
||||
);
|
||||
const toolAllowlistKey = `${serverName}.${serverToolName}`;
|
||||
const confirmation = await tool.shouldConfirmExecute(
|
||||
@@ -315,5 +325,77 @@ describe('DiscoveredMCPTool', () => {
|
||||
);
|
||||
}
|
||||
});
|
||||
|
||||
it('should handle Cancel confirmation outcome', async () => {
|
||||
const tool = new DiscoveredMCPTool(
|
||||
mockCallableToolInstance,
|
||||
serverName,
|
||||
serverToolName,
|
||||
baseDescription,
|
||||
inputSchema,
|
||||
);
|
||||
const confirmation = await tool.shouldConfirmExecute(
|
||||
{},
|
||||
new AbortController().signal,
|
||||
);
|
||||
expect(confirmation).not.toBe(false);
|
||||
if (
|
||||
confirmation &&
|
||||
typeof confirmation === 'object' &&
|
||||
'onConfirm' in confirmation &&
|
||||
typeof confirmation.onConfirm === 'function'
|
||||
) {
|
||||
// Cancel should not add anything to allowlist
|
||||
await confirmation.onConfirm(ToolConfirmationOutcome.Cancel);
|
||||
expect((DiscoveredMCPTool as any).allowlist.has(serverName)).toBe(
|
||||
false,
|
||||
);
|
||||
expect(
|
||||
(DiscoveredMCPTool as any).allowlist.has(
|
||||
`${serverName}.${serverToolName}`,
|
||||
),
|
||||
).toBe(false);
|
||||
} else {
|
||||
throw new Error(
|
||||
'Confirmation details or onConfirm not in expected format',
|
||||
);
|
||||
}
|
||||
});
|
||||
|
||||
it('should handle ProceedOnce confirmation outcome', async () => {
|
||||
const tool = new DiscoveredMCPTool(
|
||||
mockCallableToolInstance,
|
||||
serverName,
|
||||
serverToolName,
|
||||
baseDescription,
|
||||
inputSchema,
|
||||
);
|
||||
const confirmation = await tool.shouldConfirmExecute(
|
||||
{},
|
||||
new AbortController().signal,
|
||||
);
|
||||
expect(confirmation).not.toBe(false);
|
||||
if (
|
||||
confirmation &&
|
||||
typeof confirmation === 'object' &&
|
||||
'onConfirm' in confirmation &&
|
||||
typeof confirmation.onConfirm === 'function'
|
||||
) {
|
||||
// ProceedOnce should not add anything to allowlist
|
||||
await confirmation.onConfirm(ToolConfirmationOutcome.ProceedOnce);
|
||||
expect((DiscoveredMCPTool as any).allowlist.has(serverName)).toBe(
|
||||
false,
|
||||
);
|
||||
expect(
|
||||
(DiscoveredMCPTool as any).allowlist.has(
|
||||
`${serverName}.${serverToolName}`,
|
||||
),
|
||||
).toBe(false);
|
||||
} else {
|
||||
throw new Error(
|
||||
'Confirmation details or onConfirm not in expected format',
|
||||
);
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -10,8 +10,15 @@ import {
|
||||
ToolCallConfirmationDetails,
|
||||
ToolConfirmationOutcome,
|
||||
ToolMcpConfirmationDetails,
|
||||
Icon,
|
||||
} from './tools.js';
|
||||
import { CallableTool, Part, FunctionCall, Schema } from '@google/genai';
|
||||
import {
|
||||
CallableTool,
|
||||
Part,
|
||||
FunctionCall,
|
||||
FunctionDeclaration,
|
||||
Type,
|
||||
} from '@google/genai';
|
||||
|
||||
type ToolParams = Record<string, unknown>;
|
||||
|
||||
@@ -21,23 +28,49 @@ export class DiscoveredMCPTool extends BaseTool<ToolParams, ToolResult> {
|
||||
constructor(
|
||||
private readonly mcpTool: CallableTool,
|
||||
readonly serverName: string,
|
||||
readonly name: string,
|
||||
readonly description: string,
|
||||
readonly parameterSchema: Schema,
|
||||
readonly serverToolName: string,
|
||||
description: string,
|
||||
readonly parameterSchemaJson: unknown,
|
||||
readonly timeout?: number,
|
||||
readonly trust?: boolean,
|
||||
nameOverride?: string,
|
||||
) {
|
||||
super(
|
||||
name,
|
||||
nameOverride ?? generateValidName(serverToolName),
|
||||
`${serverToolName} (${serverName} MCP Server)`,
|
||||
description,
|
||||
parameterSchema,
|
||||
Icon.Hammer,
|
||||
{ type: Type.OBJECT }, // this is a dummy Schema for MCP, will be not be used to construct the FunctionDeclaration
|
||||
true, // isOutputMarkdown
|
||||
false, // canUpdateOutput
|
||||
);
|
||||
}
|
||||
|
||||
asFullyQualifiedTool(): DiscoveredMCPTool {
|
||||
return new DiscoveredMCPTool(
|
||||
this.mcpTool,
|
||||
this.serverName,
|
||||
this.serverToolName,
|
||||
this.description,
|
||||
this.parameterSchemaJson,
|
||||
this.timeout,
|
||||
this.trust,
|
||||
`${this.serverName}__${this.serverToolName}`,
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Overrides the base schema to use parametersJsonSchema when building
|
||||
* FunctionDeclaration
|
||||
*/
|
||||
override get schema(): FunctionDeclaration {
|
||||
return {
|
||||
name: this.name,
|
||||
description: this.description,
|
||||
parametersJsonSchema: this.parameterSchemaJson,
|
||||
};
|
||||
}
|
||||
|
||||
async shouldConfirmExecute(
|
||||
_params: ToolParams,
|
||||
_abortSignal: AbortSignal,
|
||||
@@ -53,7 +86,7 @@ export class DiscoveredMCPTool extends BaseTool<ToolParams, ToolResult> {
|
||||
DiscoveredMCPTool.allowlist.has(serverAllowListKey) ||
|
||||
DiscoveredMCPTool.allowlist.has(toolAllowListKey)
|
||||
) {
|
||||
return false; // server and/or tool already allow listed
|
||||
return false; // server and/or tool already allowlisted
|
||||
}
|
||||
|
||||
const confirmationDetails: ToolMcpConfirmationDetails = {
|
||||
@@ -146,3 +179,17 @@ function getStringifiedResultForDisplay(result: Part[]) {
|
||||
|
||||
return '```json\n' + JSON.stringify(processedResults, null, 2) + '\n```';
|
||||
}
|
||||
|
||||
/** Visible for testing */
|
||||
export function generateValidName(name: string) {
|
||||
// Replace invalid characters (based on 400 error message from Gemini API) with underscores
|
||||
let validToolname = name.replace(/[^a-zA-Z0-9_.-]/g, '_');
|
||||
|
||||
// If longer than 63 characters, replace middle with '___'
|
||||
// (Gemini API says max length 64, but actual limit seems to be 63)
|
||||
if (validToolname.length > 63) {
|
||||
validToolname =
|
||||
validToolname.slice(0, 28) + '___' + validToolname.slice(-32);
|
||||
}
|
||||
return validToolname;
|
||||
}
|
||||
|
||||
@@ -87,7 +87,7 @@ describe('MemoryTool', () => {
|
||||
describe('performAddMemoryEntry (static method)', () => {
|
||||
const testFilePath = path.join(
|
||||
'/mock/home',
|
||||
'.qwen',
|
||||
'.gemini',
|
||||
DEFAULT_CONTEXT_FILENAME, // Use the default for basic tests
|
||||
);
|
||||
|
||||
@@ -207,7 +207,7 @@ describe('MemoryTool', () => {
|
||||
// Use getCurrentGeminiMdFilename for the default expectation before any setGeminiMdFilename calls in a test
|
||||
const expectedFilePath = path.join(
|
||||
'/mock/home',
|
||||
'.qwen',
|
||||
'.gemini',
|
||||
getCurrentGeminiMdFilename(), // This will be DEFAULT_CONTEXT_FILENAME unless changed by a test
|
||||
);
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { BaseTool, ToolResult } from './tools.js';
|
||||
import { BaseTool, Icon, ToolResult } from './tools.js';
|
||||
import { FunctionDeclaration, Type } from '@google/genai';
|
||||
import * as fs from 'fs/promises';
|
||||
import * as path from 'path';
|
||||
@@ -46,8 +46,8 @@ Do NOT use this tool:
|
||||
- \`fact\` (string, required): The specific fact or piece of information to remember. This should be a clear, self-contained statement. For example, if the user says "My favorite color is blue", the fact would be "My favorite color is blue".
|
||||
`;
|
||||
|
||||
export const GEMINI_CONFIG_DIR = '.qwen';
|
||||
export const DEFAULT_CONTEXT_FILENAME = 'QWEN.md';
|
||||
export const GEMINI_CONFIG_DIR = '.gemini';
|
||||
export const DEFAULT_CONTEXT_FILENAME = 'GEMINI.md';
|
||||
export const MEMORY_SECTION_HEADER = '## Gemini Added Memories';
|
||||
|
||||
// This variable will hold the currently configured filename for GEMINI.md context files.
|
||||
@@ -105,6 +105,7 @@ export class MemoryTool extends BaseTool<SaveMemoryParams, ToolResult> {
|
||||
MemoryTool.Name,
|
||||
'Save Memory',
|
||||
memoryToolDescription,
|
||||
Icon.LightBulb,
|
||||
memoryToolSchemaData.parameters as Record<string, unknown>,
|
||||
);
|
||||
}
|
||||
|
||||
@@ -4,15 +4,7 @@
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import {
|
||||
vi,
|
||||
describe,
|
||||
it,
|
||||
expect,
|
||||
beforeEach,
|
||||
afterEach,
|
||||
type Mock,
|
||||
} from 'vitest';
|
||||
import { vi, describe, it, expect, beforeEach, afterEach } from 'vitest';
|
||||
import {
|
||||
modifyWithEditor,
|
||||
ModifyContext,
|
||||
@@ -21,6 +13,7 @@ import {
|
||||
} from './modifiable-tool.js';
|
||||
import { EditorType } from '../utils/editor.js';
|
||||
import fs from 'fs';
|
||||
import fsp from 'fs/promises';
|
||||
import os from 'os';
|
||||
import * as path from 'path';
|
||||
|
||||
@@ -36,9 +29,6 @@ vi.mock('diff', () => ({
|
||||
createPatch: mockCreatePatch,
|
||||
}));
|
||||
|
||||
vi.mock('fs');
|
||||
vi.mock('os');
|
||||
|
||||
interface TestParams {
|
||||
filePath: string;
|
||||
someOtherParam: string;
|
||||
@@ -46,7 +36,7 @@ interface TestParams {
|
||||
}
|
||||
|
||||
describe('modifyWithEditor', () => {
|
||||
let tempDir: string;
|
||||
let testProjectDir: string;
|
||||
let mockModifyContext: ModifyContext<TestParams>;
|
||||
let mockParams: TestParams;
|
||||
let currentContent: string;
|
||||
@@ -54,17 +44,19 @@ describe('modifyWithEditor', () => {
|
||||
let modifiedContent: string;
|
||||
let abortSignal: AbortSignal;
|
||||
|
||||
beforeEach(() => {
|
||||
beforeEach(async () => {
|
||||
vi.resetAllMocks();
|
||||
|
||||
tempDir = '/tmp/test-dir';
|
||||
testProjectDir = await fsp.mkdtemp(
|
||||
path.join(os.tmpdir(), 'modifiable-tool-test-'),
|
||||
);
|
||||
abortSignal = new AbortController().signal;
|
||||
|
||||
currentContent = 'original content\nline 2\nline 3';
|
||||
proposedContent = 'modified content\nline 2\nline 3';
|
||||
modifiedContent = 'user modified content\nline 2\nline 3\nnew line';
|
||||
mockParams = {
|
||||
filePath: path.join(tempDir, 'test.txt'),
|
||||
filePath: path.join(testProjectDir, 'test.txt'),
|
||||
someOtherParam: 'value',
|
||||
};
|
||||
|
||||
@@ -81,26 +73,18 @@ describe('modifyWithEditor', () => {
|
||||
})),
|
||||
};
|
||||
|
||||
(os.tmpdir as Mock).mockReturnValue(tempDir);
|
||||
|
||||
(fs.existsSync as Mock).mockReturnValue(true);
|
||||
(fs.mkdirSync as Mock).mockImplementation(() => undefined);
|
||||
(fs.writeFileSync as Mock).mockImplementation(() => {});
|
||||
(fs.unlinkSync as Mock).mockImplementation(() => {});
|
||||
|
||||
(fs.readFileSync as Mock).mockImplementation((filePath: string) => {
|
||||
if (filePath.includes('-new-')) {
|
||||
return modifiedContent;
|
||||
}
|
||||
return currentContent;
|
||||
mockOpenDiff.mockImplementation(async (_oldPath, newPath) => {
|
||||
await fsp.writeFile(newPath, modifiedContent, 'utf8');
|
||||
});
|
||||
|
||||
mockCreatePatch.mockReturnValue('mock diff content');
|
||||
mockOpenDiff.mockResolvedValue(undefined);
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
afterEach(async () => {
|
||||
vi.restoreAllMocks();
|
||||
await fsp.rm(testProjectDir, { recursive: true, force: true });
|
||||
const diffDir = path.join(os.tmpdir(), 'gemini-cli-tool-modify-diffs');
|
||||
await fsp.rm(diffDir, { recursive: true, force: true });
|
||||
});
|
||||
|
||||
describe('successful modification', () => {
|
||||
@@ -120,38 +104,8 @@ describe('modifyWithEditor', () => {
|
||||
);
|
||||
expect(mockModifyContext.getFilePath).toHaveBeenCalledWith(mockParams);
|
||||
|
||||
expect(fs.writeFileSync).toHaveBeenCalledTimes(2);
|
||||
expect(fs.writeFileSync).toHaveBeenNthCalledWith(
|
||||
1,
|
||||
expect.stringContaining(
|
||||
path.join(tempDir, 'gemini-cli-tool-modify-diffs'),
|
||||
),
|
||||
currentContent,
|
||||
'utf8',
|
||||
);
|
||||
expect(fs.writeFileSync).toHaveBeenNthCalledWith(
|
||||
2,
|
||||
expect.stringContaining(
|
||||
path.join(tempDir, 'gemini-cli-tool-modify-diffs'),
|
||||
),
|
||||
proposedContent,
|
||||
'utf8',
|
||||
);
|
||||
|
||||
expect(mockOpenDiff).toHaveBeenCalledWith(
|
||||
expect.stringContaining('-old-'),
|
||||
expect.stringContaining('-new-'),
|
||||
'vscode',
|
||||
);
|
||||
|
||||
expect(fs.readFileSync).toHaveBeenCalledWith(
|
||||
expect.stringContaining('-old-'),
|
||||
'utf8',
|
||||
);
|
||||
expect(fs.readFileSync).toHaveBeenCalledWith(
|
||||
expect.stringContaining('-new-'),
|
||||
'utf8',
|
||||
);
|
||||
expect(mockOpenDiff).toHaveBeenCalledOnce();
|
||||
const [oldFilePath, newFilePath] = mockOpenDiff.mock.calls[0];
|
||||
|
||||
expect(mockModifyContext.createUpdatedParams).toHaveBeenCalledWith(
|
||||
currentContent,
|
||||
@@ -171,15 +125,9 @@ describe('modifyWithEditor', () => {
|
||||
}),
|
||||
);
|
||||
|
||||
expect(fs.unlinkSync).toHaveBeenCalledTimes(2);
|
||||
expect(fs.unlinkSync).toHaveBeenNthCalledWith(
|
||||
1,
|
||||
expect.stringContaining('-old-'),
|
||||
);
|
||||
expect(fs.unlinkSync).toHaveBeenNthCalledWith(
|
||||
2,
|
||||
expect.stringContaining('-new-'),
|
||||
);
|
||||
// Check that temp files are deleted.
|
||||
await expect(fsp.access(oldFilePath)).rejects.toThrow();
|
||||
await expect(fsp.access(newFilePath)).rejects.toThrow();
|
||||
|
||||
expect(result).toEqual({
|
||||
updatedParams: {
|
||||
@@ -192,7 +140,8 @@ describe('modifyWithEditor', () => {
|
||||
});
|
||||
|
||||
it('should create temp directory if it does not exist', async () => {
|
||||
(fs.existsSync as Mock).mockReturnValue(false);
|
||||
const diffDir = path.join(os.tmpdir(), 'gemini-cli-tool-modify-diffs');
|
||||
await fsp.rm(diffDir, { recursive: true, force: true }).catch(() => {});
|
||||
|
||||
await modifyWithEditor(
|
||||
mockParams,
|
||||
@@ -201,14 +150,15 @@ describe('modifyWithEditor', () => {
|
||||
abortSignal,
|
||||
);
|
||||
|
||||
expect(fs.mkdirSync).toHaveBeenCalledWith(
|
||||
path.join(tempDir, 'gemini-cli-tool-modify-diffs'),
|
||||
{ recursive: true },
|
||||
);
|
||||
const stats = await fsp.stat(diffDir);
|
||||
expect(stats.isDirectory()).toBe(true);
|
||||
});
|
||||
|
||||
it('should not create temp directory if it already exists', async () => {
|
||||
(fs.existsSync as Mock).mockReturnValue(true);
|
||||
const diffDir = path.join(os.tmpdir(), 'gemini-cli-tool-modify-diffs');
|
||||
await fsp.mkdir(diffDir, { recursive: true });
|
||||
|
||||
const mkdirSpy = vi.spyOn(fs, 'mkdirSync');
|
||||
|
||||
await modifyWithEditor(
|
||||
mockParams,
|
||||
@@ -217,18 +167,15 @@ describe('modifyWithEditor', () => {
|
||||
abortSignal,
|
||||
);
|
||||
|
||||
expect(fs.mkdirSync).not.toHaveBeenCalled();
|
||||
expect(mkdirSpy).not.toHaveBeenCalled();
|
||||
mkdirSpy.mockRestore();
|
||||
});
|
||||
});
|
||||
|
||||
it('should handle missing old temp file gracefully', async () => {
|
||||
(fs.readFileSync as Mock).mockImplementation((filePath: string) => {
|
||||
if (filePath.includes('-old-')) {
|
||||
const error = new Error('ENOENT: no such file or directory');
|
||||
(error as NodeJS.ErrnoException).code = 'ENOENT';
|
||||
throw error;
|
||||
}
|
||||
return modifiedContent;
|
||||
mockOpenDiff.mockImplementation(async (oldPath, newPath) => {
|
||||
await fsp.writeFile(newPath, modifiedContent, 'utf8');
|
||||
await fsp.unlink(oldPath);
|
||||
});
|
||||
|
||||
const result = await modifyWithEditor(
|
||||
@@ -255,13 +202,8 @@ describe('modifyWithEditor', () => {
|
||||
});
|
||||
|
||||
it('should handle missing new temp file gracefully', async () => {
|
||||
(fs.readFileSync as Mock).mockImplementation((filePath: string) => {
|
||||
if (filePath.includes('-new-')) {
|
||||
const error = new Error('ENOENT: no such file or directory');
|
||||
(error as NodeJS.ErrnoException).code = 'ENOENT';
|
||||
throw error;
|
||||
}
|
||||
return currentContent;
|
||||
mockOpenDiff.mockImplementation(async (_oldPath, newPath) => {
|
||||
await fsp.unlink(newPath);
|
||||
});
|
||||
|
||||
const result = await modifyWithEditor(
|
||||
@@ -291,6 +233,8 @@ describe('modifyWithEditor', () => {
|
||||
const editorError = new Error('Editor failed to open');
|
||||
mockOpenDiff.mockRejectedValue(editorError);
|
||||
|
||||
const writeSpy = vi.spyOn(fs, 'writeFileSync');
|
||||
|
||||
await expect(
|
||||
modifyWithEditor(
|
||||
mockParams,
|
||||
@@ -300,14 +244,21 @@ describe('modifyWithEditor', () => {
|
||||
),
|
||||
).rejects.toThrow('Editor failed to open');
|
||||
|
||||
expect(fs.unlinkSync).toHaveBeenCalledTimes(2);
|
||||
expect(writeSpy).toHaveBeenCalledTimes(2);
|
||||
const oldFilePath = writeSpy.mock.calls[0][0] as string;
|
||||
const newFilePath = writeSpy.mock.calls[1][0] as string;
|
||||
|
||||
await expect(fsp.access(oldFilePath)).rejects.toThrow();
|
||||
await expect(fsp.access(newFilePath)).rejects.toThrow();
|
||||
|
||||
writeSpy.mockRestore();
|
||||
});
|
||||
|
||||
it('should handle temp file cleanup errors gracefully', async () => {
|
||||
const consoleErrorSpy = vi
|
||||
.spyOn(console, 'error')
|
||||
.mockImplementation(() => {});
|
||||
(fs.unlinkSync as Mock).mockImplementation((_filePath: string) => {
|
||||
vi.spyOn(fs, 'unlinkSync').mockImplementation(() => {
|
||||
throw new Error('Failed to delete file');
|
||||
});
|
||||
|
||||
@@ -327,7 +278,11 @@ describe('modifyWithEditor', () => {
|
||||
});
|
||||
|
||||
it('should create temp files with correct naming with extension', async () => {
|
||||
const testFilePath = path.join(tempDir, 'subfolder', 'test-file.txt');
|
||||
const testFilePath = path.join(
|
||||
testProjectDir,
|
||||
'subfolder',
|
||||
'test-file.txt',
|
||||
);
|
||||
mockModifyContext.getFilePath = vi.fn().mockReturnValue(testFilePath);
|
||||
|
||||
await modifyWithEditor(
|
||||
@@ -337,20 +292,18 @@ describe('modifyWithEditor', () => {
|
||||
abortSignal,
|
||||
);
|
||||
|
||||
const writeFileCalls = (fs.writeFileSync as Mock).mock.calls;
|
||||
expect(writeFileCalls).toHaveLength(2);
|
||||
|
||||
const oldFilePath = writeFileCalls[0][0];
|
||||
const newFilePath = writeFileCalls[1][0];
|
||||
|
||||
expect(mockOpenDiff).toHaveBeenCalledOnce();
|
||||
const [oldFilePath, newFilePath] = mockOpenDiff.mock.calls[0];
|
||||
expect(oldFilePath).toMatch(/gemini-cli-modify-test-file-old-\d+\.txt$/);
|
||||
expect(newFilePath).toMatch(/gemini-cli-modify-test-file-new-\d+\.txt$/);
|
||||
expect(oldFilePath).toContain(`${tempDir}/gemini-cli-tool-modify-diffs/`);
|
||||
expect(newFilePath).toContain(`${tempDir}/gemini-cli-tool-modify-diffs/`);
|
||||
|
||||
const diffDir = path.join(os.tmpdir(), 'gemini-cli-tool-modify-diffs');
|
||||
expect(path.dirname(oldFilePath)).toBe(diffDir);
|
||||
expect(path.dirname(newFilePath)).toBe(diffDir);
|
||||
});
|
||||
|
||||
it('should create temp files with correct naming without extension', async () => {
|
||||
const testFilePath = path.join(tempDir, 'subfolder', 'test-file');
|
||||
const testFilePath = path.join(testProjectDir, 'subfolder', 'test-file');
|
||||
mockModifyContext.getFilePath = vi.fn().mockReturnValue(testFilePath);
|
||||
|
||||
await modifyWithEditor(
|
||||
@@ -360,16 +313,14 @@ describe('modifyWithEditor', () => {
|
||||
abortSignal,
|
||||
);
|
||||
|
||||
const writeFileCalls = (fs.writeFileSync as Mock).mock.calls;
|
||||
expect(writeFileCalls).toHaveLength(2);
|
||||
|
||||
const oldFilePath = writeFileCalls[0][0];
|
||||
const newFilePath = writeFileCalls[1][0];
|
||||
|
||||
expect(mockOpenDiff).toHaveBeenCalledOnce();
|
||||
const [oldFilePath, newFilePath] = mockOpenDiff.mock.calls[0];
|
||||
expect(oldFilePath).toMatch(/gemini-cli-modify-test-file-old-\d+$/);
|
||||
expect(newFilePath).toMatch(/gemini-cli-modify-test-file-new-\d+$/);
|
||||
expect(oldFilePath).toContain(`${tempDir}/gemini-cli-tool-modify-diffs/`);
|
||||
expect(newFilePath).toContain(`${tempDir}/gemini-cli-tool-modify-diffs/`);
|
||||
|
||||
const diffDir = path.join(os.tmpdir(), 'gemini-cli-tool-modify-diffs');
|
||||
expect(path.dirname(oldFilePath)).toBe(diffDir);
|
||||
expect(path.dirname(newFilePath)).toBe(diffDir);
|
||||
});
|
||||
});
|
||||
|
||||
|
||||
@@ -4,54 +4,37 @@
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { vi, describe, it, expect, beforeEach, afterEach, Mock } from 'vitest';
|
||||
import { describe, it, expect, beforeEach, afterEach } from 'vitest';
|
||||
import { ReadFileTool, ReadFileToolParams } from './read-file.js';
|
||||
import * as fileUtils from '../utils/fileUtils.js';
|
||||
import path from 'path';
|
||||
import os from 'os';
|
||||
import fs from 'fs'; // For actual fs operations in setup
|
||||
import fs from 'fs';
|
||||
import fsp from 'fs/promises';
|
||||
import { Config } from '../config/config.js';
|
||||
import { FileDiscoveryService } from '../services/fileDiscoveryService.js';
|
||||
|
||||
// Mock fileUtils.processSingleFileContent
|
||||
vi.mock('../utils/fileUtils', async () => {
|
||||
const actualFileUtils =
|
||||
await vi.importActual<typeof fileUtils>('../utils/fileUtils');
|
||||
return {
|
||||
...actualFileUtils, // Spread actual implementations
|
||||
processSingleFileContent: vi.fn(), // Mock specific function
|
||||
};
|
||||
});
|
||||
|
||||
const mockProcessSingleFileContent = fileUtils.processSingleFileContent as Mock;
|
||||
|
||||
describe('ReadFileTool', () => {
|
||||
let tempRootDir: string;
|
||||
let tool: ReadFileTool;
|
||||
const abortSignal = new AbortController().signal;
|
||||
|
||||
beforeEach(() => {
|
||||
beforeEach(async () => {
|
||||
// Create a unique temporary root directory for each test run
|
||||
tempRootDir = fs.mkdtempSync(
|
||||
tempRootDir = await fsp.mkdtemp(
|
||||
path.join(os.tmpdir(), 'read-file-tool-root-'),
|
||||
);
|
||||
fs.writeFileSync(
|
||||
path.join(tempRootDir, '.geminiignore'),
|
||||
['foo.*'].join('\n'),
|
||||
);
|
||||
const fileService = new FileDiscoveryService(tempRootDir);
|
||||
|
||||
const mockConfigInstance = {
|
||||
getFileService: () => fileService,
|
||||
getFileService: () => new FileDiscoveryService(tempRootDir),
|
||||
getTargetDir: () => tempRootDir,
|
||||
} as unknown as Config;
|
||||
tool = new ReadFileTool(mockConfigInstance);
|
||||
mockProcessSingleFileContent.mockReset();
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
afterEach(async () => {
|
||||
// Clean up the temporary root directory
|
||||
if (fs.existsSync(tempRootDir)) {
|
||||
fs.rmSync(tempRootDir, { recursive: true, force: true });
|
||||
await fsp.rm(tempRootDir, { recursive: true, force: true });
|
||||
}
|
||||
});
|
||||
|
||||
@@ -129,9 +112,9 @@ describe('ReadFileTool', () => {
|
||||
it('should return a shortened, relative path', () => {
|
||||
const filePath = path.join(tempRootDir, 'sub', 'dir', 'file.txt');
|
||||
const params: ReadFileToolParams = { absolute_path: filePath };
|
||||
// Assuming tempRootDir is something like /tmp/read-file-tool-root-XXXXXX
|
||||
// The relative path would be sub/dir/file.txt
|
||||
expect(tool.getDescription(params)).toBe('sub/dir/file.txt');
|
||||
expect(tool.getDescription(params)).toBe(
|
||||
path.join('sub', 'dir', 'file.txt'),
|
||||
);
|
||||
});
|
||||
|
||||
it('should return . if path is the root directory', () => {
|
||||
@@ -142,111 +125,140 @@ describe('ReadFileTool', () => {
|
||||
|
||||
describe('execute', () => {
|
||||
it('should return validation error if params are invalid', async () => {
|
||||
const params: ReadFileToolParams = { absolute_path: 'relative/path.txt' };
|
||||
const result = await tool.execute(params, abortSignal);
|
||||
expect(result.llmContent).toBe(
|
||||
'Error: Invalid parameters provided. Reason: File path must be absolute, but was relative: relative/path.txt. You must provide an absolute path.',
|
||||
);
|
||||
expect(result.returnDisplay).toBe(
|
||||
'File path must be absolute, but was relative: relative/path.txt. You must provide an absolute path.',
|
||||
);
|
||||
const params: ReadFileToolParams = {
|
||||
absolute_path: 'relative/path.txt',
|
||||
};
|
||||
expect(await tool.execute(params, abortSignal)).toEqual({
|
||||
llmContent:
|
||||
'Error: Invalid parameters provided. Reason: File path must be absolute, but was relative: relative/path.txt. You must provide an absolute path.',
|
||||
returnDisplay:
|
||||
'File path must be absolute, but was relative: relative/path.txt. You must provide an absolute path.',
|
||||
});
|
||||
});
|
||||
|
||||
it('should return error from processSingleFileContent if it fails', async () => {
|
||||
const filePath = path.join(tempRootDir, 'error.txt');
|
||||
it('should return error if file does not exist', async () => {
|
||||
const filePath = path.join(tempRootDir, 'nonexistent.txt');
|
||||
const params: ReadFileToolParams = { absolute_path: filePath };
|
||||
const errorMessage = 'Simulated read error';
|
||||
mockProcessSingleFileContent.mockResolvedValue({
|
||||
llmContent: `Error reading file ${filePath}: ${errorMessage}`,
|
||||
returnDisplay: `Error reading file ${filePath}: ${errorMessage}`,
|
||||
error: errorMessage,
|
||||
});
|
||||
|
||||
const result = await tool.execute(params, abortSignal);
|
||||
expect(mockProcessSingleFileContent).toHaveBeenCalledWith(
|
||||
filePath,
|
||||
tempRootDir,
|
||||
undefined,
|
||||
undefined,
|
||||
);
|
||||
expect(result.llmContent).toContain(errorMessage);
|
||||
expect(result.returnDisplay).toContain(errorMessage);
|
||||
expect(await tool.execute(params, abortSignal)).toEqual({
|
||||
llmContent: `File not found: ${filePath}`,
|
||||
returnDisplay: 'File not found.',
|
||||
});
|
||||
});
|
||||
|
||||
it('should return success result for a text file', async () => {
|
||||
const filePath = path.join(tempRootDir, 'textfile.txt');
|
||||
const fileContent = 'This is a test file.';
|
||||
await fsp.writeFile(filePath, fileContent, 'utf-8');
|
||||
const params: ReadFileToolParams = { absolute_path: filePath };
|
||||
mockProcessSingleFileContent.mockResolvedValue({
|
||||
llmContent: fileContent,
|
||||
returnDisplay: `Read text file: ${path.basename(filePath)}`,
|
||||
});
|
||||
|
||||
const result = await tool.execute(params, abortSignal);
|
||||
expect(mockProcessSingleFileContent).toHaveBeenCalledWith(
|
||||
filePath,
|
||||
tempRootDir,
|
||||
undefined,
|
||||
undefined,
|
||||
);
|
||||
expect(result.llmContent).toBe(fileContent);
|
||||
expect(result.returnDisplay).toBe(
|
||||
`Read text file: ${path.basename(filePath)}`,
|
||||
);
|
||||
expect(await tool.execute(params, abortSignal)).toEqual({
|
||||
llmContent: fileContent,
|
||||
returnDisplay: '',
|
||||
});
|
||||
});
|
||||
|
||||
it('should return success result for an image file', async () => {
|
||||
// A minimal 1x1 transparent PNG file.
|
||||
const pngContent = Buffer.from([
|
||||
137, 80, 78, 71, 13, 10, 26, 10, 0, 0, 0, 13, 73, 72, 68, 82, 0, 0, 0,
|
||||
1, 0, 0, 0, 1, 8, 6, 0, 0, 0, 31, 21, 196, 137, 0, 0, 0, 10, 73, 68, 65,
|
||||
84, 120, 156, 99, 0, 1, 0, 0, 5, 0, 1, 13, 10, 45, 180, 0, 0, 0, 0, 73,
|
||||
69, 78, 68, 174, 66, 96, 130,
|
||||
]);
|
||||
const filePath = path.join(tempRootDir, 'image.png');
|
||||
const imageData = {
|
||||
inlineData: { mimeType: 'image/png', data: 'base64...' },
|
||||
};
|
||||
await fsp.writeFile(filePath, pngContent);
|
||||
const params: ReadFileToolParams = { absolute_path: filePath };
|
||||
mockProcessSingleFileContent.mockResolvedValue({
|
||||
llmContent: imageData,
|
||||
returnDisplay: `Read image file: ${path.basename(filePath)}`,
|
||||
});
|
||||
|
||||
const result = await tool.execute(params, abortSignal);
|
||||
expect(mockProcessSingleFileContent).toHaveBeenCalledWith(
|
||||
filePath,
|
||||
tempRootDir,
|
||||
undefined,
|
||||
undefined,
|
||||
);
|
||||
expect(result.llmContent).toEqual(imageData);
|
||||
expect(result.returnDisplay).toBe(
|
||||
`Read image file: ${path.basename(filePath)}`,
|
||||
);
|
||||
expect(await tool.execute(params, abortSignal)).toEqual({
|
||||
llmContent: {
|
||||
inlineData: {
|
||||
mimeType: 'image/png',
|
||||
data: pngContent.toString('base64'),
|
||||
},
|
||||
},
|
||||
returnDisplay: `Read image file: image.png`,
|
||||
});
|
||||
});
|
||||
|
||||
it('should pass offset and limit to processSingleFileContent', async () => {
|
||||
it('should treat a non-image file with image extension as an image', async () => {
|
||||
const filePath = path.join(tempRootDir, 'fake-image.png');
|
||||
const fileContent = 'This is not a real png.';
|
||||
await fsp.writeFile(filePath, fileContent, 'utf-8');
|
||||
const params: ReadFileToolParams = { absolute_path: filePath };
|
||||
|
||||
expect(await tool.execute(params, abortSignal)).toEqual({
|
||||
llmContent: {
|
||||
inlineData: {
|
||||
mimeType: 'image/png',
|
||||
data: Buffer.from(fileContent).toString('base64'),
|
||||
},
|
||||
},
|
||||
returnDisplay: `Read image file: fake-image.png`,
|
||||
});
|
||||
});
|
||||
|
||||
it('should pass offset and limit to read a slice of a text file', async () => {
|
||||
const filePath = path.join(tempRootDir, 'paginated.txt');
|
||||
const fileContent = Array.from(
|
||||
{ length: 20 },
|
||||
(_, i) => `Line ${i + 1}`,
|
||||
).join('\n');
|
||||
await fsp.writeFile(filePath, fileContent, 'utf-8');
|
||||
|
||||
const params: ReadFileToolParams = {
|
||||
absolute_path: filePath,
|
||||
offset: 10,
|
||||
limit: 5,
|
||||
offset: 5, // Start from line 6
|
||||
limit: 3,
|
||||
};
|
||||
mockProcessSingleFileContent.mockResolvedValue({
|
||||
llmContent: 'some lines',
|
||||
returnDisplay: 'Read text file (paginated)',
|
||||
});
|
||||
|
||||
await tool.execute(params, abortSignal);
|
||||
expect(mockProcessSingleFileContent).toHaveBeenCalledWith(
|
||||
filePath,
|
||||
tempRootDir,
|
||||
10,
|
||||
5,
|
||||
);
|
||||
expect(await tool.execute(params, abortSignal)).toEqual({
|
||||
llmContent: [
|
||||
'[File content truncated: showing lines 6-8 of 20 total lines. Use offset/limit parameters to view more.]',
|
||||
'Line 6',
|
||||
'Line 7',
|
||||
'Line 8',
|
||||
].join('\n'),
|
||||
returnDisplay: '(truncated)',
|
||||
});
|
||||
});
|
||||
|
||||
it('should return error if path is ignored by a .geminiignore pattern', async () => {
|
||||
const params: ReadFileToolParams = {
|
||||
absolute_path: path.join(tempRootDir, 'foo.bar'),
|
||||
};
|
||||
const result = await tool.execute(params, abortSignal);
|
||||
expect(result.returnDisplay).toContain('foo.bar');
|
||||
expect(result.returnDisplay).not.toContain('foo.baz');
|
||||
describe('with .geminiignore', () => {
|
||||
beforeEach(async () => {
|
||||
await fsp.writeFile(
|
||||
path.join(tempRootDir, '.geminiignore'),
|
||||
['foo.*', 'ignored/'].join('\n'),
|
||||
);
|
||||
});
|
||||
|
||||
it('should return error if path is ignored by a .geminiignore pattern', async () => {
|
||||
const ignoredFilePath = path.join(tempRootDir, 'foo.bar');
|
||||
await fsp.writeFile(ignoredFilePath, 'content', 'utf-8');
|
||||
const params: ReadFileToolParams = {
|
||||
absolute_path: ignoredFilePath,
|
||||
};
|
||||
const expectedError = `File path '${ignoredFilePath}' is ignored by .geminiignore pattern(s).`;
|
||||
expect(await tool.execute(params, abortSignal)).toEqual({
|
||||
llmContent: `Error: Invalid parameters provided. Reason: ${expectedError}`,
|
||||
returnDisplay: expectedError,
|
||||
});
|
||||
});
|
||||
|
||||
it('should return error if path is in an ignored directory', async () => {
|
||||
const ignoredDirPath = path.join(tempRootDir, 'ignored');
|
||||
await fsp.mkdir(ignoredDirPath);
|
||||
const filePath = path.join(ignoredDirPath, 'somefile.txt');
|
||||
await fsp.writeFile(filePath, 'content', 'utf-8');
|
||||
|
||||
const params: ReadFileToolParams = {
|
||||
absolute_path: filePath,
|
||||
};
|
||||
const expectedError = `File path '${filePath}' is ignored by .geminiignore pattern(s).`;
|
||||
expect(await tool.execute(params, abortSignal)).toEqual({
|
||||
llmContent: `Error: Invalid parameters provided. Reason: ${expectedError}`,
|
||||
returnDisplay: expectedError,
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -7,7 +7,7 @@
|
||||
import path from 'path';
|
||||
import { SchemaValidator } from '../utils/schemaValidator.js';
|
||||
import { makeRelative, shortenPath } from '../utils/paths.js';
|
||||
import { BaseTool, ToolResult } from './tools.js';
|
||||
import { BaseTool, Icon, ToolLocation, ToolResult } from './tools.js';
|
||||
import { Type } from '@google/genai';
|
||||
import {
|
||||
isWithinRoot,
|
||||
@@ -51,6 +51,7 @@ export class ReadFileTool extends BaseTool<ReadFileToolParams, ToolResult> {
|
||||
ReadFileTool.Name,
|
||||
'ReadFile',
|
||||
'Reads and returns the content of a specified file from the local filesystem. Handles text, images (PNG, JPG, GIF, WEBP, SVG, BMP), and PDF files. For text files, it can read specific line ranges.',
|
||||
Icon.FileSearch,
|
||||
{
|
||||
properties: {
|
||||
absolute_path: {
|
||||
@@ -118,6 +119,10 @@ export class ReadFileTool extends BaseTool<ReadFileToolParams, ToolResult> {
|
||||
return shortenPath(relativePath);
|
||||
}
|
||||
|
||||
toolLocations(params: ReadFileToolParams): ToolLocation[] {
|
||||
return [{ path: params.absolute_path, line: params.offset }];
|
||||
}
|
||||
|
||||
async execute(
|
||||
params: ReadFileToolParams,
|
||||
_signal: AbortSignal,
|
||||
|
||||
@@ -58,10 +58,13 @@ describe('ReadManyFilesTool', () => {
|
||||
const fileService = new FileDiscoveryService(tempRootDir);
|
||||
const mockConfig = {
|
||||
getFileService: () => fileService,
|
||||
getFileFilteringRespectGitIgnore: () => true,
|
||||
|
||||
getFileFilteringOptions: () => ({
|
||||
respectGitIgnore: true,
|
||||
respectGeminiIgnore: true,
|
||||
}),
|
||||
getTargetDir: () => tempRootDir,
|
||||
} as Partial<Config> as Config;
|
||||
|
||||
tool = new ReadManyFilesTool(mockConfig);
|
||||
|
||||
mockReadFileFn = mockControl.mockReadFile;
|
||||
@@ -269,7 +272,7 @@ describe('ReadManyFilesTool', () => {
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle non-existent specific files gracefully', async () => {
|
||||
it('should handle nonexistent specific files gracefully', async () => {
|
||||
const params = { paths: ['nonexistent-file.txt'] };
|
||||
const result = await tool.execute(params, new AbortController().signal);
|
||||
expect(result.llmContent).toEqual([
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { BaseTool, ToolResult } from './tools.js';
|
||||
import { BaseTool, Icon, ToolResult } from './tools.js';
|
||||
import { SchemaValidator } from '../utils/schemaValidator.js';
|
||||
import { getErrorMessage } from '../utils/errors.js';
|
||||
import * as path from 'path';
|
||||
@@ -17,7 +17,7 @@ import {
|
||||
getSpecificMimeType,
|
||||
} from '../utils/fileUtils.js';
|
||||
import { PartListUnion, Schema, Type } from '@google/genai';
|
||||
import { Config } from '../config/config.js';
|
||||
import { Config, DEFAULT_FILE_FILTERING_OPTIONS } from '../config/config.js';
|
||||
import {
|
||||
recordFileOperationMetric,
|
||||
FileOperation,
|
||||
@@ -62,9 +62,12 @@ export interface ReadManyFilesParams {
|
||||
useDefaultExcludes?: boolean;
|
||||
|
||||
/**
|
||||
* Optional. Whether to respect .gitignore patterns. Defaults to true.
|
||||
* Whether to respect .gitignore and .geminiignore patterns (optional, defaults to true)
|
||||
*/
|
||||
respect_git_ignore?: boolean;
|
||||
file_filtering_options?: {
|
||||
respect_git_ignore?: boolean;
|
||||
respect_gemini_ignore?: boolean;
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -125,8 +128,6 @@ export class ReadManyFilesTool extends BaseTool<
|
||||
> {
|
||||
static readonly Name: string = 'read_many_files';
|
||||
|
||||
private readonly geminiIgnorePatterns: string[] = [];
|
||||
|
||||
constructor(private config: Config) {
|
||||
const parameterSchema: Schema = {
|
||||
type: Type.OBJECT,
|
||||
@@ -173,11 +174,22 @@ export class ReadManyFilesTool extends BaseTool<
|
||||
'Optional. Whether to apply a list of default exclusion patterns (e.g., node_modules, .git, binary files). Defaults to true.',
|
||||
default: true,
|
||||
},
|
||||
respect_git_ignore: {
|
||||
type: Type.BOOLEAN,
|
||||
file_filtering_options: {
|
||||
description:
|
||||
'Optional. Whether to respect .gitignore patterns when discovering files. Only available in git repositories. Defaults to true.',
|
||||
default: true,
|
||||
'Whether to respect ignore patterns from .gitignore or .geminiignore',
|
||||
type: Type.OBJECT,
|
||||
properties: {
|
||||
respect_git_ignore: {
|
||||
description:
|
||||
'Optional: Whether to respect .gitignore patterns when listing files. Only available in git repositories. Defaults to true.',
|
||||
type: Type.BOOLEAN,
|
||||
},
|
||||
respect_gemini_ignore: {
|
||||
description:
|
||||
'Optional: Whether to respect .geminiignore patterns when listing files. Defaults to true.',
|
||||
type: Type.BOOLEAN,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
required: ['paths'],
|
||||
@@ -196,11 +208,9 @@ This tool is useful when you need to understand or analyze a collection of files
|
||||
- When the user asks to "read all files in X directory" or "show me the content of all Y files".
|
||||
|
||||
Use this tool when the user's query implies needing the content of several files simultaneously for context, analysis, or summarization. For text files, it uses default UTF-8 encoding and a '--- {filePath} ---' separator between file contents. Ensure paths are relative to the target directory. Glob patterns like 'src/**/*.js' are supported. Avoid using for single files if a more specific single-file reading tool is available, unless the user specifically requests to process a list containing just one file via this tool. Other binary files (not explicitly requested as image/PDF) are generally skipped. Default excludes apply to common non-text files (except for explicitly requested images/PDFs) and large dependency directories unless 'useDefaultExcludes' is false.`,
|
||||
Icon.FileSearch,
|
||||
parameterSchema,
|
||||
);
|
||||
this.geminiIgnorePatterns = config
|
||||
.getFileService()
|
||||
.getGeminiIgnorePatterns();
|
||||
}
|
||||
|
||||
validateParams(params: ReadManyFilesParams): string | null {
|
||||
@@ -218,17 +228,19 @@ Use this tool when the user's query implies needing the content of several files
|
||||
// Determine the final list of exclusion patterns exactly as in execute method
|
||||
const paramExcludes = params.exclude || [];
|
||||
const paramUseDefaultExcludes = params.useDefaultExcludes !== false;
|
||||
|
||||
const geminiIgnorePatterns = this.config
|
||||
.getFileService()
|
||||
.getGeminiIgnorePatterns();
|
||||
const finalExclusionPatternsForDescription: string[] =
|
||||
paramUseDefaultExcludes
|
||||
? [...DEFAULT_EXCLUDES, ...paramExcludes, ...this.geminiIgnorePatterns]
|
||||
: [...paramExcludes, ...this.geminiIgnorePatterns];
|
||||
? [...DEFAULT_EXCLUDES, ...paramExcludes, ...geminiIgnorePatterns]
|
||||
: [...paramExcludes, ...geminiIgnorePatterns];
|
||||
|
||||
let excludeDesc = `Excluding: ${finalExclusionPatternsForDescription.length > 0 ? `patterns like \`${finalExclusionPatternsForDescription.slice(0, 2).join('`, `')}${finalExclusionPatternsForDescription.length > 2 ? '...`' : '`'}` : 'none specified'}`;
|
||||
|
||||
// Add a note if .geminiignore patterns contributed to the final list of exclusions
|
||||
if (this.geminiIgnorePatterns.length > 0) {
|
||||
const geminiPatternsInEffect = this.geminiIgnorePatterns.filter((p) =>
|
||||
if (geminiIgnorePatterns.length > 0) {
|
||||
const geminiPatternsInEffect = geminiIgnorePatterns.filter((p) =>
|
||||
finalExclusionPatternsForDescription.includes(p),
|
||||
).length;
|
||||
if (geminiPatternsInEffect > 0) {
|
||||
@@ -256,12 +268,19 @@ Use this tool when the user's query implies needing the content of several files
|
||||
include = [],
|
||||
exclude = [],
|
||||
useDefaultExcludes = true,
|
||||
respect_git_ignore = true,
|
||||
} = params;
|
||||
|
||||
const respectGitIgnore =
|
||||
respect_git_ignore ?? this.config.getFileFilteringRespectGitIgnore();
|
||||
const defaultFileIgnores =
|
||||
this.config.getFileFilteringOptions() ?? DEFAULT_FILE_FILTERING_OPTIONS;
|
||||
|
||||
const fileFilteringOptions = {
|
||||
respectGitIgnore:
|
||||
params.file_filtering_options?.respect_git_ignore ??
|
||||
defaultFileIgnores.respectGitIgnore, // Use the property from the returned object
|
||||
respectGeminiIgnore:
|
||||
params.file_filtering_options?.respect_gemini_ignore ??
|
||||
defaultFileIgnores.respectGeminiIgnore, // Use the property from the returned object
|
||||
};
|
||||
// Get centralized file discovery service
|
||||
const fileDiscovery = this.config.getFileService();
|
||||
|
||||
@@ -271,8 +290,8 @@ Use this tool when the user's query implies needing the content of several files
|
||||
const contentParts: PartListUnion = [];
|
||||
|
||||
const effectiveExcludes = useDefaultExcludes
|
||||
? [...DEFAULT_EXCLUDES, ...exclude, ...this.geminiIgnorePatterns]
|
||||
: [...exclude, ...this.geminiIgnorePatterns];
|
||||
? [...DEFAULT_EXCLUDES, ...exclude]
|
||||
: [...exclude];
|
||||
|
||||
const searchPatterns = [...inputPatterns, ...include];
|
||||
if (searchPatterns.length === 0) {
|
||||
@@ -283,7 +302,8 @@ Use this tool when the user's query implies needing the content of several files
|
||||
}
|
||||
|
||||
try {
|
||||
const entries = await glob(searchPatterns, {
|
||||
const patterns = searchPatterns.map((p) => p.replace(/\\/g, '/'));
|
||||
const entries: string[] = await glob(patterns, {
|
||||
cwd: this.config.getTargetDir(),
|
||||
ignore: effectiveExcludes,
|
||||
nodir: true,
|
||||
@@ -291,20 +311,39 @@ Use this tool when the user's query implies needing the content of several files
|
||||
absolute: true,
|
||||
nocase: true,
|
||||
signal,
|
||||
withFileTypes: false,
|
||||
});
|
||||
|
||||
const filteredEntries = respectGitIgnore
|
||||
const gitFilteredEntries = fileFilteringOptions.respectGitIgnore
|
||||
? fileDiscovery
|
||||
.filterFiles(
|
||||
entries.map((p) => path.relative(this.config.getTargetDir(), p)),
|
||||
{
|
||||
respectGitIgnore,
|
||||
respectGitIgnore: true,
|
||||
respectGeminiIgnore: false,
|
||||
},
|
||||
)
|
||||
.map((p) => path.resolve(this.config.getTargetDir(), p))
|
||||
: entries;
|
||||
|
||||
// Apply gemini ignore filtering if enabled
|
||||
const finalFilteredEntries = fileFilteringOptions.respectGeminiIgnore
|
||||
? fileDiscovery
|
||||
.filterFiles(
|
||||
gitFilteredEntries.map((p) =>
|
||||
path.relative(this.config.getTargetDir(), p),
|
||||
),
|
||||
{
|
||||
respectGitIgnore: false,
|
||||
respectGeminiIgnore: true,
|
||||
},
|
||||
)
|
||||
.map((p) => path.resolve(this.config.getTargetDir(), p))
|
||||
: gitFilteredEntries;
|
||||
|
||||
let gitIgnoredCount = 0;
|
||||
let geminiIgnoredCount = 0;
|
||||
|
||||
for (const absoluteFilePath of entries) {
|
||||
// Security check: ensure the glob library didn't return something outside targetDir.
|
||||
if (!absoluteFilePath.startsWith(this.config.getTargetDir())) {
|
||||
@@ -316,11 +355,23 @@ Use this tool when the user's query implies needing the content of several files
|
||||
}
|
||||
|
||||
// Check if this file was filtered out by git ignore
|
||||
if (respectGitIgnore && !filteredEntries.includes(absoluteFilePath)) {
|
||||
if (
|
||||
fileFilteringOptions.respectGitIgnore &&
|
||||
!gitFilteredEntries.includes(absoluteFilePath)
|
||||
) {
|
||||
gitIgnoredCount++;
|
||||
continue;
|
||||
}
|
||||
|
||||
// Check if this file was filtered out by gemini ignore
|
||||
if (
|
||||
fileFilteringOptions.respectGeminiIgnore &&
|
||||
!finalFilteredEntries.includes(absoluteFilePath)
|
||||
) {
|
||||
geminiIgnoredCount++;
|
||||
continue;
|
||||
}
|
||||
|
||||
filesToConsider.add(absoluteFilePath);
|
||||
}
|
||||
|
||||
@@ -328,7 +379,15 @@ Use this tool when the user's query implies needing the content of several files
|
||||
if (gitIgnoredCount > 0) {
|
||||
skippedFiles.push({
|
||||
path: `${gitIgnoredCount} file(s)`,
|
||||
reason: 'ignored',
|
||||
reason: 'git ignored',
|
||||
});
|
||||
}
|
||||
|
||||
// Add info about gemini-ignored files if any were filtered
|
||||
if (geminiIgnoredCount > 0) {
|
||||
skippedFiles.push({
|
||||
path: `${geminiIgnoredCount} file(s)`,
|
||||
reason: 'gemini ignored',
|
||||
});
|
||||
}
|
||||
} catch (error) {
|
||||
@@ -345,7 +404,7 @@ Use this tool when the user's query implies needing the content of several files
|
||||
.relative(this.config.getTargetDir(), filePath)
|
||||
.replace(/\\/g, '/');
|
||||
|
||||
const fileType = detectFileType(filePath);
|
||||
const fileType = await detectFileType(filePath);
|
||||
|
||||
if (fileType === 'image' || fileType === 'pdf') {
|
||||
const fileExtension = path.extname(filePath).toLowerCase();
|
||||
|
||||
@@ -4,429 +4,384 @@
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { expect, describe, it, vi, beforeEach } from 'vitest';
|
||||
import {
|
||||
vi,
|
||||
describe,
|
||||
it,
|
||||
expect,
|
||||
beforeEach,
|
||||
afterEach,
|
||||
type Mock,
|
||||
} from 'vitest';
|
||||
|
||||
const mockShellExecutionService = vi.hoisted(() => vi.fn());
|
||||
vi.mock('../services/shellExecutionService.js', () => ({
|
||||
ShellExecutionService: { execute: mockShellExecutionService },
|
||||
}));
|
||||
vi.mock('fs');
|
||||
vi.mock('os');
|
||||
vi.mock('crypto');
|
||||
vi.mock('../utils/summarizer.js');
|
||||
|
||||
import { isCommandAllowed } from '../utils/shell-utils.js';
|
||||
import { ShellTool } from './shell.js';
|
||||
import { Config } from '../config/config.js';
|
||||
import { type Config } from '../config/config.js';
|
||||
import {
|
||||
type ShellExecutionResult,
|
||||
type ShellOutputEvent,
|
||||
} from '../services/shellExecutionService.js';
|
||||
import * as fs from 'fs';
|
||||
import * as os from 'os';
|
||||
import * as path from 'path';
|
||||
import * as crypto from 'crypto';
|
||||
import * as summarizer from '../utils/summarizer.js';
|
||||
import { GeminiClient } from '../core/client.js';
|
||||
import { ToolConfirmationOutcome } from './tools.js';
|
||||
import { OUTPUT_UPDATE_INTERVAL_MS } from './shell.js';
|
||||
|
||||
describe('ShellTool', () => {
|
||||
it('should allow a command if no restrictions are provided', async () => {
|
||||
const config = {
|
||||
getCoreTools: () => undefined,
|
||||
getExcludeTools: () => undefined,
|
||||
} as Config;
|
||||
const shellTool = new ShellTool(config);
|
||||
const result = shellTool.isCommandAllowed('ls -l');
|
||||
expect(result.allowed).toBe(true);
|
||||
});
|
||||
|
||||
it('should allow a command if it is in the allowed list', async () => {
|
||||
const config = {
|
||||
getCoreTools: () => ['ShellTool(ls -l)'],
|
||||
getExcludeTools: () => undefined,
|
||||
} as unknown as Config;
|
||||
const shellTool = new ShellTool(config);
|
||||
const result = shellTool.isCommandAllowed('ls -l');
|
||||
expect(result.allowed).toBe(true);
|
||||
});
|
||||
|
||||
it('should block a command if it is not in the allowed list', async () => {
|
||||
const config = {
|
||||
getCoreTools: () => ['ShellTool(ls -l)'],
|
||||
getExcludeTools: () => undefined,
|
||||
} as unknown as Config;
|
||||
const shellTool = new ShellTool(config);
|
||||
const result = shellTool.isCommandAllowed('rm -rf /');
|
||||
expect(result.allowed).toBe(false);
|
||||
expect(result.reason).toBe(
|
||||
"Command 'rm -rf /' is not in the allowed commands list",
|
||||
);
|
||||
});
|
||||
|
||||
it('should block a command if it is in the blocked list', async () => {
|
||||
const config = {
|
||||
getCoreTools: () => undefined,
|
||||
getExcludeTools: () => ['ShellTool(rm -rf /)'],
|
||||
} as unknown as Config;
|
||||
const shellTool = new ShellTool(config);
|
||||
const result = shellTool.isCommandAllowed('rm -rf /');
|
||||
expect(result.allowed).toBe(false);
|
||||
expect(result.reason).toBe(
|
||||
"Command 'rm -rf /' is blocked by configuration",
|
||||
);
|
||||
});
|
||||
|
||||
it('should allow a command if it is not in the blocked list', async () => {
|
||||
const config = {
|
||||
getCoreTools: () => undefined,
|
||||
getExcludeTools: () => ['ShellTool(rm -rf /)'],
|
||||
} as unknown as Config;
|
||||
const shellTool = new ShellTool(config);
|
||||
const result = shellTool.isCommandAllowed('ls -l');
|
||||
expect(result.allowed).toBe(true);
|
||||
});
|
||||
|
||||
it('should block a command if it is in both the allowed and blocked lists', async () => {
|
||||
const config = {
|
||||
getCoreTools: () => ['ShellTool(rm -rf /)'],
|
||||
getExcludeTools: () => ['ShellTool(rm -rf /)'],
|
||||
} as unknown as Config;
|
||||
const shellTool = new ShellTool(config);
|
||||
const result = shellTool.isCommandAllowed('rm -rf /');
|
||||
expect(result.allowed).toBe(false);
|
||||
expect(result.reason).toBe(
|
||||
"Command 'rm -rf /' is blocked by configuration",
|
||||
);
|
||||
});
|
||||
|
||||
it('should allow any command when ShellTool is in coreTools without specific commands', async () => {
|
||||
const config = {
|
||||
getCoreTools: () => ['ShellTool'],
|
||||
getExcludeTools: () => [],
|
||||
} as unknown as Config;
|
||||
const shellTool = new ShellTool(config);
|
||||
const result = shellTool.isCommandAllowed('any command');
|
||||
expect(result.allowed).toBe(true);
|
||||
});
|
||||
|
||||
it('should block any command when ShellTool is in excludeTools without specific commands', async () => {
|
||||
const config = {
|
||||
getCoreTools: () => [],
|
||||
getExcludeTools: () => ['ShellTool'],
|
||||
} as unknown as Config;
|
||||
const shellTool = new ShellTool(config);
|
||||
const result = shellTool.isCommandAllowed('any command');
|
||||
expect(result.allowed).toBe(false);
|
||||
expect(result.reason).toBe(
|
||||
'Shell tool is globally disabled in configuration',
|
||||
);
|
||||
});
|
||||
|
||||
it('should allow a command if it is in the allowed list using the public-facing name', async () => {
|
||||
const config = {
|
||||
getCoreTools: () => ['run_shell_command(ls -l)'],
|
||||
getExcludeTools: () => undefined,
|
||||
} as unknown as Config;
|
||||
const shellTool = new ShellTool(config);
|
||||
const result = shellTool.isCommandAllowed('ls -l');
|
||||
expect(result.allowed).toBe(true);
|
||||
});
|
||||
|
||||
it('should block a command if it is in the blocked list using the public-facing name', async () => {
|
||||
const config = {
|
||||
getCoreTools: () => undefined,
|
||||
getExcludeTools: () => ['run_shell_command(rm -rf /)'],
|
||||
} as unknown as Config;
|
||||
const shellTool = new ShellTool(config);
|
||||
const result = shellTool.isCommandAllowed('rm -rf /');
|
||||
expect(result.allowed).toBe(false);
|
||||
expect(result.reason).toBe(
|
||||
"Command 'rm -rf /' is blocked by configuration",
|
||||
);
|
||||
});
|
||||
|
||||
it('should block any command when ShellTool is in excludeTools using the public-facing name', async () => {
|
||||
const config = {
|
||||
getCoreTools: () => [],
|
||||
getExcludeTools: () => ['run_shell_command'],
|
||||
} as unknown as Config;
|
||||
const shellTool = new ShellTool(config);
|
||||
const result = shellTool.isCommandAllowed('any command');
|
||||
expect(result.allowed).toBe(false);
|
||||
expect(result.reason).toBe(
|
||||
'Shell tool is globally disabled in configuration',
|
||||
);
|
||||
});
|
||||
|
||||
it('should block any command if coreTools contains an empty ShellTool command list using the public-facing name', async () => {
|
||||
const config = {
|
||||
getCoreTools: () => ['run_shell_command()'],
|
||||
getExcludeTools: () => [],
|
||||
} as unknown as Config;
|
||||
const shellTool = new ShellTool(config);
|
||||
const result = shellTool.isCommandAllowed('any command');
|
||||
expect(result.allowed).toBe(false);
|
||||
expect(result.reason).toBe(
|
||||
"Command 'any command' is not in the allowed commands list",
|
||||
);
|
||||
});
|
||||
|
||||
it('should block any command if coreTools contains an empty ShellTool command list', async () => {
|
||||
const config = {
|
||||
getCoreTools: () => ['ShellTool()'],
|
||||
getExcludeTools: () => [],
|
||||
} as unknown as Config;
|
||||
const shellTool = new ShellTool(config);
|
||||
const result = shellTool.isCommandAllowed('any command');
|
||||
expect(result.allowed).toBe(false);
|
||||
expect(result.reason).toBe(
|
||||
"Command 'any command' is not in the allowed commands list",
|
||||
);
|
||||
});
|
||||
|
||||
it('should block a command with extra whitespace if it is in the blocked list', async () => {
|
||||
const config = {
|
||||
getCoreTools: () => undefined,
|
||||
getExcludeTools: () => ['ShellTool(rm -rf /)'],
|
||||
} as unknown as Config;
|
||||
const shellTool = new ShellTool(config);
|
||||
const result = shellTool.isCommandAllowed(' rm -rf / ');
|
||||
expect(result.allowed).toBe(false);
|
||||
expect(result.reason).toBe(
|
||||
"Command 'rm -rf /' is blocked by configuration",
|
||||
);
|
||||
});
|
||||
|
||||
it('should allow any command when ShellTool is present with specific commands', async () => {
|
||||
const config = {
|
||||
getCoreTools: () => ['ShellTool', 'ShellTool(ls)'],
|
||||
getExcludeTools: () => [],
|
||||
} as unknown as Config;
|
||||
const shellTool = new ShellTool(config);
|
||||
const result = shellTool.isCommandAllowed('any command');
|
||||
expect(result.allowed).toBe(true);
|
||||
});
|
||||
|
||||
it('should block a command on the blocklist even with a wildcard allow', async () => {
|
||||
const config = {
|
||||
getCoreTools: () => ['ShellTool'],
|
||||
getExcludeTools: () => ['ShellTool(rm -rf /)'],
|
||||
} as unknown as Config;
|
||||
const shellTool = new ShellTool(config);
|
||||
const result = shellTool.isCommandAllowed('rm -rf /');
|
||||
expect(result.allowed).toBe(false);
|
||||
expect(result.reason).toBe(
|
||||
"Command 'rm -rf /' is blocked by configuration",
|
||||
);
|
||||
});
|
||||
|
||||
it('should allow a command that starts with an allowed command prefix', async () => {
|
||||
const config = {
|
||||
getCoreTools: () => ['ShellTool(gh issue edit)'],
|
||||
getExcludeTools: () => [],
|
||||
} as unknown as Config;
|
||||
const shellTool = new ShellTool(config);
|
||||
const result = shellTool.isCommandAllowed(
|
||||
'gh issue edit 1 --add-label "kind/feature"',
|
||||
);
|
||||
expect(result.allowed).toBe(true);
|
||||
});
|
||||
|
||||
it('should allow a command that starts with an allowed command prefix using the public-facing name', async () => {
|
||||
const config = {
|
||||
getCoreTools: () => ['run_shell_command(gh issue edit)'],
|
||||
getExcludeTools: () => [],
|
||||
} as unknown as Config;
|
||||
const shellTool = new ShellTool(config);
|
||||
const result = shellTool.isCommandAllowed(
|
||||
'gh issue edit 1 --add-label "kind/feature"',
|
||||
);
|
||||
expect(result.allowed).toBe(true);
|
||||
});
|
||||
|
||||
it('should not allow a command that starts with an allowed command prefix but is chained with another command', async () => {
|
||||
const config = {
|
||||
getCoreTools: () => ['run_shell_command(gh issue edit)'],
|
||||
getExcludeTools: () => [],
|
||||
} as unknown as Config;
|
||||
const shellTool = new ShellTool(config);
|
||||
const result = shellTool.isCommandAllowed('gh issue edit&&rm -rf /');
|
||||
expect(result.allowed).toBe(false);
|
||||
expect(result.reason).toBe(
|
||||
"Command 'rm -rf /' is not in the allowed commands list",
|
||||
);
|
||||
});
|
||||
|
||||
it('should not allow a command that is a prefix of an allowed command', async () => {
|
||||
const config = {
|
||||
getCoreTools: () => ['run_shell_command(gh issue edit)'],
|
||||
getExcludeTools: () => [],
|
||||
} as unknown as Config;
|
||||
const shellTool = new ShellTool(config);
|
||||
const result = shellTool.isCommandAllowed('gh issue');
|
||||
expect(result.allowed).toBe(false);
|
||||
expect(result.reason).toBe(
|
||||
"Command 'gh issue' is not in the allowed commands list",
|
||||
);
|
||||
});
|
||||
|
||||
it('should not allow a command that is a prefix of a blocked command', async () => {
|
||||
const config = {
|
||||
getCoreTools: () => [],
|
||||
getExcludeTools: () => ['run_shell_command(gh issue edit)'],
|
||||
} as unknown as Config;
|
||||
const shellTool = new ShellTool(config);
|
||||
const result = shellTool.isCommandAllowed('gh issue');
|
||||
expect(result.allowed).toBe(true);
|
||||
});
|
||||
|
||||
it('should not allow a command that is chained with a pipe', async () => {
|
||||
const config = {
|
||||
getCoreTools: () => ['run_shell_command(gh issue list)'],
|
||||
getExcludeTools: () => [],
|
||||
} as unknown as Config;
|
||||
const shellTool = new ShellTool(config);
|
||||
const result = shellTool.isCommandAllowed('gh issue list | rm -rf /');
|
||||
expect(result.allowed).toBe(false);
|
||||
expect(result.reason).toBe(
|
||||
"Command 'rm -rf /' is not in the allowed commands list",
|
||||
);
|
||||
});
|
||||
|
||||
it('should not allow a command that is chained with a semicolon', async () => {
|
||||
const config = {
|
||||
getCoreTools: () => ['run_shell_command(gh issue list)'],
|
||||
getExcludeTools: () => [],
|
||||
} as unknown as Config;
|
||||
const shellTool = new ShellTool(config);
|
||||
const result = shellTool.isCommandAllowed('gh issue list; rm -rf /');
|
||||
expect(result.allowed).toBe(false);
|
||||
expect(result.reason).toBe(
|
||||
"Command 'rm -rf /' is not in the allowed commands list",
|
||||
);
|
||||
});
|
||||
|
||||
it('should block a chained command if any part is blocked', async () => {
|
||||
const config = {
|
||||
getCoreTools: () => ['run_shell_command(echo "hello")'],
|
||||
getExcludeTools: () => ['run_shell_command(rm)'],
|
||||
} as unknown as Config;
|
||||
const shellTool = new ShellTool(config);
|
||||
const result = shellTool.isCommandAllowed('echo "hello" && rm -rf /');
|
||||
expect(result.allowed).toBe(false);
|
||||
expect(result.reason).toBe(
|
||||
"Command 'rm -rf /' is blocked by configuration",
|
||||
);
|
||||
});
|
||||
|
||||
it('should block a command if its prefix is on the blocklist, even if the command itself is on the allowlist', async () => {
|
||||
const config = {
|
||||
getCoreTools: () => ['run_shell_command(git push)'],
|
||||
getExcludeTools: () => ['run_shell_command(git)'],
|
||||
} as unknown as Config;
|
||||
const shellTool = new ShellTool(config);
|
||||
const result = shellTool.isCommandAllowed('git push');
|
||||
expect(result.allowed).toBe(false);
|
||||
expect(result.reason).toBe(
|
||||
"Command 'git push' is blocked by configuration",
|
||||
);
|
||||
});
|
||||
|
||||
it('should be case-sensitive in its matching', async () => {
|
||||
const config = {
|
||||
getCoreTools: () => ['run_shell_command(echo)'],
|
||||
getExcludeTools: () => [],
|
||||
} as unknown as Config;
|
||||
const shellTool = new ShellTool(config);
|
||||
const result = shellTool.isCommandAllowed('ECHO "hello"');
|
||||
expect(result.allowed).toBe(false);
|
||||
expect(result.reason).toBe(
|
||||
'Command \'ECHO "hello"\' is not in the allowed commands list',
|
||||
);
|
||||
});
|
||||
|
||||
it('should correctly handle commands with extra whitespace around chaining operators', async () => {
|
||||
const config = {
|
||||
getCoreTools: () => ['run_shell_command(ls -l)'],
|
||||
getExcludeTools: () => ['run_shell_command(rm)'],
|
||||
} as unknown as Config;
|
||||
const shellTool = new ShellTool(config);
|
||||
const result = shellTool.isCommandAllowed('ls -l ; rm -rf /');
|
||||
expect(result.allowed).toBe(false);
|
||||
expect(result.reason).toBe(
|
||||
"Command 'rm -rf /' is blocked by configuration",
|
||||
);
|
||||
});
|
||||
|
||||
it('should allow a chained command if all parts are allowed', async () => {
|
||||
const config = {
|
||||
getCoreTools: () => [
|
||||
'run_shell_command(echo)',
|
||||
'run_shell_command(ls -l)',
|
||||
],
|
||||
getExcludeTools: () => [],
|
||||
} as unknown as Config;
|
||||
const shellTool = new ShellTool(config);
|
||||
const result = shellTool.isCommandAllowed('echo "hello" && ls -l');
|
||||
expect(result.allowed).toBe(true);
|
||||
});
|
||||
|
||||
it('should allow a command with command substitution using backticks', async () => {
|
||||
const config = {
|
||||
getCoreTools: () => ['run_shell_command(echo)'],
|
||||
getExcludeTools: () => [],
|
||||
} as unknown as Config;
|
||||
const shellTool = new ShellTool(config);
|
||||
const result = shellTool.isCommandAllowed('echo `rm -rf /`');
|
||||
expect(result.allowed).toBe(true);
|
||||
});
|
||||
|
||||
it('should block a command with command substitution using $()', async () => {
|
||||
const config = {
|
||||
getCoreTools: () => ['run_shell_command(echo)'],
|
||||
getExcludeTools: () => [],
|
||||
} as unknown as Config;
|
||||
const shellTool = new ShellTool(config);
|
||||
const result = shellTool.isCommandAllowed('echo $(rm -rf /)');
|
||||
expect(result.allowed).toBe(false);
|
||||
expect(result.reason).toBe(
|
||||
'Command substitution using $() is not allowed for security reasons',
|
||||
);
|
||||
});
|
||||
|
||||
it('should allow a command with I/O redirection', async () => {
|
||||
const config = {
|
||||
getCoreTools: () => ['run_shell_command(echo)'],
|
||||
getExcludeTools: () => [],
|
||||
} as unknown as Config;
|
||||
const shellTool = new ShellTool(config);
|
||||
const result = shellTool.isCommandAllowed('echo "hello" > file.txt');
|
||||
expect(result.allowed).toBe(true);
|
||||
});
|
||||
|
||||
it('should not allow a command that is chained with a double pipe', async () => {
|
||||
const config = {
|
||||
getCoreTools: () => ['run_shell_command(gh issue list)'],
|
||||
getExcludeTools: () => [],
|
||||
} as unknown as Config;
|
||||
const shellTool = new ShellTool(config);
|
||||
const result = shellTool.isCommandAllowed('gh issue list || rm -rf /');
|
||||
expect(result.allowed).toBe(false);
|
||||
expect(result.reason).toBe(
|
||||
"Command 'rm -rf /' is not in the allowed commands list",
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('ShellTool Bug Reproduction', () => {
|
||||
let shellTool: ShellTool;
|
||||
let config: Config;
|
||||
let mockConfig: Config;
|
||||
let mockShellOutputCallback: (event: ShellOutputEvent) => void;
|
||||
let resolveExecutionPromise: (result: ShellExecutionResult) => void;
|
||||
|
||||
beforeEach(() => {
|
||||
config = {
|
||||
getCoreTools: () => undefined,
|
||||
getExcludeTools: () => undefined,
|
||||
getDebugMode: () => false,
|
||||
getGeminiClient: () => ({}) as GeminiClient,
|
||||
getTargetDir: () => '.',
|
||||
vi.clearAllMocks();
|
||||
|
||||
mockConfig = {
|
||||
getCoreTools: vi.fn().mockReturnValue([]),
|
||||
getExcludeTools: vi.fn().mockReturnValue([]),
|
||||
getDebugMode: vi.fn().mockReturnValue(false),
|
||||
getTargetDir: vi.fn().mockReturnValue('/test/dir'),
|
||||
getSummarizeToolOutputConfig: vi.fn().mockReturnValue(undefined),
|
||||
getGeminiClient: vi.fn(),
|
||||
} as unknown as Config;
|
||||
shellTool = new ShellTool(config);
|
||||
});
|
||||
|
||||
it('should not let the summarizer override the return display', async () => {
|
||||
const summarizeSpy = vi
|
||||
.spyOn(summarizer, 'summarizeToolOutput')
|
||||
.mockResolvedValue('summarized output');
|
||||
shellTool = new ShellTool(mockConfig);
|
||||
|
||||
const abortSignal = new AbortController().signal;
|
||||
const result = await shellTool.execute(
|
||||
{ command: 'echo "hello"' },
|
||||
abortSignal,
|
||||
vi.mocked(os.platform).mockReturnValue('linux');
|
||||
vi.mocked(os.tmpdir).mockReturnValue('/tmp');
|
||||
(vi.mocked(crypto.randomBytes) as Mock).mockReturnValue(
|
||||
Buffer.from('abcdef', 'hex'),
|
||||
);
|
||||
|
||||
expect(result.returnDisplay).toBe('hello\n');
|
||||
expect(result.llmContent).toBe('summarized output');
|
||||
expect(summarizeSpy).toHaveBeenCalled();
|
||||
// Capture the output callback to simulate streaming events from the service
|
||||
mockShellExecutionService.mockImplementation((_cmd, _cwd, callback) => {
|
||||
mockShellOutputCallback = callback;
|
||||
return {
|
||||
pid: 12345,
|
||||
result: new Promise((resolve) => {
|
||||
resolveExecutionPromise = resolve;
|
||||
}),
|
||||
};
|
||||
});
|
||||
});
|
||||
|
||||
describe('isCommandAllowed', () => {
|
||||
it('should allow a command if no restrictions are provided', () => {
|
||||
(mockConfig.getCoreTools as Mock).mockReturnValue(undefined);
|
||||
(mockConfig.getExcludeTools as Mock).mockReturnValue(undefined);
|
||||
expect(isCommandAllowed('ls -l', mockConfig).allowed).toBe(true);
|
||||
});
|
||||
|
||||
it('should block a command with command substitution using $()', () => {
|
||||
expect(isCommandAllowed('echo $(rm -rf /)', mockConfig).allowed).toBe(
|
||||
false,
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('validateToolParams', () => {
|
||||
it('should return null for a valid command', () => {
|
||||
expect(shellTool.validateToolParams({ command: 'ls -l' })).toBeNull();
|
||||
});
|
||||
|
||||
it('should return an error for an empty command', () => {
|
||||
expect(shellTool.validateToolParams({ command: ' ' })).toBe(
|
||||
'Command cannot be empty.',
|
||||
);
|
||||
});
|
||||
|
||||
it('should return an error for a non-existent directory', () => {
|
||||
vi.mocked(fs.existsSync).mockReturnValue(false);
|
||||
expect(
|
||||
shellTool.validateToolParams({ command: 'ls', directory: 'rel/path' }),
|
||||
).toBe('Directory must exist.');
|
||||
});
|
||||
});
|
||||
|
||||
describe('execute', () => {
|
||||
const mockAbortSignal = new AbortController().signal;
|
||||
|
||||
const resolveShellExecution = (
|
||||
result: Partial<ShellExecutionResult> = {},
|
||||
) => {
|
||||
const fullResult: ShellExecutionResult = {
|
||||
rawOutput: Buffer.from(result.output || ''),
|
||||
output: 'Success',
|
||||
stdout: 'Success',
|
||||
stderr: '',
|
||||
exitCode: 0,
|
||||
signal: null,
|
||||
error: null,
|
||||
aborted: false,
|
||||
pid: 12345,
|
||||
...result,
|
||||
};
|
||||
resolveExecutionPromise(fullResult);
|
||||
};
|
||||
|
||||
it('should wrap command on linux and parse pgrep output', async () => {
|
||||
const promise = shellTool.execute(
|
||||
{ command: 'my-command &' },
|
||||
mockAbortSignal,
|
||||
);
|
||||
resolveShellExecution({ pid: 54321 });
|
||||
|
||||
vi.mocked(fs.existsSync).mockReturnValue(true);
|
||||
vi.mocked(fs.readFileSync).mockReturnValue('54321\n54322\n'); // Service PID and background PID
|
||||
|
||||
const result = await promise;
|
||||
|
||||
const tmpFile = path.join(os.tmpdir(), 'shell_pgrep_abcdef.tmp');
|
||||
const wrappedCommand = `{ my-command & }; __code=$?; pgrep -g 0 >${tmpFile} 2>&1; exit $__code;`;
|
||||
expect(mockShellExecutionService).toHaveBeenCalledWith(
|
||||
wrappedCommand,
|
||||
expect.any(String),
|
||||
expect.any(Function),
|
||||
mockAbortSignal,
|
||||
);
|
||||
expect(result.llmContent).toContain('Background PIDs: 54322');
|
||||
expect(vi.mocked(fs.unlinkSync)).toHaveBeenCalledWith(tmpFile);
|
||||
});
|
||||
|
||||
it('should not wrap command on windows', async () => {
|
||||
vi.mocked(os.platform).mockReturnValue('win32');
|
||||
const promise = shellTool.execute({ command: 'dir' }, mockAbortSignal);
|
||||
resolveExecutionPromise({
|
||||
rawOutput: Buffer.from(''),
|
||||
output: '',
|
||||
stdout: '',
|
||||
stderr: '',
|
||||
exitCode: 0,
|
||||
signal: null,
|
||||
error: null,
|
||||
aborted: false,
|
||||
pid: 12345,
|
||||
});
|
||||
await promise;
|
||||
expect(mockShellExecutionService).toHaveBeenCalledWith(
|
||||
'dir',
|
||||
expect.any(String),
|
||||
expect.any(Function),
|
||||
mockAbortSignal,
|
||||
);
|
||||
});
|
||||
|
||||
it('should format error messages correctly', async () => {
|
||||
const error = new Error('wrapped command failed');
|
||||
const promise = shellTool.execute(
|
||||
{ command: 'user-command' },
|
||||
mockAbortSignal,
|
||||
);
|
||||
resolveShellExecution({
|
||||
error,
|
||||
exitCode: 1,
|
||||
output: 'err',
|
||||
stderr: 'err',
|
||||
rawOutput: Buffer.from('err'),
|
||||
stdout: '',
|
||||
signal: null,
|
||||
aborted: false,
|
||||
pid: 12345,
|
||||
});
|
||||
|
||||
const result = await promise;
|
||||
// The final llmContent should contain the user's command, not the wrapper
|
||||
expect(result.llmContent).toContain('Error: wrapped command failed');
|
||||
expect(result.llmContent).not.toContain('pgrep');
|
||||
});
|
||||
|
||||
it('should summarize output when configured', async () => {
|
||||
(mockConfig.getSummarizeToolOutputConfig as Mock).mockReturnValue({
|
||||
[shellTool.name]: { tokenBudget: 1000 },
|
||||
});
|
||||
vi.mocked(summarizer.summarizeToolOutput).mockResolvedValue(
|
||||
'summarized output',
|
||||
);
|
||||
|
||||
const promise = shellTool.execute({ command: 'ls' }, mockAbortSignal);
|
||||
resolveExecutionPromise({
|
||||
output: 'long output',
|
||||
rawOutput: Buffer.from('long output'),
|
||||
stdout: 'long output',
|
||||
stderr: '',
|
||||
exitCode: 0,
|
||||
signal: null,
|
||||
error: null,
|
||||
aborted: false,
|
||||
pid: 12345,
|
||||
});
|
||||
|
||||
const result = await promise;
|
||||
|
||||
expect(summarizer.summarizeToolOutput).toHaveBeenCalledWith(
|
||||
expect.any(String),
|
||||
mockConfig.getGeminiClient(),
|
||||
mockAbortSignal,
|
||||
1000,
|
||||
);
|
||||
expect(result.llmContent).toBe('summarized output');
|
||||
expect(result.returnDisplay).toBe('long output');
|
||||
});
|
||||
|
||||
it('should clean up the temp file on synchronous execution error', async () => {
|
||||
const error = new Error('sync spawn error');
|
||||
mockShellExecutionService.mockImplementation(() => {
|
||||
throw error;
|
||||
});
|
||||
vi.mocked(fs.existsSync).mockReturnValue(true); // Pretend the file exists
|
||||
|
||||
await expect(
|
||||
shellTool.execute({ command: 'a-command' }, mockAbortSignal),
|
||||
).rejects.toThrow(error);
|
||||
|
||||
const tmpFile = path.join(os.tmpdir(), 'shell_pgrep_abcdef.tmp');
|
||||
expect(vi.mocked(fs.unlinkSync)).toHaveBeenCalledWith(tmpFile);
|
||||
});
|
||||
|
||||
describe('Streaming to `updateOutput`', () => {
|
||||
let updateOutputMock: Mock;
|
||||
beforeEach(() => {
|
||||
vi.useFakeTimers({ toFake: ['Date'] });
|
||||
updateOutputMock = vi.fn();
|
||||
});
|
||||
afterEach(() => {
|
||||
vi.useRealTimers();
|
||||
});
|
||||
|
||||
it('should throttle text output updates', async () => {
|
||||
const promise = shellTool.execute(
|
||||
{ command: 'stream' },
|
||||
mockAbortSignal,
|
||||
updateOutputMock,
|
||||
);
|
||||
|
||||
// First chunk, should be throttled.
|
||||
mockShellOutputCallback({
|
||||
type: 'data',
|
||||
stream: 'stdout',
|
||||
chunk: 'hello ',
|
||||
});
|
||||
expect(updateOutputMock).not.toHaveBeenCalled();
|
||||
|
||||
// Advance time past the throttle interval.
|
||||
await vi.advanceTimersByTimeAsync(OUTPUT_UPDATE_INTERVAL_MS + 1);
|
||||
|
||||
// Send a second chunk. THIS event triggers the update with the CUMULATIVE content.
|
||||
mockShellOutputCallback({
|
||||
type: 'data',
|
||||
stream: 'stderr',
|
||||
chunk: 'world',
|
||||
});
|
||||
|
||||
// It should have been called once now with the combined output.
|
||||
expect(updateOutputMock).toHaveBeenCalledOnce();
|
||||
expect(updateOutputMock).toHaveBeenCalledWith('hello \nworld');
|
||||
|
||||
resolveExecutionPromise({
|
||||
rawOutput: Buffer.from(''),
|
||||
output: '',
|
||||
stdout: '',
|
||||
stderr: '',
|
||||
exitCode: 0,
|
||||
signal: null,
|
||||
error: null,
|
||||
aborted: false,
|
||||
pid: 12345,
|
||||
});
|
||||
await promise;
|
||||
});
|
||||
|
||||
it('should immediately show binary detection message and throttle progress', async () => {
|
||||
const promise = shellTool.execute(
|
||||
{ command: 'cat img' },
|
||||
mockAbortSignal,
|
||||
updateOutputMock,
|
||||
);
|
||||
|
||||
mockShellOutputCallback({ type: 'binary_detected' });
|
||||
expect(updateOutputMock).toHaveBeenCalledOnce();
|
||||
expect(updateOutputMock).toHaveBeenCalledWith(
|
||||
'[Binary output detected. Halting stream...]',
|
||||
);
|
||||
|
||||
mockShellOutputCallback({
|
||||
type: 'binary_progress',
|
||||
bytesReceived: 1024,
|
||||
});
|
||||
expect(updateOutputMock).toHaveBeenCalledOnce();
|
||||
|
||||
// Advance time past the throttle interval.
|
||||
await vi.advanceTimersByTimeAsync(OUTPUT_UPDATE_INTERVAL_MS + 1);
|
||||
|
||||
// Send a SECOND progress event. This one will trigger the flush.
|
||||
mockShellOutputCallback({
|
||||
type: 'binary_progress',
|
||||
bytesReceived: 2048,
|
||||
});
|
||||
|
||||
// Now it should be called a second time with the latest progress.
|
||||
expect(updateOutputMock).toHaveBeenCalledTimes(2);
|
||||
expect(updateOutputMock).toHaveBeenLastCalledWith(
|
||||
'[Receiving binary output... 2.0 KB received]',
|
||||
);
|
||||
|
||||
resolveExecutionPromise({
|
||||
rawOutput: Buffer.from(''),
|
||||
output: '',
|
||||
stdout: '',
|
||||
stderr: '',
|
||||
exitCode: 0,
|
||||
signal: null,
|
||||
error: null,
|
||||
aborted: false,
|
||||
pid: 12345,
|
||||
});
|
||||
await promise;
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('shouldConfirmExecute', () => {
|
||||
it('should request confirmation for a new command and whitelist it on "Always"', async () => {
|
||||
const params = { command: 'npm install' };
|
||||
const confirmation = await shellTool.shouldConfirmExecute(
|
||||
params,
|
||||
new AbortController().signal,
|
||||
);
|
||||
|
||||
expect(confirmation).not.toBe(false);
|
||||
expect(confirmation && confirmation.type).toBe('exec');
|
||||
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
await (confirmation as any).onConfirm(
|
||||
ToolConfirmationOutcome.ProceedAlways,
|
||||
);
|
||||
|
||||
// Should now be whitelisted
|
||||
const secondConfirmation = await shellTool.shouldConfirmExecute(
|
||||
{ command: 'npm test' },
|
||||
new AbortController().signal,
|
||||
);
|
||||
expect(secondConfirmation).toBe(false);
|
||||
});
|
||||
|
||||
it('should skip confirmation if validation fails', async () => {
|
||||
const confirmation = await shellTool.shouldConfirmExecute(
|
||||
{ command: '' },
|
||||
new AbortController().signal,
|
||||
);
|
||||
expect(confirmation).toBe(false);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -15,25 +15,34 @@ import {
|
||||
ToolCallConfirmationDetails,
|
||||
ToolExecuteConfirmationDetails,
|
||||
ToolConfirmationOutcome,
|
||||
Icon,
|
||||
} from './tools.js';
|
||||
import { Type } from '@google/genai';
|
||||
import { SchemaValidator } from '../utils/schemaValidator.js';
|
||||
import { getErrorMessage } from '../utils/errors.js';
|
||||
import stripAnsi from 'strip-ansi';
|
||||
import { summarizeToolOutput } from '../utils/summarizer.js';
|
||||
import {
|
||||
ShellExecutionService,
|
||||
ShellOutputEvent,
|
||||
} from '../services/shellExecutionService.js';
|
||||
import { formatMemoryUsage } from '../utils/formatters.js';
|
||||
import {
|
||||
getCommandRoots,
|
||||
isCommandAllowed,
|
||||
stripShellWrapper,
|
||||
} from '../utils/shell-utils.js';
|
||||
|
||||
export const OUTPUT_UPDATE_INTERVAL_MS = 1000;
|
||||
|
||||
export interface ShellToolParams {
|
||||
command: string;
|
||||
description?: string;
|
||||
directory?: string;
|
||||
}
|
||||
import { spawn } from 'child_process';
|
||||
import { summarizeToolOutput } from '../utils/summarizer.js';
|
||||
|
||||
const OUTPUT_UPDATE_INTERVAL_MS = 1000;
|
||||
|
||||
export class ShellTool extends BaseTool<ShellToolParams, ToolResult> {
|
||||
static Name: string = 'run_shell_command';
|
||||
private whitelist: Set<string> = new Set();
|
||||
private allowlist: Set<string> = new Set();
|
||||
|
||||
constructor(private readonly config: Config) {
|
||||
super(
|
||||
@@ -41,17 +50,18 @@ export class ShellTool extends BaseTool<ShellToolParams, ToolResult> {
|
||||
'Shell',
|
||||
`This tool executes a given shell command as \`bash -c <command>\`. Command can start background processes using \`&\`. Command is executed as a subprocess that leads its own process group. Command process group can be terminated as \`kill -- -PGID\` or signaled as \`kill -s SIGNAL -- -PGID\`.
|
||||
|
||||
The following information is returned:
|
||||
The following information is returned:
|
||||
|
||||
Command: Executed command.
|
||||
Directory: Directory (relative to project root) where command was executed, or \`(root)\`.
|
||||
Stdout: Output on stdout stream. Can be \`(empty)\` or partial on error and for any unwaited background processes.
|
||||
Stderr: Output on stderr stream. Can be \`(empty)\` or partial on error and for any unwaited background processes.
|
||||
Error: Error or \`(none)\` if no error was reported for the subprocess.
|
||||
Exit Code: Exit code or \`(none)\` if terminated by signal.
|
||||
Signal: Signal number or \`(none)\` if no signal was received.
|
||||
Background PIDs: List of background processes started or \`(none)\`.
|
||||
Process Group PGID: Process group started or \`(none)\``,
|
||||
Command: Executed command.
|
||||
Directory: Directory (relative to project root) where command was executed, or \`(root)\`.
|
||||
Stdout: Output on stdout stream. Can be \`(empty)\` or partial on error and for any unwaited background processes.
|
||||
Stderr: Output on stderr stream. Can be \`(empty)\` or partial on error and for any unwaited background processes.
|
||||
Error: Error or \`(none)\` if no error was reported for the subprocess.
|
||||
Exit Code: Exit code or \`(none)\` if terminated by signal.
|
||||
Signal: Signal number or \`(none)\` if no signal was received.
|
||||
Background PIDs: List of background processes started or \`(none)\`.
|
||||
Process Group PGID: Process group started or \`(none)\``,
|
||||
Icon.Terminal,
|
||||
{
|
||||
type: Type.OBJECT,
|
||||
properties: {
|
||||
@@ -91,131 +101,8 @@ Process Group PGID: Process group started or \`(none)\``,
|
||||
return description;
|
||||
}
|
||||
|
||||
/**
|
||||
* Extracts the root command from a given shell command string.
|
||||
* This is used to identify the base command for permission checks.
|
||||
*
|
||||
* @param command The shell command string to parse
|
||||
* @returns The root command name, or undefined if it cannot be determined
|
||||
* @example getCommandRoot("ls -la /tmp") returns "ls"
|
||||
* @example getCommandRoot("git status && npm test") returns "git"
|
||||
*/
|
||||
getCommandRoot(command: string): string | undefined {
|
||||
return command
|
||||
.trim() // remove leading and trailing whitespace
|
||||
.replace(/[{}()]/g, '') // remove all grouping operators
|
||||
.split(/[\s;&|]+/)[0] // split on any whitespace or separator or chaining operators and take first part
|
||||
?.split(/[/\\]/) // split on any path separators (or return undefined if previous line was undefined)
|
||||
.pop(); // take last part and return command root (or undefined if previous line was empty)
|
||||
}
|
||||
|
||||
/**
|
||||
* Determines whether a given shell command is allowed to execute based on
|
||||
* the tool's configuration including allowlists and blocklists.
|
||||
*
|
||||
* @param command The shell command string to validate
|
||||
* @returns An object with 'allowed' boolean and optional 'reason' string if not allowed
|
||||
*/
|
||||
isCommandAllowed(command: string): { allowed: boolean; reason?: string } {
|
||||
// 0. Disallow command substitution
|
||||
if (command.includes('$(')) {
|
||||
return {
|
||||
allowed: false,
|
||||
reason:
|
||||
'Command substitution using $() is not allowed for security reasons',
|
||||
};
|
||||
}
|
||||
|
||||
const SHELL_TOOL_NAMES = [ShellTool.name, ShellTool.Name];
|
||||
|
||||
const normalize = (cmd: string): string => cmd.trim().replace(/\s+/g, ' ');
|
||||
|
||||
/**
|
||||
* Checks if a command string starts with a given prefix, ensuring it's a
|
||||
* whole word match (i.e., followed by a space or it's an exact match).
|
||||
* e.g., `isPrefixedBy('npm install', 'npm')` -> true
|
||||
* e.g., `isPrefixedBy('npm', 'npm')` -> true
|
||||
* e.g., `isPrefixedBy('npminstall', 'npm')` -> false
|
||||
*/
|
||||
const isPrefixedBy = (cmd: string, prefix: string): boolean => {
|
||||
if (!cmd.startsWith(prefix)) {
|
||||
return false;
|
||||
}
|
||||
return cmd.length === prefix.length || cmd[prefix.length] === ' ';
|
||||
};
|
||||
|
||||
/**
|
||||
* Extracts and normalizes shell commands from a list of tool strings.
|
||||
* e.g., 'ShellTool("ls -l")' becomes 'ls -l'
|
||||
*/
|
||||
const extractCommands = (tools: string[]): string[] =>
|
||||
tools.flatMap((tool) => {
|
||||
for (const toolName of SHELL_TOOL_NAMES) {
|
||||
if (tool.startsWith(`${toolName}(`) && tool.endsWith(')')) {
|
||||
return [normalize(tool.slice(toolName.length + 1, -1))];
|
||||
}
|
||||
}
|
||||
return [];
|
||||
});
|
||||
|
||||
const coreTools = this.config.getCoreTools() || [];
|
||||
const excludeTools = this.config.getExcludeTools() || [];
|
||||
|
||||
// 1. Check if the shell tool is globally disabled.
|
||||
if (SHELL_TOOL_NAMES.some((name) => excludeTools.includes(name))) {
|
||||
return {
|
||||
allowed: false,
|
||||
reason: 'Shell tool is globally disabled in configuration',
|
||||
};
|
||||
}
|
||||
|
||||
const blockedCommands = new Set(extractCommands(excludeTools));
|
||||
const allowedCommands = new Set(extractCommands(coreTools));
|
||||
|
||||
const hasSpecificAllowedCommands = allowedCommands.size > 0;
|
||||
const isWildcardAllowed = SHELL_TOOL_NAMES.some((name) =>
|
||||
coreTools.includes(name),
|
||||
);
|
||||
|
||||
const commandsToValidate = command.split(/&&|\|\||\||;/).map(normalize);
|
||||
|
||||
const blockedCommandsArr = [...blockedCommands];
|
||||
|
||||
for (const cmd of commandsToValidate) {
|
||||
// 2. Check if the command is on the blocklist.
|
||||
const isBlocked = blockedCommandsArr.some((blocked) =>
|
||||
isPrefixedBy(cmd, blocked),
|
||||
);
|
||||
if (isBlocked) {
|
||||
return {
|
||||
allowed: false,
|
||||
reason: `Command '${cmd}' is blocked by configuration`,
|
||||
};
|
||||
}
|
||||
|
||||
// 3. If in strict allow-list mode, check if the command is permitted.
|
||||
const isStrictAllowlist =
|
||||
hasSpecificAllowedCommands && !isWildcardAllowed;
|
||||
const allowedCommandsArr = [...allowedCommands];
|
||||
if (isStrictAllowlist) {
|
||||
const isAllowed = allowedCommandsArr.some((allowed) =>
|
||||
isPrefixedBy(cmd, allowed),
|
||||
);
|
||||
if (!isAllowed) {
|
||||
return {
|
||||
allowed: false,
|
||||
reason: `Command '${cmd}' is not in the allowed commands list`,
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 4. If all checks pass, the command is allowed.
|
||||
return { allowed: true };
|
||||
}
|
||||
|
||||
validateToolParams(params: ShellToolParams): string | null {
|
||||
const commandCheck = this.isCommandAllowed(params.command);
|
||||
const commandCheck = isCommandAllowed(params.command, this.config);
|
||||
if (!commandCheck.allowed) {
|
||||
if (!commandCheck.reason) {
|
||||
console.error(
|
||||
@@ -232,7 +119,7 @@ Process Group PGID: Process group started or \`(none)\``,
|
||||
if (!params.command.trim()) {
|
||||
return 'Command cannot be empty.';
|
||||
}
|
||||
if (!this.getCommandRoot(params.command)) {
|
||||
if (getCommandRoots(params.command).length === 0) {
|
||||
return 'Could not identify command root to obtain permission from user.';
|
||||
}
|
||||
if (params.directory) {
|
||||
@@ -257,18 +144,25 @@ Process Group PGID: Process group started or \`(none)\``,
|
||||
if (this.validateToolParams(params)) {
|
||||
return false; // skip confirmation, execute call will fail immediately
|
||||
}
|
||||
const rootCommand = this.getCommandRoot(params.command)!; // must be non-empty string post-validation
|
||||
if (this.whitelist.has(rootCommand)) {
|
||||
|
||||
const command = stripShellWrapper(params.command);
|
||||
const rootCommands = [...new Set(getCommandRoots(command))];
|
||||
const commandsToConfirm = rootCommands.filter(
|
||||
(command) => !this.allowlist.has(command),
|
||||
);
|
||||
|
||||
if (commandsToConfirm.length === 0) {
|
||||
return false; // already approved and whitelisted
|
||||
}
|
||||
|
||||
const confirmationDetails: ToolExecuteConfirmationDetails = {
|
||||
type: 'exec',
|
||||
title: 'Confirm Shell Command',
|
||||
command: params.command,
|
||||
rootCommand,
|
||||
rootCommand: commandsToConfirm.join(', '),
|
||||
onConfirm: async (outcome: ToolConfirmationOutcome) => {
|
||||
if (outcome === ToolConfirmationOutcome.ProceedAlways) {
|
||||
this.whitelist.add(rootCommand);
|
||||
commandsToConfirm.forEach((command) => this.allowlist.add(command));
|
||||
}
|
||||
},
|
||||
};
|
||||
@@ -277,21 +171,22 @@ Process Group PGID: Process group started or \`(none)\``,
|
||||
|
||||
async execute(
|
||||
params: ShellToolParams,
|
||||
abortSignal: AbortSignal,
|
||||
updateOutput?: (chunk: string) => void,
|
||||
signal: AbortSignal,
|
||||
updateOutput?: (output: string) => void,
|
||||
): Promise<ToolResult> {
|
||||
const validationError = this.validateToolParams(params);
|
||||
const strippedCommand = stripShellWrapper(params.command);
|
||||
const validationError = this.validateToolParams({
|
||||
...params,
|
||||
command: strippedCommand,
|
||||
});
|
||||
if (validationError) {
|
||||
return {
|
||||
llmContent: [
|
||||
`Command rejected: ${params.command}`,
|
||||
`Reason: ${validationError}`,
|
||||
].join('\n'),
|
||||
returnDisplay: `Error: ${validationError}`,
|
||||
llmContent: validationError,
|
||||
returnDisplay: validationError,
|
||||
};
|
||||
}
|
||||
|
||||
if (abortSignal.aborted) {
|
||||
if (signal.aborted) {
|
||||
return {
|
||||
llmContent: 'Command was cancelled by user before it could start.',
|
||||
returnDisplay: 'Command cancelled by user.',
|
||||
@@ -304,200 +199,182 @@ Process Group PGID: Process group started or \`(none)\``,
|
||||
.toString('hex')}.tmp`;
|
||||
const tempFilePath = path.join(os.tmpdir(), tempFileName);
|
||||
|
||||
// pgrep is not available on Windows, so we can't get background PIDs
|
||||
const command = isWindows
|
||||
? params.command
|
||||
: (() => {
|
||||
// wrap command to append subprocess pids (via pgrep) to temporary file
|
||||
let command = params.command.trim();
|
||||
if (!command.endsWith('&')) command += ';';
|
||||
return `{ ${command} }; __code=$?; pgrep -g 0 >${tempFilePath} 2>&1; exit $__code;`;
|
||||
})();
|
||||
|
||||
// spawn command in specified directory (or project root if not specified)
|
||||
const shell = isWindows
|
||||
? spawn('cmd.exe', ['/c', command], {
|
||||
stdio: ['ignore', 'pipe', 'pipe'],
|
||||
// detached: true, // ensure subprocess starts its own process group (esp. in Linux)
|
||||
cwd: path.resolve(this.config.getTargetDir(), params.directory || ''),
|
||||
})
|
||||
: spawn('bash', ['-c', command], {
|
||||
stdio: ['ignore', 'pipe', 'pipe'],
|
||||
detached: true, // ensure subprocess starts its own process group (esp. in Linux)
|
||||
cwd: path.resolve(this.config.getTargetDir(), params.directory || ''),
|
||||
});
|
||||
|
||||
let exited = false;
|
||||
let stdout = '';
|
||||
let output = '';
|
||||
let lastUpdateTime = Date.now();
|
||||
|
||||
const appendOutput = (str: string) => {
|
||||
output += str;
|
||||
if (
|
||||
updateOutput &&
|
||||
Date.now() - lastUpdateTime > OUTPUT_UPDATE_INTERVAL_MS
|
||||
) {
|
||||
updateOutput(output);
|
||||
lastUpdateTime = Date.now();
|
||||
}
|
||||
};
|
||||
|
||||
shell.stdout.on('data', (data: Buffer) => {
|
||||
// continue to consume post-exit for background processes
|
||||
// removing listeners can overflow OS buffer and block subprocesses
|
||||
// destroying (e.g. shell.stdout.destroy()) can terminate subprocesses via SIGPIPE
|
||||
if (!exited) {
|
||||
const str = stripAnsi(data.toString());
|
||||
stdout += str;
|
||||
appendOutput(str);
|
||||
}
|
||||
});
|
||||
|
||||
let stderr = '';
|
||||
shell.stderr.on('data', (data: Buffer) => {
|
||||
if (!exited) {
|
||||
const str = stripAnsi(data.toString());
|
||||
stderr += str;
|
||||
appendOutput(str);
|
||||
}
|
||||
});
|
||||
|
||||
let error: Error | null = null;
|
||||
shell.on('error', (err: Error) => {
|
||||
error = err;
|
||||
// remove wrapper from user's command in error message
|
||||
error.message = error.message.replace(command, params.command);
|
||||
});
|
||||
|
||||
let code: number | null = null;
|
||||
let processSignal: NodeJS.Signals | null = null;
|
||||
const exitHandler = (
|
||||
_code: number | null,
|
||||
_signal: NodeJS.Signals | null,
|
||||
) => {
|
||||
exited = true;
|
||||
code = _code;
|
||||
processSignal = _signal;
|
||||
};
|
||||
shell.on('exit', exitHandler);
|
||||
|
||||
const abortHandler = async () => {
|
||||
if (shell.pid && !exited) {
|
||||
if (os.platform() === 'win32') {
|
||||
// For Windows, use taskkill to kill the process tree
|
||||
spawn('taskkill', ['/pid', shell.pid.toString(), '/f', '/t']);
|
||||
} else {
|
||||
try {
|
||||
// attempt to SIGTERM process group (negative PID)
|
||||
// fall back to SIGKILL (to group) after 200ms
|
||||
process.kill(-shell.pid, 'SIGTERM');
|
||||
await new Promise((resolve) => setTimeout(resolve, 200));
|
||||
if (shell.pid && !exited) {
|
||||
process.kill(-shell.pid, 'SIGKILL');
|
||||
}
|
||||
} catch (_e) {
|
||||
// if group kill fails, fall back to killing just the main process
|
||||
try {
|
||||
if (shell.pid) {
|
||||
shell.kill('SIGKILL');
|
||||
}
|
||||
} catch (_e) {
|
||||
console.error(`failed to kill shell process ${shell.pid}: ${_e}`);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
abortSignal.addEventListener('abort', abortHandler);
|
||||
|
||||
// wait for the shell to exit
|
||||
try {
|
||||
await new Promise((resolve) => shell.on('exit', resolve));
|
||||
// pgrep is not available on Windows, so we can't get background PIDs
|
||||
const commandToExecute = isWindows
|
||||
? strippedCommand
|
||||
: (() => {
|
||||
// wrap command to append subprocess pids (via pgrep) to temporary file
|
||||
let command = strippedCommand.trim();
|
||||
if (!command.endsWith('&')) command += ';';
|
||||
return `{ ${command} }; __code=$?; pgrep -g 0 >${tempFilePath} 2>&1; exit $__code;`;
|
||||
})();
|
||||
|
||||
const cwd = path.resolve(
|
||||
this.config.getTargetDir(),
|
||||
params.directory || '',
|
||||
);
|
||||
|
||||
let cumulativeStdout = '';
|
||||
let cumulativeStderr = '';
|
||||
|
||||
let lastUpdateTime = Date.now();
|
||||
let isBinaryStream = false;
|
||||
|
||||
const { result: resultPromise } = ShellExecutionService.execute(
|
||||
commandToExecute,
|
||||
cwd,
|
||||
(event: ShellOutputEvent) => {
|
||||
if (!updateOutput) {
|
||||
return;
|
||||
}
|
||||
|
||||
let currentDisplayOutput = '';
|
||||
let shouldUpdate = false;
|
||||
|
||||
switch (event.type) {
|
||||
case 'data':
|
||||
if (isBinaryStream) break; // Don't process text if we are in binary mode
|
||||
if (event.stream === 'stdout') {
|
||||
cumulativeStdout += event.chunk;
|
||||
} else {
|
||||
cumulativeStderr += event.chunk;
|
||||
}
|
||||
currentDisplayOutput =
|
||||
cumulativeStdout +
|
||||
(cumulativeStderr ? `\n${cumulativeStderr}` : '');
|
||||
if (Date.now() - lastUpdateTime > OUTPUT_UPDATE_INTERVAL_MS) {
|
||||
shouldUpdate = true;
|
||||
}
|
||||
break;
|
||||
case 'binary_detected':
|
||||
isBinaryStream = true;
|
||||
currentDisplayOutput =
|
||||
'[Binary output detected. Halting stream...]';
|
||||
shouldUpdate = true;
|
||||
break;
|
||||
case 'binary_progress':
|
||||
isBinaryStream = true;
|
||||
currentDisplayOutput = `[Receiving binary output... ${formatMemoryUsage(
|
||||
event.bytesReceived,
|
||||
)} received]`;
|
||||
if (Date.now() - lastUpdateTime > OUTPUT_UPDATE_INTERVAL_MS) {
|
||||
shouldUpdate = true;
|
||||
}
|
||||
break;
|
||||
default: {
|
||||
throw new Error('An unhandled ShellOutputEvent was found.');
|
||||
}
|
||||
}
|
||||
|
||||
if (shouldUpdate) {
|
||||
updateOutput(currentDisplayOutput);
|
||||
lastUpdateTime = Date.now();
|
||||
}
|
||||
},
|
||||
signal,
|
||||
);
|
||||
|
||||
const result = await resultPromise;
|
||||
|
||||
const backgroundPIDs: number[] = [];
|
||||
if (os.platform() !== 'win32') {
|
||||
if (fs.existsSync(tempFilePath)) {
|
||||
const pgrepLines = fs
|
||||
.readFileSync(tempFilePath, 'utf8')
|
||||
.split('\n')
|
||||
.filter(Boolean);
|
||||
for (const line of pgrepLines) {
|
||||
if (!/^\d+$/.test(line)) {
|
||||
console.error(`pgrep: ${line}`);
|
||||
}
|
||||
const pid = Number(line);
|
||||
if (pid !== result.pid) {
|
||||
backgroundPIDs.push(pid);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if (!signal.aborted) {
|
||||
console.error('missing pgrep output');
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let llmContent = '';
|
||||
if (result.aborted) {
|
||||
llmContent = 'Command was cancelled by user before it could complete.';
|
||||
if (result.output.trim()) {
|
||||
llmContent += ` Below is the output (on stdout and stderr) before it was cancelled:\n${result.output}`;
|
||||
} else {
|
||||
llmContent += ' There was no output before it was cancelled.';
|
||||
}
|
||||
} else {
|
||||
// Create a formatted error string for display, replacing the wrapper command
|
||||
// with the user-facing command.
|
||||
const finalError = result.error
|
||||
? result.error.message.replace(commandToExecute, params.command)
|
||||
: '(none)';
|
||||
|
||||
llmContent = [
|
||||
`Command: ${params.command}`,
|
||||
`Directory: ${params.directory || '(root)'}`,
|
||||
`Stdout: ${result.stdout || '(empty)'}`,
|
||||
`Stderr: ${result.stderr || '(empty)'}`,
|
||||
`Error: ${finalError}`, // Use the cleaned error string.
|
||||
`Exit Code: ${result.exitCode ?? '(none)'}`,
|
||||
`Signal: ${result.signal ?? '(none)'}`,
|
||||
`Background PIDs: ${
|
||||
backgroundPIDs.length ? backgroundPIDs.join(', ') : '(none)'
|
||||
}`,
|
||||
`Process Group PGID: ${result.pid ?? '(none)'}`,
|
||||
].join('\n');
|
||||
}
|
||||
|
||||
let returnDisplayMessage = '';
|
||||
if (this.config.getDebugMode()) {
|
||||
returnDisplayMessage = llmContent;
|
||||
} else {
|
||||
if (result.output.trim()) {
|
||||
returnDisplayMessage = result.output;
|
||||
} else {
|
||||
if (result.aborted) {
|
||||
returnDisplayMessage = 'Command cancelled by user.';
|
||||
} else if (result.signal) {
|
||||
returnDisplayMessage = `Command terminated by signal: ${result.signal}`;
|
||||
} else if (result.error) {
|
||||
returnDisplayMessage = `Command failed: ${getErrorMessage(
|
||||
result.error,
|
||||
)}`;
|
||||
} else if (result.exitCode !== null && result.exitCode !== 0) {
|
||||
returnDisplayMessage = `Command exited with code: ${result.exitCode}`;
|
||||
}
|
||||
// If output is empty and command succeeded (code 0, no error/signal/abort),
|
||||
// returnDisplayMessage will remain empty, which is fine.
|
||||
}
|
||||
}
|
||||
|
||||
const summarizeConfig = this.config.getSummarizeToolOutputConfig();
|
||||
if (summarizeConfig && summarizeConfig[this.name]) {
|
||||
const summary = await summarizeToolOutput(
|
||||
llmContent,
|
||||
this.config.getGeminiClient(),
|
||||
signal,
|
||||
summarizeConfig[this.name].tokenBudget,
|
||||
);
|
||||
return {
|
||||
llmContent: summary,
|
||||
returnDisplay: returnDisplayMessage,
|
||||
};
|
||||
}
|
||||
|
||||
return {
|
||||
llmContent,
|
||||
returnDisplay: returnDisplayMessage,
|
||||
};
|
||||
} finally {
|
||||
abortSignal.removeEventListener('abort', abortHandler);
|
||||
}
|
||||
|
||||
// parse pids (pgrep output) from temporary file and remove it
|
||||
const backgroundPIDs: number[] = [];
|
||||
if (os.platform() !== 'win32') {
|
||||
if (fs.existsSync(tempFilePath)) {
|
||||
const pgrepLines = fs
|
||||
.readFileSync(tempFilePath, 'utf8')
|
||||
.split('\n')
|
||||
.filter(Boolean);
|
||||
for (const line of pgrepLines) {
|
||||
if (!/^\d+$/.test(line)) {
|
||||
console.error(`pgrep: ${line}`);
|
||||
}
|
||||
const pid = Number(line);
|
||||
// exclude the shell subprocess pid
|
||||
if (pid !== shell.pid) {
|
||||
backgroundPIDs.push(pid);
|
||||
}
|
||||
}
|
||||
fs.unlinkSync(tempFilePath);
|
||||
} else {
|
||||
if (!abortSignal.aborted) {
|
||||
console.error('missing pgrep output');
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let llmContent = '';
|
||||
if (abortSignal.aborted) {
|
||||
llmContent = 'Command was cancelled by user before it could complete.';
|
||||
if (output.trim()) {
|
||||
llmContent += ` Below is the output (on stdout and stderr) before it was cancelled:\n${output}`;
|
||||
} else {
|
||||
llmContent += ' There was no output before it was cancelled.';
|
||||
}
|
||||
} else {
|
||||
llmContent = [
|
||||
`Command: ${params.command}`,
|
||||
`Directory: ${params.directory || '(root)'}`,
|
||||
`Stdout: ${stdout || '(empty)'}`,
|
||||
`Stderr: ${stderr || '(empty)'}`,
|
||||
`Error: ${error ?? '(none)'}`,
|
||||
`Exit Code: ${code ?? '(none)'}`,
|
||||
`Signal: ${processSignal ?? '(none)'}`,
|
||||
`Background PIDs: ${backgroundPIDs.length ? backgroundPIDs.join(', ') : '(none)'}`,
|
||||
`Process Group PGID: ${shell.pid ?? '(none)'}`,
|
||||
].join('\n');
|
||||
}
|
||||
|
||||
let returnDisplayMessage = '';
|
||||
if (this.config.getDebugMode()) {
|
||||
returnDisplayMessage = llmContent;
|
||||
} else {
|
||||
if (output.trim()) {
|
||||
returnDisplayMessage = output;
|
||||
} else {
|
||||
// Output is empty, let's provide a reason if the command failed or was cancelled
|
||||
if (abortSignal.aborted) {
|
||||
returnDisplayMessage = 'Command cancelled by user.';
|
||||
} else if (processSignal) {
|
||||
returnDisplayMessage = `Command terminated by signal: ${processSignal}`;
|
||||
} else if (error) {
|
||||
// If error is not null, it's an Error object (or other truthy value)
|
||||
returnDisplayMessage = `Command failed: ${getErrorMessage(error)}`;
|
||||
} else if (code !== null && code !== 0) {
|
||||
returnDisplayMessage = `Command exited with code: ${code}`;
|
||||
}
|
||||
// If output is empty and command succeeded (code 0, no error/signal/abort),
|
||||
// returnDisplayMessage will remain empty, which is fine.
|
||||
}
|
||||
}
|
||||
|
||||
const summary = await summarizeToolOutput(
|
||||
llmContent,
|
||||
this.config.getGeminiClient(),
|
||||
abortSignal,
|
||||
);
|
||||
|
||||
return {
|
||||
llmContent: summary,
|
||||
returnDisplay: returnDisplayMessage,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
@@ -14,14 +14,14 @@ import {
|
||||
afterEach,
|
||||
Mocked,
|
||||
} from 'vitest';
|
||||
import { Config, ConfigParameters, ApprovalMode } from '../config/config.js';
|
||||
import {
|
||||
ToolRegistry,
|
||||
DiscoveredTool,
|
||||
sanitizeParameters,
|
||||
} from './tool-registry.js';
|
||||
import { DiscoveredMCPTool } from './mcp-tool.js';
|
||||
import { Config, ConfigParameters, ApprovalMode } from '../config/config.js';
|
||||
import { BaseTool, ToolResult } from './tools.js';
|
||||
import { BaseTool, Icon, ToolResult } from './tools.js';
|
||||
import {
|
||||
FunctionDeclaration,
|
||||
CallableTool,
|
||||
@@ -104,8 +104,12 @@ const createMockCallableTool = (
|
||||
});
|
||||
|
||||
class MockTool extends BaseTool<{ param: string }, ToolResult> {
|
||||
constructor(name = 'mock-tool', description = 'A mock tool') {
|
||||
super(name, name, description, {
|
||||
constructor(
|
||||
name = 'mock-tool',
|
||||
displayName = 'A mock tool',
|
||||
description = 'A mock tool description',
|
||||
) {
|
||||
super(name, displayName, description, Icon.Hammer, {
|
||||
type: Type.OBJECT,
|
||||
properties: {
|
||||
param: { type: Type.STRING },
|
||||
@@ -174,42 +178,85 @@ describe('ToolRegistry', () => {
|
||||
});
|
||||
});
|
||||
|
||||
describe('getAllTools', () => {
|
||||
it('should return all registered tools sorted alphabetically by displayName', () => {
|
||||
// Register tools with displayNames in non-alphabetical order
|
||||
const toolC = new MockTool('c-tool', 'Tool C');
|
||||
const toolA = new MockTool('a-tool', 'Tool A');
|
||||
const toolB = new MockTool('b-tool', 'Tool B');
|
||||
|
||||
toolRegistry.registerTool(toolC);
|
||||
toolRegistry.registerTool(toolA);
|
||||
toolRegistry.registerTool(toolB);
|
||||
|
||||
const allTools = toolRegistry.getAllTools();
|
||||
const displayNames = allTools.map((t) => t.displayName);
|
||||
|
||||
// Assert that the returned array is sorted by displayName
|
||||
expect(displayNames).toEqual(['Tool A', 'Tool B', 'Tool C']);
|
||||
});
|
||||
});
|
||||
|
||||
describe('getToolsByServer', () => {
|
||||
it('should return an empty array if no tools match the server name', () => {
|
||||
toolRegistry.registerTool(new MockTool());
|
||||
expect(toolRegistry.getToolsByServer('any-mcp-server')).toEqual([]);
|
||||
});
|
||||
|
||||
it('should return only tools matching the server name', async () => {
|
||||
it('should return only tools matching the server name, sorted by name', async () => {
|
||||
const server1Name = 'mcp-server-uno';
|
||||
const server2Name = 'mcp-server-dos';
|
||||
const mockCallable = {} as CallableTool;
|
||||
const mcpTool1 = new DiscoveredMCPTool(
|
||||
const mcpTool1_c = new DiscoveredMCPTool(
|
||||
mockCallable,
|
||||
server1Name,
|
||||
'server1Name__tool-on-server1',
|
||||
'zebra-tool',
|
||||
'd1',
|
||||
{},
|
||||
'tool-on-server1',
|
||||
);
|
||||
const mcpTool1_a = new DiscoveredMCPTool(
|
||||
mockCallable,
|
||||
server1Name,
|
||||
'apple-tool',
|
||||
'd2',
|
||||
{},
|
||||
);
|
||||
const mcpTool1_b = new DiscoveredMCPTool(
|
||||
mockCallable,
|
||||
server1Name,
|
||||
'banana-tool',
|
||||
'd3',
|
||||
{},
|
||||
);
|
||||
|
||||
const mcpTool2 = new DiscoveredMCPTool(
|
||||
mockCallable,
|
||||
server2Name,
|
||||
'server2Name__tool-on-server2',
|
||||
'd2',
|
||||
{},
|
||||
'tool-on-server2',
|
||||
'd4',
|
||||
{},
|
||||
);
|
||||
const nonMcpTool = new MockTool('regular-tool');
|
||||
|
||||
toolRegistry.registerTool(mcpTool1);
|
||||
toolRegistry.registerTool(mcpTool1_c);
|
||||
toolRegistry.registerTool(mcpTool1_a);
|
||||
toolRegistry.registerTool(mcpTool1_b);
|
||||
toolRegistry.registerTool(mcpTool2);
|
||||
toolRegistry.registerTool(nonMcpTool);
|
||||
|
||||
const toolsFromServer1 = toolRegistry.getToolsByServer(server1Name);
|
||||
expect(toolsFromServer1).toHaveLength(1);
|
||||
expect(toolsFromServer1[0].name).toBe(mcpTool1.name);
|
||||
const toolNames = toolsFromServer1.map((t) => t.name);
|
||||
|
||||
// Assert that the array has the correct tools and is sorted by name
|
||||
expect(toolsFromServer1).toHaveLength(3);
|
||||
expect(toolNames).toEqual(['apple-tool', 'banana-tool', 'zebra-tool']);
|
||||
|
||||
// Assert that all returned tools are indeed from the correct server
|
||||
for (const tool of toolsFromServer1) {
|
||||
expect((tool as DiscoveredMCPTool).serverName).toBe(server1Name);
|
||||
}
|
||||
|
||||
// Assert that the other server's tools are returned correctly
|
||||
const toolsFromServer2 = toolRegistry.getToolsByServer(server2Name);
|
||||
expect(toolsFromServer2).toHaveLength(1);
|
||||
expect(toolsFromServer2[0].name).toBe(mcpTool2.name);
|
||||
@@ -265,7 +312,7 @@ describe('ToolRegistry', () => {
|
||||
return mockChildProcess as any;
|
||||
});
|
||||
|
||||
await toolRegistry.discoverTools();
|
||||
await toolRegistry.discoverAllTools();
|
||||
|
||||
const discoveredTool = toolRegistry.getTool('tool-with-bad-format');
|
||||
expect(discoveredTool).toBeDefined();
|
||||
@@ -291,12 +338,13 @@ describe('ToolRegistry', () => {
|
||||
};
|
||||
vi.spyOn(config, 'getMcpServers').mockReturnValue(mcpServerConfigVal);
|
||||
|
||||
await toolRegistry.discoverTools();
|
||||
await toolRegistry.discoverAllTools();
|
||||
|
||||
expect(mockDiscoverMcpTools).toHaveBeenCalledWith(
|
||||
mcpServerConfigVal,
|
||||
undefined,
|
||||
toolRegistry,
|
||||
undefined,
|
||||
false,
|
||||
);
|
||||
});
|
||||
@@ -313,12 +361,13 @@ describe('ToolRegistry', () => {
|
||||
};
|
||||
vi.spyOn(config, 'getMcpServers').mockReturnValue(mcpServerConfigVal);
|
||||
|
||||
await toolRegistry.discoverTools();
|
||||
await toolRegistry.discoverAllTools();
|
||||
|
||||
expect(mockDiscoverMcpTools).toHaveBeenCalledWith(
|
||||
mcpServerConfigVal,
|
||||
undefined,
|
||||
toolRegistry,
|
||||
undefined,
|
||||
false,
|
||||
);
|
||||
});
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
*/
|
||||
|
||||
import { FunctionDeclaration, Schema, Type } from '@google/genai';
|
||||
import { Tool, ToolResult, BaseTool } from './tools.js';
|
||||
import { Tool, ToolResult, BaseTool, Icon } from './tools.js';
|
||||
import { Config } from '../config/config.js';
|
||||
import { spawn } from 'node:child_process';
|
||||
import { StringDecoder } from 'node:string_decoder';
|
||||
@@ -18,7 +18,7 @@ type ToolParams = Record<string, unknown>;
|
||||
export class DiscoveredTool extends BaseTool<ToolParams, ToolResult> {
|
||||
constructor(
|
||||
private readonly config: Config,
|
||||
readonly name: string,
|
||||
name: string,
|
||||
readonly description: string,
|
||||
readonly parameterSchema: Record<string, unknown>,
|
||||
) {
|
||||
@@ -44,6 +44,7 @@ Signal: Signal number or \`(none)\` if no signal was received.
|
||||
name,
|
||||
name,
|
||||
description,
|
||||
Icon.Hammer,
|
||||
parameterSchema,
|
||||
false, // isOutputMarkdown
|
||||
false, // canUpdateOutput
|
||||
@@ -137,10 +138,14 @@ export class ToolRegistry {
|
||||
*/
|
||||
registerTool(tool: Tool): void {
|
||||
if (this.tools.has(tool.name)) {
|
||||
// Decide on behavior: throw error, log warning, or allow overwrite
|
||||
console.warn(
|
||||
`Tool with name "${tool.name}" is already registered. Overwriting.`,
|
||||
);
|
||||
if (tool instanceof DiscoveredMCPTool) {
|
||||
tool = tool.asFullyQualifiedTool();
|
||||
} else {
|
||||
// Decide on behavior: throw error, log warning, or allow overwrite
|
||||
console.warn(
|
||||
`Tool with name "${tool.name}" is already registered. Overwriting.`,
|
||||
);
|
||||
}
|
||||
}
|
||||
this.tools.set(tool.name, tool);
|
||||
}
|
||||
@@ -148,8 +153,9 @@ export class ToolRegistry {
|
||||
/**
|
||||
* Discovers tools from project (if available and configured).
|
||||
* Can be called multiple times to update discovered tools.
|
||||
* This will discover tools from the command line and from MCP servers.
|
||||
*/
|
||||
async discoverTools(): Promise<void> {
|
||||
async discoverAllTools(): Promise<void> {
|
||||
// remove any previously discovered tools
|
||||
for (const tool of this.tools.values()) {
|
||||
if (tool instanceof DiscoveredTool || tool instanceof DiscoveredMCPTool) {
|
||||
@@ -164,10 +170,59 @@ export class ToolRegistry {
|
||||
this.config.getMcpServers() ?? {},
|
||||
this.config.getMcpServerCommand(),
|
||||
this,
|
||||
this.config.getPromptRegistry(),
|
||||
this.config.getDebugMode(),
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Discovers tools from project (if available and configured).
|
||||
* Can be called multiple times to update discovered tools.
|
||||
* This will NOT discover tools from the command line, only from MCP servers.
|
||||
*/
|
||||
async discoverMcpTools(): Promise<void> {
|
||||
// remove any previously discovered tools
|
||||
for (const tool of this.tools.values()) {
|
||||
if (tool instanceof DiscoveredMCPTool) {
|
||||
this.tools.delete(tool.name);
|
||||
}
|
||||
}
|
||||
|
||||
// discover tools using MCP servers, if configured
|
||||
await discoverMcpTools(
|
||||
this.config.getMcpServers() ?? {},
|
||||
this.config.getMcpServerCommand(),
|
||||
this,
|
||||
this.config.getPromptRegistry(),
|
||||
this.config.getDebugMode(),
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Discover or re-discover tools for a single MCP server.
|
||||
* @param serverName - The name of the server to discover tools from.
|
||||
*/
|
||||
async discoverToolsForServer(serverName: string): Promise<void> {
|
||||
// Remove any previously discovered tools from this server
|
||||
for (const [name, tool] of this.tools.entries()) {
|
||||
if (tool instanceof DiscoveredMCPTool && tool.serverName === serverName) {
|
||||
this.tools.delete(name);
|
||||
}
|
||||
}
|
||||
|
||||
const mcpServers = this.config.getMcpServers() ?? {};
|
||||
const serverConfig = mcpServers[serverName];
|
||||
if (serverConfig) {
|
||||
await discoverMcpTools(
|
||||
{ [serverName]: serverConfig },
|
||||
undefined,
|
||||
this,
|
||||
this.config.getPromptRegistry(),
|
||||
this.config.getDebugMode(),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
private async discoverAndRegisterToolsFromCommand(): Promise<void> {
|
||||
const discoveryCmd = this.config.getToolDiscoveryCommand();
|
||||
if (!discoveryCmd) {
|
||||
@@ -308,7 +363,9 @@ export class ToolRegistry {
|
||||
* Returns an array of all registered and discovered tool instances.
|
||||
*/
|
||||
getAllTools(): Tool[] {
|
||||
return Array.from(this.tools.values());
|
||||
return Array.from(this.tools.values()).sort((a, b) =>
|
||||
a.displayName.localeCompare(b.displayName),
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -321,7 +378,7 @@ export class ToolRegistry {
|
||||
serverTools.push(tool);
|
||||
}
|
||||
}
|
||||
return serverTools;
|
||||
return serverTools.sort((a, b) => a.name.localeCompare(b.name));
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -379,6 +436,19 @@ function _sanitizeParameters(schema: Schema | undefined, visited: Set<Schema>) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Handle enum values - Gemini API only allows enum for STRING type
|
||||
if (schema.enum && Array.isArray(schema.enum)) {
|
||||
if (schema.type !== Type.STRING) {
|
||||
// If enum is present but type is not STRING, convert type to STRING
|
||||
schema.type = Type.STRING;
|
||||
}
|
||||
// Filter out null and undefined values, then convert remaining values to strings for Gemini API compatibility
|
||||
schema.enum = schema.enum
|
||||
.filter((value: unknown) => value !== null && value !== undefined)
|
||||
.map((value: unknown) => String(value));
|
||||
}
|
||||
|
||||
// Vertex AI only supports 'enum' and 'date-time' for STRING format.
|
||||
if (schema.type === Type.STRING) {
|
||||
if (
|
||||
|
||||
@@ -28,6 +28,11 @@ export interface Tool<
|
||||
*/
|
||||
description: string;
|
||||
|
||||
/**
|
||||
* The icon to display when interacting via ACP
|
||||
*/
|
||||
icon: Icon;
|
||||
|
||||
/**
|
||||
* Function declaration schema from @google/genai
|
||||
*/
|
||||
@@ -60,6 +65,13 @@ export interface Tool<
|
||||
*/
|
||||
getDescription(params: TParams): string;
|
||||
|
||||
/**
|
||||
* Determines what file system paths the tool will affect
|
||||
* @param params Parameters for the tool execution
|
||||
* @returns A list of such paths
|
||||
*/
|
||||
toolLocations(params: TParams): ToolLocation[];
|
||||
|
||||
/**
|
||||
* Determines if the tool should prompt for confirmation before execution
|
||||
* @param params Parameters for the tool execution
|
||||
@@ -97,12 +109,13 @@ export abstract class BaseTool<
|
||||
* @param description Description of what the tool does
|
||||
* @param isOutputMarkdown Whether the tool's output should be rendered as markdown
|
||||
* @param canUpdateOutput Whether the tool supports live (streaming) output
|
||||
* @param parameterSchema JSON Schema defining the parameters
|
||||
* @param parameterSchema Open API 3.0 Schema defining the parameters
|
||||
*/
|
||||
constructor(
|
||||
readonly name: string,
|
||||
readonly displayName: string,
|
||||
readonly description: string,
|
||||
readonly icon: Icon,
|
||||
readonly parameterSchema: Schema,
|
||||
readonly isOutputMarkdown: boolean = true,
|
||||
readonly canUpdateOutput: boolean = false,
|
||||
@@ -158,6 +171,18 @@ export abstract class BaseTool<
|
||||
return Promise.resolve(false);
|
||||
}
|
||||
|
||||
/**
|
||||
* Determines what file system paths the tool will affect
|
||||
* @param params Parameters for the tool execution
|
||||
* @returns A list of such paths
|
||||
*/
|
||||
toolLocations(
|
||||
// eslint-disable-next-line @typescript-eslint/no-unused-vars
|
||||
params: TParams,
|
||||
): ToolLocation[] {
|
||||
return [];
|
||||
}
|
||||
|
||||
/**
|
||||
* Abstract method to execute the tool with the given parameters
|
||||
* Must be implemented by derived classes
|
||||
@@ -199,6 +224,8 @@ export type ToolResultDisplay = string | FileDiff;
|
||||
export interface FileDiff {
|
||||
fileDiff: string;
|
||||
fileName: string;
|
||||
originalContent: string | null;
|
||||
newContent: string;
|
||||
}
|
||||
|
||||
export interface ToolEditConfirmationDetails {
|
||||
@@ -210,6 +237,8 @@ export interface ToolEditConfirmationDetails {
|
||||
) => Promise<void>;
|
||||
fileName: string;
|
||||
fileDiff: string;
|
||||
originalContent: string | null;
|
||||
newContent: string;
|
||||
isModifying?: boolean;
|
||||
}
|
||||
|
||||
@@ -258,3 +287,21 @@ export enum ToolConfirmationOutcome {
|
||||
ModifyWithEditor = 'modify_with_editor',
|
||||
Cancel = 'cancel',
|
||||
}
|
||||
|
||||
export enum Icon {
|
||||
FileSearch = 'fileSearch',
|
||||
Folder = 'folder',
|
||||
Globe = 'globe',
|
||||
Hammer = 'hammer',
|
||||
LightBulb = 'lightBulb',
|
||||
Pencil = 'pencil',
|
||||
Regex = 'regex',
|
||||
Terminal = 'terminal',
|
||||
}
|
||||
|
||||
export interface ToolLocation {
|
||||
// Absolute path to the file
|
||||
path: string;
|
||||
// Which line (if known)
|
||||
line?: number;
|
||||
}
|
||||
|
||||
@@ -13,6 +13,7 @@ describe('WebFetchTool', () => {
|
||||
const mockConfig = {
|
||||
getApprovalMode: vi.fn(),
|
||||
setApprovalMode: vi.fn(),
|
||||
getProxy: vi.fn(),
|
||||
} as unknown as Config;
|
||||
|
||||
describe('shouldConfirmExecute', () => {
|
||||
|
||||
@@ -10,6 +10,7 @@ import {
|
||||
ToolResult,
|
||||
ToolCallConfirmationDetails,
|
||||
ToolConfirmationOutcome,
|
||||
Icon,
|
||||
} from './tools.js';
|
||||
import { Type } from '@google/genai';
|
||||
import { getErrorMessage } from '../utils/errors.js';
|
||||
@@ -17,9 +18,10 @@ import { Config, ApprovalMode } from '../config/config.js';
|
||||
import { getResponseText } from '../utils/generateContentResponseUtilities.js';
|
||||
import { fetchWithTimeout, isPrivateIp } from '../utils/fetch.js';
|
||||
import { convert } from 'html-to-text';
|
||||
import { ProxyAgent, setGlobalDispatcher } from 'undici';
|
||||
|
||||
const URL_FETCH_TIMEOUT_MS = 10000;
|
||||
const MAX_CONTENT_LENGTH = 50000;
|
||||
const MAX_CONTENT_LENGTH = 100000;
|
||||
|
||||
// Helper function to extract URLs from a string
|
||||
function extractUrls(text: string): string[] {
|
||||
@@ -69,6 +71,7 @@ export class WebFetchTool extends BaseTool<WebFetchToolParams, ToolResult> {
|
||||
WebFetchTool.Name,
|
||||
'WebFetch',
|
||||
"Processes content from URL(s), including local and private network addresses (e.g., localhost), embedded in a prompt. Include up to 20 URLs and instructions (e.g., summarize, extract specific data) directly in the 'prompt' parameter.",
|
||||
Icon.Globe,
|
||||
{
|
||||
properties: {
|
||||
prompt: {
|
||||
@@ -81,6 +84,10 @@ export class WebFetchTool extends BaseTool<WebFetchToolParams, ToolResult> {
|
||||
type: Type.OBJECT,
|
||||
},
|
||||
);
|
||||
const proxy = config.getProxy();
|
||||
if (proxy) {
|
||||
setGlobalDispatcher(new ProxyAgent(proxy as string));
|
||||
}
|
||||
}
|
||||
|
||||
private async executeFallback(
|
||||
@@ -94,70 +101,40 @@ export class WebFetchTool extends BaseTool<WebFetchToolParams, ToolResult> {
|
||||
returnDisplay: 'Error: No URL found in the prompt for fallback.',
|
||||
};
|
||||
}
|
||||
// For now, we only support one URL for fallback
|
||||
let url = urls[0];
|
||||
|
||||
const results: string[] = [];
|
||||
const processedUrls: string[] = [];
|
||||
|
||||
// Process multiple URLs (up to 20 as mentioned in description)
|
||||
const urlsToProcess = urls.slice(0, 20);
|
||||
|
||||
for (const originalUrl of urlsToProcess) {
|
||||
let url = originalUrl;
|
||||
|
||||
// Convert GitHub blob URL to raw URL
|
||||
if (url.includes('github.com') && url.includes('/blob/')) {
|
||||
url = url
|
||||
.replace('github.com', 'raw.githubusercontent.com')
|
||||
.replace('/blob/', '/');
|
||||
}
|
||||
|
||||
try {
|
||||
const response = await fetchWithTimeout(url, URL_FETCH_TIMEOUT_MS);
|
||||
if (!response.ok) {
|
||||
throw new Error(
|
||||
`Request failed with status code ${response.status} ${response.statusText}`,
|
||||
);
|
||||
}
|
||||
const html = await response.text();
|
||||
const textContent = convert(html, {
|
||||
wordwrap: false,
|
||||
selectors: [
|
||||
{ selector: 'a', options: { ignoreHref: true } },
|
||||
{ selector: 'img', format: 'skip' },
|
||||
],
|
||||
}).substring(0, MAX_CONTENT_LENGTH);
|
||||
|
||||
results.push(`Content from ${url}:\n${textContent}`);
|
||||
processedUrls.push(url);
|
||||
} catch (e) {
|
||||
const error = e as Error;
|
||||
results.push(`Error fetching ${url}: ${error.message}`);
|
||||
processedUrls.push(url);
|
||||
}
|
||||
// Convert GitHub blob URL to raw URL
|
||||
if (url.includes('github.com') && url.includes('/blob/')) {
|
||||
url = url
|
||||
.replace('github.com', 'raw.githubusercontent.com')
|
||||
.replace('/blob/', '/');
|
||||
}
|
||||
|
||||
try {
|
||||
const geminiClient = this.config.getGeminiClient();
|
||||
const combinedContent = results.join('\n\n---\n\n');
|
||||
|
||||
// Ensure the total prompt length doesn't exceed limits
|
||||
const maxPromptLength = 200000; // Leave room for system instructions
|
||||
const promptPrefix = `The user requested the following: "${params.prompt}".
|
||||
|
||||
I have fetched the content from the following URL(s). Please use this content to answer the user's request. Do not attempt to access the URL(s) again.
|
||||
|
||||
`;
|
||||
|
||||
let finalContent = combinedContent;
|
||||
if (promptPrefix.length + combinedContent.length > maxPromptLength) {
|
||||
const availableLength = maxPromptLength - promptPrefix.length - 100; // Leave some buffer
|
||||
finalContent =
|
||||
combinedContent.substring(0, availableLength) +
|
||||
'\n\n[Content truncated due to length limits]';
|
||||
const response = await fetchWithTimeout(url, URL_FETCH_TIMEOUT_MS);
|
||||
if (!response.ok) {
|
||||
throw new Error(
|
||||
`Request failed with status code ${response.status} ${response.statusText}`,
|
||||
);
|
||||
}
|
||||
const html = await response.text();
|
||||
const textContent = convert(html, {
|
||||
wordwrap: false,
|
||||
selectors: [
|
||||
{ selector: 'a', options: { ignoreHref: true } },
|
||||
{ selector: 'img', format: 'skip' },
|
||||
],
|
||||
}).substring(0, MAX_CONTENT_LENGTH);
|
||||
|
||||
const fallbackPrompt = promptPrefix + finalContent;
|
||||
const geminiClient = this.config.getGeminiClient();
|
||||
const fallbackPrompt = `The user requested the following: "${params.prompt}".
|
||||
|
||||
I was unable to access the URL directly. Instead, I have fetched the raw content of the page. Please use the following content to answer the user's request. Do not attempt to access the URL again.
|
||||
|
||||
---
|
||||
${textContent}
|
||||
---`;
|
||||
const result = await geminiClient.generateContent(
|
||||
[{ role: 'user', parts: [{ text: fallbackPrompt }] }],
|
||||
{},
|
||||
@@ -166,11 +143,11 @@ I have fetched the content from the following URL(s). Please use this content to
|
||||
const resultText = getResponseText(result) || '';
|
||||
return {
|
||||
llmContent: resultText,
|
||||
returnDisplay: `Content from ${processedUrls.length} URL(s) processed using fallback fetch.`,
|
||||
returnDisplay: `Content for ${url} processed using fallback fetch.`,
|
||||
};
|
||||
} catch (e) {
|
||||
const error = e as Error;
|
||||
const errorMessage = `Error during fallback processing: ${error.message}`;
|
||||
const errorMessage = `Error during fallback fetch for ${url}: ${error.message}`;
|
||||
return {
|
||||
llmContent: `Error: ${errorMessage}`,
|
||||
returnDisplay: `Error: ${errorMessage}`,
|
||||
@@ -262,12 +239,6 @@ I have fetched the content from the following URL(s). Please use this content to
|
||||
}
|
||||
|
||||
const geminiClient = this.config.getGeminiClient();
|
||||
const contentGenerator = geminiClient.getContentGenerator();
|
||||
|
||||
// Check if using OpenAI content generator - if so, use fallback
|
||||
if (contentGenerator.constructor.name === 'OpenAIContentGenerator') {
|
||||
return this.executeFallback(params, signal);
|
||||
}
|
||||
|
||||
try {
|
||||
const response = await geminiClient.generateContent(
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
*/
|
||||
|
||||
import { GroundingMetadata } from '@google/genai';
|
||||
import { BaseTool, ToolResult } from './tools.js';
|
||||
import { BaseTool, Icon, ToolResult } from './tools.js';
|
||||
import { Type } from '@google/genai';
|
||||
import { SchemaValidator } from '../utils/schemaValidator.js';
|
||||
|
||||
@@ -69,6 +69,7 @@ export class WebSearchTool extends BaseTool<
|
||||
WebSearchTool.Name,
|
||||
'GoogleSearch',
|
||||
'Performs a web search using Google Search (via the Gemini API) and returns the results. This tool is useful for finding information on the internet based on a query.',
|
||||
Icon.Globe,
|
||||
{
|
||||
type: Type.OBJECT,
|
||||
properties: {
|
||||
|
||||
@@ -15,6 +15,7 @@ import {
|
||||
ToolEditConfirmationDetails,
|
||||
ToolConfirmationOutcome,
|
||||
ToolCallConfirmationDetails,
|
||||
Icon,
|
||||
} from './tools.js';
|
||||
import { Type } from '@google/genai';
|
||||
import { SchemaValidator } from '../utils/schemaValidator.js';
|
||||
@@ -72,9 +73,10 @@ export class WriteFileTool
|
||||
super(
|
||||
WriteFileTool.Name,
|
||||
'WriteFile',
|
||||
`Writes content to a specified file in the local filesystem.
|
||||
|
||||
`Writes content to a specified file in the local filesystem.
|
||||
|
||||
The user has the ability to modify \`content\`. If modified, this will be stated in the response.`,
|
||||
Icon.Pencil,
|
||||
{
|
||||
properties: {
|
||||
file_path: {
|
||||
@@ -184,6 +186,8 @@ export class WriteFileTool
|
||||
title: `Confirm Write: ${shortenPath(relativePath)}`,
|
||||
fileName,
|
||||
fileDiff,
|
||||
originalContent,
|
||||
newContent: correctedContent,
|
||||
onConfirm: async (outcome: ToolConfirmationOutcome) => {
|
||||
if (outcome === ToolConfirmationOutcome.ProceedAlways) {
|
||||
this.config.setApprovalMode(ApprovalMode.AUTO_EDIT);
|
||||
@@ -269,7 +273,12 @@ export class WriteFileTool
|
||||
);
|
||||
}
|
||||
|
||||
const displayResult: FileDiff = { fileDiff, fileName };
|
||||
const displayResult: FileDiff = {
|
||||
fileDiff,
|
||||
fileName,
|
||||
originalContent: correctedContentResult.originalContent,
|
||||
newContent: correctedContentResult.correctedContent,
|
||||
};
|
||||
|
||||
const lines = fileContent.split('\n').length;
|
||||
const mimetype = getSpecificMimeType(params.file_path);
|
||||
|
||||
Reference in New Issue
Block a user