pre-release commit

This commit is contained in:
koalazf.99
2025-07-22 19:59:07 +08:00
parent c5dee4bb17
commit a9d6965bef
485 changed files with 111444 additions and 2 deletions

View File

@@ -0,0 +1,41 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
export class LruCache<K, V> {
private cache: Map<K, V>;
private maxSize: number;
constructor(maxSize: number) {
this.cache = new Map<K, V>();
this.maxSize = maxSize;
}
get(key: K): V | undefined {
const value = this.cache.get(key);
if (value) {
// Move to end to mark as recently used
this.cache.delete(key);
this.cache.set(key, value);
}
return value;
}
set(key: K, value: V): void {
if (this.cache.has(key)) {
this.cache.delete(key);
} else if (this.cache.size >= this.maxSize) {
const firstKey = this.cache.keys().next().value;
if (firstKey !== undefined) {
this.cache.delete(firstKey);
}
}
this.cache.set(key, value);
}
clear(): void {
this.cache.clear();
}
}

View File

@@ -0,0 +1,148 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import * as fs from 'fs';
import { vi, describe, it, expect, beforeEach } from 'vitest';
import * as fsPromises from 'fs/promises';
import * as gitUtils from './gitUtils.js';
import { bfsFileSearch } from './bfsFileSearch.js';
import { FileDiscoveryService } from '../services/fileDiscoveryService.js';
vi.mock('fs');
vi.mock('fs/promises');
vi.mock('./gitUtils.js');
const createMockDirent = (name: string, isFile: boolean): fs.Dirent => {
const dirent = new fs.Dirent();
dirent.name = name;
dirent.isFile = () => isFile;
dirent.isDirectory = () => !isFile;
return dirent;
};
// Type for the specific overload we're using
type ReaddirWithFileTypes = (
path: fs.PathLike,
options: { withFileTypes: true },
) => Promise<fs.Dirent[]>;
describe('bfsFileSearch', () => {
beforeEach(() => {
vi.resetAllMocks();
});
it('should find a file in the root directory', async () => {
const mockFs = vi.mocked(fsPromises);
const mockReaddir = mockFs.readdir as unknown as ReaddirWithFileTypes;
vi.mocked(mockReaddir).mockResolvedValue([
createMockDirent('file1.txt', true),
createMockDirent('file2.txt', true),
]);
const result = await bfsFileSearch('/test', { fileName: 'file1.txt' });
expect(result).toEqual(['/test/file1.txt']);
});
it('should find a file in a subdirectory', async () => {
const mockFs = vi.mocked(fsPromises);
const mockReaddir = mockFs.readdir as unknown as ReaddirWithFileTypes;
vi.mocked(mockReaddir).mockImplementation(async (dir) => {
if (dir === '/test') {
return [createMockDirent('subdir', false)];
}
if (dir === '/test/subdir') {
return [createMockDirent('file1.txt', true)];
}
return [];
});
const result = await bfsFileSearch('/test', { fileName: 'file1.txt' });
expect(result).toEqual(['/test/subdir/file1.txt']);
});
it('should ignore specified directories', async () => {
const mockFs = vi.mocked(fsPromises);
const mockReaddir = mockFs.readdir as unknown as ReaddirWithFileTypes;
vi.mocked(mockReaddir).mockImplementation(async (dir) => {
if (dir === '/test') {
return [
createMockDirent('subdir1', false),
createMockDirent('subdir2', false),
];
}
if (dir === '/test/subdir1') {
return [createMockDirent('file1.txt', true)];
}
if (dir === '/test/subdir2') {
return [createMockDirent('file1.txt', true)];
}
return [];
});
const result = await bfsFileSearch('/test', {
fileName: 'file1.txt',
ignoreDirs: ['subdir2'],
});
expect(result).toEqual(['/test/subdir1/file1.txt']);
});
it('should respect maxDirs limit', async () => {
const mockFs = vi.mocked(fsPromises);
const mockReaddir = mockFs.readdir as unknown as ReaddirWithFileTypes;
vi.mocked(mockReaddir).mockImplementation(async (dir) => {
if (dir === '/test') {
return [
createMockDirent('subdir1', false),
createMockDirent('subdir2', false),
];
}
if (dir === '/test/subdir1') {
return [createMockDirent('file1.txt', true)];
}
if (dir === '/test/subdir2') {
return [createMockDirent('file1.txt', true)];
}
return [];
});
const result = await bfsFileSearch('/test', {
fileName: 'file1.txt',
maxDirs: 2,
});
expect(result).toEqual(['/test/subdir1/file1.txt']);
});
it('should respect .gitignore files', async () => {
const mockFs = vi.mocked(fsPromises);
const mockGitUtils = vi.mocked(gitUtils);
mockGitUtils.isGitRepository.mockReturnValue(true);
const mockReaddir = mockFs.readdir as unknown as ReaddirWithFileTypes;
vi.mocked(mockReaddir).mockImplementation(async (dir) => {
if (dir === '/test') {
return [
createMockDirent('.gitignore', true),
createMockDirent('subdir1', false),
createMockDirent('subdir2', false),
];
}
if (dir === '/test/subdir1') {
return [createMockDirent('file1.txt', true)];
}
if (dir === '/test/subdir2') {
return [createMockDirent('file1.txt', true)];
}
return [];
});
vi.mocked(fs).readFileSync.mockReturnValue('subdir2');
const fileService = new FileDiscoveryService('/test');
const result = await bfsFileSearch('/test', {
fileName: 'file1.txt',
fileService,
});
expect(result).toEqual(['/test/subdir1/file1.txt']);
});
});

View File

@@ -0,0 +1,87 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import * as fs from 'fs/promises';
import * as path from 'path';
import { Dirent } from 'fs';
import { FileDiscoveryService } from '../services/fileDiscoveryService.js';
// Simple console logger for now.
// TODO: Integrate with a more robust server-side logger.
const logger = {
// eslint-disable-next-line @typescript-eslint/no-explicit-any
debug: (...args: any[]) => console.debug('[DEBUG] [BfsFileSearch]', ...args),
};
interface BfsFileSearchOptions {
fileName: string;
ignoreDirs?: string[];
maxDirs?: number;
debug?: boolean;
fileService?: FileDiscoveryService;
}
/**
* Performs a breadth-first search for a specific file within a directory structure.
*
* @param rootDir The directory to start the search from.
* @param options Configuration for the search.
* @returns A promise that resolves to an array of paths where the file was found.
*/
export async function bfsFileSearch(
rootDir: string,
options: BfsFileSearchOptions,
): Promise<string[]> {
const {
fileName,
ignoreDirs = [],
maxDirs = Infinity,
debug = false,
fileService,
} = options;
const foundFiles: string[] = [];
const queue: string[] = [rootDir];
const visited = new Set<string>();
let scannedDirCount = 0;
while (queue.length > 0 && scannedDirCount < maxDirs) {
const currentDir = queue.shift()!;
if (visited.has(currentDir)) {
continue;
}
visited.add(currentDir);
scannedDirCount++;
if (debug) {
logger.debug(`Scanning [${scannedDirCount}/${maxDirs}]: ${currentDir}`);
}
let entries: Dirent[];
try {
entries = await fs.readdir(currentDir, { withFileTypes: true });
} catch {
// Ignore errors for directories we can't read (e.g., permissions)
continue;
}
for (const entry of entries) {
const fullPath = path.join(currentDir, entry.name);
if (fileService?.shouldGitIgnoreFile(fullPath)) {
continue;
}
if (entry.isDirectory()) {
if (!ignoreDirs.includes(entry.name)) {
queue.push(fullPath);
}
} else if (entry.isFile() && entry.name === fileName) {
foundFiles.push(fullPath);
}
}
}
return foundFiles;
}

View File

@@ -0,0 +1,767 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
/* eslint-disable @typescript-eslint/no-explicit-any */
import {
vi,
describe,
it,
expect,
beforeEach,
Mock,
type Mocked,
} from 'vitest';
import * as fs from 'fs';
import { EditTool } from '../tools/edit.js';
// MOCKS
let callCount = 0;
const mockResponses: any[] = [];
let mockGenerateJson: any;
let mockStartChat: any;
let mockSendMessageStream: any;
vi.mock('fs', () => ({
statSync: vi.fn(),
}));
vi.mock('../core/client.js', () => ({
GeminiClient: vi.fn().mockImplementation(function (
this: any,
_config: Config,
) {
this.generateJson = (...params: any[]) => mockGenerateJson(...params); // Corrected: use mockGenerateJson
this.startChat = (...params: any[]) => mockStartChat(...params); // Corrected: use mockStartChat
this.sendMessageStream = (...params: any[]) =>
mockSendMessageStream(...params); // Corrected: use mockSendMessageStream
return this;
}),
}));
// END MOCKS
import {
countOccurrences,
ensureCorrectEdit,
ensureCorrectFileContent,
unescapeStringForGeminiBug,
resetEditCorrectorCaches_TEST_ONLY,
} from './editCorrector.js';
import { GeminiClient } from '../core/client.js';
import type { Config } from '../config/config.js';
import { ToolRegistry } from '../tools/tool-registry.js';
vi.mock('../tools/tool-registry.js');
describe('editCorrector', () => {
describe('countOccurrences', () => {
it('should return 0 for empty string', () => {
expect(countOccurrences('', 'a')).toBe(0);
});
it('should return 0 for empty substring', () => {
expect(countOccurrences('abc', '')).toBe(0);
});
it('should return 0 if substring is not found', () => {
expect(countOccurrences('abc', 'd')).toBe(0);
});
it('should return 1 if substring is found once', () => {
expect(countOccurrences('abc', 'b')).toBe(1);
});
it('should return correct count for multiple occurrences', () => {
expect(countOccurrences('ababa', 'a')).toBe(3);
expect(countOccurrences('ababab', 'ab')).toBe(3);
});
it('should count non-overlapping occurrences', () => {
expect(countOccurrences('aaaaa', 'aa')).toBe(2);
expect(countOccurrences('ababab', 'aba')).toBe(1);
});
it('should correctly count occurrences when substring is longer', () => {
expect(countOccurrences('abc', 'abcdef')).toBe(0);
});
it('should be case sensitive', () => {
expect(countOccurrences('abcABC', 'a')).toBe(1);
expect(countOccurrences('abcABC', 'A')).toBe(1);
});
});
describe('unescapeStringForGeminiBug', () => {
it('should unescape common sequences', () => {
expect(unescapeStringForGeminiBug('\\n')).toBe('\n');
expect(unescapeStringForGeminiBug('\\t')).toBe('\t');
expect(unescapeStringForGeminiBug("\\'")).toBe("'");
expect(unescapeStringForGeminiBug('\\"')).toBe('"');
expect(unescapeStringForGeminiBug('\\`')).toBe('`');
});
it('should handle multiple escaped sequences', () => {
expect(unescapeStringForGeminiBug('Hello\\nWorld\\tTest')).toBe(
'Hello\nWorld\tTest',
);
});
it('should not alter already correct sequences', () => {
expect(unescapeStringForGeminiBug('\n')).toBe('\n');
expect(unescapeStringForGeminiBug('Correct string')).toBe(
'Correct string',
);
});
it('should handle mixed correct and incorrect sequences', () => {
expect(unescapeStringForGeminiBug('\\nCorrect\t\\`')).toBe(
'\nCorrect\t`',
);
});
it('should handle backslash followed by actual newline character', () => {
expect(unescapeStringForGeminiBug('\\\n')).toBe('\n');
expect(unescapeStringForGeminiBug('First line\\\nSecond line')).toBe(
'First line\nSecond line',
);
});
it('should handle multiple backslashes before an escapable character (aggressive unescaping)', () => {
expect(unescapeStringForGeminiBug('\\\\n')).toBe('\n');
expect(unescapeStringForGeminiBug('\\\\\\t')).toBe('\t');
expect(unescapeStringForGeminiBug('\\\\\\\\`')).toBe('`');
});
it('should return empty string for empty input', () => {
expect(unescapeStringForGeminiBug('')).toBe('');
});
it('should not alter strings with no targeted escape sequences', () => {
expect(unescapeStringForGeminiBug('abc def')).toBe('abc def');
expect(unescapeStringForGeminiBug('C:\\Folder\\File')).toBe(
'C:\\Folder\\File',
);
});
it('should correctly process strings with some targeted escapes', () => {
expect(unescapeStringForGeminiBug('C:\\Users\\name')).toBe(
'C:\\Users\name',
);
});
it('should handle complex cases with mixed slashes and characters', () => {
expect(
unescapeStringForGeminiBug('\\\\\\\nLine1\\\nLine2\\tTab\\\\`Tick\\"'),
).toBe('\nLine1\nLine2\tTab`Tick"');
});
it('should handle escaped backslashes', () => {
expect(unescapeStringForGeminiBug('\\\\')).toBe('\\');
expect(unescapeStringForGeminiBug('C:\\\\Users')).toBe('C:\\Users');
expect(unescapeStringForGeminiBug('path\\\\to\\\\file')).toBe(
'path\to\\file',
);
});
it('should handle escaped backslashes mixed with other escapes (aggressive unescaping)', () => {
expect(unescapeStringForGeminiBug('line1\\\\\\nline2')).toBe(
'line1\nline2',
);
expect(unescapeStringForGeminiBug('quote\\\\"text\\\\nline')).toBe(
'quote"text\nline',
);
});
});
describe('ensureCorrectEdit', () => {
let mockGeminiClientInstance: Mocked<GeminiClient>;
let mockToolRegistry: Mocked<ToolRegistry>;
let mockConfigInstance: Config;
const abortSignal = new AbortController().signal;
beforeEach(() => {
mockToolRegistry = new ToolRegistry({} as Config) as Mocked<ToolRegistry>;
const configParams = {
apiKey: 'test-api-key',
model: 'test-model',
sandbox: false as boolean | string,
targetDir: '/test',
debugMode: false,
question: undefined as string | undefined,
fullContext: false,
coreTools: undefined as string[] | undefined,
toolDiscoveryCommand: undefined as string | undefined,
toolCallCommand: undefined as string | undefined,
mcpServerCommand: undefined as string | undefined,
mcpServers: undefined as Record<string, any> | undefined,
userAgent: 'test-agent',
userMemory: '',
geminiMdFileCount: 0,
alwaysSkipModificationConfirmation: false,
};
mockConfigInstance = {
...configParams,
getApiKey: vi.fn(() => configParams.apiKey),
getModel: vi.fn(() => configParams.model),
getSandbox: vi.fn(() => configParams.sandbox),
getTargetDir: vi.fn(() => configParams.targetDir),
getToolRegistry: vi.fn(() => mockToolRegistry),
getDebugMode: vi.fn(() => configParams.debugMode),
getQuestion: vi.fn(() => configParams.question),
getFullContext: vi.fn(() => configParams.fullContext),
getCoreTools: vi.fn(() => configParams.coreTools),
getToolDiscoveryCommand: vi.fn(() => configParams.toolDiscoveryCommand),
getToolCallCommand: vi.fn(() => configParams.toolCallCommand),
getMcpServerCommand: vi.fn(() => configParams.mcpServerCommand),
getMcpServers: vi.fn(() => configParams.mcpServers),
getUserAgent: vi.fn(() => configParams.userAgent),
getUserMemory: vi.fn(() => configParams.userMemory),
setUserMemory: vi.fn((mem: string) => {
configParams.userMemory = mem;
}),
getGeminiMdFileCount: vi.fn(() => configParams.geminiMdFileCount),
setGeminiMdFileCount: vi.fn((count: number) => {
configParams.geminiMdFileCount = count;
}),
getAlwaysSkipModificationConfirmation: vi.fn(
() => configParams.alwaysSkipModificationConfirmation,
),
setAlwaysSkipModificationConfirmation: vi.fn((skip: boolean) => {
configParams.alwaysSkipModificationConfirmation = skip;
}),
getQuotaErrorOccurred: vi.fn().mockReturnValue(false),
setQuotaErrorOccurred: vi.fn(),
} as unknown as Config;
callCount = 0;
mockResponses.length = 0;
mockGenerateJson = vi
.fn()
.mockImplementation((_contents, _schema, signal) => {
// Check if the signal is aborted. If so, throw an error or return a specific response.
if (signal && signal.aborted) {
return Promise.reject(new Error('Aborted')); // Or some other specific error/response
}
const response = mockResponses[callCount];
callCount++;
if (response === undefined) return Promise.resolve({});
return Promise.resolve(response);
});
mockStartChat = vi.fn();
mockSendMessageStream = vi.fn();
mockGeminiClientInstance = new GeminiClient(
mockConfigInstance,
) as Mocked<GeminiClient>;
mockGeminiClientInstance.getHistory = vi.fn().mockResolvedValue([]);
resetEditCorrectorCaches_TEST_ONLY();
});
describe('Scenario Group 1: originalParams.old_string matches currentContent directly', () => {
it('Test 1.1: old_string (no literal \\), new_string (escaped by Gemini) -> new_string unescaped', async () => {
const currentContent = 'This is a test string to find me.';
const originalParams = {
file_path: '/test/file.txt',
old_string: 'find me',
new_string: 'replace with \\"this\\"',
};
mockResponses.push({
corrected_new_string_escaping: 'replace with "this"',
});
const result = await ensureCorrectEdit(
'/test/file.txt',
currentContent,
originalParams,
mockGeminiClientInstance,
abortSignal,
);
expect(mockGenerateJson).toHaveBeenCalledTimes(1);
expect(result.params.new_string).toBe('replace with "this"');
expect(result.params.old_string).toBe('find me');
expect(result.occurrences).toBe(1);
});
it('Test 1.2: old_string (no literal \\), new_string (correctly formatted) -> new_string unchanged', async () => {
const currentContent = 'This is a test string to find me.';
const originalParams = {
file_path: '/test/file.txt',
old_string: 'find me',
new_string: 'replace with this',
};
const result = await ensureCorrectEdit(
'/test/file.txt',
currentContent,
originalParams,
mockGeminiClientInstance,
abortSignal,
);
expect(mockGenerateJson).toHaveBeenCalledTimes(0);
expect(result.params.new_string).toBe('replace with this');
expect(result.params.old_string).toBe('find me');
expect(result.occurrences).toBe(1);
});
it('Test 1.3: old_string (with literal \\), new_string (escaped by Gemini) -> new_string unchanged (still escaped)', async () => {
const currentContent = 'This is a test string to find\\me.';
const originalParams = {
file_path: '/test/file.txt',
old_string: 'find\\me',
new_string: 'replace with \\"this\\"',
};
mockResponses.push({
corrected_new_string_escaping: 'replace with "this"',
});
const result = await ensureCorrectEdit(
'/test/file.txt',
currentContent,
originalParams,
mockGeminiClientInstance,
abortSignal,
);
expect(mockGenerateJson).toHaveBeenCalledTimes(1);
expect(result.params.new_string).toBe('replace with "this"');
expect(result.params.old_string).toBe('find\\me');
expect(result.occurrences).toBe(1);
});
it('Test 1.4: old_string (with literal \\), new_string (correctly formatted) -> new_string unchanged', async () => {
const currentContent = 'This is a test string to find\\me.';
const originalParams = {
file_path: '/test/file.txt',
old_string: 'find\\me',
new_string: 'replace with this',
};
const result = await ensureCorrectEdit(
'/test/file.txt',
currentContent,
originalParams,
mockGeminiClientInstance,
abortSignal,
);
expect(mockGenerateJson).toHaveBeenCalledTimes(0);
expect(result.params.new_string).toBe('replace with this');
expect(result.params.old_string).toBe('find\\me');
expect(result.occurrences).toBe(1);
});
});
describe('Scenario Group 2: originalParams.old_string does NOT match, but unescapeStringForGeminiBug(originalParams.old_string) DOES match', () => {
it('Test 2.1: old_string (over-escaped, no intended literal \\), new_string (escaped by Gemini) -> new_string unescaped', async () => {
const currentContent = 'This is a test string to find "me".';
const originalParams = {
file_path: '/test/file.txt',
old_string: 'find \\"me\\"',
new_string: 'replace with \\"this\\"',
};
mockResponses.push({ corrected_new_string: 'replace with "this"' });
const result = await ensureCorrectEdit(
'/test/file.txt',
currentContent,
originalParams,
mockGeminiClientInstance,
abortSignal,
);
expect(mockGenerateJson).toHaveBeenCalledTimes(1);
expect(result.params.new_string).toBe('replace with "this"');
expect(result.params.old_string).toBe('find "me"');
expect(result.occurrences).toBe(1);
});
it('Test 2.2: old_string (over-escaped, no intended literal \\), new_string (correctly formatted) -> new_string unescaped (harmlessly)', async () => {
const currentContent = 'This is a test string to find "me".';
const originalParams = {
file_path: '/test/file.txt',
old_string: 'find \\"me\\"',
new_string: 'replace with this',
};
const result = await ensureCorrectEdit(
'/test/file.txt',
currentContent,
originalParams,
mockGeminiClientInstance,
abortSignal,
);
expect(mockGenerateJson).toHaveBeenCalledTimes(0);
expect(result.params.new_string).toBe('replace with this');
expect(result.params.old_string).toBe('find "me"');
expect(result.occurrences).toBe(1);
});
it('Test 2.3: old_string (over-escaped, with intended literal \\), new_string (simple) -> new_string corrected', async () => {
const currentContent = 'This is a test string to find \\me.';
const originalParams = {
file_path: '/test/file.txt',
old_string: 'find \\\\me',
new_string: 'replace with foobar',
};
const result = await ensureCorrectEdit(
'/test/file.txt',
currentContent,
originalParams,
mockGeminiClientInstance,
abortSignal,
);
expect(mockGenerateJson).toHaveBeenCalledTimes(0);
expect(result.params.new_string).toBe('replace with foobar');
expect(result.params.old_string).toBe('find \\me');
expect(result.occurrences).toBe(1);
});
});
describe('Scenario Group 3: LLM Correction Path', () => {
it('Test 3.1: old_string (no literal \\), new_string (escaped by Gemini), LLM re-escapes new_string -> final new_string is double unescaped', async () => {
const currentContent = 'This is a test string to corrected find me.';
const originalParams = {
file_path: '/test/file.txt',
old_string: 'find me',
new_string: 'replace with \\\\"this\\\\"',
};
const llmNewString = 'LLM says replace with "that"';
mockResponses.push({ corrected_new_string_escaping: llmNewString });
const result = await ensureCorrectEdit(
'/test/file.txt',
currentContent,
originalParams,
mockGeminiClientInstance,
abortSignal,
);
expect(mockGenerateJson).toHaveBeenCalledTimes(1);
expect(result.params.new_string).toBe(llmNewString);
expect(result.params.old_string).toBe('find me');
expect(result.occurrences).toBe(1);
});
it('Test 3.2: old_string (with literal \\), new_string (escaped by Gemini), LLM re-escapes new_string -> final new_string is unescaped once', async () => {
const currentContent = 'This is a test string to corrected find me.';
const originalParams = {
file_path: '/test/file.txt',
old_string: 'find\\me',
new_string: 'replace with \\\\"this\\\\"',
};
const llmCorrectedOldString = 'corrected find me';
const llmNewString = 'LLM says replace with "that"';
mockResponses.push({ corrected_target_snippet: llmCorrectedOldString });
mockResponses.push({ corrected_new_string: llmNewString });
const result = await ensureCorrectEdit(
'/test/file.txt',
currentContent,
originalParams,
mockGeminiClientInstance,
abortSignal,
);
expect(mockGenerateJson).toHaveBeenCalledTimes(2);
expect(result.params.new_string).toBe(llmNewString);
expect(result.params.old_string).toBe(llmCorrectedOldString);
expect(result.occurrences).toBe(1);
});
it('Test 3.3: old_string needs LLM, new_string is fine -> old_string corrected, new_string original', async () => {
const currentContent = 'This is a test string to be corrected.';
const originalParams = {
file_path: '/test/file.txt',
old_string: 'fiiind me',
new_string: 'replace with "this"',
};
const llmCorrectedOldString = 'to be corrected';
mockResponses.push({ corrected_target_snippet: llmCorrectedOldString });
const result = await ensureCorrectEdit(
'/test/file.txt',
currentContent,
originalParams,
mockGeminiClientInstance,
abortSignal,
);
expect(mockGenerateJson).toHaveBeenCalledTimes(1);
expect(result.params.new_string).toBe('replace with "this"');
expect(result.params.old_string).toBe(llmCorrectedOldString);
expect(result.occurrences).toBe(1);
});
it('Test 3.4: LLM correction path, correctNewString returns the originalNewString it was passed (which was unescaped) -> final new_string is unescaped', async () => {
const currentContent = 'This is a test string to corrected find me.';
const originalParams = {
file_path: '/test/file.txt',
old_string: 'find me',
new_string: 'replace with \\\\"this\\\\"',
};
const newStringForLLMAndReturnedByLLM = 'replace with "this"';
mockResponses.push({
corrected_new_string_escaping: newStringForLLMAndReturnedByLLM,
});
const result = await ensureCorrectEdit(
'/test/file.txt',
currentContent,
originalParams,
mockGeminiClientInstance,
abortSignal,
);
expect(mockGenerateJson).toHaveBeenCalledTimes(1);
expect(result.params.new_string).toBe(newStringForLLMAndReturnedByLLM);
expect(result.occurrences).toBe(1);
});
});
describe('Scenario Group 4: No Match Found / Multiple Matches', () => {
it('Test 4.1: No version of old_string (original, unescaped, LLM-corrected) matches -> returns original params, 0 occurrences', async () => {
const currentContent = 'This content has nothing to find.';
const originalParams = {
file_path: '/test/file.txt',
old_string: 'nonexistent string',
new_string: 'some new string',
};
mockResponses.push({ corrected_target_snippet: 'still nonexistent' });
const result = await ensureCorrectEdit(
'/test/file.txt',
currentContent,
originalParams,
mockGeminiClientInstance,
abortSignal,
);
expect(mockGenerateJson).toHaveBeenCalledTimes(1);
expect(result.params).toEqual(originalParams);
expect(result.occurrences).toBe(0);
});
it('Test 4.2: unescapedOldStringAttempt results in >1 occurrences -> returns original params, count occurrences', async () => {
const currentContent =
'This content has find "me" and also find "me" again.';
const originalParams = {
file_path: '/test/file.txt',
old_string: 'find "me"',
new_string: 'some new string',
};
const result = await ensureCorrectEdit(
'/test/file.txt',
currentContent,
originalParams,
mockGeminiClientInstance,
abortSignal,
);
expect(mockGenerateJson).toHaveBeenCalledTimes(0);
expect(result.params).toEqual(originalParams);
expect(result.occurrences).toBe(2);
});
});
describe('Scenario Group 5: Specific unescapeStringForGeminiBug checks (integrated into ensureCorrectEdit)', () => {
it('Test 5.1: old_string needs LLM to become currentContent, new_string also needs correction', async () => {
const currentContent = 'const x = "a\nbc\\"def\\"';
const originalParams = {
file_path: '/test/file.txt',
old_string: 'const x = \\"a\\nbc\\\\"def\\\\"',
new_string: 'const y = \\"new\\nval\\\\"content\\\\"',
};
const expectedFinalNewString = 'const y = "new\nval\\"content\\"';
mockResponses.push({ corrected_target_snippet: currentContent });
mockResponses.push({ corrected_new_string: expectedFinalNewString });
const result = await ensureCorrectEdit(
'/test/file.txt',
currentContent,
originalParams,
mockGeminiClientInstance,
abortSignal,
);
expect(mockGenerateJson).toHaveBeenCalledTimes(2);
expect(result.params.old_string).toBe(currentContent);
expect(result.params.new_string).toBe(expectedFinalNewString);
expect(result.occurrences).toBe(1);
});
});
describe('Scenario Group 6: Concurrent Edits', () => {
it('Test 6.1: should return early if file was modified by another process', async () => {
const filePath = '/test/file.txt';
const currentContent =
'This content has been modified by someone else.';
const originalParams = {
file_path: filePath,
old_string: 'nonexistent string',
new_string: 'some new string',
};
const now = Date.now();
const lastEditTime = now - 5000; // 5 seconds ago
// Mock the file's modification time to be recent
vi.spyOn(fs, 'statSync').mockReturnValue({
mtimeMs: now,
} as fs.Stats);
// Mock the last edit timestamp from our history to be in the past
const history = [
{
role: 'model',
parts: [
{
functionResponse: {
name: EditTool.Name,
id: `${EditTool.Name}-${lastEditTime}-123`,
response: {
output: {
llmContent: `Successfully modified file: ${filePath}`,
},
},
},
},
],
},
];
(mockGeminiClientInstance.getHistory as Mock).mockResolvedValue(
history,
);
const result = await ensureCorrectEdit(
filePath,
currentContent,
originalParams,
mockGeminiClientInstance,
abortSignal,
);
expect(result.occurrences).toBe(0);
expect(result.params).toEqual(originalParams);
});
});
});
describe('ensureCorrectFileContent', () => {
let mockGeminiClientInstance: Mocked<GeminiClient>;
let mockToolRegistry: Mocked<ToolRegistry>;
let mockConfigInstance: Config;
const abortSignal = new AbortController().signal;
beforeEach(() => {
mockToolRegistry = new ToolRegistry({} as Config) as Mocked<ToolRegistry>;
const configParams = {
apiKey: 'test-api-key',
model: 'test-model',
sandbox: false as boolean | string,
targetDir: '/test',
debugMode: false,
question: undefined as string | undefined,
fullContext: false,
coreTools: undefined as string[] | undefined,
toolDiscoveryCommand: undefined as string | undefined,
toolCallCommand: undefined as string | undefined,
mcpServerCommand: undefined as string | undefined,
mcpServers: undefined as Record<string, any> | undefined,
userAgent: 'test-agent',
userMemory: '',
geminiMdFileCount: 0,
alwaysSkipModificationConfirmation: false,
};
mockConfigInstance = {
...configParams,
getApiKey: vi.fn(() => configParams.apiKey),
getModel: vi.fn(() => configParams.model),
getSandbox: vi.fn(() => configParams.sandbox),
getTargetDir: vi.fn(() => configParams.targetDir),
getToolRegistry: vi.fn(() => mockToolRegistry),
getDebugMode: vi.fn(() => configParams.debugMode),
getQuestion: vi.fn(() => configParams.question),
getFullContext: vi.fn(() => configParams.fullContext),
getCoreTools: vi.fn(() => configParams.coreTools),
getToolDiscoveryCommand: vi.fn(() => configParams.toolDiscoveryCommand),
getToolCallCommand: vi.fn(() => configParams.toolCallCommand),
getMcpServerCommand: vi.fn(() => configParams.mcpServerCommand),
getMcpServers: vi.fn(() => configParams.mcpServers),
getUserAgent: vi.fn(() => configParams.userAgent),
getUserMemory: vi.fn(() => configParams.userMemory),
setUserMemory: vi.fn((mem: string) => {
configParams.userMemory = mem;
}),
getGeminiMdFileCount: vi.fn(() => configParams.geminiMdFileCount),
setGeminiMdFileCount: vi.fn((count: number) => {
configParams.geminiMdFileCount = count;
}),
getAlwaysSkipModificationConfirmation: vi.fn(
() => configParams.alwaysSkipModificationConfirmation,
),
setAlwaysSkipModificationConfirmation: vi.fn((skip: boolean) => {
configParams.alwaysSkipModificationConfirmation = skip;
}),
getQuotaErrorOccurred: vi.fn().mockReturnValue(false),
setQuotaErrorOccurred: vi.fn(),
} as unknown as Config;
callCount = 0;
mockResponses.length = 0;
mockGenerateJson = vi
.fn()
.mockImplementation((_contents, _schema, signal) => {
if (signal && signal.aborted) {
return Promise.reject(new Error('Aborted'));
}
const response = mockResponses[callCount];
callCount++;
if (response === undefined) return Promise.resolve({});
return Promise.resolve(response);
});
mockStartChat = vi.fn();
mockSendMessageStream = vi.fn();
mockGeminiClientInstance = new GeminiClient(
mockConfigInstance,
) as Mocked<GeminiClient>;
resetEditCorrectorCaches_TEST_ONLY();
});
it('should return content unchanged if no escaping issues detected', async () => {
const content = 'This is normal content without escaping issues';
const result = await ensureCorrectFileContent(
content,
mockGeminiClientInstance,
abortSignal,
);
expect(result).toBe(content);
expect(mockGenerateJson).toHaveBeenCalledTimes(0);
});
it('should call correctStringEscaping for potentially escaped content', async () => {
const content = 'console.log(\\"Hello World\\");';
const correctedContent = 'console.log("Hello World");';
mockResponses.push({
corrected_string_escaping: correctedContent,
});
const result = await ensureCorrectFileContent(
content,
mockGeminiClientInstance,
abortSignal,
);
expect(result).toBe(correctedContent);
expect(mockGenerateJson).toHaveBeenCalledTimes(1);
});
it('should handle correctStringEscaping returning corrected content via correct property name', async () => {
// This test specifically verifies the property name fix
const content = 'const message = \\"Hello\\nWorld\\";';
const correctedContent = 'const message = "Hello\nWorld";';
// Mock the response with the correct property name
mockResponses.push({
corrected_string_escaping: correctedContent,
});
const result = await ensureCorrectFileContent(
content,
mockGeminiClientInstance,
abortSignal,
);
expect(result).toBe(correctedContent);
expect(mockGenerateJson).toHaveBeenCalledTimes(1);
});
it('should return original content if LLM correction fails', async () => {
const content = 'console.log(\\"Hello World\\");';
// Mock empty response to simulate LLM failure
mockResponses.push({});
const result = await ensureCorrectFileContent(
content,
mockGeminiClientInstance,
abortSignal,
);
expect(result).toBe(content);
expect(mockGenerateJson).toHaveBeenCalledTimes(1);
});
it('should handle various escape sequences that need correction', async () => {
const content =
'const obj = { name: \\"John\\", age: 30, bio: \\"Developer\\nEngineer\\" };';
const correctedContent =
'const obj = { name: "John", age: 30, bio: "Developer\nEngineer" };';
mockResponses.push({
corrected_string_escaping: correctedContent,
});
const result = await ensureCorrectFileContent(
content,
mockGeminiClientInstance,
abortSignal,
);
expect(result).toBe(correctedContent);
});
});
});

View File

