mirror of
https://github.com/QwenLM/qwen-code.git
synced 2025-12-20 16:57:46 +00:00
pre-release commit
This commit is contained in:
12
packages/core/src/tools/diffOptions.ts
Normal file
12
packages/core/src/tools/diffOptions.ts
Normal file
@@ -0,0 +1,12 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import * as Diff from 'diff';
|
||||
|
||||
export const DEFAULT_DIFF_OPTIONS: Diff.PatchOptions = {
|
||||
context: 3,
|
||||
ignoreWhitespace: true,
|
||||
};
|
||||
665
packages/core/src/tools/edit.test.ts
Normal file
665
packages/core/src/tools/edit.test.ts
Normal file
@@ -0,0 +1,665 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
/* eslint-disable @typescript-eslint/no-explicit-any */
|
||||
|
||||
const mockEnsureCorrectEdit = vi.hoisted(() => vi.fn());
|
||||
const mockGenerateJson = vi.hoisted(() => vi.fn());
|
||||
const mockOpenDiff = vi.hoisted(() => vi.fn());
|
||||
|
||||
vi.mock('../utils/editCorrector.js', () => ({
|
||||
ensureCorrectEdit: mockEnsureCorrectEdit,
|
||||
}));
|
||||
|
||||
vi.mock('../core/client.js', () => ({
|
||||
GeminiClient: vi.fn().mockImplementation(() => ({
|
||||
generateJson: mockGenerateJson,
|
||||
})),
|
||||
}));
|
||||
|
||||
vi.mock('../utils/editor.js', () => ({
|
||||
openDiff: mockOpenDiff,
|
||||
}));
|
||||
|
||||
import { describe, it, expect, beforeEach, afterEach, vi, Mock } from 'vitest';
|
||||
import { EditTool, EditToolParams } from './edit.js';
|
||||
import { FileDiff } from './tools.js';
|
||||
import path from 'path';
|
||||
import fs from 'fs';
|
||||
import os from 'os';
|
||||
import { ApprovalMode, Config } from '../config/config.js';
|
||||
import { Content, Part, SchemaUnion } from '@google/genai';
|
||||
|
||||
describe('EditTool', () => {
|
||||
let tool: EditTool;
|
||||
let tempDir: string;
|
||||
let rootDir: string;
|
||||
let mockConfig: Config;
|
||||
let geminiClient: any;
|
||||
|
||||
beforeEach(() => {
|
||||
tempDir = fs.mkdtempSync(path.join(os.tmpdir(), 'edit-tool-test-'));
|
||||
rootDir = path.join(tempDir, 'root');
|
||||
fs.mkdirSync(rootDir);
|
||||
|
||||
geminiClient = {
|
||||
generateJson: mockGenerateJson, // mockGenerateJson is already defined and hoisted
|
||||
};
|
||||
|
||||
mockConfig = {
|
||||
getGeminiClient: vi.fn().mockReturnValue(geminiClient),
|
||||
getTargetDir: () => rootDir,
|
||||
getApprovalMode: vi.fn(),
|
||||
setApprovalMode: vi.fn(),
|
||||
// getGeminiConfig: () => ({ apiKey: 'test-api-key' }), // This was not a real Config method
|
||||
// Add other properties/methods of Config if EditTool uses them
|
||||
// Minimal other methods to satisfy Config type if needed by EditTool constructor or other direct uses:
|
||||
getApiKey: () => 'test-api-key',
|
||||
getModel: () => 'test-model',
|
||||
getSandbox: () => false,
|
||||
getDebugMode: () => false,
|
||||
getQuestion: () => undefined,
|
||||
getFullContext: () => false,
|
||||
getToolDiscoveryCommand: () => undefined,
|
||||
getToolCallCommand: () => undefined,
|
||||
getMcpServerCommand: () => undefined,
|
||||
getMcpServers: () => undefined,
|
||||
getUserAgent: () => 'test-agent',
|
||||
getUserMemory: () => '',
|
||||
setUserMemory: vi.fn(),
|
||||
getGeminiMdFileCount: () => 0,
|
||||
setGeminiMdFileCount: vi.fn(),
|
||||
getToolRegistry: () => ({}) as any, // Minimal mock for ToolRegistry
|
||||
} as unknown as Config;
|
||||
|
||||
// Reset mocks before each test
|
||||
(mockConfig.getApprovalMode as Mock).mockClear();
|
||||
// Default to not skipping confirmation
|
||||
(mockConfig.getApprovalMode as Mock).mockReturnValue(ApprovalMode.DEFAULT);
|
||||
|
||||
// Reset mocks and set default implementation for ensureCorrectEdit
|
||||
mockEnsureCorrectEdit.mockReset();
|
||||
mockEnsureCorrectEdit.mockImplementation(
|
||||
async (_, currentContent, params) => {
|
||||
let occurrences = 0;
|
||||
if (params.old_string && currentContent) {
|
||||
// Simple string counting for the mock
|
||||
let index = currentContent.indexOf(params.old_string);
|
||||
while (index !== -1) {
|
||||
occurrences++;
|
||||
index = currentContent.indexOf(params.old_string, index + 1);
|
||||
}
|
||||
} else if (params.old_string === '') {
|
||||
occurrences = 0; // Creating a new file
|
||||
}
|
||||
return Promise.resolve({ params, occurrences });
|
||||
},
|
||||
);
|
||||
|
||||
// Default mock for generateJson to return the snippet unchanged
|
||||
mockGenerateJson.mockReset();
|
||||
mockGenerateJson.mockImplementation(
|
||||
async (contents: Content[], schema: SchemaUnion) => {
|
||||
// The problematic_snippet is the last part of the user's content
|
||||
const userContent = contents.find((c: Content) => c.role === 'user');
|
||||
let promptText = '';
|
||||
if (userContent && userContent.parts) {
|
||||
promptText = userContent.parts
|
||||
.filter((p: Part) => typeof (p as any).text === 'string')
|
||||
.map((p: Part) => (p as any).text)
|
||||
.join('\n');
|
||||
}
|
||||
const snippetMatch = promptText.match(
|
||||
/Problematic target snippet:\n```\n([\s\S]*?)\n```/,
|
||||
);
|
||||
const problematicSnippet =
|
||||
snippetMatch && snippetMatch[1] ? snippetMatch[1] : '';
|
||||
|
||||
if (((schema as any).properties as any)?.corrected_target_snippet) {
|
||||
return Promise.resolve({
|
||||
corrected_target_snippet: problematicSnippet,
|
||||
});
|
||||
}
|
||||
if (((schema as any).properties as any)?.corrected_new_string) {
|
||||
// For new_string correction, we might need more sophisticated logic,
|
||||
// but for now, returning original is a safe default if not specified by a test.
|
||||
const originalNewStringMatch = promptText.match(
|
||||
/original_new_string \(what was intended to replace original_old_string\):\n```\n([\s\S]*?)\n```/,
|
||||
);
|
||||
const originalNewString =
|
||||
originalNewStringMatch && originalNewStringMatch[1]
|
||||
? originalNewStringMatch[1]
|
||||
: '';
|
||||
return Promise.resolve({ corrected_new_string: originalNewString });
|
||||
}
|
||||
return Promise.resolve({}); // Default empty object if schema doesn't match
|
||||
},
|
||||
);
|
||||
|
||||
tool = new EditTool(mockConfig);
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
fs.rmSync(tempDir, { recursive: true, force: true });
|
||||
});
|
||||
|
||||
describe('_applyReplacement', () => {
|
||||
// Access private method for testing
|
||||
// Note: `tool` is initialized in `beforeEach` of the parent describe block
|
||||
it('should return newString if isNewFile is true', () => {
|
||||
expect((tool as any)._applyReplacement(null, 'old', 'new', true)).toBe(
|
||||
'new',
|
||||
);
|
||||
expect(
|
||||
(tool as any)._applyReplacement('existing', 'old', 'new', true),
|
||||
).toBe('new');
|
||||
});
|
||||
|
||||
it('should return newString if currentContent is null and oldString is empty (defensive)', () => {
|
||||
expect((tool as any)._applyReplacement(null, '', 'new', false)).toBe(
|
||||
'new',
|
||||
);
|
||||
});
|
||||
|
||||
it('should return empty string if currentContent is null and oldString is not empty (defensive)', () => {
|
||||
expect((tool as any)._applyReplacement(null, 'old', 'new', false)).toBe(
|
||||
'',
|
||||
);
|
||||
});
|
||||
|
||||
it('should replace oldString with newString in currentContent', () => {
|
||||
expect(
|
||||
(tool as any)._applyReplacement(
|
||||
'hello old world old',
|
||||
'old',
|
||||
'new',
|
||||
false,
|
||||
),
|
||||
).toBe('hello new world new');
|
||||
});
|
||||
|
||||
it('should return currentContent if oldString is empty and not a new file', () => {
|
||||
expect(
|
||||
(tool as any)._applyReplacement('hello world', '', 'new', false),
|
||||
).toBe('hello world');
|
||||
});
|
||||
});
|
||||
|
||||
describe('validateToolParams', () => {
|
||||
it('should return null for valid params', () => {
|
||||
const params: EditToolParams = {
|
||||
file_path: path.join(rootDir, 'test.txt'),
|
||||
old_string: 'old',
|
||||
new_string: 'new',
|
||||
};
|
||||
expect(tool.validateToolParams(params)).toBeNull();
|
||||
});
|
||||
|
||||
it('should return error for relative path', () => {
|
||||
const params: EditToolParams = {
|
||||
file_path: 'test.txt',
|
||||
old_string: 'old',
|
||||
new_string: 'new',
|
||||
};
|
||||
expect(tool.validateToolParams(params)).toMatch(
|
||||
/File path must be absolute/,
|
||||
);
|
||||
});
|
||||
|
||||
it('should return error for path outside root', () => {
|
||||
const params: EditToolParams = {
|
||||
file_path: path.join(tempDir, 'outside-root.txt'),
|
||||
old_string: 'old',
|
||||
new_string: 'new',
|
||||
};
|
||||
expect(tool.validateToolParams(params)).toMatch(
|
||||
/File path must be within the root directory/,
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('shouldConfirmExecute', () => {
|
||||
const testFile = 'edit_me.txt';
|
||||
let filePath: string;
|
||||
|
||||
beforeEach(() => {
|
||||
filePath = path.join(rootDir, testFile);
|
||||
});
|
||||
|
||||
it('should return false if params are invalid', async () => {
|
||||
const params: EditToolParams = {
|
||||
file_path: 'relative.txt',
|
||||
old_string: 'old',
|
||||
new_string: 'new',
|
||||
};
|
||||
expect(
|
||||
await tool.shouldConfirmExecute(params, new AbortController().signal),
|
||||
).toBe(false);
|
||||
});
|
||||
|
||||
it('should request confirmation for valid edit', async () => {
|
||||
fs.writeFileSync(filePath, 'some old content here');
|
||||
const params: EditToolParams = {
|
||||
file_path: filePath,
|
||||
old_string: 'old',
|
||||
new_string: 'new',
|
||||
};
|
||||
// ensureCorrectEdit will be called by shouldConfirmExecute
|
||||
mockEnsureCorrectEdit.mockResolvedValueOnce({ params, occurrences: 1 });
|
||||
const confirmation = await tool.shouldConfirmExecute(
|
||||
params,
|
||||
new AbortController().signal,
|
||||
);
|
||||
expect(confirmation).toEqual(
|
||||
expect.objectContaining({
|
||||
title: `Confirm Edit: ${testFile}`,
|
||||
fileName: testFile,
|
||||
fileDiff: expect.any(String),
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
it('should return false if old_string is not found (ensureCorrectEdit returns 0)', async () => {
|
||||
fs.writeFileSync(filePath, 'some content here');
|
||||
const params: EditToolParams = {
|
||||
file_path: filePath,
|
||||
old_string: 'not_found',
|
||||
new_string: 'new',
|
||||
};
|
||||
mockEnsureCorrectEdit.mockResolvedValueOnce({ params, occurrences: 0 });
|
||||
expect(
|
||||
await tool.shouldConfirmExecute(params, new AbortController().signal),
|
||||
).toBe(false);
|
||||
});
|
||||
|
||||
it('should return false if multiple occurrences of old_string are found (ensureCorrectEdit returns > 1)', async () => {
|
||||
fs.writeFileSync(filePath, 'old old content here');
|
||||
const params: EditToolParams = {
|
||||
file_path: filePath,
|
||||
old_string: 'old',
|
||||
new_string: 'new',
|
||||
};
|
||||
mockEnsureCorrectEdit.mockResolvedValueOnce({ params, occurrences: 2 });
|
||||
expect(
|
||||
await tool.shouldConfirmExecute(params, new AbortController().signal),
|
||||
).toBe(false);
|
||||
});
|
||||
|
||||
it('should request confirmation for creating a new file (empty old_string)', async () => {
|
||||
const newFileName = 'new_file.txt';
|
||||
const newFilePath = path.join(rootDir, newFileName);
|
||||
const params: EditToolParams = {
|
||||
file_path: newFilePath,
|
||||
old_string: '',
|
||||
new_string: 'new file content',
|
||||
};
|
||||
// ensureCorrectEdit might not be called if old_string is empty,
|
||||
// as shouldConfirmExecute handles this for diff generation.
|
||||
// If it is called, it should return 0 occurrences for a new file.
|
||||
mockEnsureCorrectEdit.mockResolvedValueOnce({ params, occurrences: 0 });
|
||||
const confirmation = await tool.shouldConfirmExecute(
|
||||
params,
|
||||
new AbortController().signal,
|
||||
);
|
||||
expect(confirmation).toEqual(
|
||||
expect.objectContaining({
|
||||
title: `Confirm Edit: ${newFileName}`,
|
||||
fileName: newFileName,
|
||||
fileDiff: expect.any(String),
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
it('should use corrected params from ensureCorrectEdit for diff generation', async () => {
|
||||
const originalContent = 'This is the original string to be replaced.';
|
||||
const originalOldString = 'original string';
|
||||
const originalNewString = 'new string';
|
||||
|
||||
const correctedOldString = 'original string to be replaced'; // More specific
|
||||
const correctedNewString = 'completely new string'; // Different replacement
|
||||
const expectedFinalContent = 'This is the completely new string.';
|
||||
|
||||
fs.writeFileSync(filePath, originalContent);
|
||||
const params: EditToolParams = {
|
||||
file_path: filePath,
|
||||
old_string: originalOldString,
|
||||
new_string: originalNewString,
|
||||
};
|
||||
|
||||
// The main beforeEach already calls mockEnsureCorrectEdit.mockReset()
|
||||
// Set a specific mock for this test case
|
||||
let mockCalled = false;
|
||||
mockEnsureCorrectEdit.mockImplementationOnce(
|
||||
async (_, content, p, client) => {
|
||||
mockCalled = true;
|
||||
expect(content).toBe(originalContent);
|
||||
expect(p).toBe(params);
|
||||
expect(client).toBe(geminiClient);
|
||||
return {
|
||||
params: {
|
||||
file_path: filePath,
|
||||
old_string: correctedOldString,
|
||||
new_string: correctedNewString,
|
||||
},
|
||||
occurrences: 1,
|
||||
};
|
||||
},
|
||||
);
|
||||
|
||||
const confirmation = (await tool.shouldConfirmExecute(
|
||||
params,
|
||||
new AbortController().signal,
|
||||
)) as FileDiff;
|
||||
|
||||
expect(mockCalled).toBe(true); // Check if the mock implementation was run
|
||||
// expect(mockEnsureCorrectEdit).toHaveBeenCalledWith(originalContent, params, expect.anything()); // Keep this commented for now
|
||||
expect(confirmation).toEqual(
|
||||
expect.objectContaining({
|
||||
title: `Confirm Edit: ${testFile}`,
|
||||
fileName: testFile,
|
||||
}),
|
||||
);
|
||||
// Check that the diff is based on the corrected strings leading to the new state
|
||||
expect(confirmation.fileDiff).toContain(`-${originalContent}`);
|
||||
expect(confirmation.fileDiff).toContain(`+${expectedFinalContent}`);
|
||||
|
||||
// Verify that applying the correctedOldString and correctedNewString to originalContent
|
||||
// indeed produces the expectedFinalContent, which is what the diff should reflect.
|
||||
const patchedContent = originalContent.replace(
|
||||
correctedOldString, // This was the string identified by ensureCorrectEdit for replacement
|
||||
correctedNewString, // This was the string identified by ensureCorrectEdit as the replacement
|
||||
);
|
||||
expect(patchedContent).toBe(expectedFinalContent);
|
||||
});
|
||||
});
|
||||
|
||||
describe('execute', () => {
|
||||
const testFile = 'execute_me.txt';
|
||||
let filePath: string;
|
||||
|
||||
beforeEach(() => {
|
||||
filePath = path.join(rootDir, testFile);
|
||||
// Default for execute tests, can be overridden
|
||||
mockEnsureCorrectEdit.mockImplementation(async (_, content, params) => {
|
||||
let occurrences = 0;
|
||||
if (params.old_string && content) {
|
||||
let index = content.indexOf(params.old_string);
|
||||
while (index !== -1) {
|
||||
occurrences++;
|
||||
index = content.indexOf(params.old_string, index + 1);
|
||||
}
|
||||
} else if (params.old_string === '') {
|
||||
occurrences = 0;
|
||||
}
|
||||
return { params, occurrences };
|
||||
});
|
||||
});
|
||||
|
||||
it('should return error if params are invalid', async () => {
|
||||
const params: EditToolParams = {
|
||||
file_path: 'relative.txt',
|
||||
old_string: 'old',
|
||||
new_string: 'new',
|
||||
};
|
||||
const result = await tool.execute(params, new AbortController().signal);
|
||||
expect(result.llmContent).toMatch(/Error: Invalid parameters provided/);
|
||||
expect(result.returnDisplay).toMatch(/Error: File path must be absolute/);
|
||||
});
|
||||
|
||||
it('should edit an existing file and return diff with fileName', async () => {
|
||||
const initialContent = 'This is some old text.';
|
||||
const newContent = 'This is some new text.'; // old -> new
|
||||
fs.writeFileSync(filePath, initialContent, 'utf8');
|
||||
const params: EditToolParams = {
|
||||
file_path: filePath,
|
||||
old_string: 'old',
|
||||
new_string: 'new',
|
||||
};
|
||||
|
||||
// Specific mock for this test's execution path in calculateEdit
|
||||
// ensureCorrectEdit is NOT called by calculateEdit, only by shouldConfirmExecute
|
||||
// So, the default mockEnsureCorrectEdit should correctly return 1 occurrence for 'old' in initialContent
|
||||
|
||||
// Simulate confirmation by setting shouldAlwaysEdit
|
||||
(tool as any).shouldAlwaysEdit = true;
|
||||
|
||||
const result = await tool.execute(params, new AbortController().signal);
|
||||
|
||||
(tool as any).shouldAlwaysEdit = false; // Reset for other tests
|
||||
|
||||
expect(result.llmContent).toMatch(/Successfully modified file/);
|
||||
expect(fs.readFileSync(filePath, 'utf8')).toBe(newContent);
|
||||
const display = result.returnDisplay as FileDiff;
|
||||
expect(display.fileDiff).toMatch(initialContent);
|
||||
expect(display.fileDiff).toMatch(newContent);
|
||||
expect(display.fileName).toBe(testFile);
|
||||
});
|
||||
|
||||
it('should create a new file if old_string is empty and file does not exist, and return created message', async () => {
|
||||
const newFileName = 'brand_new_file.txt';
|
||||
const newFilePath = path.join(rootDir, newFileName);
|
||||
const fileContent = 'Content for the new file.';
|
||||
const params: EditToolParams = {
|
||||
file_path: newFilePath,
|
||||
old_string: '',
|
||||
new_string: fileContent,
|
||||
};
|
||||
|
||||
(mockConfig.getApprovalMode as Mock).mockReturnValueOnce(
|
||||
ApprovalMode.AUTO_EDIT,
|
||||
);
|
||||
const result = await tool.execute(params, new AbortController().signal);
|
||||
|
||||
expect(result.llmContent).toMatch(/Created new file/);
|
||||
expect(fs.existsSync(newFilePath)).toBe(true);
|
||||
expect(fs.readFileSync(newFilePath, 'utf8')).toBe(fileContent);
|
||||
expect(result.returnDisplay).toBe(`Created ${newFileName}`);
|
||||
});
|
||||
|
||||
it('should return error if old_string is not found in file', async () => {
|
||||
fs.writeFileSync(filePath, 'Some content.', 'utf8');
|
||||
const params: EditToolParams = {
|
||||
file_path: filePath,
|
||||
old_string: 'nonexistent',
|
||||
new_string: 'replacement',
|
||||
};
|
||||
// The default mockEnsureCorrectEdit will return 0 occurrences for 'nonexistent'
|
||||
const result = await tool.execute(params, new AbortController().signal);
|
||||
expect(result.llmContent).toMatch(
|
||||
/0 occurrences found for old_string in/,
|
||||
);
|
||||
expect(result.returnDisplay).toMatch(
|
||||
/Failed to edit, could not find the string to replace./,
|
||||
);
|
||||
});
|
||||
|
||||
it('should return error if multiple occurrences of old_string are found', async () => {
|
||||
fs.writeFileSync(filePath, 'multiple old old strings', 'utf8');
|
||||
const params: EditToolParams = {
|
||||
file_path: filePath,
|
||||
old_string: 'old',
|
||||
new_string: 'new',
|
||||
};
|
||||
// The default mockEnsureCorrectEdit will return 2 occurrences for 'old'
|
||||
const result = await tool.execute(params, new AbortController().signal);
|
||||
expect(result.llmContent).toMatch(
|
||||
/Expected 1 occurrence but found 2 for old_string in file/,
|
||||
);
|
||||
expect(result.returnDisplay).toMatch(
|
||||
/Failed to edit, expected 1 occurrence but found 2/,
|
||||
);
|
||||
});
|
||||
|
||||
it('should successfully replace multiple occurrences when expected_replacements specified', async () => {
|
||||
fs.writeFileSync(filePath, 'old text old text old text', 'utf8');
|
||||
const params: EditToolParams = {
|
||||
file_path: filePath,
|
||||
old_string: 'old',
|
||||
new_string: 'new',
|
||||
expected_replacements: 3,
|
||||
};
|
||||
|
||||
// Simulate confirmation by setting shouldAlwaysEdit
|
||||
(tool as any).shouldAlwaysEdit = true;
|
||||
|
||||
const result = await tool.execute(params, new AbortController().signal);
|
||||
|
||||
(tool as any).shouldAlwaysEdit = false; // Reset for other tests
|
||||
|
||||
expect(result.llmContent).toMatch(/Successfully modified file/);
|
||||
expect(fs.readFileSync(filePath, 'utf8')).toBe(
|
||||
'new text new text new text',
|
||||
);
|
||||
const display = result.returnDisplay as FileDiff;
|
||||
expect(display.fileDiff).toMatch(/old text old text old text/);
|
||||
expect(display.fileDiff).toMatch(/new text new text new text/);
|
||||
expect(display.fileName).toBe(testFile);
|
||||
});
|
||||
|
||||
it('should return error if expected_replacements does not match actual occurrences', async () => {
|
||||
fs.writeFileSync(filePath, 'old text old text', 'utf8');
|
||||
const params: EditToolParams = {
|
||||
file_path: filePath,
|
||||
old_string: 'old',
|
||||
new_string: 'new',
|
||||
expected_replacements: 3, // Expecting 3 but only 2 exist
|
||||
};
|
||||
const result = await tool.execute(params, new AbortController().signal);
|
||||
expect(result.llmContent).toMatch(
|
||||
/Expected 3 occurrences but found 2 for old_string in file/,
|
||||
);
|
||||
expect(result.returnDisplay).toMatch(
|
||||
/Failed to edit, expected 3 occurrences but found 2/,
|
||||
);
|
||||
});
|
||||
|
||||
it('should return error if trying to create a file that already exists (empty old_string)', async () => {
|
||||
fs.writeFileSync(filePath, 'Existing content', 'utf8');
|
||||
const params: EditToolParams = {
|
||||
file_path: filePath,
|
||||
old_string: '',
|
||||
new_string: 'new content',
|
||||
};
|
||||
const result = await tool.execute(params, new AbortController().signal);
|
||||
expect(result.llmContent).toMatch(/File already exists, cannot create/);
|
||||
expect(result.returnDisplay).toMatch(
|
||||
/Attempted to create a file that already exists/,
|
||||
);
|
||||
});
|
||||
|
||||
it('should include modification message when proposed content is modified', async () => {
|
||||
const initialContent = 'This is some old text.';
|
||||
fs.writeFileSync(filePath, initialContent, 'utf8');
|
||||
const params: EditToolParams = {
|
||||
file_path: filePath,
|
||||
old_string: 'old',
|
||||
new_string: 'new',
|
||||
modified_by_user: true,
|
||||
};
|
||||
|
||||
(mockConfig.getApprovalMode as Mock).mockReturnValueOnce(
|
||||
ApprovalMode.AUTO_EDIT,
|
||||
);
|
||||
const result = await tool.execute(params, new AbortController().signal);
|
||||
|
||||
expect(result.llmContent).toMatch(
|
||||
/User modified the `new_string` content/,
|
||||
);
|
||||
});
|
||||
|
||||
it('should not include modification message when proposed content is not modified', async () => {
|
||||
const initialContent = 'This is some old text.';
|
||||
fs.writeFileSync(filePath, initialContent, 'utf8');
|
||||
const params: EditToolParams = {
|
||||
file_path: filePath,
|
||||
old_string: 'old',
|
||||
new_string: 'new',
|
||||
modified_by_user: false,
|
||||
};
|
||||
|
||||
(mockConfig.getApprovalMode as Mock).mockReturnValueOnce(
|
||||
ApprovalMode.AUTO_EDIT,
|
||||
);
|
||||
const result = await tool.execute(params, new AbortController().signal);
|
||||
|
||||
expect(result.llmContent).not.toMatch(
|
||||
/User modified the `new_string` content/,
|
||||
);
|
||||
});
|
||||
|
||||
it('should not include modification message when modified_by_user is not provided', async () => {
|
||||
const initialContent = 'This is some old text.';
|
||||
fs.writeFileSync(filePath, initialContent, 'utf8');
|
||||
const params: EditToolParams = {
|
||||
file_path: filePath,
|
||||
old_string: 'old',
|
||||
new_string: 'new',
|
||||
};
|
||||
|
||||
(mockConfig.getApprovalMode as Mock).mockReturnValueOnce(
|
||||
ApprovalMode.AUTO_EDIT,
|
||||
);
|
||||
const result = await tool.execute(params, new AbortController().signal);
|
||||
|
||||
expect(result.llmContent).not.toMatch(
|
||||
/User modified the `new_string` content/,
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('getDescription', () => {
|
||||
it('should return "No file changes to..." if old_string and new_string are the same', () => {
|
||||
const testFileName = 'test.txt';
|
||||
const params: EditToolParams = {
|
||||
file_path: path.join(rootDir, testFileName),
|
||||
old_string: 'identical_string',
|
||||
new_string: 'identical_string',
|
||||
};
|
||||
// shortenPath will be called internally, resulting in just the file name
|
||||
expect(tool.getDescription(params)).toBe(
|
||||
`No file changes to ${testFileName}`,
|
||||
);
|
||||
});
|
||||
|
||||
it('should return a snippet of old and new strings if they are different', () => {
|
||||
const testFileName = 'test.txt';
|
||||
const params: EditToolParams = {
|
||||
file_path: path.join(rootDir, testFileName),
|
||||
old_string: 'this is the old string value',
|
||||
new_string: 'this is the new string value',
|
||||
};
|
||||
// shortenPath will be called internally, resulting in just the file name
|
||||
// The snippets are truncated at 30 chars + '...'
|
||||
expect(tool.getDescription(params)).toBe(
|
||||
`${testFileName}: this is the old string value => this is the new string value`,
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle very short strings correctly in the description', () => {
|
||||
const testFileName = 'short.txt';
|
||||
const params: EditToolParams = {
|
||||
file_path: path.join(rootDir, testFileName),
|
||||
old_string: 'old',
|
||||
new_string: 'new',
|
||||
};
|
||||
expect(tool.getDescription(params)).toBe(`${testFileName}: old => new`);
|
||||
});
|
||||
|
||||
it('should truncate long strings in the description', () => {
|
||||
const testFileName = 'long.txt';
|
||||
const params: EditToolParams = {
|
||||
file_path: path.join(rootDir, testFileName),
|
||||
old_string:
|
||||
'this is a very long old string that will definitely be truncated',
|
||||
new_string:
|
||||
'this is a very long new string that will also be truncated',
|
||||
};
|
||||
expect(tool.getDescription(params)).toBe(
|
||||
`${testFileName}: this is a very long old string... => this is a very long new string...`,
|
||||
);
|
||||
});
|
||||
});
|
||||
});
|
||||
471
packages/core/src/tools/edit.ts
Normal file
471
packages/core/src/tools/edit.ts
Normal file
@@ -0,0 +1,471 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import * as fs from 'fs';
|
||||
import * as path from 'path';
|
||||
import * as Diff from 'diff';
|
||||
import {
|
||||
BaseTool,
|
||||
ToolCallConfirmationDetails,
|
||||
ToolConfirmationOutcome,
|
||||
ToolEditConfirmationDetails,
|
||||
ToolResult,
|
||||
ToolResultDisplay,
|
||||
} from './tools.js';
|
||||
import { Type } from '@google/genai';
|
||||
import { SchemaValidator } from '../utils/schemaValidator.js';
|
||||
import { makeRelative, shortenPath } from '../utils/paths.js';
|
||||
import { isNodeError } from '../utils/errors.js';
|
||||
import { Config, ApprovalMode } from '../config/config.js';
|
||||
import { ensureCorrectEdit } from '../utils/editCorrector.js';
|
||||
import { DEFAULT_DIFF_OPTIONS } from './diffOptions.js';
|
||||
import { ReadFileTool } from './read-file.js';
|
||||
import { ModifiableTool, ModifyContext } from './modifiable-tool.js';
|
||||
import { isWithinRoot } from '../utils/fileUtils.js';
|
||||
|
||||
/**
|
||||
* Parameters for the Edit tool
|
||||
*/
|
||||
export interface EditToolParams {
|
||||
/**
|
||||
* The absolute path to the file to modify
|
||||
*/
|
||||
file_path: string;
|
||||
|
||||
/**
|
||||
* The text to replace
|
||||
*/
|
||||
old_string: string;
|
||||
|
||||
/**
|
||||
* The text to replace it with
|
||||
*/
|
||||
new_string: string;
|
||||
|
||||
/**
|
||||
* Number of replacements expected. Defaults to 1 if not specified.
|
||||
* Use when you want to replace multiple occurrences.
|
||||
*/
|
||||
expected_replacements?: number;
|
||||
|
||||
/**
|
||||
* Whether the edit was modified manually by the user.
|
||||
*/
|
||||
modified_by_user?: boolean;
|
||||
}
|
||||
|
||||
interface CalculatedEdit {
|
||||
currentContent: string | null;
|
||||
newContent: string;
|
||||
occurrences: number;
|
||||
error?: { display: string; raw: string };
|
||||
isNewFile: boolean;
|
||||
}
|
||||
|
||||
/**
|
||||
* Implementation of the Edit tool logic
|
||||
*/
|
||||
export class EditTool
|
||||
extends BaseTool<EditToolParams, ToolResult>
|
||||
implements ModifiableTool<EditToolParams>
|
||||
{
|
||||
static readonly Name = 'replace';
|
||||
|
||||
constructor(private readonly config: Config) {
|
||||
super(
|
||||
EditTool.Name,
|
||||
'Edit',
|
||||
`Replaces text within a file. By default, replaces a single occurrence, but can replace multiple occurrences when \`expected_replacements\` is specified. This tool requires providing significant context around the change to ensure precise targeting. Always use the ${ReadFileTool.Name} tool to examine the file's current content before attempting a text replacement.
|
||||
|
||||
The user has the ability to modify the \`new_string\` content. If modified, this will be stated in the response.
|
||||
|
||||
Expectation for required parameters:
|
||||
1. \`file_path\` MUST be an absolute path; otherwise an error will be thrown.
|
||||
2. \`old_string\` MUST be the exact literal text to replace (including all whitespace, indentation, newlines, and surrounding code etc.).
|
||||
3. \`new_string\` MUST be the exact literal text to replace \`old_string\` with (also including all whitespace, indentation, newlines, and surrounding code etc.). Ensure the resulting code is correct and idiomatic.
|
||||
4. NEVER escape \`old_string\` or \`new_string\`, that would break the exact literal text requirement.
|
||||
**Important:** If ANY of the above are not satisfied, the tool will fail. CRITICAL for \`old_string\`: Must uniquely identify the single instance to change. Include at least 3 lines of context BEFORE and AFTER the target text, matching whitespace and indentation precisely. If this string matches multiple locations, or does not match exactly, the tool will fail.
|
||||
**Multiple replacements:** Set \`expected_replacements\` to the number of occurrences you want to replace. The tool will replace ALL occurrences that match \`old_string\` exactly. Ensure the number of replacements matches your expectation.`,
|
||||
{
|
||||
properties: {
|
||||
file_path: {
|
||||
description:
|
||||
"The absolute path to the file to modify. Must start with '/'.",
|
||||
type: Type.STRING,
|
||||
},
|
||||
old_string: {
|
||||
description:
|
||||
'The exact literal text to replace, preferably unescaped. For single replacements (default), include at least 3 lines of context BEFORE and AFTER the target text, matching whitespace and indentation precisely. For multiple replacements, specify expected_replacements parameter. If this string is not the exact literal text (i.e. you escaped it) or does not match exactly, the tool will fail.',
|
||||
type: Type.STRING,
|
||||
},
|
||||
new_string: {
|
||||
description:
|
||||
'The exact literal text to replace `old_string` with, preferably unescaped. Provide the EXACT text. Ensure the resulting code is correct and idiomatic.',
|
||||
type: Type.STRING,
|
||||
},
|
||||
expected_replacements: {
|
||||
type: Type.NUMBER,
|
||||
description:
|
||||
'Number of replacements expected. Defaults to 1 if not specified. Use when you want to replace multiple occurrences.',
|
||||
minimum: 1,
|
||||
},
|
||||
},
|
||||
required: ['file_path', 'old_string', 'new_string'],
|
||||
type: Type.OBJECT,
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Validates the parameters for the Edit tool
|
||||
* @param params Parameters to validate
|
||||
* @returns Error message string or null if valid
|
||||
*/
|
||||
validateToolParams(params: EditToolParams): string | null {
|
||||
const errors = SchemaValidator.validate(this.schema.parameters, params);
|
||||
if (errors) {
|
||||
return errors;
|
||||
}
|
||||
|
||||
if (!path.isAbsolute(params.file_path)) {
|
||||
return `File path must be absolute: ${params.file_path}`;
|
||||
}
|
||||
|
||||
if (!isWithinRoot(params.file_path, this.config.getTargetDir())) {
|
||||
return `File path must be within the root directory (${this.config.getTargetDir()}): ${params.file_path}`;
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
private _applyReplacement(
|
||||
currentContent: string | null,
|
||||
oldString: string,
|
||||
newString: string,
|
||||
isNewFile: boolean,
|
||||
): string {
|
||||
if (isNewFile) {
|
||||
return newString;
|
||||
}
|
||||
if (currentContent === null) {
|
||||
// Should not happen if not a new file, but defensively return empty or newString if oldString is also empty
|
||||
return oldString === '' ? newString : '';
|
||||
}
|
||||
// If oldString is empty and it's not a new file, do not modify the content.
|
||||
if (oldString === '' && !isNewFile) {
|
||||
return currentContent;
|
||||
}
|
||||
return currentContent.replaceAll(oldString, newString);
|
||||
}
|
||||
|
||||
/**
|
||||
* Calculates the potential outcome of an edit operation.
|
||||
* @param params Parameters for the edit operation
|
||||
* @returns An object describing the potential edit outcome
|
||||
* @throws File system errors if reading the file fails unexpectedly (e.g., permissions)
|
||||
*/
|
||||
private async calculateEdit(
|
||||
params: EditToolParams,
|
||||
abortSignal: AbortSignal,
|
||||
): Promise<CalculatedEdit> {
|
||||
const expectedReplacements = params.expected_replacements ?? 1;
|
||||
let currentContent: string | null = null;
|
||||
let fileExists = false;
|
||||
let isNewFile = false;
|
||||
let finalNewString = params.new_string;
|
||||
let finalOldString = params.old_string;
|
||||
let occurrences = 0;
|
||||
let error: { display: string; raw: string } | undefined = undefined;
|
||||
|
||||
try {
|
||||
currentContent = fs.readFileSync(params.file_path, 'utf8');
|
||||
// Normalize line endings to LF for consistent processing.
|
||||
currentContent = currentContent.replace(/\r\n/g, '\n');
|
||||
fileExists = true;
|
||||
} catch (err: unknown) {
|
||||
if (!isNodeError(err) || err.code !== 'ENOENT') {
|
||||
// Rethrow unexpected FS errors (permissions, etc.)
|
||||
throw err;
|
||||
}
|
||||
fileExists = false;
|
||||
}
|
||||
|
||||
if (params.old_string === '' && !fileExists) {
|
||||
// Creating a new file
|
||||
isNewFile = true;
|
||||
} else if (!fileExists) {
|
||||
// Trying to edit a non-existent file (and old_string is not empty)
|
||||
error = {
|
||||
display: `File not found. Cannot apply edit. Use an empty old_string to create a new file.`,
|
||||
raw: `File not found: ${params.file_path}`,
|
||||
};
|
||||
} else if (currentContent !== null) {
|
||||
// Editing an existing file
|
||||
const correctedEdit = await ensureCorrectEdit(
|
||||
params.file_path,
|
||||
currentContent,
|
||||
params,
|
||||
this.config.getGeminiClient(),
|
||||
abortSignal,
|
||||
);
|
||||
finalOldString = correctedEdit.params.old_string;
|
||||
finalNewString = correctedEdit.params.new_string;
|
||||
occurrences = correctedEdit.occurrences;
|
||||
|
||||
if (params.old_string === '') {
|
||||
// Error: Trying to create a file that already exists
|
||||
error = {
|
||||
display: `Failed to edit. Attempted to create a file that already exists.`,
|
||||
raw: `File already exists, cannot create: ${params.file_path}`,
|
||||
};
|
||||
} else if (occurrences === 0) {
|
||||
error = {
|
||||
display: `Failed to edit, could not find the string to replace.`,
|
||||
raw: `Failed to edit, 0 occurrences found for old_string in ${params.file_path}. No edits made. The exact text in old_string was not found. Ensure you're not escaping content incorrectly and check whitespace, indentation, and context. Use ${ReadFileTool.Name} tool to verify.`,
|
||||
};
|
||||
} else if (occurrences !== expectedReplacements) {
|
||||
const occurenceTerm =
|
||||
expectedReplacements === 1 ? 'occurrence' : 'occurrences';
|
||||
|
||||
error = {
|
||||
display: `Failed to edit, expected ${expectedReplacements} ${occurenceTerm} but found ${occurrences}.`,
|
||||
raw: `Failed to edit, Expected ${expectedReplacements} ${occurenceTerm} but found ${occurrences} for old_string in file: ${params.file_path}`,
|
||||
};
|
||||
}
|
||||
} else {
|
||||
// Should not happen if fileExists and no exception was thrown, but defensively:
|
||||
error = {
|
||||
display: `Failed to read content of file.`,
|
||||
raw: `Failed to read content of existing file: ${params.file_path}`,
|
||||
};
|
||||
}
|
||||
|
||||
const newContent = this._applyReplacement(
|
||||
currentContent,
|
||||
finalOldString,
|
||||
finalNewString,
|
||||
isNewFile,
|
||||
);
|
||||
|
||||
return {
|
||||
currentContent,
|
||||
newContent,
|
||||
occurrences,
|
||||
error,
|
||||
isNewFile,
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Handles the confirmation prompt for the Edit tool in the CLI.
|
||||
* It needs to calculate the diff to show the user.
|
||||
*/
|
||||
async shouldConfirmExecute(
|
||||
params: EditToolParams,
|
||||
abortSignal: AbortSignal,
|
||||
): Promise<ToolCallConfirmationDetails | false> {
|
||||
if (this.config.getApprovalMode() === ApprovalMode.AUTO_EDIT) {
|
||||
return false;
|
||||
}
|
||||
const validationError = this.validateToolParams(params);
|
||||
if (validationError) {
|
||||
console.error(
|
||||
`[EditTool Wrapper] Attempted confirmation with invalid parameters: ${validationError}`,
|
||||
);
|
||||
return false;
|
||||
}
|
||||
|
||||
let editData: CalculatedEdit;
|
||||
try {
|
||||
editData = await this.calculateEdit(params, abortSignal);
|
||||
} catch (error) {
|
||||
const errorMsg = error instanceof Error ? error.message : String(error);
|
||||
console.log(`Error preparing edit: ${errorMsg}`);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (editData.error) {
|
||||
console.log(`Error: ${editData.error.display}`);
|
||||
return false;
|
||||
}
|
||||
|
||||
const fileName = path.basename(params.file_path);
|
||||
const fileDiff = Diff.createPatch(
|
||||
fileName,
|
||||
editData.currentContent ?? '',
|
||||
editData.newContent,
|
||||
'Current',
|
||||
'Proposed',
|
||||
DEFAULT_DIFF_OPTIONS,
|
||||
);
|
||||
const confirmationDetails: ToolEditConfirmationDetails = {
|
||||
type: 'edit',
|
||||
title: `Confirm Edit: ${shortenPath(makeRelative(params.file_path, this.config.getTargetDir()))}`,
|
||||
fileName,
|
||||
fileDiff,
|
||||
onConfirm: async (outcome: ToolConfirmationOutcome) => {
|
||||
if (outcome === ToolConfirmationOutcome.ProceedAlways) {
|
||||
this.config.setApprovalMode(ApprovalMode.AUTO_EDIT);
|
||||
}
|
||||
},
|
||||
};
|
||||
return confirmationDetails;
|
||||
}
|
||||
|
||||
getDescription(params: EditToolParams): string {
|
||||
if (!params.file_path || !params.old_string || !params.new_string) {
|
||||
return `Model did not provide valid parameters for edit tool`;
|
||||
}
|
||||
const relativePath = makeRelative(
|
||||
params.file_path,
|
||||
this.config.getTargetDir(),
|
||||
);
|
||||
if (params.old_string === '') {
|
||||
return `Create ${shortenPath(relativePath)}`;
|
||||
}
|
||||
|
||||
const oldStringSnippet =
|
||||
params.old_string.split('\n')[0].substring(0, 30) +
|
||||
(params.old_string.length > 30 ? '...' : '');
|
||||
const newStringSnippet =
|
||||
params.new_string.split('\n')[0].substring(0, 30) +
|
||||
(params.new_string.length > 30 ? '...' : '');
|
||||
|
||||
if (params.old_string === params.new_string) {
|
||||
return `No file changes to ${shortenPath(relativePath)}`;
|
||||
}
|
||||
return `${shortenPath(relativePath)}: ${oldStringSnippet} => ${newStringSnippet}`;
|
||||
}
|
||||
|
||||
/**
|
||||
* Executes the edit operation with the given parameters.
|
||||
* @param params Parameters for the edit operation
|
||||
* @returns Result of the edit operation
|
||||
*/
|
||||
async execute(
|
||||
params: EditToolParams,
|
||||
signal: AbortSignal,
|
||||
): Promise<ToolResult> {
|
||||
const validationError = this.validateToolParams(params);
|
||||
if (validationError) {
|
||||
return {
|
||||
llmContent: `Error: Invalid parameters provided. Reason: ${validationError}`,
|
||||
returnDisplay: `Error: ${validationError}`,
|
||||
};
|
||||
}
|
||||
|
||||
let editData: CalculatedEdit;
|
||||
try {
|
||||
editData = await this.calculateEdit(params, signal);
|
||||
} catch (error) {
|
||||
const errorMsg = error instanceof Error ? error.message : String(error);
|
||||
return {
|
||||
llmContent: `Error preparing edit: ${errorMsg}`,
|
||||
returnDisplay: `Error preparing edit: ${errorMsg}`,
|
||||
};
|
||||
}
|
||||
|
||||
if (editData.error) {
|
||||
return {
|
||||
llmContent: editData.error.raw,
|
||||
returnDisplay: `Error: ${editData.error.display}`,
|
||||
};
|
||||
}
|
||||
|
||||
try {
|
||||
this.ensureParentDirectoriesExist(params.file_path);
|
||||
fs.writeFileSync(params.file_path, editData.newContent, 'utf8');
|
||||
|
||||
let displayResult: ToolResultDisplay;
|
||||
if (editData.isNewFile) {
|
||||
displayResult = `Created ${shortenPath(makeRelative(params.file_path, this.config.getTargetDir()))}`;
|
||||
} else {
|
||||
// Generate diff for display, even though core logic doesn't technically need it
|
||||
// The CLI wrapper will use this part of the ToolResult
|
||||
const fileName = path.basename(params.file_path);
|
||||
const fileDiff = Diff.createPatch(
|
||||
fileName,
|
||||
editData.currentContent ?? '', // Should not be null here if not isNewFile
|
||||
editData.newContent,
|
||||
'Current',
|
||||
'Proposed',
|
||||
DEFAULT_DIFF_OPTIONS,
|
||||
);
|
||||
displayResult = { fileDiff, fileName };
|
||||
}
|
||||
|
||||
const llmSuccessMessageParts = [
|
||||
editData.isNewFile
|
||||
? `Created new file: ${params.file_path} with provided content.`
|
||||
: `Successfully modified file: ${params.file_path} (${editData.occurrences} replacements).`,
|
||||
];
|
||||
if (params.modified_by_user) {
|
||||
llmSuccessMessageParts.push(
|
||||
`User modified the \`new_string\` content to be: ${params.new_string}.`,
|
||||
);
|
||||
}
|
||||
|
||||
return {
|
||||
llmContent: llmSuccessMessageParts.join(' '),
|
||||
returnDisplay: displayResult,
|
||||
};
|
||||
} catch (error) {
|
||||
const errorMsg = error instanceof Error ? error.message : String(error);
|
||||
return {
|
||||
llmContent: `Error executing edit: ${errorMsg}`,
|
||||
returnDisplay: `Error writing file: ${errorMsg}`,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates parent directories if they don't exist
|
||||
*/
|
||||
private ensureParentDirectoriesExist(filePath: string): void {
|
||||
const dirName = path.dirname(filePath);
|
||||
if (!fs.existsSync(dirName)) {
|
||||
fs.mkdirSync(dirName, { recursive: true });
|
||||
}
|
||||
}
|
||||
|
||||
getModifyContext(_: AbortSignal): ModifyContext<EditToolParams> {
|
||||
return {
|
||||
getFilePath: (params: EditToolParams) => params.file_path,
|
||||
getCurrentContent: async (params: EditToolParams): Promise<string> => {
|
||||
try {
|
||||
return fs.readFileSync(params.file_path, 'utf8');
|
||||
} catch (err) {
|
||||
if (!isNodeError(err) || err.code !== 'ENOENT') throw err;
|
||||
return '';
|
||||
}
|
||||
},
|
||||
getProposedContent: async (params: EditToolParams): Promise<string> => {
|
||||
try {
|
||||
const currentContent = fs.readFileSync(params.file_path, 'utf8');
|
||||
return this._applyReplacement(
|
||||
currentContent,
|
||||
params.old_string,
|
||||
params.new_string,
|
||||
params.old_string === '' && currentContent === '',
|
||||
);
|
||||
} catch (err) {
|
||||
if (!isNodeError(err) || err.code !== 'ENOENT') throw err;
|
||||
return '';
|
||||
}
|
||||
},
|
||||
createUpdatedParams: (
|
||||
oldContent: string,
|
||||
modifiedProposedContent: string,
|
||||
originalParams: EditToolParams,
|
||||
): EditToolParams => ({
|
||||
...originalParams,
|
||||
old_string: oldContent,
|
||||
new_string: modifiedProposedContent,
|
||||
modified_by_user: true,
|
||||
}),
|
||||
};
|
||||
}
|
||||
}
|
||||
376
packages/core/src/tools/glob.test.ts
Normal file
376
packages/core/src/tools/glob.test.ts
Normal file
@@ -0,0 +1,376 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { GlobTool, GlobToolParams, GlobPath, sortFileEntries } from './glob.js';
|
||||
import { partListUnionToString } from '../core/geminiRequest.js';
|
||||
import path from 'path';
|
||||
import fs from 'fs/promises';
|
||||
import os from 'os';
|
||||
import { describe, it, expect, beforeEach, afterEach } from 'vitest'; // Removed vi
|
||||
import { FileDiscoveryService } from '../services/fileDiscoveryService.js';
|
||||
import { Config } from '../config/config.js';
|
||||
|
||||
describe('GlobTool', () => {
|
||||
let tempRootDir: string; // This will be the rootDirectory for the GlobTool instance
|
||||
let globTool: GlobTool;
|
||||
const abortSignal = new AbortController().signal;
|
||||
|
||||
// Mock config for testing
|
||||
const mockConfig = {
|
||||
getFileService: () => new FileDiscoveryService(tempRootDir),
|
||||
getFileFilteringRespectGitIgnore: () => true,
|
||||
getTargetDir: () => tempRootDir,
|
||||
} as unknown as Config;
|
||||
|
||||
beforeEach(async () => {
|
||||
// Create a unique root directory for each test run
|
||||
tempRootDir = await fs.mkdtemp(path.join(os.tmpdir(), 'glob-tool-root-'));
|
||||
globTool = new GlobTool(mockConfig);
|
||||
|
||||
// Create some test files and directories within this root
|
||||
// Top-level files
|
||||
await fs.writeFile(path.join(tempRootDir, 'fileA.txt'), 'contentA');
|
||||
await fs.writeFile(path.join(tempRootDir, 'FileB.TXT'), 'contentB'); // Different case for testing
|
||||
|
||||
// Subdirectory and files within it
|
||||
await fs.mkdir(path.join(tempRootDir, 'sub'));
|
||||
await fs.writeFile(path.join(tempRootDir, 'sub', 'fileC.md'), 'contentC');
|
||||
await fs.writeFile(path.join(tempRootDir, 'sub', 'FileD.MD'), 'contentD'); // Different case
|
||||
|
||||
// Deeper subdirectory
|
||||
await fs.mkdir(path.join(tempRootDir, 'sub', 'deep'));
|
||||
await fs.writeFile(
|
||||
path.join(tempRootDir, 'sub', 'deep', 'fileE.log'),
|
||||
'contentE',
|
||||
);
|
||||
|
||||
// Files for mtime sorting test
|
||||
await fs.writeFile(path.join(tempRootDir, 'older.sortme'), 'older_content');
|
||||
// Ensure a noticeable difference in modification time
|
||||
await new Promise((resolve) => setTimeout(resolve, 50));
|
||||
await fs.writeFile(path.join(tempRootDir, 'newer.sortme'), 'newer_content');
|
||||
});
|
||||
|
||||
afterEach(async () => {
|
||||
// Clean up the temporary root directory
|
||||
await fs.rm(tempRootDir, { recursive: true, force: true });
|
||||
});
|
||||
|
||||
describe('execute', () => {
|
||||
it('should find files matching a simple pattern in the root', async () => {
|
||||
const params: GlobToolParams = { pattern: '*.txt' };
|
||||
const result = await globTool.execute(params, abortSignal);
|
||||
expect(result.llmContent).toContain('Found 2 file(s)');
|
||||
expect(result.llmContent).toContain(path.join(tempRootDir, 'fileA.txt'));
|
||||
expect(result.llmContent).toContain(path.join(tempRootDir, 'FileB.TXT'));
|
||||
expect(result.returnDisplay).toBe('Found 2 matching file(s)');
|
||||
});
|
||||
|
||||
it('should find files case-sensitively when case_sensitive is true', async () => {
|
||||
const params: GlobToolParams = { pattern: '*.txt', case_sensitive: true };
|
||||
const result = await globTool.execute(params, abortSignal);
|
||||
expect(result.llmContent).toContain('Found 1 file(s)');
|
||||
expect(result.llmContent).toContain(path.join(tempRootDir, 'fileA.txt'));
|
||||
expect(result.llmContent).not.toContain(
|
||||
path.join(tempRootDir, 'FileB.TXT'),
|
||||
);
|
||||
});
|
||||
|
||||
it('should find files case-insensitively by default (pattern: *.TXT)', async () => {
|
||||
const params: GlobToolParams = { pattern: '*.TXT' };
|
||||
const result = await globTool.execute(params, abortSignal);
|
||||
expect(result.llmContent).toContain('Found 2 file(s)');
|
||||
expect(result.llmContent).toContain(path.join(tempRootDir, 'fileA.txt'));
|
||||
expect(result.llmContent).toContain(path.join(tempRootDir, 'FileB.TXT'));
|
||||
});
|
||||
|
||||
it('should find files case-insensitively when case_sensitive is false (pattern: *.TXT)', async () => {
|
||||
const params: GlobToolParams = {
|
||||
pattern: '*.TXT',
|
||||
case_sensitive: false,
|
||||
};
|
||||
const result = await globTool.execute(params, abortSignal);
|
||||
expect(result.llmContent).toContain('Found 2 file(s)');
|
||||
expect(result.llmContent).toContain(path.join(tempRootDir, 'fileA.txt'));
|
||||
expect(result.llmContent).toContain(path.join(tempRootDir, 'FileB.TXT'));
|
||||
});
|
||||
|
||||
it('should find files using a pattern that includes a subdirectory', async () => {
|
||||
const params: GlobToolParams = { pattern: 'sub/*.md' };
|
||||
const result = await globTool.execute(params, abortSignal);
|
||||
expect(result.llmContent).toContain('Found 2 file(s)');
|
||||
expect(result.llmContent).toContain(
|
||||
path.join(tempRootDir, 'sub', 'fileC.md'),
|
||||
);
|
||||
expect(result.llmContent).toContain(
|
||||
path.join(tempRootDir, 'sub', 'FileD.MD'),
|
||||
);
|
||||
});
|
||||
|
||||
it('should find files in a specified relative path (relative to rootDir)', async () => {
|
||||
const params: GlobToolParams = { pattern: '*.md', path: 'sub' };
|
||||
const result = await globTool.execute(params, abortSignal);
|
||||
expect(result.llmContent).toContain('Found 2 file(s)');
|
||||
expect(result.llmContent).toContain(
|
||||
path.join(tempRootDir, 'sub', 'fileC.md'),
|
||||
);
|
||||
expect(result.llmContent).toContain(
|
||||
path.join(tempRootDir, 'sub', 'FileD.MD'),
|
||||
);
|
||||
});
|
||||
|
||||
it('should find files using a deep globstar pattern (e.g., **/*.log)', async () => {
|
||||
const params: GlobToolParams = { pattern: '**/*.log' };
|
||||
const result = await globTool.execute(params, abortSignal);
|
||||
expect(result.llmContent).toContain('Found 1 file(s)');
|
||||
expect(result.llmContent).toContain(
|
||||
path.join(tempRootDir, 'sub', 'deep', 'fileE.log'),
|
||||
);
|
||||
});
|
||||
|
||||
it('should return "No files found" message when pattern matches nothing', async () => {
|
||||
const params: GlobToolParams = { pattern: '*.nonexistent' };
|
||||
const result = await globTool.execute(params, abortSignal);
|
||||
expect(result.llmContent).toContain(
|
||||
'No files found matching pattern "*.nonexistent"',
|
||||
);
|
||||
expect(result.returnDisplay).toBe('No files found');
|
||||
});
|
||||
|
||||
it('should correctly sort files by modification time (newest first)', async () => {
|
||||
const params: GlobToolParams = { pattern: '*.sortme' };
|
||||
const result = await globTool.execute(params, abortSignal);
|
||||
const llmContent = partListUnionToString(result.llmContent);
|
||||
|
||||
expect(llmContent).toContain('Found 2 file(s)');
|
||||
// Ensure llmContent is a string for TypeScript type checking
|
||||
expect(typeof llmContent).toBe('string');
|
||||
|
||||
const filesListed = llmContent
|
||||
.substring(llmContent.indexOf(':') + 1)
|
||||
.trim()
|
||||
.split('\n');
|
||||
expect(filesListed[0]).toContain(path.join(tempRootDir, 'newer.sortme'));
|
||||
expect(filesListed[1]).toContain(path.join(tempRootDir, 'older.sortme'));
|
||||
});
|
||||
});
|
||||
|
||||
describe('validateToolParams', () => {
|
||||
it('should return null for valid parameters (pattern only)', () => {
|
||||
const params: GlobToolParams = { pattern: '*.js' };
|
||||
expect(globTool.validateToolParams(params)).toBeNull();
|
||||
});
|
||||
|
||||
it('should return null for valid parameters (pattern and path)', () => {
|
||||
const params: GlobToolParams = { pattern: '*.js', path: 'sub' };
|
||||
expect(globTool.validateToolParams(params)).toBeNull();
|
||||
});
|
||||
|
||||
it('should return null for valid parameters (pattern, path, and case_sensitive)', () => {
|
||||
const params: GlobToolParams = {
|
||||
pattern: '*.js',
|
||||
path: 'sub',
|
||||
case_sensitive: true,
|
||||
};
|
||||
expect(globTool.validateToolParams(params)).toBeNull();
|
||||
});
|
||||
|
||||
it('should return error if pattern is missing (schema validation)', () => {
|
||||
// Need to correctly define this as an object without pattern
|
||||
const params = { path: '.' };
|
||||
// @ts-expect-error - We're intentionally creating invalid params for testing
|
||||
expect(globTool.validateToolParams(params)).toBe(
|
||||
`params must have required property 'pattern'`,
|
||||
);
|
||||
});
|
||||
|
||||
it('should return error if pattern is an empty string', () => {
|
||||
const params: GlobToolParams = { pattern: '' };
|
||||
expect(globTool.validateToolParams(params)).toContain(
|
||||
"The 'pattern' parameter cannot be empty.",
|
||||
);
|
||||
});
|
||||
|
||||
it('should return error if pattern is only whitespace', () => {
|
||||
const params: GlobToolParams = { pattern: ' ' };
|
||||
expect(globTool.validateToolParams(params)).toContain(
|
||||
"The 'pattern' parameter cannot be empty.",
|
||||
);
|
||||
});
|
||||
|
||||
it('should return error if path is provided but is not a string (schema validation)', () => {
|
||||
const params = {
|
||||
pattern: '*.ts',
|
||||
path: 123,
|
||||
};
|
||||
// @ts-expect-error - We're intentionally creating invalid params for testing
|
||||
expect(globTool.validateToolParams(params)).toBe(
|
||||
'params/path must be string',
|
||||
);
|
||||
});
|
||||
|
||||
it('should return error if case_sensitive is provided but is not a boolean (schema validation)', () => {
|
||||
const params = {
|
||||
pattern: '*.ts',
|
||||
case_sensitive: 'true',
|
||||
};
|
||||
// @ts-expect-error - We're intentionally creating invalid params for testing
|
||||
expect(globTool.validateToolParams(params)).toBe(
|
||||
'params/case_sensitive must be boolean',
|
||||
);
|
||||
});
|
||||
|
||||
it("should return error if search path resolves outside the tool's root directory", () => {
|
||||
// Create a globTool instance specifically for this test, with a deeper root
|
||||
tempRootDir = path.join(tempRootDir, 'sub');
|
||||
const specificGlobTool = new GlobTool(mockConfig);
|
||||
// const params: GlobToolParams = { pattern: '*.txt', path: '..' }; // This line is unused and will be removed.
|
||||
// This should be fine as tempRootDir is still within the original tempRootDir (the parent of deeperRootDir)
|
||||
// Let's try to go further up.
|
||||
const paramsOutside: GlobToolParams = {
|
||||
pattern: '*.txt',
|
||||
path: '../../../../../../../../../../tmp',
|
||||
}; // Definitely outside
|
||||
expect(specificGlobTool.validateToolParams(paramsOutside)).toContain(
|
||||
"resolves outside the tool's root directory",
|
||||
);
|
||||
});
|
||||
|
||||
it('should return error if specified search path does not exist', async () => {
|
||||
const params: GlobToolParams = {
|
||||
pattern: '*.txt',
|
||||
path: 'nonexistent_subdir',
|
||||
};
|
||||
expect(globTool.validateToolParams(params)).toContain(
|
||||
'Search path does not exist',
|
||||
);
|
||||
});
|
||||
|
||||
it('should return error if specified search path is a file, not a directory', async () => {
|
||||
const params: GlobToolParams = { pattern: '*.txt', path: 'fileA.txt' };
|
||||
expect(globTool.validateToolParams(params)).toContain(
|
||||
'Search path is not a directory',
|
||||
);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('sortFileEntries', () => {
|
||||
const nowTimestamp = new Date('2024-01-15T12:00:00.000Z').getTime();
|
||||
const oneDayInMs = 24 * 60 * 60 * 1000;
|
||||
|
||||
const createFileEntry = (fullpath: string, mtimeDate: Date): GlobPath => ({
|
||||
fullpath: () => fullpath,
|
||||
mtimeMs: mtimeDate.getTime(),
|
||||
});
|
||||
|
||||
it('should sort a mix of recent and older files correctly', () => {
|
||||
const recentTime1 = new Date(nowTimestamp - 1 * 60 * 60 * 1000); // 1 hour ago
|
||||
const recentTime2 = new Date(nowTimestamp - 2 * 60 * 60 * 1000); // 2 hours ago
|
||||
const olderTime1 = new Date(
|
||||
nowTimestamp - (oneDayInMs + 1 * 60 * 60 * 1000),
|
||||
); // 25 hours ago
|
||||
const olderTime2 = new Date(
|
||||
nowTimestamp - (oneDayInMs + 2 * 60 * 60 * 1000),
|
||||
); // 26 hours ago
|
||||
|
||||
const entries: GlobPath[] = [
|
||||
createFileEntry('older_zebra.txt', olderTime2),
|
||||
createFileEntry('recent_alpha.txt', recentTime1),
|
||||
createFileEntry('older_apple.txt', olderTime1),
|
||||
createFileEntry('recent_beta.txt', recentTime2),
|
||||
createFileEntry('older_banana.txt', olderTime1), // Same mtime as apple
|
||||
];
|
||||
|
||||
const sorted = sortFileEntries(entries, nowTimestamp, oneDayInMs);
|
||||
const sortedPaths = sorted.map((e) => e.fullpath());
|
||||
|
||||
expect(sortedPaths).toEqual([
|
||||
'recent_alpha.txt', // Recent, newest
|
||||
'recent_beta.txt', // Recent, older
|
||||
'older_apple.txt', // Older, alphabetical
|
||||
'older_banana.txt', // Older, alphabetical
|
||||
'older_zebra.txt', // Older, alphabetical
|
||||
]);
|
||||
});
|
||||
|
||||
it('should sort only recent files by mtime descending', () => {
|
||||
const recentTime1 = new Date(nowTimestamp - 1000); // Newest
|
||||
const recentTime2 = new Date(nowTimestamp - 2000);
|
||||
const recentTime3 = new Date(nowTimestamp - 3000); // Oldest recent
|
||||
|
||||
const entries: GlobPath[] = [
|
||||
createFileEntry('c.txt', recentTime2),
|
||||
createFileEntry('a.txt', recentTime3),
|
||||
createFileEntry('b.txt', recentTime1),
|
||||
];
|
||||
const sorted = sortFileEntries(entries, nowTimestamp, oneDayInMs);
|
||||
expect(sorted.map((e) => e.fullpath())).toEqual([
|
||||
'b.txt',
|
||||
'c.txt',
|
||||
'a.txt',
|
||||
]);
|
||||
});
|
||||
|
||||
it('should sort only older files alphabetically by path', () => {
|
||||
const olderTime = new Date(nowTimestamp - 2 * oneDayInMs); // All equally old
|
||||
const entries: GlobPath[] = [
|
||||
createFileEntry('zebra.txt', olderTime),
|
||||
createFileEntry('apple.txt', olderTime),
|
||||
createFileEntry('banana.txt', olderTime),
|
||||
];
|
||||
const sorted = sortFileEntries(entries, nowTimestamp, oneDayInMs);
|
||||
expect(sorted.map((e) => e.fullpath())).toEqual([
|
||||
'apple.txt',
|
||||
'banana.txt',
|
||||
'zebra.txt',
|
||||
]);
|
||||
});
|
||||
|
||||
it('should handle an empty array', () => {
|
||||
const entries: GlobPath[] = [];
|
||||
const sorted = sortFileEntries(entries, nowTimestamp, oneDayInMs);
|
||||
expect(sorted).toEqual([]);
|
||||
});
|
||||
|
||||
it('should correctly sort files when mtimes are identical for older files', () => {
|
||||
const olderTime = new Date(nowTimestamp - 2 * oneDayInMs);
|
||||
const entries: GlobPath[] = [
|
||||
createFileEntry('b.txt', olderTime),
|
||||
createFileEntry('a.txt', olderTime),
|
||||
];
|
||||
const sorted = sortFileEntries(entries, nowTimestamp, oneDayInMs);
|
||||
expect(sorted.map((e) => e.fullpath())).toEqual(['a.txt', 'b.txt']);
|
||||
});
|
||||
|
||||
it('should correctly sort files when mtimes are identical for recent files (maintaining mtime sort)', () => {
|
||||
const recentTime = new Date(nowTimestamp - 1000);
|
||||
const entries: GlobPath[] = [
|
||||
createFileEntry('b.txt', recentTime),
|
||||
createFileEntry('a.txt', recentTime),
|
||||
];
|
||||
const sorted = sortFileEntries(entries, nowTimestamp, oneDayInMs);
|
||||
expect(sorted.map((e) => e.fullpath())).toContain('a.txt');
|
||||
expect(sorted.map((e) => e.fullpath())).toContain('b.txt');
|
||||
expect(sorted.length).toBe(2);
|
||||
});
|
||||
|
||||
it('should use recencyThresholdMs parameter correctly', () => {
|
||||
const justOverThreshold = new Date(nowTimestamp - (1000 + 1)); // Barely older
|
||||
const justUnderThreshold = new Date(nowTimestamp - (1000 - 1)); // Barely recent
|
||||
const customThresholdMs = 1000; // 1 second
|
||||
|
||||
const entries: GlobPath[] = [
|
||||
createFileEntry('older_file.txt', justOverThreshold),
|
||||
createFileEntry('recent_file.txt', justUnderThreshold),
|
||||
];
|
||||
const sorted = sortFileEntries(entries, nowTimestamp, customThresholdMs);
|
||||
expect(sorted.map((e) => e.fullpath())).toEqual([
|
||||
'recent_file.txt',
|
||||
'older_file.txt',
|
||||
]);
|
||||
});
|
||||
});
|
||||
285
packages/core/src/tools/glob.ts
Normal file
285
packages/core/src/tools/glob.ts
Normal file
@@ -0,0 +1,285 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import fs from 'fs';
|
||||
import path from 'path';
|
||||
import { glob } from 'glob';
|
||||
import { SchemaValidator } from '../utils/schemaValidator.js';
|
||||
import { BaseTool, ToolResult } from './tools.js';
|
||||
import { Type } from '@google/genai';
|
||||
import { shortenPath, makeRelative } from '../utils/paths.js';
|
||||
import { isWithinRoot } from '../utils/fileUtils.js';
|
||||
import { Config } from '../config/config.js';
|
||||
|
||||
// Subset of 'Path' interface provided by 'glob' that we can implement for testing
|
||||
export interface GlobPath {
|
||||
fullpath(): string;
|
||||
mtimeMs?: number;
|
||||
}
|
||||
|
||||
/**
|
||||
* Sorts file entries based on recency and then alphabetically.
|
||||
* Recent files (modified within recencyThresholdMs) are listed first, newest to oldest.
|
||||
* Older files are listed after recent ones, sorted alphabetically by path.
|
||||
*/
|
||||
export function sortFileEntries(
|
||||
entries: GlobPath[],
|
||||
nowTimestamp: number,
|
||||
recencyThresholdMs: number,
|
||||
): GlobPath[] {
|
||||
const sortedEntries = [...entries];
|
||||
sortedEntries.sort((a, b) => {
|
||||
const mtimeA = a.mtimeMs ?? 0;
|
||||
const mtimeB = b.mtimeMs ?? 0;
|
||||
const aIsRecent = nowTimestamp - mtimeA < recencyThresholdMs;
|
||||
const bIsRecent = nowTimestamp - mtimeB < recencyThresholdMs;
|
||||
|
||||
if (aIsRecent && bIsRecent) {
|
||||
return mtimeB - mtimeA;
|
||||
} else if (aIsRecent) {
|
||||
return -1;
|
||||
} else if (bIsRecent) {
|
||||
return 1;
|
||||
} else {
|
||||
return a.fullpath().localeCompare(b.fullpath());
|
||||
}
|
||||
});
|
||||
return sortedEntries;
|
||||
}
|
||||
|
||||
/**
|
||||
* Parameters for the GlobTool
|
||||
*/
|
||||
export interface GlobToolParams {
|
||||
/**
|
||||
* The glob pattern to match files against
|
||||
*/
|
||||
pattern: string;
|
||||
|
||||
/**
|
||||
* The directory to search in (optional, defaults to current directory)
|
||||
*/
|
||||
path?: string;
|
||||
|
||||
/**
|
||||
* Whether the search should be case-sensitive (optional, defaults to false)
|
||||
*/
|
||||
case_sensitive?: boolean;
|
||||
|
||||
/**
|
||||
* Whether to respect .gitignore patterns (optional, defaults to true)
|
||||
*/
|
||||
respect_git_ignore?: boolean;
|
||||
}
|
||||
|
||||
/**
|
||||
* Implementation of the Glob tool logic
|
||||
*/
|
||||
export class GlobTool extends BaseTool<GlobToolParams, ToolResult> {
|
||||
static readonly Name = 'glob';
|
||||
|
||||
constructor(private config: Config) {
|
||||
super(
|
||||
GlobTool.Name,
|
||||
'FindFiles',
|
||||
'Efficiently finds files matching specific glob patterns (e.g., `src/**/*.ts`, `**/*.md`), returning absolute paths sorted by modification time (newest first). Ideal for quickly locating files based on their name or path structure, especially in large codebases.',
|
||||
{
|
||||
properties: {
|
||||
pattern: {
|
||||
description:
|
||||
"The glob pattern to match against (e.g., '**/*.py', 'docs/*.md').",
|
||||
type: Type.STRING,
|
||||
},
|
||||
path: {
|
||||
description:
|
||||
'Optional: The absolute path to the directory to search within. If omitted, searches the root directory.',
|
||||
type: Type.STRING,
|
||||
},
|
||||
case_sensitive: {
|
||||
description:
|
||||
'Optional: Whether the search should be case-sensitive. Defaults to false.',
|
||||
type: Type.BOOLEAN,
|
||||
},
|
||||
respect_git_ignore: {
|
||||
description:
|
||||
'Optional: Whether to respect .gitignore patterns when finding files. Only available in git repositories. Defaults to true.',
|
||||
type: Type.BOOLEAN,
|
||||
},
|
||||
},
|
||||
required: ['pattern'],
|
||||
type: Type.OBJECT,
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Validates the parameters for the tool.
|
||||
*/
|
||||
validateToolParams(params: GlobToolParams): string | null {
|
||||
const errors = SchemaValidator.validate(this.schema.parameters, params);
|
||||
if (errors) {
|
||||
return errors;
|
||||
}
|
||||
|
||||
const searchDirAbsolute = path.resolve(
|
||||
this.config.getTargetDir(),
|
||||
params.path || '.',
|
||||
);
|
||||
|
||||
if (!isWithinRoot(searchDirAbsolute, this.config.getTargetDir())) {
|
||||
return `Search path ("${searchDirAbsolute}") resolves outside the tool's root directory ("${this.config.getTargetDir()}").`;
|
||||
}
|
||||
|
||||
const targetDir = searchDirAbsolute || this.config.getTargetDir();
|
||||
try {
|
||||
if (!fs.existsSync(targetDir)) {
|
||||
return `Search path does not exist ${targetDir}`;
|
||||
}
|
||||
if (!fs.statSync(targetDir).isDirectory()) {
|
||||
return `Search path is not a directory: ${targetDir}`;
|
||||
}
|
||||
} catch (e: unknown) {
|
||||
return `Error accessing search path: ${e}`;
|
||||
}
|
||||
|
||||
if (
|
||||
!params.pattern ||
|
||||
typeof params.pattern !== 'string' ||
|
||||
params.pattern.trim() === ''
|
||||
) {
|
||||
return "The 'pattern' parameter cannot be empty.";
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets a description of the glob operation.
|
||||
*/
|
||||
getDescription(params: GlobToolParams): string {
|
||||
let description = `'${params.pattern}'`;
|
||||
if (params.path) {
|
||||
const searchDir = path.resolve(
|
||||
this.config.getTargetDir(),
|
||||
params.path || '.',
|
||||
);
|
||||
const relativePath = makeRelative(searchDir, this.config.getTargetDir());
|
||||
description += ` within ${shortenPath(relativePath)}`;
|
||||
}
|
||||
return description;
|
||||
}
|
||||
|
||||
/**
|
||||
* Executes the glob search with the given parameters
|
||||
*/
|
||||
async execute(
|
||||
params: GlobToolParams,
|
||||
signal: AbortSignal,
|
||||
): Promise<ToolResult> {
|
||||
const validationError = this.validateToolParams(params);
|
||||
if (validationError) {
|
||||
return {
|
||||
llmContent: `Error: Invalid parameters provided. Reason: ${validationError}`,
|
||||
returnDisplay: validationError,
|
||||
};
|
||||
}
|
||||
|
||||
try {
|
||||
const searchDirAbsolute = path.resolve(
|
||||
this.config.getTargetDir(),
|
||||
params.path || '.',
|
||||
);
|
||||
|
||||
// Get centralized file discovery service
|
||||
const respectGitIgnore =
|
||||
params.respect_git_ignore ??
|
||||
this.config.getFileFilteringRespectGitIgnore();
|
||||
const fileDiscovery = this.config.getFileService();
|
||||
|
||||
const entries = (await glob(params.pattern, {
|
||||
cwd: searchDirAbsolute,
|
||||
withFileTypes: true,
|
||||
nodir: true,
|
||||
stat: true,
|
||||
nocase: !params.case_sensitive,
|
||||
dot: true,
|
||||
ignore: ['**/node_modules/**', '**/.git/**'],
|
||||
follow: false,
|
||||
signal,
|
||||
})) as GlobPath[];
|
||||
|
||||
// Apply git-aware filtering if enabled and in git repository
|
||||
let filteredEntries = entries;
|
||||
let gitIgnoredCount = 0;
|
||||
|
||||
if (respectGitIgnore) {
|
||||
const relativePaths = entries.map((p) =>
|
||||
path.relative(this.config.getTargetDir(), p.fullpath()),
|
||||
);
|
||||
const filteredRelativePaths = fileDiscovery.filterFiles(relativePaths, {
|
||||
respectGitIgnore,
|
||||
});
|
||||
const filteredAbsolutePaths = new Set(
|
||||
filteredRelativePaths.map((p) =>
|
||||
path.resolve(this.config.getTargetDir(), p),
|
||||
),
|
||||
);
|
||||
|
||||
filteredEntries = entries.filter((entry) =>
|
||||
filteredAbsolutePaths.has(entry.fullpath()),
|
||||
);
|
||||
gitIgnoredCount = entries.length - filteredEntries.length;
|
||||
}
|
||||
|
||||
if (!filteredEntries || filteredEntries.length === 0) {
|
||||
let message = `No files found matching pattern "${params.pattern}" within ${searchDirAbsolute}.`;
|
||||
if (gitIgnoredCount > 0) {
|
||||
message += ` (${gitIgnoredCount} files were git-ignored)`;
|
||||
}
|
||||
return {
|
||||
llmContent: message,
|
||||
returnDisplay: `No files found`,
|
||||
};
|
||||
}
|
||||
|
||||
// Set filtering such that we first show the most recent files
|
||||
const oneDayInMs = 24 * 60 * 60 * 1000;
|
||||
const nowTimestamp = new Date().getTime();
|
||||
|
||||
// Sort the filtered entries using the new helper function
|
||||
const sortedEntries = sortFileEntries(
|
||||
filteredEntries,
|
||||
nowTimestamp,
|
||||
oneDayInMs,
|
||||
);
|
||||
|
||||
const sortedAbsolutePaths = sortedEntries.map((entry) =>
|
||||
entry.fullpath(),
|
||||
);
|
||||
const fileListDescription = sortedAbsolutePaths.join('\n');
|
||||
const fileCount = sortedAbsolutePaths.length;
|
||||
|
||||
let resultMessage = `Found ${fileCount} file(s) matching "${params.pattern}" within ${searchDirAbsolute}`;
|
||||
if (gitIgnoredCount > 0) {
|
||||
resultMessage += ` (${gitIgnoredCount} additional files were git-ignored)`;
|
||||
}
|
||||
resultMessage += `, sorted by modification time (newest first):\n${fileListDescription}`;
|
||||
|
||||
return {
|
||||
llmContent: resultMessage,
|
||||
returnDisplay: `Found ${fileCount} matching file(s)`,
|
||||
};
|
||||
} catch (error) {
|
||||
const errorMessage =
|
||||
error instanceof Error ? error.message : String(error);
|
||||
console.error(`GlobLogic execute Error: ${errorMessage}`, error);
|
||||
return {
|
||||
llmContent: `Error during glob search operation: ${errorMessage}`,
|
||||
returnDisplay: `Error: An unexpected error occurred.`,
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
264
packages/core/src/tools/grep.test.ts
Normal file
264
packages/core/src/tools/grep.test.ts
Normal file
@@ -0,0 +1,264 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { describe, it, expect, beforeEach, afterEach, vi } from 'vitest';
|
||||
import { GrepTool, GrepToolParams } from './grep.js';
|
||||
import path from 'path';
|
||||
import fs from 'fs/promises';
|
||||
import os from 'os';
|
||||
import { Config } from '../config/config.js';
|
||||
|
||||
// Mock the child_process module to control grep/git grep behavior
|
||||
vi.mock('child_process', () => ({
|
||||
spawn: vi.fn(() => ({
|
||||
on: (event: string, cb: (...args: unknown[]) => void) => {
|
||||
if (event === 'error' || event === 'close') {
|
||||
// Simulate command not found or error for git grep and system grep
|
||||
// to force fallback to JS implementation.
|
||||
setTimeout(() => cb(1), 0); // cb(1) for error/close
|
||||
}
|
||||
},
|
||||
stdout: { on: vi.fn() },
|
||||
stderr: { on: vi.fn() },
|
||||
})),
|
||||
}));
|
||||
|
||||
describe('GrepTool', () => {
|
||||
let tempRootDir: string;
|
||||
let grepTool: GrepTool;
|
||||
const abortSignal = new AbortController().signal;
|
||||
|
||||
const mockConfig = {
|
||||
getTargetDir: () => tempRootDir,
|
||||
} as unknown as Config;
|
||||
|
||||
beforeEach(async () => {
|
||||
tempRootDir = await fs.mkdtemp(path.join(os.tmpdir(), 'grep-tool-root-'));
|
||||
grepTool = new GrepTool(mockConfig);
|
||||
|
||||
// Create some test files and directories
|
||||
await fs.writeFile(
|
||||
path.join(tempRootDir, 'fileA.txt'),
|
||||
'hello world\nsecond line with world',
|
||||
);
|
||||
await fs.writeFile(
|
||||
path.join(tempRootDir, 'fileB.js'),
|
||||
'const foo = "bar";\nfunction baz() { return "hello"; }',
|
||||
);
|
||||
await fs.mkdir(path.join(tempRootDir, 'sub'));
|
||||
await fs.writeFile(
|
||||
path.join(tempRootDir, 'sub', 'fileC.txt'),
|
||||
'another world in sub dir',
|
||||
);
|
||||
await fs.writeFile(
|
||||
path.join(tempRootDir, 'sub', 'fileD.md'),
|
||||
'# Markdown file\nThis is a test.',
|
||||
);
|
||||
});
|
||||
|
||||
afterEach(async () => {
|
||||
await fs.rm(tempRootDir, { recursive: true, force: true });
|
||||
});
|
||||
|
||||
describe('validateToolParams', () => {
|
||||
it('should return null for valid params (pattern only)', () => {
|
||||
const params: GrepToolParams = { pattern: 'hello' };
|
||||
expect(grepTool.validateToolParams(params)).toBeNull();
|
||||
});
|
||||
|
||||
it('should return null for valid params (pattern and path)', () => {
|
||||
const params: GrepToolParams = { pattern: 'hello', path: '.' };
|
||||
expect(grepTool.validateToolParams(params)).toBeNull();
|
||||
});
|
||||
|
||||
it('should return null for valid params (pattern, path, and include)', () => {
|
||||
const params: GrepToolParams = {
|
||||
pattern: 'hello',
|
||||
path: '.',
|
||||
include: '*.txt',
|
||||
};
|
||||
expect(grepTool.validateToolParams(params)).toBeNull();
|
||||
});
|
||||
|
||||
it('should return error if pattern is missing', () => {
|
||||
const params = { path: '.' } as unknown as GrepToolParams;
|
||||
expect(grepTool.validateToolParams(params)).toBe(
|
||||
`params must have required property 'pattern'`,
|
||||
);
|
||||
});
|
||||
|
||||
it('should return error for invalid regex pattern', () => {
|
||||
const params: GrepToolParams = { pattern: '[[' };
|
||||
expect(grepTool.validateToolParams(params)).toContain(
|
||||
'Invalid regular expression pattern',
|
||||
);
|
||||
});
|
||||
|
||||
it('should return error if path does not exist', () => {
|
||||
const params: GrepToolParams = { pattern: 'hello', path: 'nonexistent' };
|
||||
// Check for the core error message, as the full path might vary
|
||||
expect(grepTool.validateToolParams(params)).toContain(
|
||||
'Failed to access path stats for',
|
||||
);
|
||||
expect(grepTool.validateToolParams(params)).toContain('nonexistent');
|
||||
});
|
||||
|
||||
it('should return error if path is a file, not a directory', async () => {
|
||||
const filePath = path.join(tempRootDir, 'fileA.txt');
|
||||
const params: GrepToolParams = { pattern: 'hello', path: filePath };
|
||||
expect(grepTool.validateToolParams(params)).toContain(
|
||||
`Path is not a directory: ${filePath}`,
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('execute', () => {
|
||||
it('should find matches for a simple pattern in all files', async () => {
|
||||
const params: GrepToolParams = { pattern: 'world' };
|
||||
const result = await grepTool.execute(params, abortSignal);
|
||||
expect(result.llmContent).toContain(
|
||||
'Found 3 matches for pattern "world" in path "."',
|
||||
);
|
||||
expect(result.llmContent).toContain('File: fileA.txt');
|
||||
expect(result.llmContent).toContain('L1: hello world');
|
||||
expect(result.llmContent).toContain('L2: second line with world');
|
||||
expect(result.llmContent).toContain('File: sub/fileC.txt');
|
||||
expect(result.llmContent).toContain('L1: another world in sub dir');
|
||||
expect(result.returnDisplay).toBe('Found 3 matches');
|
||||
});
|
||||
|
||||
it('should find matches in a specific path', async () => {
|
||||
const params: GrepToolParams = { pattern: 'world', path: 'sub' };
|
||||
const result = await grepTool.execute(params, abortSignal);
|
||||
expect(result.llmContent).toContain(
|
||||
'Found 1 match for pattern "world" in path "sub"',
|
||||
);
|
||||
expect(result.llmContent).toContain('File: fileC.txt'); // Path relative to 'sub'
|
||||
expect(result.llmContent).toContain('L1: another world in sub dir');
|
||||
expect(result.returnDisplay).toBe('Found 1 match');
|
||||
});
|
||||
|
||||
it('should find matches with an include glob', async () => {
|
||||
const params: GrepToolParams = { pattern: 'hello', include: '*.js' };
|
||||
const result = await grepTool.execute(params, abortSignal);
|
||||
expect(result.llmContent).toContain(
|
||||
'Found 1 match for pattern "hello" in path "." (filter: "*.js")',
|
||||
);
|
||||
expect(result.llmContent).toContain('File: fileB.js');
|
||||
expect(result.llmContent).toContain(
|
||||
'L2: function baz() { return "hello"; }',
|
||||
);
|
||||
expect(result.returnDisplay).toBe('Found 1 match');
|
||||
});
|
||||
|
||||
it('should find matches with an include glob and path', async () => {
|
||||
await fs.writeFile(
|
||||
path.join(tempRootDir, 'sub', 'another.js'),
|
||||
'const greeting = "hello";',
|
||||
);
|
||||
const params: GrepToolParams = {
|
||||
pattern: 'hello',
|
||||
path: 'sub',
|
||||
include: '*.js',
|
||||
};
|
||||
const result = await grepTool.execute(params, abortSignal);
|
||||
expect(result.llmContent).toContain(
|
||||
'Found 1 match for pattern "hello" in path "sub" (filter: "*.js")',
|
||||
);
|
||||
expect(result.llmContent).toContain('File: another.js');
|
||||
expect(result.llmContent).toContain('L1: const greeting = "hello";');
|
||||
expect(result.returnDisplay).toBe('Found 1 match');
|
||||
});
|
||||
|
||||
it('should return "No matches found" when pattern does not exist', async () => {
|
||||
const params: GrepToolParams = { pattern: 'nonexistentpattern' };
|
||||
const result = await grepTool.execute(params, abortSignal);
|
||||
expect(result.llmContent).toContain(
|
||||
'No matches found for pattern "nonexistentpattern" in path "."',
|
||||
);
|
||||
expect(result.returnDisplay).toBe('No matches found');
|
||||
});
|
||||
|
||||
it('should handle regex special characters correctly', async () => {
|
||||
const params: GrepToolParams = { pattern: 'foo.*bar' }; // Matches 'const foo = "bar";'
|
||||
const result = await grepTool.execute(params, abortSignal);
|
||||
expect(result.llmContent).toContain(
|
||||
'Found 1 match for pattern "foo.*bar" in path "."',
|
||||
);
|
||||
expect(result.llmContent).toContain('File: fileB.js');
|
||||
expect(result.llmContent).toContain('L1: const foo = "bar";');
|
||||
});
|
||||
|
||||
it('should be case-insensitive by default (JS fallback)', async () => {
|
||||
const params: GrepToolParams = { pattern: 'HELLO' };
|
||||
const result = await grepTool.execute(params, abortSignal);
|
||||
expect(result.llmContent).toContain(
|
||||
'Found 2 matches for pattern "HELLO" in path "."',
|
||||
);
|
||||
expect(result.llmContent).toContain('File: fileA.txt');
|
||||
expect(result.llmContent).toContain('L1: hello world');
|
||||
expect(result.llmContent).toContain('File: fileB.js');
|
||||
expect(result.llmContent).toContain(
|
||||
'L2: function baz() { return "hello"; }',
|
||||
);
|
||||
});
|
||||
|
||||
it('should return an error if params are invalid', async () => {
|
||||
const params = { path: '.' } as unknown as GrepToolParams; // Invalid: pattern missing
|
||||
const result = await grepTool.execute(params, abortSignal);
|
||||
expect(result.llmContent).toBe(
|
||||
"Error: Invalid parameters provided. Reason: params must have required property 'pattern'",
|
||||
);
|
||||
expect(result.returnDisplay).toBe(
|
||||
"Model provided invalid parameters. Error: params must have required property 'pattern'",
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('getDescription', () => {
|
||||
it('should generate correct description with pattern only', () => {
|
||||
const params: GrepToolParams = { pattern: 'testPattern' };
|
||||
expect(grepTool.getDescription(params)).toBe("'testPattern'");
|
||||
});
|
||||
|
||||
it('should generate correct description with pattern and include', () => {
|
||||
const params: GrepToolParams = {
|
||||
pattern: 'testPattern',
|
||||
include: '*.ts',
|
||||
};
|
||||
expect(grepTool.getDescription(params)).toBe("'testPattern' in *.ts");
|
||||
});
|
||||
|
||||
it('should generate correct description with pattern and path', () => {
|
||||
const params: GrepToolParams = {
|
||||
pattern: 'testPattern',
|
||||
path: 'src/app',
|
||||
};
|
||||
// The path will be relative to the tempRootDir, so we check for containment.
|
||||
expect(grepTool.getDescription(params)).toContain("'testPattern' within");
|
||||
expect(grepTool.getDescription(params)).toContain(
|
||||
path.join('src', 'app'),
|
||||
);
|
||||
});
|
||||
|
||||
it('should generate correct description with pattern, include, and path', () => {
|
||||
const params: GrepToolParams = {
|
||||
pattern: 'testPattern',
|
||||
include: '*.ts',
|
||||
path: 'src/app',
|
||||
};
|
||||
expect(grepTool.getDescription(params)).toContain(
|
||||
"'testPattern' in *.ts within",
|
||||
);
|
||||
expect(grepTool.getDescription(params)).toContain('src/app');
|
||||
});
|
||||
|
||||
it('should use ./ for root path in description', () => {
|
||||
const params: GrepToolParams = { pattern: 'testPattern', path: '.' };
|
||||
expect(grepTool.getDescription(params)).toBe("'testPattern' within ./");
|
||||
});
|
||||
});
|
||||
});
|
||||
547
packages/core/src/tools/grep.ts
Normal file
547
packages/core/src/tools/grep.ts
Normal file
@@ -0,0 +1,547 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import fs from 'fs';
|
||||
import fsPromises from 'fs/promises';
|
||||
import path from 'path';
|
||||
import { EOL } from 'os';
|
||||
import { spawn } from 'child_process';
|
||||
import { globStream } from 'glob';
|
||||
import { BaseTool, ToolResult } from './tools.js';
|
||||
import { Type } from '@google/genai';
|
||||
import { SchemaValidator } from '../utils/schemaValidator.js';
|
||||
import { makeRelative, shortenPath } from '../utils/paths.js';
|
||||
import { getErrorMessage, isNodeError } from '../utils/errors.js';
|
||||
import { isGitRepository } from '../utils/gitUtils.js';
|
||||
import { Config } from '../config/config.js';
|
||||
|
||||
// --- Interfaces ---
|
||||
|
||||
/**
|
||||
* Parameters for the GrepTool
|
||||
*/
|
||||
export interface GrepToolParams {
|
||||
/**
|
||||
* The regular expression pattern to search for in file contents
|
||||
*/
|
||||
pattern: string;
|
||||
|
||||
/**
|
||||
* The directory to search in (optional, defaults to current directory relative to root)
|
||||
*/
|
||||
path?: string;
|
||||
|
||||
/**
|
||||
* File pattern to include in the search (e.g. "*.js", "*.{ts,tsx}")
|
||||
*/
|
||||
include?: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* Result object for a single grep match
|
||||
*/
|
||||
interface GrepMatch {
|
||||
filePath: string;
|
||||
lineNumber: number;
|
||||
line: string;
|
||||
}
|
||||
|
||||
// --- GrepLogic Class ---
|
||||
|
||||
/**
|
||||
* Implementation of the Grep tool logic (moved from CLI)
|
||||
*/
|
||||
export class GrepTool extends BaseTool<GrepToolParams, ToolResult> {
|
||||
static readonly Name = 'search_file_content'; // Keep static name
|
||||
|
||||
constructor(private readonly config: Config) {
|
||||
super(
|
||||
GrepTool.Name,
|
||||
'SearchText',
|
||||
'Searches for a regular expression pattern within the content of files in a specified directory (or current working directory). Can filter files by a glob pattern. Returns the lines containing matches, along with their file paths and line numbers.',
|
||||
{
|
||||
properties: {
|
||||
pattern: {
|
||||
description:
|
||||
"The regular expression (regex) pattern to search for within file contents (e.g., 'function\\s+myFunction', 'import\\s+\\{.*\\}\\s+from\\s+.*').",
|
||||
type: Type.STRING,
|
||||
},
|
||||
path: {
|
||||
description:
|
||||
'Optional: The absolute path to the directory to search within. If omitted, searches the current working directory.',
|
||||
type: Type.STRING,
|
||||
},
|
||||
include: {
|
||||
description:
|
||||
"Optional: A glob pattern to filter which files are searched (e.g., '*.js', '*.{ts,tsx}', 'src/**'). If omitted, searches all files (respecting potential global ignores).",
|
||||
type: Type.STRING,
|
||||
},
|
||||
},
|
||||
required: ['pattern'],
|
||||
type: Type.OBJECT,
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
// --- Validation Methods ---
|
||||
|
||||
/**
|
||||
* Checks if a path is within the root directory and resolves it.
|
||||
* @param relativePath Path relative to the root directory (or undefined for root).
|
||||
* @returns The absolute path if valid and exists.
|
||||
* @throws {Error} If path is outside root, doesn't exist, or isn't a directory.
|
||||
*/
|
||||
private resolveAndValidatePath(relativePath?: string): string {
|
||||
const targetPath = path.resolve(
|
||||
this.config.getTargetDir(),
|
||||
relativePath || '.',
|
||||
);
|
||||
|
||||
// Security Check: Ensure the resolved path is still within the root directory.
|
||||
if (
|
||||
!targetPath.startsWith(this.config.getTargetDir()) &&
|
||||
targetPath !== this.config.getTargetDir()
|
||||
) {
|
||||
throw new Error(
|
||||
`Path validation failed: Attempted path "${relativePath || '.'}" resolves outside the allowed root directory "${this.config.getTargetDir()}".`,
|
||||
);
|
||||
}
|
||||
|
||||
// Check existence and type after resolving
|
||||
try {
|
||||
const stats = fs.statSync(targetPath);
|
||||
if (!stats.isDirectory()) {
|
||||
throw new Error(`Path is not a directory: ${targetPath}`);
|
||||
}
|
||||
} catch (error: unknown) {
|
||||
if (isNodeError(error) && error.code !== 'ENOENT') {
|
||||
throw new Error(`Path does not exist: ${targetPath}`);
|
||||
}
|
||||
throw new Error(
|
||||
`Failed to access path stats for ${targetPath}: ${error}`,
|
||||
);
|
||||
}
|
||||
|
||||
return targetPath;
|
||||
}
|
||||
|
||||
/**
|
||||
* Validates the parameters for the tool
|
||||
* @param params Parameters to validate
|
||||
* @returns An error message string if invalid, null otherwise
|
||||
*/
|
||||
validateToolParams(params: GrepToolParams): string | null {
|
||||
const errors = SchemaValidator.validate(this.schema.parameters, params);
|
||||
if (errors) {
|
||||
return errors;
|
||||
}
|
||||
|
||||
try {
|
||||
new RegExp(params.pattern);
|
||||
} catch (error) {
|
||||
return `Invalid regular expression pattern provided: ${params.pattern}. Error: ${getErrorMessage(error)}`;
|
||||
}
|
||||
|
||||
try {
|
||||
this.resolveAndValidatePath(params.path);
|
||||
} catch (error) {
|
||||
return getErrorMessage(error);
|
||||
}
|
||||
|
||||
return null; // Parameters are valid
|
||||
}
|
||||
|
||||
// --- Core Execution ---
|
||||
|
||||
/**
|
||||
* Executes the grep search with the given parameters
|
||||
* @param params Parameters for the grep search
|
||||
* @returns Result of the grep search
|
||||
*/
|
||||
async execute(
|
||||
params: GrepToolParams,
|
||||
signal: AbortSignal,
|
||||
): Promise<ToolResult> {
|
||||
const validationError = this.validateToolParams(params);
|
||||
if (validationError) {
|
||||
return {
|
||||
llmContent: `Error: Invalid parameters provided. Reason: ${validationError}`,
|
||||
returnDisplay: `Model provided invalid parameters. Error: ${validationError}`,
|
||||
};
|
||||
}
|
||||
|
||||
let searchDirAbs: string;
|
||||
try {
|
||||
searchDirAbs = this.resolveAndValidatePath(params.path);
|
||||
const searchDirDisplay = params.path || '.';
|
||||
|
||||
const matches: GrepMatch[] = await this.performGrepSearch({
|
||||
pattern: params.pattern,
|
||||
path: searchDirAbs,
|
||||
include: params.include,
|
||||
signal,
|
||||
});
|
||||
|
||||
if (matches.length === 0) {
|
||||
const noMatchMsg = `No matches found for pattern "${params.pattern}" in path "${searchDirDisplay}"${params.include ? ` (filter: "${params.include}")` : ''}.`;
|
||||
return { llmContent: noMatchMsg, returnDisplay: `No matches found` };
|
||||
}
|
||||
|
||||
const matchesByFile = matches.reduce(
|
||||
(acc, match) => {
|
||||
const relativeFilePath =
|
||||
path.relative(
|
||||
searchDirAbs,
|
||||
path.resolve(searchDirAbs, match.filePath),
|
||||
) || path.basename(match.filePath);
|
||||
if (!acc[relativeFilePath]) {
|
||||
acc[relativeFilePath] = [];
|
||||
}
|
||||
acc[relativeFilePath].push(match);
|
||||
acc[relativeFilePath].sort((a, b) => a.lineNumber - b.lineNumber);
|
||||
return acc;
|
||||
},
|
||||
{} as Record<string, GrepMatch[]>,
|
||||
);
|
||||
|
||||
const matchCount = matches.length;
|
||||
const matchTerm = matchCount === 1 ? 'match' : 'matches';
|
||||
|
||||
let llmContent = `Found ${matchCount} ${matchTerm} for pattern "${params.pattern}" in path "${searchDirDisplay}"${params.include ? ` (filter: "${params.include}")` : ''}:\n---\n`;
|
||||
|
||||
for (const filePath in matchesByFile) {
|
||||
llmContent += `File: ${filePath}\n`;
|
||||
matchesByFile[filePath].forEach((match) => {
|
||||
const trimmedLine = match.line.trim();
|
||||
llmContent += `L${match.lineNumber}: ${trimmedLine}\n`;
|
||||
});
|
||||
llmContent += '---\n';
|
||||
}
|
||||
|
||||
return {
|
||||
llmContent: llmContent.trim(),
|
||||
returnDisplay: `Found ${matchCount} ${matchTerm}`,
|
||||
};
|
||||
} catch (error) {
|
||||
console.error(`Error during GrepLogic execution: ${error}`);
|
||||
const errorMessage = getErrorMessage(error);
|
||||
return {
|
||||
llmContent: `Error during grep search operation: ${errorMessage}`,
|
||||
returnDisplay: `Error: ${errorMessage}`,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
// --- Grep Implementation Logic ---
|
||||
|
||||
/**
|
||||
* Checks if a command is available in the system's PATH.
|
||||
* @param {string} command The command name (e.g., 'git', 'grep').
|
||||
* @returns {Promise<boolean>} True if the command is available, false otherwise.
|
||||
*/
|
||||
private isCommandAvailable(command: string): Promise<boolean> {
|
||||
return new Promise((resolve) => {
|
||||
const checkCommand = process.platform === 'win32' ? 'where' : 'command';
|
||||
const checkArgs =
|
||||
process.platform === 'win32' ? [command] : ['-v', command];
|
||||
try {
|
||||
const child = spawn(checkCommand, checkArgs, {
|
||||
stdio: 'ignore',
|
||||
shell: process.platform === 'win32',
|
||||
});
|
||||
child.on('close', (code) => resolve(code === 0));
|
||||
child.on('error', () => resolve(false));
|
||||
} catch {
|
||||
resolve(false);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Parses the standard output of grep-like commands (git grep, system grep).
|
||||
* Expects format: filePath:lineNumber:lineContent
|
||||
* Handles colons within file paths and line content correctly.
|
||||
* @param {string} output The raw stdout string.
|
||||
* @param {string} basePath The absolute directory the search was run from, for relative paths.
|
||||
* @returns {GrepMatch[]} Array of match objects.
|
||||
*/
|
||||
private parseGrepOutput(output: string, basePath: string): GrepMatch[] {
|
||||
const results: GrepMatch[] = [];
|
||||
if (!output) return results;
|
||||
|
||||
const lines = output.split(EOL); // Use OS-specific end-of-line
|
||||
|
||||
for (const line of lines) {
|
||||
if (!line.trim()) continue;
|
||||
|
||||
// Find the index of the first colon.
|
||||
const firstColonIndex = line.indexOf(':');
|
||||
if (firstColonIndex === -1) continue; // Malformed
|
||||
|
||||
// Find the index of the second colon, searching *after* the first one.
|
||||
const secondColonIndex = line.indexOf(':', firstColonIndex + 1);
|
||||
if (secondColonIndex === -1) continue; // Malformed
|
||||
|
||||
// Extract parts based on the found colon indices
|
||||
const filePathRaw = line.substring(0, firstColonIndex);
|
||||
const lineNumberStr = line.substring(
|
||||
firstColonIndex + 1,
|
||||
secondColonIndex,
|
||||
);
|
||||
const lineContent = line.substring(secondColonIndex + 1);
|
||||
|
||||
const lineNumber = parseInt(lineNumberStr, 10);
|
||||
|
||||
if (!isNaN(lineNumber)) {
|
||||
const absoluteFilePath = path.resolve(basePath, filePathRaw);
|
||||
const relativeFilePath = path.relative(basePath, absoluteFilePath);
|
||||
|
||||
results.push({
|
||||
filePath: relativeFilePath || path.basename(absoluteFilePath),
|
||||
lineNumber,
|
||||
line: lineContent,
|
||||
});
|
||||
}
|
||||
}
|
||||
return results;
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets a description of the grep operation
|
||||
* @param params Parameters for the grep operation
|
||||
* @returns A string describing the grep
|
||||
*/
|
||||
getDescription(params: GrepToolParams): string {
|
||||
let description = `'${params.pattern}'`;
|
||||
if (params.include) {
|
||||
description += ` in ${params.include}`;
|
||||
}
|
||||
if (params.path) {
|
||||
const resolvedPath = path.resolve(
|
||||
this.config.getTargetDir(),
|
||||
params.path,
|
||||
);
|
||||
if (resolvedPath === this.config.getTargetDir() || params.path === '.') {
|
||||
description += ` within ./`;
|
||||
} else {
|
||||
const relativePath = makeRelative(
|
||||
resolvedPath,
|
||||
this.config.getTargetDir(),
|
||||
);
|
||||
description += ` within ${shortenPath(relativePath)}`;
|
||||
}
|
||||
}
|
||||
return description;
|
||||
}
|
||||
|
||||
/**
|
||||
* Performs the actual search using the prioritized strategies.
|
||||
* @param options Search options including pattern, absolute path, and include glob.
|
||||
* @returns A promise resolving to an array of match objects.
|
||||
*/
|
||||
private async performGrepSearch(options: {
|
||||
pattern: string;
|
||||
path: string; // Expects absolute path
|
||||
include?: string;
|
||||
signal: AbortSignal;
|
||||
}): Promise<GrepMatch[]> {
|
||||
const { pattern, path: absolutePath, include } = options;
|
||||
let strategyUsed = 'none';
|
||||
|
||||
try {
|
||||
// --- Strategy 1: git grep ---
|
||||
const isGit = isGitRepository(absolutePath);
|
||||
const gitAvailable = isGit && (await this.isCommandAvailable('git'));
|
||||
|
||||
if (gitAvailable) {
|
||||
strategyUsed = 'git grep';
|
||||
const gitArgs = [
|
||||
'grep',
|
||||
'--untracked',
|
||||
'-n',
|
||||
'-E',
|
||||
'--ignore-case',
|
||||
pattern,
|
||||
];
|
||||
if (include) {
|
||||
gitArgs.push('--', include);
|
||||
}
|
||||
|
||||
try {
|
||||
const output = await new Promise<string>((resolve, reject) => {
|
||||
const child = spawn('git', gitArgs, {
|
||||
cwd: absolutePath,
|
||||
windowsHide: true,
|
||||
});
|
||||
const stdoutChunks: Buffer[] = [];
|
||||
const stderrChunks: Buffer[] = [];
|
||||
|
||||
child.stdout.on('data', (chunk) => stdoutChunks.push(chunk));
|
||||
child.stderr.on('data', (chunk) => stderrChunks.push(chunk));
|
||||
child.on('error', (err) =>
|
||||
reject(new Error(`Failed to start git grep: ${err.message}`)),
|
||||
);
|
||||
child.on('close', (code) => {
|
||||
const stdoutData = Buffer.concat(stdoutChunks).toString('utf8');
|
||||
const stderrData = Buffer.concat(stderrChunks).toString('utf8');
|
||||
if (code === 0) resolve(stdoutData);
|
||||
else if (code === 1)
|
||||
resolve(''); // No matches
|
||||
else
|
||||
reject(
|
||||
new Error(`git grep exited with code ${code}: ${stderrData}`),
|
||||
);
|
||||
});
|
||||
});
|
||||
return this.parseGrepOutput(output, absolutePath);
|
||||
} catch (gitError: unknown) {
|
||||
console.debug(
|
||||
`GrepLogic: git grep failed: ${getErrorMessage(gitError)}. Falling back...`,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// --- Strategy 2: System grep ---
|
||||
const grepAvailable = await this.isCommandAvailable('grep');
|
||||
if (grepAvailable) {
|
||||
strategyUsed = 'system grep';
|
||||
const grepArgs = ['-r', '-n', '-H', '-E'];
|
||||
const commonExcludes = ['.git', 'node_modules', 'bower_components'];
|
||||
commonExcludes.forEach((dir) => grepArgs.push(`--exclude-dir=${dir}`));
|
||||
if (include) {
|
||||
grepArgs.push(`--include=${include}`);
|
||||
}
|
||||
grepArgs.push(pattern);
|
||||
grepArgs.push('.');
|
||||
|
||||
try {
|
||||
const output = await new Promise<string>((resolve, reject) => {
|
||||
const child = spawn('grep', grepArgs, {
|
||||
cwd: absolutePath,
|
||||
windowsHide: true,
|
||||
});
|
||||
const stdoutChunks: Buffer[] = [];
|
||||
const stderrChunks: Buffer[] = [];
|
||||
|
||||
const onData = (chunk: Buffer) => stdoutChunks.push(chunk);
|
||||
const onStderr = (chunk: Buffer) => {
|
||||
const stderrStr = chunk.toString();
|
||||
// Suppress common harmless stderr messages
|
||||
if (
|
||||
!stderrStr.includes('Permission denied') &&
|
||||
!/grep:.*: Is a directory/i.test(stderrStr)
|
||||
) {
|
||||
stderrChunks.push(chunk);
|
||||
}
|
||||
};
|
||||
const onError = (err: Error) => {
|
||||
cleanup();
|
||||
reject(new Error(`Failed to start system grep: ${err.message}`));
|
||||
};
|
||||
const onClose = (code: number | null) => {
|
||||
const stdoutData = Buffer.concat(stdoutChunks).toString('utf8');
|
||||
const stderrData = Buffer.concat(stderrChunks)
|
||||
.toString('utf8')
|
||||
.trim();
|
||||
cleanup();
|
||||
if (code === 0) resolve(stdoutData);
|
||||
else if (code === 1)
|
||||
resolve(''); // No matches
|
||||
else {
|
||||
if (stderrData)
|
||||
reject(
|
||||
new Error(
|
||||
`System grep exited with code ${code}: ${stderrData}`,
|
||||
),
|
||||
);
|
||||
else resolve(''); // Exit code > 1 but no stderr, likely just suppressed errors
|
||||
}
|
||||
};
|
||||
|
||||
const cleanup = () => {
|
||||
child.stdout.removeListener('data', onData);
|
||||
child.stderr.removeListener('data', onStderr);
|
||||
child.removeListener('error', onError);
|
||||
child.removeListener('close', onClose);
|
||||
if (child.connected) {
|
||||
child.disconnect();
|
||||
}
|
||||
};
|
||||
|
||||
child.stdout.on('data', onData);
|
||||
child.stderr.on('data', onStderr);
|
||||
child.on('error', onError);
|
||||
child.on('close', onClose);
|
||||
});
|
||||
return this.parseGrepOutput(output, absolutePath);
|
||||
} catch (grepError: unknown) {
|
||||
console.debug(
|
||||
`GrepLogic: System grep failed: ${getErrorMessage(grepError)}. Falling back...`,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// --- Strategy 3: Pure JavaScript Fallback ---
|
||||
console.debug(
|
||||
'GrepLogic: Falling back to JavaScript grep implementation.',
|
||||
);
|
||||
strategyUsed = 'javascript fallback';
|
||||
const globPattern = include ? include : '**/*';
|
||||
const ignorePatterns = [
|
||||
'.git/**',
|
||||
'node_modules/**',
|
||||
'bower_components/**',
|
||||
'.svn/**',
|
||||
'.hg/**',
|
||||
]; // Use glob patterns for ignores here
|
||||
|
||||
const filesStream = globStream(globPattern, {
|
||||
cwd: absolutePath,
|
||||
dot: true,
|
||||
ignore: ignorePatterns,
|
||||
absolute: true,
|
||||
nodir: true,
|
||||
signal: options.signal,
|
||||
});
|
||||
|
||||
const regex = new RegExp(pattern, 'i');
|
||||
const allMatches: GrepMatch[] = [];
|
||||
|
||||
for await (const filePath of filesStream) {
|
||||
const fileAbsolutePath = filePath as string;
|
||||
try {
|
||||
const content = await fsPromises.readFile(fileAbsolutePath, 'utf8');
|
||||
const lines = content.split(/\r?\n/);
|
||||
lines.forEach((line, index) => {
|
||||
if (regex.test(line)) {
|
||||
allMatches.push({
|
||||
filePath:
|
||||
path.relative(absolutePath, fileAbsolutePath) ||
|
||||
path.basename(fileAbsolutePath),
|
||||
lineNumber: index + 1,
|
||||
line,
|
||||
});
|
||||
}
|
||||
});
|
||||
} catch (readError: unknown) {
|
||||
// Ignore errors like permission denied or file gone during read
|
||||
if (!isNodeError(readError) || readError.code !== 'ENOENT') {
|
||||
console.debug(
|
||||
`GrepLogic: Could not read/process ${fileAbsolutePath}: ${getErrorMessage(readError)}`,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return allMatches;
|
||||
} catch (error: unknown) {
|
||||
console.error(
|
||||
`GrepLogic: Error in performGrepSearch (Strategy: ${strategyUsed}): ${getErrorMessage(error)}`,
|
||||
);
|
||||
throw error; // Re-throw
|
||||
}
|
||||
}
|
||||
}
|
||||
285
packages/core/src/tools/ls.ts
Normal file
285
packages/core/src/tools/ls.ts
Normal file
@@ -0,0 +1,285 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import fs from 'fs';
|
||||
import path from 'path';
|
||||
import { BaseTool, ToolResult } from './tools.js';
|
||||
import { Type } from '@google/genai';
|
||||
import { SchemaValidator } from '../utils/schemaValidator.js';
|
||||
import { makeRelative, shortenPath } from '../utils/paths.js';
|
||||
import { Config } from '../config/config.js';
|
||||
import { isWithinRoot } from '../utils/fileUtils.js';
|
||||
|
||||
/**
|
||||
* Parameters for the LS tool
|
||||
*/
|
||||
export interface LSToolParams {
|
||||
/**
|
||||
* The absolute path to the directory to list
|
||||
*/
|
||||
path: string;
|
||||
|
||||
/**
|
||||
* Array of glob patterns to ignore (optional)
|
||||
*/
|
||||
ignore?: string[];
|
||||
|
||||
/**
|
||||
* Whether to respect .gitignore patterns (optional, defaults to true)
|
||||
*/
|
||||
respect_git_ignore?: boolean;
|
||||
}
|
||||
|
||||
/**
|
||||
* File entry returned by LS tool
|
||||
*/
|
||||
export interface FileEntry {
|
||||
/**
|
||||
* Name of the file or directory
|
||||
*/
|
||||
name: string;
|
||||
|
||||
/**
|
||||
* Absolute path to the file or directory
|
||||
*/
|
||||
path: string;
|
||||
|
||||
/**
|
||||
* Whether this entry is a directory
|
||||
*/
|
||||
isDirectory: boolean;
|
||||
|
||||
/**
|
||||
* Size of the file in bytes (0 for directories)
|
||||
*/
|
||||
size: number;
|
||||
|
||||
/**
|
||||
* Last modified timestamp
|
||||
*/
|
||||
modifiedTime: Date;
|
||||
}
|
||||
|
||||
/**
|
||||
* Implementation of the LS tool logic
|
||||
*/
|
||||
export class LSTool extends BaseTool<LSToolParams, ToolResult> {
|
||||
static readonly Name = 'list_directory';
|
||||
|
||||
constructor(private config: Config) {
|
||||
super(
|
||||
LSTool.Name,
|
||||
'ReadFolder',
|
||||
'Lists the names of files and subdirectories directly within a specified directory path. Can optionally ignore entries matching provided glob patterns.',
|
||||
{
|
||||
properties: {
|
||||
path: {
|
||||
description:
|
||||
'The absolute path to the directory to list (must be absolute, not relative)',
|
||||
type: Type.STRING,
|
||||
},
|
||||
ignore: {
|
||||
description: 'List of glob patterns to ignore',
|
||||
items: {
|
||||
type: Type.STRING,
|
||||
},
|
||||
type: Type.ARRAY,
|
||||
},
|
||||
respect_git_ignore: {
|
||||
description:
|
||||
'Optional: Whether to respect .gitignore patterns when listing files. Only available in git repositories. Defaults to true.',
|
||||
type: Type.BOOLEAN,
|
||||
},
|
||||
},
|
||||
required: ['path'],
|
||||
type: Type.OBJECT,
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Validates the parameters for the tool
|
||||
* @param params Parameters to validate
|
||||
* @returns An error message string if invalid, null otherwise
|
||||
*/
|
||||
validateToolParams(params: LSToolParams): string | null {
|
||||
const errors = SchemaValidator.validate(this.schema.parameters, params);
|
||||
if (errors) {
|
||||
return errors;
|
||||
}
|
||||
if (!path.isAbsolute(params.path)) {
|
||||
return `Path must be absolute: ${params.path}`;
|
||||
}
|
||||
if (!isWithinRoot(params.path, this.config.getTargetDir())) {
|
||||
return `Path must be within the root directory (${this.config.getTargetDir()}): ${params.path}`;
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Checks if a filename matches any of the ignore patterns
|
||||
* @param filename Filename to check
|
||||
* @param patterns Array of glob patterns to check against
|
||||
* @returns True if the filename should be ignored
|
||||
*/
|
||||
private shouldIgnore(filename: string, patterns?: string[]): boolean {
|
||||
if (!patterns || patterns.length === 0) {
|
||||
return false;
|
||||
}
|
||||
for (const pattern of patterns) {
|
||||
// Convert glob pattern to RegExp
|
||||
const regexPattern = pattern
|
||||
.replace(/[.+^${}()|[\]\\]/g, '\\$&')
|
||||
.replace(/\*/g, '.*')
|
||||
.replace(/\?/g, '.');
|
||||
const regex = new RegExp(`^${regexPattern}$`);
|
||||
if (regex.test(filename)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets a description of the file reading operation
|
||||
* @param params Parameters for the file reading
|
||||
* @returns A string describing the file being read
|
||||
*/
|
||||
getDescription(params: LSToolParams): string {
|
||||
const relativePath = makeRelative(params.path, this.config.getTargetDir());
|
||||
return shortenPath(relativePath);
|
||||
}
|
||||
|
||||
// Helper for consistent error formatting
|
||||
private errorResult(llmContent: string, returnDisplay: string): ToolResult {
|
||||
return {
|
||||
llmContent,
|
||||
// Keep returnDisplay simpler in core logic
|
||||
returnDisplay: `Error: ${returnDisplay}`,
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Executes the LS operation with the given parameters
|
||||
* @param params Parameters for the LS operation
|
||||
* @returns Result of the LS operation
|
||||
*/
|
||||
async execute(
|
||||
params: LSToolParams,
|
||||
_signal: AbortSignal,
|
||||
): Promise<ToolResult> {
|
||||
const validationError = this.validateToolParams(params);
|
||||
if (validationError) {
|
||||
return this.errorResult(
|
||||
`Error: Invalid parameters provided. Reason: ${validationError}`,
|
||||
`Failed to execute tool.`,
|
||||
);
|
||||
}
|
||||
|
||||
try {
|
||||
const stats = fs.statSync(params.path);
|
||||
if (!stats) {
|
||||
// fs.statSync throws on non-existence, so this check might be redundant
|
||||
// but keeping for clarity. Error message adjusted.
|
||||
return this.errorResult(
|
||||
`Error: Directory not found or inaccessible: ${params.path}`,
|
||||
`Directory not found or inaccessible.`,
|
||||
);
|
||||
}
|
||||
if (!stats.isDirectory()) {
|
||||
return this.errorResult(
|
||||
`Error: Path is not a directory: ${params.path}`,
|
||||
`Path is not a directory.`,
|
||||
);
|
||||
}
|
||||
|
||||
const files = fs.readdirSync(params.path);
|
||||
|
||||
// Get centralized file discovery service
|
||||
const respectGitIgnore =
|
||||
params.respect_git_ignore ??
|
||||
this.config.getFileFilteringRespectGitIgnore();
|
||||
const fileDiscovery = this.config.getFileService();
|
||||
|
||||
const entries: FileEntry[] = [];
|
||||
let gitIgnoredCount = 0;
|
||||
|
||||
if (files.length === 0) {
|
||||
// Changed error message to be more neutral for LLM
|
||||
return {
|
||||
llmContent: `Directory ${params.path} is empty.`,
|
||||
returnDisplay: `Directory is empty.`,
|
||||
};
|
||||
}
|
||||
|
||||
for (const file of files) {
|
||||
if (this.shouldIgnore(file, params.ignore)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const fullPath = path.join(params.path, file);
|
||||
const relativePath = path.relative(
|
||||
this.config.getTargetDir(),
|
||||
fullPath,
|
||||
);
|
||||
|
||||
// Check if this file should be git-ignored (only in git repositories)
|
||||
if (
|
||||
respectGitIgnore &&
|
||||
fileDiscovery.shouldGitIgnoreFile(relativePath)
|
||||
) {
|
||||
gitIgnoredCount++;
|
||||
continue;
|
||||
}
|
||||
|
||||
try {
|
||||
const stats = fs.statSync(fullPath);
|
||||
const isDir = stats.isDirectory();
|
||||
entries.push({
|
||||
name: file,
|
||||
path: fullPath,
|
||||
isDirectory: isDir,
|
||||
size: isDir ? 0 : stats.size,
|
||||
modifiedTime: stats.mtime,
|
||||
});
|
||||
} catch (error) {
|
||||
// Log error internally but don't fail the whole listing
|
||||
console.error(`Error accessing ${fullPath}: ${error}`);
|
||||
}
|
||||
}
|
||||
|
||||
// Sort entries (directories first, then alphabetically)
|
||||
entries.sort((a, b) => {
|
||||
if (a.isDirectory && !b.isDirectory) return -1;
|
||||
if (!a.isDirectory && b.isDirectory) return 1;
|
||||
return a.name.localeCompare(b.name);
|
||||
});
|
||||
|
||||
// Create formatted content for LLM
|
||||
const directoryContent = entries
|
||||
.map((entry) => `${entry.isDirectory ? '[DIR] ' : ''}${entry.name}`)
|
||||
.join('\n');
|
||||
|
||||
let resultMessage = `Directory listing for ${params.path}:\n${directoryContent}`;
|
||||
if (gitIgnoredCount > 0) {
|
||||
resultMessage += `\n\n(${gitIgnoredCount} items were git-ignored)`;
|
||||
}
|
||||
|
||||
let displayMessage = `Listed ${entries.length} item(s).`;
|
||||
if (gitIgnoredCount > 0) {
|
||||
displayMessage += ` (${gitIgnoredCount} git-ignored)`;
|
||||
}
|
||||
|
||||
return {
|
||||
llmContent: resultMessage,
|
||||
returnDisplay: displayMessage,
|
||||
};
|
||||
} catch (error) {
|
||||
const errorMsg = `Error listing directory: ${error instanceof Error ? error.message : String(error)}`;
|
||||
return this.errorResult(errorMsg, 'Failed to list directory.');
|
||||
}
|
||||
}
|
||||
}
|
||||
307
packages/core/src/tools/mcp-client.test.ts
Normal file
307
packages/core/src/tools/mcp-client.test.ts
Normal file
@@ -0,0 +1,307 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { afterEach, describe, expect, it, vi, beforeEach } from 'vitest';
|
||||
import { StreamableHTTPClientTransport } from '@modelcontextprotocol/sdk/client/streamableHttp.js';
|
||||
import {
|
||||
populateMcpServerCommand,
|
||||
createTransport,
|
||||
generateValidName,
|
||||
isEnabled,
|
||||
discoverTools,
|
||||
} from './mcp-client.js';
|
||||
import { SSEClientTransport } from '@modelcontextprotocol/sdk/client/sse.js';
|
||||
import * as SdkClientStdioLib from '@modelcontextprotocol/sdk/client/stdio.js';
|
||||
import * as ClientLib from '@modelcontextprotocol/sdk/client/index.js';
|
||||
import * as GenAiLib from '@google/genai';
|
||||
|
||||
vi.mock('@modelcontextprotocol/sdk/client/stdio.js');
|
||||
vi.mock('@modelcontextprotocol/sdk/client/index.js');
|
||||
vi.mock('@google/genai');
|
||||
|
||||
describe('mcp-client', () => {
|
||||
afterEach(() => {
|
||||
vi.restoreAllMocks();
|
||||
});
|
||||
|
||||
describe('discoverTools', () => {
|
||||
it('should discover tools', async () => {
|
||||
const mockedClient = {} as unknown as ClientLib.Client;
|
||||
const mockedMcpToTool = vi.mocked(GenAiLib.mcpToTool).mockReturnValue({
|
||||
tool: () => ({
|
||||
functionDeclarations: [
|
||||
{
|
||||
name: 'testFunction',
|
||||
},
|
||||
],
|
||||
}),
|
||||
} as unknown as GenAiLib.CallableTool);
|
||||
|
||||
const tools = await discoverTools('test-server', {}, mockedClient);
|
||||
|
||||
expect(tools.length).toBe(1);
|
||||
expect(mockedMcpToTool).toHaveBeenCalledOnce();
|
||||
});
|
||||
});
|
||||
|
||||
describe('appendMcpServerCommand', () => {
|
||||
it('should do nothing if no MCP servers or command are configured', () => {
|
||||
const out = populateMcpServerCommand({}, undefined);
|
||||
expect(out).toEqual({});
|
||||
});
|
||||
|
||||
it('should discover tools via mcpServerCommand', () => {
|
||||
const commandString = 'command --arg1 value1';
|
||||
const out = populateMcpServerCommand({}, commandString);
|
||||
expect(out).toEqual({
|
||||
mcp: {
|
||||
command: 'command',
|
||||
args: ['--arg1', 'value1'],
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
it('should handle error if mcpServerCommand parsing fails', () => {
|
||||
expect(() => populateMcpServerCommand({}, 'derp && herp')).toThrowError();
|
||||
});
|
||||
});
|
||||
|
||||
describe('createTransport', () => {
|
||||
const originalEnv = process.env;
|
||||
|
||||
beforeEach(() => {
|
||||
vi.resetModules();
|
||||
process.env = {};
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
process.env = originalEnv;
|
||||
});
|
||||
|
||||
describe('should connect via httpUrl', () => {
|
||||
it('without headers', async () => {
|
||||
const transport = createTransport(
|
||||
'test-server',
|
||||
{
|
||||
httpUrl: 'http://test-server',
|
||||
},
|
||||
false,
|
||||
);
|
||||
|
||||
expect(transport).toEqual(
|
||||
new StreamableHTTPClientTransport(new URL('http://test-server'), {}),
|
||||
);
|
||||
});
|
||||
|
||||
it('with headers', async () => {
|
||||
const transport = createTransport(
|
||||
'test-server',
|
||||
{
|
||||
httpUrl: 'http://test-server',
|
||||
headers: { Authorization: 'derp' },
|
||||
},
|
||||
false,
|
||||
);
|
||||
|
||||
expect(transport).toEqual(
|
||||
new StreamableHTTPClientTransport(new URL('http://test-server'), {
|
||||
requestInit: {
|
||||
headers: { Authorization: 'derp' },
|
||||
},
|
||||
}),
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('should connect via url', () => {
|
||||
it('without headers', async () => {
|
||||
const transport = createTransport(
|
||||
'test-server',
|
||||
{
|
||||
url: 'http://test-server',
|
||||
},
|
||||
false,
|
||||
);
|
||||
expect(transport).toEqual(
|
||||
new SSEClientTransport(new URL('http://test-server'), {}),
|
||||
);
|
||||
});
|
||||
|
||||
it('with headers', async () => {
|
||||
const transport = createTransport(
|
||||
'test-server',
|
||||
{
|
||||
url: 'http://test-server',
|
||||
headers: { Authorization: 'derp' },
|
||||
},
|
||||
false,
|
||||
);
|
||||
|
||||
expect(transport).toEqual(
|
||||
new SSEClientTransport(new URL('http://test-server'), {
|
||||
requestInit: {
|
||||
headers: { Authorization: 'derp' },
|
||||
},
|
||||
}),
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
it('should connect via command', () => {
|
||||
const mockedTransport = vi.mocked(SdkClientStdioLib.StdioClientTransport);
|
||||
|
||||
createTransport(
|
||||
'test-server',
|
||||
{
|
||||
command: 'test-command',
|
||||
args: ['--foo', 'bar'],
|
||||
env: { FOO: 'bar' },
|
||||
cwd: 'test/cwd',
|
||||
},
|
||||
false,
|
||||
);
|
||||
|
||||
expect(mockedTransport).toHaveBeenCalledWith({
|
||||
command: 'test-command',
|
||||
args: ['--foo', 'bar'],
|
||||
cwd: 'test/cwd',
|
||||
env: { FOO: 'bar' },
|
||||
stderr: 'pipe',
|
||||
});
|
||||
});
|
||||
});
|
||||
describe('generateValidName', () => {
|
||||
it('should return a valid name for a simple function', () => {
|
||||
const funcDecl = { name: 'myFunction' };
|
||||
const serverName = 'myServer';
|
||||
const result = generateValidName(funcDecl, serverName);
|
||||
expect(result).toBe('myServer__myFunction');
|
||||
});
|
||||
|
||||
it('should prepend the server name', () => {
|
||||
const funcDecl = { name: 'anotherFunction' };
|
||||
const serverName = 'production-server';
|
||||
const result = generateValidName(funcDecl, serverName);
|
||||
expect(result).toBe('production-server__anotherFunction');
|
||||
});
|
||||
|
||||
it('should replace invalid characters with underscores', () => {
|
||||
const funcDecl = { name: 'invalid-name with spaces' };
|
||||
const serverName = 'test_server';
|
||||
const result = generateValidName(funcDecl, serverName);
|
||||
expect(result).toBe('test_server__invalid-name_with_spaces');
|
||||
});
|
||||
|
||||
it('should truncate long names', () => {
|
||||
const funcDecl = {
|
||||
name: 'a_very_long_function_name_that_will_definitely_exceed_the_limit',
|
||||
};
|
||||
const serverName = 'a_long_server_name';
|
||||
const result = generateValidName(funcDecl, serverName);
|
||||
expect(result.length).toBe(63);
|
||||
expect(result).toBe(
|
||||
'a_long_server_name__a_very_l___will_definitely_exceed_the_limit',
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle names with only invalid characters', () => {
|
||||
const funcDecl = { name: '!@#$%^&*()' };
|
||||
const serverName = 'special-chars';
|
||||
const result = generateValidName(funcDecl, serverName);
|
||||
expect(result).toBe('special-chars____________');
|
||||
});
|
||||
|
||||
it('should handle names that are already valid', () => {
|
||||
const funcDecl = { name: 'already_valid' };
|
||||
const serverName = 'validator';
|
||||
const result = generateValidName(funcDecl, serverName);
|
||||
expect(result).toBe('validator__already_valid');
|
||||
});
|
||||
|
||||
it('should handle names with leading/trailing invalid characters', () => {
|
||||
const funcDecl = { name: '-_invalid-_' };
|
||||
const serverName = 'trim-test';
|
||||
const result = generateValidName(funcDecl, serverName);
|
||||
expect(result).toBe('trim-test__-_invalid-_');
|
||||
});
|
||||
|
||||
it('should handle names that are exactly 63 characters long', () => {
|
||||
const longName = 'a'.repeat(45);
|
||||
const funcDecl = { name: longName };
|
||||
const serverName = 'server';
|
||||
const result = generateValidName(funcDecl, serverName);
|
||||
expect(result).toBe(`server__${longName}`);
|
||||
expect(result.length).toBe(53);
|
||||
});
|
||||
|
||||
it('should handle names that are exactly 64 characters long', () => {
|
||||
const longName = 'a'.repeat(55);
|
||||
const funcDecl = { name: longName };
|
||||
const serverName = 'server';
|
||||
const result = generateValidName(funcDecl, serverName);
|
||||
expect(result.length).toBe(63);
|
||||
expect(result).toBe(
|
||||
'server__aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa',
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle names that are longer than 64 characters', () => {
|
||||
const longName = 'a'.repeat(100);
|
||||
const funcDecl = { name: longName };
|
||||
const serverName = 'long-server';
|
||||
const result = generateValidName(funcDecl, serverName);
|
||||
expect(result.length).toBe(63);
|
||||
expect(result).toBe(
|
||||
'long-server__aaaaaaaaaaaaaaa___aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa',
|
||||
);
|
||||
});
|
||||
});
|
||||
describe('isEnabled', () => {
|
||||
const funcDecl = { name: 'myTool' };
|
||||
const serverName = 'myServer';
|
||||
|
||||
it('should return true if no include or exclude lists are provided', () => {
|
||||
const mcpServerConfig = {};
|
||||
expect(isEnabled(funcDecl, serverName, mcpServerConfig)).toBe(true);
|
||||
});
|
||||
|
||||
it('should return false if the tool is in the exclude list', () => {
|
||||
const mcpServerConfig = { excludeTools: ['myTool'] };
|
||||
expect(isEnabled(funcDecl, serverName, mcpServerConfig)).toBe(false);
|
||||
});
|
||||
|
||||
it('should return true if the tool is in the include list', () => {
|
||||
const mcpServerConfig = { includeTools: ['myTool'] };
|
||||
expect(isEnabled(funcDecl, serverName, mcpServerConfig)).toBe(true);
|
||||
});
|
||||
|
||||
it('should return true if the tool is in the include list with parentheses', () => {
|
||||
const mcpServerConfig = { includeTools: ['myTool()'] };
|
||||
expect(isEnabled(funcDecl, serverName, mcpServerConfig)).toBe(true);
|
||||
});
|
||||
|
||||
it('should return false if the include list exists but does not contain the tool', () => {
|
||||
const mcpServerConfig = { includeTools: ['anotherTool'] };
|
||||
expect(isEnabled(funcDecl, serverName, mcpServerConfig)).toBe(false);
|
||||
});
|
||||
|
||||
it('should return false if the tool is in both the include and exclude lists', () => {
|
||||
const mcpServerConfig = {
|
||||
includeTools: ['myTool'],
|
||||
excludeTools: ['myTool'],
|
||||
};
|
||||
expect(isEnabled(funcDecl, serverName, mcpServerConfig)).toBe(false);
|
||||
});
|
||||
|
||||
it('should return false if the function declaration has no name', () => {
|
||||
const namelessFuncDecl = {};
|
||||
const mcpServerConfig = {};
|
||||
expect(isEnabled(namelessFuncDecl, serverName, mcpServerConfig)).toBe(
|
||||
false,
|
||||
);
|
||||
});
|
||||
});
|
||||
});
|
||||
459
packages/core/src/tools/mcp-client.ts
Normal file
459
packages/core/src/tools/mcp-client.ts
Normal file
@@ -0,0 +1,459 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { Client } from '@modelcontextprotocol/sdk/client/index.js';
|
||||
import { Transport } from '@modelcontextprotocol/sdk/shared/transport.js';
|
||||
import { StdioClientTransport } from '@modelcontextprotocol/sdk/client/stdio.js';
|
||||
import {
|
||||
SSEClientTransport,
|
||||
SSEClientTransportOptions,
|
||||
} from '@modelcontextprotocol/sdk/client/sse.js';
|
||||
import {
|
||||
StreamableHTTPClientTransport,
|
||||
StreamableHTTPClientTransportOptions,
|
||||
} from '@modelcontextprotocol/sdk/client/streamableHttp.js';
|
||||
import { parse } from 'shell-quote';
|
||||
import { MCPServerConfig } from '../config/config.js';
|
||||
import { DiscoveredMCPTool } from './mcp-tool.js';
|
||||
import { FunctionDeclaration, Type, mcpToTool } from '@google/genai';
|
||||
import { sanitizeParameters, ToolRegistry } from './tool-registry.js';
|
||||
|
||||
export const MCP_DEFAULT_TIMEOUT_MSEC = 10 * 60 * 1000; // default to 10 minutes
|
||||
|
||||
/**
|
||||
* Enum representing the connection status of an MCP server
|
||||
*/
|
||||
export enum MCPServerStatus {
|
||||
/** Server is disconnected or experiencing errors */
|
||||
DISCONNECTED = 'disconnected',
|
||||
/** Server is in the process of connecting */
|
||||
CONNECTING = 'connecting',
|
||||
/** Server is connected and ready to use */
|
||||
CONNECTED = 'connected',
|
||||
}
|
||||
|
||||
/**
|
||||
* Enum representing the overall MCP discovery state
|
||||
*/
|
||||
export enum MCPDiscoveryState {
|
||||
/** Discovery has not started yet */
|
||||
NOT_STARTED = 'not_started',
|
||||
/** Discovery is currently in progress */
|
||||
IN_PROGRESS = 'in_progress',
|
||||
/** Discovery has completed (with or without errors) */
|
||||
COMPLETED = 'completed',
|
||||
}
|
||||
|
||||
/**
|
||||
* Map to track the status of each MCP server within the core package
|
||||
*/
|
||||
const mcpServerStatusesInternal: Map<string, MCPServerStatus> = new Map();
|
||||
|
||||
/**
|
||||
* Track the overall MCP discovery state
|
||||
*/
|
||||
let mcpDiscoveryState: MCPDiscoveryState = MCPDiscoveryState.NOT_STARTED;
|
||||
|
||||
/**
|
||||
* Event listeners for MCP server status changes
|
||||
*/
|
||||
type StatusChangeListener = (
|
||||
serverName: string,
|
||||
status: MCPServerStatus,
|
||||
) => void;
|
||||
const statusChangeListeners: StatusChangeListener[] = [];
|
||||
|
||||
/**
|
||||
* Add a listener for MCP server status changes
|
||||
*/
|
||||
export function addMCPStatusChangeListener(
|
||||
listener: StatusChangeListener,
|
||||
): void {
|
||||
statusChangeListeners.push(listener);
|
||||
}
|
||||
|
||||
/**
|
||||
* Remove a listener for MCP server status changes
|
||||
*/
|
||||
export function removeMCPStatusChangeListener(
|
||||
listener: StatusChangeListener,
|
||||
): void {
|
||||
const index = statusChangeListeners.indexOf(listener);
|
||||
if (index !== -1) {
|
||||
statusChangeListeners.splice(index, 1);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Update the status of an MCP server
|
||||
*/
|
||||
function updateMCPServerStatus(
|
||||
serverName: string,
|
||||
status: MCPServerStatus,
|
||||
): void {
|
||||
mcpServerStatusesInternal.set(serverName, status);
|
||||
// Notify all listeners
|
||||
for (const listener of statusChangeListeners) {
|
||||
listener(serverName, status);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the current status of an MCP server
|
||||
*/
|
||||
export function getMCPServerStatus(serverName: string): MCPServerStatus {
|
||||
return (
|
||||
mcpServerStatusesInternal.get(serverName) || MCPServerStatus.DISCONNECTED
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get all MCP server statuses
|
||||
*/
|
||||
export function getAllMCPServerStatuses(): Map<string, MCPServerStatus> {
|
||||
return new Map(mcpServerStatusesInternal);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the current MCP discovery state
|
||||
*/
|
||||
export function getMCPDiscoveryState(): MCPDiscoveryState {
|
||||
return mcpDiscoveryState;
|
||||
}
|
||||
|
||||
/**
|
||||
* Discovers tools from all configured MCP servers and registers them with the tool registry.
|
||||
* It orchestrates the connection and discovery process for each server defined in the
|
||||
* configuration, as well as any server specified via a command-line argument.
|
||||
*
|
||||
* @param mcpServers A record of named MCP server configurations.
|
||||
* @param mcpServerCommand An optional command string for a dynamically specified MCP server.
|
||||
* @param toolRegistry The central registry where discovered tools will be registered.
|
||||
* @returns A promise that resolves when the discovery process has been attempted for all servers.
|
||||
*/
|
||||
export async function discoverMcpTools(
|
||||
mcpServers: Record<string, MCPServerConfig>,
|
||||
mcpServerCommand: string | undefined,
|
||||
toolRegistry: ToolRegistry,
|
||||
debugMode: boolean,
|
||||
): Promise<void> {
|
||||
mcpDiscoveryState = MCPDiscoveryState.IN_PROGRESS;
|
||||
try {
|
||||
mcpServers = populateMcpServerCommand(mcpServers, mcpServerCommand);
|
||||
|
||||
const discoveryPromises = Object.entries(mcpServers).map(
|
||||
([mcpServerName, mcpServerConfig]) =>
|
||||
connectAndDiscover(
|
||||
mcpServerName,
|
||||
mcpServerConfig,
|
||||
toolRegistry,
|
||||
debugMode,
|
||||
),
|
||||
);
|
||||
await Promise.all(discoveryPromises);
|
||||
} finally {
|
||||
mcpDiscoveryState = MCPDiscoveryState.COMPLETED;
|
||||
}
|
||||
}
|
||||
|
||||
/** Visible for Testing */
|
||||
export function populateMcpServerCommand(
|
||||
mcpServers: Record<string, MCPServerConfig>,
|
||||
mcpServerCommand: string | undefined,
|
||||
): Record<string, MCPServerConfig> {
|
||||
if (mcpServerCommand) {
|
||||
const cmd = mcpServerCommand;
|
||||
const args = parse(cmd, process.env) as string[];
|
||||
if (args.some((arg) => typeof arg !== 'string')) {
|
||||
throw new Error('failed to parse mcpServerCommand: ' + cmd);
|
||||
}
|
||||
// use generic server name 'mcp'
|
||||
mcpServers['mcp'] = {
|
||||
command: args[0],
|
||||
args: args.slice(1),
|
||||
};
|
||||
}
|
||||
return mcpServers;
|
||||
}
|
||||
|
||||
/**
|
||||
* Connects to an MCP server and discovers available tools, registering them with the tool registry.
|
||||
* This function handles the complete lifecycle of connecting to a server, discovering tools,
|
||||
* and cleaning up resources if no tools are found.
|
||||
*
|
||||
* @param mcpServerName The name identifier for this MCP server
|
||||
* @param mcpServerConfig Configuration object containing connection details
|
||||
* @param toolRegistry The registry to register discovered tools with
|
||||
* @returns Promise that resolves when discovery is complete
|
||||
*/
|
||||
export async function connectAndDiscover(
|
||||
mcpServerName: string,
|
||||
mcpServerConfig: MCPServerConfig,
|
||||
toolRegistry: ToolRegistry,
|
||||
debugMode: boolean,
|
||||
): Promise<void> {
|
||||
updateMCPServerStatus(mcpServerName, MCPServerStatus.CONNECTING);
|
||||
|
||||
try {
|
||||
const mcpClient = await connectToMcpServer(
|
||||
mcpServerName,
|
||||
mcpServerConfig,
|
||||
debugMode,
|
||||
);
|
||||
try {
|
||||
updateMCPServerStatus(mcpServerName, MCPServerStatus.CONNECTED);
|
||||
|
||||
mcpClient.onerror = (error) => {
|
||||
console.error(`MCP ERROR (${mcpServerName}):`, error.toString());
|
||||
updateMCPServerStatus(mcpServerName, MCPServerStatus.DISCONNECTED);
|
||||
};
|
||||
|
||||
const tools = await discoverTools(
|
||||
mcpServerName,
|
||||
mcpServerConfig,
|
||||
mcpClient,
|
||||
);
|
||||
for (const tool of tools) {
|
||||
toolRegistry.registerTool(tool);
|
||||
}
|
||||
} catch (error) {
|
||||
mcpClient.close();
|
||||
throw error;
|
||||
}
|
||||
} catch (error) {
|
||||
console.error(`Error connecting to MCP server '${mcpServerName}':`, error);
|
||||
updateMCPServerStatus(mcpServerName, MCPServerStatus.DISCONNECTED);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Discovers and sanitizes tools from a connected MCP client.
|
||||
* It retrieves function declarations from the client, filters out disabled tools,
|
||||
* generates valid names for them, and wraps them in `DiscoveredMCPTool` instances.
|
||||
*
|
||||
* @param mcpServerName The name of the MCP server.
|
||||
* @param mcpServerConfig The configuration for the MCP server.
|
||||
* @param mcpClient The active MCP client instance.
|
||||
* @returns A promise that resolves to an array of discovered and enabled tools.
|
||||
* @throws An error if no enabled tools are found or if the server provides invalid function declarations.
|
||||
*/
|
||||
export async function discoverTools(
|
||||
mcpServerName: string,
|
||||
mcpServerConfig: MCPServerConfig,
|
||||
mcpClient: Client,
|
||||
): Promise<DiscoveredMCPTool[]> {
|
||||
try {
|
||||
const mcpCallableTool = mcpToTool(mcpClient);
|
||||
const tool = await mcpCallableTool.tool();
|
||||
|
||||
if (!Array.isArray(tool.functionDeclarations)) {
|
||||
throw new Error(`Server did not return valid function declarations.`);
|
||||
}
|
||||
|
||||
const discoveredTools: DiscoveredMCPTool[] = [];
|
||||
for (const funcDecl of tool.functionDeclarations) {
|
||||
if (!isEnabled(funcDecl, mcpServerName, mcpServerConfig)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const toolNameForModel = generateValidName(funcDecl, mcpServerName);
|
||||
|
||||
sanitizeParameters(funcDecl.parameters);
|
||||
|
||||
discoveredTools.push(
|
||||
new DiscoveredMCPTool(
|
||||
mcpCallableTool,
|
||||
mcpServerName,
|
||||
toolNameForModel,
|
||||
funcDecl.description ?? '',
|
||||
funcDecl.parameters ?? { type: Type.OBJECT, properties: {} },
|
||||
funcDecl.name!,
|
||||
mcpServerConfig.timeout ?? MCP_DEFAULT_TIMEOUT_MSEC,
|
||||
mcpServerConfig.trust,
|
||||
),
|
||||
);
|
||||
}
|
||||
if (discoveredTools.length === 0) {
|
||||
throw Error('No enabled tools found');
|
||||
}
|
||||
return discoveredTools;
|
||||
} catch (error) {
|
||||
throw new Error(`Error discovering tools: ${error}`);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates and connects an MCP client to a server based on the provided configuration.
|
||||
* It determines the appropriate transport (Stdio, SSE, or Streamable HTTP) and
|
||||
* establishes a connection. It also applies a patch to handle request timeouts.
|
||||
*
|
||||
* @param mcpServerName The name of the MCP server, used for logging and identification.
|
||||
* @param mcpServerConfig The configuration specifying how to connect to the server.
|
||||
* @returns A promise that resolves to a connected MCP `Client` instance.
|
||||
* @throws An error if the connection fails or the configuration is invalid.
|
||||
*/
|
||||
export async function connectToMcpServer(
|
||||
mcpServerName: string,
|
||||
mcpServerConfig: MCPServerConfig,
|
||||
debugMode: boolean,
|
||||
): Promise<Client> {
|
||||
const mcpClient = new Client({
|
||||
name: 'gemini-cli-mcp-client',
|
||||
version: '0.0.1',
|
||||
});
|
||||
|
||||
// patch Client.callTool to use request timeout as genai McpCallTool.callTool does not do it
|
||||
// TODO: remove this hack once GenAI SDK does callTool with request options
|
||||
if ('callTool' in mcpClient) {
|
||||
const origCallTool = mcpClient.callTool.bind(mcpClient);
|
||||
mcpClient.callTool = function (params, resultSchema, options) {
|
||||
return origCallTool(params, resultSchema, {
|
||||
...options,
|
||||
timeout: mcpServerConfig.timeout ?? MCP_DEFAULT_TIMEOUT_MSEC,
|
||||
});
|
||||
};
|
||||
}
|
||||
|
||||
try {
|
||||
const transport = createTransport(
|
||||
mcpServerName,
|
||||
mcpServerConfig,
|
||||
debugMode,
|
||||
);
|
||||
try {
|
||||
await mcpClient.connect(transport, {
|
||||
timeout: mcpServerConfig.timeout ?? MCP_DEFAULT_TIMEOUT_MSEC,
|
||||
});
|
||||
return mcpClient;
|
||||
} catch (error) {
|
||||
await transport.close();
|
||||
throw error;
|
||||
}
|
||||
} catch (error) {
|
||||
// Create a safe config object that excludes sensitive information
|
||||
const safeConfig = {
|
||||
command: mcpServerConfig.command,
|
||||
url: mcpServerConfig.url,
|
||||
httpUrl: mcpServerConfig.httpUrl,
|
||||
cwd: mcpServerConfig.cwd,
|
||||
timeout: mcpServerConfig.timeout,
|
||||
trust: mcpServerConfig.trust,
|
||||
// Exclude args, env, and headers which may contain sensitive data
|
||||
};
|
||||
|
||||
let errorString =
|
||||
`failed to start or connect to MCP server '${mcpServerName}' ` +
|
||||
`${JSON.stringify(safeConfig)}; \n${error}`;
|
||||
if (process.env.SANDBOX) {
|
||||
errorString += `\nMake sure it is available in the sandbox`;
|
||||
}
|
||||
throw new Error(errorString);
|
||||
}
|
||||
}
|
||||
|
||||
/** Visible for Testing */
|
||||
export function createTransport(
|
||||
mcpServerName: string,
|
||||
mcpServerConfig: MCPServerConfig,
|
||||
debugMode: boolean,
|
||||
): Transport {
|
||||
if (mcpServerConfig.httpUrl) {
|
||||
const transportOptions: StreamableHTTPClientTransportOptions = {};
|
||||
if (mcpServerConfig.headers) {
|
||||
transportOptions.requestInit = {
|
||||
headers: mcpServerConfig.headers,
|
||||
};
|
||||
}
|
||||
return new StreamableHTTPClientTransport(
|
||||
new URL(mcpServerConfig.httpUrl),
|
||||
transportOptions,
|
||||
);
|
||||
}
|
||||
|
||||
if (mcpServerConfig.url) {
|
||||
const transportOptions: SSEClientTransportOptions = {};
|
||||
if (mcpServerConfig.headers) {
|
||||
transportOptions.requestInit = {
|
||||
headers: mcpServerConfig.headers,
|
||||
};
|
||||
}
|
||||
return new SSEClientTransport(
|
||||
new URL(mcpServerConfig.url),
|
||||
transportOptions,
|
||||
);
|
||||
}
|
||||
|
||||
if (mcpServerConfig.command) {
|
||||
const transport = new StdioClientTransport({
|
||||
command: mcpServerConfig.command,
|
||||
args: mcpServerConfig.args || [],
|
||||
env: {
|
||||
...process.env,
|
||||
...(mcpServerConfig.env || {}),
|
||||
} as Record<string, string>,
|
||||
cwd: mcpServerConfig.cwd,
|
||||
stderr: 'pipe',
|
||||
});
|
||||
if (debugMode) {
|
||||
transport.stderr!.on('data', (data) => {
|
||||
const stderrStr = data.toString().trim();
|
||||
console.debug(`[DEBUG] [MCP STDERR (${mcpServerName})]: `, stderrStr);
|
||||
});
|
||||
}
|
||||
return transport;
|
||||
}
|
||||
|
||||
throw new Error(
|
||||
`Invalid configuration: missing httpUrl (for Streamable HTTP), url (for SSE), and command (for stdio).`,
|
||||
);
|
||||
}
|
||||
|
||||
/** Visible for testing */
|
||||
export function generateValidName(
|
||||
funcDecl: FunctionDeclaration,
|
||||
mcpServerName: string,
|
||||
) {
|
||||
// Replace invalid characters (based on 400 error message from Gemini API) with underscores
|
||||
let validToolname = funcDecl.name!.replace(/[^a-zA-Z0-9_.-]/g, '_');
|
||||
|
||||
// Prepend MCP server name to avoid conflicts with other tools
|
||||
validToolname = mcpServerName + '__' + validToolname;
|
||||
|
||||
// If longer than 63 characters, replace middle with '___'
|
||||
// (Gemini API says max length 64, but actual limit seems to be 63)
|
||||
if (validToolname.length > 63) {
|
||||
validToolname =
|
||||
validToolname.slice(0, 28) + '___' + validToolname.slice(-32);
|
||||
}
|
||||
return validToolname;
|
||||
}
|
||||
|
||||
/** Visible for testing */
|
||||
export function isEnabled(
|
||||
funcDecl: FunctionDeclaration,
|
||||
mcpServerName: string,
|
||||
mcpServerConfig: MCPServerConfig,
|
||||
): boolean {
|
||||
if (!funcDecl.name) {
|
||||
console.warn(
|
||||
`Discovered a function declaration without a name from MCP server '${mcpServerName}'. Skipping.`,
|
||||
);
|
||||
return false;
|
||||
}
|
||||
const { includeTools, excludeTools } = mcpServerConfig;
|
||||
|
||||
// excludeTools takes precedence over includeTools
|
||||
if (excludeTools && excludeTools.includes(funcDecl.name)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return (
|
||||
!includeTools ||
|
||||
includeTools.some(
|
||||
(tool) => tool === funcDecl.name || tool.startsWith(`${funcDecl.name}(`),
|
||||
)
|
||||
);
|
||||
}
|
||||
319
packages/core/src/tools/mcp-tool.test.ts
Normal file
319
packages/core/src/tools/mcp-tool.test.ts
Normal file
@@ -0,0 +1,319 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
/* eslint-disable @typescript-eslint/no-explicit-any */
|
||||
import {
|
||||
describe,
|
||||
it,
|
||||
expect,
|
||||
vi,
|
||||
beforeEach,
|
||||
afterEach,
|
||||
Mocked,
|
||||
} from 'vitest';
|
||||
import { DiscoveredMCPTool } from './mcp-tool.js'; // Added getStringifiedResultForDisplay
|
||||
import { ToolResult, ToolConfirmationOutcome } from './tools.js'; // Added ToolConfirmationOutcome
|
||||
import { CallableTool, Part } from '@google/genai';
|
||||
|
||||
// Mock @google/genai mcpToTool and CallableTool
|
||||
// We only need to mock the parts of CallableTool that DiscoveredMCPTool uses.
|
||||
const mockCallTool = vi.fn();
|
||||
const mockToolMethod = vi.fn();
|
||||
|
||||
const mockCallableToolInstance: Mocked<CallableTool> = {
|
||||
tool: mockToolMethod as any, // Not directly used by DiscoveredMCPTool instance methods
|
||||
callTool: mockCallTool as any,
|
||||
// Add other methods if DiscoveredMCPTool starts using them
|
||||
};
|
||||
|
||||
describe('DiscoveredMCPTool', () => {
|
||||
const serverName = 'mock-mcp-server';
|
||||
const toolNameForModel = 'test-mcp-tool-for-model';
|
||||
const serverToolName = 'actual-server-tool-name';
|
||||
const baseDescription = 'A test MCP tool.';
|
||||
const inputSchema: Record<string, unknown> = {
|
||||
type: 'object' as const,
|
||||
properties: { param: { type: 'string' } },
|
||||
required: ['param'],
|
||||
};
|
||||
|
||||
beforeEach(() => {
|
||||
mockCallTool.mockClear();
|
||||
mockToolMethod.mockClear();
|
||||
// Clear allowlist before each relevant test, especially for shouldConfirmExecute
|
||||
(DiscoveredMCPTool as any).allowlist.clear();
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.restoreAllMocks();
|
||||
});
|
||||
|
||||
describe('constructor', () => {
|
||||
it('should set properties correctly (non-generic server)', () => {
|
||||
const tool = new DiscoveredMCPTool(
|
||||
mockCallableToolInstance,
|
||||
serverName, // serverName is 'mock-mcp-server', not 'mcp'
|
||||
toolNameForModel,
|
||||
baseDescription,
|
||||
inputSchema,
|
||||
serverToolName,
|
||||
);
|
||||
|
||||
expect(tool.name).toBe(toolNameForModel);
|
||||
expect(tool.schema.name).toBe(toolNameForModel);
|
||||
expect(tool.schema.description).toBe(baseDescription);
|
||||
expect(tool.schema.parameters).toEqual(inputSchema);
|
||||
expect(tool.serverToolName).toBe(serverToolName);
|
||||
expect(tool.timeout).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should set properties correctly (generic "mcp" server)', () => {
|
||||
const genericServerName = 'mcp';
|
||||
const tool = new DiscoveredMCPTool(
|
||||
mockCallableToolInstance,
|
||||
genericServerName, // serverName is 'mcp'
|
||||
toolNameForModel,
|
||||
baseDescription,
|
||||
inputSchema,
|
||||
serverToolName,
|
||||
);
|
||||
expect(tool.schema.description).toBe(baseDescription);
|
||||
});
|
||||
|
||||
it('should accept and store a custom timeout', () => {
|
||||
const customTimeout = 5000;
|
||||
const tool = new DiscoveredMCPTool(
|
||||
mockCallableToolInstance,
|
||||
serverName,
|
||||
toolNameForModel,
|
||||
baseDescription,
|
||||
inputSchema,
|
||||
serverToolName,
|
||||
customTimeout,
|
||||
);
|
||||
expect(tool.timeout).toBe(customTimeout);
|
||||
});
|
||||
});
|
||||
|
||||
describe('execute', () => {
|
||||
it('should call mcpTool.callTool with correct parameters and format display output', async () => {
|
||||
const tool = new DiscoveredMCPTool(
|
||||
mockCallableToolInstance,
|
||||
serverName,
|
||||
toolNameForModel,
|
||||
baseDescription,
|
||||
inputSchema,
|
||||
serverToolName,
|
||||
);
|
||||
const params = { param: 'testValue' };
|
||||
const mockToolSuccessResultObject = {
|
||||
success: true,
|
||||
details: 'executed',
|
||||
};
|
||||
const mockFunctionResponseContent: Part[] = [
|
||||
{ text: JSON.stringify(mockToolSuccessResultObject) },
|
||||
];
|
||||
const mockMcpToolResponseParts: Part[] = [
|
||||
{
|
||||
functionResponse: {
|
||||
name: serverToolName,
|
||||
response: { content: mockFunctionResponseContent },
|
||||
},
|
||||
},
|
||||
];
|
||||
mockCallTool.mockResolvedValue(mockMcpToolResponseParts);
|
||||
|
||||
const toolResult: ToolResult = await tool.execute(params);
|
||||
|
||||
expect(mockCallTool).toHaveBeenCalledWith([
|
||||
{ name: serverToolName, args: params },
|
||||
]);
|
||||
expect(toolResult.llmContent).toEqual(mockMcpToolResponseParts);
|
||||
|
||||
const stringifiedResponseContent = JSON.stringify(
|
||||
mockToolSuccessResultObject,
|
||||
);
|
||||
expect(toolResult.returnDisplay).toBe(stringifiedResponseContent);
|
||||
});
|
||||
|
||||
it('should handle empty result from getStringifiedResultForDisplay', async () => {
|
||||
const tool = new DiscoveredMCPTool(
|
||||
mockCallableToolInstance,
|
||||
serverName,
|
||||
toolNameForModel,
|
||||
baseDescription,
|
||||
inputSchema,
|
||||
serverToolName,
|
||||
);
|
||||
const params = { param: 'testValue' };
|
||||
const mockMcpToolResponsePartsEmpty: Part[] = [];
|
||||
mockCallTool.mockResolvedValue(mockMcpToolResponsePartsEmpty);
|
||||
const toolResult: ToolResult = await tool.execute(params);
|
||||
expect(toolResult.returnDisplay).toBe('```json\n[]\n```');
|
||||
});
|
||||
|
||||
it('should propagate rejection if mcpTool.callTool rejects', async () => {
|
||||
const tool = new DiscoveredMCPTool(
|
||||
mockCallableToolInstance,
|
||||
serverName,
|
||||
toolNameForModel,
|
||||
baseDescription,
|
||||
inputSchema,
|
||||
serverToolName,
|
||||
);
|
||||
const params = { param: 'failCase' };
|
||||
const expectedError = new Error('MCP call failed');
|
||||
mockCallTool.mockRejectedValue(expectedError);
|
||||
|
||||
await expect(tool.execute(params)).rejects.toThrow(expectedError);
|
||||
});
|
||||
});
|
||||
|
||||
describe('shouldConfirmExecute', () => {
|
||||
// beforeEach is already clearing allowlist
|
||||
|
||||
it('should return false if trust is true', async () => {
|
||||
const tool = new DiscoveredMCPTool(
|
||||
mockCallableToolInstance,
|
||||
serverName,
|
||||
toolNameForModel,
|
||||
baseDescription,
|
||||
inputSchema,
|
||||
serverToolName,
|
||||
undefined,
|
||||
true,
|
||||
);
|
||||
expect(
|
||||
await tool.shouldConfirmExecute({}, new AbortController().signal),
|
||||
).toBe(false);
|
||||
});
|
||||
|
||||
it('should return false if server is allowlisted', async () => {
|
||||
(DiscoveredMCPTool as any).allowlist.add(serverName);
|
||||
const tool = new DiscoveredMCPTool(
|
||||
mockCallableToolInstance,
|
||||
serverName,
|
||||
toolNameForModel,
|
||||
baseDescription,
|
||||
inputSchema,
|
||||
serverToolName,
|
||||
);
|
||||
expect(
|
||||
await tool.shouldConfirmExecute({}, new AbortController().signal),
|
||||
).toBe(false);
|
||||
});
|
||||
|
||||
it('should return false if tool is allowlisted', async () => {
|
||||
const toolAllowlistKey = `${serverName}.${serverToolName}`;
|
||||
(DiscoveredMCPTool as any).allowlist.add(toolAllowlistKey);
|
||||
const tool = new DiscoveredMCPTool(
|
||||
mockCallableToolInstance,
|
||||
serverName,
|
||||
toolNameForModel,
|
||||
baseDescription,
|
||||
inputSchema,
|
||||
serverToolName,
|
||||
);
|
||||
expect(
|
||||
await tool.shouldConfirmExecute({}, new AbortController().signal),
|
||||
).toBe(false);
|
||||
});
|
||||
|
||||
it('should return confirmation details if not trusted and not allowlisted', async () => {
|
||||
const tool = new DiscoveredMCPTool(
|
||||
mockCallableToolInstance,
|
||||
serverName,
|
||||
toolNameForModel,
|
||||
baseDescription,
|
||||
inputSchema,
|
||||
serverToolName,
|
||||
);
|
||||
const confirmation = await tool.shouldConfirmExecute(
|
||||
{},
|
||||
new AbortController().signal,
|
||||
);
|
||||
expect(confirmation).not.toBe(false);
|
||||
if (confirmation && confirmation.type === 'mcp') {
|
||||
// Type guard for ToolMcpConfirmationDetails
|
||||
expect(confirmation.type).toBe('mcp');
|
||||
expect(confirmation.serverName).toBe(serverName);
|
||||
expect(confirmation.toolName).toBe(serverToolName);
|
||||
} else if (confirmation) {
|
||||
// Handle other possible confirmation types if necessary, or strengthen test if only MCP is expected
|
||||
throw new Error(
|
||||
'Confirmation was not of expected type MCP or was false',
|
||||
);
|
||||
} else {
|
||||
throw new Error(
|
||||
'Confirmation details not in expected format or was false',
|
||||
);
|
||||
}
|
||||
});
|
||||
|
||||
it('should add server to allowlist on ProceedAlwaysServer', async () => {
|
||||
const tool = new DiscoveredMCPTool(
|
||||
mockCallableToolInstance,
|
||||
serverName,
|
||||
toolNameForModel,
|
||||
baseDescription,
|
||||
inputSchema,
|
||||
serverToolName,
|
||||
);
|
||||
const confirmation = await tool.shouldConfirmExecute(
|
||||
{},
|
||||
new AbortController().signal,
|
||||
);
|
||||
expect(confirmation).not.toBe(false);
|
||||
if (
|
||||
confirmation &&
|
||||
typeof confirmation === 'object' &&
|
||||
'onConfirm' in confirmation &&
|
||||
typeof confirmation.onConfirm === 'function'
|
||||
) {
|
||||
await confirmation.onConfirm(
|
||||
ToolConfirmationOutcome.ProceedAlwaysServer,
|
||||
);
|
||||
expect((DiscoveredMCPTool as any).allowlist.has(serverName)).toBe(true);
|
||||
} else {
|
||||
throw new Error(
|
||||
'Confirmation details or onConfirm not in expected format',
|
||||
);
|
||||
}
|
||||
});
|
||||
|
||||
it('should add tool to allowlist on ProceedAlwaysTool', async () => {
|
||||
const tool = new DiscoveredMCPTool(
|
||||
mockCallableToolInstance,
|
||||
serverName,
|
||||
toolNameForModel,
|
||||
baseDescription,
|
||||
inputSchema,
|
||||
serverToolName,
|
||||
);
|
||||
const toolAllowlistKey = `${serverName}.${serverToolName}`;
|
||||
const confirmation = await tool.shouldConfirmExecute(
|
||||
{},
|
||||
new AbortController().signal,
|
||||
);
|
||||
expect(confirmation).not.toBe(false);
|
||||
if (
|
||||
confirmation &&
|
||||
typeof confirmation === 'object' &&
|
||||
'onConfirm' in confirmation &&
|
||||
typeof confirmation.onConfirm === 'function'
|
||||
) {
|
||||
await confirmation.onConfirm(ToolConfirmationOutcome.ProceedAlwaysTool);
|
||||
expect((DiscoveredMCPTool as any).allowlist.has(toolAllowlistKey)).toBe(
|
||||
true,
|
||||
);
|
||||
} else {
|
||||
throw new Error(
|
||||
'Confirmation details or onConfirm not in expected format',
|
||||
);
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
148
packages/core/src/tools/mcp-tool.ts
Normal file
148
packages/core/src/tools/mcp-tool.ts
Normal file
@@ -0,0 +1,148 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import {
|
||||
BaseTool,
|
||||
ToolResult,
|
||||
ToolCallConfirmationDetails,
|
||||
ToolConfirmationOutcome,
|
||||
ToolMcpConfirmationDetails,
|
||||
} from './tools.js';
|
||||
import { CallableTool, Part, FunctionCall, Schema } from '@google/genai';
|
||||
|
||||
type ToolParams = Record<string, unknown>;
|
||||
|
||||
export class DiscoveredMCPTool extends BaseTool<ToolParams, ToolResult> {
|
||||
private static readonly allowlist: Set<string> = new Set();
|
||||
|
||||
constructor(
|
||||
private readonly mcpTool: CallableTool,
|
||||
readonly serverName: string,
|
||||
readonly name: string,
|
||||
readonly description: string,
|
||||
readonly parameterSchema: Schema,
|
||||
readonly serverToolName: string,
|
||||
readonly timeout?: number,
|
||||
readonly trust?: boolean,
|
||||
) {
|
||||
super(
|
||||
name,
|
||||
`${serverToolName} (${serverName} MCP Server)`,
|
||||
description,
|
||||
parameterSchema,
|
||||
true, // isOutputMarkdown
|
||||
false, // canUpdateOutput
|
||||
);
|
||||
}
|
||||
|
||||
async shouldConfirmExecute(
|
||||
_params: ToolParams,
|
||||
_abortSignal: AbortSignal,
|
||||
): Promise<ToolCallConfirmationDetails | false> {
|
||||
const serverAllowListKey = this.serverName;
|
||||
const toolAllowListKey = `${this.serverName}.${this.serverToolName}`;
|
||||
|
||||
if (this.trust) {
|
||||
return false; // server is trusted, no confirmation needed
|
||||
}
|
||||
|
||||
if (
|
||||
DiscoveredMCPTool.allowlist.has(serverAllowListKey) ||
|
||||
DiscoveredMCPTool.allowlist.has(toolAllowListKey)
|
||||
) {
|
||||
return false; // server and/or tool already allow listed
|
||||
}
|
||||
|
||||
const confirmationDetails: ToolMcpConfirmationDetails = {
|
||||
type: 'mcp',
|
||||
title: 'Confirm MCP Tool Execution',
|
||||
serverName: this.serverName,
|
||||
toolName: this.serverToolName, // Display original tool name in confirmation
|
||||
toolDisplayName: this.name, // Display global registry name exposed to model and user
|
||||
onConfirm: async (outcome: ToolConfirmationOutcome) => {
|
||||
if (outcome === ToolConfirmationOutcome.ProceedAlwaysServer) {
|
||||
DiscoveredMCPTool.allowlist.add(serverAllowListKey);
|
||||
} else if (outcome === ToolConfirmationOutcome.ProceedAlwaysTool) {
|
||||
DiscoveredMCPTool.allowlist.add(toolAllowListKey);
|
||||
}
|
||||
},
|
||||
};
|
||||
return confirmationDetails;
|
||||
}
|
||||
|
||||
async execute(params: ToolParams): Promise<ToolResult> {
|
||||
const functionCalls: FunctionCall[] = [
|
||||
{
|
||||
name: this.serverToolName,
|
||||
args: params,
|
||||
},
|
||||
];
|
||||
|
||||
const responseParts: Part[] = await this.mcpTool.callTool(functionCalls);
|
||||
|
||||
return {
|
||||
llmContent: responseParts,
|
||||
returnDisplay: getStringifiedResultForDisplay(responseParts),
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Processes an array of `Part` objects, primarily from a tool's execution result,
|
||||
* to generate a user-friendly string representation, typically for display in a CLI.
|
||||
*
|
||||
* The `result` array can contain various types of `Part` objects:
|
||||
* 1. `FunctionResponse` parts:
|
||||
* - If the `response.content` of a `FunctionResponse` is an array consisting solely
|
||||
* of `TextPart` objects, their text content is concatenated into a single string.
|
||||
* This is to present simple textual outputs directly.
|
||||
* - If `response.content` is an array but contains other types of `Part` objects (or a mix),
|
||||
* the `content` array itself is preserved. This handles structured data like JSON objects or arrays
|
||||
* returned by a tool.
|
||||
* - If `response.content` is not an array or is missing, the entire `functionResponse`
|
||||
* object is preserved.
|
||||
* 2. Other `Part` types (e.g., `TextPart` directly in the `result` array):
|
||||
* - These are preserved as is.
|
||||
*
|
||||
* All processed parts are then collected into an array, which is JSON.stringify-ed
|
||||
* with indentation and wrapped in a markdown JSON code block.
|
||||
*/
|
||||
function getStringifiedResultForDisplay(result: Part[]) {
|
||||
if (!result || result.length === 0) {
|
||||
return '```json\n[]\n```';
|
||||
}
|
||||
|
||||
const processFunctionResponse = (part: Part) => {
|
||||
if (part.functionResponse) {
|
||||
const responseContent = part.functionResponse.response?.content;
|
||||
if (responseContent && Array.isArray(responseContent)) {
|
||||
// Check if all parts in responseContent are simple TextParts
|
||||
const allTextParts = responseContent.every(
|
||||
(p: Part) => p.text !== undefined,
|
||||
);
|
||||
if (allTextParts) {
|
||||
return responseContent.map((p: Part) => p.text).join('');
|
||||
}
|
||||
// If not all simple text parts, return the array of these content parts for JSON stringification
|
||||
return responseContent;
|
||||
}
|
||||
|
||||
// If no content, or not an array, or not a functionResponse, stringify the whole functionResponse part for inspection
|
||||
return part.functionResponse;
|
||||
}
|
||||
return part; // Fallback for unexpected structure or non-FunctionResponsePart
|
||||
};
|
||||
|
||||
const processedResults =
|
||||
result.length === 1
|
||||
? processFunctionResponse(result[0])
|
||||
: result.map(processFunctionResponse);
|
||||
if (typeof processedResults === 'string') {
|
||||
return processedResults;
|
||||
}
|
||||
|
||||
return '```json\n' + JSON.stringify(processedResults, null, 2) + '\n```';
|
||||
}
|
||||
265
packages/core/src/tools/memoryTool.test.ts
Normal file
265
packages/core/src/tools/memoryTool.test.ts
Normal file
@@ -0,0 +1,265 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { vi, describe, it, expect, beforeEach, afterEach, Mock } from 'vitest';
|
||||
import {
|
||||
MemoryTool,
|
||||
setGeminiMdFilename,
|
||||
getCurrentGeminiMdFilename,
|
||||
getAllGeminiMdFilenames,
|
||||
DEFAULT_CONTEXT_FILENAME,
|
||||
} from './memoryTool.js';
|
||||
import * as fs from 'fs/promises';
|
||||
import * as path from 'path';
|
||||
import * as os from 'os';
|
||||
|
||||
// Mock dependencies
|
||||
vi.mock('fs/promises');
|
||||
vi.mock('os');
|
||||
|
||||
const MEMORY_SECTION_HEADER = '## Gemini Added Memories';
|
||||
|
||||
// Define a type for our fsAdapter to ensure consistency
|
||||
interface FsAdapter {
|
||||
readFile: (path: string, encoding: 'utf-8') => Promise<string>;
|
||||
writeFile: (path: string, data: string, encoding: 'utf-8') => Promise<void>;
|
||||
mkdir: (
|
||||
path: string,
|
||||
options: { recursive: boolean },
|
||||
) => Promise<string | undefined>;
|
||||
}
|
||||
|
||||
describe('MemoryTool', () => {
|
||||
const mockAbortSignal = new AbortController().signal;
|
||||
|
||||
const mockFsAdapter: {
|
||||
readFile: Mock<FsAdapter['readFile']>;
|
||||
writeFile: Mock<FsAdapter['writeFile']>;
|
||||
mkdir: Mock<FsAdapter['mkdir']>;
|
||||
} = {
|
||||
readFile: vi.fn(),
|
||||
writeFile: vi.fn(),
|
||||
mkdir: vi.fn(),
|
||||
};
|
||||
|
||||
beforeEach(() => {
|
||||
vi.mocked(os.homedir).mockReturnValue('/mock/home');
|
||||
mockFsAdapter.readFile.mockReset();
|
||||
mockFsAdapter.writeFile.mockReset().mockResolvedValue(undefined);
|
||||
mockFsAdapter.mkdir
|
||||
.mockReset()
|
||||
.mockResolvedValue(undefined as string | undefined);
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.restoreAllMocks();
|
||||
// Reset GEMINI_MD_FILENAME to its original value after each test
|
||||
setGeminiMdFilename(DEFAULT_CONTEXT_FILENAME);
|
||||
});
|
||||
|
||||
describe('setGeminiMdFilename', () => {
|
||||
it('should update currentGeminiMdFilename when a valid new name is provided', () => {
|
||||
const newName = 'CUSTOM_CONTEXT.md';
|
||||
setGeminiMdFilename(newName);
|
||||
expect(getCurrentGeminiMdFilename()).toBe(newName);
|
||||
});
|
||||
|
||||
it('should not update currentGeminiMdFilename if the new name is empty or whitespace', () => {
|
||||
const initialName = getCurrentGeminiMdFilename(); // Get current before trying to change
|
||||
setGeminiMdFilename(' ');
|
||||
expect(getCurrentGeminiMdFilename()).toBe(initialName);
|
||||
|
||||
setGeminiMdFilename('');
|
||||
expect(getCurrentGeminiMdFilename()).toBe(initialName);
|
||||
});
|
||||
|
||||
it('should handle an array of filenames', () => {
|
||||
const newNames = ['CUSTOM_CONTEXT.md', 'ANOTHER_CONTEXT.md'];
|
||||
setGeminiMdFilename(newNames);
|
||||
expect(getCurrentGeminiMdFilename()).toBe('CUSTOM_CONTEXT.md');
|
||||
expect(getAllGeminiMdFilenames()).toEqual(newNames);
|
||||
});
|
||||
});
|
||||
|
||||
describe('performAddMemoryEntry (static method)', () => {
|
||||
const testFilePath = path.join(
|
||||
'/mock/home',
|
||||
'.qwen',
|
||||
DEFAULT_CONTEXT_FILENAME, // Use the default for basic tests
|
||||
);
|
||||
|
||||
it('should create section and save a fact if file does not exist', async () => {
|
||||
mockFsAdapter.readFile.mockRejectedValue({ code: 'ENOENT' }); // Simulate file not found
|
||||
const fact = 'The sky is blue';
|
||||
await MemoryTool.performAddMemoryEntry(fact, testFilePath, mockFsAdapter);
|
||||
|
||||
expect(mockFsAdapter.mkdir).toHaveBeenCalledWith(
|
||||
path.dirname(testFilePath),
|
||||
{
|
||||
recursive: true,
|
||||
},
|
||||
);
|
||||
expect(mockFsAdapter.writeFile).toHaveBeenCalledOnce();
|
||||
const writeFileCall = mockFsAdapter.writeFile.mock.calls[0];
|
||||
expect(writeFileCall[0]).toBe(testFilePath);
|
||||
const expectedContent = `${MEMORY_SECTION_HEADER}\n- ${fact}\n`;
|
||||
expect(writeFileCall[1]).toBe(expectedContent);
|
||||
expect(writeFileCall[2]).toBe('utf-8');
|
||||
});
|
||||
|
||||
it('should create section and save a fact if file is empty', async () => {
|
||||
mockFsAdapter.readFile.mockResolvedValue(''); // Simulate empty file
|
||||
const fact = 'The sky is blue';
|
||||
await MemoryTool.performAddMemoryEntry(fact, testFilePath, mockFsAdapter);
|
||||
const writeFileCall = mockFsAdapter.writeFile.mock.calls[0];
|
||||
const expectedContent = `${MEMORY_SECTION_HEADER}\n- ${fact}\n`;
|
||||
expect(writeFileCall[1]).toBe(expectedContent);
|
||||
});
|
||||
|
||||
it('should add a fact to an existing section', async () => {
|
||||
const initialContent = `Some preamble.\n\n${MEMORY_SECTION_HEADER}\n- Existing fact 1\n`;
|
||||
mockFsAdapter.readFile.mockResolvedValue(initialContent);
|
||||
const fact = 'New fact 2';
|
||||
await MemoryTool.performAddMemoryEntry(fact, testFilePath, mockFsAdapter);
|
||||
|
||||
expect(mockFsAdapter.writeFile).toHaveBeenCalledOnce();
|
||||
const writeFileCall = mockFsAdapter.writeFile.mock.calls[0];
|
||||
const expectedContent = `Some preamble.\n\n${MEMORY_SECTION_HEADER}\n- Existing fact 1\n- ${fact}\n`;
|
||||
expect(writeFileCall[1]).toBe(expectedContent);
|
||||
});
|
||||
|
||||
it('should add a fact to an existing empty section', async () => {
|
||||
const initialContent = `Some preamble.\n\n${MEMORY_SECTION_HEADER}\n`; // Empty section
|
||||
mockFsAdapter.readFile.mockResolvedValue(initialContent);
|
||||
const fact = 'First fact in section';
|
||||
await MemoryTool.performAddMemoryEntry(fact, testFilePath, mockFsAdapter);
|
||||
|
||||
expect(mockFsAdapter.writeFile).toHaveBeenCalledOnce();
|
||||
const writeFileCall = mockFsAdapter.writeFile.mock.calls[0];
|
||||
const expectedContent = `Some preamble.\n\n${MEMORY_SECTION_HEADER}\n- ${fact}\n`;
|
||||
expect(writeFileCall[1]).toBe(expectedContent);
|
||||
});
|
||||
|
||||
it('should add a fact when other ## sections exist and preserve spacing', async () => {
|
||||
const initialContent = `${MEMORY_SECTION_HEADER}\n- Fact 1\n\n## Another Section\nSome other text.`;
|
||||
mockFsAdapter.readFile.mockResolvedValue(initialContent);
|
||||
const fact = 'Fact 2';
|
||||
await MemoryTool.performAddMemoryEntry(fact, testFilePath, mockFsAdapter);
|
||||
|
||||
expect(mockFsAdapter.writeFile).toHaveBeenCalledOnce();
|
||||
const writeFileCall = mockFsAdapter.writeFile.mock.calls[0];
|
||||
// Note: The implementation ensures a single newline at the end if content exists.
|
||||
const expectedContent = `${MEMORY_SECTION_HEADER}\n- Fact 1\n- ${fact}\n\n## Another Section\nSome other text.\n`;
|
||||
expect(writeFileCall[1]).toBe(expectedContent);
|
||||
});
|
||||
|
||||
it('should correctly trim and add a fact that starts with a dash', async () => {
|
||||
mockFsAdapter.readFile.mockResolvedValue(`${MEMORY_SECTION_HEADER}\n`);
|
||||
const fact = '- - My fact with dashes';
|
||||
await MemoryTool.performAddMemoryEntry(fact, testFilePath, mockFsAdapter);
|
||||
const writeFileCall = mockFsAdapter.writeFile.mock.calls[0];
|
||||
const expectedContent = `${MEMORY_SECTION_HEADER}\n- My fact with dashes\n`;
|
||||
expect(writeFileCall[1]).toBe(expectedContent);
|
||||
});
|
||||
|
||||
it('should handle error from fsAdapter.writeFile', async () => {
|
||||
mockFsAdapter.readFile.mockResolvedValue('');
|
||||
mockFsAdapter.writeFile.mockRejectedValue(new Error('Disk full'));
|
||||
const fact = 'This will fail';
|
||||
await expect(
|
||||
MemoryTool.performAddMemoryEntry(fact, testFilePath, mockFsAdapter),
|
||||
).rejects.toThrow('[MemoryTool] Failed to add memory entry: Disk full');
|
||||
});
|
||||
});
|
||||
|
||||
describe('execute (instance method)', () => {
|
||||
let memoryTool: MemoryTool;
|
||||
let performAddMemoryEntrySpy: Mock<typeof MemoryTool.performAddMemoryEntry>;
|
||||
|
||||
beforeEach(() => {
|
||||
memoryTool = new MemoryTool();
|
||||
// Spy on the static method for these tests
|
||||
performAddMemoryEntrySpy = vi
|
||||
.spyOn(MemoryTool, 'performAddMemoryEntry')
|
||||
.mockResolvedValue(undefined) as Mock<
|
||||
typeof MemoryTool.performAddMemoryEntry
|
||||
>;
|
||||
// Cast needed as spyOn returns MockInstance
|
||||
});
|
||||
|
||||
it('should have correct name, displayName, description, and schema', () => {
|
||||
expect(memoryTool.name).toBe('save_memory');
|
||||
expect(memoryTool.displayName).toBe('Save Memory');
|
||||
expect(memoryTool.description).toContain(
|
||||
'Saves a specific piece of information',
|
||||
);
|
||||
expect(memoryTool.schema).toBeDefined();
|
||||
expect(memoryTool.schema.name).toBe('save_memory');
|
||||
expect(memoryTool.schema.parameters?.properties?.fact).toBeDefined();
|
||||
});
|
||||
|
||||
it('should call performAddMemoryEntry with correct parameters and return success', async () => {
|
||||
const params = { fact: 'The sky is blue' };
|
||||
const result = await memoryTool.execute(params, mockAbortSignal);
|
||||
// Use getCurrentGeminiMdFilename for the default expectation before any setGeminiMdFilename calls in a test
|
||||
const expectedFilePath = path.join(
|
||||
'/mock/home',
|
||||
'.qwen',
|
||||
getCurrentGeminiMdFilename(), // This will be DEFAULT_CONTEXT_FILENAME unless changed by a test
|
||||
);
|
||||
|
||||
// For this test, we expect the actual fs methods to be passed
|
||||
const expectedFsArgument = {
|
||||
readFile: fs.readFile,
|
||||
writeFile: fs.writeFile,
|
||||
mkdir: fs.mkdir,
|
||||
};
|
||||
|
||||
expect(performAddMemoryEntrySpy).toHaveBeenCalledWith(
|
||||
params.fact,
|
||||
expectedFilePath,
|
||||
expectedFsArgument,
|
||||
);
|
||||
const successMessage = `Okay, I've remembered that: "${params.fact}"`;
|
||||
expect(result.llmContent).toBe(
|
||||
JSON.stringify({ success: true, message: successMessage }),
|
||||
);
|
||||
expect(result.returnDisplay).toBe(successMessage);
|
||||
});
|
||||
|
||||
it('should return an error if fact is empty', async () => {
|
||||
const params = { fact: ' ' }; // Empty fact
|
||||
const result = await memoryTool.execute(params, mockAbortSignal);
|
||||
const errorMessage = 'Parameter "fact" must be a non-empty string.';
|
||||
|
||||
expect(performAddMemoryEntrySpy).not.toHaveBeenCalled();
|
||||
expect(result.llmContent).toBe(
|
||||
JSON.stringify({ success: false, error: errorMessage }),
|
||||
);
|
||||
expect(result.returnDisplay).toBe(`Error: ${errorMessage}`);
|
||||
});
|
||||
|
||||
it('should handle errors from performAddMemoryEntry', async () => {
|
||||
const params = { fact: 'This will fail' };
|
||||
const underlyingError = new Error(
|
||||
'[MemoryTool] Failed to add memory entry: Disk full',
|
||||
);
|
||||
performAddMemoryEntrySpy.mockRejectedValue(underlyingError);
|
||||
|
||||
const result = await memoryTool.execute(params, mockAbortSignal);
|
||||
|
||||
expect(result.llmContent).toBe(
|
||||
JSON.stringify({
|
||||
success: false,
|
||||
error: `Failed to save memory. Detail: ${underlyingError.message}`,
|
||||
}),
|
||||
);
|
||||
expect(result.returnDisplay).toBe(
|
||||
`Error saving memory: ${underlyingError.message}`,
|
||||
);
|
||||
});
|
||||
});
|
||||
});
|
||||
223
packages/core/src/tools/memoryTool.ts
Normal file
223
packages/core/src/tools/memoryTool.ts
Normal file
@@ -0,0 +1,223 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { BaseTool, ToolResult } from './tools.js';
|
||||
import { FunctionDeclaration, Type } from '@google/genai';
|
||||
import * as fs from 'fs/promises';
|
||||
import * as path from 'path';
|
||||
import { homedir } from 'os';
|
||||
|
||||
const memoryToolSchemaData: FunctionDeclaration = {
|
||||
name: 'save_memory',
|
||||
description:
|
||||
'Saves a specific piece of information or fact to your long-term memory. Use this when the user explicitly asks you to remember something, or when they state a clear, concise fact that seems important to retain for future interactions.',
|
||||
parameters: {
|
||||
type: Type.OBJECT,
|
||||
properties: {
|
||||
fact: {
|
||||
type: Type.STRING,
|
||||
description:
|
||||
'The specific fact or piece of information to remember. Should be a clear, self-contained statement.',
|
||||
},
|
||||
},
|
||||
required: ['fact'],
|
||||
},
|
||||
};
|
||||
|
||||
const memoryToolDescription = `
|
||||
Saves a specific piece of information or fact to your long-term memory.
|
||||
|
||||
Use this tool:
|
||||
|
||||
- When the user explicitly asks you to remember something (e.g., "Remember that I like pineapple on pizza", "Please save this: my cat's name is Whiskers").
|
||||
- When the user states a clear, concise fact about themselves, their preferences, or their environment that seems important for you to retain for future interactions to provide a more personalized and effective assistance.
|
||||
|
||||
Do NOT use this tool:
|
||||
|
||||
- To remember conversational context that is only relevant for the current session.
|
||||
- To save long, complex, or rambling pieces of text. The fact should be relatively short and to the point.
|
||||
- If you are unsure whether the information is a fact worth remembering long-term. If in doubt, you can ask the user, "Should I remember that for you?"
|
||||
|
||||
## Parameters
|
||||
|
||||
- \`fact\` (string, required): The specific fact or piece of information to remember. This should be a clear, self-contained statement. For example, if the user says "My favorite color is blue", the fact would be "My favorite color is blue".
|
||||
`;
|
||||
|
||||
export const GEMINI_CONFIG_DIR = '.qwen';
|
||||
export const DEFAULT_CONTEXT_FILENAME = 'QWEN.md';
|
||||
export const MEMORY_SECTION_HEADER = '## Gemini Added Memories';
|
||||
|
||||
// This variable will hold the currently configured filename for GEMINI.md context files.
|
||||
// It defaults to DEFAULT_CONTEXT_FILENAME but can be overridden by setGeminiMdFilename.
|
||||
let currentGeminiMdFilename: string | string[] = DEFAULT_CONTEXT_FILENAME;
|
||||
|
||||
export function setGeminiMdFilename(newFilename: string | string[]): void {
|
||||
if (Array.isArray(newFilename)) {
|
||||
if (newFilename.length > 0) {
|
||||
currentGeminiMdFilename = newFilename.map((name) => name.trim());
|
||||
}
|
||||
} else if (newFilename && newFilename.trim() !== '') {
|
||||
currentGeminiMdFilename = newFilename.trim();
|
||||
}
|
||||
}
|
||||
|
||||
export function getCurrentGeminiMdFilename(): string {
|
||||
if (Array.isArray(currentGeminiMdFilename)) {
|
||||
return currentGeminiMdFilename[0];
|
||||
}
|
||||
return currentGeminiMdFilename;
|
||||
}
|
||||
|
||||
export function getAllGeminiMdFilenames(): string[] {
|
||||
if (Array.isArray(currentGeminiMdFilename)) {
|
||||
return currentGeminiMdFilename;
|
||||
}
|
||||
return [currentGeminiMdFilename];
|
||||
}
|
||||
|
||||
interface SaveMemoryParams {
|
||||
fact: string;
|
||||
}
|
||||
|
||||
function getGlobalMemoryFilePath(): string {
|
||||
return path.join(homedir(), GEMINI_CONFIG_DIR, getCurrentGeminiMdFilename());
|
||||
}
|
||||
|
||||
/**
|
||||
* Ensures proper newline separation before appending content.
|
||||
*/
|
||||
function ensureNewlineSeparation(currentContent: string): string {
|
||||
if (currentContent.length === 0) return '';
|
||||
if (currentContent.endsWith('\n\n') || currentContent.endsWith('\r\n\r\n'))
|
||||
return '';
|
||||
if (currentContent.endsWith('\n') || currentContent.endsWith('\r\n'))
|
||||
return '\n';
|
||||
return '\n\n';
|
||||
}
|
||||
|
||||
export class MemoryTool extends BaseTool<SaveMemoryParams, ToolResult> {
|
||||
static readonly Name: string = memoryToolSchemaData.name!;
|
||||
constructor() {
|
||||
super(
|
||||
MemoryTool.Name,
|
||||
'Save Memory',
|
||||
memoryToolDescription,
|
||||
memoryToolSchemaData.parameters as Record<string, unknown>,
|
||||
);
|
||||
}
|
||||
|
||||
static async performAddMemoryEntry(
|
||||
text: string,
|
||||
memoryFilePath: string,
|
||||
fsAdapter: {
|
||||
readFile: (path: string, encoding: 'utf-8') => Promise<string>;
|
||||
writeFile: (
|
||||
path: string,
|
||||
data: string,
|
||||
encoding: 'utf-8',
|
||||
) => Promise<void>;
|
||||
mkdir: (
|
||||
path: string,
|
||||
options: { recursive: boolean },
|
||||
) => Promise<string | undefined>;
|
||||
},
|
||||
): Promise<void> {
|
||||
let processedText = text.trim();
|
||||
// Remove leading hyphens and spaces that might be misinterpreted as markdown list items
|
||||
processedText = processedText.replace(/^(-+\s*)+/, '').trim();
|
||||
const newMemoryItem = `- ${processedText}`;
|
||||
|
||||
try {
|
||||
await fsAdapter.mkdir(path.dirname(memoryFilePath), { recursive: true });
|
||||
let content = '';
|
||||
try {
|
||||
content = await fsAdapter.readFile(memoryFilePath, 'utf-8');
|
||||
} catch (_e) {
|
||||
// File doesn't exist, will be created with header and item.
|
||||
}
|
||||
|
||||
const headerIndex = content.indexOf(MEMORY_SECTION_HEADER);
|
||||
|
||||
if (headerIndex === -1) {
|
||||
// Header not found, append header and then the entry
|
||||
const separator = ensureNewlineSeparation(content);
|
||||
content += `${separator}${MEMORY_SECTION_HEADER}\n${newMemoryItem}\n`;
|
||||
} else {
|
||||
// Header found, find where to insert the new memory entry
|
||||
const startOfSectionContent =
|
||||
headerIndex + MEMORY_SECTION_HEADER.length;
|
||||
let endOfSectionIndex = content.indexOf('\n## ', startOfSectionContent);
|
||||
if (endOfSectionIndex === -1) {
|
||||
endOfSectionIndex = content.length; // End of file
|
||||
}
|
||||
|
||||
const beforeSectionMarker = content
|
||||
.substring(0, startOfSectionContent)
|
||||
.trimEnd();
|
||||
let sectionContent = content
|
||||
.substring(startOfSectionContent, endOfSectionIndex)
|
||||
.trimEnd();
|
||||
const afterSectionMarker = content.substring(endOfSectionIndex);
|
||||
|
||||
sectionContent += `\n${newMemoryItem}`;
|
||||
content =
|
||||
`${beforeSectionMarker}\n${sectionContent.trimStart()}\n${afterSectionMarker}`.trimEnd() +
|
||||
'\n';
|
||||
}
|
||||
await fsAdapter.writeFile(memoryFilePath, content, 'utf-8');
|
||||
} catch (error) {
|
||||
console.error(
|
||||
`[MemoryTool] Error adding memory entry to ${memoryFilePath}:`,
|
||||
error,
|
||||
);
|
||||
throw new Error(
|
||||
`[MemoryTool] Failed to add memory entry: ${error instanceof Error ? error.message : String(error)}`,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
async execute(
|
||||
params: SaveMemoryParams,
|
||||
_signal: AbortSignal,
|
||||
): Promise<ToolResult> {
|
||||
const { fact } = params;
|
||||
|
||||
if (!fact || typeof fact !== 'string' || fact.trim() === '') {
|
||||
const errorMessage = 'Parameter "fact" must be a non-empty string.';
|
||||
return {
|
||||
llmContent: JSON.stringify({ success: false, error: errorMessage }),
|
||||
returnDisplay: `Error: ${errorMessage}`,
|
||||
};
|
||||
}
|
||||
|
||||
try {
|
||||
// Use the static method with actual fs promises
|
||||
await MemoryTool.performAddMemoryEntry(fact, getGlobalMemoryFilePath(), {
|
||||
readFile: fs.readFile,
|
||||
writeFile: fs.writeFile,
|
||||
mkdir: fs.mkdir,
|
||||
});
|
||||
const successMessage = `Okay, I've remembered that: "${fact}"`;
|
||||
return {
|
||||
llmContent: JSON.stringify({ success: true, message: successMessage }),
|
||||
returnDisplay: successMessage,
|
||||
};
|
||||
} catch (error) {
|
||||
const errorMessage =
|
||||
error instanceof Error ? error.message : String(error);
|
||||
console.error(
|
||||
`[MemoryTool] Error executing save_memory for fact "${fact}": ${errorMessage}`,
|
||||
);
|
||||
return {
|
||||
llmContent: JSON.stringify({
|
||||
success: false,
|
||||
error: `Failed to save memory. Detail: ${errorMessage}`,
|
||||
}),
|
||||
returnDisplay: `Error saving memory: ${errorMessage}`,
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
393
packages/core/src/tools/modifiable-tool.test.ts
Normal file
393
packages/core/src/tools/modifiable-tool.test.ts
Normal file
@@ -0,0 +1,393 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import {
|
||||
vi,
|
||||
describe,
|
||||
it,
|
||||
expect,
|
||||
beforeEach,
|
||||
afterEach,
|
||||
type Mock,
|
||||
} from 'vitest';
|
||||
import {
|
||||
modifyWithEditor,
|
||||
ModifyContext,
|
||||
ModifiableTool,
|
||||
isModifiableTool,
|
||||
} from './modifiable-tool.js';
|
||||
import { EditorType } from '../utils/editor.js';
|
||||
import fs from 'fs';
|
||||
import os from 'os';
|
||||
import * as path from 'path';
|
||||
|
||||
// Mock dependencies
|
||||
const mockOpenDiff = vi.hoisted(() => vi.fn());
|
||||
const mockCreatePatch = vi.hoisted(() => vi.fn());
|
||||
|
||||
vi.mock('../utils/editor.js', () => ({
|
||||
openDiff: mockOpenDiff,
|
||||
}));
|
||||
|
||||
vi.mock('diff', () => ({
|
||||
createPatch: mockCreatePatch,
|
||||
}));
|
||||
|
||||
vi.mock('fs');
|
||||
vi.mock('os');
|
||||
|
||||
interface TestParams {
|
||||
filePath: string;
|
||||
someOtherParam: string;
|
||||
modifiedContent?: string;
|
||||
}
|
||||
|
||||
describe('modifyWithEditor', () => {
|
||||
let tempDir: string;
|
||||
let mockModifyContext: ModifyContext<TestParams>;
|
||||
let mockParams: TestParams;
|
||||
let currentContent: string;
|
||||
let proposedContent: string;
|
||||
let modifiedContent: string;
|
||||
let abortSignal: AbortSignal;
|
||||
|
||||
beforeEach(() => {
|
||||
vi.resetAllMocks();
|
||||
|
||||
tempDir = '/tmp/test-dir';
|
||||
abortSignal = new AbortController().signal;
|
||||
|
||||
currentContent = 'original content\nline 2\nline 3';
|
||||
proposedContent = 'modified content\nline 2\nline 3';
|
||||
modifiedContent = 'user modified content\nline 2\nline 3\nnew line';
|
||||
mockParams = {
|
||||
filePath: path.join(tempDir, 'test.txt'),
|
||||
someOtherParam: 'value',
|
||||
};
|
||||
|
||||
mockModifyContext = {
|
||||
getFilePath: vi.fn().mockReturnValue(mockParams.filePath),
|
||||
getCurrentContent: vi.fn().mockResolvedValue(currentContent),
|
||||
getProposedContent: vi.fn().mockResolvedValue(proposedContent),
|
||||
createUpdatedParams: vi
|
||||
.fn()
|
||||
.mockImplementation((oldContent, modifiedContent, originalParams) => ({
|
||||
...originalParams,
|
||||
modifiedContent,
|
||||
oldContent,
|
||||
})),
|
||||
};
|
||||
|
||||
(os.tmpdir as Mock).mockReturnValue(tempDir);
|
||||
|
||||
(fs.existsSync as Mock).mockReturnValue(true);
|
||||
(fs.mkdirSync as Mock).mockImplementation(() => undefined);
|
||||
(fs.writeFileSync as Mock).mockImplementation(() => {});
|
||||
(fs.unlinkSync as Mock).mockImplementation(() => {});
|
||||
|
||||
(fs.readFileSync as Mock).mockImplementation((filePath: string) => {
|
||||
if (filePath.includes('-new-')) {
|
||||
return modifiedContent;
|
||||
}
|
||||
return currentContent;
|
||||
});
|
||||
|
||||
mockCreatePatch.mockReturnValue('mock diff content');
|
||||
mockOpenDiff.mockResolvedValue(undefined);
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.restoreAllMocks();
|
||||
});
|
||||
|
||||
describe('successful modification', () => {
|
||||
it('should successfully modify content with VSCode editor', async () => {
|
||||
const result = await modifyWithEditor(
|
||||
mockParams,
|
||||
mockModifyContext,
|
||||
'vscode' as EditorType,
|
||||
abortSignal,
|
||||
);
|
||||
|
||||
expect(mockModifyContext.getCurrentContent).toHaveBeenCalledWith(
|
||||
mockParams,
|
||||
);
|
||||
expect(mockModifyContext.getProposedContent).toHaveBeenCalledWith(
|
||||
mockParams,
|
||||
);
|
||||
expect(mockModifyContext.getFilePath).toHaveBeenCalledWith(mockParams);
|
||||
|
||||
expect(fs.writeFileSync).toHaveBeenCalledTimes(2);
|
||||
expect(fs.writeFileSync).toHaveBeenNthCalledWith(
|
||||
1,
|
||||
expect.stringContaining(
|
||||
path.join(tempDir, 'gemini-cli-tool-modify-diffs'),
|
||||
),
|
||||
currentContent,
|
||||
'utf8',
|
||||
);
|
||||
expect(fs.writeFileSync).toHaveBeenNthCalledWith(
|
||||
2,
|
||||
expect.stringContaining(
|
||||
path.join(tempDir, 'gemini-cli-tool-modify-diffs'),
|
||||
),
|
||||
proposedContent,
|
||||
'utf8',
|
||||
);
|
||||
|
||||
expect(mockOpenDiff).toHaveBeenCalledWith(
|
||||
expect.stringContaining('-old-'),
|
||||
expect.stringContaining('-new-'),
|
||||
'vscode',
|
||||
);
|
||||
|
||||
expect(fs.readFileSync).toHaveBeenCalledWith(
|
||||
expect.stringContaining('-old-'),
|
||||
'utf8',
|
||||
);
|
||||
expect(fs.readFileSync).toHaveBeenCalledWith(
|
||||
expect.stringContaining('-new-'),
|
||||
'utf8',
|
||||
);
|
||||
|
||||
expect(mockModifyContext.createUpdatedParams).toHaveBeenCalledWith(
|
||||
currentContent,
|
||||
modifiedContent,
|
||||
mockParams,
|
||||
);
|
||||
|
||||
expect(mockCreatePatch).toHaveBeenCalledWith(
|
||||
path.basename(mockParams.filePath),
|
||||
currentContent,
|
||||
modifiedContent,
|
||||
'Current',
|
||||
'Proposed',
|
||||
expect.objectContaining({
|
||||
context: 3,
|
||||
ignoreWhitespace: true,
|
||||
}),
|
||||
);
|
||||
|
||||
expect(fs.unlinkSync).toHaveBeenCalledTimes(2);
|
||||
expect(fs.unlinkSync).toHaveBeenNthCalledWith(
|
||||
1,
|
||||
expect.stringContaining('-old-'),
|
||||
);
|
||||
expect(fs.unlinkSync).toHaveBeenNthCalledWith(
|
||||
2,
|
||||
expect.stringContaining('-new-'),
|
||||
);
|
||||
|
||||
expect(result).toEqual({
|
||||
updatedParams: {
|
||||
...mockParams,
|
||||
modifiedContent,
|
||||
oldContent: currentContent,
|
||||
},
|
||||
updatedDiff: 'mock diff content',
|
||||
});
|
||||
});
|
||||
|
||||
it('should create temp directory if it does not exist', async () => {
|
||||
(fs.existsSync as Mock).mockReturnValue(false);
|
||||
|
||||
await modifyWithEditor(
|
||||
mockParams,
|
||||
mockModifyContext,
|
||||
'vscode' as EditorType,
|
||||
abortSignal,
|
||||
);
|
||||
|
||||
expect(fs.mkdirSync).toHaveBeenCalledWith(
|
||||
path.join(tempDir, 'gemini-cli-tool-modify-diffs'),
|
||||
{ recursive: true },
|
||||
);
|
||||
});
|
||||
|
||||
it('should not create temp directory if it already exists', async () => {
|
||||
(fs.existsSync as Mock).mockReturnValue(true);
|
||||
|
||||
await modifyWithEditor(
|
||||
mockParams,
|
||||
mockModifyContext,
|
||||
'vscode' as EditorType,
|
||||
abortSignal,
|
||||
);
|
||||
|
||||
expect(fs.mkdirSync).not.toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
it('should handle missing old temp file gracefully', async () => {
|
||||
(fs.readFileSync as Mock).mockImplementation((filePath: string) => {
|
||||
if (filePath.includes('-old-')) {
|
||||
const error = new Error('ENOENT: no such file or directory');
|
||||
(error as NodeJS.ErrnoException).code = 'ENOENT';
|
||||
throw error;
|
||||
}
|
||||
return modifiedContent;
|
||||
});
|
||||
|
||||
const result = await modifyWithEditor(
|
||||
mockParams,
|
||||
mockModifyContext,
|
||||
'vscode' as EditorType,
|
||||
abortSignal,
|
||||
);
|
||||
|
||||
expect(mockCreatePatch).toHaveBeenCalledWith(
|
||||
path.basename(mockParams.filePath),
|
||||
'',
|
||||
modifiedContent,
|
||||
'Current',
|
||||
'Proposed',
|
||||
expect.objectContaining({
|
||||
context: 3,
|
||||
ignoreWhitespace: true,
|
||||
}),
|
||||
);
|
||||
|
||||
expect(result.updatedParams).toBeDefined();
|
||||
expect(result.updatedDiff).toBe('mock diff content');
|
||||
});
|
||||
|
||||
it('should handle missing new temp file gracefully', async () => {
|
||||
(fs.readFileSync as Mock).mockImplementation((filePath: string) => {
|
||||
if (filePath.includes('-new-')) {
|
||||
const error = new Error('ENOENT: no such file or directory');
|
||||
(error as NodeJS.ErrnoException).code = 'ENOENT';
|
||||
throw error;
|
||||
}
|
||||
return currentContent;
|
||||
});
|
||||
|
||||
const result = await modifyWithEditor(
|
||||
mockParams,
|
||||
mockModifyContext,
|
||||
'vscode' as EditorType,
|
||||
abortSignal,
|
||||
);
|
||||
|
||||
expect(mockCreatePatch).toHaveBeenCalledWith(
|
||||
path.basename(mockParams.filePath),
|
||||
currentContent,
|
||||
'',
|
||||
'Current',
|
||||
'Proposed',
|
||||
expect.objectContaining({
|
||||
context: 3,
|
||||
ignoreWhitespace: true,
|
||||
}),
|
||||
);
|
||||
|
||||
expect(result.updatedParams).toBeDefined();
|
||||
expect(result.updatedDiff).toBe('mock diff content');
|
||||
});
|
||||
|
||||
it('should clean up temp files even if editor fails', async () => {
|
||||
const editorError = new Error('Editor failed to open');
|
||||
mockOpenDiff.mockRejectedValue(editorError);
|
||||
|
||||
await expect(
|
||||
modifyWithEditor(
|
||||
mockParams,
|
||||
mockModifyContext,
|
||||
'vscode' as EditorType,
|
||||
abortSignal,
|
||||
),
|
||||
).rejects.toThrow('Editor failed to open');
|
||||
|
||||
expect(fs.unlinkSync).toHaveBeenCalledTimes(2);
|
||||
});
|
||||
|
||||
it('should handle temp file cleanup errors gracefully', async () => {
|
||||
const consoleErrorSpy = vi
|
||||
.spyOn(console, 'error')
|
||||
.mockImplementation(() => {});
|
||||
(fs.unlinkSync as Mock).mockImplementation((_filePath: string) => {
|
||||
throw new Error('Failed to delete file');
|
||||
});
|
||||
|
||||
await modifyWithEditor(
|
||||
mockParams,
|
||||
mockModifyContext,
|
||||
'vscode' as EditorType,
|
||||
abortSignal,
|
||||
);
|
||||
|
||||
expect(consoleErrorSpy).toHaveBeenCalledTimes(2);
|
||||
expect(consoleErrorSpy).toHaveBeenCalledWith(
|
||||
expect.stringContaining('Error deleting temp diff file:'),
|
||||
);
|
||||
|
||||
consoleErrorSpy.mockRestore();
|
||||
});
|
||||
|
||||
it('should create temp files with correct naming with extension', async () => {
|
||||
const testFilePath = path.join(tempDir, 'subfolder', 'test-file.txt');
|
||||
mockModifyContext.getFilePath = vi.fn().mockReturnValue(testFilePath);
|
||||
|
||||
await modifyWithEditor(
|
||||
mockParams,
|
||||
mockModifyContext,
|
||||
'vscode' as EditorType,
|
||||
abortSignal,
|
||||
);
|
||||
|
||||
const writeFileCalls = (fs.writeFileSync as Mock).mock.calls;
|
||||
expect(writeFileCalls).toHaveLength(2);
|
||||
|
||||
const oldFilePath = writeFileCalls[0][0];
|
||||
const newFilePath = writeFileCalls[1][0];
|
||||
|
||||
expect(oldFilePath).toMatch(/gemini-cli-modify-test-file-old-\d+\.txt$/);
|
||||
expect(newFilePath).toMatch(/gemini-cli-modify-test-file-new-\d+\.txt$/);
|
||||
expect(oldFilePath).toContain(`${tempDir}/gemini-cli-tool-modify-diffs/`);
|
||||
expect(newFilePath).toContain(`${tempDir}/gemini-cli-tool-modify-diffs/`);
|
||||
});
|
||||
|
||||
it('should create temp files with correct naming without extension', async () => {
|
||||
const testFilePath = path.join(tempDir, 'subfolder', 'test-file');
|
||||
mockModifyContext.getFilePath = vi.fn().mockReturnValue(testFilePath);
|
||||
|
||||
await modifyWithEditor(
|
||||
mockParams,
|
||||
mockModifyContext,
|
||||
'vscode' as EditorType,
|
||||
abortSignal,
|
||||
);
|
||||
|
||||
const writeFileCalls = (fs.writeFileSync as Mock).mock.calls;
|
||||
expect(writeFileCalls).toHaveLength(2);
|
||||
|
||||
const oldFilePath = writeFileCalls[0][0];
|
||||
const newFilePath = writeFileCalls[1][0];
|
||||
|
||||
expect(oldFilePath).toMatch(/gemini-cli-modify-test-file-old-\d+$/);
|
||||
expect(newFilePath).toMatch(/gemini-cli-modify-test-file-new-\d+$/);
|
||||
expect(oldFilePath).toContain(`${tempDir}/gemini-cli-tool-modify-diffs/`);
|
||||
expect(newFilePath).toContain(`${tempDir}/gemini-cli-tool-modify-diffs/`);
|
||||
});
|
||||
});
|
||||
|
||||
describe('isModifiableTool', () => {
|
||||
it('should return true for objects with getModifyContext method', () => {
|
||||
const mockTool = {
|
||||
name: 'test-tool',
|
||||
getModifyContext: vi.fn(),
|
||||
} as unknown as ModifiableTool<TestParams>;
|
||||
|
||||
expect(isModifiableTool(mockTool)).toBe(true);
|
||||
});
|
||||
|
||||
it('should return false for objects without getModifyContext method', () => {
|
||||
const mockTool = {
|
||||
name: 'test-tool',
|
||||
} as unknown as ModifiableTool<TestParams>;
|
||||
|
||||
expect(isModifiableTool(mockTool)).toBe(false);
|
||||
});
|
||||
});
|
||||
165
packages/core/src/tools/modifiable-tool.ts
Normal file
165
packages/core/src/tools/modifiable-tool.ts
Normal file
@@ -0,0 +1,165 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { EditorType, openDiff } from '../utils/editor.js';
|
||||
import os from 'os';
|
||||
import path from 'path';
|
||||
import fs from 'fs';
|
||||
import * as Diff from 'diff';
|
||||
import { DEFAULT_DIFF_OPTIONS } from './diffOptions.js';
|
||||
import { isNodeError } from '../utils/errors.js';
|
||||
import { Tool } from './tools.js';
|
||||
|
||||
/**
|
||||
* A tool that supports a modify operation.
|
||||
*/
|
||||
export interface ModifiableTool<ToolParams> extends Tool<ToolParams> {
|
||||
getModifyContext(abortSignal: AbortSignal): ModifyContext<ToolParams>;
|
||||
}
|
||||
|
||||
export interface ModifyContext<ToolParams> {
|
||||
getFilePath: (params: ToolParams) => string;
|
||||
|
||||
getCurrentContent: (params: ToolParams) => Promise<string>;
|
||||
|
||||
getProposedContent: (params: ToolParams) => Promise<string>;
|
||||
|
||||
createUpdatedParams: (
|
||||
oldContent: string,
|
||||
modifiedProposedContent: string,
|
||||
originalParams: ToolParams,
|
||||
) => ToolParams;
|
||||
}
|
||||
|
||||
export interface ModifyResult<ToolParams> {
|
||||
updatedParams: ToolParams;
|
||||
updatedDiff: string;
|
||||
}
|
||||
|
||||
export function isModifiableTool<TParams>(
|
||||
tool: Tool<TParams>,
|
||||
): tool is ModifiableTool<TParams> {
|
||||
return 'getModifyContext' in tool;
|
||||
}
|
||||
|
||||
function createTempFilesForModify(
|
||||
currentContent: string,
|
||||
proposedContent: string,
|
||||
file_path: string,
|
||||
): { oldPath: string; newPath: string } {
|
||||
const tempDir = os.tmpdir();
|
||||
const diffDir = path.join(tempDir, 'gemini-cli-tool-modify-diffs');
|
||||
|
||||
if (!fs.existsSync(diffDir)) {
|
||||
fs.mkdirSync(diffDir, { recursive: true });
|
||||
}
|
||||
|
||||
const ext = path.extname(file_path);
|
||||
const fileName = path.basename(file_path, ext);
|
||||
const timestamp = Date.now();
|
||||
const tempOldPath = path.join(
|
||||
diffDir,
|
||||
`gemini-cli-modify-${fileName}-old-${timestamp}${ext}`,
|
||||
);
|
||||
const tempNewPath = path.join(
|
||||
diffDir,
|
||||
`gemini-cli-modify-${fileName}-new-${timestamp}${ext}`,
|
||||
);
|
||||
|
||||
fs.writeFileSync(tempOldPath, currentContent, 'utf8');
|
||||
fs.writeFileSync(tempNewPath, proposedContent, 'utf8');
|
||||
|
||||
return { oldPath: tempOldPath, newPath: tempNewPath };
|
||||
}
|
||||
|
||||
function getUpdatedParams<ToolParams>(
|
||||
tmpOldPath: string,
|
||||
tempNewPath: string,
|
||||
originalParams: ToolParams,
|
||||
modifyContext: ModifyContext<ToolParams>,
|
||||
): { updatedParams: ToolParams; updatedDiff: string } {
|
||||
let oldContent = '';
|
||||
let newContent = '';
|
||||
|
||||
try {
|
||||
oldContent = fs.readFileSync(tmpOldPath, 'utf8');
|
||||
} catch (err) {
|
||||
if (!isNodeError(err) || err.code !== 'ENOENT') throw err;
|
||||
oldContent = '';
|
||||
}
|
||||
|
||||
try {
|
||||
newContent = fs.readFileSync(tempNewPath, 'utf8');
|
||||
} catch (err) {
|
||||
if (!isNodeError(err) || err.code !== 'ENOENT') throw err;
|
||||
newContent = '';
|
||||
}
|
||||
|
||||
const updatedParams = modifyContext.createUpdatedParams(
|
||||
oldContent,
|
||||
newContent,
|
||||
originalParams,
|
||||
);
|
||||
const updatedDiff = Diff.createPatch(
|
||||
path.basename(modifyContext.getFilePath(originalParams)),
|
||||
oldContent,
|
||||
newContent,
|
||||
'Current',
|
||||
'Proposed',
|
||||
DEFAULT_DIFF_OPTIONS,
|
||||
);
|
||||
|
||||
return { updatedParams, updatedDiff };
|
||||
}
|
||||
|
||||
function deleteTempFiles(oldPath: string, newPath: string): void {
|
||||
try {
|
||||
fs.unlinkSync(oldPath);
|
||||
} catch {
|
||||
console.error(`Error deleting temp diff file: ${oldPath}`);
|
||||
}
|
||||
|
||||
try {
|
||||
fs.unlinkSync(newPath);
|
||||
} catch {
|
||||
console.error(`Error deleting temp diff file: ${newPath}`);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Triggers an external editor for the user to modify the proposed content,
|
||||
* and returns the updated tool parameters and the diff after the user has modified the proposed content.
|
||||
*/
|
||||
export async function modifyWithEditor<ToolParams>(
|
||||
originalParams: ToolParams,
|
||||
modifyContext: ModifyContext<ToolParams>,
|
||||
editorType: EditorType,
|
||||
_abortSignal: AbortSignal,
|
||||
): Promise<ModifyResult<ToolParams>> {
|
||||
const currentContent = await modifyContext.getCurrentContent(originalParams);
|
||||
const proposedContent =
|
||||
await modifyContext.getProposedContent(originalParams);
|
||||
|
||||
const { oldPath, newPath } = createTempFilesForModify(
|
||||
currentContent,
|
||||
proposedContent,
|
||||
modifyContext.getFilePath(originalParams),
|
||||
);
|
||||
|
||||
try {
|
||||
await openDiff(oldPath, newPath, editorType);
|
||||
const result = getUpdatedParams(
|
||||
oldPath,
|
||||
newPath,
|
||||
originalParams,
|
||||
modifyContext,
|
||||
);
|
||||
|
||||
return result;
|
||||
} finally {
|
||||
deleteTempFiles(oldPath, newPath);
|
||||
}
|
||||
}
|
||||
252
packages/core/src/tools/read-file.test.ts
Normal file
252
packages/core/src/tools/read-file.test.ts
Normal file
@@ -0,0 +1,252 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { vi, describe, it, expect, beforeEach, afterEach, Mock } from 'vitest';
|
||||
import { ReadFileTool, ReadFileToolParams } from './read-file.js';
|
||||
import * as fileUtils from '../utils/fileUtils.js';
|
||||
import path from 'path';
|
||||
import os from 'os';
|
||||
import fs from 'fs'; // For actual fs operations in setup
|
||||
import { Config } from '../config/config.js';
|
||||
import { FileDiscoveryService } from '../services/fileDiscoveryService.js';
|
||||
|
||||
// Mock fileUtils.processSingleFileContent
|
||||
vi.mock('../utils/fileUtils', async () => {
|
||||
const actualFileUtils =
|
||||
await vi.importActual<typeof fileUtils>('../utils/fileUtils');
|
||||
return {
|
||||
...actualFileUtils, // Spread actual implementations
|
||||
processSingleFileContent: vi.fn(), // Mock specific function
|
||||
};
|
||||
});
|
||||
|
||||
const mockProcessSingleFileContent = fileUtils.processSingleFileContent as Mock;
|
||||
|
||||
describe('ReadFileTool', () => {
|
||||
let tempRootDir: string;
|
||||
let tool: ReadFileTool;
|
||||
const abortSignal = new AbortController().signal;
|
||||
|
||||
beforeEach(() => {
|
||||
// Create a unique temporary root directory for each test run
|
||||
tempRootDir = fs.mkdtempSync(
|
||||
path.join(os.tmpdir(), 'read-file-tool-root-'),
|
||||
);
|
||||
fs.writeFileSync(
|
||||
path.join(tempRootDir, '.geminiignore'),
|
||||
['foo.*'].join('\n'),
|
||||
);
|
||||
const fileService = new FileDiscoveryService(tempRootDir);
|
||||
const mockConfigInstance = {
|
||||
getFileService: () => fileService,
|
||||
getTargetDir: () => tempRootDir,
|
||||
} as unknown as Config;
|
||||
tool = new ReadFileTool(mockConfigInstance);
|
||||
mockProcessSingleFileContent.mockReset();
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
// Clean up the temporary root directory
|
||||
if (fs.existsSync(tempRootDir)) {
|
||||
fs.rmSync(tempRootDir, { recursive: true, force: true });
|
||||
}
|
||||
});
|
||||
|
||||
describe('validateToolParams', () => {
|
||||
it('should return null for valid params (absolute path within root)', () => {
|
||||
const params: ReadFileToolParams = {
|
||||
absolute_path: path.join(tempRootDir, 'test.txt'),
|
||||
};
|
||||
expect(tool.validateToolParams(params)).toBeNull();
|
||||
});
|
||||
|
||||
it('should return null for valid params with offset and limit', () => {
|
||||
const params: ReadFileToolParams = {
|
||||
absolute_path: path.join(tempRootDir, 'test.txt'),
|
||||
offset: 0,
|
||||
limit: 10,
|
||||
};
|
||||
expect(tool.validateToolParams(params)).toBeNull();
|
||||
});
|
||||
|
||||
it('should return error for relative path', () => {
|
||||
const params: ReadFileToolParams = { absolute_path: 'test.txt' };
|
||||
expect(tool.validateToolParams(params)).toBe(
|
||||
`File path must be absolute, but was relative: test.txt. You must provide an absolute path.`,
|
||||
);
|
||||
});
|
||||
|
||||
it('should return error for path outside root', () => {
|
||||
const outsidePath = path.resolve(os.tmpdir(), 'outside-root.txt');
|
||||
const params: ReadFileToolParams = { absolute_path: outsidePath };
|
||||
expect(tool.validateToolParams(params)).toMatch(
|
||||
/File path must be within the root directory/,
|
||||
);
|
||||
});
|
||||
|
||||
it('should return error for negative offset', () => {
|
||||
const params: ReadFileToolParams = {
|
||||
absolute_path: path.join(tempRootDir, 'test.txt'),
|
||||
offset: -1,
|
||||
limit: 10,
|
||||
};
|
||||
expect(tool.validateToolParams(params)).toBe(
|
||||
'Offset must be a non-negative number',
|
||||
);
|
||||
});
|
||||
|
||||
it('should return error for non-positive limit', () => {
|
||||
const paramsZero: ReadFileToolParams = {
|
||||
absolute_path: path.join(tempRootDir, 'test.txt'),
|
||||
offset: 0,
|
||||
limit: 0,
|
||||
};
|
||||
expect(tool.validateToolParams(paramsZero)).toBe(
|
||||
'Limit must be a positive number',
|
||||
);
|
||||
const paramsNegative: ReadFileToolParams = {
|
||||
absolute_path: path.join(tempRootDir, 'test.txt'),
|
||||
offset: 0,
|
||||
limit: -5,
|
||||
};
|
||||
expect(tool.validateToolParams(paramsNegative)).toBe(
|
||||
'Limit must be a positive number',
|
||||
);
|
||||
});
|
||||
|
||||
it('should return error for schema validation failure (e.g. missing path)', () => {
|
||||
const params = { offset: 0 } as unknown as ReadFileToolParams;
|
||||
expect(tool.validateToolParams(params)).toBe(
|
||||
`params must have required property 'absolute_path'`,
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('getDescription', () => {
|
||||
it('should return a shortened, relative path', () => {
|
||||
const filePath = path.join(tempRootDir, 'sub', 'dir', 'file.txt');
|
||||
const params: ReadFileToolParams = { absolute_path: filePath };
|
||||
// Assuming tempRootDir is something like /tmp/read-file-tool-root-XXXXXX
|
||||
// The relative path would be sub/dir/file.txt
|
||||
expect(tool.getDescription(params)).toBe('sub/dir/file.txt');
|
||||
});
|
||||
|
||||
it('should return . if path is the root directory', () => {
|
||||
const params: ReadFileToolParams = { absolute_path: tempRootDir };
|
||||
expect(tool.getDescription(params)).toBe('.');
|
||||
});
|
||||
});
|
||||
|
||||
describe('execute', () => {
|
||||
it('should return validation error if params are invalid', async () => {
|
||||
const params: ReadFileToolParams = { absolute_path: 'relative/path.txt' };
|
||||
const result = await tool.execute(params, abortSignal);
|
||||
expect(result.llmContent).toBe(
|
||||
'Error: Invalid parameters provided. Reason: File path must be absolute, but was relative: relative/path.txt. You must provide an absolute path.',
|
||||
);
|
||||
expect(result.returnDisplay).toBe(
|
||||
'File path must be absolute, but was relative: relative/path.txt. You must provide an absolute path.',
|
||||
);
|
||||
});
|
||||
|
||||
it('should return error from processSingleFileContent if it fails', async () => {
|
||||
const filePath = path.join(tempRootDir, 'error.txt');
|
||||
const params: ReadFileToolParams = { absolute_path: filePath };
|
||||
const errorMessage = 'Simulated read error';
|
||||
mockProcessSingleFileContent.mockResolvedValue({
|
||||
llmContent: `Error reading file ${filePath}: ${errorMessage}`,
|
||||
returnDisplay: `Error reading file ${filePath}: ${errorMessage}`,
|
||||
error: errorMessage,
|
||||
});
|
||||
|
||||
const result = await tool.execute(params, abortSignal);
|
||||
expect(mockProcessSingleFileContent).toHaveBeenCalledWith(
|
||||
filePath,
|
||||
tempRootDir,
|
||||
undefined,
|
||||
undefined,
|
||||
);
|
||||
expect(result.llmContent).toContain(errorMessage);
|
||||
expect(result.returnDisplay).toContain(errorMessage);
|
||||
});
|
||||
|
||||
it('should return success result for a text file', async () => {
|
||||
const filePath = path.join(tempRootDir, 'textfile.txt');
|
||||
const fileContent = 'This is a test file.';
|
||||
const params: ReadFileToolParams = { absolute_path: filePath };
|
||||
mockProcessSingleFileContent.mockResolvedValue({
|
||||
llmContent: fileContent,
|
||||
returnDisplay: `Read text file: ${path.basename(filePath)}`,
|
||||
});
|
||||
|
||||
const result = await tool.execute(params, abortSignal);
|
||||
expect(mockProcessSingleFileContent).toHaveBeenCalledWith(
|
||||
filePath,
|
||||
tempRootDir,
|
||||
undefined,
|
||||
undefined,
|
||||
);
|
||||
expect(result.llmContent).toBe(fileContent);
|
||||
expect(result.returnDisplay).toBe(
|
||||
`Read text file: ${path.basename(filePath)}`,
|
||||
);
|
||||
});
|
||||
|
||||
it('should return success result for an image file', async () => {
|
||||
const filePath = path.join(tempRootDir, 'image.png');
|
||||
const imageData = {
|
||||
inlineData: { mimeType: 'image/png', data: 'base64...' },
|
||||
};
|
||||
const params: ReadFileToolParams = { absolute_path: filePath };
|
||||
mockProcessSingleFileContent.mockResolvedValue({
|
||||
llmContent: imageData,
|
||||
returnDisplay: `Read image file: ${path.basename(filePath)}`,
|
||||
});
|
||||
|
||||
const result = await tool.execute(params, abortSignal);
|
||||
expect(mockProcessSingleFileContent).toHaveBeenCalledWith(
|
||||
filePath,
|
||||
tempRootDir,
|
||||
undefined,
|
||||
undefined,
|
||||
);
|
||||
expect(result.llmContent).toEqual(imageData);
|
||||
expect(result.returnDisplay).toBe(
|
||||
`Read image file: ${path.basename(filePath)}`,
|
||||
);
|
||||
});
|
||||
|
||||
it('should pass offset and limit to processSingleFileContent', async () => {
|
||||
const filePath = path.join(tempRootDir, 'paginated.txt');
|
||||
const params: ReadFileToolParams = {
|
||||
absolute_path: filePath,
|
||||
offset: 10,
|
||||
limit: 5,
|
||||
};
|
||||
mockProcessSingleFileContent.mockResolvedValue({
|
||||
llmContent: 'some lines',
|
||||
returnDisplay: 'Read text file (paginated)',
|
||||
});
|
||||
|
||||
await tool.execute(params, abortSignal);
|
||||
expect(mockProcessSingleFileContent).toHaveBeenCalledWith(
|
||||
filePath,
|
||||
tempRootDir,
|
||||
10,
|
||||
5,
|
||||
);
|
||||
});
|
||||
|
||||
it('should return error if path is ignored by a .geminiignore pattern', async () => {
|
||||
const params: ReadFileToolParams = {
|
||||
absolute_path: path.join(tempRootDir, 'foo.bar'),
|
||||
};
|
||||
const result = await tool.execute(params, abortSignal);
|
||||
expect(result.returnDisplay).toContain('foo.bar');
|
||||
expect(result.returnDisplay).not.toContain('foo.baz');
|
||||
});
|
||||
});
|
||||
});
|
||||
165
packages/core/src/tools/read-file.ts
Normal file
165
packages/core/src/tools/read-file.ts
Normal file
@@ -0,0 +1,165 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import path from 'path';
|
||||
import { SchemaValidator } from '../utils/schemaValidator.js';
|
||||
import { makeRelative, shortenPath } from '../utils/paths.js';
|
||||
import { BaseTool, ToolResult } from './tools.js';
|
||||
import { Type } from '@google/genai';
|
||||
import {
|
||||
isWithinRoot,
|
||||
processSingleFileContent,
|
||||
getSpecificMimeType,
|
||||
} from '../utils/fileUtils.js';
|
||||
import { Config } from '../config/config.js';
|
||||
import {
|
||||
recordFileOperationMetric,
|
||||
FileOperation,
|
||||
} from '../telemetry/metrics.js';
|
||||
|
||||
/**
|
||||
* Parameters for the ReadFile tool
|
||||
*/
|
||||
export interface ReadFileToolParams {
|
||||
/**
|
||||
* The absolute path to the file to read
|
||||
*/
|
||||
absolute_path: string;
|
||||
|
||||
/**
|
||||
* The line number to start reading from (optional)
|
||||
*/
|
||||
offset?: number;
|
||||
|
||||
/**
|
||||
* The number of lines to read (optional)
|
||||
*/
|
||||
limit?: number;
|
||||
}
|
||||
|
||||
/**
|
||||
* Implementation of the ReadFile tool logic
|
||||
*/
|
||||
export class ReadFileTool extends BaseTool<ReadFileToolParams, ToolResult> {
|
||||
static readonly Name: string = 'read_file';
|
||||
|
||||
constructor(private config: Config) {
|
||||
super(
|
||||
ReadFileTool.Name,
|
||||
'ReadFile',
|
||||
'Reads and returns the content of a specified file from the local filesystem. Handles text, images (PNG, JPG, GIF, WEBP, SVG, BMP), and PDF files. For text files, it can read specific line ranges.',
|
||||
{
|
||||
properties: {
|
||||
absolute_path: {
|
||||
description:
|
||||
"The absolute path to the file to read (e.g., '/home/user/project/file.txt'). Relative paths are not supported. You must provide an absolute path.",
|
||||
type: Type.STRING,
|
||||
},
|
||||
offset: {
|
||||
description:
|
||||
"Optional: For text files, the 0-based line number to start reading from. Requires 'limit' to be set. Use for paginating through large files.",
|
||||
type: Type.NUMBER,
|
||||
},
|
||||
limit: {
|
||||
description:
|
||||
"Optional: For text files, maximum number of lines to read. Use with 'offset' to paginate through large files. If omitted, reads the entire file (if feasible, up to a default limit).",
|
||||
type: Type.NUMBER,
|
||||
},
|
||||
},
|
||||
required: ['absolute_path'],
|
||||
type: Type.OBJECT,
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
validateToolParams(params: ReadFileToolParams): string | null {
|
||||
const errors = SchemaValidator.validate(this.schema.parameters, params);
|
||||
if (errors) {
|
||||
return errors;
|
||||
}
|
||||
|
||||
const filePath = params.absolute_path;
|
||||
if (!path.isAbsolute(filePath)) {
|
||||
return `File path must be absolute, but was relative: ${filePath}. You must provide an absolute path.`;
|
||||
}
|
||||
if (!isWithinRoot(filePath, this.config.getTargetDir())) {
|
||||
return `File path must be within the root directory (${this.config.getTargetDir()}): ${filePath}`;
|
||||
}
|
||||
if (params.offset !== undefined && params.offset < 0) {
|
||||
return 'Offset must be a non-negative number';
|
||||
}
|
||||
if (params.limit !== undefined && params.limit <= 0) {
|
||||
return 'Limit must be a positive number';
|
||||
}
|
||||
|
||||
const fileService = this.config.getFileService();
|
||||
if (fileService.shouldGeminiIgnoreFile(params.absolute_path)) {
|
||||
return `File path '${filePath}' is ignored by .geminiignore pattern(s).`;
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
getDescription(params: ReadFileToolParams): string {
|
||||
if (
|
||||
!params ||
|
||||
typeof params.absolute_path !== 'string' ||
|
||||
params.absolute_path.trim() === ''
|
||||
) {
|
||||
return `Path unavailable`;
|
||||
}
|
||||
const relativePath = makeRelative(
|
||||
params.absolute_path,
|
||||
this.config.getTargetDir(),
|
||||
);
|
||||
return shortenPath(relativePath);
|
||||
}
|
||||
|
||||
async execute(
|
||||
params: ReadFileToolParams,
|
||||
_signal: AbortSignal,
|
||||
): Promise<ToolResult> {
|
||||
const validationError = this.validateToolParams(params);
|
||||
if (validationError) {
|
||||
return {
|
||||
llmContent: `Error: Invalid parameters provided. Reason: ${validationError}`,
|
||||
returnDisplay: validationError,
|
||||
};
|
||||
}
|
||||
|
||||
const result = await processSingleFileContent(
|
||||
params.absolute_path,
|
||||
this.config.getTargetDir(),
|
||||
params.offset,
|
||||
params.limit,
|
||||
);
|
||||
|
||||
if (result.error) {
|
||||
return {
|
||||
llmContent: result.error, // The detailed error for LLM
|
||||
returnDisplay: result.returnDisplay, // User-friendly error
|
||||
};
|
||||
}
|
||||
|
||||
const lines =
|
||||
typeof result.llmContent === 'string'
|
||||
? result.llmContent.split('\n').length
|
||||
: undefined;
|
||||
const mimetype = getSpecificMimeType(params.absolute_path);
|
||||
recordFileOperationMetric(
|
||||
this.config,
|
||||
FileOperation.READ,
|
||||
lines,
|
||||
mimetype,
|
||||
path.extname(params.absolute_path),
|
||||
);
|
||||
|
||||
return {
|
||||
llmContent: result.llmContent,
|
||||
returnDisplay: result.returnDisplay,
|
||||
};
|
||||
}
|
||||
}
|
||||
425
packages/core/src/tools/read-many-files.test.ts
Normal file
425
packages/core/src/tools/read-many-files.test.ts
Normal file
@@ -0,0 +1,425 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { vi, describe, it, expect, beforeEach, afterEach } from 'vitest';
|
||||
import type { Mock } from 'vitest';
|
||||
import { mockControl } from '../__mocks__/fs/promises.js';
|
||||
import { ReadManyFilesTool } from './read-many-files.js';
|
||||
import { FileDiscoveryService } from '../services/fileDiscoveryService.js';
|
||||
import path from 'path';
|
||||
import fs from 'fs'; // Actual fs for setup
|
||||
import os from 'os';
|
||||
import { Config } from '../config/config.js';
|
||||
|
||||
vi.mock('mime-types', () => {
|
||||
const lookup = (filename: string) => {
|
||||
if (filename.endsWith('.ts') || filename.endsWith('.js')) {
|
||||
return 'text/plain';
|
||||
}
|
||||
if (filename.endsWith('.png')) {
|
||||
return 'image/png';
|
||||
}
|
||||
if (filename.endsWith('.pdf')) {
|
||||
return 'application/pdf';
|
||||
}
|
||||
if (filename.endsWith('.mp3') || filename.endsWith('.wav')) {
|
||||
return 'audio/mpeg';
|
||||
}
|
||||
if (filename.endsWith('.mp4') || filename.endsWith('.mov')) {
|
||||
return 'video/mp4';
|
||||
}
|
||||
return false;
|
||||
};
|
||||
return {
|
||||
default: {
|
||||
lookup,
|
||||
},
|
||||
lookup,
|
||||
};
|
||||
});
|
||||
|
||||
describe('ReadManyFilesTool', () => {
|
||||
let tool: ReadManyFilesTool;
|
||||
let tempRootDir: string;
|
||||
let tempDirOutsideRoot: string;
|
||||
let mockReadFileFn: Mock;
|
||||
|
||||
beforeEach(async () => {
|
||||
tempRootDir = fs.mkdtempSync(
|
||||
path.join(os.tmpdir(), 'read-many-files-root-'),
|
||||
);
|
||||
tempDirOutsideRoot = fs.mkdtempSync(
|
||||
path.join(os.tmpdir(), 'read-many-files-external-'),
|
||||
);
|
||||
fs.writeFileSync(path.join(tempRootDir, '.geminiignore'), 'foo.*');
|
||||
const fileService = new FileDiscoveryService(tempRootDir);
|
||||
const mockConfig = {
|
||||
getFileService: () => fileService,
|
||||
getFileFilteringRespectGitIgnore: () => true,
|
||||
getTargetDir: () => tempRootDir,
|
||||
} as Partial<Config> as Config;
|
||||
|
||||
tool = new ReadManyFilesTool(mockConfig);
|
||||
|
||||
mockReadFileFn = mockControl.mockReadFile;
|
||||
mockReadFileFn.mockReset();
|
||||
|
||||
mockReadFileFn.mockImplementation(
|
||||
async (filePath: fs.PathLike, options?: Record<string, unknown>) => {
|
||||
const fp =
|
||||
typeof filePath === 'string'
|
||||
? filePath
|
||||
: (filePath as Buffer).toString();
|
||||
|
||||
if (fs.existsSync(fp)) {
|
||||
const originalFs = await vi.importActual<typeof fs>('fs');
|
||||
return originalFs.promises.readFile(fp, options);
|
||||
}
|
||||
|
||||
if (fp.endsWith('nonexistent-file.txt')) {
|
||||
const err = new Error(
|
||||
`ENOENT: no such file or directory, open '${fp}'`,
|
||||
);
|
||||
(err as NodeJS.ErrnoException).code = 'ENOENT';
|
||||
throw err;
|
||||
}
|
||||
if (fp.endsWith('unreadable.txt')) {
|
||||
const err = new Error(`EACCES: permission denied, open '${fp}'`);
|
||||
(err as NodeJS.ErrnoException).code = 'EACCES';
|
||||
throw err;
|
||||
}
|
||||
if (fp.endsWith('.png'))
|
||||
return Buffer.from([0x89, 0x50, 0x4e, 0x47, 0x0d, 0x0a, 0x1a, 0x0a]); // PNG header
|
||||
if (fp.endsWith('.pdf')) return Buffer.from('%PDF-1.4...'); // PDF start
|
||||
if (fp.endsWith('binary.bin'))
|
||||
return Buffer.from([0x00, 0x01, 0x02, 0x00, 0x03]);
|
||||
|
||||
const err = new Error(
|
||||
`ENOENT: no such file or directory, open '${fp}' (unmocked path)`,
|
||||
);
|
||||
(err as NodeJS.ErrnoException).code = 'ENOENT';
|
||||
throw err;
|
||||
},
|
||||
);
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
if (fs.existsSync(tempRootDir)) {
|
||||
fs.rmSync(tempRootDir, { recursive: true, force: true });
|
||||
}
|
||||
if (fs.existsSync(tempDirOutsideRoot)) {
|
||||
fs.rmSync(tempDirOutsideRoot, { recursive: true, force: true });
|
||||
}
|
||||
});
|
||||
|
||||
describe('validateParams', () => {
|
||||
it('should return null for valid relative paths within root', () => {
|
||||
const params = { paths: ['file1.txt', 'subdir/file2.txt'] };
|
||||
expect(tool.validateParams(params)).toBeNull();
|
||||
});
|
||||
|
||||
it('should return null for valid glob patterns within root', () => {
|
||||
const params = { paths: ['*.txt', 'subdir/**/*.js'] };
|
||||
expect(tool.validateParams(params)).toBeNull();
|
||||
});
|
||||
|
||||
it('should return null for paths trying to escape the root (e.g., ../) as execute handles this', () => {
|
||||
const params = { paths: ['../outside.txt'] };
|
||||
expect(tool.validateParams(params)).toBeNull();
|
||||
});
|
||||
|
||||
it('should return null for absolute paths as execute handles this', () => {
|
||||
const params = { paths: [path.join(tempDirOutsideRoot, 'absolute.txt')] };
|
||||
expect(tool.validateParams(params)).toBeNull();
|
||||
});
|
||||
|
||||
it('should return error if paths array is empty', () => {
|
||||
const params = { paths: [] };
|
||||
expect(tool.validateParams(params)).toBe(
|
||||
'params/paths must NOT have fewer than 1 items',
|
||||
);
|
||||
});
|
||||
|
||||
it('should return null for valid exclude and include patterns', () => {
|
||||
const params = {
|
||||
paths: ['src/**/*.ts'],
|
||||
exclude: ['**/*.test.ts'],
|
||||
include: ['src/utils/*.ts'],
|
||||
};
|
||||
expect(tool.validateParams(params)).toBeNull();
|
||||
});
|
||||
|
||||
it('should return error if paths array contains an empty string', () => {
|
||||
const params = { paths: ['file1.txt', ''] };
|
||||
expect(tool.validateParams(params)).toBe(
|
||||
'params/paths/1 must NOT have fewer than 1 characters',
|
||||
);
|
||||
});
|
||||
|
||||
it('should return error if include array contains non-string elements', () => {
|
||||
const params = {
|
||||
paths: ['file1.txt'],
|
||||
include: ['*.ts', 123] as string[],
|
||||
};
|
||||
expect(tool.validateParams(params)).toBe(
|
||||
'params/include/1 must be string',
|
||||
);
|
||||
});
|
||||
|
||||
it('should return error if exclude array contains non-string elements', () => {
|
||||
const params = {
|
||||
paths: ['file1.txt'],
|
||||
exclude: ['*.log', {}] as string[],
|
||||
};
|
||||
expect(tool.validateParams(params)).toBe(
|
||||
'params/exclude/1 must be string',
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('execute', () => {
|
||||
const createFile = (filePath: string, content = '') => {
|
||||
const fullPath = path.join(tempRootDir, filePath);
|
||||
fs.mkdirSync(path.dirname(fullPath), { recursive: true });
|
||||
fs.writeFileSync(fullPath, content);
|
||||
};
|
||||
const createBinaryFile = (filePath: string, data: Uint8Array) => {
|
||||
const fullPath = path.join(tempRootDir, filePath);
|
||||
fs.mkdirSync(path.dirname(fullPath), { recursive: true });
|
||||
fs.writeFileSync(fullPath, data);
|
||||
};
|
||||
|
||||
it('should read a single specified file', async () => {
|
||||
createFile('file1.txt', 'Content of file1');
|
||||
const params = { paths: ['file1.txt'] };
|
||||
const result = await tool.execute(params, new AbortController().signal);
|
||||
const expectedPath = path.join(tempRootDir, 'file1.txt');
|
||||
expect(result.llmContent).toEqual([
|
||||
`--- ${expectedPath} ---\n\nContent of file1\n\n`,
|
||||
]);
|
||||
expect(result.returnDisplay).toContain(
|
||||
'Successfully read and concatenated content from **1 file(s)**',
|
||||
);
|
||||
});
|
||||
|
||||
it('should read multiple specified files', async () => {
|
||||
createFile('file1.txt', 'Content1');
|
||||
createFile('subdir/file2.js', 'Content2');
|
||||
const params = { paths: ['file1.txt', 'subdir/file2.js'] };
|
||||
const result = await tool.execute(params, new AbortController().signal);
|
||||
const content = result.llmContent as string[];
|
||||
const expectedPath1 = path.join(tempRootDir, 'file1.txt');
|
||||
const expectedPath2 = path.join(tempRootDir, 'subdir/file2.js');
|
||||
expect(
|
||||
content.some((c) =>
|
||||
c.includes(`--- ${expectedPath1} ---\n\nContent1\n\n`),
|
||||
),
|
||||
).toBe(true);
|
||||
expect(
|
||||
content.some((c) =>
|
||||
c.includes(`--- ${expectedPath2} ---\n\nContent2\n\n`),
|
||||
),
|
||||
).toBe(true);
|
||||
expect(result.returnDisplay).toContain(
|
||||
'Successfully read and concatenated content from **2 file(s)**',
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle glob patterns', async () => {
|
||||
createFile('file.txt', 'Text file');
|
||||
createFile('another.txt', 'Another text');
|
||||
createFile('sub/data.json', '{}');
|
||||
const params = { paths: ['*.txt'] };
|
||||
const result = await tool.execute(params, new AbortController().signal);
|
||||
const content = result.llmContent as string[];
|
||||
const expectedPath1 = path.join(tempRootDir, 'file.txt');
|
||||
const expectedPath2 = path.join(tempRootDir, 'another.txt');
|
||||
expect(
|
||||
content.some((c) =>
|
||||
c.includes(`--- ${expectedPath1} ---\n\nText file\n\n`),
|
||||
),
|
||||
).toBe(true);
|
||||
expect(
|
||||
content.some((c) =>
|
||||
c.includes(`--- ${expectedPath2} ---\n\nAnother text\n\n`),
|
||||
),
|
||||
).toBe(true);
|
||||
expect(content.find((c) => c.includes('sub/data.json'))).toBeUndefined();
|
||||
expect(result.returnDisplay).toContain(
|
||||
'Successfully read and concatenated content from **2 file(s)**',
|
||||
);
|
||||
});
|
||||
|
||||
it('should respect exclude patterns', async () => {
|
||||
createFile('src/main.ts', 'Main content');
|
||||
createFile('src/main.test.ts', 'Test content');
|
||||
const params = { paths: ['src/**/*.ts'], exclude: ['**/*.test.ts'] };
|
||||
const result = await tool.execute(params, new AbortController().signal);
|
||||
const content = result.llmContent as string[];
|
||||
const expectedPath = path.join(tempRootDir, 'src/main.ts');
|
||||
expect(content).toEqual([`--- ${expectedPath} ---\n\nMain content\n\n`]);
|
||||
expect(
|
||||
content.find((c) => c.includes('src/main.test.ts')),
|
||||
).toBeUndefined();
|
||||
expect(result.returnDisplay).toContain(
|
||||
'Successfully read and concatenated content from **1 file(s)**',
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle non-existent specific files gracefully', async () => {
|
||||
const params = { paths: ['nonexistent-file.txt'] };
|
||||
const result = await tool.execute(params, new AbortController().signal);
|
||||
expect(result.llmContent).toEqual([
|
||||
'No files matching the criteria were found or all were skipped.',
|
||||
]);
|
||||
expect(result.returnDisplay).toContain(
|
||||
'No files were read and concatenated based on the criteria.',
|
||||
);
|
||||
});
|
||||
|
||||
it('should use default excludes', async () => {
|
||||
createFile('node_modules/some-lib/index.js', 'lib code');
|
||||
createFile('src/app.js', 'app code');
|
||||
const params = { paths: ['**/*.js'] };
|
||||
const result = await tool.execute(params, new AbortController().signal);
|
||||
const content = result.llmContent as string[];
|
||||
const expectedPath = path.join(tempRootDir, 'src/app.js');
|
||||
expect(content).toEqual([`--- ${expectedPath} ---\n\napp code\n\n`]);
|
||||
expect(
|
||||
content.find((c) => c.includes('node_modules/some-lib/index.js')),
|
||||
).toBeUndefined();
|
||||
expect(result.returnDisplay).toContain(
|
||||
'Successfully read and concatenated content from **1 file(s)**',
|
||||
);
|
||||
});
|
||||
|
||||
it('should NOT use default excludes if useDefaultExcludes is false', async () => {
|
||||
createFile('node_modules/some-lib/index.js', 'lib code');
|
||||
createFile('src/app.js', 'app code');
|
||||
const params = { paths: ['**/*.js'], useDefaultExcludes: false };
|
||||
const result = await tool.execute(params, new AbortController().signal);
|
||||
const content = result.llmContent as string[];
|
||||
const expectedPath1 = path.join(
|
||||
tempRootDir,
|
||||
'node_modules/some-lib/index.js',
|
||||
);
|
||||
const expectedPath2 = path.join(tempRootDir, 'src/app.js');
|
||||
expect(
|
||||
content.some((c) =>
|
||||
c.includes(`--- ${expectedPath1} ---\n\nlib code\n\n`),
|
||||
),
|
||||
).toBe(true);
|
||||
expect(
|
||||
content.some((c) =>
|
||||
c.includes(`--- ${expectedPath2} ---\n\napp code\n\n`),
|
||||
),
|
||||
).toBe(true);
|
||||
expect(result.returnDisplay).toContain(
|
||||
'Successfully read and concatenated content from **2 file(s)**',
|
||||
);
|
||||
});
|
||||
|
||||
it('should include images as inlineData parts if explicitly requested by extension', async () => {
|
||||
createBinaryFile(
|
||||
'image.png',
|
||||
Buffer.from([0x89, 0x50, 0x4e, 0x47, 0x0d, 0x0a, 0x1a, 0x0a]),
|
||||
);
|
||||
const params = { paths: ['*.png'] }; // Explicitly requesting .png
|
||||
const result = await tool.execute(params, new AbortController().signal);
|
||||
expect(result.llmContent).toEqual([
|
||||
{
|
||||
inlineData: {
|
||||
data: Buffer.from([
|
||||
0x89, 0x50, 0x4e, 0x47, 0x0d, 0x0a, 0x1a, 0x0a,
|
||||
]).toString('base64'),
|
||||
mimeType: 'image/png',
|
||||
},
|
||||
},
|
||||
]);
|
||||
expect(result.returnDisplay).toContain(
|
||||
'Successfully read and concatenated content from **1 file(s)**',
|
||||
);
|
||||
});
|
||||
|
||||
it('should include images as inlineData parts if explicitly requested by name', async () => {
|
||||
createBinaryFile(
|
||||
'myExactImage.png',
|
||||
Buffer.from([0x89, 0x50, 0x4e, 0x47, 0x0d, 0x0a, 0x1a, 0x0a]),
|
||||
);
|
||||
const params = { paths: ['myExactImage.png'] }; // Explicitly requesting by full name
|
||||
const result = await tool.execute(params, new AbortController().signal);
|
||||
expect(result.llmContent).toEqual([
|
||||
{
|
||||
inlineData: {
|
||||
data: Buffer.from([
|
||||
0x89, 0x50, 0x4e, 0x47, 0x0d, 0x0a, 0x1a, 0x0a,
|
||||
]).toString('base64'),
|
||||
mimeType: 'image/png',
|
||||
},
|
||||
},
|
||||
]);
|
||||
});
|
||||
|
||||
it('should skip PDF files if not explicitly requested by extension or name', async () => {
|
||||
createBinaryFile('document.pdf', Buffer.from('%PDF-1.4...'));
|
||||
createFile('notes.txt', 'text notes');
|
||||
const params = { paths: ['*'] }; // Generic glob, not specific to .pdf
|
||||
const result = await tool.execute(params, new AbortController().signal);
|
||||
const content = result.llmContent as string[];
|
||||
const expectedPath = path.join(tempRootDir, 'notes.txt');
|
||||
expect(
|
||||
content.some(
|
||||
(c) =>
|
||||
typeof c === 'string' &&
|
||||
c.includes(`--- ${expectedPath} ---\n\ntext notes\n\n`),
|
||||
),
|
||||
).toBe(true);
|
||||
expect(result.returnDisplay).toContain('**Skipped 1 item(s):**');
|
||||
expect(result.returnDisplay).toContain(
|
||||
'- `document.pdf` (Reason: asset file (image/pdf) was not explicitly requested by name or extension)',
|
||||
);
|
||||
});
|
||||
|
||||
it('should include PDF files as inlineData parts if explicitly requested by extension', async () => {
|
||||
createBinaryFile('important.pdf', Buffer.from('%PDF-1.4...'));
|
||||
const params = { paths: ['*.pdf'] }; // Explicitly requesting .pdf files
|
||||
const result = await tool.execute(params, new AbortController().signal);
|
||||
expect(result.llmContent).toEqual([
|
||||
{
|
||||
inlineData: {
|
||||
data: Buffer.from('%PDF-1.4...').toString('base64'),
|
||||
mimeType: 'application/pdf',
|
||||
},
|
||||
},
|
||||
]);
|
||||
});
|
||||
|
||||
it('should include PDF files as inlineData parts if explicitly requested by name', async () => {
|
||||
createBinaryFile('report-final.pdf', Buffer.from('%PDF-1.4...'));
|
||||
const params = { paths: ['report-final.pdf'] };
|
||||
const result = await tool.execute(params, new AbortController().signal);
|
||||
expect(result.llmContent).toEqual([
|
||||
{
|
||||
inlineData: {
|
||||
data: Buffer.from('%PDF-1.4...').toString('base64'),
|
||||
mimeType: 'application/pdf',
|
||||
},
|
||||
},
|
||||
]);
|
||||
});
|
||||
|
||||
it('should return error if path is ignored by a .geminiignore pattern', async () => {
|
||||
createFile('foo.bar', '');
|
||||
createFile('bar.ts', '');
|
||||
createFile('foo.quux', '');
|
||||
const params = { paths: ['foo.bar', 'bar.ts', 'foo.quux'] };
|
||||
const result = await tool.execute(params, new AbortController().signal);
|
||||
expect(result.returnDisplay).not.toContain('foo.bar');
|
||||
expect(result.returnDisplay).not.toContain('foo.quux');
|
||||
expect(result.returnDisplay).toContain('bar.ts');
|
||||
});
|
||||
});
|
||||
});
|
||||
457
packages/core/src/tools/read-many-files.ts
Normal file
457
packages/core/src/tools/read-many-files.ts
Normal file
@@ -0,0 +1,457 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { BaseTool, ToolResult } from './tools.js';
|
||||
import { SchemaValidator } from '../utils/schemaValidator.js';
|
||||
import { getErrorMessage } from '../utils/errors.js';
|
||||
import * as path from 'path';
|
||||
import { glob } from 'glob';
|
||||
import { getCurrentGeminiMdFilename } from './memoryTool.js';
|
||||
import {
|
||||
detectFileType,
|
||||
processSingleFileContent,
|
||||
DEFAULT_ENCODING,
|
||||
getSpecificMimeType,
|
||||
} from '../utils/fileUtils.js';
|
||||
import { PartListUnion, Schema, Type } from '@google/genai';
|
||||
import { Config } from '../config/config.js';
|
||||
import {
|
||||
recordFileOperationMetric,
|
||||
FileOperation,
|
||||
} from '../telemetry/metrics.js';
|
||||
|
||||
/**
|
||||
* Parameters for the ReadManyFilesTool.
|
||||
*/
|
||||
export interface ReadManyFilesParams {
|
||||
/**
|
||||
* An array of file paths or directory paths to search within.
|
||||
* Paths are relative to the tool's configured target directory.
|
||||
* Glob patterns can be used directly in these paths.
|
||||
*/
|
||||
paths: string[];
|
||||
|
||||
/**
|
||||
* Optional. Glob patterns for files to include.
|
||||
* These are effectively combined with the `paths`.
|
||||
* Example: ["*.ts", "src/** /*.md"]
|
||||
*/
|
||||
include?: string[];
|
||||
|
||||
/**
|
||||
* Optional. Glob patterns for files/directories to exclude.
|
||||
* Applied as ignore patterns.
|
||||
* Example: ["*.log", "dist/**"]
|
||||
*/
|
||||
exclude?: string[];
|
||||
|
||||
/**
|
||||
* Optional. Search directories recursively.
|
||||
* This is generally controlled by glob patterns (e.g., `**`).
|
||||
* The glob implementation is recursive by default for `**`.
|
||||
* For simplicity, we'll rely on `**` for recursion.
|
||||
*/
|
||||
recursive?: boolean;
|
||||
|
||||
/**
|
||||
* Optional. Apply default exclusion patterns. Defaults to true.
|
||||
*/
|
||||
useDefaultExcludes?: boolean;
|
||||
|
||||
/**
|
||||
* Optional. Whether to respect .gitignore patterns. Defaults to true.
|
||||
*/
|
||||
respect_git_ignore?: boolean;
|
||||
}
|
||||
|
||||
/**
|
||||
* Default exclusion patterns for commonly ignored directories and binary file types.
|
||||
* These are compatible with glob ignore patterns.
|
||||
* TODO(adh): Consider making this configurable or extendable through a command line argument.
|
||||
* TODO(adh): Look into sharing this list with the glob tool.
|
||||
*/
|
||||
const DEFAULT_EXCLUDES: string[] = [
|
||||
'**/node_modules/**',
|
||||
'**/.git/**',
|
||||
'**/.vscode/**',
|
||||
'**/.idea/**',
|
||||
'**/dist/**',
|
||||
'**/build/**',
|
||||
'**/coverage/**',
|
||||
'**/__pycache__/**',
|
||||
'**/*.pyc',
|
||||
'**/*.pyo',
|
||||
'**/*.bin',
|
||||
'**/*.exe',
|
||||
'**/*.dll',
|
||||
'**/*.so',
|
||||
'**/*.dylib',
|
||||
'**/*.class',
|
||||
'**/*.jar',
|
||||
'**/*.war',
|
||||
'**/*.zip',
|
||||
'**/*.tar',
|
||||
'**/*.gz',
|
||||
'**/*.bz2',
|
||||
'**/*.rar',
|
||||
'**/*.7z',
|
||||
'**/*.doc',
|
||||
'**/*.docx',
|
||||
'**/*.xls',
|
||||
'**/*.xlsx',
|
||||
'**/*.ppt',
|
||||
'**/*.pptx',
|
||||
'**/*.odt',
|
||||
'**/*.ods',
|
||||
'**/*.odp',
|
||||
'**/*.DS_Store',
|
||||
'**/.env',
|
||||
`**/${getCurrentGeminiMdFilename()}`,
|
||||
];
|
||||
|
||||
const DEFAULT_OUTPUT_SEPARATOR_FORMAT = '--- {filePath} ---';
|
||||
|
||||
/**
|
||||
* Tool implementation for finding and reading multiple text files from the local filesystem
|
||||
* within a specified target directory. The content is concatenated.
|
||||
* It is intended to run in an environment with access to the local file system (e.g., a Node.js backend).
|
||||
*/
|
||||
export class ReadManyFilesTool extends BaseTool<
|
||||
ReadManyFilesParams,
|
||||
ToolResult
|
||||
> {
|
||||
static readonly Name: string = 'read_many_files';
|
||||
|
||||
private readonly geminiIgnorePatterns: string[] = [];
|
||||
|
||||
constructor(private config: Config) {
|
||||
const parameterSchema: Schema = {
|
||||
type: Type.OBJECT,
|
||||
properties: {
|
||||
paths: {
|
||||
type: Type.ARRAY,
|
||||
items: {
|
||||
type: Type.STRING,
|
||||
minLength: '1',
|
||||
},
|
||||
minItems: '1',
|
||||
description:
|
||||
"Required. An array of glob patterns or paths relative to the tool's target directory. Examples: ['src/**/*.ts'], ['README.md', 'docs/']",
|
||||
},
|
||||
include: {
|
||||
type: Type.ARRAY,
|
||||
items: {
|
||||
type: Type.STRING,
|
||||
minLength: '1',
|
||||
},
|
||||
description:
|
||||
'Optional. Additional glob patterns to include. These are merged with `paths`. Example: ["*.test.ts"] to specifically add test files if they were broadly excluded.',
|
||||
default: [],
|
||||
},
|
||||
exclude: {
|
||||
type: Type.ARRAY,
|
||||
items: {
|
||||
type: Type.STRING,
|
||||
minLength: '1',
|
||||
},
|
||||
description:
|
||||
'Optional. Glob patterns for files/directories to exclude. Added to default excludes if useDefaultExcludes is true. Example: ["**/*.log", "temp/"]',
|
||||
default: [],
|
||||
},
|
||||
recursive: {
|
||||
type: Type.BOOLEAN,
|
||||
description:
|
||||
'Optional. Whether to search recursively (primarily controlled by `**` in glob patterns). Defaults to true.',
|
||||
default: true,
|
||||
},
|
||||
useDefaultExcludes: {
|
||||
type: Type.BOOLEAN,
|
||||
description:
|
||||
'Optional. Whether to apply a list of default exclusion patterns (e.g., node_modules, .git, binary files). Defaults to true.',
|
||||
default: true,
|
||||
},
|
||||
respect_git_ignore: {
|
||||
type: Type.BOOLEAN,
|
||||
description:
|
||||
'Optional. Whether to respect .gitignore patterns when discovering files. Only available in git repositories. Defaults to true.',
|
||||
default: true,
|
||||
},
|
||||
},
|
||||
required: ['paths'],
|
||||
};
|
||||
|
||||
super(
|
||||
ReadManyFilesTool.Name,
|
||||
'ReadManyFiles',
|
||||
`Reads content from multiple files specified by paths or glob patterns within a configured target directory. For text files, it concatenates their content into a single string. It is primarily designed for text-based files. However, it can also process image (e.g., .png, .jpg) and PDF (.pdf) files if their file names or extensions are explicitly included in the 'paths' argument. For these explicitly requested non-text files, their data is read and included in a format suitable for model consumption (e.g., base64 encoded).
|
||||
|
||||
This tool is useful when you need to understand or analyze a collection of files, such as:
|
||||
- Getting an overview of a codebase or parts of it (e.g., all TypeScript files in the 'src' directory).
|
||||
- Finding where specific functionality is implemented if the user asks broad questions about code.
|
||||
- Reviewing documentation files (e.g., all Markdown files in the 'docs' directory).
|
||||
- Gathering context from multiple configuration files.
|
||||
- When the user asks to "read all files in X directory" or "show me the content of all Y files".
|
||||
|
||||
Use this tool when the user's query implies needing the content of several files simultaneously for context, analysis, or summarization. For text files, it uses default UTF-8 encoding and a '--- {filePath} ---' separator between file contents. Ensure paths are relative to the target directory. Glob patterns like 'src/**/*.js' are supported. Avoid using for single files if a more specific single-file reading tool is available, unless the user specifically requests to process a list containing just one file via this tool. Other binary files (not explicitly requested as image/PDF) are generally skipped. Default excludes apply to common non-text files (except for explicitly requested images/PDFs) and large dependency directories unless 'useDefaultExcludes' is false.`,
|
||||
parameterSchema,
|
||||
);
|
||||
this.geminiIgnorePatterns = config
|
||||
.getFileService()
|
||||
.getGeminiIgnorePatterns();
|
||||
}
|
||||
|
||||
validateParams(params: ReadManyFilesParams): string | null {
|
||||
const errors = SchemaValidator.validate(this.schema.parameters, params);
|
||||
if (errors) {
|
||||
return errors;
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
getDescription(params: ReadManyFilesParams): string {
|
||||
const allPatterns = [...params.paths, ...(params.include || [])];
|
||||
const pathDesc = `using patterns: \`${allPatterns.join('`, `')}\` (within target directory: \`${this.config.getTargetDir()}\`)`;
|
||||
|
||||
// Determine the final list of exclusion patterns exactly as in execute method
|
||||
const paramExcludes = params.exclude || [];
|
||||
const paramUseDefaultExcludes = params.useDefaultExcludes !== false;
|
||||
|
||||
const finalExclusionPatternsForDescription: string[] =
|
||||
paramUseDefaultExcludes
|
||||
? [...DEFAULT_EXCLUDES, ...paramExcludes, ...this.geminiIgnorePatterns]
|
||||
: [...paramExcludes, ...this.geminiIgnorePatterns];
|
||||
|
||||
let excludeDesc = `Excluding: ${finalExclusionPatternsForDescription.length > 0 ? `patterns like \`${finalExclusionPatternsForDescription.slice(0, 2).join('`, `')}${finalExclusionPatternsForDescription.length > 2 ? '...`' : '`'}` : 'none specified'}`;
|
||||
|
||||
// Add a note if .geminiignore patterns contributed to the final list of exclusions
|
||||
if (this.geminiIgnorePatterns.length > 0) {
|
||||
const geminiPatternsInEffect = this.geminiIgnorePatterns.filter((p) =>
|
||||
finalExclusionPatternsForDescription.includes(p),
|
||||
).length;
|
||||
if (geminiPatternsInEffect > 0) {
|
||||
excludeDesc += ` (includes ${geminiPatternsInEffect} from .geminiignore)`;
|
||||
}
|
||||
}
|
||||
|
||||
return `Will attempt to read and concatenate files ${pathDesc}. ${excludeDesc}. File encoding: ${DEFAULT_ENCODING}. Separator: "${DEFAULT_OUTPUT_SEPARATOR_FORMAT.replace('{filePath}', 'path/to/file.ext')}".`;
|
||||
}
|
||||
|
||||
async execute(
|
||||
params: ReadManyFilesParams,
|
||||
signal: AbortSignal,
|
||||
): Promise<ToolResult> {
|
||||
const validationError = this.validateParams(params);
|
||||
if (validationError) {
|
||||
return {
|
||||
llmContent: `Error: Invalid parameters for ${this.displayName}. Reason: ${validationError}`,
|
||||
returnDisplay: `## Parameter Error\n\n${validationError}`,
|
||||
};
|
||||
}
|
||||
|
||||
const {
|
||||
paths: inputPatterns,
|
||||
include = [],
|
||||
exclude = [],
|
||||
useDefaultExcludes = true,
|
||||
respect_git_ignore = true,
|
||||
} = params;
|
||||
|
||||
const respectGitIgnore =
|
||||
respect_git_ignore ?? this.config.getFileFilteringRespectGitIgnore();
|
||||
|
||||
// Get centralized file discovery service
|
||||
const fileDiscovery = this.config.getFileService();
|
||||
|
||||
const filesToConsider = new Set<string>();
|
||||
const skippedFiles: Array<{ path: string; reason: string }> = [];
|
||||
const processedFilesRelativePaths: string[] = [];
|
||||
const contentParts: PartListUnion = [];
|
||||
|
||||
const effectiveExcludes = useDefaultExcludes
|
||||
? [...DEFAULT_EXCLUDES, ...exclude, ...this.geminiIgnorePatterns]
|
||||
: [...exclude, ...this.geminiIgnorePatterns];
|
||||
|
||||
const searchPatterns = [...inputPatterns, ...include];
|
||||
if (searchPatterns.length === 0) {
|
||||
return {
|
||||
llmContent: 'No search paths or include patterns provided.',
|
||||
returnDisplay: `## Information\n\nNo search paths or include patterns were specified. Nothing to read or concatenate.`,
|
||||
};
|
||||
}
|
||||
|
||||
try {
|
||||
const entries = await glob(searchPatterns, {
|
||||
cwd: this.config.getTargetDir(),
|
||||
ignore: effectiveExcludes,
|
||||
nodir: true,
|
||||
dot: true,
|
||||
absolute: true,
|
||||
nocase: true,
|
||||
signal,
|
||||
});
|
||||
|
||||
const filteredEntries = respectGitIgnore
|
||||
? fileDiscovery
|
||||
.filterFiles(
|
||||
entries.map((p) => path.relative(this.config.getTargetDir(), p)),
|
||||
{
|
||||
respectGitIgnore,
|
||||
},
|
||||
)
|
||||
.map((p) => path.resolve(this.config.getTargetDir(), p))
|
||||
: entries;
|
||||
|
||||
let gitIgnoredCount = 0;
|
||||
for (const absoluteFilePath of entries) {
|
||||
// Security check: ensure the glob library didn't return something outside targetDir.
|
||||
if (!absoluteFilePath.startsWith(this.config.getTargetDir())) {
|
||||
skippedFiles.push({
|
||||
path: absoluteFilePath,
|
||||
reason: `Security: Glob library returned path outside target directory. Base: ${this.config.getTargetDir()}, Path: ${absoluteFilePath}`,
|
||||
});
|
||||
continue;
|
||||
}
|
||||
|
||||
// Check if this file was filtered out by git ignore
|
||||
if (respectGitIgnore && !filteredEntries.includes(absoluteFilePath)) {
|
||||
gitIgnoredCount++;
|
||||
continue;
|
||||
}
|
||||
|
||||
filesToConsider.add(absoluteFilePath);
|
||||
}
|
||||
|
||||
// Add info about git-ignored files if any were filtered
|
||||
if (gitIgnoredCount > 0) {
|
||||
skippedFiles.push({
|
||||
path: `${gitIgnoredCount} file(s)`,
|
||||
reason: 'ignored',
|
||||
});
|
||||
}
|
||||
} catch (error) {
|
||||
return {
|
||||
llmContent: `Error during file search: ${getErrorMessage(error)}`,
|
||||
returnDisplay: `## File Search Error\n\nAn error occurred while searching for files:\n\`\`\`\n${getErrorMessage(error)}\n\`\`\``,
|
||||
};
|
||||
}
|
||||
|
||||
const sortedFiles = Array.from(filesToConsider).sort();
|
||||
|
||||
for (const filePath of sortedFiles) {
|
||||
const relativePathForDisplay = path
|
||||
.relative(this.config.getTargetDir(), filePath)
|
||||
.replace(/\\/g, '/');
|
||||
|
||||
const fileType = detectFileType(filePath);
|
||||
|
||||
if (fileType === 'image' || fileType === 'pdf') {
|
||||
const fileExtension = path.extname(filePath).toLowerCase();
|
||||
const fileNameWithoutExtension = path.basename(filePath, fileExtension);
|
||||
const requestedExplicitly = inputPatterns.some(
|
||||
(pattern: string) =>
|
||||
pattern.toLowerCase().includes(fileExtension) ||
|
||||
pattern.includes(fileNameWithoutExtension),
|
||||
);
|
||||
|
||||
if (!requestedExplicitly) {
|
||||
skippedFiles.push({
|
||||
path: relativePathForDisplay,
|
||||
reason:
|
||||
'asset file (image/pdf) was not explicitly requested by name or extension',
|
||||
});
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
// Use processSingleFileContent for all file types now
|
||||
const fileReadResult = await processSingleFileContent(
|
||||
filePath,
|
||||
this.config.getTargetDir(),
|
||||
);
|
||||
|
||||
if (fileReadResult.error) {
|
||||
skippedFiles.push({
|
||||
path: relativePathForDisplay,
|
||||
reason: `Read error: ${fileReadResult.error}`,
|
||||
});
|
||||
} else {
|
||||
if (typeof fileReadResult.llmContent === 'string') {
|
||||
const separator = DEFAULT_OUTPUT_SEPARATOR_FORMAT.replace(
|
||||
'{filePath}',
|
||||
filePath,
|
||||
);
|
||||
contentParts.push(`${separator}\n\n${fileReadResult.llmContent}\n\n`);
|
||||
} else {
|
||||
contentParts.push(fileReadResult.llmContent); // This is a Part for image/pdf
|
||||
}
|
||||
processedFilesRelativePaths.push(relativePathForDisplay);
|
||||
const lines =
|
||||
typeof fileReadResult.llmContent === 'string'
|
||||
? fileReadResult.llmContent.split('\n').length
|
||||
: undefined;
|
||||
const mimetype = getSpecificMimeType(filePath);
|
||||
recordFileOperationMetric(
|
||||
this.config,
|
||||
FileOperation.READ,
|
||||
lines,
|
||||
mimetype,
|
||||
path.extname(filePath),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
let displayMessage = `### ReadManyFiles Result (Target Dir: \`${this.config.getTargetDir()}\`)\n\n`;
|
||||
if (processedFilesRelativePaths.length > 0) {
|
||||
displayMessage += `Successfully read and concatenated content from **${processedFilesRelativePaths.length} file(s)**.\n`;
|
||||
if (processedFilesRelativePaths.length <= 10) {
|
||||
displayMessage += `\n**Processed Files:**\n`;
|
||||
processedFilesRelativePaths.forEach(
|
||||
(p) => (displayMessage += `- \`${p}\`\n`),
|
||||
);
|
||||
} else {
|
||||
displayMessage += `\n**Processed Files (first 10 shown):**\n`;
|
||||
processedFilesRelativePaths
|
||||
.slice(0, 10)
|
||||
.forEach((p) => (displayMessage += `- \`${p}\`\n`));
|
||||
displayMessage += `- ...and ${processedFilesRelativePaths.length - 10} more.\n`;
|
||||
}
|
||||
}
|
||||
|
||||
if (skippedFiles.length > 0) {
|
||||
if (processedFilesRelativePaths.length === 0) {
|
||||
displayMessage += `No files were read and concatenated based on the criteria.\n`;
|
||||
}
|
||||
if (skippedFiles.length <= 5) {
|
||||
displayMessage += `\n**Skipped ${skippedFiles.length} item(s):**\n`;
|
||||
} else {
|
||||
displayMessage += `\n**Skipped ${skippedFiles.length} item(s) (first 5 shown):**\n`;
|
||||
}
|
||||
skippedFiles
|
||||
.slice(0, 5)
|
||||
.forEach(
|
||||
(f) => (displayMessage += `- \`${f.path}\` (Reason: ${f.reason})\n`),
|
||||
);
|
||||
if (skippedFiles.length > 5) {
|
||||
displayMessage += `- ...and ${skippedFiles.length - 5} more.\n`;
|
||||
}
|
||||
} else if (
|
||||
processedFilesRelativePaths.length === 0 &&
|
||||
skippedFiles.length === 0
|
||||
) {
|
||||
displayMessage += `No files were read and concatenated based on the criteria.\n`;
|
||||
}
|
||||
|
||||
if (contentParts.length === 0) {
|
||||
contentParts.push(
|
||||
'No files matching the criteria were found or all were skipped.',
|
||||
);
|
||||
}
|
||||
return {
|
||||
llmContent: contentParts,
|
||||
returnDisplay: displayMessage.trim(),
|
||||
};
|
||||
}
|
||||
}
|
||||
432
packages/core/src/tools/shell.test.ts
Normal file
432
packages/core/src/tools/shell.test.ts
Normal file
@@ -0,0 +1,432 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { expect, describe, it, vi, beforeEach } from 'vitest';
|
||||
import { ShellTool } from './shell.js';
|
||||
import { Config } from '../config/config.js';
|
||||
import * as summarizer from '../utils/summarizer.js';
|
||||
import { GeminiClient } from '../core/client.js';
|
||||
|
||||
describe('ShellTool', () => {
|
||||
it('should allow a command if no restrictions are provided', async () => {
|
||||
const config = {
|
||||
getCoreTools: () => undefined,
|
||||
getExcludeTools: () => undefined,
|
||||
} as Config;
|
||||
const shellTool = new ShellTool(config);
|
||||
const result = shellTool.isCommandAllowed('ls -l');
|
||||
expect(result.allowed).toBe(true);
|
||||
});
|
||||
|
||||
it('should allow a command if it is in the allowed list', async () => {
|
||||
const config = {
|
||||
getCoreTools: () => ['ShellTool(ls -l)'],
|
||||
getExcludeTools: () => undefined,
|
||||
} as unknown as Config;
|
||||
const shellTool = new ShellTool(config);
|
||||
const result = shellTool.isCommandAllowed('ls -l');
|
||||
expect(result.allowed).toBe(true);
|
||||
});
|
||||
|
||||
it('should block a command if it is not in the allowed list', async () => {
|
||||
const config = {
|
||||
getCoreTools: () => ['ShellTool(ls -l)'],
|
||||
getExcludeTools: () => undefined,
|
||||
} as unknown as Config;
|
||||
const shellTool = new ShellTool(config);
|
||||
const result = shellTool.isCommandAllowed('rm -rf /');
|
||||
expect(result.allowed).toBe(false);
|
||||
expect(result.reason).toBe(
|
||||
"Command 'rm -rf /' is not in the allowed commands list",
|
||||
);
|
||||
});
|
||||
|
||||
it('should block a command if it is in the blocked list', async () => {
|
||||
const config = {
|
||||
getCoreTools: () => undefined,
|
||||
getExcludeTools: () => ['ShellTool(rm -rf /)'],
|
||||
} as unknown as Config;
|
||||
const shellTool = new ShellTool(config);
|
||||
const result = shellTool.isCommandAllowed('rm -rf /');
|
||||
expect(result.allowed).toBe(false);
|
||||
expect(result.reason).toBe(
|
||||
"Command 'rm -rf /' is blocked by configuration",
|
||||
);
|
||||
});
|
||||
|
||||
it('should allow a command if it is not in the blocked list', async () => {
|
||||
const config = {
|
||||
getCoreTools: () => undefined,
|
||||
getExcludeTools: () => ['ShellTool(rm -rf /)'],
|
||||
} as unknown as Config;
|
||||
const shellTool = new ShellTool(config);
|
||||
const result = shellTool.isCommandAllowed('ls -l');
|
||||
expect(result.allowed).toBe(true);
|
||||
});
|
||||
|
||||
it('should block a command if it is in both the allowed and blocked lists', async () => {
|
||||
const config = {
|
||||
getCoreTools: () => ['ShellTool(rm -rf /)'],
|
||||
getExcludeTools: () => ['ShellTool(rm -rf /)'],
|
||||
} as unknown as Config;
|
||||
const shellTool = new ShellTool(config);
|
||||
const result = shellTool.isCommandAllowed('rm -rf /');
|
||||
expect(result.allowed).toBe(false);
|
||||
expect(result.reason).toBe(
|
||||
"Command 'rm -rf /' is blocked by configuration",
|
||||
);
|
||||
});
|
||||
|
||||
it('should allow any command when ShellTool is in coreTools without specific commands', async () => {
|
||||
const config = {
|
||||
getCoreTools: () => ['ShellTool'],
|
||||
getExcludeTools: () => [],
|
||||
} as unknown as Config;
|
||||
const shellTool = new ShellTool(config);
|
||||
const result = shellTool.isCommandAllowed('any command');
|
||||
expect(result.allowed).toBe(true);
|
||||
});
|
||||
|
||||
it('should block any command when ShellTool is in excludeTools without specific commands', async () => {
|
||||
const config = {
|
||||
getCoreTools: () => [],
|
||||
getExcludeTools: () => ['ShellTool'],
|
||||
} as unknown as Config;
|
||||
const shellTool = new ShellTool(config);
|
||||
const result = shellTool.isCommandAllowed('any command');
|
||||
expect(result.allowed).toBe(false);
|
||||
expect(result.reason).toBe(
|
||||
'Shell tool is globally disabled in configuration',
|
||||
);
|
||||
});
|
||||
|
||||
it('should allow a command if it is in the allowed list using the public-facing name', async () => {
|
||||
const config = {
|
||||
getCoreTools: () => ['run_shell_command(ls -l)'],
|
||||
getExcludeTools: () => undefined,
|
||||
} as unknown as Config;
|
||||
const shellTool = new ShellTool(config);
|
||||
const result = shellTool.isCommandAllowed('ls -l');
|
||||
expect(result.allowed).toBe(true);
|
||||
});
|
||||
|
||||
it('should block a command if it is in the blocked list using the public-facing name', async () => {
|
||||
const config = {
|
||||
getCoreTools: () => undefined,
|
||||
getExcludeTools: () => ['run_shell_command(rm -rf /)'],
|
||||
} as unknown as Config;
|
||||
const shellTool = new ShellTool(config);
|
||||
const result = shellTool.isCommandAllowed('rm -rf /');
|
||||
expect(result.allowed).toBe(false);
|
||||
expect(result.reason).toBe(
|
||||
"Command 'rm -rf /' is blocked by configuration",
|
||||
);
|
||||
});
|
||||
|
||||
it('should block any command when ShellTool is in excludeTools using the public-facing name', async () => {
|
||||
const config = {
|
||||
getCoreTools: () => [],
|
||||
getExcludeTools: () => ['run_shell_command'],
|
||||
} as unknown as Config;
|
||||
const shellTool = new ShellTool(config);
|
||||
const result = shellTool.isCommandAllowed('any command');
|
||||
expect(result.allowed).toBe(false);
|
||||
expect(result.reason).toBe(
|
||||
'Shell tool is globally disabled in configuration',
|
||||
);
|
||||
});
|
||||
|
||||
it('should block any command if coreTools contains an empty ShellTool command list using the public-facing name', async () => {
|
||||
const config = {
|
||||
getCoreTools: () => ['run_shell_command()'],
|
||||
getExcludeTools: () => [],
|
||||
} as unknown as Config;
|
||||
const shellTool = new ShellTool(config);
|
||||
const result = shellTool.isCommandAllowed('any command');
|
||||
expect(result.allowed).toBe(false);
|
||||
expect(result.reason).toBe(
|
||||
"Command 'any command' is not in the allowed commands list",
|
||||
);
|
||||
});
|
||||
|
||||
it('should block any command if coreTools contains an empty ShellTool command list', async () => {
|
||||
const config = {
|
||||
getCoreTools: () => ['ShellTool()'],
|
||||
getExcludeTools: () => [],
|
||||
} as unknown as Config;
|
||||
const shellTool = new ShellTool(config);
|
||||
const result = shellTool.isCommandAllowed('any command');
|
||||
expect(result.allowed).toBe(false);
|
||||
expect(result.reason).toBe(
|
||||
"Command 'any command' is not in the allowed commands list",
|
||||
);
|
||||
});
|
||||
|
||||
it('should block a command with extra whitespace if it is in the blocked list', async () => {
|
||||
const config = {
|
||||
getCoreTools: () => undefined,
|
||||
getExcludeTools: () => ['ShellTool(rm -rf /)'],
|
||||
} as unknown as Config;
|
||||
const shellTool = new ShellTool(config);
|
||||
const result = shellTool.isCommandAllowed(' rm -rf / ');
|
||||
expect(result.allowed).toBe(false);
|
||||
expect(result.reason).toBe(
|
||||
"Command 'rm -rf /' is blocked by configuration",
|
||||
);
|
||||
});
|
||||
|
||||
it('should allow any command when ShellTool is present with specific commands', async () => {
|
||||
const config = {
|
||||
getCoreTools: () => ['ShellTool', 'ShellTool(ls)'],
|
||||
getExcludeTools: () => [],
|
||||
} as unknown as Config;
|
||||
const shellTool = new ShellTool(config);
|
||||
const result = shellTool.isCommandAllowed('any command');
|
||||
expect(result.allowed).toBe(true);
|
||||
});
|
||||
|
||||
it('should block a command on the blocklist even with a wildcard allow', async () => {
|
||||
const config = {
|
||||
getCoreTools: () => ['ShellTool'],
|
||||
getExcludeTools: () => ['ShellTool(rm -rf /)'],
|
||||
} as unknown as Config;
|
||||
const shellTool = new ShellTool(config);
|
||||
const result = shellTool.isCommandAllowed('rm -rf /');
|
||||
expect(result.allowed).toBe(false);
|
||||
expect(result.reason).toBe(
|
||||
"Command 'rm -rf /' is blocked by configuration",
|
||||
);
|
||||
});
|
||||
|
||||
it('should allow a command that starts with an allowed command prefix', async () => {
|
||||
const config = {
|
||||
getCoreTools: () => ['ShellTool(gh issue edit)'],
|
||||
getExcludeTools: () => [],
|
||||
} as unknown as Config;
|
||||
const shellTool = new ShellTool(config);
|
||||
const result = shellTool.isCommandAllowed(
|
||||
'gh issue edit 1 --add-label "kind/feature"',
|
||||
);
|
||||
expect(result.allowed).toBe(true);
|
||||
});
|
||||
|
||||
it('should allow a command that starts with an allowed command prefix using the public-facing name', async () => {
|
||||
const config = {
|
||||
getCoreTools: () => ['run_shell_command(gh issue edit)'],
|
||||
getExcludeTools: () => [],
|
||||
} as unknown as Config;
|
||||
const shellTool = new ShellTool(config);
|
||||
const result = shellTool.isCommandAllowed(
|
||||
'gh issue edit 1 --add-label "kind/feature"',
|
||||
);
|
||||
expect(result.allowed).toBe(true);
|
||||
});
|
||||
|
||||
it('should not allow a command that starts with an allowed command prefix but is chained with another command', async () => {
|
||||
const config = {
|
||||
getCoreTools: () => ['run_shell_command(gh issue edit)'],
|
||||
getExcludeTools: () => [],
|
||||
} as unknown as Config;
|
||||
const shellTool = new ShellTool(config);
|
||||
const result = shellTool.isCommandAllowed('gh issue edit&&rm -rf /');
|
||||
expect(result.allowed).toBe(false);
|
||||
expect(result.reason).toBe(
|
||||
"Command 'rm -rf /' is not in the allowed commands list",
|
||||
);
|
||||
});
|
||||
|
||||
it('should not allow a command that is a prefix of an allowed command', async () => {
|
||||
const config = {
|
||||
getCoreTools: () => ['run_shell_command(gh issue edit)'],
|
||||
getExcludeTools: () => [],
|
||||
} as unknown as Config;
|
||||
const shellTool = new ShellTool(config);
|
||||
const result = shellTool.isCommandAllowed('gh issue');
|
||||
expect(result.allowed).toBe(false);
|
||||
expect(result.reason).toBe(
|
||||
"Command 'gh issue' is not in the allowed commands list",
|
||||
);
|
||||
});
|
||||
|
||||
it('should not allow a command that is a prefix of a blocked command', async () => {
|
||||
const config = {
|
||||
getCoreTools: () => [],
|
||||
getExcludeTools: () => ['run_shell_command(gh issue edit)'],
|
||||
} as unknown as Config;
|
||||
const shellTool = new ShellTool(config);
|
||||
const result = shellTool.isCommandAllowed('gh issue');
|
||||
expect(result.allowed).toBe(true);
|
||||
});
|
||||
|
||||
it('should not allow a command that is chained with a pipe', async () => {
|
||||
const config = {
|
||||
getCoreTools: () => ['run_shell_command(gh issue list)'],
|
||||
getExcludeTools: () => [],
|
||||
} as unknown as Config;
|
||||
const shellTool = new ShellTool(config);
|
||||
const result = shellTool.isCommandAllowed('gh issue list | rm -rf /');
|
||||
expect(result.allowed).toBe(false);
|
||||
expect(result.reason).toBe(
|
||||
"Command 'rm -rf /' is not in the allowed commands list",
|
||||
);
|
||||
});
|
||||
|
||||
it('should not allow a command that is chained with a semicolon', async () => {
|
||||
const config = {
|
||||
getCoreTools: () => ['run_shell_command(gh issue list)'],
|
||||
getExcludeTools: () => [],
|
||||
} as unknown as Config;
|
||||
const shellTool = new ShellTool(config);
|
||||
const result = shellTool.isCommandAllowed('gh issue list; rm -rf /');
|
||||
expect(result.allowed).toBe(false);
|
||||
expect(result.reason).toBe(
|
||||
"Command 'rm -rf /' is not in the allowed commands list",
|
||||
);
|
||||
});
|
||||
|
||||
it('should block a chained command if any part is blocked', async () => {
|
||||
const config = {
|
||||
getCoreTools: () => ['run_shell_command(echo "hello")'],
|
||||
getExcludeTools: () => ['run_shell_command(rm)'],
|
||||
} as unknown as Config;
|
||||
const shellTool = new ShellTool(config);
|
||||
const result = shellTool.isCommandAllowed('echo "hello" && rm -rf /');
|
||||
expect(result.allowed).toBe(false);
|
||||
expect(result.reason).toBe(
|
||||
"Command 'rm -rf /' is blocked by configuration",
|
||||
);
|
||||
});
|
||||
|
||||
it('should block a command if its prefix is on the blocklist, even if the command itself is on the allowlist', async () => {
|
||||
const config = {
|
||||
getCoreTools: () => ['run_shell_command(git push)'],
|
||||
getExcludeTools: () => ['run_shell_command(git)'],
|
||||
} as unknown as Config;
|
||||
const shellTool = new ShellTool(config);
|
||||
const result = shellTool.isCommandAllowed('git push');
|
||||
expect(result.allowed).toBe(false);
|
||||
expect(result.reason).toBe(
|
||||
"Command 'git push' is blocked by configuration",
|
||||
);
|
||||
});
|
||||
|
||||
it('should be case-sensitive in its matching', async () => {
|
||||
const config = {
|
||||
getCoreTools: () => ['run_shell_command(echo)'],
|
||||
getExcludeTools: () => [],
|
||||
} as unknown as Config;
|
||||
const shellTool = new ShellTool(config);
|
||||
const result = shellTool.isCommandAllowed('ECHO "hello"');
|
||||
expect(result.allowed).toBe(false);
|
||||
expect(result.reason).toBe(
|
||||
'Command \'ECHO "hello"\' is not in the allowed commands list',
|
||||
);
|
||||
});
|
||||
|
||||
it('should correctly handle commands with extra whitespace around chaining operators', async () => {
|
||||
const config = {
|
||||
getCoreTools: () => ['run_shell_command(ls -l)'],
|
||||
getExcludeTools: () => ['run_shell_command(rm)'],
|
||||
} as unknown as Config;
|
||||
const shellTool = new ShellTool(config);
|
||||
const result = shellTool.isCommandAllowed('ls -l ; rm -rf /');
|
||||
expect(result.allowed).toBe(false);
|
||||
expect(result.reason).toBe(
|
||||
"Command 'rm -rf /' is blocked by configuration",
|
||||
);
|
||||
});
|
||||
|
||||
it('should allow a chained command if all parts are allowed', async () => {
|
||||
const config = {
|
||||
getCoreTools: () => [
|
||||
'run_shell_command(echo)',
|
||||
'run_shell_command(ls -l)',
|
||||
],
|
||||
getExcludeTools: () => [],
|
||||
} as unknown as Config;
|
||||
const shellTool = new ShellTool(config);
|
||||
const result = shellTool.isCommandAllowed('echo "hello" && ls -l');
|
||||
expect(result.allowed).toBe(true);
|
||||
});
|
||||
|
||||
it('should allow a command with command substitution using backticks', async () => {
|
||||
const config = {
|
||||
getCoreTools: () => ['run_shell_command(echo)'],
|
||||
getExcludeTools: () => [],
|
||||
} as unknown as Config;
|
||||
const shellTool = new ShellTool(config);
|
||||
const result = shellTool.isCommandAllowed('echo `rm -rf /`');
|
||||
expect(result.allowed).toBe(true);
|
||||
});
|
||||
|
||||
it('should block a command with command substitution using $()', async () => {
|
||||
const config = {
|
||||
getCoreTools: () => ['run_shell_command(echo)'],
|
||||
getExcludeTools: () => [],
|
||||
} as unknown as Config;
|
||||
const shellTool = new ShellTool(config);
|
||||
const result = shellTool.isCommandAllowed('echo $(rm -rf /)');
|
||||
expect(result.allowed).toBe(false);
|
||||
expect(result.reason).toBe(
|
||||
'Command substitution using $() is not allowed for security reasons',
|
||||
);
|
||||
});
|
||||
|
||||
it('should allow a command with I/O redirection', async () => {
|
||||
const config = {
|
||||
getCoreTools: () => ['run_shell_command(echo)'],
|
||||
getExcludeTools: () => [],
|
||||
} as unknown as Config;
|
||||
const shellTool = new ShellTool(config);
|
||||
const result = shellTool.isCommandAllowed('echo "hello" > file.txt');
|
||||
expect(result.allowed).toBe(true);
|
||||
});
|
||||
|
||||
it('should not allow a command that is chained with a double pipe', async () => {
|
||||
const config = {
|
||||
getCoreTools: () => ['run_shell_command(gh issue list)'],
|
||||
getExcludeTools: () => [],
|
||||
} as unknown as Config;
|
||||
const shellTool = new ShellTool(config);
|
||||
const result = shellTool.isCommandAllowed('gh issue list || rm -rf /');
|
||||
expect(result.allowed).toBe(false);
|
||||
expect(result.reason).toBe(
|
||||
"Command 'rm -rf /' is not in the allowed commands list",
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('ShellTool Bug Reproduction', () => {
|
||||
let shellTool: ShellTool;
|
||||
let config: Config;
|
||||
|
||||
beforeEach(() => {
|
||||
config = {
|
||||
getCoreTools: () => undefined,
|
||||
getExcludeTools: () => undefined,
|
||||
getDebugMode: () => false,
|
||||
getGeminiClient: () => ({}) as GeminiClient,
|
||||
getTargetDir: () => '.',
|
||||
} as unknown as Config;
|
||||
shellTool = new ShellTool(config);
|
||||
});
|
||||
|
||||
it('should not let the summarizer override the return display', async () => {
|
||||
const summarizeSpy = vi
|
||||
.spyOn(summarizer, 'summarizeToolOutput')
|
||||
.mockResolvedValue('summarized output');
|
||||
|
||||
const abortSignal = new AbortController().signal;
|
||||
const result = await shellTool.execute(
|
||||
{ command: 'echo "hello"' },
|
||||
abortSignal,
|
||||
);
|
||||
|
||||
expect(result.returnDisplay).toBe('hello\n');
|
||||
expect(result.llmContent).toBe('summarized output');
|
||||
expect(summarizeSpy).toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
503
packages/core/src/tools/shell.ts
Normal file
503
packages/core/src/tools/shell.ts
Normal file
@@ -0,0 +1,503 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import fs from 'fs';
|
||||
import path from 'path';
|
||||
import os from 'os';
|
||||
import crypto from 'crypto';
|
||||
import { Config } from '../config/config.js';
|
||||
import {
|
||||
BaseTool,
|
||||
ToolResult,
|
||||
ToolCallConfirmationDetails,
|
||||
ToolExecuteConfirmationDetails,
|
||||
ToolConfirmationOutcome,
|
||||
} from './tools.js';
|
||||
import { Type } from '@google/genai';
|
||||
import { SchemaValidator } from '../utils/schemaValidator.js';
|
||||
import { getErrorMessage } from '../utils/errors.js';
|
||||
import stripAnsi from 'strip-ansi';
|
||||
|
||||
export interface ShellToolParams {
|
||||
command: string;
|
||||
description?: string;
|
||||
directory?: string;
|
||||
}
|
||||
import { spawn } from 'child_process';
|
||||
import { summarizeToolOutput } from '../utils/summarizer.js';
|
||||
|
||||
const OUTPUT_UPDATE_INTERVAL_MS = 1000;
|
||||
|
||||
export class ShellTool extends BaseTool<ShellToolParams, ToolResult> {
|
||||
static Name: string = 'run_shell_command';
|
||||
private whitelist: Set<string> = new Set();
|
||||
|
||||
constructor(private readonly config: Config) {
|
||||
super(
|
||||
ShellTool.Name,
|
||||
'Shell',
|
||||
`This tool executes a given shell command as \`bash -c <command>\`. Command can start background processes using \`&\`. Command is executed as a subprocess that leads its own process group. Command process group can be terminated as \`kill -- -PGID\` or signaled as \`kill -s SIGNAL -- -PGID\`.
|
||||
|
||||
The following information is returned:
|
||||
|
||||
Command: Executed command.
|
||||
Directory: Directory (relative to project root) where command was executed, or \`(root)\`.
|
||||
Stdout: Output on stdout stream. Can be \`(empty)\` or partial on error and for any unwaited background processes.
|
||||
Stderr: Output on stderr stream. Can be \`(empty)\` or partial on error and for any unwaited background processes.
|
||||
Error: Error or \`(none)\` if no error was reported for the subprocess.
|
||||
Exit Code: Exit code or \`(none)\` if terminated by signal.
|
||||
Signal: Signal number or \`(none)\` if no signal was received.
|
||||
Background PIDs: List of background processes started or \`(none)\`.
|
||||
Process Group PGID: Process group started or \`(none)\``,
|
||||
{
|
||||
type: Type.OBJECT,
|
||||
properties: {
|
||||
command: {
|
||||
type: Type.STRING,
|
||||
description: 'Exact bash command to execute as `bash -c <command>`',
|
||||
},
|
||||
description: {
|
||||
type: Type.STRING,
|
||||
description:
|
||||
'Brief description of the command for the user. Be specific and concise. Ideally a single sentence. Can be up to 3 sentences for clarity. No line breaks.',
|
||||
},
|
||||
directory: {
|
||||
type: Type.STRING,
|
||||
description:
|
||||
'(OPTIONAL) Directory to run the command in, if not the project root directory. Must be relative to the project root directory and must already exist.',
|
||||
},
|
||||
},
|
||||
required: ['command'],
|
||||
},
|
||||
false, // output is not markdown
|
||||
true, // output can be updated
|
||||
);
|
||||
}
|
||||
|
||||
getDescription(params: ShellToolParams): string {
|
||||
let description = `${params.command}`;
|
||||
// append optional [in directory]
|
||||
// note description is needed even if validation fails due to absolute path
|
||||
if (params.directory) {
|
||||
description += ` [in ${params.directory}]`;
|
||||
}
|
||||
// append optional (description), replacing any line breaks with spaces
|
||||
if (params.description) {
|
||||
description += ` (${params.description.replace(/\n/g, ' ')})`;
|
||||
}
|
||||
return description;
|
||||
}
|
||||
|
||||
/**
|
||||
* Extracts the root command from a given shell command string.
|
||||
* This is used to identify the base command for permission checks.
|
||||
*
|
||||
* @param command The shell command string to parse
|
||||
* @returns The root command name, or undefined if it cannot be determined
|
||||
* @example getCommandRoot("ls -la /tmp") returns "ls"
|
||||
* @example getCommandRoot("git status && npm test") returns "git"
|
||||
*/
|
||||
getCommandRoot(command: string): string | undefined {
|
||||
return command
|
||||
.trim() // remove leading and trailing whitespace
|
||||
.replace(/[{}()]/g, '') // remove all grouping operators
|
||||
.split(/[\s;&|]+/)[0] // split on any whitespace or separator or chaining operators and take first part
|
||||
?.split(/[/\\]/) // split on any path separators (or return undefined if previous line was undefined)
|
||||
.pop(); // take last part and return command root (or undefined if previous line was empty)
|
||||
}
|
||||
|
||||
/**
|
||||
* Determines whether a given shell command is allowed to execute based on
|
||||
* the tool's configuration including allowlists and blocklists.
|
||||
*
|
||||
* @param command The shell command string to validate
|
||||
* @returns An object with 'allowed' boolean and optional 'reason' string if not allowed
|
||||
*/
|
||||
isCommandAllowed(command: string): { allowed: boolean; reason?: string } {
|
||||
// 0. Disallow command substitution
|
||||
if (command.includes('$(')) {
|
||||
return {
|
||||
allowed: false,
|
||||
reason:
|
||||
'Command substitution using $() is not allowed for security reasons',
|
||||
};
|
||||
}
|
||||
|
||||
const SHELL_TOOL_NAMES = [ShellTool.name, ShellTool.Name];
|
||||
|
||||
const normalize = (cmd: string): string => cmd.trim().replace(/\s+/g, ' ');
|
||||
|
||||
/**
|
||||
* Checks if a command string starts with a given prefix, ensuring it's a
|
||||
* whole word match (i.e., followed by a space or it's an exact match).
|
||||
* e.g., `isPrefixedBy('npm install', 'npm')` -> true
|
||||
* e.g., `isPrefixedBy('npm', 'npm')` -> true
|
||||
* e.g., `isPrefixedBy('npminstall', 'npm')` -> false
|
||||
*/
|
||||
const isPrefixedBy = (cmd: string, prefix: string): boolean => {
|
||||
if (!cmd.startsWith(prefix)) {
|
||||
return false;
|
||||
}
|
||||
return cmd.length === prefix.length || cmd[prefix.length] === ' ';
|
||||
};
|
||||
|
||||
/**
|
||||
* Extracts and normalizes shell commands from a list of tool strings.
|
||||
* e.g., 'ShellTool("ls -l")' becomes 'ls -l'
|
||||
*/
|
||||
const extractCommands = (tools: string[]): string[] =>
|
||||
tools.flatMap((tool) => {
|
||||
for (const toolName of SHELL_TOOL_NAMES) {
|
||||
if (tool.startsWith(`${toolName}(`) && tool.endsWith(')')) {
|
||||
return [normalize(tool.slice(toolName.length + 1, -1))];
|
||||
}
|
||||
}
|
||||
return [];
|
||||
});
|
||||
|
||||
const coreTools = this.config.getCoreTools() || [];
|
||||
const excludeTools = this.config.getExcludeTools() || [];
|
||||
|
||||
// 1. Check if the shell tool is globally disabled.
|
||||
if (SHELL_TOOL_NAMES.some((name) => excludeTools.includes(name))) {
|
||||
return {
|
||||
allowed: false,
|
||||
reason: 'Shell tool is globally disabled in configuration',
|
||||
};
|
||||
}
|
||||
|
||||
const blockedCommands = new Set(extractCommands(excludeTools));
|
||||
const allowedCommands = new Set(extractCommands(coreTools));
|
||||
|
||||
const hasSpecificAllowedCommands = allowedCommands.size > 0;
|
||||
const isWildcardAllowed = SHELL_TOOL_NAMES.some((name) =>
|
||||
coreTools.includes(name),
|
||||
);
|
||||
|
||||
const commandsToValidate = command.split(/&&|\|\||\||;/).map(normalize);
|
||||
|
||||
const blockedCommandsArr = [...blockedCommands];
|
||||
|
||||
for (const cmd of commandsToValidate) {
|
||||
// 2. Check if the command is on the blocklist.
|
||||
const isBlocked = blockedCommandsArr.some((blocked) =>
|
||||
isPrefixedBy(cmd, blocked),
|
||||
);
|
||||
if (isBlocked) {
|
||||
return {
|
||||
allowed: false,
|
||||
reason: `Command '${cmd}' is blocked by configuration`,
|
||||
};
|
||||
}
|
||||
|
||||
// 3. If in strict allow-list mode, check if the command is permitted.
|
||||
const isStrictAllowlist =
|
||||
hasSpecificAllowedCommands && !isWildcardAllowed;
|
||||
const allowedCommandsArr = [...allowedCommands];
|
||||
if (isStrictAllowlist) {
|
||||
const isAllowed = allowedCommandsArr.some((allowed) =>
|
||||
isPrefixedBy(cmd, allowed),
|
||||
);
|
||||
if (!isAllowed) {
|
||||
return {
|
||||
allowed: false,
|
||||
reason: `Command '${cmd}' is not in the allowed commands list`,
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 4. If all checks pass, the command is allowed.
|
||||
return { allowed: true };
|
||||
}
|
||||
|
||||
validateToolParams(params: ShellToolParams): string | null {
|
||||
const commandCheck = this.isCommandAllowed(params.command);
|
||||
if (!commandCheck.allowed) {
|
||||
if (!commandCheck.reason) {
|
||||
console.error(
|
||||
'Unexpected: isCommandAllowed returned false without a reason',
|
||||
);
|
||||
return `Command is not allowed: ${params.command}`;
|
||||
}
|
||||
return commandCheck.reason;
|
||||
}
|
||||
const errors = SchemaValidator.validate(this.schema.parameters, params);
|
||||
if (errors) {
|
||||
return errors;
|
||||
}
|
||||
if (!params.command.trim()) {
|
||||
return 'Command cannot be empty.';
|
||||
}
|
||||
if (!this.getCommandRoot(params.command)) {
|
||||
return 'Could not identify command root to obtain permission from user.';
|
||||
}
|
||||
if (params.directory) {
|
||||
if (path.isAbsolute(params.directory)) {
|
||||
return 'Directory cannot be absolute. Must be relative to the project root directory.';
|
||||
}
|
||||
const directory = path.resolve(
|
||||
this.config.getTargetDir(),
|
||||
params.directory,
|
||||
);
|
||||
if (!fs.existsSync(directory)) {
|
||||
return 'Directory must exist.';
|
||||
}
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
async shouldConfirmExecute(
|
||||
params: ShellToolParams,
|
||||
_abortSignal: AbortSignal,
|
||||
): Promise<ToolCallConfirmationDetails | false> {
|
||||
if (this.validateToolParams(params)) {
|
||||
return false; // skip confirmation, execute call will fail immediately
|
||||
}
|
||||
const rootCommand = this.getCommandRoot(params.command)!; // must be non-empty string post-validation
|
||||
if (this.whitelist.has(rootCommand)) {
|
||||
return false; // already approved and whitelisted
|
||||
}
|
||||
const confirmationDetails: ToolExecuteConfirmationDetails = {
|
||||
type: 'exec',
|
||||
title: 'Confirm Shell Command',
|
||||
command: params.command,
|
||||
rootCommand,
|
||||
onConfirm: async (outcome: ToolConfirmationOutcome) => {
|
||||
if (outcome === ToolConfirmationOutcome.ProceedAlways) {
|
||||
this.whitelist.add(rootCommand);
|
||||
}
|
||||
},
|
||||
};
|
||||
return confirmationDetails;
|
||||
}
|
||||
|
||||
async execute(
|
||||
params: ShellToolParams,
|
||||
abortSignal: AbortSignal,
|
||||
updateOutput?: (chunk: string) => void,
|
||||
): Promise<ToolResult> {
|
||||
const validationError = this.validateToolParams(params);
|
||||
if (validationError) {
|
||||
return {
|
||||
llmContent: [
|
||||
`Command rejected: ${params.command}`,
|
||||
`Reason: ${validationError}`,
|
||||
].join('\n'),
|
||||
returnDisplay: `Error: ${validationError}`,
|
||||
};
|
||||
}
|
||||
|
||||
if (abortSignal.aborted) {
|
||||
return {
|
||||
llmContent: 'Command was cancelled by user before it could start.',
|
||||
returnDisplay: 'Command cancelled by user.',
|
||||
};
|
||||
}
|
||||
|
||||
const isWindows = os.platform() === 'win32';
|
||||
const tempFileName = `shell_pgrep_${crypto
|
||||
.randomBytes(6)
|
||||
.toString('hex')}.tmp`;
|
||||
const tempFilePath = path.join(os.tmpdir(), tempFileName);
|
||||
|
||||
// pgrep is not available on Windows, so we can't get background PIDs
|
||||
const command = isWindows
|
||||
? params.command
|
||||
: (() => {
|
||||
// wrap command to append subprocess pids (via pgrep) to temporary file
|
||||
let command = params.command.trim();
|
||||
if (!command.endsWith('&')) command += ';';
|
||||
return `{ ${command} }; __code=$?; pgrep -g 0 >${tempFilePath} 2>&1; exit $__code;`;
|
||||
})();
|
||||
|
||||
// spawn command in specified directory (or project root if not specified)
|
||||
const shell = isWindows
|
||||
? spawn('cmd.exe', ['/c', command], {
|
||||
stdio: ['ignore', 'pipe', 'pipe'],
|
||||
// detached: true, // ensure subprocess starts its own process group (esp. in Linux)
|
||||
cwd: path.resolve(this.config.getTargetDir(), params.directory || ''),
|
||||
})
|
||||
: spawn('bash', ['-c', command], {
|
||||
stdio: ['ignore', 'pipe', 'pipe'],
|
||||
detached: true, // ensure subprocess starts its own process group (esp. in Linux)
|
||||
cwd: path.resolve(this.config.getTargetDir(), params.directory || ''),
|
||||
});
|
||||
|
||||
let exited = false;
|
||||
let stdout = '';
|
||||
let output = '';
|
||||
let lastUpdateTime = Date.now();
|
||||
|
||||
const appendOutput = (str: string) => {
|
||||
output += str;
|
||||
if (
|
||||
updateOutput &&
|
||||
Date.now() - lastUpdateTime > OUTPUT_UPDATE_INTERVAL_MS
|
||||
) {
|
||||
updateOutput(output);
|
||||
lastUpdateTime = Date.now();
|
||||
}
|
||||
};
|
||||
|
||||
shell.stdout.on('data', (data: Buffer) => {
|
||||
// continue to consume post-exit for background processes
|
||||
// removing listeners can overflow OS buffer and block subprocesses
|
||||
// destroying (e.g. shell.stdout.destroy()) can terminate subprocesses via SIGPIPE
|
||||
if (!exited) {
|
||||
const str = stripAnsi(data.toString());
|
||||
stdout += str;
|
||||
appendOutput(str);
|
||||
}
|
||||
});
|
||||
|
||||
let stderr = '';
|
||||
shell.stderr.on('data', (data: Buffer) => {
|
||||
if (!exited) {
|
||||
const str = stripAnsi(data.toString());
|
||||
stderr += str;
|
||||
appendOutput(str);
|
||||
}
|
||||
});
|
||||
|
||||
let error: Error | null = null;
|
||||
shell.on('error', (err: Error) => {
|
||||
error = err;
|
||||
// remove wrapper from user's command in error message
|
||||
error.message = error.message.replace(command, params.command);
|
||||
});
|
||||
|
||||
let code: number | null = null;
|
||||
let processSignal: NodeJS.Signals | null = null;
|
||||
const exitHandler = (
|
||||
_code: number | null,
|
||||
_signal: NodeJS.Signals | null,
|
||||
) => {
|
||||
exited = true;
|
||||
code = _code;
|
||||
processSignal = _signal;
|
||||
};
|
||||
shell.on('exit', exitHandler);
|
||||
|
||||
const abortHandler = async () => {
|
||||
if (shell.pid && !exited) {
|
||||
if (os.platform() === 'win32') {
|
||||
// For Windows, use taskkill to kill the process tree
|
||||
spawn('taskkill', ['/pid', shell.pid.toString(), '/f', '/t']);
|
||||
} else {
|
||||
try {
|
||||
// attempt to SIGTERM process group (negative PID)
|
||||
// fall back to SIGKILL (to group) after 200ms
|
||||
process.kill(-shell.pid, 'SIGTERM');
|
||||
await new Promise((resolve) => setTimeout(resolve, 200));
|
||||
if (shell.pid && !exited) {
|
||||
process.kill(-shell.pid, 'SIGKILL');
|
||||
}
|
||||
} catch (_e) {
|
||||
// if group kill fails, fall back to killing just the main process
|
||||
try {
|
||||
if (shell.pid) {
|
||||
shell.kill('SIGKILL');
|
||||
}
|
||||
} catch (_e) {
|
||||
console.error(`failed to kill shell process ${shell.pid}: ${_e}`);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
abortSignal.addEventListener('abort', abortHandler);
|
||||
|
||||
// wait for the shell to exit
|
||||
try {
|
||||
await new Promise((resolve) => shell.on('exit', resolve));
|
||||
} finally {
|
||||
abortSignal.removeEventListener('abort', abortHandler);
|
||||
}
|
||||
|
||||
// parse pids (pgrep output) from temporary file and remove it
|
||||
const backgroundPIDs: number[] = [];
|
||||
if (os.platform() !== 'win32') {
|
||||
if (fs.existsSync(tempFilePath)) {
|
||||
const pgrepLines = fs
|
||||
.readFileSync(tempFilePath, 'utf8')
|
||||
.split('\n')
|
||||
.filter(Boolean);
|
||||
for (const line of pgrepLines) {
|
||||
if (!/^\d+$/.test(line)) {
|
||||
console.error(`pgrep: ${line}`);
|
||||
}
|
||||
const pid = Number(line);
|
||||
// exclude the shell subprocess pid
|
||||
if (pid !== shell.pid) {
|
||||
backgroundPIDs.push(pid);
|
||||
}
|
||||
}
|
||||
fs.unlinkSync(tempFilePath);
|
||||
} else {
|
||||
if (!abortSignal.aborted) {
|
||||
console.error('missing pgrep output');
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let llmContent = '';
|
||||
if (abortSignal.aborted) {
|
||||
llmContent = 'Command was cancelled by user before it could complete.';
|
||||
if (output.trim()) {
|
||||
llmContent += ` Below is the output (on stdout and stderr) before it was cancelled:\n${output}`;
|
||||
} else {
|
||||
llmContent += ' There was no output before it was cancelled.';
|
||||
}
|
||||
} else {
|
||||
llmContent = [
|
||||
`Command: ${params.command}`,
|
||||
`Directory: ${params.directory || '(root)'}`,
|
||||
`Stdout: ${stdout || '(empty)'}`,
|
||||
`Stderr: ${stderr || '(empty)'}`,
|
||||
`Error: ${error ?? '(none)'}`,
|
||||
`Exit Code: ${code ?? '(none)'}`,
|
||||
`Signal: ${processSignal ?? '(none)'}`,
|
||||
`Background PIDs: ${backgroundPIDs.length ? backgroundPIDs.join(', ') : '(none)'}`,
|
||||
`Process Group PGID: ${shell.pid ?? '(none)'}`,
|
||||
].join('\n');
|
||||
}
|
||||
|
||||
let returnDisplayMessage = '';
|
||||
if (this.config.getDebugMode()) {
|
||||
returnDisplayMessage = llmContent;
|
||||
} else {
|
||||
if (output.trim()) {
|
||||
returnDisplayMessage = output;
|
||||
} else {
|
||||
// Output is empty, let's provide a reason if the command failed or was cancelled
|
||||
if (abortSignal.aborted) {
|
||||
returnDisplayMessage = 'Command cancelled by user.';
|
||||
} else if (processSignal) {
|
||||
returnDisplayMessage = `Command terminated by signal: ${processSignal}`;
|
||||
} else if (error) {
|
||||
// If error is not null, it's an Error object (or other truthy value)
|
||||
returnDisplayMessage = `Command failed: ${getErrorMessage(error)}`;
|
||||
} else if (code !== null && code !== 0) {
|
||||
returnDisplayMessage = `Command exited with code: ${code}`;
|
||||
}
|
||||
// If output is empty and command succeeded (code 0, no error/signal/abort),
|
||||
// returnDisplayMessage will remain empty, which is fine.
|
||||
}
|
||||
}
|
||||
|
||||
const summary = await summarizeToolOutput(
|
||||
llmContent,
|
||||
this.config.getGeminiClient(),
|
||||
abortSignal,
|
||||
);
|
||||
|
||||
return {
|
||||
llmContent: summary,
|
||||
returnDisplay: returnDisplayMessage,
|
||||
};
|
||||
}
|
||||
}
|
||||
510
packages/core/src/tools/tool-registry.test.ts
Normal file
510
packages/core/src/tools/tool-registry.test.ts
Normal file
@@ -0,0 +1,510 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
/* eslint-disable @typescript-eslint/no-explicit-any */
|
||||
import {
|
||||
describe,
|
||||
it,
|
||||
expect,
|
||||
vi,
|
||||
beforeEach,
|
||||
afterEach,
|
||||
Mocked,
|
||||
} from 'vitest';
|
||||
import {
|
||||
ToolRegistry,
|
||||
DiscoveredTool,
|
||||
sanitizeParameters,
|
||||
} from './tool-registry.js';
|
||||
import { DiscoveredMCPTool } from './mcp-tool.js';
|
||||
import { Config, ConfigParameters, ApprovalMode } from '../config/config.js';
|
||||
import { BaseTool, ToolResult } from './tools.js';
|
||||
import {
|
||||
FunctionDeclaration,
|
||||
CallableTool,
|
||||
mcpToTool,
|
||||
Type,
|
||||
Schema,
|
||||
} from '@google/genai';
|
||||
import { spawn } from 'node:child_process';
|
||||
|
||||
// Use vi.hoisted to define the mock function so it can be used in the vi.mock factory
|
||||
const mockDiscoverMcpTools = vi.hoisted(() => vi.fn());
|
||||
|
||||
// Mock ./mcp-client.js to control its behavior within tool-registry tests
|
||||
vi.mock('./mcp-client.js', () => ({
|
||||
discoverMcpTools: mockDiscoverMcpTools,
|
||||
}));
|
||||
|
||||
// Mock node:child_process
|
||||
vi.mock('node:child_process', async () => {
|
||||
const actual = await vi.importActual('node:child_process');
|
||||
return {
|
||||
...actual,
|
||||
execSync: vi.fn(),
|
||||
spawn: vi.fn(),
|
||||
};
|
||||
});
|
||||
|
||||
// Mock MCP SDK Client and Transports
|
||||
const mockMcpClientConnect = vi.fn();
|
||||
const mockMcpClientOnError = vi.fn();
|
||||
const mockStdioTransportClose = vi.fn();
|
||||
const mockSseTransportClose = vi.fn();
|
||||
|
||||
vi.mock('@modelcontextprotocol/sdk/client/index.js', () => {
|
||||
const MockClient = vi.fn().mockImplementation(() => ({
|
||||
connect: mockMcpClientConnect,
|
||||
set onerror(handler: any) {
|
||||
mockMcpClientOnError(handler);
|
||||
},
|
||||
}));
|
||||
return { Client: MockClient };
|
||||
});
|
||||
|
||||
vi.mock('@modelcontextprotocol/sdk/client/stdio.js', () => {
|
||||
const MockStdioClientTransport = vi.fn().mockImplementation(() => ({
|
||||
stderr: {
|
||||
on: vi.fn(),
|
||||
},
|
||||
close: mockStdioTransportClose,
|
||||
}));
|
||||
return { StdioClientTransport: MockStdioClientTransport };
|
||||
});
|
||||
|
||||
vi.mock('@modelcontextprotocol/sdk/client/sse.js', () => {
|
||||
const MockSSEClientTransport = vi.fn().mockImplementation(() => ({
|
||||
close: mockSseTransportClose,
|
||||
}));
|
||||
return { SSEClientTransport: MockSSEClientTransport };
|
||||
});
|
||||
|
||||
// Mock @google/genai mcpToTool
|
||||
vi.mock('@google/genai', async () => {
|
||||
const actualGenai =
|
||||
await vi.importActual<typeof import('@google/genai')>('@google/genai');
|
||||
return {
|
||||
...actualGenai,
|
||||
mcpToTool: vi.fn().mockImplementation(() => ({
|
||||
tool: vi.fn().mockResolvedValue({ functionDeclarations: [] }),
|
||||
callTool: vi.fn(),
|
||||
})),
|
||||
};
|
||||
});
|
||||
|
||||
// Helper to create a mock CallableTool for specific test needs
|
||||
const createMockCallableTool = (
|
||||
toolDeclarations: FunctionDeclaration[],
|
||||
): Mocked<CallableTool> => ({
|
||||
tool: vi.fn().mockResolvedValue({ functionDeclarations: toolDeclarations }),
|
||||
callTool: vi.fn(),
|
||||
});
|
||||
|
||||
class MockTool extends BaseTool<{ param: string }, ToolResult> {
|
||||
constructor(name = 'mock-tool', description = 'A mock tool') {
|
||||
super(name, name, description, {
|
||||
type: Type.OBJECT,
|
||||
properties: {
|
||||
param: { type: Type.STRING },
|
||||
},
|
||||
required: ['param'],
|
||||
});
|
||||
}
|
||||
async execute(params: { param: string }): Promise<ToolResult> {
|
||||
return {
|
||||
llmContent: `Executed with ${params.param}`,
|
||||
returnDisplay: `Executed with ${params.param}`,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
const baseConfigParams: ConfigParameters = {
|
||||
cwd: '/tmp',
|
||||
model: 'test-model',
|
||||
embeddingModel: 'test-embedding-model',
|
||||
sandbox: undefined,
|
||||
targetDir: '/test/dir',
|
||||
debugMode: false,
|
||||
userMemory: '',
|
||||
geminiMdFileCount: 0,
|
||||
approvalMode: ApprovalMode.DEFAULT,
|
||||
sessionId: 'test-session-id',
|
||||
};
|
||||
|
||||
describe('ToolRegistry', () => {
|
||||
let config: Config;
|
||||
let toolRegistry: ToolRegistry;
|
||||
let mockConfigGetToolDiscoveryCommand: ReturnType<typeof vi.spyOn>;
|
||||
|
||||
beforeEach(() => {
|
||||
config = new Config(baseConfigParams);
|
||||
toolRegistry = new ToolRegistry(config);
|
||||
vi.spyOn(console, 'warn').mockImplementation(() => {});
|
||||
vi.spyOn(console, 'error').mockImplementation(() => {});
|
||||
vi.spyOn(console, 'debug').mockImplementation(() => {});
|
||||
vi.spyOn(console, 'log').mockImplementation(() => {});
|
||||
|
||||
mockMcpClientConnect.mockReset().mockResolvedValue(undefined);
|
||||
mockStdioTransportClose.mockReset();
|
||||
mockSseTransportClose.mockReset();
|
||||
vi.mocked(mcpToTool).mockClear();
|
||||
vi.mocked(mcpToTool).mockReturnValue(createMockCallableTool([]));
|
||||
|
||||
mockConfigGetToolDiscoveryCommand = vi.spyOn(
|
||||
config,
|
||||
'getToolDiscoveryCommand',
|
||||
);
|
||||
vi.spyOn(config, 'getMcpServers');
|
||||
vi.spyOn(config, 'getMcpServerCommand');
|
||||
mockDiscoverMcpTools.mockReset().mockResolvedValue(undefined);
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.restoreAllMocks();
|
||||
});
|
||||
|
||||
describe('registerTool', () => {
|
||||
it('should register a new tool', () => {
|
||||
const tool = new MockTool();
|
||||
toolRegistry.registerTool(tool);
|
||||
expect(toolRegistry.getTool('mock-tool')).toBe(tool);
|
||||
});
|
||||
});
|
||||
|
||||
describe('getToolsByServer', () => {
|
||||
it('should return an empty array if no tools match the server name', () => {
|
||||
toolRegistry.registerTool(new MockTool());
|
||||
expect(toolRegistry.getToolsByServer('any-mcp-server')).toEqual([]);
|
||||
});
|
||||
|
||||
it('should return only tools matching the server name', async () => {
|
||||
const server1Name = 'mcp-server-uno';
|
||||
const server2Name = 'mcp-server-dos';
|
||||
const mockCallable = {} as CallableTool;
|
||||
const mcpTool1 = new DiscoveredMCPTool(
|
||||
mockCallable,
|
||||
server1Name,
|
||||
'server1Name__tool-on-server1',
|
||||
'd1',
|
||||
{},
|
||||
'tool-on-server1',
|
||||
);
|
||||
const mcpTool2 = new DiscoveredMCPTool(
|
||||
mockCallable,
|
||||
server2Name,
|
||||
'server2Name__tool-on-server2',
|
||||
'd2',
|
||||
{},
|
||||
'tool-on-server2',
|
||||
);
|
||||
const nonMcpTool = new MockTool('regular-tool');
|
||||
|
||||
toolRegistry.registerTool(mcpTool1);
|
||||
toolRegistry.registerTool(mcpTool2);
|
||||
toolRegistry.registerTool(nonMcpTool);
|
||||
|
||||
const toolsFromServer1 = toolRegistry.getToolsByServer(server1Name);
|
||||
expect(toolsFromServer1).toHaveLength(1);
|
||||
expect(toolsFromServer1[0].name).toBe(mcpTool1.name);
|
||||
|
||||
const toolsFromServer2 = toolRegistry.getToolsByServer(server2Name);
|
||||
expect(toolsFromServer2).toHaveLength(1);
|
||||
expect(toolsFromServer2[0].name).toBe(mcpTool2.name);
|
||||
});
|
||||
});
|
||||
|
||||
describe('discoverTools', () => {
|
||||
it('should sanitize tool parameters during discovery from command', async () => {
|
||||
const discoveryCommand = 'my-discovery-command';
|
||||
mockConfigGetToolDiscoveryCommand.mockReturnValue(discoveryCommand);
|
||||
|
||||
const unsanitizedToolDeclaration: FunctionDeclaration = {
|
||||
name: 'tool-with-bad-format',
|
||||
description: 'A tool with an invalid format property',
|
||||
parameters: {
|
||||
type: Type.OBJECT,
|
||||
properties: {
|
||||
some_string: {
|
||||
type: Type.STRING,
|
||||
format: 'uuid', // This is an unsupported format
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
const mockSpawn = vi.mocked(spawn);
|
||||
const mockChildProcess = {
|
||||
stdout: { on: vi.fn() },
|
||||
stderr: { on: vi.fn() },
|
||||
on: vi.fn(),
|
||||
};
|
||||
mockSpawn.mockReturnValue(mockChildProcess as any);
|
||||
|
||||
// Simulate stdout data
|
||||
mockChildProcess.stdout.on.mockImplementation((event, callback) => {
|
||||
if (event === 'data') {
|
||||
callback(
|
||||
Buffer.from(
|
||||
JSON.stringify([
|
||||
{ function_declarations: [unsanitizedToolDeclaration] },
|
||||
]),
|
||||
),
|
||||
);
|
||||
}
|
||||
return mockChildProcess as any;
|
||||
});
|
||||
|
||||
// Simulate process close
|
||||
mockChildProcess.on.mockImplementation((event, callback) => {
|
||||
if (event === 'close') {
|
||||
callback(0);
|
||||
}
|
||||
return mockChildProcess as any;
|
||||
});
|
||||
|
||||
await toolRegistry.discoverTools();
|
||||
|
||||
const discoveredTool = toolRegistry.getTool('tool-with-bad-format');
|
||||
expect(discoveredTool).toBeDefined();
|
||||
|
||||
const registeredParams = (discoveredTool as DiscoveredTool).schema
|
||||
.parameters as Schema;
|
||||
expect(registeredParams.properties?.['some_string']).toBeDefined();
|
||||
expect(registeredParams.properties?.['some_string']).toHaveProperty(
|
||||
'format',
|
||||
undefined,
|
||||
);
|
||||
});
|
||||
|
||||
it('should discover tools using MCP servers defined in getMcpServers', async () => {
|
||||
mockConfigGetToolDiscoveryCommand.mockReturnValue(undefined);
|
||||
vi.spyOn(config, 'getMcpServerCommand').mockReturnValue(undefined);
|
||||
const mcpServerConfigVal = {
|
||||
'my-mcp-server': {
|
||||
command: 'mcp-server-cmd',
|
||||
args: ['--port', '1234'],
|
||||
trust: true,
|
||||
},
|
||||
};
|
||||
vi.spyOn(config, 'getMcpServers').mockReturnValue(mcpServerConfigVal);
|
||||
|
||||
await toolRegistry.discoverTools();
|
||||
|
||||
expect(mockDiscoverMcpTools).toHaveBeenCalledWith(
|
||||
mcpServerConfigVal,
|
||||
undefined,
|
||||
toolRegistry,
|
||||
false,
|
||||
);
|
||||
});
|
||||
|
||||
it('should discover tools using MCP servers defined in getMcpServers', async () => {
|
||||
mockConfigGetToolDiscoveryCommand.mockReturnValue(undefined);
|
||||
vi.spyOn(config, 'getMcpServerCommand').mockReturnValue(undefined);
|
||||
const mcpServerConfigVal = {
|
||||
'my-mcp-server': {
|
||||
command: 'mcp-server-cmd',
|
||||
args: ['--port', '1234'],
|
||||
trust: true,
|
||||
},
|
||||
};
|
||||
vi.spyOn(config, 'getMcpServers').mockReturnValue(mcpServerConfigVal);
|
||||
|
||||
await toolRegistry.discoverTools();
|
||||
|
||||
expect(mockDiscoverMcpTools).toHaveBeenCalledWith(
|
||||
mcpServerConfigVal,
|
||||
undefined,
|
||||
toolRegistry,
|
||||
false,
|
||||
);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('sanitizeParameters', () => {
|
||||
it('should remove default when anyOf is present', () => {
|
||||
const schema: Schema = {
|
||||
anyOf: [{ type: Type.STRING }, { type: Type.NUMBER }],
|
||||
default: 'hello',
|
||||
};
|
||||
sanitizeParameters(schema);
|
||||
expect(schema.default).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should recursively sanitize items in anyOf', () => {
|
||||
const schema: Schema = {
|
||||
anyOf: [
|
||||
{
|
||||
anyOf: [{ type: Type.STRING }],
|
||||
default: 'world',
|
||||
},
|
||||
{ type: Type.NUMBER },
|
||||
],
|
||||
};
|
||||
sanitizeParameters(schema);
|
||||
expect(schema.anyOf![0].default).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should recursively sanitize items in items', () => {
|
||||
const schema: Schema = {
|
||||
items: {
|
||||
anyOf: [{ type: Type.STRING }],
|
||||
default: 'world',
|
||||
},
|
||||
};
|
||||
sanitizeParameters(schema);
|
||||
expect(schema.items!.default).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should recursively sanitize items in properties', () => {
|
||||
const schema: Schema = {
|
||||
properties: {
|
||||
prop1: {
|
||||
anyOf: [{ type: Type.STRING }],
|
||||
default: 'world',
|
||||
},
|
||||
},
|
||||
};
|
||||
sanitizeParameters(schema);
|
||||
expect(schema.properties!.prop1.default).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should handle complex nested schemas', () => {
|
||||
const schema: Schema = {
|
||||
properties: {
|
||||
prop1: {
|
||||
items: {
|
||||
anyOf: [{ type: Type.STRING }],
|
||||
default: 'world',
|
||||
},
|
||||
},
|
||||
prop2: {
|
||||
anyOf: [
|
||||
{
|
||||
properties: {
|
||||
nestedProp: {
|
||||
anyOf: [{ type: Type.NUMBER }],
|
||||
default: 123,
|
||||
},
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
},
|
||||
};
|
||||
sanitizeParameters(schema);
|
||||
expect(schema.properties!.prop1.items!.default).toBeUndefined();
|
||||
const nestedProp =
|
||||
schema.properties!.prop2.anyOf![0].properties!.nestedProp;
|
||||
expect(nestedProp?.default).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should remove unsupported format from a simple string property', () => {
|
||||
const schema: Schema = {
|
||||
type: Type.OBJECT,
|
||||
properties: {
|
||||
name: { type: Type.STRING },
|
||||
id: { type: Type.STRING, format: 'uuid' },
|
||||
},
|
||||
};
|
||||
sanitizeParameters(schema);
|
||||
expect(schema.properties?.['id']).toHaveProperty('format', undefined);
|
||||
expect(schema.properties?.['name']).not.toHaveProperty('format');
|
||||
});
|
||||
|
||||
it('should NOT remove supported format values', () => {
|
||||
const schema: Schema = {
|
||||
type: Type.OBJECT,
|
||||
properties: {
|
||||
date: { type: Type.STRING, format: 'date-time' },
|
||||
role: {
|
||||
type: Type.STRING,
|
||||
format: 'enum',
|
||||
enum: ['admin', 'user'],
|
||||
},
|
||||
},
|
||||
};
|
||||
const originalSchema = JSON.parse(JSON.stringify(schema));
|
||||
sanitizeParameters(schema);
|
||||
expect(schema).toEqual(originalSchema);
|
||||
});
|
||||
|
||||
it('should handle arrays of objects', () => {
|
||||
const schema: Schema = {
|
||||
type: Type.OBJECT,
|
||||
properties: {
|
||||
items: {
|
||||
type: Type.ARRAY,
|
||||
items: {
|
||||
type: Type.OBJECT,
|
||||
properties: {
|
||||
itemId: { type: Type.STRING, format: 'uuid' },
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
sanitizeParameters(schema);
|
||||
expect(
|
||||
(schema.properties?.['items']?.items as Schema)?.properties?.['itemId'],
|
||||
).toHaveProperty('format', undefined);
|
||||
});
|
||||
|
||||
it('should handle schemas with no properties to sanitize', () => {
|
||||
const schema: Schema = {
|
||||
type: Type.OBJECT,
|
||||
properties: {
|
||||
count: { type: Type.NUMBER },
|
||||
isActive: { type: Type.BOOLEAN },
|
||||
},
|
||||
};
|
||||
const originalSchema = JSON.parse(JSON.stringify(schema));
|
||||
sanitizeParameters(schema);
|
||||
expect(schema).toEqual(originalSchema);
|
||||
});
|
||||
|
||||
it('should not crash on an empty or undefined schema', () => {
|
||||
expect(() => sanitizeParameters({})).not.toThrow();
|
||||
expect(() => sanitizeParameters(undefined)).not.toThrow();
|
||||
});
|
||||
|
||||
it('should handle complex nested schemas with cycles', () => {
|
||||
const userNode: any = {
|
||||
type: Type.OBJECT,
|
||||
properties: {
|
||||
id: { type: Type.STRING, format: 'uuid' },
|
||||
name: { type: Type.STRING },
|
||||
manager: {
|
||||
type: Type.OBJECT,
|
||||
properties: {
|
||||
id: { type: Type.STRING, format: 'uuid' },
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
userNode.properties.reports = {
|
||||
type: Type.ARRAY,
|
||||
items: userNode,
|
||||
};
|
||||
|
||||
const schema: Schema = {
|
||||
type: Type.OBJECT,
|
||||
properties: {
|
||||
ceo: userNode,
|
||||
},
|
||||
};
|
||||
|
||||
expect(() => sanitizeParameters(schema)).not.toThrow();
|
||||
expect(schema.properties?.['ceo']?.properties?.['id']).toHaveProperty(
|
||||
'format',
|
||||
undefined,
|
||||
);
|
||||
expect(
|
||||
schema.properties?.['ceo']?.properties?.['manager']?.properties?.['id'],
|
||||
).toHaveProperty('format', undefined);
|
||||
});
|
||||
});
|
||||
392
packages/core/src/tools/tool-registry.ts
Normal file
392
packages/core/src/tools/tool-registry.ts
Normal file
@@ -0,0 +1,392 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { FunctionDeclaration, Schema, Type } from '@google/genai';
|
||||
import { Tool, ToolResult, BaseTool } from './tools.js';
|
||||
import { Config } from '../config/config.js';
|
||||
import { spawn } from 'node:child_process';
|
||||
import { StringDecoder } from 'node:string_decoder';
|
||||
import { discoverMcpTools } from './mcp-client.js';
|
||||
import { DiscoveredMCPTool } from './mcp-tool.js';
|
||||
import { parse } from 'shell-quote';
|
||||
|
||||
type ToolParams = Record<string, unknown>;
|
||||
|
||||
export class DiscoveredTool extends BaseTool<ToolParams, ToolResult> {
|
||||
constructor(
|
||||
private readonly config: Config,
|
||||
readonly name: string,
|
||||
readonly description: string,
|
||||
readonly parameterSchema: Record<string, unknown>,
|
||||
) {
|
||||
const discoveryCmd = config.getToolDiscoveryCommand()!;
|
||||
const callCommand = config.getToolCallCommand()!;
|
||||
description += `
|
||||
|
||||
This tool was discovered from the project by executing the command \`${discoveryCmd}\` on project root.
|
||||
When called, this tool will execute the command \`${callCommand} ${name}\` on project root.
|
||||
Tool discovery and call commands can be configured in project or user settings.
|
||||
|
||||
When called, the tool call command is executed as a subprocess.
|
||||
On success, tool output is returned as a json string.
|
||||
Otherwise, the following information is returned:
|
||||
|
||||
Stdout: Output on stdout stream. Can be \`(empty)\` or partial.
|
||||
Stderr: Output on stderr stream. Can be \`(empty)\` or partial.
|
||||
Error: Error or \`(none)\` if no error was reported for the subprocess.
|
||||
Exit Code: Exit code or \`(none)\` if terminated by signal.
|
||||
Signal: Signal number or \`(none)\` if no signal was received.
|
||||
`;
|
||||
super(
|
||||
name,
|
||||
name,
|
||||
description,
|
||||
parameterSchema,
|
||||
false, // isOutputMarkdown
|
||||
false, // canUpdateOutput
|
||||
);
|
||||
}
|
||||
|
||||
async execute(params: ToolParams): Promise<ToolResult> {
|
||||
const callCommand = this.config.getToolCallCommand()!;
|
||||
const child = spawn(callCommand, [this.name]);
|
||||
child.stdin.write(JSON.stringify(params));
|
||||
child.stdin.end();
|
||||
|
||||
let stdout = '';
|
||||
let stderr = '';
|
||||
let error: Error | null = null;
|
||||
let code: number | null = null;
|
||||
let signal: NodeJS.Signals | null = null;
|
||||
|
||||
await new Promise<void>((resolve) => {
|
||||
const onStdout = (data: Buffer) => {
|
||||
stdout += data?.toString();
|
||||
};
|
||||
|
||||
const onStderr = (data: Buffer) => {
|
||||
stderr += data?.toString();
|
||||
};
|
||||
|
||||
const onError = (err: Error) => {
|
||||
error = err;
|
||||
};
|
||||
|
||||
const onClose = (
|
||||
_code: number | null,
|
||||
_signal: NodeJS.Signals | null,
|
||||
) => {
|
||||
code = _code;
|
||||
signal = _signal;
|
||||
cleanup();
|
||||
resolve();
|
||||
};
|
||||
|
||||
const cleanup = () => {
|
||||
child.stdout.removeListener('data', onStdout);
|
||||
child.stderr.removeListener('data', onStderr);
|
||||
child.removeListener('error', onError);
|
||||
child.removeListener('close', onClose);
|
||||
if (child.connected) {
|
||||
child.disconnect();
|
||||
}
|
||||
};
|
||||
|
||||
child.stdout.on('data', onStdout);
|
||||
child.stderr.on('data', onStderr);
|
||||
child.on('error', onError);
|
||||
child.on('close', onClose);
|
||||
});
|
||||
|
||||
// if there is any error, non-zero exit code, signal, or stderr, return error details instead of stdout
|
||||
if (error || code !== 0 || signal || stderr) {
|
||||
const llmContent = [
|
||||
`Stdout: ${stdout || '(empty)'}`,
|
||||
`Stderr: ${stderr || '(empty)'}`,
|
||||
`Error: ${error ?? '(none)'}`,
|
||||
`Exit Code: ${code ?? '(none)'}`,
|
||||
`Signal: ${signal ?? '(none)'}`,
|
||||
].join('\n');
|
||||
return {
|
||||
llmContent,
|
||||
returnDisplay: llmContent,
|
||||
};
|
||||
}
|
||||
|
||||
return {
|
||||
llmContent: stdout,
|
||||
returnDisplay: stdout,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
export class ToolRegistry {
|
||||
private tools: Map<string, Tool> = new Map();
|
||||
private config: Config;
|
||||
|
||||
constructor(config: Config) {
|
||||
this.config = config;
|
||||
}
|
||||
|
||||
/**
|
||||
* Registers a tool definition.
|
||||
* @param tool - The tool object containing schema and execution logic.
|
||||
*/
|
||||
registerTool(tool: Tool): void {
|
||||
if (this.tools.has(tool.name)) {
|
||||
// Decide on behavior: throw error, log warning, or allow overwrite
|
||||
console.warn(
|
||||
`Tool with name "${tool.name}" is already registered. Overwriting.`,
|
||||
);
|
||||
}
|
||||
this.tools.set(tool.name, tool);
|
||||
}
|
||||
|
||||
/**
|
||||
* Discovers tools from project (if available and configured).
|
||||
* Can be called multiple times to update discovered tools.
|
||||
*/
|
||||
async discoverTools(): Promise<void> {
|
||||
// remove any previously discovered tools
|
||||
for (const tool of this.tools.values()) {
|
||||
if (tool instanceof DiscoveredTool || tool instanceof DiscoveredMCPTool) {
|
||||
this.tools.delete(tool.name);
|
||||
}
|
||||
}
|
||||
|
||||
await this.discoverAndRegisterToolsFromCommand();
|
||||
|
||||
// discover tools using MCP servers, if configured
|
||||
await discoverMcpTools(
|
||||
this.config.getMcpServers() ?? {},
|
||||
this.config.getMcpServerCommand(),
|
||||
this,
|
||||
this.config.getDebugMode(),
|
||||
);
|
||||
}
|
||||
|
||||
private async discoverAndRegisterToolsFromCommand(): Promise<void> {
|
||||
const discoveryCmd = this.config.getToolDiscoveryCommand();
|
||||
if (!discoveryCmd) {
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
const cmdParts = parse(discoveryCmd);
|
||||
if (cmdParts.length === 0) {
|
||||
throw new Error(
|
||||
'Tool discovery command is empty or contains only whitespace.',
|
||||
);
|
||||
}
|
||||
const proc = spawn(cmdParts[0] as string, cmdParts.slice(1) as string[]);
|
||||
let stdout = '';
|
||||
const stdoutDecoder = new StringDecoder('utf8');
|
||||
let stderr = '';
|
||||
const stderrDecoder = new StringDecoder('utf8');
|
||||
let sizeLimitExceeded = false;
|
||||
const MAX_STDOUT_SIZE = 10 * 1024 * 1024; // 10MB limit
|
||||
const MAX_STDERR_SIZE = 10 * 1024 * 1024; // 10MB limit
|
||||
|
||||
let stdoutByteLength = 0;
|
||||
let stderrByteLength = 0;
|
||||
|
||||
proc.stdout.on('data', (data) => {
|
||||
if (sizeLimitExceeded) return;
|
||||
if (stdoutByteLength + data.length > MAX_STDOUT_SIZE) {
|
||||
sizeLimitExceeded = true;
|
||||
proc.kill();
|
||||
return;
|
||||
}
|
||||
stdoutByteLength += data.length;
|
||||
stdout += stdoutDecoder.write(data);
|
||||
});
|
||||
|
||||
proc.stderr.on('data', (data) => {
|
||||
if (sizeLimitExceeded) return;
|
||||
if (stderrByteLength + data.length > MAX_STDERR_SIZE) {
|
||||
sizeLimitExceeded = true;
|
||||
proc.kill();
|
||||
return;
|
||||
}
|
||||
stderrByteLength += data.length;
|
||||
stderr += stderrDecoder.write(data);
|
||||
});
|
||||
|
||||
await new Promise<void>((resolve, reject) => {
|
||||
proc.on('error', reject);
|
||||
proc.on('close', (code) => {
|
||||
stdout += stdoutDecoder.end();
|
||||
stderr += stderrDecoder.end();
|
||||
|
||||
if (sizeLimitExceeded) {
|
||||
return reject(
|
||||
new Error(
|
||||
`Tool discovery command output exceeded size limit of ${MAX_STDOUT_SIZE} bytes.`,
|
||||
),
|
||||
);
|
||||
}
|
||||
|
||||
if (code !== 0) {
|
||||
console.error(`Command failed with code ${code}`);
|
||||
console.error(stderr);
|
||||
return reject(
|
||||
new Error(`Tool discovery command failed with exit code ${code}`),
|
||||
);
|
||||
}
|
||||
resolve();
|
||||
});
|
||||
});
|
||||
|
||||
// execute discovery command and extract function declarations (w/ or w/o "tool" wrappers)
|
||||
const functions: FunctionDeclaration[] = [];
|
||||
const discoveredItems = JSON.parse(stdout.trim());
|
||||
|
||||
if (!discoveredItems || !Array.isArray(discoveredItems)) {
|
||||
throw new Error(
|
||||
'Tool discovery command did not return a JSON array of tools.',
|
||||
);
|
||||
}
|
||||
|
||||
for (const tool of discoveredItems) {
|
||||
if (tool && typeof tool === 'object') {
|
||||
if (Array.isArray(tool['function_declarations'])) {
|
||||
functions.push(...tool['function_declarations']);
|
||||
} else if (Array.isArray(tool['functionDeclarations'])) {
|
||||
functions.push(...tool['functionDeclarations']);
|
||||
} else if (tool['name']) {
|
||||
functions.push(tool as FunctionDeclaration);
|
||||
}
|
||||
}
|
||||
}
|
||||
// register each function as a tool
|
||||
for (const func of functions) {
|
||||
if (!func.name) {
|
||||
console.warn('Discovered a tool with no name. Skipping.');
|
||||
continue;
|
||||
}
|
||||
// Sanitize the parameters before registering the tool.
|
||||
const parameters =
|
||||
func.parameters &&
|
||||
typeof func.parameters === 'object' &&
|
||||
!Array.isArray(func.parameters)
|
||||
? (func.parameters as Schema)
|
||||
: {};
|
||||
sanitizeParameters(parameters);
|
||||
this.registerTool(
|
||||
new DiscoveredTool(
|
||||
this.config,
|
||||
func.name,
|
||||
func.description ?? '',
|
||||
parameters as Record<string, unknown>,
|
||||
),
|
||||
);
|
||||
}
|
||||
} catch (e) {
|
||||
console.error(`Tool discovery command "${discoveryCmd}" failed:`, e);
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Retrieves the list of tool schemas (FunctionDeclaration array).
|
||||
* Extracts the declarations from the ToolListUnion structure.
|
||||
* Includes discovered (vs registered) tools if configured.
|
||||
* @returns An array of FunctionDeclarations.
|
||||
*/
|
||||
getFunctionDeclarations(): FunctionDeclaration[] {
|
||||
const declarations: FunctionDeclaration[] = [];
|
||||
this.tools.forEach((tool) => {
|
||||
declarations.push(tool.schema);
|
||||
});
|
||||
return declarations;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns an array of all registered and discovered tool instances.
|
||||
*/
|
||||
getAllTools(): Tool[] {
|
||||
return Array.from(this.tools.values());
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns an array of tools registered from a specific MCP server.
|
||||
*/
|
||||
getToolsByServer(serverName: string): Tool[] {
|
||||
const serverTools: Tool[] = [];
|
||||
for (const tool of this.tools.values()) {
|
||||
if ((tool as DiscoveredMCPTool)?.serverName === serverName) {
|
||||
serverTools.push(tool);
|
||||
}
|
||||
}
|
||||
return serverTools;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the definition of a specific tool.
|
||||
*/
|
||||
getTool(name: string): Tool | undefined {
|
||||
return this.tools.get(name);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Sanitizes a schema object in-place to ensure compatibility with the Gemini API.
|
||||
*
|
||||
* NOTE: This function mutates the passed schema object.
|
||||
*
|
||||
* It performs the following actions:
|
||||
* - Removes the `default` property when `anyOf` is present.
|
||||
* - Removes unsupported `format` values from string properties, keeping only 'enum' and 'date-time'.
|
||||
* - Recursively sanitizes nested schemas within `anyOf`, `items`, and `properties`.
|
||||
* - Handles circular references within the schema to prevent infinite loops.
|
||||
*
|
||||
* @param schema The schema object to sanitize. It will be modified directly.
|
||||
*/
|
||||
export function sanitizeParameters(schema?: Schema) {
|
||||
_sanitizeParameters(schema, new Set<Schema>());
|
||||
}
|
||||
|
||||
/**
|
||||
* Internal recursive implementation for sanitizeParameters.
|
||||
* @param schema The schema object to sanitize.
|
||||
* @param visited A set used to track visited schema objects during recursion.
|
||||
*/
|
||||
function _sanitizeParameters(schema: Schema | undefined, visited: Set<Schema>) {
|
||||
if (!schema || visited.has(schema)) {
|
||||
return;
|
||||
}
|
||||
visited.add(schema);
|
||||
|
||||
if (schema.anyOf) {
|
||||
// Vertex AI gets confused if both anyOf and default are set.
|
||||
schema.default = undefined;
|
||||
for (const item of schema.anyOf) {
|
||||
if (typeof item !== 'boolean') {
|
||||
_sanitizeParameters(item, visited);
|
||||
}
|
||||
}
|
||||
}
|
||||
if (schema.items && typeof schema.items !== 'boolean') {
|
||||
_sanitizeParameters(schema.items, visited);
|
||||
}
|
||||
if (schema.properties) {
|
||||
for (const item of Object.values(schema.properties)) {
|
||||
if (typeof item !== 'boolean') {
|
||||
_sanitizeParameters(item, visited);
|
||||
}
|
||||
}
|
||||
}
|
||||
// Vertex AI only supports 'enum' and 'date-time' for STRING format.
|
||||
if (schema.type === Type.STRING) {
|
||||
if (
|
||||
schema.format &&
|
||||
schema.format !== 'enum' &&
|
||||
schema.format !== 'date-time'
|
||||
) {
|
||||
schema.format = undefined;
|
||||
}
|
||||
}
|
||||
}
|
||||
260
packages/core/src/tools/tools.ts
Normal file
260
packages/core/src/tools/tools.ts
Normal file
@@ -0,0 +1,260 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { FunctionDeclaration, PartListUnion, Schema } from '@google/genai';
|
||||
|
||||
/**
|
||||
* Interface representing the base Tool functionality
|
||||
*/
|
||||
export interface Tool<
|
||||
TParams = unknown,
|
||||
TResult extends ToolResult = ToolResult,
|
||||
> {
|
||||
/**
|
||||
* The internal name of the tool (used for API calls)
|
||||
*/
|
||||
name: string;
|
||||
|
||||
/**
|
||||
* The user-friendly display name of the tool
|
||||
*/
|
||||
displayName: string;
|
||||
|
||||
/**
|
||||
* Description of what the tool does
|
||||
*/
|
||||
description: string;
|
||||
|
||||
/**
|
||||
* Function declaration schema from @google/genai
|
||||
*/
|
||||
schema: FunctionDeclaration;
|
||||
|
||||
/**
|
||||
* Whether the tool's output should be rendered as markdown
|
||||
*/
|
||||
isOutputMarkdown: boolean;
|
||||
|
||||
/**
|
||||
* Whether the tool supports live (streaming) output
|
||||
*/
|
||||
canUpdateOutput: boolean;
|
||||
|
||||
/**
|
||||
* Validates the parameters for the tool
|
||||
* Should be called from both `shouldConfirmExecute` and `execute`
|
||||
* `shouldConfirmExecute` should return false immediately if invalid
|
||||
* @param params Parameters to validate
|
||||
* @returns An error message string if invalid, null otherwise
|
||||
*/
|
||||
validateToolParams(params: TParams): string | null;
|
||||
|
||||
/**
|
||||
* Gets a pre-execution description of the tool operation
|
||||
* @param params Parameters for the tool execution
|
||||
* @returns A markdown string describing what the tool will do
|
||||
* Optional for backward compatibility
|
||||
*/
|
||||
getDescription(params: TParams): string;
|
||||
|
||||
/**
|
||||
* Determines if the tool should prompt for confirmation before execution
|
||||
* @param params Parameters for the tool execution
|
||||
* @returns Whether execute should be confirmed.
|
||||
*/
|
||||
shouldConfirmExecute(
|
||||
params: TParams,
|
||||
abortSignal: AbortSignal,
|
||||
): Promise<ToolCallConfirmationDetails | false>;
|
||||
|
||||
/**
|
||||
* Executes the tool with the given parameters
|
||||
* @param params Parameters for the tool execution
|
||||
* @returns Result of the tool execution
|
||||
*/
|
||||
execute(
|
||||
params: TParams,
|
||||
signal: AbortSignal,
|
||||
updateOutput?: (output: string) => void,
|
||||
): Promise<TResult>;
|
||||
}
|
||||
|
||||
/**
|
||||
* Base implementation for tools with common functionality
|
||||
*/
|
||||
export abstract class BaseTool<
|
||||
TParams = unknown,
|
||||
TResult extends ToolResult = ToolResult,
|
||||
> implements Tool<TParams, TResult>
|
||||
{
|
||||
/**
|
||||
* Creates a new instance of BaseTool
|
||||
* @param name Internal name of the tool (used for API calls)
|
||||
* @param displayName User-friendly display name of the tool
|
||||
* @param description Description of what the tool does
|
||||
* @param isOutputMarkdown Whether the tool's output should be rendered as markdown
|
||||
* @param canUpdateOutput Whether the tool supports live (streaming) output
|
||||
* @param parameterSchema JSON Schema defining the parameters
|
||||
*/
|
||||
constructor(
|
||||
readonly name: string,
|
||||
readonly displayName: string,
|
||||
readonly description: string,
|
||||
readonly parameterSchema: Schema,
|
||||
readonly isOutputMarkdown: boolean = true,
|
||||
readonly canUpdateOutput: boolean = false,
|
||||
) {}
|
||||
|
||||
/**
|
||||
* Function declaration schema computed from name, description, and parameterSchema
|
||||
*/
|
||||
get schema(): FunctionDeclaration {
|
||||
return {
|
||||
name: this.name,
|
||||
description: this.description,
|
||||
parameters: this.parameterSchema,
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Validates the parameters for the tool
|
||||
* This is a placeholder implementation and should be overridden
|
||||
* Should be called from both `shouldConfirmExecute` and `execute`
|
||||
* `shouldConfirmExecute` should return false immediately if invalid
|
||||
* @param params Parameters to validate
|
||||
* @returns An error message string if invalid, null otherwise
|
||||
*/
|
||||
// eslint-disable-next-line @typescript-eslint/no-unused-vars
|
||||
validateToolParams(params: TParams): string | null {
|
||||
// Implementation would typically use a JSON Schema validator
|
||||
// This is a placeholder that should be implemented by derived classes
|
||||
return null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets a pre-execution description of the tool operation
|
||||
* Default implementation that should be overridden by derived classes
|
||||
* @param params Parameters for the tool execution
|
||||
* @returns A markdown string describing what the tool will do
|
||||
*/
|
||||
getDescription(params: TParams): string {
|
||||
return JSON.stringify(params);
|
||||
}
|
||||
|
||||
/**
|
||||
* Determines if the tool should prompt for confirmation before execution
|
||||
* @param params Parameters for the tool execution
|
||||
* @returns Whether or not execute should be confirmed by the user.
|
||||
*/
|
||||
shouldConfirmExecute(
|
||||
// eslint-disable-next-line @typescript-eslint/no-unused-vars
|
||||
params: TParams,
|
||||
// eslint-disable-next-line @typescript-eslint/no-unused-vars
|
||||
abortSignal: AbortSignal,
|
||||
): Promise<ToolCallConfirmationDetails | false> {
|
||||
return Promise.resolve(false);
|
||||
}
|
||||
|
||||
/**
|
||||
* Abstract method to execute the tool with the given parameters
|
||||
* Must be implemented by derived classes
|
||||
* @param params Parameters for the tool execution
|
||||
* @param signal AbortSignal for tool cancellation
|
||||
* @returns Result of the tool execution
|
||||
*/
|
||||
abstract execute(
|
||||
params: TParams,
|
||||
signal: AbortSignal,
|
||||
updateOutput?: (output: string) => void,
|
||||
): Promise<TResult>;
|
||||
}
|
||||
|
||||
export interface ToolResult {
|
||||
/**
|
||||
* A short, one-line summary of the tool's action and result.
|
||||
* e.g., "Read 5 files", "Wrote 256 bytes to foo.txt"
|
||||
*/
|
||||
summary?: string;
|
||||
/**
|
||||
* Content meant to be included in LLM history.
|
||||
* This should represent the factual outcome of the tool execution.
|
||||
*/
|
||||
llmContent: PartListUnion;
|
||||
|
||||
/**
|
||||
* Markdown string for user display.
|
||||
* This provides a user-friendly summary or visualization of the result.
|
||||
* NOTE: This might also be considered UI-specific and could potentially be
|
||||
* removed or modified in a further refactor if the server becomes purely API-driven.
|
||||
* For now, we keep it as the core logic in ReadFileTool currently produces it.
|
||||
*/
|
||||
returnDisplay: ToolResultDisplay;
|
||||
}
|
||||
|
||||
export type ToolResultDisplay = string | FileDiff;
|
||||
|
||||
export interface FileDiff {
|
||||
fileDiff: string;
|
||||
fileName: string;
|
||||
}
|
||||
|
||||
export interface ToolEditConfirmationDetails {
|
||||
type: 'edit';
|
||||
title: string;
|
||||
onConfirm: (
|
||||
outcome: ToolConfirmationOutcome,
|
||||
payload?: ToolConfirmationPayload,
|
||||
) => Promise<void>;
|
||||
fileName: string;
|
||||
fileDiff: string;
|
||||
isModifying?: boolean;
|
||||
}
|
||||
|
||||
export interface ToolConfirmationPayload {
|
||||
// used to override `modifiedProposedContent` for modifiable tools in the
|
||||
// inline modify flow
|
||||
newContent: string;
|
||||
}
|
||||
|
||||
export interface ToolExecuteConfirmationDetails {
|
||||
type: 'exec';
|
||||
title: string;
|
||||
onConfirm: (outcome: ToolConfirmationOutcome) => Promise<void>;
|
||||
command: string;
|
||||
rootCommand: string;
|
||||
}
|
||||
|
||||
export interface ToolMcpConfirmationDetails {
|
||||
type: 'mcp';
|
||||
title: string;
|
||||
serverName: string;
|
||||
toolName: string;
|
||||
toolDisplayName: string;
|
||||
onConfirm: (outcome: ToolConfirmationOutcome) => Promise<void>;
|
||||
}
|
||||
|
||||
export interface ToolInfoConfirmationDetails {
|
||||
type: 'info';
|
||||
title: string;
|
||||
onConfirm: (outcome: ToolConfirmationOutcome) => Promise<void>;
|
||||
prompt: string;
|
||||
urls?: string[];
|
||||
}
|
||||
|
||||
export type ToolCallConfirmationDetails =
|
||||
| ToolEditConfirmationDetails
|
||||
| ToolExecuteConfirmationDetails
|
||||
| ToolMcpConfirmationDetails
|
||||
| ToolInfoConfirmationDetails;
|
||||
|
||||
export enum ToolConfirmationOutcome {
|
||||
ProceedOnce = 'proceed_once',
|
||||
ProceedAlways = 'proceed_always',
|
||||
ProceedAlwaysServer = 'proceed_always_server',
|
||||
ProceedAlwaysTool = 'proceed_always_tool',
|
||||
ModifyWithEditor = 'modify_with_editor',
|
||||
Cancel = 'cancel',
|
||||
}
|
||||
86
packages/core/src/tools/web-fetch.test.ts
Normal file
86
packages/core/src/tools/web-fetch.test.ts
Normal file
@@ -0,0 +1,86 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { describe, it, expect, vi } from 'vitest';
|
||||
import { WebFetchTool } from './web-fetch.js';
|
||||
import { Config, ApprovalMode } from '../config/config.js';
|
||||
import { ToolConfirmationOutcome } from './tools.js';
|
||||
|
||||
describe('WebFetchTool', () => {
|
||||
const mockConfig = {
|
||||
getApprovalMode: vi.fn(),
|
||||
setApprovalMode: vi.fn(),
|
||||
} as unknown as Config;
|
||||
|
||||
describe('shouldConfirmExecute', () => {
|
||||
it('should return confirmation details with the correct prompt and urls', async () => {
|
||||
const tool = new WebFetchTool(mockConfig);
|
||||
const params = { prompt: 'fetch https://example.com' };
|
||||
const confirmationDetails = await tool.shouldConfirmExecute(params);
|
||||
|
||||
expect(confirmationDetails).toEqual({
|
||||
type: 'info',
|
||||
title: 'Confirm Web Fetch',
|
||||
prompt: 'fetch https://example.com',
|
||||
urls: ['https://example.com'],
|
||||
onConfirm: expect.any(Function),
|
||||
});
|
||||
});
|
||||
|
||||
it('should convert github urls to raw format', async () => {
|
||||
const tool = new WebFetchTool(mockConfig);
|
||||
const params = {
|
||||
prompt:
|
||||
'fetch https://github.com/google/gemini-react/blob/main/README.md',
|
||||
};
|
||||
const confirmationDetails = await tool.shouldConfirmExecute(params);
|
||||
|
||||
expect(confirmationDetails).toEqual({
|
||||
type: 'info',
|
||||
title: 'Confirm Web Fetch',
|
||||
prompt:
|
||||
'fetch https://github.com/google/gemini-react/blob/main/README.md',
|
||||
urls: [
|
||||
'https://raw.githubusercontent.com/google/gemini-react/main/README.md',
|
||||
],
|
||||
onConfirm: expect.any(Function),
|
||||
});
|
||||
});
|
||||
|
||||
it('should return false if approval mode is AUTO_EDIT', async () => {
|
||||
const tool = new WebFetchTool({
|
||||
...mockConfig,
|
||||
getApprovalMode: () => ApprovalMode.AUTO_EDIT,
|
||||
} as unknown as Config);
|
||||
const params = { prompt: 'fetch https://example.com' };
|
||||
const confirmationDetails = await tool.shouldConfirmExecute(params);
|
||||
|
||||
expect(confirmationDetails).toBe(false);
|
||||
});
|
||||
|
||||
it('should call setApprovalMode when onConfirm is called with ProceedAlways', async () => {
|
||||
const setApprovalMode = vi.fn();
|
||||
const tool = new WebFetchTool({
|
||||
...mockConfig,
|
||||
setApprovalMode,
|
||||
} as unknown as Config);
|
||||
const params = { prompt: 'fetch https://example.com' };
|
||||
const confirmationDetails = await tool.shouldConfirmExecute(params);
|
||||
|
||||
if (
|
||||
confirmationDetails &&
|
||||
typeof confirmationDetails === 'object' &&
|
||||
'onConfirm' in confirmationDetails
|
||||
) {
|
||||
await confirmationDetails.onConfirm(
|
||||
ToolConfirmationOutcome.ProceedAlways,
|
||||
);
|
||||
}
|
||||
|
||||
expect(setApprovalMode).toHaveBeenCalledWith(ApprovalMode.AUTO_EDIT);
|
||||
});
|
||||
});
|
||||
});
|
||||
389
packages/core/src/tools/web-fetch.ts
Normal file
389
packages/core/src/tools/web-fetch.ts
Normal file
@@ -0,0 +1,389 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { SchemaValidator } from '../utils/schemaValidator.js';
|
||||
import {
|
||||
BaseTool,
|
||||
ToolResult,
|
||||
ToolCallConfirmationDetails,
|
||||
ToolConfirmationOutcome,
|
||||
} from './tools.js';
|
||||
import { Type } from '@google/genai';
|
||||
import { getErrorMessage } from '../utils/errors.js';
|
||||
import { Config, ApprovalMode } from '../config/config.js';
|
||||
import { getResponseText } from '../utils/generateContentResponseUtilities.js';
|
||||
import { fetchWithTimeout, isPrivateIp } from '../utils/fetch.js';
|
||||
import { convert } from 'html-to-text';
|
||||
|
||||
const URL_FETCH_TIMEOUT_MS = 10000;
|
||||
const MAX_CONTENT_LENGTH = 50000;
|
||||
|
||||
// Helper function to extract URLs from a string
|
||||
function extractUrls(text: string): string[] {
|
||||
const urlRegex = /(https?:\/\/[^\s]+)/g;
|
||||
return text.match(urlRegex) || [];
|
||||
}
|
||||
|
||||
// Interfaces for grounding metadata (similar to web-search.ts)
|
||||
interface GroundingChunkWeb {
|
||||
uri?: string;
|
||||
title?: string;
|
||||
}
|
||||
|
||||
interface GroundingChunkItem {
|
||||
web?: GroundingChunkWeb;
|
||||
}
|
||||
|
||||
interface GroundingSupportSegment {
|
||||
startIndex: number;
|
||||
endIndex: number;
|
||||
text?: string;
|
||||
}
|
||||
|
||||
interface GroundingSupportItem {
|
||||
segment?: GroundingSupportSegment;
|
||||
groundingChunkIndices?: number[];
|
||||
}
|
||||
|
||||
/**
|
||||
* Parameters for the WebFetch tool
|
||||
*/
|
||||
export interface WebFetchToolParams {
|
||||
/**
|
||||
* The prompt containing URL(s) (up to 20) and instructions for processing their content.
|
||||
*/
|
||||
prompt: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* Implementation of the WebFetch tool logic
|
||||
*/
|
||||
export class WebFetchTool extends BaseTool<WebFetchToolParams, ToolResult> {
|
||||
static readonly Name: string = 'web_fetch';
|
||||
|
||||
constructor(private readonly config: Config) {
|
||||
super(
|
||||
WebFetchTool.Name,
|
||||
'WebFetch',
|
||||
"Processes content from URL(s), including local and private network addresses (e.g., localhost), embedded in a prompt. Include up to 20 URLs and instructions (e.g., summarize, extract specific data) directly in the 'prompt' parameter.",
|
||||
{
|
||||
properties: {
|
||||
prompt: {
|
||||
description:
|
||||
'A comprehensive prompt that includes the URL(s) (up to 20) to fetch and specific instructions on how to process their content (e.g., "Summarize https://example.com/article and extract key points from https://another.com/data"). Must contain as least one URL starting with http:// or https://.',
|
||||
type: Type.STRING,
|
||||
},
|
||||
},
|
||||
required: ['prompt'],
|
||||
type: Type.OBJECT,
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
private async executeFallback(
|
||||
params: WebFetchToolParams,
|
||||
signal: AbortSignal,
|
||||
): Promise<ToolResult> {
|
||||
const urls = extractUrls(params.prompt);
|
||||
if (urls.length === 0) {
|
||||
return {
|
||||
llmContent: 'Error: No URL found in the prompt for fallback.',
|
||||
returnDisplay: 'Error: No URL found in the prompt for fallback.',
|
||||
};
|
||||
}
|
||||
|
||||
const results: string[] = [];
|
||||
const processedUrls: string[] = [];
|
||||
|
||||
// Process multiple URLs (up to 20 as mentioned in description)
|
||||
const urlsToProcess = urls.slice(0, 20);
|
||||
|
||||
for (const originalUrl of urlsToProcess) {
|
||||
let url = originalUrl;
|
||||
|
||||
// Convert GitHub blob URL to raw URL
|
||||
if (url.includes('github.com') && url.includes('/blob/')) {
|
||||
url = url
|
||||
.replace('github.com', 'raw.githubusercontent.com')
|
||||
.replace('/blob/', '/');
|
||||
}
|
||||
|
||||
try {
|
||||
const response = await fetchWithTimeout(url, URL_FETCH_TIMEOUT_MS);
|
||||
if (!response.ok) {
|
||||
throw new Error(
|
||||
`Request failed with status code ${response.status} ${response.statusText}`,
|
||||
);
|
||||
}
|
||||
const html = await response.text();
|
||||
const textContent = convert(html, {
|
||||
wordwrap: false,
|
||||
selectors: [
|
||||
{ selector: 'a', options: { ignoreHref: true } },
|
||||
{ selector: 'img', format: 'skip' },
|
||||
],
|
||||
}).substring(0, MAX_CONTENT_LENGTH);
|
||||
|
||||
results.push(`Content from ${url}:\n${textContent}`);
|
||||
processedUrls.push(url);
|
||||
} catch (e) {
|
||||
const error = e as Error;
|
||||
results.push(`Error fetching ${url}: ${error.message}`);
|
||||
processedUrls.push(url);
|
||||
}
|
||||
}
|
||||
|
||||
try {
|
||||
const geminiClient = this.config.getGeminiClient();
|
||||
const combinedContent = results.join('\n\n---\n\n');
|
||||
|
||||
// Ensure the total prompt length doesn't exceed limits
|
||||
const maxPromptLength = 200000; // Leave room for system instructions
|
||||
const promptPrefix = `The user requested the following: "${params.prompt}".
|
||||
|
||||
I have fetched the content from the following URL(s). Please use this content to answer the user's request. Do not attempt to access the URL(s) again.
|
||||
|
||||
`;
|
||||
|
||||
let finalContent = combinedContent;
|
||||
if (promptPrefix.length + combinedContent.length > maxPromptLength) {
|
||||
const availableLength = maxPromptLength - promptPrefix.length - 100; // Leave some buffer
|
||||
finalContent =
|
||||
combinedContent.substring(0, availableLength) +
|
||||
'\n\n[Content truncated due to length limits]';
|
||||
}
|
||||
|
||||
const fallbackPrompt = promptPrefix + finalContent;
|
||||
|
||||
const result = await geminiClient.generateContent(
|
||||
[{ role: 'user', parts: [{ text: fallbackPrompt }] }],
|
||||
{},
|
||||
signal,
|
||||
);
|
||||
const resultText = getResponseText(result) || '';
|
||||
return {
|
||||
llmContent: resultText,
|
||||
returnDisplay: `Content from ${processedUrls.length} URL(s) processed using fallback fetch.`,
|
||||
};
|
||||
} catch (e) {
|
||||
const error = e as Error;
|
||||
const errorMessage = `Error during fallback processing: ${error.message}`;
|
||||
return {
|
||||
llmContent: `Error: ${errorMessage}`,
|
||||
returnDisplay: `Error: ${errorMessage}`,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
validateParams(params: WebFetchToolParams): string | null {
|
||||
const errors = SchemaValidator.validate(this.schema.parameters, params);
|
||||
if (errors) {
|
||||
return errors;
|
||||
}
|
||||
if (!params.prompt || params.prompt.trim() === '') {
|
||||
return "The 'prompt' parameter cannot be empty and must contain URL(s) and instructions.";
|
||||
}
|
||||
if (
|
||||
!params.prompt.includes('http://') &&
|
||||
!params.prompt.includes('https://')
|
||||
) {
|
||||
return "The 'prompt' must contain at least one valid URL (starting with http:// or https://).";
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
getDescription(params: WebFetchToolParams): string {
|
||||
const displayPrompt =
|
||||
params.prompt.length > 100
|
||||
? params.prompt.substring(0, 97) + '...'
|
||||
: params.prompt;
|
||||
return `Processing URLs and instructions from prompt: "${displayPrompt}"`;
|
||||
}
|
||||
|
||||
async shouldConfirmExecute(
|
||||
params: WebFetchToolParams,
|
||||
): Promise<ToolCallConfirmationDetails | false> {
|
||||
if (this.config.getApprovalMode() === ApprovalMode.AUTO_EDIT) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const validationError = this.validateParams(params);
|
||||
if (validationError) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Perform GitHub URL conversion here to differentiate between user-provided
|
||||
// URL and the actual URL to be fetched.
|
||||
const urls = extractUrls(params.prompt).map((url) => {
|
||||
if (url.includes('github.com') && url.includes('/blob/')) {
|
||||
return url
|
||||
.replace('github.com', 'raw.githubusercontent.com')
|
||||
.replace('/blob/', '/');
|
||||
}
|
||||
return url;
|
||||
});
|
||||
|
||||
const confirmationDetails: ToolCallConfirmationDetails = {
|
||||
type: 'info',
|
||||
title: `Confirm Web Fetch`,
|
||||
prompt: params.prompt,
|
||||
urls,
|
||||
onConfirm: async (outcome: ToolConfirmationOutcome) => {
|
||||
if (outcome === ToolConfirmationOutcome.ProceedAlways) {
|
||||
this.config.setApprovalMode(ApprovalMode.AUTO_EDIT);
|
||||
}
|
||||
},
|
||||
};
|
||||
return confirmationDetails;
|
||||
}
|
||||
|
||||
async execute(
|
||||
params: WebFetchToolParams,
|
||||
signal: AbortSignal,
|
||||
): Promise<ToolResult> {
|
||||
const validationError = this.validateParams(params);
|
||||
if (validationError) {
|
||||
return {
|
||||
llmContent: `Error: Invalid parameters provided. Reason: ${validationError}`,
|
||||
returnDisplay: validationError,
|
||||
};
|
||||
}
|
||||
|
||||
const userPrompt = params.prompt;
|
||||
const urls = extractUrls(userPrompt);
|
||||
const url = urls[0];
|
||||
const isPrivate = isPrivateIp(url);
|
||||
|
||||
if (isPrivate) {
|
||||
return this.executeFallback(params, signal);
|
||||
}
|
||||
|
||||
const geminiClient = this.config.getGeminiClient();
|
||||
const contentGenerator = geminiClient.getContentGenerator();
|
||||
|
||||
// Check if using OpenAI content generator - if so, use fallback
|
||||
if (contentGenerator.constructor.name === 'OpenAIContentGenerator') {
|
||||
return this.executeFallback(params, signal);
|
||||
}
|
||||
|
||||
try {
|
||||
const response = await geminiClient.generateContent(
|
||||
[{ role: 'user', parts: [{ text: userPrompt }] }],
|
||||
{ tools: [{ urlContext: {} }] },
|
||||
signal, // Pass signal
|
||||
);
|
||||
|
||||
console.debug(
|
||||
`[WebFetchTool] Full response for prompt "${userPrompt.substring(
|
||||
0,
|
||||
50,
|
||||
)}...":`,
|
||||
JSON.stringify(response, null, 2),
|
||||
);
|
||||
|
||||
let responseText = getResponseText(response) || '';
|
||||
const urlContextMeta = response.candidates?.[0]?.urlContextMetadata;
|
||||
const groundingMetadata = response.candidates?.[0]?.groundingMetadata;
|
||||
const sources = groundingMetadata?.groundingChunks as
|
||||
| GroundingChunkItem[]
|
||||
| undefined;
|
||||
const groundingSupports = groundingMetadata?.groundingSupports as
|
||||
| GroundingSupportItem[]
|
||||
| undefined;
|
||||
|
||||
// Error Handling
|
||||
let processingError = false;
|
||||
|
||||
if (
|
||||
urlContextMeta?.urlMetadata &&
|
||||
urlContextMeta.urlMetadata.length > 0
|
||||
) {
|
||||
const allStatuses = urlContextMeta.urlMetadata.map(
|
||||
(m) => m.urlRetrievalStatus,
|
||||
);
|
||||
if (allStatuses.every((s) => s !== 'URL_RETRIEVAL_STATUS_SUCCESS')) {
|
||||
processingError = true;
|
||||
}
|
||||
} else if (!responseText.trim() && !sources?.length) {
|
||||
// No URL metadata and no content/sources
|
||||
processingError = true;
|
||||
}
|
||||
|
||||
if (
|
||||
!processingError &&
|
||||
!responseText.trim() &&
|
||||
(!sources || sources.length === 0)
|
||||
) {
|
||||
// Successfully retrieved some URL (or no specific error from urlContextMeta), but no usable text or grounding data.
|
||||
processingError = true;
|
||||
}
|
||||
|
||||
if (processingError) {
|
||||
return this.executeFallback(params, signal);
|
||||
}
|
||||
|
||||
const sourceListFormatted: string[] = [];
|
||||
if (sources && sources.length > 0) {
|
||||
sources.forEach((source: GroundingChunkItem, index: number) => {
|
||||
const title = source.web?.title || 'Untitled';
|
||||
const uri = source.web?.uri || 'Unknown URI'; // Fallback if URI is missing
|
||||
sourceListFormatted.push(`[${index + 1}] ${title} (${uri})`);
|
||||
});
|
||||
|
||||
if (groundingSupports && groundingSupports.length > 0) {
|
||||
const insertions: Array<{ index: number; marker: string }> = [];
|
||||
groundingSupports.forEach((support: GroundingSupportItem) => {
|
||||
if (support.segment && support.groundingChunkIndices) {
|
||||
const citationMarker = support.groundingChunkIndices
|
||||
.map((chunkIndex: number) => `[${chunkIndex + 1}]`)
|
||||
.join('');
|
||||
insertions.push({
|
||||
index: support.segment.endIndex,
|
||||
marker: citationMarker,
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
insertions.sort((a, b) => b.index - a.index);
|
||||
const responseChars = responseText.split('');
|
||||
insertions.forEach((insertion) => {
|
||||
responseChars.splice(insertion.index, 0, insertion.marker);
|
||||
});
|
||||
responseText = responseChars.join('');
|
||||
}
|
||||
|
||||
if (sourceListFormatted.length > 0) {
|
||||
responseText += `
|
||||
|
||||
Sources:
|
||||
${sourceListFormatted.join('\n')}`;
|
||||
}
|
||||
}
|
||||
|
||||
const llmContent = responseText;
|
||||
|
||||
console.debug(
|
||||
`[WebFetchTool] Formatted tool response for prompt "${userPrompt}:\n\n":`,
|
||||
llmContent,
|
||||
);
|
||||
|
||||
return {
|
||||
llmContent,
|
||||
returnDisplay: `Content processed from prompt.`,
|
||||
};
|
||||
} catch (error: unknown) {
|
||||
const errorMessage = `Error processing web content for prompt "${userPrompt.substring(
|
||||
0,
|
||||
50,
|
||||
)}...": ${getErrorMessage(error)}`;
|
||||
console.error(errorMessage, error);
|
||||
return {
|
||||
llmContent: `Error: ${errorMessage}`,
|
||||
returnDisplay: `Error: ${errorMessage}`,
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
197
packages/core/src/tools/web-search.ts
Normal file
197
packages/core/src/tools/web-search.ts
Normal file
@@ -0,0 +1,197 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { GroundingMetadata } from '@google/genai';
|
||||
import { BaseTool, ToolResult } from './tools.js';
|
||||
import { Type } from '@google/genai';
|
||||
import { SchemaValidator } from '../utils/schemaValidator.js';
|
||||
|
||||
import { getErrorMessage } from '../utils/errors.js';
|
||||
import { Config } from '../config/config.js';
|
||||
import { getResponseText } from '../utils/generateContentResponseUtilities.js';
|
||||
|
||||
interface GroundingChunkWeb {
|
||||
uri?: string;
|
||||
title?: string;
|
||||
}
|
||||
|
||||
interface GroundingChunkItem {
|
||||
web?: GroundingChunkWeb;
|
||||
// Other properties might exist if needed in the future
|
||||
}
|
||||
|
||||
interface GroundingSupportSegment {
|
||||
startIndex: number;
|
||||
endIndex: number;
|
||||
text?: string; // text is optional as per the example
|
||||
}
|
||||
|
||||
interface GroundingSupportItem {
|
||||
segment?: GroundingSupportSegment;
|
||||
groundingChunkIndices?: number[];
|
||||
confidenceScores?: number[]; // Optional as per example
|
||||
}
|
||||
|
||||
/**
|
||||
* Parameters for the WebSearchTool.
|
||||
*/
|
||||
export interface WebSearchToolParams {
|
||||
/**
|
||||
* The search query.
|
||||
*/
|
||||
|
||||
query: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* Extends ToolResult to include sources for web search.
|
||||
*/
|
||||
export interface WebSearchToolResult extends ToolResult {
|
||||
sources?: GroundingMetadata extends { groundingChunks: GroundingChunkItem[] }
|
||||
? GroundingMetadata['groundingChunks']
|
||||
: GroundingChunkItem[];
|
||||
}
|
||||
|
||||
/**
|
||||
* A tool to perform web searches using Google Search via the Gemini API.
|
||||
*/
|
||||
export class WebSearchTool extends BaseTool<
|
||||
WebSearchToolParams,
|
||||
WebSearchToolResult
|
||||
> {
|
||||
static readonly Name: string = 'google_web_search';
|
||||
|
||||
constructor(private readonly config: Config) {
|
||||
super(
|
||||
WebSearchTool.Name,
|
||||
'GoogleSearch',
|
||||
'Performs a web search using Google Search (via the Gemini API) and returns the results. This tool is useful for finding information on the internet based on a query.',
|
||||
{
|
||||
type: Type.OBJECT,
|
||||
properties: {
|
||||
query: {
|
||||
type: Type.STRING,
|
||||
description: 'The search query to find information on the web.',
|
||||
},
|
||||
},
|
||||
required: ['query'],
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Validates the parameters for the WebSearchTool.
|
||||
* @param params The parameters to validate
|
||||
* @returns An error message string if validation fails, null if valid
|
||||
*/
|
||||
validateParams(params: WebSearchToolParams): string | null {
|
||||
const errors = SchemaValidator.validate(this.schema.parameters, params);
|
||||
if (errors) {
|
||||
return errors;
|
||||
}
|
||||
|
||||
if (!params.query || params.query.trim() === '') {
|
||||
return "The 'query' parameter cannot be empty.";
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
getDescription(params: WebSearchToolParams): string {
|
||||
return `Searching the web for: "${params.query}"`;
|
||||
}
|
||||
|
||||
async execute(
|
||||
params: WebSearchToolParams,
|
||||
signal: AbortSignal,
|
||||
): Promise<WebSearchToolResult> {
|
||||
const validationError = this.validateToolParams(params);
|
||||
if (validationError) {
|
||||
return {
|
||||
llmContent: `Error: Invalid parameters provided. Reason: ${validationError}`,
|
||||
returnDisplay: validationError,
|
||||
};
|
||||
}
|
||||
const geminiClient = this.config.getGeminiClient();
|
||||
|
||||
try {
|
||||
const response = await geminiClient.generateContent(
|
||||
[{ role: 'user', parts: [{ text: params.query }] }],
|
||||
{ tools: [{ googleSearch: {} }] },
|
||||
signal,
|
||||
);
|
||||
|
||||
const responseText = getResponseText(response);
|
||||
const groundingMetadata = response.candidates?.[0]?.groundingMetadata;
|
||||
const sources = groundingMetadata?.groundingChunks as
|
||||
| GroundingChunkItem[]
|
||||
| undefined;
|
||||
const groundingSupports = groundingMetadata?.groundingSupports as
|
||||
| GroundingSupportItem[]
|
||||
| undefined;
|
||||
|
||||
if (!responseText || !responseText.trim()) {
|
||||
return {
|
||||
llmContent: `No search results or information found for query: "${params.query}"`,
|
||||
returnDisplay: 'No information found.',
|
||||
};
|
||||
}
|
||||
|
||||
let modifiedResponseText = responseText;
|
||||
const sourceListFormatted: string[] = [];
|
||||
|
||||
if (sources && sources.length > 0) {
|
||||
sources.forEach((source: GroundingChunkItem, index: number) => {
|
||||
const title = source.web?.title || 'Untitled';
|
||||
const uri = source.web?.uri || 'No URI';
|
||||
sourceListFormatted.push(`[${index + 1}] ${title} (${uri})`);
|
||||
});
|
||||
|
||||
if (groundingSupports && groundingSupports.length > 0) {
|
||||
const insertions: Array<{ index: number; marker: string }> = [];
|
||||
groundingSupports.forEach((support: GroundingSupportItem) => {
|
||||
if (support.segment && support.groundingChunkIndices) {
|
||||
const citationMarker = support.groundingChunkIndices
|
||||
.map((chunkIndex: number) => `[${chunkIndex + 1}]`)
|
||||
.join('');
|
||||
insertions.push({
|
||||
index: support.segment.endIndex,
|
||||
marker: citationMarker,
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
// Sort insertions by index in descending order to avoid shifting subsequent indices
|
||||
insertions.sort((a, b) => b.index - a.index);
|
||||
|
||||
const responseChars = modifiedResponseText.split(''); // Use new variable
|
||||
insertions.forEach((insertion) => {
|
||||
// Fixed arrow function syntax
|
||||
responseChars.splice(insertion.index, 0, insertion.marker);
|
||||
});
|
||||
modifiedResponseText = responseChars.join(''); // Assign back to modifiedResponseText
|
||||
}
|
||||
|
||||
if (sourceListFormatted.length > 0) {
|
||||
modifiedResponseText +=
|
||||
'\n\nSources:\n' + sourceListFormatted.join('\n'); // Fixed string concatenation
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
llmContent: `Web search results for "${params.query}":\n\n${modifiedResponseText}`,
|
||||
returnDisplay: `Search results for "${params.query}" returned.`,
|
||||
sources,
|
||||
};
|
||||
} catch (error: unknown) {
|
||||
const errorMessage = `Error during web search for query "${params.query}": ${getErrorMessage(error)}`;
|
||||
console.error(errorMessage, error);
|
||||
return {
|
||||
llmContent: `Error: ${errorMessage}`,
|
||||
returnDisplay: `Error performing web search.`,
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
619
packages/core/src/tools/write-file.test.ts
Normal file
619
packages/core/src/tools/write-file.test.ts
Normal file
@@ -0,0 +1,619 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import {
|
||||
describe,
|
||||
it,
|
||||
expect,
|
||||
beforeEach,
|
||||
afterEach,
|
||||
vi,
|
||||
type Mocked,
|
||||
} from 'vitest';
|
||||
import { WriteFileTool } from './write-file.js';
|
||||
import {
|
||||
FileDiff,
|
||||
ToolConfirmationOutcome,
|
||||
ToolEditConfirmationDetails,
|
||||
} from './tools.js';
|
||||
import { type EditToolParams } from './edit.js';
|
||||
import { ApprovalMode, Config } from '../config/config.js';
|
||||
import { ToolRegistry } from './tool-registry.js';
|
||||
import path from 'path';
|
||||
import fs from 'fs';
|
||||
import os from 'os';
|
||||
import { GeminiClient } from '../core/client.js';
|
||||
import {
|
||||
ensureCorrectEdit,
|
||||
ensureCorrectFileContent,
|
||||
CorrectedEditResult,
|
||||
} from '../utils/editCorrector.js';
|
||||
|
||||
const rootDir = path.resolve(os.tmpdir(), 'gemini-cli-test-root');
|
||||
|
||||
// --- MOCKS ---
|
||||
vi.mock('../core/client.js');
|
||||
vi.mock('../utils/editCorrector.js');
|
||||
|
||||
let mockGeminiClientInstance: Mocked<GeminiClient>;
|
||||
const mockEnsureCorrectEdit = vi.fn<typeof ensureCorrectEdit>();
|
||||
const mockEnsureCorrectFileContent = vi.fn<typeof ensureCorrectFileContent>();
|
||||
|
||||
// Wire up the mocked functions to be used by the actual module imports
|
||||
vi.mocked(ensureCorrectEdit).mockImplementation(mockEnsureCorrectEdit);
|
||||
vi.mocked(ensureCorrectFileContent).mockImplementation(
|
||||
mockEnsureCorrectFileContent,
|
||||
);
|
||||
|
||||
// Mock Config
|
||||
const mockConfigInternal = {
|
||||
getTargetDir: () => rootDir,
|
||||
getApprovalMode: vi.fn(() => ApprovalMode.DEFAULT),
|
||||
setApprovalMode: vi.fn(),
|
||||
getGeminiClient: vi.fn(), // Initialize as a plain mock function
|
||||
getApiKey: () => 'test-key',
|
||||
getModel: () => 'test-model',
|
||||
getSandbox: () => false,
|
||||
getDebugMode: () => false,
|
||||
getQuestion: () => undefined,
|
||||
getFullContext: () => false,
|
||||
getToolDiscoveryCommand: () => undefined,
|
||||
getToolCallCommand: () => undefined,
|
||||
getMcpServerCommand: () => undefined,
|
||||
getMcpServers: () => undefined,
|
||||
getUserAgent: () => 'test-agent',
|
||||
getUserMemory: () => '',
|
||||
setUserMemory: vi.fn(),
|
||||
getGeminiMdFileCount: () => 0,
|
||||
setGeminiMdFileCount: vi.fn(),
|
||||
getToolRegistry: () =>
|
||||
({
|
||||
registerTool: vi.fn(),
|
||||
discoverTools: vi.fn(),
|
||||
}) as unknown as ToolRegistry,
|
||||
};
|
||||
const mockConfig = mockConfigInternal as unknown as Config;
|
||||
// --- END MOCKS ---
|
||||
|
||||
describe('WriteFileTool', () => {
|
||||
let tool: WriteFileTool;
|
||||
let tempDir: string;
|
||||
|
||||
beforeEach(() => {
|
||||
// Create a unique temporary directory for files created outside the root
|
||||
tempDir = fs.mkdtempSync(
|
||||
path.join(os.tmpdir(), 'write-file-test-external-'),
|
||||
);
|
||||
// Ensure the rootDir for the tool exists
|
||||
if (!fs.existsSync(rootDir)) {
|
||||
fs.mkdirSync(rootDir, { recursive: true });
|
||||
}
|
||||
|
||||
// Setup GeminiClient mock
|
||||
mockGeminiClientInstance = new (vi.mocked(GeminiClient))(
|
||||
mockConfig,
|
||||
) as Mocked<GeminiClient>;
|
||||
vi.mocked(GeminiClient).mockImplementation(() => mockGeminiClientInstance);
|
||||
|
||||
// Now that mockGeminiClientInstance is initialized, set the mock implementation for getGeminiClient
|
||||
mockConfigInternal.getGeminiClient.mockReturnValue(
|
||||
mockGeminiClientInstance,
|
||||
);
|
||||
|
||||
tool = new WriteFileTool(mockConfig);
|
||||
|
||||
// Reset mocks before each test
|
||||
mockConfigInternal.getApprovalMode.mockReturnValue(ApprovalMode.DEFAULT);
|
||||
mockConfigInternal.setApprovalMode.mockClear();
|
||||
mockEnsureCorrectEdit.mockReset();
|
||||
mockEnsureCorrectFileContent.mockReset();
|
||||
|
||||
// Default mock implementations that return valid structures
|
||||
mockEnsureCorrectEdit.mockImplementation(
|
||||
async (
|
||||
filePath: string,
|
||||
_currentContent: string,
|
||||
params: EditToolParams,
|
||||
_client: GeminiClient,
|
||||
signal?: AbortSignal, // Make AbortSignal optional to match usage
|
||||
): Promise<CorrectedEditResult> => {
|
||||
if (signal?.aborted) {
|
||||
return Promise.reject(new Error('Aborted'));
|
||||
}
|
||||
return Promise.resolve({
|
||||
params: { ...params, new_string: params.new_string ?? '' },
|
||||
occurrences: 1,
|
||||
});
|
||||
},
|
||||
);
|
||||
mockEnsureCorrectFileContent.mockImplementation(
|
||||
async (
|
||||
content: string,
|
||||
_client: GeminiClient,
|
||||
signal?: AbortSignal,
|
||||
): Promise<string> => {
|
||||
// Make AbortSignal optional
|
||||
if (signal?.aborted) {
|
||||
return Promise.reject(new Error('Aborted'));
|
||||
}
|
||||
return Promise.resolve(content ?? '');
|
||||
},
|
||||
);
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
// Clean up the temporary directories
|
||||
if (fs.existsSync(tempDir)) {
|
||||
fs.rmSync(tempDir, { recursive: true, force: true });
|
||||
}
|
||||
if (fs.existsSync(rootDir)) {
|
||||
fs.rmSync(rootDir, { recursive: true, force: true });
|
||||
}
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
describe('validateToolParams', () => {
|
||||
it('should return null for valid absolute path within root', () => {
|
||||
const params = {
|
||||
file_path: path.join(rootDir, 'test.txt'),
|
||||
content: 'hello',
|
||||
};
|
||||
expect(tool.validateToolParams(params)).toBeNull();
|
||||
});
|
||||
|
||||
it('should return error for relative path', () => {
|
||||
const params = { file_path: 'test.txt', content: 'hello' };
|
||||
expect(tool.validateToolParams(params)).toMatch(
|
||||
/File path must be absolute/,
|
||||
);
|
||||
});
|
||||
|
||||
it('should return error for path outside root', () => {
|
||||
const outsidePath = path.resolve(tempDir, 'outside-root.txt');
|
||||
const params = {
|
||||
file_path: outsidePath,
|
||||
content: 'hello',
|
||||
};
|
||||
expect(tool.validateToolParams(params)).toMatch(
|
||||
/File path must be within the root directory/,
|
||||
);
|
||||
});
|
||||
|
||||
it('should return error if path is a directory', () => {
|
||||
const dirAsFilePath = path.join(rootDir, 'a_directory');
|
||||
fs.mkdirSync(dirAsFilePath);
|
||||
const params = {
|
||||
file_path: dirAsFilePath,
|
||||
content: 'hello',
|
||||
};
|
||||
expect(tool.validateToolParams(params)).toMatch(
|
||||
`Path is a directory, not a file: ${dirAsFilePath}`,
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('_getCorrectedFileContent', () => {
|
||||
it('should call ensureCorrectFileContent for a new file', async () => {
|
||||
const filePath = path.join(rootDir, 'new_corrected_file.txt');
|
||||
const proposedContent = 'Proposed new content.';
|
||||
const correctedContent = 'Corrected new content.';
|
||||
const abortSignal = new AbortController().signal;
|
||||
// Ensure the mock is set for this specific test case if needed, or rely on beforeEach
|
||||
mockEnsureCorrectFileContent.mockResolvedValue(correctedContent);
|
||||
|
||||
// @ts-expect-error _getCorrectedFileContent is private
|
||||
const result = await tool._getCorrectedFileContent(
|
||||
filePath,
|
||||
proposedContent,
|
||||
abortSignal,
|
||||
);
|
||||
|
||||
expect(mockEnsureCorrectFileContent).toHaveBeenCalledWith(
|
||||
proposedContent,
|
||||
mockGeminiClientInstance,
|
||||
abortSignal,
|
||||
);
|
||||
expect(mockEnsureCorrectEdit).not.toHaveBeenCalled();
|
||||
expect(result.correctedContent).toBe(correctedContent);
|
||||
expect(result.originalContent).toBe('');
|
||||
expect(result.fileExists).toBe(false);
|
||||
expect(result.error).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should call ensureCorrectEdit for an existing file', async () => {
|
||||
const filePath = path.join(rootDir, 'existing_corrected_file.txt');
|
||||
const originalContent = 'Original existing content.';
|
||||
const proposedContent = 'Proposed replacement content.';
|
||||
const correctedProposedContent = 'Corrected replacement content.';
|
||||
const abortSignal = new AbortController().signal;
|
||||
fs.writeFileSync(filePath, originalContent, 'utf8');
|
||||
|
||||
// Ensure this mock is active and returns the correct structure
|
||||
mockEnsureCorrectEdit.mockResolvedValue({
|
||||
params: {
|
||||
file_path: filePath,
|
||||
old_string: originalContent,
|
||||
new_string: correctedProposedContent,
|
||||
},
|
||||
occurrences: 1,
|
||||
} as CorrectedEditResult);
|
||||
|
||||
// @ts-expect-error _getCorrectedFileContent is private
|
||||
const result = await tool._getCorrectedFileContent(
|
||||
filePath,
|
||||
proposedContent,
|
||||
abortSignal,
|
||||
);
|
||||
|
||||
expect(mockEnsureCorrectEdit).toHaveBeenCalledWith(
|
||||
filePath,
|
||||
originalContent,
|
||||
{
|
||||
old_string: originalContent,
|
||||
new_string: proposedContent,
|
||||
file_path: filePath,
|
||||
},
|
||||
mockGeminiClientInstance,
|
||||
abortSignal,
|
||||
);
|
||||
expect(mockEnsureCorrectFileContent).not.toHaveBeenCalled();
|
||||
expect(result.correctedContent).toBe(correctedProposedContent);
|
||||
expect(result.originalContent).toBe(originalContent);
|
||||
expect(result.fileExists).toBe(true);
|
||||
expect(result.error).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should return error if reading an existing file fails (e.g. permissions)', async () => {
|
||||
const filePath = path.join(rootDir, 'unreadable_file.txt');
|
||||
const proposedContent = 'some content';
|
||||
const abortSignal = new AbortController().signal;
|
||||
fs.writeFileSync(filePath, 'content', { mode: 0o000 });
|
||||
|
||||
const readError = new Error('Permission denied');
|
||||
const originalReadFileSync = fs.readFileSync;
|
||||
vi.spyOn(fs, 'readFileSync').mockImplementationOnce(() => {
|
||||
throw readError;
|
||||
});
|
||||
|
||||
// @ts-expect-error _getCorrectedFileContent is private
|
||||
const result = await tool._getCorrectedFileContent(
|
||||
filePath,
|
||||
proposedContent,
|
||||
abortSignal,
|
||||
);
|
||||
|
||||
expect(fs.readFileSync).toHaveBeenCalledWith(filePath, 'utf8');
|
||||
expect(mockEnsureCorrectEdit).not.toHaveBeenCalled();
|
||||
expect(mockEnsureCorrectFileContent).not.toHaveBeenCalled();
|
||||
expect(result.correctedContent).toBe(proposedContent);
|
||||
expect(result.originalContent).toBe('');
|
||||
expect(result.fileExists).toBe(true);
|
||||
expect(result.error).toEqual({
|
||||
message: 'Permission denied',
|
||||
code: undefined,
|
||||
});
|
||||
|
||||
vi.spyOn(fs, 'readFileSync').mockImplementation(originalReadFileSync);
|
||||
fs.chmodSync(filePath, 0o600);
|
||||
});
|
||||
});
|
||||
|
||||
describe('shouldConfirmExecute', () => {
|
||||
const abortSignal = new AbortController().signal;
|
||||
it('should return false if params are invalid (relative path)', async () => {
|
||||
const params = { file_path: 'relative.txt', content: 'test' };
|
||||
const confirmation = await tool.shouldConfirmExecute(params, abortSignal);
|
||||
expect(confirmation).toBe(false);
|
||||
});
|
||||
|
||||
it('should return false if params are invalid (outside root)', async () => {
|
||||
const outsidePath = path.resolve(tempDir, 'outside-root.txt');
|
||||
const params = { file_path: outsidePath, content: 'test' };
|
||||
const confirmation = await tool.shouldConfirmExecute(params, abortSignal);
|
||||
expect(confirmation).toBe(false);
|
||||
});
|
||||
|
||||
it('should return false if _getCorrectedFileContent returns an error', async () => {
|
||||
const filePath = path.join(rootDir, 'confirm_error_file.txt');
|
||||
const params = { file_path: filePath, content: 'test content' };
|
||||
fs.writeFileSync(filePath, 'original', { mode: 0o000 });
|
||||
|
||||
const readError = new Error('Simulated read error for confirmation');
|
||||
const originalReadFileSync = fs.readFileSync;
|
||||
vi.spyOn(fs, 'readFileSync').mockImplementationOnce(() => {
|
||||
throw readError;
|
||||
});
|
||||
|
||||
const confirmation = await tool.shouldConfirmExecute(params, abortSignal);
|
||||
expect(confirmation).toBe(false);
|
||||
|
||||
vi.spyOn(fs, 'readFileSync').mockImplementation(originalReadFileSync);
|
||||
fs.chmodSync(filePath, 0o600);
|
||||
});
|
||||
|
||||
it('should request confirmation with diff for a new file (with corrected content)', async () => {
|
||||
const filePath = path.join(rootDir, 'confirm_new_file.txt');
|
||||
const proposedContent = 'Proposed new content for confirmation.';
|
||||
const correctedContent = 'Corrected new content for confirmation.';
|
||||
mockEnsureCorrectFileContent.mockResolvedValue(correctedContent); // Ensure this mock is active
|
||||
|
||||
const params = { file_path: filePath, content: proposedContent };
|
||||
const confirmation = (await tool.shouldConfirmExecute(
|
||||
params,
|
||||
abortSignal,
|
||||
)) as ToolEditConfirmationDetails;
|
||||
|
||||
expect(mockEnsureCorrectFileContent).toHaveBeenCalledWith(
|
||||
proposedContent,
|
||||
mockGeminiClientInstance,
|
||||
abortSignal,
|
||||
);
|
||||
expect(confirmation).toEqual(
|
||||
expect.objectContaining({
|
||||
title: `Confirm Write: ${path.basename(filePath)}`,
|
||||
fileName: 'confirm_new_file.txt',
|
||||
fileDiff: expect.stringContaining(correctedContent),
|
||||
}),
|
||||
);
|
||||
expect(confirmation.fileDiff).toMatch(
|
||||
/--- confirm_new_file.txt\tCurrent/,
|
||||
);
|
||||
expect(confirmation.fileDiff).toMatch(
|
||||
/\+\+\+ confirm_new_file.txt\tProposed/,
|
||||
);
|
||||
});
|
||||
|
||||
it('should request confirmation with diff for an existing file (with corrected content)', async () => {
|
||||
const filePath = path.join(rootDir, 'confirm_existing_file.txt');
|
||||
const originalContent = 'Original content for confirmation.';
|
||||
const proposedContent = 'Proposed replacement for confirmation.';
|
||||
const correctedProposedContent =
|
||||
'Corrected replacement for confirmation.';
|
||||
fs.writeFileSync(filePath, originalContent, 'utf8');
|
||||
|
||||
mockEnsureCorrectEdit.mockResolvedValue({
|
||||
params: {
|
||||
file_path: filePath,
|
||||
old_string: originalContent,
|
||||
new_string: correctedProposedContent,
|
||||
},
|
||||
occurrences: 1,
|
||||
});
|
||||
|
||||
const params = { file_path: filePath, content: proposedContent };
|
||||
const confirmation = (await tool.shouldConfirmExecute(
|
||||
params,
|
||||
abortSignal,
|
||||
)) as ToolEditConfirmationDetails;
|
||||
|
||||
expect(mockEnsureCorrectEdit).toHaveBeenCalledWith(
|
||||
filePath,
|
||||
originalContent,
|
||||
{
|
||||
old_string: originalContent,
|
||||
new_string: proposedContent,
|
||||
file_path: filePath,
|
||||
},
|
||||
mockGeminiClientInstance,
|
||||
abortSignal,
|
||||
);
|
||||
expect(confirmation).toEqual(
|
||||
expect.objectContaining({
|
||||
title: `Confirm Write: ${path.basename(filePath)}`,
|
||||
fileName: 'confirm_existing_file.txt',
|
||||
fileDiff: expect.stringContaining(correctedProposedContent),
|
||||
}),
|
||||
);
|
||||
expect(confirmation.fileDiff).toMatch(
|
||||
originalContent.replace(/[.*+?^${}()|[\\]\\]/g, '\\$&'),
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('execute', () => {
|
||||
const abortSignal = new AbortController().signal;
|
||||
it('should return error if params are invalid (relative path)', async () => {
|
||||
const params = { file_path: 'relative.txt', content: 'test' };
|
||||
const result = await tool.execute(params, abortSignal);
|
||||
expect(result.llmContent).toMatch(/Error: Invalid parameters provided/);
|
||||
expect(result.returnDisplay).toMatch(/Error: File path must be absolute/);
|
||||
});
|
||||
|
||||
it('should return error if params are invalid (path outside root)', async () => {
|
||||
const outsidePath = path.resolve(tempDir, 'outside-root.txt');
|
||||
const params = { file_path: outsidePath, content: 'test' };
|
||||
const result = await tool.execute(params, abortSignal);
|
||||
expect(result.llmContent).toMatch(/Error: Invalid parameters provided/);
|
||||
expect(result.returnDisplay).toMatch(
|
||||
/Error: File path must be within the root directory/,
|
||||
);
|
||||
});
|
||||
|
||||
it('should return error if _getCorrectedFileContent returns an error during execute', async () => {
|
||||
const filePath = path.join(rootDir, 'execute_error_file.txt');
|
||||
const params = { file_path: filePath, content: 'test content' };
|
||||
fs.writeFileSync(filePath, 'original', { mode: 0o000 });
|
||||
|
||||
const readError = new Error('Simulated read error for execute');
|
||||
const originalReadFileSync = fs.readFileSync;
|
||||
vi.spyOn(fs, 'readFileSync').mockImplementationOnce(() => {
|
||||
throw readError;
|
||||
});
|
||||
|
||||
const result = await tool.execute(params, abortSignal);
|
||||
expect(result.llmContent).toMatch(/Error checking existing file/);
|
||||
expect(result.returnDisplay).toMatch(
|
||||
/Error checking existing file: Simulated read error for execute/,
|
||||
);
|
||||
|
||||
vi.spyOn(fs, 'readFileSync').mockImplementation(originalReadFileSync);
|
||||
fs.chmodSync(filePath, 0o600);
|
||||
});
|
||||
|
||||
it('should write a new file with corrected content and return diff', async () => {
|
||||
const filePath = path.join(rootDir, 'execute_new_corrected_file.txt');
|
||||
const proposedContent = 'Proposed new content for execute.';
|
||||
const correctedContent = 'Corrected new content for execute.';
|
||||
mockEnsureCorrectFileContent.mockResolvedValue(correctedContent);
|
||||
|
||||
const params = { file_path: filePath, content: proposedContent };
|
||||
|
||||
const confirmDetails = await tool.shouldConfirmExecute(
|
||||
params,
|
||||
abortSignal,
|
||||
);
|
||||
if (typeof confirmDetails === 'object' && confirmDetails.onConfirm) {
|
||||
await confirmDetails.onConfirm(ToolConfirmationOutcome.ProceedOnce);
|
||||
}
|
||||
|
||||
const result = await tool.execute(params, abortSignal);
|
||||
|
||||
expect(mockEnsureCorrectFileContent).toHaveBeenCalledWith(
|
||||
proposedContent,
|
||||
mockGeminiClientInstance,
|
||||
abortSignal,
|
||||
);
|
||||
expect(result.llmContent).toMatch(
|
||||
/Successfully created and wrote to new file/,
|
||||
);
|
||||
expect(fs.existsSync(filePath)).toBe(true);
|
||||
expect(fs.readFileSync(filePath, 'utf8')).toBe(correctedContent);
|
||||
const display = result.returnDisplay as FileDiff;
|
||||
expect(display.fileName).toBe('execute_new_corrected_file.txt');
|
||||
expect(display.fileDiff).toMatch(
|
||||
/--- execute_new_corrected_file.txt\tOriginal/,
|
||||
);
|
||||
expect(display.fileDiff).toMatch(
|
||||
/\+\+\+ execute_new_corrected_file.txt\tWritten/,
|
||||
);
|
||||
expect(display.fileDiff).toMatch(
|
||||
correctedContent.replace(/[.*+?^${}()|[\\]\\]/g, '\\$&'),
|
||||
);
|
||||
});
|
||||
|
||||
it('should overwrite an existing file with corrected content and return diff', async () => {
|
||||
const filePath = path.join(
|
||||
rootDir,
|
||||
'execute_existing_corrected_file.txt',
|
||||
);
|
||||
const initialContent = 'Initial content for execute.';
|
||||
const proposedContent = 'Proposed overwrite for execute.';
|
||||
const correctedProposedContent = 'Corrected overwrite for execute.';
|
||||
fs.writeFileSync(filePath, initialContent, 'utf8');
|
||||
|
||||
mockEnsureCorrectEdit.mockResolvedValue({
|
||||
params: {
|
||||
file_path: filePath,
|
||||
old_string: initialContent,
|
||||
new_string: correctedProposedContent,
|
||||
},
|
||||
occurrences: 1,
|
||||
});
|
||||
|
||||
const params = { file_path: filePath, content: proposedContent };
|
||||
|
||||
const confirmDetails = await tool.shouldConfirmExecute(
|
||||
params,
|
||||
abortSignal,
|
||||
);
|
||||
if (typeof confirmDetails === 'object' && confirmDetails.onConfirm) {
|
||||
await confirmDetails.onConfirm(ToolConfirmationOutcome.ProceedOnce);
|
||||
}
|
||||
|
||||
const result = await tool.execute(params, abortSignal);
|
||||
|
||||
expect(mockEnsureCorrectEdit).toHaveBeenCalledWith(
|
||||
filePath,
|
||||
initialContent,
|
||||
{
|
||||
old_string: initialContent,
|
||||
new_string: proposedContent,
|
||||
file_path: filePath,
|
||||
},
|
||||
mockGeminiClientInstance,
|
||||
abortSignal,
|
||||
);
|
||||
expect(result.llmContent).toMatch(/Successfully overwrote file/);
|
||||
expect(fs.readFileSync(filePath, 'utf8')).toBe(correctedProposedContent);
|
||||
const display = result.returnDisplay as FileDiff;
|
||||
expect(display.fileName).toBe('execute_existing_corrected_file.txt');
|
||||
expect(display.fileDiff).toMatch(
|
||||
initialContent.replace(/[.*+?^${}()|[\\]\\]/g, '\\$&'),
|
||||
);
|
||||
expect(display.fileDiff).toMatch(
|
||||
correctedProposedContent.replace(/[.*+?^${}()|[\\]\\]/g, '\\$&'),
|
||||
);
|
||||
});
|
||||
|
||||
it('should create directory if it does not exist', async () => {
|
||||
const dirPath = path.join(rootDir, 'new_dir_for_write');
|
||||
const filePath = path.join(dirPath, 'file_in_new_dir.txt');
|
||||
const content = 'Content in new directory';
|
||||
mockEnsureCorrectFileContent.mockResolvedValue(content); // Ensure this mock is active
|
||||
|
||||
const params = { file_path: filePath, content };
|
||||
// Simulate confirmation if your logic requires it before execute, or remove if not needed for this path
|
||||
const confirmDetails = await tool.shouldConfirmExecute(
|
||||
params,
|
||||
abortSignal,
|
||||
);
|
||||
if (typeof confirmDetails === 'object' && confirmDetails.onConfirm) {
|
||||
await confirmDetails.onConfirm(ToolConfirmationOutcome.ProceedOnce);
|
||||
}
|
||||
|
||||
await tool.execute(params, abortSignal);
|
||||
|
||||
expect(fs.existsSync(dirPath)).toBe(true);
|
||||
expect(fs.statSync(dirPath).isDirectory()).toBe(true);
|
||||
expect(fs.existsSync(filePath)).toBe(true);
|
||||
expect(fs.readFileSync(filePath, 'utf8')).toBe(content);
|
||||
});
|
||||
|
||||
it('should include modification message when proposed content is modified', async () => {
|
||||
const filePath = path.join(rootDir, 'new_file_modified.txt');
|
||||
const content = 'New file content modified by user';
|
||||
mockEnsureCorrectFileContent.mockResolvedValue(content);
|
||||
|
||||
const params = {
|
||||
file_path: filePath,
|
||||
content,
|
||||
modified_by_user: true,
|
||||
};
|
||||
const result = await tool.execute(params, abortSignal);
|
||||
|
||||
expect(result.llmContent).toMatch(/User modified the `content`/);
|
||||
});
|
||||
|
||||
it('should not include modification message when proposed content is not modified', async () => {
|
||||
const filePath = path.join(rootDir, 'new_file_unmodified.txt');
|
||||
const content = 'New file content not modified';
|
||||
mockEnsureCorrectFileContent.mockResolvedValue(content);
|
||||
|
||||
const params = {
|
||||
file_path: filePath,
|
||||
content,
|
||||
modified_by_user: false,
|
||||
};
|
||||
const result = await tool.execute(params, abortSignal);
|
||||
|
||||
expect(result.llmContent).not.toMatch(/User modified the `content`/);
|
||||
});
|
||||
|
||||
it('should not include modification message when modified_by_user is not provided', async () => {
|
||||
const filePath = path.join(rootDir, 'new_file_unmodified.txt');
|
||||
const content = 'New file content not modified';
|
||||
mockEnsureCorrectFileContent.mockResolvedValue(content);
|
||||
|
||||
const params = {
|
||||
file_path: filePath,
|
||||
content,
|
||||
};
|
||||
const result = await tool.execute(params, abortSignal);
|
||||
|
||||
expect(result.llmContent).not.toMatch(/User modified the `content`/);
|
||||
});
|
||||
});
|
||||
});
|
||||
398
packages/core/src/tools/write-file.ts
Normal file
398
packages/core/src/tools/write-file.ts
Normal file
@@ -0,0 +1,398 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import fs from 'fs';
|
||||
import path from 'path';
|
||||
import * as Diff from 'diff';
|
||||
import { Config, ApprovalMode } from '../config/config.js';
|
||||
import {
|
||||
BaseTool,
|
||||
ToolResult,
|
||||
FileDiff,
|
||||
ToolEditConfirmationDetails,
|
||||
ToolConfirmationOutcome,
|
||||
ToolCallConfirmationDetails,
|
||||
} from './tools.js';
|
||||
import { Type } from '@google/genai';
|
||||
import { SchemaValidator } from '../utils/schemaValidator.js';
|
||||
import { makeRelative, shortenPath } from '../utils/paths.js';
|
||||
import { getErrorMessage, isNodeError } from '../utils/errors.js';
|
||||
import {
|
||||
ensureCorrectEdit,
|
||||
ensureCorrectFileContent,
|
||||
} from '../utils/editCorrector.js';
|
||||
import { DEFAULT_DIFF_OPTIONS } from './diffOptions.js';
|
||||
import { ModifiableTool, ModifyContext } from './modifiable-tool.js';
|
||||
import { getSpecificMimeType, isWithinRoot } from '../utils/fileUtils.js';
|
||||
import {
|
||||
recordFileOperationMetric,
|
||||
FileOperation,
|
||||
} from '../telemetry/metrics.js';
|
||||
|
||||
/**
|
||||
* Parameters for the WriteFile tool
|
||||
*/
|
||||
export interface WriteFileToolParams {
|
||||
/**
|
||||
* The absolute path to the file to write to
|
||||
*/
|
||||
file_path: string;
|
||||
|
||||
/**
|
||||
* The content to write to the file
|
||||
*/
|
||||
content: string;
|
||||
|
||||
/**
|
||||
* Whether the proposed content was modified by the user.
|
||||
*/
|
||||
modified_by_user?: boolean;
|
||||
}
|
||||
|
||||
interface GetCorrectedFileContentResult {
|
||||
originalContent: string;
|
||||
correctedContent: string;
|
||||
fileExists: boolean;
|
||||
error?: { message: string; code?: string };
|
||||
}
|
||||
|
||||
/**
|
||||
* Implementation of the WriteFile tool logic
|
||||
*/
|
||||
export class WriteFileTool
|
||||
extends BaseTool<WriteFileToolParams, ToolResult>
|
||||
implements ModifiableTool<WriteFileToolParams>
|
||||
{
|
||||
static readonly Name: string = 'write_file';
|
||||
|
||||
constructor(private readonly config: Config) {
|
||||
super(
|
||||
WriteFileTool.Name,
|
||||
'WriteFile',
|
||||
`Writes content to a specified file in the local filesystem.
|
||||
|
||||
The user has the ability to modify \`content\`. If modified, this will be stated in the response.`,
|
||||
{
|
||||
properties: {
|
||||
file_path: {
|
||||
description:
|
||||
"The absolute path to the file to write to (e.g., '/home/user/project/file.txt'). Relative paths are not supported.",
|
||||
type: Type.STRING,
|
||||
},
|
||||
content: {
|
||||
description: 'The content to write to the file.',
|
||||
type: Type.STRING,
|
||||
},
|
||||
},
|
||||
required: ['file_path', 'content'],
|
||||
type: Type.OBJECT,
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
validateToolParams(params: WriteFileToolParams): string | null {
|
||||
const errors = SchemaValidator.validate(this.schema.parameters, params);
|
||||
if (errors) {
|
||||
return errors;
|
||||
}
|
||||
|
||||
const filePath = params.file_path;
|
||||
if (!path.isAbsolute(filePath)) {
|
||||
return `File path must be absolute: ${filePath}`;
|
||||
}
|
||||
if (!isWithinRoot(filePath, this.config.getTargetDir())) {
|
||||
return `File path must be within the root directory (${this.config.getTargetDir()}): ${filePath}`;
|
||||
}
|
||||
|
||||
try {
|
||||
// This check should be performed only if the path exists.
|
||||
// If it doesn't exist, it's a new file, which is valid for writing.
|
||||
if (fs.existsSync(filePath)) {
|
||||
const stats = fs.lstatSync(filePath);
|
||||
if (stats.isDirectory()) {
|
||||
return `Path is a directory, not a file: ${filePath}`;
|
||||
}
|
||||
}
|
||||
} catch (statError: unknown) {
|
||||
// If fs.existsSync is true but lstatSync fails (e.g., permissions, race condition where file is deleted)
|
||||
// this indicates an issue with accessing the path that should be reported.
|
||||
return `Error accessing path properties for validation: ${filePath}. Reason: ${statError instanceof Error ? statError.message : String(statError)}`;
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
getDescription(params: WriteFileToolParams): string {
|
||||
if (!params.file_path || !params.content) {
|
||||
return `Model did not provide valid parameters for write file tool`;
|
||||
}
|
||||
const relativePath = makeRelative(
|
||||
params.file_path,
|
||||
this.config.getTargetDir(),
|
||||
);
|
||||
return `Writing to ${shortenPath(relativePath)}`;
|
||||
}
|
||||
|
||||
/**
|
||||
* Handles the confirmation prompt for the WriteFile tool.
|
||||
*/
|
||||
async shouldConfirmExecute(
|
||||
params: WriteFileToolParams,
|
||||
abortSignal: AbortSignal,
|
||||
): Promise<ToolCallConfirmationDetails | false> {
|
||||
if (this.config.getApprovalMode() === ApprovalMode.AUTO_EDIT) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const validationError = this.validateToolParams(params);
|
||||
if (validationError) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const correctedContentResult = await this._getCorrectedFileContent(
|
||||
params.file_path,
|
||||
params.content,
|
||||
abortSignal,
|
||||
);
|
||||
|
||||
if (correctedContentResult.error) {
|
||||
// If file exists but couldn't be read, we can't show a diff for confirmation.
|
||||
return false;
|
||||
}
|
||||
|
||||
const { originalContent, correctedContent } = correctedContentResult;
|
||||
const relativePath = makeRelative(
|
||||
params.file_path,
|
||||
this.config.getTargetDir(),
|
||||
);
|
||||
const fileName = path.basename(params.file_path);
|
||||
|
||||
const fileDiff = Diff.createPatch(
|
||||
fileName,
|
||||
originalContent, // Original content (empty if new file or unreadable)
|
||||
correctedContent, // Content after potential correction
|
||||
'Current',
|
||||
'Proposed',
|
||||
DEFAULT_DIFF_OPTIONS,
|
||||
);
|
||||
|
||||
const confirmationDetails: ToolEditConfirmationDetails = {
|
||||
type: 'edit',
|
||||
title: `Confirm Write: ${shortenPath(relativePath)}`,
|
||||
fileName,
|
||||
fileDiff,
|
||||
onConfirm: async (outcome: ToolConfirmationOutcome) => {
|
||||
if (outcome === ToolConfirmationOutcome.ProceedAlways) {
|
||||
this.config.setApprovalMode(ApprovalMode.AUTO_EDIT);
|
||||
}
|
||||
},
|
||||
};
|
||||
return confirmationDetails;
|
||||
}
|
||||
|
||||
async execute(
|
||||
params: WriteFileToolParams,
|
||||
abortSignal: AbortSignal,
|
||||
): Promise<ToolResult> {
|
||||
const validationError = this.validateToolParams(params);
|
||||
if (validationError) {
|
||||
return {
|
||||
llmContent: `Error: Invalid parameters provided. Reason: ${validationError}`,
|
||||
returnDisplay: `Error: ${validationError}`,
|
||||
};
|
||||
}
|
||||
|
||||
const correctedContentResult = await this._getCorrectedFileContent(
|
||||
params.file_path,
|
||||
params.content,
|
||||
abortSignal,
|
||||
);
|
||||
|
||||
if (correctedContentResult.error) {
|
||||
const errDetails = correctedContentResult.error;
|
||||
const errorMsg = `Error checking existing file: ${errDetails.message}`;
|
||||
return {
|
||||
llmContent: `Error checking existing file ${params.file_path}: ${errDetails.message}`,
|
||||
returnDisplay: errorMsg,
|
||||
};
|
||||
}
|
||||
|
||||
const {
|
||||
originalContent,
|
||||
correctedContent: fileContent,
|
||||
fileExists,
|
||||
} = correctedContentResult;
|
||||
// fileExists is true if the file existed (and was readable or unreadable but caught by readError).
|
||||
// fileExists is false if the file did not exist (ENOENT).
|
||||
const isNewFile =
|
||||
!fileExists ||
|
||||
(correctedContentResult.error !== undefined &&
|
||||
!correctedContentResult.fileExists);
|
||||
|
||||
try {
|
||||
const dirName = path.dirname(params.file_path);
|
||||
if (!fs.existsSync(dirName)) {
|
||||
fs.mkdirSync(dirName, { recursive: true });
|
||||
}
|
||||
|
||||
fs.writeFileSync(params.file_path, fileContent, 'utf8');
|
||||
|
||||
// Generate diff for display result
|
||||
const fileName = path.basename(params.file_path);
|
||||
// If there was a readError, originalContent in correctedContentResult is '',
|
||||
// but for the diff, we want to show the original content as it was before the write if possible.
|
||||
// However, if it was unreadable, currentContentForDiff will be empty.
|
||||
const currentContentForDiff = correctedContentResult.error
|
||||
? '' // Or some indicator of unreadable content
|
||||
: originalContent;
|
||||
|
||||
const fileDiff = Diff.createPatch(
|
||||
fileName,
|
||||
currentContentForDiff,
|
||||
fileContent,
|
||||
'Original',
|
||||
'Written',
|
||||
DEFAULT_DIFF_OPTIONS,
|
||||
);
|
||||
|
||||
const llmSuccessMessageParts = [
|
||||
isNewFile
|
||||
? `Successfully created and wrote to new file: ${params.file_path}.`
|
||||
: `Successfully overwrote file: ${params.file_path}.`,
|
||||
];
|
||||
if (params.modified_by_user) {
|
||||
llmSuccessMessageParts.push(
|
||||
`User modified the \`content\` to be: ${params.content}`,
|
||||
);
|
||||
}
|
||||
|
||||
const displayResult: FileDiff = { fileDiff, fileName };
|
||||
|
||||
const lines = fileContent.split('\n').length;
|
||||
const mimetype = getSpecificMimeType(params.file_path);
|
||||
const extension = path.extname(params.file_path); // Get extension
|
||||
if (isNewFile) {
|
||||
recordFileOperationMetric(
|
||||
this.config,
|
||||
FileOperation.CREATE,
|
||||
lines,
|
||||
mimetype,
|
||||
extension,
|
||||
);
|
||||
} else {
|
||||
recordFileOperationMetric(
|
||||
this.config,
|
||||
FileOperation.UPDATE,
|
||||
lines,
|
||||
mimetype,
|
||||
extension,
|
||||
);
|
||||
}
|
||||
|
||||
return {
|
||||
llmContent: llmSuccessMessageParts.join(' '),
|
||||
returnDisplay: displayResult,
|
||||
};
|
||||
} catch (error) {
|
||||
const errorMsg = `Error writing to file: ${error instanceof Error ? error.message : String(error)}`;
|
||||
return {
|
||||
llmContent: `Error writing to file ${params.file_path}: ${errorMsg}`,
|
||||
returnDisplay: `Error: ${errorMsg}`,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
private async _getCorrectedFileContent(
|
||||
filePath: string,
|
||||
proposedContent: string,
|
||||
abortSignal: AbortSignal,
|
||||
): Promise<GetCorrectedFileContentResult> {
|
||||
let originalContent = '';
|
||||
let fileExists = false;
|
||||
let correctedContent = proposedContent;
|
||||
|
||||
try {
|
||||
originalContent = fs.readFileSync(filePath, 'utf8');
|
||||
fileExists = true; // File exists and was read
|
||||
} catch (err) {
|
||||
if (isNodeError(err) && err.code === 'ENOENT') {
|
||||
fileExists = false;
|
||||
originalContent = '';
|
||||
} else {
|
||||
// File exists but could not be read (permissions, etc.)
|
||||
fileExists = true; // Mark as existing but problematic
|
||||
originalContent = ''; // Can't use its content
|
||||
const error = {
|
||||
message: getErrorMessage(err),
|
||||
code: isNodeError(err) ? err.code : undefined,
|
||||
};
|
||||
// Return early as we can't proceed with content correction meaningfully
|
||||
return { originalContent, correctedContent, fileExists, error };
|
||||
}
|
||||
}
|
||||
|
||||
// If readError is set, we have returned.
|
||||
// So, file was either read successfully (fileExists=true, originalContent set)
|
||||
// or it was ENOENT (fileExists=false, originalContent='').
|
||||
|
||||
if (fileExists) {
|
||||
// This implies originalContent is available
|
||||
const { params: correctedParams } = await ensureCorrectEdit(
|
||||
filePath,
|
||||
originalContent,
|
||||
{
|
||||
old_string: originalContent, // Treat entire current content as old_string
|
||||
new_string: proposedContent,
|
||||
file_path: filePath,
|
||||
},
|
||||
this.config.getGeminiClient(),
|
||||
abortSignal,
|
||||
);
|
||||
correctedContent = correctedParams.new_string;
|
||||
} else {
|
||||
// This implies new file (ENOENT)
|
||||
correctedContent = await ensureCorrectFileContent(
|
||||
proposedContent,
|
||||
this.config.getGeminiClient(),
|
||||
abortSignal,
|
||||
);
|
||||
}
|
||||
return { originalContent, correctedContent, fileExists };
|
||||
}
|
||||
|
||||
getModifyContext(
|
||||
abortSignal: AbortSignal,
|
||||
): ModifyContext<WriteFileToolParams> {
|
||||
return {
|
||||
getFilePath: (params: WriteFileToolParams) => params.file_path,
|
||||
getCurrentContent: async (params: WriteFileToolParams) => {
|
||||
const correctedContentResult = await this._getCorrectedFileContent(
|
||||
params.file_path,
|
||||
params.content,
|
||||
abortSignal,
|
||||
);
|
||||
return correctedContentResult.originalContent;
|
||||
},
|
||||
getProposedContent: async (params: WriteFileToolParams) => {
|
||||
const correctedContentResult = await this._getCorrectedFileContent(
|
||||
params.file_path,
|
||||
params.content,
|
||||
abortSignal,
|
||||
);
|
||||
return correctedContentResult.correctedContent;
|
||||
},
|
||||
createUpdatedParams: (
|
||||
_oldContent: string,
|
||||
modifiedProposedContent: string,
|
||||
originalParams: WriteFileToolParams,
|
||||
) => ({
|
||||
...originalParams,
|
||||
content: modifiedProposedContent,
|
||||
modified_by_user: true,
|
||||
}),
|
||||
};
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user