chore: sync gemini-cli v0.1.19

This commit is contained in:
tanzhenxin
2025-08-18 19:55:46 +08:00
244 changed files with 19407 additions and 5030 deletions

View File

@@ -9,8 +9,10 @@ import {
toGenerateContentRequest,
fromGenerateContentResponse,
CaGenerateContentResponse,
toContents,
} from './converter.js';
import {
ContentListUnion,
GenerateContentParameters,
GenerateContentResponse,
FinishReason,
@@ -295,4 +297,57 @@ describe('converter', () => {
);
});
});
describe('toContents', () => {
it('should handle Content', () => {
const content: ContentListUnion = {
role: 'user',
parts: [{ text: 'hello' }],
};
expect(toContents(content)).toEqual([
{ role: 'user', parts: [{ text: 'hello' }] },
]);
});
it('should handle array of Contents', () => {
const contents: ContentListUnion = [
{ role: 'user', parts: [{ text: 'hello' }] },
{ role: 'model', parts: [{ text: 'hi' }] },
];
expect(toContents(contents)).toEqual([
{ role: 'user', parts: [{ text: 'hello' }] },
{ role: 'model', parts: [{ text: 'hi' }] },
]);
});
it('should handle Part', () => {
const part: ContentListUnion = { text: 'a part' };
expect(toContents(part)).toEqual([
{ role: 'user', parts: [{ text: 'a part' }] },
]);
});
it('should handle array of Parts', () => {
const parts = [{ text: 'part 1' }, 'part 2'];
expect(toContents(parts)).toEqual([
{ role: 'user', parts: [{ text: 'part 1' }] },
{ role: 'user', parts: [{ text: 'part 2' }] },
]);
});
it('should handle string', () => {
const str: ContentListUnion = 'a string';
expect(toContents(str)).toEqual([
{ role: 'user', parts: [{ text: 'a string' }] },
]);
});
it('should handle array of strings', () => {
const strings: ContentListUnion = ['string 1', 'string 2'];
expect(toContents(strings)).toEqual([
{ role: 'user', parts: [{ text: 'string 1' }] },
{ role: 'user', parts: [{ text: 'string 2' }] },
]);
});
});
});

View File

@@ -22,7 +22,6 @@ import {
Part,
SafetySetting,
PartUnion,
SchemaUnion,
SpeechConfigUnion,
ThinkingConfig,
ToolListUnion,
@@ -61,7 +60,7 @@ interface VertexGenerationConfig {
frequencyPenalty?: number;
seed?: number;
responseMimeType?: string;
responseSchema?: SchemaUnion;
responseJsonSchema?: unknown;
routingConfig?: GenerationConfigRoutingConfig;
modelSelectionConfig?: ModelSelectionConfig;
responseModalities?: string[];
@@ -157,7 +156,7 @@ function toVertexGenerateContentRequest(
};
}
function toContents(contents: ContentListUnion): Content[] {
export function toContents(contents: ContentListUnion): Content[] {
if (Array.isArray(contents)) {
// it's a Content[] or a PartsUnion[]
return contents.map(toContent);
@@ -230,7 +229,7 @@ function toVertexGenerationConfig(
frequencyPenalty: config.frequencyPenalty,
seed: config.seed,
responseMimeType: config.responseMimeType,
responseSchema: config.responseSchema,
responseJsonSchema: config.responseJsonSchema,
routingConfig: config.routingConfig,
modelSelectionConfig: config.modelSelectionConfig,
responseModalities: config.responseModalities,

View File

@@ -366,7 +366,7 @@ async function cacheCredentials(credentials: Credentials) {
await fs.mkdir(path.dirname(filePath), { recursive: true });
const credString = JSON.stringify(credentials, null, 2);
await fs.writeFile(filePath, credString);
await fs.writeFile(filePath, credString, { mode: 0o600 });
}
function getCachedCredentialPath(): string {

View File

@@ -4,7 +4,7 @@
* SPDX-License-Identifier: Apache-2.0
*/
import { describe, it, expect, vi } from 'vitest';
import { beforeEach, describe, it, expect, vi } from 'vitest';
import { CodeAssistServer } from './server.js';
import { OAuth2Client } from 'google-auth-library';
import { UserTierId } from './types.js';
@@ -12,6 +12,10 @@ import { UserTierId } from './types.js';
vi.mock('google-auth-library');
describe('CodeAssistServer', () => {
beforeEach(() => {
vi.resetAllMocks();
});
it('should be able to be constructed', () => {
const auth = new OAuth2Client();
const server = new CodeAssistServer(

View File

@@ -53,9 +53,6 @@ export async function setupUser(client: OAuth2Client): Promise<UserData> {
}
const tier = getOnboardTier(loadRes);
if (tier.userDefinedCloudaicompanionProject && !projectId) {
throw new ProjectIdRequiredError();
}
const onboardReq: OnboardUserRequest = {
tierId: tier.id,
@@ -69,8 +66,13 @@ export async function setupUser(client: OAuth2Client): Promise<UserData> {
await new Promise((f) => setTimeout(f, 5000));
lroRes = await caServer.onboardUser(onboardReq);
}
if (!lroRes.response?.cloudaicompanionProject?.id && !projectId) {
throw new ProjectIdRequiredError();
}
return {
projectId: lroRes.response?.cloudaicompanionProject?.id || '',
projectId: lroRes.response?.cloudaicompanionProject?.id || projectId!,
userTier: tier.id,
};
}

View File

@@ -162,6 +162,18 @@ describe('Server Config (config.ts)', () => {
await expect(config.initialize()).resolves.toBeUndefined();
});
it('should throw an error if initialized more than once', async () => {
const config = new Config({
...baseParams,
checkpointing: false,
});
await expect(config.initialize()).resolves.toBeUndefined();
await expect(config.initialize()).rejects.toThrow(
'Config was already initialized',
);
});
});
describe('refreshAuth', () => {

View File

@@ -69,6 +69,10 @@ export interface BugCommandSettings {
urlTemplate: string;
}
export interface ChatCompressionSettings {
contextPercentageThreshold?: number;
}
export interface SummarizeToolOutputSettings {
tokenBudget?: number;
}
@@ -189,7 +193,6 @@ export interface ConfigParameters {
extensionContextFilePaths?: string[];
maxSessionTurns?: number;
sessionTokenLimit?: number;
maxFolderItems?: number;
experimentalAcp?: boolean;
listExtensions?: boolean;
extensions?: GeminiCLIExtension[];
@@ -197,6 +200,8 @@ export interface ConfigParameters {
noBrowser?: boolean;
summarizeToolOutput?: Record<string, SummarizeToolOutputSettings>;
ideModeFeature?: boolean;
folderTrustFeature?: boolean;
folderTrust?: boolean;
ideMode?: boolean;
enableOpenAILogging?: boolean;
sampling_params?: Record<string, unknown>;
@@ -213,6 +218,8 @@ export interface ConfigParameters {
loadMemoryFromIncludeDirectories?: boolean;
// Web search providers
tavilyApiKey?: string;
chatCompression?: ChatCompressionSettings;
interactive?: boolean;
}
export class Config {
@@ -257,6 +264,8 @@ export class Config {
private readonly extensionContextFilePaths: string[];
private readonly noBrowser: boolean;
private readonly ideModeFeature: boolean;
private readonly folderTrustFeature: boolean;
private readonly folderTrust: boolean;
private ideMode: boolean;
private ideClient: IdeClient;
private inFallbackMode = false;
@@ -267,7 +276,6 @@ export class Config {
}>;
private readonly maxSessionTurns: number;
private readonly sessionTokenLimit: number;
private readonly maxFolderItems: number;
private readonly listExtensions: boolean;
private readonly _extensions: GeminiCLIExtension[];
private readonly _blockedMcpServers: Array<{
@@ -289,6 +297,9 @@ export class Config {
private readonly cliVersion?: string;
private readonly loadMemoryFromIncludeDirectories: boolean = false;
private readonly tavilyApiKey?: string;
private readonly chatCompression: ChatCompressionSettings | undefined;
private readonly interactive: boolean;
private initialized: boolean = false;
constructor(params: ConfigParameters) {
this.sessionId = params.sessionId;
@@ -343,7 +354,6 @@ export class Config {
this.extensionContextFilePaths = params.extensionContextFilePaths ?? [];
this.maxSessionTurns = params.maxSessionTurns ?? -1;
this.sessionTokenLimit = params.sessionTokenLimit ?? -1;
this.maxFolderItems = params.maxFolderItems ?? 20;
this.experimentalAcp = params.experimentalAcp ?? false;
this.listExtensions = params.listExtensions ?? false;
this._extensions = params.extensions ?? [];
@@ -351,12 +361,10 @@ export class Config {
this.noBrowser = params.noBrowser ?? false;
this.summarizeToolOutput = params.summarizeToolOutput;
this.ideModeFeature = params.ideModeFeature ?? false;
this.folderTrustFeature = params.folderTrustFeature ?? false;
this.folderTrust = params.folderTrust ?? false;
this.ideMode = params.ideMode ?? false;
this.ideClient = IdeClient.getInstance();
if (this.ideMode && this.ideModeFeature) {
this.ideClient.connect();
logIdeConnection(this, new IdeConnectionEvent(IdeConnectionType.START));
}
this.systemPromptMappings = params.systemPromptMappings;
this.enableOpenAILogging = params.enableOpenAILogging ?? false;
this.sampling_params = params.sampling_params;
@@ -365,6 +373,8 @@ export class Config {
this.loadMemoryFromIncludeDirectories =
params.loadMemoryFromIncludeDirectories ?? false;
this.chatCompression = params.chatCompression;
this.interactive = params.interactive ?? false;
// Web search
this.tavilyApiKey = params.tavilyApiKey;
@@ -386,7 +396,14 @@ export class Config {
}
}
/**
* Must only be called once, throws if called again.
*/
async initialize(): Promise<void> {
if (this.initialized) {
throw Error('Config was already initialized');
}
this.initialized = true;
// Initialize centralized FileDiscoveryService
this.getFileService();
if (this.getCheckpointingEnabled()) {
@@ -468,10 +485,6 @@ export class Config {
return this.sessionTokenLimit;
}
getMaxFolderItems(): number {
return this.maxFolderItems;
}
setQuotaErrorOccurred(value: boolean): void {
this.quotaErrorOccurred = value;
}
@@ -718,6 +731,14 @@ export class Config {
return this.ideMode;
}
getFolderTrustFeature(): boolean {
return this.folderTrustFeature;
}
getFolderTrust(): boolean {
return this.folderTrust;
}
setIdeMode(value: boolean): void {
this.ideMode = value;
}
@@ -728,7 +749,7 @@ export class Config {
await this.ideClient.connect();
logIdeConnection(this, new IdeConnectionEvent(IdeConnectionType.SESSION));
} else {
this.ideClient.disconnect();
await this.ideClient.disconnect();
}
}
@@ -762,6 +783,14 @@ export class Config {
return this.systemPromptMappings;
}
getChatCompression(): ChatCompressionSettings | undefined {
return this.chatCompression;
}
isInteractive(): boolean {
return this.interactive;
}
async getGitService(): Promise<GitService> {
if (!this.gitService) {
this.gitService = new GitService(this.targetDir);

View File

@@ -214,7 +214,6 @@ describe('Gemini Client (client.ts)', () => {
getFileService: vi.fn().mockReturnValue(fileService),
getMaxSessionTurns: vi.fn().mockReturnValue(0),
getSessionTokenLimit: vi.fn().mockReturnValue(32000),
getMaxFolderItems: vi.fn().mockReturnValue(20),
getQuotaErrorOccurred: vi.fn().mockReturnValue(false),
setQuotaErrorOccurred: vi.fn(),
getNoBrowser: vi.fn().mockReturnValue(false),
@@ -222,13 +221,14 @@ describe('Gemini Client (client.ts)', () => {
getUsageStatisticsEnabled: vi.fn().mockReturnValue(true),
getIdeModeFeature: vi.fn().mockReturnValue(false),
getIdeMode: vi.fn().mockReturnValue(true),
getDebugMode: vi.fn().mockReturnValue(false),
getWorkspaceContext: vi.fn().mockReturnValue({
getDirectories: vi.fn().mockReturnValue(['/test/dir']),
}),
getGeminiClient: vi.fn(),
setFallbackMode: vi.fn(),
getDebugMode: vi.fn().mockReturnValue(false),
getCliVersion: vi.fn().mockReturnValue('1.0.0'),
getChatCompression: vi.fn().mockReturnValue(undefined),
};
const MockedConfig = vi.mocked(Config, true);
MockedConfig.mockImplementation(
@@ -491,8 +491,7 @@ describe('Gemini Client (client.ts)', () => {
const mockChat = {
addHistory: vi.fn(),
};
// eslint-disable-next-line @typescript-eslint/no-explicit-any
client['chat'] = mockChat as any;
client['chat'] = mockChat as unknown as GeminiChat;
const newContent = {
role: 'user',
@@ -574,14 +573,19 @@ describe('Gemini Client (client.ts)', () => {
expect(newChat).toBe(initialChat);
});
it('should trigger summarization if token count is at threshold', async () => {
it('should trigger summarization if token count is at threshold with contextPercentageThreshold setting', async () => {
const MOCKED_TOKEN_LIMIT = 1000;
const MOCKED_CONTEXT_PERCENTAGE_THRESHOLD = 0.5;
vi.mocked(tokenLimit).mockReturnValue(MOCKED_TOKEN_LIMIT);
vi.spyOn(client['config'], 'getChatCompression').mockReturnValue({
contextPercentageThreshold: MOCKED_CONTEXT_PERCENTAGE_THRESHOLD,
});
mockGetHistory.mockReturnValue([
{ role: 'user', parts: [{ text: '...history...' }] },
]);
const originalTokenCount = 1000 * 0.7;
const originalTokenCount =
MOCKED_TOKEN_LIMIT * MOCKED_CONTEXT_PERCENTAGE_THRESHOLD;
const newTokenCount = 100;
mockCountTokens
@@ -704,7 +708,7 @@ describe('Gemini Client (client.ts)', () => {
});
describe('sendMessageStream', () => {
it('should include IDE context when ideModeFeature is enabled', async () => {
it('should include editor context when ideModeFeature is enabled', async () => {
// Arrange
vi.mocked(ideContext.getIdeContext).mockReturnValue({
workspaceState: {
@@ -762,21 +766,30 @@ describe('Gemini Client (client.ts)', () => {
// Assert
expect(ideContext.getIdeContext).toHaveBeenCalled();
const expectedContext = `
This is the file that the user is looking at:
- Path: /path/to/active/file.ts
This is the cursor position in the file:
- Cursor Position: Line 5, Character 10
This is the selected text in the file:
- hello
Here are some other files the user has open, with the most recent at the top:
- /path/to/recent/file1.ts
- /path/to/recent/file2.ts
Here is the user's editor context as a JSON object. This is for your information only.
\`\`\`json
${JSON.stringify(
{
activeFile: {
path: '/path/to/active/file.ts',
cursor: {
line: 5,
character: 10,
},
selectedText: 'hello',
},
otherOpenFiles: ['/path/to/recent/file1.ts', '/path/to/recent/file2.ts'],
},
null,
2,
)}
\`\`\`
`.trim();
const expectedRequest = [{ text: expectedContext }, ...initialRequest];
expect(mockTurnRunFn).toHaveBeenCalledWith(
expectedRequest,
expect.any(Object),
);
const expectedRequest = [{ text: expectedContext }];
expect(mockChat.addHistory).toHaveBeenCalledWith({
role: 'user',
parts: expectedRequest,
});
});
it('should not add context if ideModeFeature is enabled but no open files', async () => {
@@ -876,18 +889,29 @@ Here are some other files the user has open, with the most recent at the top:
// Assert
expect(ideContext.getIdeContext).toHaveBeenCalled();
const expectedContext = `
This is the file that the user is looking at:
- Path: /path/to/active/file.ts
This is the cursor position in the file:
- Cursor Position: Line 5, Character 10
This is the selected text in the file:
- hello
Here is the user's editor context as a JSON object. This is for your information only.
\`\`\`json
${JSON.stringify(
{
activeFile: {
path: '/path/to/active/file.ts',
cursor: {
line: 5,
character: 10,
},
selectedText: 'hello',
},
},
null,
2,
)}
\`\`\`
`.trim();
const expectedRequest = [{ text: expectedContext }, ...initialRequest];
expect(mockTurnRunFn).toHaveBeenCalledWith(
expectedRequest,
expect.any(Object),
);
const expectedRequest = [{ text: expectedContext }];
expect(mockChat.addHistory).toHaveBeenCalledWith({
role: 'user',
parts: expectedRequest,
});
});
it('should add context if ideModeFeature is enabled and there are open files but no active file', async () => {
@@ -941,15 +965,22 @@ This is the selected text in the file:
// Assert
expect(ideContext.getIdeContext).toHaveBeenCalled();
const expectedContext = `
Here are some files the user has open, with the most recent at the top:
- /path/to/recent/file1.ts
- /path/to/recent/file2.ts
Here is the user's editor context as a JSON object. This is for your information only.
\`\`\`json
${JSON.stringify(
{
otherOpenFiles: ['/path/to/recent/file1.ts', '/path/to/recent/file2.ts'],
},
null,
2,
)}
\`\`\`
`.trim();
const expectedRequest = [{ text: expectedContext }, ...initialRequest];
expect(mockTurnRunFn).toHaveBeenCalledWith(
expectedRequest,
expect.any(Object),
);
const expectedRequest = [{ text: expectedContext }];
expect(mockChat.addHistory).toHaveBeenCalledWith({
role: 'user',
parts: expectedRequest,
});
});
it('should return the turn instance after the stream is complete', async () => {
@@ -1227,6 +1258,268 @@ Here are some files the user has open, with the most recent at the top:
`${eventCount} events generated (properly bounded by MAX_TURNS)`,
);
});
describe('Editor context delta', () => {
const mockStream = (async function* () {
yield { type: 'content', value: 'Hello' };
})();
beforeEach(() => {
client['forceFullIdeContext'] = false; // Reset before each delta test
vi.spyOn(client, 'tryCompressChat').mockResolvedValue(null);
vi.spyOn(client['config'], 'getIdeModeFeature').mockReturnValue(true);
mockTurnRunFn.mockReturnValue(mockStream);
const mockChat: Partial<GeminiChat> = {
addHistory: vi.fn(),
setHistory: vi.fn(),
sendMessage: vi.fn().mockResolvedValue({ text: 'summary' }),
// Assume history is not empty for delta checks
getHistory: vi
.fn()
.mockReturnValue([
{ role: 'user', parts: [{ text: 'previous message' }] },
]),
};
client['chat'] = mockChat as GeminiChat;
const mockGenerator: Partial<ContentGenerator> = {
countTokens: vi.fn().mockResolvedValue({ totalTokens: 0 }),
generateContent: mockGenerateContentFn,
};
client['contentGenerator'] = mockGenerator as ContentGenerator;
});
const testCases = [
{
description: 'sends delta when active file changes',
previousActiveFile: {
path: '/path/to/old/file.ts',
cursor: { line: 5, character: 10 },
selectedText: 'hello',
},
currentActiveFile: {
path: '/path/to/active/file.ts',
cursor: { line: 5, character: 10 },
selectedText: 'hello',
},
shouldSendContext: true,
},
{
description: 'sends delta when cursor line changes',
previousActiveFile: {
path: '/path/to/active/file.ts',
cursor: { line: 1, character: 10 },
selectedText: 'hello',
},
currentActiveFile: {
path: '/path/to/active/file.ts',
cursor: { line: 5, character: 10 },
selectedText: 'hello',
},
shouldSendContext: true,
},
{
description: 'sends delta when cursor character changes',
previousActiveFile: {
path: '/path/to/active/file.ts',
cursor: { line: 5, character: 1 },
selectedText: 'hello',
},
currentActiveFile: {
path: '/path/to/active/file.ts',
cursor: { line: 5, character: 10 },
selectedText: 'hello',
},
shouldSendContext: true,
},
{
description: 'sends delta when selected text changes',
previousActiveFile: {
path: '/path/to/active/file.ts',
cursor: { line: 5, character: 10 },
selectedText: 'world',
},
currentActiveFile: {
path: '/path/to/active/file.ts',
cursor: { line: 5, character: 10 },
selectedText: 'hello',
},
shouldSendContext: true,
},
{
description: 'sends delta when selected text is added',
previousActiveFile: {
path: '/path/to/active/file.ts',
cursor: { line: 5, character: 10 },
},
currentActiveFile: {
path: '/path/to/active/file.ts',
cursor: { line: 5, character: 10 },
selectedText: 'hello',
},
shouldSendContext: true,
},
{
description: 'sends delta when selected text is removed',
previousActiveFile: {
path: '/path/to/active/file.ts',
cursor: { line: 5, character: 10 },
selectedText: 'hello',
},
currentActiveFile: {
path: '/path/to/active/file.ts',
cursor: { line: 5, character: 10 },
},
shouldSendContext: true,
},
{
description: 'does not send context when nothing changes',
previousActiveFile: {
path: '/path/to/active/file.ts',
cursor: { line: 5, character: 10 },
selectedText: 'hello',
},
currentActiveFile: {
path: '/path/to/active/file.ts',
cursor: { line: 5, character: 10 },
selectedText: 'hello',
},
shouldSendContext: false,
},
];
it.each(testCases)(
'$description',
async ({
previousActiveFile,
currentActiveFile,
shouldSendContext,
}) => {
// Setup previous context
client['lastSentIdeContext'] = {
workspaceState: {
openFiles: [
{
path: previousActiveFile.path,
cursor: previousActiveFile.cursor,
selectedText: previousActiveFile.selectedText,
isActive: true,
timestamp: Date.now() - 1000,
},
],
},
};
// Setup current context
vi.mocked(ideContext.getIdeContext).mockReturnValue({
workspaceState: {
openFiles: [
{ ...currentActiveFile, isActive: true, timestamp: Date.now() },
],
},
});
const stream = client.sendMessageStream(
[{ text: 'Hi' }],
new AbortController().signal,
'prompt-id-delta',
);
for await (const _ of stream) {
// consume stream
}
const mockChat = client['chat'] as unknown as {
addHistory: (typeof vi)['fn'];
};
if (shouldSendContext) {
expect(mockChat.addHistory).toHaveBeenCalledWith(
expect.objectContaining({
parts: expect.arrayContaining([
expect.objectContaining({
text: expect.stringContaining(
"Here is a summary of changes in the user's editor context",
),
}),
]),
}),
);
} else {
expect(mockChat.addHistory).not.toHaveBeenCalled();
}
},
);
it('sends full context when history is cleared, even if editor state is unchanged', async () => {
const activeFile = {
path: '/path/to/active/file.ts',
cursor: { line: 5, character: 10 },
selectedText: 'hello',
};
// Setup previous context
client['lastSentIdeContext'] = {
workspaceState: {
openFiles: [
{
path: activeFile.path,
cursor: activeFile.cursor,
selectedText: activeFile.selectedText,
isActive: true,
timestamp: Date.now() - 1000,
},
],
},
};
// Setup current context (same as previous)
vi.mocked(ideContext.getIdeContext).mockReturnValue({
workspaceState: {
openFiles: [
{ ...activeFile, isActive: true, timestamp: Date.now() },
],
},
});
// Make history empty
const mockChat = client['chat'] as unknown as {
getHistory: ReturnType<(typeof vi)['fn']>;
addHistory: ReturnType<(typeof vi)['fn']>;
};
mockChat.getHistory.mockReturnValue([]);
const stream = client.sendMessageStream(
[{ text: 'Hi' }],
new AbortController().signal,
'prompt-id-history-cleared',
);
for await (const _ of stream) {
// consume stream
}
expect(mockChat.addHistory).toHaveBeenCalledWith(
expect.objectContaining({
parts: expect.arrayContaining([
expect.objectContaining({
text: expect.stringContaining(
"Here is the user's editor context",
),
}),
]),
}),
);
// Also verify it's the full context, not a delta.
const call = mockChat.addHistory.mock.calls[0][0];
const contextText = call.parts[0].text;
const contextJson = JSON.parse(
contextText.match(/```json\n(.*)\n```/s)![1],
);
expect(contextJson).toHaveProperty('activeFile');
expect(contextJson.activeFile.path).toBe('/path/to/active/file.ts');
});
});
});
describe('generateContent', () => {

View File

@@ -7,8 +7,6 @@
import {
EmbedContentParameters,
GenerateContentConfig,
Part,
SchemaUnion,
PartListUnion,
Content,
Tool,
@@ -16,7 +14,10 @@ import {
FunctionDeclaration,
Schema,
} from '@google/genai';
import { getFolderStructure } from '../utils/getFolderStructure.js';
import {
getDirectoryContextString,
getEnvironmentContext,
} from '../utils/environmentContext.js';
import {
Turn,
ServerGeminiStreamEvent,
@@ -26,8 +27,6 @@ import {
import { Config } from '../config/config.js';
import { UserTierId } from '../code_assist/types.js';
import { getCoreSystemPrompt, getCompressionPrompt } from './prompts.js';
import { ReadManyFilesTool } from '../tools/read-many-files.js';
import { getFunctionCalls } from '../utils/generateContentResponseUtilities.js';
import { checkNextSpeaker } from '../utils/nextSpeakerChecker.js';
import { reportError } from '../utils/errorReporting.js';
import { GeminiChat } from './geminiChat.js';
@@ -41,12 +40,14 @@ import {
ContentGeneratorConfig,
createContentGenerator,
} from './contentGenerator.js';
import { getFunctionCalls } from '../utils/generateContentResponseUtilities.js';
import { ProxyAgent, setGlobalDispatcher } from 'undici';
import { DEFAULT_GEMINI_FLASH_MODEL } from '../config/models.js';
import { LoopDetectionService } from '../services/loopDetectionService.js';
import { ideContext } from '../ide/ideContext.js';
import { logNextSpeakerCheck } from '../telemetry/loggers.js';
import { NextSpeakerCheckEvent } from '../telemetry/types.js';
import { IdeContext, File } from '../ide/ideContext.js';
function isThinkingSupported(model: string) {
if (model.startsWith('gemini-2.5')) return true;
@@ -109,6 +110,8 @@ export class GeminiClient {
private readonly loopDetector: LoopDetectionService;
private lastPromptId: string;
private lastSentIdeContext: IdeContext | undefined;
private forceFullIdeContext = true;
constructor(private config: Config) {
if (config.getProxy()) {
@@ -161,6 +164,7 @@ export class GeminiClient {
setHistory(history: Content[]) {
this.getChat().setHistory(history);
this.forceFullIdeContext = true;
}
async setTools(): Promise<void> {
@@ -181,115 +185,13 @@ export class GeminiClient {
this.getChat().addHistory({
role: 'user',
parts: [{ text: await this.getDirectoryContext() }],
parts: [{ text: await getDirectoryContextString(this.config) }],
});
}
private async getDirectoryContext(): Promise<string> {
const workspaceContext = this.config.getWorkspaceContext();
const workspaceDirectories = workspaceContext.getDirectories();
const folderStructures = await Promise.all(
workspaceDirectories.map((dir) =>
getFolderStructure(dir, {
fileService: this.config.getFileService(),
}),
),
);
const folderStructure = folderStructures.join('\n');
const dirList = workspaceDirectories.map((dir) => ` - ${dir}`).join('\n');
const workingDirPreamble = `I'm currently working in the following directories:\n${dirList}\n Folder structures are as follows:\n${folderStructure}`;
return workingDirPreamble;
}
private async getEnvironment(): Promise<Part[]> {
const today = new Date().toLocaleDateString(undefined, {
weekday: 'long',
year: 'numeric',
month: 'long',
day: 'numeric',
});
const platform = process.platform;
const workspaceContext = this.config.getWorkspaceContext();
const workspaceDirectories = workspaceContext.getDirectories();
const folderStructures = await Promise.all(
workspaceDirectories.map((dir) =>
getFolderStructure(dir, {
fileService: this.config.getFileService(),
}),
),
);
const folderStructure = folderStructures.join('\n');
let workingDirPreamble: string;
if (workspaceDirectories.length === 1) {
workingDirPreamble = `I'm currently working in the directory: ${workspaceDirectories[0]}`;
} else {
const dirList = workspaceDirectories
.map((dir) => ` - ${dir}`)
.join('\n');
workingDirPreamble = `I'm currently working in the following directories:\n${dirList}`;
}
const context = `
This is the Qwen Code. We are setting up the context for our chat.
Today's date is ${today}.
My operating system is: ${platform}
${workingDirPreamble}
Here is the folder structure of the current working directories:\n
${folderStructure}
`.trim();
const initialParts: Part[] = [{ text: context }];
const toolRegistry = await this.config.getToolRegistry();
// Add full file context if the flag is set
if (this.config.getFullContext()) {
try {
const readManyFilesTool = toolRegistry.getTool(
'read_many_files',
) as ReadManyFilesTool;
if (readManyFilesTool) {
// Read all files in the target directory
const result = await readManyFilesTool.execute(
{
paths: ['**/*'], // Read everything recursively
useDefaultExcludes: true, // Use default excludes
},
AbortSignal.timeout(30000),
);
if (result.llmContent) {
initialParts.push({
text: `\n--- Full File Context ---\n${result.llmContent}`,
});
} else {
console.warn(
'Full context requested, but read_many_files returned no content.',
);
}
} else {
console.warn(
'Full context requested, but read_many_files tool not found.',
);
}
} catch (error) {
// Not using reportError here as it's a startup/config phase, not a chat/generation phase error.
console.error('Error reading full file context:', error);
initialParts.push({
text: '\n--- Error reading full file context ---',
});
}
}
return initialParts;
}
async startChat(extraHistory?: Content[]): Promise<GeminiChat> {
const envParts = await this.getEnvironment();
this.forceFullIdeContext = true;
const envParts = await getEnvironmentContext(this.config);
const toolRegistry = await this.config.getToolRegistry();
const toolDeclarations = toolRegistry.getFunctionDeclarations();
const tools: Tool[] = [{ functionDeclarations: toolDeclarations }];
@@ -338,6 +240,174 @@ export class GeminiClient {
}
}
private getIdeContextParts(forceFullContext: boolean): {
contextParts: string[];
newIdeContext: IdeContext | undefined;
} {
const currentIdeContext = ideContext.getIdeContext();
if (!currentIdeContext) {
return { contextParts: [], newIdeContext: undefined };
}
if (forceFullContext || !this.lastSentIdeContext) {
// Send full context as JSON
const openFiles = currentIdeContext.workspaceState?.openFiles || [];
const activeFile = openFiles.find((f) => f.isActive);
const otherOpenFiles = openFiles
.filter((f) => !f.isActive)
.map((f) => f.path);
const contextData: Record<string, unknown> = {};
if (activeFile) {
contextData.activeFile = {
path: activeFile.path,
cursor: activeFile.cursor
? {
line: activeFile.cursor.line,
character: activeFile.cursor.character,
}
: undefined,
selectedText: activeFile.selectedText || undefined,
};
}
if (otherOpenFiles.length > 0) {
contextData.otherOpenFiles = otherOpenFiles;
}
if (Object.keys(contextData).length === 0) {
return { contextParts: [], newIdeContext: currentIdeContext };
}
const jsonString = JSON.stringify(contextData, null, 2);
const contextParts = [
"Here is the user's editor context as a JSON object. This is for your information only.",
'```json',
jsonString,
'```',
];
if (this.config.getDebugMode()) {
console.log(contextParts.join('\n'));
}
return {
contextParts,
newIdeContext: currentIdeContext,
};
} else {
// Calculate and send delta as JSON
const delta: Record<string, unknown> = {};
const changes: Record<string, unknown> = {};
const lastFiles = new Map(
(this.lastSentIdeContext.workspaceState?.openFiles || []).map(
(f: File) => [f.path, f],
),
);
const currentFiles = new Map(
(currentIdeContext.workspaceState?.openFiles || []).map((f: File) => [
f.path,
f,
]),
);
const openedFiles: string[] = [];
for (const [path] of currentFiles.entries()) {
if (!lastFiles.has(path)) {
openedFiles.push(path);
}
}
if (openedFiles.length > 0) {
changes.filesOpened = openedFiles;
}
const closedFiles: string[] = [];
for (const [path] of lastFiles.entries()) {
if (!currentFiles.has(path)) {
closedFiles.push(path);
}
}
if (closedFiles.length > 0) {
changes.filesClosed = closedFiles;
}
const lastActiveFile = (
this.lastSentIdeContext.workspaceState?.openFiles || []
).find((f: File) => f.isActive);
const currentActiveFile = (
currentIdeContext.workspaceState?.openFiles || []
).find((f: File) => f.isActive);
if (currentActiveFile) {
if (!lastActiveFile || lastActiveFile.path !== currentActiveFile.path) {
changes.activeFileChanged = {
path: currentActiveFile.path,
cursor: currentActiveFile.cursor
? {
line: currentActiveFile.cursor.line,
character: currentActiveFile.cursor.character,
}
: undefined,
selectedText: currentActiveFile.selectedText || undefined,
};
} else {
const lastCursor = lastActiveFile.cursor;
const currentCursor = currentActiveFile.cursor;
if (
currentCursor &&
(!lastCursor ||
lastCursor.line !== currentCursor.line ||
lastCursor.character !== currentCursor.character)
) {
changes.cursorMoved = {
path: currentActiveFile.path,
cursor: {
line: currentCursor.line,
character: currentCursor.character,
},
};
}
const lastSelectedText = lastActiveFile.selectedText || '';
const currentSelectedText = currentActiveFile.selectedText || '';
if (lastSelectedText !== currentSelectedText) {
changes.selectionChanged = {
path: currentActiveFile.path,
selectedText: currentSelectedText,
};
}
}
} else if (lastActiveFile) {
changes.activeFileChanged = {
path: null,
previousPath: lastActiveFile.path,
};
}
if (Object.keys(changes).length === 0) {
return { contextParts: [], newIdeContext: currentIdeContext };
}
delta.changes = changes;
const jsonString = JSON.stringify(delta, null, 2);
const contextParts = [
"Here is a summary of changes in the user's editor context, in JSON format. This is for your information only.",
'```json',
jsonString,
'```',
];
if (this.config.getDebugMode()) {
console.log(contextParts.join('\n'));
}
return {
contextParts,
newIdeContext: currentIdeContext,
};
}
}
async *sendMessageStream(
request: PartListUnion,
signal: AbortSignal,
@@ -379,7 +449,7 @@ export class GeminiClient {
const currentHistory = this.getChat().getHistory(true);
const userMemory = this.config.getUserMemory();
const systemPrompt = getCoreSystemPrompt(userMemory);
const environment = await this.getEnvironment();
const environment = await getEnvironmentContext(this.config);
// Create a mock request content to count total tokens
const mockRequestContent = [
@@ -416,49 +486,17 @@ export class GeminiClient {
}
if (this.config.getIdeModeFeature() && this.config.getIdeMode()) {
const ideContextState = ideContext.getIdeContext();
const openFiles = ideContextState?.workspaceState?.openFiles;
if (openFiles && openFiles.length > 0) {
const contextParts: string[] = [];
const firstFile = openFiles[0];
const activeFile = firstFile.isActive ? firstFile : undefined;
if (activeFile) {
contextParts.push(
`This is the file that the user is looking at:\n- Path: ${activeFile.path}`,
);
if (activeFile.cursor) {
contextParts.push(
`This is the cursor position in the file:\n- Cursor Position: Line ${activeFile.cursor.line}, Character ${activeFile.cursor.character}`,
);
}
if (activeFile.selectedText) {
contextParts.push(
`This is the selected text in the file:\n- ${activeFile.selectedText}`,
);
}
}
const otherOpenFiles = activeFile ? openFiles.slice(1) : openFiles;
if (otherOpenFiles.length > 0) {
const recentFiles = otherOpenFiles
.map((file) => `- ${file.path}`)
.join('\n');
const heading = activeFile
? `Here are some other files the user has open, with the most recent at the top:`
: `Here are some files the user has open, with the most recent at the top:`;
contextParts.push(`${heading}\n${recentFiles}`);
}
if (contextParts.length > 0) {
request = [
{ text: contextParts.join('\n') },
...(Array.isArray(request) ? request : [request]),
];
}
const { contextParts, newIdeContext } = this.getIdeContextParts(
this.forceFullIdeContext || this.getHistory().length === 0,
);
if (contextParts.length > 0) {
this.getChat().addHistory({
role: 'user',
parts: [{ text: contextParts.join('\n') }],
});
}
this.lastSentIdeContext = newIdeContext;
this.forceFullIdeContext = false;
}
const turn = new Turn(this.getChat(), prompt_id);
@@ -517,7 +555,7 @@ export class GeminiClient {
async generateJson(
contents: Content[],
schema: SchemaUnion,
schema: Record<string, unknown>,
abortSignal: AbortSignal,
model?: string,
config: GenerateContentConfig = {},
@@ -723,12 +761,16 @@ export class GeminiClient {
return null;
}
const contextPercentageThreshold =
this.config.getChatCompression()?.contextPercentageThreshold;
// Don't compress if not forced and we are under the limit.
if (
!force &&
originalTokenCount < this.COMPRESSION_TOKEN_THRESHOLD * tokenLimit(model)
) {
return null;
if (!force) {
const threshold =
contextPercentageThreshold ?? this.COMPRESSION_TOKEN_THRESHOLD;
if (originalTokenCount < threshold * tokenLimit(model)) {
return null;
}
}
let compressBeforeIndex = findIndexAfterFraction(
@@ -771,6 +813,7 @@ export class GeminiClient {
},
...historyToKeep,
]);
this.forceFullIdeContext = true;
const { totalTokens: newTokenCount } =
await this.getContentGenerator().countTokens({

View File

@@ -9,10 +9,12 @@ import {
createContentGenerator,
AuthType,
createContentGeneratorConfig,
ContentGenerator,
} from './contentGenerator.js';
import { createCodeAssistContentGenerator } from '../code_assist/codeAssist.js';
import { GoogleGenAI } from '@google/genai';
import { Config } from '../config/config.js';
import { LoggingContentGenerator } from './loggingContentGenerator.js';
vi.mock('../code_assist/codeAssist.js');
vi.mock('@google/genai');
@@ -23,7 +25,7 @@ const mockConfig = {
describe('createContentGenerator', () => {
it('should create a CodeAssistContentGenerator', async () => {
const mockGenerator = {} as unknown;
const mockGenerator = {} as unknown as ContentGenerator;
vi.mocked(createCodeAssistContentGenerator).mockResolvedValue(
mockGenerator as never,
);
@@ -35,13 +37,15 @@ describe('createContentGenerator', () => {
mockConfig,
);
expect(createCodeAssistContentGenerator).toHaveBeenCalled();
expect(generator).toBe(mockGenerator);
expect(generator).toEqual(
new LoggingContentGenerator(mockGenerator, mockConfig),
);
});
it('should create a GoogleGenAI content generator', async () => {
const mockGenerator = {
models: {},
} as unknown;
} as unknown as GoogleGenAI;
vi.mocked(GoogleGenAI).mockImplementation(() => mockGenerator as never);
const generator = await createContentGenerator(
{
@@ -60,7 +64,12 @@ describe('createContentGenerator', () => {
},
},
});
expect(generator).toBe((mockGenerator as GoogleGenAI).models);
expect(generator).toEqual(
new LoggingContentGenerator(
(mockGenerator as GoogleGenAI).models,
mockConfig,
),
);
});
});

View File

@@ -18,6 +18,7 @@ import { DEFAULT_GEMINI_MODEL, DEFAULT_QWEN_MODEL } from '../config/models.js';
import { Config } from '../config/config.js';
import { getEffectiveModel } from './modelCheck.js';
import { UserTierId } from '../code_assist/types.js';
import { LoggingContentGenerator } from './loggingContentGenerator.js';
/**
* Interface abstracting the core functionalities for generating content and counting tokens.
@@ -161,11 +162,14 @@ export async function createContentGenerator(
config.authType === AuthType.LOGIN_WITH_GOOGLE ||
config.authType === AuthType.CLOUD_SHELL
) {
return createCodeAssistContentGenerator(
httpOptions,
config.authType,
return new LoggingContentGenerator(
await createCodeAssistContentGenerator(
httpOptions,
config.authType,
gcConfig,
sessionId,
),
gcConfig,
sessionId,
);
}
@@ -178,8 +182,7 @@ export async function createContentGenerator(
vertexai: config.vertexai,
httpOptions,
});
return googleGenAI.models;
return new LoggingContentGenerator(googleGenAI.models, gcConfig);
}
if (config.authType === AuthType.USE_OPENAI) {

View File

@@ -24,44 +24,15 @@ import {
} from '../index.js';
import { Part, PartListUnion } from '@google/genai';
import { ModifiableTool, ModifyContext } from '../tools/modifiable-tool.js';
class MockTool extends BaseTool<Record<string, unknown>, ToolResult> {
shouldConfirm = false;
executeFn = vi.fn();
constructor(name = 'mockTool') {
super(name, name, 'A mock tool', Icon.Hammer, {});
}
async shouldConfirmExecute(
_params: Record<string, unknown>,
_abortSignal: AbortSignal,
): Promise<ToolCallConfirmationDetails | false> {
if (this.shouldConfirm) {
return {
type: 'exec',
title: 'Confirm Mock Tool',
command: 'do_thing',
rootCommand: 'do_thing',
onConfirm: async () => {},
};
}
return false;
}
async execute(
params: Record<string, unknown>,
_abortSignal: AbortSignal,
): Promise<ToolResult> {
this.executeFn(params);
return { llmContent: 'Tool executed', returnDisplay: 'Tool executed' };
}
}
import {
ModifiableDeclarativeTool,
ModifyContext,
} from '../tools/modifiable-tool.js';
import { MockTool } from '../test-utils/tools.js';
class MockModifiableTool
extends MockTool
implements ModifiableTool<Record<string, unknown>>
implements ModifiableDeclarativeTool<Record<string, unknown>>
{
constructor(name = 'mockModifiableTool') {
super(name);
@@ -83,15 +54,13 @@ class MockModifiableTool
};
}
async shouldConfirmExecute(
_params: Record<string, unknown>,
_abortSignal: AbortSignal,
): Promise<ToolCallConfirmationDetails | false> {
async shouldConfirmExecute(): Promise<ToolCallConfirmationDetails | false> {
if (this.shouldConfirm) {
return {
type: 'edit',
title: 'Confirm Mock Tool',
fileName: 'test.txt',
filePath: 'test.txt',
fileDiff: 'diff',
originalContent: 'originalContent',
newContent: 'newContent',
@@ -106,14 +75,15 @@ describe('CoreToolScheduler', () => {
it('should cancel a tool call if the signal is aborted before confirmation', async () => {
const mockTool = new MockTool();
mockTool.shouldConfirm = true;
const declarativeTool = mockTool;
const toolRegistry = {
getTool: () => mockTool,
getTool: () => declarativeTool,
getFunctionDeclarations: () => [],
tools: new Map(),
discovery: {} as any,
registerTool: () => {},
getToolByName: () => mockTool,
getToolByDisplayName: () => mockTool,
getToolByName: () => declarativeTool,
getToolByDisplayName: () => declarativeTool,
getTools: () => [],
discoverTools: async () => {},
getAllTools: () => [],
@@ -176,14 +146,15 @@ describe('CoreToolScheduler', () => {
describe('CoreToolScheduler with payload', () => {
it('should update args and diff and execute tool when payload is provided', async () => {
const mockTool = new MockModifiableTool();
const declarativeTool = mockTool;
const toolRegistry = {
getTool: () => mockTool,
getTool: () => declarativeTool,
getFunctionDeclarations: () => [],
tools: new Map(),
discovery: {} as any,
registerTool: () => {},
getToolByName: () => mockTool,
getToolByDisplayName: () => mockTool,
getToolByName: () => declarativeTool,
getToolByDisplayName: () => declarativeTool,
getTools: () => [],
discoverTools: async () => {},
getAllTools: () => [],
@@ -220,10 +191,7 @@ describe('CoreToolScheduler with payload', () => {
await scheduler.schedule([request], abortController.signal);
const confirmationDetails = await mockTool.shouldConfirmExecute(
{},
abortController.signal,
);
const confirmationDetails = await mockTool.shouldConfirmExecute();
if (confirmationDetails) {
const payload: ToolConfirmationPayload = { newContent: 'final version' };
@@ -434,6 +402,7 @@ describe('CoreToolScheduler edit cancellation', () => {
type: 'edit',
title: 'Confirm Edit',
fileName: 'test.txt',
filePath: 'test.txt',
fileDiff:
'--- test.txt\n+++ test.txt\n@@ -1,1 +1,1 @@\n-old content\n+new content',
originalContent: 'old content',
@@ -454,14 +423,15 @@ describe('CoreToolScheduler edit cancellation', () => {
}
const mockEditTool = new MockEditTool();
const declarativeTool = mockEditTool;
const toolRegistry = {
getTool: () => mockEditTool,
getTool: () => declarativeTool,
getFunctionDeclarations: () => [],
tools: new Map(),
discovery: {} as any,
registerTool: () => {},
getToolByName: () => mockEditTool,
getToolByDisplayName: () => mockEditTool,
getToolByName: () => declarativeTool,
getToolByDisplayName: () => declarativeTool,
getTools: () => [],
discoverTools: async () => {},
getAllTools: () => [],
@@ -539,18 +509,23 @@ describe('CoreToolScheduler YOLO mode', () => {
it('should execute tool requiring confirmation directly without waiting', async () => {
// Arrange
const mockTool = new MockTool();
mockTool.executeFn.mockReturnValue({
llmContent: 'Tool executed',
returnDisplay: 'Tool executed',
});
// This tool would normally require confirmation.
mockTool.shouldConfirm = true;
const declarativeTool = mockTool;
const toolRegistry = {
getTool: () => mockTool,
getToolByName: () => mockTool,
getTool: () => declarativeTool,
getToolByName: () => declarativeTool,
// Other properties are not needed for this test but are included for type consistency.
getFunctionDeclarations: () => [],
tools: new Map(),
discovery: {} as any,
registerTool: () => {},
getToolByDisplayName: () => mockTool,
getToolByDisplayName: () => declarativeTool,
getTools: () => [],
discoverTools: async () => {},
getAllTools: () => [],
@@ -617,3 +592,195 @@ describe('CoreToolScheduler YOLO mode', () => {
}
});
});
describe('CoreToolScheduler request queueing', () => {
it('should queue a request if another is running', async () => {
let resolveFirstCall: (result: ToolResult) => void;
const firstCallPromise = new Promise<ToolResult>((resolve) => {
resolveFirstCall = resolve;
});
const mockTool = new MockTool();
mockTool.executeFn.mockImplementation(() => firstCallPromise);
const declarativeTool = mockTool;
const toolRegistry = {
getTool: () => declarativeTool,
getToolByName: () => declarativeTool,
getFunctionDeclarations: () => [],
tools: new Map(),
discovery: {} as any,
registerTool: () => {},
getToolByDisplayName: () => declarativeTool,
getTools: () => [],
discoverTools: async () => {},
getAllTools: () => [],
getToolsByServer: () => [],
};
const onAllToolCallsComplete = vi.fn();
const onToolCallsUpdate = vi.fn();
const mockConfig = {
getSessionId: () => 'test-session-id',
getUsageStatisticsEnabled: () => true,
getDebugMode: () => false,
getApprovalMode: () => ApprovalMode.YOLO, // Use YOLO to avoid confirmation prompts
} as unknown as Config;
const scheduler = new CoreToolScheduler({
config: mockConfig,
toolRegistry: Promise.resolve(toolRegistry as any),
onAllToolCallsComplete,
onToolCallsUpdate,
getPreferredEditor: () => 'vscode',
onEditorClose: vi.fn(),
});
const abortController = new AbortController();
const request1 = {
callId: '1',
name: 'mockTool',
args: { a: 1 },
isClientInitiated: false,
prompt_id: 'prompt-1',
};
const request2 = {
callId: '2',
name: 'mockTool',
args: { b: 2 },
isClientInitiated: false,
prompt_id: 'prompt-2',
};
// Schedule the first call, which will pause execution.
scheduler.schedule([request1], abortController.signal);
// Wait for the first call to be in the 'executing' state.
await vi.waitFor(() => {
const calls = onToolCallsUpdate.mock.calls.at(-1)?.[0] as ToolCall[];
expect(calls?.[0]?.status).toBe('executing');
});
// Schedule the second call while the first is "running".
const schedulePromise2 = scheduler.schedule(
[request2],
abortController.signal,
);
// Ensure the second tool call hasn't been executed yet.
expect(mockTool.executeFn).toHaveBeenCalledTimes(1);
expect(mockTool.executeFn).toHaveBeenCalledWith({ a: 1 });
// Complete the first tool call.
resolveFirstCall!({
llmContent: 'First call complete',
returnDisplay: 'First call complete',
});
// Wait for the second schedule promise to resolve.
await schedulePromise2;
// Wait for the second call to be in the 'executing' state.
await vi.waitFor(() => {
const calls = onToolCallsUpdate.mock.calls.at(-1)?.[0] as ToolCall[];
expect(calls?.[0]?.status).toBe('executing');
});
// Now the second tool call should have been executed.
expect(mockTool.executeFn).toHaveBeenCalledTimes(2);
expect(mockTool.executeFn).toHaveBeenCalledWith({ b: 2 });
// Let the second call finish.
const secondCallResult = {
llmContent: 'Second call complete',
returnDisplay: 'Second call complete',
};
// Since the mock is shared, we need to resolve the current promise.
// In a real scenario, a new promise would be created for the second call.
resolveFirstCall!(secondCallResult);
// Wait for the second completion.
await vi.waitFor(() => {
expect(onAllToolCallsComplete).toHaveBeenCalledTimes(2);
});
// Verify the completion callbacks were called correctly.
expect(onAllToolCallsComplete.mock.calls[0][0][0].status).toBe('success');
expect(onAllToolCallsComplete.mock.calls[1][0][0].status).toBe('success');
});
it('should handle two synchronous calls to schedule', async () => {
const mockTool = new MockTool();
const declarativeTool = mockTool;
const toolRegistry = {
getTool: () => declarativeTool,
getToolByName: () => declarativeTool,
getFunctionDeclarations: () => [],
tools: new Map(),
discovery: {} as any,
registerTool: () => {},
getToolByDisplayName: () => declarativeTool,
getTools: () => [],
discoverTools: async () => {},
getAllTools: () => [],
getToolsByServer: () => [],
};
const onAllToolCallsComplete = vi.fn();
const onToolCallsUpdate = vi.fn();
const mockConfig = {
getSessionId: () => 'test-session-id',
getUsageStatisticsEnabled: () => true,
getDebugMode: () => false,
getApprovalMode: () => ApprovalMode.YOLO,
} as unknown as Config;
const scheduler = new CoreToolScheduler({
config: mockConfig,
toolRegistry: Promise.resolve(toolRegistry as any),
onAllToolCallsComplete,
onToolCallsUpdate,
getPreferredEditor: () => 'vscode',
onEditorClose: vi.fn(),
});
const abortController = new AbortController();
const request1 = {
callId: '1',
name: 'mockTool',
args: { a: 1 },
isClientInitiated: false,
prompt_id: 'prompt-1',
};
const request2 = {
callId: '2',
name: 'mockTool',
args: { b: 2 },
isClientInitiated: false,
prompt_id: 'prompt-2',
};
// Schedule two calls synchronously.
const schedulePromise1 = scheduler.schedule(
[request1],
abortController.signal,
);
const schedulePromise2 = scheduler.schedule(
[request2],
abortController.signal,
);
// Wait for both promises to resolve.
await Promise.all([schedulePromise1, schedulePromise2]);
// Ensure the tool was called twice with the correct arguments.
expect(mockTool.executeFn).toHaveBeenCalledTimes(2);
expect(mockTool.executeFn).toHaveBeenCalledWith({ a: 1 });
expect(mockTool.executeFn).toHaveBeenCalledWith({ b: 2 });
// Ensure completion callbacks were called twice.
expect(onAllToolCallsComplete).toHaveBeenCalledTimes(2);
});
});

View File

@@ -8,7 +8,6 @@ import {
ToolCallRequestInfo,
ToolCallResponseInfo,
ToolConfirmationOutcome,
Tool,
ToolCallConfirmationDetails,
ToolResult,
ToolResultDisplay,
@@ -20,11 +19,13 @@ import {
ToolCallEvent,
ToolConfirmationPayload,
ToolErrorType,
AnyDeclarativeTool,
AnyToolInvocation,
} from '../index.js';
import { Part, PartListUnion } from '@google/genai';
import { getResponseTextFromParts } from '../utils/generateContentResponseUtilities.js';
import {
isModifiableTool,
isModifiableDeclarativeTool,
ModifyContext,
modifyWithEditor,
} from '../tools/modifiable-tool.js';
@@ -33,7 +34,8 @@ import * as Diff from 'diff';
export type ValidatingToolCall = {
status: 'validating';
request: ToolCallRequestInfo;
tool: Tool;
tool: AnyDeclarativeTool;
invocation: AnyToolInvocation;
startTime?: number;
outcome?: ToolConfirmationOutcome;
};
@@ -41,7 +43,8 @@ export type ValidatingToolCall = {
export type ScheduledToolCall = {
status: 'scheduled';
request: ToolCallRequestInfo;
tool: Tool;
tool: AnyDeclarativeTool;
invocation: AnyToolInvocation;
startTime?: number;
outcome?: ToolConfirmationOutcome;
};
@@ -50,6 +53,7 @@ export type ErroredToolCall = {
status: 'error';
request: ToolCallRequestInfo;
response: ToolCallResponseInfo;
tool?: AnyDeclarativeTool;
durationMs?: number;
outcome?: ToolConfirmationOutcome;
};
@@ -57,8 +61,9 @@ export type ErroredToolCall = {
export type SuccessfulToolCall = {
status: 'success';
request: ToolCallRequestInfo;
tool: Tool;
tool: AnyDeclarativeTool;
response: ToolCallResponseInfo;
invocation: AnyToolInvocation;
durationMs?: number;
outcome?: ToolConfirmationOutcome;
};
@@ -66,7 +71,8 @@ export type SuccessfulToolCall = {
export type ExecutingToolCall = {
status: 'executing';
request: ToolCallRequestInfo;
tool: Tool;
tool: AnyDeclarativeTool;
invocation: AnyToolInvocation;
liveOutput?: string;
startTime?: number;
outcome?: ToolConfirmationOutcome;
@@ -76,7 +82,8 @@ export type CancelledToolCall = {
status: 'cancelled';
request: ToolCallRequestInfo;
response: ToolCallResponseInfo;
tool: Tool;
tool: AnyDeclarativeTool;
invocation: AnyToolInvocation;
durationMs?: number;
outcome?: ToolConfirmationOutcome;
};
@@ -84,7 +91,8 @@ export type CancelledToolCall = {
export type WaitingToolCall = {
status: 'awaiting_approval';
request: ToolCallRequestInfo;
tool: Tool;
tool: AnyDeclarativeTool;
invocation: AnyToolInvocation;
confirmationDetails: ToolCallConfirmationDetails;
startTime?: number;
outcome?: ToolConfirmationOutcome;
@@ -117,7 +125,7 @@ export type OutputUpdateHandler = (
export type AllToolCallsCompleteHandler = (
completedToolCalls: CompletedToolCall[],
) => void;
) => Promise<void>;
export type ToolCallsUpdateHandler = (toolCalls: ToolCall[]) => void;
@@ -236,6 +244,14 @@ export class CoreToolScheduler {
private getPreferredEditor: () => EditorType | undefined;
private config: Config;
private onEditorClose: () => void;
private isFinalizingToolCalls = false;
private isScheduling = false;
private requestQueue: Array<{
request: ToolCallRequestInfo | ToolCallRequestInfo[];
signal: AbortSignal;
resolve: () => void;
reject: (reason?: Error) => void;
}> = [];
constructor(options: CoreToolSchedulerOptions) {
this.config = options.config;
@@ -289,6 +305,7 @@ export class CoreToolScheduler {
// currentCall is a non-terminal state here and should have startTime and tool.
const existingStartTime = currentCall.startTime;
const toolInstance = currentCall.tool;
const invocation = currentCall.invocation;
const outcome = currentCall.outcome;
@@ -300,6 +317,7 @@ export class CoreToolScheduler {
return {
request: currentCall.request,
tool: toolInstance,
invocation,
status: 'success',
response: auxiliaryData as ToolCallResponseInfo,
durationMs,
@@ -313,6 +331,7 @@ export class CoreToolScheduler {
return {
request: currentCall.request,
status: 'error',
tool: toolInstance,
response: auxiliaryData as ToolCallResponseInfo,
durationMs,
outcome,
@@ -326,6 +345,7 @@ export class CoreToolScheduler {
confirmationDetails: auxiliaryData as ToolCallConfirmationDetails,
startTime: existingStartTime,
outcome,
invocation,
} as WaitingToolCall;
case 'scheduled':
return {
@@ -334,6 +354,7 @@ export class CoreToolScheduler {
status: 'scheduled',
startTime: existingStartTime,
outcome,
invocation,
} as ScheduledToolCall;
case 'cancelled': {
const durationMs = existingStartTime
@@ -358,6 +379,7 @@ export class CoreToolScheduler {
return {
request: currentCall.request,
tool: toolInstance,
invocation,
status: 'cancelled',
response: {
callId: currentCall.request.callId,
@@ -385,6 +407,7 @@ export class CoreToolScheduler {
status: 'validating',
startTime: existingStartTime,
outcome,
invocation,
} as ValidatingToolCall;
case 'executing':
return {
@@ -393,6 +416,7 @@ export class CoreToolScheduler {
status: 'executing',
startTime: existingStartTime,
outcome,
invocation,
} as ExecutingToolCall;
default: {
const exhaustiveCheck: never = newStatus;
@@ -406,114 +430,247 @@ export class CoreToolScheduler {
private setArgsInternal(targetCallId: string, args: unknown): void {
this.toolCalls = this.toolCalls.map((call) => {
if (call.request.callId !== targetCallId) return call;
// We should never be asked to set args on an ErroredToolCall, but
// we guard for the case anyways.
if (call.request.callId !== targetCallId || call.status === 'error') {
return call;
}
const invocationOrError = this.buildInvocation(
call.tool,
args as Record<string, unknown>,
);
if (invocationOrError instanceof Error) {
const response = createErrorResponse(
call.request,
invocationOrError,
ToolErrorType.INVALID_TOOL_PARAMS,
);
return {
request: { ...call.request, args: args as Record<string, unknown> },
status: 'error',
tool: call.tool,
response,
} as ErroredToolCall;
}
return {
...call,
request: { ...call.request, args: args as Record<string, unknown> },
invocation: invocationOrError,
};
});
}
private isRunning(): boolean {
return this.toolCalls.some(
(call) =>
call.status === 'executing' || call.status === 'awaiting_approval',
return (
this.isFinalizingToolCalls ||
this.toolCalls.some(
(call) =>
call.status === 'executing' || call.status === 'awaiting_approval',
)
);
}
async schedule(
private buildInvocation(
tool: AnyDeclarativeTool,
args: object,
): AnyToolInvocation | Error {
try {
return tool.build(args);
} catch (e) {
if (e instanceof Error) {
return e;
}
return new Error(String(e));
}
}
schedule(
request: ToolCallRequestInfo | ToolCallRequestInfo[],
signal: AbortSignal,
): Promise<void> {
if (this.isRunning()) {
throw new Error(
'Cannot schedule new tool calls while other tool calls are actively running (executing or awaiting approval).',
);
}
const requestsToProcess = Array.isArray(request) ? request : [request];
const toolRegistry = await this.toolRegistry;
const newToolCalls: ToolCall[] = requestsToProcess.map(
(reqInfo): ToolCall => {
const toolInstance = toolRegistry.getTool(reqInfo.name);
if (!toolInstance) {
return {
status: 'error',
request: reqInfo,
response: createErrorResponse(
reqInfo,
new Error(`Tool "${reqInfo.name}" not found in registry.`),
ToolErrorType.TOOL_NOT_REGISTERED,
),
durationMs: 0,
};
}
return {
status: 'validating',
request: reqInfo,
tool: toolInstance,
startTime: Date.now(),
};
},
);
this.toolCalls = this.toolCalls.concat(newToolCalls);
this.notifyToolCallsUpdate();
for (const toolCall of newToolCalls) {
if (toolCall.status !== 'validating') {
continue;
}
const { request: reqInfo, tool: toolInstance } = toolCall;
try {
if (this.config.getApprovalMode() === ApprovalMode.YOLO) {
this.setStatusInternal(reqInfo.callId, 'scheduled');
} else {
const confirmationDetails = await toolInstance.shouldConfirmExecute(
reqInfo.args,
signal,
if (this.isRunning() || this.isScheduling) {
return new Promise((resolve, reject) => {
const abortHandler = () => {
// Find and remove the request from the queue
const index = this.requestQueue.findIndex(
(item) => item.request === request,
);
if (confirmationDetails) {
const originalOnConfirm = confirmationDetails.onConfirm;
const wrappedConfirmationDetails: ToolCallConfirmationDetails = {
...confirmationDetails,
onConfirm: (
outcome: ToolConfirmationOutcome,
payload?: ToolConfirmationPayload,
) =>
this.handleConfirmationResponse(
reqInfo.callId,
originalOnConfirm,
outcome,
signal,
payload,
),
};
this.setStatusInternal(
reqInfo.callId,
'awaiting_approval',
wrappedConfirmationDetails,
);
} else {
this.setStatusInternal(reqInfo.callId, 'scheduled');
if (index > -1) {
this.requestQueue.splice(index, 1);
reject(new Error('Tool call cancelled while in queue.'));
}
}
} catch (error) {
this.setStatusInternal(
reqInfo.callId,
'error',
createErrorResponse(
reqInfo,
error instanceof Error ? error : new Error(String(error)),
ToolErrorType.UNHANDLED_EXCEPTION,
),
};
signal.addEventListener('abort', abortHandler, { once: true });
this.requestQueue.push({
request,
signal,
resolve: () => {
signal.removeEventListener('abort', abortHandler);
resolve();
},
reject: (reason?: Error) => {
signal.removeEventListener('abort', abortHandler);
reject(reason);
},
});
});
}
return this._schedule(request, signal);
}
private async _schedule(
request: ToolCallRequestInfo | ToolCallRequestInfo[],
signal: AbortSignal,
): Promise<void> {
this.isScheduling = true;
try {
if (this.isRunning()) {
throw new Error(
'Cannot schedule new tool calls while other tool calls are actively running (executing or awaiting approval).',
);
}
const requestsToProcess = Array.isArray(request) ? request : [request];
const toolRegistry = await this.toolRegistry;
const newToolCalls: ToolCall[] = requestsToProcess.map(
(reqInfo): ToolCall => {
const toolInstance = toolRegistry.getTool(reqInfo.name);
if (!toolInstance) {
return {
status: 'error',
request: reqInfo,
response: createErrorResponse(
reqInfo,
new Error(`Tool "${reqInfo.name}" not found in registry.`),
ToolErrorType.TOOL_NOT_REGISTERED,
),
durationMs: 0,
};
}
const invocationOrError = this.buildInvocation(
toolInstance,
reqInfo.args,
);
if (invocationOrError instanceof Error) {
return {
status: 'error',
request: reqInfo,
tool: toolInstance,
response: createErrorResponse(
reqInfo,
invocationOrError,
ToolErrorType.INVALID_TOOL_PARAMS,
),
durationMs: 0,
};
}
return {
status: 'validating',
request: reqInfo,
tool: toolInstance,
invocation: invocationOrError,
startTime: Date.now(),
};
},
);
this.toolCalls = this.toolCalls.concat(newToolCalls);
this.notifyToolCallsUpdate();
for (const toolCall of newToolCalls) {
if (toolCall.status !== 'validating') {
continue;
}
const { request: reqInfo, invocation } = toolCall;
try {
if (this.config.getApprovalMode() === ApprovalMode.YOLO) {
this.setToolCallOutcome(
reqInfo.callId,
ToolConfirmationOutcome.ProceedAlways,
);
this.setStatusInternal(reqInfo.callId, 'scheduled');
} else {
const confirmationDetails =
await invocation.shouldConfirmExecute(signal);
if (confirmationDetails) {
// Allow IDE to resolve confirmation
if (
confirmationDetails.type === 'edit' &&
confirmationDetails.ideConfirmation
) {
confirmationDetails.ideConfirmation.then((resolution) => {
if (resolution.status === 'accepted') {
this.handleConfirmationResponse(
reqInfo.callId,
confirmationDetails.onConfirm,
ToolConfirmationOutcome.ProceedOnce,
signal,
);
} else {
this.handleConfirmationResponse(
reqInfo.callId,
confirmationDetails.onConfirm,
ToolConfirmationOutcome.Cancel,
signal,
);
}
});
}
const originalOnConfirm = confirmationDetails.onConfirm;
const wrappedConfirmationDetails: ToolCallConfirmationDetails = {
...confirmationDetails,
onConfirm: (
outcome: ToolConfirmationOutcome,
payload?: ToolConfirmationPayload,
) =>
this.handleConfirmationResponse(
reqInfo.callId,
originalOnConfirm,
outcome,
signal,
payload,
),
};
this.setStatusInternal(
reqInfo.callId,
'awaiting_approval',
wrappedConfirmationDetails,
);
} else {
this.setToolCallOutcome(
reqInfo.callId,
ToolConfirmationOutcome.ProceedAlways,
);
this.setStatusInternal(reqInfo.callId, 'scheduled');
}
}
} catch (error) {
this.setStatusInternal(
reqInfo.callId,
'error',
createErrorResponse(
reqInfo,
error instanceof Error ? error : new Error(String(error)),
ToolErrorType.UNHANDLED_EXCEPTION,
),
);
}
}
this.attemptExecutionOfScheduledCalls(signal);
void this.checkAndNotifyCompletion();
} finally {
this.isScheduling = false;
}
this.attemptExecutionOfScheduledCalls(signal);
this.checkAndNotifyCompletion();
}
async handleConfirmationResponse(
@@ -531,13 +688,7 @@ export class CoreToolScheduler {
await originalOnConfirm(outcome);
}
this.toolCalls = this.toolCalls.map((call) => {
if (call.request.callId !== callId) return call;
return {
...call,
outcome,
};
});
this.setToolCallOutcome(callId, outcome);
if (outcome === ToolConfirmationOutcome.Cancel || signal.aborted) {
this.setStatusInternal(
@@ -547,7 +698,7 @@ export class CoreToolScheduler {
);
} else if (outcome === ToolConfirmationOutcome.ModifyWithEditor) {
const waitingToolCall = toolCall as WaitingToolCall;
if (isModifiableTool(waitingToolCall.tool)) {
if (isModifiableDeclarativeTool(waitingToolCall.tool)) {
const modifyContext = waitingToolCall.tool.getModifyContext(signal);
const editorType = this.getPreferredEditor();
if (!editorType) {
@@ -602,7 +753,7 @@ export class CoreToolScheduler {
): Promise<void> {
if (
toolCall.confirmationDetails.type !== 'edit' ||
!isModifiableTool(toolCall.tool)
!isModifiableDeclarativeTool(toolCall.tool)
) {
return;
}
@@ -651,6 +802,7 @@ export class CoreToolScheduler {
const scheduledCall = toolCall;
const { callId, name: toolName } = scheduledCall.request;
const invocation = scheduledCall.invocation;
this.setStatusInternal(callId, 'executing');
const liveOutputCallback =
@@ -668,8 +820,8 @@ export class CoreToolScheduler {
}
: undefined;
scheduledCall.tool
.execute(scheduledCall.request.args, signal, liveOutputCallback)
invocation
.execute(signal, liveOutputCallback)
.then(async (toolResult: ToolResult) => {
if (signal.aborted) {
this.setStatusInternal(
@@ -722,7 +874,7 @@ export class CoreToolScheduler {
}
}
private checkAndNotifyCompletion(): void {
private async checkAndNotifyCompletion(): Promise<void> {
const allCallsAreTerminal = this.toolCalls.every(
(call) =>
call.status === 'success' ||
@@ -739,9 +891,18 @@ export class CoreToolScheduler {
}
if (this.onAllToolCallsComplete) {
this.onAllToolCallsComplete(completedCalls);
this.isFinalizingToolCalls = true;
await this.onAllToolCallsComplete(completedCalls);
this.isFinalizingToolCalls = false;
}
this.notifyToolCallsUpdate();
// After completion, process the next item in the queue.
if (this.requestQueue.length > 0) {
const next = this.requestQueue.shift()!;
this._schedule(next.request, next.signal)
.then(next.resolve)
.catch(next.reject);
}
}
}
@@ -750,4 +911,14 @@ export class CoreToolScheduler {
this.onToolCallsUpdate([...this.toolCalls]);
}
}
private setToolCallOutcome(callId: string, outcome: ToolConfirmationOutcome) {
this.toolCalls = this.toolCalls.map((call) => {
if (call.request.callId !== callId) return call;
return {
...call,
outcome,
};
});
}
}

View File

@@ -14,24 +14,15 @@ import {
SendMessageParameters,
createUserContent,
Part,
GenerateContentResponseUsageMetadata,
Tool,
} from '@google/genai';
import { retryWithBackoff } from '../utils/retry.js';
import { isFunctionResponse } from '../utils/messageInspectors.js';
import { ContentGenerator, AuthType } from './contentGenerator.js';
import { Config } from '../config/config.js';
import {
logApiRequest,
logApiResponse,
logApiError,
} from '../telemetry/loggers.js';
import {
ApiErrorEvent,
ApiRequestEvent,
ApiResponseEvent,
} from '../telemetry/types.js';
import { DEFAULT_GEMINI_FLASH_MODEL } from '../config/models.js';
import { hasCycleInSchema } from '../tools/tools.js';
import { StructuredError } from './turn.js';
/**
* Returns true if the response is valid, false otherwise.
@@ -137,80 +128,6 @@ export class GeminiChat {
validateHistory(history);
}
private _getRequestTextFromContents(contents: Content[]): string {
return JSON.stringify(contents);
}
private async _logApiRequest(
contents: Content[],
model: string,
prompt_id: string,
): Promise<void> {
const requestText = this._getRequestTextFromContents(contents);
logApiRequest(
this.config,
new ApiRequestEvent(model, prompt_id, requestText),
);
}
private async _logApiResponse(
durationMs: number,
prompt_id: string,
usageMetadata?: GenerateContentResponseUsageMetadata,
responseText?: string,
responseId?: string,
): Promise<void> {
const authType = this.config.getContentGeneratorConfig()?.authType;
// Don't log API responses for openaiContentGenerator
if (authType === AuthType.QWEN_OAUTH || authType === AuthType.USE_OPENAI) {
return;
}
logApiResponse(
this.config,
new ApiResponseEvent(
responseId || `gemini-${Date.now()}`,
this.config.getModel(),
durationMs,
prompt_id,
authType,
usageMetadata,
responseText,
),
);
}
private _logApiError(
durationMs: number,
error: unknown,
prompt_id: string,
responseId?: string,
): void {
const errorMessage = error instanceof Error ? error.message : String(error);
const errorType = error instanceof Error ? error.name : 'unknown';
const authType = this.config.getContentGeneratorConfig()?.authType;
// Don't log API errors for openaiContentGenerator
if (authType === AuthType.QWEN_OAUTH || authType === AuthType.USE_OPENAI) {
return;
}
logApiError(
this.config,
new ApiErrorEvent(
responseId,
this.config.getModel(),
errorMessage,
durationMs,
prompt_id,
authType,
errorType,
),
);
}
/**
* Handles falling back to Flash model when persistent 429 errors occur for OAuth users.
* Uses a fallback handler if provided by the config; otherwise, returns null.
@@ -263,6 +180,9 @@ export class GeminiChat {
return null;
}
setSystemInstruction(sysInstr: string) {
this.generationConfig.systemInstruction = sysInstr;
}
/**
* Sends a message to the model and returns the response.
*
@@ -291,9 +211,6 @@ export class GeminiChat {
const userContent = createUserContent(params.message);
const requestContents = this.getHistory(true).concat(userContent);
this._logApiRequest(requestContents, this.config.getModel(), prompt_id);
const startTime = Date.now();
let response: GenerateContentResponse;
try {
@@ -321,25 +238,19 @@ export class GeminiChat {
};
response = await retryWithBackoff(apiCall, {
shouldRetry: (error: Error) => {
if (error && error.message) {
shouldRetry: (error: unknown) => {
// Check for known error messages and codes.
if (error instanceof Error && error.message) {
if (isSchemaDepthError(error.message)) return false;
if (error.message.includes('429')) return true;
if (error.message.match(/5\d{2}/)) return true;
}
return false;
return false; // Don't retry other errors by default
},
onPersistent429: async (authType?: string, error?: unknown) =>
await this.handleFlashFallback(authType, error),
authType: this.config.getContentGeneratorConfig()?.authType,
});
const durationMs = Date.now() - startTime;
await this._logApiResponse(
durationMs,
prompt_id,
response.usageMetadata,
JSON.stringify(response),
response.responseId,
);
this.sendPromise = (async () => {
const outputContent = response.candidates?.[0]?.content;
@@ -367,8 +278,6 @@ export class GeminiChat {
});
return response;
} catch (error) {
const durationMs = Date.now() - startTime;
this._logApiError(durationMs, error, prompt_id);
this.sendPromise = Promise.resolve();
throw error;
}
@@ -403,9 +312,6 @@ export class GeminiChat {
await this.sendPromise;
const userContent = createUserContent(params.message);
const requestContents = this.getHistory(true).concat(userContent);
this._logApiRequest(requestContents, this.config.getModel(), prompt_id);
const startTime = Date.now();
try {
const apiCall = () => {
@@ -436,9 +342,10 @@ export class GeminiChat {
// the stream. For simple 429/500 errors on initial call, this is fine.
// If errors occur mid-stream, this setup won't resume the stream; it will restart it.
const streamResponse = await retryWithBackoff(apiCall, {
shouldRetry: (error: Error) => {
// Check error messages for status codes, or specific error names if known
if (error && error.message) {
shouldRetry: (error: unknown) => {
// Check for known error messages and codes.
if (error instanceof Error && error.message) {
if (isSchemaDepthError(error.message)) return false;
if (error.message.includes('429')) return true;
if (error.message.match(/5\d{2}/)) return true;
}
@@ -456,16 +363,9 @@ export class GeminiChat {
.then(() => undefined)
.catch(() => undefined);
const result = this.processStreamResponse(
streamResponse,
userContent,
startTime,
prompt_id,
);
const result = this.processStreamResponse(streamResponse, userContent);
return result;
} catch (error) {
const durationMs = Date.now() - startTime;
this._logApiError(durationMs, error, prompt_id);
this.sendPromise = Promise.resolve();
throw error;
}
@@ -526,22 +426,37 @@ export class GeminiChat {
this.generationConfig.tools = tools;
}
getFinalUsageMetadata(
chunks: GenerateContentResponse[],
): GenerateContentResponseUsageMetadata | undefined {
const lastChunkWithMetadata = chunks
.slice()
.reverse()
.find((chunk) => chunk.usageMetadata);
return lastChunkWithMetadata?.usageMetadata;
async maybeIncludeSchemaDepthContext(error: StructuredError): Promise<void> {
// Check for potentially problematic cyclic tools with cyclic schemas
// and include a recommendation to remove potentially problematic tools.
if (
isSchemaDepthError(error.message) ||
isInvalidArgumentError(error.message)
) {
const tools = (await this.config.getToolRegistry()).getAllTools();
const cyclicSchemaTools: string[] = [];
for (const tool of tools) {
if (
(tool.schema.parametersJsonSchema &&
hasCycleInSchema(tool.schema.parametersJsonSchema)) ||
(tool.schema.parameters && hasCycleInSchema(tool.schema.parameters))
) {
cyclicSchemaTools.push(tool.displayName);
}
}
if (cyclicSchemaTools.length > 0) {
const extraDetails =
`\n\nThis error was probably caused by cyclic schema references in one of the following tools, try disabling them with excludeTools:\n\n - ` +
cyclicSchemaTools.join(`\n - `) +
`\n`;
error.message += extraDetails;
}
}
}
private async *processStreamResponse(
streamResponse: AsyncGenerator<GenerateContentResponse>,
inputContent: Content,
startTime: number,
prompt_id: string,
) {
const outputContent: Content[] = [];
const chunks: GenerateContentResponse[] = [];
@@ -564,26 +479,16 @@ export class GeminiChat {
}
} catch (error) {
errorOccurred = true;
const durationMs = Date.now() - startTime;
this._logApiError(durationMs, error, prompt_id);
throw error;
}
if (!errorOccurred) {
const durationMs = Date.now() - startTime;
const allParts: Part[] = [];
for (const content of outputContent) {
if (content.parts) {
allParts.push(...content.parts);
}
}
await this._logApiResponse(
durationMs,
prompt_id,
this.getFinalUsageMetadata(chunks),
JSON.stringify(chunks),
chunks[chunks.length - 1]?.responseId,
);
}
this.recordHistory(inputContent, outputContent);
}
@@ -755,3 +660,12 @@ export class GeminiChat {
return null;
}
}
/** Visible for Testing */
export function isSchemaDepthError(errorMessage: string): boolean {
return errorMessage.includes('maximum schema depth exceeded');
}
export function isInvalidArgumentError(errorMessage: string): boolean {
return errorMessage.includes('Request contains an invalid argument');
}

View File

@@ -565,6 +565,52 @@ describe('Logger', () => {
});
});
describe('checkpointExists', () => {
const tag = 'exists-test';
let taggedFilePath: string;
beforeEach(() => {
taggedFilePath = path.join(TEST_GEMINI_DIR, `checkpoint-${tag}.json`);
});
it('should return true if the checkpoint file exists', async () => {
await fs.writeFile(taggedFilePath, '{}');
const exists = await logger.checkpointExists(tag);
expect(exists).toBe(true);
});
it('should return false if the checkpoint file does not exist', async () => {
const exists = await logger.checkpointExists('non-existent-tag');
expect(exists).toBe(false);
});
it('should throw an error if logger is not initialized', async () => {
const uninitializedLogger = new Logger(testSessionId);
uninitializedLogger.close();
await expect(uninitializedLogger.checkpointExists(tag)).rejects.toThrow(
'Logger not initialized. Cannot check for checkpoint existence.',
);
});
it('should re-throw an error if fs.access fails for reasons other than not existing', async () => {
vi.spyOn(fs, 'access').mockRejectedValueOnce(
new Error('EACCES: permission denied'),
);
const consoleErrorSpy = vi
.spyOn(console, 'error')
.mockImplementation(() => {});
await expect(logger.checkpointExists(tag)).rejects.toThrow(
'EACCES: permission denied',
);
expect(consoleErrorSpy).toHaveBeenCalledWith(
`Failed to check checkpoint existence for ${taggedFilePath}:`,
expect.any(Error),
);
});
});
describe('close', () => {
it('should reset logger state', async () => {
await logger.logMessage(MessageSenderType.USER, 'A message');

View File

@@ -310,6 +310,29 @@ export class Logger {
}
}
async checkpointExists(tag: string): Promise<boolean> {
if (!this.initialized) {
throw new Error(
'Logger not initialized. Cannot check for checkpoint existence.',
);
}
const filePath = this._checkpointPath(tag);
try {
await fs.access(filePath);
return true;
} catch (error) {
const nodeError = error as NodeJS.ErrnoException;
if (nodeError.code === 'ENOENT') {
return false;
}
console.error(
`Failed to check checkpoint existence for ${filePath}:`,
error,
);
throw error;
}
}
close(): void {
this.initialized = false;
this.logFilePath = undefined;

View File

@@ -0,0 +1,196 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import {
Content,
CountTokensParameters,
CountTokensResponse,
EmbedContentParameters,
EmbedContentResponse,
GenerateContentParameters,
GenerateContentResponseUsageMetadata,
GenerateContentResponse,
} from '@google/genai';
import {
ApiRequestEvent,
ApiResponseEvent,
ApiErrorEvent,
} from '../telemetry/types.js';
import { Config } from '../config/config.js';
import {
logApiError,
logApiRequest,
logApiResponse,
} from '../telemetry/loggers.js';
import { ContentGenerator } from './contentGenerator.js';
import { toContents } from '../code_assist/converter.js';
interface StructuredError {
status: number;
}
export function isStructuredError(error: unknown): error is StructuredError {
return (
typeof error === 'object' &&
error !== null &&
'status' in error &&
typeof (error as StructuredError).status === 'number'
);
}
/**
* A decorator that wraps a ContentGenerator to add logging to API calls.
*/
export class LoggingContentGenerator implements ContentGenerator {
constructor(
private readonly wrapped: ContentGenerator,
private readonly config: Config,
) {}
private logApiRequest(
contents: Content[],
model: string,
promptId: string,
): void {
const requestText = JSON.stringify(contents);
logApiRequest(
this.config,
new ApiRequestEvent(model, promptId, requestText),
);
}
private _logApiResponse(
responseId: string,
durationMs: number,
prompt_id: string,
usageMetadata?: GenerateContentResponseUsageMetadata,
responseText?: string,
): void {
logApiResponse(
this.config,
new ApiResponseEvent(
responseId,
this.config.getModel(),
durationMs,
prompt_id,
this.config.getContentGeneratorConfig()?.authType,
usageMetadata,
responseText,
),
);
}
private _logApiError(
responseId: string | undefined,
durationMs: number,
error: unknown,
prompt_id: string,
): void {
const errorMessage = error instanceof Error ? error.message : String(error);
const errorType = error instanceof Error ? error.name : 'unknown';
logApiError(
this.config,
new ApiErrorEvent(
responseId,
this.config.getModel(),
errorMessage,
durationMs,
prompt_id,
this.config.getContentGeneratorConfig()?.authType,
errorType,
isStructuredError(error)
? (error as StructuredError).status
: undefined,
),
);
}
async generateContent(
req: GenerateContentParameters,
userPromptId: string,
): Promise<GenerateContentResponse> {
const startTime = Date.now();
this.logApiRequest(toContents(req.contents), req.model, userPromptId);
try {
const response = await this.wrapped.generateContent(req, userPromptId);
const durationMs = Date.now() - startTime;
this._logApiResponse(
response.responseId ?? '',
durationMs,
userPromptId,
response.usageMetadata,
JSON.stringify(response),
);
return response;
} catch (error) {
const durationMs = Date.now() - startTime;
this._logApiError(undefined, durationMs, error, userPromptId);
throw error;
}
}
async generateContentStream(
req: GenerateContentParameters,
userPromptId: string,
): Promise<AsyncGenerator<GenerateContentResponse>> {
const startTime = Date.now();
this.logApiRequest(toContents(req.contents), req.model, userPromptId);
let stream: AsyncGenerator<GenerateContentResponse>;
try {
stream = await this.wrapped.generateContentStream(req, userPromptId);
} catch (error) {
const durationMs = Date.now() - startTime;
this._logApiError(undefined, durationMs, error, userPromptId);
throw error;
}
return this.loggingStreamWrapper(stream, startTime, userPromptId);
}
private async *loggingStreamWrapper(
stream: AsyncGenerator<GenerateContentResponse>,
startTime: number,
userPromptId: string,
): AsyncGenerator<GenerateContentResponse> {
let lastResponse: GenerateContentResponse | undefined;
let lastUsageMetadata: GenerateContentResponseUsageMetadata | undefined;
try {
for await (const response of stream) {
lastResponse = response;
if (response.usageMetadata) {
lastUsageMetadata = response.usageMetadata;
}
yield response;
}
} catch (error) {
const durationMs = Date.now() - startTime;
this._logApiError(undefined, durationMs, error, userPromptId);
throw error;
}
const durationMs = Date.now() - startTime;
if (lastResponse) {
this._logApiResponse(
lastResponse.responseId ?? '',
durationMs,
userPromptId,
lastUsageMetadata,
JSON.stringify(lastResponse),
);
}
}
async countTokens(req: CountTokensParameters): Promise<CountTokensResponse> {
return this.wrapped.countTokens(req);
}
async embedContent(
req: EmbedContentParameters,
): Promise<EmbedContentResponse> {
return this.wrapped.embedContent(req);
}
}

View File

@@ -10,12 +10,10 @@ import {
ToolRegistry,
ToolCallRequestInfo,
ToolResult,
Tool,
ToolCallConfirmationDetails,
Config,
Icon,
} from '../index.js';
import { Part, Type } from '@google/genai';
import { Part } from '@google/genai';
import { MockTool } from '../test-utils/tools.js';
const mockConfig = {
getSessionId: () => 'test-session-id',
@@ -25,36 +23,11 @@ const mockConfig = {
describe('executeToolCall', () => {
let mockToolRegistry: ToolRegistry;
let mockTool: Tool;
let mockTool: MockTool;
let abortController: AbortController;
beforeEach(() => {
mockTool = {
name: 'testTool',
displayName: 'Test Tool',
description: 'A tool for testing',
icon: Icon.Hammer,
schema: {
name: 'testTool',
description: 'A tool for testing',
parameters: {
type: Type.OBJECT,
properties: {
param1: { type: Type.STRING },
},
required: ['param1'],
},
},
execute: vi.fn(),
validateToolParams: vi.fn(() => null),
shouldConfirmExecute: vi.fn(() =>
Promise.resolve(false as false | ToolCallConfirmationDetails),
),
isOutputMarkdown: false,
canUpdateOutput: false,
getDescription: vi.fn(),
toolLocations: vi.fn(() => []),
};
mockTool = new MockTool();
mockToolRegistry = {
getTool: vi.fn(),
@@ -77,7 +50,7 @@ describe('executeToolCall', () => {
returnDisplay: 'Success!',
};
vi.mocked(mockToolRegistry.getTool).mockReturnValue(mockTool);
vi.mocked(mockTool.execute).mockResolvedValue(toolResult);
vi.spyOn(mockTool, 'buildAndExecute').mockResolvedValue(toolResult);
const response = await executeToolCall(
mockConfig,
@@ -87,7 +60,7 @@ describe('executeToolCall', () => {
);
expect(mockToolRegistry.getTool).toHaveBeenCalledWith('testTool');
expect(mockTool.execute).toHaveBeenCalledWith(
expect(mockTool.buildAndExecute).toHaveBeenCalledWith(
request.args,
abortController.signal,
);
@@ -149,7 +122,7 @@ describe('executeToolCall', () => {
};
const executionError = new Error('Tool execution failed');
vi.mocked(mockToolRegistry.getTool).mockReturnValue(mockTool);
vi.mocked(mockTool.execute).mockRejectedValue(executionError);
vi.spyOn(mockTool, 'buildAndExecute').mockRejectedValue(executionError);
const response = await executeToolCall(
mockConfig,
@@ -183,25 +156,27 @@ describe('executeToolCall', () => {
const cancellationError = new Error('Operation cancelled');
vi.mocked(mockToolRegistry.getTool).mockReturnValue(mockTool);
vi.mocked(mockTool.execute).mockImplementation(async (_args, signal) => {
if (signal?.aborted) {
return Promise.reject(cancellationError);
}
return new Promise((_resolve, reject) => {
signal?.addEventListener('abort', () => {
reject(cancellationError);
vi.spyOn(mockTool, 'buildAndExecute').mockImplementation(
async (_args, signal) => {
if (signal?.aborted) {
return Promise.reject(cancellationError);
}
return new Promise((_resolve, reject) => {
signal?.addEventListener('abort', () => {
reject(cancellationError);
});
// Simulate work that might happen if not aborted immediately
const timeoutId = setTimeout(
() =>
reject(
new Error('Should have been cancelled if not aborted prior'),
),
100,
);
signal?.addEventListener('abort', () => clearTimeout(timeoutId));
});
// Simulate work that might happen if not aborted immediately
const timeoutId = setTimeout(
() =>
reject(
new Error('Should have been cancelled if not aborted prior'),
),
100,
);
signal?.addEventListener('abort', () => clearTimeout(timeoutId));
});
});
},
);
abortController.abort(); // Abort before calling
const response = await executeToolCall(
@@ -232,7 +207,7 @@ describe('executeToolCall', () => {
returnDisplay: 'Image processed',
};
vi.mocked(mockToolRegistry.getTool).mockReturnValue(mockTool);
vi.mocked(mockTool.execute).mockResolvedValue(toolResult);
vi.spyOn(mockTool, 'buildAndExecute').mockResolvedValue(toolResult);
const response = await executeToolCall(
mockConfig,

View File

@@ -5,6 +5,7 @@
*/
import {
FileDiff,
logToolCall,
ToolCallRequestInfo,
ToolCallResponseInfo,
@@ -14,6 +15,7 @@ import {
} from '../index.js';
import { Config } from '../config/config.js';
import { convertToFunctionResponse } from './coreToolScheduler.js';
import { ToolCallDecision } from '../telemetry/tool-call-decision.js';
/**
* Executes a single tool call non-interactively.
@@ -64,7 +66,7 @@ export async function executeToolCall(
try {
// Directly execute without confirmation or live output handling
const effectiveAbortSignal = abortSignal ?? new AbortController().signal;
const toolResult: ToolResult = await tool.execute(
const toolResult: ToolResult = await tool.buildAndExecute(
toolCallRequest.args,
effectiveAbortSignal,
// No live output callback for non-interactive mode
@@ -74,6 +76,24 @@ export async function executeToolCall(
const tool_display = toolResult.returnDisplay;
// eslint-disable-next-line @typescript-eslint/no-explicit-any
let metadata: { [key: string]: any } = {};
if (
toolResult.error === undefined &&
typeof tool_display === 'object' &&
tool_display !== null &&
'diffStat' in tool_display
) {
const diffStat = (tool_display as FileDiff).diffStat;
if (diffStat) {
metadata = {
ai_added_lines: diffStat.ai_added_lines,
ai_removed_lines: diffStat.ai_removed_lines,
user_added_lines: diffStat.user_added_lines,
user_removed_lines: diffStat.user_removed_lines,
};
}
}
const durationMs = Date.now() - startTime;
logToolCall(config, {
'event.name': 'tool_call',
@@ -87,6 +107,8 @@ export async function executeToolCall(
error_type:
toolResult.error === undefined ? undefined : toolResult.error.type,
prompt_id: toolCallRequest.prompt_id,
metadata,
decision: ToolCallDecision.AUTO_ACCEPT,
});
const response = convertToFunctionResponse(

View File

@@ -0,0 +1,814 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { vi, describe, it, expect, beforeEach, Mock, afterEach } from 'vitest';
import {
ContextState,
SubAgentScope,
SubagentTerminateMode,
PromptConfig,
ModelConfig,
RunConfig,
OutputConfig,
ToolConfig,
} from './subagent.js';
import { Config, ConfigParameters } from '../config/config.js';
import { GeminiChat } from './geminiChat.js';
import { createContentGenerator } from './contentGenerator.js';
import { getEnvironmentContext } from '../utils/environmentContext.js';
import { executeToolCall } from './nonInteractiveToolExecutor.js';
import { ToolRegistry } from '../tools/tool-registry.js';
import { DEFAULT_GEMINI_MODEL } from '../config/models.js';
import {
Content,
FunctionCall,
FunctionDeclaration,
GenerateContentConfig,
Type,
} from '@google/genai';
import { ToolErrorType } from '../tools/tool-error.js';
vi.mock('./geminiChat.js');
vi.mock('./contentGenerator.js');
vi.mock('../utils/environmentContext.js');
vi.mock('./nonInteractiveToolExecutor.js');
vi.mock('../ide/ide-client.js');
async function createMockConfig(
toolRegistryMocks = {},
): Promise<{ config: Config; toolRegistry: ToolRegistry }> {
const configParams: ConfigParameters = {
sessionId: 'test-session',
model: DEFAULT_GEMINI_MODEL,
targetDir: '.',
debugMode: false,
cwd: process.cwd(),
};
const config = new Config(configParams);
await config.initialize();
// eslint-disable-next-line @typescript-eslint/no-explicit-any
await config.refreshAuth('test-auth' as any);
// Mock ToolRegistry
const mockToolRegistry = {
getTool: vi.fn(),
getFunctionDeclarationsFiltered: vi.fn().mockReturnValue([]),
...toolRegistryMocks,
} as unknown as ToolRegistry;
vi.spyOn(config, 'getToolRegistry').mockResolvedValue(mockToolRegistry);
return { config, toolRegistry: mockToolRegistry };
}
// Helper to simulate LLM responses (sequence of tool calls over multiple turns)
const createMockStream = (
functionCallsList: Array<FunctionCall[] | 'stop'>,
) => {
let index = 0;
return vi.fn().mockImplementation(() => {
const response = functionCallsList[index] || 'stop';
index++;
return (async function* () {
if (response === 'stop') {
// When stopping, the model might return text, but the subagent logic primarily cares about the absence of functionCalls.
yield { text: 'Done.' };
} else if (response.length > 0) {
yield { functionCalls: response };
} else {
yield { text: 'Done.' }; // Handle empty array also as stop
}
})();
});
};
describe('subagent.ts', () => {
describe('ContextState', () => {
it('should set and get values correctly', () => {
const context = new ContextState();
context.set('key1', 'value1');
context.set('key2', 123);
expect(context.get('key1')).toBe('value1');
expect(context.get('key2')).toBe(123);
expect(context.get_keys()).toEqual(['key1', 'key2']);
});
it('should return undefined for missing keys', () => {
const context = new ContextState();
expect(context.get('missing')).toBeUndefined();
});
});
describe('SubAgentScope', () => {
let mockSendMessageStream: Mock;
const defaultModelConfig: ModelConfig = {
model: 'gemini-1.5-flash-latest',
temp: 0.5, // Specific temp to test override
top_p: 1,
};
const defaultRunConfig: RunConfig = {
max_time_minutes: 5,
max_turns: 10,
};
beforeEach(async () => {
vi.clearAllMocks();
vi.mocked(getEnvironmentContext).mockResolvedValue([
{ text: 'Env Context' },
]);
vi.mocked(createContentGenerator).mockResolvedValue({
getGenerativeModel: vi.fn(),
// eslint-disable-next-line @typescript-eslint/no-explicit-any
} as any);
mockSendMessageStream = vi.fn();
// We mock the implementation of the constructor.
vi.mocked(GeminiChat).mockImplementation(
() =>
({
sendMessageStream: mockSendMessageStream,
}) as unknown as GeminiChat,
);
});
afterEach(() => {
vi.restoreAllMocks();
});
// Helper to safely access generationConfig from mock calls
const getGenerationConfigFromMock = (
callIndex = 0,
): GenerateContentConfig & { systemInstruction?: string | Content } => {
const callArgs = vi.mocked(GeminiChat).mock.calls[callIndex];
const generationConfig = callArgs?.[2];
// Ensure it's defined before proceeding
expect(generationConfig).toBeDefined();
if (!generationConfig) throw new Error('generationConfig is undefined');
return generationConfig as GenerateContentConfig & {
systemInstruction?: string | Content;
};
};
describe('create (Tool Validation)', () => {
const promptConfig: PromptConfig = { systemPrompt: 'Test prompt' };
it('should create a SubAgentScope successfully with minimal config', async () => {
const { config } = await createMockConfig();
const scope = await SubAgentScope.create(
'test-agent',
config,
promptConfig,
defaultModelConfig,
defaultRunConfig,
);
expect(scope).toBeInstanceOf(SubAgentScope);
});
it('should throw an error if a tool requires confirmation', async () => {
const mockTool = {
schema: { parameters: { type: Type.OBJECT, properties: {} } },
build: vi.fn().mockReturnValue({
shouldConfirmExecute: vi.fn().mockResolvedValue({
type: 'exec',
title: 'Confirm',
command: 'rm -rf /',
}),
}),
};
const { config } = await createMockConfig({
// eslint-disable-next-line @typescript-eslint/no-explicit-any
getTool: vi.fn().mockReturnValue(mockTool as any),
});
const toolConfig: ToolConfig = { tools: ['risky_tool'] };
await expect(
SubAgentScope.create(
'test-agent',
config,
promptConfig,
defaultModelConfig,
defaultRunConfig,
toolConfig,
),
).rejects.toThrow(
'Tool "risky_tool" requires user confirmation and cannot be used in a non-interactive subagent.',
);
});
it('should succeed if tools do not require confirmation', async () => {
const mockTool = {
schema: { parameters: { type: Type.OBJECT, properties: {} } },
build: vi.fn().mockReturnValue({
shouldConfirmExecute: vi.fn().mockResolvedValue(null),
}),
};
const { config } = await createMockConfig({
// eslint-disable-next-line @typescript-eslint/no-explicit-any
getTool: vi.fn().mockReturnValue(mockTool as any),
});
const toolConfig: ToolConfig = { tools: ['safe_tool'] };
const scope = await SubAgentScope.create(
'test-agent',
config,
promptConfig,
defaultModelConfig,
defaultRunConfig,
toolConfig,
);
expect(scope).toBeInstanceOf(SubAgentScope);
});
it('should skip interactivity check and warn for tools with required parameters', async () => {
const consoleWarnSpy = vi
.spyOn(console, 'warn')
.mockImplementation(() => {});
const mockToolWithParams = {
schema: {
parameters: {
type: Type.OBJECT,
properties: {
path: { type: Type.STRING },
},
required: ['path'],
},
},
// build should not be called, but we mock it to be safe
build: vi.fn(),
};
const { config } = await createMockConfig({
getTool: vi.fn().mockReturnValue(mockToolWithParams),
});
const toolConfig: ToolConfig = { tools: ['tool_with_params'] };
// The creation should succeed without throwing
const scope = await SubAgentScope.create(
'test-agent',
config,
promptConfig,
defaultModelConfig,
defaultRunConfig,
toolConfig,
);
expect(scope).toBeInstanceOf(SubAgentScope);
// Check that the warning was logged
expect(consoleWarnSpy).toHaveBeenCalledWith(
'Cannot check tool "tool_with_params" for interactivity because it requires parameters. Assuming it is safe for non-interactive use.',
);
// Ensure build was never called
expect(mockToolWithParams.build).not.toHaveBeenCalled();
consoleWarnSpy.mockRestore();
});
});
describe('runNonInteractive - Initialization and Prompting', () => {
it('should correctly template the system prompt and initialize GeminiChat', async () => {
const { config } = await createMockConfig();
vi.mocked(GeminiChat).mockClear();
const promptConfig: PromptConfig = {
systemPrompt: 'Hello ${name}, your task is ${task}.',
};
const context = new ContextState();
context.set('name', 'Agent');
context.set('task', 'Testing');
// Model stops immediately
mockSendMessageStream.mockImplementation(createMockStream(['stop']));
const scope = await SubAgentScope.create(
'test-agent',
config,
promptConfig,
defaultModelConfig,
defaultRunConfig,
);
await scope.runNonInteractive(context);
// Check if GeminiChat was initialized correctly by the subagent
expect(GeminiChat).toHaveBeenCalledTimes(1);
const callArgs = vi.mocked(GeminiChat).mock.calls[0];
// Check Generation Config
const generationConfig = getGenerationConfigFromMock();
// Check temperature override
expect(generationConfig.temperature).toBe(defaultModelConfig.temp);
expect(generationConfig.systemInstruction).toContain(
'Hello Agent, your task is Testing.',
);
expect(generationConfig.systemInstruction).toContain(
'Important Rules:',
);
// Check History (should include environment context)
const history = callArgs[3];
expect(history).toEqual([
{ role: 'user', parts: [{ text: 'Env Context' }] },
{
role: 'model',
parts: [{ text: 'Got it. Thanks for the context!' }],
},
]);
});
it('should include output instructions in the system prompt when outputs are defined', async () => {
const { config } = await createMockConfig();
vi.mocked(GeminiChat).mockClear();
const promptConfig: PromptConfig = { systemPrompt: 'Do the task.' };
const outputConfig: OutputConfig = {
outputs: {
result1: 'The first result',
},
};
const context = new ContextState();
// Model stops immediately
mockSendMessageStream.mockImplementation(createMockStream(['stop']));
const scope = await SubAgentScope.create(
'test-agent',
config,
promptConfig,
defaultModelConfig,
defaultRunConfig,
undefined, // ToolConfig
outputConfig,
);
await scope.runNonInteractive(context);
const generationConfig = getGenerationConfigFromMock();
const systemInstruction = generationConfig.systemInstruction as string;
expect(systemInstruction).toContain('Do the task.');
expect(systemInstruction).toContain(
'you MUST emit the required output variables',
);
expect(systemInstruction).toContain(
"Use 'self.emitvalue' to emit the 'result1' key",
);
});
it('should use initialMessages instead of systemPrompt if provided', async () => {
const { config } = await createMockConfig();
vi.mocked(GeminiChat).mockClear();
const initialMessages: Content[] = [
{ role: 'user', parts: [{ text: 'Hi' }] },
];
const promptConfig: PromptConfig = { initialMessages };
const context = new ContextState();
// Model stops immediately
mockSendMessageStream.mockImplementation(createMockStream(['stop']));
const scope = await SubAgentScope.create(
'test-agent',
config,
promptConfig,
defaultModelConfig,
defaultRunConfig,
);
await scope.runNonInteractive(context);
const callArgs = vi.mocked(GeminiChat).mock.calls[0];
const generationConfig = getGenerationConfigFromMock();
const history = callArgs[3];
expect(generationConfig.systemInstruction).toBeUndefined();
expect(history).toEqual([
{ role: 'user', parts: [{ text: 'Env Context' }] },
{
role: 'model',
parts: [{ text: 'Got it. Thanks for the context!' }],
},
...initialMessages,
]);
});
it('should throw an error if template variables are missing', async () => {
const { config } = await createMockConfig();
const promptConfig: PromptConfig = {
systemPrompt: 'Hello ${name}, you are missing ${missing}.',
};
const context = new ContextState();
context.set('name', 'Agent');
// 'missing' is not set
const scope = await SubAgentScope.create(
'test-agent',
config,
promptConfig,
defaultModelConfig,
defaultRunConfig,
);
// The error from templating causes the runNonInteractive to reject and the terminate_reason to be ERROR.
await expect(scope.runNonInteractive(context)).rejects.toThrow(
'Missing context values for the following keys: missing',
);
expect(scope.output.terminate_reason).toBe(SubagentTerminateMode.ERROR);
});
it('should validate that systemPrompt and initialMessages are mutually exclusive', async () => {
const { config } = await createMockConfig();
const promptConfig: PromptConfig = {
systemPrompt: 'System',
initialMessages: [{ role: 'user', parts: [{ text: 'Hi' }] }],
};
const context = new ContextState();
const agent = await SubAgentScope.create(
'TestAgent',
config,
promptConfig,
defaultModelConfig,
defaultRunConfig,
);
await expect(agent.runNonInteractive(context)).rejects.toThrow(
'PromptConfig cannot have both `systemPrompt` and `initialMessages` defined.',
);
expect(agent.output.terminate_reason).toBe(SubagentTerminateMode.ERROR);
});
});
describe('runNonInteractive - Execution and Tool Use', () => {
const promptConfig: PromptConfig = { systemPrompt: 'Execute task.' };
it('should terminate with GOAL if no outputs are expected and model stops', async () => {
const { config } = await createMockConfig();
// Model stops immediately
mockSendMessageStream.mockImplementation(createMockStream(['stop']));
const scope = await SubAgentScope.create(
'test-agent',
config,
promptConfig,
defaultModelConfig,
defaultRunConfig,
// No ToolConfig, No OutputConfig
);
await scope.runNonInteractive(new ContextState());
expect(scope.output.terminate_reason).toBe(SubagentTerminateMode.GOAL);
expect(scope.output.emitted_vars).toEqual({});
expect(mockSendMessageStream).toHaveBeenCalledTimes(1);
// Check the initial message
expect(mockSendMessageStream.mock.calls[0][0].message).toEqual([
{ text: 'Get Started!' },
]);
});
it('should handle self.emitvalue and terminate with GOAL when outputs are met', async () => {
const { config } = await createMockConfig();
const outputConfig: OutputConfig = {
outputs: { result: 'The final result' },
};
// Turn 1: Model responds with emitvalue call
// Turn 2: Model stops after receiving the tool response
mockSendMessageStream.mockImplementation(
createMockStream([
[
{
name: 'self.emitvalue',
args: {
emit_variable_name: 'result',
emit_variable_value: 'Success!',
},
},
],
'stop',
]),
);
const scope = await SubAgentScope.create(
'test-agent',
config,
promptConfig,
defaultModelConfig,
defaultRunConfig,
undefined,
outputConfig,
);
await scope.runNonInteractive(new ContextState());
expect(scope.output.terminate_reason).toBe(SubagentTerminateMode.GOAL);
expect(scope.output.emitted_vars).toEqual({ result: 'Success!' });
expect(mockSendMessageStream).toHaveBeenCalledTimes(2);
// Check the tool response sent back in the second call
const secondCallArgs = mockSendMessageStream.mock.calls[1][0];
expect(secondCallArgs.message).toEqual([
{ text: 'Emitted variable result successfully' },
]);
});
it('should execute external tools and provide the response to the model', async () => {
const listFilesToolDef: FunctionDeclaration = {
name: 'list_files',
description: 'Lists files',
parameters: { type: Type.OBJECT, properties: {} },
};
const { config, toolRegistry } = await createMockConfig({
getFunctionDeclarationsFiltered: vi
.fn()
.mockReturnValue([listFilesToolDef]),
});
const toolConfig: ToolConfig = { tools: ['list_files'] };
// Turn 1: Model calls the external tool
// Turn 2: Model stops
mockSendMessageStream.mockImplementation(
createMockStream([
[
{
id: 'call_1',
name: 'list_files',
args: { path: '.' },
},
],
'stop',
]),
);
// Mock the tool execution result
vi.mocked(executeToolCall).mockResolvedValue({
callId: 'call_1',
responseParts: 'file1.txt\nfile2.ts',
resultDisplay: 'Listed 2 files',
error: undefined,
errorType: undefined, // Or ToolErrorType.NONE if available and appropriate
});
const scope = await SubAgentScope.create(
'test-agent',
config,
promptConfig,
defaultModelConfig,
defaultRunConfig,
toolConfig,
);
await scope.runNonInteractive(new ContextState());
// Check tool execution
expect(executeToolCall).toHaveBeenCalledWith(
config,
expect.objectContaining({ name: 'list_files', args: { path: '.' } }),
toolRegistry,
expect.any(AbortSignal),
);
// Check the response sent back to the model
const secondCallArgs = mockSendMessageStream.mock.calls[1][0];
expect(secondCallArgs.message).toEqual([
{ text: 'file1.txt\nfile2.ts' },
]);
expect(scope.output.terminate_reason).toBe(SubagentTerminateMode.GOAL);
});
it('should provide specific tool error responses to the model', async () => {
const { config } = await createMockConfig();
const toolConfig: ToolConfig = { tools: ['failing_tool'] };
// Turn 1: Model calls the failing tool
// Turn 2: Model stops after receiving the error response
mockSendMessageStream.mockImplementation(
createMockStream([
[
{
id: 'call_fail',
name: 'failing_tool',
args: {},
},
],
'stop',
]),
);
// Mock the tool execution failure.
vi.mocked(executeToolCall).mockResolvedValue({
callId: 'call_fail',
responseParts: 'ERROR: Tool failed catastrophically', // This should be sent to the model
resultDisplay: 'Tool failed catastrophically',
error: new Error('Failure'),
errorType: ToolErrorType.INVALID_TOOL_PARAMS,
});
const scope = await SubAgentScope.create(
'test-agent',
config,
promptConfig,
defaultModelConfig,
defaultRunConfig,
toolConfig,
);
await scope.runNonInteractive(new ContextState());
// The agent should send the specific error message from responseParts.
const secondCallArgs = mockSendMessageStream.mock.calls[1][0];
expect(secondCallArgs.message).toEqual([
{
text: 'ERROR: Tool failed catastrophically',
},
]);
});
it('should nudge the model if it stops before emitting all required variables', async () => {
const { config } = await createMockConfig();
const outputConfig: OutputConfig = {
outputs: { required_var: 'Must be present' },
};
// Turn 1: Model stops prematurely
// Turn 2: Model responds to the nudge and emits the variable
// Turn 3: Model stops
mockSendMessageStream.mockImplementation(
createMockStream([
'stop',
[
{
name: 'self.emitvalue',
args: {
emit_variable_name: 'required_var',
emit_variable_value: 'Here it is',
},
},
],
'stop',
]),
);
const scope = await SubAgentScope.create(
'test-agent',
config,
promptConfig,
defaultModelConfig,
defaultRunConfig,
undefined,
outputConfig,
);
await scope.runNonInteractive(new ContextState());
// Check the nudge message sent in Turn 2
const secondCallArgs = mockSendMessageStream.mock.calls[1][0];
// We check that the message contains the required variable name and the nudge phrasing.
expect(secondCallArgs.message[0].text).toContain('required_var');
expect(secondCallArgs.message[0].text).toContain(
'You have stopped calling tools',
);
expect(scope.output.terminate_reason).toBe(SubagentTerminateMode.GOAL);
expect(scope.output.emitted_vars).toEqual({
required_var: 'Here it is',
});
expect(mockSendMessageStream).toHaveBeenCalledTimes(3);
});
});
describe('runNonInteractive - Termination and Recovery', () => {
const promptConfig: PromptConfig = { systemPrompt: 'Execute task.' };
it('should terminate with MAX_TURNS if the limit is reached', async () => {
const { config } = await createMockConfig();
const runConfig: RunConfig = { ...defaultRunConfig, max_turns: 2 };
// Model keeps looping by calling emitvalue repeatedly
mockSendMessageStream.mockImplementation(
createMockStream([
[
{
name: 'self.emitvalue',
args: { emit_variable_name: 'loop', emit_variable_value: 'v1' },
},
],
[
{
name: 'self.emitvalue',
args: { emit_variable_name: 'loop', emit_variable_value: 'v2' },
},
],
// This turn should not happen
[
{
name: 'self.emitvalue',
args: { emit_variable_name: 'loop', emit_variable_value: 'v3' },
},
],
]),
);
const scope = await SubAgentScope.create(
'test-agent',
config,
promptConfig,
defaultModelConfig,
runConfig,
);
await scope.runNonInteractive(new ContextState());
expect(mockSendMessageStream).toHaveBeenCalledTimes(2);
expect(scope.output.terminate_reason).toBe(
SubagentTerminateMode.MAX_TURNS,
);
});
it('should terminate with TIMEOUT if the time limit is reached during an LLM call', async () => {
// Use fake timers to reliably test timeouts
vi.useFakeTimers();
const { config } = await createMockConfig();
const runConfig: RunConfig = { max_time_minutes: 5, max_turns: 100 };
// We need to control the resolution of the sendMessageStream promise to advance the timer during execution.
let resolveStream: (
value: AsyncGenerator<unknown, void, unknown>,
) => void;
const streamPromise = new Promise<
AsyncGenerator<unknown, void, unknown>
>((resolve) => {
// eslint-disable-next-line @typescript-eslint/no-explicit-any
resolveStream = resolve as any;
});
// The LLM call will hang until we resolve the promise.
mockSendMessageStream.mockReturnValue(streamPromise);
const scope = await SubAgentScope.create(
'test-agent',
config,
promptConfig,
defaultModelConfig,
runConfig,
);
const runPromise = scope.runNonInteractive(new ContextState());
// Advance time beyond the limit (6 minutes) while the agent is awaiting the LLM response.
await vi.advanceTimersByTimeAsync(6 * 60 * 1000);
// Now resolve the stream. The model returns 'stop'.
// eslint-disable-next-line @typescript-eslint/no-explicit-any
resolveStream!(createMockStream(['stop'])() as any);
await runPromise;
expect(scope.output.terminate_reason).toBe(
SubagentTerminateMode.TIMEOUT,
);
expect(mockSendMessageStream).toHaveBeenCalledTimes(1);
vi.useRealTimers();
});
it('should terminate with ERROR if the model call throws', async () => {
const { config } = await createMockConfig();
mockSendMessageStream.mockRejectedValue(new Error('API Failure'));
const scope = await SubAgentScope.create(
'test-agent',
config,
promptConfig,
defaultModelConfig,
defaultRunConfig,
);
await expect(
scope.runNonInteractive(new ContextState()),
).rejects.toThrow('API Failure');
expect(scope.output.terminate_reason).toBe(SubagentTerminateMode.ERROR);
});
});
});
});

View File

@@ -0,0 +1,681 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { reportError } from '../utils/errorReporting.js';
import { ToolRegistry } from '../tools/tool-registry.js';
import { Config } from '../config/config.js';
import { ToolCallRequestInfo } from './turn.js';
import { executeToolCall } from './nonInteractiveToolExecutor.js';
import { createContentGenerator } from './contentGenerator.js';
import { getEnvironmentContext } from '../utils/environmentContext.js';
import {
Content,
Part,
FunctionCall,
GenerateContentConfig,
FunctionDeclaration,
Type,
} from '@google/genai';
import { GeminiChat } from './geminiChat.js';
/**
* @fileoverview Defines the configuration interfaces for a subagent.
*
* These interfaces specify the structure for defining the subagent's prompt,
* the model parameters, and the execution settings.
*/
/**
* Describes the possible termination modes for a subagent.
* This enum provides a clear indication of why a subagent's execution might have ended.
*/
export enum SubagentTerminateMode {
/**
* Indicates that the subagent's execution terminated due to an unrecoverable error.
*/
ERROR = 'ERROR',
/**
* Indicates that the subagent's execution terminated because it exceeded the maximum allowed working time.
*/
TIMEOUT = 'TIMEOUT',
/**
* Indicates that the subagent's execution successfully completed all its defined goals.
*/
GOAL = 'GOAL',
/**
* Indicates that the subagent's execution terminated because it exceeded the maximum number of turns.
*/
MAX_TURNS = 'MAX_TURNS',
}
/**
* Represents the output structure of a subagent's execution.
* This interface defines the data that a subagent will return upon completion,
* including any emitted variables and the reason for its termination.
*/
export interface OutputObject {
/**
* A record of key-value pairs representing variables emitted by the subagent
* during its execution. These variables can be used by the calling agent.
*/
emitted_vars: Record<string, string>;
/**
* The reason for the subagent's termination, indicating whether it completed
* successfully, timed out, or encountered an error.
*/
terminate_reason: SubagentTerminateMode;
}
/**
* Configures the initial prompt for the subagent.
*/
export interface PromptConfig {
/**
* A single system prompt string that defines the subagent's persona and instructions.
* Note: You should use either `systemPrompt` or `initialMessages`, but not both.
*/
systemPrompt?: string;
/**
* An array of user/model content pairs to seed the chat history for few-shot prompting.
* Note: You should use either `systemPrompt` or `initialMessages`, but not both.
*/
initialMessages?: Content[];
}
/**
* Configures the tools available to the subagent during its execution.
*/
export interface ToolConfig {
/**
* A list of tool names (from the tool registry) or full function declarations
* that the subagent is permitted to use.
*/
tools: Array<string | FunctionDeclaration>;
}
/**
* Configures the expected outputs for the subagent.
*/
export interface OutputConfig {
/**
* A record describing the variables the subagent is expected to emit.
* The subagent will be prompted to generate these values before terminating.
*/
outputs: Record<string, string>;
}
/**
* Configures the generative model parameters for the subagent.
* This interface specifies the model to be used and its associated generation settings,
* such as temperature and top-p values, which influence the creativity and diversity of the model's output.
*/
export interface ModelConfig {
/**
* The name or identifier of the model to be used (e.g., 'gemini-2.5-pro').
*
* TODO: In the future, this needs to support 'auto' or some other string to support routing use cases.
*/
model: string;
/**
* The temperature for the model's sampling process.
*/
temp: number;
/**
* The top-p value for nucleus sampling.
*/
top_p: number;
}
/**
* Configures the execution environment and constraints for the subagent.
* This interface defines parameters that control the subagent's runtime behavior,
* such as maximum execution time, to prevent infinite loops or excessive resource consumption.
*
* TODO: Consider adding max_tokens as a form of budgeting.
*/
export interface RunConfig {
/** The maximum execution time for the subagent in minutes. */
max_time_minutes: number;
/**
* The maximum number of conversational turns (a user message + model response)
* before the execution is terminated. Helps prevent infinite loops.
*/
max_turns?: number;
}
/**
* Manages the runtime context state for the subagent.
* This class provides a mechanism to store and retrieve key-value pairs
* that represent the dynamic state and variables accessible to the subagent
* during its execution.
*/
export class ContextState {
private state: Record<string, unknown> = {};
/**
* Retrieves a value from the context state.
*
* @param key - The key of the value to retrieve.
* @returns The value associated with the key, or undefined if the key is not found.
*/
get(key: string): unknown {
return this.state[key];
}
/**
* Sets a value in the context state.
*
* @param key - The key to set the value under.
* @param value - The value to set.
*/
set(key: string, value: unknown): void {
this.state[key] = value;
}
/**
* Retrieves all keys in the context state.
*
* @returns An array of all keys in the context state.
*/
get_keys(): string[] {
return Object.keys(this.state);
}
}
/**
* Replaces `${...}` placeholders in a template string with values from a context.
*
* This function identifies all placeholders in the format `${key}`, validates that
* each key exists in the provided `ContextState`, and then performs the substitution.
*
* @param template The template string containing placeholders.
* @param context The `ContextState` object providing placeholder values.
* @returns The populated string with all placeholders replaced.
* @throws {Error} if any placeholder key is not found in the context.
*/
function templateString(template: string, context: ContextState): string {
const placeholderRegex = /\$\{(\w+)\}/g;
// First, find all unique keys required by the template.
const requiredKeys = new Set(
Array.from(template.matchAll(placeholderRegex), (match) => match[1]),
);
// Check if all required keys exist in the context.
const contextKeys = new Set(context.get_keys());
const missingKeys = Array.from(requiredKeys).filter(
(key) => !contextKeys.has(key),
);
if (missingKeys.length > 0) {
throw new Error(
`Missing context values for the following keys: ${missingKeys.join(
', ',
)}`,
);
}
// Perform the replacement using a replacer function.
return template.replace(placeholderRegex, (_match, key) =>
String(context.get(key)),
);
}
/**
* Represents the scope and execution environment for a subagent.
* This class orchestrates the subagent's lifecycle, managing its chat interactions,
* runtime context, and the collection of its outputs.
*/
export class SubAgentScope {
output: OutputObject = {
terminate_reason: SubagentTerminateMode.ERROR,
emitted_vars: {},
};
private readonly subagentId: string;
/**
* Constructs a new SubAgentScope instance.
* @param name - The name for the subagent, used for logging and identification.
* @param runtimeContext - The shared runtime configuration and services.
* @param promptConfig - Configuration for the subagent's prompt and behavior.
* @param modelConfig - Configuration for the generative model parameters.
* @param runConfig - Configuration for the subagent's execution environment.
* @param toolConfig - Optional configuration for tools available to the subagent.
* @param outputConfig - Optional configuration for the subagent's expected outputs.
*/
private constructor(
readonly name: string,
readonly runtimeContext: Config,
private readonly promptConfig: PromptConfig,
private readonly modelConfig: ModelConfig,
private readonly runConfig: RunConfig,
private readonly toolConfig?: ToolConfig,
private readonly outputConfig?: OutputConfig,
) {
const randomPart = Math.random().toString(36).slice(2, 8);
this.subagentId = `${this.name}-${randomPart}`;
}
/**
* Creates and validates a new SubAgentScope instance.
* This factory method ensures that all tools provided in the prompt configuration
* are valid for non-interactive use before creating the subagent instance.
* @param {string} name - The name of the subagent.
* @param {Config} runtimeContext - The shared runtime configuration and services.
* @param {PromptConfig} promptConfig - Configuration for the subagent's prompt and behavior.
* @param {ModelConfig} modelConfig - Configuration for the generative model parameters.
* @param {RunConfig} runConfig - Configuration for the subagent's execution environment.
* @param {ToolConfig} [toolConfig] - Optional configuration for tools.
* @param {OutputConfig} [outputConfig] - Optional configuration for expected outputs.
* @returns {Promise<SubAgentScope>} A promise that resolves to a valid SubAgentScope instance.
* @throws {Error} If any tool requires user confirmation.
*/
static async create(
name: string,
runtimeContext: Config,
promptConfig: PromptConfig,
modelConfig: ModelConfig,
runConfig: RunConfig,
toolConfig?: ToolConfig,
outputConfig?: OutputConfig,
): Promise<SubAgentScope> {
if (toolConfig) {
const toolRegistry: ToolRegistry = await runtimeContext.getToolRegistry();
const toolsToLoad: string[] = [];
for (const tool of toolConfig.tools) {
if (typeof tool === 'string') {
toolsToLoad.push(tool);
}
}
for (const toolName of toolsToLoad) {
const tool = toolRegistry.getTool(toolName);
if (tool) {
const requiredParams = tool.schema.parameters?.required ?? [];
if (requiredParams.length > 0) {
// This check is imperfect. A tool might require parameters but still
// be interactive (e.g., `delete_file(path)`). However, we cannot
// build a generic invocation without knowing what dummy parameters
// to provide. Crashing here because `build({})` fails is worse
// than allowing a potential hang later if an interactive tool is
// used. This is a best-effort check.
console.warn(
`Cannot check tool "${toolName}" for interactivity because it requires parameters. Assuming it is safe for non-interactive use.`,
);
continue;
}
const invocation = tool.build({});
const confirmationDetails = await invocation.shouldConfirmExecute(
new AbortController().signal,
);
if (confirmationDetails) {
throw new Error(
`Tool "${toolName}" requires user confirmation and cannot be used in a non-interactive subagent.`,
);
}
}
}
}
return new SubAgentScope(
name,
runtimeContext,
promptConfig,
modelConfig,
runConfig,
toolConfig,
outputConfig,
);
}
/**
* Runs the subagent in a non-interactive mode.
* This method orchestrates the subagent's execution loop, including prompt templating,
* tool execution, and termination conditions.
* @param {ContextState} context - The current context state containing variables for prompt templating.
* @returns {Promise<void>} A promise that resolves when the subagent has completed its execution.
*/
async runNonInteractive(context: ContextState): Promise<void> {
const chat = await this.createChatObject(context);
if (!chat) {
this.output.terminate_reason = SubagentTerminateMode.ERROR;
return;
}
const abortController = new AbortController();
const toolRegistry: ToolRegistry =
await this.runtimeContext.getToolRegistry();
// Prepare the list of tools available to the subagent.
const toolsList: FunctionDeclaration[] = [];
if (this.toolConfig) {
const toolsToLoad: string[] = [];
for (const tool of this.toolConfig.tools) {
if (typeof tool === 'string') {
toolsToLoad.push(tool);
} else {
toolsList.push(tool);
}
}
toolsList.push(
...toolRegistry.getFunctionDeclarationsFiltered(toolsToLoad),
);
}
// Add local scope functions if outputs are expected.
if (this.outputConfig && this.outputConfig.outputs) {
toolsList.push(...this.getScopeLocalFuncDefs());
}
let currentMessages: Content[] = [
{ role: 'user', parts: [{ text: 'Get Started!' }] },
];
const startTime = Date.now();
let turnCounter = 0;
try {
while (true) {
// Check termination conditions.
if (
this.runConfig.max_turns &&
turnCounter >= this.runConfig.max_turns
) {
this.output.terminate_reason = SubagentTerminateMode.MAX_TURNS;
break;
}
let durationMin = (Date.now() - startTime) / (1000 * 60);
if (durationMin >= this.runConfig.max_time_minutes) {
this.output.terminate_reason = SubagentTerminateMode.TIMEOUT;
break;
}
const promptId = `${this.runtimeContext.getSessionId()}#${this.subagentId}#${turnCounter++}`;
const messageParams = {
message: currentMessages[0]?.parts || [],
config: {
abortSignal: abortController.signal,
tools: [{ functionDeclarations: toolsList }],
},
};
const responseStream = await chat.sendMessageStream(
messageParams,
promptId,
);
const functionCalls: FunctionCall[] = [];
for await (const resp of responseStream) {
if (abortController.signal.aborted) return;
if (resp.functionCalls) functionCalls.push(...resp.functionCalls);
}
durationMin = (Date.now() - startTime) / (1000 * 60);
if (durationMin >= this.runConfig.max_time_minutes) {
this.output.terminate_reason = SubagentTerminateMode.TIMEOUT;
break;
}
if (functionCalls.length > 0) {
currentMessages = await this.processFunctionCalls(
functionCalls,
toolRegistry,
abortController,
promptId,
);
} else {
// Model stopped calling tools. Check if goal is met.
if (
!this.outputConfig ||
Object.keys(this.outputConfig.outputs).length === 0
) {
this.output.terminate_reason = SubagentTerminateMode.GOAL;
break;
}
const remainingVars = Object.keys(this.outputConfig.outputs).filter(
(key) => !(key in this.output.emitted_vars),
);
if (remainingVars.length === 0) {
this.output.terminate_reason = SubagentTerminateMode.GOAL;
break;
}
const nudgeMessage = `You have stopped calling tools but have not emitted the following required variables: ${remainingVars.join(
', ',
)}. Please use the 'self.emitvalue' tool to emit them now, or continue working if necessary.`;
console.debug(nudgeMessage);
currentMessages = [
{
role: 'user',
parts: [{ text: nudgeMessage }],
},
];
}
}
} catch (error) {
console.error('Error during subagent execution:', error);
this.output.terminate_reason = SubagentTerminateMode.ERROR;
throw error;
}
}
/**
* Processes a list of function calls, executing each one and collecting their responses.
* This method iterates through the provided function calls, executes them using the
* `executeToolCall` function (or handles `self.emitvalue` internally), and aggregates
* their results. It also manages error reporting for failed tool executions.
* @param {FunctionCall[]} functionCalls - An array of `FunctionCall` objects to process.
* @param {ToolRegistry} toolRegistry - The tool registry to look up and execute tools.
* @param {AbortController} abortController - An `AbortController` to signal cancellation of tool executions.
* @returns {Promise<Content[]>} A promise that resolves to an array of `Content` parts representing the tool responses,
* which are then used to update the chat history.
*/
private async processFunctionCalls(
functionCalls: FunctionCall[],
toolRegistry: ToolRegistry,
abortController: AbortController,
promptId: string,
): Promise<Content[]> {
const toolResponseParts: Part[] = [];
for (const functionCall of functionCalls) {
const callId = functionCall.id ?? `${functionCall.name}-${Date.now()}`;
const requestInfo: ToolCallRequestInfo = {
callId,
name: functionCall.name as string,
args: (functionCall.args ?? {}) as Record<string, unknown>,
isClientInitiated: true,
prompt_id: promptId,
};
let toolResponse;
// Handle scope-local tools first.
if (functionCall.name === 'self.emitvalue') {
const valName = String(requestInfo.args['emit_variable_name']);
const valVal = String(requestInfo.args['emit_variable_value']);
this.output.emitted_vars[valName] = valVal;
toolResponse = {
callId,
responseParts: `Emitted variable ${valName} successfully`,
resultDisplay: `Emitted variable ${valName} successfully`,
error: undefined,
};
} else {
toolResponse = await executeToolCall(
this.runtimeContext,
requestInfo,
toolRegistry,
abortController.signal,
);
}
if (toolResponse.error) {
console.error(
`Error executing tool ${functionCall.name}: ${toolResponse.resultDisplay || toolResponse.error.message}`,
);
}
if (toolResponse.responseParts) {
const parts = Array.isArray(toolResponse.responseParts)
? toolResponse.responseParts
: [toolResponse.responseParts];
for (const part of parts) {
if (typeof part === 'string') {
toolResponseParts.push({ text: part });
} else if (part) {
toolResponseParts.push(part);
}
}
}
}
// If all tool calls failed, inform the model so it can re-evaluate.
if (functionCalls.length > 0 && toolResponseParts.length === 0) {
toolResponseParts.push({
text: 'All tool calls failed. Please analyze the errors and try an alternative approach.',
});
}
return [{ role: 'user', parts: toolResponseParts }];
}
private async createChatObject(context: ContextState) {
if (!this.promptConfig.systemPrompt && !this.promptConfig.initialMessages) {
throw new Error(
'PromptConfig must have either `systemPrompt` or `initialMessages` defined.',
);
}
if (this.promptConfig.systemPrompt && this.promptConfig.initialMessages) {
throw new Error(
'PromptConfig cannot have both `systemPrompt` and `initialMessages` defined.',
);
}
const envParts = await getEnvironmentContext(this.runtimeContext);
const envHistory: Content[] = [
{ role: 'user', parts: envParts },
{ role: 'model', parts: [{ text: 'Got it. Thanks for the context!' }] },
];
const start_history = [
...envHistory,
...(this.promptConfig.initialMessages ?? []),
];
const systemInstruction = this.promptConfig.systemPrompt
? this.buildChatSystemPrompt(context)
: undefined;
try {
const generationConfig: GenerateContentConfig & {
systemInstruction?: string | Content;
} = {
temperature: this.modelConfig.temp,
topP: this.modelConfig.top_p,
};
if (systemInstruction) {
generationConfig.systemInstruction = systemInstruction;
}
const contentGenerator = await createContentGenerator(
this.runtimeContext.getContentGeneratorConfig(),
this.runtimeContext,
this.runtimeContext.getSessionId(),
);
this.runtimeContext.setModel(this.modelConfig.model);
return new GeminiChat(
this.runtimeContext,
contentGenerator,
generationConfig,
start_history,
);
} catch (error) {
await reportError(
error,
'Error initializing Gemini chat session.',
start_history,
'startChat',
);
// The calling function will handle the undefined return.
return undefined;
}
}
/**
* Returns an array of FunctionDeclaration objects for tools that are local to the subagent's scope.
* Currently, this includes the `self.emitvalue` tool for emitting variables.
* @returns An array of `FunctionDeclaration` objects.
*/
private getScopeLocalFuncDefs() {
const emitValueTool: FunctionDeclaration = {
name: 'self.emitvalue',
description: `* This tool emits A SINGLE return value from this execution, such that it can be collected and presented to the calling function.
* You can only emit ONE VALUE each time you call this tool. You are expected to call this tool MULTIPLE TIMES if you have MULTIPLE OUTPUTS.`,
parameters: {
type: Type.OBJECT,
properties: {
emit_variable_name: {
description: 'This is the name of the variable to be returned.',
type: Type.STRING,
},
emit_variable_value: {
description:
'This is the _value_ to be returned for this variable.',
type: Type.STRING,
},
},
required: ['emit_variable_name', 'emit_variable_value'],
},
};
return [emitValueTool];
}
/**
* Builds the system prompt for the chat based on the provided configurations.
* It templates the base system prompt and appends instructions for emitting
* variables if an `OutputConfig` is provided.
* @param {ContextState} context - The context for templating.
* @returns {string} The complete system prompt.
*/
private buildChatSystemPrompt(context: ContextState): string {
if (!this.promptConfig.systemPrompt) {
// This should ideally be caught in createChatObject, but serves as a safeguard.
return '';
}
let finalPrompt = templateString(this.promptConfig.systemPrompt, context);
// Add instructions for emitting variables if needed.
if (this.outputConfig && this.outputConfig.outputs) {
let outputInstructions =
'\n\nAfter you have achieved all other goals, you MUST emit the required output variables. For each expected output, make one final call to the `self.emitvalue` tool.';
for (const [key, value] of Object.entries(this.outputConfig.outputs)) {
outputInstructions += `\n* Use 'self.emitvalue' to emit the '${key}' key, with a value described as: '${value}'`;
}
finalPrompt += outputInstructions;
}
// Add general non-interactive instructions.
finalPrompt += `
Important Rules:
* You are running in a non-interactive mode. You CANNOT ask the user for input or clarification. You must proceed with the information you have.
* Once you believe all goals have been met and all required outputs have been emitted, stop calling tools.`;
return finalPrompt;
}
}

View File

@@ -17,12 +17,14 @@ import { GeminiChat } from './geminiChat.js';
const mockSendMessageStream = vi.fn();
const mockGetHistory = vi.fn();
const mockMaybeIncludeSchemaDepthContext = vi.fn();
vi.mock('@google/genai', async (importOriginal) => {
const actual = await importOriginal<typeof import('@google/genai')>();
const MockChat = vi.fn().mockImplementation(() => ({
sendMessageStream: mockSendMessageStream,
getHistory: mockGetHistory,
maybeIncludeSchemaDepthContext: mockMaybeIncludeSchemaDepthContext,
}));
return {
...actual,
@@ -46,6 +48,7 @@ describe('Turn', () => {
type MockedChatInstance = {
sendMessageStream: typeof mockSendMessageStream;
getHistory: typeof mockGetHistory;
maybeIncludeSchemaDepthContext: typeof mockMaybeIncludeSchemaDepthContext;
};
let mockChatInstance: MockedChatInstance;
@@ -54,6 +57,7 @@ describe('Turn', () => {
mockChatInstance = {
sendMessageStream: mockSendMessageStream,
getHistory: mockGetHistory,
maybeIncludeSchemaDepthContext: mockMaybeIncludeSchemaDepthContext,
};
turn = new Turn(mockChatInstance as unknown as GeminiChat, 'prompt-id-1');
mockGetHistory.mockReturnValue([]);
@@ -200,7 +204,7 @@ describe('Turn', () => {
{ role: 'model', parts: [{ text: 'Previous history' }] },
];
mockGetHistory.mockReturnValue(historyContent);
mockMaybeIncludeSchemaDepthContext.mockResolvedValue(undefined);
const events = [];
for await (const event of turn.run(
reqParts,

View File

@@ -288,6 +288,7 @@ export class Turn {
message: getErrorMessage(error),
status,
};
await this.chat.maybeIncludeSchemaDepthContext(structuredError);
yield { type: GeminiEventType.Error, value: { error: structuredError } };
return;
}

View File

@@ -6,12 +6,33 @@
export enum DetectedIde {
VSCode = 'vscode',
VSCodium = 'vscodium',
Cursor = 'cursor',
CloudShell = 'cloudshell',
Codespaces = 'codespaces',
Windsurf = 'windsurf',
FirebaseStudio = 'firebasestudio',
Trae = 'trae',
}
export function getIdeDisplayName(ide: DetectedIde): string {
switch (ide) {
case DetectedIde.VSCode:
return 'VS Code';
case DetectedIde.VSCodium:
return 'VSCodium';
case DetectedIde.Cursor:
return 'Cursor';
case DetectedIde.CloudShell:
return 'Cloud Shell';
case DetectedIde.Codespaces:
return 'GitHub Codespaces';
case DetectedIde.Windsurf:
return 'Windsurf';
case DetectedIde.FirebaseStudio:
return 'Firebase Studio';
case DetectedIde.Trae:
return 'Trae';
default: {
// This ensures that if a new IDE is added to the enum, we get a compile-time error.
const exhaustiveCheck: never = ide;
@@ -21,8 +42,24 @@ export function getIdeDisplayName(ide: DetectedIde): string {
}
export function detectIde(): DetectedIde | undefined {
if (process.env.TERM_PROGRAM === 'vscode') {
return DetectedIde.VSCode;
// Only VSCode-based integrations are currently supported.
if (process.env.TERM_PROGRAM !== 'vscode') {
return undefined;
}
return undefined;
if (process.env.CURSOR_TRACE_ID) {
return DetectedIde.Cursor;
}
if (process.env.CODESPACES) {
return DetectedIde.Codespaces;
}
if (process.env.EDITOR_IN_CLOUD_SHELL) {
return DetectedIde.CloudShell;
}
if (process.env.TERM_PRODUCT === 'Trae') {
return DetectedIde.Trae;
}
if (process.env.FIREBASE_DEPLOY_AGENT) {
return DetectedIde.FirebaseStudio;
}
return DetectedIde.VSCode;
}

View File

@@ -4,18 +4,29 @@
* SPDX-License-Identifier: Apache-2.0
*/
import * as fs from 'node:fs';
import * as path from 'node:path';
import {
detectIde,
DetectedIde,
getIdeDisplayName,
} from '../ide/detect-ide.js';
import { ideContext, IdeContextNotificationSchema } from '../ide/ideContext.js';
import {
ideContext,
IdeContextNotificationSchema,
IdeDiffAcceptedNotificationSchema,
IdeDiffClosedNotificationSchema,
CloseDiffResponseSchema,
DiffUpdateResult,
} from '../ide/ideContext.js';
import { Client } from '@modelcontextprotocol/sdk/client/index.js';
import { StreamableHTTPClientTransport } from '@modelcontextprotocol/sdk/client/streamableHttp.js';
const logger = {
// eslint-disable-next-line @typescript-eslint/no-explicit-any
debug: (...args: any[]) => console.debug('[DEBUG] [IDEClient]', ...args),
// eslint-disable-next-line @typescript-eslint/no-explicit-any
error: (...args: any[]) => console.error('[ERROR] [IDEClient]', ...args),
};
export type IDEConnectionState = {
@@ -29,6 +40,16 @@ export enum IDEConnectionStatus {
Connecting = 'connecting',
}
function getRealPath(path: string): string {
try {
return fs.realpathSync(path);
} catch (_e) {
// If realpathSync fails, it might be because the path doesn't exist.
// In that case, we can fall back to the original path.
return path;
}
}
/**
* Manages the connection to and interaction with the IDE server.
*/
@@ -42,6 +63,7 @@ export class IdeClient {
};
private readonly currentIde: DetectedIde | undefined;
private readonly currentIdeDisplayName: string | undefined;
private diffResponses = new Map<string, (result: DiffUpdateResult) => void>();
private constructor() {
this.currentIde = detectIde();
@@ -58,13 +80,21 @@ export class IdeClient {
}
async connect(): Promise<void> {
this.setState(IDEConnectionStatus.Connecting);
if (!this.currentIde || !this.currentIdeDisplayName) {
this.setState(IDEConnectionStatus.Disconnected);
this.setState(
IDEConnectionStatus.Disconnected,
`IDE integration is not supported in your current environment. To use this feature, run Gemini CLI in one of these supported IDEs: ${Object.values(
DetectedIde,
)
.map((ide) => getIdeDisplayName(ide))
.join(', ')}`,
false,
);
return;
}
this.setState(IDEConnectionStatus.Connecting);
if (!this.validateWorkspacePath()) {
return;
}
@@ -77,7 +107,83 @@ export class IdeClient {
await this.establishConnection(port);
}
disconnect() {
/**
* A diff is accepted with any modifications if the user performs one of the
* following actions:
* - Clicks the checkbox icon in the IDE to accept
* - Runs `command+shift+p` > "Gemini CLI: Accept Diff in IDE" to accept
* - Selects "accept" in the CLI UI
* - Saves the file via `ctrl/command+s`
*
* A diff is rejected if the user performs one of the following actions:
* - Clicks the "x" icon in the IDE
* - Runs "Gemini CLI: Close Diff in IDE"
* - Selects "no" in the CLI UI
* - Closes the file
*/
async openDiff(
filePath: string,
newContent?: string,
): Promise<DiffUpdateResult> {
return new Promise<DiffUpdateResult>((resolve, reject) => {
this.diffResponses.set(filePath, resolve);
this.client
?.callTool({
name: `openDiff`,
arguments: {
filePath,
newContent,
},
})
.catch((err) => {
logger.debug(`callTool for ${filePath} failed:`, err);
reject(err);
});
});
}
async closeDiff(filePath: string): Promise<string | undefined> {
try {
const result = await this.client?.callTool({
name: `closeDiff`,
arguments: {
filePath,
},
});
if (result) {
const parsed = CloseDiffResponseSchema.parse(result);
return parsed.content;
}
} catch (err) {
logger.debug(`callTool for ${filePath} failed:`, err);
}
return;
}
// Closes the diff. Instead of waiting for a notification,
// manually resolves the diff resolver as the desired outcome.
async resolveDiffFromCli(filePath: string, outcome: 'accepted' | 'rejected') {
const content = await this.closeDiff(filePath);
const resolver = this.diffResponses.get(filePath);
if (resolver) {
if (outcome === 'accepted') {
resolver({ status: 'accepted', content });
} else {
resolver({ status: 'rejected', content: undefined });
}
this.diffResponses.delete(filePath);
}
}
async disconnect() {
if (this.state.status === IDEConnectionStatus.Disconnected) {
return;
}
for (const filePath of this.diffResponses.keys()) {
await this.closeDiff(filePath);
}
this.diffResponses.clear();
this.setState(
IDEConnectionStatus.Disconnected,
'IDE integration disabled. To enable it again, run /ide enable.',
@@ -93,19 +199,35 @@ export class IdeClient {
return this.state;
}
private setState(status: IDEConnectionStatus, details?: string) {
getDetectedIdeDisplayName(): string | undefined {
return this.currentIdeDisplayName;
}
private setState(
status: IDEConnectionStatus,
details?: string,
logToConsole = false,
) {
const isAlreadyDisconnected =
this.state.status === IDEConnectionStatus.Disconnected &&
status === IDEConnectionStatus.Disconnected;
// Only update details if the state wasn't already disconnected, so that
// the first detail message is preserved.
// Only update details & log to console if the state wasn't already
// disconnected, so that the first detail message is preserved.
if (!isAlreadyDisconnected) {
this.state = { status, details };
if (details) {
if (logToConsole) {
logger.error(details);
} else {
// We only want to log disconnect messages to debug
// if they are not already being logged to the console.
logger.debug(details);
}
}
}
if (status === IDEConnectionStatus.Disconnected) {
logger.debug('IDE integration disconnected:', details);
ideContext.clearIdeContext();
}
}
@@ -116,6 +238,7 @@ export class IdeClient {
this.setState(
IDEConnectionStatus.Disconnected,
`Failed to connect to IDE companion extension for ${this.currentIdeDisplayName}. Please ensure the extension is running and try refreshing your terminal. To install the extension, run /ide install.`,
true,
);
return false;
}
@@ -123,13 +246,19 @@ export class IdeClient {
this.setState(
IDEConnectionStatus.Disconnected,
`To use this feature, please open a single workspace folder in ${this.currentIdeDisplayName} and try again.`,
true,
);
return false;
}
if (ideWorkspacePath !== process.cwd()) {
const idePath = getRealPath(ideWorkspacePath).toLocaleLowerCase();
const cwd = getRealPath(process.cwd()).toLocaleLowerCase();
const rel = path.relative(idePath, cwd);
if (rel.startsWith('..') || path.isAbsolute(rel)) {
this.setState(
IDEConnectionStatus.Disconnected,
`Directory mismatch. Gemini CLI is running in a different location than the open workspace in ${this.currentIdeDisplayName}. Please run the CLI from the same directory as your project's root folder.`,
true,
);
return false;
}
@@ -141,7 +270,8 @@ export class IdeClient {
if (!port) {
this.setState(
IDEConnectionStatus.Disconnected,
`Failed to connect to IDE companion extension for ${this.currentIdeDisplayName}. Please ensure the extension is running and try refreshing your terminal. To install the extension, run /ide install.`,
`Failed to connect to IDE companion extension for ${this.currentIdeDisplayName}. Please ensure the extension is running and try restarting your terminal. To install the extension, run /ide install.`,
true,
);
return undefined;
}
@@ -163,14 +293,43 @@ export class IdeClient {
this.setState(
IDEConnectionStatus.Disconnected,
`IDE connection error. The connection was lost unexpectedly. Please try reconnecting by running /ide enable`,
true,
);
};
this.client.onclose = () => {
this.setState(
IDEConnectionStatus.Disconnected,
`IDE connection error. The connection was lost unexpectedly. Please try reconnecting by running /ide enable`,
true,
);
};
this.client.setNotificationHandler(
IdeDiffAcceptedNotificationSchema,
(notification) => {
const { filePath, content } = notification.params;
const resolver = this.diffResponses.get(filePath);
if (resolver) {
resolver({ status: 'accepted', content });
this.diffResponses.delete(filePath);
} else {
logger.debug(`No resolver found for ${filePath}`);
}
},
);
this.client.setNotificationHandler(
IdeDiffClosedNotificationSchema,
(notification) => {
const { filePath } = notification.params;
const resolver = this.diffResponses.get(filePath);
if (resolver) {
resolver({ status: 'rejected', content: undefined });
this.diffResponses.delete(filePath);
} else {
logger.debug(`No resolver found for ${filePath}`);
}
},
);
}
private async establishConnection(port: string) {
@@ -183,7 +342,7 @@ export class IdeClient {
});
transport = new StreamableHTTPClientTransport(
new URL(`http://localhost:${port}/mcp`),
new URL(`http://${getIdeServerHost()}:${port}/mcp`),
);
this.registerClientHandlers();
@@ -194,7 +353,8 @@ export class IdeClient {
} catch (_error) {
this.setState(
IDEConnectionStatus.Disconnected,
`Failed to connect to IDE companion extension for ${this.currentIdeDisplayName}. Please ensure the extension is running and try refreshing your terminal. To install the extension, run /ide install.`,
`Failed to connect to IDE companion extension for ${this.currentIdeDisplayName}. Please ensure the extension is running and try restarting your terminal. To install the extension, run /ide install.`,
true,
);
if (transport) {
try {
@@ -236,11 +396,13 @@ export class IdeClient {
this.client?.close();
}
getDetectedIdeDisplayName(): string | undefined {
return this.currentIdeDisplayName;
}
setDisconnected() {
this.setState(IDEConnectionStatus.Disconnected);
}
}
function getIdeServerHost() {
const isInContainer =
fs.existsSync('/.dockerenv') || fs.existsSync('/run/.containerenv');
return isInContainer ? 'host.docker.internal' : 'localhost';
}

View File

@@ -24,9 +24,17 @@ describe('ide-installer', () => {
expect(installer).toBeInstanceOf(Object);
});
it('should return null for an unknown IDE', () => {
it('should return an OpenVSXInstaller for "vscodium"', () => {
const installer = getIdeInstaller(DetectedIde.VSCodium);
expect(installer).not.toBeNull();
expect(installer).toBeInstanceOf(Object);
});
it('should return a DefaultIDEInstaller for an unknown IDE', () => {
const installer = getIdeInstaller('unknown' as DetectedIde);
expect(installer).toBeNull();
// Assuming DefaultIDEInstaller is the fallback
expect(installer).not.toBeNull();
expect(installer).toBeInstanceOf(Object);
});
});
@@ -59,4 +67,44 @@ describe('ide-installer', () => {
});
});
});
describe('OpenVSXInstaller', () => {
let installer: IdeInstaller;
beforeEach(() => {
installer = getIdeInstaller(DetectedIde.VSCodium)!;
});
afterEach(() => {
vi.restoreAllMocks();
});
describe('install', () => {
it('should call execSync with the correct command and return success', async () => {
const execSyncSpy = vi
.spyOn(child_process, 'execSync')
.mockImplementation(() => '');
const result = await installer.install();
expect(execSyncSpy).toHaveBeenCalledWith(
'npx ovsx get google.gemini-cli-vscode-ide-companion',
{ stdio: 'pipe' },
);
expect(result.success).toBe(true);
expect(result.message).toContain(
'VS Code companion extension was installed successfully from OpenVSX',
);
});
it('should return a failure message on failed installation', async () => {
vi.spyOn(child_process, 'execSync').mockImplementation(() => {
throw new Error('Command failed');
});
const result = await installer.install();
expect(result.success).toBe(false);
expect(result.message).toContain(
'Failed to install VS Code companion extension from OpenVSX',
);
});
});
});
});

View File

@@ -147,11 +147,31 @@ class VsCodeInstaller implements IdeInstaller {
}
}
class OpenVSXInstaller implements IdeInstaller {
async install(): Promise<InstallResult> {
// TODO: Use the correct extension path.
const command = `npx ovsx get google.gemini-cli-vscode-ide-companion`;
try {
child_process.execSync(command, { stdio: 'pipe' });
return {
success: true,
message:
'VS Code companion extension was installed successfully from OpenVSX. Please restart your terminal to complete the setup.',
};
} catch (_error) {
return {
success: false,
message: `Failed to install VS Code companion extension from OpenVSX. Please try installing it manually.`,
};
}
}
}
export function getIdeInstaller(ide: DetectedIde): IdeInstaller | null {
switch (ide) {
case DetectedIde.VSCode:
return new VsCodeInstaller();
default:
return null;
return new OpenVSXInstaller();
}
}

View File

@@ -36,10 +36,69 @@ export type IdeContext = z.infer<typeof IdeContextSchema>;
* Zod schema for validating the 'ide/contextUpdate' notification from the IDE.
*/
export const IdeContextNotificationSchema = z.object({
jsonrpc: z.literal('2.0'),
method: z.literal('ide/contextUpdate'),
params: IdeContextSchema,
});
export const IdeDiffAcceptedNotificationSchema = z.object({
jsonrpc: z.literal('2.0'),
method: z.literal('ide/diffAccepted'),
params: z.object({
filePath: z.string(),
content: z.string(),
}),
});
export const IdeDiffClosedNotificationSchema = z.object({
jsonrpc: z.literal('2.0'),
method: z.literal('ide/diffClosed'),
params: z.object({
filePath: z.string(),
content: z.string().optional(),
}),
});
export const CloseDiffResponseSchema = z
.object({
content: z
.array(
z.object({
text: z.string(),
type: z.literal('text'),
}),
)
.min(1),
})
.transform((val, ctx) => {
try {
const parsed = JSON.parse(val.content[0].text);
const innerSchema = z.object({ content: z.string().optional() });
const validationResult = innerSchema.safeParse(parsed);
if (!validationResult.success) {
validationResult.error.issues.forEach((issue) => ctx.addIssue(issue));
return z.NEVER;
}
return validationResult.data;
} catch (_) {
ctx.addIssue({
code: z.ZodIssueCode.custom,
message: 'Invalid JSON in text content',
});
return z.NEVER;
}
});
export type DiffUpdateResult =
| {
status: 'accepted';
content?: string;
}
| {
status: 'rejected';
content: undefined;
};
type IdeContextSubscriber = (ideContext: IdeContext | undefined) => void;
/**

View File

@@ -99,5 +99,3 @@ export { sessionId } from './utils/session.js';
export * from './utils/browser.js';
// OpenAI Logging Utilities
export { OpenAILogger, openaiLogger } from './utils/openaiLogger.js';
export { default as OpenAILogViewer } from './utils/openaiLogViewer.js';
export { default as OpenAIAnalytics } from './utils/openaiAnalytics.js';

View File

@@ -12,34 +12,78 @@ import { MCPServerConfig } from '../config/config.js';
vi.mock('google-auth-library');
describe('GoogleCredentialProvider', () => {
const validConfig = {
url: 'https://test.googleapis.com',
oauth: {
scopes: ['scope1', 'scope2'],
},
} as MCPServerConfig;
it('should throw an error if no scopes are provided', () => {
expect(() => new GoogleCredentialProvider()).toThrow(
const config = {
url: 'https://test.googleapis.com',
} as MCPServerConfig;
expect(() => new GoogleCredentialProvider(config)).toThrow(
'Scopes must be provided in the oauth config for Google Credentials provider',
);
});
it('should use scopes from the config if provided', () => {
new GoogleCredentialProvider(validConfig);
expect(GoogleAuth).toHaveBeenCalledWith({
scopes: ['scope1', 'scope2'],
});
});
it('should throw an error for a non-allowlisted host', () => {
const config = {
url: 'https://example.com',
oauth: {
scopes: ['scope1', 'scope2'],
},
} as MCPServerConfig;
expect(() => new GoogleCredentialProvider(config)).toThrow(
'Host "example.com" is not an allowed host for Google Credential provider.',
);
});
it('should allow luci.app', () => {
const config = {
url: 'https://luci.app',
oauth: {
scopes: ['scope1', 'scope2'],
},
} as MCPServerConfig;
new GoogleCredentialProvider(config);
expect(GoogleAuth).toHaveBeenCalledWith({
scopes: ['scope1', 'scope2'],
});
});
it('should allow sub.luci.app', () => {
const config = {
url: 'https://sub.luci.app',
oauth: {
scopes: ['scope1', 'scope2'],
},
} as MCPServerConfig;
new GoogleCredentialProvider(config);
});
it('should not allow googleapis.com without a subdomain', () => {
const config = {
url: 'https://googleapis.com',
oauth: {
scopes: ['scope1', 'scope2'],
},
} as MCPServerConfig;
expect(() => new GoogleCredentialProvider(config)).toThrow(
'Host "googleapis.com" is not an allowed host for Google Credential provider.',
);
});
describe('with provider instance', () => {
let provider: GoogleCredentialProvider;
beforeEach(() => {
const config = {
oauth: {
scopes: ['scope1', 'scope2'],
},
} as MCPServerConfig;
provider = new GoogleCredentialProvider(config);
provider = new GoogleCredentialProvider(validConfig);
vi.clearAllMocks();
});

View File

@@ -14,6 +14,8 @@ import {
import { GoogleAuth } from 'google-auth-library';
import { MCPServerConfig } from '../config/config.js';
const ALLOWED_HOSTS = [/^.+\.googleapis\.com$/, /^(.*\.)?luci\.app$/];
export class GoogleCredentialProvider implements OAuthClientProvider {
private readonly auth: GoogleAuth;
@@ -29,6 +31,20 @@ export class GoogleCredentialProvider implements OAuthClientProvider {
private _clientInformation?: OAuthClientInformationFull;
constructor(private readonly config?: MCPServerConfig) {
const url = this.config?.url || this.config?.httpUrl;
if (!url) {
throw new Error(
'URL must be provided in the config for Google Credentials provider',
);
}
const hostname = new URL(url).hostname;
if (!ALLOWED_HOSTS.some((pattern) => pattern.test(hostname))) {
throw new Error(
`Host "${hostname}" is not an allowed host for Google Credential provider.`,
);
}
const scopes = this.config?.oauth?.scopes;
if (!scopes || scopes.length === 0) {
throw new Error(

View File

@@ -202,6 +202,80 @@ describe('LoopDetectionService', () => {
expect(loggers.logLoopDetected).not.toHaveBeenCalled();
});
it('should not detect loops when content transitions into a code block', () => {
service.reset('');
const repeatedContent = createRepetitiveContent(1, CONTENT_CHUNK_SIZE);
// Add some repetitive content outside of code block
for (let i = 0; i < CONTENT_LOOP_THRESHOLD - 2; i++) {
service.addAndCheck(createContentEvent(repeatedContent));
}
// Now transition into a code block - this should prevent loop detection
// even though we were already close to the threshold
const codeBlockStart = '```javascript\n';
const isLoop = service.addAndCheck(createContentEvent(codeBlockStart));
expect(isLoop).toBe(false);
// Continue adding repetitive content inside the code block - should not trigger loop
for (let i = 0; i < CONTENT_LOOP_THRESHOLD; i++) {
const isLoopInside = service.addAndCheck(
createContentEvent(repeatedContent),
);
expect(isLoopInside).toBe(false);
}
expect(loggers.logLoopDetected).not.toHaveBeenCalled();
});
it('should skip loop detection when already inside a code block (this.inCodeBlock)', () => {
service.reset('');
// Start with content that puts us inside a code block
service.addAndCheck(createContentEvent('Here is some code:\n```\n'));
// Verify we are now inside a code block and any content should be ignored for loop detection
const repeatedContent = createRepetitiveContent(1, CONTENT_CHUNK_SIZE);
for (let i = 0; i < CONTENT_LOOP_THRESHOLD + 5; i++) {
const isLoop = service.addAndCheck(createContentEvent(repeatedContent));
expect(isLoop).toBe(false);
}
expect(loggers.logLoopDetected).not.toHaveBeenCalled();
});
it('should correctly track inCodeBlock state with multiple fence transitions', () => {
service.reset('');
const repeatedContent = createRepetitiveContent(1, CONTENT_CHUNK_SIZE);
// Outside code block - should track content
service.addAndCheck(createContentEvent('Normal text '));
// Enter code block (1 fence) - should stop tracking
const enterResult = service.addAndCheck(createContentEvent('```\n'));
expect(enterResult).toBe(false);
// Inside code block - should not track loops
for (let i = 0; i < 5; i++) {
const insideResult = service.addAndCheck(
createContentEvent(repeatedContent),
);
expect(insideResult).toBe(false);
}
// Exit code block (2nd fence) - should reset tracking but still return false
const exitResult = service.addAndCheck(createContentEvent('```\n'));
expect(exitResult).toBe(false);
// Enter code block again (3rd fence) - should stop tracking again
const reenterResult = service.addAndCheck(
createContentEvent('```python\n'),
);
expect(reenterResult).toBe(false);
expect(loggers.logLoopDetected).not.toHaveBeenCalled();
});
it('should detect a loop when repetitive content is outside a code block', () => {
service.reset('');
const repeatedContent = createRepetitiveContent(1, CONTENT_CHUNK_SIZE);
@@ -281,6 +355,200 @@ describe('LoopDetectionService', () => {
expect(isLoop).toBe(false);
}
expect(loggers.logLoopDetected).not.toHaveBeenCalled();
});
it('should reset tracking when a table is detected', () => {
service.reset('');
const repeatedContent = createRepetitiveContent(1, CONTENT_CHUNK_SIZE);
for (let i = 0; i < CONTENT_LOOP_THRESHOLD - 1; i++) {
service.addAndCheck(createContentEvent(repeatedContent));
}
// This should reset tracking and not trigger a loop
service.addAndCheck(createContentEvent('| Column 1 | Column 2 |'));
// Add more repeated content after table - should not trigger loop
for (let i = 0; i < CONTENT_LOOP_THRESHOLD - 1; i++) {
const isLoop = service.addAndCheck(createContentEvent(repeatedContent));
expect(isLoop).toBe(false);
}
expect(loggers.logLoopDetected).not.toHaveBeenCalled();
});
it('should reset tracking when a list item is detected', () => {
service.reset('');
const repeatedContent = createRepetitiveContent(1, CONTENT_CHUNK_SIZE);
for (let i = 0; i < CONTENT_LOOP_THRESHOLD - 1; i++) {
service.addAndCheck(createContentEvent(repeatedContent));
}
// This should reset tracking and not trigger a loop
service.addAndCheck(createContentEvent('* List item'));
// Add more repeated content after list - should not trigger loop
for (let i = 0; i < CONTENT_LOOP_THRESHOLD - 1; i++) {
const isLoop = service.addAndCheck(createContentEvent(repeatedContent));
expect(isLoop).toBe(false);
}
expect(loggers.logLoopDetected).not.toHaveBeenCalled();
});
it('should reset tracking when a heading is detected', () => {
service.reset('');
const repeatedContent = createRepetitiveContent(1, CONTENT_CHUNK_SIZE);
for (let i = 0; i < CONTENT_LOOP_THRESHOLD - 1; i++) {
service.addAndCheck(createContentEvent(repeatedContent));
}
// This should reset tracking and not trigger a loop
service.addAndCheck(createContentEvent('## Heading'));
// Add more repeated content after heading - should not trigger loop
for (let i = 0; i < CONTENT_LOOP_THRESHOLD - 1; i++) {
const isLoop = service.addAndCheck(createContentEvent(repeatedContent));
expect(isLoop).toBe(false);
}
expect(loggers.logLoopDetected).not.toHaveBeenCalled();
});
it('should reset tracking when a blockquote is detected', () => {
service.reset('');
const repeatedContent = createRepetitiveContent(1, CONTENT_CHUNK_SIZE);
for (let i = 0; i < CONTENT_LOOP_THRESHOLD - 1; i++) {
service.addAndCheck(createContentEvent(repeatedContent));
}
// This should reset tracking and not trigger a loop
service.addAndCheck(createContentEvent('> Quote text'));
// Add more repeated content after blockquote - should not trigger loop
for (let i = 0; i < CONTENT_LOOP_THRESHOLD - 1; i++) {
const isLoop = service.addAndCheck(createContentEvent(repeatedContent));
expect(isLoop).toBe(false);
}
expect(loggers.logLoopDetected).not.toHaveBeenCalled();
});
it('should reset tracking for various list item formats', () => {
const repeatedContent = createRepetitiveContent(1, CONTENT_CHUNK_SIZE);
// Test different list formats - make sure they start at beginning of line
const listFormats = [
'* Bullet item',
'- Dash item',
'+ Plus item',
'1. Numbered item',
'42. Another numbered item',
];
listFormats.forEach((listFormat, index) => {
service.reset('');
// Build up to near threshold
for (let i = 0; i < CONTENT_LOOP_THRESHOLD - 1; i++) {
service.addAndCheck(createContentEvent(repeatedContent));
}
// Reset should occur with list item - add newline to ensure it starts at beginning
service.addAndCheck(createContentEvent('\n' + listFormat));
// Should not trigger loop after reset - use different content to avoid any cached state issues
const newRepeatedContent = createRepetitiveContent(
index + 100,
CONTENT_CHUNK_SIZE,
);
for (let i = 0; i < CONTENT_LOOP_THRESHOLD - 1; i++) {
const isLoop = service.addAndCheck(
createContentEvent(newRepeatedContent),
);
expect(isLoop).toBe(false);
}
});
expect(loggers.logLoopDetected).not.toHaveBeenCalled();
});
it('should reset tracking for various table formats', () => {
const repeatedContent = createRepetitiveContent(1, CONTENT_CHUNK_SIZE);
const tableFormats = [
'| Column 1 | Column 2 |',
'|---|---|',
'|++|++|',
'+---+---+',
];
tableFormats.forEach((tableFormat, index) => {
service.reset('');
// Build up to near threshold
for (let i = 0; i < CONTENT_LOOP_THRESHOLD - 1; i++) {
service.addAndCheck(createContentEvent(repeatedContent));
}
// Reset should occur with table format - add newline to ensure it starts at beginning
service.addAndCheck(createContentEvent('\n' + tableFormat));
// Should not trigger loop after reset - use different content to avoid any cached state issues
const newRepeatedContent = createRepetitiveContent(
index + 200,
CONTENT_CHUNK_SIZE,
);
for (let i = 0; i < CONTENT_LOOP_THRESHOLD - 1; i++) {
const isLoop = service.addAndCheck(
createContentEvent(newRepeatedContent),
);
expect(isLoop).toBe(false);
}
});
expect(loggers.logLoopDetected).not.toHaveBeenCalled();
});
it('should reset tracking for various heading levels', () => {
const repeatedContent = createRepetitiveContent(1, CONTENT_CHUNK_SIZE);
const headingFormats = [
'# H1 Heading',
'## H2 Heading',
'### H3 Heading',
'#### H4 Heading',
'##### H5 Heading',
'###### H6 Heading',
];
headingFormats.forEach((headingFormat, index) => {
service.reset('');
// Build up to near threshold
for (let i = 0; i < CONTENT_LOOP_THRESHOLD - 1; i++) {
service.addAndCheck(createContentEvent(repeatedContent));
}
// Reset should occur with heading - add newline to ensure it starts at beginning
service.addAndCheck(createContentEvent('\n' + headingFormat));
// Should not trigger loop after reset - use different content to avoid any cached state issues
const newRepeatedContent = createRepetitiveContent(
index + 300,
CONTENT_CHUNK_SIZE,
);
for (let i = 0; i < CONTENT_LOOP_THRESHOLD - 1; i++) {
const isLoop = service.addAndCheck(
createContentEvent(newRepeatedContent),
);
expect(isLoop).toBe(false);
}
});
expect(loggers.logLoopDetected).not.toHaveBeenCalled();
});
});

View File

@@ -9,7 +9,6 @@ import { GeminiEventType, ServerGeminiStreamEvent } from '../core/turn.js';
import { logLoopDetected } from '../telemetry/loggers.js';
import { LoopDetectedEvent, LoopType } from '../telemetry/types.js';
import { Config, DEFAULT_GEMINI_FLASH_MODEL } from '../config/config.js';
import { SchemaUnion, Type } from '@google/genai';
const TOOL_CALL_LOOP_THRESHOLD = 5;
const CONTENT_LOOP_THRESHOLD = 10;
@@ -161,20 +160,26 @@ export class LoopDetectionService {
* as repetitive code structures are common and not necessarily loops.
*/
private checkContentLoop(content: string): boolean {
// Code blocks can often contain repetitive syntax that is not indicative of a loop.
// To avoid false positives, we detect when we are inside a code block and
// temporarily disable loop detection.
// Different content elements can often contain repetitive syntax that is not indicative of a loop.
// To avoid false positives, we detect when we encounter different content types and
// reset tracking to avoid analyzing content that spans across different element boundaries.
const numFences = (content.match(/```/g) ?? []).length;
if (numFences) {
// Reset tracking when a code fence is detected to avoid analyzing content
// that spans across code block boundaries.
const hasTable = /(^|\n)\s*(\|.*\||[|+-]{3,})/.test(content);
const hasListItem =
/(^|\n)\s*[*-+]\s/.test(content) || /(^|\n)\s*\d+\.\s/.test(content);
const hasHeading = /(^|\n)#+\s/.test(content);
const hasBlockquote = /(^|\n)>\s/.test(content);
if (numFences || hasTable || hasListItem || hasHeading || hasBlockquote) {
// Reset tracking when different content elements are detected to avoid analyzing content
// that spans across different element boundaries.
this.resetContentTracking();
}
const wasInCodeBlock = this.inCodeBlock;
this.inCodeBlock =
numFences % 2 === 0 ? this.inCodeBlock : !this.inCodeBlock;
if (wasInCodeBlock) {
if (wasInCodeBlock || this.inCodeBlock) {
return false;
}
@@ -335,16 +340,16 @@ Please analyze the conversation history to determine the possibility that the co
...recentHistory,
{ role: 'user', parts: [{ text: prompt }] },
];
const schema: SchemaUnion = {
type: Type.OBJECT,
const schema: Record<string, unknown> = {
type: 'object',
properties: {
reasoning: {
type: Type.STRING,
type: 'string',
description:
'Your reasoning on if the conversation is looping without forward progress.',
},
confidence: {
type: Type.NUMBER,
type: 'number',
description:
'A number between 0.0 and 1.0 representing your confidence that the conversation is in an unproductive state.',
},

View File

@@ -185,6 +185,16 @@ describe('ShellExecutionService', () => {
expect(result.error).toBe(spawnError);
expect(result.exitCode).toBe(1);
});
it('handles errors that do not fire the exit event', async () => {
const error = new Error('spawn abc ENOENT');
const { result } = await simulateExecution('touch cat.jpg', (cp) => {
cp.emit('error', error); // No exit event is fired.
});
expect(result.error).toBe(error);
expect(result.exitCode).toBe(1);
});
});
describe('Aborting Commands', () => {

View File

@@ -174,7 +174,19 @@ export class ShellExecutionService {
child.stdout.on('data', (data) => handleOutput(data, 'stdout'));
child.stderr.on('data', (data) => handleOutput(data, 'stderr'));
child.on('error', (err) => {
const { stdout, stderr, finalBuffer } = cleanup();
error = err;
resolve({
error,
stdout,
stderr,
rawOutput: finalBuffer,
output: stdout + (stderr ? `\n${stderr}` : ''),
exitCode: 1,
signal: null,
aborted: false,
pid: child.pid,
});
});
const abortHandler = async () => {
@@ -200,18 +212,8 @@ export class ShellExecutionService {
abortSignal.addEventListener('abort', abortHandler, { once: true });
child.on('exit', (code, signal) => {
exited = true;
abortSignal.removeEventListener('abort', abortHandler);
if (stdoutDecoder) {
stdout += stripAnsi(stdoutDecoder.decode());
}
if (stderrDecoder) {
stderr += stripAnsi(stderrDecoder.decode());
}
const finalBuffer = Buffer.concat(outputChunks);
child.on('exit', (code: number, signal: NodeJS.Signals) => {
const { stdout, stderr, finalBuffer } = cleanup();
resolve({
rawOutput: finalBuffer,
@@ -225,6 +227,25 @@ export class ShellExecutionService {
pid: child.pid,
});
});
/**
* Cleans up a process (and it's accompanying state) that is exiting or
* erroring and returns output formatted output buffers and strings
*/
function cleanup() {
exited = true;
abortSignal.removeEventListener('abort', abortHandler);
if (stdoutDecoder) {
stdout += stripAnsi(stdoutDecoder.decode());
}
if (stderrDecoder) {
stderr += stripAnsi(stderrDecoder.decode());
}
const finalBuffer = Buffer.concat(outputChunks);
return { stdout, stderr, finalBuffer };
}
});
return { pid: child.pid, result };

View File

@@ -0,0 +1,263 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { vi, describe, it, expect, beforeEach, afterEach } from 'vitest';
import * as https from 'https';
import { ClientRequest, IncomingMessage } from 'http';
import { Readable, Writable } from 'stream';
import {
ClearcutLogger,
LogResponse,
LogEventEntry,
} from './clearcut-logger.js';
import { Config } from '../../config/config.js';
import * as userAccount from '../../utils/user_account.js';
import * as userId from '../../utils/user_id.js';
// Mock dependencies
vi.mock('https-proxy-agent');
vi.mock('https');
vi.mock('../../utils/user_account');
vi.mock('../../utils/user_id');
const mockHttps = vi.mocked(https);
const mockUserAccount = vi.mocked(userAccount);
const mockUserId = vi.mocked(userId);
describe('ClearcutLogger', () => {
let mockConfig: Config;
let logger: ClearcutLogger | undefined;
// A helper to get the internal events array for testing
const getEvents = (l: ClearcutLogger): LogEventEntry[][] =>
l['events'].toArray() as LogEventEntry[][];
const getEventsSize = (l: ClearcutLogger): number => l['events'].size;
const getMaxEvents = (l: ClearcutLogger): number => l['max_events'];
const getMaxRetryEvents = (l: ClearcutLogger): number =>
l['max_retry_events'];
const requeueFailedEvents = (l: ClearcutLogger, events: LogEventEntry[][]) =>
l['requeueFailedEvents'](events);
beforeEach(() => {
vi.useFakeTimers();
vi.setSystemTime(new Date());
mockConfig = {
getUsageStatisticsEnabled: vi.fn().mockReturnValue(true),
getDebugMode: vi.fn().mockReturnValue(false),
getSessionId: vi.fn().mockReturnValue('test-session-id'),
getProxy: vi.fn().mockReturnValue(undefined),
} as unknown as Config;
mockUserAccount.getCachedGoogleAccount.mockReturnValue('test@google.com');
mockUserAccount.getLifetimeGoogleAccounts.mockReturnValue(1);
mockUserId.getInstallationId.mockReturnValue('test-installation-id');
logger = ClearcutLogger.getInstance(mockConfig);
expect(logger).toBeDefined();
});
afterEach(() => {
ClearcutLogger.clearInstance();
vi.useRealTimers();
vi.restoreAllMocks();
});
it('should not return an instance if usage statistics are disabled', () => {
ClearcutLogger.clearInstance();
vi.spyOn(mockConfig, 'getUsageStatisticsEnabled').mockReturnValue(false);
const disabledLogger = ClearcutLogger.getInstance(mockConfig);
expect(disabledLogger).toBeUndefined();
});
describe('enqueueLogEvent', () => {
it('should add events to the queue', () => {
logger!.enqueueLogEvent({ test: 'event1' });
expect(getEventsSize(logger!)).toBe(1);
});
it('should evict the oldest event when the queue is full', () => {
const maxEvents = getMaxEvents(logger!);
for (let i = 0; i < maxEvents; i++) {
logger!.enqueueLogEvent({ event_id: i });
}
expect(getEventsSize(logger!)).toBe(maxEvents);
const firstEvent = JSON.parse(
getEvents(logger!)[0][0].source_extension_json,
);
expect(firstEvent.event_id).toBe(0);
// This should push out the first event
logger!.enqueueLogEvent({ event_id: maxEvents });
expect(getEventsSize(logger!)).toBe(maxEvents);
const newFirstEvent = JSON.parse(
getEvents(logger!)[0][0].source_extension_json,
);
expect(newFirstEvent.event_id).toBe(1);
const lastEvent = JSON.parse(
getEvents(logger!)[maxEvents - 1][0].source_extension_json,
);
expect(lastEvent.event_id).toBe(maxEvents);
});
});
describe('flushToClearcut', () => {
let mockRequest: Writable;
let mockResponse: Readable & Partial<IncomingMessage>;
beforeEach(() => {
mockRequest = new Writable({
write(chunk, encoding, callback) {
callback();
},
});
vi.spyOn(mockRequest, 'on');
vi.spyOn(mockRequest, 'end').mockReturnThis();
vi.spyOn(mockRequest, 'destroy').mockReturnThis();
mockResponse = new Readable({ read() {} }) as Readable &
Partial<IncomingMessage>;
mockHttps.request.mockImplementation(
(
_options: string | https.RequestOptions | URL,
...args: unknown[]
): ClientRequest => {
const callback = args.find((arg) => typeof arg === 'function') as
| ((res: IncomingMessage) => void)
| undefined;
if (callback) {
callback(mockResponse as IncomingMessage);
}
return mockRequest as ClientRequest;
},
);
});
it('should clear events on successful flush', async () => {
mockResponse.statusCode = 200;
const mockResponseBody = { nextRequestWaitMs: 1000 };
// Encoded protobuf for {nextRequestWaitMs: 1000} which is `08 E8 07`
const encodedResponse = Buffer.from([8, 232, 7]);
logger!.enqueueLogEvent({ event_id: 1 });
const flushPromise = logger!.flushToClearcut();
mockResponse.push(encodedResponse);
mockResponse.push(null); // End the stream
const response: LogResponse = await flushPromise;
expect(getEventsSize(logger!)).toBe(0);
expect(response.nextRequestWaitMs).toBe(
mockResponseBody.nextRequestWaitMs,
);
});
it('should handle a network error and requeue events', async () => {
logger!.enqueueLogEvent({ event_id: 1 });
logger!.enqueueLogEvent({ event_id: 2 });
expect(getEventsSize(logger!)).toBe(2);
const flushPromise = logger!.flushToClearcut();
mockRequest.emit('error', new Error('Network error'));
await flushPromise;
expect(getEventsSize(logger!)).toBe(2);
const events = getEvents(logger!);
expect(JSON.parse(events[0][0].source_extension_json).event_id).toBe(1);
});
it('should handle an HTTP error and requeue events', async () => {
mockResponse.statusCode = 500;
mockResponse.statusMessage = 'Internal Server Error';
logger!.enqueueLogEvent({ event_id: 1 });
logger!.enqueueLogEvent({ event_id: 2 });
expect(getEventsSize(logger!)).toBe(2);
const flushPromise = logger!.flushToClearcut();
mockResponse.emit('end'); // End the response to trigger promise resolution
await flushPromise;
expect(getEventsSize(logger!)).toBe(2);
const events = getEvents(logger!);
expect(JSON.parse(events[0][0].source_extension_json).event_id).toBe(1);
});
});
describe('requeueFailedEvents logic', () => {
it('should limit the number of requeued events to max_retry_events', () => {
const maxRetryEvents = getMaxRetryEvents(logger!);
const eventsToLogCount = maxRetryEvents + 5;
const eventsToSend: LogEventEntry[][] = [];
for (let i = 0; i < eventsToLogCount; i++) {
eventsToSend.push([
{
event_time_ms: Date.now(),
source_extension_json: JSON.stringify({ event_id: i }),
},
]);
}
requeueFailedEvents(logger!, eventsToSend);
expect(getEventsSize(logger!)).toBe(maxRetryEvents);
const firstRequeuedEvent = JSON.parse(
getEvents(logger!)[0][0].source_extension_json,
);
// The last `maxRetryEvents` are kept. The oldest of those is at index `eventsToLogCount - maxRetryEvents`.
expect(firstRequeuedEvent.event_id).toBe(
eventsToLogCount - maxRetryEvents,
);
});
it('should not requeue more events than available space in the queue', () => {
const maxEvents = getMaxEvents(logger!);
const spaceToLeave = 5;
const initialEventCount = maxEvents - spaceToLeave;
for (let i = 0; i < initialEventCount; i++) {
logger!.enqueueLogEvent({ event_id: `initial_${i}` });
}
expect(getEventsSize(logger!)).toBe(initialEventCount);
const failedEventsCount = 10; // More than spaceToLeave
const eventsToSend: LogEventEntry[][] = [];
for (let i = 0; i < failedEventsCount; i++) {
eventsToSend.push([
{
event_time_ms: Date.now(),
source_extension_json: JSON.stringify({ event_id: `failed_${i}` }),
},
]);
}
requeueFailedEvents(logger!, eventsToSend);
// availableSpace is 5. eventsToRequeue is min(10, 5) = 5.
// Total size should be initialEventCount + 5 = maxEvents.
expect(getEventsSize(logger!)).toBe(maxEvents);
// The requeued events are the *last* 5 of the failed events.
// startIndex = max(0, 10 - 5) = 5.
// Loop unshifts events from index 9 down to 5.
// The first element in the deque is the one with id 'failed_5'.
const firstRequeuedEvent = JSON.parse(
getEvents(logger!)[0][0].source_extension_json,
);
expect(firstRequeuedEvent.event_id).toBe('failed_5');
});
});
});

View File

@@ -30,8 +30,8 @@ import {
getCachedGoogleAccount,
getLifetimeGoogleAccounts,
} from '../../utils/user_account.js';
import { HttpError, retryWithBackoff } from '../../utils/retry.js';
import { getInstallationId } from '../../utils/user_id.js';
import { FixedDeque } from 'mnemonist';
const start_session_event_name = 'start_session';
const new_prompt_event_name = 'new_prompt';
@@ -51,18 +51,60 @@ export interface LogResponse {
nextRequestWaitMs?: number;
}
export interface LogEventEntry {
event_time_ms: number;
source_extension_json: string;
}
export type EventValue = {
gemini_cli_key: EventMetadataKey | string;
value: string;
};
export type LogEvent = {
console_type: string;
application: number;
event_name: string;
event_metadata: EventValue[][];
client_email?: string;
client_install_id?: string;
};
/**
* Determine the surface that the user is currently using. Surface is effectively the
* distribution channel in which the user is using Gemini CLI. Gemini CLI comes bundled
* w/ Firebase Studio and Cloud Shell. Users that manually download themselves will
* likely be "SURFACE_NOT_SET".
*
* This is computed based upon a series of environment variables these distribution
* methods might have in their runtimes.
*/
function determineSurface(): string {
if (process.env.CLOUD_SHELL === 'true') {
return 'CLOUD_SHELL';
} else if (process.env.MONOSPACE_ENV === 'true') {
return 'FIREBASE_STUDIO';
} else {
return process.env.SURFACE || 'SURFACE_NOT_SET';
}
}
// Singleton class for batch posting log events to Clearcut. When a new event comes in, the elapsed time
// is checked and events are flushed to Clearcut if at least a minute has passed since the last flush.
export class ClearcutLogger {
private static instance: ClearcutLogger;
private config?: Config;
// eslint-disable-next-line @typescript-eslint/no-explicit-any -- Clearcut expects this format.
private readonly events: any = [];
private readonly events: FixedDeque<LogEventEntry[]>;
private last_flush_time: number = Date.now();
private flush_interval_ms: number = 1000 * 60; // Wait at least a minute before flushing events.
private readonly max_events: number = 1000; // Maximum events to keep in memory
private readonly max_retry_events: number = 100; // Maximum failed events to retry
private flushing: boolean = false; // Prevent concurrent flush operations
private pendingFlush: boolean = false; // Track if a flush was requested during an ongoing flush
private constructor(config?: Config) {
this.config = config;
this.events = new FixedDeque<LogEventEntry[]>(Array, this.max_events);
}
static getInstance(config?: Config): ClearcutLogger | undefined {
@@ -74,30 +116,67 @@ export class ClearcutLogger {
return ClearcutLogger.instance;
}
// eslint-disable-next-line @typescript-eslint/no-explicit-any -- Clearcut expects this format.
enqueueLogEvent(event: any): void {
this.events.push([
{
event_time_ms: Date.now(),
source_extension_json: safeJsonStringify(event),
},
]);
/** For testing purposes only. */
static clearInstance(): void {
// @ts-expect-error - ClearcutLogger is a singleton, but we need to clear it for tests.
ClearcutLogger.instance = undefined;
}
createLogEvent(name: string, data: object[]): object {
const email = getCachedGoogleAccount();
const totalAccounts = getLifetimeGoogleAccounts();
data.push({
gemini_cli_key: EventMetadataKey.GEMINI_CLI_GOOGLE_ACCOUNTS_COUNT,
value: totalAccounts.toString(),
});
enqueueLogEvent(event: object): void {
try {
// Manually handle overflow for FixedDeque, which throws when full.
const wasAtCapacity = this.events.size >= this.max_events;
// eslint-disable-next-line @typescript-eslint/no-explicit-any
const logEvent: any = {
if (wasAtCapacity) {
this.events.shift(); // Evict oldest element to make space.
}
this.events.push([
{
event_time_ms: Date.now(),
source_extension_json: safeJsonStringify(event),
},
]);
if (wasAtCapacity && this.config?.getDebugMode()) {
console.debug(
`ClearcutLogger: Dropped old event to prevent memory leak (queue size: ${this.events.size})`,
);
}
} catch (error) {
if (this.config?.getDebugMode()) {
console.error('ClearcutLogger: Failed to enqueue log event.', error);
}
}
}
addDefaultFields(data: EventValue[]): void {
const totalAccounts = getLifetimeGoogleAccounts();
const surface = determineSurface();
const defaultLogMetadata = [
{
gemini_cli_key: EventMetadataKey.GEMINI_CLI_GOOGLE_ACCOUNTS_COUNT,
value: totalAccounts.toString(),
},
{
gemini_cli_key: EventMetadataKey.GEMINI_CLI_SURFACE,
value: surface,
},
];
data.push(...defaultLogMetadata);
}
createLogEvent(name: string, data: EventValue[]): LogEvent {
const email = getCachedGoogleAccount();
// Add default fields that should exist for all logs
this.addDefaultFields(data);
const logEvent: LogEvent = {
console_type: 'GEMINI_CLI',
application: 102,
event_name: name,
event_metadata: [data] as object[],
event_metadata: [data],
};
// Should log either email or install ID, not both. See go/cloudmill-1p-oss-instrumentation#define-sessionable-id
@@ -121,16 +200,25 @@ export class ClearcutLogger {
}
async flushToClearcut(): Promise<LogResponse> {
if (this.flushing) {
if (this.config?.getDebugMode()) {
console.debug(
'ClearcutLogger: Flush already in progress, marking pending flush.',
);
}
this.pendingFlush = true;
return Promise.resolve({});
}
this.flushing = true;
if (this.config?.getDebugMode()) {
console.log('Flushing log events to Clearcut.');
}
const eventsToSend = [...this.events];
if (eventsToSend.length === 0) {
return {};
}
const eventsToSend = this.events.toArray() as LogEventEntry[][];
this.events.clear();
const flushFn = () =>
new Promise<Buffer>((resolve, reject) => {
return new Promise<{ buffer: Buffer; statusCode?: number }>(
(resolve, reject) => {
const request = [
{
log_source_name: 'CONCORD',
@@ -144,6 +232,7 @@ export class ClearcutLogger {
path: '/log',
method: 'POST',
headers: { 'Content-Length': Buffer.byteLength(body) },
timeout: 30000, // 30-second timeout
};
const bufs: Buffer[] = [];
const req = https.request(
@@ -152,49 +241,77 @@ export class ClearcutLogger {
agent: this.getProxyAgent(),
},
(res) => {
if (
res.statusCode &&
(res.statusCode < 200 || res.statusCode >= 300)
) {
const err: HttpError = new Error(
`Request failed with status ${res.statusCode}`,
);
err.status = res.statusCode;
res.resume();
return reject(err);
}
res.on('error', reject); // Handle stream errors
res.on('data', (buf) => bufs.push(buf));
res.on('end', () => resolve(Buffer.concat(bufs)));
res.on('end', () => {
try {
const buffer = Buffer.concat(bufs);
// Check if we got a successful response
if (
res.statusCode &&
res.statusCode >= 200 &&
res.statusCode < 300
) {
resolve({ buffer, statusCode: res.statusCode });
} else {
// HTTP error - reject with status code for retry handling
reject(
new Error(`HTTP ${res.statusCode}: ${res.statusMessage}`),
);
}
} catch (e) {
reject(e);
}
});
},
);
req.on('error', reject);
req.on('error', (e) => {
// Network-level error
reject(e);
});
req.on('timeout', () => {
if (!req.destroyed) {
req.destroy(new Error('Request timeout after 30 seconds'));
}
});
req.end(body);
},
)
.then(({ buffer }) => {
try {
this.last_flush_time = Date.now();
return this.decodeLogResponse(buffer) || {};
} catch (error: unknown) {
console.error('Error decoding log response:', error);
return {};
}
})
.catch((error: unknown) => {
// Handle both network-level and HTTP-level errors
if (this.config?.getDebugMode()) {
console.error('Error flushing log events:', error);
}
// Re-queue failed events for retry
this.requeueFailedEvents(eventsToSend);
// Return empty response to maintain the Promise<LogResponse> contract
return {};
})
.finally(() => {
this.flushing = false;
// If a flush was requested while we were flushing, flush again
if (this.pendingFlush) {
this.pendingFlush = false;
// Fire and forget the pending flush
this.flushToClearcut().catch((error) => {
if (this.config?.getDebugMode()) {
console.debug('Error in pending flush to Clearcut:', error);
}
});
}
});
try {
const responseBuffer = await retryWithBackoff(flushFn, {
maxAttempts: 3,
initialDelayMs: 200,
shouldRetry: (err: unknown) => {
if (!(err instanceof Error)) return false;
const status = (err as HttpError).status as number | undefined;
// If status is not available, it's likely a network error
if (status === undefined) return true;
// Retry on 429 (Too many Requests) and 5xx server errors.
return status === 429 || (status >= 500 && status < 600);
},
});
this.events.splice(0, eventsToSend.length);
this.last_flush_time = Date.now();
return this.decodeLogResponse(responseBuffer) || {};
} catch (error) {
if (this.config?.getDebugMode()) {
console.error('Clearcut flush failed after multiple retries.', error);
}
return {};
}
}
// Visible for testing. Decodes protobuf-encoded response from Clearcut server.
@@ -237,12 +354,7 @@ export class ClearcutLogger {
}
logStartSessionEvent(event: StartSessionEvent): void {
const surface =
process.env.CLOUD_SHELL === 'true'
? 'CLOUD_SHELL'
: process.env.SURFACE || 'SURFACE_NOT_SET';
const data = [
const data: EventValue[] = [
{
gemini_cli_key: EventMetadataKey.GEMINI_CLI_START_SESSION_MODEL,
value: event.model,
@@ -307,10 +419,6 @@ export class ClearcutLogger {
EventMetadataKey.GEMINI_CLI_START_SESSION_TELEMETRY_LOG_USER_PROMPTS_ENABLED,
value: event.telemetry_log_user_prompts_enabled.toString(),
},
{
gemini_cli_key: EventMetadataKey.GEMINI_CLI_SURFACE,
value: surface,
},
];
// Flush start event immediately
@@ -321,7 +429,7 @@ export class ClearcutLogger {
}
logNewPromptEvent(event: UserPromptEvent): void {
const data = [
const data: EventValue[] = [
{
gemini_cli_key: EventMetadataKey.GEMINI_CLI_USER_PROMPT_LENGTH,
value: JSON.stringify(event.prompt_length),
@@ -345,7 +453,7 @@ export class ClearcutLogger {
}
logToolCallEvent(event: ToolCallEvent): void {
const data = [
const data: EventValue[] = [
{
gemini_cli_key: EventMetadataKey.GEMINI_CLI_TOOL_CALL_NAME,
value: JSON.stringify(event.function_name),
@@ -376,13 +484,31 @@ export class ClearcutLogger {
},
];
if (event.metadata) {
const metadataMapping: { [key: string]: EventMetadataKey } = {
ai_added_lines: EventMetadataKey.GEMINI_CLI_AI_ADDED_LINES,
ai_removed_lines: EventMetadataKey.GEMINI_CLI_AI_REMOVED_LINES,
user_added_lines: EventMetadataKey.GEMINI_CLI_USER_ADDED_LINES,
user_removed_lines: EventMetadataKey.GEMINI_CLI_USER_REMOVED_LINES,
};
for (const [key, gemini_cli_key] of Object.entries(metadataMapping)) {
if (event.metadata[key] !== undefined) {
data.push({
gemini_cli_key,
value: JSON.stringify(event.metadata[key]),
});
}
}
}
const logEvent = this.createLogEvent(tool_call_event_name, data);
this.enqueueLogEvent(logEvent);
this.flushIfNeeded();
}
logApiRequestEvent(event: ApiRequestEvent): void {
const data = [
const data: EventValue[] = [
{
gemini_cli_key: EventMetadataKey.GEMINI_CLI_API_REQUEST_MODEL,
value: JSON.stringify(event.model),
@@ -398,7 +524,7 @@ export class ClearcutLogger {
}
logApiResponseEvent(event: ApiResponseEvent): void {
const data = [
const data: EventValue[] = [
{
gemini_cli_key: EventMetadataKey.GEMINI_CLI_API_RESPONSE_MODEL,
value: JSON.stringify(event.model),
@@ -455,7 +581,7 @@ export class ClearcutLogger {
}
logApiErrorEvent(event: ApiErrorEvent): void {
const data = [
const data: EventValue[] = [
{
gemini_cli_key: EventMetadataKey.GEMINI_CLI_API_ERROR_MODEL,
value: JSON.stringify(event.model),
@@ -487,7 +613,7 @@ export class ClearcutLogger {
}
logFlashFallbackEvent(event: FlashFallbackEvent): void {
const data = [
const data: EventValue[] = [
{
gemini_cli_key: EventMetadataKey.GEMINI_CLI_AUTH_TYPE,
value: JSON.stringify(event.auth_type),
@@ -505,7 +631,7 @@ export class ClearcutLogger {
}
logLoopDetectedEvent(event: LoopDetectedEvent): void {
const data = [
const data: EventValue[] = [
{
gemini_cli_key: EventMetadataKey.GEMINI_CLI_PROMPT_ID,
value: JSON.stringify(event.prompt_id),
@@ -521,7 +647,7 @@ export class ClearcutLogger {
}
logNextSpeakerCheck(event: NextSpeakerCheckEvent): void {
const data = [
const data: EventValue[] = [
{
gemini_cli_key: EventMetadataKey.GEMINI_CLI_PROMPT_ID,
value: JSON.stringify(event.prompt_id),
@@ -547,7 +673,7 @@ export class ClearcutLogger {
}
logSlashCommandEvent(event: SlashCommandEvent): void {
const data = [
const data: EventValue[] = [
{
gemini_cli_key: EventMetadataKey.GEMINI_CLI_SLASH_COMMAND_NAME,
value: JSON.stringify(event.command),
@@ -566,7 +692,7 @@ export class ClearcutLogger {
}
logMalformedJsonResponseEvent(event: MalformedJsonResponseEvent): void {
const data = [
const data: EventValue[] = [
{
gemini_cli_key:
EventMetadataKey.GEMINI_CLI_MALFORMED_JSON_RESPONSE_MODEL,
@@ -581,7 +707,7 @@ export class ClearcutLogger {
}
logIdeConnectionEvent(event: IdeConnectionEvent): void {
const data = [
const data: EventValue[] = [
{
gemini_cli_key: EventMetadataKey.GEMINI_CLI_IDE_CONNECTION_TYPE,
value: JSON.stringify(event.connection_type),
@@ -593,7 +719,7 @@ export class ClearcutLogger {
}
logEndSessionEvent(event: EndSessionEvent): void {
const data = [
const data: EventValue[] = [
{
gemini_cli_key: EventMetadataKey.GEMINI_CLI_SESSION_ID,
value: event?.session_id?.toString() ?? '',
@@ -623,4 +749,57 @@ export class ClearcutLogger {
const event = new EndSessionEvent(this.config);
this.logEndSessionEvent(event);
}
private requeueFailedEvents(eventsToSend: LogEventEntry[][]): void {
// Add the events back to the front of the queue to be retried, but limit retry queue size
const eventsToRetry = eventsToSend.slice(-this.max_retry_events); // Keep only the most recent events
// Log a warning if we're dropping events
if (
eventsToSend.length > this.max_retry_events &&
this.config?.getDebugMode()
) {
console.warn(
`ClearcutLogger: Dropping ${
eventsToSend.length - this.max_retry_events
} events due to retry queue limit. Total events: ${
eventsToSend.length
}, keeping: ${this.max_retry_events}`,
);
}
// Determine how many events can be re-queued
const availableSpace = this.max_events - this.events.size;
const numEventsToRequeue = Math.min(eventsToRetry.length, availableSpace);
if (numEventsToRequeue === 0) {
if (this.config?.getDebugMode()) {
console.debug(
`ClearcutLogger: No events re-queued (queue size: ${this.events.size})`,
);
}
return;
}
// Get the most recent events to re-queue
const eventsToRequeue = eventsToRetry.slice(
eventsToRetry.length - numEventsToRequeue,
);
// Prepend events to the front of the deque to be retried first.
// We iterate backwards to maintain the original order of the failed events.
for (let i = eventsToRequeue.length - 1; i >= 0; i--) {
this.events.unshift(eventsToRequeue[i]);
}
// Clear any potential overflow
while (this.events.size > this.max_events) {
this.events.pop();
}
if (this.config?.getDebugMode()) {
console.debug(
`ClearcutLogger: Re-queued ${numEventsToRequeue} events for retry (queue size: ${this.events.size})`,
);
}
}
}

View File

@@ -197,6 +197,18 @@ export enum EventMetadataKey {
// Logs the type of the IDE connection.
GEMINI_CLI_IDE_CONNECTION_TYPE = 46,
// Logs AI added lines in edit/write tool response.
GEMINI_CLI_AI_ADDED_LINES = 47,
// Logs AI removed lines in edit/write tool response.
GEMINI_CLI_AI_REMOVED_LINES = 48,
// Logs user added lines in edit/write tool response.
GEMINI_CLI_USER_ADDED_LINES = 49,
// Logs user removed lines in edit/write tool response.
GEMINI_CLI_USER_REMOVED_LINES = 50,
}
export function getEventMetadataKey(

View File

@@ -14,7 +14,7 @@ import { ToolCallEvent } from './types.js';
import { Config } from '../config/config.js';
import { CompletedToolCall } from '../core/coreToolScheduler.js';
import { ToolCallRequestInfo, ToolCallResponseInfo } from '../core/turn.js';
import { Tool } from '../tools/tools.js';
import { MockTool } from '../test-utils/tools.js';
describe('Circular Reference Handling', () => {
it('should handle circular references in tool function arguments', () => {
@@ -56,11 +56,13 @@ describe('Circular Reference Handling', () => {
errorType: undefined,
};
const tool = new MockTool('mock-tool');
const mockCompletedToolCall: CompletedToolCall = {
status: 'success',
request: mockRequest,
response: mockResponse,
tool: {} as Tool,
tool,
invocation: tool.build({}),
durationMs: 100,
};
@@ -104,11 +106,13 @@ describe('Circular Reference Handling', () => {
errorType: undefined,
};
const tool = new MockTool('mock-tool');
const mockCompletedToolCall: CompletedToolCall = {
status: 'success',
request: mockRequest,
response: mockResponse,
tool: {} as Tool,
tool,
invocation: tool.build({}),
durationMs: 100,
};

View File

@@ -5,6 +5,7 @@
*/
import {
AnyToolInvocation,
AuthType,
CompletedToolCall,
ContentGeneratorConfig,
@@ -34,11 +35,11 @@ import {
logToolCall,
logFlashFallback,
} from './loggers.js';
import { ToolCallDecision } from './tool-call-decision.js';
import {
ApiRequestEvent,
ApiResponseEvent,
StartSessionEvent,
ToolCallDecision,
ToolCallEvent,
UserPromptEvent,
FlashFallbackEvent,
@@ -435,6 +436,7 @@ describe('loggers', () => {
});
it('should log a tool call with all fields', () => {
const tool = new EditTool(mockConfig);
const call: CompletedToolCall = {
status: 'success',
request: {
@@ -454,7 +456,8 @@ describe('loggers', () => {
error: undefined,
errorType: undefined,
},
tool: new EditTool(mockConfig),
tool,
invocation: {} as AnyToolInvocation,
durationMs: 100,
outcome: ToolConfirmationOutcome.ProceedOnce,
};
@@ -584,6 +587,7 @@ describe('loggers', () => {
},
outcome: ToolConfirmationOutcome.ModifyWithEditor,
tool: new EditTool(mockConfig),
invocation: {} as AnyToolInvocation,
durationMs: 100,
};
const event = new ToolCallEvent(call);
@@ -648,6 +652,7 @@ describe('loggers', () => {
errorType: undefined,
},
tool: new EditTool(mockConfig),
invocation: {} as AnyToolInvocation,
durationMs: 100,
};
const event = new ToolCallEvent(call);

View File

@@ -221,5 +221,87 @@ describe('Telemetry Metrics', () => {
mimetype: 'application/javascript',
});
});
it('should include diffStat when provided', () => {
initializeMetricsModule(mockConfig);
mockCounterAddFn.mockClear();
const diffStat = {
ai_added_lines: 5,
ai_removed_lines: 2,
user_added_lines: 3,
user_removed_lines: 1,
};
recordFileOperationMetricModule(
mockConfig,
FileOperation.UPDATE,
undefined,
undefined,
undefined,
diffStat,
);
expect(mockCounterAddFn).toHaveBeenCalledWith(1, {
'session.id': 'test-session-id',
operation: FileOperation.UPDATE,
ai_added_lines: 5,
ai_removed_lines: 2,
user_added_lines: 3,
user_removed_lines: 1,
});
});
it('should not include diffStat attributes when diffStat is not provided', () => {
initializeMetricsModule(mockConfig);
mockCounterAddFn.mockClear();
recordFileOperationMetricModule(
mockConfig,
FileOperation.UPDATE,
10,
'text/plain',
'txt',
undefined,
);
expect(mockCounterAddFn).toHaveBeenCalledWith(1, {
'session.id': 'test-session-id',
operation: FileOperation.UPDATE,
lines: 10,
mimetype: 'text/plain',
extension: 'txt',
});
});
it('should handle diffStat with all zero values', () => {
initializeMetricsModule(mockConfig);
mockCounterAddFn.mockClear();
const diffStat = {
ai_added_lines: 0,
ai_removed_lines: 0,
user_added_lines: 0,
user_removed_lines: 0,
};
recordFileOperationMetricModule(
mockConfig,
FileOperation.UPDATE,
undefined,
undefined,
undefined,
diffStat,
);
expect(mockCounterAddFn).toHaveBeenCalledWith(1, {
'session.id': 'test-session-id',
operation: FileOperation.UPDATE,
ai_added_lines: 0,
ai_removed_lines: 0,
user_added_lines: 0,
user_removed_lines: 0,
});
});
});
});

View File

@@ -23,6 +23,7 @@ import {
METRIC_FILE_OPERATION_COUNT,
} from './constants.js';
import { Config } from '../config/config.js';
import { DiffStat } from '../tools/tools.js';
export enum FileOperation {
CREATE = 'create',
@@ -100,7 +101,7 @@ export function recordToolCallMetrics(
functionName: string,
durationMs: number,
success: boolean,
decision?: 'accept' | 'reject' | 'modify',
decision?: 'accept' | 'reject' | 'modify' | 'auto_accept',
): void {
if (!toolCallCounter || !toolCallLatencyHistogram || !isMetricsInitialized)
return;
@@ -189,6 +190,7 @@ export function recordFileOperationMetric(
lines?: number,
mimetype?: string,
extension?: string,
diffStat?: DiffStat,
): void {
if (!fileOperationCounter || !isMetricsInitialized) return;
const attributes: Attributes = {
@@ -198,5 +200,11 @@ export function recordFileOperationMetric(
if (lines !== undefined) attributes.lines = lines;
if (mimetype !== undefined) attributes.mimetype = mimetype;
if (extension !== undefined) attributes.extension = extension;
if (diffStat !== undefined) {
attributes.ai_added_lines = diffStat.ai_added_lines;
attributes.ai_removed_lines = diffStat.ai_removed_lines;
attributes.user_added_lines = diffStat.user_added_lines;
attributes.user_removed_lines = diffStat.user_removed_lines;
}
fileOperationCounter.add(1, attributes);
}

View File

@@ -0,0 +1,32 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { ToolConfirmationOutcome } from '../tools/tools.js';
export enum ToolCallDecision {
ACCEPT = 'accept',
REJECT = 'reject',
MODIFY = 'modify',
AUTO_ACCEPT = 'auto_accept',
}
export function getDecisionFromOutcome(
outcome: ToolConfirmationOutcome,
): ToolCallDecision {
switch (outcome) {
case ToolConfirmationOutcome.ProceedOnce:
return ToolCallDecision.ACCEPT;
case ToolConfirmationOutcome.ProceedAlways:
case ToolConfirmationOutcome.ProceedAlwaysServer:
case ToolConfirmationOutcome.ProceedAlwaysTool:
return ToolCallDecision.AUTO_ACCEPT;
case ToolConfirmationOutcome.ModifyWithEditor:
return ToolCallDecision.MODIFY;
case ToolConfirmationOutcome.Cancel:
default:
return ToolCallDecision.REJECT;
}
}

View File

@@ -7,31 +7,12 @@
import { GenerateContentResponseUsageMetadata } from '@google/genai';
import { Config } from '../config/config.js';
import { CompletedToolCall } from '../core/coreToolScheduler.js';
import { ToolConfirmationOutcome } from '../tools/tools.js';
import { FileDiff } from '../tools/tools.js';
import { AuthType } from '../core/contentGenerator.js';
export enum ToolCallDecision {
ACCEPT = 'accept',
REJECT = 'reject',
MODIFY = 'modify',
}
export function getDecisionFromOutcome(
outcome: ToolConfirmationOutcome,
): ToolCallDecision {
switch (outcome) {
case ToolConfirmationOutcome.ProceedOnce:
case ToolConfirmationOutcome.ProceedAlways:
case ToolConfirmationOutcome.ProceedAlwaysServer:
case ToolConfirmationOutcome.ProceedAlwaysTool:
return ToolCallDecision.ACCEPT;
case ToolConfirmationOutcome.ModifyWithEditor:
return ToolCallDecision.MODIFY;
case ToolConfirmationOutcome.Cancel:
default:
return ToolCallDecision.REJECT;
}
}
import {
getDecisionFromOutcome,
ToolCallDecision,
} from './tool-call-decision.js';
export class StartSessionEvent {
'event.name': 'cli_config';
@@ -125,6 +106,8 @@ export class ToolCallEvent {
error?: string;
error_type?: string;
prompt_id: string;
// eslint-disable-next-line @typescript-eslint/no-explicit-any
metadata?: { [key: string]: any };
constructor(call: CompletedToolCall) {
this['event.name'] = 'tool_call';
@@ -139,6 +122,23 @@ export class ToolCallEvent {
this.error = call.response.error?.message;
this.error_type = call.response.errorType;
this.prompt_id = call.request.prompt_id;
if (
call.status === 'success' &&
typeof call.response.resultDisplay === 'object' &&
call.response.resultDisplay !== null &&
'diffStat' in call.response.resultDisplay
) {
const diffStat = (call.response.resultDisplay as FileDiff).diffStat;
if (diffStat) {
this.metadata = {
ai_added_lines: diffStat.ai_added_lines,
ai_removed_lines: diffStat.ai_removed_lines,
user_added_lines: diffStat.user_added_lines,
user_removed_lines: diffStat.user_removed_lines,
};
}
}
}
}

View File

@@ -6,12 +6,8 @@
import { describe, it, expect, vi, beforeEach } from 'vitest';
import { UiTelemetryService } from './uiTelemetry.js';
import {
ApiErrorEvent,
ApiResponseEvent,
ToolCallEvent,
ToolCallDecision,
} from './types.js';
import { ToolCallDecision } from './tool-call-decision.js';
import { ApiErrorEvent, ApiResponseEvent, ToolCallEvent } from './types.js';
import {
EVENT_API_ERROR,
EVENT_API_RESPONSE,
@@ -23,7 +19,8 @@ import {
SuccessfulToolCall,
} from '../core/coreToolScheduler.js';
import { ToolErrorType } from '../tools/tool-error.js';
import { Tool, ToolConfirmationOutcome } from '../tools/tools.js';
import { ToolConfirmationOutcome } from '../tools/tools.js';
import { MockTool } from '../test-utils/tools.js';
const createFakeCompletedToolCall = (
name: string,
@@ -39,12 +36,14 @@ const createFakeCompletedToolCall = (
isClientInitiated: false,
prompt_id: 'prompt-id-1',
};
const tool = new MockTool(name);
if (success) {
return {
status: 'success',
request,
tool: { name } as Tool, // Mock tool
tool,
invocation: tool.build({}),
response: {
callId: request.callId,
responseParts: {
@@ -65,6 +64,7 @@ const createFakeCompletedToolCall = (
return {
status: 'error',
request,
tool,
response: {
callId: request.callId,
responseParts: {
@@ -104,6 +104,7 @@ describe('UiTelemetryService', () => {
[ToolCallDecision.ACCEPT]: 0,
[ToolCallDecision.REJECT]: 0,
[ToolCallDecision.MODIFY]: 0,
[ToolCallDecision.AUTO_ACCEPT]: 0,
},
byName: {},
},
@@ -362,6 +363,7 @@ describe('UiTelemetryService', () => {
[ToolCallDecision.ACCEPT]: 1,
[ToolCallDecision.REJECT]: 0,
[ToolCallDecision.MODIFY]: 0,
[ToolCallDecision.AUTO_ACCEPT]: 0,
},
});
});
@@ -395,6 +397,7 @@ describe('UiTelemetryService', () => {
[ToolCallDecision.ACCEPT]: 0,
[ToolCallDecision.REJECT]: 1,
[ToolCallDecision.MODIFY]: 0,
[ToolCallDecision.AUTO_ACCEPT]: 0,
},
});
});
@@ -434,11 +437,13 @@ describe('UiTelemetryService', () => {
[ToolCallDecision.ACCEPT]: 0,
[ToolCallDecision.REJECT]: 0,
[ToolCallDecision.MODIFY]: 0,
[ToolCallDecision.AUTO_ACCEPT]: 0,
});
expect(tools.byName['test_tool'].decisions).toEqual({
[ToolCallDecision.ACCEPT]: 0,
[ToolCallDecision.REJECT]: 0,
[ToolCallDecision.MODIFY]: 0,
[ToolCallDecision.AUTO_ACCEPT]: 0,
});
});
@@ -483,6 +488,7 @@ describe('UiTelemetryService', () => {
[ToolCallDecision.ACCEPT]: 1,
[ToolCallDecision.REJECT]: 1,
[ToolCallDecision.MODIFY]: 0,
[ToolCallDecision.AUTO_ACCEPT]: 0,
},
});
});

View File

@@ -11,12 +11,8 @@ import {
EVENT_TOOL_CALL,
} from './constants.js';
import {
ApiErrorEvent,
ApiResponseEvent,
ToolCallEvent,
ToolCallDecision,
} from './types.js';
import { ToolCallDecision } from './tool-call-decision.js';
import { ApiErrorEvent, ApiResponseEvent, ToolCallEvent } from './types.js';
export type UiEvent =
| (ApiResponseEvent & { 'event.name': typeof EVENT_API_RESPONSE })
@@ -32,6 +28,7 @@ export interface ToolCallStats {
[ToolCallDecision.ACCEPT]: number;
[ToolCallDecision.REJECT]: number;
[ToolCallDecision.MODIFY]: number;
[ToolCallDecision.AUTO_ACCEPT]: number;
};
}
@@ -62,6 +59,7 @@ export interface SessionMetrics {
[ToolCallDecision.ACCEPT]: number;
[ToolCallDecision.REJECT]: number;
[ToolCallDecision.MODIFY]: number;
[ToolCallDecision.AUTO_ACCEPT]: number;
};
byName: Record<string, ToolCallStats>;
};
@@ -94,6 +92,7 @@ const createInitialMetrics = (): SessionMetrics => ({
[ToolCallDecision.ACCEPT]: 0,
[ToolCallDecision.REJECT]: 0,
[ToolCallDecision.MODIFY]: 0,
[ToolCallDecision.AUTO_ACCEPT]: 0,
},
byName: {},
},
@@ -192,6 +191,7 @@ export class UiTelemetryService extends EventEmitter {
[ToolCallDecision.ACCEPT]: 0,
[ToolCallDecision.REJECT]: 0,
[ToolCallDecision.MODIFY]: 0,
[ToolCallDecision.AUTO_ACCEPT]: 0,
},
};
}

View File

@@ -0,0 +1,63 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { vi } from 'vitest';
import {
BaseTool,
Icon,
ToolCallConfirmationDetails,
ToolResult,
} from '../tools/tools.js';
import { Schema, Type } from '@google/genai';
/**
* A highly configurable mock tool for testing purposes.
*/
export class MockTool extends BaseTool<{ [key: string]: unknown }, ToolResult> {
executeFn = vi.fn();
shouldConfirm = false;
constructor(
name = 'mock-tool',
displayName?: string,
description = 'A mock tool for testing.',
params: Schema = {
type: Type.OBJECT,
properties: { param: { type: Type.STRING } },
},
) {
super(name, displayName ?? name, description, Icon.Hammer, params);
}
async execute(
params: { [key: string]: unknown },
_abortSignal: AbortSignal,
): Promise<ToolResult> {
const result = this.executeFn(params);
return (
result ?? {
llmContent: `Tool ${this.name} executed successfully.`,
returnDisplay: `Tool ${this.name} executed successfully.`,
}
);
}
async shouldConfirmExecute(
_params: { [key: string]: unknown },
_abortSignal: AbortSignal,
): Promise<ToolCallConfirmationDetails | false> {
if (this.shouldConfirm) {
return {
type: 'exec' as const,
title: `Confirm ${this.displayName}`,
command: this.name,
rootCommand: this.name,
onConfirm: async () => {},
};
}
return false;
}
}

View File

@@ -0,0 +1,129 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { describe, expect, it } from 'vitest';
import { getDiffStat } from './diffOptions.js';
describe('getDiffStat', () => {
const fileName = 'test.txt';
it('should return 0 for all stats when there are no changes', () => {
const oldStr = 'line1\nline2\n';
const aiStr = 'line1\nline2\n';
const userStr = 'line1\nline2\n';
const diffStat = getDiffStat(fileName, oldStr, aiStr, userStr);
expect(diffStat).toEqual({
ai_added_lines: 0,
ai_removed_lines: 0,
user_added_lines: 0,
user_removed_lines: 0,
});
});
it('should correctly report AI additions', () => {
const oldStr = 'line1\nline2\n';
const aiStr = 'line1\nline2\nline3\n';
const userStr = 'line1\nline2\nline3\n';
const diffStat = getDiffStat(fileName, oldStr, aiStr, userStr);
expect(diffStat).toEqual({
ai_added_lines: 1,
ai_removed_lines: 0,
user_added_lines: 0,
user_removed_lines: 0,
});
});
it('should correctly report AI removals', () => {
const oldStr = 'line1\nline2\nline3\n';
const aiStr = 'line1\nline3\n';
const userStr = 'line1\nline3\n';
const diffStat = getDiffStat(fileName, oldStr, aiStr, userStr);
expect(diffStat).toEqual({
ai_added_lines: 0,
ai_removed_lines: 1,
user_added_lines: 0,
user_removed_lines: 0,
});
});
it('should correctly report AI modifications', () => {
const oldStr = 'line1\nline2\nline3\n';
const aiStr = 'line1\nline_two\nline3\n';
const userStr = 'line1\nline_two\nline3\n';
const diffStat = getDiffStat(fileName, oldStr, aiStr, userStr);
expect(diffStat).toEqual({
ai_added_lines: 1,
ai_removed_lines: 1,
user_added_lines: 0,
user_removed_lines: 0,
});
});
it('should correctly report user additions', () => {
const oldStr = 'line1\nline2\n';
const aiStr = 'line1\nline2\nline3\n';
const userStr = 'line1\nline2\nline3\nline4\n';
const diffStat = getDiffStat(fileName, oldStr, aiStr, userStr);
expect(diffStat).toEqual({
ai_added_lines: 1,
ai_removed_lines: 0,
user_added_lines: 1,
user_removed_lines: 0,
});
});
it('should correctly report user removals', () => {
const oldStr = 'line1\nline2\n';
const aiStr = 'line1\nline2\nline3\n';
const userStr = 'line1\nline2\n';
const diffStat = getDiffStat(fileName, oldStr, aiStr, userStr);
expect(diffStat).toEqual({
ai_added_lines: 1,
ai_removed_lines: 0,
user_added_lines: 0,
user_removed_lines: 1,
});
});
it('should correctly report user modifications', () => {
const oldStr = 'line1\nline2\n';
const aiStr = 'line1\nline2\nline3\n';
const userStr = 'line1\nline2\nline_three\n';
const diffStat = getDiffStat(fileName, oldStr, aiStr, userStr);
expect(diffStat).toEqual({
ai_added_lines: 1,
ai_removed_lines: 0,
user_added_lines: 1,
user_removed_lines: 1,
});
});
it('should handle complex changes from both AI and user', () => {
const oldStr = 'line1\nline2\nline3\nline4\n';
const aiStr = 'line_one\nline2\nline_three\nline4\n';
const userStr = 'line_one\nline_two\nline_three\nline4\nline5\n';
const diffStat = getDiffStat(fileName, oldStr, aiStr, userStr);
expect(diffStat).toEqual({
ai_added_lines: 2,
ai_removed_lines: 2,
user_added_lines: 2,
user_removed_lines: 1,
});
});
it('should report a single line modification as one addition and one removal', () => {
const oldStr = 'hello world';
const aiStr = 'hello universe';
const userStr = 'hello universe';
const diffStat = getDiffStat(fileName, oldStr, aiStr, userStr);
expect(diffStat).toEqual({
ai_added_lines: 1,
ai_removed_lines: 1,
user_added_lines: 0,
user_removed_lines: 0,
});
});
});

View File

@@ -5,8 +5,61 @@
*/
import * as Diff from 'diff';
import { DiffStat } from './tools.js';
export const DEFAULT_DIFF_OPTIONS: Diff.PatchOptions = {
context: 3,
ignoreWhitespace: true,
};
export function getDiffStat(
fileName: string,
oldStr: string,
aiStr: string,
userStr: string,
): DiffStat {
const countLines = (patch: Diff.ParsedDiff) => {
let added = 0;
let removed = 0;
patch.hunks.forEach((hunk: Diff.Hunk) => {
hunk.lines.forEach((line: string) => {
if (line.startsWith('+')) {
added++;
} else if (line.startsWith('-')) {
removed++;
}
});
});
return { added, removed };
};
const patch = Diff.structuredPatch(
fileName,
fileName,
oldStr,
aiStr,
'Current',
'Proposed',
DEFAULT_DIFF_OPTIONS,
);
const { added: aiAddedLines, removed: aiRemovedLines } = countLines(patch);
const userPatch = Diff.structuredPatch(
fileName,
fileName,
aiStr,
userStr,
'Proposed',
'User',
DEFAULT_DIFF_OPTIONS,
);
const { added: userAddedLines, removed: userRemovedLines } =
countLines(userPatch);
return {
ai_added_lines: aiAddedLines,
ai_removed_lines: aiRemovedLines,
user_added_lines: userAddedLines,
user_removed_lines: userRemovedLines,
};
}

View File

@@ -10,6 +10,8 @@ const mockEnsureCorrectEdit = vi.hoisted(() => vi.fn());
const mockGenerateJson = vi.hoisted(() => vi.fn());
const mockOpenDiff = vi.hoisted(() => vi.fn());
import { IDEConnectionStatus } from '../ide/ide-client.js';
vi.mock('../utils/editCorrector.js', () => ({
ensureCorrectEdit: mockEnsureCorrectEdit,
}));
@@ -25,8 +27,8 @@ vi.mock('../utils/editor.js', () => ({
}));
import { describe, it, expect, beforeEach, afterEach, vi, Mock } from 'vitest';
import { EditTool, EditToolParams } from './edit.js';
import { FileDiff } from './tools.js';
import { applyReplacement, EditTool, EditToolParams } from './edit.js';
import { FileDiff, ToolConfirmationOutcome } from './tools.js';
import { ToolErrorType } from './tool-error.js';
import path from 'path';
import fs from 'fs';
@@ -58,6 +60,9 @@ describe('EditTool', () => {
getApprovalMode: vi.fn(),
setApprovalMode: vi.fn(),
getWorkspaceContext: () => createMockWorkspaceContext(rootDir),
getIdeClient: () => undefined,
getIdeMode: () => false,
getIdeModeFeature: () => false,
// getGeminiConfig: () => ({ apiKey: 'test-api-key' }), // This was not a real Config method
// Add other properties/methods of Config if EditTool uses them
// Minimal other methods to satisfy Config type if needed by EditTool constructor or other direct uses:
@@ -150,45 +155,30 @@ describe('EditTool', () => {
fs.rmSync(tempDir, { recursive: true, force: true });
});
describe('_applyReplacement', () => {
// Access private method for testing
// Note: `tool` is initialized in `beforeEach` of the parent describe block
describe('applyReplacement', () => {
it('should return newString if isNewFile is true', () => {
expect((tool as any)._applyReplacement(null, 'old', 'new', true)).toBe(
'new',
);
expect(
(tool as any)._applyReplacement('existing', 'old', 'new', true),
).toBe('new');
expect(applyReplacement(null, 'old', 'new', true)).toBe('new');
expect(applyReplacement('existing', 'old', 'new', true)).toBe('new');
});
it('should return newString if currentContent is null and oldString is empty (defensive)', () => {
expect((tool as any)._applyReplacement(null, '', 'new', false)).toBe(
'new',
);
expect(applyReplacement(null, '', 'new', false)).toBe('new');
});
it('should return empty string if currentContent is null and oldString is not empty (defensive)', () => {
expect((tool as any)._applyReplacement(null, 'old', 'new', false)).toBe(
'',
);
expect(applyReplacement(null, 'old', 'new', false)).toBe('');
});
it('should replace oldString with newString in currentContent', () => {
expect(
(tool as any)._applyReplacement(
'hello old world old',
'old',
'new',
false,
),
).toBe('hello new world new');
expect(applyReplacement('hello old world old', 'old', 'new', false)).toBe(
'hello new world new',
);
});
it('should return currentContent if oldString is empty and not a new file', () => {
expect(
(tool as any)._applyReplacement('hello world', '', 'new', false),
).toBe('hello world');
expect(applyReplacement('hello world', '', 'new', false)).toBe(
'hello world',
);
});
});
@@ -234,15 +224,13 @@ describe('EditTool', () => {
filePath = path.join(rootDir, testFile);
});
it('should return false if params are invalid', async () => {
it('should throw an error if params are invalid', async () => {
const params: EditToolParams = {
file_path: 'relative.txt',
old_string: 'old',
new_string: 'new',
};
expect(
await tool.shouldConfirmExecute(params, new AbortController().signal),
).toBe(false);
expect(() => tool.build(params)).toThrow();
});
it('should request confirmation for valid edit', async () => {
@@ -254,8 +242,8 @@ describe('EditTool', () => {
};
// ensureCorrectEdit will be called by shouldConfirmExecute
mockEnsureCorrectEdit.mockResolvedValueOnce({ params, occurrences: 1 });
const confirmation = await tool.shouldConfirmExecute(
params,
const invocation = tool.build(params);
const confirmation = await invocation.shouldConfirmExecute(
new AbortController().signal,
);
expect(confirmation).toEqual(
@@ -275,9 +263,11 @@ describe('EditTool', () => {
new_string: 'new',
};
mockEnsureCorrectEdit.mockResolvedValueOnce({ params, occurrences: 0 });
expect(
await tool.shouldConfirmExecute(params, new AbortController().signal),
).toBe(false);
const invocation = tool.build(params);
const confirmation = await invocation.shouldConfirmExecute(
new AbortController().signal,
);
expect(confirmation).toBe(false);
});
it('should return false if multiple occurrences of old_string are found (ensureCorrectEdit returns > 1)', async () => {
@@ -288,9 +278,11 @@ describe('EditTool', () => {
new_string: 'new',
};
mockEnsureCorrectEdit.mockResolvedValueOnce({ params, occurrences: 2 });
expect(
await tool.shouldConfirmExecute(params, new AbortController().signal),
).toBe(false);
const invocation = tool.build(params);
const confirmation = await invocation.shouldConfirmExecute(
new AbortController().signal,
);
expect(confirmation).toBe(false);
});
it('should request confirmation for creating a new file (empty old_string)', async () => {
@@ -305,8 +297,8 @@ describe('EditTool', () => {
// as shouldConfirmExecute handles this for diff generation.
// If it is called, it should return 0 occurrences for a new file.
mockEnsureCorrectEdit.mockResolvedValueOnce({ params, occurrences: 0 });
const confirmation = await tool.shouldConfirmExecute(
params,
const invocation = tool.build(params);
const confirmation = await invocation.shouldConfirmExecute(
new AbortController().signal,
);
expect(confirmation).toEqual(
@@ -353,9 +345,8 @@ describe('EditTool', () => {
};
},
);
const confirmation = (await tool.shouldConfirmExecute(
params,
const invocation = tool.build(params);
const confirmation = (await invocation.shouldConfirmExecute(
new AbortController().signal,
)) as FileDiff;
@@ -403,15 +394,13 @@ describe('EditTool', () => {
});
});
it('should return error if params are invalid', async () => {
it('should throw error if params are invalid', async () => {
const params: EditToolParams = {
file_path: 'relative.txt',
old_string: 'old',
new_string: 'new',
};
const result = await tool.execute(params, new AbortController().signal);
expect(result.llmContent).toMatch(/Error: Invalid parameters provided/);
expect(result.returnDisplay).toMatch(/Error: File path must be absolute/);
expect(() => tool.build(params)).toThrow(/File path must be absolute/);
});
it('should edit an existing file and return diff with fileName', async () => {
@@ -428,12 +417,8 @@ describe('EditTool', () => {
// ensureCorrectEdit is NOT called by calculateEdit, only by shouldConfirmExecute
// So, the default mockEnsureCorrectEdit should correctly return 1 occurrence for 'old' in initialContent
// Simulate confirmation by setting shouldAlwaysEdit
(tool as any).shouldAlwaysEdit = true;
const result = await tool.execute(params, new AbortController().signal);
(tool as any).shouldAlwaysEdit = false; // Reset for other tests
const invocation = tool.build(params);
const result = await invocation.execute(new AbortController().signal);
expect(result.llmContent).toMatch(/Successfully modified file/);
expect(fs.readFileSync(filePath, 'utf8')).toBe(newContent);
@@ -456,7 +441,8 @@ describe('EditTool', () => {
(mockConfig.getApprovalMode as Mock).mockReturnValueOnce(
ApprovalMode.AUTO_EDIT,
);
const result = await tool.execute(params, new AbortController().signal);
const invocation = tool.build(params);
const result = await invocation.execute(new AbortController().signal);
expect(result.llmContent).toMatch(/Created new file/);
expect(fs.existsSync(newFilePath)).toBe(true);
@@ -472,7 +458,8 @@ describe('EditTool', () => {
new_string: 'replacement',
};
// The default mockEnsureCorrectEdit will return 0 occurrences for 'nonexistent'
const result = await tool.execute(params, new AbortController().signal);
const invocation = tool.build(params);
const result = await invocation.execute(new AbortController().signal);
expect(result.llmContent).toMatch(
/0 occurrences found for old_string in/,
);
@@ -489,7 +476,8 @@ describe('EditTool', () => {
new_string: 'new',
};
// The default mockEnsureCorrectEdit will return 2 occurrences for 'old'
const result = await tool.execute(params, new AbortController().signal);
const invocation = tool.build(params);
const result = await invocation.execute(new AbortController().signal);
expect(result.llmContent).toMatch(
/Expected 1 occurrence but found 2 for old_string in file/,
);
@@ -507,12 +495,8 @@ describe('EditTool', () => {
expected_replacements: 3,
};
// Simulate confirmation by setting shouldAlwaysEdit
(tool as any).shouldAlwaysEdit = true;
const result = await tool.execute(params, new AbortController().signal);
(tool as any).shouldAlwaysEdit = false; // Reset for other tests
const invocation = tool.build(params);
const result = await invocation.execute(new AbortController().signal);
expect(result.llmContent).toMatch(/Successfully modified file/);
expect(fs.readFileSync(filePath, 'utf8')).toBe(
@@ -532,7 +516,8 @@ describe('EditTool', () => {
new_string: 'new',
expected_replacements: 3, // Expecting 3 but only 2 exist
};
const result = await tool.execute(params, new AbortController().signal);
const invocation = tool.build(params);
const result = await invocation.execute(new AbortController().signal);
expect(result.llmContent).toMatch(
/Expected 3 occurrences but found 2 for old_string in file/,
);
@@ -548,7 +533,8 @@ describe('EditTool', () => {
old_string: '',
new_string: 'new content',
};
const result = await tool.execute(params, new AbortController().signal);
const invocation = tool.build(params);
const result = await invocation.execute(new AbortController().signal);
expect(result.llmContent).toMatch(/File already exists, cannot create/);
expect(result.returnDisplay).toMatch(
/Attempted to create a file that already exists/,
@@ -568,7 +554,8 @@ describe('EditTool', () => {
(mockConfig.getApprovalMode as Mock).mockReturnValueOnce(
ApprovalMode.AUTO_EDIT,
);
const result = await tool.execute(params, new AbortController().signal);
const invocation = tool.build(params);
const result = await invocation.execute(new AbortController().signal);
expect(result.llmContent).toMatch(
/User modified the `new_string` content/,
@@ -588,7 +575,8 @@ describe('EditTool', () => {
(mockConfig.getApprovalMode as Mock).mockReturnValueOnce(
ApprovalMode.AUTO_EDIT,
);
const result = await tool.execute(params, new AbortController().signal);
const invocation = tool.build(params);
const result = await invocation.execute(new AbortController().signal);
expect(result.llmContent).not.toMatch(
/User modified the `new_string` content/,
@@ -607,7 +595,8 @@ describe('EditTool', () => {
(mockConfig.getApprovalMode as Mock).mockReturnValueOnce(
ApprovalMode.AUTO_EDIT,
);
const result = await tool.execute(params, new AbortController().signal);
const invocation = tool.build(params);
const result = await invocation.execute(new AbortController().signal);
expect(result.llmContent).not.toMatch(
/User modified the `new_string` content/,
@@ -622,7 +611,8 @@ describe('EditTool', () => {
old_string: 'identical',
new_string: 'identical',
};
const result = await tool.execute(params, new AbortController().signal);
const invocation = tool.build(params);
const result = await invocation.execute(new AbortController().signal);
expect(result.llmContent).toMatch(/No changes to apply/);
expect(result.returnDisplay).toMatch(/No changes to apply/);
});
@@ -642,7 +632,8 @@ describe('EditTool', () => {
old_string: 'any',
new_string: 'new',
};
const result = await tool.execute(params, new AbortController().signal);
const invocation = tool.build(params);
const result = await invocation.execute(new AbortController().signal);
expect(result.error?.type).toBe(ToolErrorType.FILE_NOT_FOUND);
});
@@ -653,7 +644,8 @@ describe('EditTool', () => {
old_string: '',
new_string: 'new content',
};
const result = await tool.execute(params, new AbortController().signal);
const invocation = tool.build(params);
const result = await invocation.execute(new AbortController().signal);
expect(result.error?.type).toBe(
ToolErrorType.ATTEMPT_TO_CREATE_EXISTING_FILE,
);
@@ -666,7 +658,8 @@ describe('EditTool', () => {
old_string: 'not-found',
new_string: 'new',
};
const result = await tool.execute(params, new AbortController().signal);
const invocation = tool.build(params);
const result = await invocation.execute(new AbortController().signal);
expect(result.error?.type).toBe(ToolErrorType.EDIT_NO_OCCURRENCE_FOUND);
});
@@ -678,7 +671,8 @@ describe('EditTool', () => {
new_string: 'new',
expected_replacements: 3,
};
const result = await tool.execute(params, new AbortController().signal);
const invocation = tool.build(params);
const result = await invocation.execute(new AbortController().signal);
expect(result.error?.type).toBe(
ToolErrorType.EDIT_EXPECTED_OCCURRENCE_MISMATCH,
);
@@ -691,18 +685,18 @@ describe('EditTool', () => {
old_string: 'content',
new_string: 'content',
};
const result = await tool.execute(params, new AbortController().signal);
const invocation = tool.build(params);
const result = await invocation.execute(new AbortController().signal);
expect(result.error?.type).toBe(ToolErrorType.EDIT_NO_CHANGE);
});
it('should return INVALID_PARAMETERS error for relative path', async () => {
it('should throw INVALID_PARAMETERS error for relative path', async () => {
const params: EditToolParams = {
file_path: 'relative/path.txt',
old_string: 'a',
new_string: 'b',
};
const result = await tool.execute(params, new AbortController().signal);
expect(result.error?.type).toBe(ToolErrorType.INVALID_TOOL_PARAMS);
expect(() => tool.build(params)).toThrow();
});
it('should return FILE_WRITE_FAILURE on write error', async () => {
@@ -715,7 +709,8 @@ describe('EditTool', () => {
old_string: 'content',
new_string: 'new content',
};
const result = await tool.execute(params, new AbortController().signal);
const invocation = tool.build(params);
const result = await invocation.execute(new AbortController().signal);
expect(result.error?.type).toBe(ToolErrorType.FILE_WRITE_FAILURE);
});
});
@@ -728,8 +723,9 @@ describe('EditTool', () => {
old_string: 'identical_string',
new_string: 'identical_string',
};
const invocation = tool.build(params);
// shortenPath will be called internally, resulting in just the file name
expect(tool.getDescription(params)).toBe(
expect(invocation.getDescription()).toBe(
`No file changes to ${testFileName}`,
);
});
@@ -741,9 +737,10 @@ describe('EditTool', () => {
old_string: 'this is the old string value',
new_string: 'this is the new string value',
};
const invocation = tool.build(params);
// shortenPath will be called internally, resulting in just the file name
// The snippets are truncated at 30 chars + '...'
expect(tool.getDescription(params)).toBe(
expect(invocation.getDescription()).toBe(
`${testFileName}: this is the old string value => this is the new string value`,
);
});
@@ -755,7 +752,8 @@ describe('EditTool', () => {
old_string: 'old',
new_string: 'new',
};
expect(tool.getDescription(params)).toBe(`${testFileName}: old => new`);
const invocation = tool.build(params);
expect(invocation.getDescription()).toBe(`${testFileName}: old => new`);
});
it('should truncate long strings in the description', () => {
@@ -767,7 +765,8 @@ describe('EditTool', () => {
new_string:
'this is a very long new string that will also be truncated',
};
expect(tool.getDescription(params)).toBe(
const invocation = tool.build(params);
expect(invocation.getDescription()).toBe(
`${testFileName}: this is a very long old string... => this is a very long new string...`,
);
});
@@ -796,4 +795,57 @@ describe('EditTool', () => {
expect(error).toContain(rootDir);
});
});
describe('IDE mode', () => {
const testFile = 'edit_me.txt';
let filePath: string;
let ideClient: any;
beforeEach(() => {
filePath = path.join(rootDir, testFile);
ideClient = {
openDiff: vi.fn(),
getConnectionStatus: vi.fn().mockReturnValue({
status: IDEConnectionStatus.Connected,
}),
};
(mockConfig as any).getIdeMode = () => true;
(mockConfig as any).getIdeModeFeature = () => true;
(mockConfig as any).getIdeClient = () => ideClient;
});
it('should call ideClient.openDiff and update params on confirmation', async () => {
const initialContent = 'some old content here';
const newContent = 'some new content here';
const modifiedContent = 'some modified content here';
fs.writeFileSync(filePath, initialContent);
const params: EditToolParams = {
file_path: filePath,
old_string: 'old',
new_string: 'new',
};
mockEnsureCorrectEdit.mockResolvedValueOnce({
params: { ...params, old_string: 'old', new_string: 'new' },
occurrences: 1,
});
ideClient.openDiff.mockResolvedValueOnce({
status: 'accepted',
content: modifiedContent,
});
const invocation = tool.build(params);
const confirmation = await invocation.shouldConfirmExecute(
new AbortController().signal,
);
expect(ideClient.openDiff).toHaveBeenCalledWith(filePath, newContent);
if (confirmation && 'onConfirm' in confirmation) {
await confirmation.onConfirm(ToolConfirmationOutcome.ProceedOnce);
}
expect(params.old_string).toBe(initialContent);
expect(params.new_string).toBe(modifiedContent);
});
});
});

View File

@@ -8,25 +8,46 @@ import * as fs from 'fs';
import * as path from 'path';
import * as Diff from 'diff';
import {
BaseTool,
BaseDeclarativeTool,
Icon,
ToolCallConfirmationDetails,
ToolConfirmationOutcome,
ToolEditConfirmationDetails,
ToolInvocation,
ToolLocation,
ToolResult,
ToolResultDisplay,
} from './tools.js';
import { ToolErrorType } from './tool-error.js';
import { Type } from '@google/genai';
import { SchemaValidator } from '../utils/schemaValidator.js';
import { makeRelative, shortenPath } from '../utils/paths.js';
import { isNodeError } from '../utils/errors.js';
import { Config, ApprovalMode } from '../config/config.js';
import { ensureCorrectEdit } from '../utils/editCorrector.js';
import { DEFAULT_DIFF_OPTIONS } from './diffOptions.js';
import { DEFAULT_DIFF_OPTIONS, getDiffStat } from './diffOptions.js';
import { ReadFileTool } from './read-file.js';
import { ModifiableTool, ModifyContext } from './modifiable-tool.js';
import { ModifiableDeclarativeTool, ModifyContext } from './modifiable-tool.js';
import { IDEConnectionStatus } from '../ide/ide-client.js';
export function applyReplacement(
currentContent: string | null,
oldString: string,
newString: string,
isNewFile: boolean,
): string {
if (isNewFile) {
return newString;
}
if (currentContent === null) {
// Should not happen if not a new file, but defensively return empty or newString if oldString is also empty
return oldString === '' ? newString : '';
}
// If oldString is empty and it's not a new file, do not modify the content.
if (oldString === '' && !isNewFile) {
return currentContent;
}
return currentContent.replaceAll(oldString, newString);
}
/**
* Parameters for the Edit tool
@@ -57,6 +78,11 @@ export interface EditToolParams {
* Whether the edit was modified manually by the user.
*/
modified_by_user?: boolean;
/**
* Initially proposed string.
*/
ai_proposed_string?: string;
}
interface CalculatedEdit {
@@ -67,112 +93,14 @@ interface CalculatedEdit {
isNewFile: boolean;
}
/**
* Implementation of the Edit tool logic
*/
export class EditTool
extends BaseTool<EditToolParams, ToolResult>
implements ModifiableTool<EditToolParams>
{
static readonly Name = 'replace';
class EditToolInvocation implements ToolInvocation<EditToolParams, ToolResult> {
constructor(
private readonly config: Config,
public params: EditToolParams,
) {}
constructor(private readonly config: Config) {
super(
EditTool.Name,
'Edit',
`Replaces text within a file. By default, replaces a single occurrence, but can replace multiple occurrences when \`expected_replacements\` is specified. This tool requires providing significant context around the change to ensure precise targeting. Always use the ${ReadFileTool.Name} tool to examine the file's current content before attempting a text replacement.
The user has the ability to modify the \`new_string\` content. If modified, this will be stated in the response.
Expectation for required parameters:
1. \`file_path\` MUST be an absolute path; otherwise an error will be thrown.
2. \`old_string\` MUST be the exact literal text to replace (including all whitespace, indentation, newlines, and surrounding code etc.).
3. \`new_string\` MUST be the exact literal text to replace \`old_string\` with (also including all whitespace, indentation, newlines, and surrounding code etc.). Ensure the resulting code is correct and idiomatic.
4. NEVER escape \`old_string\` or \`new_string\`, that would break the exact literal text requirement.
**Important:** If ANY of the above are not satisfied, the tool will fail. CRITICAL for \`old_string\`: Must uniquely identify the single instance to change. Include at least 3 lines of context BEFORE and AFTER the target text, matching whitespace and indentation precisely. If this string matches multiple locations, or does not match exactly, the tool will fail.
**Multiple replacements:** Set \`expected_replacements\` to the number of occurrences you want to replace. The tool will replace ALL occurrences that match \`old_string\` exactly. Ensure the number of replacements matches your expectation.`,
Icon.Pencil,
{
properties: {
file_path: {
description:
"The absolute path to the file to modify. Must start with '/'.",
type: Type.STRING,
},
old_string: {
description:
'The exact literal text to replace, preferably unescaped. For single replacements (default), include at least 3 lines of context BEFORE and AFTER the target text, matching whitespace and indentation precisely. For multiple replacements, specify expected_replacements parameter. If this string is not the exact literal text (i.e. you escaped it) or does not match exactly, the tool will fail.',
type: Type.STRING,
},
new_string: {
description:
'The exact literal text to replace `old_string` with, preferably unescaped. Provide the EXACT text. Ensure the resulting code is correct and idiomatic.',
type: Type.STRING,
},
expected_replacements: {
type: Type.NUMBER,
description:
'Number of replacements expected. Defaults to 1 if not specified. Use when you want to replace multiple occurrences.',
minimum: 1,
},
},
required: ['file_path', 'old_string', 'new_string'],
type: Type.OBJECT,
},
);
}
/**
* Validates the parameters for the Edit tool
* @param params Parameters to validate
* @returns Error message string or null if valid
*/
validateToolParams(params: EditToolParams): string | null {
const errors = SchemaValidator.validate(this.schema.parameters, params);
if (errors) {
return errors;
}
if (!path.isAbsolute(params.file_path)) {
return `File path must be absolute: ${params.file_path}`;
}
const workspaceContext = this.config.getWorkspaceContext();
if (!workspaceContext.isPathWithinWorkspace(params.file_path)) {
const directories = workspaceContext.getDirectories();
return `File path must be within one of the workspace directories: ${directories.join(', ')}`;
}
return null;
}
/**
* Determines any file locations affected by the tool execution
* @param params Parameters for the tool execution
* @returns A list of such paths
*/
toolLocations(params: EditToolParams): ToolLocation[] {
return [{ path: params.file_path }];
}
private _applyReplacement(
currentContent: string | null,
oldString: string,
newString: string,
isNewFile: boolean,
): string {
if (isNewFile) {
return newString;
}
if (currentContent === null) {
// Should not happen if not a new file, but defensively return empty or newString if oldString is also empty
return oldString === '' ? newString : '';
}
// If oldString is empty and it's not a new file, do not modify the content.
if (oldString === '' && !isNewFile) {
return currentContent;
}
return currentContent.replaceAll(oldString, newString);
toolLocations(): ToolLocation[] {
return [{ path: this.params.file_path }];
}
/**
@@ -270,7 +198,7 @@ Expectation for required parameters:
};
}
const newContent = this._applyReplacement(
const newContent = applyReplacement(
currentContent,
finalOldString,
finalNewString,
@@ -291,23 +219,15 @@ Expectation for required parameters:
* It needs to calculate the diff to show the user.
*/
async shouldConfirmExecute(
params: EditToolParams,
abortSignal: AbortSignal,
): Promise<ToolCallConfirmationDetails | false> {
if (this.config.getApprovalMode() === ApprovalMode.AUTO_EDIT) {
return false;
}
const validationError = this.validateToolParams(params);
if (validationError) {
console.error(
`[EditTool Wrapper] Attempted confirmation with invalid parameters: ${validationError}`,
);
return false;
}
let editData: CalculatedEdit;
try {
editData = await this.calculateEdit(params, abortSignal);
editData = await this.calculateEdit(this.params, abortSignal);
} catch (error) {
const errorMsg = error instanceof Error ? error.message : String(error);
console.log(`Error preparing edit: ${errorMsg}`);
@@ -319,7 +239,7 @@ Expectation for required parameters:
return false;
}
const fileName = path.basename(params.file_path);
const fileName = path.basename(this.params.file_path);
const fileDiff = Diff.createPatch(
fileName,
editData.currentContent ?? '',
@@ -328,10 +248,19 @@ Expectation for required parameters:
'Proposed',
DEFAULT_DIFF_OPTIONS,
);
const ideClient = this.config.getIdeClient();
const ideConfirmation =
this.config.getIdeModeFeature() &&
this.config.getIdeMode() &&
ideClient?.getConnectionStatus().status === IDEConnectionStatus.Connected
? ideClient.openDiff(this.params.file_path, editData.newContent)
: undefined;
const confirmationDetails: ToolEditConfirmationDetails = {
type: 'edit',
title: `Confirm Edit: ${shortenPath(makeRelative(params.file_path, this.config.getTargetDir()))}`,
title: `Confirm Edit: ${shortenPath(makeRelative(this.params.file_path, this.config.getTargetDir()))}`,
fileName,
filePath: this.params.file_path,
fileDiff,
originalContent: editData.currentContent,
newContent: editData.newContent,
@@ -339,31 +268,39 @@ Expectation for required parameters:
if (outcome === ToolConfirmationOutcome.ProceedAlways) {
this.config.setApprovalMode(ApprovalMode.AUTO_EDIT);
}
if (ideConfirmation) {
const result = await ideConfirmation;
if (result.status === 'accepted' && result.content) {
// TODO(chrstn): See https://github.com/google-gemini/gemini-cli/pull/5618#discussion_r2255413084
// for info on a possible race condition where the file is modified on disk while being edited.
this.params.old_string = editData.currentContent ?? '';
this.params.new_string = result.content;
}
}
},
ideConfirmation,
};
return confirmationDetails;
}
getDescription(params: EditToolParams): string {
if (!params.file_path || !params.old_string || !params.new_string) {
return `Model did not provide valid parameters for edit tool`;
}
getDescription(): string {
const relativePath = makeRelative(
params.file_path,
this.params.file_path,
this.config.getTargetDir(),
);
if (params.old_string === '') {
if (this.params.old_string === '') {
return `Create ${shortenPath(relativePath)}`;
}
const oldStringSnippet =
params.old_string.split('\n')[0].substring(0, 30) +
(params.old_string.length > 30 ? '...' : '');
this.params.old_string.split('\n')[0].substring(0, 30) +
(this.params.old_string.length > 30 ? '...' : '');
const newStringSnippet =
params.new_string.split('\n')[0].substring(0, 30) +
(params.new_string.length > 30 ? '...' : '');
this.params.new_string.split('\n')[0].substring(0, 30) +
(this.params.new_string.length > 30 ? '...' : '');
if (params.old_string === params.new_string) {
if (this.params.old_string === this.params.new_string) {
return `No file changes to ${shortenPath(relativePath)}`;
}
return `${shortenPath(relativePath)}: ${oldStringSnippet} => ${newStringSnippet}`;
@@ -374,25 +311,10 @@ Expectation for required parameters:
* @param params Parameters for the edit operation
* @returns Result of the edit operation
*/
async execute(
params: EditToolParams,
signal: AbortSignal,
): Promise<ToolResult> {
const validationError = this.validateToolParams(params);
if (validationError) {
return {
llmContent: `Error: Invalid parameters provided. Reason: ${validationError}`,
returnDisplay: `Error: ${validationError}`,
error: {
message: validationError,
type: ToolErrorType.INVALID_TOOL_PARAMS,
},
};
}
async execute(signal: AbortSignal): Promise<ToolResult> {
let editData: CalculatedEdit;
try {
editData = await this.calculateEdit(params, signal);
editData = await this.calculateEdit(this.params, signal);
} catch (error) {
const errorMsg = error instanceof Error ? error.message : String(error);
return {
@@ -417,16 +339,16 @@ Expectation for required parameters:
}
try {
this.ensureParentDirectoriesExist(params.file_path);
fs.writeFileSync(params.file_path, editData.newContent, 'utf8');
this.ensureParentDirectoriesExist(this.params.file_path);
fs.writeFileSync(this.params.file_path, editData.newContent, 'utf8');
let displayResult: ToolResultDisplay;
if (editData.isNewFile) {
displayResult = `Created ${shortenPath(makeRelative(params.file_path, this.config.getTargetDir()))}`;
displayResult = `Created ${shortenPath(makeRelative(this.params.file_path, this.config.getTargetDir()))}`;
} else {
// Generate diff for display, even though core logic doesn't technically need it
// The CLI wrapper will use this part of the ToolResult
const fileName = path.basename(params.file_path);
const fileName = path.basename(this.params.file_path);
const fileDiff = Diff.createPatch(
fileName,
editData.currentContent ?? '', // Should not be null here if not isNewFile
@@ -435,22 +357,31 @@ Expectation for required parameters:
'Proposed',
DEFAULT_DIFF_OPTIONS,
);
const originallyProposedContent =
this.params.ai_proposed_string || this.params.new_string;
const diffStat = getDiffStat(
fileName,
editData.currentContent ?? '',
originallyProposedContent,
this.params.new_string,
);
displayResult = {
fileDiff,
fileName,
originalContent: editData.currentContent,
newContent: editData.newContent,
diffStat,
};
}
const llmSuccessMessageParts = [
editData.isNewFile
? `Created new file: ${params.file_path} with provided content.`
: `Successfully modified file: ${params.file_path} (${editData.occurrences} replacements).`,
? `Created new file: ${this.params.file_path} with provided content.`
: `Successfully modified file: ${this.params.file_path} (${editData.occurrences} replacements).`,
];
if (params.modified_by_user) {
if (this.params.modified_by_user) {
llmSuccessMessageParts.push(
`User modified the \`new_string\` content to be: ${params.new_string}.`,
`User modified the \`new_string\` content to be: ${this.params.new_string}.`,
);
}
@@ -480,6 +411,94 @@ Expectation for required parameters:
fs.mkdirSync(dirName, { recursive: true });
}
}
}
/**
* Implementation of the Edit tool logic
*/
export class EditTool
extends BaseDeclarativeTool<EditToolParams, ToolResult>
implements ModifiableDeclarativeTool<EditToolParams>
{
static readonly Name = 'replace';
constructor(private readonly config: Config) {
super(
EditTool.Name,
'Edit',
`Replaces text within a file. By default, replaces a single occurrence, but can replace multiple occurrences when \`expected_replacements\` is specified. This tool requires providing significant context around the change to ensure precise targeting. Always use the ${ReadFileTool.Name} tool to examine the file's current content before attempting a text replacement.
The user has the ability to modify the \`new_string\` content. If modified, this will be stated in the response.
Expectation for required parameters:
1. \`file_path\` MUST be an absolute path; otherwise an error will be thrown.
2. \`old_string\` MUST be the exact literal text to replace (including all whitespace, indentation, newlines, and surrounding code etc.).
3. \`new_string\` MUST be the exact literal text to replace \`old_string\` with (also including all whitespace, indentation, newlines, and surrounding code etc.). Ensure the resulting code is correct and idiomatic.
4. NEVER escape \`old_string\` or \`new_string\`, that would break the exact literal text requirement.
**Important:** If ANY of the above are not satisfied, the tool will fail. CRITICAL for \`old_string\`: Must uniquely identify the single instance to change. Include at least 3 lines of context BEFORE and AFTER the target text, matching whitespace and indentation precisely. If this string matches multiple locations, or does not match exactly, the tool will fail.
**Multiple replacements:** Set \`expected_replacements\` to the number of occurrences you want to replace. The tool will replace ALL occurrences that match \`old_string\` exactly. Ensure the number of replacements matches your expectation.`,
Icon.Pencil,
{
properties: {
file_path: {
description:
"The absolute path to the file to modify. Must start with '/'.",
type: 'string',
},
old_string: {
description:
'The exact literal text to replace, preferably unescaped. For single replacements (default), include at least 3 lines of context BEFORE and AFTER the target text, matching whitespace and indentation precisely. For multiple replacements, specify expected_replacements parameter. If this string is not the exact literal text (i.e. you escaped it) or does not match exactly, the tool will fail.',
type: 'string',
},
new_string: {
description:
'The exact literal text to replace `old_string` with, preferably unescaped. Provide the EXACT text. Ensure the resulting code is correct and idiomatic.',
type: 'string',
},
expected_replacements: {
type: 'number',
description:
'Number of replacements expected. Defaults to 1 if not specified. Use when you want to replace multiple occurrences.',
minimum: 1,
},
},
required: ['file_path', 'old_string', 'new_string'],
type: 'object',
},
);
}
/**
* Validates the parameters for the Edit tool
* @param params Parameters to validate
* @returns Error message string or null if valid
*/
validateToolParams(params: EditToolParams): string | null {
const errors = SchemaValidator.validate(
this.schema.parametersJsonSchema,
params,
);
if (errors) {
return errors;
}
if (!path.isAbsolute(params.file_path)) {
return `File path must be absolute: ${params.file_path}`;
}
const workspaceContext = this.config.getWorkspaceContext();
if (!workspaceContext.isPathWithinWorkspace(params.file_path)) {
const directories = workspaceContext.getDirectories();
return `File path must be within one of the workspace directories: ${directories.join(', ')}`;
}
return null;
}
protected createInvocation(
params: EditToolParams,
): ToolInvocation<EditToolParams, ToolResult> {
return new EditToolInvocation(this.config, params);
}
getModifyContext(_: AbortSignal): ModifyContext<EditToolParams> {
return {
@@ -495,7 +514,7 @@ Expectation for required parameters:
getProposedContent: async (params: EditToolParams): Promise<string> => {
try {
const currentContent = fs.readFileSync(params.file_path, 'utf8');
return this._applyReplacement(
return applyReplacement(
currentContent,
params.old_string,
params.new_string,
@@ -510,12 +529,16 @@ Expectation for required parameters:
oldContent: string,
modifiedProposedContent: string,
originalParams: EditToolParams,
): EditToolParams => ({
...originalParams,
old_string: oldContent,
new_string: modifiedProposedContent,
modified_by_user: true,
}),
): EditToolParams => {
const content = originalParams.new_string;
return {
...originalParams,
ai_proposed_string: content,
old_string: oldContent,
new_string: modifiedProposedContent,
modified_by_user: true,
};
},
};
}
}

View File

@@ -64,7 +64,8 @@ describe('GlobTool', () => {
describe('execute', () => {
it('should find files matching a simple pattern in the root', async () => {
const params: GlobToolParams = { pattern: '*.txt' };
const result = await globTool.execute(params, abortSignal);
const invocation = globTool.build(params);
const result = await invocation.execute(abortSignal);
expect(result.llmContent).toContain('Found 2 file(s)');
expect(result.llmContent).toContain(path.join(tempRootDir, 'fileA.txt'));
expect(result.llmContent).toContain(path.join(tempRootDir, 'FileB.TXT'));
@@ -73,7 +74,8 @@ describe('GlobTool', () => {
it('should find files case-sensitively when case_sensitive is true', async () => {
const params: GlobToolParams = { pattern: '*.txt', case_sensitive: true };
const result = await globTool.execute(params, abortSignal);
const invocation = globTool.build(params);
const result = await invocation.execute(abortSignal);
expect(result.llmContent).toContain('Found 1 file(s)');
expect(result.llmContent).toContain(path.join(tempRootDir, 'fileA.txt'));
expect(result.llmContent).not.toContain(
@@ -83,7 +85,8 @@ describe('GlobTool', () => {
it('should find files case-insensitively by default (pattern: *.TXT)', async () => {
const params: GlobToolParams = { pattern: '*.TXT' };
const result = await globTool.execute(params, abortSignal);
const invocation = globTool.build(params);
const result = await invocation.execute(abortSignal);
expect(result.llmContent).toContain('Found 2 file(s)');
expect(result.llmContent).toContain(path.join(tempRootDir, 'fileA.txt'));
expect(result.llmContent).toContain(path.join(tempRootDir, 'FileB.TXT'));
@@ -94,7 +97,8 @@ describe('GlobTool', () => {
pattern: '*.TXT',
case_sensitive: false,
};
const result = await globTool.execute(params, abortSignal);
const invocation = globTool.build(params);
const result = await invocation.execute(abortSignal);
expect(result.llmContent).toContain('Found 2 file(s)');
expect(result.llmContent).toContain(path.join(tempRootDir, 'fileA.txt'));
expect(result.llmContent).toContain(path.join(tempRootDir, 'FileB.TXT'));
@@ -102,7 +106,8 @@ describe('GlobTool', () => {
it('should find files using a pattern that includes a subdirectory', async () => {
const params: GlobToolParams = { pattern: 'sub/*.md' };
const result = await globTool.execute(params, abortSignal);
const invocation = globTool.build(params);
const result = await invocation.execute(abortSignal);
expect(result.llmContent).toContain('Found 2 file(s)');
expect(result.llmContent).toContain(
path.join(tempRootDir, 'sub', 'fileC.md'),
@@ -114,7 +119,8 @@ describe('GlobTool', () => {
it('should find files in a specified relative path (relative to rootDir)', async () => {
const params: GlobToolParams = { pattern: '*.md', path: 'sub' };
const result = await globTool.execute(params, abortSignal);
const invocation = globTool.build(params);
const result = await invocation.execute(abortSignal);
expect(result.llmContent).toContain('Found 2 file(s)');
expect(result.llmContent).toContain(
path.join(tempRootDir, 'sub', 'fileC.md'),
@@ -126,7 +132,8 @@ describe('GlobTool', () => {
it('should find files using a deep globstar pattern (e.g., **/*.log)', async () => {
const params: GlobToolParams = { pattern: '**/*.log' };
const result = await globTool.execute(params, abortSignal);
const invocation = globTool.build(params);
const result = await invocation.execute(abortSignal);
expect(result.llmContent).toContain('Found 1 file(s)');
expect(result.llmContent).toContain(
path.join(tempRootDir, 'sub', 'deep', 'fileE.log'),
@@ -135,7 +142,8 @@ describe('GlobTool', () => {
it('should return "No files found" message when pattern matches nothing', async () => {
const params: GlobToolParams = { pattern: '*.nonexistent' };
const result = await globTool.execute(params, abortSignal);
const invocation = globTool.build(params);
const result = await invocation.execute(abortSignal);
expect(result.llmContent).toContain(
'No files found matching pattern "*.nonexistent"',
);
@@ -144,7 +152,8 @@ describe('GlobTool', () => {
it('should correctly sort files by modification time (newest first)', async () => {
const params: GlobToolParams = { pattern: '*.sortme' };
const result = await globTool.execute(params, abortSignal);
const invocation = globTool.build(params);
const result = await invocation.execute(abortSignal);
const llmContent = partListUnionToString(result.llmContent);
expect(llmContent).toContain('Found 2 file(s)');
@@ -242,8 +251,8 @@ describe('GlobTool', () => {
// Let's try to go further up.
const paramsOutside: GlobToolParams = {
pattern: '*.txt',
path: '../../../../../../../../../../tmp',
}; // Definitely outside
path: '../../../../../../../../../../tmp', // Definitely outside
};
expect(specificGlobTool.validateToolParams(paramsOutside)).toContain(
'resolves outside the allowed workspace directories',
);
@@ -290,7 +299,8 @@ describe('GlobTool', () => {
it('should work with paths in workspace subdirectories', async () => {
const params: GlobToolParams = { pattern: '*.md', path: 'sub' };
const result = await globTool.execute(params, abortSignal);
const invocation = globTool.build(params);
const result = await invocation.execute(abortSignal);
expect(result.llmContent).toContain('Found 2 file(s)');
expect(result.llmContent).toContain('fileC.md');

View File

@@ -8,8 +8,13 @@ import fs from 'fs';
import path from 'path';
import { glob } from 'glob';
import { SchemaValidator } from '../utils/schemaValidator.js';
import { BaseTool, Icon, ToolResult } from './tools.js';
import { Type } from '@google/genai';
import {
BaseDeclarativeTool,
BaseToolInvocation,
Icon,
ToolInvocation,
ToolResult,
} from './tools.js';
import { shortenPath, makeRelative } from '../utils/paths.js';
import { Config } from '../config/config.js';
@@ -74,10 +79,168 @@ export interface GlobToolParams {
respect_git_ignore?: boolean;
}
class GlobToolInvocation extends BaseToolInvocation<
GlobToolParams,
ToolResult
> {
constructor(
private config: Config,
params: GlobToolParams,
) {
super(params);
}
getDescription(): string {
let description = `'${this.params.pattern}'`;
if (this.params.path) {
const searchDir = path.resolve(
this.config.getTargetDir(),
this.params.path || '.',
);
const relativePath = makeRelative(searchDir, this.config.getTargetDir());
description += ` within ${shortenPath(relativePath)}`;
}
return description;
}
async execute(signal: AbortSignal): Promise<ToolResult> {
try {
const workspaceContext = this.config.getWorkspaceContext();
const workspaceDirectories = workspaceContext.getDirectories();
// If a specific path is provided, resolve it and check if it's within workspace
let searchDirectories: readonly string[];
if (this.params.path) {
const searchDirAbsolute = path.resolve(
this.config.getTargetDir(),
this.params.path,
);
if (!workspaceContext.isPathWithinWorkspace(searchDirAbsolute)) {
return {
llmContent: `Error: Path "${this.params.path}" is not within any workspace directory`,
returnDisplay: `Path is not within workspace`,
};
}
searchDirectories = [searchDirAbsolute];
} else {
// Search across all workspace directories
searchDirectories = workspaceDirectories;
}
// Get centralized file discovery service
const respectGitIgnore =
this.params.respect_git_ignore ??
this.config.getFileFilteringRespectGitIgnore();
const fileDiscovery = this.config.getFileService();
// Collect entries from all search directories
let allEntries: GlobPath[] = [];
for (const searchDir of searchDirectories) {
const entries = (await glob(this.params.pattern, {
cwd: searchDir,
withFileTypes: true,
nodir: true,
stat: true,
nocase: !this.params.case_sensitive,
dot: true,
ignore: ['**/node_modules/**', '**/.git/**'],
follow: false,
signal,
})) as GlobPath[];
allEntries = allEntries.concat(entries);
}
const entries = allEntries;
// Apply git-aware filtering if enabled and in git repository
let filteredEntries = entries;
let gitIgnoredCount = 0;
if (respectGitIgnore) {
const relativePaths = entries.map((p) =>
path.relative(this.config.getTargetDir(), p.fullpath()),
);
const filteredRelativePaths = fileDiscovery.filterFiles(relativePaths, {
respectGitIgnore,
});
const filteredAbsolutePaths = new Set(
filteredRelativePaths.map((p) =>
path.resolve(this.config.getTargetDir(), p),
),
);
filteredEntries = entries.filter((entry) =>
filteredAbsolutePaths.has(entry.fullpath()),
);
gitIgnoredCount = entries.length - filteredEntries.length;
}
if (!filteredEntries || filteredEntries.length === 0) {
let message = `No files found matching pattern "${this.params.pattern}"`;
if (searchDirectories.length === 1) {
message += ` within ${searchDirectories[0]}`;
} else {
message += ` within ${searchDirectories.length} workspace directories`;
}
if (gitIgnoredCount > 0) {
message += ` (${gitIgnoredCount} files were git-ignored)`;
}
return {
llmContent: message,
returnDisplay: `No files found`,
};
}
// Set filtering such that we first show the most recent files
const oneDayInMs = 24 * 60 * 60 * 1000;
const nowTimestamp = new Date().getTime();
// Sort the filtered entries using the new helper function
const sortedEntries = sortFileEntries(
filteredEntries,
nowTimestamp,
oneDayInMs,
);
const sortedAbsolutePaths = sortedEntries.map((entry) =>
entry.fullpath(),
);
const fileListDescription = sortedAbsolutePaths.join('\n');
const fileCount = sortedAbsolutePaths.length;
let resultMessage = `Found ${fileCount} file(s) matching "${this.params.pattern}"`;
if (searchDirectories.length === 1) {
resultMessage += ` within ${searchDirectories[0]}`;
} else {
resultMessage += ` across ${searchDirectories.length} workspace directories`;
}
if (gitIgnoredCount > 0) {
resultMessage += ` (${gitIgnoredCount} additional files were git-ignored)`;
}
resultMessage += `, sorted by modification time (newest first):\n${fileListDescription}`;
return {
llmContent: resultMessage,
returnDisplay: `Found ${fileCount} matching file(s)`,
};
} catch (error) {
const errorMessage =
error instanceof Error ? error.message : String(error);
console.error(`GlobLogic execute Error: ${errorMessage}`, error);
return {
llmContent: `Error during glob search operation: ${errorMessage}`,
returnDisplay: `Error: An unexpected error occurred.`,
};
}
}
}
/**
* Implementation of the Glob tool logic
*/
export class GlobTool extends BaseTool<GlobToolParams, ToolResult> {
export class GlobTool extends BaseDeclarativeTool<GlobToolParams, ToolResult> {
static readonly Name = 'glob';
constructor(private config: Config) {
@@ -91,26 +254,26 @@ export class GlobTool extends BaseTool<GlobToolParams, ToolResult> {
pattern: {
description:
"The glob pattern to match against (e.g., '**/*.py', 'docs/*.md').",
type: Type.STRING,
type: 'string',
},
path: {
description:
'Optional: The absolute path to the directory to search within. If omitted, searches the root directory.',
type: Type.STRING,
type: 'string',
},
case_sensitive: {
description:
'Optional: Whether the search should be case-sensitive. Defaults to false.',
type: Type.BOOLEAN,
type: 'boolean',
},
respect_git_ignore: {
description:
'Optional: Whether to respect .gitignore patterns when finding files. Only available in git repositories. Defaults to true.',
type: Type.BOOLEAN,
type: 'boolean',
},
},
required: ['pattern'],
type: Type.OBJECT,
type: 'object',
},
);
}
@@ -119,7 +282,10 @@ export class GlobTool extends BaseTool<GlobToolParams, ToolResult> {
* Validates the parameters for the tool.
*/
validateToolParams(params: GlobToolParams): string | null {
const errors = SchemaValidator.validate(this.schema.parameters, params);
const errors = SchemaValidator.validate(
this.schema.parametersJsonSchema,
params,
);
if (errors) {
return errors;
}
@@ -158,166 +324,9 @@ export class GlobTool extends BaseTool<GlobToolParams, ToolResult> {
return null;
}
/**
* Gets a description of the glob operation.
*/
getDescription(params: GlobToolParams): string {
let description = `'${params.pattern}'`;
if (params.path) {
const searchDir = path.resolve(
this.config.getTargetDir(),
params.path || '.',
);
const relativePath = makeRelative(searchDir, this.config.getTargetDir());
description += ` within ${shortenPath(relativePath)}`;
}
return description;
}
/**
* Executes the glob search with the given parameters
*/
async execute(
protected createInvocation(
params: GlobToolParams,
signal: AbortSignal,
): Promise<ToolResult> {
const validationError = this.validateToolParams(params);
if (validationError) {
return {
llmContent: `Error: Invalid parameters provided. Reason: ${validationError}`,
returnDisplay: validationError,
};
}
try {
const workspaceContext = this.config.getWorkspaceContext();
const workspaceDirectories = workspaceContext.getDirectories();
// If a specific path is provided, resolve it and check if it's within workspace
let searchDirectories: readonly string[];
if (params.path) {
const searchDirAbsolute = path.resolve(
this.config.getTargetDir(),
params.path,
);
if (!workspaceContext.isPathWithinWorkspace(searchDirAbsolute)) {
return {
llmContent: `Error: Path "${params.path}" is not within any workspace directory`,
returnDisplay: `Path is not within workspace`,
};
}
searchDirectories = [searchDirAbsolute];
} else {
// Search across all workspace directories
searchDirectories = workspaceDirectories;
}
// Get centralized file discovery service
const respectGitIgnore =
params.respect_git_ignore ??
this.config.getFileFilteringRespectGitIgnore();
const fileDiscovery = this.config.getFileService();
// Collect entries from all search directories
let allEntries: GlobPath[] = [];
for (const searchDir of searchDirectories) {
const entries = (await glob(params.pattern, {
cwd: searchDir,
withFileTypes: true,
nodir: true,
stat: true,
nocase: !params.case_sensitive,
dot: true,
ignore: ['**/node_modules/**', '**/.git/**'],
follow: false,
signal,
})) as GlobPath[];
allEntries = allEntries.concat(entries);
}
const entries = allEntries;
// Apply git-aware filtering if enabled and in git repository
let filteredEntries = entries;
let gitIgnoredCount = 0;
if (respectGitIgnore) {
const relativePaths = entries.map((p) =>
path.relative(this.config.getTargetDir(), p.fullpath()),
);
const filteredRelativePaths = fileDiscovery.filterFiles(relativePaths, {
respectGitIgnore,
});
const filteredAbsolutePaths = new Set(
filteredRelativePaths.map((p) =>
path.resolve(this.config.getTargetDir(), p),
),
);
filteredEntries = entries.filter((entry) =>
filteredAbsolutePaths.has(entry.fullpath()),
);
gitIgnoredCount = entries.length - filteredEntries.length;
}
if (!filteredEntries || filteredEntries.length === 0) {
let message = `No files found matching pattern "${params.pattern}"`;
if (searchDirectories.length === 1) {
message += ` within ${searchDirectories[0]}`;
} else {
message += ` within ${searchDirectories.length} workspace directories`;
}
if (gitIgnoredCount > 0) {
message += ` (${gitIgnoredCount} files were git-ignored)`;
}
return {
llmContent: message,
returnDisplay: `No files found`,
};
}
// Set filtering such that we first show the most recent files
const oneDayInMs = 24 * 60 * 60 * 1000;
const nowTimestamp = new Date().getTime();
// Sort the filtered entries using the new helper function
const sortedEntries = sortFileEntries(
filteredEntries,
nowTimestamp,
oneDayInMs,
);
const sortedAbsolutePaths = sortedEntries.map((entry) =>
entry.fullpath(),
);
const fileListDescription = sortedAbsolutePaths.join('\n');
const fileCount = sortedAbsolutePaths.length;
let resultMessage = `Found ${fileCount} file(s) matching "${params.pattern}"`;
if (searchDirectories.length === 1) {
resultMessage += ` within ${searchDirectories[0]}`;
} else {
resultMessage += ` across ${searchDirectories.length} workspace directories`;
}
if (gitIgnoredCount > 0) {
resultMessage += ` (${gitIgnoredCount} additional files were git-ignored)`;
}
resultMessage += `, sorted by modification time (newest first):\n${fileListDescription}`;
return {
llmContent: resultMessage,
returnDisplay: `Found ${fileCount} matching file(s)`,
};
} catch (error) {
const errorMessage =
error instanceof Error ? error.message : String(error);
console.error(`GlobLogic execute Error: ${errorMessage}`, error);
return {
llmContent: `Error during glob search operation: ${errorMessage}`,
returnDisplay: `Error: An unexpected error occurred.`,
};
}
): ToolInvocation<GlobToolParams, ToolResult> {
return new GlobToolInvocation(this.config, params);
}
}

View File

@@ -126,7 +126,8 @@ describe('GrepTool', () => {
describe('execute', () => {
it('should find matches for a simple pattern in all files', async () => {
const params: GrepToolParams = { pattern: 'world' };
const result = await grepTool.execute(params, abortSignal);
const invocation = grepTool.build(params);
const result = await invocation.execute(abortSignal);
expect(result.llmContent).toContain(
'Found 3 matches for pattern "world" in the workspace directory',
);
@@ -142,7 +143,8 @@ describe('GrepTool', () => {
it('should find matches in a specific path', async () => {
const params: GrepToolParams = { pattern: 'world', path: 'sub' };
const result = await grepTool.execute(params, abortSignal);
const invocation = grepTool.build(params);
const result = await invocation.execute(abortSignal);
expect(result.llmContent).toContain(
'Found 1 match for pattern "world" in path "sub"',
);
@@ -153,7 +155,8 @@ describe('GrepTool', () => {
it('should find matches with an include glob', async () => {
const params: GrepToolParams = { pattern: 'hello', include: '*.js' };
const result = await grepTool.execute(params, abortSignal);
const invocation = grepTool.build(params);
const result = await invocation.execute(abortSignal);
expect(result.llmContent).toContain(
'Found 1 match for pattern "hello" in the workspace directory (filter: "*.js"):',
);
@@ -174,7 +177,8 @@ describe('GrepTool', () => {
path: 'sub',
include: '*.js',
};
const result = await grepTool.execute(params, abortSignal);
const invocation = grepTool.build(params);
const result = await invocation.execute(abortSignal);
expect(result.llmContent).toContain(
'Found 1 match for pattern "hello" in path "sub" (filter: "*.js")',
);
@@ -185,7 +189,8 @@ describe('GrepTool', () => {
it('should return "No matches found" when pattern does not exist', async () => {
const params: GrepToolParams = { pattern: 'nonexistentpattern' };
const result = await grepTool.execute(params, abortSignal);
const invocation = grepTool.build(params);
const result = await invocation.execute(abortSignal);
expect(result.llmContent).toContain(
'No matches found for pattern "nonexistentpattern" in the workspace directory.',
);
@@ -194,7 +199,8 @@ describe('GrepTool', () => {
it('should handle regex special characters correctly', async () => {
const params: GrepToolParams = { pattern: 'foo.*bar' }; // Matches 'const foo = "bar";'
const result = await grepTool.execute(params, abortSignal);
const invocation = grepTool.build(params);
const result = await invocation.execute(abortSignal);
expect(result.llmContent).toContain(
'Found 1 match for pattern "foo.*bar" in the workspace directory:',
);
@@ -204,7 +210,8 @@ describe('GrepTool', () => {
it('should be case-insensitive by default (JS fallback)', async () => {
const params: GrepToolParams = { pattern: 'HELLO' };
const result = await grepTool.execute(params, abortSignal);
const invocation = grepTool.build(params);
const result = await invocation.execute(abortSignal);
expect(result.llmContent).toContain(
'Found 2 matches for pattern "HELLO" in the workspace directory:',
);
@@ -216,14 +223,10 @@ describe('GrepTool', () => {
);
});
it('should return an error if params are invalid', async () => {
it('should throw an error if params are invalid', async () => {
const params = { path: '.' } as unknown as GrepToolParams; // Invalid: pattern missing
const result = await grepTool.execute(params, abortSignal);
expect(result.llmContent).toBe(
"Error: Invalid parameters provided. Reason: params must have required property 'pattern'",
);
expect(result.returnDisplay).toBe(
"Model provided invalid parameters. Error: params must have required property 'pattern'",
expect(() => grepTool.build(params)).toThrow(
/params must have required property 'pattern'/,
);
});
@@ -242,7 +245,8 @@ describe('GrepTool', () => {
// Search for 'world' which exists in both the regular file and the ignored file
const params: GrepToolParams = { pattern: 'world' };
const result = await grepToolWithIgnore.execute(params, abortSignal);
const invocation = grepToolWithIgnore.build(params);
const result = await invocation.execute(abortSignal);
// Should only find matches in the non-ignored files (3 matches)
expect(result.llmContent).toContain(
@@ -294,7 +298,8 @@ describe('GrepTool', () => {
const multiDirGrepTool = new GrepTool(multiDirConfig);
const params: GrepToolParams = { pattern: 'world' };
const result = await multiDirGrepTool.execute(params, abortSignal);
const invocation = multiDirGrepTool.build(params);
const result = await invocation.execute(abortSignal);
// Should find matches in both directories
expect(result.llmContent).toContain(
@@ -350,7 +355,8 @@ describe('GrepTool', () => {
// Search only in the 'sub' directory of the first workspace
const params: GrepToolParams = { pattern: 'world', path: 'sub' };
const result = await multiDirGrepTool.execute(params, abortSignal);
const invocation = multiDirGrepTool.build(params);
const result = await invocation.execute(abortSignal);
// Should only find matches in the specified sub directory
expect(result.llmContent).toContain(
@@ -370,7 +376,8 @@ describe('GrepTool', () => {
describe('getDescription', () => {
it('should generate correct description with pattern only', () => {
const params: GrepToolParams = { pattern: 'testPattern' };
expect(grepTool.getDescription(params)).toBe("'testPattern'");
const invocation = grepTool.build(params);
expect(invocation.getDescription()).toBe("'testPattern'");
});
it('should generate correct description with pattern and include', () => {
@@ -378,19 +385,21 @@ describe('GrepTool', () => {
pattern: 'testPattern',
include: '*.ts',
};
expect(grepTool.getDescription(params)).toBe("'testPattern' in *.ts");
const invocation = grepTool.build(params);
expect(invocation.getDescription()).toBe("'testPattern' in *.ts");
});
it('should generate correct description with pattern and path', () => {
it('should generate correct description with pattern and path', async () => {
const dirPath = path.join(tempRootDir, 'src', 'app');
await fs.mkdir(dirPath, { recursive: true });
const params: GrepToolParams = {
pattern: 'testPattern',
path: path.join('src', 'app'),
};
const invocation = grepTool.build(params);
// The path will be relative to the tempRootDir, so we check for containment.
expect(grepTool.getDescription(params)).toContain("'testPattern' within");
expect(grepTool.getDescription(params)).toContain(
path.join('src', 'app'),
);
expect(invocation.getDescription()).toContain("'testPattern' within");
expect(invocation.getDescription()).toContain(path.join('src', 'app'));
});
it('should indicate searching across all workspace directories when no path specified', () => {
@@ -403,28 +412,31 @@ describe('GrepTool', () => {
const multiDirGrepTool = new GrepTool(multiDirConfig);
const params: GrepToolParams = { pattern: 'testPattern' };
expect(multiDirGrepTool.getDescription(params)).toBe(
const invocation = multiDirGrepTool.build(params);
expect(invocation.getDescription()).toBe(
"'testPattern' across all workspace directories",
);
});
it('should generate correct description with pattern, include, and path', () => {
it('should generate correct description with pattern, include, and path', async () => {
const dirPath = path.join(tempRootDir, 'src', 'app');
await fs.mkdir(dirPath, { recursive: true });
const params: GrepToolParams = {
pattern: 'testPattern',
include: '*.ts',
path: path.join('src', 'app'),
};
expect(grepTool.getDescription(params)).toContain(
const invocation = grepTool.build(params);
expect(invocation.getDescription()).toContain(
"'testPattern' in *.ts within",
);
expect(grepTool.getDescription(params)).toContain(
path.join('src', 'app'),
);
expect(invocation.getDescription()).toContain(path.join('src', 'app'));
});
it('should use ./ for root path in description', () => {
const params: GrepToolParams = { pattern: 'testPattern', path: '.' };
expect(grepTool.getDescription(params)).toBe("'testPattern' within ./");
const invocation = grepTool.build(params);
expect(invocation.getDescription()).toBe("'testPattern' within ./");
});
});
});

View File

@@ -9,9 +9,14 @@ import fsPromises from 'fs/promises';
import path from 'path';
import { EOL } from 'os';
import { spawn } from 'child_process';
import { globIterate } from 'glob';
import { BaseTool, Icon, ToolResult } from './tools.js';
import { Type } from '@google/genai';
import { globStream } from 'glob';
import {
BaseDeclarativeTool,
BaseToolInvocation,
Icon,
ToolInvocation,
ToolResult,
} from './tools.js';
import { SchemaValidator } from '../utils/schemaValidator.js';
import { makeRelative, shortenPath } from '../utils/paths.js';
import { getErrorMessage, isNodeError } from '../utils/errors.js';
@@ -49,46 +54,17 @@ interface GrepMatch {
line: string;
}
// --- GrepLogic Class ---
/**
* Implementation of the Grep tool logic (moved from CLI)
*/
export class GrepTool extends BaseTool<GrepToolParams, ToolResult> {
static readonly Name = 'search_file_content'; // Keep static name
constructor(private readonly config: Config) {
super(
GrepTool.Name,
'SearchText',
'Searches for a regular expression pattern within the content of files in a specified directory (or current working directory). Can filter files by a glob pattern. Returns the lines containing matches, along with their file paths and line numbers.',
Icon.Regex,
{
properties: {
pattern: {
description:
"The regular expression (regex) pattern to search for within file contents (e.g., 'function\\s+myFunction', 'import\\s+\\{.*\\}\\s+from\\s+.*').",
type: Type.STRING,
},
path: {
description:
'Optional: The absolute path to the directory to search within. If omitted, searches the current working directory.',
type: Type.STRING,
},
include: {
description:
"Optional: A glob pattern to filter which files are searched (e.g., '*.js', '*.{ts,tsx}', 'src/**'). If omitted, searches all files (respecting potential global ignores).",
type: Type.STRING,
},
},
required: ['pattern'],
type: Type.OBJECT,
},
);
class GrepToolInvocation extends BaseToolInvocation<
GrepToolParams,
ToolResult
> {
constructor(
private readonly config: Config,
params: GrepToolParams,
) {
super(params);
}
// --- Validation Methods ---
/**
* Checks if a path is within the root directory and resolves it.
* @param relativePath Path relative to the root directory (or undefined for root).
@@ -130,58 +106,11 @@ export class GrepTool extends BaseTool<GrepToolParams, ToolResult> {
return targetPath;
}
/**
* Validates the parameters for the tool
* @param params Parameters to validate
* @returns An error message string if invalid, null otherwise
*/
validateToolParams(params: GrepToolParams): string | null {
const errors = SchemaValidator.validate(this.schema.parameters, params);
if (errors) {
return errors;
}
try {
new RegExp(params.pattern);
} catch (error) {
return `Invalid regular expression pattern provided: ${params.pattern}. Error: ${getErrorMessage(error)}`;
}
// Only validate path if one is provided
if (params.path) {
try {
this.resolveAndValidatePath(params.path);
} catch (error) {
return getErrorMessage(error);
}
}
return null; // Parameters are valid
}
// --- Core Execution ---
/**
* Executes the grep search with the given parameters
* @param params Parameters for the grep search
* @returns Result of the grep search
*/
async execute(
params: GrepToolParams,
signal: AbortSignal,
): Promise<ToolResult> {
const validationError = this.validateToolParams(params);
if (validationError) {
return {
llmContent: `Error: Invalid parameters provided. Reason: ${validationError}`,
returnDisplay: `Model provided invalid parameters. Error: ${validationError}`,
};
}
async execute(signal: AbortSignal): Promise<ToolResult> {
try {
const workspaceContext = this.config.getWorkspaceContext();
const searchDirAbs = this.resolveAndValidatePath(params.path);
const searchDirDisplay = params.path || '.';
const searchDirAbs = this.resolveAndValidatePath(this.params.path);
const searchDirDisplay = this.params.path || '.';
// Determine which directories to search
let searchDirectories: readonly string[];
@@ -197,9 +126,9 @@ export class GrepTool extends BaseTool<GrepToolParams, ToolResult> {
let allMatches: GrepMatch[] = [];
for (const searchDir of searchDirectories) {
const matches = await this.performGrepSearch({
pattern: params.pattern,
pattern: this.params.pattern,
path: searchDir,
include: params.include,
include: this.params.include,
signal,
});
@@ -226,7 +155,7 @@ export class GrepTool extends BaseTool<GrepToolParams, ToolResult> {
}
if (allMatches.length === 0) {
const noMatchMsg = `No matches found for pattern "${params.pattern}" ${searchLocationDescription}${params.include ? ` (filter: "${params.include}")` : ''}.`;
const noMatchMsg = `No matches found for pattern "${this.params.pattern}" ${searchLocationDescription}${this.params.include ? ` (filter: "${this.params.include}")` : ''}.`;
return { llmContent: noMatchMsg, returnDisplay: `No matches found` };
}
@@ -247,7 +176,7 @@ export class GrepTool extends BaseTool<GrepToolParams, ToolResult> {
const matchCount = allMatches.length;
const matchTerm = matchCount === 1 ? 'match' : 'matches';
let llmContent = `Found ${matchCount} ${matchTerm} for pattern "${params.pattern}" ${searchLocationDescription}${params.include ? ` (filter: "${params.include}")` : ''}:
let llmContent = `Found ${matchCount} ${matchTerm} for pattern "${this.params.pattern}" ${searchLocationDescription}${this.params.include ? ` (filter: "${this.params.include}")` : ''}:
---
`;
@@ -274,8 +203,6 @@ export class GrepTool extends BaseTool<GrepToolParams, ToolResult> {
}
}
// --- Grep Implementation Logic ---
/**
* Checks if a command is available in the system's PATH.
* @param {string} command The command name (e.g., 'git', 'grep').
@@ -353,17 +280,20 @@ export class GrepTool extends BaseTool<GrepToolParams, ToolResult> {
* @param params Parameters for the grep operation
* @returns A string describing the grep
*/
getDescription(params: GrepToolParams): string {
let description = `'${params.pattern}'`;
if (params.include) {
description += ` in ${params.include}`;
getDescription(): string {
let description = `'${this.params.pattern}'`;
if (this.params.include) {
description += ` in ${this.params.include}`;
}
if (params.path) {
if (this.params.path) {
const resolvedPath = path.resolve(
this.config.getTargetDir(),
params.path,
this.params.path,
);
if (resolvedPath === this.config.getTargetDir() || params.path === '.') {
if (
resolvedPath === this.config.getTargetDir() ||
this.params.path === '.'
) {
description += ` within ./`;
} else {
const relativePath = makeRelative(
@@ -445,7 +375,9 @@ export class GrepTool extends BaseTool<GrepToolParams, ToolResult> {
return this.parseGrepOutput(output, absolutePath);
} catch (gitError: unknown) {
console.debug(
`GrepLogic: git grep failed: ${getErrorMessage(gitError)}. Falling back...`,
`GrepLogic: git grep failed: ${getErrorMessage(
gitError,
)}. Falling back...`,
);
}
}
@@ -525,7 +457,9 @@ export class GrepTool extends BaseTool<GrepToolParams, ToolResult> {
return this.parseGrepOutput(output, absolutePath);
} catch (grepError: unknown) {
console.debug(
`GrepLogic: System grep failed: ${getErrorMessage(grepError)}. Falling back...`,
`GrepLogic: System grep failed: ${getErrorMessage(
grepError,
)}. Falling back...`,
);
}
}
@@ -550,7 +484,7 @@ export class GrepTool extends BaseTool<GrepToolParams, ToolResult> {
...fileDiscovery.getGeminiIgnorePatterns(),
]; // Use glob patterns for ignores here
const filesIterator = globIterate(globPattern, {
const filesIterator = globStream(globPattern, {
cwd: absolutePath,
dot: true,
ignore: ignorePatterns,
@@ -582,7 +516,9 @@ export class GrepTool extends BaseTool<GrepToolParams, ToolResult> {
// Ignore errors like permission denied or file gone during read
if (!isNodeError(readError) || readError.code !== 'ENOENT') {
console.debug(
`GrepLogic: Could not read/process ${fileAbsolutePath}: ${getErrorMessage(readError)}`,
`GrepLogic: Could not read/process ${fileAbsolutePath}: ${getErrorMessage(
readError,
)}`,
);
}
}
@@ -591,9 +527,129 @@ export class GrepTool extends BaseTool<GrepToolParams, ToolResult> {
return allMatches;
} catch (error: unknown) {
console.error(
`GrepLogic: Error in performGrepSearch (Strategy: ${strategyUsed}): ${getErrorMessage(error)}`,
`GrepLogic: Error in performGrepSearch (Strategy: ${strategyUsed}): ${getErrorMessage(
error,
)}`,
);
throw error; // Re-throw
}
}
}
// --- GrepLogic Class ---
/**
* Implementation of the Grep tool logic (moved from CLI)
*/
export class GrepTool extends BaseDeclarativeTool<GrepToolParams, ToolResult> {
static readonly Name = 'search_file_content'; // Keep static name
constructor(private readonly config: Config) {
super(
GrepTool.Name,
'SearchText',
'Searches for a regular expression pattern within the content of files in a specified directory (or current working directory). Can filter files by a glob pattern. Returns the lines containing matches, along with their file paths and line numbers.',
Icon.Regex,
{
properties: {
pattern: {
description:
"The regular expression (regex) pattern to search for within file contents (e.g., 'function\\s+myFunction', 'import\\s+\\{.*\\}\\s+from\\s+.*').",
type: 'string',
},
path: {
description:
'Optional: The absolute path to the directory to search within. If omitted, searches the current working directory.',
type: 'string',
},
include: {
description:
"Optional: A glob pattern to filter which files are searched (e.g., '*.js', '*.{ts,tsx}', 'src/**'). If omitted, searches all files (respecting potential global ignores).",
type: 'string',
},
},
required: ['pattern'],
type: 'object',
},
);
}
/**
* Checks if a path is within the root directory and resolves it.
* @param relativePath Path relative to the root directory (or undefined for root).
* @returns The absolute path if valid and exists, or null if no path specified (to search all directories).
* @throws {Error} If path is outside root, doesn't exist, or isn't a directory.
*/
private resolveAndValidatePath(relativePath?: string): string | null {
// If no path specified, return null to indicate searching all workspace directories
if (!relativePath) {
return null;
}
const targetPath = path.resolve(this.config.getTargetDir(), relativePath);
// Security Check: Ensure the resolved path is within workspace boundaries
const workspaceContext = this.config.getWorkspaceContext();
if (!workspaceContext.isPathWithinWorkspace(targetPath)) {
const directories = workspaceContext.getDirectories();
throw new Error(
`Path validation failed: Attempted path "${relativePath}" resolves outside the allowed workspace directories: ${directories.join(', ')}`,
);
}
// Check existence and type after resolving
try {
const stats = fs.statSync(targetPath);
if (!stats.isDirectory()) {
throw new Error(`Path is not a directory: ${targetPath}`);
}
} catch (error: unknown) {
if (isNodeError(error) && error.code !== 'ENOENT') {
throw new Error(`Path does not exist: ${targetPath}`);
}
throw new Error(
`Failed to access path stats for ${targetPath}: ${error}`,
);
}
return targetPath;
}
/**
* Validates the parameters for the tool
* @param params Parameters to validate
* @returns An error message string if invalid, null otherwise
*/
validateToolParams(params: GrepToolParams): string | null {
const errors = SchemaValidator.validate(
this.schema.parametersJsonSchema,
params,
);
if (errors) {
return errors;
}
try {
new RegExp(params.pattern);
} catch (error) {
return `Invalid regular expression pattern provided: ${params.pattern}. Error: ${getErrorMessage(error)}`;
}
// Only validate path if one is provided
if (params.path) {
try {
this.resolveAndValidatePath(params.path);
} catch (error) {
return getErrorMessage(error);
}
}
return null; // Parameters are valid
}
protected createInvocation(
params: GrepToolParams,
): ToolInvocation<GrepToolParams, ToolResult> {
return new GrepToolInvocation(this.config, params);
}
}

View File

@@ -7,7 +7,6 @@
import fs from 'fs';
import path from 'path';
import { BaseTool, Icon, ToolResult } from './tools.js';
import { Type } from '@google/genai';
import { SchemaValidator } from '../utils/schemaValidator.js';
import { makeRelative, shortenPath } from '../utils/paths.js';
import { Config, DEFAULT_FILE_FILTERING_OPTIONS } from '../config/config.js';
@@ -82,35 +81,35 @@ export class LSTool extends BaseTool<LSToolParams, ToolResult> {
path: {
description:
'The absolute path to the directory to list (must be absolute, not relative)',
type: Type.STRING,
type: 'string',
},
ignore: {
description: 'List of glob patterns to ignore',
items: {
type: Type.STRING,
type: 'string',
},
type: Type.ARRAY,
type: 'array',
},
file_filtering_options: {
description:
'Optional: Whether to respect ignore patterns from .gitignore or .geminiignore',
type: Type.OBJECT,
type: 'object',
properties: {
respect_git_ignore: {
description:
'Optional: Whether to respect .gitignore patterns when listing files. Only available in git repositories. Defaults to true.',
type: Type.BOOLEAN,
type: 'boolean',
},
respect_gemini_ignore: {
description:
'Optional: Whether to respect .geminiignore patterns when listing files. Defaults to true.',
type: Type.BOOLEAN,
type: 'boolean',
},
},
},
},
required: ['path'],
type: Type.OBJECT,
type: 'object',
},
);
}
@@ -121,7 +120,10 @@ export class LSTool extends BaseTool<LSToolParams, ToolResult> {
* @returns An error message string if invalid, null otherwise
*/
validateToolParams(params: LSToolParams): string | null {
const errors = SchemaValidator.validate(this.schema.parameters, params);
const errors = SchemaValidator.validate(
this.schema.parametersJsonSchema,
params,
);
if (errors) {
return errors;
}

View File

@@ -12,6 +12,8 @@ import {
isEnabled,
discoverTools,
discoverPrompts,
hasValidTypes,
connectToMcpServer,
} from './mcp-client.js';
import { SSEClientTransport } from '@modelcontextprotocol/sdk/client/sse.js';
import * as SdkClientStdioLib from '@modelcontextprotocol/sdk/client/stdio.js';
@@ -22,6 +24,8 @@ import { AuthProviderType } from '../config/config.js';
import { PromptRegistry } from '../prompts/prompt-registry.js';
import { DiscoveredMCPTool } from './mcp-tool.js';
import { WorkspaceContext } from '../utils/workspaceContext.js';
import { pathToFileURL } from 'node:url';
vi.mock('@modelcontextprotocol/sdk/client/stdio.js');
vi.mock('@modelcontextprotocol/sdk/client/index.js');
@@ -97,6 +101,232 @@ describe('mcp-client', () => {
`Error discovering tool: 'invalid tool name' from MCP server 'test-server': ${testError.message}`,
);
});
it('should skip tools if a parameter is missing a type', async () => {
const mockedClient = {} as unknown as ClientLib.Client;
const consoleWarnSpy = vi
.spyOn(console, 'warn')
.mockImplementation(() => {});
vi.mocked(GenAiLib.mcpToTool).mockReturnValue({
tool: () =>
Promise.resolve({
functionDeclarations: [
{
name: 'validTool',
parametersJsonSchema: {
type: 'object',
properties: {
param1: { type: 'string' },
},
},
},
{
name: 'invalidTool',
parametersJsonSchema: {
type: 'object',
properties: {
param1: { description: 'a param with no type' },
},
},
},
],
}),
} as unknown as GenAiLib.CallableTool);
const tools = await discoverTools('test-server', {}, mockedClient);
expect(tools.length).toBe(1);
expect(vi.mocked(DiscoveredMCPTool).mock.calls[0][2]).toBe('validTool');
expect(consoleWarnSpy).toHaveBeenCalledOnce();
expect(consoleWarnSpy).toHaveBeenCalledWith(
`Skipping tool 'invalidTool' from MCP server 'test-server' because it has ` +
`missing types in its parameter schema. Please file an issue with the owner of the MCP server.`,
);
consoleWarnSpy.mockRestore();
});
it('should skip tools if a nested parameter is missing a type', async () => {
const mockedClient = {} as unknown as ClientLib.Client;
const consoleWarnSpy = vi
.spyOn(console, 'warn')
.mockImplementation(() => {});
vi.mocked(GenAiLib.mcpToTool).mockReturnValue({
tool: () =>
Promise.resolve({
functionDeclarations: [
{
name: 'invalidTool',
parametersJsonSchema: {
type: 'object',
properties: {
param1: {
type: 'object',
properties: {
nestedParam: {
description: 'a nested param with no type',
},
},
},
},
},
},
],
}),
} as unknown as GenAiLib.CallableTool);
const tools = await discoverTools('test-server', {}, mockedClient);
expect(tools.length).toBe(0);
expect(consoleWarnSpy).toHaveBeenCalledOnce();
expect(consoleWarnSpy).toHaveBeenCalledWith(
`Skipping tool 'invalidTool' from MCP server 'test-server' because it has ` +
`missing types in its parameter schema. Please file an issue with the owner of the MCP server.`,
);
consoleWarnSpy.mockRestore();
});
it('should skip tool if an array item is missing a type', async () => {
const mockedClient = {} as unknown as ClientLib.Client;
const consoleWarnSpy = vi
.spyOn(console, 'warn')
.mockImplementation(() => {});
vi.mocked(GenAiLib.mcpToTool).mockReturnValue({
tool: () =>
Promise.resolve({
functionDeclarations: [
{
name: 'invalidTool',
parametersJsonSchema: {
type: 'object',
properties: {
param1: {
type: 'array',
items: {
description: 'an array item with no type',
},
},
},
},
},
],
}),
} as unknown as GenAiLib.CallableTool);
const tools = await discoverTools('test-server', {}, mockedClient);
expect(tools.length).toBe(0);
expect(consoleWarnSpy).toHaveBeenCalledOnce();
expect(consoleWarnSpy).toHaveBeenCalledWith(
`Skipping tool 'invalidTool' from MCP server 'test-server' because it has ` +
`missing types in its parameter schema. Please file an issue with the owner of the MCP server.`,
);
consoleWarnSpy.mockRestore();
});
it('should discover tool with no properties in schema', async () => {
const mockedClient = {} as unknown as ClientLib.Client;
const consoleWarnSpy = vi
.spyOn(console, 'warn')
.mockImplementation(() => {});
vi.mocked(GenAiLib.mcpToTool).mockReturnValue({
tool: () =>
Promise.resolve({
functionDeclarations: [
{
name: 'validTool',
parametersJsonSchema: {
type: 'object',
},
},
],
}),
} as unknown as GenAiLib.CallableTool);
const tools = await discoverTools('test-server', {}, mockedClient);
expect(tools.length).toBe(1);
expect(vi.mocked(DiscoveredMCPTool).mock.calls[0][2]).toBe('validTool');
expect(consoleWarnSpy).not.toHaveBeenCalled();
consoleWarnSpy.mockRestore();
});
it('should discover tool with empty properties object in schema', async () => {
const mockedClient = {} as unknown as ClientLib.Client;
const consoleWarnSpy = vi
.spyOn(console, 'warn')
.mockImplementation(() => {});
vi.mocked(GenAiLib.mcpToTool).mockReturnValue({
tool: () =>
Promise.resolve({
functionDeclarations: [
{
name: 'validTool',
parametersJsonSchema: {
type: 'object',
properties: {},
},
},
],
}),
} as unknown as GenAiLib.CallableTool);
const tools = await discoverTools('test-server', {}, mockedClient);
expect(tools.length).toBe(1);
expect(vi.mocked(DiscoveredMCPTool).mock.calls[0][2]).toBe('validTool');
expect(consoleWarnSpy).not.toHaveBeenCalled();
consoleWarnSpy.mockRestore();
});
});
describe('connectToMcpServer', () => {
it('should register a roots/list handler', async () => {
const mockedClient = {
registerCapabilities: vi.fn(),
setRequestHandler: vi.fn(),
callTool: vi.fn(),
connect: vi.fn(),
};
vi.mocked(ClientLib.Client).mockReturnValue(
mockedClient as unknown as ClientLib.Client,
);
vi.spyOn(SdkClientStdioLib, 'StdioClientTransport').mockReturnValue(
{} as SdkClientStdioLib.StdioClientTransport,
);
const mockWorkspaceContext = {
getDirectories: vi
.fn()
.mockReturnValue(['/test/dir', '/another/project']),
} as unknown as WorkspaceContext;
await connectToMcpServer(
'test-server',
{
command: 'test-command',
},
false,
mockWorkspaceContext,
);
expect(mockedClient.registerCapabilities).toHaveBeenCalledWith({
roots: {},
});
expect(mockedClient.setRequestHandler).toHaveBeenCalledOnce();
const handler = mockedClient.setRequestHandler.mock.calls[0][1];
const roots = await handler();
expect(roots).toEqual({
roots: [
{
uri: pathToFileURL('/test/dir').toString(),
name: 'dir',
},
{
uri: pathToFileURL('/another/project').toString(),
name: 'project',
},
],
});
});
});
describe('discoverPrompts', () => {
@@ -309,7 +539,9 @@ describe('mcp-client', () => {
});
it('should connect via command', async () => {
const mockedTransport = vi.mocked(SdkClientStdioLib.StdioClientTransport);
const mockedTransport = vi
.spyOn(SdkClientStdioLib, 'StdioClientTransport')
.mockReturnValue({} as SdkClientStdioLib.StdioClientTransport);
await createTransport(
'test-server',
@@ -336,7 +568,7 @@ describe('mcp-client', () => {
const transport = await createTransport(
'test-server',
{
httpUrl: 'http://test-server',
httpUrl: 'http://test.googleapis.com',
authProviderType: AuthProviderType.GOOGLE_CREDENTIALS,
oauth: {
scopes: ['scope1'],
@@ -355,7 +587,7 @@ describe('mcp-client', () => {
const transport = await createTransport(
'test-server',
{
url: 'http://test-server',
url: 'http://test.googleapis.com',
authProviderType: AuthProviderType.GOOGLE_CREDENTIALS,
oauth: {
scopes: ['scope1'],
@@ -383,7 +615,7 @@ describe('mcp-client', () => {
false,
),
).rejects.toThrow(
'No URL configured for Google Credentials MCP server',
'URL must be provided in the config for Google Credentials provider',
);
});
});
@@ -433,4 +665,163 @@ describe('mcp-client', () => {
);
});
});
describe('hasValidTypes', () => {
it('should return true for a valid schema with anyOf', () => {
const schema = {
anyOf: [{ type: 'string' }, { type: 'number' }],
};
expect(hasValidTypes(schema)).toBe(true);
});
it('should return false for an invalid schema with anyOf', () => {
const schema = {
anyOf: [{ type: 'string' }, { description: 'no type' }],
};
expect(hasValidTypes(schema)).toBe(false);
});
it('should return true for a valid schema with allOf', () => {
const schema = {
allOf: [
{ type: 'string' },
{ type: 'object', properties: { foo: { type: 'string' } } },
],
};
expect(hasValidTypes(schema)).toBe(true);
});
it('should return false for an invalid schema with allOf', () => {
const schema = {
allOf: [{ type: 'string' }, { description: 'no type' }],
};
expect(hasValidTypes(schema)).toBe(false);
});
it('should return true for a valid schema with oneOf', () => {
const schema = {
oneOf: [{ type: 'string' }, { type: 'number' }],
};
expect(hasValidTypes(schema)).toBe(true);
});
it('should return false for an invalid schema with oneOf', () => {
const schema = {
oneOf: [{ type: 'string' }, { description: 'no type' }],
};
expect(hasValidTypes(schema)).toBe(false);
});
it('should return true for a valid schema with nested subschemas', () => {
const schema = {
anyOf: [
{ type: 'string' },
{
allOf: [
{ type: 'object', properties: { a: { type: 'string' } } },
{ type: 'object', properties: { b: { type: 'number' } } },
],
},
],
};
expect(hasValidTypes(schema)).toBe(true);
});
it('should return false for an invalid schema with nested subschemas', () => {
const schema = {
anyOf: [
{ type: 'string' },
{
allOf: [
{ type: 'object', properties: { a: { type: 'string' } } },
{ description: 'no type' },
],
},
],
};
expect(hasValidTypes(schema)).toBe(false);
});
it('should return true for a schema with a type and subschemas', () => {
const schema = {
type: 'string',
anyOf: [{ minLength: 1 }, { maxLength: 5 }],
};
expect(hasValidTypes(schema)).toBe(true);
});
it('should return false for a schema with no type and no subschemas', () => {
const schema = {
description: 'a schema with no type',
};
expect(hasValidTypes(schema)).toBe(false);
});
it('should return true for a valid schema', () => {
const schema = {
type: 'object',
properties: {
param1: { type: 'string' },
},
};
expect(hasValidTypes(schema)).toBe(true);
});
it('should return false if a parameter is missing a type', () => {
const schema = {
type: 'object',
properties: {
param1: { description: 'a param with no type' },
},
};
expect(hasValidTypes(schema)).toBe(false);
});
it('should return false if a nested parameter is missing a type', () => {
const schema = {
type: 'object',
properties: {
param1: {
type: 'object',
properties: {
nestedParam: {
description: 'a nested param with no type',
},
},
},
},
};
expect(hasValidTypes(schema)).toBe(false);
});
it('should return false if an array item is missing a type', () => {
const schema = {
type: 'object',
properties: {
param1: {
type: 'array',
items: {
description: 'an array item with no type',
},
},
},
};
expect(hasValidTypes(schema)).toBe(false);
});
it('should return true for a schema with no properties', () => {
const schema = {
type: 'object',
};
expect(hasValidTypes(schema)).toBe(true);
});
it('should return true for a schema with an empty properties object', () => {
const schema = {
type: 'object',
properties: {},
};
expect(hasValidTypes(schema)).toBe(true);
});
});
});

View File

@@ -20,6 +20,7 @@ import {
ListPromptsResultSchema,
GetPromptResult,
GetPromptResultSchema,
ListRootsRequestSchema,
} from '@modelcontextprotocol/sdk/types.js';
import { parse } from 'shell-quote';
import { AuthProviderType, MCPServerConfig } from '../config/config.js';
@@ -33,6 +34,9 @@ import { MCPOAuthProvider } from '../mcp/oauth-provider.js';
import { OAuthUtils } from '../mcp/oauth-utils.js';
import { MCPOAuthTokenStorage } from '../mcp/oauth-token-storage.js';
import { getErrorMessage } from '../utils/errors.js';
import { basename } from 'node:path';
import { pathToFileURL } from 'node:url';
import { WorkspaceContext } from '../utils/workspaceContext.js';
export const MCP_DEFAULT_TIMEOUT_MSEC = 10 * 60 * 1000; // default to 10 minutes
@@ -306,6 +310,7 @@ export async function discoverMcpTools(
toolRegistry: ToolRegistry,
promptRegistry: PromptRegistry,
debugMode: boolean,
workspaceContext: WorkspaceContext,
): Promise<void> {
mcpDiscoveryState = MCPDiscoveryState.IN_PROGRESS;
try {
@@ -319,6 +324,7 @@ export async function discoverMcpTools(
toolRegistry,
promptRegistry,
debugMode,
workspaceContext,
),
);
await Promise.all(discoveryPromises);
@@ -363,6 +369,7 @@ export async function connectAndDiscover(
toolRegistry: ToolRegistry,
promptRegistry: PromptRegistry,
debugMode: boolean,
workspaceContext: WorkspaceContext,
): Promise<void> {
updateMCPServerStatus(mcpServerName, MCPServerStatus.CONNECTING);
@@ -372,6 +379,7 @@ export async function connectAndDiscover(
mcpServerName,
mcpServerConfig,
debugMode,
workspaceContext,
);
mcpClient.onerror = (error) => {
@@ -416,6 +424,65 @@ export async function connectAndDiscover(
}
}
/**
* Recursively validates that a JSON schema and all its nested properties and
* items have a `type` defined.
*
* @param schema The JSON schema to validate.
* @returns `true` if the schema is valid, `false` otherwise.
*
* @visiblefortesting
*/
export function hasValidTypes(schema: unknown): boolean {
if (typeof schema !== 'object' || schema === null) {
// Not a schema object we can validate, or not a schema at all.
// Treat as valid as it has no properties to be invalid.
return true;
}
const s = schema as Record<string, unknown>;
if (!s.type) {
// These keywords contain an array of schemas that should be validated.
//
// If no top level type was given, then they must each have a type.
let hasSubSchema = false;
const schemaArrayKeywords = ['anyOf', 'allOf', 'oneOf'];
for (const keyword of schemaArrayKeywords) {
const subSchemas = s[keyword];
if (Array.isArray(subSchemas)) {
hasSubSchema = true;
for (const subSchema of subSchemas) {
if (!hasValidTypes(subSchema)) {
return false;
}
}
}
}
// If the node itself is missing a type and had no subschemas, then it isn't valid.
if (!hasSubSchema) return false;
}
if (s.type === 'object' && s.properties) {
if (typeof s.properties === 'object' && s.properties !== null) {
for (const prop of Object.values(s.properties)) {
if (!hasValidTypes(prop)) {
return false;
}
}
}
}
if (s.type === 'array' && s.items) {
if (!hasValidTypes(s.items)) {
return false;
}
}
return true;
}
/**
* Discovers and sanitizes tools from a connected MCP client.
* It retrieves function declarations from the client, filters out disabled tools,
@@ -448,6 +515,15 @@ export async function discoverTools(
continue;
}
if (!hasValidTypes(funcDecl.parametersJsonSchema)) {
console.warn(
`Skipping tool '${funcDecl.name}' from MCP server '${mcpServerName}' ` +
`because it has missing types in its parameter schema. Please file an ` +
`issue with the owner of the MCP server.`,
);
continue;
}
discoveredTools.push(
new DiscoveredMCPTool(
mcpCallableTool,
@@ -587,12 +663,30 @@ export async function connectToMcpServer(
mcpServerName: string,
mcpServerConfig: MCPServerConfig,
debugMode: boolean,
workspaceContext: WorkspaceContext,
): Promise<Client> {
const mcpClient = new Client({
name: 'qwen-code-mcp-client',
version: '0.0.1',
});
mcpClient.registerCapabilities({
roots: {},
});
mcpClient.setRequestHandler(ListRootsRequestSchema, async () => {
const roots = [];
for (const dir of workspaceContext.getDirectories()) {
roots.push({
uri: pathToFileURL(dir).toString(),
name: basename(dir),
});
}
return {
roots,
};
});
// patch Client.callTool to use request timeout as genai McpCallTool.callTool does not do it
// TODO: remove this hack once GenAI SDK does callTool with request options
if ('callTool' in mcpClient) {

View File

@@ -12,13 +12,7 @@ import {
ToolMcpConfirmationDetails,
Icon,
} from './tools.js';
import {
CallableTool,
Part,
FunctionCall,
FunctionDeclaration,
Type,
} from '@google/genai';
import { CallableTool, Part, FunctionCall } from '@google/genai';
type ToolParams = Record<string, unknown>;
@@ -64,7 +58,7 @@ export class DiscoveredMCPTool extends BaseTool<ToolParams, ToolResult> {
readonly serverName: string,
readonly serverToolName: string,
description: string,
readonly parameterSchemaJson: unknown,
readonly parameterSchema: unknown,
readonly timeout?: number,
readonly trust?: boolean,
nameOverride?: string,
@@ -74,7 +68,7 @@ export class DiscoveredMCPTool extends BaseTool<ToolParams, ToolResult> {
`${serverToolName} (${serverName} MCP Server)`,
description,
Icon.Hammer,
{ type: Type.OBJECT }, // this is a dummy Schema for MCP, will be not be used to construct the FunctionDeclaration
parameterSchema,
true, // isOutputMarkdown
false, // canUpdateOutput
);
@@ -86,25 +80,13 @@ export class DiscoveredMCPTool extends BaseTool<ToolParams, ToolResult> {
this.serverName,
this.serverToolName,
this.description,
this.parameterSchemaJson,
this.parameterSchema,
this.timeout,
this.trust,
`${this.serverName}__${this.serverToolName}`,
);
}
/**
* Overrides the base schema to use parametersJsonSchema when building
* FunctionDeclaration
*/
override get schema(): FunctionDeclaration {
return {
name: this.name,
description: this.description,
parametersJsonSchema: this.parameterSchemaJson,
};
}
async shouldConfirmExecute(
_params: ToolParams,
_abortSignal: AbortSignal,

View File

@@ -199,7 +199,17 @@ describe('MemoryTool', () => {
);
expect(memoryTool.schema).toBeDefined();
expect(memoryTool.schema.name).toBe('save_memory');
expect(memoryTool.schema.parameters?.properties?.fact).toBeDefined();
expect(memoryTool.schema.parametersJsonSchema).toStrictEqual({
type: 'object',
properties: {
fact: {
type: 'string',
description:
'The specific fact or piece of information to remember. Should be a clear, self-contained statement.',
},
},
required: ['fact'],
});
});
it('should call performAddMemoryEntry with correct parameters and return success', async () => {

View File

@@ -11,24 +11,24 @@ import {
ToolConfirmationOutcome,
Icon,
} from './tools.js';
import { FunctionDeclaration, Type } from '@google/genai';
import { FunctionDeclaration } from '@google/genai';
import * as fs from 'fs/promises';
import * as path from 'path';
import { homedir } from 'os';
import * as Diff from 'diff';
import { DEFAULT_DIFF_OPTIONS } from './diffOptions.js';
import { tildeifyPath } from '../utils/paths.js';
import { ModifiableTool, ModifyContext } from './modifiable-tool.js';
import { ModifiableDeclarativeTool, ModifyContext } from './modifiable-tool.js';
const memoryToolSchemaData: FunctionDeclaration = {
name: 'save_memory',
description:
'Saves a specific piece of information or fact to your long-term memory. Use this when the user explicitly asks you to remember something, or when they state a clear, concise fact that seems important to retain for future interactions.',
parameters: {
type: Type.OBJECT,
parametersJsonSchema: {
type: 'object',
properties: {
fact: {
type: Type.STRING,
type: 'string',
description:
'The specific fact or piece of information to remember. Should be a clear, self-contained statement.',
},
@@ -112,7 +112,7 @@ function ensureNewlineSeparation(currentContent: string): string {
export class MemoryTool
extends BaseTool<SaveMemoryParams, ToolResult>
implements ModifiableTool<SaveMemoryParams>
implements ModifiableDeclarativeTool<SaveMemoryParams>
{
private static readonly allowlist: Set<string> = new Set();
@@ -123,7 +123,7 @@ export class MemoryTool
'Save Memory',
memoryToolDescription,
Icon.LightBulb,
memoryToolSchemaData.parameters as Record<string, unknown>,
memoryToolSchemaData.parametersJsonSchema as Record<string, unknown>,
);
}
@@ -220,6 +220,7 @@ export class MemoryTool
type: 'edit',
title: `Confirm Memory Save: ${tildeifyPath(memoryFilePath)}`,
fileName: memoryFilePath,
filePath: memoryFilePath,
fileDiff,
originalContent: currentContent,
newContent,

View File

@@ -8,8 +8,8 @@ import { vi, describe, it, expect, beforeEach, afterEach } from 'vitest';
import {
modifyWithEditor,
ModifyContext,
ModifiableTool,
isModifiableTool,
ModifiableDeclarativeTool,
isModifiableDeclarativeTool,
} from './modifiable-tool.js';
import { EditorType } from '../utils/editor.js';
import fs from 'fs';
@@ -338,16 +338,16 @@ describe('isModifiableTool', () => {
const mockTool = {
name: 'test-tool',
getModifyContext: vi.fn(),
} as unknown as ModifiableTool<TestParams>;
} as unknown as ModifiableDeclarativeTool<TestParams>;
expect(isModifiableTool(mockTool)).toBe(true);
expect(isModifiableDeclarativeTool(mockTool)).toBe(true);
});
it('should return false for objects without getModifyContext method', () => {
const mockTool = {
name: 'test-tool',
} as unknown as ModifiableTool<TestParams>;
} as unknown as ModifiableDeclarativeTool<TestParams>;
expect(isModifiableTool(mockTool)).toBe(false);
expect(isModifiableDeclarativeTool(mockTool)).toBe(false);
});
});

View File

@@ -11,13 +11,14 @@ import fs from 'fs';
import * as Diff from 'diff';
import { DEFAULT_DIFF_OPTIONS } from './diffOptions.js';
import { isNodeError } from '../utils/errors.js';
import { Tool } from './tools.js';
import { AnyDeclarativeTool, DeclarativeTool, ToolResult } from './tools.js';
/**
* A tool that supports a modify operation.
* A declarative tool that supports a modify operation.
*/
export interface ModifiableTool<ToolParams> extends Tool<ToolParams> {
getModifyContext(abortSignal: AbortSignal): ModifyContext<ToolParams>;
export interface ModifiableDeclarativeTool<TParams extends object>
extends DeclarativeTool<TParams, ToolResult> {
getModifyContext(abortSignal: AbortSignal): ModifyContext<TParams>;
}
export interface ModifyContext<ToolParams> {
@@ -39,9 +40,12 @@ export interface ModifyResult<ToolParams> {
updatedDiff: string;
}
export function isModifiableTool<TParams>(
tool: Tool<TParams>,
): tool is ModifiableTool<TParams> {
/**
* Type guard to check if a declarative tool is modifiable.
*/
export function isModifiableDeclarativeTool(
tool: AnyDeclarativeTool,
): tool is ModifiableDeclarativeTool<object> {
return 'getModifyContext' in tool;
}

View File

@@ -6,6 +6,7 @@
import { describe, it, expect, beforeEach, afterEach } from 'vitest';
import { ReadFileTool, ReadFileToolParams } from './read-file.js';
import { ToolErrorType } from './tool-error.js';
import path from 'path';
import os from 'os';
import fs from 'fs';
@@ -13,6 +14,7 @@ import fsp from 'fs/promises';
import { Config } from '../config/config.js';
import { FileDiscoveryService } from '../services/fileDiscoveryService.js';
import { createMockWorkspaceContext } from '../test-utils/mockWorkspaceContext.js';
import { ToolInvocation, ToolResult } from './tools.js';
describe('ReadFileTool', () => {
let tempRootDir: string;
@@ -40,112 +42,137 @@ describe('ReadFileTool', () => {
}
});
describe('validateToolParams', () => {
it('should return null for valid params (absolute path within root)', () => {
describe('build', () => {
it('should return an invocation for valid params (absolute path within root)', () => {
const params: ReadFileToolParams = {
absolute_path: path.join(tempRootDir, 'test.txt'),
};
expect(tool.validateToolParams(params)).toBeNull();
const result = tool.build(params);
expect(typeof result).not.toBe('string');
});
it('should return null for valid params with offset and limit', () => {
it('should throw error if file path is relative', () => {
const params: ReadFileToolParams = {
absolute_path: path.join(tempRootDir, 'test.txt'),
offset: 0,
limit: 10,
absolute_path: 'relative/path.txt',
};
expect(tool.validateToolParams(params)).toBeNull();
});
it('should return error for relative path', () => {
const params: ReadFileToolParams = { absolute_path: 'test.txt' };
expect(tool.validateToolParams(params)).toBe(
`File path must be absolute, but was relative: test.txt. You must provide an absolute path.`,
expect(() => tool.build(params)).toThrow(
'File path must be absolute, but was relative: relative/path.txt. You must provide an absolute path.',
);
});
it('should return error for path outside root', () => {
const outsidePath = path.resolve(os.tmpdir(), 'outside-root.txt');
const params: ReadFileToolParams = { absolute_path: outsidePath };
const error = tool.validateToolParams(params);
expect(error).toContain(
'File path must be within one of the workspace directories',
it('should throw error if path is outside root', () => {
const params: ReadFileToolParams = {
absolute_path: '/outside/root.txt',
};
expect(() => tool.build(params)).toThrow(
/File path must be within one of the workspace directories/,
);
});
it('should return error for negative offset', () => {
it('should throw error if offset is negative', () => {
const params: ReadFileToolParams = {
absolute_path: path.join(tempRootDir, 'test.txt'),
offset: -1,
limit: 10,
};
expect(tool.validateToolParams(params)).toBe(
expect(() => tool.build(params)).toThrow(
'Offset must be a non-negative number',
);
});
it('should return error for non-positive limit', () => {
const paramsZero: ReadFileToolParams = {
it('should throw error if limit is zero or negative', () => {
const params: ReadFileToolParams = {
absolute_path: path.join(tempRootDir, 'test.txt'),
offset: 0,
limit: 0,
};
expect(tool.validateToolParams(paramsZero)).toBe(
expect(() => tool.build(params)).toThrow(
'Limit must be a positive number',
);
const paramsNegative: ReadFileToolParams = {
absolute_path: path.join(tempRootDir, 'test.txt'),
offset: 0,
limit: -5,
};
expect(tool.validateToolParams(paramsNegative)).toBe(
'Limit must be a positive number',
);
});
it('should return error for schema validation failure (e.g. missing path)', () => {
const params = { offset: 0 } as unknown as ReadFileToolParams;
expect(tool.validateToolParams(params)).toBe(
`params must have required property 'absolute_path'`,
);
});
});
describe('getDescription', () => {
it('should return a shortened, relative path', () => {
const filePath = path.join(tempRootDir, 'sub', 'dir', 'file.txt');
const params: ReadFileToolParams = { absolute_path: filePath };
expect(tool.getDescription(params)).toBe(
path.join('sub', 'dir', 'file.txt'),
it('should return relative path without limit/offset', () => {
const subDir = path.join(tempRootDir, 'sub', 'dir');
const params: ReadFileToolParams = {
absolute_path: path.join(subDir, 'file.txt'),
};
const invocation = tool.build(params);
expect(typeof invocation).not.toBe('string');
expect(
(
invocation as ToolInvocation<ReadFileToolParams, ToolResult>
).getDescription(),
).toBe(path.join('sub', 'dir', 'file.txt'));
});
it('should return shortened path when file path is deep', () => {
const deepPath = path.join(
tempRootDir,
'very',
'deep',
'directory',
'structure',
'that',
'exceeds',
'the',
'normal',
'limit',
'file.txt',
);
const params: ReadFileToolParams = { absolute_path: deepPath };
const invocation = tool.build(params);
expect(typeof invocation).not.toBe('string');
const desc = (
invocation as ToolInvocation<ReadFileToolParams, ToolResult>
).getDescription();
expect(desc).toContain('...');
expect(desc).toContain('file.txt');
});
it('should handle non-normalized file paths correctly', () => {
const subDir = path.join(tempRootDir, 'sub', 'dir');
const params: ReadFileToolParams = {
absolute_path: path.join(subDir, '..', 'dir', 'file.txt'),
};
const invocation = tool.build(params);
expect(typeof invocation).not.toBe('string');
expect(
(
invocation as ToolInvocation<ReadFileToolParams, ToolResult>
).getDescription(),
).toBe(path.join('sub', 'dir', 'file.txt'));
});
it('should return . if path is the root directory', () => {
const params: ReadFileToolParams = { absolute_path: tempRootDir };
expect(tool.getDescription(params)).toBe('.');
const invocation = tool.build(params);
expect(typeof invocation).not.toBe('string');
expect(
(
invocation as ToolInvocation<ReadFileToolParams, ToolResult>
).getDescription(),
).toBe('.');
});
});
describe('execute', () => {
it('should return validation error if params are invalid', async () => {
const params: ReadFileToolParams = {
absolute_path: 'relative/path.txt',
};
expect(await tool.execute(params, abortSignal)).toEqual({
llmContent:
'Error: Invalid parameters provided. Reason: File path must be absolute, but was relative: relative/path.txt. You must provide an absolute path.',
returnDisplay:
'File path must be absolute, but was relative: relative/path.txt. You must provide an absolute path.',
});
});
it('should return error if file does not exist', async () => {
const filePath = path.join(tempRootDir, 'nonexistent.txt');
const params: ReadFileToolParams = { absolute_path: filePath };
const invocation = tool.build(params) as ToolInvocation<
ReadFileToolParams,
ToolResult
>;
expect(await tool.execute(params, abortSignal)).toEqual({
llmContent: `File not found: ${filePath}`,
const result = await invocation.execute(abortSignal);
expect(result).toEqual({
llmContent:
'Could not read file because no file was found at the specified path.',
returnDisplay: 'File not found.',
error: {
message: `File not found: ${filePath}`,
type: ToolErrorType.FILE_NOT_FOUND,
},
});
});
@@ -154,59 +181,191 @@ describe('ReadFileTool', () => {
const fileContent = 'This is a test file.';
await fsp.writeFile(filePath, fileContent, 'utf-8');
const params: ReadFileToolParams = { absolute_path: filePath };
const invocation = tool.build(params) as ToolInvocation<
ReadFileToolParams,
ToolResult
>;
expect(await tool.execute(params, abortSignal)).toEqual({
expect(await invocation.execute(abortSignal)).toEqual({
llmContent: fileContent,
returnDisplay: '',
});
});
it('should return success result for an image file', async () => {
// A minimal 1x1 transparent PNG file.
const pngContent = Buffer.from([
137, 80, 78, 71, 13, 10, 26, 10, 0, 0, 0, 13, 73, 72, 68, 82, 0, 0, 0,
1, 0, 0, 0, 1, 8, 6, 0, 0, 0, 31, 21, 196, 137, 0, 0, 0, 10, 73, 68, 65,
84, 120, 156, 99, 0, 1, 0, 0, 5, 0, 1, 13, 10, 45, 180, 0, 0, 0, 0, 73,
69, 78, 68, 174, 66, 96, 130,
]);
const filePath = path.join(tempRootDir, 'image.png');
await fsp.writeFile(filePath, pngContent);
const params: ReadFileToolParams = { absolute_path: filePath };
it('should return error if path is a directory', async () => {
const dirPath = path.join(tempRootDir, 'directory');
await fsp.mkdir(dirPath);
const params: ReadFileToolParams = { absolute_path: dirPath };
const invocation = tool.build(params) as ToolInvocation<
ReadFileToolParams,
ToolResult
>;
expect(await tool.execute(params, abortSignal)).toEqual({
llmContent: {
inlineData: {
mimeType: 'image/png',
data: pngContent.toString('base64'),
},
const result = await invocation.execute(abortSignal);
expect(result).toEqual({
llmContent:
'Could not read file because the provided path is a directory, not a file.',
returnDisplay: 'Path is a directory.',
error: {
message: `Path is a directory, not a file: ${dirPath}`,
type: ToolErrorType.INVALID_TOOL_PARAMS,
},
returnDisplay: `Read image file: image.png`,
});
});
it('should treat a non-image file with image extension as an image', async () => {
const filePath = path.join(tempRootDir, 'fake-image.png');
const fileContent = 'This is not a real png.';
it('should return error for a file that is too large', async () => {
const filePath = path.join(tempRootDir, 'largefile.txt');
// 21MB of content exceeds 20MB limit
const largeContent = 'x'.repeat(21 * 1024 * 1024);
await fsp.writeFile(filePath, largeContent, 'utf-8');
const params: ReadFileToolParams = { absolute_path: filePath };
const invocation = tool.build(params) as ToolInvocation<
ReadFileToolParams,
ToolResult
>;
const result = await invocation.execute(abortSignal);
expect(result).toHaveProperty('error');
expect(result.error?.type).toBe(ToolErrorType.FILE_TOO_LARGE);
expect(result.error?.message).toContain(
'File size exceeds the 20MB limit',
);
});
it('should handle text file with lines exceeding maximum length', async () => {
const filePath = path.join(tempRootDir, 'longlines.txt');
const longLine = 'a'.repeat(2500); // Exceeds MAX_LINE_LENGTH_TEXT_FILE (2000)
const fileContent = `Short line\n${longLine}\nAnother short line`;
await fsp.writeFile(filePath, fileContent, 'utf-8');
const params: ReadFileToolParams = { absolute_path: filePath };
const invocation = tool.build(params) as ToolInvocation<
ReadFileToolParams,
ToolResult
>;
expect(await tool.execute(params, abortSignal)).toEqual({
llmContent: {
inlineData: {
mimeType: 'image/png',
data: Buffer.from(fileContent).toString('base64'),
},
},
returnDisplay: `Read image file: fake-image.png`,
});
const result = await invocation.execute(abortSignal);
expect(result.llmContent).toContain(
'IMPORTANT: The file content has been truncated',
);
expect(result.llmContent).toContain('--- FILE CONTENT (truncated) ---');
expect(result.returnDisplay).toContain('some lines were shortened');
});
it('should pass offset and limit to read a slice of a text file', async () => {
it('should handle image file and return appropriate content', async () => {
const imagePath = path.join(tempRootDir, 'image.png');
// Minimal PNG header
const pngHeader = Buffer.from([
0x89, 0x50, 0x4e, 0x47, 0x0d, 0x0a, 0x1a, 0x0a,
]);
await fsp.writeFile(imagePath, pngHeader);
const params: ReadFileToolParams = { absolute_path: imagePath };
const invocation = tool.build(params) as ToolInvocation<
ReadFileToolParams,
ToolResult
>;
const result = await invocation.execute(abortSignal);
expect(result.llmContent).toEqual({
inlineData: {
data: pngHeader.toString('base64'),
mimeType: 'image/png',
},
});
expect(result.returnDisplay).toBe('Read image file: image.png');
});
it('should handle PDF file and return appropriate content', async () => {
const pdfPath = path.join(tempRootDir, 'document.pdf');
// Minimal PDF header
const pdfHeader = Buffer.from('%PDF-1.4');
await fsp.writeFile(pdfPath, pdfHeader);
const params: ReadFileToolParams = { absolute_path: pdfPath };
const invocation = tool.build(params) as ToolInvocation<
ReadFileToolParams,
ToolResult
>;
const result = await invocation.execute(abortSignal);
expect(result.llmContent).toEqual({
inlineData: {
data: pdfHeader.toString('base64'),
mimeType: 'application/pdf',
},
});
expect(result.returnDisplay).toBe('Read pdf file: document.pdf');
});
it('should handle binary file and skip content', async () => {
const binPath = path.join(tempRootDir, 'binary.bin');
// Binary data with null bytes
const binaryData = Buffer.from([0x00, 0xff, 0x00, 0xff]);
await fsp.writeFile(binPath, binaryData);
const params: ReadFileToolParams = { absolute_path: binPath };
const invocation = tool.build(params) as ToolInvocation<
ReadFileToolParams,
ToolResult
>;
const result = await invocation.execute(abortSignal);
expect(result.llmContent).toBe(
'Cannot display content of binary file: binary.bin',
);
expect(result.returnDisplay).toBe('Skipped binary file: binary.bin');
});
it('should handle SVG file as text', async () => {
const svgPath = path.join(tempRootDir, 'image.svg');
const svgContent = '<svg><circle cx="50" cy="50" r="40"/></svg>';
await fsp.writeFile(svgPath, svgContent, 'utf-8');
const params: ReadFileToolParams = { absolute_path: svgPath };
const invocation = tool.build(params) as ToolInvocation<
ReadFileToolParams,
ToolResult
>;
const result = await invocation.execute(abortSignal);
expect(result.llmContent).toBe(svgContent);
expect(result.returnDisplay).toBe('Read SVG as text: image.svg');
});
it('should handle large SVG file', async () => {
const svgPath = path.join(tempRootDir, 'large.svg');
// Create SVG content larger than 1MB
const largeContent = '<svg>' + 'x'.repeat(1024 * 1024 + 1) + '</svg>';
await fsp.writeFile(svgPath, largeContent, 'utf-8');
const params: ReadFileToolParams = { absolute_path: svgPath };
const invocation = tool.build(params) as ToolInvocation<
ReadFileToolParams,
ToolResult
>;
const result = await invocation.execute(abortSignal);
expect(result.llmContent).toBe(
'Cannot display content of SVG file larger than 1MB: large.svg',
);
expect(result.returnDisplay).toBe(
'Skipped large SVG file (>1MB): large.svg',
);
});
it('should handle empty file', async () => {
const emptyPath = path.join(tempRootDir, 'empty.txt');
await fsp.writeFile(emptyPath, '', 'utf-8');
const params: ReadFileToolParams = { absolute_path: emptyPath };
const invocation = tool.build(params) as ToolInvocation<
ReadFileToolParams,
ToolResult
>;
const result = await invocation.execute(abortSignal);
expect(result.llmContent).toBe('');
expect(result.returnDisplay).toBe('');
});
it('should support offset and limit for text files', async () => {
const filePath = path.join(tempRootDir, 'paginated.txt');
const fileContent = Array.from(
{ length: 20 },
(_, i) => `Line ${i + 1}`,
).join('\n');
const lines = Array.from({ length: 20 }, (_, i) => `Line ${i + 1}`);
const fileContent = lines.join('\n');
await fsp.writeFile(filePath, fileContent, 'utf-8');
const params: ReadFileToolParams = {
@@ -214,16 +373,24 @@ describe('ReadFileTool', () => {
offset: 5, // Start from line 6
limit: 3,
};
const invocation = tool.build(params) as ToolInvocation<
ReadFileToolParams,
ToolResult
>;
expect(await tool.execute(params, abortSignal)).toEqual({
llmContent: [
'[File content truncated: showing lines 6-8 of 20 total lines. Use offset/limit parameters to view more.]',
'Line 6',
'Line 7',
'Line 8',
].join('\n'),
returnDisplay: 'Read lines 6-8 of 20 from paginated.txt',
});
const result = await invocation.execute(abortSignal);
expect(result.llmContent).toContain(
'IMPORTANT: The file content has been truncated',
);
expect(result.llmContent).toContain(
'Status: Showing lines 6-8 of 20 total lines',
);
expect(result.llmContent).toContain('Line 6');
expect(result.llmContent).toContain('Line 7');
expect(result.llmContent).toContain('Line 8');
expect(result.returnDisplay).toBe(
'Read lines 6-8 of 20 from paginated.txt',
);
});
describe('with .geminiignore', () => {
@@ -234,66 +401,37 @@ describe('ReadFileTool', () => {
);
});
it('should return error if path is ignored by a .geminiignore pattern', async () => {
it('should throw error if path is ignored by a .geminiignore pattern', async () => {
const ignoredFilePath = path.join(tempRootDir, 'foo.bar');
await fsp.writeFile(ignoredFilePath, 'content', 'utf-8');
const params: ReadFileToolParams = {
absolute_path: ignoredFilePath,
};
const expectedError = `File path '${ignoredFilePath}' is ignored by .geminiignore pattern(s).`;
expect(await tool.execute(params, abortSignal)).toEqual({
llmContent: `Error: Invalid parameters provided. Reason: ${expectedError}`,
returnDisplay: expectedError,
});
expect(() => tool.build(params)).toThrow(expectedError);
});
it('should return error if path is in an ignored directory', async () => {
it('should throw error if file is in an ignored directory', async () => {
const ignoredDirPath = path.join(tempRootDir, 'ignored');
await fsp.mkdir(ignoredDirPath);
const filePath = path.join(ignoredDirPath, 'somefile.txt');
await fsp.writeFile(filePath, 'content', 'utf-8');
await fsp.mkdir(ignoredDirPath, { recursive: true });
const ignoredFilePath = path.join(ignoredDirPath, 'file.txt');
await fsp.writeFile(ignoredFilePath, 'content', 'utf-8');
const params: ReadFileToolParams = {
absolute_path: filePath,
absolute_path: ignoredFilePath,
};
const expectedError = `File path '${filePath}' is ignored by .geminiignore pattern(s).`;
expect(await tool.execute(params, abortSignal)).toEqual({
llmContent: `Error: Invalid parameters provided. Reason: ${expectedError}`,
returnDisplay: expectedError,
});
const expectedError = `File path '${ignoredFilePath}' is ignored by .geminiignore pattern(s).`;
expect(() => tool.build(params)).toThrow(expectedError);
});
});
});
describe('workspace boundary validation', () => {
it('should validate paths are within workspace root', () => {
const params: ReadFileToolParams = {
absolute_path: path.join(tempRootDir, 'file.txt'),
};
expect(tool.validateToolParams(params)).toBeNull();
});
it('should reject paths outside workspace root', () => {
const params: ReadFileToolParams = {
absolute_path: '/etc/passwd',
};
const error = tool.validateToolParams(params);
expect(error).toContain(
'File path must be within one of the workspace directories',
);
expect(error).toContain(tempRootDir);
});
it('should provide clear error message with workspace directories', () => {
const outsidePath = path.join(os.tmpdir(), 'outside-workspace.txt');
const params: ReadFileToolParams = {
absolute_path: outsidePath,
};
const error = tool.validateToolParams(params);
expect(error).toContain(
'File path must be within one of the workspace directories',
);
expect(error).toContain(tempRootDir);
it('should allow reading non-ignored files', async () => {
const allowedFilePath = path.join(tempRootDir, 'allowed.txt');
await fsp.writeFile(allowedFilePath, 'content', 'utf-8');
const params: ReadFileToolParams = {
absolute_path: allowedFilePath,
};
const invocation = tool.build(params);
expect(typeof invocation).not.toBe('string');
});
});
});
});

View File

@@ -7,8 +7,16 @@
import path from 'path';
import { SchemaValidator } from '../utils/schemaValidator.js';
import { makeRelative, shortenPath } from '../utils/paths.js';
import { BaseTool, Icon, ToolLocation, ToolResult } from './tools.js';
import { Type } from '@google/genai';
import {
BaseDeclarativeTool,
BaseToolInvocation,
Icon,
ToolInvocation,
ToolLocation,
ToolResult,
} from './tools.js';
import { ToolErrorType } from './tool-error.js';
import { PartUnion } from '@google/genai';
import {
processSingleFileContent,
getSpecificMimeType,
@@ -39,44 +47,162 @@ export interface ReadFileToolParams {
limit?: number;
}
class ReadFileToolInvocation extends BaseToolInvocation<
ReadFileToolParams,
ToolResult
> {
constructor(
private config: Config,
params: ReadFileToolParams,
) {
super(params);
}
getDescription(): string {
const relativePath = makeRelative(
this.params.absolute_path,
this.config.getTargetDir(),
);
return shortenPath(relativePath);
}
override toolLocations(): ToolLocation[] {
return [{ path: this.params.absolute_path, line: this.params.offset }];
}
async execute(): Promise<ToolResult> {
const result = await processSingleFileContent(
this.params.absolute_path,
this.config.getTargetDir(),
this.params.offset,
this.params.limit,
);
if (result.error) {
// Map error messages to ToolErrorType
let errorType: ToolErrorType;
let llmContent: string;
// Check error message patterns to determine error type
if (
result.error.includes('File not found') ||
result.error.includes('does not exist') ||
result.error.includes('ENOENT')
) {
errorType = ToolErrorType.FILE_NOT_FOUND;
llmContent =
'Could not read file because no file was found at the specified path.';
} else if (
result.error.includes('is a directory') ||
result.error.includes('EISDIR')
) {
errorType = ToolErrorType.INVALID_TOOL_PARAMS;
llmContent =
'Could not read file because the provided path is a directory, not a file.';
} else if (
result.error.includes('too large') ||
result.error.includes('File size exceeds')
) {
errorType = ToolErrorType.FILE_TOO_LARGE;
llmContent = `Could not read file. ${result.error}`;
} else {
// Other read errors map to READ_CONTENT_FAILURE
errorType = ToolErrorType.READ_CONTENT_FAILURE;
llmContent = `Could not read file. ${result.error}`;
}
return {
llmContent,
returnDisplay: result.returnDisplay || 'Error reading file',
error: {
message: result.error,
type: errorType,
},
};
}
let llmContent: PartUnion;
if (result.isTruncated) {
const [start, end] = result.linesShown!;
const total = result.originalLineCount!;
const nextOffset = this.params.offset
? this.params.offset + end - start + 1
: end;
llmContent = `
IMPORTANT: The file content has been truncated.
Status: Showing lines ${start}-${end} of ${total} total lines.
Action: To read more of the file, you can use the 'offset' and 'limit' parameters in a subsequent 'read_file' call. For example, to read the next section of the file, use offset: ${nextOffset}.
--- FILE CONTENT (truncated) ---
${result.llmContent}`;
} else {
llmContent = result.llmContent || '';
}
const lines =
typeof result.llmContent === 'string'
? result.llmContent.split('\n').length
: undefined;
const mimetype = getSpecificMimeType(this.params.absolute_path);
recordFileOperationMetric(
this.config,
FileOperation.READ,
lines,
mimetype,
path.extname(this.params.absolute_path),
);
return {
llmContent,
returnDisplay: result.returnDisplay || '',
};
}
}
/**
* Implementation of the ReadFile tool logic
*/
export class ReadFileTool extends BaseTool<ReadFileToolParams, ToolResult> {
export class ReadFileTool extends BaseDeclarativeTool<
ReadFileToolParams,
ToolResult
> {
static readonly Name: string = 'read_file';
constructor(private config: Config) {
super(
ReadFileTool.Name,
'ReadFile',
'Reads and returns the content of a specified file from the local filesystem. Handles text, images (PNG, JPG, GIF, WEBP, SVG, BMP), and PDF files. For text files, it can read specific line ranges.',
`Reads and returns the content of a specified file. If the file is large, the content will be truncated. The tool's response will clearly indicate if truncation has occurred and will provide details on how to read more of the file using the 'offset' and 'limit' parameters. Handles text, images (PNG, JPG, GIF, WEBP, SVG, BMP), and PDF files. For text files, it can read specific line ranges.`,
Icon.FileSearch,
{
properties: {
absolute_path: {
description:
"The absolute path to the file to read (e.g., '/home/user/project/file.txt'). Relative paths are not supported. You must provide an absolute path.",
type: Type.STRING,
type: 'string',
},
offset: {
description:
"Optional: For text files, the 0-based line number to start reading from. Requires 'limit' to be set. Use for paginating through large files.",
type: Type.NUMBER,
type: 'number',
},
limit: {
description:
"Optional: For text files, maximum number of lines to read. Use with 'offset' to paginate through large files. If omitted, reads the entire file (if feasible, up to a default limit).",
type: Type.NUMBER,
type: 'number',
},
},
required: ['absolute_path'],
type: Type.OBJECT,
type: 'object',
},
);
}
validateToolParams(params: ReadFileToolParams): string | null {
const errors = SchemaValidator.validate(this.schema.parameters, params);
protected validateToolParams(params: ReadFileToolParams): string | null {
const errors = SchemaValidator.validate(
this.schema.parametersJsonSchema,
params,
);
if (errors) {
return errors;
}
@@ -106,67 +232,9 @@ export class ReadFileTool extends BaseTool<ReadFileToolParams, ToolResult> {
return null;
}
getDescription(params: ReadFileToolParams): string {
if (
!params ||
typeof params.absolute_path !== 'string' ||
params.absolute_path.trim() === ''
) {
return `Path unavailable`;
}
const relativePath = makeRelative(
params.absolute_path,
this.config.getTargetDir(),
);
return shortenPath(relativePath);
}
toolLocations(params: ReadFileToolParams): ToolLocation[] {
return [{ path: params.absolute_path, line: params.offset }];
}
async execute(
protected createInvocation(
params: ReadFileToolParams,
_signal: AbortSignal,
): Promise<ToolResult> {
const validationError = this.validateToolParams(params);
if (validationError) {
return {
llmContent: `Error: Invalid parameters provided. Reason: ${validationError}`,
returnDisplay: validationError,
};
}
const result = await processSingleFileContent(
params.absolute_path,
this.config.getTargetDir(),
params.offset,
params.limit,
);
if (result.error) {
return {
llmContent: result.error, // The detailed error for LLM
returnDisplay: result.returnDisplay || 'Error reading file', // User-friendly error
};
}
const lines =
typeof result.llmContent === 'string'
? result.llmContent.split('\n').length
: undefined;
const mimetype = getSpecificMimeType(params.absolute_path);
recordFileOperationMetric(
this.config,
FileOperation.READ,
lines,
mimetype,
path.extname(params.absolute_path),
);
return {
llmContent: result.llmContent || '',
returnDisplay: result.returnDisplay || '',
};
): ToolInvocation<ReadFileToolParams, ToolResult> {
return new ReadFileToolInvocation(this.config, params);
}
}

View File

@@ -476,6 +476,34 @@ describe('ReadManyFilesTool', () => {
fs.rmSync(tempDir1, { recursive: true, force: true });
fs.rmSync(tempDir2, { recursive: true, force: true });
});
it('should add a warning for truncated files', async () => {
createFile('file1.txt', 'Content1');
// Create a file that will be "truncated" by making it long
const longContent = Array.from({ length: 2500 }, (_, i) => `L${i}`).join(
'\n',
);
createFile('large-file.txt', longContent);
const params = { paths: ['*.txt'] };
const result = await tool.execute(params, new AbortController().signal);
const content = result.llmContent as string[];
const normalFileContent = content.find((c) => c.includes('file1.txt'));
const truncatedFileContent = content.find((c) =>
c.includes('large-file.txt'),
);
expect(normalFileContent).not.toContain(
'[WARNING: This file was truncated.',
);
expect(truncatedFileContent).toContain(
"[WARNING: This file was truncated. To view the full content, use the 'read_file' tool on this specific file.]",
);
// Check that the actual content is still there but truncated
expect(truncatedFileContent).toContain('L200');
expect(truncatedFileContent).not.toContain('L2400');
});
});
describe('Batch Processing', () => {
@@ -495,7 +523,7 @@ describe('ReadManyFilesTool', () => {
fs.writeFileSync(fullPath, content);
};
it('should process files in parallel for performance', async () => {
it('should process files in parallel', async () => {
// Mock detectFileType to add artificial delay to simulate I/O
const detectFileTypeSpy = vi.spyOn(
await import('../utils/fileUtils.js'),
@@ -506,31 +534,21 @@ describe('ReadManyFilesTool', () => {
const fileCount = 4;
const files = createMultipleFiles(fileCount, 'Batch test');
// Mock with 100ms delay per file to simulate I/O operations
// Mock with 10ms delay per file to simulate I/O operations
detectFileTypeSpy.mockImplementation(async (_filePath: string) => {
await new Promise((resolve) => setTimeout(resolve, 100));
await new Promise((resolve) => setTimeout(resolve, 10));
return 'text';
});
const startTime = Date.now();
const params = { paths: files };
const result = await tool.execute(params, new AbortController().signal);
const endTime = Date.now();
const processingTime = endTime - startTime;
console.log(
`Processing time: ${processingTime}ms for ${fileCount} files`,
);
// Verify parallel processing performance improvement
// Parallel processing should complete in ~100ms (single file time)
// Sequential would take ~400ms (4 files × 100ms each)
expect(processingTime).toBeLessThan(200); // Should PASS with parallel implementation
// Verify all files were processed
const content = result.llmContent as string[];
expect(content).toHaveLength(fileCount);
for (let i = 0; i < fileCount; i++) {
expect(content.join('')).toContain(`Batch test ${i}`);
}
// Cleanup mock
detectFileTypeSpy.mockRestore();

View File

@@ -16,7 +16,7 @@ import {
DEFAULT_ENCODING,
getSpecificMimeType,
} from '../utils/fileUtils.js';
import { PartListUnion, Schema, Type } from '@google/genai';
import { PartListUnion } from '@google/genai';
import { Config, DEFAULT_FILE_FILTERING_OPTIONS } from '../config/config.js';
import {
recordFileOperationMetric,
@@ -150,47 +150,47 @@ export class ReadManyFilesTool extends BaseTool<
static readonly Name: string = 'read_many_files';
constructor(private config: Config) {
const parameterSchema: Schema = {
type: Type.OBJECT,
const parameterSchema = {
type: 'object',
properties: {
paths: {
type: Type.ARRAY,
type: 'array',
items: {
type: Type.STRING,
minLength: '1',
type: 'string',
minLength: 1,
},
minItems: '1',
minItems: 1,
description:
"Required. An array of glob patterns or paths relative to the tool's target directory. Examples: ['src/**/*.ts'], ['README.md', 'docs/']",
},
include: {
type: Type.ARRAY,
type: 'array',
items: {
type: Type.STRING,
minLength: '1',
type: 'string',
minLength: 1,
},
description:
'Optional. Additional glob patterns to include. These are merged with `paths`. Example: ["*.test.ts"] to specifically add test files if they were broadly excluded.',
default: [],
},
exclude: {
type: Type.ARRAY,
type: 'array',
items: {
type: Type.STRING,
minLength: '1',
type: 'string',
minLength: 1,
},
description:
'Optional. Glob patterns for files/directories to exclude. Added to default excludes if useDefaultExcludes is true. Example: ["**/*.log", "temp/"]',
default: [],
},
recursive: {
type: Type.BOOLEAN,
type: 'boolean',
description:
'Optional. Whether to search recursively (primarily controlled by `**` in glob patterns). Defaults to true.',
default: true,
},
useDefaultExcludes: {
type: Type.BOOLEAN,
type: 'boolean',
description:
'Optional. Whether to apply a list of default exclusion patterns (e.g., node_modules, .git, binary files). Defaults to true.',
default: true,
@@ -198,17 +198,17 @@ export class ReadManyFilesTool extends BaseTool<
file_filtering_options: {
description:
'Whether to respect ignore patterns from .gitignore or .geminiignore',
type: Type.OBJECT,
type: 'object',
properties: {
respect_git_ignore: {
description:
'Optional: Whether to respect .gitignore patterns when listing files. Only available in git repositories. Defaults to true.',
type: Type.BOOLEAN,
type: 'boolean',
},
respect_gemini_ignore: {
description:
'Optional: Whether to respect .geminiignore patterns when listing files. Defaults to true.',
type: Type.BOOLEAN,
type: 'boolean',
},
},
},
@@ -235,7 +235,10 @@ Use this tool when the user's query implies needing the content of several files
}
validateParams(params: ReadManyFilesParams): string | null {
const errors = SchemaValidator.validate(this.schema.parameters, params);
const errors = SchemaValidator.validate(
this.schema.parametersJsonSchema,
params,
);
if (errors) {
return errors;
}
@@ -524,11 +527,15 @@ Use this tool when the user's query implies needing the content of several files
'{filePath}',
filePath,
);
contentParts.push(
`${separator}\n\n${fileReadResult.llmContent}\n\n`,
);
let fileContentForLlm = '';
if (fileReadResult.isTruncated) {
fileContentForLlm += `[WARNING: This file was truncated. To view the full content, use the 'read_file' tool on this specific file.]\n\n`;
}
fileContentForLlm += fileReadResult.llmContent;
contentParts.push(`${separator}\n\n${fileContentForLlm}\n\n`);
} else {
contentParts.push(fileReadResult.llmContent); // This is a Part for image/pdf
// This is a Part for image/pdf, which we don't add the separator to.
contentParts.push(fileReadResult.llmContent);
}
processedFilesRelativePaths.push(relativePathForDisplay);

View File

@@ -25,6 +25,7 @@ vi.mock('../utils/summarizer.js');
import { isCommandAllowed } from '../utils/shell-utils.js';
import { ShellTool } from './shell.js';
import { ToolErrorType } from './tool-error.js';
import { type Config } from '../config/config.js';
import {
type ShellExecutionResult,
@@ -208,6 +209,42 @@ describe('ShellTool', () => {
expect(result.llmContent).not.toContain('pgrep');
});
it('should return error with error property for invalid parameters', async () => {
const result = await shellTool.execute(
{ command: '' }, // Empty command is invalid
mockAbortSignal,
);
expect(result.llmContent).toContain(
'Could not execute command due to invalid parameters:',
);
expect(result.returnDisplay).toBe('Command cannot be empty.');
expect(result.error).toEqual({
message: 'Command cannot be empty.',
type: ToolErrorType.INVALID_TOOL_PARAMS,
});
});
it('should return error with error property for invalid directory', async () => {
vi.mocked(fs.existsSync).mockReturnValue(false);
const result = await shellTool.execute(
{ command: 'ls', directory: 'nonexistent' },
mockAbortSignal,
);
expect(result.llmContent).toContain(
'Could not execute command due to invalid parameters:',
);
expect(result.returnDisplay).toBe(
"Directory 'nonexistent' is not a registered workspace directory.",
);
expect(result.error).toEqual({
message:
"Directory 'nonexistent' is not a registered workspace directory.",
type: ToolErrorType.INVALID_TOOL_PARAMS,
});
});
it('should summarize output when configured', async () => {
(mockConfig.getSummarizeToolOutputConfig as Mock).mockReturnValue({
[shellTool.name]: { tokenBudget: 1000 },

View File

@@ -17,7 +17,7 @@ import {
ToolConfirmationOutcome,
Icon,
} from './tools.js';
import { Type } from '@google/genai';
import { ToolErrorType } from './tool-error.js';
import { SchemaValidator } from '../utils/schemaValidator.js';
import { getErrorMessage } from '../utils/errors.js';
import { summarizeToolOutput } from '../utils/summarizer.js';
@@ -63,19 +63,19 @@ export class ShellTool extends BaseTool<ShellToolParams, ToolResult> {
Process Group PGID: Process group started or \`(none)\``,
Icon.Terminal,
{
type: Type.OBJECT,
type: 'object',
properties: {
command: {
type: Type.STRING,
type: 'string',
description: 'Exact bash command to execute as `bash -c <command>`',
},
description: {
type: Type.STRING,
type: 'string',
description:
'Brief description of the command for the user. Be specific and concise. Ideally a single sentence. Can be up to 3 sentences for clarity. No line breaks.',
},
directory: {
type: Type.STRING,
type: 'string',
description:
'(OPTIONAL) Directory to run the command in, if not the project root directory. Must be relative to the project root directory and must already exist.',
},
@@ -112,7 +112,10 @@ export class ShellTool extends BaseTool<ShellToolParams, ToolResult> {
}
return commandCheck.reason;
}
const errors = SchemaValidator.validate(this.schema.parameters, params);
const errors = SchemaValidator.validate(
this.schema.parametersJsonSchema,
params,
);
if (errors) {
return errors;
}
@@ -186,8 +189,12 @@ export class ShellTool extends BaseTool<ShellToolParams, ToolResult> {
});
if (validationError) {
return {
llmContent: validationError,
llmContent: `Could not execute command due to invalid parameters: ${validationError}`,
returnDisplay: validationError,
error: {
message: validationError,
type: ToolErrorType.INVALID_TOOL_PARAMS,
},
};
}

View File

@@ -19,6 +19,10 @@ export enum ToolErrorType {
FILE_WRITE_FAILURE = 'file_write_failure',
READ_CONTENT_FAILURE = 'read_content_failure',
ATTEMPT_TO_CREATE_EXISTING_FILE = 'attempt_to_create_existing_file',
FILE_TOO_LARGE = 'file_too_large',
PERMISSION_DENIED = 'permission_denied',
NO_SPACE_LEFT = 'no_space_left',
TARGET_IS_DIRECTORY = 'target_is_directory',
// Edit-specific Errors
EDIT_PREPARATION_FAILURE = 'edit_preparation_failure',

View File

@@ -15,22 +15,12 @@ import {
Mocked,
} from 'vitest';
import { Config, ConfigParameters, ApprovalMode } from '../config/config.js';
import {
ToolRegistry,
DiscoveredTool,
sanitizeParameters,
} from './tool-registry.js';
import { ToolRegistry, DiscoveredTool } from './tool-registry.js';
import { DiscoveredMCPTool } from './mcp-tool.js';
import { BaseTool, Icon, ToolResult } from './tools.js';
import {
FunctionDeclaration,
CallableTool,
mcpToTool,
Type,
Schema,
} from '@google/genai';
import { FunctionDeclaration, CallableTool, mcpToTool } from '@google/genai';
import { spawn } from 'node:child_process';
import fs from 'node:fs';
import { MockTool } from '../test-utils/tools.js';
vi.mock('node:fs');
@@ -106,28 +96,6 @@ const createMockCallableTool = (
callTool: vi.fn(),
});
class MockTool extends BaseTool<{ param: string }, ToolResult> {
constructor(
name = 'mock-tool',
displayName = 'A mock tool',
description = 'A mock tool description',
) {
super(name, displayName, description, Icon.Hammer, {
type: Type.OBJECT,
properties: {
param: { type: Type.STRING },
},
required: ['param'],
});
}
async execute(params: { param: string }): Promise<ToolResult> {
return {
llmContent: `Executed with ${params.param}`,
returnDisplay: `Executed with ${params.param}`,
};
}
}
const baseConfigParams: ConfigParameters = {
cwd: '/tmp',
model: 'test-model',
@@ -275,18 +243,18 @@ describe('ToolRegistry', () => {
});
describe('discoverTools', () => {
it('should sanitize tool parameters during discovery from command', async () => {
it('should will preserve tool parametersJsonSchema during discovery from command', async () => {
const discoveryCommand = 'my-discovery-command';
mockConfigGetToolDiscoveryCommand.mockReturnValue(discoveryCommand);
const unsanitizedToolDeclaration: FunctionDeclaration = {
name: 'tool-with-bad-format',
description: 'A tool with an invalid format property',
parameters: {
type: Type.OBJECT,
parametersJsonSchema: {
type: 'object',
properties: {
some_string: {
type: Type.STRING,
type: 'string',
format: 'uuid', // This is an unsupported format
},
},
@@ -329,11 +297,39 @@ describe('ToolRegistry', () => {
expect(discoveredTool).toBeDefined();
const registeredParams = (discoveredTool as DiscoveredTool).schema
.parameters as Schema;
expect(registeredParams.properties?.['some_string']).toBeDefined();
expect(registeredParams.properties?.['some_string']).toHaveProperty(
'format',
.parametersJsonSchema;
expect(registeredParams).toStrictEqual({
type: 'object',
properties: {
some_string: {
type: 'string',
format: 'uuid',
},
},
});
});
it('should discover tools using MCP servers defined in getMcpServers', async () => {
mockConfigGetToolDiscoveryCommand.mockReturnValue(undefined);
vi.spyOn(config, 'getMcpServerCommand').mockReturnValue(undefined);
const mcpServerConfigVal = {
'my-mcp-server': {
command: 'mcp-server-cmd',
args: ['--port', '1234'],
trust: true,
},
};
vi.spyOn(config, 'getMcpServers').mockReturnValue(mcpServerConfigVal);
await toolRegistry.discoverAllTools();
expect(mockDiscoverMcpTools).toHaveBeenCalledWith(
mcpServerConfigVal,
undefined,
toolRegistry,
config.getPromptRegistry(),
false,
expect.any(Object),
);
});
@@ -357,214 +353,8 @@ describe('ToolRegistry', () => {
toolRegistry,
config.getPromptRegistry(),
false,
);
});
it('should discover tools using MCP servers defined in getMcpServers', async () => {
mockConfigGetToolDiscoveryCommand.mockReturnValue(undefined);
vi.spyOn(config, 'getMcpServerCommand').mockReturnValue(undefined);
const mcpServerConfigVal = {
'my-mcp-server': {
command: 'mcp-server-cmd',
args: ['--port', '1234'],
trust: true,
},
};
vi.spyOn(config, 'getMcpServers').mockReturnValue(mcpServerConfigVal);
await toolRegistry.discoverAllTools();
expect(mockDiscoverMcpTools).toHaveBeenCalledWith(
mcpServerConfigVal,
undefined,
toolRegistry,
config.getPromptRegistry(),
false,
expect.any(Object),
);
});
});
});
describe('sanitizeParameters', () => {
it('should remove default when anyOf is present', () => {
const schema: Schema = {
anyOf: [{ type: Type.STRING }, { type: Type.NUMBER }],
default: 'hello',
};
sanitizeParameters(schema);
expect(schema.default).toBeUndefined();
});
it('should recursively sanitize items in anyOf', () => {
const schema: Schema = {
anyOf: [
{
anyOf: [{ type: Type.STRING }],
default: 'world',
},
{ type: Type.NUMBER },
],
};
sanitizeParameters(schema);
expect(schema.anyOf![0].default).toBeUndefined();
});
it('should recursively sanitize items in items', () => {
const schema: Schema = {
items: {
anyOf: [{ type: Type.STRING }],
default: 'world',
},
};
sanitizeParameters(schema);
expect(schema.items!.default).toBeUndefined();
});
it('should recursively sanitize items in properties', () => {
const schema: Schema = {
properties: {
prop1: {
anyOf: [{ type: Type.STRING }],
default: 'world',
},
},
};
sanitizeParameters(schema);
expect(schema.properties!.prop1.default).toBeUndefined();
});
it('should handle complex nested schemas', () => {
const schema: Schema = {
properties: {
prop1: {
items: {
anyOf: [{ type: Type.STRING }],
default: 'world',
},
},
prop2: {
anyOf: [
{
properties: {
nestedProp: {
anyOf: [{ type: Type.NUMBER }],
default: 123,
},
},
},
],
},
},
};
sanitizeParameters(schema);
expect(schema.properties!.prop1.items!.default).toBeUndefined();
const nestedProp =
schema.properties!.prop2.anyOf![0].properties!.nestedProp;
expect(nestedProp?.default).toBeUndefined();
});
it('should remove unsupported format from a simple string property', () => {
const schema: Schema = {
type: Type.OBJECT,
properties: {
name: { type: Type.STRING },
id: { type: Type.STRING, format: 'uuid' },
},
};
sanitizeParameters(schema);
expect(schema.properties?.['id']).toHaveProperty('format', undefined);
expect(schema.properties?.['name']).not.toHaveProperty('format');
});
it('should NOT remove supported format values', () => {
const schema: Schema = {
type: Type.OBJECT,
properties: {
date: { type: Type.STRING, format: 'date-time' },
role: {
type: Type.STRING,
format: 'enum',
enum: ['admin', 'user'],
},
},
};
const originalSchema = JSON.parse(JSON.stringify(schema));
sanitizeParameters(schema);
expect(schema).toEqual(originalSchema);
});
it('should handle arrays of objects', () => {
const schema: Schema = {
type: Type.OBJECT,
properties: {
items: {
type: Type.ARRAY,
items: {
type: Type.OBJECT,
properties: {
itemId: { type: Type.STRING, format: 'uuid' },
},
},
},
},
};
sanitizeParameters(schema);
expect(
(schema.properties?.['items']?.items as Schema)?.properties?.['itemId'],
).toHaveProperty('format', undefined);
});
it('should handle schemas with no properties to sanitize', () => {
const schema: Schema = {
type: Type.OBJECT,
properties: {
count: { type: Type.NUMBER },
isActive: { type: Type.BOOLEAN },
},
};
const originalSchema = JSON.parse(JSON.stringify(schema));
sanitizeParameters(schema);
expect(schema).toEqual(originalSchema);
});
it('should not crash on an empty or undefined schema', () => {
expect(() => sanitizeParameters({})).not.toThrow();
expect(() => sanitizeParameters(undefined)).not.toThrow();
});
it('should handle complex nested schemas with cycles', () => {
const userNode: any = {
type: Type.OBJECT,
properties: {
id: { type: Type.STRING, format: 'uuid' },
name: { type: Type.STRING },
manager: {
type: Type.OBJECT,
properties: {
id: { type: Type.STRING, format: 'uuid' },
},
},
},
};
userNode.properties.reports = {
type: Type.ARRAY,
items: userNode,
};
const schema: Schema = {
type: Type.OBJECT,
properties: {
ceo: userNode,
},
};
expect(() => sanitizeParameters(schema)).not.toThrow();
expect(schema.properties?.['ceo']?.properties?.['id']).toHaveProperty(
'format',
undefined,
);
expect(
schema.properties?.['ceo']?.properties?.['manager']?.properties?.['id'],
).toHaveProperty('format', undefined);
});
});

View File

@@ -4,8 +4,8 @@
* SPDX-License-Identifier: Apache-2.0
*/
import { FunctionDeclaration, Schema, Type } from '@google/genai';
import { Tool, ToolResult, BaseTool, Icon } from './tools.js';
import { FunctionDeclaration } from '@google/genai';
import { AnyDeclarativeTool, Icon, ToolResult, BaseTool } from './tools.js';
import { Config } from '../config/config.js';
import { spawn } from 'node:child_process';
import { StringDecoder } from 'node:string_decoder';
@@ -125,7 +125,7 @@ Signal: Signal number or \`(none)\` if no signal was received.
}
export class ToolRegistry {
private tools: Map<string, Tool> = new Map();
private tools: Map<string, AnyDeclarativeTool> = new Map();
private config: Config;
constructor(config: Config) {
@@ -136,7 +136,7 @@ export class ToolRegistry {
* Registers a tool definition.
* @param tool - The tool object containing schema and execution logic.
*/
registerTool(tool: Tool): void {
registerTool(tool: AnyDeclarativeTool): void {
if (this.tools.has(tool.name)) {
if (tool instanceof DiscoveredMCPTool) {
tool = tool.asFullyQualifiedTool();
@@ -178,6 +178,7 @@ export class ToolRegistry {
this,
this.config.getPromptRegistry(),
this.config.getDebugMode(),
this.config.getWorkspaceContext(),
);
}
@@ -199,6 +200,7 @@ export class ToolRegistry {
this,
this.config.getPromptRegistry(),
this.config.getDebugMode(),
this.config.getWorkspaceContext(),
);
}
@@ -225,6 +227,7 @@ export class ToolRegistry {
this,
this.config.getPromptRegistry(),
this.config.getDebugMode(),
this.config.getWorkspaceContext(),
);
}
}
@@ -328,14 +331,12 @@ export class ToolRegistry {
console.warn('Discovered a tool with no name. Skipping.');
continue;
}
// Sanitize the parameters before registering the tool.
const parameters =
func.parameters &&
typeof func.parameters === 'object' &&
!Array.isArray(func.parameters)
? (func.parameters as Schema)
func.parametersJsonSchema &&
typeof func.parametersJsonSchema === 'object' &&
!Array.isArray(func.parametersJsonSchema)
? func.parametersJsonSchema
: {};
sanitizeParameters(parameters);
this.registerTool(
new DiscoveredTool(
this.config,
@@ -365,10 +366,26 @@ export class ToolRegistry {
return declarations;
}
/**
* Retrieves a filtered list of tool schemas based on a list of tool names.
* @param toolNames - An array of tool names to include.
* @returns An array of FunctionDeclarations for the specified tools.
*/
getFunctionDeclarationsFiltered(toolNames: string[]): FunctionDeclaration[] {
const declarations: FunctionDeclaration[] = [];
for (const name of toolNames) {
const tool = this.tools.get(name);
if (tool) {
declarations.push(tool.schema);
}
}
return declarations;
}
/**
* Returns an array of all registered and discovered tool instances.
*/
getAllTools(): Tool[] {
getAllTools(): AnyDeclarativeTool[] {
return Array.from(this.tools.values()).sort((a, b) =>
a.displayName.localeCompare(b.displayName),
);
@@ -377,8 +394,8 @@ export class ToolRegistry {
/**
* Returns an array of tools registered from a specific MCP server.
*/
getToolsByServer(serverName: string): Tool[] {
const serverTools: Tool[] = [];
getToolsByServer(serverName: string): AnyDeclarativeTool[] {
const serverTools: AnyDeclarativeTool[] = [];
for (const tool of this.tools.values()) {
if ((tool as DiscoveredMCPTool)?.serverName === serverName) {
serverTools.push(tool);
@@ -390,79 +407,7 @@ export class ToolRegistry {
/**
* Get the definition of a specific tool.
*/
getTool(name: string): Tool | undefined {
getTool(name: string): AnyDeclarativeTool | undefined {
return this.tools.get(name);
}
}
/**
* Sanitizes a schema object in-place to ensure compatibility with the Gemini API.
*
* NOTE: This function mutates the passed schema object.
*
* It performs the following actions:
* - Removes the `default` property when `anyOf` is present.
* - Removes unsupported `format` values from string properties, keeping only 'enum' and 'date-time'.
* - Recursively sanitizes nested schemas within `anyOf`, `items`, and `properties`.
* - Handles circular references within the schema to prevent infinite loops.
*
* @param schema The schema object to sanitize. It will be modified directly.
*/
export function sanitizeParameters(schema?: Schema) {
_sanitizeParameters(schema, new Set<Schema>());
}
/**
* Internal recursive implementation for sanitizeParameters.
* @param schema The schema object to sanitize.
* @param visited A set used to track visited schema objects during recursion.
*/
function _sanitizeParameters(schema: Schema | undefined, visited: Set<Schema>) {
if (!schema || visited.has(schema)) {
return;
}
visited.add(schema);
if (schema.anyOf) {
// Vertex AI gets confused if both anyOf and default are set.
schema.default = undefined;
for (const item of schema.anyOf) {
if (typeof item !== 'boolean') {
_sanitizeParameters(item, visited);
}
}
}
if (schema.items && typeof schema.items !== 'boolean') {
_sanitizeParameters(schema.items, visited);
}
if (schema.properties) {
for (const item of Object.values(schema.properties)) {
if (typeof item !== 'boolean') {
_sanitizeParameters(item, visited);
}
}
}
// Handle enum values - Gemini API only allows enum for STRING type
if (schema.enum && Array.isArray(schema.enum)) {
if (schema.type !== Type.STRING) {
// If enum is present but type is not STRING, convert type to STRING
schema.type = Type.STRING;
}
// Filter out null and undefined values, then convert remaining values to strings for Gemini API compatibility
schema.enum = schema.enum
.filter((value: unknown) => value !== null && value !== undefined)
.map((value: unknown) => String(value));
}
// Vertex AI only supports 'enum' and 'date-time' for STRING format.
if (schema.type === Type.STRING) {
if (
schema.format &&
schema.format !== 'enum' &&
schema.format !== 'date-time'
) {
schema.format = undefined;
}
}
}

View File

@@ -0,0 +1,125 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { describe, it, expect } from 'vitest';
import { hasCycleInSchema } from './tools.js'; // Added getStringifiedResultForDisplay
describe('hasCycleInSchema', () => {
it('should detect a simple direct cycle', () => {
const schema = {
properties: {
data: {
$ref: '#/properties/data',
},
},
};
expect(hasCycleInSchema(schema)).toBe(true);
});
it('should detect a cycle from object properties referencing parent properties', () => {
const schema = {
type: 'object',
properties: {
data: {
type: 'object',
properties: {
child: { $ref: '#/properties/data' },
},
},
},
};
expect(hasCycleInSchema(schema)).toBe(true);
});
it('should detect a cycle from array items referencing parent properties', () => {
const schema = {
type: 'object',
properties: {
data: {
type: 'array',
items: {
type: 'object',
properties: {
child: { $ref: '#/properties/data/items' },
},
},
},
},
};
expect(hasCycleInSchema(schema)).toBe(true);
});
it('should detect a cycle between sibling properties', () => {
const schema = {
type: 'object',
properties: {
a: {
type: 'object',
properties: {
child: { $ref: '#/properties/b' },
},
},
b: {
type: 'object',
properties: {
child: { $ref: '#/properties/a' },
},
},
},
};
expect(hasCycleInSchema(schema)).toBe(true);
});
it('should not detect a cycle in a valid schema', () => {
const schema = {
type: 'object',
properties: {
name: { type: 'string' },
address: { $ref: '#/definitions/address' },
},
definitions: {
address: {
type: 'object',
properties: {
street: { type: 'string' },
city: { type: 'string' },
},
},
},
};
expect(hasCycleInSchema(schema)).toBe(false);
});
it('should handle non-cyclic sibling refs', () => {
const schema = {
properties: {
a: { $ref: '#/definitions/stringDef' },
b: { $ref: '#/definitions/stringDef' },
},
definitions: {
stringDef: { type: 'string' },
},
};
expect(hasCycleInSchema(schema)).toBe(false);
});
it('should handle nested but not cyclic refs', () => {
const schema = {
properties: {
a: { $ref: '#/definitions/defA' },
},
definitions: {
defA: { properties: { b: { $ref: '#/definitions/defB' } } },
defB: { type: 'string' },
},
};
expect(hasCycleInSchema(schema)).toBe(false);
});
it('should return false for an empty schema', () => {
expect(hasCycleInSchema({})).toBe(false);
});
});

View File

@@ -4,105 +4,276 @@
* SPDX-License-Identifier: Apache-2.0
*/
import { FunctionDeclaration, PartListUnion, Schema } from '@google/genai';
import { FunctionDeclaration, PartListUnion } from '@google/genai';
import { ToolErrorType } from './tool-error.js';
import { DiffUpdateResult } from '../ide/ideContext.js';
/**
* Interface representing the base Tool functionality
* Represents a validated and ready-to-execute tool call.
* An instance of this is created by a `ToolBuilder`.
*/
export interface Tool<
TParams = unknown,
TResult extends ToolResult = ToolResult,
export interface ToolInvocation<
TParams extends object,
TResult extends ToolResult,
> {
/**
* The internal name of the tool (used for API calls)
* The validated parameters for this specific invocation.
*/
name: string;
params: TParams;
/**
* The user-friendly display name of the tool
* Gets a pre-execution description of the tool operation.
* @returns A markdown string describing what the tool will do.
*/
displayName: string;
getDescription(): string;
/**
* Description of what the tool does
* Determines what file system paths the tool will affect.
* @returns A list of such paths.
*/
description: string;
toolLocations(): ToolLocation[];
/**
* The icon to display when interacting via ACP
*/
icon: Icon;
/**
* Function declaration schema from @google/genai
*/
schema: FunctionDeclaration;
/**
* Whether the tool's output should be rendered as markdown
*/
isOutputMarkdown: boolean;
/**
* Whether the tool supports live (streaming) output
*/
canUpdateOutput: boolean;
/**
* Validates the parameters for the tool
* Should be called from both `shouldConfirmExecute` and `execute`
* `shouldConfirmExecute` should return false immediately if invalid
* @param params Parameters to validate
* @returns An error message string if invalid, null otherwise
*/
validateToolParams(params: TParams): string | null;
/**
* Gets a pre-execution description of the tool operation
* @param params Parameters for the tool execution
* @returns A markdown string describing what the tool will do
* Optional for backward compatibility
*/
getDescription(params: TParams): string;
/**
* Determines what file system paths the tool will affect
* @param params Parameters for the tool execution
* @returns A list of such paths
*/
toolLocations(params: TParams): ToolLocation[];
/**
* Determines if the tool should prompt for confirmation before execution
* @param params Parameters for the tool execution
* @returns Whether execute should be confirmed.
* Determines if the tool should prompt for confirmation before execution.
* @returns Confirmation details or false if no confirmation is needed.
*/
shouldConfirmExecute(
params: TParams,
abortSignal: AbortSignal,
): Promise<ToolCallConfirmationDetails | false>;
/**
* Executes the tool with the given parameters
* @param params Parameters for the tool execution
* @returns Result of the tool execution
* Executes the tool with the validated parameters.
* @param signal AbortSignal for tool cancellation.
* @param updateOutput Optional callback to stream output.
* @returns Result of the tool execution.
*/
execute(
params: TParams,
signal: AbortSignal,
updateOutput?: (output: string) => void,
): Promise<TResult>;
}
/**
* A convenience base class for ToolInvocation.
*/
export abstract class BaseToolInvocation<
TParams extends object,
TResult extends ToolResult,
> implements ToolInvocation<TParams, TResult>
{
constructor(readonly params: TParams) {}
abstract getDescription(): string;
toolLocations(): ToolLocation[] {
return [];
}
shouldConfirmExecute(
_abortSignal: AbortSignal,
): Promise<ToolCallConfirmationDetails | false> {
return Promise.resolve(false);
}
abstract execute(
signal: AbortSignal,
updateOutput?: (output: string) => void,
): Promise<TResult>;
}
/**
* A type alias for a tool invocation where the specific parameter and result types are not known.
*/
export type AnyToolInvocation = ToolInvocation<object, ToolResult>;
/**
* An adapter that wraps the legacy `Tool` interface to make it compatible
* with the new `ToolInvocation` pattern.
*/
export class LegacyToolInvocation<
TParams extends object,
TResult extends ToolResult,
> implements ToolInvocation<TParams, TResult>
{
constructor(
private readonly legacyTool: BaseTool<TParams, TResult>,
readonly params: TParams,
) {}
getDescription(): string {
return this.legacyTool.getDescription(this.params);
}
toolLocations(): ToolLocation[] {
return this.legacyTool.toolLocations(this.params);
}
shouldConfirmExecute(
abortSignal: AbortSignal,
): Promise<ToolCallConfirmationDetails | false> {
return this.legacyTool.shouldConfirmExecute(this.params, abortSignal);
}
execute(
signal: AbortSignal,
updateOutput?: (output: string) => void,
): Promise<TResult> {
return this.legacyTool.execute(this.params, signal, updateOutput);
}
}
/**
* Interface for a tool builder that validates parameters and creates invocations.
*/
export interface ToolBuilder<
TParams extends object,
TResult extends ToolResult,
> {
/**
* The internal name of the tool (used for API calls).
*/
name: string;
/**
* The user-friendly display name of the tool.
*/
displayName: string;
/**
* Description of what the tool does.
*/
description: string;
/**
* The icon to display when interacting via ACP.
*/
icon: Icon;
/**
* Function declaration schema from @google/genai.
*/
schema: FunctionDeclaration;
/**
* Whether the tool's output should be rendered as markdown.
*/
isOutputMarkdown: boolean;
/**
* Whether the tool supports live (streaming) output.
*/
canUpdateOutput: boolean;
/**
* Validates raw parameters and builds a ready-to-execute invocation.
* @param params The raw, untrusted parameters from the model.
* @returns A valid `ToolInvocation` if successful. Throws an error if validation fails.
*/
build(params: TParams): ToolInvocation<TParams, TResult>;
}
/**
* New base class for tools that separates validation from execution.
* New tools should extend this class.
*/
export abstract class DeclarativeTool<
TParams extends object,
TResult extends ToolResult,
> implements ToolBuilder<TParams, TResult>
{
constructor(
readonly name: string,
readonly displayName: string,
readonly description: string,
readonly icon: Icon,
readonly parameterSchema: unknown,
readonly isOutputMarkdown: boolean = true,
readonly canUpdateOutput: boolean = false,
) {}
get schema(): FunctionDeclaration {
return {
name: this.name,
description: this.description,
parametersJsonSchema: this.parameterSchema,
};
}
/**
* Validates the raw tool parameters.
* Subclasses should override this to add custom validation logic
* beyond the JSON schema check.
* @param params The raw parameters from the model.
* @returns An error message string if invalid, null otherwise.
*/
protected validateToolParams(_params: TParams): string | null {
// Base implementation can be extended by subclasses.
return null;
}
/**
* The core of the new pattern. It validates parameters and, if successful,
* returns a `ToolInvocation` object that encapsulates the logic for the
* specific, validated call.
* @param params The raw, untrusted parameters from the model.
* @returns A `ToolInvocation` instance.
*/
abstract build(params: TParams): ToolInvocation<TParams, TResult>;
/**
* A convenience method that builds and executes the tool in one step.
* Throws an error if validation fails.
* @param params The raw, untrusted parameters from the model.
* @param signal AbortSignal for tool cancellation.
* @param updateOutput Optional callback to stream output.
* @returns The result of the tool execution.
*/
async buildAndExecute(
params: TParams,
signal: AbortSignal,
updateOutput?: (output: string) => void,
): Promise<TResult> {
const invocation = this.build(params);
return invocation.execute(signal, updateOutput);
}
}
/**
* New base class for declarative tools that separates validation from execution.
* New tools should extend this class, which provides a `build` method that
* validates parameters before deferring to a `createInvocation` method for
* the final `ToolInvocation` object instantiation.
*/
export abstract class BaseDeclarativeTool<
TParams extends object,
TResult extends ToolResult,
> extends DeclarativeTool<TParams, TResult> {
build(params: TParams): ToolInvocation<TParams, TResult> {
const validationError = this.validateToolParams(params);
if (validationError) {
throw new Error(validationError);
}
return this.createInvocation(params);
}
protected abstract createInvocation(
params: TParams,
): ToolInvocation<TParams, TResult>;
}
/**
* A type alias for a declarative tool where the specific parameter and result types are not known.
*/
export type AnyDeclarativeTool = DeclarativeTool<object, ToolResult>;
/**
* Base implementation for tools with common functionality
* @deprecated Use `DeclarativeTool` for new tools.
*/
export abstract class BaseTool<
TParams = unknown,
TParams extends object,
TResult extends ToolResult = ToolResult,
> implements Tool<TParams, TResult>
{
> extends DeclarativeTool<TParams, TResult> {
/**
* Creates a new instance of BaseTool
* @param name Internal name of the tool (used for API calls)
@@ -110,27 +281,34 @@ export abstract class BaseTool<
* @param description Description of what the tool does
* @param isOutputMarkdown Whether the tool's output should be rendered as markdown
* @param canUpdateOutput Whether the tool supports live (streaming) output
* @param parameterSchema Open API 3.0 Schema defining the parameters
* @param parameterSchema JSON Schema defining the parameters
*/
constructor(
readonly name: string,
readonly displayName: string,
readonly description: string,
readonly icon: Icon,
readonly parameterSchema: Schema,
readonly parameterSchema: unknown,
readonly isOutputMarkdown: boolean = true,
readonly canUpdateOutput: boolean = false,
) {}
) {
super(
name,
displayName,
description,
icon,
parameterSchema,
isOutputMarkdown,
canUpdateOutput,
);
}
/**
* Function declaration schema computed from name, description, and parameterSchema
*/
get schema(): FunctionDeclaration {
return {
name: this.name,
description: this.description,
parameters: this.parameterSchema,
};
build(params: TParams): ToolInvocation<TParams, TResult> {
const validationError = this.validateToolParams(params);
if (validationError) {
throw new Error(validationError);
}
return new LegacyToolInvocation(this, params);
}
/**
@@ -228,6 +406,91 @@ export interface ToolResult {
};
}
/**
* Detects cycles in a JSON schemas due to `$ref`s.
* @param schema The root of the JSON schema.
* @returns `true` if a cycle is detected, `false` otherwise.
*/
export function hasCycleInSchema(schema: object): boolean {
function resolveRef(ref: string): object | null {
if (!ref.startsWith('#/')) {
return null;
}
const path = ref.substring(2).split('/');
let current: unknown = schema;
for (const segment of path) {
if (
typeof current !== 'object' ||
current === null ||
!Object.prototype.hasOwnProperty.call(current, segment)
) {
return null;
}
current = (current as Record<string, unknown>)[segment];
}
return current as object;
}
function traverse(
node: unknown,
visitedRefs: Set<string>,
pathRefs: Set<string>,
): boolean {
if (typeof node !== 'object' || node === null) {
return false;
}
if (Array.isArray(node)) {
for (const item of node) {
if (traverse(item, visitedRefs, pathRefs)) {
return true;
}
}
return false;
}
if ('$ref' in node && typeof node.$ref === 'string') {
const ref = node.$ref;
if (ref === '#/' || pathRefs.has(ref)) {
// A ref to just '#/' is always a cycle.
return true; // Cycle detected!
}
if (visitedRefs.has(ref)) {
return false; // Bail early, we have checked this ref before.
}
const resolvedNode = resolveRef(ref);
if (resolvedNode) {
// Add it to both visited and the current path
visitedRefs.add(ref);
pathRefs.add(ref);
const hasCycle = traverse(resolvedNode, visitedRefs, pathRefs);
pathRefs.delete(ref); // Backtrack, leaving it in visited
return hasCycle;
}
}
// Crawl all the properties of node
for (const key in node) {
if (Object.prototype.hasOwnProperty.call(node, key)) {
if (
traverse(
(node as Record<string, unknown>)[key],
visitedRefs,
pathRefs,
)
) {
return true;
}
}
}
return false;
}
return traverse(schema, new Set<string>(), new Set<string>());
}
export type ToolResultDisplay = string | FileDiff;
export interface FileDiff {
@@ -235,6 +498,14 @@ export interface FileDiff {
fileName: string;
originalContent: string | null;
newContent: string;
diffStat?: DiffStat;
}
export interface DiffStat {
ai_removed_lines: number;
ai_added_lines: number;
user_added_lines: number;
user_removed_lines: number;
}
export interface ToolEditConfirmationDetails {
@@ -245,10 +516,12 @@ export interface ToolEditConfirmationDetails {
payload?: ToolConfirmationPayload,
) => Promise<void>;
fileName: string;
filePath: string;
fileDiff: string;
originalContent: string | null;
newContent: string;
isModifying?: boolean;
ideConfirmation?: Promise<DiffUpdateResult>;
}
export interface ToolConfirmationPayload {

View File

@@ -12,7 +12,6 @@ import {
ToolConfirmationOutcome,
Icon,
} from './tools.js';
import { Type } from '@google/genai';
import { Config, ApprovalMode } from '../config/config.js';
import { getResponseText } from '../utils/generateContentResponseUtilities.js';
import { fetchWithTimeout } from '../utils/fetch.js';
@@ -52,15 +51,15 @@ export class WebFetchTool extends BaseTool<WebFetchToolParams, ToolResult> {
properties: {
url: {
description: 'The URL to fetch content from',
type: Type.STRING,
type: 'string',
},
prompt: {
description: 'The prompt to run on the fetched content',
type: Type.STRING,
type: 'string',
},
},
required: ['url', 'prompt'],
type: Type.OBJECT,
type: 'object',
},
);
const proxy = config.getProxy();
@@ -127,7 +126,10 @@ ${textContent}
}
validateParams(params: WebFetchToolParams): string | null {
const errors = SchemaValidator.validate(this.schema.parameters, params);
const errors = SchemaValidator.validate(
this.schema.parametersJsonSchema,
params,
);
if (errors) {
return errors;
}

View File

@@ -75,7 +75,10 @@ export class WebSearchTool extends BaseTool<
* @returns An error message string if validation fails, null if valid
*/
validateParams(params: WebSearchToolParams): string | null {
const errors = SchemaValidator.validate(this.schema.parameters, params);
const errors = SchemaValidator.validate(
this.schema.parametersJsonSchema,
params,
);
if (errors) {
return errors;
}

View File

@@ -14,6 +14,7 @@ import {
type Mocked,
} from 'vitest';
import { WriteFileTool, WriteFileToolParams } from './write-file.js';
import { ToolErrorType } from './tool-error.js';
import {
FileDiff,
ToolConfirmationOutcome,
@@ -55,6 +56,9 @@ const mockConfigInternal = {
getApprovalMode: vi.fn(() => ApprovalMode.DEFAULT),
setApprovalMode: vi.fn(),
getGeminiClient: vi.fn(), // Initialize as a plain mock function
getIdeClient: vi.fn(),
getIdeMode: vi.fn(() => false),
getIdeModeFeature: vi.fn(() => false),
getWorkspaceContext: () => createMockWorkspaceContext(rootDir),
getApiKey: () => 'test-key',
getModel: () => 'test-model',
@@ -110,6 +114,14 @@ describe('WriteFileTool', () => {
mockConfigInternal.getGeminiClient.mockReturnValue(
mockGeminiClientInstance,
);
mockConfigInternal.getIdeClient.mockReturnValue({
openDiff: vi.fn(),
closeDiff: vi.fn(),
getIdeContext: vi.fn(),
subscribeToIdeContext: vi.fn(),
isCodeTrackerEnabled: vi.fn(),
getTrackedCode: vi.fn(),
});
tool = new WriteFileTool(mockConfig);
@@ -453,18 +465,27 @@ describe('WriteFileTool', () => {
it('should return error if params are invalid (relative path)', async () => {
const params = { file_path: 'relative.txt', content: 'test' };
const result = await tool.execute(params, abortSignal);
expect(result.llmContent).toMatch(/Error: Invalid parameters provided/);
expect(result.returnDisplay).toMatch(/Error: File path must be absolute/);
expect(result.llmContent).toContain(
'Could not write file due to invalid parameters:',
);
expect(result.returnDisplay).toMatch(/File path must be absolute/);
expect(result.error).toEqual({
message: 'File path must be absolute: relative.txt',
type: ToolErrorType.INVALID_TOOL_PARAMS,
});
});
it('should return error if params are invalid (path outside root)', async () => {
const outsidePath = path.resolve(tempDir, 'outside-root.txt');
const params = { file_path: outsidePath, content: 'test' };
const result = await tool.execute(params, abortSignal);
expect(result.llmContent).toMatch(/Error: Invalid parameters provided/);
expect(result.returnDisplay).toContain(
'Error: File path must be within one of the workspace directories',
expect(result.llmContent).toContain(
'Could not write file due to invalid parameters:',
);
expect(result.returnDisplay).toContain(
'File path must be within one of the workspace directories',
);
expect(result.error?.type).toBe(ToolErrorType.INVALID_TOOL_PARAMS);
});
it('should return error if _getCorrectedFileContent returns an error during execute', async () => {
@@ -479,10 +500,15 @@ describe('WriteFileTool', () => {
});
const result = await tool.execute(params, abortSignal);
expect(result.llmContent).toMatch(/Error checking existing file/);
expect(result.llmContent).toContain('Error checking existing file:');
expect(result.returnDisplay).toMatch(
/Error checking existing file: Simulated read error for execute/,
);
expect(result.error).toEqual({
message:
'Error checking existing file: Simulated read error for execute',
type: ToolErrorType.FILE_WRITE_FAILURE,
});
vi.spyOn(fs, 'readFileSync').mockImplementation(originalReadFileSync);
fs.chmodSync(filePath, 0o600);
@@ -500,7 +526,11 @@ describe('WriteFileTool', () => {
params,
abortSignal,
);
if (typeof confirmDetails === 'object' && confirmDetails.onConfirm) {
if (
typeof confirmDetails === 'object' &&
'onConfirm' in confirmDetails &&
confirmDetails.onConfirm
) {
await confirmDetails.onConfirm(ToolConfirmationOutcome.ProceedOnce);
}
@@ -554,7 +584,11 @@ describe('WriteFileTool', () => {
params,
abortSignal,
);
if (typeof confirmDetails === 'object' && confirmDetails.onConfirm) {
if (
typeof confirmDetails === 'object' &&
'onConfirm' in confirmDetails &&
confirmDetails.onConfirm
) {
await confirmDetails.onConfirm(ToolConfirmationOutcome.ProceedOnce);
}
@@ -595,7 +629,11 @@ describe('WriteFileTool', () => {
params,
abortSignal,
);
if (typeof confirmDetails === 'object' && confirmDetails.onConfirm) {
if (
typeof confirmDetails === 'object' &&
'onConfirm' in confirmDetails &&
confirmDetails.onConfirm
) {
await confirmDetails.onConfirm(ToolConfirmationOutcome.ProceedOnce);
}
@@ -686,4 +724,114 @@ describe('WriteFileTool', () => {
expect(error).toContain(rootDir);
});
});
describe('specific error types for write failures', () => {
const abortSignal = new AbortController().signal;
it('should return PERMISSION_DENIED error when write fails with EACCES', async () => {
const filePath = path.join(rootDir, 'permission_denied_file.txt');
const content = 'test content';
// Mock writeFileSync to throw EACCES error
const originalWriteFileSync = fs.writeFileSync;
vi.spyOn(fs, 'writeFileSync').mockImplementationOnce(() => {
const error = new Error('Permission denied') as NodeJS.ErrnoException;
error.code = 'EACCES';
throw error;
});
const params = { file_path: filePath, content };
const result = await tool.execute(params, abortSignal);
expect(result.error?.type).toBe(ToolErrorType.PERMISSION_DENIED);
expect(result.llmContent).toContain(
`Permission denied writing to file: ${filePath} (EACCES)`,
);
expect(result.returnDisplay).toContain('Permission denied');
vi.spyOn(fs, 'writeFileSync').mockImplementation(originalWriteFileSync);
});
it('should return NO_SPACE_LEFT error when write fails with ENOSPC', async () => {
const filePath = path.join(rootDir, 'no_space_file.txt');
const content = 'test content';
// Mock writeFileSync to throw ENOSPC error
const originalWriteFileSync = fs.writeFileSync;
vi.spyOn(fs, 'writeFileSync').mockImplementationOnce(() => {
const error = new Error(
'No space left on device',
) as NodeJS.ErrnoException;
error.code = 'ENOSPC';
throw error;
});
const params = { file_path: filePath, content };
const result = await tool.execute(params, abortSignal);
expect(result.error?.type).toBe(ToolErrorType.NO_SPACE_LEFT);
expect(result.llmContent).toContain(
`No space left on device: ${filePath} (ENOSPC)`,
);
expect(result.returnDisplay).toContain('No space left');
vi.spyOn(fs, 'writeFileSync').mockImplementation(originalWriteFileSync);
});
it('should return TARGET_IS_DIRECTORY error when write fails with EISDIR', async () => {
const dirPath = path.join(rootDir, 'test_directory');
const content = 'test content';
// Mock fs.existsSync to return false to bypass validation
const originalExistsSync = fs.existsSync;
vi.spyOn(fs, 'existsSync').mockImplementation((path) => {
if (path === dirPath) {
return false; // Pretend directory doesn't exist to bypass validation
}
return originalExistsSync(path as string);
});
// Mock writeFileSync to throw EISDIR error
const originalWriteFileSync = fs.writeFileSync;
vi.spyOn(fs, 'writeFileSync').mockImplementationOnce(() => {
const error = new Error('Is a directory') as NodeJS.ErrnoException;
error.code = 'EISDIR';
throw error;
});
const params = { file_path: dirPath, content };
const result = await tool.execute(params, abortSignal);
expect(result.error?.type).toBe(ToolErrorType.TARGET_IS_DIRECTORY);
expect(result.llmContent).toContain(
`Target is a directory, not a file: ${dirPath} (EISDIR)`,
);
expect(result.returnDisplay).toContain('Target is a directory');
vi.spyOn(fs, 'existsSync').mockImplementation(originalExistsSync);
vi.spyOn(fs, 'writeFileSync').mockImplementation(originalWriteFileSync);
});
it('should return FILE_WRITE_FAILURE for generic write errors', async () => {
const filePath = path.join(rootDir, 'generic_error_file.txt');
const content = 'test content';
// Ensure fs.existsSync is not mocked for this test
vi.restoreAllMocks();
// Mock writeFileSync to throw generic error
vi.spyOn(fs, 'writeFileSync').mockImplementationOnce(() => {
throw new Error('Generic write error');
});
const params = { file_path: filePath, content };
const result = await tool.execute(params, abortSignal);
expect(result.error?.type).toBe(ToolErrorType.FILE_WRITE_FAILURE);
expect(result.llmContent).toContain(
'Error writing to file: Generic write error',
);
expect(result.returnDisplay).toContain('Generic write error');
});
});
});

View File

@@ -17,7 +17,7 @@ import {
ToolCallConfirmationDetails,
Icon,
} from './tools.js';
import { Type } from '@google/genai';
import { ToolErrorType } from './tool-error.js';
import { SchemaValidator } from '../utils/schemaValidator.js';
import { makeRelative, shortenPath } from '../utils/paths.js';
import { getErrorMessage, isNodeError } from '../utils/errors.js';
@@ -25,13 +25,14 @@ import {
ensureCorrectEdit,
ensureCorrectFileContent,
} from '../utils/editCorrector.js';
import { DEFAULT_DIFF_OPTIONS } from './diffOptions.js';
import { ModifiableTool, ModifyContext } from './modifiable-tool.js';
import { DEFAULT_DIFF_OPTIONS, getDiffStat } from './diffOptions.js';
import { ModifiableDeclarativeTool, ModifyContext } from './modifiable-tool.js';
import { getSpecificMimeType } from '../utils/fileUtils.js';
import {
recordFileOperationMetric,
FileOperation,
} from '../telemetry/metrics.js';
import { IDEConnectionStatus } from '../ide/ide-client.js';
/**
* Parameters for the WriteFile tool
@@ -51,6 +52,11 @@ export interface WriteFileToolParams {
* Whether the proposed content was modified by the user.
*/
modified_by_user?: boolean;
/**
* Initially proposed content.
*/
ai_proposed_content?: string;
}
interface GetCorrectedFileContentResult {
@@ -65,7 +71,7 @@ interface GetCorrectedFileContentResult {
*/
export class WriteFileTool
extends BaseTool<WriteFileToolParams, ToolResult>
implements ModifiableTool<WriteFileToolParams>
implements ModifiableDeclarativeTool<WriteFileToolParams>
{
static readonly Name: string = 'write_file';
@@ -82,21 +88,24 @@ export class WriteFileTool
file_path: {
description:
"The absolute path to the file to write to (e.g., '/home/user/project/file.txt'). Relative paths are not supported.",
type: Type.STRING,
type: 'string',
},
content: {
description: 'The content to write to the file.',
type: Type.STRING,
type: 'string',
},
},
required: ['file_path', 'content'],
type: Type.OBJECT,
type: 'object',
},
);
}
validateToolParams(params: WriteFileToolParams): string | null {
const errors = SchemaValidator.validate(this.schema.parameters, params);
const errors = SchemaValidator.validate(
this.schema.parametersJsonSchema,
params,
);
if (errors) {
return errors;
}
@@ -184,10 +193,19 @@ export class WriteFileTool
DEFAULT_DIFF_OPTIONS,
);
const ideClient = this.config.getIdeClient();
const ideConfirmation =
this.config.getIdeModeFeature() &&
this.config.getIdeMode() &&
ideClient.getConnectionStatus().status === IDEConnectionStatus.Connected
? ideClient.openDiff(params.file_path, correctedContent)
: undefined;
const confirmationDetails: ToolEditConfirmationDetails = {
type: 'edit',
title: `Confirm Write: ${shortenPath(relativePath)}`,
fileName,
filePath: params.file_path,
fileDiff,
originalContent,
newContent: correctedContent,
@@ -195,7 +213,15 @@ export class WriteFileTool
if (outcome === ToolConfirmationOutcome.ProceedAlways) {
this.config.setApprovalMode(ApprovalMode.AUTO_EDIT);
}
if (ideConfirmation) {
const result = await ideConfirmation;
if (result.status === 'accepted' && result.content) {
params.content = result.content;
}
}
},
ideConfirmation,
};
return confirmationDetails;
}
@@ -207,8 +233,12 @@ export class WriteFileTool
const validationError = this.validateToolParams(params);
if (validationError) {
return {
llmContent: `Error: Invalid parameters provided. Reason: ${validationError}`,
returnDisplay: `Error: ${validationError}`,
llmContent: `Could not write file due to invalid parameters: ${validationError}`,
returnDisplay: validationError,
error: {
message: validationError,
type: ToolErrorType.INVALID_TOOL_PARAMS,
},
};
}
@@ -220,10 +250,16 @@ export class WriteFileTool
if (correctedContentResult.error) {
const errDetails = correctedContentResult.error;
const errorMsg = `Error checking existing file: ${errDetails.message}`;
const errorMsg = errDetails.code
? `Error checking existing file '${params.file_path}': ${errDetails.message} (${errDetails.code})`
: `Error checking existing file: ${errDetails.message}`;
return {
llmContent: `Error checking existing file ${params.file_path}: ${errDetails.message}`,
llmContent: errorMsg,
returnDisplay: errorMsg,
error: {
message: errorMsg,
type: ToolErrorType.FILE_WRITE_FAILURE,
},
};
}
@@ -265,6 +301,15 @@ export class WriteFileTool
DEFAULT_DIFF_OPTIONS,
);
const originallyProposedContent =
params.ai_proposed_content || params.content;
const diffStat = getDiffStat(
fileName,
currentContentForDiff,
originallyProposedContent,
params.content,
);
const llmSuccessMessageParts = [
isNewFile
? `Successfully created and wrote to new file: ${params.file_path}.`
@@ -281,6 +326,7 @@ export class WriteFileTool
fileName,
originalContent: correctedContentResult.originalContent,
newContent: correctedContentResult.correctedContent,
diffStat,
};
const lines = fileContent.split('\n').length;
@@ -293,6 +339,7 @@ export class WriteFileTool
lines,
mimetype,
extension,
diffStat,
);
} else {
recordFileOperationMetric(
@@ -301,6 +348,7 @@ export class WriteFileTool
lines,
mimetype,
extension,
diffStat,
);
}
@@ -309,10 +357,43 @@ export class WriteFileTool
returnDisplay: displayResult,
};
} catch (error) {
const errorMsg = `Error writing to file: ${error instanceof Error ? error.message : String(error)}`;
// Capture detailed error information for debugging
let errorMsg: string;
let errorType = ToolErrorType.FILE_WRITE_FAILURE;
if (isNodeError(error)) {
// Handle specific Node.js errors with their error codes
errorMsg = `Error writing to file '${params.file_path}': ${error.message} (${error.code})`;
// Log specific error types for better debugging
if (error.code === 'EACCES') {
errorMsg = `Permission denied writing to file: ${params.file_path} (${error.code})`;
errorType = ToolErrorType.PERMISSION_DENIED;
} else if (error.code === 'ENOSPC') {
errorMsg = `No space left on device: ${params.file_path} (${error.code})`;
errorType = ToolErrorType.NO_SPACE_LEFT;
} else if (error.code === 'EISDIR') {
errorMsg = `Target is a directory, not a file: ${params.file_path} (${error.code})`;
errorType = ToolErrorType.TARGET_IS_DIRECTORY;
}
// Include stack trace in debug mode for better troubleshooting
if (this.config.getDebugMode() && error.stack) {
console.error('Write file error stack:', error.stack);
}
} else if (error instanceof Error) {
errorMsg = `Error writing to file: ${error.message}`;
} else {
errorMsg = `Error writing to file: ${String(error)}`;
}
return {
llmContent: `Error writing to file ${params.file_path}: ${errorMsg}`,
returnDisplay: `Error: ${errorMsg}`,
llmContent: errorMsg,
returnDisplay: errorMsg,
error: {
message: errorMsg,
type: errorType,
},
};
}
}
@@ -400,11 +481,15 @@ export class WriteFileTool
_oldContent: string,
modifiedProposedContent: string,
originalParams: WriteFileToolParams,
) => ({
...originalParams,
content: modifiedProposedContent,
modified_by_user: true,
}),
) => {
const content = originalParams.content;
return {
...originalParams,
ai_proposed_content: content,
content: modifiedProposedContent,
modified_by_user: true,
};
},
};
}
}

View File

@@ -190,80 +190,43 @@ describe('bfsFileSearch', () => {
});
});
it('should perform parallel directory scanning efficiently (performance test)', async () => {
// Create a more complex directory structure for performance testing
console.log('\n🚀 Testing Parallel BFS Performance...');
it('should find all files in a complex directory structure', async () => {
// Create a complex directory structure to test correctness at scale
// without flaky performance checks.
const numDirs = 50;
const numFilesPerDir = 2;
const numTargetDirs = 10;
// Create 50 directories with multiple levels for faster test execution
for (let i = 0; i < 50; i++) {
await createEmptyDir(`dir${i}`);
await createEmptyDir(`dir${i}`, 'subdir1');
await createEmptyDir(`dir${i}`, 'subdir2');
await createEmptyDir(`dir${i}`, 'subdir1', 'deep');
if (i < 10) {
// Add target files in some directories
await createTestFile('content', `dir${i}`, 'QWEN.md');
await createTestFile('content', `dir${i}`, 'subdir1', 'QWEN.md');
}
const dirCreationPromises: Array<Promise<unknown>> = [];
for (let i = 0; i < numDirs; i++) {
dirCreationPromises.push(createEmptyDir(`dir${i}`));
dirCreationPromises.push(createEmptyDir(`dir${i}`, 'subdir1'));
dirCreationPromises.push(createEmptyDir(`dir${i}`, 'subdir2'));
dirCreationPromises.push(createEmptyDir(`dir${i}`, 'subdir1', 'deep'));
}
await Promise.all(dirCreationPromises);
// Run multiple iterations to ensure consistency
const iterations = 3;
const durations: number[] = [];
let foundFiles = 0;
let firstResultSorted: string[] | undefined;
for (let i = 0; i < iterations; i++) {
const searchStartTime = performance.now();
const result = await bfsFileSearch(testRootDir, {
fileName: 'QWEN.md',
maxDirs: 200,
debug: false,
});
const duration = performance.now() - searchStartTime;
durations.push(duration);
// Verify consistency: all iterations should find the exact same files
if (firstResultSorted === undefined) {
foundFiles = result.length;
firstResultSorted = result.sort();
} else {
expect(result.sort()).toEqual(firstResultSorted);
}
console.log(`📊 Iteration ${i + 1}: ${duration.toFixed(2)}ms`);
const fileCreationPromises: Array<Promise<string>> = [];
for (let i = 0; i < numTargetDirs; i++) {
// Add target files in some directories
fileCreationPromises.push(
createTestFile('content', `dir${i}`, 'GEMINI.md'),
);
fileCreationPromises.push(
createTestFile('content', `dir${i}`, 'subdir1', 'GEMINI.md'),
);
}
const expectedFiles = await Promise.all(fileCreationPromises);
const avgDuration = durations.reduce((a, b) => a + b, 0) / durations.length;
const maxDuration = Math.max(...durations);
const minDuration = Math.min(...durations);
const result = await bfsFileSearch(testRootDir, {
fileName: 'GEMINI.md',
// Provide a generous maxDirs limit to ensure it doesn't prematurely stop
// in this large test case. Total dirs created is 200.
maxDirs: 250,
});
console.log(`📊 Average Duration: ${avgDuration.toFixed(2)}ms`);
console.log(
`📊 Min/Max Duration: ${minDuration.toFixed(2)}ms / ${maxDuration.toFixed(2)}ms`,
);
console.log(`📁 Found ${foundFiles} QWEN.md files`);
console.log(
`🏎️ Processing ~${Math.round(200 / (avgDuration / 1000))} dirs/second`,
);
// Verify we found the expected files
expect(foundFiles).toBe(20); // 10 dirs * 2 files each
// Performance expectation: check consistency rather than absolute time
const variance = maxDuration - minDuration;
const consistencyRatio = variance / avgDuration;
// Ensure reasonable performance (generous limit for CI environments)
expect(avgDuration).toBeLessThan(2000); // Very generous limit
// Ensure consistency across runs (variance should not be too high)
// More tolerant in CI environments where performance can be variable
const maxConsistencyRatio = process.env.CI ? 3.0 : 1.5;
expect(consistencyRatio).toBeLessThan(maxConsistencyRatio); // Max variance should be reasonable
console.log(
`✅ Performance test passed: avg=${avgDuration.toFixed(2)}ms, consistency=${(consistencyRatio * 100).toFixed(1)}% (threshold: ${(maxConsistencyRatio * 100).toFixed(0)}%)`,
);
// Verify we found the exact files we created
expect(result.length).toBe(numTargetDirs * numFilesPerDir);
expect(result.sort()).toEqual(expectedFiles.sort());
});
});

View File

@@ -4,12 +4,7 @@
* SPDX-License-Identifier: Apache-2.0
*/
import {
Content,
GenerateContentConfig,
SchemaUnion,
Type,
} from '@google/genai';
import { Content, GenerateContentConfig } from '@google/genai';
import { GeminiClient } from '../core/client.js';
import { EditToolParams, EditTool } from '../tools/edit.js';
import { WriteFileTool } from '../tools/write-file.js';
@@ -364,11 +359,11 @@ export async function ensureCorrectFileContent(
}
// Define the expected JSON schema for the LLM response for old_string correction
const OLD_STRING_CORRECTION_SCHEMA: SchemaUnion = {
type: Type.OBJECT,
const OLD_STRING_CORRECTION_SCHEMA: Record<string, unknown> = {
type: 'object',
properties: {
corrected_target_snippet: {
type: Type.STRING,
type: 'string',
description:
'The corrected version of the target snippet that exactly and uniquely matches a segment within the provided file content.',
},
@@ -438,11 +433,11 @@ Return ONLY the corrected target snippet in the specified JSON format with the k
}
// Define the expected JSON schema for the new_string correction LLM response
const NEW_STRING_CORRECTION_SCHEMA: SchemaUnion = {
type: Type.OBJECT,
const NEW_STRING_CORRECTION_SCHEMA: Record<string, unknown> = {
type: 'object',
properties: {
corrected_new_string: {
type: Type.STRING,
type: 'string',
description:
'The original_new_string adjusted to be a suitable replacement for the corrected_old_string, while maintaining the original intent of the change.',
},
@@ -521,11 +516,11 @@ Return ONLY the corrected string in the specified JSON format with the key 'corr
}
}
const CORRECT_NEW_STRING_ESCAPING_SCHEMA: SchemaUnion = {
type: Type.OBJECT,
const CORRECT_NEW_STRING_ESCAPING_SCHEMA: Record<string, unknown> = {
type: 'object',
properties: {
corrected_new_string_escaping: {
type: Type.STRING,
type: 'string',
description:
'The new_string with corrected escaping, ensuring it is a proper replacement for the old_string, especially considering potential over-escaping issues from previous LLM generations.',
},
@@ -593,11 +588,11 @@ Return ONLY the corrected string in the specified JSON format with the key 'corr
}
}
const CORRECT_STRING_ESCAPING_SCHEMA: SchemaUnion = {
type: Type.OBJECT,
const CORRECT_STRING_ESCAPING_SCHEMA: Record<string, unknown> = {
type: 'object',
properties: {
corrected_string_escaping: {
type: Type.STRING,
type: 'string',
description:
'The string with corrected escaping, ensuring it is valid, specially considering potential over-escaping issues from previous LLM generations.',
},

View File

@@ -0,0 +1,205 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import {
describe,
it,
expect,
vi,
beforeEach,
afterEach,
type Mock,
} from 'vitest';
import {
getEnvironmentContext,
getDirectoryContextString,
} from './environmentContext.js';
import { Config } from '../config/config.js';
import { getFolderStructure } from './getFolderStructure.js';
vi.mock('../config/config.js');
vi.mock('./getFolderStructure.js', () => ({
getFolderStructure: vi.fn(),
}));
vi.mock('../tools/read-many-files.js');
describe('getDirectoryContextString', () => {
let mockConfig: Partial<Config>;
beforeEach(() => {
mockConfig = {
getWorkspaceContext: vi.fn().mockReturnValue({
getDirectories: vi.fn().mockReturnValue(['/test/dir']),
}),
getFileService: vi.fn(),
};
vi.mocked(getFolderStructure).mockResolvedValue('Mock Folder Structure');
});
afterEach(() => {
vi.resetAllMocks();
});
it('should return context string for a single directory', async () => {
const contextString = await getDirectoryContextString(mockConfig as Config);
expect(contextString).toContain(
"I'm currently working in the directory: /test/dir",
);
expect(contextString).toContain(
'Here is the folder structure of the current working directories:\n\nMock Folder Structure',
);
});
it('should return context string for multiple directories', async () => {
(
vi.mocked(mockConfig.getWorkspaceContext!)().getDirectories as Mock
).mockReturnValue(['/test/dir1', '/test/dir2']);
vi.mocked(getFolderStructure)
.mockResolvedValueOnce('Structure 1')
.mockResolvedValueOnce('Structure 2');
const contextString = await getDirectoryContextString(mockConfig as Config);
expect(contextString).toContain(
"I'm currently working in the following directories:\n - /test/dir1\n - /test/dir2",
);
expect(contextString).toContain(
'Here is the folder structure of the current working directories:\n\nStructure 1\nStructure 2',
);
});
});
describe('getEnvironmentContext', () => {
let mockConfig: Partial<Config>;
let mockToolRegistry: { getTool: Mock };
beforeEach(() => {
vi.useFakeTimers();
vi.setSystemTime(new Date('2025-08-05T12:00:00Z'));
mockToolRegistry = {
getTool: vi.fn(),
};
mockConfig = {
getWorkspaceContext: vi.fn().mockReturnValue({
getDirectories: vi.fn().mockReturnValue(['/test/dir']),
}),
getFileService: vi.fn(),
getFullContext: vi.fn().mockReturnValue(false),
getToolRegistry: vi.fn().mockResolvedValue(mockToolRegistry),
};
vi.mocked(getFolderStructure).mockResolvedValue('Mock Folder Structure');
});
afterEach(() => {
vi.useRealTimers();
vi.resetAllMocks();
});
it('should return basic environment context for a single directory', async () => {
const parts = await getEnvironmentContext(mockConfig as Config);
expect(parts.length).toBe(1);
const context = parts[0].text;
expect(context).toContain("Today's date is Tuesday, August 5, 2025");
expect(context).toContain(`My operating system is: ${process.platform}`);
expect(context).toContain(
"I'm currently working in the directory: /test/dir",
);
expect(context).toContain(
'Here is the folder structure of the current working directories:\n\nMock Folder Structure',
);
expect(getFolderStructure).toHaveBeenCalledWith('/test/dir', {
fileService: undefined,
});
});
it('should return basic environment context for multiple directories', async () => {
(
vi.mocked(mockConfig.getWorkspaceContext!)().getDirectories as Mock
).mockReturnValue(['/test/dir1', '/test/dir2']);
vi.mocked(getFolderStructure)
.mockResolvedValueOnce('Structure 1')
.mockResolvedValueOnce('Structure 2');
const parts = await getEnvironmentContext(mockConfig as Config);
expect(parts.length).toBe(1);
const context = parts[0].text;
expect(context).toContain(
"I'm currently working in the following directories:\n - /test/dir1\n - /test/dir2",
);
expect(context).toContain(
'Here is the folder structure of the current working directories:\n\nStructure 1\nStructure 2',
);
expect(getFolderStructure).toHaveBeenCalledTimes(2);
});
it('should include full file context when getFullContext is true', async () => {
mockConfig.getFullContext = vi.fn().mockReturnValue(true);
const mockReadManyFilesTool = {
build: vi.fn().mockReturnValue({
execute: vi
.fn()
.mockResolvedValue({ llmContent: 'Full file content here' }),
}),
};
mockToolRegistry.getTool.mockReturnValue(mockReadManyFilesTool);
const parts = await getEnvironmentContext(mockConfig as Config);
expect(parts.length).toBe(2);
expect(parts[1].text).toBe(
'\n--- Full File Context ---\nFull file content here',
);
expect(mockToolRegistry.getTool).toHaveBeenCalledWith('read_many_files');
expect(mockReadManyFilesTool.build).toHaveBeenCalledWith({
paths: ['**/*'],
useDefaultExcludes: true,
});
});
it('should handle read_many_files returning no content', async () => {
mockConfig.getFullContext = vi.fn().mockReturnValue(true);
const mockReadManyFilesTool = {
build: vi.fn().mockReturnValue({
execute: vi.fn().mockResolvedValue({ llmContent: '' }),
}),
};
mockToolRegistry.getTool.mockReturnValue(mockReadManyFilesTool);
const parts = await getEnvironmentContext(mockConfig as Config);
expect(parts.length).toBe(1); // No extra part added
});
it('should handle read_many_files tool not being found', async () => {
mockConfig.getFullContext = vi.fn().mockReturnValue(true);
mockToolRegistry.getTool.mockReturnValue(null);
const parts = await getEnvironmentContext(mockConfig as Config);
expect(parts.length).toBe(1); // No extra part added
});
it('should handle errors when reading full file context', async () => {
mockConfig.getFullContext = vi.fn().mockReturnValue(true);
const mockReadManyFilesTool = {
build: vi.fn().mockReturnValue({
execute: vi.fn().mockRejectedValue(new Error('Read error')),
}),
};
mockToolRegistry.getTool.mockReturnValue(mockReadManyFilesTool);
const parts = await getEnvironmentContext(mockConfig as Config);
expect(parts.length).toBe(2);
expect(parts[1].text).toBe('\n--- Error reading full file context ---');
});
});

View File

@@ -0,0 +1,109 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { Part } from '@google/genai';
import { Config } from '../config/config.js';
import { getFolderStructure } from './getFolderStructure.js';
/**
* Generates a string describing the current workspace directories and their structures.
* @param {Config} config - The runtime configuration and services.
* @returns {Promise<string>} A promise that resolves to the directory context string.
*/
export async function getDirectoryContextString(
config: Config,
): Promise<string> {
const workspaceContext = config.getWorkspaceContext();
const workspaceDirectories = workspaceContext.getDirectories();
const folderStructures = await Promise.all(
workspaceDirectories.map((dir) =>
getFolderStructure(dir, {
fileService: config.getFileService(),
}),
),
);
const folderStructure = folderStructures.join('\n');
let workingDirPreamble: string;
if (workspaceDirectories.length === 1) {
workingDirPreamble = `I'm currently working in the directory: ${workspaceDirectories[0]}`;
} else {
const dirList = workspaceDirectories.map((dir) => ` - ${dir}`).join('\n');
workingDirPreamble = `I'm currently working in the following directories:\n${dirList}`;
}
return `${workingDirPreamble}
Here is the folder structure of the current working directories:
${folderStructure}`;
}
/**
* Retrieves environment-related information to be included in the chat context.
* This includes the current working directory, date, operating system, and folder structure.
* Optionally, it can also include the full file context if enabled.
* @param {Config} config - The runtime configuration and services.
* @returns A promise that resolves to an array of `Part` objects containing environment information.
*/
export async function getEnvironmentContext(config: Config): Promise<Part[]> {
const today = new Date().toLocaleDateString(undefined, {
weekday: 'long',
year: 'numeric',
month: 'long',
day: 'numeric',
});
const platform = process.platform;
const directoryContext = await getDirectoryContextString(config);
const context = `
This is the Qwen Code. We are setting up the context for our chat.
Today's date is ${today}.
My operating system is: ${platform}
${directoryContext}
`.trim();
const initialParts: Part[] = [{ text: context }];
const toolRegistry = await config.getToolRegistry();
// Add full file context if the flag is set
if (config.getFullContext()) {
try {
const readManyFilesTool = toolRegistry.getTool('read_many_files');
if (readManyFilesTool) {
const invocation = readManyFilesTool.build({
paths: ['**/*'], // Read everything recursively
useDefaultExcludes: true, // Use default excludes
});
// Read all files in the target directory
const result = await invocation.execute(AbortSignal.timeout(30000));
if (result.llmContent) {
initialParts.push({
text: `\n--- Full File Context ---\n${result.llmContent}`,
});
} else {
console.warn(
'Full context requested, but read_many_files returned no content.',
);
}
} else {
console.warn(
'Full context requested, but read_many_files tool not found.',
);
}
} catch (error) {
// Not using reportError here as it's a startup/config phase, not a chat/generation phase error.
console.error('Error reading full file context:', error);
initialParts.push({
text: '\n--- Error reading full file context ---',
});
}
}
return initialParts;
}

View File

@@ -196,9 +196,13 @@ describe('fileUtils', () => {
vi.restoreAllMocks(); // Restore spies on actualNodeFs
});
it('should detect typescript type by extension (ts)', async () => {
it('should detect typescript type by extension (ts, mts, cts, tsx)', async () => {
expect(await detectFileType('file.ts')).toBe('text');
expect(await detectFileType('file.test.ts')).toBe('text');
expect(await detectFileType('file.mts')).toBe('text');
expect(await detectFileType('vite.config.mts')).toBe('text');
expect(await detectFileType('file.cts')).toBe('text');
expect(await detectFileType('component.tsx')).toBe('text');
});
it('should detect image type by extension (png)', async () => {
@@ -416,10 +420,7 @@ describe('fileUtils', () => {
); // Read lines 6-10
const expectedContent = lines.slice(5, 10).join('\n');
expect(result.llmContent).toContain(expectedContent);
expect(result.llmContent).toContain(
'[File content truncated: showing lines 6-10 of 20 total lines. Use offset/limit parameters to view more.]',
);
expect(result.llmContent).toBe(expectedContent);
expect(result.returnDisplay).toBe('Read lines 6-10 of 20 from test.txt');
expect(result.isTruncated).toBe(true);
expect(result.originalLineCount).toBe(20);
@@ -440,9 +441,6 @@ describe('fileUtils', () => {
const expectedContent = lines.slice(10, 20).join('\n');
expect(result.llmContent).toContain(expectedContent);
expect(result.llmContent).toContain(
'[File content truncated: showing lines 11-20 of 20 total lines. Use offset/limit parameters to view more.]',
);
expect(result.returnDisplay).toBe('Read lines 11-20 of 20 from test.txt');
expect(result.isTruncated).toBe(true); // This is the key check for the bug
expect(result.originalLineCount).toBe(20);
@@ -485,9 +483,6 @@ describe('fileUtils', () => {
longLine.substring(0, 2000) + '... [truncated]',
);
expect(result.llmContent).toContain('Another short line');
expect(result.llmContent).toContain(
'[File content partially truncated: some lines exceeded maximum length of 2000 characters.]',
);
expect(result.returnDisplay).toBe(
'Read all 3 lines from test.txt (some lines were shortened)',
);

View File

@@ -122,9 +122,10 @@ export async function detectFileType(
): Promise<'text' | 'image' | 'pdf' | 'audio' | 'video' | 'binary' | 'svg'> {
const ext = path.extname(filePath).toLowerCase();
// The mimetype for "ts" is MPEG transport stream (a video format) but we want
// to assume these are typescript files instead.
if (ext === '.ts') {
// The mimetype for various TypeScript extensions (ts, mts, cts, tsx) can be
// MPEG transport stream (a video format), but we want to assume these are
// TypeScript files instead.
if (['.ts', '.mts', '.cts'].includes(ext)) {
return 'text';
}
@@ -194,10 +195,18 @@ export async function detectFileType(
return 'text';
}
export enum FileErrorType {
FILE_NOT_FOUND = 'FILE_NOT_FOUND',
IS_DIRECTORY = 'IS_DIRECTORY',
FILE_TOO_LARGE = 'FILE_TOO_LARGE',
READ_ERROR = 'READ_ERROR',
}
export interface ProcessedFileReadResult {
llmContent: PartUnion; // string for text, Part for image/pdf/unreadable binary
returnDisplay: string;
error?: string; // Optional error message for the LLM if file processing failed
errorType?: FileErrorType; // Structured error type using enum
isTruncated?: boolean; // For text files, indicates if content was truncated
originalLineCount?: number; // For text files
linesShown?: [number, number]; // For text files [startLine, endLine] (1-based for display)
@@ -224,6 +233,7 @@ export async function processSingleFileContent(
llmContent: '',
returnDisplay: 'File not found.',
error: `File not found: ${filePath}`,
errorType: FileErrorType.FILE_NOT_FOUND,
};
}
const stats = await fs.promises.stat(filePath);
@@ -232,6 +242,7 @@ export async function processSingleFileContent(
llmContent: '',
returnDisplay: 'Path is a directory.',
error: `Path is a directory, not a file: ${filePath}`,
errorType: FileErrorType.IS_DIRECTORY,
};
}
@@ -302,14 +313,7 @@ export async function processSingleFileContent(
const contentRangeTruncated =
startLine > 0 || endLine < originalLineCount;
const isTruncated = contentRangeTruncated || linesWereTruncatedInLength;
let llmTextContent = '';
if (contentRangeTruncated) {
llmTextContent += `[File content truncated: showing lines ${actualStartLine + 1}-${endLine} of ${originalLineCount} total lines. Use offset/limit parameters to view more.]\n`;
} else if (linesWereTruncatedInLength) {
llmTextContent += `[File content partially truncated: some lines exceeded maximum length of ${MAX_LINE_LENGTH_TEXT_FILE} characters.]\n`;
}
llmTextContent += formattedLines.join('\n');
const llmContent = formattedLines.join('\n');
// By default, return nothing to streamline the common case of a successful read_file.
let returnDisplay = '';
@@ -325,7 +329,7 @@ export async function processSingleFileContent(
}
return {
llmContent: llmTextContent,
llmContent,
returnDisplay,
isTruncated,
originalLineCount,

View File

@@ -26,6 +26,17 @@ describe('CrawlCache', () => {
const key2 = getCacheKey('/foo', 'baz');
expect(key1).not.toBe(key2);
});
it('should generate a different hash for different maxDepth values', () => {
const key1 = getCacheKey('/foo', 'bar', 1);
const key2 = getCacheKey('/foo', 'bar', 2);
const key3 = getCacheKey('/foo', 'bar', undefined);
const key4 = getCacheKey('/foo', 'bar');
expect(key1).not.toBe(key2);
expect(key1).not.toBe(key3);
expect(key2).not.toBe(key3);
expect(key3).toBe(key4);
});
});
describe('in-memory cache operations', () => {

View File

@@ -17,10 +17,14 @@ const cacheTimers = new Map<string, NodeJS.Timeout>();
export const getCacheKey = (
directory: string,
ignoreContent: string,
maxDepth?: number,
): string => {
const hash = crypto.createHash('sha256');
hash.update(directory);
hash.update(ignoreContent);
if (maxDepth !== undefined) {
hash.update(String(maxDepth));
}
return hash.digest('hex');
};

View File

@@ -290,6 +290,30 @@ describe('FileSearch', () => {
expect(results).toEqual(['src/file1.js', 'src/file2.js']); // Assuming alphabetical sort
});
it('should use fzf for fuzzy matching when pattern does not contain wildcards', async () => {
tmpDir = await createTmpDir({
src: {
'main.js': '',
'util.ts': '',
'style.css': '',
},
});
const fileSearch = new FileSearch({
projectRoot: tmpDir,
useGitignore: false,
useGeminiignore: false,
ignoreDirs: [],
cache: false,
cacheTtl: 0,
});
await fileSearch.initialize();
const results = await fileSearch.search('sst');
expect(results).toEqual(['src/style.css']);
});
it('should return empty array when no matches are found', async () => {
tmpDir = await createTmpDir({
src: ['file1.js'],
@@ -446,6 +470,46 @@ describe('FileSearch', () => {
expect(crawlSpy).toHaveBeenCalledTimes(1);
});
it('should miss the cache when maxDepth changes', async () => {
tmpDir = await createTmpDir({ 'file1.js': '' });
const getOptions = (maxDepth?: number) => ({
projectRoot: tmpDir,
useGitignore: false,
useGeminiignore: false,
ignoreDirs: [],
cache: true,
cacheTtl: 10000,
maxDepth,
});
// 1. First search with maxDepth: 1, should trigger a crawl.
const fs1 = new FileSearch(getOptions(1));
const crawlSpy1 = vi.spyOn(
fs1 as FileSearchWithPrivateMethods,
'performCrawl',
);
await fs1.initialize();
expect(crawlSpy1).toHaveBeenCalledTimes(1);
// 2. Second search with maxDepth: 2, should be a cache miss and trigger a crawl.
const fs2 = new FileSearch(getOptions(2));
const crawlSpy2 = vi.spyOn(
fs2 as FileSearchWithPrivateMethods,
'performCrawl',
);
await fs2.initialize();
expect(crawlSpy2).toHaveBeenCalledTimes(1);
// 3. Third search with maxDepth: 1 again, should be a cache hit.
const fs3 = new FileSearch(getOptions(1));
const crawlSpy3 = vi.spyOn(
fs3 as FileSearchWithPrivateMethods,
'performCrawl',
);
await fs3.initialize();
expect(crawlSpy3).not.toHaveBeenCalled();
});
});
it('should handle empty or commented-only ignore files', async () => {
@@ -639,4 +703,109 @@ describe('FileSearch', () => {
// 3. Assert that the maxResults limit was respected, even with a cache hit.
expect(limitedResults).toEqual(['file1.js', 'file2.js']);
});
describe('with maxDepth', () => {
beforeEach(async () => {
tmpDir = await createTmpDir({
'file-root.txt': '',
level1: {
'file-level1.txt': '',
level2: {
'file-level2.txt': '',
level3: {
'file-level3.txt': '',
},
},
},
});
});
it('should only search top-level files when maxDepth is 0', async () => {
const fileSearch = new FileSearch({
projectRoot: tmpDir,
useGitignore: false,
useGeminiignore: false,
ignoreDirs: [],
cache: false,
cacheTtl: 0,
maxDepth: 0,
});
await fileSearch.initialize();
const results = await fileSearch.search('');
expect(results).toEqual(['level1/', 'file-root.txt']);
});
it('should search one level deep when maxDepth is 1', async () => {
const fileSearch = new FileSearch({
projectRoot: tmpDir,
useGitignore: false,
useGeminiignore: false,
ignoreDirs: [],
cache: false,
cacheTtl: 0,
maxDepth: 1,
});
await fileSearch.initialize();
const results = await fileSearch.search('');
expect(results).toEqual([
'level1/',
'level1/level2/',
'file-root.txt',
'level1/file-level1.txt',
]);
});
it('should search two levels deep when maxDepth is 2', async () => {
const fileSearch = new FileSearch({
projectRoot: tmpDir,
useGitignore: false,
useGeminiignore: false,
ignoreDirs: [],
cache: false,
cacheTtl: 0,
maxDepth: 2,
});
await fileSearch.initialize();
const results = await fileSearch.search('');
expect(results).toEqual([
'level1/',
'level1/level2/',
'level1/level2/level3/',
'file-root.txt',
'level1/file-level1.txt',
'level1/level2/file-level2.txt',
]);
});
it('should perform a full recursive search when maxDepth is undefined', async () => {
const fileSearch = new FileSearch({
projectRoot: tmpDir,
useGitignore: false,
useGeminiignore: false,
ignoreDirs: [],
cache: false,
cacheTtl: 0,
maxDepth: undefined, // Explicitly undefined
});
await fileSearch.initialize();
const results = await fileSearch.search('');
expect(results).toEqual([
'level1/',
'level1/level2/',
'level1/level2/level3/',
'file-root.txt',
'level1/file-level1.txt',
'level1/level2/file-level2.txt',
'level1/level2/level3/file-level3.txt',
]);
});
});
});

View File

@@ -11,6 +11,7 @@ import picomatch from 'picomatch';
import { Ignore } from './ignore.js';
import { ResultCache } from './result-cache.js';
import * as cache from './crawlCache.js';
import { AsyncFzf, FzfResultItem } from 'fzf';
export type FileSearchOptions = {
projectRoot: string;
@@ -19,6 +20,7 @@ export type FileSearchOptions = {
useGeminiignore: boolean;
cache: boolean;
cacheTtl: number;
maxDepth?: number;
};
export class AbortError extends Error {
@@ -91,6 +93,7 @@ export class FileSearch {
private readonly ignore: Ignore = new Ignore();
private resultCache: ResultCache | undefined;
private allFiles: string[] = [];
private fzf: AsyncFzf<string[]> | undefined;
/**
* Constructs a new `FileSearch` instance.
@@ -122,22 +125,38 @@ export class FileSearch {
pattern: string,
options: SearchOptions = {},
): Promise<string[]> {
if (!this.resultCache) {
if (!this.resultCache || !this.fzf) {
throw new Error('Engine not initialized. Call initialize() first.');
}
pattern = pattern || '*';
let filteredCandidates;
const { files: candidates, isExactMatch } =
await this.resultCache!.get(pattern);
let filteredCandidates;
if (isExactMatch) {
// Use the cached result.
filteredCandidates = candidates;
} else {
// Apply the user's picomatch pattern filter
filteredCandidates = await filter(candidates, pattern, options.signal);
this.resultCache!.set(pattern, filteredCandidates);
let shouldCache = true;
if (pattern.includes('*')) {
filteredCandidates = await filter(candidates, pattern, options.signal);
} else {
filteredCandidates = await this.fzf
.find(pattern)
.then((results: Array<FzfResultItem<string>>) =>
results.map((entry: FzfResultItem<string>) => entry.item),
)
.catch(() => {
shouldCache = false;
return [];
});
}
if (shouldCache) {
this.resultCache!.set(pattern, filteredCandidates);
}
}
// Trade-off: We apply a two-stage filtering process.
@@ -215,6 +234,7 @@ export class FileSearch {
const cacheKey = cache.getCacheKey(
this.absoluteDir,
this.ignore.getFingerprint(),
this.options.maxDepth,
);
const cachedResults = cache.read(cacheKey);
@@ -230,6 +250,7 @@ export class FileSearch {
const cacheKey = cache.getCacheKey(
this.absoluteDir,
this.ignore.getFingerprint(),
this.options.maxDepth,
);
cache.write(cacheKey, this.allFiles, this.options.cacheTtl * 1000);
}
@@ -257,6 +278,10 @@ export class FileSearch {
return dirFilter(`${relativePath}/`);
});
if (this.options.maxDepth !== undefined) {
api.withMaxDepth(this.options.maxDepth);
}
return api.crawl(this.absoluteDir).withPromise();
}
@@ -265,5 +290,11 @@ export class FileSearch {
*/
private buildResultCache(): void {
this.resultCache = new ResultCache(this.allFiles, this.absoluteDir);
// The v1 algorithm is much faster since it only looks at the first
// occurence of the pattern. We use it for search spaces that have >20k
// files, because the v2 algorithm is just too slow in those cases.
this.fzf = new AsyncFzf(this.allFiles, {
fuzzy: this.allFiles.length > 20000 ? 'v1' : 'v2',
});
}
}

View File

@@ -335,5 +335,8 @@ export async function loadServerHierarchicalMemory(
logger.debug(
`Combined instructions (snippet): ${combinedInstructions.substring(0, 500)}...`,
);
return { memoryContent: combinedInstructions, fileCount: filePaths.length };
return {
memoryContent: combinedInstructions,
fileCount: contentsWithPaths.length,
};
}

View File

@@ -6,7 +6,7 @@
import { describe, it, expect, vi, beforeEach, Mock, afterEach } from 'vitest';
import { Content, GoogleGenAI, Models } from '@google/genai';
import { DEFAULT_GEMINI_FLASH_LITE_MODEL } from '../config/models.js';
import { DEFAULT_GEMINI_FLASH_MODEL } from '../config/models.js';
import { GeminiClient } from '../core/client.js';
import { Config } from '../config/config.js';
import { checkNextSpeaker, NextSpeakerResponse } from './nextSpeakerChecker.js';
@@ -248,6 +248,6 @@ describe('checkNextSpeaker', () => {
expect(mockGeminiClient.generateJson).toHaveBeenCalled();
const generateJsonCall = (mockGeminiClient.generateJson as Mock).mock
.calls[0];
expect(generateJsonCall[3]).toBe(DEFAULT_GEMINI_FLASH_LITE_MODEL);
expect(generateJsonCall[3]).toBe(DEFAULT_GEMINI_FLASH_MODEL);
});
});

View File

@@ -4,8 +4,8 @@
* SPDX-License-Identifier: Apache-2.0
*/
import { Content, SchemaUnion, Type } from '@google/genai';
import { DEFAULT_GEMINI_FLASH_LITE_MODEL } from '../config/models.js';
import { Content } from '@google/genai';
import { DEFAULT_GEMINI_FLASH_MODEL } from '../config/models.js';
import { GeminiClient } from '../core/client.js';
import { GeminiChat } from '../core/geminiChat.js';
import { isFunctionResponse } from './messageInspectors.js';
@@ -16,16 +16,16 @@ const CHECK_PROMPT = `Analyze *only* the content and structure of your immediate
2. **Question to User:** If your last response ends with a direct question specifically addressed *to the user*, then the **'user'** should speak next.
3. **Waiting for User:** If your last response completed a thought, statement, or task *and* does not meet the criteria for Rule 1 (Model Continues) or Rule 2 (Question to User), it implies a pause expecting user input or reaction. In this case, the **'user'** should speak next.`;
const RESPONSE_SCHEMA: SchemaUnion = {
type: Type.OBJECT,
const RESPONSE_SCHEMA: Record<string, unknown> = {
type: 'object',
properties: {
reasoning: {
type: Type.STRING,
type: 'string',
description:
"Brief explanation justifying the 'next_speaker' choice based *strictly* on the applicable rule and the content/structure of the preceding turn.",
},
next_speaker: {
type: Type.STRING,
type: 'string',
enum: ['user', 'model'],
description:
'Who should speak next based *only* on the preceding turn and the decision rules',
@@ -112,7 +112,7 @@ export async function checkNextSpeaker(
contents,
RESPONSE_SCHEMA,
abortSignal,
DEFAULT_GEMINI_FLASH_LITE_MODEL,
DEFAULT_GEMINI_FLASH_MODEL,
)) as unknown as NextSpeakerResponse;
if (

View File

@@ -1,362 +0,0 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import path from 'node:path';
import fs from 'node:fs/promises';
import { openaiLogger } from './openaiLogger.js';
/**
* OpenAI API usage analytics
*
* This utility analyzes OpenAI API logs to provide insights into API usage
* patterns, costs, and performance.
*/
export class OpenAIAnalytics {
/**
* Calculate statistics for OpenAI API usage
* @param days Number of days to analyze (default: 7)
*/
static async calculateStats(days: number = 7): Promise<{
totalRequests: number;
successRate: number;
avgResponseTime: number;
requestsByModel: Record<string, number>;
tokenUsage: {
promptTokens: number;
completionTokens: number;
totalTokens: number;
};
estimatedCost: number;
errorRates: Record<string, number>;
timeDistribution: Record<string, number>;
}> {
const logs = await openaiLogger.getLogFiles();
const now = new Date();
const cutoffDate = new Date(now.getTime() - days * 24 * 60 * 60 * 1000);
let totalRequests = 0;
let successfulRequests = 0;
const totalResponseTime = 0;
const requestsByModel: Record<string, number> = {};
const tokenUsage = { promptTokens: 0, completionTokens: 0, totalTokens: 0 };
const errorTypes: Record<string, number> = {};
const hourDistribution: Record<string, number> = {};
// Initialize hour distribution (0-23)
for (let i = 0; i < 24; i++) {
const hour = i.toString().padStart(2, '0');
hourDistribution[hour] = 0;
}
// Model pricing estimates (per 1000 tokens)
const pricing: Record<string, { input: number; output: number }> = {
'gpt-4': { input: 0.03, output: 0.06 },
'gpt-4-32k': { input: 0.06, output: 0.12 },
'gpt-4-1106-preview': { input: 0.01, output: 0.03 },
'gpt-4-0125-preview': { input: 0.01, output: 0.03 },
'gpt-4-0613': { input: 0.03, output: 0.06 },
'gpt-4-32k-0613': { input: 0.06, output: 0.12 },
'gpt-3.5-turbo': { input: 0.0015, output: 0.002 },
'gpt-3.5-turbo-16k': { input: 0.003, output: 0.004 },
'gpt-3.5-turbo-0613': { input: 0.0015, output: 0.002 },
'gpt-3.5-turbo-16k-0613': { input: 0.003, output: 0.004 },
};
// Default pricing for unknown models
const defaultPricing = { input: 0.01, output: 0.03 };
let estimatedCost = 0;
for (const logFile of logs) {
try {
const logData = await openaiLogger.readLogFile(logFile);
// Type guard to check if logData has the expected structure
if (!isObjectWith<{ timestamp: string }>(logData, ['timestamp'])) {
continue; // Skip malformed logs
}
const logDate = new Date(logData.timestamp);
// Skip if log is older than the cutoff date
if (logDate < cutoffDate) {
continue;
}
totalRequests++;
const hour = logDate.getUTCHours().toString().padStart(2, '0');
hourDistribution[hour]++;
// Check if request was successful
if (
isObjectWith<{ response?: unknown; error?: unknown }>(logData, [
'response',
'error',
]) &&
logData.response &&
!logData.error
) {
successfulRequests++;
// Extract model if available
const model = getModelFromLog(logData);
if (model) {
requestsByModel[model] = (requestsByModel[model] || 0) + 1;
}
// Extract token usage if available
const usage = getTokenUsageFromLog(logData);
if (usage) {
tokenUsage.promptTokens += usage.prompt_tokens || 0;
tokenUsage.completionTokens += usage.completion_tokens || 0;
tokenUsage.totalTokens += usage.total_tokens || 0;
// Calculate cost if model is known
const modelName = model || 'unknown';
const modelPricing = pricing[modelName] || defaultPricing;
const inputCost =
((usage.prompt_tokens || 0) / 1000) * modelPricing.input;
const outputCost =
((usage.completion_tokens || 0) / 1000) * modelPricing.output;
estimatedCost += inputCost + outputCost;
}
} else if (
isObjectWith<{ error?: unknown }>(logData, ['error']) &&
logData.error
) {
// Categorize errors
const errorType = getErrorTypeFromLog(logData);
errorTypes[errorType] = (errorTypes[errorType] || 0) + 1;
}
} catch (error) {
console.error(`Error processing log file ${logFile}:`, error);
}
}
// Calculate success rate and average response time
const successRate =
totalRequests > 0 ? (successfulRequests / totalRequests) * 100 : 0;
const avgResponseTime =
totalRequests > 0 ? totalResponseTime / totalRequests : 0;
// Calculate error rates as percentages
const errorRates: Record<string, number> = {};
for (const [errorType, count] of Object.entries(errorTypes)) {
errorRates[errorType] =
totalRequests > 0 ? (count / totalRequests) * 100 : 0;
}
return {
totalRequests,
successRate,
avgResponseTime,
requestsByModel,
tokenUsage,
estimatedCost,
errorRates,
timeDistribution: hourDistribution,
};
}
/**
* Generate a human-readable report of OpenAI API usage
* @param days Number of days to include in the report
*/
static async generateReport(days: number = 7): Promise<string> {
const stats = await this.calculateStats(days);
let report = `# OpenAI API Usage Report\n`;
report += `## Last ${days} days (${new Date().toISOString().split('T')[0]})\n\n`;
report += `### Overview\n`;
report += `- Total Requests: ${stats.totalRequests}\n`;
report += `- Success Rate: ${stats.successRate.toFixed(2)}%\n`;
report += `- Estimated Cost: $${stats.estimatedCost.toFixed(2)}\n\n`;
report += `### Token Usage\n`;
report += `- Prompt Tokens: ${stats.tokenUsage.promptTokens.toLocaleString()}\n`;
report += `- Completion Tokens: ${stats.tokenUsage.completionTokens.toLocaleString()}\n`;
report += `- Total Tokens: ${stats.tokenUsage.totalTokens.toLocaleString()}\n\n`;
report += `### Models Used\n`;
const sortedModels = Object.entries(stats.requestsByModel) as Array<
[string, number]
>;
sortedModels.sort((a, b) => b[1] - a[1]);
for (const [model, count] of sortedModels) {
const percentage = ((count / stats.totalRequests) * 100).toFixed(1);
report += `- ${model}: ${count} requests (${percentage}%)\n`;
}
if (Object.keys(stats.errorRates).length > 0) {
report += `\n### Error Types\n`;
const sortedErrors = Object.entries(stats.errorRates) as Array<
[string, number]
>;
sortedErrors.sort((a, b) => b[1] - a[1]);
for (const [errorType, rate] of sortedErrors) {
report += `- ${errorType}: ${rate.toFixed(1)}%\n`;
}
}
report += `\n### Usage by Hour (UTC)\n`;
report += `\`\`\`\n`;
const maxRequests = Math.max(...Object.values(stats.timeDistribution));
const scale = 40; // max bar length
for (let i = 0; i < 24; i++) {
const hour = i.toString().padStart(2, '0');
const requests = stats.timeDistribution[hour] || 0;
const barLength =
maxRequests > 0 ? Math.round((requests / maxRequests) * scale) : 0;
const bar = '█'.repeat(barLength);
report += `${hour}:00 ${bar.padEnd(scale)} ${requests}\n`;
}
report += `\`\`\`\n`;
return report;
}
/**
* Save an analytics report to a file
* @param days Number of days to include
* @param outputPath File path for the report (defaults to logs/openai/analytics.md)
*/
static async saveReport(
days: number = 7,
outputPath?: string,
): Promise<string> {
const report = await this.generateReport(days);
const reportPath =
outputPath || path.join(process.cwd(), 'logs', 'openai', 'analytics.md');
await fs.writeFile(reportPath, report, 'utf-8');
return reportPath;
}
}
function isObjectWith<T extends object>(
obj: unknown,
keys: Array<keyof T>,
): obj is T {
return (
typeof obj === 'object' && obj !== null && keys.every((key) => key in obj)
);
}
/**
* Extract the model name from a log entry
*/
function getModelFromLog(logData: unknown): string | undefined {
if (
isObjectWith<{
request?: { model?: string };
response?: { model?: string; modelVersion?: string };
}>(logData, ['request', 'response'])
) {
const data = logData as {
request?: { model?: string };
response?: { model?: string; modelVersion?: string };
};
if (data.request && data.request.model) return data.request.model;
if (data.response && data.response.model) return data.response.model;
if (data.response && data.response.modelVersion)
return data.response.modelVersion;
}
return undefined;
}
/**
* Extract token usage information from a log entry
*/
function getTokenUsageFromLog(logData: unknown):
| {
prompt_tokens?: number;
completion_tokens?: number;
total_tokens?: number;
}
| undefined {
if (
isObjectWith<{
response?: {
usage?: object;
usageMetadata?: {
promptTokenCount?: number;
candidatesTokenCount?: number;
totalTokenCount?: number;
};
};
}>(logData, ['response'])
) {
const data = logData as {
response?: {
usage?: object;
usageMetadata?: {
promptTokenCount?: number;
candidatesTokenCount?: number;
totalTokenCount?: number;
};
};
};
if (data.response && data.response.usage) return data.response.usage;
if (data.response && data.response.usageMetadata) {
const metadata = data.response.usageMetadata;
return {
prompt_tokens: metadata.promptTokenCount,
completion_tokens: metadata.candidatesTokenCount,
total_tokens: metadata.totalTokenCount,
};
}
}
return undefined;
}
/**
* Extract and categorize error types from a log entry
*/
function getErrorTypeFromLog(logData: unknown): string {
if (isObjectWith<{ error?: { message?: string } }>(logData, ['error'])) {
const data = logData as { error?: { message?: string } };
if (data.error) {
const errorMsg = data.error.message || '';
if (errorMsg.includes('rate limit')) return 'rate_limit';
if (errorMsg.includes('timeout')) return 'timeout';
if (errorMsg.includes('authentication')) return 'authentication';
if (errorMsg.includes('quota')) return 'quota_exceeded';
if (errorMsg.includes('invalid')) return 'invalid_request';
if (errorMsg.includes('not available')) return 'model_unavailable';
if (errorMsg.includes('content filter')) return 'content_filtered';
return 'other';
}
}
return 'unknown';
}
// CLI interface when script is run directly
if (import.meta.url === `file://${process.argv[1]}`) {
async function main() {
const args = process.argv.slice(2);
const days = args[0] ? parseInt(args[0], 10) : 7;
try {
const reportPath = await OpenAIAnalytics.saveReport(days);
console.log(`Analytics report saved to: ${reportPath}`);
// Also print to console
const report = await OpenAIAnalytics.generateReport(days);
console.log(report);
} catch (error) {
console.error('Error generating analytics report:', error);
}
}
main().catch(console.error);
}
export default OpenAIAnalytics;

View File

@@ -1,199 +0,0 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import path from 'node:path';
import { openaiLogger } from './openaiLogger.js';
/**
* CLI utility for viewing and managing OpenAI logs
*/
export class OpenAILogViewer {
/**
* List all available OpenAI logs
* @param limit Optional limit on the number of logs to display
*/
static async listLogs(limit?: number): Promise<void> {
try {
const logs = await openaiLogger.getLogFiles(limit);
if (logs.length === 0) {
console.log('No OpenAI logs found');
return;
}
console.log(`Found ${logs.length} OpenAI logs:`);
for (let i = 0; i < logs.length; i++) {
const filePath = logs[i];
const filename = path.basename(filePath);
const logData = await openaiLogger.readLogFile(filePath);
// Type guard for logData
if (typeof logData !== 'object' || logData === null) {
console.log(`${i + 1}. ${filename} - Invalid log data`);
continue;
}
const data = logData as Record<string, unknown>;
// Format the log entry summary
const requestType = getRequestType(data.request);
const status = data.error ? 'ERROR' : 'OK';
console.log(
`${i + 1}. ${filename} - ${requestType} - ${status} - ${data.timestamp}`,
);
}
} catch (error) {
console.error('Error listing logs:', error);
}
}
/**
* View details of a specific log file
* @param identifier Either a log index (1-based) or a filename
*/
static async viewLog(identifier: number | string): Promise<void> {
try {
let logFile: string | undefined;
const logs = await openaiLogger.getLogFiles();
if (logs.length === 0) {
console.log('No OpenAI logs found');
return;
}
if (typeof identifier === 'number') {
// Adjust for 1-based indexing
if (identifier < 1 || identifier > logs.length) {
console.error(
`Invalid log index. Please provide a number between 1 and ${logs.length}`,
);
return;
}
logFile = logs[identifier - 1];
} else {
// Find by filename
logFile = logs.find((log) => path.basename(log) === identifier);
if (!logFile) {
console.error(`Log file '${identifier}' not found`);
return;
}
}
const logData = await openaiLogger.readLogFile(logFile);
console.log(JSON.stringify(logData, null, 2));
} catch (error) {
console.error('Error viewing log:', error);
}
}
/**
* Clean up old logs, keeping only the most recent ones
* @param keepCount Number of recent logs to keep
*/
static async cleanupLogs(keepCount: number = 50): Promise<void> {
try {
const allLogs = await openaiLogger.getLogFiles();
if (allLogs.length === 0) {
console.log('No OpenAI logs found');
return;
}
if (allLogs.length <= keepCount) {
console.log(`Only ${allLogs.length} logs exist, no cleanup needed`);
return;
}
const logsToDelete = allLogs.slice(keepCount);
const fs = await import('node:fs/promises');
for (const log of logsToDelete) {
await fs.unlink(log);
}
console.log(
`Deleted ${logsToDelete.length} old log files. Kept ${keepCount} most recent logs.`,
);
} catch (error) {
console.error('Error cleaning up logs:', error);
}
}
}
/**
* Helper function to determine the type of request in a log
*/
function getRequestType(request: unknown): string {
if (!request) return 'unknown';
if (typeof request !== 'object' || request === null) return 'unknown';
const req = request as Record<string, unknown>;
if (req.contents) {
return 'generate_content';
} else if (typeof req.model === 'string' && req.model.includes('embedding')) {
return 'embedding';
} else if (req.input) {
return 'embedding';
} else if ('countTokens' in req || 'contents' in req) {
return 'count_tokens';
}
return 'api_call';
}
// CLI interface when script is run directly
if (import.meta.url === `file://${process.argv[1]}`) {
async function main() {
const args = process.argv.slice(2);
const command = args[0]?.toLowerCase();
switch (command) {
case 'list': {
const limit = args[1] ? parseInt(args[1], 10) : undefined;
await OpenAILogViewer.listLogs(limit);
break;
}
case 'view': {
const identifier = args[1];
if (!identifier) {
console.error('Please provide a log index or filename to view');
process.exit(1);
}
await OpenAILogViewer.viewLog(
isNaN(Number(identifier)) ? identifier : Number(identifier),
);
break;
}
case 'cleanup': {
const keepCount = args[1] ? parseInt(args[1], 10) : 50;
await OpenAILogViewer.cleanupLogs(keepCount);
break;
}
default:
console.log('OpenAI Log Viewer');
console.log('----------------');
console.log('Commands:');
console.log(
' list [limit] - List all logs, optionally limiting to the specified number',
);
console.log(
' view <index|file> - View a specific log by index number or filename',
);
console.log(
' cleanup [keepCount] - Remove old logs, keeping only the specified number (default: 50)',
);
break;
}
}
main().catch(console.error);
}
export default OpenAILogViewer;

View File

@@ -4,7 +4,6 @@
* SPDX-License-Identifier: Apache-2.0
*/
import { Schema } from '@google/genai';
import AjvPkg from 'ajv';
// Ajv's ESM/CJS interop: use 'any' for compatibility as recommended by Ajv docs
// eslint-disable-next-line @typescript-eslint/no-explicit-any
@@ -19,50 +18,18 @@ export class SchemaValidator {
* Returns null if the data confroms to the schema described by schema (or if schema
* is null). Otherwise, returns a string describing the error.
*/
static validate(schema: Schema | undefined, data: unknown): string | null {
static validate(schema: unknown | undefined, data: unknown): string | null {
if (!schema) {
return null;
}
if (typeof data !== 'object' || data === null) {
return 'Value of params must be an object';
}
const validate = ajValidator.compile(this.toObjectSchema(schema));
const validate = ajValidator.compile(schema);
const valid = validate(data);
if (!valid && validate.errors) {
return ajValidator.errorsText(validate.errors, { dataVar: 'params' });
}
return null;
}
/**
* Converts @google/genai's Schema to an object compatible with avj.
* This is necessary because it represents Types as an Enum (with
* UPPERCASE values) and minItems and minLength as strings, when they should be numbers.
*/
private static toObjectSchema(schema: Schema): object {
const newSchema: Record<string, unknown> = { ...schema };
if (newSchema.anyOf && Array.isArray(newSchema.anyOf)) {
newSchema.anyOf = newSchema.anyOf.map((v) => this.toObjectSchema(v));
}
if (newSchema.items) {
newSchema.items = this.toObjectSchema(newSchema.items);
}
if (newSchema.properties && typeof newSchema.properties === 'object') {
const newProperties: Record<string, unknown> = {};
for (const [key, value] of Object.entries(newSchema.properties)) {
newProperties[key] = this.toObjectSchema(value as Schema);
}
newSchema.properties = newProperties;
}
if (newSchema.type) {
newSchema.type = String(newSchema.type).toLowerCase();
}
if (newSchema.minItems) {
newSchema.minItems = Number(newSchema.minItems);
}
if (newSchema.minLength) {
newSchema.minLength = Number(newSchema.minLength);
}
return newSchema;
}
}