@@ -0,0 +1,755 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import {
Content,
GenerateContentConfig,
SchemaUnion,
Type,
} from '@google/genai';
import { GeminiClient } from '../core/client.js';
import { EditToolParams, EditTool } from '../tools/edit.js';
import { WriteFileTool } from '../tools/write-file.js';
import { ReadFileTool } from '../tools/read-file.js';
import { ReadManyFilesTool } from '../tools/read-many-files.js';
import { GrepTool } from '../tools/grep.js';
import { LruCache } from './LruCache.js';
import { DEFAULT_GEMINI_FLASH_MODEL } from '../config/models.js';
import {
isFunctionResponse,
isFunctionCall,
} from '../utils/messageInspectors.js';
import * as fs from 'fs';
const EditModel = DEFAULT_GEMINI_FLASH_MODEL;
const EditConfig: GenerateContentConfig = {
thinkingConfig: {
thinkingBudget: 0,
},
};
const MAX_CACHE_SIZE = 50;
// Cache for ensureCorrectEdit results
const editCorrectionCache = new LruCache<string, CorrectedEditResult>(
MAX_CACHE_SIZE,
);
// Cache for ensureCorrectFileContent results
const fileContentCorrectionCache = new LruCache<string, string>(MAX_CACHE_SIZE);
/**
* Defines the structure of the parameters within CorrectedEditResult
*/
interface CorrectedEditParams {
file_path: string;
old_string: string;
new_string: string;
}
/**
* Defines the result structure for ensureCorrectEdit.
*/
export interface CorrectedEditResult {
params: CorrectedEditParams;
occurrences: number;
}
/**
* Extracts the timestamp from the .id value, which is in format
* <tool.name>-<timestamp>-<uuid>
* @param fcnId the ID value of a functionCall or functionResponse object
* @returns -1 if the timestamp could not be extracted, else the timestamp (as a number)
*/
function getTimestampFromFunctionId(fcnId: string): number {
const idParts = fcnId.split('-');
if (idParts.length > 2) {
const timestamp = parseInt(idParts[1], 10);
if (!isNaN(timestamp)) {
return timestamp;
}
}
return -1;
}
/**
* Will look through the gemini client history and determine when the most recent
* edit to a target file occured. If no edit happened, it will return -1
* @param filePath the path to the file
* @param client the geminiClient, so that we can get the history
* @returns a DateTime (as a number) of when the last edit occured, or -1 if no edit was found.
*/
async function findLastEditTimestamp(
filePath: string,
client: GeminiClient,
): Promise<number> {
const history = (await client.getHistory()) ?? [];
// Tools that may reference the file path in their FunctionResponse `output`.
const toolsInResp = new Set([
WriteFileTool.Name,
EditTool.Name,
ReadManyFilesTool.Name,
GrepTool.Name,
]);
// Tools that may reference the file path in their FunctionCall `args`.
const toolsInCall = new Set([...toolsInResp, ReadFileTool.Name]);
// Iterate backwards to find the most recent relevant action.
for (const entry of history.slice().reverse()) {
if (!entry.parts) continue;
for (const part of entry.parts) {
let id: string | undefined;
let content: unknown;
// Check for a relevant FunctionCall with the file path in its arguments.
if (
isFunctionCall(entry) &&
part.functionCall?.name &&
toolsInCall.has(part.functionCall.name)
) {
id = part.functionCall.id;
content = part.functionCall.args;
}
// Check for a relevant FunctionResponse with the file path in its output.
else if (
isFunctionResponse(entry) &&
part.functionResponse?.name &&
toolsInResp.has(part.functionResponse.name)
) {
const { response } = part.functionResponse;
if (response && !('error' in response) && 'output' in response) {
id = part.functionResponse.id;
content = response.output;
}
}
if (!id || content === undefined) continue;
// Use the "blunt hammer" approach to find the file path in the content.
// Note that the tool response data is inconsistent in their formatting
// with successes and errors - so, we just check for the existance
// as the best guess to if error/failed occured with the response.
const stringified = JSON.stringify(content);
if (
!stringified.includes('Error') && // only applicable for functionResponse
!stringified.includes('Failed') && // only applicable for functionResponse
stringified.includes(filePath)
) {
return getTimestampFromFunctionId(id);
}
}
}
return -1;
}
/**
* Attempts to correct edit parameters if the original old_string is not found.
* It tries unescaping, and then LLM-based correction.
* Results are cached to avoid redundant processing.
*
* @param currentContent The current content of the file.
* @param originalParams The original EditToolParams
* @param client The GeminiClient for LLM calls.
* @returns A promise resolving to an object containing the (potentially corrected)
* EditToolParams (as CorrectedEditParams) and the final occurrences count.
*/
export async function ensureCorrectEdit(
filePath: string,
currentContent: string,
originalParams: EditToolParams, // This is the EditToolParams from edit.ts, without \'corrected\'
client: GeminiClient,
abortSignal: AbortSignal,
): Promise<CorrectedEditResult> {
const cacheKey = `${currentContent}---${originalParams.old_string}---${originalParams.new_string}`;
const cachedResult = editCorrectionCache.get(cacheKey);
if (cachedResult) {
return cachedResult;
}
let finalNewString = originalParams.new_string;
const newStringPotentiallyEscaped =
unescapeStringForGeminiBug(originalParams.new_string) !==
originalParams.new_string;
const expectedReplacements = originalParams.expected_replacements ?? 1;
let finalOldString = originalParams.old_string;
let occurrences = countOccurrences(currentContent, finalOldString);
if (occurrences === expectedReplacements) {
if (newStringPotentiallyEscaped) {
finalNewString = await correctNewStringEscaping(
client,
finalOldString,
originalParams.new_string,
abortSignal,
);
}
} else if (occurrences > expectedReplacements) {
const expectedReplacements = originalParams.expected_replacements ?? 1;
// If user expects multiple replacements, return as-is
if (occurrences === expectedReplacements) {
const result: CorrectedEditResult = {
params: { ...originalParams },
occurrences,
};
editCorrectionCache.set(cacheKey, result);
return result;
}
// If user expects 1 but found multiple, try to correct (existing behavior)
if (expectedReplacements === 1) {
const result: CorrectedEditResult = {
params: { ...originalParams },
occurrences,
};
editCorrectionCache.set(cacheKey, result);
return result;
}
// If occurrences don't match expected, return as-is (will fail validation later)
const result: CorrectedEditResult = {
params: { ...originalParams },
occurrences,
};
editCorrectionCache.set(cacheKey, result);
return result;
} else {
// occurrences is 0 or some other unexpected state initially
const unescapedOldStringAttempt = unescapeStringForGeminiBug(
originalParams.old_string,
);
occurrences = countOccurrences(currentContent, unescapedOldStringAttempt);
if (occurrences === expectedReplacements) {
finalOldString = unescapedOldStringAttempt;
if (newStringPotentiallyEscaped) {
finalNewString = await correctNewString(
client,
originalParams.old_string, // original old
unescapedOldStringAttempt, // corrected old
originalParams.new_string, // original new (which is potentially escaped)
abortSignal,
);
}
} else if (occurrences === 0) {
if (filePath) {
// In order to keep from clobbering edits made outside our system,
// let's check if there was a more recent edit to the file than what
// our system has done
const lastEditedByUsTime = await findLastEditTimestamp(
filePath,
client,
);
// Add a 1-second buffer to account for timing inaccuracies. If the file
// was modified more than a second after the last edit tool was run, we
// can assume it was modified by something else.
if (lastEditedByUsTime > 0) {
const stats = fs.statSync(filePath);
const diff = stats.mtimeMs - lastEditedByUsTime;
if (diff > 2000) {
// Hard coded for 2 seconds
// This file was edited sooner
const result: CorrectedEditResult = {
params: { ...originalParams },
occurrences: 0, // Explicitly 0 as LLM failed
};
editCorrectionCache.set(cacheKey, result);
return result;
}
}
}
const llmCorrectedOldString = await correctOldStringMismatch(
client,
currentContent,
unescapedOldStringAttempt,
abortSignal,
);
const llmOldOccurrences = countOccurrences(
currentContent,
llmCorrectedOldString,
);
if (llmOldOccurrences === expectedReplacements) {
finalOldString = llmCorrectedOldString;
occurrences = llmOldOccurrences;
if (newStringPotentiallyEscaped) {
const baseNewStringForLLMCorrection = unescapeStringForGeminiBug(
originalParams.new_string,
);
finalNewString = await correctNewString(
client,
originalParams.old_string, // original old
llmCorrectedOldString, // corrected old
baseNewStringForLLMCorrection, // base new for correction
abortSignal,
);
}
} else {
// LLM correction also failed for old_string
const result: CorrectedEditResult = {
params: { ...originalParams },
occurrences: 0, // Explicitly 0 as LLM failed
};
editCorrectionCache.set(cacheKey, result);
return result;
}
} else {
// Unescaping old_string resulted in > 1 occurrence
const result: CorrectedEditResult = {
params: { ...originalParams },
occurrences, // This will be > 1
};
editCorrectionCache.set(cacheKey, result);
return result;
}
}
const { targetString, pair } = trimPairIfPossible(
finalOldString,
finalNewString,
currentContent,
expectedReplacements,
);
finalOldString = targetString;
finalNewString = pair;
// Final result construction
const result: CorrectedEditResult = {
params: {
file_path: originalParams.file_path,
old_string: finalOldString,
new_string: finalNewString,
},
occurrences: countOccurrences(currentContent, finalOldString), // Recalculate occurrences with the final old_string
};
editCorrectionCache.set(cacheKey, result);
return result;
}
export async function ensureCorrectFileContent(
content: string,
client: GeminiClient,
abortSignal: AbortSignal,
): Promise<string> {
const cachedResult = fileContentCorrectionCache.get(content);
if (cachedResult) {
return cachedResult;
}
const contentPotentiallyEscaped =
unescapeStringForGeminiBug(content) !== content;
if (!contentPotentiallyEscaped) {
fileContentCorrectionCache.set(content, content);
return content;
}
const correctedContent = await correctStringEscaping(
content,
client,
abortSignal,
);
fileContentCorrectionCache.set(content, correctedContent);
return correctedContent;
}
// Define the expected JSON schema for the LLM response for old_string correction
const OLD_STRING_CORRECTION_SCHEMA: SchemaUnion = {
type: Type.OBJECT,
properties: {
corrected_target_snippet: {
type: Type.STRING,
description:
'The corrected version of the target snippet that exactly and uniquely matches a segment within the provided file content.',
},
},
required: ['corrected_target_snippet'],
};
export async function correctOldStringMismatch(
geminiClient: GeminiClient,
fileContent: string,
problematicSnippet: string,
abortSignal: AbortSignal,
): Promise<string> {
const prompt = `
Context: A process needs to find an exact literal, unique match for a specific text snippet within a file's content. The provided snippet failed to match exactly. This is most likely because it has been overly escaped.
Task: Analyze the provided file content and the problematic target snippet. Identify the segment in the file content that the snippet was *most likely* intended to match. Output the *exact*, literal text of that segment from the file content. Focus *only* on removing extra escape characters and correcting formatting, whitespace, or minor differences to achieve a PERFECT literal match. The output must be the exact literal text as it appears in the file.
Problematic target snippet:
\`\`\`
${problematicSnippet}
\`\`\`
File Content:
\`\`\`
${fileContent}
\`\`\`
For example, if the problematic target snippet was "\\\\\\nconst greeting = \`Hello \\\\\`\${name}\\\\\`\`;" and the file content had content that looked like "\nconst greeting = \`Hello ${'\\`'}\${name}${'\\`'}\`;", then corrected_target_snippet should likely be "\nconst greeting = \`Hello ${'\\`'}\${name}${'\\`'}\`;" to fix the incorrect escaping to match the original file content.
If the differences are only in whitespace or formatting, apply similar whitespace/formatting changes to the corrected_target_snippet.
Return ONLY the corrected target snippet in the specified JSON format with the key 'corrected_target_snippet'. If no clear, unique match can be found, return an empty string for 'corrected_target_snippet'.
`.trim();
const contents: Content[] = [{ role: 'user', parts: [{ text: prompt }] }];
try {
const result = await geminiClient.generateJson(
contents,
OLD_STRING_CORRECTION_SCHEMA,
abortSignal,
EditModel,
EditConfig,
);
if (
result &&
typeof result.corrected_target_snippet === 'string' &&
result.corrected_target_snippet.length > 0
) {
return result.corrected_target_snippet;
} else {
return problematicSnippet;
}
} catch (error) {
if (abortSignal.aborted) {
throw error;
}
console.error(
'Error during LLM call for old string snippet correction:',
error,
);
return problematicSnippet;
}
}
// Define the expected JSON schema for the new_string correction LLM response
const NEW_STRING_CORRECTION_SCHEMA: SchemaUnion = {
type: Type.OBJECT,
properties: {
corrected_new_string: {
type: Type.STRING,
description:
'The original_new_string adjusted to be a suitable replacement for the corrected_old_string, while maintaining the original intent of the change.',
},
},
required: ['corrected_new_string'],
};
/**
* Adjusts the new_string to align with a corrected old_string, maintaining the original intent.
*/
export async function correctNewString(
geminiClient: GeminiClient,
originalOldString: string,
correctedOldString: string,
originalNewString: string,
abortSignal: AbortSignal,
): Promise<string> {
if (originalOldString === correctedOldString) {
return originalNewString;
}
const prompt = `
Context: A text replacement operation was planned. The original text to be replaced (original_old_string) was slightly different from the actual text in the file (corrected_old_string). The original_old_string has now been corrected to match the file content.
We now need to adjust the replacement text (original_new_string) so that it makes sense as a replacement for the corrected_old_string, while preserving the original intent of the change.
original_old_string (what was initially intended to be found):
\`\`\`
${originalOldString}
\`\`\`
corrected_old_string (what was actually found in the file and will be replaced):
\`\`\`
${correctedOldString}
\`\`\`
original_new_string (what was intended to replace original_old_string):
\`\`\`
${originalNewString}
\`\`\`
Task: Based on the differences between original_old_string and corrected_old_string, and the content of original_new_string, generate a corrected_new_string. This corrected_new_string should be what original_new_string would have been if it was designed to replace corrected_old_string directly, while maintaining the spirit of the original transformation.
For example, if original_old_string was "\\\\\\nconst greeting = \`Hello \\\\\`\${name}\\\\\`\`;" and corrected_old_string is "\nconst greeting = \`Hello ${'\\`'}\${name}${'\\`'}\`;", and original_new_string was "\\\\\\nconst greeting = \`Hello \\\\\`\${name} \${lastName}\\\\\`\`;", then corrected_new_string should likely be "\nconst greeting = \`Hello ${'\\`'}\${name} \${lastName}${'\\`'}\`;" to fix the incorrect escaping.
If the differences are only in whitespace or formatting, apply similar whitespace/formatting changes to the corrected_new_string.
Return ONLY the corrected string in the specified JSON format with the key 'corrected_new_string'. If no adjustment is deemed necessary or possible, return the original_new_string.
`.trim();
const contents: Content[] = [{ role: 'user', parts: [{ text: prompt }] }];
try {
const result = await geminiClient.generateJson(
contents,
NEW_STRING_CORRECTION_SCHEMA,
abortSignal,
EditModel,
EditConfig,
);
if (
result &&
typeof result.corrected_new_string === 'string' &&
result.corrected_new_string.length > 0
) {
return result.corrected_new_string;
} else {
return originalNewString;
}
} catch (error) {
if (abortSignal.aborted) {
throw error;
}
console.error('Error during LLM call for new_string correction:', error);
return originalNewString;
}
}
const CORRECT_NEW_STRING_ESCAPING_SCHEMA: SchemaUnion = {
type: Type.OBJECT,
properties: {
corrected_new_string_escaping: {
type: Type.STRING,
description:
'The new_string with corrected escaping, ensuring it is a proper replacement for the old_string, especially considering potential over-escaping issues from previous LLM generations.',
},
},
required: ['corrected_new_string_escaping'],
};
export async function correctNewStringEscaping(
geminiClient: GeminiClient,
oldString: string,
potentiallyProblematicNewString: string,
abortSignal: AbortSignal,
): Promise<string> {
const prompt = `
Context: A text replacement operation is planned. The text to be replaced (old_string) has been correctly identified in the file. However, the replacement text (new_string) might have been improperly escaped by a previous LLM generation (e.g. too many backslashes for newlines like \\n instead of \n, or unnecessarily quotes like \\"Hello\\" instead of "Hello").
old_string (this is the exact text that will be replaced):
\`\`\`
${oldString}
\`\`\`
potentially_problematic_new_string (this is the text that should replace old_string, but MIGHT have bad escaping, or might be entirely correct):
\`\`\`
${potentiallyProblematicNewString}
\`\`\`
Task: Analyze the potentially_problematic_new_string. If it's syntactically invalid due to incorrect escaping (e.g., "\n", "\t", "\\", "\\'", "\\""), correct the invalid syntax. The goal is to ensure the new_string, when inserted into the code, will be a valid and correctly interpreted.
For example, if old_string is "foo" and potentially_problematic_new_string is "bar\\nbaz", the corrected_new_string_escaping should be "bar\nbaz".
If potentially_problematic_new_string is console.log(\\"Hello World\\"), it should be console.log("Hello World").
Return ONLY the corrected string in the specified JSON format with the key 'corrected_new_string_escaping'. If no escaping correction is needed, return the original potentially_problematic_new_string.
`.trim();
const contents: Content[] = [{ role: 'user', parts: [{ text: prompt }] }];
try {
const result = await geminiClient.generateJson(
contents,
CORRECT_NEW_STRING_ESCAPING_SCHEMA,
abortSignal,
EditModel,
EditConfig,
);
if (
result &&
typeof result.corrected_new_string_escaping === 'string' &&
result.corrected_new_string_escaping.length > 0
) {
return result.corrected_new_string_escaping;
} else {
return potentiallyProblematicNewString;
}
} catch (error) {
if (abortSignal.aborted) {
throw error;
}
console.error(
'Error during LLM call for new_string escaping correction:',
error,
);
return potentiallyProblematicNewString;
}
}
const CORRECT_STRING_ESCAPING_SCHEMA: SchemaUnion = {
type: Type.OBJECT,
properties: {
corrected_string_escaping: {
type: Type.STRING,
description:
'The string with corrected escaping, ensuring it is valid, specially considering potential over-escaping issues from previous LLM generations.',
},
},
required: ['corrected_string_escaping'],
};
export async function correctStringEscaping(
potentiallyProblematicString: string,
client: GeminiClient,
abortSignal: AbortSignal,
): Promise<string> {
const prompt = `
Context: An LLM has just generated potentially_problematic_string and the text might have been improperly escaped (e.g. too many backslashes for newlines like \\n instead of \n, or unnecessarily quotes like \\"Hello\\" instead of "Hello").
potentially_problematic_string (this text MIGHT have bad escaping, or might be entirely correct):
\`\`\`
${potentiallyProblematicString}
\`\`\`
Task: Analyze the potentially_problematic_string. If it's syntactically invalid due to incorrect escaping (e.g., "\n", "\t", "\\", "\\'", "\\""), correct the invalid syntax. The goal is to ensure the text will be a valid and correctly interpreted.
For example, if potentially_problematic_string is "bar\\nbaz", the corrected_new_string_escaping should be "bar\nbaz".
If potentially_problematic_string is console.log(\\"Hello World\\"), it should be console.log("Hello World").
Return ONLY the corrected string in the specified JSON format with the key 'corrected_string_escaping'. If no escaping correction is needed, return the original potentially_problematic_string.
`.trim();
const contents: Content[] = [{ role: 'user', parts: [{ text: prompt }] }];
try {
const result = await client.generateJson(
contents,
CORRECT_STRING_ESCAPING_SCHEMA,
abortSignal,
EditModel,
EditConfig,
);
if (
result &&
typeof result.corrected_string_escaping === 'string' &&
result.corrected_string_escaping.length > 0
) {
return result.corrected_string_escaping;
} else {
return potentiallyProblematicString;
}
} catch (error) {
if (abortSignal.aborted) {
throw error;
}
console.error(
'Error during LLM call for string escaping correction:',
error,
);
return potentiallyProblematicString;
}
}
function trimPairIfPossible(
target: string,
trimIfTargetTrims: string,
currentContent: string,
expectedReplacements: number,
) {
const trimmedTargetString = target.trim();
if (target.length !== trimmedTargetString.length) {
const trimmedTargetOccurrences = countOccurrences(
currentContent,
trimmedTargetString,
);
if (trimmedTargetOccurrences === expectedReplacements) {
const trimmedReactiveString = trimIfTargetTrims.trim();
return {
targetString: trimmedTargetString,
pair: trimmedReactiveString,
};
}
}
return {
targetString: target,
pair: trimIfTargetTrims,
};
}
/**
* Unescapes a string that might have been overly escaped by an LLM.
*/
export function unescapeStringForGeminiBug(inputString: string): string {
// Regex explanation:
// \\ : Matches exactly one literal backslash character.
// (n|t|r|'|"|`|\\|\n) : This is a capturing group. It matches one of the following:
// n, t, r, ', ", ` : These match the literal characters 'n', 't', 'r', single quote, double quote, or backtick.
// This handles cases like "\\n", "\\`", etc.
// \\ : This matches a literal backslash. This handles cases like "\\\\" (escaped backslash).
// \n : This matches an actual newline character. This handles cases where the input
// string might have something like "\\\n" (a literal backslash followed by a newline).
// g : Global flag, to replace all occurrences.
return inputString.replace(
/\\+(n|t|r|'|"|`|\\|\n)/g,
(match, capturedChar) => {
// 'match' is the entire erroneous sequence, e.g., if the input (in memory) was "\\\\`", match is "\\\\`".
// 'capturedChar' is the character that determines the true meaning, e.g., '`'.
switch (capturedChar) {
case 'n':
return '\n'; // Correctly escaped: \n (newline character)
case 't':
return '\t'; // Correctly escaped: \t (tab character)
case 'r':
return '\r'; // Correctly escaped: \r (carriage return character)
case "'":
return "'"; // Correctly escaped: ' (apostrophe character)
case '"':
return '"'; // Correctly escaped: " (quotation mark character)
case '`':
return '`'; // Correctly escaped: ` (backtick character)
case '\\': // This handles when 'capturedChar' is a literal backslash
return '\\'; // Replace escaped backslash (e.g., "\\\\") with single backslash
case '\n': // This handles when 'capturedChar' is an actual newline
return '\n'; // Replace the whole erroneous sequence (e.g., "\\\n" in memory) with a clean newline
default:
// This fallback should ideally not be reached if the regex captures correctly.
// It would return the original matched sequence if an unexpected character was captured.
return match;
}
},
);
}
/**
* Counts occurrences of a substring in a string
*/
export function countOccurrences(str: string, substr: string): number {
if (substr === '') {
return 0;
}
let count = 0;
let pos = str.indexOf(substr);
while (pos !== -1) {
count++;
pos = str.indexOf(substr, pos + substr.length); // Start search after the current match
}
return count;
}
export function resetEditCorrectorCaches_TEST_ONLY() {
editCorrectionCache.clear();
fileContentCorrectionCache.clear();
}

View File

@@ -0,0 +1,378 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import {
vi,
describe,
it,
expect,
beforeEach,
afterEach,
type Mock,
} from 'vitest';
import {
checkHasEditorType,
getDiffCommand,
openDiff,
allowEditorTypeInSandbox,
isEditorAvailable,
type EditorType,
} from './editor.js';
import { execSync, spawn } from 'child_process';
vi.mock('child_process', () => ({
execSync: vi.fn(),
spawn: vi.fn(),
}));
const originalPlatform = process.platform;
describe('editor utils', () => {
beforeEach(() => {
vi.clearAllMocks();
delete process.env.SANDBOX;
Object.defineProperty(process, 'platform', {
value: originalPlatform,
writable: true,
});
});
afterEach(() => {
vi.restoreAllMocks();
delete process.env.SANDBOX;
Object.defineProperty(process, 'platform', {
value: originalPlatform,
writable: true,
});
});
describe('checkHasEditorType', () => {
const testCases: Array<{
editor: EditorType;
command: string;
win32Command: string;
}> = [
{ editor: 'vscode', command: 'code', win32Command: 'code.cmd' },
{ editor: 'vscodium', command: 'codium', win32Command: 'codium.cmd' },
{ editor: 'windsurf', command: 'windsurf', win32Command: 'windsurf' },
{ editor: 'cursor', command: 'cursor', win32Command: 'cursor' },
{ editor: 'vim', command: 'vim', win32Command: 'vim' },
{ editor: 'neovim', command: 'nvim', win32Command: 'nvim' },
{ editor: 'zed', command: 'zed', win32Command: 'zed' },
];
for (const { editor, command, win32Command } of testCases) {
describe(`${editor}`, () => {
it(`should return true if "${command}" command exists on non-windows`, () => {
Object.defineProperty(process, 'platform', { value: 'linux' });
(execSync as Mock).mockReturnValue(
Buffer.from(`/usr/bin/${command}`),
);
expect(checkHasEditorType(editor)).toBe(true);
expect(execSync).toHaveBeenCalledWith(`command -v ${command}`, {
stdio: 'ignore',
});
});
it(`should return false if "${command}" command does not exist on non-windows`, () => {
Object.defineProperty(process, 'platform', { value: 'linux' });
(execSync as Mock).mockImplementation(() => {
throw new Error();
});
expect(checkHasEditorType(editor)).toBe(false);
});
it(`should return true if "${win32Command}" command exists on windows`, () => {
Object.defineProperty(process, 'platform', { value: 'win32' });
(execSync as Mock).mockReturnValue(
Buffer.from(`C:\\Program Files\\...\\${win32Command}`),
);
expect(checkHasEditorType(editor)).toBe(true);
expect(execSync).toHaveBeenCalledWith(`where.exe ${win32Command}`, {
stdio: 'ignore',
});
});
it(`should return false if "${win32Command}" command does not exist on windows`, () => {
Object.defineProperty(process, 'platform', { value: 'win32' });
(execSync as Mock).mockImplementation(() => {
throw new Error();
});
expect(checkHasEditorType(editor)).toBe(false);
});
});
}
});
describe('getDiffCommand', () => {
const guiEditors: Array<{
editor: EditorType;
command: string;
win32Command: string;
}> = [
{ editor: 'vscode', command: 'code', win32Command: 'code.cmd' },
{ editor: 'vscodium', command: 'codium', win32Command: 'codium.cmd' },
{ editor: 'windsurf', command: 'windsurf', win32Command: 'windsurf' },
{ editor: 'cursor', command: 'cursor', win32Command: 'cursor' },
{ editor: 'zed', command: 'zed', win32Command: 'zed' },
];
for (const { editor, command, win32Command } of guiEditors) {
it(`should return the correct command for ${editor} on non-windows`, () => {
Object.defineProperty(process, 'platform', { value: 'linux' });
const diffCommand = getDiffCommand('old.txt', 'new.txt', editor);
expect(diffCommand).toEqual({
command,
args: ['--wait', '--diff', 'old.txt', 'new.txt'],
});
});
it(`should return the correct command for ${editor} on windows`, () => {
Object.defineProperty(process, 'platform', { value: 'win32' });
const diffCommand = getDiffCommand('old.txt', 'new.txt', editor);
expect(diffCommand).toEqual({
command: win32Command,
args: ['--wait', '--diff', 'old.txt', 'new.txt'],
});
});
}
const terminalEditors: Array<{
editor: EditorType;
command: string;
}> = [
{ editor: 'vim', command: 'vim' },
{ editor: 'neovim', command: 'nvim' },
];
for (const { editor, command } of terminalEditors) {
it(`should return the correct command for ${editor}`, () => {
const diffCommand = getDiffCommand('old.txt', 'new.txt', editor);
expect(diffCommand).toEqual({
command,
args: [
'-d',
'-i',
'NONE',
'-c',
'wincmd h | set readonly | wincmd l',
'-c',
'highlight DiffAdd cterm=bold ctermbg=22 guibg=#005f00 | highlight DiffChange cterm=bold ctermbg=24 guibg=#005f87 | highlight DiffText ctermbg=21 guibg=#0000af | highlight DiffDelete ctermbg=52 guibg=#5f0000',
'-c',
'set showtabline=2 | set tabline=[Instructions]\\ :wqa(save\\ &\\ quit)\\ \\|\\ i/esc(toggle\\ edit\\ mode)',
'-c',
'wincmd h | setlocal statusline=OLD\\ FILE',
'-c',
'wincmd l | setlocal statusline=%#StatusBold#NEW\\ FILE\\ :wqa(save\\ &\\ quit)\\ \\|\\ i/esc(toggle\\ edit\\ mode)',
'-c',
'autocmd WinClosed * wqa',
'old.txt',
'new.txt',
],
});
});
}
it('should return null for an unsupported editor', () => {
// @ts-expect-error Testing unsupported editor
const command = getDiffCommand('old.txt', 'new.txt', 'foobar');
expect(command).toBeNull();
});
});
describe('openDiff', () => {
const spawnEditors: EditorType[] = [
'vscode',
'vscodium',
'windsurf',
'cursor',
'zed',
];
for (const editor of spawnEditors) {
it(`should call spawn for ${editor}`, async () => {
const mockSpawn = {
on: vi.fn((event, cb) => {
if (event === 'close') {
cb(0);
}
}),
};
(spawn as Mock).mockReturnValue(mockSpawn);
await openDiff('old.txt', 'new.txt', editor);
const diffCommand = getDiffCommand('old.txt', 'new.txt', editor)!;
expect(spawn).toHaveBeenCalledWith(
diffCommand.command,
diffCommand.args,
{
stdio: 'inherit',
shell: true,
},
);
expect(mockSpawn.on).toHaveBeenCalledWith(
'close',
expect.any(Function),
);
expect(mockSpawn.on).toHaveBeenCalledWith(
'error',
expect.any(Function),
);
});
it(`should reject if spawn for ${editor} fails`, async () => {
const mockError = new Error('spawn error');
const mockSpawn = {
on: vi.fn((event, cb) => {
if (event === 'error') {
cb(mockError);
}
}),
};
(spawn as Mock).mockReturnValue(mockSpawn);
await expect(openDiff('old.txt', 'new.txt', editor)).rejects.toThrow(
'spawn error',
);
});
it(`should reject if ${editor} exits with non-zero code`, async () => {
const mockSpawn = {
on: vi.fn((event, cb) => {
if (event === 'close') {
cb(1);
}
}),
};
(spawn as Mock).mockReturnValue(mockSpawn);
await expect(openDiff('old.txt', 'new.txt', editor)).rejects.toThrow(
`${editor} exited with code 1`,
);
});
}
const execSyncEditors: EditorType[] = ['vim', 'neovim'];
for (const editor of execSyncEditors) {
it(`should call execSync for ${editor} on non-windows`, async () => {
Object.defineProperty(process, 'platform', { value: 'linux' });
await openDiff('old.txt', 'new.txt', editor);
expect(execSync).toHaveBeenCalledTimes(1);
const diffCommand = getDiffCommand('old.txt', 'new.txt', editor)!;
const expectedCommand = `${
diffCommand.command
} ${diffCommand.args.map((arg) => `"${arg}"`).join(' ')}`;
expect(execSync).toHaveBeenCalledWith(expectedCommand, {
stdio: 'inherit',
encoding: 'utf8',
});
});
it(`should call execSync for ${editor} on windows`, async () => {
Object.defineProperty(process, 'platform', { value: 'win32' });
await openDiff('old.txt', 'new.txt', editor);
expect(execSync).toHaveBeenCalledTimes(1);
const diffCommand = getDiffCommand('old.txt', 'new.txt', editor)!;
const expectedCommand = `${diffCommand.command} ${diffCommand.args.join(
' ',
)}`;
expect(execSync).toHaveBeenCalledWith(expectedCommand, {
stdio: 'inherit',
encoding: 'utf8',
});
});
}
it('should log an error if diff command is not available', async () => {
const consoleErrorSpy = vi
.spyOn(console, 'error')
.mockImplementation(() => {});
// @ts-expect-error Testing unsupported editor
await openDiff('old.txt', 'new.txt', 'foobar');
expect(consoleErrorSpy).toHaveBeenCalledWith(
'No diff tool available. Install a supported editor.',
);
});
});
describe('allowEditorTypeInSandbox', () => {
it('should allow vim in sandbox mode', () => {
process.env.SANDBOX = 'sandbox';
expect(allowEditorTypeInSandbox('vim')).toBe(true);
});
it('should allow vim when not in sandbox mode', () => {
expect(allowEditorTypeInSandbox('vim')).toBe(true);
});
it('should allow neovim in sandbox mode', () => {
process.env.SANDBOX = 'sandbox';
expect(allowEditorTypeInSandbox('neovim')).toBe(true);
});
it('should allow neovim when not in sandbox mode', () => {
expect(allowEditorTypeInSandbox('neovim')).toBe(true);
});
const guiEditors: EditorType[] = [
'vscode',
'vscodium',
'windsurf',
'cursor',
'zed',
];
for (const editor of guiEditors) {
it(`should not allow ${editor} in sandbox mode`, () => {
process.env.SANDBOX = 'sandbox';
expect(allowEditorTypeInSandbox(editor)).toBe(false);
});
it(`should allow ${editor} when not in sandbox mode`, () => {
expect(allowEditorTypeInSandbox(editor)).toBe(true);
});
}
});
describe('isEditorAvailable', () => {
it('should return false for undefined editor', () => {
expect(isEditorAvailable(undefined)).toBe(false);
});
it('should return false for empty string editor', () => {
expect(isEditorAvailable('')).toBe(false);
});
it('should return false for invalid editor type', () => {
expect(isEditorAvailable('invalid-editor')).toBe(false);
});
it('should return true for vscode when installed and not in sandbox mode', () => {
(execSync as Mock).mockReturnValue(Buffer.from('/usr/bin/code'));
expect(isEditorAvailable('vscode')).toBe(true);
});
it('should return false for vscode when not installed and not in sandbox mode', () => {
(execSync as Mock).mockImplementation(() => {
throw new Error();
});
expect(isEditorAvailable('vscode')).toBe(false);
});
it('should return false for vscode when installed and in sandbox mode', () => {
(execSync as Mock).mockReturnValue(Buffer.from('/usr/bin/code'));
process.env.SANDBOX = 'sandbox';
expect(isEditorAvailable('vscode')).toBe(false);
});
it('should return true for vim when installed and in sandbox mode', () => {
(execSync as Mock).mockReturnValue(Buffer.from('/usr/bin/vim'));
process.env.SANDBOX = 'sandbox';
expect(isEditorAvailable('vim')).toBe(true);
});
it('should return true for neovim when installed and in sandbox mode', () => {
(execSync as Mock).mockReturnValue(Buffer.from('/usr/bin/nvim'));
process.env.SANDBOX = 'sandbox';
expect(isEditorAvailable('neovim')).toBe(true);
});
});
});

View File

@@ -0,0 +1,201 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { execSync, spawn } from 'child_process';
export type EditorType =
| 'vscode'
| 'vscodium'
| 'windsurf'
| 'cursor'
| 'vim'
| 'neovim'
| 'zed';
function isValidEditorType(editor: string): editor is EditorType {
return [
'vscode',
'vscodium',
'windsurf',
'cursor',
'vim',
'neovim',
'zed',
].includes(editor);
}
interface DiffCommand {
command: string;
args: string[];
}
function commandExists(cmd: string): boolean {
try {
execSync(
process.platform === 'win32' ? `where.exe ${cmd}` : `command -v ${cmd}`,
{ stdio: 'ignore' },
);
return true;
} catch {
return false;
}
}
const editorCommands: Record<EditorType, { win32: string; default: string }> = {
vscode: { win32: 'code.cmd', default: 'code' },
vscodium: { win32: 'codium.cmd', default: 'codium' },
windsurf: { win32: 'windsurf', default: 'windsurf' },
cursor: { win32: 'cursor', default: 'cursor' },
vim: { win32: 'vim', default: 'vim' },
neovim: { win32: 'nvim', default: 'nvim' },
zed: { win32: 'zed', default: 'zed' },
};
export function checkHasEditorType(editor: EditorType): boolean {
const commandConfig = editorCommands[editor];
const command =
process.platform === 'win32' ? commandConfig.win32 : commandConfig.default;
return commandExists(command);
}
export function allowEditorTypeInSandbox(editor: EditorType): boolean {
const notUsingSandbox = !process.env.SANDBOX;
if (['vscode', 'vscodium', 'windsurf', 'cursor', 'zed'].includes(editor)) {
return notUsingSandbox;
}
return true;
}
/**
* Check if the editor is valid and can be used.
* Returns false if preferred editor is not set / invalid / not available / not allowed in sandbox.
*/
export function isEditorAvailable(editor: string | undefined): boolean {
if (editor && isValidEditorType(editor)) {
return checkHasEditorType(editor) && allowEditorTypeInSandbox(editor);
}
return false;
}
/**
* Get the diff command for a specific editor.
*/
export function getDiffCommand(
oldPath: string,
newPath: string,
editor: EditorType,
): DiffCommand | null {
if (!isValidEditorType(editor)) {
return null;
}
const commandConfig = editorCommands[editor];
const command =
process.platform === 'win32' ? commandConfig.win32 : commandConfig.default;
switch (editor) {
case 'vscode':
case 'vscodium':
case 'windsurf':
case 'cursor':
case 'zed':
return { command, args: ['--wait', '--diff', oldPath, newPath] };
case 'vim':
case 'neovim':
return {
command,
args: [
'-d',
// skip viminfo file to avoid E138 errors
'-i',
'NONE',
// make the left window read-only and the right window editable
'-c',
'wincmd h | set readonly | wincmd l',
// set up colors for diffs
'-c',
'highlight DiffAdd cterm=bold ctermbg=22 guibg=#005f00 | highlight DiffChange cterm=bold ctermbg=24 guibg=#005f87 | highlight DiffText ctermbg=21 guibg=#0000af | highlight DiffDelete ctermbg=52 guibg=#5f0000',
// Show helpful messages
'-c',
'set showtabline=2 | set tabline=[Instructions]\\ :wqa(save\\ &\\ quit)\\ \\|\\ i/esc(toggle\\ edit\\ mode)',
'-c',
'wincmd h | setlocal statusline=OLD\\ FILE',
'-c',
'wincmd l | setlocal statusline=%#StatusBold#NEW\\ FILE\\ :wqa(save\\ &\\ quit)\\ \\|\\ i/esc(toggle\\ edit\\ mode)',
// Auto close all windows when one is closed
'-c',
'autocmd WinClosed * wqa',
oldPath,
newPath,
],
};
default:
return null;
}
}
/**
* Opens a diff tool to compare two files.
* Terminal-based editors by default blocks parent process until the editor exits.
* GUI-based editors require args such as "--wait" to block parent process.
*/
export async function openDiff(
oldPath: string,
newPath: string,
editor: EditorType,
): Promise<void> {
const diffCommand = getDiffCommand(oldPath, newPath, editor);
if (!diffCommand) {
console.error('No diff tool available. Install a supported editor.');
return;
}
try {
switch (editor) {
case 'vscode':
case 'vscodium':
case 'windsurf':
case 'cursor':
case 'zed':
// Use spawn for GUI-based editors to avoid blocking the entire process
return new Promise((resolve, reject) => {
const childProcess = spawn(diffCommand.command, diffCommand.args, {
stdio: 'inherit',
shell: true,
});
childProcess.on('close', (code) => {
if (code === 0) {
resolve();
} else {
reject(new Error(`${editor} exited with code ${code}`));
}
});
childProcess.on('error', (error) => {
reject(error);
});
});
case 'vim':
case 'neovim': {
// Use execSync for terminal-based editors
const command =
process.platform === 'win32'
? `${diffCommand.command} ${diffCommand.args.join(' ')}`
: `${diffCommand.command} ${diffCommand.args.map((arg) => `"${arg}"`).join(' ')}`;
execSync(command, {
stdio: 'inherit',
encoding: 'utf8',
});
break;
}
default:
throw new Error(`Unsupported editor: ${editor}`);
}
} catch (error) {
console.error(error);
}
}

View File

@@ -0,0 +1,219 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { describe, it, expect, vi, beforeEach, afterEach, Mock } from 'vitest';
// Use a type alias for SpyInstance as it's not directly exported
type SpyInstance = ReturnType<typeof vi.spyOn>;
import { reportError } from './errorReporting.js';
import fs from 'node:fs/promises';
import os from 'node:os';
// Mock dependencies
vi.mock('node:fs/promises');
vi.mock('node:os');
describe('reportError', () => {
let consoleErrorSpy: SpyInstance;
const MOCK_TMP_DIR = '/tmp';
const MOCK_TIMESTAMP = '2025-01-01T00-00-00-000Z';
beforeEach(() => {
vi.resetAllMocks();
consoleErrorSpy = vi.spyOn(console, 'error').mockImplementation(() => {});
(os.tmpdir as Mock).mockReturnValue(MOCK_TMP_DIR);
vi.spyOn(Date.prototype, 'toISOString').mockReturnValue(MOCK_TIMESTAMP);
});
afterEach(() => {
vi.restoreAllMocks();
});
const getExpectedReportPath = (type: string) =>
`${MOCK_TMP_DIR}/gemini-client-error-${type}-${MOCK_TIMESTAMP}.json`;
it('should generate a report and log the path', async () => {
const error = new Error('Test error');
error.stack = 'Test stack';
const baseMessage = 'An error occurred.';
const context = { data: 'test context' };
const type = 'test-type';
const expectedReportPath = getExpectedReportPath(type);
(fs.writeFile as Mock).mockResolvedValue(undefined);
await reportError(error, baseMessage, context, type);
expect(os.tmpdir).toHaveBeenCalledTimes(1);
expect(fs.writeFile).toHaveBeenCalledWith(
expectedReportPath,
JSON.stringify(
{
error: { message: 'Test error', stack: error.stack },
context,
},
null,
2,
),
);
expect(consoleErrorSpy).toHaveBeenCalledWith(
`${baseMessage} Full report available at: ${expectedReportPath}`,
);
});
it('should handle errors that are plain objects with a message property', async () => {
const error = { message: 'Test plain object error' };
const baseMessage = 'Another error.';
const type = 'general';
const expectedReportPath = getExpectedReportPath(type);
(fs.writeFile as Mock).mockResolvedValue(undefined);
await reportError(error, baseMessage);
expect(fs.writeFile).toHaveBeenCalledWith(
expectedReportPath,
JSON.stringify(
{
error: { message: 'Test plain object error' },
},
null,
2,
),
);
expect(consoleErrorSpy).toHaveBeenCalledWith(
`${baseMessage} Full report available at: ${expectedReportPath}`,
);
});
it('should handle string errors', async () => {
const error = 'Just a string error';
const baseMessage = 'String error occurred.';
const type = 'general';
const expectedReportPath = getExpectedReportPath(type);
(fs.writeFile as Mock).mockResolvedValue(undefined);
await reportError(error, baseMessage);
expect(fs.writeFile).toHaveBeenCalledWith(
expectedReportPath,
JSON.stringify(
{
error: { message: 'Just a string error' },
},
null,
2,
),
);
expect(consoleErrorSpy).toHaveBeenCalledWith(
`${baseMessage} Full report available at: ${expectedReportPath}`,
);
});
it('should log fallback message if writing report fails', async () => {
const error = new Error('Main error');
const baseMessage = 'Failed operation.';
const writeError = new Error('Failed to write file');
const context = ['some context'];
const type = 'general';
const expectedReportPath = getExpectedReportPath(type);
(fs.writeFile as Mock).mockRejectedValue(writeError);
await reportError(error, baseMessage, context, type);
expect(fs.writeFile).toHaveBeenCalledWith(
expectedReportPath,
expect.any(String),
); // It still tries to write
expect(consoleErrorSpy).toHaveBeenCalledWith(
`${baseMessage} Additionally, failed to write detailed error report:`,
writeError,
);
expect(consoleErrorSpy).toHaveBeenCalledWith(
'Original error that triggered report generation:',
error,
);
expect(consoleErrorSpy).toHaveBeenCalledWith('Original context:', context);
});
it('should handle stringification failure of report content (e.g. BigInt in context)', async () => {
const error = new Error('Main error');
error.stack = 'Main stack';
const baseMessage = 'Failed operation with BigInt.';
const context = { a: BigInt(1) }; // BigInt cannot be stringified by JSON.stringify
const type = 'bigint-fail';
const stringifyError = new TypeError(
'Do not know how to serialize a BigInt',
);
const expectedMinimalReportPath = getExpectedReportPath(type);
// Simulate JSON.stringify throwing an error for the full report
const originalJsonStringify = JSON.stringify;
let callCount = 0;
vi.spyOn(JSON, 'stringify').mockImplementation((value, replacer, space) => {
callCount++;
if (callCount === 1) {
// First call is for the full report content
throw stringifyError;
}
// Subsequent calls (for minimal report) should succeed
return originalJsonStringify(value, replacer, space);
});
(fs.writeFile as Mock).mockResolvedValue(undefined); // Mock for the minimal report write
await reportError(error, baseMessage, context, type);
expect(consoleErrorSpy).toHaveBeenCalledWith(
`${baseMessage} Could not stringify report content (likely due to context):`,
stringifyError,
);
expect(consoleErrorSpy).toHaveBeenCalledWith(
'Original error that triggered report generation:',
error,
);
expect(consoleErrorSpy).toHaveBeenCalledWith(
'Original context could not be stringified or included in report.',
);
// Check that it attempts to write a minimal report
expect(fs.writeFile).toHaveBeenCalledWith(
expectedMinimalReportPath,
originalJsonStringify(
{ error: { message: error.message, stack: error.stack } },
null,
2,
),
);
expect(consoleErrorSpy).toHaveBeenCalledWith(
`${baseMessage} Partial report (excluding context) available at: ${expectedMinimalReportPath}`,
);
});
it('should generate a report without context if context is not provided', async () => {
const error = new Error('Error without context');
error.stack = 'No context stack';
const baseMessage = 'Simple error.';
const type = 'general';
const expectedReportPath = getExpectedReportPath(type);
(fs.writeFile as Mock).mockResolvedValue(undefined);
await reportError(error, baseMessage, undefined, type);
expect(fs.writeFile).toHaveBeenCalledWith(
expectedReportPath,
JSON.stringify(
{
error: { message: 'Error without context', stack: error.stack },
},
null,
2,
),
);
expect(consoleErrorSpy).toHaveBeenCalledWith(
`${baseMessage} Full report available at: ${expectedReportPath}`,
);
});
});

View File

@@ -0,0 +1,117 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import fs from 'node:fs/promises';
import os from 'node:os';
import path from 'node:path';
import { Content } from '@google/genai';
interface ErrorReportData {
error: { message: string; stack?: string } | { message: string };
context?: unknown;
additionalInfo?: Record<string, unknown>;
}
/**
* Generates an error report, writes it to a temporary file, and logs information to console.error.
* @param error The error object.
* @param context The relevant context (e.g., chat history, request contents).
* @param type A string to identify the type of error (e.g., 'startChat', 'generateJson-api').
* @param baseMessage The initial message to log to console.error before the report path.
*/
export async function reportError(
error: Error | unknown,
baseMessage: string,
context?: Content[] | Record<string, unknown> | unknown[],
type = 'general',
): Promise<void> {
const timestamp = new Date().toISOString().replace(/[:.]/g, '-');
const reportFileName = `gemini-client-error-${type}-${timestamp}.json`;
const reportPath = path.join(os.tmpdir(), reportFileName);
let errorToReport: { message: string; stack?: string };
if (error instanceof Error) {
errorToReport = { message: error.message, stack: error.stack };
} else if (
typeof error === 'object' &&
error !== null &&
'message' in error
) {
errorToReport = {
message: String((error as { message: unknown }).message),
};
} else {
errorToReport = { message: String(error) };
}
const reportContent: ErrorReportData = { error: errorToReport };
if (context) {
reportContent.context = context;
}
let stringifiedReportContent: string;
try {
stringifiedReportContent = JSON.stringify(reportContent, null, 2);
} catch (stringifyError) {
// This can happen if context contains something like BigInt
console.error(
`${baseMessage} Could not stringify report content (likely due to context):`,
stringifyError,
);
console.error('Original error that triggered report generation:', error);
if (context) {
console.error(
'Original context could not be stringified or included in report.',
);
}
// Fallback: try to report only the error if context was the issue
try {
const minimalReportContent = { error: errorToReport };
stringifiedReportContent = JSON.stringify(minimalReportContent, null, 2);
// Still try to write the minimal report
await fs.writeFile(reportPath, stringifiedReportContent);
console.error(
`${baseMessage} Partial report (excluding context) available at: ${reportPath}`,
);
} catch (minimalWriteError) {
console.error(
`${baseMessage} Failed to write even a minimal error report:`,
minimalWriteError,
);
}
return;
}
try {
await fs.writeFile(reportPath, stringifiedReportContent);
console.error(`${baseMessage} Full report available at: ${reportPath}`);
} catch (writeError) {
console.error(
`${baseMessage} Additionally, failed to write detailed error report:`,
writeError,
);
// Log the original error as a fallback if report writing fails
console.error('Original error that triggered report generation:', error);
if (context) {
// Context was stringifiable, but writing the file failed.
// We already have stringifiedReportContent, but it might be too large for console.
// So, we try to log the original context object, and if that fails, its stringified version (truncated).
try {
console.error('Original context:', context);
} catch {
try {
console.error(
'Original context (stringified, truncated):',
JSON.stringify(context).substring(0, 1000),
);
} catch {
console.error('Original context could not be logged or stringified.');
}
}
}
}
}

View File

@@ -0,0 +1,62 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { GaxiosError } from 'gaxios';
export function isNodeError(error: unknown): error is NodeJS.ErrnoException {
return error instanceof Error && 'code' in error;
}
export function getErrorMessage(error: unknown): string {
if (error instanceof Error) {
return error.message;
}
try {
return String(error);
} catch {
return 'Failed to get error details';
}
}
export class ForbiddenError extends Error {}
export class UnauthorizedError extends Error {}
export class BadRequestError extends Error {}
interface ResponseData {
error?: {
code?: number;
message?: string;
};
}
export function toFriendlyError(error: unknown): unknown {
if (error instanceof GaxiosError) {
const data = parseResponseData(error);
if (data.error && data.error.message && data.error.code) {
switch (data.error.code) {
case 400:
return new BadRequestError(data.error.message);
case 401:
return new UnauthorizedError(data.error.message);
case 403:
// It's import to pass the message here since it might
// explain the cause like "the cloud project you're
// using doesn't have code assist enabled".
return new ForbiddenError(data.error.message);
default:
}
}
}
return error;
}
function parseResponseData(error: GaxiosError): ResponseData {
// Inexplicably, Gaxios sometimes doesn't JSONify the response data.
if (typeof error.response?.data === 'string') {
return JSON.parse(error.response?.data) as ResponseData;
}
return typeof error.response?.data as ResponseData;
}

View File

@@ -0,0 +1,57 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { getErrorMessage, isNodeError } from './errors.js';
import { URL } from 'url';
const PRIVATE_IP_RANGES = [
/^10\./,
/^127\./,
/^172\.(1[6-9]|2[0-9]|3[0-1])\./,
/^192\.168\./,
/^::1$/,
/^fc00:/,
/^fe80:/,
];
export class FetchError extends Error {
constructor(
message: string,
public code?: string,
) {
super(message);
this.name = 'FetchError';
}
}
export function isPrivateIp(url: string): boolean {
try {
const hostname = new URL(url).hostname;
return PRIVATE_IP_RANGES.some((range) => range.test(hostname));
} catch (_e) {
return false;
}
}
export async function fetchWithTimeout(
url: string,
timeout: number,
): Promise<Response> {
const controller = new AbortController();
const timeoutId = setTimeout(() => controller.abort(), timeout);
try {
const response = await fetch(url, { signal: controller.signal });
return response;
} catch (error) {
if (isNodeError(error) && error.code === 'ABORT_ERR') {
throw new FetchError(`Request timed out after ${timeout}ms`, 'ETIMEDOUT');
}
throw new FetchError(getErrorMessage(error));
} finally {
clearTimeout(timeoutId);
}
}

View File

@@ -0,0 +1,489 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import {
describe,
it,
expect,
vi,
beforeEach,
afterEach,
type Mock,
} from 'vitest';
import * as actualNodeFs from 'node:fs'; // For setup/teardown
import fsPromises from 'node:fs/promises';
import path from 'node:path';
import os from 'node:os';
import mime from 'mime-types';
import {
isWithinRoot,
isBinaryFile,
detectFileType,
processSingleFileContent,
} from './fileUtils.js';
vi.mock('mime-types', () => ({
default: { lookup: vi.fn() },
lookup: vi.fn(),
}));
const mockMimeLookup = mime.lookup as Mock;
describe('fileUtils', () => {
let tempRootDir: string;
const originalProcessCwd = process.cwd;
let testTextFilePath: string;
let testImageFilePath: string;
let testPdfFilePath: string;
let testBinaryFilePath: string;
let nonExistentFilePath: string;
let directoryPath: string;
beforeEach(() => {
vi.resetAllMocks(); // Reset all mocks, including mime.lookup
tempRootDir = actualNodeFs.mkdtempSync(
path.join(os.tmpdir(), 'fileUtils-test-'),
);
process.cwd = vi.fn(() => tempRootDir); // Mock cwd if necessary for relative path logic within tests
testTextFilePath = path.join(tempRootDir, 'test.txt');
testImageFilePath = path.join(tempRootDir, 'image.png');
testPdfFilePath = path.join(tempRootDir, 'document.pdf');
testBinaryFilePath = path.join(tempRootDir, 'app.exe');
nonExistentFilePath = path.join(tempRootDir, 'notfound.txt');
directoryPath = path.join(tempRootDir, 'subdir');
actualNodeFs.mkdirSync(directoryPath, { recursive: true }); // Ensure subdir exists
});
afterEach(() => {
if (actualNodeFs.existsSync(tempRootDir)) {
actualNodeFs.rmSync(tempRootDir, { recursive: true, force: true });
}
process.cwd = originalProcessCwd;
vi.restoreAllMocks(); // Restore any spies
});
describe('isWithinRoot', () => {
const root = path.resolve('/project/root');
it('should return true for paths directly within the root', () => {
expect(isWithinRoot(path.join(root, 'file.txt'), root)).toBe(true);
expect(isWithinRoot(path.join(root, 'subdir', 'file.txt'), root)).toBe(
true,
);
});
it('should return true for the root path itself', () => {
expect(isWithinRoot(root, root)).toBe(true);
});
it('should return false for paths outside the root', () => {
expect(
isWithinRoot(path.resolve('/project/other', 'file.txt'), root),
).toBe(false);
expect(isWithinRoot(path.resolve('/unrelated', 'file.txt'), root)).toBe(
false,
);
});
it('should return false for paths that only partially match the root prefix', () => {
expect(
isWithinRoot(
path.resolve('/project/root-but-actually-different'),
root,
),
).toBe(false);
});
it('should handle paths with trailing slashes correctly', () => {
expect(isWithinRoot(path.join(root, 'file.txt') + path.sep, root)).toBe(
true,
);
expect(isWithinRoot(root + path.sep, root)).toBe(true);
});
it('should handle different path separators (POSIX vs Windows)', () => {
const posixRoot = '/project/root';
const posixPathInside = '/project/root/file.txt';
const posixPathOutside = '/project/other/file.txt';
expect(isWithinRoot(posixPathInside, posixRoot)).toBe(true);
expect(isWithinRoot(posixPathOutside, posixRoot)).toBe(false);
});
it('should return false for a root path that is a sub-path of the path to check', () => {
const pathToCheck = path.resolve('/project/root/sub');
const rootSub = path.resolve('/project/root');
expect(isWithinRoot(pathToCheck, rootSub)).toBe(true);
const pathToCheckSuper = path.resolve('/project/root');
const rootSuper = path.resolve('/project/root/sub');
expect(isWithinRoot(pathToCheckSuper, rootSuper)).toBe(false);
});
});
describe('isBinaryFile', () => {
let filePathForBinaryTest: string;
beforeEach(() => {
filePathForBinaryTest = path.join(tempRootDir, 'binaryCheck.tmp');
});
afterEach(() => {
if (actualNodeFs.existsSync(filePathForBinaryTest)) {
actualNodeFs.unlinkSync(filePathForBinaryTest);
}
});
it('should return false for an empty file', () => {
actualNodeFs.writeFileSync(filePathForBinaryTest, '');
expect(isBinaryFile(filePathForBinaryTest)).toBe(false);
});
it('should return false for a typical text file', () => {
actualNodeFs.writeFileSync(
filePathForBinaryTest,
'Hello, world!\nThis is a test file with normal text content.',
);
expect(isBinaryFile(filePathForBinaryTest)).toBe(false);
});
it('should return true for a file with many null bytes', () => {
const binaryContent = Buffer.from([
0x48, 0x65, 0x00, 0x6c, 0x6f, 0x00, 0x00, 0x00, 0x00, 0x00,
]); // "He\0llo\0\0\0\0\0"
actualNodeFs.writeFileSync(filePathForBinaryTest, binaryContent);
expect(isBinaryFile(filePathForBinaryTest)).toBe(true);
});
it('should return true for a file with high percentage of non-printable ASCII', () => {
const binaryContent = Buffer.from([
0x41, 0x42, 0x01, 0x02, 0x03, 0x04, 0x05, 0x43, 0x44, 0x06,
]); // AB\x01\x02\x03\x04\x05CD\x06
actualNodeFs.writeFileSync(filePathForBinaryTest, binaryContent);
expect(isBinaryFile(filePathForBinaryTest)).toBe(true);
});
it('should return false if file access fails (e.g., ENOENT)', () => {
// Ensure the file does not exist
if (actualNodeFs.existsSync(filePathForBinaryTest)) {
actualNodeFs.unlinkSync(filePathForBinaryTest);
}
expect(isBinaryFile(filePathForBinaryTest)).toBe(false);
});
});
describe('detectFileType', () => {
let filePathForDetectTest: string;
beforeEach(() => {
filePathForDetectTest = path.join(tempRootDir, 'detectType.tmp');
// Default: create as a text file for isBinaryFile fallback
actualNodeFs.writeFileSync(filePathForDetectTest, 'Plain text content');
});
afterEach(() => {
if (actualNodeFs.existsSync(filePathForDetectTest)) {
actualNodeFs.unlinkSync(filePathForDetectTest);
}
vi.restoreAllMocks(); // Restore spies on actualNodeFs
});
it('should detect typescript type by extension (ts)', () => {
expect(detectFileType('file.ts')).toBe('text');
expect(detectFileType('file.test.ts')).toBe('text');
});
it('should detect image type by extension (png)', () => {
mockMimeLookup.mockReturnValueOnce('image/png');
expect(detectFileType('file.png')).toBe('image');
});
it('should detect image type by extension (jpeg)', () => {
mockMimeLookup.mockReturnValueOnce('image/jpeg');
expect(detectFileType('file.jpg')).toBe('image');
});
it('should detect svg type by extension', () => {
expect(detectFileType('image.svg')).toBe('svg');
expect(detectFileType('image.icon.svg')).toBe('svg');
});
it('should detect pdf type by extension', () => {
mockMimeLookup.mockReturnValueOnce('application/pdf');
expect(detectFileType('file.pdf')).toBe('pdf');
});
it('should detect audio type by extension', () => {
mockMimeLookup.mockReturnValueOnce('audio/mpeg');
expect(detectFileType('song.mp3')).toBe('audio');
});
it('should detect video type by extension', () => {
mockMimeLookup.mockReturnValueOnce('video/mp4');
expect(detectFileType('movie.mp4')).toBe('video');
});
it('should detect known binary extensions as binary (e.g. .zip)', () => {
mockMimeLookup.mockReturnValueOnce('application/zip');
expect(detectFileType('archive.zip')).toBe('binary');
});
it('should detect known binary extensions as binary (e.g. .exe)', () => {
mockMimeLookup.mockReturnValueOnce('application/octet-stream'); // Common for .exe
expect(detectFileType('app.exe')).toBe('binary');
});
it('should use isBinaryFile for unknown extensions and detect as binary', () => {
mockMimeLookup.mockReturnValueOnce(false); // Unknown mime type
// Create a file that isBinaryFile will identify as binary
const binaryContent = Buffer.from([
0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a,
]);
actualNodeFs.writeFileSync(filePathForDetectTest, binaryContent);
expect(detectFileType(filePathForDetectTest)).toBe('binary');
});
it('should default to text if mime type is unknown and content is not binary', () => {
mockMimeLookup.mockReturnValueOnce(false); // Unknown mime type
// filePathForDetectTest is already a text file by default from beforeEach
expect(detectFileType(filePathForDetectTest)).toBe('text');
});
});
describe('processSingleFileContent', () => {
beforeEach(() => {
// Ensure files exist for statSync checks before readFile might be mocked
if (actualNodeFs.existsSync(testTextFilePath))
actualNodeFs.unlinkSync(testTextFilePath);
if (actualNodeFs.existsSync(testImageFilePath))
actualNodeFs.unlinkSync(testImageFilePath);
if (actualNodeFs.existsSync(testPdfFilePath))
actualNodeFs.unlinkSync(testPdfFilePath);
if (actualNodeFs.existsSync(testBinaryFilePath))
actualNodeFs.unlinkSync(testBinaryFilePath);
});
it('should read a text file successfully', async () => {
const content = 'Line 1\\nLine 2\\nLine 3';
actualNodeFs.writeFileSync(testTextFilePath, content);
const result = await processSingleFileContent(
testTextFilePath,
tempRootDir,
);
expect(result.llmContent).toBe(content);
expect(result.returnDisplay).toBe('');
expect(result.error).toBeUndefined();
});
it('should handle file not found', async () => {
const result = await processSingleFileContent(
nonExistentFilePath,
tempRootDir,
);
expect(result.error).toContain('File not found');
expect(result.returnDisplay).toContain('File not found');
});
it('should handle read errors for text files', async () => {
actualNodeFs.writeFileSync(testTextFilePath, 'content'); // File must exist for initial statSync
const readError = new Error('Simulated read error');
vi.spyOn(fsPromises, 'readFile').mockRejectedValueOnce(readError);
const result = await processSingleFileContent(
testTextFilePath,
tempRootDir,
);
expect(result.error).toContain('Simulated read error');
expect(result.returnDisplay).toContain('Simulated read error');
});
it('should handle read errors for image/pdf files', async () => {
actualNodeFs.writeFileSync(testImageFilePath, 'content'); // File must exist
mockMimeLookup.mockReturnValue('image/png');
const readError = new Error('Simulated image read error');
vi.spyOn(fsPromises, 'readFile').mockRejectedValueOnce(readError);
const result = await processSingleFileContent(
testImageFilePath,
tempRootDir,
);
expect(result.error).toContain('Simulated image read error');
expect(result.returnDisplay).toContain('Simulated image read error');
});
it('should process an image file', async () => {
const fakePngData = Buffer.from('fake png data');
actualNodeFs.writeFileSync(testImageFilePath, fakePngData);
mockMimeLookup.mockReturnValue('image/png');
const result = await processSingleFileContent(
testImageFilePath,
tempRootDir,
);
expect(
(result.llmContent as { inlineData: unknown }).inlineData,
).toBeDefined();
expect(
(result.llmContent as { inlineData: { mimeType: string } }).inlineData
.mimeType,
).toBe('image/png');
expect(
(result.llmContent as { inlineData: { data: string } }).inlineData.data,
).toBe(fakePngData.toString('base64'));
expect(result.returnDisplay).toContain('Read image file: image.png');
});
it('should process a PDF file', async () => {
const fakePdfData = Buffer.from('fake pdf data');
actualNodeFs.writeFileSync(testPdfFilePath, fakePdfData);
mockMimeLookup.mockReturnValue('application/pdf');
const result = await processSingleFileContent(
testPdfFilePath,
tempRootDir,
);
expect(
(result.llmContent as { inlineData: unknown }).inlineData,
).toBeDefined();
expect(
(result.llmContent as { inlineData: { mimeType: string } }).inlineData
.mimeType,
).toBe('application/pdf');
expect(
(result.llmContent as { inlineData: { data: string } }).inlineData.data,
).toBe(fakePdfData.toString('base64'));
expect(result.returnDisplay).toContain('Read pdf file: document.pdf');
});
it('should read an SVG file as text when under 1MB', async () => {
const svgContent = `
<svg xmlns="http://www.w3.org/2000/svg" width="100" height="100">
<rect width="100" height="100" fill="blue" />
</svg>
`;
const testSvgFilePath = path.join(tempRootDir, 'test.svg');
actualNodeFs.writeFileSync(testSvgFilePath, svgContent, 'utf-8');
mockMimeLookup.mockReturnValue('image/svg+xml');
const result = await processSingleFileContent(
testSvgFilePath,
tempRootDir,
);
expect(result.llmContent).toBe(svgContent);
expect(result.returnDisplay).toContain('Read SVG as text');
});
it('should skip binary files', async () => {
actualNodeFs.writeFileSync(
testBinaryFilePath,
Buffer.from([0x00, 0x01, 0x02]),
);
mockMimeLookup.mockReturnValueOnce('application/octet-stream');
// isBinaryFile will operate on the real file.
const result = await processSingleFileContent(
testBinaryFilePath,
tempRootDir,
);
expect(result.llmContent).toContain(
'Cannot display content of binary file',
);
expect(result.returnDisplay).toContain('Skipped binary file: app.exe');
});
it('should handle path being a directory', async () => {
const result = await processSingleFileContent(directoryPath, tempRootDir);
expect(result.error).toContain('Path is a directory');
expect(result.returnDisplay).toContain('Path is a directory');
});
it('should paginate text files correctly (offset and limit)', async () => {
const lines = Array.from({ length: 20 }, (_, i) => `Line ${i + 1}`);
actualNodeFs.writeFileSync(testTextFilePath, lines.join('\n'));
const result = await processSingleFileContent(
testTextFilePath,
tempRootDir,
5,
5,
); // Read lines 6-10
const expectedContent = lines.slice(5, 10).join('\n');
expect(result.llmContent).toContain(expectedContent);
expect(result.llmContent).toContain(
'[File content truncated: showing lines 6-10 of 20 total lines. Use offset/limit parameters to view more.]',
);
expect(result.returnDisplay).toBe('(truncated)');
expect(result.isTruncated).toBe(true);
expect(result.originalLineCount).toBe(20);
expect(result.linesShown).toEqual([6, 10]);
});
it('should handle limit exceeding file length', async () => {
const lines = ['Line 1', 'Line 2'];
actualNodeFs.writeFileSync(testTextFilePath, lines.join('\n'));
const result = await processSingleFileContent(
testTextFilePath,
tempRootDir,
0,
10,
);
const expectedContent = lines.join('\n');
expect(result.llmContent).toBe(expectedContent);
expect(result.returnDisplay).toBe('');
expect(result.isTruncated).toBe(false);
expect(result.originalLineCount).toBe(2);
expect(result.linesShown).toEqual([1, 2]);
});
it('should truncate long lines in text files', async () => {
const longLine = 'a'.repeat(2500);
actualNodeFs.writeFileSync(
testTextFilePath,
`Short line\n${longLine}\nAnother short line`,
);
const result = await processSingleFileContent(
testTextFilePath,
tempRootDir,
);
expect(result.llmContent).toContain('Short line');
expect(result.llmContent).toContain(
longLine.substring(0, 2000) + '... [truncated]',
);
expect(result.llmContent).toContain('Another short line');
expect(result.llmContent).toContain(
'[File content partially truncated: some lines exceeded maximum length of 2000 characters.]',
);
expect(result.isTruncated).toBe(true);
});
it('should return an error if the file size exceeds 20MB', async () => {
// Create a file just over 20MB
const twentyOneMB = 21 * 1024 * 1024;
const buffer = Buffer.alloc(twentyOneMB, 0x61); // Fill with 'a'
actualNodeFs.writeFileSync(testTextFilePath, buffer);
const result = await processSingleFileContent(
testTextFilePath,
tempRootDir,
);
expect(result.error).toContain('File size exceeds the 20MB limit');
expect(result.returnDisplay).toContain(
'File size exceeds the 20MB limit',
);
expect(result.llmContent).toContain('File size exceeds the 20MB limit');
});
});
});

View File

@@ -0,0 +1,337 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import fs from 'fs';
import path from 'path';
import { PartUnion } from '@google/genai';
import mime from 'mime-types';
// Constants for text file processing
const DEFAULT_MAX_LINES_TEXT_FILE = 2000;
const MAX_LINE_LENGTH_TEXT_FILE = 2000;
// Default values for encoding and separator format
export const DEFAULT_ENCODING: BufferEncoding = 'utf-8';
/**
* Looks up the specific MIME type for a file path.
* @param filePath Path to the file.
* @returns The specific MIME type string (e.g., 'text/python', 'application/javascript') or undefined if not found or ambiguous.
*/
export function getSpecificMimeType(filePath: string): string | undefined {
const lookedUpMime = mime.lookup(filePath);
return typeof lookedUpMime === 'string' ? lookedUpMime : undefined;
}
/**
* Checks if a path is within a given root directory.
* @param pathToCheck The absolute path to check.
* @param rootDirectory The absolute root directory.
* @returns True if the path is within the root directory, false otherwise.
*/
export function isWithinRoot(
pathToCheck: string,
rootDirectory: string,
): boolean {
const normalizedPathToCheck = path.resolve(pathToCheck);
const normalizedRootDirectory = path.resolve(rootDirectory);
// Ensure the rootDirectory path ends with a separator for correct startsWith comparison,
// unless it's the root path itself (e.g., '/' or 'C:\').
const rootWithSeparator =
normalizedRootDirectory === path.sep ||
normalizedRootDirectory.endsWith(path.sep)
? normalizedRootDirectory
: normalizedRootDirectory + path.sep;
return (
normalizedPathToCheck === normalizedRootDirectory ||
normalizedPathToCheck.startsWith(rootWithSeparator)
);
}
/**
* Determines if a file is likely binary based on content sampling.
* @param filePath Path to the file.
* @returns True if the file appears to be binary.
*/
export function isBinaryFile(filePath: string): boolean {
try {
const fd = fs.openSync(filePath, 'r');
// Read up to 4KB or file size, whichever is smaller
const fileSize = fs.fstatSync(fd).size;
if (fileSize === 0) {
// Empty file is not considered binary for content checking
fs.closeSync(fd);
return false;
}
const bufferSize = Math.min(4096, fileSize);
const buffer = Buffer.alloc(bufferSize);
const bytesRead = fs.readSync(fd, buffer, 0, buffer.length, 0);
fs.closeSync(fd);
if (bytesRead === 0) return false;
let nonPrintableCount = 0;
for (let i = 0; i < bytesRead; i++) {
if (buffer[i] === 0) return true; // Null byte is a strong indicator
if (buffer[i] < 9 || (buffer[i] > 13 && buffer[i] < 32)) {
nonPrintableCount++;
}
}
// If >30% non-printable characters, consider it binary
return nonPrintableCount / bytesRead > 0.3;
} catch {
// If any error occurs (e.g. file not found, permissions),
// treat as not binary here; let higher-level functions handle existence/access errors.
return false;
}
}
/**
* Detects the type of file based on extension and content.
* @param filePath Path to the file.
* @returns 'text', 'image', 'pdf', 'audio', 'video', or 'binary'.
*/
export function detectFileType(
filePath: string,
): 'text' | 'image' | 'pdf' | 'audio' | 'video' | 'binary' | 'svg' {
const ext = path.extname(filePath).toLowerCase();
// The mimetype for "ts" is MPEG transport stream (a video format) but we want
// to assume these are typescript files instead.
if (ext === '.ts') {
return 'text';
}
if (ext === '.svg') {
return 'svg';
}
const lookedUpMimeType = mime.lookup(filePath); // Returns false if not found, or the mime type string
if (lookedUpMimeType) {
if (lookedUpMimeType.startsWith('image/')) {
return 'image';
}
if (lookedUpMimeType.startsWith('audio/')) {
return 'audio';
}
if (lookedUpMimeType.startsWith('video/')) {
return 'video';
}
if (lookedUpMimeType === 'application/pdf') {
return 'pdf';
}
}
// Stricter binary check for common non-text extensions before content check
// These are often not well-covered by mime-types or might be misidentified.
if (
[
'.zip',
'.tar',
'.gz',
'.exe',
'.dll',
'.so',
'.class',
'.jar',
'.war',
'.7z',
'.doc',
'.docx',
'.xls',
'.xlsx',
'.ppt',
'.pptx',
'.odt',
'.ods',
'.odp',
'.bin',
'.dat',
'.obj',
'.o',
'.a',
'.lib',
'.wasm',
'.pyc',
'.pyo',
].includes(ext)
) {
return 'binary';
}
// Fallback to content-based check if mime type wasn't conclusive for image/pdf
// and it's not a known binary extension.
if (isBinaryFile(filePath)) {
return 'binary';
}
return 'text';
}
export interface ProcessedFileReadResult {
llmContent: PartUnion; // string for text, Part for image/pdf/unreadable binary
returnDisplay: string;
error?: string; // Optional error message for the LLM if file processing failed
isTruncated?: boolean; // For text files, indicates if content was truncated
originalLineCount?: number; // For text files
linesShown?: [number, number]; // For text files [startLine, endLine] (1-based for display)
}
/**
* Reads and processes a single file, handling text, images, and PDFs.
* @param filePath Absolute path to the file.
* @param rootDirectory Absolute path to the project root for relative path display.
* @param offset Optional offset for text files (0-based line number).
* @param limit Optional limit for text files (number of lines to read).
* @returns ProcessedFileReadResult object.
*/
export async function processSingleFileContent(
filePath: string,
rootDirectory: string,
offset?: number,
limit?: number,
): Promise<ProcessedFileReadResult> {
try {
if (!fs.existsSync(filePath)) {
// Sync check is acceptable before async read
return {
llmContent: '',
returnDisplay: 'File not found.',
error: `File not found: ${filePath}`,
};
}
const stats = await fs.promises.stat(filePath);
if (stats.isDirectory()) {
return {
llmContent: '',
returnDisplay: 'Path is a directory.',
error: `Path is a directory, not a file: ${filePath}`,
};
}
const fileSizeInBytes = stats.size;
// 20MB limit
const maxFileSize = 20 * 1024 * 1024;
if (fileSizeInBytes > maxFileSize) {
throw new Error(
`File size exceeds the 20MB limit: ${filePath} (${(
fileSizeInBytes /
(1024 * 1024)
).toFixed(2)}MB)`,
);
}
const fileType = detectFileType(filePath);
const relativePathForDisplay = path
.relative(rootDirectory, filePath)
.replace(/\\/g, '/');
switch (fileType) {
case 'binary': {
return {
llmContent: `Cannot display content of binary file: ${relativePathForDisplay}`,
returnDisplay: `Skipped binary file: ${relativePathForDisplay}`,
};
}
case 'svg': {
const SVG_MAX_SIZE_BYTES = 1 * 1024 * 1024;
if (stats.size > SVG_MAX_SIZE_BYTES) {
return {
llmContent: `Cannot display content of SVG file larger than 1MB: ${relativePathForDisplay}`,
returnDisplay: `Skipped large SVG file (>1MB): ${relativePathForDisplay}`,
};
}
const content = await fs.promises.readFile(filePath, 'utf8');
return {
llmContent: content,
returnDisplay: `Read SVG as text: ${relativePathForDisplay}`,
};
}
case 'text': {
const content = await fs.promises.readFile(filePath, 'utf8');
const lines = content.split('\n');
const originalLineCount = lines.length;
const startLine = offset || 0;
const effectiveLimit =
limit === undefined ? DEFAULT_MAX_LINES_TEXT_FILE : limit;
// Ensure endLine does not exceed originalLineCount
const endLine = Math.min(startLine + effectiveLimit, originalLineCount);
// Ensure selectedLines doesn't try to slice beyond array bounds if startLine is too high
const actualStartLine = Math.min(startLine, originalLineCount);
const selectedLines = lines.slice(actualStartLine, endLine);
let linesWereTruncatedInLength = false;
const formattedLines = selectedLines.map((line) => {
if (line.length > MAX_LINE_LENGTH_TEXT_FILE) {
linesWereTruncatedInLength = true;
return (
line.substring(0, MAX_LINE_LENGTH_TEXT_FILE) + '... [truncated]'
);
}
return line;
});
const contentRangeTruncated = endLine < originalLineCount;
const isTruncated = contentRangeTruncated || linesWereTruncatedInLength;
let llmTextContent = '';
if (contentRangeTruncated) {
llmTextContent += `[File content truncated: showing lines ${actualStartLine + 1}-${endLine} of ${originalLineCount} total lines. Use offset/limit parameters to view more.]\n`;
} else if (linesWereTruncatedInLength) {
llmTextContent += `[File content partially truncated: some lines exceeded maximum length of ${MAX_LINE_LENGTH_TEXT_FILE} characters.]\n`;
}
llmTextContent += formattedLines.join('\n');
return {
llmContent: llmTextContent,
returnDisplay: isTruncated ? '(truncated)' : '',
isTruncated,
originalLineCount,
linesShown: [actualStartLine + 1, endLine],
};
}
case 'image':
case 'pdf':
case 'audio':
case 'video': {
const contentBuffer = await fs.promises.readFile(filePath);
const base64Data = contentBuffer.toString('base64');
return {
llmContent: {
inlineData: {
data: base64Data,
mimeType: mime.lookup(filePath) || 'application/octet-stream',
},
},
returnDisplay: `Read ${fileType} file: ${relativePathForDisplay}`,
};
}
default: {
// Should not happen with current detectFileType logic
const exhaustiveCheck: never = fileType;
return {
llmContent: `Unhandled file type: ${exhaustiveCheck}`,
returnDisplay: `Skipped unhandled file type: ${relativePathForDisplay}`,
error: `Unhandled file type for ${filePath}`,
};
}
}
} catch (error) {
const errorMessage = error instanceof Error ? error.message : String(error);
const displayPath = path
.relative(rootDirectory, filePath)
.replace(/\\/g, '/');
return {
llmContent: `Error reading file ${displayPath}: ${errorMessage}`,
returnDisplay: `Error reading file ${displayPath}: ${errorMessage}`,
error: `Error reading file ${filePath}: ${errorMessage}`,
};
}
}

View File

@@ -0,0 +1,144 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { describe, it, expect, beforeEach, vi } from 'vitest';
import { Config } from '../config/config.js';
import {
setSimulate429,
disableSimulationAfterFallback,
shouldSimulate429,
createSimulated429Error,
resetRequestCounter,
} from './testUtils.js';
import { DEFAULT_GEMINI_FLASH_MODEL } from '../config/models.js';
import { retryWithBackoff } from './retry.js';
import { AuthType } from '../core/contentGenerator.js';
describe('Flash Fallback Integration', () => {
let config: Config;
beforeEach(() => {
config = new Config({
sessionId: 'test-session',
targetDir: '/test',
debugMode: false,
cwd: '/test',
model: 'gemini-2.5-pro',
});
// Reset simulation state for each test
setSimulate429(false);
resetRequestCounter();
});
it('should automatically accept fallback', async () => {
// Set up a minimal flash fallback handler for testing
const flashFallbackHandler = async (): Promise<boolean> => true;
config.setFlashFallbackHandler(flashFallbackHandler);
// Call the handler directly to test
const result = await config.flashFallbackHandler!(
'gemini-2.5-pro',
DEFAULT_GEMINI_FLASH_MODEL,
);
// Verify it automatically accepts
expect(result).toBe(true);
});
it('should trigger fallback after 2 consecutive 429 errors for OAuth users', async () => {
let fallbackCalled = false;
let fallbackModel = '';
// Mock function that simulates exactly 2 429 errors, then succeeds after fallback
const mockApiCall = vi
.fn()
.mockRejectedValueOnce(createSimulated429Error())
.mockRejectedValueOnce(createSimulated429Error())
.mockResolvedValueOnce('success after fallback');
// Mock fallback handler
const mockFallbackHandler = vi.fn(async (_authType?: string) => {
fallbackCalled = true;
fallbackModel = DEFAULT_GEMINI_FLASH_MODEL;
return fallbackModel;
});
// Test with OAuth personal auth type, with maxAttempts = 2 to ensure fallback triggers
const result = await retryWithBackoff(mockApiCall, {
maxAttempts: 2,
initialDelayMs: 1,
maxDelayMs: 10,
shouldRetry: (error: Error) => {
const status = (error as Error & { status?: number }).status;
return status === 429;
},
onPersistent429: mockFallbackHandler,
authType: AuthType.LOGIN_WITH_GOOGLE,
});
// Verify fallback was triggered
expect(fallbackCalled).toBe(true);
expect(fallbackModel).toBe(DEFAULT_GEMINI_FLASH_MODEL);
expect(mockFallbackHandler).toHaveBeenCalledWith(
AuthType.LOGIN_WITH_GOOGLE,
expect.any(Error),
);
expect(result).toBe('success after fallback');
// Should have: 2 failures, then fallback triggered, then 1 success after retry reset
expect(mockApiCall).toHaveBeenCalledTimes(3);
});
it('should not trigger fallback for API key users', async () => {
let fallbackCalled = false;
// Mock function that simulates 429 errors
const mockApiCall = vi.fn().mockRejectedValue(createSimulated429Error());
// Mock fallback handler
const mockFallbackHandler = vi.fn(async () => {
fallbackCalled = true;
return DEFAULT_GEMINI_FLASH_MODEL;
});
// Test with API key auth type - should not trigger fallback
try {
await retryWithBackoff(mockApiCall, {
maxAttempts: 5,
initialDelayMs: 10,
maxDelayMs: 100,
shouldRetry: (error: Error) => {
const status = (error as Error & { status?: number }).status;
return status === 429;
},
onPersistent429: mockFallbackHandler,
authType: AuthType.USE_GEMINI, // API key auth type
});
} catch (error) {
// Expected to throw after max attempts
expect((error as Error).message).toContain('Rate limit exceeded');
}
// Verify fallback was NOT triggered for API key users
expect(fallbackCalled).toBe(false);
expect(mockFallbackHandler).not.toHaveBeenCalled();
});
it('should properly disable simulation state after fallback', () => {
// Enable simulation
setSimulate429(true);
// Verify simulation is enabled
expect(shouldSimulate429()).toBe(true);
// Disable simulation after fallback
disableSimulationAfterFallback();
// Verify simulation is now disabled
expect(shouldSimulate429()).toBe(false);
});
});

View File

@@ -0,0 +1,323 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { describe, it, expect } from 'vitest';
import {
getResponseText,
getResponseTextFromParts,
getFunctionCalls,
getFunctionCallsFromParts,
getFunctionCallsAsJson,
getFunctionCallsFromPartsAsJson,
getStructuredResponse,
getStructuredResponseFromParts,
} from './generateContentResponseUtilities.js';
import {
GenerateContentResponse,
Part,
FinishReason,
SafetyRating,
} from '@google/genai';
const mockTextPart = (text: string): Part => ({ text });
const mockFunctionCallPart = (
name: string,
args?: Record<string, unknown>,
): Part => ({
functionCall: { name, args: args ?? {} },
});
const mockResponse = (
parts: Part[],
finishReason: FinishReason = FinishReason.STOP,
safetyRatings: SafetyRating[] = [],
): GenerateContentResponse => ({
candidates: [
{
content: {
parts,
role: 'model',
},
index: 0,
finishReason,
safetyRatings,
},
],
promptFeedback: {
safetyRatings: [],
},
text: undefined,
data: undefined,
functionCalls: undefined,
executableCode: undefined,
codeExecutionResult: undefined,
});
const minimalMockResponse = (
candidates: GenerateContentResponse['candidates'],
): GenerateContentResponse => ({
candidates,
promptFeedback: { safetyRatings: [] },
text: undefined,
data: undefined,
functionCalls: undefined,
executableCode: undefined,
codeExecutionResult: undefined,
});
describe('generateContentResponseUtilities', () => {
describe('getResponseText', () => {
it('should return undefined for no candidates', () => {
expect(getResponseText(minimalMockResponse(undefined))).toBeUndefined();
});
it('should return undefined for empty candidates array', () => {
expect(getResponseText(minimalMockResponse([]))).toBeUndefined();
});
it('should return undefined for no parts', () => {
const response = mockResponse([]);
expect(getResponseText(response)).toBeUndefined();
});
it('should extract text from a single text part', () => {
const response = mockResponse([mockTextPart('Hello')]);
expect(getResponseText(response)).toBe('Hello');
});
it('should concatenate text from multiple text parts', () => {
const response = mockResponse([
mockTextPart('Hello '),
mockTextPart('World'),
]);
expect(getResponseText(response)).toBe('Hello World');
});
it('should ignore function call parts', () => {
const response = mockResponse([
mockTextPart('Hello '),
mockFunctionCallPart('testFunc'),
mockTextPart('World'),
]);
expect(getResponseText(response)).toBe('Hello World');
});
it('should return undefined if only function call parts exist', () => {
const response = mockResponse([
mockFunctionCallPart('testFunc'),
mockFunctionCallPart('anotherFunc'),
]);
expect(getResponseText(response)).toBeUndefined();
});
});
describe('getResponseTextFromParts', () => {
it('should return undefined for no parts', () => {
expect(getResponseTextFromParts([])).toBeUndefined();
});
it('should extract text from a single text part', () => {
expect(getResponseTextFromParts([mockTextPart('Hello')])).toBe('Hello');
});
it('should concatenate text from multiple text parts', () => {
expect(
getResponseTextFromParts([
mockTextPart('Hello '),
mockTextPart('World'),
]),
).toBe('Hello World');
});
it('should ignore function call parts', () => {
expect(
getResponseTextFromParts([
mockTextPart('Hello '),
mockFunctionCallPart('testFunc'),
mockTextPart('World'),
]),
).toBe('Hello World');
});
it('should return undefined if only function call parts exist', () => {
expect(
getResponseTextFromParts([
mockFunctionCallPart('testFunc'),
mockFunctionCallPart('anotherFunc'),
]),
).toBeUndefined();
});
});
describe('getFunctionCalls', () => {
it('should return undefined for no candidates', () => {
expect(getFunctionCalls(minimalMockResponse(undefined))).toBeUndefined();
});
it('should return undefined for empty candidates array', () => {
expect(getFunctionCalls(minimalMockResponse([]))).toBeUndefined();
});
it('should return undefined for no parts', () => {
const response = mockResponse([]);
expect(getFunctionCalls(response)).toBeUndefined();
});
it('should extract a single function call', () => {
const func = { name: 'testFunc', args: { a: 1 } };
const response = mockResponse([
mockFunctionCallPart(func.name, func.args),
]);
expect(getFunctionCalls(response)).toEqual([func]);
});
it('should extract multiple function calls', () => {
const func1 = { name: 'testFunc1', args: { a: 1 } };
const func2 = { name: 'testFunc2', args: { b: 2 } };
const response = mockResponse([
mockFunctionCallPart(func1.name, func1.args),
mockFunctionCallPart(func2.name, func2.args),
]);
expect(getFunctionCalls(response)).toEqual([func1, func2]);
});
it('should ignore text parts', () => {
const func = { name: 'testFunc', args: { a: 1 } };
const response = mockResponse([
mockTextPart('Some text'),
mockFunctionCallPart(func.name, func.args),
mockTextPart('More text'),
]);
expect(getFunctionCalls(response)).toEqual([func]);
});
it('should return undefined if only text parts exist', () => {
const response = mockResponse([
mockTextPart('Some text'),
mockTextPart('More text'),
]);
expect(getFunctionCalls(response)).toBeUndefined();
});
});
describe('getFunctionCallsFromParts', () => {
it('should return undefined for no parts', () => {
expect(getFunctionCallsFromParts([])).toBeUndefined();
});
it('should extract a single function call', () => {
const func = { name: 'testFunc', args: { a: 1 } };
expect(
getFunctionCallsFromParts([mockFunctionCallPart(func.name, func.args)]),
).toEqual([func]);
});
it('should extract multiple function calls', () => {
const func1 = { name: 'testFunc1', args: { a: 1 } };
const func2 = { name: 'testFunc2', args: { b: 2 } };
expect(
getFunctionCallsFromParts([
mockFunctionCallPart(func1.name, func1.args),
mockFunctionCallPart(func2.name, func2.args),
]),
).toEqual([func1, func2]);
});
it('should ignore text parts', () => {
const func = { name: 'testFunc', args: { a: 1 } };
expect(
getFunctionCallsFromParts([
mockTextPart('Some text'),
mockFunctionCallPart(func.name, func.args),
mockTextPart('More text'),
]),
).toEqual([func]);
});
it('should return undefined if only text parts exist', () => {
expect(
getFunctionCallsFromParts([
mockTextPart('Some text'),
mockTextPart('More text'),
]),
).toBeUndefined();
});
});
describe('getFunctionCallsAsJson', () => {
it('should return JSON string of function calls', () => {
const func1 = { name: 'testFunc1', args: { a: 1 } };
const func2 = { name: 'testFunc2', args: { b: 2 } };
const response = mockResponse([
mockFunctionCallPart(func1.name, func1.args),
mockTextPart('text in between'),
mockFunctionCallPart(func2.name, func2.args),
]);
const expectedJson = JSON.stringify([func1, func2], null, 2);
expect(getFunctionCallsAsJson(response)).toBe(expectedJson);
});
it('should return undefined if no function calls', () => {
const response = mockResponse([mockTextPart('Hello')]);
expect(getFunctionCallsAsJson(response)).toBeUndefined();
});
});
describe('getFunctionCallsFromPartsAsJson', () => {
it('should return JSON string of function calls from parts', () => {
const func1 = { name: 'testFunc1', args: { a: 1 } };
const func2 = { name: 'testFunc2', args: { b: 2 } };
const parts = [
mockFunctionCallPart(func1.name, func1.args),
mockTextPart('text in between'),
mockFunctionCallPart(func2.name, func2.args),
];
const expectedJson = JSON.stringify([func1, func2], null, 2);
expect(getFunctionCallsFromPartsAsJson(parts)).toBe(expectedJson);
});
it('should return undefined if no function calls in parts', () => {
const parts = [mockTextPart('Hello')];
expect(getFunctionCallsFromPartsAsJson(parts)).toBeUndefined();
});
});
describe('getStructuredResponse', () => {
it('should return only text if only text exists', () => {
const response = mockResponse([mockTextPart('Hello World')]);
expect(getStructuredResponse(response)).toBe('Hello World');
});
it('should return only function call JSON if only function calls exist', () => {
const func = { name: 'testFunc', args: { data: 'payload' } };
const response = mockResponse([
mockFunctionCallPart(func.name, func.args),
]);
const expectedJson = JSON.stringify([func], null, 2);
expect(getStructuredResponse(response)).toBe(expectedJson);
});
it('should return text and function call JSON if both exist', () => {
const text = 'Consider this data:';
const func = { name: 'processData', args: { item: 42 } };
const response = mockResponse([
mockTextPart(text),
mockFunctionCallPart(func.name, func.args),
]);
const expectedJson = JSON.stringify([func], null, 2);
expect(getStructuredResponse(response)).toBe(`${text}\n${expectedJson}`);
});
it('should return undefined if neither text nor function calls exist', () => {
const response = mockResponse([]);
expect(getStructuredResponse(response)).toBeUndefined();
});
});
describe('getStructuredResponseFromParts', () => {
it('should return only text if only text exists in parts', () => {
const parts = [mockTextPart('Hello World')];
expect(getStructuredResponseFromParts(parts)).toBe('Hello World');
});
it('should return only function call JSON if only function calls exist in parts', () => {
const func = { name: 'testFunc', args: { data: 'payload' } };
const parts = [mockFunctionCallPart(func.name, func.args)];
const expectedJson = JSON.stringify([func], null, 2);
expect(getStructuredResponseFromParts(parts)).toBe(expectedJson);
});
it('should return text and function call JSON if both exist in parts', () => {
const text = 'Consider this data:';
const func = { name: 'processData', args: { item: 42 } };
const parts = [
mockTextPart(text),
mockFunctionCallPart(func.name, func.args),
];
const expectedJson = JSON.stringify([func], null, 2);
expect(getStructuredResponseFromParts(parts)).toBe(
`${text}\n${expectedJson}`,
);
});
it('should return undefined if neither text nor function calls exist in parts', () => {
const parts: Part[] = [];
expect(getStructuredResponseFromParts(parts)).toBeUndefined();
});
});
});

View File

@@ -0,0 +1,119 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { GenerateContentResponse, Part, FunctionCall } from '@google/genai';
export function getResponseText(
response: GenerateContentResponse,
): string | undefined {
const parts = response.candidates?.[0]?.content?.parts;
if (!parts) {
return undefined;
}
const textSegments = parts
.map((part) => part.text)
.filter((text): text is string => typeof text === 'string');
if (textSegments.length === 0) {
return undefined;
}
return textSegments.join('');
}
export function getResponseTextFromParts(parts: Part[]): string | undefined {
if (!parts) {
return undefined;
}
const textSegments = parts
.map((part) => part.text)
.filter((text): text is string => typeof text === 'string');
if (textSegments.length === 0) {
return undefined;
}
return textSegments.join('');
}
export function getFunctionCalls(
response: GenerateContentResponse,
): FunctionCall[] | undefined {
const parts = response.candidates?.[0]?.content?.parts;
if (!parts) {
return undefined;
}
const functionCallParts = parts
.filter((part) => !!part.functionCall)
.map((part) => part.functionCall as FunctionCall);
return functionCallParts.length > 0 ? functionCallParts : undefined;
}
export function getFunctionCallsFromParts(
parts: Part[],
): FunctionCall[] | undefined {
if (!parts) {
return undefined;
}
const functionCallParts = parts
.filter((part) => !!part.functionCall)
.map((part) => part.functionCall as FunctionCall);
return functionCallParts.length > 0 ? functionCallParts : undefined;
}
export function getFunctionCallsAsJson(
response: GenerateContentResponse,
): string | undefined {
const functionCalls = getFunctionCalls(response);
if (!functionCalls) {
return undefined;
}
return JSON.stringify(functionCalls, null, 2);
}
export function getFunctionCallsFromPartsAsJson(
parts: Part[],
): string | undefined {
const functionCalls = getFunctionCallsFromParts(parts);
if (!functionCalls) {
return undefined;
}
return JSON.stringify(functionCalls, null, 2);
}
export function getStructuredResponse(
response: GenerateContentResponse,
): string | undefined {
const textContent = getResponseText(response);
const functionCallsJson = getFunctionCallsAsJson(response);
if (textContent && functionCallsJson) {
return `${textContent}\n${functionCallsJson}`;
}
if (textContent) {
return textContent;
}
if (functionCallsJson) {
return functionCallsJson;
}
return undefined;
}
export function getStructuredResponseFromParts(
parts: Part[],
): string | undefined {
const textContent = getResponseTextFromParts(parts);
const functionCallsJson = getFunctionCallsFromPartsAsJson(parts);
if (textContent && functionCallsJson) {
return `${textContent}\n${functionCallsJson}`;
}
if (textContent) {
return textContent;
}
if (functionCallsJson) {
return functionCallsJson;
}
return undefined;
}

View File

@@ -0,0 +1,344 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
/* eslint-disable @typescript-eslint/no-explicit-any */
import { describe, it, expect, vi, beforeEach, afterEach, Mock } from 'vitest';
import fsPromises from 'fs/promises';
import * as fs from 'fs';
import { Dirent as FSDirent } from 'fs';
import * as nodePath from 'path';
import { getFolderStructure } from './getFolderStructure.js';
import * as gitUtils from './gitUtils.js';
import { FileDiscoveryService } from '../services/fileDiscoveryService.js';
vi.mock('path', async (importOriginal) => {
const original = (await importOriginal()) as typeof nodePath;
return {
...original,
resolve: vi.fn((str) => str),
// Other path functions (basename, join, normalize, etc.) will use original implementation
};
});
vi.mock('fs/promises');
vi.mock('fs');
vi.mock('./gitUtils.js');
// Import 'path' again here, it will be the mocked version
import * as path from 'path';
// Helper to create Dirent-like objects for mocking fs.readdir
const createDirent = (name: string, type: 'file' | 'dir'): FSDirent => ({
name,
isFile: () => type === 'file',
isDirectory: () => type === 'dir',
isBlockDevice: () => false,
isCharacterDevice: () => false,
isSymbolicLink: () => false,
isFIFO: () => false,
isSocket: () => false,
path: '',
parentPath: '',
});
describe('getFolderStructure', () => {
beforeEach(() => {
vi.resetAllMocks();
// path.resolve is now a vi.fn() due to the top-level vi.mock.
// We ensure its implementation is set for each test (or rely on the one from vi.mock).
// vi.resetAllMocks() clears call history but not the implementation set by vi.fn() in vi.mock.
// If we needed to change it per test, we would do it here:
(path.resolve as Mock).mockImplementation((str: string) => str);
// Re-apply/define the mock implementation for fsPromises.readdir for each test
(fsPromises.readdir as Mock).mockImplementation(
async (dirPath: string | Buffer | URL) => {
// path.normalize here will use the mocked path module.
// Since normalize is spread from original, it should be the real one.
const normalizedPath = path.normalize(dirPath.toString());
if (mockFsStructure[normalizedPath]) {
return mockFsStructure[normalizedPath];
}
throw Object.assign(
new Error(
`ENOENT: no such file or directory, scandir '${normalizedPath}'`,
),
{ code: 'ENOENT' },
);
},
);
});
afterEach(() => {
vi.restoreAllMocks(); // Restores spies (like fsPromises.readdir) and resets vi.fn mocks (like path.resolve)
});
const mockFsStructure: Record<string, FSDirent[]> = {
'/testroot': [
createDirent('file1.txt', 'file'),
createDirent('subfolderA', 'dir'),
createDirent('emptyFolder', 'dir'),
createDirent('.hiddenfile', 'file'),
createDirent('node_modules', 'dir'),
],
'/testroot/subfolderA': [
createDirent('fileA1.ts', 'file'),
createDirent('fileA2.js', 'file'),
createDirent('subfolderB', 'dir'),
],
'/testroot/subfolderA/subfolderB': [createDirent('fileB1.md', 'file')],
'/testroot/emptyFolder': [],
'/testroot/node_modules': [createDirent('somepackage', 'dir')],
'/testroot/manyFilesFolder': Array.from({ length: 10 }, (_, i) =>
createDirent(`file-${i}.txt`, 'file'),
),
'/testroot/manyFolders': Array.from({ length: 5 }, (_, i) =>
createDirent(`folder-${i}`, 'dir'),
),
...Array.from({ length: 5 }, (_, i) => ({
[`/testroot/manyFolders/folder-${i}`]: [
createDirent('child.txt', 'file'),
],
})).reduce((acc, val) => ({ ...acc, ...val }), {}),
'/testroot/deepFolders': [createDirent('level1', 'dir')],
'/testroot/deepFolders/level1': [createDirent('level2', 'dir')],
'/testroot/deepFolders/level1/level2': [createDirent('level3', 'dir')],
'/testroot/deepFolders/level1/level2/level3': [
createDirent('file.txt', 'file'),
],
};
it('should return basic folder structure', async () => {
const structure = await getFolderStructure('/testroot/subfolderA');
const expected = `
Showing up to 200 items (files + folders).
/testroot/subfolderA/
├───fileA1.ts
├───fileA2.js
└───subfolderB/
└───fileB1.md
`.trim();
expect(structure.trim()).toBe(expected);
});
it('should handle an empty folder', async () => {
const structure = await getFolderStructure('/testroot/emptyFolder');
const expected = `
Showing up to 200 items (files + folders).
/testroot/emptyFolder/
`.trim();
expect(structure.trim()).toBe(expected.trim());
});
it('should ignore folders specified in ignoredFolders (default)', async () => {
const structure = await getFolderStructure('/testroot');
const expected = `
Showing up to 200 items (files + folders). Folders or files indicated with ... contain more items not shown, were ignored, or the display limit (200 items) was reached.
/testroot/
├───.hiddenfile
├───file1.txt
├───emptyFolder/
├───node_modules/...
└───subfolderA/
├───fileA1.ts
├───fileA2.js
└───subfolderB/
└───fileB1.md
`.trim();
expect(structure.trim()).toBe(expected);
});
it('should ignore folders specified in custom ignoredFolders', async () => {
const structure = await getFolderStructure('/testroot', {
ignoredFolders: new Set(['subfolderA', 'node_modules']),
});
const expected = `
Showing up to 200 items (files + folders). Folders or files indicated with ... contain more items not shown, were ignored, or the display limit (200 items) was reached.
/testroot/
├───.hiddenfile
├───file1.txt
├───emptyFolder/
├───node_modules/...
└───subfolderA/...
`.trim();
expect(structure.trim()).toBe(expected);
});
it('should filter files by fileIncludePattern', async () => {
const structure = await getFolderStructure('/testroot/subfolderA', {
fileIncludePattern: /\.ts$/,
});
const expected = `
Showing up to 200 items (files + folders).
/testroot/subfolderA/
├───fileA1.ts
└───subfolderB/
`.trim();
expect(structure.trim()).toBe(expected);
});
it('should handle maxItems truncation for files within a folder', async () => {
const structure = await getFolderStructure('/testroot/subfolderA', {
maxItems: 3,
});
const expected = `
Showing up to 3 items (files + folders).
/testroot/subfolderA/
├───fileA1.ts
├───fileA2.js
└───subfolderB/
`.trim();
expect(structure.trim()).toBe(expected);
});
it('should handle maxItems truncation for subfolders', async () => {
const structure = await getFolderStructure('/testroot/manyFolders', {
maxItems: 4,
});
const expectedRevised = `
Showing up to 4 items (files + folders). Folders or files indicated with ... contain more items not shown, were ignored, or the display limit (4 items) was reached.
/testroot/manyFolders/
├───folder-0/
├───folder-1/
├───folder-2/
├───folder-3/
└───...
`.trim();
expect(structure.trim()).toBe(expectedRevised);
});
it('should handle maxItems that only allows the root folder itself', async () => {
const structure = await getFolderStructure('/testroot/subfolderA', {
maxItems: 1,
});
const expectedRevisedMax1 = `
Showing up to 1 items (files + folders). Folders or files indicated with ... contain more items not shown, were ignored, or the display limit (1 items) was reached.
/testroot/subfolderA/
├───fileA1.ts
├───...
└───...
`.trim();
expect(structure.trim()).toBe(expectedRevisedMax1);
});
it('should handle non-existent directory', async () => {
// Temporarily make fsPromises.readdir throw ENOENT for this specific path
const originalReaddir = fsPromises.readdir;
(fsPromises.readdir as Mock).mockImplementation(
async (p: string | Buffer | URL) => {
if (p === '/nonexistent') {
throw Object.assign(new Error('ENOENT'), { code: 'ENOENT' });
}
return originalReaddir(p);
},
);
const structure = await getFolderStructure('/nonexistent');
expect(structure).toContain(
'Error: Could not read directory "/nonexistent"',
);
});
it('should handle deep folder structure within limits', async () => {
const structure = await getFolderStructure('/testroot/deepFolders', {
maxItems: 10,
});
const expected = `
Showing up to 10 items (files + folders).
/testroot/deepFolders/
└───level1/
└───level2/
└───level3/
└───file.txt
`.trim();
expect(structure.trim()).toBe(expected);
});
it('should truncate deep folder structure if maxItems is small', async () => {
const structure = await getFolderStructure('/testroot/deepFolders', {
maxItems: 3,
});
const expected = `
Showing up to 3 items (files + folders).
/testroot/deepFolders/
└───level1/
└───level2/
└───level3/
`.trim();
expect(structure.trim()).toBe(expected);
});
});
describe('getFolderStructure gitignore', () => {
beforeEach(() => {
vi.resetAllMocks();
(path.resolve as Mock).mockImplementation((str: string) => str);
(fsPromises.readdir as Mock).mockImplementation(async (p) => {
const path = p.toString();
if (path === '/test/project') {
return [
createDirent('file1.txt', 'file'),
createDirent('node_modules', 'dir'),
createDirent('ignored.txt', 'file'),
createDirent('.qwen', 'dir'),
] as any;
}
if (path === '/test/project/node_modules') {
return [createDirent('some-package', 'dir')] as any;
}
if (path === '/test/project/.gemini') {
return [
createDirent('config.yaml', 'file'),
createDirent('logs.json', 'file'),
] as any;
}
return [];
});
(fs.readFileSync as Mock).mockImplementation((p) => {
const path = p.toString();
if (path === '/test/project/.gitignore') {
return 'ignored.txt\nnode_modules/\n.qwen/\n!/.qwen/config.yaml';
}
return '';
});
vi.mocked(gitUtils.isGitRepository).mockReturnValue(true);
});
it('should ignore files and folders specified in .gitignore', async () => {
const fileService = new FileDiscoveryService('/test/project');
const structure = await getFolderStructure('/test/project', {
fileService,
});
expect(structure).not.toContain('ignored.txt');
expect(structure).toContain('node_modules/...');
expect(structure).not.toContain('logs.json');
});
it('should not ignore files if respectGitIgnore is false', async () => {
const fileService = new FileDiscoveryService('/test/project');
const structure = await getFolderStructure('/test/project', {
fileService,
respectGitIgnore: false,
});
expect(structure).toContain('ignored.txt');
// node_modules is still ignored by default
expect(structure).toContain('node_modules/...');
});
});

View File

@@ -0,0 +1,347 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import * as fs from 'fs/promises';
import { Dirent } from 'fs';
import * as path from 'path';
import { getErrorMessage, isNodeError } from './errors.js';
import { FileDiscoveryService } from '../services/fileDiscoveryService.js';
const MAX_ITEMS = 200;
const TRUNCATION_INDICATOR = '...';
const DEFAULT_IGNORED_FOLDERS = new Set(['node_modules', '.git', 'dist']);
// --- Interfaces ---
/** Options for customizing folder structure retrieval. */
interface FolderStructureOptions {
/** Maximum number of files and folders combined to display. Defaults to 200. */
maxItems?: number;
/** Set of folder names to ignore completely. Case-sensitive. */
ignoredFolders?: Set<string>;
/** Optional regex to filter included files by name. */
fileIncludePattern?: RegExp;
/** For filtering files. */
fileService?: FileDiscoveryService;
/** Whether to use .gitignore patterns. */
respectGitIgnore?: boolean;
}
// Define a type for the merged options where fileIncludePattern remains optional
type MergedFolderStructureOptions = Required<
Omit<FolderStructureOptions, 'fileIncludePattern' | 'fileService'>
> & {
fileIncludePattern?: RegExp;
fileService?: FileDiscoveryService;
};
/** Represents the full, unfiltered information about a folder and its contents. */
interface FullFolderInfo {
name: string;
path: string;
files: string[];
subFolders: FullFolderInfo[];
totalChildren: number; // Number of files and subfolders included from this folder during BFS scan
totalFiles: number; // Number of files included from this folder during BFS scan
isIgnored?: boolean; // Flag to easily identify ignored folders later
hasMoreFiles?: boolean; // Indicates if files were truncated for this specific folder
hasMoreSubfolders?: boolean; // Indicates if subfolders were truncated for this specific folder
}
// --- Interfaces ---
// --- Helper Functions ---
async function readFullStructure(
rootPath: string,
options: MergedFolderStructureOptions,
): Promise<FullFolderInfo | null> {
const rootName = path.basename(rootPath);
const rootNode: FullFolderInfo = {
name: rootName,
path: rootPath,
files: [],
subFolders: [],
totalChildren: 0,
totalFiles: 0,
};
const queue: Array<{ folderInfo: FullFolderInfo; currentPath: string }> = [
{ folderInfo: rootNode, currentPath: rootPath },
];
let currentItemCount = 0;
// Count the root node itself as one item if we are not just listing its content
const processedPaths = new Set<string>(); // To avoid processing same path if symlinks create loops
while (queue.length > 0) {
const { folderInfo, currentPath } = queue.shift()!;
if (processedPaths.has(currentPath)) {
continue;
}
processedPaths.add(currentPath);
if (currentItemCount >= options.maxItems) {
// If the root itself caused us to exceed, we can't really show anything.
// Otherwise, this folder won't be processed further.
// The parent that queued this would have set its own hasMoreSubfolders flag.
continue;
}
let entries: Dirent[];
try {
const rawEntries = await fs.readdir(currentPath, { withFileTypes: true });
// Sort entries alphabetically by name for consistent processing order
entries = rawEntries.sort((a, b) => a.name.localeCompare(b.name));
} catch (error: unknown) {
if (
isNodeError(error) &&
(error.code === 'EACCES' || error.code === 'ENOENT')
) {
console.warn(
`Warning: Could not read directory ${currentPath}: ${error.message}`,
);
if (currentPath === rootPath && error.code === 'ENOENT') {
return null; // Root directory itself not found
}
// For other EACCES/ENOENT on subdirectories, just skip them.
continue;
}
throw error;
}
const filesInCurrentDir: string[] = [];
const subFoldersInCurrentDir: FullFolderInfo[] = [];
// Process files first in the current directory
for (const entry of entries) {
if (entry.isFile()) {
if (currentItemCount >= options.maxItems) {
folderInfo.hasMoreFiles = true;
break;
}
const fileName = entry.name;
const filePath = path.join(currentPath, fileName);
if (options.respectGitIgnore && options.fileService) {
if (options.fileService.shouldGitIgnoreFile(filePath)) {
continue;
}
}
if (
!options.fileIncludePattern ||
options.fileIncludePattern.test(fileName)
) {
filesInCurrentDir.push(fileName);
currentItemCount++;
folderInfo.totalFiles++;
folderInfo.totalChildren++;
}
}
}
folderInfo.files = filesInCurrentDir;
// Then process directories and queue them
for (const entry of entries) {
if (entry.isDirectory()) {
// Check if adding this directory ITSELF would meet or exceed maxItems
// (currentItemCount refers to items *already* added before this one)
if (currentItemCount >= options.maxItems) {
folderInfo.hasMoreSubfolders = true;
break; // Already at limit, cannot add this folder or any more
}
// If adding THIS folder makes us hit the limit exactly, and it might have children,
// it's better to show '...' for the parent, unless this is the very last item slot.
// This logic is tricky. Let's try a simpler: if we can't add this item, mark and break.
const subFolderName = entry.name;
const subFolderPath = path.join(currentPath, subFolderName);
let isIgnoredByGit = false;
if (options.respectGitIgnore && options.fileService) {
if (options.fileService.shouldGitIgnoreFile(subFolderPath)) {
isIgnoredByGit = true;
}
}
if (options.ignoredFolders.has(subFolderName) || isIgnoredByGit) {
const ignoredSubFolder: FullFolderInfo = {
name: subFolderName,
path: subFolderPath,
files: [],
subFolders: [],
totalChildren: 0,
totalFiles: 0,
isIgnored: true,
};
subFoldersInCurrentDir.push(ignoredSubFolder);
currentItemCount++; // Count the ignored folder itself
folderInfo.totalChildren++; // Also counts towards parent's children
continue;
}
const subFolderNode: FullFolderInfo = {
name: subFolderName,
path: subFolderPath,
files: [],
subFolders: [],
totalChildren: 0,
totalFiles: 0,
};
subFoldersInCurrentDir.push(subFolderNode);
currentItemCount++;
folderInfo.totalChildren++; // Counts towards parent's children
// Add to queue for processing its children later
queue.push({ folderInfo: subFolderNode, currentPath: subFolderPath });
}
}
folderInfo.subFolders = subFoldersInCurrentDir;
}
return rootNode;
}
/**
* Reads the directory structure using BFS, respecting maxItems.
* @param node The current node in the reduced structure.
* @param indent The current indentation string.
* @param isLast Sibling indicator.
* @param builder Array to build the string lines.
*/
function formatStructure(
node: FullFolderInfo,
currentIndent: string,
isLastChildOfParent: boolean,
isProcessingRootNode: boolean,
builder: string[],
): void {
const connector = isLastChildOfParent ? '└───' : '├───';
// The root node of the structure (the one passed initially to getFolderStructure)
// is not printed with a connector line itself, only its name as a header.
// Its children are printed relative to that conceptual root.
// Ignored root nodes ARE printed with a connector.
if (!isProcessingRootNode || node.isIgnored) {
builder.push(
`${currentIndent}${connector}${node.name}/${node.isIgnored ? TRUNCATION_INDICATOR : ''}`,
);
}
// Determine the indent for the children of *this* node.
// If *this* node was the root of the whole structure, its children start with no indent before their connectors.
// Otherwise, children's indent extends from the current node's indent.
const indentForChildren = isProcessingRootNode
? ''
: currentIndent + (isLastChildOfParent ? ' ' : '│ ');
// Render files of the current node
const fileCount = node.files.length;
for (let i = 0; i < fileCount; i++) {
const isLastFileAmongSiblings =
i === fileCount - 1 &&
node.subFolders.length === 0 &&
!node.hasMoreSubfolders;
const fileConnector = isLastFileAmongSiblings ? '└───' : '├───';
builder.push(`${indentForChildren}${fileConnector}${node.files[i]}`);
}
if (node.hasMoreFiles) {
const isLastIndicatorAmongSiblings =
node.subFolders.length === 0 && !node.hasMoreSubfolders;
const fileConnector = isLastIndicatorAmongSiblings ? '└───' : '├───';
builder.push(`${indentForChildren}${fileConnector}${TRUNCATION_INDICATOR}`);
}
// Render subfolders of the current node
const subFolderCount = node.subFolders.length;
for (let i = 0; i < subFolderCount; i++) {
const isLastSubfolderAmongSiblings =
i === subFolderCount - 1 && !node.hasMoreSubfolders;
// Children are never the root node being processed initially.
formatStructure(
node.subFolders[i],
indentForChildren,
isLastSubfolderAmongSiblings,
false,
builder,
);
}
if (node.hasMoreSubfolders) {
builder.push(`${indentForChildren}└───${TRUNCATION_INDICATOR}`);
}
}
// --- Main Exported Function ---
/**
* Generates a string representation of a directory's structure,
* limiting the number of items displayed. Ignored folders are shown
* followed by '...' instead of their contents.
*
* @param directory The absolute or relative path to the directory.
* @param options Optional configuration settings.
* @returns A promise resolving to the formatted folder structure string.
*/
export async function getFolderStructure(
directory: string,
options?: FolderStructureOptions,
): Promise<string> {
const resolvedPath = path.resolve(directory);
const mergedOptions: MergedFolderStructureOptions = {
maxItems: options?.maxItems ?? MAX_ITEMS,
ignoredFolders: options?.ignoredFolders ?? DEFAULT_IGNORED_FOLDERS,
fileIncludePattern: options?.fileIncludePattern,
fileService: options?.fileService,
respectGitIgnore: options?.respectGitIgnore ?? true,
};
try {
// 1. Read the structure using BFS, respecting maxItems
const structureRoot = await readFullStructure(resolvedPath, mergedOptions);
if (!structureRoot) {
return `Error: Could not read directory "${resolvedPath}". Check path and permissions.`;
}
// 2. Format the structure into a string
const structureLines: string[] = [];
// Pass true for isRoot for the initial call
formatStructure(structureRoot, '', true, true, structureLines);
// 3. Build the final output string
const displayPath = resolvedPath.replace(/\\/g, '/');
let disclaimer = '';
// Check if truncation occurred anywhere or if ignored folders are present.
// A simple check: if any node indicates more files/subfolders, or is ignored.
let truncationOccurred = false;
function checkForTruncation(node: FullFolderInfo) {
if (node.hasMoreFiles || node.hasMoreSubfolders || node.isIgnored) {
truncationOccurred = true;
}
if (!truncationOccurred) {
for (const sub of node.subFolders) {
checkForTruncation(sub);
if (truncationOccurred) break;
}
}
}
checkForTruncation(structureRoot);
if (truncationOccurred) {
disclaimer = `Folders or files indicated with ${TRUNCATION_INDICATOR} contain more items not shown, were ignored, or the display limit (${mergedOptions.maxItems} items) was reached.`;
}
const summary =
`Showing up to ${mergedOptions.maxItems} items (files + folders). ${disclaimer}`.trim();
const output = `${summary}\n\n${displayPath}/\n${structureLines.join('\n')}`;
return output;
} catch (error: unknown) {
console.error(`Error getting folder structure for ${resolvedPath}:`, error);
return `Error processing directory "${resolvedPath}": ${getErrorMessage(error)}`;
}
}

View File

@@ -0,0 +1,175 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { describe, it, expect, beforeEach, vi, afterEach } from 'vitest';
import { GitIgnoreParser } from './gitIgnoreParser.js';
import * as fs from 'fs';
import * as path from 'path';
import { isGitRepository } from './gitUtils.js';
// Mock fs module
vi.mock('fs');
// Mock gitUtils module
vi.mock('./gitUtils.js');
describe('GitIgnoreParser', () => {
let parser: GitIgnoreParser;
const mockProjectRoot = '/test/project';
beforeEach(() => {
parser = new GitIgnoreParser(mockProjectRoot);
// Reset mocks before each test
vi.mocked(fs.readFileSync).mockClear();
vi.mocked(isGitRepository).mockReturnValue(true);
});
afterEach(() => {
vi.restoreAllMocks();
});
describe('initialization', () => {
it('should initialize without errors when no .gitignore exists', () => {
expect(() => parser.loadGitRepoPatterns()).not.toThrow();
});
it('should load .gitignore patterns when file exists', () => {
const gitignoreContent = `
# Comment
node_modules/
*.log
/dist
.env
`;
vi.mocked(fs.readFileSync).mockReturnValueOnce(gitignoreContent);
parser.loadGitRepoPatterns();
expect(parser.getPatterns()).toEqual([
'.git',
'node_modules/',
'*.log',
'/dist',
'.env',
]);
expect(parser.isIgnored('node_modules/some-lib')).toBe(true);
expect(parser.isIgnored('src/app.log')).toBe(true);
expect(parser.isIgnored('dist/index.js')).toBe(true);
expect(parser.isIgnored('.env')).toBe(true);
});
it('should handle git exclude file', () => {
vi.mocked(fs.readFileSync).mockImplementation((filePath) => {
if (
filePath === path.join(mockProjectRoot, '.git', 'info', 'exclude')
) {
return 'temp/\n*.tmp';
}
throw new Error('ENOENT');
});
parser.loadGitRepoPatterns();
expect(parser.getPatterns()).toEqual(['.git', 'temp/', '*.tmp']);
expect(parser.isIgnored('temp/file.txt')).toBe(true);
expect(parser.isIgnored('src/file.tmp')).toBe(true);
});
it('should handle custom patterns file name', () => {
vi.mocked(isGitRepository).mockReturnValue(false);
vi.mocked(fs.readFileSync).mockImplementation((filePath) => {
if (filePath === path.join(mockProjectRoot, '.geminiignore')) {
return 'temp/\n*.tmp';
}
throw new Error('ENOENT');
});
parser.loadPatterns('.geminiignore');
expect(parser.getPatterns()).toEqual(['temp/', '*.tmp']);
expect(parser.isIgnored('temp/file.txt')).toBe(true);
expect(parser.isIgnored('src/file.tmp')).toBe(true);
});
it('should initialize without errors when no .geminiignore exists', () => {
expect(() => parser.loadPatterns('.geminiignore')).not.toThrow();
});
});
describe('isIgnored', () => {
beforeEach(() => {
const gitignoreContent = `
node_modules/
*.log
/dist
/.env
src/*.tmp
!src/important.tmp
`;
vi.mocked(fs.readFileSync).mockReturnValueOnce(gitignoreContent);
parser.loadGitRepoPatterns();
});
it('should always ignore .git directory', () => {
expect(parser.isIgnored('.git')).toBe(true);
expect(parser.isIgnored('.git/config')).toBe(true);
expect(parser.isIgnored(path.join(mockProjectRoot, '.git', 'HEAD'))).toBe(
true,
);
});
it('should ignore files matching patterns', () => {
expect(parser.isIgnored('node_modules/package/index.js')).toBe(true);
expect(parser.isIgnored('app.log')).toBe(true);
expect(parser.isIgnored('logs/app.log')).toBe(true);
expect(parser.isIgnored('dist/bundle.js')).toBe(true);
expect(parser.isIgnored('.env')).toBe(true);
expect(parser.isIgnored('config/.env')).toBe(false); // .env is anchored to root
});
it('should ignore files with path-specific patterns', () => {
expect(parser.isIgnored('src/temp.tmp')).toBe(true);
expect(parser.isIgnored('other/temp.tmp')).toBe(false);
});
it('should handle negation patterns', () => {
expect(parser.isIgnored('src/important.tmp')).toBe(false);
});
it('should not ignore files that do not match patterns', () => {
expect(parser.isIgnored('src/index.ts')).toBe(false);
expect(parser.isIgnored('README.md')).toBe(false);
});
it('should handle absolute paths correctly', () => {
const absolutePath = path.join(mockProjectRoot, 'node_modules', 'lib');
expect(parser.isIgnored(absolutePath)).toBe(true);
});
it('should handle paths outside project root by not ignoring them', () => {
const outsidePath = path.resolve(mockProjectRoot, '../other/file.txt');
expect(parser.isIgnored(outsidePath)).toBe(false);
});
it('should handle relative paths correctly', () => {
expect(parser.isIgnored('node_modules/some-package')).toBe(true);
expect(parser.isIgnored('../some/other/file.txt')).toBe(false);
});
it('should normalize path separators on Windows', () => {
expect(parser.isIgnored('node_modules\\package')).toBe(true);
expect(parser.isIgnored('src\\temp.tmp')).toBe(true);
});
});
describe('getIgnoredPatterns', () => {
it('should return the raw patterns added', () => {
const gitignoreContent = '*.log\n!important.log';
vi.mocked(fs.readFileSync).mockReturnValueOnce(gitignoreContent);
parser.loadGitRepoPatterns();
expect(parser.getPatterns()).toEqual(['.git', '*.log', '!important.log']);
});
});
});

View File

@@ -0,0 +1,79 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import * as fs from 'fs';
import * as path from 'path';
import ignore, { type Ignore } from 'ignore';
import { isGitRepository } from './gitUtils.js';
export interface GitIgnoreFilter {
isIgnored(filePath: string): boolean;
getPatterns(): string[];
}
export class GitIgnoreParser implements GitIgnoreFilter {
private projectRoot: string;
private ig: Ignore = ignore();
private patterns: string[] = [];
constructor(projectRoot: string) {
this.projectRoot = path.resolve(projectRoot);
}
loadGitRepoPatterns(): void {
if (!isGitRepository(this.projectRoot)) return;
// Always ignore .git directory regardless of .gitignore content
this.addPatterns(['.git']);
const patternFiles = ['.gitignore', path.join('.git', 'info', 'exclude')];
for (const pf of patternFiles) {
this.loadPatterns(pf);
}
}
loadPatterns(patternsFileName: string): void {
const patternsFilePath = path.join(this.projectRoot, patternsFileName);
let content: string;
try {
content = fs.readFileSync(patternsFilePath, 'utf-8');
} catch (_error) {
// ignore file not found
return;
}
const patterns = (content ?? '')
.split('\n')
.map((p) => p.trim())
.filter((p) => p !== '' && !p.startsWith('#'));
this.addPatterns(patterns);
}
private addPatterns(patterns: string[]) {
this.ig.add(patterns);
this.patterns.push(...patterns);
}
isIgnored(filePath: string): boolean {
const relativePath = path.isAbsolute(filePath)
? path.relative(this.projectRoot, filePath)
: filePath;
if (relativePath === '' || relativePath.startsWith('..')) {
return false;
}
let normalizedPath = relativePath.replace(/\\/g, '/');
if (normalizedPath.startsWith('./')) {
normalizedPath = normalizedPath.substring(2);
}
return this.ig.ignores(normalizedPath);
}
getPatterns(): string[] {
return this.patterns;
}
}

View File

@@ -0,0 +1,73 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import * as fs from 'fs';
import * as path from 'path';
/**
* Checks if a directory is within a git repository
* @param directory The directory to check
* @returns true if the directory is in a git repository, false otherwise
*/
export function isGitRepository(directory: string): boolean {
try {
let currentDir = path.resolve(directory);
while (true) {
const gitDir = path.join(currentDir, '.git');
// Check if .git exists (either as directory or file for worktrees)
if (fs.existsSync(gitDir)) {
return true;
}
const parentDir = path.dirname(currentDir);
// If we've reached the root directory, stop searching
if (parentDir === currentDir) {
break;
}
currentDir = parentDir;
}
return false;
} catch (_error) {
// If any filesystem error occurs, assume not a git repo
return false;
}
}
/**
* Finds the root directory of a git repository
* @param directory Starting directory to search from
* @returns The git repository root path, or null if not in a git repository
*/
export function findGitRoot(directory: string): string | null {
try {
let currentDir = path.resolve(directory);
while (true) {
const gitDir = path.join(currentDir, '.git');
if (fs.existsSync(gitDir)) {
return currentDir;
}
const parentDir = path.dirname(currentDir);
if (parentDir === currentDir) {
break;
}
currentDir = parentDir;
}
return null;
} catch (_error) {
return null;
}
}

View File

@@ -0,0 +1,607 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { vi, describe, it, expect, beforeEach, Mocked } from 'vitest';
import * as fsPromises from 'fs/promises';
import * as fsSync from 'fs';
import { Stats, Dirent } from 'fs';
import * as os from 'os';
import * as path from 'path';
import { loadServerHierarchicalMemory } from './memoryDiscovery.js';
import {
GEMINI_CONFIG_DIR,
setGeminiMdFilename,
getCurrentGeminiMdFilename,
DEFAULT_CONTEXT_FILENAME,
} from '../tools/memoryTool.js';
import { FileDiscoveryService } from '../services/fileDiscoveryService.js';
const ORIGINAL_GEMINI_MD_FILENAME_CONST_FOR_TEST = DEFAULT_CONTEXT_FILENAME;
// Mock the entire fs/promises module
vi.mock('fs/promises');
// Mock the parts of fsSync we might use (like constants or existsSync if needed)
vi.mock('fs', async (importOriginal) => {
const actual = await importOriginal<typeof fsSync>();
return {
...actual, // Spread actual to get all exports, including Stats and Dirent if they are classes/constructors
constants: { ...actual.constants }, // Preserve constants
};
});
vi.mock('os');
describe('loadServerHierarchicalMemory', () => {
const mockFs = fsPromises as Mocked<typeof fsPromises>;
const mockOs = os as Mocked<typeof os>;
const CWD = '/test/project/src';
const PROJECT_ROOT = '/test/project';
const USER_HOME = '/test/userhome';
let GLOBAL_GEMINI_DIR: string;
let GLOBAL_GEMINI_FILE: string; // Defined in beforeEach
const fileService = new FileDiscoveryService(PROJECT_ROOT);
beforeEach(() => {
vi.resetAllMocks();
// Set environment variables to indicate test environment
process.env.NODE_ENV = 'test';
process.env.VITEST = 'true';
setGeminiMdFilename(DEFAULT_CONTEXT_FILENAME); // Use defined const
mockOs.homedir.mockReturnValue(USER_HOME);
// Define these here to use potentially reset/updated values from imports
GLOBAL_GEMINI_DIR = path.join(USER_HOME, GEMINI_CONFIG_DIR);
GLOBAL_GEMINI_FILE = path.join(
GLOBAL_GEMINI_DIR,
getCurrentGeminiMdFilename(), // Use current filename
);
mockFs.stat.mockRejectedValue(new Error('File not found'));
mockFs.readdir.mockResolvedValue([]);
mockFs.readFile.mockRejectedValue(new Error('File not found'));
mockFs.access.mockRejectedValue(new Error('File not found'));
});
it('should return empty memory and count if no context files are found', async () => {
const { memoryContent, fileCount } = await loadServerHierarchicalMemory(
CWD,
false,
fileService,
);
expect(memoryContent).toBe('');
expect(fileCount).toBe(0);
});
it('should load only the global context file if present and others are not (default filename)', async () => {
const globalDefaultFile = path.join(
GLOBAL_GEMINI_DIR,
DEFAULT_CONTEXT_FILENAME,
);
mockFs.access.mockImplementation(async (p) => {
if (p === globalDefaultFile) {
return undefined;
}
throw new Error('File not found');
});
mockFs.readFile.mockImplementation(async (p) => {
if (p === globalDefaultFile) {
return 'Global memory content';
}
throw new Error('File not found');
});
const { memoryContent, fileCount } = await loadServerHierarchicalMemory(
CWD,
false,
fileService,
);
expect(memoryContent).toBe(
`--- Context from: ${path.relative(CWD, globalDefaultFile)} ---\nGlobal memory content\n--- End of Context from: ${path.relative(CWD, globalDefaultFile)} ---`,
);
expect(fileCount).toBe(1);
expect(mockFs.readFile).toHaveBeenCalledWith(globalDefaultFile, 'utf-8');
});
it('should load only the global custom context file if present and filename is changed', async () => {
const customFilename = 'CUSTOM_AGENTS.md';
setGeminiMdFilename(customFilename);
const globalCustomFile = path.join(GLOBAL_GEMINI_DIR, customFilename);
mockFs.access.mockImplementation(async (p) => {
if (p === globalCustomFile) {
return undefined;
}
throw new Error('File not found');
});
mockFs.readFile.mockImplementation(async (p) => {
if (p === globalCustomFile) {
return 'Global custom memory';
}
throw new Error('File not found');
});
const { memoryContent, fileCount } = await loadServerHierarchicalMemory(
CWD,
false,
fileService,
);
expect(memoryContent).toBe(
`--- Context from: ${path.relative(CWD, globalCustomFile)} ---\nGlobal custom memory\n--- End of Context from: ${path.relative(CWD, globalCustomFile)} ---`,
);
expect(fileCount).toBe(1);
expect(mockFs.readFile).toHaveBeenCalledWith(globalCustomFile, 'utf-8');
});
it('should load context files by upward traversal with custom filename', async () => {
const customFilename = 'PROJECT_CONTEXT.md';
setGeminiMdFilename(customFilename);
const projectRootCustomFile = path.join(PROJECT_ROOT, customFilename);
const srcCustomFile = path.join(CWD, customFilename);
mockFs.stat.mockImplementation(async (p) => {
if (p === path.join(PROJECT_ROOT, '.git')) {
return { isDirectory: () => true } as Stats;
}
throw new Error('File not found');
});
mockFs.access.mockImplementation(async (p) => {
if (p === projectRootCustomFile || p === srcCustomFile) {
return undefined;
}
throw new Error('File not found');
});
mockFs.readFile.mockImplementation(async (p) => {
if (p === projectRootCustomFile) {
return 'Project root custom memory';
}
if (p === srcCustomFile) {
return 'Src directory custom memory';
}
throw new Error('File not found');
});
const { memoryContent, fileCount } = await loadServerHierarchicalMemory(
CWD,
false,
fileService,
);
const expectedContent =
`--- Context from: ${path.relative(CWD, projectRootCustomFile)} ---\nProject root custom memory\n--- End of Context from: ${path.relative(CWD, projectRootCustomFile)} ---\n\n` +
`--- Context from: ${customFilename} ---\nSrc directory custom memory\n--- End of Context from: ${customFilename} ---`;
expect(memoryContent).toBe(expectedContent);
expect(fileCount).toBe(2);
expect(mockFs.readFile).toHaveBeenCalledWith(
projectRootCustomFile,
'utf-8',
);
expect(mockFs.readFile).toHaveBeenCalledWith(srcCustomFile, 'utf-8');
});
it('should load context files by downward traversal with custom filename', async () => {
const customFilename = 'LOCAL_CONTEXT.md';
setGeminiMdFilename(customFilename);
const subDir = path.join(CWD, 'subdir');
const subDirCustomFile = path.join(subDir, customFilename);
const cwdCustomFile = path.join(CWD, customFilename);
mockFs.access.mockImplementation(async (p) => {
if (p === cwdCustomFile || p === subDirCustomFile) return undefined;
throw new Error('File not found');
});
mockFs.readFile.mockImplementation(async (p) => {
if (p === cwdCustomFile) return 'CWD custom memory';
if (p === subDirCustomFile) return 'Subdir custom memory';
throw new Error('File not found');
});
mockFs.readdir.mockImplementation((async (
p: fsSync.PathLike,
): Promise<Dirent[]> => {
if (p === CWD) {
return [
{
name: customFilename,
isFile: () => true,
isDirectory: () => false,
} as Dirent,
{
name: 'subdir',
isFile: () => false,
isDirectory: () => true,
} as Dirent,
] as Dirent[];
}
if (p === subDir) {
return [
{
name: customFilename,
isFile: () => true,
isDirectory: () => false,
} as Dirent,
] as Dirent[];
}
return [] as Dirent[];
}) as unknown as typeof fsPromises.readdir);
const { memoryContent, fileCount } = await loadServerHierarchicalMemory(
CWD,
false,
fileService,
);
const expectedContent =
`--- Context from: ${customFilename} ---\nCWD custom memory\n--- End of Context from: ${customFilename} ---\n\n` +
`--- Context from: ${path.join('subdir', customFilename)} ---\nSubdir custom memory\n--- End of Context from: ${path.join('subdir', customFilename)} ---`;
expect(memoryContent).toBe(expectedContent);
expect(fileCount).toBe(2);
});
it('should load ORIGINAL_GEMINI_MD_FILENAME files by upward traversal from CWD to project root', async () => {
const projectRootGeminiFile = path.join(
PROJECT_ROOT,
ORIGINAL_GEMINI_MD_FILENAME_CONST_FOR_TEST,
);
const srcGeminiFile = path.join(
CWD,
ORIGINAL_GEMINI_MD_FILENAME_CONST_FOR_TEST,
);
mockFs.stat.mockImplementation(async (p) => {
if (p === path.join(PROJECT_ROOT, '.git')) {
return { isDirectory: () => true } as Stats;
}
throw new Error('File not found');
});
mockFs.access.mockImplementation(async (p) => {
if (p === projectRootGeminiFile || p === srcGeminiFile) {
return undefined;
}
throw new Error('File not found');
});
mockFs.readFile.mockImplementation(async (p) => {
if (p === projectRootGeminiFile) {
return 'Project root memory';
}
if (p === srcGeminiFile) {
return 'Src directory memory';
}
throw new Error('File not found');
});
const { memoryContent, fileCount } = await loadServerHierarchicalMemory(
CWD,
false,
fileService,
);
const expectedContent =
`--- Context from: ${path.relative(CWD, projectRootGeminiFile)} ---\nProject root memory\n--- End of Context from: ${path.relative(CWD, projectRootGeminiFile)} ---\n\n` +
`--- Context from: ${ORIGINAL_GEMINI_MD_FILENAME_CONST_FOR_TEST} ---\nSrc directory memory\n--- End of Context from: ${ORIGINAL_GEMINI_MD_FILENAME_CONST_FOR_TEST} ---`;
expect(memoryContent).toBe(expectedContent);
expect(fileCount).toBe(2);
expect(mockFs.readFile).toHaveBeenCalledWith(
projectRootGeminiFile,
'utf-8',
);
expect(mockFs.readFile).toHaveBeenCalledWith(srcGeminiFile, 'utf-8');
});
it('should load ORIGINAL_GEMINI_MD_FILENAME files by downward traversal from CWD', async () => {
const subDir = path.join(CWD, 'subdir');
const subDirGeminiFile = path.join(
subDir,
ORIGINAL_GEMINI_MD_FILENAME_CONST_FOR_TEST,
);
const cwdGeminiFile = path.join(
CWD,
ORIGINAL_GEMINI_MD_FILENAME_CONST_FOR_TEST,
);
mockFs.access.mockImplementation(async (p) => {
if (p === cwdGeminiFile || p === subDirGeminiFile) return undefined;
throw new Error('File not found');
});
mockFs.readFile.mockImplementation(async (p) => {
if (p === cwdGeminiFile) return 'CWD memory';
if (p === subDirGeminiFile) return 'Subdir memory';
throw new Error('File not found');
});
mockFs.readdir.mockImplementation((async (
p: fsSync.PathLike,
): Promise<Dirent[]> => {
if (p === CWD) {
return [
{
name: ORIGINAL_GEMINI_MD_FILENAME_CONST_FOR_TEST,
isFile: () => true,
isDirectory: () => false,
} as Dirent,
{
name: 'subdir',
isFile: () => false,
isDirectory: () => true,
} as Dirent,
] as Dirent[];
}
if (p === subDir) {
return [
{
name: ORIGINAL_GEMINI_MD_FILENAME_CONST_FOR_TEST,
isFile: () => true,
isDirectory: () => false,
} as Dirent,
] as Dirent[];
}
return [] as Dirent[];
}) as unknown as typeof fsPromises.readdir);
const { memoryContent, fileCount } = await loadServerHierarchicalMemory(
CWD,
false,
fileService,
);
const expectedContent =
`--- Context from: ${ORIGINAL_GEMINI_MD_FILENAME_CONST_FOR_TEST} ---\nCWD memory\n--- End of Context from: ${ORIGINAL_GEMINI_MD_FILENAME_CONST_FOR_TEST} ---\n\n` +
`--- Context from: ${path.join('subdir', ORIGINAL_GEMINI_MD_FILENAME_CONST_FOR_TEST)} ---\nSubdir memory\n--- End of Context from: ${path.join('subdir', ORIGINAL_GEMINI_MD_FILENAME_CONST_FOR_TEST)} ---`;
expect(memoryContent).toBe(expectedContent);
expect(fileCount).toBe(2);
});
it('should load and correctly order global, upward, and downward ORIGINAL_GEMINI_MD_FILENAME files', async () => {
setGeminiMdFilename(ORIGINAL_GEMINI_MD_FILENAME_CONST_FOR_TEST); // Explicitly set for this test
const globalFileToUse = path.join(
GLOBAL_GEMINI_DIR,
ORIGINAL_GEMINI_MD_FILENAME_CONST_FOR_TEST,
);
const projectParentDir = path.dirname(PROJECT_ROOT);
const projectParentGeminiFile = path.join(
projectParentDir,
ORIGINAL_GEMINI_MD_FILENAME_CONST_FOR_TEST,
);
const projectRootGeminiFile = path.join(
PROJECT_ROOT,
ORIGINAL_GEMINI_MD_FILENAME_CONST_FOR_TEST,
);
const cwdGeminiFile = path.join(
CWD,
ORIGINAL_GEMINI_MD_FILENAME_CONST_FOR_TEST,
);
const subDir = path.join(CWD, 'sub');
const subDirGeminiFile = path.join(
subDir,
ORIGINAL_GEMINI_MD_FILENAME_CONST_FOR_TEST,
);
mockFs.stat.mockImplementation(async (p) => {
if (p === path.join(PROJECT_ROOT, '.git')) {
return { isDirectory: () => true } as Stats;
} else if (p === path.join(PROJECT_ROOT, '.gemini')) {
return { isDirectory: () => true } as Stats;
}
throw new Error('File not found');
});
mockFs.access.mockImplementation(async (p) => {
if (
p === globalFileToUse || // Use the dynamically set global file path
p === projectParentGeminiFile ||
p === projectRootGeminiFile ||
p === cwdGeminiFile ||
p === subDirGeminiFile
) {
return undefined;
}
throw new Error('File not found');
});
mockFs.readFile.mockImplementation(async (p) => {
if (p === globalFileToUse) return 'Global memory'; // Use the dynamically set global file path
if (p === projectParentGeminiFile) return 'Project parent memory';
if (p === projectRootGeminiFile) return 'Project root memory';
if (p === cwdGeminiFile) return 'CWD memory';
if (p === subDirGeminiFile) return 'Subdir memory';
throw new Error('File not found');
});
mockFs.readdir.mockImplementation((async (
p: fsSync.PathLike,
): Promise<Dirent[]> => {
if (p === CWD) {
return [
{
name: 'sub',
isFile: () => false,
isDirectory: () => true,
} as Dirent,
] as Dirent[];
}
if (p === subDir) {
return [
{
name: ORIGINAL_GEMINI_MD_FILENAME_CONST_FOR_TEST,
isFile: () => true,
isDirectory: () => false,
} as Dirent,
] as Dirent[];
}
return [] as Dirent[];
}) as unknown as typeof fsPromises.readdir);
const { memoryContent, fileCount } = await loadServerHierarchicalMemory(
CWD,
false,
fileService,
);
const relPathGlobal = path.relative(CWD, GLOBAL_GEMINI_FILE);
const relPathProjectParent = path.relative(CWD, projectParentGeminiFile);
const relPathProjectRoot = path.relative(CWD, projectRootGeminiFile);
const relPathCwd = ORIGINAL_GEMINI_MD_FILENAME_CONST_FOR_TEST;
const relPathSubDir = path.join(
'sub',
ORIGINAL_GEMINI_MD_FILENAME_CONST_FOR_TEST,
);
const expectedContent = [
`--- Context from: ${relPathGlobal} ---\nGlobal memory\n--- End of Context from: ${relPathGlobal} ---`,
`--- Context from: ${relPathProjectParent} ---\nProject parent memory\n--- End of Context from: ${relPathProjectParent} ---`,
`--- Context from: ${relPathProjectRoot} ---\nProject root memory\n--- End of Context from: ${relPathProjectRoot} ---`,
`--- Context from: ${relPathCwd} ---\nCWD memory\n--- End of Context from: ${relPathCwd} ---`,
`--- Context from: ${relPathSubDir} ---\nSubdir memory\n--- End of Context from: ${relPathSubDir} ---`,
].join('\n\n');
expect(memoryContent).toBe(expectedContent);
expect(fileCount).toBe(5);
});
it('should ignore specified directories during downward scan', async () => {
const ignoredDir = path.join(CWD, 'node_modules');
const ignoredDirGeminiFile = path.join(
ignoredDir,
ORIGINAL_GEMINI_MD_FILENAME_CONST_FOR_TEST,
); // Corrected
const regularSubDir = path.join(CWD, 'my_code');
const regularSubDirGeminiFile = path.join(
regularSubDir,
ORIGINAL_GEMINI_MD_FILENAME_CONST_FOR_TEST,
);
mockFs.access.mockImplementation(async (p) => {
if (p === regularSubDirGeminiFile) return undefined;
if (p === ignoredDirGeminiFile)
throw new Error('Should not access ignored file');
throw new Error('File not found');
});
mockFs.readFile.mockImplementation(async (p) => {
if (p === regularSubDirGeminiFile) return 'My code memory';
throw new Error('File not found');
});
mockFs.readdir.mockImplementation((async (
p: fsSync.PathLike,
): Promise<Dirent[]> => {
if (p === CWD) {
return [
{
name: 'node_modules',
isFile: () => false,
isDirectory: () => true,
} as Dirent,
{
name: 'my_code',
isFile: () => false,
isDirectory: () => true,
} as Dirent,
] as Dirent[];
}
if (p === regularSubDir) {
return [
{
name: ORIGINAL_GEMINI_MD_FILENAME_CONST_FOR_TEST,
isFile: () => true,
isDirectory: () => false,
} as Dirent,
] as Dirent[];
}
if (p === ignoredDir) {
return [] as Dirent[];
}
return [] as Dirent[];
}) as unknown as typeof fsPromises.readdir);
const { memoryContent, fileCount } = await loadServerHierarchicalMemory(
CWD,
false,
fileService,
);
const expectedContent = `--- Context from: ${path.join('my_code', ORIGINAL_GEMINI_MD_FILENAME_CONST_FOR_TEST)} ---\nMy code memory\n--- End of Context from: ${path.join('my_code', ORIGINAL_GEMINI_MD_FILENAME_CONST_FOR_TEST)} ---`;
expect(memoryContent).toBe(expectedContent);
expect(fileCount).toBe(1);
expect(mockFs.readFile).not.toHaveBeenCalledWith(
ignoredDirGeminiFile,
'utf-8',
);
});
it('should respect MAX_DIRECTORIES_TO_SCAN_FOR_MEMORY during downward scan', async () => {
const consoleDebugSpy = vi
.spyOn(console, 'debug')
.mockImplementation(() => {});
const dirNames: Dirent[] = [];
for (let i = 0; i < 250; i++) {
dirNames.push({
name: `deep_dir_${i}`,
isFile: () => false,
isDirectory: () => true,
} as Dirent);
}
mockFs.readdir.mockImplementation((async (
p: fsSync.PathLike,
): Promise<Dirent[]> => {
if (p === CWD) return dirNames;
if (p.toString().startsWith(path.join(CWD, 'deep_dir_')))
return [] as Dirent[];
return [] as Dirent[];
}) as unknown as typeof fsPromises.readdir);
mockFs.access.mockRejectedValue(new Error('not found'));
await loadServerHierarchicalMemory(CWD, true, fileService);
expect(consoleDebugSpy).toHaveBeenCalledWith(
expect.stringContaining('[DEBUG] [BfsFileSearch]'),
expect.stringContaining('Scanning [200/200]:'),
);
consoleDebugSpy.mockRestore();
});
it('should load extension context file paths', async () => {
const extensionFilePath = '/test/extensions/ext1/GEMINI.md';
mockFs.access.mockImplementation(async (p) => {
if (p === extensionFilePath) {
return undefined;
}
throw new Error('File not found');
});
mockFs.readFile.mockImplementation(async (p) => {
if (p === extensionFilePath) {
return 'Extension memory content';
}
throw new Error('File not found');
});
const { memoryContent, fileCount } = await loadServerHierarchicalMemory(
CWD,
false,
fileService,
[extensionFilePath],
);
expect(memoryContent).toBe(
`--- Context from: ${path.relative(CWD, extensionFilePath)} ---\nExtension memory content\n--- End of Context from: ${path.relative(CWD, extensionFilePath)} ---`,
);
expect(fileCount).toBe(1);
expect(mockFs.readFile).toHaveBeenCalledWith(extensionFilePath, 'utf-8');
});
});

View File

@@ -0,0 +1,319 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import * as fs from 'fs/promises';
import * as fsSync from 'fs';
import * as path from 'path';
import { homedir } from 'os';
import { bfsFileSearch } from './bfsFileSearch.js';
import {
GEMINI_CONFIG_DIR,
getAllGeminiMdFilenames,
} from '../tools/memoryTool.js';
import { FileDiscoveryService } from '../services/fileDiscoveryService.js';
import { processImports } from './memoryImportProcessor.js';
// Simple console logger, similar to the one previously in CLI's config.ts
// TODO: Integrate with a more robust server-side logger if available/appropriate.
const logger = {
// eslint-disable-next-line @typescript-eslint/no-explicit-any
debug: (...args: any[]) =>
console.debug('[DEBUG] [MemoryDiscovery]', ...args),
// eslint-disable-next-line @typescript-eslint/no-explicit-any
warn: (...args: any[]) => console.warn('[WARN] [MemoryDiscovery]', ...args),
// eslint-disable-next-line @typescript-eslint/no-explicit-any
error: (...args: any[]) =>
console.error('[ERROR] [MemoryDiscovery]', ...args),
};
const MAX_DIRECTORIES_TO_SCAN_FOR_MEMORY = 200;
interface GeminiFileContent {
filePath: string;
content: string | null;
}
async function findProjectRoot(startDir: string): Promise<string | null> {
let currentDir = path.resolve(startDir);
while (true) {
const gitPath = path.join(currentDir, '.git');
try {
const stats = await fs.stat(gitPath);
if (stats.isDirectory()) {
return currentDir;
}
} catch (error: unknown) {
// Don't log ENOENT errors as they're expected when .git doesn't exist
// Also don't log errors in test environments, which often have mocked fs
const isENOENT =
typeof error === 'object' &&
error !== null &&
'code' in error &&
(error as { code: string }).code === 'ENOENT';
// Only log unexpected errors in non-test environments
// process.env.NODE_ENV === 'test' or VITEST are common test indicators
const isTestEnv = process.env.NODE_ENV === 'test' || process.env.VITEST;
if (!isENOENT && !isTestEnv) {
if (typeof error === 'object' && error !== null && 'code' in error) {
const fsError = error as { code: string; message: string };
logger.warn(
`Error checking for .git directory at ${gitPath}: ${fsError.message}`,
);
} else {
logger.warn(
`Non-standard error checking for .git directory at ${gitPath}: ${String(error)}`,
);
}
}
}
const parentDir = path.dirname(currentDir);
if (parentDir === currentDir) {
return null;
}
currentDir = parentDir;
}
}
async function getGeminiMdFilePathsInternal(
currentWorkingDirectory: string,
userHomePath: string,
debugMode: boolean,
fileService: FileDiscoveryService,
extensionContextFilePaths: string[] = [],
): Promise<string[]> {
const allPaths = new Set<string>();
const geminiMdFilenames = getAllGeminiMdFilenames();
for (const geminiMdFilename of geminiMdFilenames) {
const resolvedCwd = path.resolve(currentWorkingDirectory);
const resolvedHome = path.resolve(userHomePath);
const globalMemoryPath = path.join(
resolvedHome,
GEMINI_CONFIG_DIR,
geminiMdFilename,
);
if (debugMode)
logger.debug(
`Searching for ${geminiMdFilename} starting from CWD: ${resolvedCwd}`,
);
if (debugMode) logger.debug(`User home directory: ${resolvedHome}`);
try {
await fs.access(globalMemoryPath, fsSync.constants.R_OK);
allPaths.add(globalMemoryPath);
if (debugMode)
logger.debug(
`Found readable global ${geminiMdFilename}: ${globalMemoryPath}`,
);
} catch {
if (debugMode)
logger.debug(
`Global ${geminiMdFilename} not found or not readable: ${globalMemoryPath}`,
);
}
const projectRoot = await findProjectRoot(resolvedCwd);
if (debugMode)
logger.debug(`Determined project root: ${projectRoot ?? 'None'}`);
const upwardPaths: string[] = [];
let currentDir = resolvedCwd;
// Determine the directory that signifies the top of the project or user-specific space.
const ultimateStopDir = projectRoot
? path.dirname(projectRoot)
: path.dirname(resolvedHome);
while (currentDir && currentDir !== path.dirname(currentDir)) {
// Loop until filesystem root or currentDir is empty
if (debugMode) {
logger.debug(
`Checking for ${geminiMdFilename} in (upward scan): ${currentDir}`,
);
}
// Skip the global .gemini directory itself during upward scan from CWD,
// as global is handled separately and explicitly first.
if (currentDir === path.join(resolvedHome, GEMINI_CONFIG_DIR)) {
if (debugMode) {
logger.debug(
`Upward scan reached global config dir path, stopping upward search here: ${currentDir}`,
);
}
break;
}
const potentialPath = path.join(currentDir, geminiMdFilename);
try {
await fs.access(potentialPath, fsSync.constants.R_OK);
// Add to upwardPaths only if it's not the already added globalMemoryPath
if (potentialPath !== globalMemoryPath) {
upwardPaths.unshift(potentialPath);
if (debugMode) {
logger.debug(
`Found readable upward ${geminiMdFilename}: ${potentialPath}`,
);
}
}
} catch {
if (debugMode) {
logger.debug(
`Upward ${geminiMdFilename} not found or not readable in: ${currentDir}`,
);
}
}
// Stop condition: if currentDir is the ultimateStopDir, break after this iteration.
if (currentDir === ultimateStopDir) {
if (debugMode)
logger.debug(
`Reached ultimate stop directory for upward scan: ${currentDir}`,
);
break;
}
currentDir = path.dirname(currentDir);
}
upwardPaths.forEach((p) => allPaths.add(p));
const downwardPaths = await bfsFileSearch(resolvedCwd, {
fileName: geminiMdFilename,
maxDirs: MAX_DIRECTORIES_TO_SCAN_FOR_MEMORY,
debug: debugMode,
fileService,
});
downwardPaths.sort(); // Sort for consistent ordering, though hierarchy might be more complex
if (debugMode && downwardPaths.length > 0)
logger.debug(
`Found downward ${geminiMdFilename} files (sorted): ${JSON.stringify(
downwardPaths,
)}`,
);
// Add downward paths only if they haven't been included already (e.g. from upward scan)
for (const dPath of downwardPaths) {
allPaths.add(dPath);
}
}
// Add extension context file paths
for (const extensionPath of extensionContextFilePaths) {
allPaths.add(extensionPath);
}
const finalPaths = Array.from(allPaths);
if (debugMode)
logger.debug(
`Final ordered ${getAllGeminiMdFilenames()} paths to read: ${JSON.stringify(
finalPaths,
)}`,
);
return finalPaths;
}
async function readGeminiMdFiles(
filePaths: string[],
debugMode: boolean,
): Promise<GeminiFileContent[]> {
const results: GeminiFileContent[] = [];
for (const filePath of filePaths) {
try {
const content = await fs.readFile(filePath, 'utf-8');
// Process imports in the content
const processedContent = await processImports(
content,
path.dirname(filePath),
debugMode,
);
results.push({ filePath, content: processedContent });
if (debugMode)
logger.debug(
`Successfully read and processed imports: ${filePath} (Length: ${processedContent.length})`,
);
} catch (error: unknown) {
const isTestEnv = process.env.NODE_ENV === 'test' || process.env.VITEST;
if (!isTestEnv) {
const message = error instanceof Error ? error.message : String(error);
logger.warn(
`Warning: Could not read ${getAllGeminiMdFilenames()} file at ${filePath}. Error: ${message}`,
);
}
results.push({ filePath, content: null }); // Still include it with null content
if (debugMode) logger.debug(`Failed to read: ${filePath}`);
}
}
return results;
}
function concatenateInstructions(
instructionContents: GeminiFileContent[],
// CWD is needed to resolve relative paths for display markers
currentWorkingDirectoryForDisplay: string,
): string {
return instructionContents
.filter((item) => typeof item.content === 'string')
.map((item) => {
const trimmedContent = (item.content as string).trim();
if (trimmedContent.length === 0) {
return null;
}
const displayPath = path.isAbsolute(item.filePath)
? path.relative(currentWorkingDirectoryForDisplay, item.filePath)
: item.filePath;
return `--- Context from: ${displayPath} ---\n${trimmedContent}\n--- End of Context from: ${displayPath} ---`;
})
.filter((block): block is string => block !== null)
.join('\n\n');
}
/**
* Loads hierarchical GEMINI.md files and concatenates their content.
* This function is intended for use by the server.
*/
export async function loadServerHierarchicalMemory(
currentWorkingDirectory: string,
debugMode: boolean,
fileService: FileDiscoveryService,
extensionContextFilePaths: string[] = [],
): Promise<{ memoryContent: string; fileCount: number }> {
if (debugMode)
logger.debug(
`Loading server hierarchical memory for CWD: ${currentWorkingDirectory}`,
);
// For the server, homedir() refers to the server process's home.
// This is consistent with how MemoryTool already finds the global path.
const userHomePath = homedir();
const filePaths = await getGeminiMdFilePathsInternal(
currentWorkingDirectory,
userHomePath,
debugMode,
fileService,
extensionContextFilePaths,
);
if (filePaths.length === 0) {
if (debugMode) logger.debug('No GEMINI.md files found in hierarchy.');
return { memoryContent: '', fileCount: 0 };
}
const contentsWithPaths = await readGeminiMdFiles(filePaths, debugMode);
// Pass CWD for relative path display in concatenated content
const combinedInstructions = concatenateInstructions(
contentsWithPaths,
currentWorkingDirectory,
);
if (debugMode)
logger.debug(
`Combined instructions length: ${combinedInstructions.length}`,
);
if (debugMode && combinedInstructions.length > 0)
logger.debug(
`Combined instructions (snippet): ${combinedInstructions.substring(0, 500)}...`,
);
return { memoryContent: combinedInstructions, fileCount: filePaths.length };
}

View File

@@ -0,0 +1,257 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
import * as fs from 'fs/promises';
import * as path from 'path';
import { processImports, validateImportPath } from './memoryImportProcessor.js';
// Mock fs/promises
vi.mock('fs/promises');
const mockedFs = vi.mocked(fs);
// Mock console methods to capture warnings
const originalConsoleWarn = console.warn;
const originalConsoleError = console.error;
const originalConsoleDebug = console.debug;
describe('memoryImportProcessor', () => {
beforeEach(() => {
vi.clearAllMocks();
// Mock console methods
console.warn = vi.fn();
console.error = vi.fn();
console.debug = vi.fn();
});
afterEach(() => {
// Restore console methods
console.warn = originalConsoleWarn;
console.error = originalConsoleError;
console.debug = originalConsoleDebug;
});
describe('processImports', () => {
it('should process basic md file imports', async () => {
const content = 'Some content @./test.md more content';
const basePath = '/test/path';
const importedContent = '# Imported Content\nThis is imported.';
mockedFs.access.mockResolvedValue(undefined);
mockedFs.readFile.mockResolvedValue(importedContent);
const result = await processImports(content, basePath, true);
expect(result).toContain('<!-- Imported from: ./test.md -->');
expect(result).toContain(importedContent);
expect(result).toContain('<!-- End of import from: ./test.md -->');
expect(mockedFs.readFile).toHaveBeenCalledWith(
path.resolve(basePath, './test.md'),
'utf-8',
);
});
it('should warn and fail for non-md file imports', async () => {
const content = 'Some content @./instructions.txt more content';
const basePath = '/test/path';
const result = await processImports(content, basePath, true);
expect(console.warn).toHaveBeenCalledWith(
'[WARN] [ImportProcessor]',
'Import processor only supports .md files. Attempting to import non-md file: ./instructions.txt. This will fail.',
);
expect(result).toContain(
'<!-- Import failed: ./instructions.txt - Only .md files are supported -->',
);
expect(mockedFs.readFile).not.toHaveBeenCalled();
});
it('should handle circular imports', async () => {
const content = 'Content @./circular.md more content';
const basePath = '/test/path';
const circularContent = 'Circular @./main.md content';
mockedFs.access.mockResolvedValue(undefined);
mockedFs.readFile.mockResolvedValue(circularContent);
// Set up the import state to simulate we're already processing main.md
const importState = {
processedFiles: new Set<string>(),
maxDepth: 10,
currentDepth: 0,
currentFile: '/test/path/main.md', // Simulate we're processing main.md
};
const result = await processImports(content, basePath, true, importState);
// The circular import should be detected when processing the nested import
expect(result).toContain('<!-- Circular import detected: ./main.md -->');
});
it('should handle file not found errors', async () => {
const content = 'Content @./nonexistent.md more content';
const basePath = '/test/path';
mockedFs.access.mockRejectedValue(new Error('File not found'));
const result = await processImports(content, basePath, true);
expect(result).toContain(
'<!-- Import failed: ./nonexistent.md - File not found -->',
);
expect(console.error).toHaveBeenCalledWith(
'[ERROR] [ImportProcessor]',
'Failed to import ./nonexistent.md: File not found',
);
});
it('should respect max depth limit', async () => {
const content = 'Content @./deep.md more content';
const basePath = '/test/path';
const deepContent = 'Deep @./deeper.md content';
mockedFs.access.mockResolvedValue(undefined);
mockedFs.readFile.mockResolvedValue(deepContent);
const importState = {
processedFiles: new Set<string>(),
maxDepth: 1,
currentDepth: 1,
};
const result = await processImports(content, basePath, true, importState);
expect(console.warn).toHaveBeenCalledWith(
'[WARN] [ImportProcessor]',
'Maximum import depth (1) reached. Stopping import processing.',
);
expect(result).toBe(content);
});
it('should handle nested imports recursively', async () => {
const content = 'Main @./nested.md content';
const basePath = '/test/path';
const nestedContent = 'Nested @./inner.md content';
const innerContent = 'Inner content';
mockedFs.access.mockResolvedValue(undefined);
mockedFs.readFile
.mockResolvedValueOnce(nestedContent)
.mockResolvedValueOnce(innerContent);
const result = await processImports(content, basePath, true);
expect(result).toContain('<!-- Imported from: ./nested.md -->');
expect(result).toContain('<!-- Imported from: ./inner.md -->');
expect(result).toContain(innerContent);
});
it('should handle absolute paths in imports', async () => {
const content = 'Content @/absolute/path/file.md more content';
const basePath = '/test/path';
const importedContent = 'Absolute path content';
mockedFs.access.mockResolvedValue(undefined);
mockedFs.readFile.mockResolvedValue(importedContent);
const result = await processImports(content, basePath, true);
expect(result).toContain(
'<!-- Import failed: /absolute/path/file.md - Path traversal attempt -->',
);
});
it('should handle multiple imports in same content', async () => {
const content = 'Start @./first.md middle @./second.md end';
const basePath = '/test/path';
const firstContent = 'First content';
const secondContent = 'Second content';
mockedFs.access.mockResolvedValue(undefined);
mockedFs.readFile
.mockResolvedValueOnce(firstContent)
.mockResolvedValueOnce(secondContent);
const result = await processImports(content, basePath, true);
expect(result).toContain('<!-- Imported from: ./first.md -->');
expect(result).toContain('<!-- Imported from: ./second.md -->');
expect(result).toContain(firstContent);
expect(result).toContain(secondContent);
});
});
describe('validateImportPath', () => {
it('should reject URLs', () => {
expect(
validateImportPath('https://example.com/file.md', '/base', [
'/allowed',
]),
).toBe(false);
expect(
validateImportPath('http://example.com/file.md', '/base', ['/allowed']),
).toBe(false);
expect(
validateImportPath('file:///path/to/file.md', '/base', ['/allowed']),
).toBe(false);
});
it('should allow paths within allowed directories', () => {
expect(validateImportPath('./file.md', '/base', ['/base'])).toBe(true);
expect(validateImportPath('../file.md', '/base', ['/allowed'])).toBe(
false,
);
expect(
validateImportPath('/allowed/sub/file.md', '/base', ['/allowed']),
).toBe(true);
});
it('should reject paths outside allowed directories', () => {
expect(
validateImportPath('/forbidden/file.md', '/base', ['/allowed']),
).toBe(false);
expect(validateImportPath('../../../file.md', '/base', ['/base'])).toBe(
false,
);
});
it('should handle multiple allowed directories', () => {
expect(
validateImportPath('./file.md', '/base', ['/allowed1', '/allowed2']),
).toBe(false);
expect(
validateImportPath('/allowed1/file.md', '/base', [
'/allowed1',
'/allowed2',
]),
).toBe(true);
expect(
validateImportPath('/allowed2/file.md', '/base', [
'/allowed1',
'/allowed2',
]),
).toBe(true);
});
it('should handle relative paths correctly', () => {
expect(validateImportPath('file.md', '/base', ['/base'])).toBe(true);
expect(validateImportPath('./file.md', '/base', ['/base'])).toBe(true);
expect(validateImportPath('../file.md', '/base', ['/parent'])).toBe(
false,
);
});
it('should handle absolute paths correctly', () => {
expect(
validateImportPath('/allowed/file.md', '/base', ['/allowed']),
).toBe(true);
expect(
validateImportPath('/forbidden/file.md', '/base', ['/allowed']),
).toBe(false);
});
});
});

View File

@@ -0,0 +1,214 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import * as fs from 'fs/promises';
import * as path from 'path';
// Simple console logger for import processing
const logger = {
// eslint-disable-next-line @typescript-eslint/no-explicit-any
debug: (...args: any[]) =>
console.debug('[DEBUG] [ImportProcessor]', ...args),
// eslint-disable-next-line @typescript-eslint/no-explicit-any
warn: (...args: any[]) => console.warn('[WARN] [ImportProcessor]', ...args),
// eslint-disable-next-line @typescript-eslint/no-explicit-any
error: (...args: any[]) =>
console.error('[ERROR] [ImportProcessor]', ...args),
};
/**
* Interface for tracking import processing state to prevent circular imports
*/
interface ImportState {
processedFiles: Set<string>;
maxDepth: number;
currentDepth: number;
currentFile?: string; // Track the current file being processed
}
/**
* Processes import statements in GEMINI.md content
* Supports @path/to/file.md syntax for importing content from other files
*
* @param content - The content to process for imports
* @param basePath - The directory path where the current file is located
* @param debugMode - Whether to enable debug logging
* @param importState - State tracking for circular import prevention
* @returns Processed content with imports resolved
*/
export async function processImports(
content: string,
basePath: string,
debugMode: boolean = false,
importState: ImportState = {
processedFiles: new Set(),
maxDepth: 10,
currentDepth: 0,
},
): Promise<string> {
if (importState.currentDepth >= importState.maxDepth) {
if (debugMode) {
logger.warn(
`Maximum import depth (${importState.maxDepth}) reached. Stopping import processing.`,
);
}
return content;
}
// Regex to match @path/to/file imports (supports any file extension)
// Supports both @path/to/file.md and @./path/to/file.md syntax
const importRegex = /@([./]?[^\s\n]+\.[^\s\n]+)/g;
let processedContent = content;
let match: RegExpExecArray | null;
// Process all imports in the content
while ((match = importRegex.exec(content)) !== null) {
const importPath = match[1];
// Validate import path to prevent path traversal attacks
if (!validateImportPath(importPath, basePath, [basePath])) {
processedContent = processedContent.replace(
match[0],
`<!-- Import failed: ${importPath} - Path traversal attempt -->`,
);
continue;
}
// Check if the import is for a non-md file and warn
if (!importPath.endsWith('.md')) {
logger.warn(
`Import processor only supports .md files. Attempting to import non-md file: ${importPath}. This will fail.`,
);
// Replace the import with a warning comment
processedContent = processedContent.replace(
match[0],
`<!-- Import failed: ${importPath} - Only .md files are supported -->`,
);
continue;
}
const fullPath = path.resolve(basePath, importPath);
if (debugMode) {
logger.debug(`Processing import: ${importPath} -> ${fullPath}`);
}
// Check for circular imports - if we're already processing this file
if (importState.currentFile === fullPath) {
if (debugMode) {
logger.warn(`Circular import detected: ${importPath}`);
}
// Replace the import with a warning comment
processedContent = processedContent.replace(
match[0],
`<!-- Circular import detected: ${importPath} -->`,
);
continue;
}
// Check if we've already processed this file in this import chain
if (importState.processedFiles.has(fullPath)) {
if (debugMode) {
logger.warn(`File already processed in this chain: ${importPath}`);
}
// Replace the import with a warning comment
processedContent = processedContent.replace(
match[0],
`<!-- File already processed: ${importPath} -->`,
);
continue;
}
// Check for potential circular imports by looking at the import chain
if (importState.currentFile) {
const currentFileDir = path.dirname(importState.currentFile);
const potentialCircularPath = path.resolve(currentFileDir, importPath);
if (potentialCircularPath === importState.currentFile) {
if (debugMode) {
logger.warn(`Circular import detected: ${importPath}`);
}
// Replace the import with a warning comment
processedContent = processedContent.replace(
match[0],
`<!-- Circular import detected: ${importPath} -->`,
);
continue;
}
}
try {
// Check if the file exists
await fs.access(fullPath);
// Read the imported file content
const importedContent = await fs.readFile(fullPath, 'utf-8');
if (debugMode) {
logger.debug(`Successfully read imported file: ${fullPath}`);
}
// Recursively process imports in the imported content
const processedImportedContent = await processImports(
importedContent,
path.dirname(fullPath),
debugMode,
{
...importState,
processedFiles: new Set([...importState.processedFiles, fullPath]),
currentDepth: importState.currentDepth + 1,
currentFile: fullPath, // Set the current file being processed
},
);
// Replace the import statement with the processed content
processedContent = processedContent.replace(
match[0],
`<!-- Imported from: ${importPath} -->\n${processedImportedContent}\n<!-- End of import from: ${importPath} -->`,
);
} catch (error) {
const errorMessage =
error instanceof Error ? error.message : String(error);
if (debugMode) {
logger.error(`Failed to import ${importPath}: ${errorMessage}`);
}
// Replace the import with an error comment
processedContent = processedContent.replace(
match[0],
`<!-- Import failed: ${importPath} - ${errorMessage} -->`,
);
}
}
return processedContent;
}
/**
* Validates import paths to ensure they are safe and within allowed directories
*
* @param importPath - The import path to validate
* @param basePath - The base directory for resolving relative paths
* @param allowedDirectories - Array of allowed directory paths
* @returns Whether the import path is valid
*/
export function validateImportPath(
importPath: string,
basePath: string,
allowedDirectories: string[],
): boolean {
// Reject URLs
if (/^(file|https?):\/\//.test(importPath)) {
return false;
}
const resolvedPath = path.resolve(basePath, importPath);
return allowedDirectories.some((allowedDir) => {
const normalizedAllowedDir = path.resolve(allowedDir);
return resolvedPath.startsWith(normalizedAllowedDir);
});
}

View File

@@ -0,0 +1,23 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { Content } from '@google/genai';
export function isFunctionResponse(content: Content): boolean {
return (
content.role === 'user' &&
!!content.parts &&
content.parts.every((part) => !!part.functionResponse)
);
}
export function isFunctionCall(content: Content): boolean {
return (
content.role === 'model' &&
!!content.parts &&
content.parts.every((part) => !!part.functionCall)
);
}

View File

@@ -0,0 +1,253 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { describe, it, expect, vi, beforeEach, Mock, afterEach } from 'vitest';
import { Content, GoogleGenAI, Models } from '@google/genai';
import { DEFAULT_GEMINI_FLASH_MODEL } from '../config/models.js';
import { GeminiClient } from '../core/client.js';
import { Config } from '../config/config.js';
import { checkNextSpeaker, NextSpeakerResponse } from './nextSpeakerChecker.js';
import { GeminiChat } from '../core/geminiChat.js';
// Mock GeminiClient and Config constructor
vi.mock('../core/client.js');
vi.mock('../config/config.js');
// Define mocks for GoogleGenAI and Models instances that will be used across tests
const mockModelsInstance = {
generateContent: vi.fn(),
generateContentStream: vi.fn(),
countTokens: vi.fn(),
embedContent: vi.fn(),
batchEmbedContents: vi.fn(),
} as unknown as Models;
const mockGoogleGenAIInstance = {
getGenerativeModel: vi.fn().mockReturnValue(mockModelsInstance),
// Add other methods of GoogleGenAI if they are directly used by GeminiChat constructor or its methods
} as unknown as GoogleGenAI;
vi.mock('@google/genai', async () => {
const actualGenAI =
await vi.importActual<typeof import('@google/genai')>('@google/genai');
return {
...actualGenAI,
GoogleGenAI: vi.fn(() => mockGoogleGenAIInstance), // Mock constructor to return the predefined instance
// If Models is instantiated directly in GeminiChat, mock its constructor too
// For now, assuming Models instance is obtained via getGenerativeModel
};
});
describe('checkNextSpeaker', () => {
let chatInstance: GeminiChat;
let mockGeminiClient: GeminiClient;
let MockConfig: Mock;
const abortSignal = new AbortController().signal;
beforeEach(() => {
MockConfig = vi.mocked(Config);
const mockConfigInstance = new MockConfig(
'test-api-key',
'gemini-pro',
false,
'.',
false,
undefined,
false,
undefined,
undefined,
undefined,
);
mockGeminiClient = new GeminiClient(mockConfigInstance);
// Reset mocks before each test to ensure test isolation
vi.mocked(mockModelsInstance.generateContent).mockReset();
vi.mocked(mockModelsInstance.generateContentStream).mockReset();
// GeminiChat will receive the mocked instances via the mocked GoogleGenAI constructor
chatInstance = new GeminiChat(
mockConfigInstance,
mockModelsInstance, // This is the instance returned by mockGoogleGenAIInstance.getGenerativeModel
{},
[], // initial history
);
// Spy on getHistory for chatInstance
vi.spyOn(chatInstance, 'getHistory');
});
afterEach(() => {
vi.clearAllMocks();
});
it('should return null if history is empty', async () => {
(chatInstance.getHistory as Mock).mockReturnValue([]);
const result = await checkNextSpeaker(
chatInstance,
mockGeminiClient,
abortSignal,
);
expect(result).toBeNull();
expect(mockGeminiClient.generateJson).not.toHaveBeenCalled();
});
it('should return null if the last speaker was the user', async () => {
(chatInstance.getHistory as Mock).mockReturnValue([
{ role: 'user', parts: [{ text: 'Hello' }] },
] as Content[]);
const result = await checkNextSpeaker(
chatInstance,
mockGeminiClient,
abortSignal,
);
expect(result).toBeNull();
expect(mockGeminiClient.generateJson).not.toHaveBeenCalled();
});
it("should return { next_speaker: 'model' } when model intends to continue", async () => {
(chatInstance.getHistory as Mock).mockReturnValue([
{ role: 'model', parts: [{ text: 'I will now do something.' }] },
] as Content[]);
const mockApiResponse: NextSpeakerResponse = {
reasoning: 'Model stated it will do something.',
next_speaker: 'model',
};
(mockGeminiClient.generateJson as Mock).mockResolvedValue(mockApiResponse);
const result = await checkNextSpeaker(
chatInstance,
mockGeminiClient,
abortSignal,
);
expect(result).toEqual(mockApiResponse);
expect(mockGeminiClient.generateJson).toHaveBeenCalledTimes(1);
});
it("should return { next_speaker: 'user' } when model asks a question", async () => {
(chatInstance.getHistory as Mock).mockReturnValue([
{ role: 'model', parts: [{ text: 'What would you like to do?' }] },
] as Content[]);
const mockApiResponse: NextSpeakerResponse = {
reasoning: 'Model asked a question.',
next_speaker: 'user',
};
(mockGeminiClient.generateJson as Mock).mockResolvedValue(mockApiResponse);
const result = await checkNextSpeaker(
chatInstance,
mockGeminiClient,
abortSignal,
);
expect(result).toEqual(mockApiResponse);
});
it("should return { next_speaker: 'user' } when model makes a statement", async () => {
(chatInstance.getHistory as Mock).mockReturnValue([
{ role: 'model', parts: [{ text: 'This is a statement.' }] },
] as Content[]);
const mockApiResponse: NextSpeakerResponse = {
reasoning: 'Model made a statement, awaiting user input.',
next_speaker: 'user',
};
(mockGeminiClient.generateJson as Mock).mockResolvedValue(mockApiResponse);
const result = await checkNextSpeaker(
chatInstance,
mockGeminiClient,
abortSignal,
);
expect(result).toEqual(mockApiResponse);
});
it('should return null if geminiClient.generateJson throws an error', async () => {
const consoleWarnSpy = vi
.spyOn(console, 'warn')
.mockImplementation(() => {});
(chatInstance.getHistory as Mock).mockReturnValue([
{ role: 'model', parts: [{ text: 'Some model output.' }] },
] as Content[]);
(mockGeminiClient.generateJson as Mock).mockRejectedValue(
new Error('API Error'),
);
const result = await checkNextSpeaker(
chatInstance,
mockGeminiClient,
abortSignal,
);
expect(result).toBeNull();
consoleWarnSpy.mockRestore();
});
it('should return null if geminiClient.generateJson returns invalid JSON (missing next_speaker)', async () => {
(chatInstance.getHistory as Mock).mockReturnValue([
{ role: 'model', parts: [{ text: 'Some model output.' }] },
] as Content[]);
(mockGeminiClient.generateJson as Mock).mockResolvedValue({
reasoning: 'This is incomplete.',
} as unknown as NextSpeakerResponse); // Type assertion to simulate invalid response
const result = await checkNextSpeaker(
chatInstance,
mockGeminiClient,
abortSignal,
);
expect(result).toBeNull();
});
it('should return null if geminiClient.generateJson returns a non-string next_speaker', async () => {
(chatInstance.getHistory as Mock).mockReturnValue([
{ role: 'model', parts: [{ text: 'Some model output.' }] },
] as Content[]);
(mockGeminiClient.generateJson as Mock).mockResolvedValue({
reasoning: 'Model made a statement, awaiting user input.',
next_speaker: 123, // Invalid type
} as unknown as NextSpeakerResponse);
const result = await checkNextSpeaker(
chatInstance,
mockGeminiClient,
abortSignal,
);
expect(result).toBeNull();
});
it('should return null if geminiClient.generateJson returns an invalid next_speaker string value', async () => {
(chatInstance.getHistory as Mock).mockReturnValue([
{ role: 'model', parts: [{ text: 'Some model output.' }] },
] as Content[]);
(mockGeminiClient.generateJson as Mock).mockResolvedValue({
reasoning: 'Model made a statement, awaiting user input.',
next_speaker: 'neither', // Invalid enum value
} as unknown as NextSpeakerResponse);
const result = await checkNextSpeaker(
chatInstance,
mockGeminiClient,
abortSignal,
);
expect(result).toBeNull();
});
it('should call generateJson with DEFAULT_GEMINI_FLASH_MODEL', async () => {
(chatInstance.getHistory as Mock).mockReturnValue([
{ role: 'model', parts: [{ text: 'Some model output.' }] },
] as Content[]);
const mockApiResponse: NextSpeakerResponse = {
reasoning: 'Model made a statement, awaiting user input.',
next_speaker: 'user',
};
(mockGeminiClient.generateJson as Mock).mockResolvedValue(mockApiResponse);
await checkNextSpeaker(chatInstance, mockGeminiClient, abortSignal);
expect(mockGeminiClient.generateJson).toHaveBeenCalled();
const generateJsonCall = (mockGeminiClient.generateJson as Mock).mock
.calls[0];
expect(generateJsonCall[3]).toBe(DEFAULT_GEMINI_FLASH_MODEL);
});
});

View File

@@ -0,0 +1,153 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { Content, SchemaUnion, Type } from '@google/genai';
import { DEFAULT_GEMINI_FLASH_MODEL } from '../config/models.js';
import { GeminiClient } from '../core/client.js';
import { GeminiChat } from '../core/geminiChat.js';
import { isFunctionResponse } from './messageInspectors.js';
const CHECK_PROMPT = `Analyze *only* the content and structure of your immediately preceding response (your last turn in the conversation history). Based *strictly* on that response, determine who should logically speak next: the 'user' or the 'model' (you).
**Decision Rules (apply in order):**
1. **Model Continues:** If your last response explicitly states an immediate next action *you* intend to take (e.g., "Next, I will...", "Now I'll process...", "Moving on to analyze...", indicates an intended tool call that didn't execute), OR if the response seems clearly incomplete (cut off mid-thought without a natural conclusion), then the **'model'** should speak next.
2. **Question to User:** If your last response ends with a direct question specifically addressed *to the user*, then the **'user'** should speak next.
3. **Waiting for User:** If your last response completed a thought, statement, or task *and* does not meet the criteria for Rule 1 (Model Continues) or Rule 2 (Question to User), it implies a pause expecting user input or reaction. In this case, the **'user'** should speak next.
**Output Format:**
Respond *only* in JSON format according to the following schema. Do not include any text outside the JSON structure.
\`\`\`json
{
"type": "object",
"properties": {
"reasoning": {
"type": "string",
"description": "Brief explanation justifying the 'next_speaker' choice based *strictly* on the applicable rule and the content/structure of the preceding turn."
},
"next_speaker": {
"type": "string",
"enum": ["user", "model"],
"description": "Who should speak next based *only* on the preceding turn and the decision rules."
}
},
"required": ["next_speaker", "reasoning"]
}
\`\`\`
`;
const RESPONSE_SCHEMA: SchemaUnion = {
type: Type.OBJECT,
properties: {
reasoning: {
type: Type.STRING,
description:
"Brief explanation justifying the 'next_speaker' choice based *strictly* on the applicable rule and the content/structure of the preceding turn.",
},
next_speaker: {
type: Type.STRING,
enum: ['user', 'model'],
description:
'Who should speak next based *only* on the preceding turn and the decision rules',
},
},
required: ['reasoning', 'next_speaker'],
};
export interface NextSpeakerResponse {
reasoning: string;
next_speaker: 'user' | 'model';
}
export async function checkNextSpeaker(
chat: GeminiChat,
geminiClient: GeminiClient,
abortSignal: AbortSignal,
): Promise<NextSpeakerResponse | null> {
// We need to capture the curated history because there are many moments when the model will return invalid turns
// that when passed back up to the endpoint will break subsequent calls. An example of this is when the model decides
// to respond with an empty part collection if you were to send that message back to the server it will respond with
// a 400 indicating that model part collections MUST have content.
const curatedHistory = chat.getHistory(/* curated */ true);
// Ensure there's a model response to analyze
if (curatedHistory.length === 0) {
// Cannot determine next speaker if history is empty.
return null;
}
const comprehensiveHistory = chat.getHistory();
// If comprehensiveHistory is empty, there is no last message to check.
// This case should ideally be caught by the curatedHistory.length check earlier,
// but as a safeguard:
if (comprehensiveHistory.length === 0) {
return null;
}
const lastComprehensiveMessage =
comprehensiveHistory[comprehensiveHistory.length - 1];
// If the last message is a user message containing only function_responses,
// then the model should speak next.
if (
lastComprehensiveMessage &&
isFunctionResponse(lastComprehensiveMessage)
) {
return {
reasoning:
'The last message was a function response, so the model should speak next.',
next_speaker: 'model',
};
}
if (
lastComprehensiveMessage &&
lastComprehensiveMessage.role === 'model' &&
lastComprehensiveMessage.parts &&
lastComprehensiveMessage.parts.length === 0
) {
lastComprehensiveMessage.parts.push({ text: '' });
return {
reasoning:
'The last message was a filler model message with no content (nothing for user to act on), model should speak next.',
next_speaker: 'model',
};
}
// Things checked out. Let's proceed to potentially making an LLM request.
const lastMessage = curatedHistory[curatedHistory.length - 1];
if (!lastMessage || lastMessage.role !== 'model') {
// Cannot determine next speaker if the last turn wasn't from the model
// or if history is empty.
return null;
}
const contents: Content[] = [
...curatedHistory,
{ role: 'user', parts: [{ text: CHECK_PROMPT }] },
];
try {
const parsedResponse = (await geminiClient.generateJson(
contents,
RESPONSE_SCHEMA,
abortSignal,
DEFAULT_GEMINI_FLASH_MODEL,
)) as unknown as NextSpeakerResponse;
if (
parsedResponse &&
parsedResponse.next_speaker &&
['user', 'model'].includes(parsedResponse.next_speaker)
) {
return parsedResponse;
}
return null;
} catch (error) {
console.warn(
'Failed to talk to Gemini endpoint when seeing if conversation should continue.',
error,
);
return null;
}
}

View File

@@ -0,0 +1,362 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import path from 'node:path';
import fs from 'node:fs/promises';
import { openaiLogger } from './openaiLogger.js';
/**
* OpenAI API usage analytics
*
* This utility analyzes OpenAI API logs to provide insights into API usage
* patterns, costs, and performance.
*/
export class OpenAIAnalytics {
/**
* Calculate statistics for OpenAI API usage
* @param days Number of days to analyze (default: 7)
*/
static async calculateStats(days: number = 7): Promise<{
totalRequests: number;
successRate: number;
avgResponseTime: number;
requestsByModel: Record<string, number>;
tokenUsage: {
promptTokens: number;
completionTokens: number;
totalTokens: number;
};
estimatedCost: number;
errorRates: Record<string, number>;
timeDistribution: Record<string, number>;
}> {
const logs = await openaiLogger.getLogFiles();
const now = new Date();
const cutoffDate = new Date(now.getTime() - days * 24 * 60 * 60 * 1000);
let totalRequests = 0;
let successfulRequests = 0;
const totalResponseTime = 0;
const requestsByModel: Record<string, number> = {};
const tokenUsage = { promptTokens: 0, completionTokens: 0, totalTokens: 0 };
const errorTypes: Record<string, number> = {};
const hourDistribution: Record<string, number> = {};
// Initialize hour distribution (0-23)
for (let i = 0; i < 24; i++) {
const hour = i.toString().padStart(2, '0');
hourDistribution[hour] = 0;
}
// Model pricing estimates (per 1000 tokens)
const pricing: Record<string, { input: number; output: number }> = {
'gpt-4': { input: 0.03, output: 0.06 },
'gpt-4-32k': { input: 0.06, output: 0.12 },
'gpt-4-1106-preview': { input: 0.01, output: 0.03 },
'gpt-4-0125-preview': { input: 0.01, output: 0.03 },
'gpt-4-0613': { input: 0.03, output: 0.06 },
'gpt-4-32k-0613': { input: 0.06, output: 0.12 },
'gpt-3.5-turbo': { input: 0.0015, output: 0.002 },
'gpt-3.5-turbo-16k': { input: 0.003, output: 0.004 },
'gpt-3.5-turbo-0613': { input: 0.0015, output: 0.002 },
'gpt-3.5-turbo-16k-0613': { input: 0.003, output: 0.004 },
};
// Default pricing for unknown models
const defaultPricing = { input: 0.01, output: 0.03 };
let estimatedCost = 0;
for (const logFile of logs) {
try {
const logData = await openaiLogger.readLogFile(logFile);
// Type guard to check if logData has the expected structure
if (!isObjectWith<{ timestamp: string }>(logData, ['timestamp'])) {
continue; // Skip malformed logs
}
const logDate = new Date(logData.timestamp);
// Skip if log is older than the cutoff date
if (logDate < cutoffDate) {
continue;
}
totalRequests++;
const hour = logDate.getUTCHours().toString().padStart(2, '0');
hourDistribution[hour]++;
// Check if request was successful
if (
isObjectWith<{ response?: unknown; error?: unknown }>(logData, [
'response',
'error',
]) &&
logData.response &&
!logData.error
) {
successfulRequests++;
// Extract model if available
const model = getModelFromLog(logData);
if (model) {
requestsByModel[model] = (requestsByModel[model] || 0) + 1;
}
// Extract token usage if available
const usage = getTokenUsageFromLog(logData);
if (usage) {
tokenUsage.promptTokens += usage.prompt_tokens || 0;
tokenUsage.completionTokens += usage.completion_tokens || 0;
tokenUsage.totalTokens += usage.total_tokens || 0;
// Calculate cost if model is known
const modelName = model || 'unknown';
const modelPricing = pricing[modelName] || defaultPricing;
const inputCost =
((usage.prompt_tokens || 0) / 1000) * modelPricing.input;
const outputCost =
((usage.completion_tokens || 0) / 1000) * modelPricing.output;
estimatedCost += inputCost + outputCost;
}
} else if (
isObjectWith<{ error?: unknown }>(logData, ['error']) &&
logData.error
) {
// Categorize errors
const errorType = getErrorTypeFromLog(logData);
errorTypes[errorType] = (errorTypes[errorType] || 0) + 1;
}
} catch (error) {
console.error(`Error processing log file ${logFile}:`, error);
}
}
// Calculate success rate and average response time
const successRate =
totalRequests > 0 ? (successfulRequests / totalRequests) * 100 : 0;
const avgResponseTime =
totalRequests > 0 ? totalResponseTime / totalRequests : 0;
// Calculate error rates as percentages
const errorRates: Record<string, number> = {};
for (const [errorType, count] of Object.entries(errorTypes)) {
errorRates[errorType] =
totalRequests > 0 ? (count / totalRequests) * 100 : 0;
}
return {
totalRequests,
successRate,
avgResponseTime,
requestsByModel,
tokenUsage,
estimatedCost,
errorRates,
timeDistribution: hourDistribution,
};
}
/**
* Generate a human-readable report of OpenAI API usage
* @param days Number of days to include in the report
*/
static async generateReport(days: number = 7): Promise<string> {
const stats = await this.calculateStats(days);
let report = `# OpenAI API Usage Report\n`;
report += `## Last ${days} days (${new Date().toISOString().split('T')[0]})\n\n`;
report += `### Overview\n`;
report += `- Total Requests: ${stats.totalRequests}\n`;
report += `- Success Rate: ${stats.successRate.toFixed(2)}%\n`;
report += `- Estimated Cost: $${stats.estimatedCost.toFixed(2)}\n\n`;
report += `### Token Usage\n`;
report += `- Prompt Tokens: ${stats.tokenUsage.promptTokens.toLocaleString()}\n`;
report += `- Completion Tokens: ${stats.tokenUsage.completionTokens.toLocaleString()}\n`;
report += `- Total Tokens: ${stats.tokenUsage.totalTokens.toLocaleString()}\n\n`;
report += `### Models Used\n`;
const sortedModels = Object.entries(stats.requestsByModel) as Array<
[string, number]
>;
sortedModels.sort((a, b) => b[1] - a[1]);
for (const [model, count] of sortedModels) {
const percentage = ((count / stats.totalRequests) * 100).toFixed(1);
report += `- ${model}: ${count} requests (${percentage}%)\n`;
}
if (Object.keys(stats.errorRates).length > 0) {
report += `\n### Error Types\n`;
const sortedErrors = Object.entries(stats.errorRates) as Array<
[string, number]
>;
sortedErrors.sort((a, b) => b[1] - a[1]);
for (const [errorType, rate] of sortedErrors) {
report += `- ${errorType}: ${rate.toFixed(1)}%\n`;
}
}
report += `\n### Usage by Hour (UTC)\n`;
report += `\`\`\`\n`;
const maxRequests = Math.max(...Object.values(stats.timeDistribution));
const scale = 40; // max bar length
for (let i = 0; i < 24; i++) {
const hour = i.toString().padStart(2, '0');
const requests = stats.timeDistribution[hour] || 0;
const barLength =
maxRequests > 0 ? Math.round((requests / maxRequests) * scale) : 0;
const bar = '█'.repeat(barLength);
report += `${hour}:00 ${bar.padEnd(scale)} ${requests}\n`;
}
report += `\`\`\`\n`;
return report;
}
/**
* Save an analytics report to a file
* @param days Number of days to include
* @param outputPath File path for the report (defaults to logs/openai/analytics.md)
*/
static async saveReport(
days: number = 7,
outputPath?: string,
): Promise<string> {
const report = await this.generateReport(days);
const reportPath =
outputPath || path.join(process.cwd(), 'logs', 'openai', 'analytics.md');
await fs.writeFile(reportPath, report, 'utf-8');
return reportPath;
}
}
function isObjectWith<T extends object>(
obj: unknown,
keys: Array<keyof T>,
): obj is T {
return (
typeof obj === 'object' && obj !== null && keys.every((key) => key in obj)
);
}
/**
* Extract the model name from a log entry
*/
function getModelFromLog(logData: unknown): string | undefined {
if (
isObjectWith<{
request?: { model?: string };
response?: { model?: string; modelVersion?: string };
}>(logData, ['request', 'response'])
) {
const data = logData as {
request?: { model?: string };
response?: { model?: string; modelVersion?: string };
};
if (data.request && data.request.model) return data.request.model;
if (data.response && data.response.model) return data.response.model;
if (data.response && data.response.modelVersion)
return data.response.modelVersion;
}
return undefined;
}
/**
* Extract token usage information from a log entry
*/
function getTokenUsageFromLog(logData: unknown):
| {
prompt_tokens?: number;
completion_tokens?: number;
total_tokens?: number;
}
| undefined {
if (
isObjectWith<{
response?: {
usage?: object;
usageMetadata?: {
promptTokenCount?: number;
candidatesTokenCount?: number;
totalTokenCount?: number;
};
};
}>(logData, ['response'])
) {
const data = logData as {
response?: {
usage?: object;
usageMetadata?: {
promptTokenCount?: number;
candidatesTokenCount?: number;
totalTokenCount?: number;
};
};
};
if (data.response && data.response.usage) return data.response.usage;
if (data.response && data.response.usageMetadata) {
const metadata = data.response.usageMetadata;
return {
prompt_tokens: metadata.promptTokenCount,
completion_tokens: metadata.candidatesTokenCount,
total_tokens: metadata.totalTokenCount,
};
}
}
return undefined;
}
/**
* Extract and categorize error types from a log entry
*/
function getErrorTypeFromLog(logData: unknown): string {
if (isObjectWith<{ error?: { message?: string } }>(logData, ['error'])) {
const data = logData as { error?: { message?: string } };
if (data.error) {
const errorMsg = data.error.message || '';
if (errorMsg.includes('rate limit')) return 'rate_limit';
if (errorMsg.includes('timeout')) return 'timeout';
if (errorMsg.includes('authentication')) return 'authentication';
if (errorMsg.includes('quota')) return 'quota_exceeded';
if (errorMsg.includes('invalid')) return 'invalid_request';
if (errorMsg.includes('not available')) return 'model_unavailable';
if (errorMsg.includes('content filter')) return 'content_filtered';
return 'other';
}
}
return 'unknown';
}
// CLI interface when script is run directly
if (import.meta.url === `file://${process.argv[1]}`) {
async function main() {
const args = process.argv.slice(2);
const days = args[0] ? parseInt(args[0], 10) : 7;
try {
const reportPath = await OpenAIAnalytics.saveReport(days);
console.log(`Analytics report saved to: ${reportPath}`);
// Also print to console
const report = await OpenAIAnalytics.generateReport(days);
console.log(report);
} catch (error) {
console.error('Error generating analytics report:', error);
}
}
main().catch(console.error);
}
export default OpenAIAnalytics;

View File

@@ -0,0 +1,199 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import path from 'node:path';
import { openaiLogger } from './openaiLogger.js';
/**
* CLI utility for viewing and managing OpenAI logs
*/
export class OpenAILogViewer {
/**
* List all available OpenAI logs
* @param limit Optional limit on the number of logs to display
*/
static async listLogs(limit?: number): Promise<void> {
try {
const logs = await openaiLogger.getLogFiles(limit);
if (logs.length === 0) {
console.log('No OpenAI logs found');
return;
}
console.log(`Found ${logs.length} OpenAI logs:`);
for (let i = 0; i < logs.length; i++) {
const filePath = logs[i];
const filename = path.basename(filePath);
const logData = await openaiLogger.readLogFile(filePath);
// Type guard for logData
if (typeof logData !== 'object' || logData === null) {
console.log(`${i + 1}. ${filename} - Invalid log data`);
continue;
}
const data = logData as Record<string, unknown>;
// Format the log entry summary
const requestType = getRequestType(data.request);
const status = data.error ? 'ERROR' : 'OK';
console.log(
`${i + 1}. ${filename} - ${requestType} - ${status} - ${data.timestamp}`,
);
}
} catch (error) {
console.error('Error listing logs:', error);
}
}
/**
* View details of a specific log file
* @param identifier Either a log index (1-based) or a filename
*/
static async viewLog(identifier: number | string): Promise<void> {
try {
let logFile: string | undefined;
const logs = await openaiLogger.getLogFiles();
if (logs.length === 0) {
console.log('No OpenAI logs found');
return;
}
if (typeof identifier === 'number') {
// Adjust for 1-based indexing
if (identifier < 1 || identifier > logs.length) {
console.error(
`Invalid log index. Please provide a number between 1 and ${logs.length}`,
);
return;
}
logFile = logs[identifier - 1];
} else {
// Find by filename
logFile = logs.find((log) => path.basename(log) === identifier);
if (!logFile) {
console.error(`Log file '${identifier}' not found`);
return;
}
}
const logData = await openaiLogger.readLogFile(logFile);
console.log(JSON.stringify(logData, null, 2));
} catch (error) {
console.error('Error viewing log:', error);
}
}
/**
* Clean up old logs, keeping only the most recent ones
* @param keepCount Number of recent logs to keep
*/
static async cleanupLogs(keepCount: number = 50): Promise<void> {
try {
const allLogs = await openaiLogger.getLogFiles();
if (allLogs.length === 0) {
console.log('No OpenAI logs found');
return;
}
if (allLogs.length <= keepCount) {
console.log(`Only ${allLogs.length} logs exist, no cleanup needed`);
return;
}
const logsToDelete = allLogs.slice(keepCount);
const fs = await import('node:fs/promises');
for (const log of logsToDelete) {
await fs.unlink(log);
}
console.log(
`Deleted ${logsToDelete.length} old log files. Kept ${keepCount} most recent logs.`,
);
} catch (error) {
console.error('Error cleaning up logs:', error);
}
}
}
/**
* Helper function to determine the type of request in a log
*/
function getRequestType(request: unknown): string {
if (!request) return 'unknown';
if (typeof request !== 'object' || request === null) return 'unknown';
const req = request as Record<string, unknown>;
if (req.contents) {
return 'generate_content';
} else if (typeof req.model === 'string' && req.model.includes('embedding')) {
return 'embedding';
} else if (req.input) {
return 'embedding';
} else if ('countTokens' in req || 'contents' in req) {
return 'count_tokens';
}
return 'api_call';
}
// CLI interface when script is run directly
if (import.meta.url === `file://${process.argv[1]}`) {
async function main() {
const args = process.argv.slice(2);
const command = args[0]?.toLowerCase();
switch (command) {
case 'list': {
const limit = args[1] ? parseInt(args[1], 10) : undefined;
await OpenAILogViewer.listLogs(limit);
break;
}
case 'view': {
const identifier = args[1];
if (!identifier) {
console.error('Please provide a log index or filename to view');
process.exit(1);
}
await OpenAILogViewer.viewLog(
isNaN(Number(identifier)) ? identifier : Number(identifier),
);
break;
}
case 'cleanup': {
const keepCount = args[1] ? parseInt(args[1], 10) : 50;
await OpenAILogViewer.cleanupLogs(keepCount);
break;
}
default:
console.log('OpenAI Log Viewer');
console.log('----------------');
console.log('Commands:');
console.log(
' list [limit] - List all logs, optionally limiting to the specified number',
);
console.log(
' view <index|file> - View a specific log by index number or filename',
);
console.log(
' cleanup [keepCount] - Remove old logs, keeping only the specified number (default: 50)',
);
break;
}
}
main().catch(console.error);
}
export default OpenAILogViewer;

View File

@@ -0,0 +1,135 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import * as path from 'node:path';
import { promises as fs } from 'node:fs';
import { v4 as uuidv4 } from 'uuid';
import * as os from 'os';
/**
* Logger specifically for OpenAI API requests and responses
*/
export class OpenAILogger {
private logDir: string;
private initialized: boolean = false;
/**
* Creates a new OpenAI logger
* @param customLogDir Optional custom log directory path
*/
constructor(customLogDir?: string) {
this.logDir = customLogDir || path.join(process.cwd(), 'logs', 'openai');
}
/**
* Initialize the logger by creating the log directory if it doesn't exist
*/
async initialize(): Promise<void> {
if (this.initialized) return;
try {
await fs.mkdir(this.logDir, { recursive: true });
this.initialized = true;
} catch (error) {
console.error('Failed to initialize OpenAI logger:', error);
throw new Error(`Failed to initialize OpenAI logger: ${error}`);
}
}
/**
* Logs an OpenAI API request and its response
* @param request The request sent to OpenAI
* @param response The response received from OpenAI
* @param error Optional error if the request failed
* @returns The file path where the log was written
*/
async logInteraction(
request: unknown,
response?: unknown,
error?: Error,
): Promise<string> {
if (!this.initialized) {
await this.initialize();
}
const timestamp = new Date().toISOString().replace(/:/g, '-');
const id = uuidv4().slice(0, 8);
const filename = `openai-${timestamp}-${id}.json`;
const filePath = path.join(this.logDir, filename);
const logData = {
timestamp: new Date().toISOString(),
request,
response: response || null,
error: error
? {
message: error.message,
stack: error.stack,
}
: null,
system: {
hostname: os.hostname(),
platform: os.platform(),
release: os.release(),
nodeVersion: process.version,
},
};
try {
await fs.writeFile(filePath, JSON.stringify(logData, null, 2), 'utf-8');
return filePath;
} catch (writeError) {
console.error('Failed to write OpenAI log file:', writeError);
throw new Error(`Failed to write OpenAI log file: ${writeError}`);
}
}
/**
* Get all logged interactions
* @param limit Optional limit on the number of log files to return (sorted by most recent first)
* @returns Array of log file paths
*/
async getLogFiles(limit?: number): Promise<string[]> {
if (!this.initialized) {
await this.initialize();
}
try {
const files = await fs.readdir(this.logDir);
const logFiles = files
.filter((file) => file.startsWith('openai-') && file.endsWith('.json'))
.map((file) => path.join(this.logDir, file))
.sort()
.reverse();
return limit ? logFiles.slice(0, limit) : logFiles;
} catch (error) {
if ((error as NodeJS.ErrnoException).code === 'ENOENT') {
return [];
}
console.error('Failed to read OpenAI log directory:', error);
return [];
}
}
/**
* Read a specific log file
* @param filePath The path to the log file
* @returns The log file content
*/
async readLogFile(filePath: string): Promise<unknown> {
try {
const content = await fs.readFile(filePath, 'utf-8');
return JSON.parse(content);
} catch (error) {
console.error(`Failed to read log file ${filePath}:`, error);
throw new Error(`Failed to read log file: ${error}`);
}
}
}
// Create a singleton instance for easy import
export const openaiLogger = new OpenAILogger();

View File

@@ -0,0 +1,160 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import path from 'node:path';
import os from 'os';
import * as crypto from 'crypto';
export const GEMINI_DIR = '.qwen';
export const GOOGLE_ACCOUNTS_FILENAME = 'google_accounts.json';
const TMP_DIR_NAME = 'tmp';
/**
* Replaces the home directory with a tilde.
* @param path - The path to tildeify.
* @returns The tildeified path.
*/
export function tildeifyPath(path: string): string {
const homeDir = os.homedir();
if (path.startsWith(homeDir)) {
return path.replace(homeDir, '~');
}
return path;
}
/**
* Shortens a path string if it exceeds maxLen, prioritizing the start and end segments.
* Example: /path/to/a/very/long/file.txt -> /path/.../long/file.txt
*/
export function shortenPath(filePath: string, maxLen: number = 35): string {
if (filePath.length <= maxLen) {
return filePath;
}
const parsedPath = path.parse(filePath);
const root = parsedPath.root;
const separator = path.sep;
// Get segments of the path *after* the root
const relativePath = filePath.substring(root.length);
const segments = relativePath.split(separator).filter((s) => s !== ''); // Filter out empty segments
// Handle cases with no segments after root (e.g., "/", "C:\") or only one segment
if (segments.length <= 1) {
// Fallback to simple start/end truncation for very short paths or single segments
const keepLen = Math.floor((maxLen - 3) / 2);
// Ensure keepLen is not negative if maxLen is very small
if (keepLen <= 0) {
return filePath.substring(0, maxLen - 3) + '...';
}
const start = filePath.substring(0, keepLen);
const end = filePath.substring(filePath.length - keepLen);
return `${start}...${end}`;
}
const firstDir = segments[0];
const lastSegment = segments[segments.length - 1];
const startComponent = root + firstDir;
const endPartSegments: string[] = [];
// Base length: separator + "..." + lastDir
let currentLength = separator.length + lastSegment.length;
// Iterate backwards through segments (excluding the first one)
for (let i = segments.length - 2; i >= 0; i--) {
const segment = segments[i];
// Length needed if we add this segment: current + separator + segment
const lengthWithSegment = currentLength + separator.length + segment.length;
if (lengthWithSegment <= maxLen) {
endPartSegments.unshift(segment); // Add to the beginning of the end part
currentLength = lengthWithSegment;
} else {
break;
}
}
let result = endPartSegments.join(separator) + separator + lastSegment;
if (currentLength > maxLen) {
return result;
}
// Construct the final path
result = startComponent + separator + result;
// As a final check, if the result is somehow still too long
// truncate the result string from the beginning, prefixing with "...".
if (result.length > maxLen) {
return '...' + result.substring(result.length - maxLen - 3);
}
return result;
}
/**
* Calculates the relative path from a root directory to a target path.
* Ensures both paths are resolved before calculating.
* Returns '.' if the target path is the same as the root directory.
*
* @param targetPath The absolute or relative path to make relative.
* @param rootDirectory The absolute path of the directory to make the target path relative to.
* @returns The relative path from rootDirectory to targetPath.
*/
export function makeRelative(
targetPath: string,
rootDirectory: string,
): string {
const resolvedTargetPath = path.resolve(targetPath);
const resolvedRootDirectory = path.resolve(rootDirectory);
const relativePath = path.relative(resolvedRootDirectory, resolvedTargetPath);
// If the paths are the same, path.relative returns '', return '.' instead
return relativePath || '.';
}
/**
* Escapes spaces in a file path.
*/
export function escapePath(filePath: string): string {
let result = '';
for (let i = 0; i < filePath.length; i++) {
// Only escape spaces that are not already escaped.
if (filePath[i] === ' ' && (i === 0 || filePath[i - 1] !== '\\')) {
result += '\\ ';
} else {
result += filePath[i];
}
}
return result;
}
/**
* Unescapes spaces in a file path.
*/
export function unescapePath(filePath: string): string {
return filePath.replace(/\\ /g, ' ');
}
/**
* Generates a unique hash for a project based on its root path.
* @param projectRoot The absolute path to the project's root directory.
* @returns A SHA256 hash of the project root path.
*/
export function getProjectHash(projectRoot: string): string {
return crypto.createHash('sha256').update(projectRoot).digest('hex');
}
/**
* Generates a unique temporary directory path for a project.
* @param projectRoot The absolute path to the project's root directory.
* @returns The path to the project's temporary directory.
*/
export function getProjectTempDir(projectRoot: string): string {
const hash = getProjectHash(projectRoot);
return path.join(os.homedir(), GEMINI_DIR, TMP_DIR_NAME, hash);
}

View File

@@ -0,0 +1,112 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
export interface ApiError {
error: {
code: number;
message: string;
status: string;
details: unknown[];
};
}
interface StructuredError {
message: string;
status?: number;
}
export function isApiError(error: unknown): error is ApiError {
return (
typeof error === 'object' &&
error !== null &&
'error' in error &&
typeof (error as ApiError).error === 'object' &&
'message' in (error as ApiError).error
);
}
export function isStructuredError(error: unknown): error is StructuredError {
return (
typeof error === 'object' &&
error !== null &&
'message' in error &&
typeof (error as StructuredError).message === 'string'
);
}
export function isProQuotaExceededError(error: unknown): boolean {
// Check for Pro quota exceeded errors by looking for the specific pattern
// This will match patterns like:
// - "Quota exceeded for quota metric 'Gemini 2.5 Pro Requests'"
// - "Quota exceeded for quota metric 'Gemini 2.5-preview Pro Requests'"
// We use string methods instead of regex to avoid ReDoS vulnerabilities
const checkMessage = (message: string): boolean =>
message.includes("Quota exceeded for quota metric 'Gemini") &&
message.includes("Pro Requests'");
if (typeof error === 'string') {
return checkMessage(error);
}
if (isStructuredError(error)) {
return checkMessage(error.message);
}
if (isApiError(error)) {
return checkMessage(error.error.message);
}
// Check if it's a Gaxios error with response data
if (error && typeof error === 'object' && 'response' in error) {
const gaxiosError = error as {
response?: {
data?: unknown;
};
};
if (gaxiosError.response && gaxiosError.response.data) {
console.log(
'[DEBUG] isProQuotaExceededError - checking response data:',
gaxiosError.response.data,
);
if (typeof gaxiosError.response.data === 'string') {
return checkMessage(gaxiosError.response.data);
}
if (
typeof gaxiosError.response.data === 'object' &&
gaxiosError.response.data !== null &&
'error' in gaxiosError.response.data
) {
const errorData = gaxiosError.response.data as {
error?: { message?: string };
};
return checkMessage(errorData.error?.message || '');
}
}
}
console.log(
'[DEBUG] isProQuotaExceededError - no matching error format for:',
error,
);
return false;
}
export function isGenericQuotaExceededError(error: unknown): boolean {
if (typeof error === 'string') {
return error.includes('Quota exceeded for quota metric');
}
if (isStructuredError(error)) {
return error.message.includes('Quota exceeded for quota metric');
}
if (isApiError(error)) {
return error.error.message.includes('Quota exceeded for quota metric');
}
return false;
}

View File

@@ -0,0 +1,407 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
/* eslint-disable @typescript-eslint/no-explicit-any */
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
import { retryWithBackoff } from './retry.js';
import { setSimulate429 } from './testUtils.js';
// Define an interface for the error with a status property
interface HttpError extends Error {
status?: number;
}
// Helper to create a mock function that fails a certain number of times
const createFailingFunction = (
failures: number,
successValue: string = 'success',
) => {
let attempts = 0;
return vi.fn(async () => {
attempts++;
if (attempts <= failures) {
// Simulate a retryable error
const error: HttpError = new Error(`Simulated error attempt ${attempts}`);
error.status = 500; // Simulate a server error
throw error;
}
return successValue;
});
};
// Custom error for testing non-retryable conditions
class NonRetryableError extends Error {
constructor(message: string) {
super(message);
this.name = 'NonRetryableError';
}
}
describe('retryWithBackoff', () => {
beforeEach(() => {
vi.useFakeTimers();
// Disable 429 simulation for tests
setSimulate429(false);
// Suppress unhandled promise rejection warnings for tests that expect errors
console.warn = vi.fn();
});
afterEach(() => {
vi.restoreAllMocks();
vi.useRealTimers();
});
it('should return the result on the first attempt if successful', async () => {
const mockFn = createFailingFunction(0);
const result = await retryWithBackoff(mockFn);
expect(result).toBe('success');
expect(mockFn).toHaveBeenCalledTimes(1);
});
it('should retry and succeed if failures are within maxAttempts', async () => {
const mockFn = createFailingFunction(2);
const promise = retryWithBackoff(mockFn, {
maxAttempts: 3,
initialDelayMs: 10,
});
await vi.runAllTimersAsync(); // Ensure all delays and retries complete
const result = await promise;
expect(result).toBe('success');
expect(mockFn).toHaveBeenCalledTimes(3);
});
it('should throw an error if all attempts fail', async () => {
const mockFn = createFailingFunction(3);
// 1. Start the retryable operation, which returns a promise.
const promise = retryWithBackoff(mockFn, {
maxAttempts: 3,
initialDelayMs: 10,
});
// 2. IMPORTANT: Attach the rejection expectation to the promise *immediately*.
// This ensures a 'catch' handler is present before the promise can reject.
// The result is a new promise that resolves when the assertion is met.
const assertionPromise = expect(promise).rejects.toThrow(
'Simulated error attempt 3',
);
// 3. Now, advance the timers. This will trigger the retries and the
// eventual rejection. The handler attached in step 2 will catch it.
await vi.runAllTimersAsync();
// 4. Await the assertion promise itself to ensure the test was successful.
await assertionPromise;
// 5. Finally, assert the number of calls.
expect(mockFn).toHaveBeenCalledTimes(3);
});
it('should not retry if shouldRetry returns false', async () => {
const mockFn = vi.fn(async () => {
throw new NonRetryableError('Non-retryable error');
});
const shouldRetry = (error: Error) => !(error instanceof NonRetryableError);
const promise = retryWithBackoff(mockFn, {
shouldRetry,
initialDelayMs: 10,
});
await expect(promise).rejects.toThrow('Non-retryable error');
expect(mockFn).toHaveBeenCalledTimes(1);
});
it('should use default shouldRetry if not provided, retrying on 429', async () => {
const mockFn = vi.fn(async () => {
const error = new Error('Too Many Requests') as any;
error.status = 429;
throw error;
});
const promise = retryWithBackoff(mockFn, {
maxAttempts: 2,
initialDelayMs: 10,
});
// Attach the rejection expectation *before* running timers
const assertionPromise =
expect(promise).rejects.toThrow('Too Many Requests');
// Run timers to trigger retries and eventual rejection
await vi.runAllTimersAsync();
// Await the assertion
await assertionPromise;
expect(mockFn).toHaveBeenCalledTimes(2);
});
it('should use default shouldRetry if not provided, not retrying on 400', async () => {
const mockFn = vi.fn(async () => {
const error = new Error('Bad Request') as any;
error.status = 400;
throw error;
});
const promise = retryWithBackoff(mockFn, {
maxAttempts: 2,
initialDelayMs: 10,
});
await expect(promise).rejects.toThrow('Bad Request');
expect(mockFn).toHaveBeenCalledTimes(1);
});
it('should respect maxDelayMs', async () => {
const mockFn = createFailingFunction(3);
const setTimeoutSpy = vi.spyOn(global, 'setTimeout');
const promise = retryWithBackoff(mockFn, {
maxAttempts: 4,
initialDelayMs: 100,
maxDelayMs: 250, // Max delay is less than 100 * 2 * 2 = 400
});
await vi.advanceTimersByTimeAsync(1000); // Advance well past all delays
await promise;
const delays = setTimeoutSpy.mock.calls.map((call) => call[1] as number);
// Delays should be around initial, initial*2, maxDelay (due to cap)
// Jitter makes exact assertion hard, so we check ranges / caps
expect(delays.length).toBe(3);
expect(delays[0]).toBeGreaterThanOrEqual(100 * 0.7);
expect(delays[0]).toBeLessThanOrEqual(100 * 1.3);
expect(delays[1]).toBeGreaterThanOrEqual(200 * 0.7);
expect(delays[1]).toBeLessThanOrEqual(200 * 1.3);
// The third delay should be capped by maxDelayMs (250ms), accounting for jitter
expect(delays[2]).toBeGreaterThanOrEqual(250 * 0.7);
expect(delays[2]).toBeLessThanOrEqual(250 * 1.3);
});
it('should handle jitter correctly, ensuring varied delays', async () => {
let mockFn = createFailingFunction(5);
const setTimeoutSpy = vi.spyOn(global, 'setTimeout');
// Run retryWithBackoff multiple times to observe jitter
const runRetry = () =>
retryWithBackoff(mockFn, {
maxAttempts: 2, // Only one retry, so one delay
initialDelayMs: 100,
maxDelayMs: 1000,
});
// We expect rejections as mockFn fails 5 times
const promise1 = runRetry();
// Attach the rejection expectation *before* running timers
const assertionPromise1 = expect(promise1).rejects.toThrow();
await vi.runAllTimersAsync(); // Advance for the delay in the first runRetry
await assertionPromise1;
const firstDelaySet = setTimeoutSpy.mock.calls.map(
(call) => call[1] as number,
);
setTimeoutSpy.mockClear(); // Clear calls for the next run
// Reset mockFn to reset its internal attempt counter for the next run
mockFn = createFailingFunction(5); // Re-initialize with 5 failures
const promise2 = runRetry();
// Attach the rejection expectation *before* running timers
const assertionPromise2 = expect(promise2).rejects.toThrow();
await vi.runAllTimersAsync(); // Advance for the delay in the second runRetry
await assertionPromise2;
const secondDelaySet = setTimeoutSpy.mock.calls.map(
(call) => call[1] as number,
);
// Check that the delays are not exactly the same due to jitter
// This is a probabilistic test, but with +/-30% jitter, it's highly likely they differ.
if (firstDelaySet.length > 0 && secondDelaySet.length > 0) {
// Check the first delay of each set
expect(firstDelaySet[0]).not.toBe(secondDelaySet[0]);
} else {
// If somehow no delays were captured (e.g. test setup issue), fail explicitly
throw new Error('Delays were not captured for jitter test');
}
// Ensure delays are within the expected jitter range [70, 130] for initialDelayMs = 100
[...firstDelaySet, ...secondDelaySet].forEach((d) => {
expect(d).toBeGreaterThanOrEqual(100 * 0.7);
expect(d).toBeLessThanOrEqual(100 * 1.3);
});
});
describe('Flash model fallback for OAuth users', () => {
it('should trigger fallback for OAuth personal users after persistent 429 errors', async () => {
const fallbackCallback = vi.fn().mockResolvedValue('gemini-2.5-flash');
let fallbackOccurred = false;
const mockFn = vi.fn().mockImplementation(async () => {
if (!fallbackOccurred) {
const error: HttpError = new Error('Rate limit exceeded');
error.status = 429;
throw error;
}
return 'success';
});
const promise = retryWithBackoff(mockFn, {
maxAttempts: 3,
initialDelayMs: 100,
onPersistent429: async (authType?: string) => {
fallbackOccurred = true;
return await fallbackCallback(authType);
},
authType: 'oauth-personal',
});
// Advance all timers to complete retries
await vi.runAllTimersAsync();
// Should succeed after fallback
await expect(promise).resolves.toBe('success');
// Verify callback was called with correct auth type
expect(fallbackCallback).toHaveBeenCalledWith('oauth-personal');
// Should retry again after fallback
expect(mockFn).toHaveBeenCalledTimes(3); // 2 initial attempts + 1 after fallback
});
it('should NOT trigger fallback for API key users', async () => {
const fallbackCallback = vi.fn();
const mockFn = vi.fn(async () => {
const error: HttpError = new Error('Rate limit exceeded');
error.status = 429;
throw error;
});
const promise = retryWithBackoff(mockFn, {
maxAttempts: 3,
initialDelayMs: 100,
onPersistent429: fallbackCallback,
authType: 'gemini-api-key',
});
// Handle the promise properly to avoid unhandled rejections
const resultPromise = promise.catch((error) => error);
await vi.runAllTimersAsync();
const result = await resultPromise;
// Should fail after all retries without fallback
expect(result).toBeInstanceOf(Error);
expect(result.message).toBe('Rate limit exceeded');
// Callback should not be called for API key users
expect(fallbackCallback).not.toHaveBeenCalled();
});
it('should reset attempt counter and continue after successful fallback', async () => {
let fallbackCalled = false;
const fallbackCallback = vi.fn().mockImplementation(async () => {
fallbackCalled = true;
return 'gemini-2.5-flash';
});
const mockFn = vi.fn().mockImplementation(async () => {
if (!fallbackCalled) {
const error: HttpError = new Error('Rate limit exceeded');
error.status = 429;
throw error;
}
return 'success';
});
const promise = retryWithBackoff(mockFn, {
maxAttempts: 3,
initialDelayMs: 100,
onPersistent429: fallbackCallback,
authType: 'oauth-personal',
});
await vi.runAllTimersAsync();
await expect(promise).resolves.toBe('success');
expect(fallbackCallback).toHaveBeenCalledOnce();
});
it('should continue with original error if fallback is rejected', async () => {
const fallbackCallback = vi.fn().mockResolvedValue(null); // User rejected fallback
const mockFn = vi.fn(async () => {
const error: HttpError = new Error('Rate limit exceeded');
error.status = 429;
throw error;
});
const promise = retryWithBackoff(mockFn, {
maxAttempts: 3,
initialDelayMs: 100,
onPersistent429: fallbackCallback,
authType: 'oauth-personal',
});
// Handle the promise properly to avoid unhandled rejections
const resultPromise = promise.catch((error) => error);
await vi.runAllTimersAsync();
const result = await resultPromise;
// Should fail with original error when fallback is rejected
expect(result).toBeInstanceOf(Error);
expect(result.message).toBe('Rate limit exceeded');
expect(fallbackCallback).toHaveBeenCalledWith(
'oauth-personal',
expect.any(Error),
);
});
it('should handle mixed error types (only count consecutive 429s)', async () => {
const fallbackCallback = vi.fn().mockResolvedValue('gemini-2.5-flash');
let attempts = 0;
let fallbackOccurred = false;
const mockFn = vi.fn().mockImplementation(async () => {
attempts++;
if (fallbackOccurred) {
return 'success';
}
if (attempts === 1) {
// First attempt: 500 error (resets consecutive count)
const error: HttpError = new Error('Server error');
error.status = 500;
throw error;
} else {
// Remaining attempts: 429 errors
const error: HttpError = new Error('Rate limit exceeded');
error.status = 429;
throw error;
}
});
const promise = retryWithBackoff(mockFn, {
maxAttempts: 5,
initialDelayMs: 100,
onPersistent429: async (authType?: string) => {
fallbackOccurred = true;
return await fallbackCallback(authType);
},
authType: 'oauth-personal',
});
await vi.runAllTimersAsync();
await expect(promise).resolves.toBe('success');
// Should trigger fallback after 2 consecutive 429s (attempts 2-3)
expect(fallbackCallback).toHaveBeenCalledWith('oauth-personal');
});
});
});

View File

@@ -0,0 +1,335 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { AuthType } from '../core/contentGenerator.js';
import {
isProQuotaExceededError,
isGenericQuotaExceededError,
} from './quotaErrorDetection.js';
export interface RetryOptions {
maxAttempts: number;
initialDelayMs: number;
maxDelayMs: number;
shouldRetry: (error: Error) => boolean;
onPersistent429?: (
authType?: string,
error?: unknown,
) => Promise<string | boolean | null>;
authType?: string;
}
const DEFAULT_RETRY_OPTIONS: RetryOptions = {
maxAttempts: 5,
initialDelayMs: 5000,
maxDelayMs: 30000, // 30 seconds
shouldRetry: defaultShouldRetry,
};
/**
* Default predicate function to determine if a retry should be attempted.
* Retries on 429 (Too Many Requests) and 5xx server errors.
* @param error The error object.
* @returns True if the error is a transient error, false otherwise.
*/
function defaultShouldRetry(error: Error | unknown): boolean {
// Check for common transient error status codes either in message or a status property
if (error && typeof (error as { status?: number }).status === 'number') {
const status = (error as { status: number }).status;
if (status === 429 || (status >= 500 && status < 600)) {
return true;
}
}
if (error instanceof Error && error.message) {
if (error.message.includes('429')) return true;
if (error.message.match(/5\d{2}/)) return true;
}
return false;
}
/**
* Delays execution for a specified number of milliseconds.
* @param ms The number of milliseconds to delay.
* @returns A promise that resolves after the delay.
*/
function delay(ms: number): Promise<void> {
return new Promise((resolve) => setTimeout(resolve, ms));
}
/**
* Retries a function with exponential backoff and jitter.
* @param fn The asynchronous function to retry.
* @param options Optional retry configuration.
* @returns A promise that resolves with the result of the function if successful.
* @throws The last error encountered if all attempts fail.
*/
export async function retryWithBackoff<T>(
fn: () => Promise<T>,
options?: Partial<RetryOptions>,
): Promise<T> {
const {
maxAttempts,
initialDelayMs,
maxDelayMs,
onPersistent429,
authType,
shouldRetry,
} = {
...DEFAULT_RETRY_OPTIONS,
...options,
};
let attempt = 0;
let currentDelay = initialDelayMs;
let consecutive429Count = 0;
while (attempt < maxAttempts) {
attempt++;
try {
return await fn();
} catch (error) {
const errorStatus = getErrorStatus(error);
// Check for Pro quota exceeded error first - immediate fallback for OAuth users
if (
errorStatus === 429 &&
authType === AuthType.LOGIN_WITH_GOOGLE &&
isProQuotaExceededError(error) &&
onPersistent429
) {
try {
const fallbackModel = await onPersistent429(authType, error);
if (fallbackModel !== false && fallbackModel !== null) {
// Reset attempt counter and try with new model
attempt = 0;
consecutive429Count = 0;
currentDelay = initialDelayMs;
// With the model updated, we continue to the next attempt
continue;
} else {
// Fallback handler returned null/false, meaning don't continue - stop retry process
throw error;
}
} catch (fallbackError) {
// If fallback fails, continue with original error
console.warn('Fallback to Flash model failed:', fallbackError);
}
}
// Check for generic quota exceeded error (but not Pro, which was handled above) - immediate fallback for OAuth users
if (
errorStatus === 429 &&
authType === AuthType.LOGIN_WITH_GOOGLE &&
!isProQuotaExceededError(error) &&
isGenericQuotaExceededError(error) &&
onPersistent429
) {
try {
const fallbackModel = await onPersistent429(authType, error);
if (fallbackModel !== false && fallbackModel !== null) {
// Reset attempt counter and try with new model
attempt = 0;
consecutive429Count = 0;
currentDelay = initialDelayMs;
// With the model updated, we continue to the next attempt
continue;
} else {
// Fallback handler returned null/false, meaning don't continue - stop retry process
throw error;
}
} catch (fallbackError) {
// If fallback fails, continue with original error
console.warn('Fallback to Flash model failed:', fallbackError);
}
}
// Track consecutive 429 errors
if (errorStatus === 429) {
consecutive429Count++;
} else {
consecutive429Count = 0;
}
// If we have persistent 429s and a fallback callback for OAuth
if (
consecutive429Count >= 2 &&
onPersistent429 &&
authType === AuthType.LOGIN_WITH_GOOGLE
) {
try {
const fallbackModel = await onPersistent429(authType, error);
if (fallbackModel !== false && fallbackModel !== null) {
// Reset attempt counter and try with new model
attempt = 0;
consecutive429Count = 0;
currentDelay = initialDelayMs;
// With the model updated, we continue to the next attempt
continue;
} else {
// Fallback handler returned null/false, meaning don't continue - stop retry process
throw error;
}
} catch (fallbackError) {
// If fallback fails, continue with original error
console.warn('Fallback to Flash model failed:', fallbackError);
}
}
// Check if we've exhausted retries or shouldn't retry
if (attempt >= maxAttempts || !shouldRetry(error as Error)) {
throw error;
}
const { delayDurationMs, errorStatus: delayErrorStatus } =
getDelayDurationAndStatus(error);
if (delayDurationMs > 0) {
// Respect Retry-After header if present and parsed
console.warn(
`Attempt ${attempt} failed with status ${delayErrorStatus ?? 'unknown'}. Retrying after explicit delay of ${delayDurationMs}ms...`,
error,
);
await delay(delayDurationMs);
// Reset currentDelay for next potential non-429 error, or if Retry-After is not present next time
currentDelay = initialDelayMs;
} else {
// Fallback to exponential backoff with jitter
logRetryAttempt(attempt, error, errorStatus);
// Add jitter: +/- 30% of currentDelay
const jitter = currentDelay * 0.3 * (Math.random() * 2 - 1);
const delayWithJitter = Math.max(0, currentDelay + jitter);
await delay(delayWithJitter);
currentDelay = Math.min(maxDelayMs, currentDelay * 2);
}
}
}
// This line should theoretically be unreachable due to the throw in the catch block.
// Added for type safety and to satisfy the compiler that a promise is always returned.
throw new Error('Retry attempts exhausted');
}
/**
* Extracts the HTTP status code from an error object.
* @param error The error object.
* @returns The HTTP status code, or undefined if not found.
*/
function getErrorStatus(error: unknown): number | undefined {
if (typeof error === 'object' && error !== null) {
if ('status' in error && typeof error.status === 'number') {
return error.status;
}
// Check for error.response.status (common in axios errors)
if (
'response' in error &&
typeof (error as { response?: unknown }).response === 'object' &&
(error as { response?: unknown }).response !== null
) {
const response = (
error as { response: { status?: unknown; headers?: unknown } }
).response;
if ('status' in response && typeof response.status === 'number') {
return response.status;
}
}
}
return undefined;
}
/**
* Extracts the Retry-After delay from an error object's headers.
* @param error The error object.
* @returns The delay in milliseconds, or 0 if not found or invalid.
*/
function getRetryAfterDelayMs(error: unknown): number {
if (typeof error === 'object' && error !== null) {
// Check for error.response.headers (common in axios errors)
if (
'response' in error &&
typeof (error as { response?: unknown }).response === 'object' &&
(error as { response?: unknown }).response !== null
) {
const response = (error as { response: { headers?: unknown } }).response;
if (
'headers' in response &&
typeof response.headers === 'object' &&
response.headers !== null
) {
const headers = response.headers as { 'retry-after'?: unknown };
const retryAfterHeader = headers['retry-after'];
if (typeof retryAfterHeader === 'string') {
const retryAfterSeconds = parseInt(retryAfterHeader, 10);
if (!isNaN(retryAfterSeconds)) {
return retryAfterSeconds * 1000;
}
// It might be an HTTP date
const retryAfterDate = new Date(retryAfterHeader);
if (!isNaN(retryAfterDate.getTime())) {
return Math.max(0, retryAfterDate.getTime() - Date.now());
}
}
}
}
}
return 0;
}
/**
* Determines the delay duration based on the error, prioritizing Retry-After header.
* @param error The error object.
* @returns An object containing the delay duration in milliseconds and the error status.
*/
function getDelayDurationAndStatus(error: unknown): {
delayDurationMs: number;
errorStatus: number | undefined;
} {
const errorStatus = getErrorStatus(error);
let delayDurationMs = 0;
if (errorStatus === 429) {
delayDurationMs = getRetryAfterDelayMs(error);
}
return { delayDurationMs, errorStatus };
}
/**
* Logs a message for a retry attempt when using exponential backoff.
* @param attempt The current attempt number.
* @param error The error that caused the retry.
* @param errorStatus The HTTP status code of the error, if available.
*/
function logRetryAttempt(
attempt: number,
error: unknown,
errorStatus?: number,
): void {
let message = `Attempt ${attempt} failed. Retrying with backoff...`;
if (errorStatus) {
message = `Attempt ${attempt} failed with status ${errorStatus}. Retrying with backoff...`;
}
if (errorStatus === 429) {
console.warn(message, error);
} else if (errorStatus && errorStatus >= 500 && errorStatus < 600) {
console.error(message, error);
} else if (error instanceof Error) {
// Fallback for errors that might not have a status but have a message
if (error.message.includes('429')) {
console.warn(
`Attempt ${attempt} failed with 429 error (no Retry-After header). Retrying with backoff...`,
error,
);
} else if (error.message.match(/5\d{2}/)) {
console.error(
`Attempt ${attempt} failed with 5xx error. Retrying with backoff...`,
error,
);
} else {
console.warn(message, error); // Default to warn for other errors
}
} else {
console.warn(message, error); // Default to warn if error type is unknown
}
}

View File

@@ -0,0 +1,73 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { describe, it, expect } from 'vitest';
import { safeJsonStringify } from './safeJsonStringify.js';
describe('safeJsonStringify', () => {
it('should stringify normal objects without issues', () => {
const obj = { name: 'test', value: 42 };
const result = safeJsonStringify(obj);
expect(result).toBe('{"name":"test","value":42}');
});
it('should handle circular references by replacing them with [Circular]', () => {
// eslint-disable-next-line @typescript-eslint/no-explicit-any
const obj: any = { name: 'test' };
obj.circular = obj; // Create circular reference
const result = safeJsonStringify(obj);
expect(result).toBe('{"name":"test","circular":"[Circular]"}');
});
it('should handle complex circular structures like HttpsProxyAgent', () => {
// eslint-disable-next-line @typescript-eslint/no-explicit-any
const agent: any = {
sockets: {},
options: { host: 'example.com' },
};
agent.sockets['example.com'] = [{ agent }];
const result = safeJsonStringify(agent);
expect(result).toContain('[Circular]');
expect(result).toContain('example.com');
});
it('should respect the space parameter for formatting', () => {
const obj = { name: 'test', value: 42 };
const result = safeJsonStringify(obj, 2);
expect(result).toBe('{\n "name": "test",\n "value": 42\n}');
});
it('should handle circular references with formatting', () => {
// eslint-disable-next-line @typescript-eslint/no-explicit-any
const obj: any = { name: 'test' };
obj.circular = obj;
const result = safeJsonStringify(obj, 2);
expect(result).toBe('{\n "name": "test",\n "circular": "[Circular]"\n}');
});
it('should handle arrays with circular references', () => {
// eslint-disable-next-line @typescript-eslint/no-explicit-any
const arr: any[] = [{ id: 1 }];
arr[0].parent = arr; // Create circular reference
const result = safeJsonStringify(arr);
expect(result).toBe('[{"id":1,"parent":"[Circular]"}]');
});
it('should handle null and undefined values', () => {
expect(safeJsonStringify(null)).toBe('null');
expect(safeJsonStringify(undefined)).toBe(undefined);
});
it('should handle primitive values', () => {
expect(safeJsonStringify('test')).toBe('"test"');
expect(safeJsonStringify(42)).toBe('42');
expect(safeJsonStringify(true)).toBe('true');
});
});

View File

@@ -0,0 +1,32 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
/**
* Safely stringifies an object to JSON, handling circular references by replacing them with [Circular].
*
* @param obj - The object to stringify
* @param space - Optional space parameter for formatting (defaults to no formatting)
* @returns JSON string with circular references replaced by [Circular]
*/
export function safeJsonStringify(
obj: unknown,
space?: string | number,
): string {
const seen = new WeakSet();
return JSON.stringify(
obj,
(key, value) => {
if (typeof value === 'object' && value !== null) {
if (seen.has(value)) {
return '[Circular]';
}
seen.add(value);
}
return value;
},
space,
);
}

View File

@@ -0,0 +1,66 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { Schema } from '@google/genai';
import * as ajv from 'ajv';
const ajValidator = new ajv.Ajv();
/**
* Simple utility to validate objects against JSON Schemas
*/
export class SchemaValidator {
/**
* Returns null if the data confroms to the schema described by schema (or if schema
* is null). Otherwise, returns a string describing the error.
*/
static validate(schema: Schema | undefined, data: unknown): string | null {
if (!schema) {
return null;
}
if (typeof data !== 'object' || data === null) {
return 'Value of params must be an object';
}
const validate = ajValidator.compile(this.toObjectSchema(schema));
const valid = validate(data);
if (!valid && validate.errors) {
return ajValidator.errorsText(validate.errors, { dataVar: 'params' });
}
return null;
}
/**
* Converts @google/genai's Schema to an object compatible with avj.
* This is necessry because it represents Types as an Enum (with
* UPPERCASE values) and minItems and minLength as strings, when they should be numbers.
*/
private static toObjectSchema(schema: Schema): object {
const newSchema: Record<string, unknown> = { ...schema };
if (newSchema.anyOf && Array.isArray(newSchema.anyOf)) {
newSchema.anyOf = newSchema.anyOf.map((v) => this.toObjectSchema(v));
}
if (newSchema.items) {
newSchema.items = this.toObjectSchema(newSchema.items);
}
if (newSchema.properties && typeof newSchema.properties === 'object') {
const newProperties: Record<string, unknown> = {};
for (const [key, value] of Object.entries(newSchema.properties)) {
newProperties[key] = this.toObjectSchema(value as Schema);
}
newSchema.properties = newProperties;
}
if (newSchema.type) {
newSchema.type = String(newSchema.type).toLowerCase();
}
if (newSchema.minItems) {
newSchema.minItems = Number(newSchema.minItems);
}
if (newSchema.minLength) {
newSchema.minLength = Number(newSchema.minLength);
}
return newSchema;
}
}

View File

@@ -0,0 +1,9 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { randomUUID } from 'crypto';
export const sessionId = randomUUID();

View File

@@ -0,0 +1,208 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { describe, it, expect, vi, beforeEach, afterEach, Mock } from 'vitest';
import { GeminiClient } from '../core/client.js';
import { Config } from '../config/config.js';
import {
summarizeToolOutput,
llmSummarizer,
defaultSummarizer,
} from './summarizer.js';
import { ToolResult } from '../tools/tools.js';
// Mock GeminiClient and Config constructor
vi.mock('../core/client.js');
vi.mock('../config/config.js');
describe('summarizers', () => {
let mockGeminiClient: GeminiClient;
let MockConfig: Mock;
const abortSignal = new AbortController().signal;
beforeEach(() => {
MockConfig = vi.mocked(Config);
const mockConfigInstance = new MockConfig(
'test-api-key',
'gemini-pro',
false,
'.',
false,
undefined,
false,
undefined,
undefined,
undefined,
);
mockGeminiClient = new GeminiClient(mockConfigInstance);
(mockGeminiClient.generateContent as Mock) = vi.fn();
vi.spyOn(console, 'error').mockImplementation(() => {});
});
afterEach(() => {
vi.clearAllMocks();
(console.error as Mock).mockRestore();
});
describe('summarizeToolOutput', () => {
it('should return original text if it is shorter than maxLength', async () => {
const shortText = 'This is a short text.';
const result = await summarizeToolOutput(
shortText,
mockGeminiClient,
abortSignal,
2000,
);
expect(result).toBe(shortText);
expect(mockGeminiClient.generateContent).not.toHaveBeenCalled();
});
it('should return original text if it is empty', async () => {
const emptyText = '';
const result = await summarizeToolOutput(
emptyText,
mockGeminiClient,
abortSignal,
2000,
);
expect(result).toBe(emptyText);
expect(mockGeminiClient.generateContent).not.toHaveBeenCalled();
});
it('should call generateContent if text is longer than maxLength', async () => {
const longText = 'This is a very long text.'.repeat(200);
const summary = 'This is a summary.';
(mockGeminiClient.generateContent as Mock).mockResolvedValue({
candidates: [{ content: { parts: [{ text: summary }] } }],
});
const result = await summarizeToolOutput(
longText,
mockGeminiClient,
abortSignal,
2000,
);
expect(mockGeminiClient.generateContent).toHaveBeenCalledTimes(1);
expect(result).toBe(summary);
});
it('should return original text if generateContent throws an error', async () => {
const longText = 'This is a very long text.'.repeat(200);
const error = new Error('API Error');
(mockGeminiClient.generateContent as Mock).mockRejectedValue(error);
const result = await summarizeToolOutput(
longText,
mockGeminiClient,
abortSignal,
2000,
);
expect(mockGeminiClient.generateContent).toHaveBeenCalledTimes(1);
expect(result).toBe(longText);
expect(console.error).toHaveBeenCalledWith(
'Failed to summarize tool output.',
error,
);
});
it('should construct the correct prompt for summarization', async () => {
const longText = 'This is a very long text.'.repeat(200);
const summary = 'This is a summary.';
(mockGeminiClient.generateContent as Mock).mockResolvedValue({
candidates: [{ content: { parts: [{ text: summary }] } }],
});
await summarizeToolOutput(longText, mockGeminiClient, abortSignal, 1000);
const expectedPrompt = `Summarize the following tool output to be a maximum of 1000 characters. The summary should be concise and capture the main points of the tool output.
The summarization should be done based on the content that is provided. Here are the basic rules to follow:
1. If the text is a directory listing or any output that is structural, use the history of the conversation to understand the context. Using this context try to understand what information we need from the tool output and return that as a response.
2. If the text is text content and there is nothing structural that we need, summarize the text.
3. If the text is the output of a shell command, use the history of the conversation to understand the context. Using this context try to understand what information we need from the tool output and return a summarization along with the stack trace of any error within the <error></error> tags. The stack trace should be complete and not truncated. If there are warnings, you should include them in the summary within <warning></warning> tags.
Text to summarize:
"${longText}"
Return the summary string which should first contain an overall summarization of text followed by the full stack trace of errors and warnings in the tool output.
`;
const calledWith = (mockGeminiClient.generateContent as Mock).mock
.calls[0];
const contents = calledWith[0];
expect(contents[0].parts[0].text).toBe(expectedPrompt);
});
});
describe('llmSummarizer', () => {
it('should summarize tool output using summarizeToolOutput', async () => {
const toolResult: ToolResult = {
llmContent: 'This is a very long text.'.repeat(200),
returnDisplay: '',
};
const summary = 'This is a summary.';
(mockGeminiClient.generateContent as Mock).mockResolvedValue({
candidates: [{ content: { parts: [{ text: summary }] } }],
});
const result = await llmSummarizer(
toolResult,
mockGeminiClient,
abortSignal,
);
expect(mockGeminiClient.generateContent).toHaveBeenCalledTimes(1);
expect(result).toBe(summary);
});
it('should handle different llmContent types', async () => {
const longText = 'This is a very long text.'.repeat(200);
const toolResult: ToolResult = {
llmContent: [{ text: longText }],
returnDisplay: '',
};
const summary = 'This is a summary.';
(mockGeminiClient.generateContent as Mock).mockResolvedValue({
candidates: [{ content: { parts: [{ text: summary }] } }],
});
const result = await llmSummarizer(
toolResult,
mockGeminiClient,
abortSignal,
);
expect(mockGeminiClient.generateContent).toHaveBeenCalledTimes(1);
const calledWith = (mockGeminiClient.generateContent as Mock).mock
.calls[0];
const contents = calledWith[0];
expect(contents[0].parts[0].text).toContain(`"${longText}"`);
expect(result).toBe(summary);
});
});
describe('defaultSummarizer', () => {
it('should stringify the llmContent', async () => {
const toolResult: ToolResult = {
llmContent: { text: 'some data' },
returnDisplay: '',
};
const result = await defaultSummarizer(
toolResult,
mockGeminiClient,
abortSignal,
);
expect(result).toBe(JSON.stringify({ text: 'some data' }));
expect(mockGeminiClient.generateContent).not.toHaveBeenCalled();
});
});
});

View File

@@ -0,0 +1,131 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { ToolResult } from '../tools/tools.js';
import {
Content,
GenerateContentConfig,
GenerateContentResponse,
} from '@google/genai';
import { GeminiClient } from '../core/client.js';
import { DEFAULT_GEMINI_FLASH_MODEL } from '../config/models.js';
import { PartListUnion } from '@google/genai';
/**
* A function that summarizes the result of a tool execution.
*
* @param result The result of the tool execution.
* @returns The summary of the result.
*/
export type Summarizer = (
result: ToolResult,
geminiClient: GeminiClient,
abortSignal: AbortSignal,
) => Promise<string>;
/**
* The default summarizer for tool results.
*
* @param result The result of the tool execution.
* @param geminiClient The Gemini client to use for summarization.
* @param abortSignal The abort signal to use for summarization.
* @returns The summary of the result.
*/
export const defaultSummarizer: Summarizer = (
result: ToolResult,
_geminiClient: GeminiClient,
_abortSignal: AbortSignal,
) => Promise.resolve(JSON.stringify(result.llmContent));
// TODO: Move both these functions to utils
function partToString(part: PartListUnion): string {
if (!part) {
return '';
}
if (typeof part === 'string') {
return part;
}
if (Array.isArray(part)) {
return part.map(partToString).join('');
}
if ('text' in part) {
return part.text ?? '';
}
return '';
}
function getResponseText(response: GenerateContentResponse): string | null {
if (response.candidates && response.candidates.length > 0) {
const candidate = response.candidates[0];
if (
candidate.content &&
candidate.content.parts &&
candidate.content.parts.length > 0
) {
return candidate.content.parts
.filter((part) => part.text)
.map((part) => part.text)
.join('');
}
}
return null;
}
const toolOutputSummarizerModel = DEFAULT_GEMINI_FLASH_MODEL;
const toolOutputSummarizerConfig: GenerateContentConfig = {
maxOutputTokens: 2000,
};
const SUMMARIZE_TOOL_OUTPUT_PROMPT = `Summarize the following tool output to be a maximum of {maxLength} characters. The summary should be concise and capture the main points of the tool output.
The summarization should be done based on the content that is provided. Here are the basic rules to follow:
1. If the text is a directory listing or any output that is structural, use the history of the conversation to understand the context. Using this context try to understand what information we need from the tool output and return that as a response.
2. If the text is text content and there is nothing structural that we need, summarize the text.
3. If the text is the output of a shell command, use the history of the conversation to understand the context. Using this context try to understand what information we need from the tool output and return a summarization along with the stack trace of any error within the <error></error> tags. The stack trace should be complete and not truncated. If there are warnings, you should include them in the summary within <warning></warning> tags.
Text to summarize:
"{textToSummarize}"
Return the summary string which should first contain an overall summarization of text followed by the full stack trace of errors and warnings in the tool output.
`;
export const llmSummarizer: Summarizer = (result, geminiClient, abortSignal) =>
summarizeToolOutput(
partToString(result.llmContent),
geminiClient,
abortSignal,
);
export async function summarizeToolOutput(
textToSummarize: string,
geminiClient: GeminiClient,
abortSignal: AbortSignal,
maxLength: number = 2000,
): Promise<string> {
if (!textToSummarize || textToSummarize.length < maxLength) {
return textToSummarize;
}
const prompt = SUMMARIZE_TOOL_OUTPUT_PROMPT.replace(
'{maxLength}',
String(maxLength),
).replace('{textToSummarize}', textToSummarize);
const contents: Content[] = [{ role: 'user', parts: [{ text: prompt }] }];
try {
const parsedResponse = (await geminiClient.generateContent(
contents,
toolOutputSummarizerConfig,
abortSignal,
toolOutputSummarizerModel,
)) as unknown as GenerateContentResponse;
return getResponseText(parsedResponse) || textToSummarize;
} catch (error) {
console.error('Failed to summarize tool output.', error);
return textToSummarize;
}
}

View File

@@ -0,0 +1,87 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
/**
* Testing utilities for simulating 429 errors in unit tests
*/
let requestCounter = 0;
let simulate429Enabled = false;
let simulate429AfterRequests = 0;
let simulate429ForAuthType: string | undefined;
let fallbackOccurred = false;
/**
* Check if we should simulate a 429 error for the current request
*/
export function shouldSimulate429(authType?: string): boolean {
if (!simulate429Enabled || fallbackOccurred) {
return false;
}
// If auth type filter is set, only simulate for that auth type
if (simulate429ForAuthType && authType !== simulate429ForAuthType) {
return false;
}
requestCounter++;
// If afterRequests is set, only simulate after that many requests
if (simulate429AfterRequests > 0) {
return requestCounter > simulate429AfterRequests;
}
// Otherwise, simulate for every request
return true;
}
/**
* Reset the request counter (useful for tests)
*/
export function resetRequestCounter(): void {
requestCounter = 0;
}
/**
* Disable 429 simulation after successful fallback
*/
export function disableSimulationAfterFallback(): void {
fallbackOccurred = true;
}
/**
* Create a simulated 429 error response
*/
export function createSimulated429Error(): Error {
const error = new Error('Rate limit exceeded (simulated)') as Error & {
status: number;
};
error.status = 429;
return error;
}
/**
* Reset simulation state when switching auth methods
*/
export function resetSimulationState(): void {
fallbackOccurred = false;
resetRequestCounter();
}
/**
* Enable/disable 429 simulation programmatically (for tests)
*/
export function setSimulate429(
enabled: boolean,
afterRequests = 0,
forAuthType?: string,
): void {
simulate429Enabled = enabled;
simulate429AfterRequests = afterRequests;
simulate429ForAuthType = forAuthType;
fallbackOccurred = false; // Reset fallback state when simulation is re-enabled
resetRequestCounter();
}

View File

@@ -0,0 +1,237 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { vi, describe, it, expect, beforeEach, afterEach, Mock } from 'vitest';
import {
cacheGoogleAccount,
getCachedGoogleAccount,
clearCachedGoogleAccount,
getLifetimeGoogleAccounts,
} from './user_account.js';
import * as fs from 'node:fs';
import * as os from 'node:os';
import path from 'node:path';
vi.mock('os', async (importOriginal) => {
const os = await importOriginal<typeof import('os')>();
return {
...os,
homedir: vi.fn(),
};
});
describe('user_account', () => {
let tempHomeDir: string;
const accountsFile = () =>
path.join(tempHomeDir, '.qwen', 'google_accounts.json');
beforeEach(() => {
tempHomeDir = fs.mkdtempSync(
path.join(os.tmpdir(), 'gemini-cli-test-home-'),
);
(os.homedir as Mock).mockReturnValue(tempHomeDir);
});
afterEach(() => {
fs.rmSync(tempHomeDir, { recursive: true, force: true });
vi.clearAllMocks();
});
describe('cacheGoogleAccount', () => {
it('should create directory and write initial account file', async () => {
await cacheGoogleAccount('test1@google.com');
// Verify Google Account ID was cached
expect(fs.existsSync(accountsFile())).toBe(true);
expect(fs.readFileSync(accountsFile(), 'utf-8')).toBe(
JSON.stringify({ active: 'test1@google.com', old: [] }, null, 2),
);
});
it('should update active account and move previous to old', async () => {
fs.mkdirSync(path.dirname(accountsFile()), { recursive: true });
fs.writeFileSync(
accountsFile(),
JSON.stringify(
{ active: 'test2@google.com', old: ['test1@google.com'] },
null,
2,
),
);
await cacheGoogleAccount('test3@google.com');
expect(fs.readFileSync(accountsFile(), 'utf-8')).toBe(
JSON.stringify(
{
active: 'test3@google.com',
old: ['test1@google.com', 'test2@google.com'],
},
null,
2,
),
);
});
it('should not add a duplicate to the old list', async () => {
fs.mkdirSync(path.dirname(accountsFile()), { recursive: true });
fs.writeFileSync(
accountsFile(),
JSON.stringify(
{ active: 'test1@google.com', old: ['test2@google.com'] },
null,
2,
),
);
await cacheGoogleAccount('test2@google.com');
await cacheGoogleAccount('test1@google.com');
expect(fs.readFileSync(accountsFile(), 'utf-8')).toBe(
JSON.stringify(
{ active: 'test1@google.com', old: ['test2@google.com'] },
null,
2,
),
);
});
it('should handle corrupted JSON by starting fresh', async () => {
fs.mkdirSync(path.dirname(accountsFile()), { recursive: true });
fs.writeFileSync(accountsFile(), 'not valid json');
const consoleDebugSpy = vi
.spyOn(console, 'debug')
.mockImplementation(() => {});
await cacheGoogleAccount('test1@google.com');
expect(consoleDebugSpy).toHaveBeenCalled();
expect(JSON.parse(fs.readFileSync(accountsFile(), 'utf-8'))).toEqual({
active: 'test1@google.com',
old: [],
});
});
});
describe('getCachedGoogleAccount', () => {
it('should return the active account if file exists and is valid', () => {
fs.mkdirSync(path.dirname(accountsFile()), { recursive: true });
fs.writeFileSync(
accountsFile(),
JSON.stringify({ active: 'active@google.com', old: [] }, null, 2),
);
const account = getCachedGoogleAccount();
expect(account).toBe('active@google.com');
});
it('should return null if file does not exist', () => {
const account = getCachedGoogleAccount();
expect(account).toBeNull();
});
it('should return null if file is empty', () => {
fs.mkdirSync(path.dirname(accountsFile()), { recursive: true });
fs.writeFileSync(accountsFile(), '');
const account = getCachedGoogleAccount();
expect(account).toBeNull();
});
it('should return null and log if file is corrupted', () => {
fs.mkdirSync(path.dirname(accountsFile()), { recursive: true });
fs.writeFileSync(accountsFile(), '{ "active": "test@google.com"'); // Invalid JSON
const consoleDebugSpy = vi
.spyOn(console, 'debug')
.mockImplementation(() => {});
const account = getCachedGoogleAccount();
expect(account).toBeNull();
expect(consoleDebugSpy).toHaveBeenCalled();
});
});
describe('clearCachedGoogleAccount', () => {
it('should set active to null and move it to old', async () => {
fs.mkdirSync(path.dirname(accountsFile()), { recursive: true });
fs.writeFileSync(
accountsFile(),
JSON.stringify(
{ active: 'active@google.com', old: ['old1@google.com'] },
null,
2,
),
);
await clearCachedGoogleAccount();
const stored = JSON.parse(fs.readFileSync(accountsFile(), 'utf-8'));
expect(stored.active).toBeNull();
expect(stored.old).toEqual(['old1@google.com', 'active@google.com']);
});
it('should handle empty file gracefully', async () => {
fs.mkdirSync(path.dirname(accountsFile()), { recursive: true });
fs.writeFileSync(accountsFile(), '');
await clearCachedGoogleAccount();
const stored = JSON.parse(fs.readFileSync(accountsFile(), 'utf-8'));
expect(stored.active).toBeNull();
expect(stored.old).toEqual([]);
});
});
describe('getLifetimeGoogleAccounts', () => {
it('should return 0 if the file does not exist', () => {
expect(getLifetimeGoogleAccounts()).toBe(0);
});
it('should return 0 if the file is empty', () => {
fs.mkdirSync(path.dirname(accountsFile()), { recursive: true });
fs.writeFileSync(accountsFile(), '');
expect(getLifetimeGoogleAccounts()).toBe(0);
});
it('should return 0 if the file is corrupted', () => {
fs.mkdirSync(path.dirname(accountsFile()), { recursive: true });
fs.writeFileSync(accountsFile(), 'invalid json');
const consoleDebugSpy = vi
.spyOn(console, 'debug')
.mockImplementation(() => {});
expect(getLifetimeGoogleAccounts()).toBe(0);
expect(consoleDebugSpy).toHaveBeenCalled();
});
it('should return 1 if there is only an active account', () => {
fs.mkdirSync(path.dirname(accountsFile()), { recursive: true });
fs.writeFileSync(
accountsFile(),
JSON.stringify({ active: 'test1@google.com', old: [] }),
);
expect(getLifetimeGoogleAccounts()).toBe(1);
});
it('should correctly count old accounts when active is null', () => {
fs.mkdirSync(path.dirname(accountsFile()), { recursive: true });
fs.writeFileSync(
accountsFile(),
JSON.stringify({
active: null,
old: ['test1@google.com', 'test2@google.com'],
}),
);
expect(getLifetimeGoogleAccounts()).toBe(2);
});
it('should correctly count both active and old accounts', () => {
fs.mkdirSync(path.dirname(accountsFile()), { recursive: true });
fs.writeFileSync(
accountsFile(),
JSON.stringify({
active: 'test3@google.com',
old: ['test1@google.com', 'test2@google.com'],
}),
);
expect(getLifetimeGoogleAccounts()).toBe(3);
});
});
});

View File

@@ -0,0 +1,115 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import path from 'node:path';
import { promises as fsp, existsSync, readFileSync } from 'node:fs';
import * as os from 'os';
import { GEMINI_DIR, GOOGLE_ACCOUNTS_FILENAME } from './paths.js';
interface UserAccounts {
active: string | null;
old: string[];
}
function getGoogleAccountsCachePath(): string {
return path.join(os.homedir(), GEMINI_DIR, GOOGLE_ACCOUNTS_FILENAME);
}
async function readAccounts(filePath: string): Promise<UserAccounts> {
try {
const content = await fsp.readFile(filePath, 'utf-8');
if (!content.trim()) {
return { active: null, old: [] };
}
return JSON.parse(content) as UserAccounts;
} catch (error) {
if (error instanceof Error && 'code' in error && error.code === 'ENOENT') {
// File doesn't exist, which is fine.
return { active: null, old: [] };
}
// File is corrupted or not valid JSON, start with a fresh object.
console.debug('Could not parse accounts file, starting fresh.', error);
return { active: null, old: [] };
}
}
export async function cacheGoogleAccount(email: string): Promise<void> {
const filePath = getGoogleAccountsCachePath();
await fsp.mkdir(path.dirname(filePath), { recursive: true });
const accounts = await readAccounts(filePath);
if (accounts.active && accounts.active !== email) {
if (!accounts.old.includes(accounts.active)) {
accounts.old.push(accounts.active);
}
}
// If the new email was in the old list, remove it
accounts.old = accounts.old.filter((oldEmail) => oldEmail !== email);
accounts.active = email;
await fsp.writeFile(filePath, JSON.stringify(accounts, null, 2), 'utf-8');
}
export function getCachedGoogleAccount(): string | null {
try {
const filePath = getGoogleAccountsCachePath();
if (existsSync(filePath)) {
const content = readFileSync(filePath, 'utf-8').trim();
if (!content) {
return null;
}
const accounts: UserAccounts = JSON.parse(content);
return accounts.active;
}
return null;
} catch (error) {
console.debug('Error reading cached Google Account:', error);
return null;
}
}
export function getLifetimeGoogleAccounts(): number {
try {
const filePath = getGoogleAccountsCachePath();
if (!existsSync(filePath)) {
return 0;
}
const content = readFileSync(filePath, 'utf-8').trim();
if (!content) {
return 0;
}
const accounts: UserAccounts = JSON.parse(content);
let count = accounts.old.length;
if (accounts.active) {
count++;
}
return count;
} catch (error) {
console.debug('Error reading lifetime Google Accounts:', error);
return 0;
}
}
export async function clearCachedGoogleAccount(): Promise<void> {
const filePath = getGoogleAccountsCachePath();
if (!existsSync(filePath)) {
return;
}
const accounts = await readAccounts(filePath);
if (accounts.active) {
if (!accounts.old.includes(accounts.active)) {
accounts.old.push(accounts.active);
}
accounts.active = null;
}
await fsp.writeFile(filePath, JSON.stringify(accounts, null, 2), 'utf-8');
}

View File

@@ -0,0 +1,24 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { describe, it, expect } from 'vitest';
import { getInstallationId } from './user_id.js';
describe('user_id', () => {
describe('getInstallationId', () => {
it('should return a valid UUID format string', () => {
const installationId = getInstallationId();
expect(installationId).toBeDefined();
expect(typeof installationId).toBe('string');
expect(installationId.length).toBeGreaterThan(0);
// Should return the same ID on subsequent calls (consistent)
const secondCall = getInstallationId();
expect(secondCall).toBe(installationId);
});
});
});

View File

@@ -0,0 +1,58 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import * as os from 'os';
import * as fs from 'fs';
import * as path from 'path';
import { randomUUID } from 'crypto';
import { GEMINI_DIR } from './paths.js';
const homeDir = os.homedir() ?? '';
const geminiDir = path.join(homeDir, GEMINI_DIR);
const installationIdFile = path.join(geminiDir, 'installation_id');
function ensureGeminiDirExists() {
if (!fs.existsSync(geminiDir)) {
fs.mkdirSync(geminiDir, { recursive: true });
}
}
function readInstallationIdFromFile(): string | null {
if (fs.existsSync(installationIdFile)) {
const installationid = fs.readFileSync(installationIdFile, 'utf-8').trim();
return installationid || null;
}
return null;
}
function writeInstallationIdToFile(installationId: string) {
fs.writeFileSync(installationIdFile, installationId, 'utf-8');
}
/**
* Retrieves the installation ID from a file, creating it if it doesn't exist.
* This ID is used for unique user installation tracking.
* @returns A UUID string for the user.
*/
export function getInstallationId(): string {
try {
ensureGeminiDirExists();
let installationId = readInstallationIdFromFile();
if (!installationId) {
installationId = randomUUID();
writeInstallationIdToFile(installationId);
}
return installationId;
} catch (error) {
console.error(
'Error accessing installation ID file, generating ephemeral ID:',
error,
);
return '123456789';
}
}