mirror of
https://github.com/QwenLM/qwen-code.git
synced 2025-12-20 16:57:46 +00:00
Rename server->core (#638)
This commit is contained in:
committed by
GitHub
parent
c81148a0cc
commit
21fba832d1
41
packages/core/src/utils/LruCache.ts
Normal file
41
packages/core/src/utils/LruCache.ts
Normal file
@@ -0,0 +1,41 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
export class LruCache<K, V> {
|
||||
private cache: Map<K, V>;
|
||||
private maxSize: number;
|
||||
|
||||
constructor(maxSize: number) {
|
||||
this.cache = new Map<K, V>();
|
||||
this.maxSize = maxSize;
|
||||
}
|
||||
|
||||
get(key: K): V | undefined {
|
||||
const value = this.cache.get(key);
|
||||
if (value) {
|
||||
// Move to end to mark as recently used
|
||||
this.cache.delete(key);
|
||||
this.cache.set(key, value);
|
||||
}
|
||||
return value;
|
||||
}
|
||||
|
||||
set(key: K, value: V): void {
|
||||
if (this.cache.has(key)) {
|
||||
this.cache.delete(key);
|
||||
} else if (this.cache.size >= this.maxSize) {
|
||||
const firstKey = this.cache.keys().next().value;
|
||||
if (firstKey !== undefined) {
|
||||
this.cache.delete(firstKey);
|
||||
}
|
||||
}
|
||||
this.cache.set(key, value);
|
||||
}
|
||||
|
||||
clear(): void {
|
||||
this.cache.clear();
|
||||
}
|
||||
}
|
||||
503
packages/core/src/utils/editCorrector.test.ts
Normal file
503
packages/core/src/utils/editCorrector.test.ts
Normal file
@@ -0,0 +1,503 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
/* eslint-disable @typescript-eslint/no-explicit-any */
|
||||
import { vi, describe, it, expect, beforeEach, type Mocked } from 'vitest';
|
||||
|
||||
// MOCKS
|
||||
let callCount = 0;
|
||||
const mockResponses: any[] = [];
|
||||
|
||||
let mockGenerateJson: any;
|
||||
let mockStartChat: any;
|
||||
let mockSendMessageStream: any;
|
||||
|
||||
vi.mock('../core/client.js', () => ({
|
||||
GeminiClient: vi.fn().mockImplementation(function (
|
||||
this: any,
|
||||
_config: Config,
|
||||
) {
|
||||
this.generateJson = (...params: any[]) => mockGenerateJson(...params); // Corrected: use mockGenerateJson
|
||||
this.startChat = (...params: any[]) => mockStartChat(...params); // Corrected: use mockStartChat
|
||||
this.sendMessageStream = (...params: any[]) =>
|
||||
mockSendMessageStream(...params); // Corrected: use mockSendMessageStream
|
||||
return this;
|
||||
}),
|
||||
}));
|
||||
// END MOCKS
|
||||
|
||||
import {
|
||||
countOccurrences,
|
||||
ensureCorrectEdit,
|
||||
unescapeStringForGeminiBug,
|
||||
resetEditCorrectorCaches_TEST_ONLY,
|
||||
} from './editCorrector.js';
|
||||
import { GeminiClient } from '../core/client.js';
|
||||
import type { Config } from '../config/config.js';
|
||||
import { ToolRegistry } from '../tools/tool-registry.js';
|
||||
|
||||
vi.mock('../tools/tool-registry.js');
|
||||
|
||||
describe('editCorrector', () => {
|
||||
describe('countOccurrences', () => {
|
||||
it('should return 0 for empty string', () => {
|
||||
expect(countOccurrences('', 'a')).toBe(0);
|
||||
});
|
||||
it('should return 0 for empty substring', () => {
|
||||
expect(countOccurrences('abc', '')).toBe(0);
|
||||
});
|
||||
it('should return 0 if substring is not found', () => {
|
||||
expect(countOccurrences('abc', 'd')).toBe(0);
|
||||
});
|
||||
it('should return 1 if substring is found once', () => {
|
||||
expect(countOccurrences('abc', 'b')).toBe(1);
|
||||
});
|
||||
it('should return correct count for multiple occurrences', () => {
|
||||
expect(countOccurrences('ababa', 'a')).toBe(3);
|
||||
expect(countOccurrences('ababab', 'ab')).toBe(3);
|
||||
});
|
||||
it('should count non-overlapping occurrences', () => {
|
||||
expect(countOccurrences('aaaaa', 'aa')).toBe(2);
|
||||
expect(countOccurrences('ababab', 'aba')).toBe(1);
|
||||
});
|
||||
it('should correctly count occurrences when substring is longer', () => {
|
||||
expect(countOccurrences('abc', 'abcdef')).toBe(0);
|
||||
});
|
||||
it('should be case sensitive', () => {
|
||||
expect(countOccurrences('abcABC', 'a')).toBe(1);
|
||||
expect(countOccurrences('abcABC', 'A')).toBe(1);
|
||||
});
|
||||
});
|
||||
|
||||
describe('unescapeStringForGeminiBug', () => {
|
||||
it('should unescape common sequences', () => {
|
||||
expect(unescapeStringForGeminiBug('\\n')).toBe('\n');
|
||||
expect(unescapeStringForGeminiBug('\\t')).toBe('\t');
|
||||
expect(unescapeStringForGeminiBug("\\'")).toBe("'");
|
||||
expect(unescapeStringForGeminiBug('\\"')).toBe('"');
|
||||
expect(unescapeStringForGeminiBug('\\`')).toBe('`');
|
||||
});
|
||||
it('should handle multiple escaped sequences', () => {
|
||||
expect(unescapeStringForGeminiBug('Hello\\nWorld\\tTest')).toBe(
|
||||
'Hello\nWorld\tTest',
|
||||
);
|
||||
});
|
||||
it('should not alter already correct sequences', () => {
|
||||
expect(unescapeStringForGeminiBug('\n')).toBe('\n');
|
||||
expect(unescapeStringForGeminiBug('Correct string')).toBe(
|
||||
'Correct string',
|
||||
);
|
||||
});
|
||||
it('should handle mixed correct and incorrect sequences', () => {
|
||||
expect(unescapeStringForGeminiBug('\\nCorrect\t\\`')).toBe(
|
||||
'\nCorrect\t`',
|
||||
);
|
||||
});
|
||||
it('should handle backslash followed by actual newline character', () => {
|
||||
expect(unescapeStringForGeminiBug('\\\n')).toBe('\n');
|
||||
expect(unescapeStringForGeminiBug('First line\\\nSecond line')).toBe(
|
||||
'First line\nSecond line',
|
||||
);
|
||||
});
|
||||
it('should handle multiple backslashes before an escapable character', () => {
|
||||
expect(unescapeStringForGeminiBug('\\\\n')).toBe('\n');
|
||||
expect(unescapeStringForGeminiBug('\\\\\\t')).toBe('\t');
|
||||
expect(unescapeStringForGeminiBug('\\\\\\\\`')).toBe('`');
|
||||
});
|
||||
it('should return empty string for empty input', () => {
|
||||
expect(unescapeStringForGeminiBug('')).toBe('');
|
||||
});
|
||||
it('should not alter strings with no targeted escape sequences', () => {
|
||||
expect(unescapeStringForGeminiBug('abc def')).toBe('abc def');
|
||||
expect(unescapeStringForGeminiBug('C:\\Folder\\File')).toBe(
|
||||
'C:\\Folder\\File',
|
||||
);
|
||||
});
|
||||
it('should correctly process strings with some targeted escapes', () => {
|
||||
expect(unescapeStringForGeminiBug('C:\\Users\\name')).toBe(
|
||||
'C:\\Users\name',
|
||||
);
|
||||
});
|
||||
it('should handle complex cases with mixed slashes and characters', () => {
|
||||
expect(
|
||||
unescapeStringForGeminiBug('\\\\\\\nLine1\\\nLine2\\tTab\\\\`Tick\\"'),
|
||||
).toBe('\nLine1\nLine2\tTab`Tick"');
|
||||
});
|
||||
});
|
||||
|
||||
describe('ensureCorrectEdit', () => {
|
||||
let mockGeminiClientInstance: Mocked<GeminiClient>;
|
||||
let mockToolRegistry: Mocked<ToolRegistry>;
|
||||
let mockConfigInstance: Config;
|
||||
const abortSignal = new AbortController().signal;
|
||||
|
||||
beforeEach(() => {
|
||||
mockToolRegistry = new ToolRegistry({} as Config) as Mocked<ToolRegistry>;
|
||||
const configParams = {
|
||||
apiKey: 'test-api-key',
|
||||
model: 'test-model',
|
||||
sandbox: false as boolean | string,
|
||||
targetDir: '/test',
|
||||
debugMode: false,
|
||||
question: undefined as string | undefined,
|
||||
fullContext: false,
|
||||
coreTools: undefined as string[] | undefined,
|
||||
toolDiscoveryCommand: undefined as string | undefined,
|
||||
toolCallCommand: undefined as string | undefined,
|
||||
mcpServerCommand: undefined as string | undefined,
|
||||
mcpServers: undefined as Record<string, any> | undefined,
|
||||
userAgent: 'test-agent',
|
||||
userMemory: '',
|
||||
geminiMdFileCount: 0,
|
||||
alwaysSkipModificationConfirmation: false,
|
||||
};
|
||||
mockConfigInstance = {
|
||||
...configParams,
|
||||
getApiKey: vi.fn(() => configParams.apiKey),
|
||||
getModel: vi.fn(() => configParams.model),
|
||||
getSandbox: vi.fn(() => configParams.sandbox),
|
||||
getTargetDir: vi.fn(() => configParams.targetDir),
|
||||
getToolRegistry: vi.fn(() => mockToolRegistry),
|
||||
getDebugMode: vi.fn(() => configParams.debugMode),
|
||||
getQuestion: vi.fn(() => configParams.question),
|
||||
getFullContext: vi.fn(() => configParams.fullContext),
|
||||
getCoreTools: vi.fn(() => configParams.coreTools),
|
||||
getToolDiscoveryCommand: vi.fn(() => configParams.toolDiscoveryCommand),
|
||||
getToolCallCommand: vi.fn(() => configParams.toolCallCommand),
|
||||
getMcpServerCommand: vi.fn(() => configParams.mcpServerCommand),
|
||||
getMcpServers: vi.fn(() => configParams.mcpServers),
|
||||
getUserAgent: vi.fn(() => configParams.userAgent),
|
||||
getUserMemory: vi.fn(() => configParams.userMemory),
|
||||
setUserMemory: vi.fn((mem: string) => {
|
||||
configParams.userMemory = mem;
|
||||
}),
|
||||
getGeminiMdFileCount: vi.fn(() => configParams.geminiMdFileCount),
|
||||
setGeminiMdFileCount: vi.fn((count: number) => {
|
||||
configParams.geminiMdFileCount = count;
|
||||
}),
|
||||
getAlwaysSkipModificationConfirmation: vi.fn(
|
||||
() => configParams.alwaysSkipModificationConfirmation,
|
||||
),
|
||||
setAlwaysSkipModificationConfirmation: vi.fn((skip: boolean) => {
|
||||
configParams.alwaysSkipModificationConfirmation = skip;
|
||||
}),
|
||||
} as unknown as Config;
|
||||
|
||||
callCount = 0;
|
||||
mockResponses.length = 0;
|
||||
mockGenerateJson = vi
|
||||
.fn()
|
||||
.mockImplementation((_contents, _schema, signal) => {
|
||||
// Check if the signal is aborted. If so, throw an error or return a specific response.
|
||||
if (signal && signal.aborted) {
|
||||
return Promise.reject(new Error('Aborted')); // Or some other specific error/response
|
||||
}
|
||||
const response = mockResponses[callCount];
|
||||
callCount++;
|
||||
if (response === undefined) return Promise.resolve({});
|
||||
return Promise.resolve(response);
|
||||
});
|
||||
mockStartChat = vi.fn();
|
||||
mockSendMessageStream = vi.fn();
|
||||
|
||||
mockGeminiClientInstance = new GeminiClient(
|
||||
mockConfigInstance,
|
||||
) as Mocked<GeminiClient>;
|
||||
resetEditCorrectorCaches_TEST_ONLY();
|
||||
});
|
||||
|
||||
describe('Scenario Group 1: originalParams.old_string matches currentContent directly', () => {
|
||||
it('Test 1.1: old_string (no literal \\), new_string (escaped by Gemini) -> new_string unescaped', async () => {
|
||||
const currentContent = 'This is a test string to find me.';
|
||||
const originalParams = {
|
||||
file_path: '/test/file.txt',
|
||||
old_string: 'find me',
|
||||
new_string: 'replace with \\"this\\"',
|
||||
};
|
||||
mockResponses.push({
|
||||
corrected_new_string_escaping: 'replace with "this"',
|
||||
});
|
||||
const result = await ensureCorrectEdit(
|
||||
currentContent,
|
||||
originalParams,
|
||||
mockGeminiClientInstance,
|
||||
abortSignal,
|
||||
);
|
||||
expect(mockGenerateJson).toHaveBeenCalledTimes(1);
|
||||
expect(result.params.new_string).toBe('replace with "this"');
|
||||
expect(result.params.old_string).toBe('find me');
|
||||
expect(result.occurrences).toBe(1);
|
||||
});
|
||||
it('Test 1.2: old_string (no literal \\), new_string (correctly formatted) -> new_string unchanged', async () => {
|
||||
const currentContent = 'This is a test string to find me.';
|
||||
const originalParams = {
|
||||
file_path: '/test/file.txt',
|
||||
old_string: 'find me',
|
||||
new_string: 'replace with this',
|
||||
};
|
||||
const result = await ensureCorrectEdit(
|
||||
currentContent,
|
||||
originalParams,
|
||||
mockGeminiClientInstance,
|
||||
abortSignal,
|
||||
);
|
||||
expect(mockGenerateJson).toHaveBeenCalledTimes(0);
|
||||
expect(result.params.new_string).toBe('replace with this');
|
||||
expect(result.params.old_string).toBe('find me');
|
||||
expect(result.occurrences).toBe(1);
|
||||
});
|
||||
it('Test 1.3: old_string (with literal \\), new_string (escaped by Gemini) -> new_string unchanged (still escaped)', async () => {
|
||||
const currentContent = 'This is a test string to find\\me.';
|
||||
const originalParams = {
|
||||
file_path: '/test/file.txt',
|
||||
old_string: 'find\\me',
|
||||
new_string: 'replace with \\"this\\"',
|
||||
};
|
||||
mockResponses.push({
|
||||
corrected_new_string_escaping: 'replace with "this"',
|
||||
});
|
||||
const result = await ensureCorrectEdit(
|
||||
currentContent,
|
||||
originalParams,
|
||||
mockGeminiClientInstance,
|
||||
abortSignal,
|
||||
);
|
||||
expect(mockGenerateJson).toHaveBeenCalledTimes(1);
|
||||
expect(result.params.new_string).toBe('replace with "this"');
|
||||
expect(result.params.old_string).toBe('find\\me');
|
||||
expect(result.occurrences).toBe(1);
|
||||
});
|
||||
it('Test 1.4: old_string (with literal \\), new_string (correctly formatted) -> new_string unchanged', async () => {
|
||||
const currentContent = 'This is a test string to find\\me.';
|
||||
const originalParams = {
|
||||
file_path: '/test/file.txt',
|
||||
old_string: 'find\\me',
|
||||
new_string: 'replace with this',
|
||||
};
|
||||
const result = await ensureCorrectEdit(
|
||||
currentContent,
|
||||
originalParams,
|
||||
mockGeminiClientInstance,
|
||||
abortSignal,
|
||||
);
|
||||
expect(mockGenerateJson).toHaveBeenCalledTimes(0);
|
||||
expect(result.params.new_string).toBe('replace with this');
|
||||
expect(result.params.old_string).toBe('find\\me');
|
||||
expect(result.occurrences).toBe(1);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Scenario Group 2: originalParams.old_string does NOT match, but unescapeStringForGeminiBug(originalParams.old_string) DOES match', () => {
|
||||
it('Test 2.1: old_string (over-escaped, no intended literal \\), new_string (escaped by Gemini) -> new_string unescaped', async () => {
|
||||
const currentContent = 'This is a test string to find "me".';
|
||||
const originalParams = {
|
||||
file_path: '/test/file.txt',
|
||||
old_string: 'find \\"me\\"',
|
||||
new_string: 'replace with \\"this\\"',
|
||||
};
|
||||
mockResponses.push({ corrected_new_string: 'replace with "this"' });
|
||||
const result = await ensureCorrectEdit(
|
||||
currentContent,
|
||||
originalParams,
|
||||
mockGeminiClientInstance,
|
||||
abortSignal,
|
||||
);
|
||||
expect(mockGenerateJson).toHaveBeenCalledTimes(1);
|
||||
expect(result.params.new_string).toBe('replace with "this"');
|
||||
expect(result.params.old_string).toBe('find "me"');
|
||||
expect(result.occurrences).toBe(1);
|
||||
});
|
||||
it('Test 2.2: old_string (over-escaped, no intended literal \\), new_string (correctly formatted) -> new_string unescaped (harmlessly)', async () => {
|
||||
const currentContent = 'This is a test string to find "me".';
|
||||
const originalParams = {
|
||||
file_path: '/test/file.txt',
|
||||
old_string: 'find \\"me\\"',
|
||||
new_string: 'replace with this',
|
||||
};
|
||||
const result = await ensureCorrectEdit(
|
||||
currentContent,
|
||||
originalParams,
|
||||
mockGeminiClientInstance,
|
||||
abortSignal,
|
||||
);
|
||||
expect(mockGenerateJson).toHaveBeenCalledTimes(0);
|
||||
expect(result.params.new_string).toBe('replace with this');
|
||||
expect(result.params.old_string).toBe('find "me"');
|
||||
expect(result.occurrences).toBe(1);
|
||||
});
|
||||
it('Test 2.3: old_string (over-escaped, with intended literal \\), new_string (simple) -> new_string corrected', async () => {
|
||||
const currentContent = 'This is a test string to find \\me.';
|
||||
const originalParams = {
|
||||
file_path: '/test/file.txt',
|
||||
old_string: 'find \\\\me',
|
||||
new_string: 'replace with foobar',
|
||||
};
|
||||
mockResponses.push({
|
||||
corrected_target_snippet: 'find \\me',
|
||||
});
|
||||
const result = await ensureCorrectEdit(
|
||||
currentContent,
|
||||
originalParams,
|
||||
mockGeminiClientInstance,
|
||||
abortSignal,
|
||||
);
|
||||
expect(mockGenerateJson).toHaveBeenCalledTimes(1);
|
||||
expect(result.params.new_string).toBe('replace with foobar');
|
||||
expect(result.params.old_string).toBe('find \\me');
|
||||
expect(result.occurrences).toBe(1);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Scenario Group 3: LLM Correction Path', () => {
|
||||
it('Test 3.1: old_string (no literal \\), new_string (escaped by Gemini), LLM re-escapes new_string -> final new_string is double unescaped', async () => {
|
||||
const currentContent = 'This is a test string to corrected find me.';
|
||||
const originalParams = {
|
||||
file_path: '/test/file.txt',
|
||||
old_string: 'find me',
|
||||
new_string: 'replace with \\\\"this\\\\"',
|
||||
};
|
||||
const llmNewString = 'LLM says replace with "that"';
|
||||
mockResponses.push({ corrected_new_string_escaping: llmNewString });
|
||||
const result = await ensureCorrectEdit(
|
||||
currentContent,
|
||||
originalParams,
|
||||
mockGeminiClientInstance,
|
||||
abortSignal,
|
||||
);
|
||||
expect(mockGenerateJson).toHaveBeenCalledTimes(1);
|
||||
expect(result.params.new_string).toBe(llmNewString);
|
||||
expect(result.params.old_string).toBe('find me');
|
||||
expect(result.occurrences).toBe(1);
|
||||
});
|
||||
it('Test 3.2: old_string (with literal \\), new_string (escaped by Gemini), LLM re-escapes new_string -> final new_string is unescaped once', async () => {
|
||||
const currentContent = 'This is a test string to corrected find me.';
|
||||
const originalParams = {
|
||||
file_path: '/test/file.txt',
|
||||
old_string: 'find\\me',
|
||||
new_string: 'replace with \\\\"this\\\\"',
|
||||
};
|
||||
const llmCorrectedOldString = 'corrected find me';
|
||||
const llmNewString = 'LLM says replace with "that"';
|
||||
mockResponses.push({ corrected_target_snippet: llmCorrectedOldString });
|
||||
mockResponses.push({ corrected_new_string: llmNewString });
|
||||
const result = await ensureCorrectEdit(
|
||||
currentContent,
|
||||
originalParams,
|
||||
mockGeminiClientInstance,
|
||||
abortSignal,
|
||||
);
|
||||
expect(mockGenerateJson).toHaveBeenCalledTimes(2);
|
||||
expect(result.params.new_string).toBe(llmNewString);
|
||||
expect(result.params.old_string).toBe(llmCorrectedOldString);
|
||||
expect(result.occurrences).toBe(1);
|
||||
});
|
||||
it('Test 3.3: old_string needs LLM, new_string is fine -> old_string corrected, new_string original', async () => {
|
||||
const currentContent = 'This is a test string to be corrected.';
|
||||
const originalParams = {
|
||||
file_path: '/test/file.txt',
|
||||
old_string: 'fiiind me',
|
||||
new_string: 'replace with "this"',
|
||||
};
|
||||
const llmCorrectedOldString = 'to be corrected';
|
||||
mockResponses.push({ corrected_target_snippet: llmCorrectedOldString });
|
||||
const result = await ensureCorrectEdit(
|
||||
currentContent,
|
||||
originalParams,
|
||||
mockGeminiClientInstance,
|
||||
abortSignal,
|
||||
);
|
||||
expect(mockGenerateJson).toHaveBeenCalledTimes(1);
|
||||
expect(result.params.new_string).toBe('replace with "this"');
|
||||
expect(result.params.old_string).toBe(llmCorrectedOldString);
|
||||
expect(result.occurrences).toBe(1);
|
||||
});
|
||||
it('Test 3.4: LLM correction path, correctNewString returns the originalNewString it was passed (which was unescaped) -> final new_string is unescaped', async () => {
|
||||
const currentContent = 'This is a test string to corrected find me.';
|
||||
const originalParams = {
|
||||
file_path: '/test/file.txt',
|
||||
old_string: 'find me',
|
||||
new_string: 'replace with \\\\"this\\\\"',
|
||||
};
|
||||
const newStringForLLMAndReturnedByLLM = 'replace with "this"';
|
||||
mockResponses.push({
|
||||
corrected_new_string_escaping: newStringForLLMAndReturnedByLLM,
|
||||
});
|
||||
const result = await ensureCorrectEdit(
|
||||
currentContent,
|
||||
originalParams,
|
||||
mockGeminiClientInstance,
|
||||
abortSignal,
|
||||
);
|
||||
expect(mockGenerateJson).toHaveBeenCalledTimes(1);
|
||||
expect(result.params.new_string).toBe(newStringForLLMAndReturnedByLLM);
|
||||
expect(result.occurrences).toBe(1);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Scenario Group 4: No Match Found / Multiple Matches', () => {
|
||||
it('Test 4.1: No version of old_string (original, unescaped, LLM-corrected) matches -> returns original params, 0 occurrences', async () => {
|
||||
const currentContent = 'This content has nothing to find.';
|
||||
const originalParams = {
|
||||
file_path: '/test/file.txt',
|
||||
old_string: 'nonexistent string',
|
||||
new_string: 'some new string',
|
||||
};
|
||||
mockResponses.push({ corrected_target_snippet: 'still nonexistent' });
|
||||
const result = await ensureCorrectEdit(
|
||||
currentContent,
|
||||
originalParams,
|
||||
mockGeminiClientInstance,
|
||||
abortSignal,
|
||||
);
|
||||
expect(mockGenerateJson).toHaveBeenCalledTimes(1);
|
||||
expect(result.params).toEqual(originalParams);
|
||||
expect(result.occurrences).toBe(0);
|
||||
});
|
||||
it('Test 4.2: unescapedOldStringAttempt results in >1 occurrences -> returns original params, count occurrences', async () => {
|
||||
const currentContent =
|
||||
'This content has find "me" and also find "me" again.';
|
||||
const originalParams = {
|
||||
file_path: '/test/file.txt',
|
||||
old_string: 'find "me"',
|
||||
new_string: 'some new string',
|
||||
};
|
||||
const result = await ensureCorrectEdit(
|
||||
currentContent,
|
||||
originalParams,
|
||||
mockGeminiClientInstance,
|
||||
abortSignal,
|
||||
);
|
||||
expect(mockGenerateJson).toHaveBeenCalledTimes(0);
|
||||
expect(result.params).toEqual(originalParams);
|
||||
expect(result.occurrences).toBe(2);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Scenario Group 5: Specific unescapeStringForGeminiBug checks (integrated into ensureCorrectEdit)', () => {
|
||||
it('Test 5.1: old_string needs LLM to become currentContent, new_string also needs correction', async () => {
|
||||
const currentContent = 'const x = "a\\nbc\\\\"def\\\\"';
|
||||
const originalParams = {
|
||||
file_path: '/test/file.txt',
|
||||
old_string: 'const x = \\\\"a\\\\nbc\\\\\\\\"def\\\\\\\\"',
|
||||
new_string: 'const y = \\\\"new\\\\nval\\\\\\\\"content\\\\\\\\"',
|
||||
};
|
||||
const expectedFinalNewString = 'const y = "new\\nval\\\\"content\\\\"';
|
||||
mockResponses.push({ corrected_target_snippet: currentContent });
|
||||
mockResponses.push({ corrected_new_string: expectedFinalNewString });
|
||||
const result = await ensureCorrectEdit(
|
||||
currentContent,
|
||||
originalParams,
|
||||
mockGeminiClientInstance,
|
||||
abortSignal,
|
||||
);
|
||||
expect(mockGenerateJson).toHaveBeenCalledTimes(2);
|
||||
expect(result.params.old_string).toBe(currentContent);
|
||||
expect(result.params.new_string).toBe(expectedFinalNewString);
|
||||
expect(result.occurrences).toBe(1);
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
593
packages/core/src/utils/editCorrector.ts
Normal file
593
packages/core/src/utils/editCorrector.ts
Normal file
@@ -0,0 +1,593 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import {
|
||||
Content,
|
||||
GenerateContentConfig,
|
||||
SchemaUnion,
|
||||
Type,
|
||||
} from '@google/genai';
|
||||
import { GeminiClient } from '../core/client.js';
|
||||
import { EditToolParams } from '../tools/edit.js';
|
||||
import { LruCache } from './LruCache.js';
|
||||
|
||||
const EditModel = 'gemini-2.5-flash-preview-04-17';
|
||||
const EditConfig: GenerateContentConfig = {
|
||||
thinkingConfig: {
|
||||
thinkingBudget: 0,
|
||||
},
|
||||
};
|
||||
|
||||
const MAX_CACHE_SIZE = 50;
|
||||
|
||||
// Cache for ensureCorrectEdit results
|
||||
const editCorrectionCache = new LruCache<string, CorrectedEditResult>(
|
||||
MAX_CACHE_SIZE,
|
||||
);
|
||||
|
||||
// Cache for ensureCorrectFileContent results
|
||||
const fileContentCorrectionCache = new LruCache<string, string>(MAX_CACHE_SIZE);
|
||||
|
||||
/**
|
||||
* Defines the structure of the parameters within CorrectedEditResult
|
||||
*/
|
||||
interface CorrectedEditParams {
|
||||
file_path: string;
|
||||
old_string: string;
|
||||
new_string: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* Defines the result structure for ensureCorrectEdit.
|
||||
*/
|
||||
export interface CorrectedEditResult {
|
||||
params: CorrectedEditParams;
|
||||
occurrences: number;
|
||||
}
|
||||
|
||||
/**
|
||||
* Attempts to correct edit parameters if the original old_string is not found.
|
||||
* It tries unescaping, and then LLM-based correction.
|
||||
* Results are cached to avoid redundant processing.
|
||||
*
|
||||
* @param currentContent The current content of the file.
|
||||
* @param originalParams The original EditToolParams
|
||||
* @param client The GeminiClient for LLM calls.
|
||||
* @returns A promise resolving to an object containing the (potentially corrected)
|
||||
* EditToolParams (as CorrectedEditParams) and the final occurrences count.
|
||||
*/
|
||||
export async function ensureCorrectEdit(
|
||||
currentContent: string,
|
||||
originalParams: EditToolParams, // This is the EditToolParams from edit.ts, without \'corrected\'
|
||||
client: GeminiClient,
|
||||
abortSignal: AbortSignal,
|
||||
): Promise<CorrectedEditResult> {
|
||||
const cacheKey = `${currentContent}---${originalParams.old_string}---${originalParams.new_string}`;
|
||||
const cachedResult = editCorrectionCache.get(cacheKey);
|
||||
if (cachedResult) {
|
||||
return cachedResult;
|
||||
}
|
||||
|
||||
let finalNewString = originalParams.new_string;
|
||||
const newStringPotentiallyEscaped =
|
||||
unescapeStringForGeminiBug(originalParams.new_string) !==
|
||||
originalParams.new_string;
|
||||
|
||||
let finalOldString = originalParams.old_string;
|
||||
let occurrences = countOccurrences(currentContent, finalOldString);
|
||||
|
||||
if (occurrences === 1) {
|
||||
if (newStringPotentiallyEscaped) {
|
||||
finalNewString = await correctNewStringEscaping(
|
||||
client,
|
||||
finalOldString,
|
||||
originalParams.new_string,
|
||||
abortSignal,
|
||||
);
|
||||
}
|
||||
} else if (occurrences > 1) {
|
||||
const result: CorrectedEditResult = {
|
||||
params: { ...originalParams },
|
||||
occurrences,
|
||||
};
|
||||
editCorrectionCache.set(cacheKey, result);
|
||||
return result;
|
||||
} else {
|
||||
// occurrences is 0 or some other unexpected state initially
|
||||
const unescapedOldStringAttempt = unescapeStringForGeminiBug(
|
||||
originalParams.old_string,
|
||||
);
|
||||
occurrences = countOccurrences(currentContent, unescapedOldStringAttempt);
|
||||
|
||||
if (occurrences === 1) {
|
||||
finalOldString = unescapedOldStringAttempt;
|
||||
if (newStringPotentiallyEscaped) {
|
||||
finalNewString = await correctNewString(
|
||||
client,
|
||||
originalParams.old_string, // original old
|
||||
unescapedOldStringAttempt, // corrected old
|
||||
originalParams.new_string, // original new (which is potentially escaped)
|
||||
abortSignal,
|
||||
);
|
||||
}
|
||||
} else if (occurrences === 0) {
|
||||
const llmCorrectedOldString = await correctOldStringMismatch(
|
||||
client,
|
||||
currentContent,
|
||||
unescapedOldStringAttempt,
|
||||
abortSignal,
|
||||
);
|
||||
const llmOldOccurrences = countOccurrences(
|
||||
currentContent,
|
||||
llmCorrectedOldString,
|
||||
);
|
||||
|
||||
if (llmOldOccurrences === 1) {
|
||||
finalOldString = llmCorrectedOldString;
|
||||
occurrences = llmOldOccurrences;
|
||||
|
||||
if (newStringPotentiallyEscaped) {
|
||||
const baseNewStringForLLMCorrection = unescapeStringForGeminiBug(
|
||||
originalParams.new_string,
|
||||
);
|
||||
finalNewString = await correctNewString(
|
||||
client,
|
||||
originalParams.old_string, // original old
|
||||
llmCorrectedOldString, // corrected old
|
||||
baseNewStringForLLMCorrection, // base new for correction
|
||||
abortSignal,
|
||||
);
|
||||
}
|
||||
} else {
|
||||
// LLM correction also failed for old_string
|
||||
const result: CorrectedEditResult = {
|
||||
params: { ...originalParams },
|
||||
occurrences: 0, // Explicitly 0 as LLM failed
|
||||
};
|
||||
editCorrectionCache.set(cacheKey, result);
|
||||
return result;
|
||||
}
|
||||
} else {
|
||||
// Unescaping old_string resulted in > 1 occurrences
|
||||
const result: CorrectedEditResult = {
|
||||
params: { ...originalParams },
|
||||
occurrences, // This will be > 1
|
||||
};
|
||||
editCorrectionCache.set(cacheKey, result);
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
||||
const { targetString, pair } = trimPairIfPossible(
|
||||
finalOldString,
|
||||
finalNewString,
|
||||
currentContent,
|
||||
);
|
||||
finalOldString = targetString;
|
||||
finalNewString = pair;
|
||||
|
||||
// Final result construction
|
||||
const result: CorrectedEditResult = {
|
||||
params: {
|
||||
file_path: originalParams.file_path,
|
||||
old_string: finalOldString,
|
||||
new_string: finalNewString,
|
||||
},
|
||||
occurrences: countOccurrences(currentContent, finalOldString), // Recalculate occurrences with the final old_string
|
||||
};
|
||||
editCorrectionCache.set(cacheKey, result);
|
||||
return result;
|
||||
}
|
||||
|
||||
export async function ensureCorrectFileContent(
|
||||
content: string,
|
||||
client: GeminiClient,
|
||||
abortSignal: AbortSignal,
|
||||
): Promise<string> {
|
||||
const cachedResult = fileContentCorrectionCache.get(content);
|
||||
if (cachedResult) {
|
||||
return cachedResult;
|
||||
}
|
||||
|
||||
const contentPotentiallyEscaped =
|
||||
unescapeStringForGeminiBug(content) !== content;
|
||||
if (!contentPotentiallyEscaped) {
|
||||
fileContentCorrectionCache.set(content, content);
|
||||
return content;
|
||||
}
|
||||
|
||||
const correctedContent = await correctStringEscaping(
|
||||
content,
|
||||
client,
|
||||
abortSignal,
|
||||
);
|
||||
fileContentCorrectionCache.set(content, correctedContent);
|
||||
return correctedContent;
|
||||
}
|
||||
|
||||
// Define the expected JSON schema for the LLM response for old_string correction
|
||||
const OLD_STRING_CORRECTION_SCHEMA: SchemaUnion = {
|
||||
type: Type.OBJECT,
|
||||
properties: {
|
||||
corrected_target_snippet: {
|
||||
type: Type.STRING,
|
||||
description:
|
||||
'The corrected version of the target snippet that exactly and uniquely matches a segment within the provided file content.',
|
||||
},
|
||||
},
|
||||
required: ['corrected_target_snippet'],
|
||||
};
|
||||
|
||||
export async function correctOldStringMismatch(
|
||||
geminiClient: GeminiClient,
|
||||
fileContent: string,
|
||||
problematicSnippet: string,
|
||||
abortSignal: AbortSignal,
|
||||
): Promise<string> {
|
||||
const prompt = `
|
||||
Context: A process needs to find an exact literal, unique match for a specific text snippet within a file's content. The provided snippet failed to match exactly. This is most likely because it has been overly escaped.
|
||||
|
||||
Task: Analyze the provided file content and the problematic target snippet. Identify the segment in the file content that the snippet was *most likely* intended to match. Output the *exact*, literal text of that segment from the file content. Focus *only* on removing extra escape characters and correcting formatting, whitespace, or minor differences to achieve a PERFECT literal match. The output must be the exact literal text as it appears in the file.
|
||||
|
||||
Problematic target snippet:
|
||||
\`\`\`
|
||||
${problematicSnippet}
|
||||
\`\`\`
|
||||
|
||||
File Content:
|
||||
\`\`\`
|
||||
${fileContent}
|
||||
\`\`\`
|
||||
|
||||
For example, if the problematic target snippet was "\\\\\\nconst greeting = \`Hello \\\\\`\${name}\\\\\`\`;" and the file content had content that looked like "\nconst greeting = \`Hello ${'\\`'}\${name}${'\\`'}\`;", then corrected_target_snippet should likely be "\nconst greeting = \`Hello ${'\\`'}\${name}${'\\`'}\`;" to fix the incorrect escaping to match the original file content.
|
||||
If the differences are only in whitespace or formatting, apply similar whitespace/formatting changes to the corrected_target_snippet.
|
||||
|
||||
Return ONLY the corrected target snippet in the specified JSON format with the key 'corrected_target_snippet'. If no clear, unique match can be found, return an empty string for 'corrected_target_snippet'.
|
||||
`.trim();
|
||||
|
||||
const contents: Content[] = [{ role: 'user', parts: [{ text: prompt }] }];
|
||||
|
||||
try {
|
||||
const result = await geminiClient.generateJson(
|
||||
contents,
|
||||
OLD_STRING_CORRECTION_SCHEMA,
|
||||
abortSignal,
|
||||
EditModel,
|
||||
EditConfig,
|
||||
);
|
||||
|
||||
if (
|
||||
result &&
|
||||
typeof result.corrected_target_snippet === 'string' &&
|
||||
result.corrected_target_snippet.length > 0
|
||||
) {
|
||||
return result.corrected_target_snippet;
|
||||
} else {
|
||||
return problematicSnippet;
|
||||
}
|
||||
} catch (error) {
|
||||
if (abortSignal.aborted) {
|
||||
throw error;
|
||||
}
|
||||
|
||||
console.error(
|
||||
'Error during LLM call for old string snippet correction:',
|
||||
error,
|
||||
);
|
||||
|
||||
return problematicSnippet;
|
||||
}
|
||||
}
|
||||
|
||||
// Define the expected JSON schema for the new_string correction LLM response
|
||||
const NEW_STRING_CORRECTION_SCHEMA: SchemaUnion = {
|
||||
type: Type.OBJECT,
|
||||
properties: {
|
||||
corrected_new_string: {
|
||||
type: Type.STRING,
|
||||
description:
|
||||
'The original_new_string adjusted to be a suitable replacement for the corrected_old_string, while maintaining the original intent of the change.',
|
||||
},
|
||||
},
|
||||
required: ['corrected_new_string'],
|
||||
};
|
||||
|
||||
/**
|
||||
* Adjusts the new_string to align with a corrected old_string, maintaining the original intent.
|
||||
*/
|
||||
export async function correctNewString(
|
||||
geminiClient: GeminiClient,
|
||||
originalOldString: string,
|
||||
correctedOldString: string,
|
||||
originalNewString: string,
|
||||
abortSignal: AbortSignal,
|
||||
): Promise<string> {
|
||||
if (originalOldString === correctedOldString) {
|
||||
return originalNewString;
|
||||
}
|
||||
|
||||
const prompt = `
|
||||
Context: A text replacement operation was planned. The original text to be replaced (original_old_string) was slightly different from the actual text in the file (corrected_old_string). The original_old_string has now been corrected to match the file content.
|
||||
We now need to adjust the replacement text (original_new_string) so that it makes sense as a replacement for the corrected_old_string, while preserving the original intent of the change.
|
||||
|
||||
original_old_string (what was initially intended to be found):
|
||||
\`\`\`
|
||||
${originalOldString}
|
||||
\`\`\`
|
||||
|
||||
corrected_old_string (what was actually found in the file and will be replaced):
|
||||
\`\`\`
|
||||
${correctedOldString}
|
||||
\`\`\`
|
||||
|
||||
original_new_string (what was intended to replace original_old_string):
|
||||
\`\`\`
|
||||
${originalNewString}
|
||||
\`\`\`
|
||||
|
||||
Task: Based on the differences between original_old_string and corrected_old_string, and the content of original_new_string, generate a corrected_new_string. This corrected_new_string should be what original_new_string would have been if it was designed to replace corrected_old_string directly, while maintaining the spirit of the original transformation.
|
||||
|
||||
For example, if original_old_string was "\\\\\\nconst greeting = \`Hello \\\\\`\${name}\\\\\`\`;" and corrected_old_string is "\nconst greeting = \`Hello ${'\\`'}\${name}${'\\`'}\`;", and original_new_string was "\\\\\\nconst greeting = \`Hello \\\\\`\${name} \${lastName}\\\\\`\`;", then corrected_new_string should likely be "\nconst greeting = \`Hello ${'\\`'}\${name} \${lastName}${'\\`'}\`;" to fix the incorrect escaping.
|
||||
If the differences are only in whitespace or formatting, apply similar whitespace/formatting changes to the corrected_new_string.
|
||||
|
||||
Return ONLY the corrected string in the specified JSON format with the key 'corrected_new_string'. If no adjustment is deemed necessary or possible, return the original_new_string.
|
||||
`.trim();
|
||||
|
||||
const contents: Content[] = [{ role: 'user', parts: [{ text: prompt }] }];
|
||||
|
||||
try {
|
||||
const result = await geminiClient.generateJson(
|
||||
contents,
|
||||
NEW_STRING_CORRECTION_SCHEMA,
|
||||
abortSignal,
|
||||
EditModel,
|
||||
EditConfig,
|
||||
);
|
||||
|
||||
if (
|
||||
result &&
|
||||
typeof result.corrected_new_string === 'string' &&
|
||||
result.corrected_new_string.length > 0
|
||||
) {
|
||||
return result.corrected_new_string;
|
||||
} else {
|
||||
return originalNewString;
|
||||
}
|
||||
} catch (error) {
|
||||
if (abortSignal.aborted) {
|
||||
throw error;
|
||||
}
|
||||
|
||||
console.error('Error during LLM call for new_string correction:', error);
|
||||
return originalNewString;
|
||||
}
|
||||
}
|
||||
|
||||
const CORRECT_NEW_STRING_ESCAPING_SCHEMA: SchemaUnion = {
|
||||
type: Type.OBJECT,
|
||||
properties: {
|
||||
corrected_new_string_escaping: {
|
||||
type: Type.STRING,
|
||||
description:
|
||||
'The new_string with corrected escaping, ensuring it is a proper replacement for the old_string, especially considering potential over-escaping issues from previous LLM generations.',
|
||||
},
|
||||
},
|
||||
required: ['corrected_new_string_escaping'],
|
||||
};
|
||||
|
||||
export async function correctNewStringEscaping(
|
||||
geminiClient: GeminiClient,
|
||||
oldString: string,
|
||||
potentiallyProblematicNewString: string,
|
||||
abortSignal: AbortSignal,
|
||||
): Promise<string> {
|
||||
const prompt = `
|
||||
Context: A text replacement operation is planned. The text to be replaced (old_string) has been correctly identified in the file. However, the replacement text (new_string) might have been improperly escaped by a previous LLM generation (e.g. too many backslashes for newlines like \\n instead of \n, or unnecessarily quotes like \\"Hello\\" instead of "Hello").
|
||||
|
||||
old_string (this is the exact text that will be replaced):
|
||||
\`\`\`
|
||||
${oldString}
|
||||
\`\`\`
|
||||
|
||||
potentially_problematic_new_string (this is the text that should replace old_string, but MIGHT have bad escaping, or might be entirely correct):
|
||||
\`\`\`
|
||||
${potentiallyProblematicNewString}
|
||||
\`\`\`
|
||||
|
||||
Task: Analyze the potentially_problematic_new_string. If it's syntactically invalid due to incorrect escaping (e.g., "\n", "\t", "\\", "\\'", "\\""), correct the invalid syntax. The goal is to ensure the new_string, when inserted into the code, will be a valid and correctly interpreted.
|
||||
|
||||
For example, if old_string is "foo" and potentially_problematic_new_string is "bar\\nbaz", the corrected_new_string_escaping should be "bar\nbaz".
|
||||
If potentially_problematic_new_string is console.log(\\"Hello World\\"), it should be console.log("Hello World").
|
||||
|
||||
Return ONLY the corrected string in the specified JSON format with the key 'corrected_new_string_escaping'. If no escaping correction is needed, return the original potentially_problematic_new_string.
|
||||
`.trim();
|
||||
|
||||
const contents: Content[] = [{ role: 'user', parts: [{ text: prompt }] }];
|
||||
|
||||
try {
|
||||
const result = await geminiClient.generateJson(
|
||||
contents,
|
||||
CORRECT_NEW_STRING_ESCAPING_SCHEMA,
|
||||
abortSignal,
|
||||
EditModel,
|
||||
EditConfig,
|
||||
);
|
||||
|
||||
if (
|
||||
result &&
|
||||
typeof result.corrected_new_string_escaping === 'string' &&
|
||||
result.corrected_new_string_escaping.length > 0
|
||||
) {
|
||||
return result.corrected_new_string_escaping;
|
||||
} else {
|
||||
return potentiallyProblematicNewString;
|
||||
}
|
||||
} catch (error) {
|
||||
if (abortSignal.aborted) {
|
||||
throw error;
|
||||
}
|
||||
|
||||
console.error(
|
||||
'Error during LLM call for new_string escaping correction:',
|
||||
error,
|
||||
);
|
||||
return potentiallyProblematicNewString;
|
||||
}
|
||||
}
|
||||
|
||||
const CORRECT_STRING_ESCAPING_SCHEMA: SchemaUnion = {
|
||||
type: Type.OBJECT,
|
||||
properties: {
|
||||
corrected_string_escaping: {
|
||||
type: Type.STRING,
|
||||
description:
|
||||
'The string with corrected escaping, ensuring it is valid, specially considering potential over-escaping issues from previous LLM generations.',
|
||||
},
|
||||
},
|
||||
required: ['corrected_string_escaping'],
|
||||
};
|
||||
|
||||
export async function correctStringEscaping(
|
||||
potentiallyProblematicString: string,
|
||||
client: GeminiClient,
|
||||
abortSignal: AbortSignal,
|
||||
): Promise<string> {
|
||||
const prompt = `
|
||||
Context: An LLM has just generated potentially_problematic_string and the text might have been improperly escaped (e.g. too many backslashes for newlines like \\n instead of \n, or unnecessarily quotes like \\"Hello\\" instead of "Hello").
|
||||
|
||||
potentially_problematic_string (this text MIGHT have bad escaping, or might be entirely correct):
|
||||
\`\`\`
|
||||
${potentiallyProblematicString}
|
||||
\`\`\`
|
||||
|
||||
Task: Analyze the potentially_problematic_string. If it's syntactically invalid due to incorrect escaping (e.g., "\n", "\t", "\\", "\\'", "\\""), correct the invalid syntax. The goal is to ensure the text will be a valid and correctly interpreted.
|
||||
|
||||
For example, if potentially_problematic_string is "bar\\nbaz", the corrected_new_string_escaping should be "bar\nbaz".
|
||||
If potentially_problematic_string is console.log(\\"Hello World\\"), it should be console.log("Hello World").
|
||||
|
||||
Return ONLY the corrected string in the specified JSON format with the key 'corrected_string_escaping'. If no escaping correction is needed, return the original potentially_problematic_string.
|
||||
`.trim();
|
||||
|
||||
const contents: Content[] = [{ role: 'user', parts: [{ text: prompt }] }];
|
||||
|
||||
try {
|
||||
const result = await client.generateJson(
|
||||
contents,
|
||||
CORRECT_STRING_ESCAPING_SCHEMA,
|
||||
abortSignal,
|
||||
EditModel,
|
||||
EditConfig,
|
||||
);
|
||||
|
||||
if (
|
||||
result &&
|
||||
typeof result.corrected_new_string_escaping === 'string' &&
|
||||
result.corrected_new_string_escaping.length > 0
|
||||
) {
|
||||
return result.corrected_new_string_escaping;
|
||||
} else {
|
||||
return potentiallyProblematicString;
|
||||
}
|
||||
} catch (error) {
|
||||
if (abortSignal.aborted) {
|
||||
throw error;
|
||||
}
|
||||
|
||||
console.error(
|
||||
'Error during LLM call for string escaping correction:',
|
||||
error,
|
||||
);
|
||||
return potentiallyProblematicString;
|
||||
}
|
||||
}
|
||||
|
||||
function trimPairIfPossible(
|
||||
target: string,
|
||||
trimIfTargetTrims: string,
|
||||
currentContent: string,
|
||||
) {
|
||||
const trimmedTargetString = target.trim();
|
||||
if (target.length !== trimmedTargetString.length) {
|
||||
const trimmedTargetOccurrences = countOccurrences(
|
||||
currentContent,
|
||||
trimmedTargetString,
|
||||
);
|
||||
|
||||
if (trimmedTargetOccurrences === 1) {
|
||||
const trimmedReactiveString = trimIfTargetTrims.trim();
|
||||
return {
|
||||
targetString: trimmedTargetString,
|
||||
pair: trimmedReactiveString,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
targetString: target,
|
||||
pair: trimIfTargetTrims,
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Unescapes a string that might have been overly escaped by an LLM.
|
||||
*/
|
||||
export function unescapeStringForGeminiBug(inputString: string): string {
|
||||
// Regex explanation:
|
||||
// \\+ : Matches one or more literal backslash characters.
|
||||
// (n|t|r|'|"|`|\n) : This is a capturing group. It matches one of the following:
|
||||
// n, t, r, ', ", ` : These match the literal characters 'n', 't', 'r', single quote, double quote, or backtick.
|
||||
// This handles cases like "\\n", "\\\\`", etc.
|
||||
// \n : This matches an actual newline character. This handles cases where the input
|
||||
// string might have something like "\\\n" (a literal backslash followed by a newline).
|
||||
// g : Global flag, to replace all occurrences.
|
||||
|
||||
return inputString.replace(/\\+(n|t|r|'|"|`|\n)/g, (match, capturedChar) => {
|
||||
// 'match' is the entire erroneous sequence, e.g., if the input (in memory) was "\\\\`", match is "\\\\`".
|
||||
// 'capturedChar' is the character that determines the true meaning, e.g., '`'.
|
||||
|
||||
switch (capturedChar) {
|
||||
case 'n':
|
||||
return '\n'; // Correctly escaped: \n (newline character)
|
||||
case 't':
|
||||
return '\t'; // Correctly escaped: \t (tab character)
|
||||
case 'r':
|
||||
return '\r'; // Correctly escaped: \r (carriage return character)
|
||||
case "'":
|
||||
return "'"; // Correctly escaped: ' (apostrophe character)
|
||||
case '"':
|
||||
return '"'; // Correctly escaped: " (quotation mark character)
|
||||
case '`':
|
||||
return '`'; // Correctly escaped: ` (backtick character)
|
||||
case '\n': // This handles when 'capturedChar' is an actual newline
|
||||
return '\n'; // Replace the whole erroneous sequence (e.g., "\\\n" in memory) with a clean newline
|
||||
default:
|
||||
// This fallback should ideally not be reached if the regex captures correctly.
|
||||
// It would return the original matched sequence if an unexpected character was captured.
|
||||
return match;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Counts occurrences of a substring in a string
|
||||
*/
|
||||
export function countOccurrences(str: string, substr: string): number {
|
||||
if (substr === '') {
|
||||
return 0;
|
||||
}
|
||||
let count = 0;
|
||||
let pos = str.indexOf(substr);
|
||||
while (pos !== -1) {
|
||||
count++;
|
||||
pos = str.indexOf(substr, pos + substr.length); // Start search after the current match
|
||||
}
|
||||
return count;
|
||||
}
|
||||
|
||||
export function resetEditCorrectorCaches_TEST_ONLY() {
|
||||
editCorrectionCache.clear();
|
||||
fileContentCorrectionCache.clear();
|
||||
}
|
||||
220
packages/core/src/utils/errorReporting.test.ts
Normal file
220
packages/core/src/utils/errorReporting.test.ts
Normal file
@@ -0,0 +1,220 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { describe, it, expect, vi, beforeEach, afterEach, Mock } from 'vitest';
|
||||
|
||||
// Use a type alias for SpyInstance as it's not directly exported
|
||||
type SpyInstance = ReturnType<typeof vi.spyOn>;
|
||||
import { reportError } from './errorReporting.js';
|
||||
import fs from 'node:fs/promises';
|
||||
import os from 'node:os';
|
||||
|
||||
// Mock dependencies
|
||||
vi.mock('node:fs/promises');
|
||||
vi.mock('node:os');
|
||||
|
||||
describe('reportError', () => {
|
||||
let consoleErrorSpy: SpyInstance;
|
||||
const MOCK_TMP_DIR = '/tmp';
|
||||
const MOCK_TIMESTAMP = '2025-01-01T00-00-00-000Z';
|
||||
|
||||
beforeEach(() => {
|
||||
vi.resetAllMocks();
|
||||
consoleErrorSpy = vi.spyOn(console, 'error').mockImplementation(() => {});
|
||||
(os.tmpdir as Mock).mockReturnValue(MOCK_TMP_DIR);
|
||||
vi.spyOn(Date.prototype, 'toISOString').mockReturnValue(MOCK_TIMESTAMP);
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
consoleErrorSpy.mockRestore();
|
||||
vi.restoreAllMocks();
|
||||
});
|
||||
|
||||
const getExpectedReportPath = (type: string) =>
|
||||
`${MOCK_TMP_DIR}/gemini-client-error-${type}-${MOCK_TIMESTAMP}.json`;
|
||||
|
||||
it('should generate a report and log the path', async () => {
|
||||
const error = new Error('Test error');
|
||||
error.stack = 'Test stack';
|
||||
const baseMessage = 'An error occurred.';
|
||||
const context = { data: 'test context' };
|
||||
const type = 'test-type';
|
||||
const expectedReportPath = getExpectedReportPath(type);
|
||||
|
||||
(fs.writeFile as Mock).mockResolvedValue(undefined);
|
||||
|
||||
await reportError(error, baseMessage, context, type);
|
||||
|
||||
expect(os.tmpdir).toHaveBeenCalledTimes(1);
|
||||
expect(fs.writeFile).toHaveBeenCalledWith(
|
||||
expectedReportPath,
|
||||
JSON.stringify(
|
||||
{
|
||||
error: { message: 'Test error', stack: error.stack },
|
||||
context,
|
||||
},
|
||||
null,
|
||||
2,
|
||||
),
|
||||
);
|
||||
expect(consoleErrorSpy).toHaveBeenCalledWith(
|
||||
`${baseMessage} Full report available at: ${expectedReportPath}`,
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle errors that are plain objects with a message property', async () => {
|
||||
const error = { message: 'Test plain object error' };
|
||||
const baseMessage = 'Another error.';
|
||||
const type = 'general';
|
||||
const expectedReportPath = getExpectedReportPath(type);
|
||||
|
||||
(fs.writeFile as Mock).mockResolvedValue(undefined);
|
||||
await reportError(error, baseMessage);
|
||||
|
||||
expect(fs.writeFile).toHaveBeenCalledWith(
|
||||
expectedReportPath,
|
||||
JSON.stringify(
|
||||
{
|
||||
error: { message: 'Test plain object error' },
|
||||
},
|
||||
null,
|
||||
2,
|
||||
),
|
||||
);
|
||||
expect(consoleErrorSpy).toHaveBeenCalledWith(
|
||||
`${baseMessage} Full report available at: ${expectedReportPath}`,
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle string errors', async () => {
|
||||
const error = 'Just a string error';
|
||||
const baseMessage = 'String error occurred.';
|
||||
const type = 'general';
|
||||
const expectedReportPath = getExpectedReportPath(type);
|
||||
|
||||
(fs.writeFile as Mock).mockResolvedValue(undefined);
|
||||
await reportError(error, baseMessage);
|
||||
|
||||
expect(fs.writeFile).toHaveBeenCalledWith(
|
||||
expectedReportPath,
|
||||
JSON.stringify(
|
||||
{
|
||||
error: { message: 'Just a string error' },
|
||||
},
|
||||
null,
|
||||
2,
|
||||
),
|
||||
);
|
||||
expect(consoleErrorSpy).toHaveBeenCalledWith(
|
||||
`${baseMessage} Full report available at: ${expectedReportPath}`,
|
||||
);
|
||||
});
|
||||
|
||||
it('should log fallback message if writing report fails', async () => {
|
||||
const error = new Error('Main error');
|
||||
const baseMessage = 'Failed operation.';
|
||||
const writeError = new Error('Failed to write file');
|
||||
const context = ['some context'];
|
||||
const type = 'general';
|
||||
const expectedReportPath = getExpectedReportPath(type);
|
||||
|
||||
(fs.writeFile as Mock).mockRejectedValue(writeError);
|
||||
|
||||
await reportError(error, baseMessage, context, type);
|
||||
|
||||
expect(fs.writeFile).toHaveBeenCalledWith(
|
||||
expectedReportPath,
|
||||
expect.any(String),
|
||||
); // It still tries to write
|
||||
expect(consoleErrorSpy).toHaveBeenCalledWith(
|
||||
`${baseMessage} Additionally, failed to write detailed error report:`,
|
||||
writeError,
|
||||
);
|
||||
expect(consoleErrorSpy).toHaveBeenCalledWith(
|
||||
'Original error that triggered report generation:',
|
||||
error,
|
||||
);
|
||||
expect(consoleErrorSpy).toHaveBeenCalledWith('Original context:', context);
|
||||
});
|
||||
|
||||
it('should handle stringification failure of report content (e.g. BigInt in context)', async () => {
|
||||
const error = new Error('Main error');
|
||||
error.stack = 'Main stack';
|
||||
const baseMessage = 'Failed operation with BigInt.';
|
||||
const context = { a: BigInt(1) }; // BigInt cannot be stringified by JSON.stringify
|
||||
const type = 'bigint-fail';
|
||||
const stringifyError = new TypeError(
|
||||
'Do not know how to serialize a BigInt',
|
||||
);
|
||||
const expectedMinimalReportPath = getExpectedReportPath(type);
|
||||
|
||||
// Simulate JSON.stringify throwing an error for the full report
|
||||
const originalJsonStringify = JSON.stringify;
|
||||
let callCount = 0;
|
||||
vi.spyOn(JSON, 'stringify').mockImplementation((value, replacer, space) => {
|
||||
callCount++;
|
||||
if (callCount === 1) {
|
||||
// First call is for the full report content
|
||||
throw stringifyError;
|
||||
}
|
||||
// Subsequent calls (for minimal report) should succeed
|
||||
return originalJsonStringify(value, replacer, space);
|
||||
});
|
||||
|
||||
(fs.writeFile as Mock).mockResolvedValue(undefined); // Mock for the minimal report write
|
||||
|
||||
await reportError(error, baseMessage, context, type);
|
||||
|
||||
expect(consoleErrorSpy).toHaveBeenCalledWith(
|
||||
`${baseMessage} Could not stringify report content (likely due to context):`,
|
||||
stringifyError,
|
||||
);
|
||||
expect(consoleErrorSpy).toHaveBeenCalledWith(
|
||||
'Original error that triggered report generation:',
|
||||
error,
|
||||
);
|
||||
expect(consoleErrorSpy).toHaveBeenCalledWith(
|
||||
'Original context could not be stringified or included in report.',
|
||||
);
|
||||
// Check that it attempts to write a minimal report
|
||||
expect(fs.writeFile).toHaveBeenCalledWith(
|
||||
expectedMinimalReportPath,
|
||||
originalJsonStringify(
|
||||
{ error: { message: error.message, stack: error.stack } },
|
||||
null,
|
||||
2,
|
||||
),
|
||||
);
|
||||
expect(consoleErrorSpy).toHaveBeenCalledWith(
|
||||
`${baseMessage} Partial report (excluding context) available at: ${expectedMinimalReportPath}`,
|
||||
);
|
||||
});
|
||||
|
||||
it('should generate a report without context if context is not provided', async () => {
|
||||
const error = new Error('Error without context');
|
||||
error.stack = 'No context stack';
|
||||
const baseMessage = 'Simple error.';
|
||||
const type = 'general';
|
||||
const expectedReportPath = getExpectedReportPath(type);
|
||||
|
||||
(fs.writeFile as Mock).mockResolvedValue(undefined);
|
||||
await reportError(error, baseMessage, undefined, type);
|
||||
|
||||
expect(fs.writeFile).toHaveBeenCalledWith(
|
||||
expectedReportPath,
|
||||
JSON.stringify(
|
||||
{
|
||||
error: { message: 'Error without context', stack: error.stack },
|
||||
},
|
||||
null,
|
||||
2,
|
||||
),
|
||||
);
|
||||
expect(consoleErrorSpy).toHaveBeenCalledWith(
|
||||
`${baseMessage} Full report available at: ${expectedReportPath}`,
|
||||
);
|
||||
});
|
||||
});
|
||||
117
packages/core/src/utils/errorReporting.ts
Normal file
117
packages/core/src/utils/errorReporting.ts
Normal file
@@ -0,0 +1,117 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import fs from 'node:fs/promises';
|
||||
import os from 'node:os';
|
||||
import path from 'node:path';
|
||||
import { Content } from '@google/genai';
|
||||
|
||||
interface ErrorReportData {
|
||||
error: { message: string; stack?: string } | { message: string };
|
||||
context?: unknown;
|
||||
additionalInfo?: Record<string, unknown>;
|
||||
}
|
||||
|
||||
/**
|
||||
* Generates an error report, writes it to a temporary file, and logs information to console.error.
|
||||
* @param error The error object.
|
||||
* @param context The relevant context (e.g., chat history, request contents).
|
||||
* @param type A string to identify the type of error (e.g., 'startChat', 'generateJson-api').
|
||||
* @param baseMessage The initial message to log to console.error before the report path.
|
||||
*/
|
||||
export async function reportError(
|
||||
error: Error | unknown,
|
||||
baseMessage: string,
|
||||
context?: Content[] | Record<string, unknown> | unknown[],
|
||||
type = 'general',
|
||||
): Promise<void> {
|
||||
const timestamp = new Date().toISOString().replace(/[:.]/g, '-');
|
||||
const reportFileName = `gemini-client-error-${type}-${timestamp}.json`;
|
||||
const reportPath = path.join(os.tmpdir(), reportFileName);
|
||||
|
||||
let errorToReport: { message: string; stack?: string };
|
||||
if (error instanceof Error) {
|
||||
errorToReport = { message: error.message, stack: error.stack };
|
||||
} else if (
|
||||
typeof error === 'object' &&
|
||||
error !== null &&
|
||||
'message' in error
|
||||
) {
|
||||
errorToReport = {
|
||||
message: String((error as { message: unknown }).message),
|
||||
};
|
||||
} else {
|
||||
errorToReport = { message: String(error) };
|
||||
}
|
||||
|
||||
const reportContent: ErrorReportData = { error: errorToReport };
|
||||
|
||||
if (context) {
|
||||
reportContent.context = context;
|
||||
}
|
||||
|
||||
let stringifiedReportContent: string;
|
||||
try {
|
||||
stringifiedReportContent = JSON.stringify(reportContent, null, 2);
|
||||
} catch (stringifyError) {
|
||||
// This can happen if context contains something like BigInt
|
||||
console.error(
|
||||
`${baseMessage} Could not stringify report content (likely due to context):`,
|
||||
stringifyError,
|
||||
);
|
||||
console.error('Original error that triggered report generation:', error);
|
||||
if (context) {
|
||||
console.error(
|
||||
'Original context could not be stringified or included in report.',
|
||||
);
|
||||
}
|
||||
// Fallback: try to report only the error if context was the issue
|
||||
try {
|
||||
const minimalReportContent = { error: errorToReport };
|
||||
stringifiedReportContent = JSON.stringify(minimalReportContent, null, 2);
|
||||
// Still try to write the minimal report
|
||||
await fs.writeFile(reportPath, stringifiedReportContent);
|
||||
console.error(
|
||||
`${baseMessage} Partial report (excluding context) available at: ${reportPath}`,
|
||||
);
|
||||
} catch (minimalWriteError) {
|
||||
console.error(
|
||||
`${baseMessage} Failed to write even a minimal error report:`,
|
||||
minimalWriteError,
|
||||
);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
await fs.writeFile(reportPath, stringifiedReportContent);
|
||||
console.error(`${baseMessage} Full report available at: ${reportPath}`);
|
||||
} catch (writeError) {
|
||||
console.error(
|
||||
`${baseMessage} Additionally, failed to write detailed error report:`,
|
||||
writeError,
|
||||
);
|
||||
// Log the original error as a fallback if report writing fails
|
||||
console.error('Original error that triggered report generation:', error);
|
||||
if (context) {
|
||||
// Context was stringifiable, but writing the file failed.
|
||||
// We already have stringifiedReportContent, but it might be too large for console.
|
||||
// So, we try to log the original context object, and if that fails, its stringified version (truncated).
|
||||
try {
|
||||
console.error('Original context:', context);
|
||||
} catch {
|
||||
try {
|
||||
console.error(
|
||||
'Original context (stringified, truncated):',
|
||||
JSON.stringify(context).substring(0, 1000),
|
||||
);
|
||||
} catch {
|
||||
console.error('Original context could not be logged or stringified.');
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
22
packages/core/src/utils/errors.ts
Normal file
22
packages/core/src/utils/errors.ts
Normal file
@@ -0,0 +1,22 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
export function isNodeError(error: unknown): error is NodeJS.ErrnoException {
|
||||
return error instanceof Error && 'code' in error;
|
||||
}
|
||||
|
||||
export function getErrorMessage(error: unknown): string {
|
||||
if (error instanceof Error) {
|
||||
return error.message;
|
||||
} else {
|
||||
try {
|
||||
const errorMessage = String(error);
|
||||
return errorMessage;
|
||||
} catch {
|
||||
return 'Failed to get error details';
|
||||
}
|
||||
}
|
||||
}
|
||||
431
packages/core/src/utils/fileUtils.test.ts
Normal file
431
packages/core/src/utils/fileUtils.test.ts
Normal file
@@ -0,0 +1,431 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import {
|
||||
describe,
|
||||
it,
|
||||
expect,
|
||||
vi,
|
||||
beforeEach,
|
||||
afterEach,
|
||||
type Mock,
|
||||
} from 'vitest';
|
||||
|
||||
import * as actualNodeFs from 'node:fs'; // For setup/teardown
|
||||
import fsPromises from 'node:fs/promises';
|
||||
import path from 'node:path';
|
||||
import os from 'node:os';
|
||||
import mime from 'mime-types';
|
||||
|
||||
import {
|
||||
isWithinRoot,
|
||||
isBinaryFile,
|
||||
detectFileType,
|
||||
processSingleFileContent,
|
||||
} from './fileUtils.js';
|
||||
|
||||
vi.mock('mime-types', () => ({
|
||||
default: { lookup: vi.fn() },
|
||||
lookup: vi.fn(),
|
||||
}));
|
||||
|
||||
const mockMimeLookup = mime.lookup as Mock;
|
||||
|
||||
describe('fileUtils', () => {
|
||||
let tempRootDir: string;
|
||||
const originalProcessCwd = process.cwd;
|
||||
|
||||
let testTextFilePath: string;
|
||||
let testImageFilePath: string;
|
||||
let testPdfFilePath: string;
|
||||
let testBinaryFilePath: string;
|
||||
let nonExistentFilePath: string;
|
||||
let directoryPath: string;
|
||||
|
||||
beforeEach(() => {
|
||||
vi.resetAllMocks(); // Reset all mocks, including mime.lookup
|
||||
|
||||
tempRootDir = actualNodeFs.mkdtempSync(
|
||||
path.join(os.tmpdir(), 'fileUtils-test-'),
|
||||
);
|
||||
process.cwd = vi.fn(() => tempRootDir); // Mock cwd if necessary for relative path logic within tests
|
||||
|
||||
testTextFilePath = path.join(tempRootDir, 'test.txt');
|
||||
testImageFilePath = path.join(tempRootDir, 'image.png');
|
||||
testPdfFilePath = path.join(tempRootDir, 'document.pdf');
|
||||
testBinaryFilePath = path.join(tempRootDir, 'app.exe');
|
||||
nonExistentFilePath = path.join(tempRootDir, 'notfound.txt');
|
||||
directoryPath = path.join(tempRootDir, 'subdir');
|
||||
|
||||
actualNodeFs.mkdirSync(directoryPath, { recursive: true }); // Ensure subdir exists
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
if (actualNodeFs.existsSync(tempRootDir)) {
|
||||
actualNodeFs.rmSync(tempRootDir, { recursive: true, force: true });
|
||||
}
|
||||
process.cwd = originalProcessCwd;
|
||||
vi.restoreAllMocks(); // Restore any spies
|
||||
});
|
||||
|
||||
describe('isWithinRoot', () => {
|
||||
const root = path.resolve('/project/root');
|
||||
|
||||
it('should return true for paths directly within the root', () => {
|
||||
expect(isWithinRoot(path.join(root, 'file.txt'), root)).toBe(true);
|
||||
expect(isWithinRoot(path.join(root, 'subdir', 'file.txt'), root)).toBe(
|
||||
true,
|
||||
);
|
||||
});
|
||||
|
||||
it('should return true for the root path itself', () => {
|
||||
expect(isWithinRoot(root, root)).toBe(true);
|
||||
});
|
||||
|
||||
it('should return false for paths outside the root', () => {
|
||||
expect(
|
||||
isWithinRoot(path.resolve('/project/other', 'file.txt'), root),
|
||||
).toBe(false);
|
||||
expect(isWithinRoot(path.resolve('/unrelated', 'file.txt'), root)).toBe(
|
||||
false,
|
||||
);
|
||||
});
|
||||
|
||||
it('should return false for paths that only partially match the root prefix', () => {
|
||||
expect(
|
||||
isWithinRoot(
|
||||
path.resolve('/project/root-but-actually-different'),
|
||||
root,
|
||||
),
|
||||
).toBe(false);
|
||||
});
|
||||
|
||||
it('should handle paths with trailing slashes correctly', () => {
|
||||
expect(isWithinRoot(path.join(root, 'file.txt') + path.sep, root)).toBe(
|
||||
true,
|
||||
);
|
||||
expect(isWithinRoot(root + path.sep, root)).toBe(true);
|
||||
});
|
||||
|
||||
it('should handle different path separators (POSIX vs Windows)', () => {
|
||||
const posixRoot = '/project/root';
|
||||
const posixPathInside = '/project/root/file.txt';
|
||||
const posixPathOutside = '/project/other/file.txt';
|
||||
expect(isWithinRoot(posixPathInside, posixRoot)).toBe(true);
|
||||
expect(isWithinRoot(posixPathOutside, posixRoot)).toBe(false);
|
||||
});
|
||||
|
||||
it('should return false for a root path that is a sub-path of the path to check', () => {
|
||||
const pathToCheck = path.resolve('/project/root/sub');
|
||||
const rootSub = path.resolve('/project/root');
|
||||
expect(isWithinRoot(pathToCheck, rootSub)).toBe(true);
|
||||
|
||||
const pathToCheckSuper = path.resolve('/project/root');
|
||||
const rootSuper = path.resolve('/project/root/sub');
|
||||
expect(isWithinRoot(pathToCheckSuper, rootSuper)).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe('isBinaryFile', () => {
|
||||
let filePathForBinaryTest: string;
|
||||
|
||||
beforeEach(() => {
|
||||
filePathForBinaryTest = path.join(tempRootDir, 'binaryCheck.tmp');
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
if (actualNodeFs.existsSync(filePathForBinaryTest)) {
|
||||
actualNodeFs.unlinkSync(filePathForBinaryTest);
|
||||
}
|
||||
});
|
||||
|
||||
it('should return false for an empty file', () => {
|
||||
actualNodeFs.writeFileSync(filePathForBinaryTest, '');
|
||||
expect(isBinaryFile(filePathForBinaryTest)).toBe(false);
|
||||
});
|
||||
|
||||
it('should return false for a typical text file', () => {
|
||||
actualNodeFs.writeFileSync(
|
||||
filePathForBinaryTest,
|
||||
'Hello, world!\nThis is a test file with normal text content.',
|
||||
);
|
||||
expect(isBinaryFile(filePathForBinaryTest)).toBe(false);
|
||||
});
|
||||
|
||||
it('should return true for a file with many null bytes', () => {
|
||||
const binaryContent = Buffer.from([
|
||||
0x48, 0x65, 0x00, 0x6c, 0x6f, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||
]); // "He\0llo\0\0\0\0\0"
|
||||
actualNodeFs.writeFileSync(filePathForBinaryTest, binaryContent);
|
||||
expect(isBinaryFile(filePathForBinaryTest)).toBe(true);
|
||||
});
|
||||
|
||||
it('should return true for a file with high percentage of non-printable ASCII', () => {
|
||||
const binaryContent = Buffer.from([
|
||||
0x41, 0x42, 0x01, 0x02, 0x03, 0x04, 0x05, 0x43, 0x44, 0x06,
|
||||
]); // AB\x01\x02\x03\x04\x05CD\x06
|
||||
actualNodeFs.writeFileSync(filePathForBinaryTest, binaryContent);
|
||||
expect(isBinaryFile(filePathForBinaryTest)).toBe(true);
|
||||
});
|
||||
|
||||
it('should return false if file access fails (e.g., ENOENT)', () => {
|
||||
// Ensure the file does not exist
|
||||
if (actualNodeFs.existsSync(filePathForBinaryTest)) {
|
||||
actualNodeFs.unlinkSync(filePathForBinaryTest);
|
||||
}
|
||||
expect(isBinaryFile(filePathForBinaryTest)).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe('detectFileType', () => {
|
||||
let filePathForDetectTest: string;
|
||||
|
||||
beforeEach(() => {
|
||||
filePathForDetectTest = path.join(tempRootDir, 'detectType.tmp');
|
||||
// Default: create as a text file for isBinaryFile fallback
|
||||
actualNodeFs.writeFileSync(filePathForDetectTest, 'Plain text content');
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
if (actualNodeFs.existsSync(filePathForDetectTest)) {
|
||||
actualNodeFs.unlinkSync(filePathForDetectTest);
|
||||
}
|
||||
vi.restoreAllMocks(); // Restore spies on actualNodeFs
|
||||
});
|
||||
|
||||
it('should detect image type by extension (png)', () => {
|
||||
mockMimeLookup.mockReturnValueOnce('image/png');
|
||||
expect(detectFileType('file.png')).toBe('image');
|
||||
});
|
||||
|
||||
it('should detect image type by extension (jpeg)', () => {
|
||||
mockMimeLookup.mockReturnValueOnce('image/jpeg');
|
||||
expect(detectFileType('file.jpg')).toBe('image');
|
||||
});
|
||||
|
||||
it('should detect pdf type by extension', () => {
|
||||
mockMimeLookup.mockReturnValueOnce('application/pdf');
|
||||
expect(detectFileType('file.pdf')).toBe('pdf');
|
||||
});
|
||||
|
||||
it('should detect known binary extensions as binary (e.g. .zip)', () => {
|
||||
mockMimeLookup.mockReturnValueOnce('application/zip');
|
||||
expect(detectFileType('archive.zip')).toBe('binary');
|
||||
});
|
||||
it('should detect known binary extensions as binary (e.g. .exe)', () => {
|
||||
mockMimeLookup.mockReturnValueOnce('application/octet-stream'); // Common for .exe
|
||||
expect(detectFileType('app.exe')).toBe('binary');
|
||||
});
|
||||
|
||||
it('should use isBinaryFile for unknown extensions and detect as binary', () => {
|
||||
mockMimeLookup.mockReturnValueOnce(false); // Unknown mime type
|
||||
// Create a file that isBinaryFile will identify as binary
|
||||
const binaryContent = Buffer.from([
|
||||
0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a,
|
||||
]);
|
||||
actualNodeFs.writeFileSync(filePathForDetectTest, binaryContent);
|
||||
expect(detectFileType(filePathForDetectTest)).toBe('binary');
|
||||
});
|
||||
|
||||
it('should default to text if mime type is unknown and content is not binary', () => {
|
||||
mockMimeLookup.mockReturnValueOnce(false); // Unknown mime type
|
||||
// filePathForDetectTest is already a text file by default from beforeEach
|
||||
expect(detectFileType(filePathForDetectTest)).toBe('text');
|
||||
});
|
||||
});
|
||||
|
||||
describe('processSingleFileContent', () => {
|
||||
beforeEach(() => {
|
||||
// Ensure files exist for statSync checks before readFile might be mocked
|
||||
if (actualNodeFs.existsSync(testTextFilePath))
|
||||
actualNodeFs.unlinkSync(testTextFilePath);
|
||||
if (actualNodeFs.existsSync(testImageFilePath))
|
||||
actualNodeFs.unlinkSync(testImageFilePath);
|
||||
if (actualNodeFs.existsSync(testPdfFilePath))
|
||||
actualNodeFs.unlinkSync(testPdfFilePath);
|
||||
if (actualNodeFs.existsSync(testBinaryFilePath))
|
||||
actualNodeFs.unlinkSync(testBinaryFilePath);
|
||||
});
|
||||
|
||||
it('should read a text file successfully', async () => {
|
||||
const content = 'Line 1\\nLine 2\\nLine 3';
|
||||
actualNodeFs.writeFileSync(testTextFilePath, content);
|
||||
const result = await processSingleFileContent(
|
||||
testTextFilePath,
|
||||
tempRootDir,
|
||||
);
|
||||
expect(result.llmContent).toBe(content);
|
||||
expect(result.returnDisplay).toBe('');
|
||||
expect(result.error).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should handle file not found', async () => {
|
||||
const result = await processSingleFileContent(
|
||||
nonExistentFilePath,
|
||||
tempRootDir,
|
||||
);
|
||||
expect(result.error).toContain('File not found');
|
||||
expect(result.returnDisplay).toContain('File not found');
|
||||
});
|
||||
|
||||
it('should handle read errors for text files', async () => {
|
||||
actualNodeFs.writeFileSync(testTextFilePath, 'content'); // File must exist for initial statSync
|
||||
const readError = new Error('Simulated read error');
|
||||
vi.spyOn(fsPromises, 'readFile').mockRejectedValueOnce(readError);
|
||||
|
||||
const result = await processSingleFileContent(
|
||||
testTextFilePath,
|
||||
tempRootDir,
|
||||
);
|
||||
expect(result.error).toContain('Simulated read error');
|
||||
expect(result.returnDisplay).toContain('Simulated read error');
|
||||
});
|
||||
|
||||
it('should handle read errors for image/pdf files', async () => {
|
||||
actualNodeFs.writeFileSync(testImageFilePath, 'content'); // File must exist
|
||||
mockMimeLookup.mockReturnValue('image/png');
|
||||
const readError = new Error('Simulated image read error');
|
||||
vi.spyOn(fsPromises, 'readFile').mockRejectedValueOnce(readError);
|
||||
|
||||
const result = await processSingleFileContent(
|
||||
testImageFilePath,
|
||||
tempRootDir,
|
||||
);
|
||||
expect(result.error).toContain('Simulated image read error');
|
||||
expect(result.returnDisplay).toContain('Simulated image read error');
|
||||
});
|
||||
|
||||
it('should process an image file', async () => {
|
||||
const fakePngData = Buffer.from('fake png data');
|
||||
actualNodeFs.writeFileSync(testImageFilePath, fakePngData);
|
||||
mockMimeLookup.mockReturnValue('image/png');
|
||||
const result = await processSingleFileContent(
|
||||
testImageFilePath,
|
||||
tempRootDir,
|
||||
);
|
||||
expect(
|
||||
(result.llmContent as { inlineData: unknown }).inlineData,
|
||||
).toBeDefined();
|
||||
expect(
|
||||
(result.llmContent as { inlineData: { mimeType: string } }).inlineData
|
||||
.mimeType,
|
||||
).toBe('image/png');
|
||||
expect(
|
||||
(result.llmContent as { inlineData: { data: string } }).inlineData.data,
|
||||
).toBe(fakePngData.toString('base64'));
|
||||
expect(result.returnDisplay).toContain('Read image file: image.png');
|
||||
});
|
||||
|
||||
it('should process a PDF file', async () => {
|
||||
const fakePdfData = Buffer.from('fake pdf data');
|
||||
actualNodeFs.writeFileSync(testPdfFilePath, fakePdfData);
|
||||
mockMimeLookup.mockReturnValue('application/pdf');
|
||||
const result = await processSingleFileContent(
|
||||
testPdfFilePath,
|
||||
tempRootDir,
|
||||
);
|
||||
expect(
|
||||
(result.llmContent as { inlineData: unknown }).inlineData,
|
||||
).toBeDefined();
|
||||
expect(
|
||||
(result.llmContent as { inlineData: { mimeType: string } }).inlineData
|
||||
.mimeType,
|
||||
).toBe('application/pdf');
|
||||
expect(
|
||||
(result.llmContent as { inlineData: { data: string } }).inlineData.data,
|
||||
).toBe(fakePdfData.toString('base64'));
|
||||
expect(result.returnDisplay).toContain('Read pdf file: document.pdf');
|
||||
});
|
||||
|
||||
it('should skip binary files', async () => {
|
||||
actualNodeFs.writeFileSync(
|
||||
testBinaryFilePath,
|
||||
Buffer.from([0x00, 0x01, 0x02]),
|
||||
);
|
||||
mockMimeLookup.mockReturnValueOnce('application/octet-stream');
|
||||
// isBinaryFile will operate on the real file.
|
||||
|
||||
const result = await processSingleFileContent(
|
||||
testBinaryFilePath,
|
||||
tempRootDir,
|
||||
);
|
||||
expect(result.llmContent).toContain(
|
||||
'Cannot display content of binary file',
|
||||
);
|
||||
expect(result.returnDisplay).toContain('Skipped binary file: app.exe');
|
||||
});
|
||||
|
||||
it('should handle path being a directory', async () => {
|
||||
const result = await processSingleFileContent(directoryPath, tempRootDir);
|
||||
expect(result.error).toContain('Path is a directory');
|
||||
expect(result.returnDisplay).toContain('Path is a directory');
|
||||
});
|
||||
|
||||
it('should paginate text files correctly (offset and limit)', async () => {
|
||||
const lines = Array.from({ length: 20 }, (_, i) => `Line ${i + 1}`);
|
||||
actualNodeFs.writeFileSync(testTextFilePath, lines.join('\n'));
|
||||
|
||||
const result = await processSingleFileContent(
|
||||
testTextFilePath,
|
||||
tempRootDir,
|
||||
5,
|
||||
5,
|
||||
); // Read lines 6-10
|
||||
const expectedContent = lines.slice(5, 10).join('\n');
|
||||
|
||||
expect(result.llmContent).toContain(expectedContent);
|
||||
expect(result.llmContent).toContain(
|
||||
'[File content truncated: showing lines 6-10 of 20 total lines. Use offset/limit parameters to view more.]',
|
||||
);
|
||||
expect(result.returnDisplay).toBe('(truncated)');
|
||||
expect(result.isTruncated).toBe(true);
|
||||
expect(result.originalLineCount).toBe(20);
|
||||
expect(result.linesShown).toEqual([6, 10]);
|
||||
});
|
||||
|
||||
it('should handle limit exceeding file length', async () => {
|
||||
const lines = ['Line 1', 'Line 2'];
|
||||
actualNodeFs.writeFileSync(testTextFilePath, lines.join('\n'));
|
||||
|
||||
const result = await processSingleFileContent(
|
||||
testTextFilePath,
|
||||
tempRootDir,
|
||||
0,
|
||||
10,
|
||||
);
|
||||
const expectedContent = lines.join('\n');
|
||||
|
||||
expect(result.llmContent).toBe(expectedContent);
|
||||
expect(result.returnDisplay).toBe('');
|
||||
expect(result.isTruncated).toBe(false);
|
||||
expect(result.originalLineCount).toBe(2);
|
||||
expect(result.linesShown).toEqual([1, 2]);
|
||||
});
|
||||
|
||||
it('should truncate long lines in text files', async () => {
|
||||
const longLine = 'a'.repeat(2500);
|
||||
actualNodeFs.writeFileSync(
|
||||
testTextFilePath,
|
||||
`Short line\n${longLine}\nAnother short line`,
|
||||
);
|
||||
|
||||
const result = await processSingleFileContent(
|
||||
testTextFilePath,
|
||||
tempRootDir,
|
||||
);
|
||||
|
||||
expect(result.llmContent).toContain('Short line');
|
||||
expect(result.llmContent).toContain(
|
||||
longLine.substring(0, 2000) + '... [truncated]',
|
||||
);
|
||||
expect(result.llmContent).toContain('Another short line');
|
||||
expect(result.llmContent).toContain(
|
||||
'[File content partially truncated: some lines exceeded maximum length of 2000 characters.]',
|
||||
);
|
||||
expect(result.isTruncated).toBe(true);
|
||||
});
|
||||
});
|
||||
});
|
||||
280
packages/core/src/utils/fileUtils.ts
Normal file
280
packages/core/src/utils/fileUtils.ts
Normal file
@@ -0,0 +1,280 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import fs from 'fs';
|
||||
import path from 'path';
|
||||
import { PartUnion } from '@google/genai';
|
||||
import mime from 'mime-types';
|
||||
|
||||
// Constants for text file processing
|
||||
const DEFAULT_MAX_LINES_TEXT_FILE = 2000;
|
||||
const MAX_LINE_LENGTH_TEXT_FILE = 2000;
|
||||
|
||||
// Default values for encoding and separator format
|
||||
export const DEFAULT_ENCODING: BufferEncoding = 'utf-8';
|
||||
|
||||
/**
|
||||
* Checks if a path is within a given root directory.
|
||||
* @param pathToCheck The absolute path to check.
|
||||
* @param rootDirectory The absolute root directory.
|
||||
* @returns True if the path is within the root directory, false otherwise.
|
||||
*/
|
||||
export function isWithinRoot(
|
||||
pathToCheck: string,
|
||||
rootDirectory: string,
|
||||
): boolean {
|
||||
const normalizedPathToCheck = path.normalize(pathToCheck);
|
||||
const normalizedRootDirectory = path.normalize(rootDirectory);
|
||||
|
||||
// Ensure the rootDirectory path ends with a separator for correct startsWith comparison,
|
||||
// unless it's the root path itself (e.g., '/' or 'C:\').
|
||||
const rootWithSeparator =
|
||||
normalizedRootDirectory === path.sep ||
|
||||
normalizedRootDirectory.endsWith(path.sep)
|
||||
? normalizedRootDirectory
|
||||
: normalizedRootDirectory + path.sep;
|
||||
|
||||
return (
|
||||
normalizedPathToCheck === normalizedRootDirectory ||
|
||||
normalizedPathToCheck.startsWith(rootWithSeparator)
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Determines if a file is likely binary based on content sampling.
|
||||
* @param filePath Path to the file.
|
||||
* @returns True if the file appears to be binary.
|
||||
*/
|
||||
export function isBinaryFile(filePath: string): boolean {
|
||||
try {
|
||||
const fd = fs.openSync(filePath, 'r');
|
||||
// Read up to 4KB or file size, whichever is smaller
|
||||
const fileSize = fs.fstatSync(fd).size;
|
||||
if (fileSize === 0) {
|
||||
// Empty file is not considered binary for content checking
|
||||
fs.closeSync(fd);
|
||||
return false;
|
||||
}
|
||||
const bufferSize = Math.min(4096, fileSize);
|
||||
const buffer = Buffer.alloc(bufferSize);
|
||||
const bytesRead = fs.readSync(fd, buffer, 0, buffer.length, 0);
|
||||
fs.closeSync(fd);
|
||||
|
||||
if (bytesRead === 0) return false;
|
||||
|
||||
let nonPrintableCount = 0;
|
||||
for (let i = 0; i < bytesRead; i++) {
|
||||
if (buffer[i] === 0) return true; // Null byte is a strong indicator
|
||||
if (buffer[i] < 9 || (buffer[i] > 13 && buffer[i] < 32)) {
|
||||
nonPrintableCount++;
|
||||
}
|
||||
}
|
||||
// If >30% non-printable characters, consider it binary
|
||||
return nonPrintableCount / bytesRead > 0.3;
|
||||
} catch {
|
||||
// If any error occurs (e.g. file not found, permissions),
|
||||
// treat as not binary here; let higher-level functions handle existence/access errors.
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Detects the type of file based on extension and content.
|
||||
* @param filePath Path to the file.
|
||||
* @returns 'text', 'image', 'pdf', or 'binary'.
|
||||
*/
|
||||
export function detectFileType(
|
||||
filePath: string,
|
||||
): 'text' | 'image' | 'pdf' | 'binary' {
|
||||
const ext = path.extname(filePath).toLowerCase();
|
||||
const lookedUpMimeType = mime.lookup(filePath); // Returns false if not found, or the mime type string
|
||||
|
||||
if (lookedUpMimeType && lookedUpMimeType.startsWith('image/')) {
|
||||
return 'image';
|
||||
}
|
||||
if (lookedUpMimeType && lookedUpMimeType === 'application/pdf') {
|
||||
return 'pdf';
|
||||
}
|
||||
|
||||
// Stricter binary check for common non-text extensions before content check
|
||||
// These are often not well-covered by mime-types or might be misidentified.
|
||||
if (
|
||||
[
|
||||
'.zip',
|
||||
'.tar',
|
||||
'.gz',
|
||||
'.exe',
|
||||
'.dll',
|
||||
'.so',
|
||||
'.class',
|
||||
'.jar',
|
||||
'.war',
|
||||
'.7z',
|
||||
'.doc',
|
||||
'.docx',
|
||||
'.xls',
|
||||
'.xlsx',
|
||||
'.ppt',
|
||||
'.pptx',
|
||||
'.odt',
|
||||
'.ods',
|
||||
'.odp',
|
||||
'.bin',
|
||||
'.dat',
|
||||
'.obj',
|
||||
'.o',
|
||||
'.a',
|
||||
'.lib',
|
||||
'.wasm',
|
||||
'.pyc',
|
||||
'.pyo',
|
||||
].includes(ext)
|
||||
) {
|
||||
return 'binary';
|
||||
}
|
||||
|
||||
// Fallback to content-based check if mime type wasn't conclusive for image/pdf
|
||||
// and it's not a known binary extension.
|
||||
if (isBinaryFile(filePath)) {
|
||||
return 'binary';
|
||||
}
|
||||
|
||||
return 'text';
|
||||
}
|
||||
|
||||
export interface ProcessedFileReadResult {
|
||||
llmContent: PartUnion; // string for text, Part for image/pdf/unreadable binary
|
||||
returnDisplay: string;
|
||||
error?: string; // Optional error message for the LLM if file processing failed
|
||||
isTruncated?: boolean; // For text files, indicates if content was truncated
|
||||
originalLineCount?: number; // For text files
|
||||
linesShown?: [number, number]; // For text files [startLine, endLine] (1-based for display)
|
||||
}
|
||||
|
||||
/**
|
||||
* Reads and processes a single file, handling text, images, and PDFs.
|
||||
* @param filePath Absolute path to the file.
|
||||
* @param rootDirectory Absolute path to the project root for relative path display.
|
||||
* @param offset Optional offset for text files (0-based line number).
|
||||
* @param limit Optional limit for text files (number of lines to read).
|
||||
* @returns ProcessedFileReadResult object.
|
||||
*/
|
||||
export async function processSingleFileContent(
|
||||
filePath: string,
|
||||
rootDirectory: string,
|
||||
offset?: number,
|
||||
limit?: number,
|
||||
): Promise<ProcessedFileReadResult> {
|
||||
try {
|
||||
if (!fs.existsSync(filePath)) {
|
||||
// Sync check is acceptable before async read
|
||||
return {
|
||||
llmContent: '',
|
||||
returnDisplay: 'File not found.',
|
||||
error: `File not found: ${filePath}`,
|
||||
};
|
||||
}
|
||||
const stats = fs.statSync(filePath); // Sync check
|
||||
if (stats.isDirectory()) {
|
||||
return {
|
||||
llmContent: '',
|
||||
returnDisplay: 'Path is a directory.',
|
||||
error: `Path is a directory, not a file: ${filePath}`,
|
||||
};
|
||||
}
|
||||
|
||||
const fileType = detectFileType(filePath);
|
||||
const relativePathForDisplay = path
|
||||
.relative(rootDirectory, filePath)
|
||||
.replace(/\\/g, '/');
|
||||
|
||||
switch (fileType) {
|
||||
case 'binary': {
|
||||
return {
|
||||
llmContent: `Cannot display content of binary file: ${relativePathForDisplay}`,
|
||||
returnDisplay: `Skipped binary file: ${relativePathForDisplay}`,
|
||||
};
|
||||
}
|
||||
case 'text': {
|
||||
const content = await fs.promises.readFile(filePath, 'utf8');
|
||||
const lines = content.split('\n');
|
||||
const originalLineCount = lines.length;
|
||||
|
||||
const startLine = offset || 0;
|
||||
const effectiveLimit =
|
||||
limit === undefined ? DEFAULT_MAX_LINES_TEXT_FILE : limit;
|
||||
// Ensure endLine does not exceed originalLineCount
|
||||
const endLine = Math.min(startLine + effectiveLimit, originalLineCount);
|
||||
// Ensure selectedLines doesn't try to slice beyond array bounds if startLine is too high
|
||||
const actualStartLine = Math.min(startLine, originalLineCount);
|
||||
const selectedLines = lines.slice(actualStartLine, endLine);
|
||||
|
||||
let linesWereTruncatedInLength = false;
|
||||
const formattedLines = selectedLines.map((line) => {
|
||||
if (line.length > MAX_LINE_LENGTH_TEXT_FILE) {
|
||||
linesWereTruncatedInLength = true;
|
||||
return (
|
||||
line.substring(0, MAX_LINE_LENGTH_TEXT_FILE) + '... [truncated]'
|
||||
);
|
||||
}
|
||||
return line;
|
||||
});
|
||||
|
||||
const contentRangeTruncated = endLine < originalLineCount;
|
||||
const isTruncated = contentRangeTruncated || linesWereTruncatedInLength;
|
||||
|
||||
let llmTextContent = '';
|
||||
if (contentRangeTruncated) {
|
||||
llmTextContent += `[File content truncated: showing lines ${actualStartLine + 1}-${endLine} of ${originalLineCount} total lines. Use offset/limit parameters to view more.]\n`;
|
||||
} else if (linesWereTruncatedInLength) {
|
||||
llmTextContent += `[File content partially truncated: some lines exceeded maximum length of ${MAX_LINE_LENGTH_TEXT_FILE} characters.]\n`;
|
||||
}
|
||||
llmTextContent += formattedLines.join('\n');
|
||||
|
||||
return {
|
||||
llmContent: llmTextContent,
|
||||
returnDisplay: isTruncated ? '(truncated)' : '',
|
||||
isTruncated,
|
||||
originalLineCount,
|
||||
linesShown: [actualStartLine + 1, endLine],
|
||||
};
|
||||
}
|
||||
case 'image':
|
||||
case 'pdf': {
|
||||
const contentBuffer = await fs.promises.readFile(filePath);
|
||||
const base64Data = contentBuffer.toString('base64');
|
||||
return {
|
||||
llmContent: {
|
||||
inlineData: {
|
||||
data: base64Data,
|
||||
mimeType: mime.lookup(filePath) || 'application/octet-stream',
|
||||
},
|
||||
},
|
||||
returnDisplay: `Read ${fileType} file: ${relativePathForDisplay}`,
|
||||
};
|
||||
}
|
||||
default: {
|
||||
// Should not happen with current detectFileType logic
|
||||
const exhaustiveCheck: never = fileType;
|
||||
return {
|
||||
llmContent: `Unhandled file type: ${exhaustiveCheck}`,
|
||||
returnDisplay: `Skipped unhandled file type: ${relativePathForDisplay}`,
|
||||
error: `Unhandled file type for ${filePath}`,
|
||||
};
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
const errorMessage = error instanceof Error ? error.message : String(error);
|
||||
const displayPath = path
|
||||
.relative(rootDirectory, filePath)
|
||||
.replace(/\\/g, '/');
|
||||
return {
|
||||
llmContent: `Error reading file ${displayPath}: ${errorMessage}`,
|
||||
returnDisplay: `Error reading file ${displayPath}: ${errorMessage}`,
|
||||
error: `Error reading file ${filePath}: ${errorMessage}`,
|
||||
};
|
||||
}
|
||||
}
|
||||
17
packages/core/src/utils/generateContentResponseUtilities.ts
Normal file
17
packages/core/src/utils/generateContentResponseUtilities.ts
Normal file
@@ -0,0 +1,17 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { GenerateContentResponse } from '@google/genai';
|
||||
|
||||
export function getResponseText(
|
||||
response: GenerateContentResponse,
|
||||
): string | undefined {
|
||||
return (
|
||||
response.candidates?.[0]?.content?.parts
|
||||
?.map((part) => part.text)
|
||||
.join('') || undefined
|
||||
);
|
||||
}
|
||||
278
packages/core/src/utils/getFolderStructure.test.ts
Normal file
278
packages/core/src/utils/getFolderStructure.test.ts
Normal file
@@ -0,0 +1,278 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { describe, it, expect, vi, beforeEach, afterEach, Mock } from 'vitest';
|
||||
import fsPromises from 'fs/promises';
|
||||
import { Dirent as FSDirent } from 'fs';
|
||||
import * as nodePath from 'path';
|
||||
import { getFolderStructure } from './getFolderStructure.js';
|
||||
|
||||
vi.mock('path', async (importOriginal) => {
|
||||
const original = (await importOriginal()) as typeof nodePath;
|
||||
return {
|
||||
...original,
|
||||
resolve: vi.fn((str) => str),
|
||||
// Other path functions (basename, join, normalize, etc.) will use original implementation
|
||||
};
|
||||
});
|
||||
|
||||
vi.mock('fs/promises');
|
||||
|
||||
// Import 'path' again here, it will be the mocked version
|
||||
import * as path from 'path';
|
||||
|
||||
// Helper to create Dirent-like objects for mocking fs.readdir
|
||||
const createDirent = (name: string, type: 'file' | 'dir'): FSDirent => ({
|
||||
name,
|
||||
isFile: () => type === 'file',
|
||||
isDirectory: () => type === 'dir',
|
||||
isBlockDevice: () => false,
|
||||
isCharacterDevice: () => false,
|
||||
isSymbolicLink: () => false,
|
||||
isFIFO: () => false,
|
||||
isSocket: () => false,
|
||||
parentPath: '',
|
||||
path: '',
|
||||
});
|
||||
|
||||
describe('getFolderStructure', () => {
|
||||
beforeEach(() => {
|
||||
vi.resetAllMocks();
|
||||
|
||||
// path.resolve is now a vi.fn() due to the top-level vi.mock.
|
||||
// We ensure its implementation is set for each test (or rely on the one from vi.mock).
|
||||
// vi.resetAllMocks() clears call history but not the implementation set by vi.fn() in vi.mock.
|
||||
// If we needed to change it per test, we would do it here:
|
||||
(path.resolve as Mock).mockImplementation((str: string) => str);
|
||||
|
||||
// Re-apply/define the mock implementation for fsPromises.readdir for each test
|
||||
(fsPromises.readdir as Mock).mockImplementation(
|
||||
async (dirPath: string | Buffer | URL) => {
|
||||
// path.normalize here will use the mocked path module.
|
||||
// Since normalize is spread from original, it should be the real one.
|
||||
const normalizedPath = path.normalize(dirPath.toString());
|
||||
if (mockFsStructure[normalizedPath]) {
|
||||
return mockFsStructure[normalizedPath];
|
||||
}
|
||||
throw Object.assign(
|
||||
new Error(
|
||||
`ENOENT: no such file or directory, scandir '${normalizedPath}'`,
|
||||
),
|
||||
{ code: 'ENOENT' },
|
||||
);
|
||||
},
|
||||
);
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.restoreAllMocks(); // Restores spies (like fsPromises.readdir) and resets vi.fn mocks (like path.resolve)
|
||||
});
|
||||
|
||||
const mockFsStructure: Record<string, FSDirent[]> = {
|
||||
'/testroot': [
|
||||
createDirent('file1.txt', 'file'),
|
||||
createDirent('subfolderA', 'dir'),
|
||||
createDirent('emptyFolder', 'dir'),
|
||||
createDirent('.hiddenfile', 'file'),
|
||||
createDirent('node_modules', 'dir'),
|
||||
],
|
||||
'/testroot/subfolderA': [
|
||||
createDirent('fileA1.ts', 'file'),
|
||||
createDirent('fileA2.js', 'file'),
|
||||
createDirent('subfolderB', 'dir'),
|
||||
],
|
||||
'/testroot/subfolderA/subfolderB': [createDirent('fileB1.md', 'file')],
|
||||
'/testroot/emptyFolder': [],
|
||||
'/testroot/node_modules': [createDirent('somepackage', 'dir')],
|
||||
'/testroot/manyFilesFolder': Array.from({ length: 10 }, (_, i) =>
|
||||
createDirent(`file-${i}.txt`, 'file'),
|
||||
),
|
||||
'/testroot/manyFolders': Array.from({ length: 5 }, (_, i) =>
|
||||
createDirent(`folder-${i}`, 'dir'),
|
||||
),
|
||||
...Array.from({ length: 5 }, (_, i) => ({
|
||||
[`/testroot/manyFolders/folder-${i}`]: [
|
||||
createDirent('child.txt', 'file'),
|
||||
],
|
||||
})).reduce((acc, val) => ({ ...acc, ...val }), {}),
|
||||
'/testroot/deepFolders': [createDirent('level1', 'dir')],
|
||||
'/testroot/deepFolders/level1': [createDirent('level2', 'dir')],
|
||||
'/testroot/deepFolders/level1/level2': [createDirent('level3', 'dir')],
|
||||
'/testroot/deepFolders/level1/level2/level3': [
|
||||
createDirent('file.txt', 'file'),
|
||||
],
|
||||
};
|
||||
|
||||
it('should return basic folder structure', async () => {
|
||||
const structure = await getFolderStructure('/testroot/subfolderA');
|
||||
const expected = `
|
||||
Showing up to 200 items (files + folders).
|
||||
|
||||
/testroot/subfolderA/
|
||||
├───fileA1.ts
|
||||
├───fileA2.js
|
||||
└───subfolderB/
|
||||
└───fileB1.md
|
||||
`.trim();
|
||||
expect(structure.trim()).toBe(expected);
|
||||
});
|
||||
|
||||
it('should handle an empty folder', async () => {
|
||||
const structure = await getFolderStructure('/testroot/emptyFolder');
|
||||
const expected = `
|
||||
Showing up to 200 items (files + folders).
|
||||
|
||||
/testroot/emptyFolder/
|
||||
`.trim();
|
||||
expect(structure.trim()).toBe(expected.trim());
|
||||
});
|
||||
|
||||
it('should ignore folders specified in ignoredFolders (default)', async () => {
|
||||
const structure = await getFolderStructure('/testroot');
|
||||
const expected = `
|
||||
Showing up to 200 items (files + folders). Folders or files indicated with ... contain more items not shown, were ignored, or the display limit (200 items) was reached.
|
||||
|
||||
/testroot/
|
||||
├───.hiddenfile
|
||||
├───file1.txt
|
||||
├───emptyFolder/
|
||||
├───node_modules/...
|
||||
└───subfolderA/
|
||||
├───fileA1.ts
|
||||
├───fileA2.js
|
||||
└───subfolderB/
|
||||
└───fileB1.md
|
||||
`.trim();
|
||||
expect(structure.trim()).toBe(expected);
|
||||
});
|
||||
|
||||
it('should ignore folders specified in custom ignoredFolders', async () => {
|
||||
const structure = await getFolderStructure('/testroot', {
|
||||
ignoredFolders: new Set(['subfolderA', 'node_modules']),
|
||||
});
|
||||
const expected = `
|
||||
Showing up to 200 items (files + folders). Folders or files indicated with ... contain more items not shown, were ignored, or the display limit (200 items) was reached.
|
||||
|
||||
/testroot/
|
||||
├───.hiddenfile
|
||||
├───file1.txt
|
||||
├───emptyFolder/
|
||||
├───node_modules/...
|
||||
└───subfolderA/...
|
||||
`.trim();
|
||||
expect(structure.trim()).toBe(expected);
|
||||
});
|
||||
|
||||
it('should filter files by fileIncludePattern', async () => {
|
||||
const structure = await getFolderStructure('/testroot/subfolderA', {
|
||||
fileIncludePattern: /\.ts$/,
|
||||
});
|
||||
const expected = `
|
||||
Showing up to 200 items (files + folders).
|
||||
|
||||
/testroot/subfolderA/
|
||||
├───fileA1.ts
|
||||
└───subfolderB/
|
||||
`.trim();
|
||||
expect(structure.trim()).toBe(expected);
|
||||
});
|
||||
|
||||
it('should handle maxItems truncation for files within a folder', async () => {
|
||||
const structure = await getFolderStructure('/testroot/subfolderA', {
|
||||
maxItems: 3,
|
||||
});
|
||||
const expected = `
|
||||
Showing up to 3 items (files + folders).
|
||||
|
||||
/testroot/subfolderA/
|
||||
├───fileA1.ts
|
||||
├───fileA2.js
|
||||
└───subfolderB/
|
||||
`.trim();
|
||||
expect(structure.trim()).toBe(expected);
|
||||
});
|
||||
|
||||
it('should handle maxItems truncation for subfolders', async () => {
|
||||
const structure = await getFolderStructure('/testroot/manyFolders', {
|
||||
maxItems: 4,
|
||||
});
|
||||
const expectedRevised = `
|
||||
Showing up to 4 items (files + folders). Folders or files indicated with ... contain more items not shown, were ignored, or the display limit (4 items) was reached.
|
||||
|
||||
/testroot/manyFolders/
|
||||
├───folder-0/
|
||||
├───folder-1/
|
||||
├───folder-2/
|
||||
├───folder-3/
|
||||
└───...
|
||||
`.trim();
|
||||
expect(structure.trim()).toBe(expectedRevised);
|
||||
});
|
||||
|
||||
it('should handle maxItems that only allows the root folder itself', async () => {
|
||||
const structure = await getFolderStructure('/testroot/subfolderA', {
|
||||
maxItems: 1,
|
||||
});
|
||||
const expectedRevisedMax1 = `
|
||||
Showing up to 1 items (files + folders). Folders or files indicated with ... contain more items not shown, were ignored, or the display limit (1 items) was reached.
|
||||
|
||||
/testroot/subfolderA/
|
||||
├───fileA1.ts
|
||||
├───...
|
||||
└───...
|
||||
`.trim();
|
||||
expect(structure.trim()).toBe(expectedRevisedMax1);
|
||||
});
|
||||
|
||||
it('should handle non-existent directory', async () => {
|
||||
// Temporarily make fsPromises.readdir throw ENOENT for this specific path
|
||||
const originalReaddir = fsPromises.readdir;
|
||||
(fsPromises.readdir as Mock).mockImplementation(
|
||||
async (p: string | Buffer | URL) => {
|
||||
if (p === '/nonexistent') {
|
||||
throw Object.assign(new Error('ENOENT'), { code: 'ENOENT' });
|
||||
}
|
||||
return originalReaddir(p);
|
||||
},
|
||||
);
|
||||
|
||||
const structure = await getFolderStructure('/nonexistent');
|
||||
expect(structure).toContain(
|
||||
'Error: Could not read directory "/nonexistent"',
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle deep folder structure within limits', async () => {
|
||||
const structure = await getFolderStructure('/testroot/deepFolders', {
|
||||
maxItems: 10,
|
||||
});
|
||||
const expected = `
|
||||
Showing up to 10 items (files + folders).
|
||||
|
||||
/testroot/deepFolders/
|
||||
└───level1/
|
||||
└───level2/
|
||||
└───level3/
|
||||
└───file.txt
|
||||
`.trim();
|
||||
expect(structure.trim()).toBe(expected);
|
||||
});
|
||||
|
||||
it('should truncate deep folder structure if maxItems is small', async () => {
|
||||
const structure = await getFolderStructure('/testroot/deepFolders', {
|
||||
maxItems: 3,
|
||||
});
|
||||
const expected = `
|
||||
Showing up to 3 items (files + folders).
|
||||
|
||||
/testroot/deepFolders/
|
||||
└───level1/
|
||||
└───level2/
|
||||
└───level3/
|
||||
`.trim();
|
||||
expect(structure.trim()).toBe(expected);
|
||||
});
|
||||
});
|
||||
325
packages/core/src/utils/getFolderStructure.ts
Normal file
325
packages/core/src/utils/getFolderStructure.ts
Normal file
@@ -0,0 +1,325 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import * as fs from 'fs/promises';
|
||||
import { Dirent } from 'fs';
|
||||
import * as path from 'path';
|
||||
import { getErrorMessage, isNodeError } from './errors.js';
|
||||
|
||||
const MAX_ITEMS = 200;
|
||||
const TRUNCATION_INDICATOR = '...';
|
||||
const DEFAULT_IGNORED_FOLDERS = new Set(['node_modules', '.git', 'dist']);
|
||||
|
||||
// --- Interfaces ---
|
||||
|
||||
/** Options for customizing folder structure retrieval. */
|
||||
interface FolderStructureOptions {
|
||||
/** Maximum number of files and folders combined to display. Defaults to 200. */
|
||||
maxItems?: number;
|
||||
/** Set of folder names to ignore completely. Case-sensitive. */
|
||||
ignoredFolders?: Set<string>;
|
||||
/** Optional regex to filter included files by name. */
|
||||
fileIncludePattern?: RegExp;
|
||||
}
|
||||
|
||||
// Define a type for the merged options where fileIncludePattern remains optional
|
||||
type MergedFolderStructureOptions = Required<
|
||||
Omit<FolderStructureOptions, 'fileIncludePattern'>
|
||||
> & {
|
||||
fileIncludePattern?: RegExp;
|
||||
};
|
||||
|
||||
/** Represents the full, unfiltered information about a folder and its contents. */
|
||||
interface FullFolderInfo {
|
||||
name: string;
|
||||
path: string;
|
||||
files: string[];
|
||||
subFolders: FullFolderInfo[];
|
||||
totalChildren: number; // Number of files and subfolders included from this folder during BFS scan
|
||||
totalFiles: number; // Number of files included from this folder during BFS scan
|
||||
isIgnored?: boolean; // Flag to easily identify ignored folders later
|
||||
hasMoreFiles?: boolean; // Indicates if files were truncated for this specific folder
|
||||
hasMoreSubfolders?: boolean; // Indicates if subfolders were truncated for this specific folder
|
||||
}
|
||||
|
||||
// --- Interfaces ---
|
||||
|
||||
// --- Helper Functions ---
|
||||
|
||||
async function readFullStructure(
|
||||
rootPath: string,
|
||||
options: MergedFolderStructureOptions,
|
||||
): Promise<FullFolderInfo | null> {
|
||||
const rootName = path.basename(rootPath);
|
||||
const rootNode: FullFolderInfo = {
|
||||
name: rootName,
|
||||
path: rootPath,
|
||||
files: [],
|
||||
subFolders: [],
|
||||
totalChildren: 0,
|
||||
totalFiles: 0,
|
||||
};
|
||||
|
||||
const queue: Array<{ folderInfo: FullFolderInfo; currentPath: string }> = [
|
||||
{ folderInfo: rootNode, currentPath: rootPath },
|
||||
];
|
||||
let currentItemCount = 0;
|
||||
// Count the root node itself as one item if we are not just listing its content
|
||||
|
||||
const processedPaths = new Set<string>(); // To avoid processing same path if symlinks create loops
|
||||
|
||||
while (queue.length > 0) {
|
||||
const { folderInfo, currentPath } = queue.shift()!;
|
||||
|
||||
if (processedPaths.has(currentPath)) {
|
||||
continue;
|
||||
}
|
||||
processedPaths.add(currentPath);
|
||||
|
||||
if (currentItemCount >= options.maxItems) {
|
||||
// If the root itself caused us to exceed, we can't really show anything.
|
||||
// Otherwise, this folder won't be processed further.
|
||||
// The parent that queued this would have set its own hasMoreSubfolders flag.
|
||||
continue;
|
||||
}
|
||||
|
||||
let entries: Dirent[];
|
||||
try {
|
||||
const rawEntries = await fs.readdir(currentPath, { withFileTypes: true });
|
||||
// Sort entries alphabetically by name for consistent processing order
|
||||
entries = rawEntries.sort((a, b) => a.name.localeCompare(b.name));
|
||||
} catch (error: unknown) {
|
||||
if (
|
||||
isNodeError(error) &&
|
||||
(error.code === 'EACCES' || error.code === 'ENOENT')
|
||||
) {
|
||||
console.warn(
|
||||
`Warning: Could not read directory ${currentPath}: ${error.message}`,
|
||||
);
|
||||
if (currentPath === rootPath && error.code === 'ENOENT') {
|
||||
return null; // Root directory itself not found
|
||||
}
|
||||
// For other EACCES/ENOENT on subdirectories, just skip them.
|
||||
continue;
|
||||
}
|
||||
throw error;
|
||||
}
|
||||
|
||||
const filesInCurrentDir: string[] = [];
|
||||
const subFoldersInCurrentDir: FullFolderInfo[] = [];
|
||||
|
||||
// Process files first in the current directory
|
||||
for (const entry of entries) {
|
||||
if (entry.isFile()) {
|
||||
if (currentItemCount >= options.maxItems) {
|
||||
folderInfo.hasMoreFiles = true;
|
||||
break;
|
||||
}
|
||||
const fileName = entry.name;
|
||||
if (
|
||||
!options.fileIncludePattern ||
|
||||
options.fileIncludePattern.test(fileName)
|
||||
) {
|
||||
filesInCurrentDir.push(fileName);
|
||||
currentItemCount++;
|
||||
folderInfo.totalFiles++;
|
||||
folderInfo.totalChildren++;
|
||||
}
|
||||
}
|
||||
}
|
||||
folderInfo.files = filesInCurrentDir;
|
||||
|
||||
// Then process directories and queue them
|
||||
for (const entry of entries) {
|
||||
if (entry.isDirectory()) {
|
||||
// Check if adding this directory ITSELF would meet or exceed maxItems
|
||||
// (currentItemCount refers to items *already* added before this one)
|
||||
if (currentItemCount >= options.maxItems) {
|
||||
folderInfo.hasMoreSubfolders = true;
|
||||
break; // Already at limit, cannot add this folder or any more
|
||||
}
|
||||
// If adding THIS folder makes us hit the limit exactly, and it might have children,
|
||||
// it's better to show '...' for the parent, unless this is the very last item slot.
|
||||
// This logic is tricky. Let's try a simpler: if we can't add this item, mark and break.
|
||||
|
||||
const subFolderName = entry.name;
|
||||
const subFolderPath = path.join(currentPath, subFolderName);
|
||||
|
||||
if (options.ignoredFolders.has(subFolderName)) {
|
||||
const ignoredSubFolder: FullFolderInfo = {
|
||||
name: subFolderName,
|
||||
path: subFolderPath,
|
||||
files: [],
|
||||
subFolders: [],
|
||||
totalChildren: 0,
|
||||
totalFiles: 0,
|
||||
isIgnored: true,
|
||||
};
|
||||
subFoldersInCurrentDir.push(ignoredSubFolder);
|
||||
currentItemCount++; // Count the ignored folder itself
|
||||
folderInfo.totalChildren++; // Also counts towards parent's children
|
||||
continue;
|
||||
}
|
||||
|
||||
const subFolderNode: FullFolderInfo = {
|
||||
name: subFolderName,
|
||||
path: subFolderPath,
|
||||
files: [],
|
||||
subFolders: [],
|
||||
totalChildren: 0,
|
||||
totalFiles: 0,
|
||||
};
|
||||
subFoldersInCurrentDir.push(subFolderNode);
|
||||
currentItemCount++;
|
||||
folderInfo.totalChildren++; // Counts towards parent's children
|
||||
|
||||
// Add to queue for processing its children later
|
||||
queue.push({ folderInfo: subFolderNode, currentPath: subFolderPath });
|
||||
}
|
||||
}
|
||||
folderInfo.subFolders = subFoldersInCurrentDir;
|
||||
}
|
||||
|
||||
return rootNode;
|
||||
}
|
||||
|
||||
/**
|
||||
* Reads the directory structure using BFS, respecting maxItems.
|
||||
* @param node The current node in the reduced structure.
|
||||
* @param indent The current indentation string.
|
||||
* @param isLast Sibling indicator.
|
||||
* @param builder Array to build the string lines.
|
||||
*/
|
||||
function formatStructure(
|
||||
node: FullFolderInfo,
|
||||
currentIndent: string,
|
||||
isLastChildOfParent: boolean,
|
||||
isProcessingRootNode: boolean,
|
||||
builder: string[],
|
||||
): void {
|
||||
const connector = isLastChildOfParent ? '└───' : '├───';
|
||||
|
||||
// The root node of the structure (the one passed initially to getFolderStructure)
|
||||
// is not printed with a connector line itself, only its name as a header.
|
||||
// Its children are printed relative to that conceptual root.
|
||||
// Ignored root nodes ARE printed with a connector.
|
||||
if (!isProcessingRootNode || node.isIgnored) {
|
||||
builder.push(
|
||||
`${currentIndent}${connector}${node.name}/${node.isIgnored ? TRUNCATION_INDICATOR : ''}`,
|
||||
);
|
||||
}
|
||||
|
||||
// Determine the indent for the children of *this* node.
|
||||
// If *this* node was the root of the whole structure, its children start with no indent before their connectors.
|
||||
// Otherwise, children's indent extends from the current node's indent.
|
||||
const indentForChildren = isProcessingRootNode
|
||||
? ''
|
||||
: currentIndent + (isLastChildOfParent ? ' ' : '│ ');
|
||||
|
||||
// Render files of the current node
|
||||
const fileCount = node.files.length;
|
||||
for (let i = 0; i < fileCount; i++) {
|
||||
const isLastFileAmongSiblings =
|
||||
i === fileCount - 1 &&
|
||||
node.subFolders.length === 0 &&
|
||||
!node.hasMoreSubfolders;
|
||||
const fileConnector = isLastFileAmongSiblings ? '└───' : '├───';
|
||||
builder.push(`${indentForChildren}${fileConnector}${node.files[i]}`);
|
||||
}
|
||||
if (node.hasMoreFiles) {
|
||||
const isLastIndicatorAmongSiblings =
|
||||
node.subFolders.length === 0 && !node.hasMoreSubfolders;
|
||||
const fileConnector = isLastIndicatorAmongSiblings ? '└───' : '├───';
|
||||
builder.push(`${indentForChildren}${fileConnector}${TRUNCATION_INDICATOR}`);
|
||||
}
|
||||
|
||||
// Render subfolders of the current node
|
||||
const subFolderCount = node.subFolders.length;
|
||||
for (let i = 0; i < subFolderCount; i++) {
|
||||
const isLastSubfolderAmongSiblings =
|
||||
i === subFolderCount - 1 && !node.hasMoreSubfolders;
|
||||
// Children are never the root node being processed initially.
|
||||
formatStructure(
|
||||
node.subFolders[i],
|
||||
indentForChildren,
|
||||
isLastSubfolderAmongSiblings,
|
||||
false,
|
||||
builder,
|
||||
);
|
||||
}
|
||||
if (node.hasMoreSubfolders) {
|
||||
builder.push(`${indentForChildren}└───${TRUNCATION_INDICATOR}`);
|
||||
}
|
||||
}
|
||||
|
||||
// --- Main Exported Function ---
|
||||
|
||||
/**
|
||||
* Generates a string representation of a directory's structure,
|
||||
* limiting the number of items displayed. Ignored folders are shown
|
||||
* followed by '...' instead of their contents.
|
||||
*
|
||||
* @param directory The absolute or relative path to the directory.
|
||||
* @param options Optional configuration settings.
|
||||
* @returns A promise resolving to the formatted folder structure string.
|
||||
*/
|
||||
export async function getFolderStructure(
|
||||
directory: string,
|
||||
options?: FolderStructureOptions,
|
||||
): Promise<string> {
|
||||
const resolvedPath = path.resolve(directory);
|
||||
const mergedOptions: MergedFolderStructureOptions = {
|
||||
maxItems: options?.maxItems ?? MAX_ITEMS,
|
||||
ignoredFolders: options?.ignoredFolders ?? DEFAULT_IGNORED_FOLDERS,
|
||||
fileIncludePattern: options?.fileIncludePattern,
|
||||
};
|
||||
|
||||
try {
|
||||
// 1. Read the structure using BFS, respecting maxItems
|
||||
const structureRoot = await readFullStructure(resolvedPath, mergedOptions);
|
||||
|
||||
if (!structureRoot) {
|
||||
return `Error: Could not read directory "${resolvedPath}". Check path and permissions.`;
|
||||
}
|
||||
|
||||
// 2. Format the structure into a string
|
||||
const structureLines: string[] = [];
|
||||
// Pass true for isRoot for the initial call
|
||||
formatStructure(structureRoot, '', true, true, structureLines);
|
||||
|
||||
// 3. Build the final output string
|
||||
const displayPath = resolvedPath.replace(/\\/g, '/');
|
||||
|
||||
let disclaimer = '';
|
||||
// Check if truncation occurred anywhere or if ignored folders are present.
|
||||
// A simple check: if any node indicates more files/subfolders, or is ignored.
|
||||
let truncationOccurred = false;
|
||||
function checkForTruncation(node: FullFolderInfo) {
|
||||
if (node.hasMoreFiles || node.hasMoreSubfolders || node.isIgnored) {
|
||||
truncationOccurred = true;
|
||||
}
|
||||
if (!truncationOccurred) {
|
||||
for (const sub of node.subFolders) {
|
||||
checkForTruncation(sub);
|
||||
if (truncationOccurred) break;
|
||||
}
|
||||
}
|
||||
}
|
||||
checkForTruncation(structureRoot);
|
||||
|
||||
if (truncationOccurred) {
|
||||
disclaimer = `Folders or files indicated with ${TRUNCATION_INDICATOR} contain more items not shown, were ignored, or the display limit (${mergedOptions.maxItems} items) was reached.`;
|
||||
}
|
||||
|
||||
const summary =
|
||||
`Showing up to ${mergedOptions.maxItems} items (files + folders). ${disclaimer}`.trim();
|
||||
|
||||
return `${summary}\n\n${displayPath}/\n${structureLines.join('\n')}`;
|
||||
} catch (error: unknown) {
|
||||
console.error(`Error getting folder structure for ${resolvedPath}:`, error);
|
||||
return `Error processing directory "${resolvedPath}": ${getErrorMessage(error)}`;
|
||||
}
|
||||
}
|
||||
382
packages/core/src/utils/memoryDiscovery.test.ts
Normal file
382
packages/core/src/utils/memoryDiscovery.test.ts
Normal file
@@ -0,0 +1,382 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import {
|
||||
vi,
|
||||
describe,
|
||||
it,
|
||||
expect,
|
||||
beforeEach,
|
||||
// afterEach, // Removed unused import
|
||||
Mocked,
|
||||
} from 'vitest';
|
||||
import * as fsPromises from 'fs/promises';
|
||||
import * as fsSync from 'fs'; // For constants
|
||||
import { Stats, Dirent } from 'fs'; // Import types directly from 'fs'
|
||||
import * as os from 'os';
|
||||
import * as path from 'path';
|
||||
import { loadServerHierarchicalMemory } from './memoryDiscovery.js';
|
||||
import { GEMINI_CONFIG_DIR, GEMINI_MD_FILENAME } from '../tools/memoryTool.js';
|
||||
|
||||
// Mock the entire fs/promises module
|
||||
vi.mock('fs/promises');
|
||||
// Mock the parts of fsSync we might use (like constants or existsSync if needed)
|
||||
vi.mock('fs', async (importOriginal) => {
|
||||
const actual = await importOriginal<typeof fsSync>();
|
||||
return {
|
||||
...actual, // Spread actual to get all exports, including Stats and Dirent if they are classes/constructors
|
||||
constants: { ...actual.constants }, // Preserve constants
|
||||
// Mock other fsSync functions if directly used by memoryDiscovery, e.g., existsSync
|
||||
// existsSync: vi.fn(),
|
||||
};
|
||||
});
|
||||
vi.mock('os');
|
||||
|
||||
describe('loadServerHierarchicalMemory', () => {
|
||||
const mockFs = fsPromises as Mocked<typeof fsPromises>;
|
||||
const mockOs = os as Mocked<typeof os>;
|
||||
|
||||
const CWD = '/test/project/src';
|
||||
const PROJECT_ROOT = '/test/project';
|
||||
const USER_HOME = '/test/userhome';
|
||||
const GLOBAL_GEMINI_DIR = path.join(USER_HOME, GEMINI_CONFIG_DIR);
|
||||
const GLOBAL_GEMINI_FILE = path.join(GLOBAL_GEMINI_DIR, GEMINI_MD_FILENAME);
|
||||
|
||||
beforeEach(() => {
|
||||
vi.resetAllMocks();
|
||||
|
||||
mockOs.homedir.mockReturnValue(USER_HOME);
|
||||
mockFs.stat.mockRejectedValue(new Error('File not found'));
|
||||
mockFs.readdir.mockResolvedValue([]);
|
||||
mockFs.readFile.mockRejectedValue(new Error('File not found'));
|
||||
mockFs.access.mockRejectedValue(new Error('File not found'));
|
||||
});
|
||||
|
||||
it('should return empty memory and count if no GEMINI.md files are found', async () => {
|
||||
const { memoryContent, fileCount } = await loadServerHierarchicalMemory(
|
||||
CWD,
|
||||
false,
|
||||
);
|
||||
expect(memoryContent).toBe('');
|
||||
expect(fileCount).toBe(0);
|
||||
});
|
||||
|
||||
it('should load only the global GEMINI.md if present and others are not', async () => {
|
||||
mockFs.access.mockImplementation(async (p) => {
|
||||
if (p === GLOBAL_GEMINI_FILE) {
|
||||
return undefined;
|
||||
}
|
||||
throw new Error('File not found');
|
||||
});
|
||||
mockFs.readFile.mockImplementation(async (p) => {
|
||||
if (p === GLOBAL_GEMINI_FILE) {
|
||||
return 'Global memory content';
|
||||
}
|
||||
throw new Error('File not found');
|
||||
});
|
||||
|
||||
const { memoryContent, fileCount } = await loadServerHierarchicalMemory(
|
||||
CWD,
|
||||
false,
|
||||
);
|
||||
|
||||
expect(memoryContent).toBe(
|
||||
`--- Context from: ${path.relative(CWD, GLOBAL_GEMINI_FILE)} ---\nGlobal memory content\n--- End of Context from: ${path.relative(CWD, GLOBAL_GEMINI_FILE)} ---`,
|
||||
);
|
||||
expect(fileCount).toBe(1);
|
||||
expect(mockFs.readFile).toHaveBeenCalledWith(GLOBAL_GEMINI_FILE, 'utf-8');
|
||||
});
|
||||
|
||||
it('should load GEMINI.md files by upward traversal from CWD to project root', async () => {
|
||||
const projectRootGeminiFile = path.join(PROJECT_ROOT, GEMINI_MD_FILENAME);
|
||||
const srcGeminiFile = path.join(CWD, GEMINI_MD_FILENAME);
|
||||
|
||||
mockFs.stat.mockImplementation(async (p) => {
|
||||
if (p === path.join(PROJECT_ROOT, '.git')) {
|
||||
return { isDirectory: () => true } as Stats;
|
||||
}
|
||||
throw new Error('File not found');
|
||||
});
|
||||
|
||||
mockFs.access.mockImplementation(async (p) => {
|
||||
if (p === projectRootGeminiFile || p === srcGeminiFile) {
|
||||
return undefined;
|
||||
}
|
||||
throw new Error('File not found');
|
||||
});
|
||||
|
||||
mockFs.readFile.mockImplementation(async (p) => {
|
||||
if (p === projectRootGeminiFile) {
|
||||
return 'Project root memory';
|
||||
}
|
||||
if (p === srcGeminiFile) {
|
||||
return 'Src directory memory';
|
||||
}
|
||||
throw new Error('File not found');
|
||||
});
|
||||
|
||||
const { memoryContent, fileCount } = await loadServerHierarchicalMemory(
|
||||
CWD,
|
||||
false,
|
||||
);
|
||||
const expectedContent =
|
||||
`--- Context from: ${path.relative(CWD, projectRootGeminiFile)} ---\nProject root memory\n--- End of Context from: ${path.relative(CWD, projectRootGeminiFile)} ---\n\n` +
|
||||
`--- Context from: ${GEMINI_MD_FILENAME} ---\nSrc directory memory\n--- End of Context from: ${GEMINI_MD_FILENAME} ---`;
|
||||
|
||||
expect(memoryContent).toBe(expectedContent);
|
||||
expect(fileCount).toBe(2);
|
||||
expect(mockFs.readFile).toHaveBeenCalledWith(
|
||||
projectRootGeminiFile,
|
||||
'utf-8',
|
||||
);
|
||||
expect(mockFs.readFile).toHaveBeenCalledWith(srcGeminiFile, 'utf-8');
|
||||
});
|
||||
|
||||
it('should load GEMINI.md files by downward traversal from CWD', async () => {
|
||||
const subDir = path.join(CWD, 'subdir');
|
||||
const subDirGeminiFile = path.join(subDir, GEMINI_MD_FILENAME);
|
||||
const cwdGeminiFile = path.join(CWD, GEMINI_MD_FILENAME);
|
||||
|
||||
mockFs.access.mockImplementation(async (p) => {
|
||||
if (p === cwdGeminiFile || p === subDirGeminiFile) return undefined;
|
||||
throw new Error('File not found');
|
||||
});
|
||||
|
||||
mockFs.readFile.mockImplementation(async (p) => {
|
||||
if (p === cwdGeminiFile) return 'CWD memory';
|
||||
if (p === subDirGeminiFile) return 'Subdir memory';
|
||||
throw new Error('File not found');
|
||||
});
|
||||
|
||||
mockFs.readdir.mockImplementation((async (
|
||||
p: fsSync.PathLike,
|
||||
): Promise<Dirent[]> => {
|
||||
if (p === CWD) {
|
||||
return [
|
||||
{
|
||||
name: GEMINI_MD_FILENAME,
|
||||
isFile: () => true,
|
||||
isDirectory: () => false,
|
||||
},
|
||||
{ name: 'subdir', isFile: () => false, isDirectory: () => true },
|
||||
] as Dirent[];
|
||||
}
|
||||
if (p === subDir) {
|
||||
return [
|
||||
{
|
||||
name: GEMINI_MD_FILENAME,
|
||||
isFile: () => true,
|
||||
isDirectory: () => false,
|
||||
},
|
||||
] as Dirent[];
|
||||
}
|
||||
return [] as Dirent[];
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
}) as any);
|
||||
|
||||
const { memoryContent, fileCount } = await loadServerHierarchicalMemory(
|
||||
CWD,
|
||||
false,
|
||||
);
|
||||
const expectedContent =
|
||||
`--- Context from: ${GEMINI_MD_FILENAME} ---\nCWD memory\n--- End of Context from: ${GEMINI_MD_FILENAME} ---\n\n` +
|
||||
`--- Context from: ${path.join('subdir', GEMINI_MD_FILENAME)} ---\nSubdir memory\n--- End of Context from: ${path.join('subdir', GEMINI_MD_FILENAME)} ---`;
|
||||
|
||||
expect(memoryContent).toBe(expectedContent);
|
||||
expect(fileCount).toBe(2);
|
||||
});
|
||||
|
||||
it('should load and correctly order global, upward, and downward GEMINI.md files', async () => {
|
||||
const projectParentDir = path.dirname(PROJECT_ROOT);
|
||||
const projectParentGeminiFile = path.join(
|
||||
projectParentDir,
|
||||
GEMINI_MD_FILENAME,
|
||||
);
|
||||
const projectRootGeminiFile = path.join(PROJECT_ROOT, GEMINI_MD_FILENAME);
|
||||
const cwdGeminiFile = path.join(CWD, GEMINI_MD_FILENAME);
|
||||
const subDir = path.join(CWD, 'sub');
|
||||
const subDirGeminiFile = path.join(subDir, GEMINI_MD_FILENAME);
|
||||
|
||||
mockFs.stat.mockImplementation(async (p) => {
|
||||
if (p === path.join(PROJECT_ROOT, '.git')) {
|
||||
return { isDirectory: () => true } as Stats;
|
||||
}
|
||||
throw new Error('File not found');
|
||||
});
|
||||
|
||||
mockFs.access.mockImplementation(async (p) => {
|
||||
if (
|
||||
p === GLOBAL_GEMINI_FILE ||
|
||||
p === projectParentGeminiFile ||
|
||||
p === projectRootGeminiFile ||
|
||||
p === cwdGeminiFile ||
|
||||
p === subDirGeminiFile
|
||||
) {
|
||||
return undefined;
|
||||
}
|
||||
throw new Error('File not found');
|
||||
});
|
||||
|
||||
mockFs.readFile.mockImplementation(async (p) => {
|
||||
if (p === GLOBAL_GEMINI_FILE) return 'Global memory';
|
||||
if (p === projectParentGeminiFile) return 'Project parent memory';
|
||||
if (p === projectRootGeminiFile) return 'Project root memory';
|
||||
if (p === cwdGeminiFile) return 'CWD memory';
|
||||
if (p === subDirGeminiFile) return 'Subdir memory';
|
||||
throw new Error('File not found');
|
||||
});
|
||||
|
||||
mockFs.readdir.mockImplementation((async (
|
||||
p: fsSync.PathLike,
|
||||
): Promise<Dirent[]> => {
|
||||
if (p === CWD) {
|
||||
return [
|
||||
{ name: 'sub', isFile: () => false, isDirectory: () => true },
|
||||
] as Dirent[];
|
||||
}
|
||||
if (p === subDir) {
|
||||
return [
|
||||
{
|
||||
name: GEMINI_MD_FILENAME,
|
||||
isFile: () => true,
|
||||
isDirectory: () => false,
|
||||
},
|
||||
] as Dirent[];
|
||||
}
|
||||
return [] as Dirent[];
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
}) as any);
|
||||
|
||||
const { memoryContent, fileCount } = await loadServerHierarchicalMemory(
|
||||
CWD,
|
||||
false,
|
||||
);
|
||||
|
||||
const relPathGlobal = path.relative(CWD, GLOBAL_GEMINI_FILE);
|
||||
const relPathProjectParent = path.relative(CWD, projectParentGeminiFile);
|
||||
const relPathProjectRoot = path.relative(CWD, projectRootGeminiFile);
|
||||
const relPathCwd = GEMINI_MD_FILENAME;
|
||||
const relPathSubDir = path.join('sub', GEMINI_MD_FILENAME);
|
||||
|
||||
const expectedContent = [
|
||||
`--- Context from: ${relPathGlobal} ---\nGlobal memory\n--- End of Context from: ${relPathGlobal} ---`,
|
||||
`--- Context from: ${relPathProjectParent} ---\nProject parent memory\n--- End of Context from: ${relPathProjectParent} ---`,
|
||||
`--- Context from: ${relPathProjectRoot} ---\nProject root memory\n--- End of Context from: ${relPathProjectRoot} ---`,
|
||||
`--- Context from: ${relPathCwd} ---\nCWD memory\n--- End of Context from: ${relPathCwd} ---`,
|
||||
`--- Context from: ${relPathSubDir} ---\nSubdir memory\n--- End of Context from: ${relPathSubDir} ---`,
|
||||
].join('\n\n');
|
||||
|
||||
expect(memoryContent).toBe(expectedContent);
|
||||
expect(fileCount).toBe(5);
|
||||
});
|
||||
|
||||
it('should ignore specified directories during downward scan', async () => {
|
||||
const ignoredDir = path.join(CWD, 'node_modules');
|
||||
const ignoredDirGeminiFile = path.join(ignoredDir, GEMINI_MD_FILENAME);
|
||||
const regularSubDir = path.join(CWD, 'my_code');
|
||||
const regularSubDirGeminiFile = path.join(
|
||||
regularSubDir,
|
||||
GEMINI_MD_FILENAME,
|
||||
);
|
||||
|
||||
mockFs.access.mockImplementation(async (p) => {
|
||||
if (p === regularSubDirGeminiFile) return undefined;
|
||||
if (p === ignoredDirGeminiFile)
|
||||
throw new Error('Should not access ignored file');
|
||||
throw new Error('File not found');
|
||||
});
|
||||
|
||||
mockFs.readFile.mockImplementation(async (p) => {
|
||||
if (p === regularSubDirGeminiFile) return 'My code memory';
|
||||
throw new Error('File not found');
|
||||
});
|
||||
|
||||
mockFs.readdir.mockImplementation((async (
|
||||
p: fsSync.PathLike,
|
||||
): Promise<Dirent[]> => {
|
||||
if (p === CWD) {
|
||||
return [
|
||||
{
|
||||
name: 'node_modules',
|
||||
isFile: () => false,
|
||||
isDirectory: () => true,
|
||||
},
|
||||
{ name: 'my_code', isFile: () => false, isDirectory: () => true },
|
||||
] as Dirent[];
|
||||
}
|
||||
if (p === regularSubDir) {
|
||||
return [
|
||||
{
|
||||
name: GEMINI_MD_FILENAME,
|
||||
isFile: () => true,
|
||||
isDirectory: () => false,
|
||||
},
|
||||
] as Dirent[];
|
||||
}
|
||||
if (p === ignoredDir) {
|
||||
return [
|
||||
{
|
||||
name: GEMINI_MD_FILENAME,
|
||||
isFile: () => true,
|
||||
isDirectory: () => false,
|
||||
},
|
||||
] as Dirent[];
|
||||
}
|
||||
return [] as Dirent[];
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
}) as any);
|
||||
|
||||
const { memoryContent, fileCount } = await loadServerHierarchicalMemory(
|
||||
CWD,
|
||||
false,
|
||||
);
|
||||
|
||||
const expectedContent = `--- Context from: ${path.join('my_code', GEMINI_MD_FILENAME)} ---\nMy code memory\n--- End of Context from: ${path.join('my_code', GEMINI_MD_FILENAME)} ---`;
|
||||
|
||||
expect(memoryContent).toBe(expectedContent);
|
||||
expect(fileCount).toBe(1);
|
||||
expect(mockFs.readFile).not.toHaveBeenCalledWith(
|
||||
ignoredDirGeminiFile,
|
||||
'utf-8',
|
||||
);
|
||||
});
|
||||
|
||||
it('should respect MAX_DIRECTORIES_TO_SCAN_FOR_MEMORY during downward scan', async () => {
|
||||
const consoleDebugSpy = vi
|
||||
.spyOn(console, 'debug')
|
||||
.mockImplementation(() => {});
|
||||
|
||||
const dirNames: Dirent[] = [];
|
||||
for (let i = 0; i < 250; i++) {
|
||||
dirNames.push({
|
||||
name: `deep_dir_${i}`,
|
||||
isFile: () => false,
|
||||
isDirectory: () => true,
|
||||
} as Dirent);
|
||||
}
|
||||
|
||||
mockFs.readdir.mockImplementation((async (
|
||||
p: fsSync.PathLike,
|
||||
): Promise<Dirent[]> => {
|
||||
if (p === CWD) return dirNames;
|
||||
if (p.toString().startsWith(path.join(CWD, 'deep_dir_')))
|
||||
return [] as Dirent[];
|
||||
return [] as Dirent[];
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
}) as any);
|
||||
mockFs.access.mockRejectedValue(new Error('not found'));
|
||||
|
||||
await loadServerHierarchicalMemory(CWD, true);
|
||||
|
||||
expect(consoleDebugSpy).toHaveBeenCalledWith(
|
||||
expect.stringContaining('[DEBUG] [MemoryDiscovery]'),
|
||||
expect.stringContaining(
|
||||
'Max directory scan limit (200) reached. Stopping downward scan at:',
|
||||
),
|
||||
);
|
||||
consoleDebugSpy.mockRestore();
|
||||
});
|
||||
});
|
||||
351
packages/core/src/utils/memoryDiscovery.ts
Normal file
351
packages/core/src/utils/memoryDiscovery.ts
Normal file
@@ -0,0 +1,351 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import * as fs from 'fs/promises';
|
||||
import * as fsSync from 'fs';
|
||||
import * as path from 'path';
|
||||
import { homedir } from 'os';
|
||||
import { GEMINI_CONFIG_DIR, GEMINI_MD_FILENAME } from '../tools/memoryTool.js';
|
||||
|
||||
// Simple console logger, similar to the one previously in CLI's config.ts
|
||||
// TODO: Integrate with a more robust server-side logger if available/appropriate.
|
||||
const logger = {
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
debug: (...args: any[]) =>
|
||||
console.debug('[DEBUG] [MemoryDiscovery]', ...args),
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
warn: (...args: any[]) => console.warn('[WARN] [MemoryDiscovery]', ...args),
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
error: (...args: any[]) =>
|
||||
console.error('[ERROR] [MemoryDiscovery]', ...args),
|
||||
};
|
||||
|
||||
// TODO(adh): Refactor to use a shared ignore list with other tools like glob and read-many-files.
|
||||
const DEFAULT_IGNORE_DIRECTORIES = [
|
||||
'node_modules',
|
||||
'.git',
|
||||
'dist',
|
||||
'build',
|
||||
'out',
|
||||
'coverage',
|
||||
'.vscode',
|
||||
'.idea',
|
||||
'.DS_Store',
|
||||
];
|
||||
|
||||
const MAX_DIRECTORIES_TO_SCAN_FOR_MEMORY = 200;
|
||||
|
||||
interface GeminiFileContent {
|
||||
filePath: string;
|
||||
content: string | null;
|
||||
}
|
||||
|
||||
async function findProjectRoot(startDir: string): Promise<string | null> {
|
||||
let currentDir = path.resolve(startDir);
|
||||
while (true) {
|
||||
const gitPath = path.join(currentDir, '.git');
|
||||
try {
|
||||
const stats = await fs.stat(gitPath);
|
||||
if (stats.isDirectory()) {
|
||||
return currentDir;
|
||||
}
|
||||
} catch (error: unknown) {
|
||||
if (typeof error === 'object' && error !== null && 'code' in error) {
|
||||
const fsError = error as { code: string; message: string };
|
||||
if (fsError.code !== 'ENOENT') {
|
||||
logger.warn(
|
||||
`Error checking for .git directory at ${gitPath}: ${fsError.message}`,
|
||||
);
|
||||
}
|
||||
} else {
|
||||
logger.warn(
|
||||
`Non-standard error checking for .git directory at ${gitPath}: ${String(error)}`,
|
||||
);
|
||||
}
|
||||
}
|
||||
const parentDir = path.dirname(currentDir);
|
||||
if (parentDir === currentDir) {
|
||||
return null;
|
||||
}
|
||||
currentDir = parentDir;
|
||||
}
|
||||
}
|
||||
|
||||
async function collectDownwardGeminiFiles(
|
||||
directory: string,
|
||||
debugMode: boolean,
|
||||
ignoreDirs: string[],
|
||||
scannedDirCount: { count: number },
|
||||
maxScanDirs: number,
|
||||
): Promise<string[]> {
|
||||
if (scannedDirCount.count >= maxScanDirs) {
|
||||
if (debugMode)
|
||||
logger.debug(
|
||||
`Max directory scan limit (${maxScanDirs}) reached. Stopping downward scan at: ${directory}`,
|
||||
);
|
||||
return [];
|
||||
}
|
||||
scannedDirCount.count++;
|
||||
|
||||
if (debugMode)
|
||||
logger.debug(
|
||||
`Scanning downward for ${GEMINI_MD_FILENAME} files in: ${directory} (scanned: ${scannedDirCount.count}/${maxScanDirs})`,
|
||||
);
|
||||
const collectedPaths: string[] = [];
|
||||
try {
|
||||
const entries = await fs.readdir(directory, { withFileTypes: true });
|
||||
for (const entry of entries) {
|
||||
const fullPath = path.join(directory, entry.name);
|
||||
if (entry.isDirectory()) {
|
||||
if (ignoreDirs.includes(entry.name)) {
|
||||
if (debugMode)
|
||||
logger.debug(`Skipping ignored directory: ${fullPath}`);
|
||||
continue;
|
||||
}
|
||||
const subDirPaths = await collectDownwardGeminiFiles(
|
||||
fullPath,
|
||||
debugMode,
|
||||
ignoreDirs,
|
||||
scannedDirCount,
|
||||
maxScanDirs,
|
||||
);
|
||||
collectedPaths.push(...subDirPaths);
|
||||
} else if (entry.isFile() && entry.name === GEMINI_MD_FILENAME) {
|
||||
try {
|
||||
await fs.access(fullPath, fsSync.constants.R_OK);
|
||||
collectedPaths.push(fullPath);
|
||||
if (debugMode)
|
||||
logger.debug(
|
||||
`Found readable downward ${GEMINI_MD_FILENAME}: ${fullPath}`,
|
||||
);
|
||||
} catch {
|
||||
if (debugMode)
|
||||
logger.debug(
|
||||
`Downward ${GEMINI_MD_FILENAME} not readable, skipping: ${fullPath}`,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
const message = error instanceof Error ? error.message : String(error);
|
||||
logger.warn(`Error scanning directory ${directory}: ${message}`);
|
||||
if (debugMode) logger.debug(`Failed to scan directory: ${directory}`);
|
||||
}
|
||||
return collectedPaths;
|
||||
}
|
||||
|
||||
async function getGeminiMdFilePathsInternal(
|
||||
currentWorkingDirectory: string,
|
||||
userHomePath: string, // Keep userHomePath as a parameter for clarity
|
||||
debugMode: boolean,
|
||||
): Promise<string[]> {
|
||||
const resolvedCwd = path.resolve(currentWorkingDirectory);
|
||||
const resolvedHome = path.resolve(userHomePath);
|
||||
const globalMemoryPath = path.join(
|
||||
resolvedHome,
|
||||
GEMINI_CONFIG_DIR,
|
||||
GEMINI_MD_FILENAME,
|
||||
);
|
||||
const paths: string[] = [];
|
||||
|
||||
if (debugMode)
|
||||
logger.debug(
|
||||
`Searching for ${GEMINI_MD_FILENAME} starting from CWD: ${resolvedCwd}`,
|
||||
);
|
||||
if (debugMode) logger.debug(`User home directory: ${resolvedHome}`);
|
||||
|
||||
try {
|
||||
await fs.access(globalMemoryPath, fsSync.constants.R_OK);
|
||||
paths.push(globalMemoryPath);
|
||||
if (debugMode)
|
||||
logger.debug(
|
||||
`Found readable global ${GEMINI_MD_FILENAME}: ${globalMemoryPath}`,
|
||||
);
|
||||
} catch {
|
||||
if (debugMode)
|
||||
logger.debug(
|
||||
`Global ${GEMINI_MD_FILENAME} not found or not readable: ${globalMemoryPath}`,
|
||||
);
|
||||
}
|
||||
|
||||
const projectRoot = await findProjectRoot(resolvedCwd);
|
||||
if (debugMode)
|
||||
logger.debug(`Determined project root: ${projectRoot ?? 'None'}`);
|
||||
|
||||
const upwardPaths: string[] = [];
|
||||
let currentDir = resolvedCwd;
|
||||
// Determine the directory that signifies the top of the project or user-specific space.
|
||||
const ultimateStopDir = projectRoot
|
||||
? path.dirname(projectRoot)
|
||||
: path.dirname(resolvedHome);
|
||||
|
||||
while (currentDir && currentDir !== path.dirname(currentDir)) {
|
||||
// Loop until filesystem root or currentDir is empty
|
||||
if (debugMode) {
|
||||
logger.debug(
|
||||
`Checking for ${GEMINI_MD_FILENAME} in (upward scan): ${currentDir}`,
|
||||
);
|
||||
}
|
||||
|
||||
// Skip the global .gemini directory itself during upward scan from CWD,
|
||||
// as global is handled separately and explicitly first.
|
||||
if (currentDir === path.join(resolvedHome, GEMINI_CONFIG_DIR)) {
|
||||
if (debugMode) {
|
||||
logger.debug(
|
||||
`Upward scan reached global config dir path, stopping upward search here: ${currentDir}`,
|
||||
);
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
const potentialPath = path.join(currentDir, GEMINI_MD_FILENAME);
|
||||
try {
|
||||
await fs.access(potentialPath, fsSync.constants.R_OK);
|
||||
// Add to upwardPaths only if it's not the already added globalMemoryPath
|
||||
if (potentialPath !== globalMemoryPath) {
|
||||
upwardPaths.unshift(potentialPath);
|
||||
if (debugMode) {
|
||||
logger.debug(
|
||||
`Found readable upward ${GEMINI_MD_FILENAME}: ${potentialPath}`,
|
||||
);
|
||||
}
|
||||
}
|
||||
} catch {
|
||||
if (debugMode) {
|
||||
logger.debug(
|
||||
`Upward ${GEMINI_MD_FILENAME} not found or not readable in: ${currentDir}`,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// Stop condition: if currentDir is the ultimateStopDir, break after this iteration.
|
||||
if (currentDir === ultimateStopDir) {
|
||||
if (debugMode)
|
||||
logger.debug(
|
||||
`Reached ultimate stop directory for upward scan: ${currentDir}`,
|
||||
);
|
||||
break;
|
||||
}
|
||||
|
||||
currentDir = path.dirname(currentDir);
|
||||
}
|
||||
paths.push(...upwardPaths);
|
||||
|
||||
if (debugMode)
|
||||
logger.debug(`Starting downward scan from CWD: ${resolvedCwd}`);
|
||||
const scannedDirCount = { count: 0 };
|
||||
const downwardPaths = await collectDownwardGeminiFiles(
|
||||
resolvedCwd,
|
||||
debugMode,
|
||||
DEFAULT_IGNORE_DIRECTORIES,
|
||||
scannedDirCount,
|
||||
MAX_DIRECTORIES_TO_SCAN_FOR_MEMORY,
|
||||
);
|
||||
downwardPaths.sort(); // Sort for consistent ordering, though hierarchy might be more complex
|
||||
if (debugMode && downwardPaths.length > 0)
|
||||
logger.debug(
|
||||
`Found downward ${GEMINI_MD_FILENAME} files (sorted): ${JSON.stringify(downwardPaths)}`,
|
||||
);
|
||||
// Add downward paths only if they haven't been included already (e.g. from upward scan)
|
||||
for (const dPath of downwardPaths) {
|
||||
if (!paths.includes(dPath)) {
|
||||
paths.push(dPath);
|
||||
}
|
||||
}
|
||||
|
||||
if (debugMode)
|
||||
logger.debug(
|
||||
`Final ordered ${GEMINI_MD_FILENAME} paths to read: ${JSON.stringify(paths)}`,
|
||||
);
|
||||
return paths;
|
||||
}
|
||||
|
||||
async function readGeminiMdFiles(
|
||||
filePaths: string[],
|
||||
debugMode: boolean,
|
||||
): Promise<GeminiFileContent[]> {
|
||||
const results: GeminiFileContent[] = [];
|
||||
for (const filePath of filePaths) {
|
||||
try {
|
||||
const content = await fs.readFile(filePath, 'utf-8');
|
||||
results.push({ filePath, content });
|
||||
if (debugMode)
|
||||
logger.debug(
|
||||
`Successfully read: ${filePath} (Length: ${content.length})`,
|
||||
);
|
||||
} catch (error: unknown) {
|
||||
const message = error instanceof Error ? error.message : String(error);
|
||||
logger.warn(
|
||||
`Warning: Could not read ${GEMINI_MD_FILENAME} file at ${filePath}. Error: ${message}`,
|
||||
);
|
||||
results.push({ filePath, content: null }); // Still include it with null content
|
||||
if (debugMode) logger.debug(`Failed to read: ${filePath}`);
|
||||
}
|
||||
}
|
||||
return results;
|
||||
}
|
||||
|
||||
function concatenateInstructions(
|
||||
instructionContents: GeminiFileContent[],
|
||||
// CWD is needed to resolve relative paths for display markers
|
||||
currentWorkingDirectoryForDisplay: string,
|
||||
): string {
|
||||
return instructionContents
|
||||
.filter((item) => typeof item.content === 'string')
|
||||
.map((item) => {
|
||||
const trimmedContent = (item.content as string).trim();
|
||||
if (trimmedContent.length === 0) {
|
||||
return null;
|
||||
}
|
||||
const displayPath = path.isAbsolute(item.filePath)
|
||||
? path.relative(currentWorkingDirectoryForDisplay, item.filePath)
|
||||
: item.filePath;
|
||||
return `--- Context from: ${displayPath} ---\n${trimmedContent}\n--- End of Context from: ${displayPath} ---`;
|
||||
})
|
||||
.filter((block): block is string => block !== null)
|
||||
.join('\n\n');
|
||||
}
|
||||
|
||||
/**
|
||||
* Loads hierarchical GEMINI.md files and concatenates their content.
|
||||
* This function is intended for use by the server.
|
||||
*/
|
||||
export async function loadServerHierarchicalMemory(
|
||||
currentWorkingDirectory: string,
|
||||
debugMode: boolean,
|
||||
): Promise<{ memoryContent: string; fileCount: number }> {
|
||||
if (debugMode)
|
||||
logger.debug(
|
||||
`Loading server hierarchical memory for CWD: ${currentWorkingDirectory}`,
|
||||
);
|
||||
// For the server, homedir() refers to the server process's home.
|
||||
// This is consistent with how MemoryTool already finds the global path.
|
||||
const userHomePath = homedir();
|
||||
const filePaths = await getGeminiMdFilePathsInternal(
|
||||
currentWorkingDirectory,
|
||||
userHomePath,
|
||||
debugMode,
|
||||
);
|
||||
if (filePaths.length === 0) {
|
||||
if (debugMode) logger.debug('No GEMINI.md files found in hierarchy.');
|
||||
return { memoryContent: '', fileCount: 0 };
|
||||
}
|
||||
const contentsWithPaths = await readGeminiMdFiles(filePaths, debugMode);
|
||||
// Pass CWD for relative path display in concatenated content
|
||||
const combinedInstructions = concatenateInstructions(
|
||||
contentsWithPaths,
|
||||
currentWorkingDirectory,
|
||||
);
|
||||
if (debugMode)
|
||||
logger.debug(
|
||||
`Combined instructions length: ${combinedInstructions.length}`,
|
||||
);
|
||||
if (debugMode && combinedInstructions.length > 0)
|
||||
logger.debug(
|
||||
`Combined instructions (snippet): ${combinedInstructions.substring(0, 500)}...`,
|
||||
);
|
||||
return { memoryContent: combinedInstructions, fileCount: filePaths.length };
|
||||
}
|
||||
15
packages/core/src/utils/messageInspectors.ts
Normal file
15
packages/core/src/utils/messageInspectors.ts
Normal file
@@ -0,0 +1,15 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { Content } from '@google/genai';
|
||||
|
||||
export function isFunctionResponse(content: Content): boolean {
|
||||
return (
|
||||
content.role === 'user' &&
|
||||
!!content.parts &&
|
||||
content.parts.every((part) => !!part.functionResponse)
|
||||
);
|
||||
}
|
||||
235
packages/core/src/utils/nextSpeakerChecker.test.ts
Normal file
235
packages/core/src/utils/nextSpeakerChecker.test.ts
Normal file
@@ -0,0 +1,235 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { describe, it, expect, vi, beforeEach, Mock, afterEach } from 'vitest';
|
||||
import { Content, GoogleGenAI, Models } from '@google/genai';
|
||||
import { GeminiClient } from '../core/client.js';
|
||||
import { Config } from '../config/config.js';
|
||||
import { checkNextSpeaker, NextSpeakerResponse } from './nextSpeakerChecker.js';
|
||||
import { GeminiChat } from '../core/geminiChat.js';
|
||||
|
||||
// Mock GeminiClient and Config constructor
|
||||
vi.mock('../core/client.js');
|
||||
vi.mock('../config/config.js');
|
||||
|
||||
// Define mocks for GoogleGenAI and Models instances that will be used across tests
|
||||
const mockModelsInstance = {
|
||||
generateContent: vi.fn(),
|
||||
generateContentStream: vi.fn(),
|
||||
countTokens: vi.fn(),
|
||||
embedContent: vi.fn(),
|
||||
batchEmbedContents: vi.fn(),
|
||||
} as unknown as Models;
|
||||
|
||||
const mockGoogleGenAIInstance = {
|
||||
getGenerativeModel: vi.fn().mockReturnValue(mockModelsInstance),
|
||||
// Add other methods of GoogleGenAI if they are directly used by GeminiChat constructor or its methods
|
||||
} as unknown as GoogleGenAI;
|
||||
|
||||
vi.mock('@google/genai', async () => {
|
||||
const actualGenAI =
|
||||
await vi.importActual<typeof import('@google/genai')>('@google/genai');
|
||||
return {
|
||||
...actualGenAI,
|
||||
GoogleGenAI: vi.fn(() => mockGoogleGenAIInstance), // Mock constructor to return the predefined instance
|
||||
// If Models is instantiated directly in GeminiChat, mock its constructor too
|
||||
// For now, assuming Models instance is obtained via getGenerativeModel
|
||||
};
|
||||
});
|
||||
|
||||
describe('checkNextSpeaker', () => {
|
||||
let chatInstance: GeminiChat;
|
||||
let mockGeminiClient: GeminiClient;
|
||||
let MockConfig: Mock;
|
||||
const abortSignal = new AbortController().signal;
|
||||
|
||||
beforeEach(() => {
|
||||
MockConfig = vi.mocked(Config);
|
||||
const mockConfigInstance = new MockConfig(
|
||||
'test-api-key',
|
||||
'gemini-pro',
|
||||
false,
|
||||
'.',
|
||||
false,
|
||||
undefined,
|
||||
false,
|
||||
undefined,
|
||||
undefined,
|
||||
undefined,
|
||||
);
|
||||
|
||||
mockGeminiClient = new GeminiClient(mockConfigInstance);
|
||||
|
||||
// Reset mocks before each test to ensure test isolation
|
||||
vi.mocked(mockModelsInstance.generateContent).mockReset();
|
||||
vi.mocked(mockModelsInstance.generateContentStream).mockReset();
|
||||
|
||||
// GeminiChat will receive the mocked instances via the mocked GoogleGenAI constructor
|
||||
chatInstance = new GeminiChat(
|
||||
mockGoogleGenAIInstance, // This will be the instance returned by the mocked GoogleGenAI constructor
|
||||
mockModelsInstance, // This is the instance returned by mockGoogleGenAIInstance.getGenerativeModel
|
||||
'gemini-pro', // model name
|
||||
{},
|
||||
[], // initial history
|
||||
);
|
||||
|
||||
// Spy on getHistory for chatInstance
|
||||
vi.spyOn(chatInstance, 'getHistory');
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
it('should return null if history is empty', async () => {
|
||||
(chatInstance.getHistory as Mock).mockReturnValue([]);
|
||||
const result = await checkNextSpeaker(
|
||||
chatInstance,
|
||||
mockGeminiClient,
|
||||
abortSignal,
|
||||
);
|
||||
expect(result).toBeNull();
|
||||
expect(mockGeminiClient.generateJson).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should return null if the last speaker was the user', async () => {
|
||||
(chatInstance.getHistory as Mock).mockReturnValue([
|
||||
{ role: 'user', parts: [{ text: 'Hello' }] },
|
||||
] as Content[]);
|
||||
const result = await checkNextSpeaker(
|
||||
chatInstance,
|
||||
mockGeminiClient,
|
||||
abortSignal,
|
||||
);
|
||||
expect(result).toBeNull();
|
||||
expect(mockGeminiClient.generateJson).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it("should return { next_speaker: 'model' } when model intends to continue", async () => {
|
||||
(chatInstance.getHistory as Mock).mockReturnValue([
|
||||
{ role: 'model', parts: [{ text: 'I will now do something.' }] },
|
||||
] as Content[]);
|
||||
const mockApiResponse: NextSpeakerResponse = {
|
||||
reasoning: 'Model stated it will do something.',
|
||||
next_speaker: 'model',
|
||||
};
|
||||
(mockGeminiClient.generateJson as Mock).mockResolvedValue(mockApiResponse);
|
||||
|
||||
const result = await checkNextSpeaker(
|
||||
chatInstance,
|
||||
mockGeminiClient,
|
||||
abortSignal,
|
||||
);
|
||||
expect(result).toEqual(mockApiResponse);
|
||||
expect(mockGeminiClient.generateJson).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
|
||||
it("should return { next_speaker: 'user' } when model asks a question", async () => {
|
||||
(chatInstance.getHistory as Mock).mockReturnValue([
|
||||
{ role: 'model', parts: [{ text: 'What would you like to do?' }] },
|
||||
] as Content[]);
|
||||
const mockApiResponse: NextSpeakerResponse = {
|
||||
reasoning: 'Model asked a question.',
|
||||
next_speaker: 'user',
|
||||
};
|
||||
(mockGeminiClient.generateJson as Mock).mockResolvedValue(mockApiResponse);
|
||||
|
||||
const result = await checkNextSpeaker(
|
||||
chatInstance,
|
||||
mockGeminiClient,
|
||||
abortSignal,
|
||||
);
|
||||
expect(result).toEqual(mockApiResponse);
|
||||
});
|
||||
|
||||
it("should return { next_speaker: 'user' } when model makes a statement", async () => {
|
||||
(chatInstance.getHistory as Mock).mockReturnValue([
|
||||
{ role: 'model', parts: [{ text: 'This is a statement.' }] },
|
||||
] as Content[]);
|
||||
const mockApiResponse: NextSpeakerResponse = {
|
||||
reasoning: 'Model made a statement, awaiting user input.',
|
||||
next_speaker: 'user',
|
||||
};
|
||||
(mockGeminiClient.generateJson as Mock).mockResolvedValue(mockApiResponse);
|
||||
|
||||
const result = await checkNextSpeaker(
|
||||
chatInstance,
|
||||
mockGeminiClient,
|
||||
abortSignal,
|
||||
);
|
||||
expect(result).toEqual(mockApiResponse);
|
||||
});
|
||||
|
||||
it('should return null if geminiClient.generateJson throws an error', async () => {
|
||||
const consoleWarnSpy = vi
|
||||
.spyOn(console, 'warn')
|
||||
.mockImplementation(() => {});
|
||||
(chatInstance.getHistory as Mock).mockReturnValue([
|
||||
{ role: 'model', parts: [{ text: 'Some model output.' }] },
|
||||
] as Content[]);
|
||||
(mockGeminiClient.generateJson as Mock).mockRejectedValue(
|
||||
new Error('API Error'),
|
||||
);
|
||||
|
||||
const result = await checkNextSpeaker(
|
||||
chatInstance,
|
||||
mockGeminiClient,
|
||||
abortSignal,
|
||||
);
|
||||
expect(result).toBeNull();
|
||||
consoleWarnSpy.mockRestore();
|
||||
});
|
||||
|
||||
it('should return null if geminiClient.generateJson returns invalid JSON (missing next_speaker)', async () => {
|
||||
(chatInstance.getHistory as Mock).mockReturnValue([
|
||||
{ role: 'model', parts: [{ text: 'Some model output.' }] },
|
||||
] as Content[]);
|
||||
(mockGeminiClient.generateJson as Mock).mockResolvedValue({
|
||||
reasoning: 'This is incomplete.',
|
||||
} as unknown as NextSpeakerResponse); // Type assertion to simulate invalid response
|
||||
|
||||
const result = await checkNextSpeaker(
|
||||
chatInstance,
|
||||
mockGeminiClient,
|
||||
abortSignal,
|
||||
);
|
||||
expect(result).toBeNull();
|
||||
});
|
||||
|
||||
it('should return null if geminiClient.generateJson returns a non-string next_speaker', async () => {
|
||||
(chatInstance.getHistory as Mock).mockReturnValue([
|
||||
{ role: 'model', parts: [{ text: 'Some model output.' }] },
|
||||
] as Content[]);
|
||||
(mockGeminiClient.generateJson as Mock).mockResolvedValue({
|
||||
reasoning: 'Model made a statement, awaiting user input.',
|
||||
next_speaker: 123, // Invalid type
|
||||
} as unknown as NextSpeakerResponse);
|
||||
|
||||
const result = await checkNextSpeaker(
|
||||
chatInstance,
|
||||
mockGeminiClient,
|
||||
abortSignal,
|
||||
);
|
||||
expect(result).toBeNull();
|
||||
});
|
||||
|
||||
it('should return null if geminiClient.generateJson returns an invalid next_speaker string value', async () => {
|
||||
(chatInstance.getHistory as Mock).mockReturnValue([
|
||||
{ role: 'model', parts: [{ text: 'Some model output.' }] },
|
||||
] as Content[]);
|
||||
(mockGeminiClient.generateJson as Mock).mockResolvedValue({
|
||||
reasoning: 'Model made a statement, awaiting user input.',
|
||||
next_speaker: 'neither', // Invalid enum value
|
||||
} as unknown as NextSpeakerResponse);
|
||||
|
||||
const result = await checkNextSpeaker(
|
||||
chatInstance,
|
||||
mockGeminiClient,
|
||||
abortSignal,
|
||||
);
|
||||
expect(result).toBeNull();
|
||||
});
|
||||
});
|
||||
151
packages/core/src/utils/nextSpeakerChecker.ts
Normal file
151
packages/core/src/utils/nextSpeakerChecker.ts
Normal file
@@ -0,0 +1,151 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { Content, SchemaUnion, Type } from '@google/genai';
|
||||
import { GeminiClient } from '../core/client.js';
|
||||
import { GeminiChat } from '../core/geminiChat.js';
|
||||
import { isFunctionResponse } from './messageInspectors.js';
|
||||
|
||||
const CHECK_PROMPT = `Analyze *only* the content and structure of your immediately preceding response (your last turn in the conversation history). Based *strictly* on that response, determine who should logically speak next: the 'user' or the 'model' (you).
|
||||
**Decision Rules (apply in order):**
|
||||
1. **Model Continues:** If your last response explicitly states an immediate next action *you* intend to take (e.g., "Next, I will...", "Now I'll process...", "Moving on to analyze...", indicates an intended tool call that didn't execute), OR if the response seems clearly incomplete (cut off mid-thought without a natural conclusion), then the **'model'** should speak next.
|
||||
2. **Question to User:** If your last response ends with a direct question specifically addressed *to the user*, then the **'user'** should speak next.
|
||||
3. **Waiting for User:** If your last response completed a thought, statement, or task *and* does not meet the criteria for Rule 1 (Model Continues) or Rule 2 (Question to User), it implies a pause expecting user input or reaction. In this case, the **'user'** should speak next.
|
||||
**Output Format:**
|
||||
Respond *only* in JSON format according to the following schema. Do not include any text outside the JSON structure.
|
||||
\`\`\`json
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"reasoning": {
|
||||
"type": "string",
|
||||
"description": "Brief explanation justifying the 'next_speaker' choice based *strictly* on the applicable rule and the content/structure of the preceding turn."
|
||||
},
|
||||
"next_speaker": {
|
||||
"type": "string",
|
||||
"enum": ["user", "model"],
|
||||
"description": "Who should speak next based *only* on the preceding turn and the decision rules."
|
||||
}
|
||||
},
|
||||
"required": ["next_speaker", "reasoning"]
|
||||
}
|
||||
\`\`\`
|
||||
`;
|
||||
|
||||
const RESPONSE_SCHEMA: SchemaUnion = {
|
||||
type: Type.OBJECT,
|
||||
properties: {
|
||||
reasoning: {
|
||||
type: Type.STRING,
|
||||
description:
|
||||
"Brief explanation justifying the 'next_speaker' choice based *strictly* on the applicable rule and the content/structure of the preceding turn.",
|
||||
},
|
||||
next_speaker: {
|
||||
type: Type.STRING,
|
||||
enum: ['user', 'model'],
|
||||
description:
|
||||
'Who should speak next based *only* on the preceding turn and the decision rules',
|
||||
},
|
||||
},
|
||||
required: ['reasoning', 'next_speaker'],
|
||||
};
|
||||
|
||||
export interface NextSpeakerResponse {
|
||||
reasoning: string;
|
||||
next_speaker: 'user' | 'model';
|
||||
}
|
||||
|
||||
export async function checkNextSpeaker(
|
||||
chat: GeminiChat,
|
||||
geminiClient: GeminiClient,
|
||||
abortSignal: AbortSignal,
|
||||
): Promise<NextSpeakerResponse | null> {
|
||||
// We need to capture the curated history because there are many moments when the model will return invalid turns
|
||||
// that when passed back up to the endpoint will break subsequent calls. An example of this is when the model decides
|
||||
// to respond with an empty part collection if you were to send that message back to the server it will respond with
|
||||
// a 400 indicating that model part collections MUST have content.
|
||||
const curatedHistory = chat.getHistory(/* curated */ true);
|
||||
|
||||
// Ensure there's a model response to analyze
|
||||
if (curatedHistory.length === 0) {
|
||||
// Cannot determine next speaker if history is empty.
|
||||
return null;
|
||||
}
|
||||
|
||||
const comprehensiveHistory = chat.getHistory();
|
||||
// If comprehensiveHistory is empty, there is no last message to check.
|
||||
// This case should ideally be caught by the curatedHistory.length check earlier,
|
||||
// but as a safeguard:
|
||||
if (comprehensiveHistory.length === 0) {
|
||||
return null;
|
||||
}
|
||||
const lastComprehensiveMessage =
|
||||
comprehensiveHistory[comprehensiveHistory.length - 1];
|
||||
|
||||
// If the last message is a user message containing only function_responses,
|
||||
// then the model should speak next.
|
||||
if (
|
||||
lastComprehensiveMessage &&
|
||||
isFunctionResponse(lastComprehensiveMessage)
|
||||
) {
|
||||
return {
|
||||
reasoning:
|
||||
'The last message was a function response, so the model should speak next.',
|
||||
next_speaker: 'model',
|
||||
};
|
||||
}
|
||||
|
||||
if (
|
||||
lastComprehensiveMessage &&
|
||||
lastComprehensiveMessage.role === 'model' &&
|
||||
lastComprehensiveMessage.parts &&
|
||||
lastComprehensiveMessage.parts.length === 0
|
||||
) {
|
||||
lastComprehensiveMessage.parts.push({ text: '' });
|
||||
return {
|
||||
reasoning:
|
||||
'The last message was a filler model message with no content (nothing for user to act on), model should speak next.',
|
||||
next_speaker: 'model',
|
||||
};
|
||||
}
|
||||
|
||||
// Things checked out. Lets proceed to potentially making an LLM request.
|
||||
|
||||
const lastMessage = curatedHistory[curatedHistory.length - 1];
|
||||
if (!lastMessage || lastMessage.role !== 'model') {
|
||||
// Cannot determine next speaker if the last turn wasn't from the model
|
||||
// or if history is empty.
|
||||
return null;
|
||||
}
|
||||
|
||||
const contents: Content[] = [
|
||||
...curatedHistory,
|
||||
{ role: 'user', parts: [{ text: CHECK_PROMPT }] },
|
||||
];
|
||||
|
||||
try {
|
||||
const parsedResponse = (await geminiClient.generateJson(
|
||||
contents,
|
||||
RESPONSE_SCHEMA,
|
||||
abortSignal,
|
||||
)) as unknown as NextSpeakerResponse;
|
||||
|
||||
if (
|
||||
parsedResponse &&
|
||||
parsedResponse.next_speaker &&
|
||||
['user', 'model'].includes(parsedResponse.next_speaker)
|
||||
) {
|
||||
return parsedResponse;
|
||||
}
|
||||
return null;
|
||||
} catch (error) {
|
||||
console.warn(
|
||||
'Failed to talk to Gemini endpoint when seeing if conversation should continue.',
|
||||
error,
|
||||
);
|
||||
return null;
|
||||
}
|
||||
}
|
||||
139
packages/core/src/utils/paths.ts
Normal file
139
packages/core/src/utils/paths.ts
Normal file
@@ -0,0 +1,139 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import path from 'node:path';
|
||||
import os from 'os';
|
||||
|
||||
/**
|
||||
* Replaces the home directory with a tilde.
|
||||
* @param path - The path to tildeify.
|
||||
* @returns The tildeified path.
|
||||
*/
|
||||
export function tildeifyPath(path: string): string {
|
||||
const homeDir = os.homedir();
|
||||
if (path.startsWith(homeDir)) {
|
||||
return path.replace(homeDir, '~');
|
||||
}
|
||||
return path;
|
||||
}
|
||||
|
||||
/**
|
||||
* Shortens a path string if it exceeds maxLen, prioritizing the start and end segments.
|
||||
* Example: /path/to/a/very/long/file.txt -> /path/.../long/file.txt
|
||||
*/
|
||||
export function shortenPath(filePath: string, maxLen: number = 35): string {
|
||||
if (filePath.length <= maxLen) {
|
||||
return filePath;
|
||||
}
|
||||
|
||||
const parsedPath = path.parse(filePath);
|
||||
const root = parsedPath.root;
|
||||
const separator = path.sep;
|
||||
|
||||
// Get segments of the path *after* the root
|
||||
const relativePath = filePath.substring(root.length);
|
||||
const segments = relativePath.split(separator).filter((s) => s !== ''); // Filter out empty segments
|
||||
|
||||
// Handle cases with no segments after root (e.g., "/", "C:\") or only one segment
|
||||
if (segments.length <= 1) {
|
||||
// Fallback to simple start/end truncation for very short paths or single segments
|
||||
const keepLen = Math.floor((maxLen - 3) / 2);
|
||||
// Ensure keepLen is not negative if maxLen is very small
|
||||
if (keepLen <= 0) {
|
||||
return filePath.substring(0, maxLen - 3) + '...';
|
||||
}
|
||||
const start = filePath.substring(0, keepLen);
|
||||
const end = filePath.substring(filePath.length - keepLen);
|
||||
return `${start}...${end}`;
|
||||
}
|
||||
|
||||
const firstDir = segments[0];
|
||||
const startComponent = root + firstDir;
|
||||
|
||||
const endPartSegments: string[] = [];
|
||||
// Base length: startComponent + separator + "..."
|
||||
let currentLength = startComponent.length + separator.length + 3;
|
||||
|
||||
// Iterate backwards through segments (excluding the first one)
|
||||
for (let i = segments.length - 1; i >= 1; i--) {
|
||||
const segment = segments[i];
|
||||
// Length needed if we add this segment: current + separator + segment
|
||||
const lengthWithSegment = currentLength + separator.length + segment.length;
|
||||
|
||||
if (lengthWithSegment <= maxLen) {
|
||||
endPartSegments.unshift(segment); // Add to the beginning of the end part
|
||||
currentLength = lengthWithSegment;
|
||||
} else {
|
||||
// Adding this segment would exceed maxLen
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Construct the final path
|
||||
let result = startComponent + separator + '...';
|
||||
if (endPartSegments.length > 0) {
|
||||
result += separator + endPartSegments.join(separator);
|
||||
}
|
||||
|
||||
// As a final check, if the result is somehow still too long (e.g., startComponent + ... is too long)
|
||||
// fallback to simple truncation of the original path
|
||||
if (result.length > maxLen) {
|
||||
const keepLen = Math.floor((maxLen - 3) / 2);
|
||||
if (keepLen <= 0) {
|
||||
return filePath.substring(0, maxLen - 3) + '...';
|
||||
}
|
||||
const start = filePath.substring(0, keepLen);
|
||||
const end = filePath.substring(filePath.length - keepLen);
|
||||
return `${start}...${end}`;
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* Calculates the relative path from a root directory to a target path.
|
||||
* Ensures both paths are resolved before calculating.
|
||||
* Returns '.' if the target path is the same as the root directory.
|
||||
*
|
||||
* @param targetPath The absolute or relative path to make relative.
|
||||
* @param rootDirectory The absolute path of the directory to make the target path relative to.
|
||||
* @returns The relative path from rootDirectory to targetPath.
|
||||
*/
|
||||
export function makeRelative(
|
||||
targetPath: string,
|
||||
rootDirectory: string,
|
||||
): string {
|
||||
const resolvedTargetPath = path.resolve(targetPath);
|
||||
const resolvedRootDirectory = path.resolve(rootDirectory);
|
||||
|
||||
const relativePath = path.relative(resolvedRootDirectory, resolvedTargetPath);
|
||||
|
||||
// If the paths are the same, path.relative returns '', return '.' instead
|
||||
return relativePath || '.';
|
||||
}
|
||||
|
||||
/**
|
||||
* Escapes spaces in a file path.
|
||||
*/
|
||||
export function escapePath(filePath: string): string {
|
||||
let result = '';
|
||||
for (let i = 0; i < filePath.length; i++) {
|
||||
// Only escape spaces that are not already escaped.
|
||||
if (filePath[i] === ' ' && (i === 0 || filePath[i - 1] !== '\\')) {
|
||||
result += '\\ ';
|
||||
} else {
|
||||
result += filePath[i];
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* Unescapes spaces in a file path.
|
||||
*/
|
||||
export function unescapePath(filePath: string): string {
|
||||
return filePath.replace(/\\ /g, ' ');
|
||||
}
|
||||
238
packages/core/src/utils/retry.test.ts
Normal file
238
packages/core/src/utils/retry.test.ts
Normal file
@@ -0,0 +1,238 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
/* eslint-disable @typescript-eslint/no-explicit-any */
|
||||
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
|
||||
import { retryWithBackoff } from './retry.js';
|
||||
|
||||
// Define an interface for the error with a status property
|
||||
interface HttpError extends Error {
|
||||
status?: number;
|
||||
}
|
||||
|
||||
// Helper to create a mock function that fails a certain number of times
|
||||
const createFailingFunction = (
|
||||
failures: number,
|
||||
successValue: string = 'success',
|
||||
) => {
|
||||
let attempts = 0;
|
||||
return vi.fn(async () => {
|
||||
attempts++;
|
||||
if (attempts <= failures) {
|
||||
// Simulate a retryable error
|
||||
const error: HttpError = new Error(`Simulated error attempt ${attempts}`);
|
||||
error.status = 500; // Simulate a server error
|
||||
throw error;
|
||||
}
|
||||
return successValue;
|
||||
});
|
||||
};
|
||||
|
||||
// Custom error for testing non-retryable conditions
|
||||
class NonRetryableError extends Error {
|
||||
constructor(message: string) {
|
||||
super(message);
|
||||
this.name = 'NonRetryableError';
|
||||
}
|
||||
}
|
||||
|
||||
describe('retryWithBackoff', () => {
|
||||
beforeEach(() => {
|
||||
vi.useFakeTimers();
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.restoreAllMocks();
|
||||
});
|
||||
|
||||
it('should return the result on the first attempt if successful', async () => {
|
||||
const mockFn = createFailingFunction(0);
|
||||
const result = await retryWithBackoff(mockFn);
|
||||
expect(result).toBe('success');
|
||||
expect(mockFn).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
|
||||
it('should retry and succeed if failures are within maxAttempts', async () => {
|
||||
const mockFn = createFailingFunction(2);
|
||||
const promise = retryWithBackoff(mockFn, {
|
||||
maxAttempts: 3,
|
||||
initialDelayMs: 10,
|
||||
});
|
||||
|
||||
await vi.runAllTimersAsync(); // Ensure all delays and retries complete
|
||||
|
||||
const result = await promise;
|
||||
expect(result).toBe('success');
|
||||
expect(mockFn).toHaveBeenCalledTimes(3);
|
||||
});
|
||||
|
||||
it('should throw an error if all attempts fail', async () => {
|
||||
const mockFn = createFailingFunction(3);
|
||||
|
||||
// 1. Start the retryable operation, which returns a promise.
|
||||
const promise = retryWithBackoff(mockFn, {
|
||||
maxAttempts: 3,
|
||||
initialDelayMs: 10,
|
||||
});
|
||||
|
||||
// 2. IMPORTANT: Attach the rejection expectation to the promise *immediately*.
|
||||
// This ensures a 'catch' handler is present before the promise can reject.
|
||||
// The result is a new promise that resolves when the assertion is met.
|
||||
const assertionPromise = expect(promise).rejects.toThrow(
|
||||
'Simulated error attempt 3',
|
||||
);
|
||||
|
||||
// 3. Now, advance the timers. This will trigger the retries and the
|
||||
// eventual rejection. The handler attached in step 2 will catch it.
|
||||
await vi.runAllTimersAsync();
|
||||
|
||||
// 4. Await the assertion promise itself to ensure the test was successful.
|
||||
await assertionPromise;
|
||||
|
||||
// 5. Finally, assert the number of calls.
|
||||
expect(mockFn).toHaveBeenCalledTimes(3);
|
||||
});
|
||||
|
||||
it('should not retry if shouldRetry returns false', async () => {
|
||||
const mockFn = vi.fn(async () => {
|
||||
throw new NonRetryableError('Non-retryable error');
|
||||
});
|
||||
const shouldRetry = (error: Error) => !(error instanceof NonRetryableError);
|
||||
|
||||
const promise = retryWithBackoff(mockFn, {
|
||||
shouldRetry,
|
||||
initialDelayMs: 10,
|
||||
});
|
||||
|
||||
await expect(promise).rejects.toThrow('Non-retryable error');
|
||||
expect(mockFn).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
|
||||
it('should use default shouldRetry if not provided, retrying on 429', async () => {
|
||||
const mockFn = vi.fn(async () => {
|
||||
const error = new Error('Too Many Requests') as any;
|
||||
error.status = 429;
|
||||
throw error;
|
||||
});
|
||||
|
||||
const promise = retryWithBackoff(mockFn, {
|
||||
maxAttempts: 2,
|
||||
initialDelayMs: 10,
|
||||
});
|
||||
|
||||
// Attach the rejection expectation *before* running timers
|
||||
const assertionPromise =
|
||||
expect(promise).rejects.toThrow('Too Many Requests');
|
||||
|
||||
// Run timers to trigger retries and eventual rejection
|
||||
await vi.runAllTimersAsync();
|
||||
|
||||
// Await the assertion
|
||||
await assertionPromise;
|
||||
|
||||
expect(mockFn).toHaveBeenCalledTimes(2);
|
||||
});
|
||||
|
||||
it('should use default shouldRetry if not provided, not retrying on 400', async () => {
|
||||
const mockFn = vi.fn(async () => {
|
||||
const error = new Error('Bad Request') as any;
|
||||
error.status = 400;
|
||||
throw error;
|
||||
});
|
||||
|
||||
const promise = retryWithBackoff(mockFn, {
|
||||
maxAttempts: 2,
|
||||
initialDelayMs: 10,
|
||||
});
|
||||
await expect(promise).rejects.toThrow('Bad Request');
|
||||
expect(mockFn).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
|
||||
it('should respect maxDelayMs', async () => {
|
||||
const mockFn = createFailingFunction(3);
|
||||
const setTimeoutSpy = vi.spyOn(global, 'setTimeout');
|
||||
|
||||
const promise = retryWithBackoff(mockFn, {
|
||||
maxAttempts: 4,
|
||||
initialDelayMs: 100,
|
||||
maxDelayMs: 250, // Max delay is less than 100 * 2 * 2 = 400
|
||||
});
|
||||
|
||||
await vi.advanceTimersByTimeAsync(1000); // Advance well past all delays
|
||||
await promise;
|
||||
|
||||
const delays = setTimeoutSpy.mock.calls.map((call) => call[1] as number);
|
||||
|
||||
// Delays should be around initial, initial*2, maxDelay (due to cap)
|
||||
// Jitter makes exact assertion hard, so we check ranges / caps
|
||||
expect(delays.length).toBe(3);
|
||||
expect(delays[0]).toBeGreaterThanOrEqual(100 * 0.7);
|
||||
expect(delays[0]).toBeLessThanOrEqual(100 * 1.3);
|
||||
expect(delays[1]).toBeGreaterThanOrEqual(200 * 0.7);
|
||||
expect(delays[1]).toBeLessThanOrEqual(200 * 1.3);
|
||||
// The third delay should be capped by maxDelayMs (250ms), accounting for jitter
|
||||
expect(delays[2]).toBeGreaterThanOrEqual(250 * 0.7);
|
||||
expect(delays[2]).toBeLessThanOrEqual(250 * 1.3);
|
||||
|
||||
setTimeoutSpy.mockRestore();
|
||||
});
|
||||
|
||||
it('should handle jitter correctly, ensuring varied delays', async () => {
|
||||
let mockFn = createFailingFunction(5);
|
||||
const setTimeoutSpy = vi.spyOn(global, 'setTimeout');
|
||||
|
||||
// Run retryWithBackoff multiple times to observe jitter
|
||||
const runRetry = () =>
|
||||
retryWithBackoff(mockFn, {
|
||||
maxAttempts: 2, // Only one retry, so one delay
|
||||
initialDelayMs: 100,
|
||||
maxDelayMs: 1000,
|
||||
});
|
||||
|
||||
// We expect rejections as mockFn fails 5 times
|
||||
const promise1 = runRetry();
|
||||
// Attach the rejection expectation *before* running timers
|
||||
const assertionPromise1 = expect(promise1).rejects.toThrow();
|
||||
await vi.runAllTimersAsync(); // Advance for the delay in the first runRetry
|
||||
await assertionPromise1;
|
||||
|
||||
const firstDelaySet = setTimeoutSpy.mock.calls.map(
|
||||
(call) => call[1] as number,
|
||||
);
|
||||
setTimeoutSpy.mockClear(); // Clear calls for the next run
|
||||
|
||||
// Reset mockFn to reset its internal attempt counter for the next run
|
||||
mockFn = createFailingFunction(5); // Re-initialize with 5 failures
|
||||
|
||||
const promise2 = runRetry();
|
||||
// Attach the rejection expectation *before* running timers
|
||||
const assertionPromise2 = expect(promise2).rejects.toThrow();
|
||||
await vi.runAllTimersAsync(); // Advance for the delay in the second runRetry
|
||||
await assertionPromise2;
|
||||
|
||||
const secondDelaySet = setTimeoutSpy.mock.calls.map(
|
||||
(call) => call[1] as number,
|
||||
);
|
||||
|
||||
// Check that the delays are not exactly the same due to jitter
|
||||
// This is a probabilistic test, but with +/-30% jitter, it's highly likely they differ.
|
||||
if (firstDelaySet.length > 0 && secondDelaySet.length > 0) {
|
||||
// Check the first delay of each set
|
||||
expect(firstDelaySet[0]).not.toBe(secondDelaySet[0]);
|
||||
} else {
|
||||
// If somehow no delays were captured (e.g. test setup issue), fail explicitly
|
||||
throw new Error('Delays were not captured for jitter test');
|
||||
}
|
||||
|
||||
// Ensure delays are within the expected jitter range [70, 130] for initialDelayMs = 100
|
||||
[...firstDelaySet, ...secondDelaySet].forEach((d) => {
|
||||
expect(d).toBeGreaterThanOrEqual(100 * 0.7);
|
||||
expect(d).toBeLessThanOrEqual(100 * 1.3);
|
||||
});
|
||||
|
||||
setTimeoutSpy.mockRestore();
|
||||
});
|
||||
});
|
||||
227
packages/core/src/utils/retry.ts
Normal file
227
packages/core/src/utils/retry.ts
Normal file
@@ -0,0 +1,227 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
export interface RetryOptions {
|
||||
maxAttempts: number;
|
||||
initialDelayMs: number;
|
||||
maxDelayMs: number;
|
||||
shouldRetry: (error: Error) => boolean;
|
||||
}
|
||||
|
||||
const DEFAULT_RETRY_OPTIONS: RetryOptions = {
|
||||
maxAttempts: 5,
|
||||
initialDelayMs: 5000,
|
||||
maxDelayMs: 30000, // 30 seconds
|
||||
shouldRetry: defaultShouldRetry,
|
||||
};
|
||||
|
||||
/**
|
||||
* Default predicate function to determine if a retry should be attempted.
|
||||
* Retries on 429 (Too Many Requests) and 5xx server errors.
|
||||
* @param error The error object.
|
||||
* @returns True if the error is a transient error, false otherwise.
|
||||
*/
|
||||
function defaultShouldRetry(error: Error | unknown): boolean {
|
||||
// Check for common transient error status codes either in message or a status property
|
||||
if (error && typeof (error as { status?: number }).status === 'number') {
|
||||
const status = (error as { status: number }).status;
|
||||
if (status === 429 || (status >= 500 && status < 600)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
if (error instanceof Error && error.message) {
|
||||
if (error.message.includes('429')) return true;
|
||||
if (error.message.match(/5\d{2}/)) return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
/**
|
||||
* Delays execution for a specified number of milliseconds.
|
||||
* @param ms The number of milliseconds to delay.
|
||||
* @returns A promise that resolves after the delay.
|
||||
*/
|
||||
function delay(ms: number): Promise<void> {
|
||||
return new Promise((resolve) => setTimeout(resolve, ms));
|
||||
}
|
||||
|
||||
/**
|
||||
* Retries a function with exponential backoff and jitter.
|
||||
* @param fn The asynchronous function to retry.
|
||||
* @param options Optional retry configuration.
|
||||
* @returns A promise that resolves with the result of the function if successful.
|
||||
* @throws The last error encountered if all attempts fail.
|
||||
*/
|
||||
export async function retryWithBackoff<T>(
|
||||
fn: () => Promise<T>,
|
||||
options?: Partial<RetryOptions>,
|
||||
): Promise<T> {
|
||||
const { maxAttempts, initialDelayMs, maxDelayMs, shouldRetry } = {
|
||||
...DEFAULT_RETRY_OPTIONS,
|
||||
...options,
|
||||
};
|
||||
|
||||
let attempt = 0;
|
||||
let currentDelay = initialDelayMs;
|
||||
|
||||
while (attempt < maxAttempts) {
|
||||
attempt++;
|
||||
try {
|
||||
return await fn();
|
||||
} catch (error) {
|
||||
if (attempt >= maxAttempts || !shouldRetry(error as Error)) {
|
||||
throw error;
|
||||
}
|
||||
|
||||
const { delayDurationMs, errorStatus } = getDelayDurationAndStatus(error);
|
||||
|
||||
if (delayDurationMs > 0) {
|
||||
// Respect Retry-After header if present and parsed
|
||||
console.warn(
|
||||
`Attempt ${attempt} failed with status ${errorStatus ?? 'unknown'}. Retrying after explicit delay of ${delayDurationMs}ms...`,
|
||||
error,
|
||||
);
|
||||
await delay(delayDurationMs);
|
||||
// Reset currentDelay for next potential non-429 error, or if Retry-After is not present next time
|
||||
currentDelay = initialDelayMs;
|
||||
} else {
|
||||
// Fallback to exponential backoff with jitter
|
||||
logRetryAttempt(attempt, error, errorStatus);
|
||||
// Add jitter: +/- 30% of currentDelay
|
||||
const jitter = currentDelay * 0.3 * (Math.random() * 2 - 1);
|
||||
const delayWithJitter = Math.max(0, currentDelay + jitter);
|
||||
await delay(delayWithJitter);
|
||||
currentDelay = Math.min(maxDelayMs, currentDelay * 2);
|
||||
}
|
||||
}
|
||||
}
|
||||
// This line should theoretically be unreachable due to the throw in the catch block.
|
||||
// Added for type safety and to satisfy the compiler that a promise is always returned.
|
||||
throw new Error('Retry attempts exhausted');
|
||||
}
|
||||
|
||||
/**
|
||||
* Extracts the HTTP status code from an error object.
|
||||
* @param error The error object.
|
||||
* @returns The HTTP status code, or undefined if not found.
|
||||
*/
|
||||
function getErrorStatus(error: unknown): number | undefined {
|
||||
if (typeof error === 'object' && error !== null) {
|
||||
if ('status' in error && typeof error.status === 'number') {
|
||||
return error.status;
|
||||
}
|
||||
// Check for error.response.status (common in axios errors)
|
||||
if (
|
||||
'response' in error &&
|
||||
typeof (error as { response?: unknown }).response === 'object' &&
|
||||
(error as { response?: unknown }).response !== null
|
||||
) {
|
||||
const response = (
|
||||
error as { response: { status?: unknown; headers?: unknown } }
|
||||
).response;
|
||||
if ('status' in response && typeof response.status === 'number') {
|
||||
return response.status;
|
||||
}
|
||||
}
|
||||
}
|
||||
return undefined;
|
||||
}
|
||||
|
||||
/**
|
||||
* Extracts the Retry-After delay from an error object's headers.
|
||||
* @param error The error object.
|
||||
* @returns The delay in milliseconds, or 0 if not found or invalid.
|
||||
*/
|
||||
function getRetryAfterDelayMs(error: unknown): number {
|
||||
if (typeof error === 'object' && error !== null) {
|
||||
// Check for error.response.headers (common in axios errors)
|
||||
if (
|
||||
'response' in error &&
|
||||
typeof (error as { response?: unknown }).response === 'object' &&
|
||||
(error as { response?: unknown }).response !== null
|
||||
) {
|
||||
const response = (error as { response: { headers?: unknown } }).response;
|
||||
if (
|
||||
'headers' in response &&
|
||||
typeof response.headers === 'object' &&
|
||||
response.headers !== null
|
||||
) {
|
||||
const headers = response.headers as { 'retry-after'?: unknown };
|
||||
const retryAfterHeader = headers['retry-after'];
|
||||
if (typeof retryAfterHeader === 'string') {
|
||||
const retryAfterSeconds = parseInt(retryAfterHeader, 10);
|
||||
if (!isNaN(retryAfterSeconds)) {
|
||||
return retryAfterSeconds * 1000;
|
||||
}
|
||||
// It might be an HTTP date
|
||||
const retryAfterDate = new Date(retryAfterHeader);
|
||||
if (!isNaN(retryAfterDate.getTime())) {
|
||||
return Math.max(0, retryAfterDate.getTime() - Date.now());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
/**
|
||||
* Determines the delay duration based on the error, prioritizing Retry-After header.
|
||||
* @param error The error object.
|
||||
* @returns An object containing the delay duration in milliseconds and the error status.
|
||||
*/
|
||||
function getDelayDurationAndStatus(error: unknown): {
|
||||
delayDurationMs: number;
|
||||
errorStatus: number | undefined;
|
||||
} {
|
||||
const errorStatus = getErrorStatus(error);
|
||||
let delayDurationMs = 0;
|
||||
|
||||
if (errorStatus === 429) {
|
||||
delayDurationMs = getRetryAfterDelayMs(error);
|
||||
}
|
||||
return { delayDurationMs, errorStatus };
|
||||
}
|
||||
|
||||
/**
|
||||
* Logs a message for a retry attempt when using exponential backoff.
|
||||
* @param attempt The current attempt number.
|
||||
* @param error The error that caused the retry.
|
||||
* @param errorStatus The HTTP status code of the error, if available.
|
||||
*/
|
||||
function logRetryAttempt(
|
||||
attempt: number,
|
||||
error: unknown,
|
||||
errorStatus?: number,
|
||||
): void {
|
||||
let message = `Attempt ${attempt} failed. Retrying with backoff...`;
|
||||
if (errorStatus) {
|
||||
message = `Attempt ${attempt} failed with status ${errorStatus}. Retrying with backoff...`;
|
||||
}
|
||||
|
||||
if (errorStatus === 429) {
|
||||
console.warn(message, error);
|
||||
} else if (errorStatus && errorStatus >= 500 && errorStatus < 600) {
|
||||
console.error(message, error);
|
||||
} else if (error instanceof Error) {
|
||||
// Fallback for errors that might not have a status but have a message
|
||||
if (error.message.includes('429')) {
|
||||
console.warn(
|
||||
`Attempt ${attempt} failed with 429 error (no Retry-After header). Retrying with backoff...`,
|
||||
error,
|
||||
);
|
||||
} else if (error.message.match(/5\d{2}/)) {
|
||||
console.error(
|
||||
`Attempt ${attempt} failed with 5xx error. Retrying with backoff...`,
|
||||
error,
|
||||
);
|
||||
} else {
|
||||
console.warn(message, error); // Default to warn for other errors
|
||||
}
|
||||
} else {
|
||||
console.warn(message, error); // Default to warn if error type is unknown
|
||||
}
|
||||
}
|
||||
58
packages/core/src/utils/schemaValidator.ts
Normal file
58
packages/core/src/utils/schemaValidator.ts
Normal file
@@ -0,0 +1,58 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
/**
|
||||
* Simple utility to validate objects against JSON Schemas
|
||||
*/
|
||||
export class SchemaValidator {
|
||||
/**
|
||||
* Validates data against a JSON schema
|
||||
* @param schema JSON Schema to validate against
|
||||
* @param data Data to validate
|
||||
* @returns True if valid, false otherwise
|
||||
*/
|
||||
static validate(schema: Record<string, unknown>, data: unknown): boolean {
|
||||
// This is a simplified implementation
|
||||
// In a real application, you would use a library like Ajv for proper validation
|
||||
|
||||
// Check for required fields
|
||||
if (schema.required && Array.isArray(schema.required)) {
|
||||
const required = schema.required as string[];
|
||||
const dataObj = data as Record<string, unknown>;
|
||||
|
||||
for (const field of required) {
|
||||
if (dataObj[field] === undefined) {
|
||||
console.error(`Missing required field: ${field}`);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check property types if properties are defined
|
||||
if (schema.properties && typeof schema.properties === 'object') {
|
||||
const properties = schema.properties as Record<string, { type?: string }>;
|
||||
const dataObj = data as Record<string, unknown>;
|
||||
|
||||
for (const [key, prop] of Object.entries(properties)) {
|
||||
if (dataObj[key] !== undefined && prop.type) {
|
||||
const expectedType = prop.type;
|
||||
const actualType = Array.isArray(dataObj[key])
|
||||
? 'array'
|
||||
: typeof dataObj[key];
|
||||
|
||||
if (expectedType !== actualType) {
|
||||
console.error(
|
||||
`Type mismatch for property "${key}": expected ${expectedType}, got ${actualType}`,
|
||||
);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user