merge main and fix conflict

This commit is contained in:
koalazf.99
2025-08-26 13:50:29 +08:00
176 changed files with 9529 additions and 4561 deletions

View File

@@ -15,3 +15,4 @@ export {
IdeConnectionEvent,
IdeConnectionType,
} from './src/telemetry/types.js';
export { makeFakeConfig } from './src/test-utils/config.js';

View File

@@ -20,7 +20,7 @@
"dist"
],
"dependencies": {
"@google/genai": "1.9.0",
"@google/genai": "1.13.0",
"@modelcontextprotocol/sdk": "^1.11.0",
"@opentelemetry/api": "^1.9.0",
"@opentelemetry/exporter-logs-otlp-grpc": "^0.52.0",

View File

@@ -16,9 +16,17 @@ const mockPaidTier: GeminiUserTier = {
id: UserTierId.STANDARD,
name: 'paid',
description: 'Paid tier',
isDefault: true,
};
describe('setupUser', () => {
const mockFreeTier: GeminiUserTier = {
id: UserTierId.FREE,
name: 'free',
description: 'Free tier',
isDefault: true,
};
describe('setupUser for existing user', () => {
let mockLoad: ReturnType<typeof vi.fn>;
let mockOnboardUser: ReturnType<typeof vi.fn>;
@@ -42,7 +50,7 @@ describe('setupUser', () => {
);
});
it('should use GOOGLE_CLOUD_PROJECT when set', async () => {
it('should use GOOGLE_CLOUD_PROJECT when set and project from server is undefined', async () => {
process.env.GOOGLE_CLOUD_PROJECT = 'test-project';
mockLoad.mockResolvedValue({
currentTier: mockPaidTier,
@@ -57,8 +65,8 @@ describe('setupUser', () => {
);
});
it('should treat empty GOOGLE_CLOUD_PROJECT as undefined and use project from server', async () => {
process.env.GOOGLE_CLOUD_PROJECT = '';
it('should ignore GOOGLE_CLOUD_PROJECT when project from server is set', async () => {
process.env.GOOGLE_CLOUD_PROJECT = 'test-project';
mockLoad.mockResolvedValue({
cloudaicompanionProject: 'server-project',
currentTier: mockPaidTier,
@@ -66,7 +74,7 @@ describe('setupUser', () => {
const projectId = await setupUser({} as OAuth2Client);
expect(CodeAssistServer).toHaveBeenCalledWith(
{},
undefined,
'test-project',
{},
'',
undefined,
@@ -89,3 +97,119 @@ describe('setupUser', () => {
);
});
});
describe('setupUser for new user', () => {
let mockLoad: ReturnType<typeof vi.fn>;
let mockOnboardUser: ReturnType<typeof vi.fn>;
beforeEach(() => {
vi.resetAllMocks();
mockLoad = vi.fn();
mockOnboardUser = vi.fn().mockResolvedValue({
done: true,
response: {
cloudaicompanionProject: {
id: 'server-project',
},
},
});
vi.mocked(CodeAssistServer).mockImplementation(
() =>
({
loadCodeAssist: mockLoad,
onboardUser: mockOnboardUser,
}) as unknown as CodeAssistServer,
);
});
it('should use GOOGLE_CLOUD_PROJECT when set and onboard a new paid user', async () => {
process.env.GOOGLE_CLOUD_PROJECT = 'test-project';
mockLoad.mockResolvedValue({
allowedTiers: [mockPaidTier],
});
const userData = await setupUser({} as OAuth2Client);
expect(CodeAssistServer).toHaveBeenCalledWith(
{},
'test-project',
{},
'',
undefined,
);
expect(mockLoad).toHaveBeenCalled();
expect(mockOnboardUser).toHaveBeenCalledWith({
tierId: 'standard-tier',
cloudaicompanionProject: 'test-project',
metadata: {
ideType: 'IDE_UNSPECIFIED',
platform: 'PLATFORM_UNSPECIFIED',
pluginType: 'GEMINI',
duetProject: 'test-project',
},
});
expect(userData).toEqual({
projectId: 'server-project',
userTier: 'standard-tier',
});
});
it('should onboard a new free user when GOOGLE_CLOUD_PROJECT is not set', async () => {
delete process.env.GOOGLE_CLOUD_PROJECT;
mockLoad.mockResolvedValue({
allowedTiers: [mockFreeTier],
});
const userData = await setupUser({} as OAuth2Client);
expect(CodeAssistServer).toHaveBeenCalledWith(
{},
undefined,
{},
'',
undefined,
);
expect(mockLoad).toHaveBeenCalled();
expect(mockOnboardUser).toHaveBeenCalledWith({
tierId: 'free-tier',
cloudaicompanionProject: undefined,
metadata: {
ideType: 'IDE_UNSPECIFIED',
platform: 'PLATFORM_UNSPECIFIED',
pluginType: 'GEMINI',
},
});
expect(userData).toEqual({
projectId: 'server-project',
userTier: 'free-tier',
});
});
it('should use GOOGLE_CLOUD_PROJECT when onboard response has no project ID', async () => {
process.env.GOOGLE_CLOUD_PROJECT = 'test-project';
mockLoad.mockResolvedValue({
allowedTiers: [mockPaidTier],
});
mockOnboardUser.mockResolvedValue({
done: true,
response: {
cloudaicompanionProject: undefined,
},
});
const userData = await setupUser({} as OAuth2Client);
expect(userData).toEqual({
projectId: 'test-project',
userTier: 'standard-tier',
});
});
it('should throw ProjectIdRequiredError when no project ID is available', async () => {
delete process.env.GOOGLE_CLOUD_PROJECT;
mockLoad.mockResolvedValue({
allowedTiers: [mockPaidTier],
});
mockOnboardUser.mockResolvedValue({
done: true,
response: {},
});
await expect(setupUser({} as OAuth2Client)).rejects.toThrow(
ProjectIdRequiredError,
);
});
});

View File

@@ -33,32 +33,58 @@ export interface UserData {
* @returns the user's actual project id
*/
export async function setupUser(client: OAuth2Client): Promise<UserData> {
let projectId = process.env.GOOGLE_CLOUD_PROJECT || undefined;
const projectId = process.env.GOOGLE_CLOUD_PROJECT || undefined;
const caServer = new CodeAssistServer(client, projectId, {}, '', undefined);
const clientMetadata: ClientMetadata = {
const coreClientMetadata: ClientMetadata = {
ideType: 'IDE_UNSPECIFIED',
platform: 'PLATFORM_UNSPECIFIED',
pluginType: 'GEMINI',
duetProject: projectId,
};
const loadRes = await caServer.loadCodeAssist({
cloudaicompanionProject: projectId,
metadata: clientMetadata,
metadata: {
...coreClientMetadata,
duetProject: projectId,
},
});
if (!projectId && loadRes.cloudaicompanionProject) {
projectId = loadRes.cloudaicompanionProject;
if (loadRes.currentTier) {
if (!loadRes.cloudaicompanionProject) {
if (projectId) {
return {
projectId,
userTier: loadRes.currentTier.id,
};
}
throw new ProjectIdRequiredError();
}
return {
projectId: loadRes.cloudaicompanionProject,
userTier: loadRes.currentTier.id,
};
}
const tier = getOnboardTier(loadRes);
const onboardReq: OnboardUserRequest = {
tierId: tier.id,
cloudaicompanionProject: projectId,
metadata: clientMetadata,
};
let onboardReq: OnboardUserRequest;
if (tier.id === UserTierId.FREE) {
// The free tier uses a managed google cloud project. Setting a project in the `onboardUser` request causes a `Precondition Failed` error.
onboardReq = {
tierId: tier.id,
cloudaicompanionProject: undefined,
metadata: coreClientMetadata,
};
} else {
onboardReq = {
tierId: tier.id,
cloudaicompanionProject: projectId,
metadata: {
...coreClientMetadata,
duetProject: projectId,
},
};
}
// Poll onboardUser until long running operation is complete.
let lroRes = await caServer.onboardUser(onboardReq);
@@ -67,20 +93,23 @@ export async function setupUser(client: OAuth2Client): Promise<UserData> {
lroRes = await caServer.onboardUser(onboardReq);
}
if (!lroRes.response?.cloudaicompanionProject?.id && !projectId) {
if (!lroRes.response?.cloudaicompanionProject?.id) {
if (projectId) {
return {
projectId,
userTier: tier.id,
};
}
throw new ProjectIdRequiredError();
}
return {
projectId: lroRes.response?.cloudaicompanionProject?.id || projectId!,
projectId: lroRes.response.cloudaicompanionProject.id,
userTier: tier.id,
};
}
function getOnboardTier(res: LoadCodeAssistResponse): GeminiUserTier {
if (res.currentTier) {
return res.currentTier;
}
for (const tier of res.allowedTiers || []) {
if (tier.isDefault) {
return tier;

View File

@@ -4,7 +4,8 @@
* SPDX-License-Identifier: Apache-2.0
*/
import { describe, it, expect, vi, beforeEach, Mock } from 'vitest';
import { describe, it, expect, vi, beforeEach } from 'vitest';
import { Mock } from 'vitest';
import { Config, ConfigParameters, SandboxConfig } from './config.js';
import * as path from 'path';
import { setGeminiMdFilename as mockSetGeminiMdFilename } from '../tools/memoryTool.js';
@@ -14,6 +15,7 @@ import {
} from '../telemetry/index.js';
import {
AuthType,
ContentGeneratorConfig,
createContentGeneratorConfig,
} from '../core/contentGenerator.js';
import { GeminiClient } from '../core/client.js';
@@ -131,6 +133,7 @@ describe('Server Config (config.ts)', () => {
telemetry: TELEMETRY_SETTINGS,
sessionId: SESSION_ID,
model: MODEL,
usageStatisticsEnabled: false,
};
beforeEach(() => {
@@ -254,6 +257,7 @@ describe('Server Config (config.ts)', () => {
// Verify that history was restored to the new client
expect(mockNewClient.setHistory).toHaveBeenCalledWith(
mockExistingHistory,
{ stripThoughts: false },
);
});
@@ -287,6 +291,92 @@ describe('Server Config (config.ts)', () => {
// Verify that setHistory was not called since there was no existing history
expect(mockNewClient.setHistory).not.toHaveBeenCalled();
});
it('should strip thoughts when switching from GenAI to Vertex', async () => {
const config = new Config(baseParams);
const mockContentConfig = {
model: 'gemini-pro',
apiKey: 'test-key',
authType: AuthType.USE_GEMINI,
};
(
config as unknown as { contentGeneratorConfig: ContentGeneratorConfig }
).contentGeneratorConfig = mockContentConfig;
(createContentGeneratorConfig as Mock).mockReturnValue({
...mockContentConfig,
authType: AuthType.LOGIN_WITH_GOOGLE,
});
const mockExistingHistory = [
{ role: 'user', parts: [{ text: 'Hello' }] },
];
const mockExistingClient = {
isInitialized: vi.fn().mockReturnValue(true),
getHistory: vi.fn().mockReturnValue(mockExistingHistory),
};
const mockNewClient = {
isInitialized: vi.fn().mockReturnValue(true),
getHistory: vi.fn().mockReturnValue([]),
setHistory: vi.fn(),
initialize: vi.fn().mockResolvedValue(undefined),
};
(
config as unknown as { geminiClient: typeof mockExistingClient }
).geminiClient = mockExistingClient;
(GeminiClient as Mock).mockImplementation(() => mockNewClient);
await config.refreshAuth(AuthType.LOGIN_WITH_GOOGLE);
expect(mockNewClient.setHistory).toHaveBeenCalledWith(
mockExistingHistory,
{ stripThoughts: true },
);
});
it('should not strip thoughts when switching from Vertex to GenAI', async () => {
const config = new Config(baseParams);
const mockContentConfig = {
model: 'gemini-pro',
apiKey: 'test-key',
authType: AuthType.LOGIN_WITH_GOOGLE,
};
(
config as unknown as { contentGeneratorConfig: ContentGeneratorConfig }
).contentGeneratorConfig = mockContentConfig;
(createContentGeneratorConfig as Mock).mockReturnValue({
...mockContentConfig,
authType: AuthType.USE_GEMINI,
});
const mockExistingHistory = [
{ role: 'user', parts: [{ text: 'Hello' }] },
];
const mockExistingClient = {
isInitialized: vi.fn().mockReturnValue(true),
getHistory: vi.fn().mockReturnValue(mockExistingHistory),
};
const mockNewClient = {
isInitialized: vi.fn().mockReturnValue(true),
getHistory: vi.fn().mockReturnValue([]),
setHistory: vi.fn(),
initialize: vi.fn().mockResolvedValue(undefined),
};
(
config as unknown as { geminiClient: typeof mockExistingClient }
).geminiClient = mockExistingClient;
(GeminiClient as Mock).mockImplementation(() => mockNewClient);
await config.refreshAuth(AuthType.USE_GEMINI);
expect(mockNewClient.setHistory).toHaveBeenCalledWith(
mockExistingHistory,
{ stripThoughts: false },
);
});
});
it('Config constructor should store userMemory correctly', () => {
@@ -384,6 +474,28 @@ describe('Server Config (config.ts)', () => {
expect(fileService).toBeDefined();
});
describe('Usage Statistics', () => {
it('defaults usage statistics to enabled if not specified', () => {
const config = new Config({
...baseParams,
usageStatisticsEnabled: undefined,
});
expect(config.getUsageStatisticsEnabled()).toBe(true);
});
it.each([{ enabled: true }, { enabled: false }])(
'sets usage statistics based on the provided value (enabled: $enabled)',
({ enabled }) => {
const config = new Config({
...baseParams,
usageStatisticsEnabled: enabled,
});
expect(config.getUsageStatisticsEnabled()).toBe(enabled);
},
);
});
describe('Telemetry Settings', () => {
it('should return default telemetry target if not provided', () => {
const params: ConfigParameters = {

View File

@@ -193,13 +193,12 @@ export interface ConfigParameters {
extensionContextFilePaths?: string[];
maxSessionTurns?: number;
sessionTokenLimit?: number;
experimentalAcp?: boolean;
experimentalZedIntegration?: boolean;
listExtensions?: boolean;
extensions?: GeminiCLIExtension[];
blockedMcpServers?: Array<{ name: string; extensionName: string }>;
noBrowser?: boolean;
summarizeToolOutput?: Record<string, SummarizeToolOutputSettings>;
ideModeFeature?: boolean;
folderTrustFeature?: boolean;
folderTrust?: boolean;
ideMode?: boolean;
@@ -222,6 +221,7 @@ export interface ConfigParameters {
tavilyApiKey?: string;
chatCompression?: ChatCompressionSettings;
interactive?: boolean;
trustedFolder?: boolean;
}
export class Config {
@@ -265,7 +265,6 @@ export class Config {
private readonly model: string;
private readonly extensionContextFilePaths: string[];
private readonly noBrowser: boolean;
private readonly ideModeFeature: boolean;
private readonly folderTrustFeature: boolean;
private readonly folderTrust: boolean;
private ideMode: boolean;
@@ -289,7 +288,6 @@ export class Config {
private readonly summarizeToolOutput:
| Record<string, SummarizeToolOutputSettings>
| undefined;
private readonly experimentalAcp: boolean = false;
private readonly enableOpenAILogging: boolean;
private readonly contentGenerator?: {
timeout?: number;
@@ -297,10 +295,12 @@ export class Config {
samplingParams?: Record<string, unknown>;
};
private readonly cliVersion?: string;
private readonly experimentalZedIntegration: boolean = false;
private readonly loadMemoryFromIncludeDirectories: boolean = false;
private readonly tavilyApiKey?: string;
private readonly chatCompression: ChatCompressionSettings | undefined;
private readonly interactive: boolean;
private readonly trustedFolder: boolean | undefined;
private initialized: boolean = false;
constructor(params: ConfigParameters) {
@@ -356,13 +356,13 @@ export class Config {
this.extensionContextFilePaths = params.extensionContextFilePaths ?? [];
this.maxSessionTurns = params.maxSessionTurns ?? -1;
this.sessionTokenLimit = params.sessionTokenLimit ?? -1;
this.experimentalAcp = params.experimentalAcp ?? false;
this.experimentalZedIntegration =
params.experimentalZedIntegration ?? false;
this.listExtensions = params.listExtensions ?? false;
this._extensions = params.extensions ?? [];
this._blockedMcpServers = params.blockedMcpServers ?? [];
this.noBrowser = params.noBrowser ?? false;
this.summarizeToolOutput = params.summarizeToolOutput;
this.ideModeFeature = params.ideModeFeature ?? false;
this.folderTrustFeature = params.folderTrustFeature ?? false;
this.folderTrust = params.folderTrust ?? false;
this.ideMode = params.ideMode ?? false;
@@ -376,6 +376,7 @@ export class Config {
params.loadMemoryFromIncludeDirectories ?? false;
this.chatCompression = params.chatCompression;
this.interactive = params.interactive ?? false;
this.trustedFolder = params.trustedFolder;
// Web search
this.tavilyApiKey = params.tavilyApiKey;
@@ -431,13 +432,21 @@ export class Config {
const newGeminiClient = new GeminiClient(this);
await newGeminiClient.initialize(newContentGeneratorConfig);
// Vertex and Genai have incompatible encryption and sending history with
// throughtSignature from Genai to Vertex will fail, we need to strip them
const fromGenaiToVertex =
this.contentGeneratorConfig?.authType === AuthType.USE_GEMINI &&
authMethod === AuthType.LOGIN_WITH_GOOGLE;
// Only assign to instance properties after successful initialization
this.contentGeneratorConfig = newContentGeneratorConfig;
this.geminiClient = newGeminiClient;
// Restore the conversation history to the new client
if (existingHistory.length > 0) {
this.geminiClient.setHistory(existingHistory);
this.geminiClient.setHistory(existingHistory, {
stripThoughts: fromGenaiToVertex,
});
}
// Reset the session flag since we're explicitly changing auth and using default model
@@ -685,8 +694,8 @@ export class Config {
return this.extensionContextFilePaths;
}
getExperimentalAcp(): boolean {
return this.experimentalAcp;
getExperimentalZedIntegration(): boolean {
return this.experimentalZedIntegration;
}
getListExtensions(): boolean {
@@ -720,10 +729,6 @@ export class Config {
return this.tavilyApiKey;
}
getIdeModeFeature(): boolean {
return this.ideModeFeature;
}
getIdeClient(): IdeClient {
return this.ideClient;
}
@@ -740,6 +745,10 @@ export class Config {
return this.folderTrust;
}
isTrustedFolder(): boolean | undefined {
return this.trustedFolder;
}
setIdeMode(value: boolean): void {
this.ideMode = value;
}

View File

@@ -708,7 +708,7 @@ describe('Gemini Client (client.ts)', () => {
});
describe('sendMessageStream', () => {
it('should include editor context when ideModeFeature is enabled', async () => {
it('should include editor context when ideMode is enabled', async () => {
// Arrange
vi.mocked(ideContext.getIdeContext).mockReturnValue({
workspaceState: {
@@ -732,7 +732,7 @@ describe('Gemini Client (client.ts)', () => {
},
});
vi.spyOn(client['config'], 'getIdeModeFeature').mockReturnValue(true);
vi.spyOn(client['config'], 'getIdeMode').mockReturnValue(true);
const mockStream = (async function* () {
yield { type: 'content', value: 'Hello' };
@@ -792,7 +792,7 @@ ${JSON.stringify(
});
});
it('should not add context if ideModeFeature is enabled but no open files', async () => {
it('should not add context if ideMode is enabled but no open files', async () => {
// Arrange
vi.mocked(ideContext.getIdeContext).mockReturnValue({
workspaceState: {
@@ -800,7 +800,7 @@ ${JSON.stringify(
},
});
vi.spyOn(client['config'], 'getIdeModeFeature').mockReturnValue(true);
vi.spyOn(client['config'], 'getIdeMode').mockReturnValue(true);
const mockStream = (async function* () {
yield { type: 'content', value: 'Hello' };
@@ -839,7 +839,7 @@ ${JSON.stringify(
);
});
it('should add context if ideModeFeature is enabled and there is one active file', async () => {
it('should add context if ideMode is enabled and there is one active file', async () => {
// Arrange
vi.mocked(ideContext.getIdeContext).mockReturnValue({
workspaceState: {
@@ -855,7 +855,7 @@ ${JSON.stringify(
},
});
vi.spyOn(client['config'], 'getIdeModeFeature').mockReturnValue(true);
vi.spyOn(client['config'], 'getIdeMode').mockReturnValue(true);
const mockStream = (async function* () {
yield { type: 'content', value: 'Hello' };
@@ -914,7 +914,7 @@ ${JSON.stringify(
});
});
it('should add context if ideModeFeature is enabled and there are open files but no active file', async () => {
it('should add context if ideMode is enabled and there are open files but no active file', async () => {
// Arrange
vi.mocked(ideContext.getIdeContext).mockReturnValue({
workspaceState: {
@@ -931,7 +931,7 @@ ${JSON.stringify(
},
});
vi.spyOn(client['config'], 'getIdeModeFeature').mockReturnValue(true);
vi.spyOn(client['config'], 'getIdeMode').mockReturnValue(true);
const mockStream = (async function* () {
yield { type: 'content', value: 'Hello' };
@@ -1267,7 +1267,7 @@ ${JSON.stringify(
beforeEach(() => {
client['forceFullIdeContext'] = false; // Reset before each delta test
vi.spyOn(client, 'tryCompressChat').mockResolvedValue(null);
vi.spyOn(client['config'], 'getIdeModeFeature').mockReturnValue(true);
vi.spyOn(client['config'], 'getIdeMode').mockReturnValue(true);
mockTurnRunFn.mockReturnValue(mockStream);
const mockChat: Partial<GeminiChat> = {
@@ -1637,4 +1637,73 @@ ${JSON.stringify(
);
});
});
describe('setHistory', () => {
it('should strip thought signatures when stripThoughts is true', () => {
const mockChat = {
setHistory: vi.fn(),
};
client['chat'] = mockChat as unknown as GeminiChat;
const historyWithThoughts: Content[] = [
{
role: 'user',
parts: [{ text: 'hello' }],
},
{
role: 'model',
parts: [
{ text: 'thinking...', thoughtSignature: 'thought-123' },
{
functionCall: { name: 'test', args: {} },
thoughtSignature: 'thought-456',
},
],
},
];
client.setHistory(historyWithThoughts, { stripThoughts: true });
const expectedHistory: Content[] = [
{
role: 'user',
parts: [{ text: 'hello' }],
},
{
role: 'model',
parts: [
{ text: 'thinking...' },
{ functionCall: { name: 'test', args: {} } },
],
},
];
expect(mockChat.setHistory).toHaveBeenCalledWith(expectedHistory);
});
it('should not strip thought signatures when stripThoughts is false', () => {
const mockChat = {
setHistory: vi.fn(),
};
client['chat'] = mockChat as unknown as GeminiChat;
const historyWithThoughts: Content[] = [
{
role: 'user',
parts: [{ text: 'hello' }],
},
{
role: 'model',
parts: [
{ text: 'thinking...', thoughtSignature: 'thought-123' },
{ text: 'ok', thoughtSignature: 'thought-456' },
],
},
];
client.setHistory(historyWithThoughts, { stripThoughts: false });
expect(mockChat.setHistory).toHaveBeenCalledWith(historyWithThoughts);
});
});
});

View File

@@ -162,8 +162,32 @@ export class GeminiClient {
return this.getChat().getHistory();
}
setHistory(history: Content[]) {
this.getChat().setHistory(history);
setHistory(
history: Content[],
{ stripThoughts = false }: { stripThoughts?: boolean } = {},
) {
const historyToSet = stripThoughts
? history.map((content) => {
const newContent = { ...content };
if (newContent.parts) {
newContent.parts = newContent.parts.map((part) => {
if (
part &&
typeof part === 'object' &&
'thoughtSignature' in part
) {
const newPart = { ...part };
delete (newPart as { thoughtSignature?: string })
.thoughtSignature;
return newPart;
}
return part;
});
}
return newContent;
})
: history;
this.getChat().setHistory(historyToSet);
this.forceFullIdeContext = true;
}
@@ -498,11 +522,7 @@ export class GeminiClient {
lastMessage.role === 'model' &&
(lastMessage.parts?.some((p) => 'functionCall' in p) || false);
if (
this.config.getIdeModeFeature() &&
this.config.getIdeMode() &&
!hasPendingToolCall
) {
if (this.config.getIdeMode() && !hasPendingToolCall) {
const { contextParts, newIdeContext } = this.getIdeContextParts(
this.forceFullIdeContext || this.getHistory().length === 0,
);

View File

@@ -16,7 +16,7 @@ import {
import { createCodeAssistContentGenerator } from '../code_assist/codeAssist.js';
import { DEFAULT_GEMINI_MODEL, DEFAULT_QWEN_MODEL } from '../config/models.js';
import { Config } from '../config/config.js';
import { getEffectiveModel } from './modelCheck.js';
import { UserTierId } from '../code_assist/types.js';
import { LoggingContentGenerator } from './loggingContentGenerator.js';
@@ -71,6 +71,7 @@ export type ContentGeneratorConfig = {
max_tokens?: number;
};
proxy?: string | undefined;
userAgent?: string;
};
export function createContentGeneratorConfig(
@@ -112,11 +113,6 @@ export function createContentGeneratorConfig(
if (authType === AuthType.USE_GEMINI && geminiApiKey) {
contentGeneratorConfig.apiKey = geminiApiKey;
contentGeneratorConfig.vertexai = false;
getEffectiveModel(
contentGeneratorConfig.apiKey,
contentGeneratorConfig.model,
contentGeneratorConfig.proxy,
);
return contentGeneratorConfig;
}

View File

@@ -9,7 +9,6 @@ import { describe, it, expect, vi } from 'vitest';
import {
CoreToolScheduler,
ToolCall,
ValidatingToolCall,
convertToFunctionResponse,
} from './coreToolScheduler.js';
import {
@@ -19,7 +18,7 @@ import {
ToolConfirmationPayload,
ToolResult,
Config,
Icon,
Kind,
ApprovalMode,
} from '../index.js';
import { Part, PartListUnion } from '@google/genai';
@@ -54,7 +53,9 @@ class MockModifiableTool
};
}
async shouldConfirmExecute(): Promise<ToolCallConfirmationDetails | false> {
override async shouldConfirmExecute(): Promise<
ToolCallConfirmationDetails | false
> {
if (this.shouldConfirm) {
return {
type: 'edit',
@@ -121,8 +122,6 @@ describe('CoreToolScheduler', () => {
abortController.abort();
await scheduler.schedule([request], abortController.signal);
const _waitingCall = onToolCallsUpdate.mock
.calls[1][0][0] as ValidatingToolCall;
const confirmationDetails = await mockTool.shouldConfirmExecute(
{},
abortController.signal,
@@ -389,12 +388,12 @@ describe('CoreToolScheduler edit cancellation', () => {
'mockEditTool',
'mockEditTool',
'A mock edit tool',
Icon.Pencil,
Kind.Edit,
{},
);
}
async shouldConfirmExecute(
override async shouldConfirmExecute(
_params: Record<string, unknown>,
_abortSignal: AbortSignal,
): Promise<ToolCallConfirmationDetails | false> {

View File

@@ -1,24 +0,0 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
// 移除未使用的导入
/**
* Checks if the default "pro" model is rate-limited and returns a fallback "flash"
* model if necessary. This function is designed to be silent.
* @param apiKey The API key to use for the check.
* @param currentConfiguredModel The model currently configured in settings.
* @returns An object indicating the model to use, whether a switch occurred,
* and the original model if a switch happened.
*/
export async function getEffectiveModel(
_apiKey: string,
currentConfiguredModel: string,
_proxy: string | undefined,
): Promise<string> {
// Disable Google API Model Check
return currentConfiguredModel;
}

View File

@@ -679,7 +679,7 @@ describe('OpenAIContentGenerator', () => {
model: 'text-embedding-ada-002',
};
const _result = await generator.embedContent(request);
await generator.embedContent(request);
expect(mockOpenAIClient.embeddings.create).toHaveBeenCalledWith({
model: 'text-embedding-ada-002',
@@ -1627,7 +1627,7 @@ describe('OpenAIContentGenerator', () => {
describe('error suppression functionality', () => {
it('should allow subclasses to suppress error logging', async () => {
class TestGenerator extends OpenAIContentGenerator {
protected shouldSuppressErrorLogging(): boolean {
protected override shouldSuppressErrorLogging(): boolean {
return true; // Always suppress for this test
}
}

View File

@@ -0,0 +1,92 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { describe, it, expect, afterEach, vi } from 'vitest';
import { detectIde, DetectedIde } from './detect-ide.js';
describe('detectIde', () => {
afterEach(() => {
vi.unstubAllEnvs();
});
it.each([
{
env: {},
expected: DetectedIde.VSCode,
},
{
env: { __COG_BASHRC_SOURCED: '1' },
expected: DetectedIde.Devin,
},
{
env: { REPLIT_USER: 'test' },
expected: DetectedIde.Replit,
},
{
env: { CURSOR_TRACE_ID: 'test' },
expected: DetectedIde.Cursor,
},
{
env: { CODESPACES: 'true' },
expected: DetectedIde.Codespaces,
},
{
env: { EDITOR_IN_CLOUD_SHELL: 'true' },
expected: DetectedIde.CloudShell,
},
{
env: { CLOUD_SHELL: 'true' },
expected: DetectedIde.CloudShell,
},
{
env: { TERM_PRODUCT: 'Trae' },
expected: DetectedIde.Trae,
},
{
env: { FIREBASE_DEPLOY_AGENT: 'true' },
expected: DetectedIde.FirebaseStudio,
},
{
env: { MONOSPACE_ENV: 'true' },
expected: DetectedIde.FirebaseStudio,
},
])('detects the IDE for $expected', ({ env, expected }) => {
// Clear all environment variables first
vi.unstubAllEnvs();
// Set TERM_PROGRAM to vscode (required for all IDE detection)
vi.stubEnv('TERM_PROGRAM', 'vscode');
// Explicitly stub all environment variables that detectIde() checks to undefined
// This ensures no real environment variables interfere with the tests
vi.stubEnv('__COG_BASHRC_SOURCED', undefined);
vi.stubEnv('REPLIT_USER', undefined);
vi.stubEnv('CURSOR_TRACE_ID', undefined);
vi.stubEnv('CODESPACES', undefined);
vi.stubEnv('EDITOR_IN_CLOUD_SHELL', undefined);
vi.stubEnv('CLOUD_SHELL', undefined);
vi.stubEnv('TERM_PRODUCT', undefined);
vi.stubEnv('FIREBASE_DEPLOY_AGENT', undefined);
vi.stubEnv('MONOSPACE_ENV', undefined);
// Set only the specific environment variables for this test case
for (const [key, value] of Object.entries(env)) {
vi.stubEnv(key, value);
}
expect(detectIde()).toBe(expected);
});
it('returns undefined for non-vscode', () => {
// Clear all environment variables first
vi.unstubAllEnvs();
// Set TERM_PROGRAM to something other than vscode
vi.stubEnv('TERM_PROGRAM', 'definitely-not-vscode');
expect(detectIde()).toBeUndefined();
});
});

View File

@@ -5,34 +5,54 @@
*/
export enum DetectedIde {
Devin = 'devin',
Replit = 'replit',
VSCode = 'vscode',
VSCodium = 'vscodium',
Cursor = 'cursor',
CloudShell = 'cloudshell',
Codespaces = 'codespaces',
Windsurf = 'windsurf',
FirebaseStudio = 'firebasestudio',
Trae = 'trae',
}
export function getIdeDisplayName(ide: DetectedIde): string {
export interface IdeInfo {
displayName: string;
}
export function getIdeInfo(ide: DetectedIde): IdeInfo {
switch (ide) {
case DetectedIde.Devin:
return {
displayName: 'Devin',
};
case DetectedIde.Replit:
return {
displayName: 'Replit',
};
case DetectedIde.VSCode:
return 'VS Code';
case DetectedIde.VSCodium:
return 'VSCodium';
return {
displayName: 'VS Code',
};
case DetectedIde.Cursor:
return 'Cursor';
return {
displayName: 'Cursor',
};
case DetectedIde.CloudShell:
return 'Cloud Shell';
return {
displayName: 'Cloud Shell',
};
case DetectedIde.Codespaces:
return 'GitHub Codespaces';
case DetectedIde.Windsurf:
return 'Windsurf';
return {
displayName: 'GitHub Codespaces',
};
case DetectedIde.FirebaseStudio:
return 'Firebase Studio';
return {
displayName: 'Firebase Studio',
};
case DetectedIde.Trae:
return 'Trae';
return {
displayName: 'Trae',
};
default: {
// This ensures that if a new IDE is added to the enum, we get a compile-time error.
const exhaustiveCheck: never = ide;
@@ -46,19 +66,25 @@ export function detectIde(): DetectedIde | undefined {
if (process.env.TERM_PROGRAM !== 'vscode') {
return undefined;
}
if (process.env.__COG_BASHRC_SOURCED) {
return DetectedIde.Devin;
}
if (process.env.REPLIT_USER) {
return DetectedIde.Replit;
}
if (process.env.CURSOR_TRACE_ID) {
return DetectedIde.Cursor;
}
if (process.env.CODESPACES) {
return DetectedIde.Codespaces;
}
if (process.env.EDITOR_IN_CLOUD_SHELL) {
if (process.env.EDITOR_IN_CLOUD_SHELL || process.env.CLOUD_SHELL) {
return DetectedIde.CloudShell;
}
if (process.env.TERM_PRODUCT === 'Trae') {
return DetectedIde.Trae;
}
if (process.env.FIREBASE_DEPLOY_AGENT) {
if (process.env.FIREBASE_DEPLOY_AGENT || process.env.MONOSPACE_ENV) {
return DetectedIde.FirebaseStudio;
}
return DetectedIde.VSCode;

View File

@@ -6,11 +6,7 @@
import * as fs from 'node:fs';
import * as path from 'node:path';
import {
detectIde,
DetectedIde,
getIdeDisplayName,
} from '../ide/detect-ide.js';
import { detectIde, DetectedIde, getIdeInfo } from '../ide/detect-ide.js';
import {
ideContext,
IdeContextNotificationSchema,
@@ -68,7 +64,7 @@ export class IdeClient {
private constructor() {
this.currentIde = detectIde();
if (this.currentIde) {
this.currentIdeDisplayName = getIdeDisplayName(this.currentIde);
this.currentIdeDisplayName = getIdeInfo(this.currentIde).displayName;
}
}
@@ -86,7 +82,7 @@ export class IdeClient {
`IDE integration is not supported in your current environment. To use this feature, run Qwen Code in one of these supported IDEs: ${Object.values(
DetectedIde,
)
.map((ide) => getIdeDisplayName(ide))
.map((ide) => getIdeInfo(ide).displayName)
.join(', ')}`,
false,
);

View File

@@ -24,11 +24,6 @@ describe('ide-installer', () => {
expect(installer).toBeInstanceOf(Object);
});
it('should return null for "vscodium" (not implemented)', () => {
const installer = getIdeInstaller(DetectedIde.VSCodium);
expect(installer).toBeNull();
});
it('should return null for an unknown IDE', () => {
const installer = getIdeInstaller('unknown' as DetectedIde);
expect(installer).toBeNull();

View File

@@ -42,6 +42,7 @@ export * from './utils/systemEncoding.js';
export * from './utils/textUtils.js';
export * from './utils/formatters.js';
export * from './utils/filesearch/fileSearch.js';
export * from './utils/errorParsing.js';
// Export services
export * from './services/fileDiscoveryService.js';
@@ -51,8 +52,8 @@ export * from './services/gitService.js';
export * from './ide/ide-client.js';
export * from './ide/ideContext.js';
export * from './ide/ide-installer.js';
export { getIdeDisplayName, DetectedIde } from './ide/detect-ide.js';
export * from './ide/constants.js';
export { getIdeInfo, DetectedIde, IdeInfo } from './ide/detect-ide.js';
// Export Shell Execution Service
export * from './services/shellExecutionService.js';

View File

@@ -91,7 +91,6 @@ export class MCPOAuthProvider {
private static readonly REDIRECT_PORT = 7777;
private static readonly REDIRECT_PATH = '/oauth/callback';
private static readonly HTTP_OK = 200;
private static readonly HTTP_REDIRECT = 302;
/**
* Register a client dynamically with the OAuth server.

View File

@@ -0,0 +1,9 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { setupServer } from 'msw/node';
export const server = setupServer();

View File

@@ -70,7 +70,7 @@ export class QwenContentGenerator extends OpenAIContentGenerator {
/**
* Override error logging behavior to suppress auth errors during token refresh
*/
protected shouldSuppressErrorLogging(
protected override shouldSuppressErrorLogging(
error: unknown,
_request: GenerateContentParameters,
): boolean {
@@ -81,7 +81,7 @@ export class QwenContentGenerator extends OpenAIContentGenerator {
/**
* Override to use dynamic token and endpoint
*/
async generateContent(
override async generateContent(
request: GenerateContentParameters,
userPromptId: string,
): Promise<GenerateContentResponse> {
@@ -105,7 +105,7 @@ export class QwenContentGenerator extends OpenAIContentGenerator {
/**
* Override to use dynamic token and endpoint
*/
async generateContentStream(
override async generateContentStream(
request: GenerateContentParameters,
userPromptId: string,
): Promise<AsyncGenerator<GenerateContentResponse>> {
@@ -132,7 +132,7 @@ export class QwenContentGenerator extends OpenAIContentGenerator {
/**
* Override to use dynamic token and endpoint
*/
async countTokens(
override async countTokens(
request: CountTokensParameters,
): Promise<CountTokensResponse> {
return this.withValidToken(async (token) => {
@@ -153,7 +153,7 @@ export class QwenContentGenerator extends OpenAIContentGenerator {
/**
* Override to use dynamic token and endpoint
*/
async embedContent(
override async embedContent(
request: EmbedContentParameters,
): Promise<EmbedContentResponse> {
return this.withValidToken(async (token) => {

View File

@@ -223,17 +223,9 @@ describe('Type Guards', () => {
describe('QwenOAuth2Client', () => {
let client: QwenOAuth2Client;
let _mockConfig: Config;
let originalFetch: typeof global.fetch;
beforeEach(() => {
// Setup mock config
_mockConfig = {
getQwenClientId: vi.fn().mockReturnValue('test-client-id'),
isBrowserLaunchSuppressed: vi.fn().mockReturnValue(false),
getProxy: vi.fn().mockReturnValue(undefined),
} as unknown as Config;
// Create client instance
client = new QwenOAuth2Client({ proxy: undefined });
@@ -1010,7 +1002,6 @@ describe('getQwenOAuthClient - Enhanced Error Scenarios', () => {
describe('authWithQwenDeviceFlow - Comprehensive Testing', () => {
let mockConfig: Config;
let originalFetch: typeof global.fetch;
let _client: QwenOAuth2Client;
beforeEach(() => {
mockConfig = {
@@ -1018,7 +1009,7 @@ describe('authWithQwenDeviceFlow - Comprehensive Testing', () => {
isBrowserLaunchSuppressed: vi.fn().mockReturnValue(false),
} as unknown as Config;
_client = new QwenOAuth2Client({ proxy: undefined });
new QwenOAuth2Client({ proxy: undefined });
originalFetch = global.fetch;
global.fetch = vi.fn();

View File

@@ -234,11 +234,8 @@ export interface IQwenOAuth2Client {
*/
export class QwenOAuth2Client implements IQwenOAuth2Client {
private credentials: QwenCredentials = {};
private proxy?: string;
constructor(options: { proxy?: string }) {
this.proxy = options.proxy;
}
constructor(_options?: { proxy?: string }) {}
setCredentials(credentials: QwenCredentials): void {
this.credentials = credentials;

View File

@@ -4,176 +4,306 @@
* SPDX-License-Identifier: Apache-2.0
*/
import { vi, describe, it, expect, beforeEach, afterEach } from 'vitest';
import * as https from 'https';
import { ClientRequest, IncomingMessage } from 'http';
import { Readable, Writable } from 'stream';
import {
ClearcutLogger,
LogResponse,
LogEventEntry,
} from './clearcut-logger.js';
import { Config } from '../../config/config.js';
vi,
describe,
it,
expect,
afterEach,
beforeAll,
afterAll,
} from 'vitest';
import { ClearcutLogger, LogEventEntry, TEST_ONLY } from './clearcut-logger.js';
import { ConfigParameters } from '../../config/config.js';
import * as userAccount from '../../utils/user_account.js';
import * as userId from '../../utils/user_id.js';
import { EventMetadataKey } from './event-metadata-key.js';
import { makeFakeConfig } from '../../test-utils/config.js';
import { http, HttpResponse } from 'msw';
import { server } from '../../mocks/msw.js';
// Mock dependencies
vi.mock('https-proxy-agent');
vi.mock('https');
vi.mock('../../utils/user_account');
vi.mock('../../utils/user_id');
const mockHttps = vi.mocked(https);
const mockUserAccount = vi.mocked(userAccount);
const mockUserId = vi.mocked(userId);
describe('ClearcutLogger', () => {
let mockConfig: Config;
let logger: ClearcutLogger | undefined;
// TODO(richieforeman): Consider moving this to test setup globally.
beforeAll(() => {
server.listen({});
});
afterEach(() => {
server.resetHandlers();
});
afterAll(() => {
server.close();
});
describe('ClearcutLogger', () => {
const NEXT_WAIT_MS = 1234;
const CLEARCUT_URL = 'https://play.googleapis.com/log';
const MOCK_DATE = new Date('2025-01-02T00:00:00.000Z');
const EXAMPLE_RESPONSE = `["${NEXT_WAIT_MS}",null,[[["ANDROID_BACKUP",0],["BATTERY_STATS",0],["SMART_SETUP",0],["TRON",0]],-3334737594024971225],[]]`;
// A helper to get the internal events array for testing
const getEvents = (l: ClearcutLogger): LogEventEntry[][] =>
l['events'].toArray() as LogEventEntry[][];
const getEventsSize = (l: ClearcutLogger): number => l['events'].size;
const getMaxEvents = (l: ClearcutLogger): number => l['max_events'];
const getMaxRetryEvents = (l: ClearcutLogger): number =>
l['max_retry_events'];
const requeueFailedEvents = (l: ClearcutLogger, events: LogEventEntry[][]) =>
l['requeueFailedEvents'](events);
beforeEach(() => {
vi.useFakeTimers();
vi.setSystemTime(new Date());
mockConfig = {
getUsageStatisticsEnabled: vi.fn().mockReturnValue(true),
getDebugMode: vi.fn().mockReturnValue(false),
getSessionId: vi.fn().mockReturnValue('test-session-id'),
getProxy: vi.fn().mockReturnValue(undefined),
} as unknown as Config;
mockUserAccount.getCachedGoogleAccount.mockReturnValue('test@google.com');
mockUserAccount.getLifetimeGoogleAccounts.mockReturnValue(1);
mockUserId.getInstallationId.mockReturnValue('test-installation-id');
logger = ClearcutLogger.getInstance(mockConfig);
expect(logger).toBeDefined();
afterEach(() => {
vi.unstubAllEnvs();
});
function setup({
config = {} as Partial<ConfigParameters>,
lifetimeGoogleAccounts = 1,
cachedGoogleAccount = 'test@google.com',
installationId = 'test-installation-id',
} = {}) {
server.resetHandlers(
http.post(CLEARCUT_URL, () => HttpResponse.text(EXAMPLE_RESPONSE)),
);
vi.useFakeTimers();
vi.setSystemTime(MOCK_DATE);
const loggerConfig = makeFakeConfig({
...config,
});
ClearcutLogger.clearInstance();
mockUserAccount.getCachedGoogleAccount.mockReturnValue(cachedGoogleAccount);
mockUserAccount.getLifetimeGoogleAccounts.mockReturnValue(
lifetimeGoogleAccounts,
);
mockUserId.getInstallationId.mockReturnValue(installationId);
const logger = ClearcutLogger.getInstance(loggerConfig);
return { logger, loggerConfig };
}
afterEach(() => {
ClearcutLogger.clearInstance();
vi.useRealTimers();
vi.restoreAllMocks();
});
it('should not return an instance if usage statistics are disabled', () => {
ClearcutLogger.clearInstance();
vi.spyOn(mockConfig, 'getUsageStatisticsEnabled').mockReturnValue(false);
const disabledLogger = ClearcutLogger.getInstance(mockConfig);
expect(disabledLogger).toBeUndefined();
describe('getInstance', () => {
it.each([
{ usageStatisticsEnabled: false, expectedValue: undefined },
{
usageStatisticsEnabled: true,
expectedValue: expect.any(ClearcutLogger),
},
])(
'returns an instance if usage statistics are enabled',
({ usageStatisticsEnabled, expectedValue }) => {
ClearcutLogger.clearInstance();
const { logger } = setup({
config: {
usageStatisticsEnabled,
},
});
expect(logger).toEqual(expectedValue);
},
);
it('is a singleton', () => {
ClearcutLogger.clearInstance();
const { loggerConfig } = setup();
const logger1 = ClearcutLogger.getInstance(loggerConfig);
const logger2 = ClearcutLogger.getInstance(loggerConfig);
expect(logger1).toBe(logger2);
});
});
describe('createLogEvent', () => {
it('logs the total number of google accounts', () => {
const { logger } = setup({
lifetimeGoogleAccounts: 9001,
});
const event = logger?.createLogEvent('abc', []);
expect(event?.event_metadata[0][0]).toEqual({
gemini_cli_key: EventMetadataKey.GEMINI_CLI_GOOGLE_ACCOUNTS_COUNT,
value: '9001',
});
});
it('logs the current surface from a github action', () => {
const { logger } = setup({});
vi.stubEnv('GITHUB_SHA', '8675309');
const event = logger?.createLogEvent('abc', []);
expect(event?.event_metadata[0][1]).toEqual({
gemini_cli_key: EventMetadataKey.GEMINI_CLI_SURFACE,
value: 'GitHub',
});
});
it('honors the value from env.SURFACE over all others', () => {
const { logger } = setup({});
vi.stubEnv('TERM_PROGRAM', 'vscode');
vi.stubEnv('SURFACE', 'ide-1234');
const event = logger?.createLogEvent('abc', []);
expect(event?.event_metadata[0][1]).toEqual({
gemini_cli_key: EventMetadataKey.GEMINI_CLI_SURFACE,
value: 'ide-1234',
});
});
it.each([
{
env: {
CURSOR_TRACE_ID: 'abc123',
GITHUB_SHA: undefined,
},
expectedValue: 'cursor',
},
{
env: {
TERM_PROGRAM: 'vscode',
GITHUB_SHA: undefined,
},
expectedValue: 'vscode',
},
{
env: {
MONOSPACE_ENV: 'true',
GITHUB_SHA: undefined,
},
expectedValue: 'firebasestudio',
},
{
env: {
__COG_BASHRC_SOURCED: 'true',
GITHUB_SHA: undefined,
},
expectedValue: 'devin',
},
{
env: {
CLOUD_SHELL: 'true',
GITHUB_SHA: undefined,
},
expectedValue: 'cloudshell',
},
])(
'logs the current surface for as $expectedValue, preempting vscode detection',
({ env, expectedValue }) => {
const { logger } = setup({});
// Clear all environment variables that could interfere with surface detection
vi.stubEnv('SURFACE', undefined);
vi.stubEnv('GITHUB_SHA', undefined);
vi.stubEnv('CURSOR_TRACE_ID', undefined);
vi.stubEnv('__COG_BASHRC_SOURCED', undefined);
vi.stubEnv('REPLIT_USER', undefined);
vi.stubEnv('CODESPACES', undefined);
vi.stubEnv('EDITOR_IN_CLOUD_SHELL', undefined);
vi.stubEnv('CLOUD_SHELL', undefined);
vi.stubEnv('TERM_PRODUCT', undefined);
vi.stubEnv('FIREBASE_DEPLOY_AGENT', undefined);
vi.stubEnv('MONOSPACE_ENV', undefined);
// Set the specific environment variables for this test case
for (const [key, value] of Object.entries(env)) {
vi.stubEnv(key, value);
}
vi.stubEnv('TERM_PROGRAM', 'vscode');
const event = logger?.createLogEvent('abc', []);
expect(event?.event_metadata[0][1]).toEqual({
gemini_cli_key: EventMetadataKey.GEMINI_CLI_SURFACE,
value: expectedValue,
});
},
);
});
describe('enqueueLogEvent', () => {
it('should add events to the queue', () => {
const { logger } = setup();
logger!.enqueueLogEvent({ test: 'event1' });
expect(getEventsSize(logger!)).toBe(1);
});
it('should evict the oldest event when the queue is full', () => {
const maxEvents = getMaxEvents(logger!);
const { logger } = setup();
for (let i = 0; i < maxEvents; i++) {
for (let i = 0; i < TEST_ONLY.MAX_EVENTS; i++) {
logger!.enqueueLogEvent({ event_id: i });
}
expect(getEventsSize(logger!)).toBe(maxEvents);
expect(getEventsSize(logger!)).toBe(TEST_ONLY.MAX_EVENTS);
const firstEvent = JSON.parse(
getEvents(logger!)[0][0].source_extension_json,
);
expect(firstEvent.event_id).toBe(0);
// This should push out the first event
logger!.enqueueLogEvent({ event_id: maxEvents });
logger!.enqueueLogEvent({ event_id: TEST_ONLY.MAX_EVENTS });
expect(getEventsSize(logger!)).toBe(maxEvents);
expect(getEventsSize(logger!)).toBe(TEST_ONLY.MAX_EVENTS);
const newFirstEvent = JSON.parse(
getEvents(logger!)[0][0].source_extension_json,
);
expect(newFirstEvent.event_id).toBe(1);
const lastEvent = JSON.parse(
getEvents(logger!)[maxEvents - 1][0].source_extension_json,
getEvents(logger!)[TEST_ONLY.MAX_EVENTS - 1][0].source_extension_json,
);
expect(lastEvent.event_id).toBe(maxEvents);
expect(lastEvent.event_id).toBe(TEST_ONLY.MAX_EVENTS);
});
});
describe('flushToClearcut', () => {
let mockRequest: Writable;
let mockResponse: Readable & Partial<IncomingMessage>;
beforeEach(() => {
mockRequest = new Writable({
write(chunk, encoding, callback) {
callback();
it('allows for usage with a configured proxy agent', async () => {
const { logger } = setup({
config: {
proxy: 'http://mycoolproxy.whatever.com:3128',
},
});
vi.spyOn(mockRequest, 'on');
vi.spyOn(mockRequest, 'end').mockReturnThis();
vi.spyOn(mockRequest, 'destroy').mockReturnThis();
mockResponse = new Readable({ read() {} }) as Readable &
Partial<IncomingMessage>;
logger!.enqueueLogEvent({ event_id: 1 });
mockHttps.request.mockImplementation(
(
_options: string | https.RequestOptions | URL,
...args: unknown[]
): ClientRequest => {
const callback = args.find((arg) => typeof arg === 'function') as
| ((res: IncomingMessage) => void)
| undefined;
const response = await logger!.flushToClearcut();
if (callback) {
callback(mockResponse as IncomingMessage);
}
return mockRequest as ClientRequest;
},
);
expect(response.nextRequestWaitMs).toBe(NEXT_WAIT_MS);
});
it('should clear events on successful flush', async () => {
mockResponse.statusCode = 200;
const mockResponseBody = { nextRequestWaitMs: 1000 };
// Encoded protobuf for {nextRequestWaitMs: 1000} which is `08 E8 07`
const encodedResponse = Buffer.from([8, 232, 7]);
const { logger } = setup();
logger!.enqueueLogEvent({ event_id: 1 });
const flushPromise = logger!.flushToClearcut();
const response = await logger!.flushToClearcut();
mockResponse.push(encodedResponse);
mockResponse.push(null); // End the stream
const response: LogResponse = await flushPromise;
expect(getEventsSize(logger!)).toBe(0);
expect(response.nextRequestWaitMs).toBe(
mockResponseBody.nextRequestWaitMs,
);
expect(getEvents(logger!)).toEqual([]);
expect(response.nextRequestWaitMs).toBe(NEXT_WAIT_MS);
});
it('should handle a network error and requeue events', async () => {
const { logger } = setup();
server.resetHandlers(http.post(CLEARCUT_URL, () => HttpResponse.error()));
logger!.enqueueLogEvent({ event_id: 1 });
logger!.enqueueLogEvent({ event_id: 2 });
expect(getEventsSize(logger!)).toBe(2);
const flushPromise = logger!.flushToClearcut();
mockRequest.emit('error', new Error('Network error'));
await flushPromise;
const x = logger!.flushToClearcut();
await x;
expect(getEventsSize(logger!)).toBe(2);
const events = getEvents(logger!);
@@ -181,18 +311,28 @@ describe('ClearcutLogger', () => {
});
it('should handle an HTTP error and requeue events', async () => {
mockResponse.statusCode = 500;
mockResponse.statusMessage = 'Internal Server Error';
const { logger } = setup();
server.resetHandlers(
http.post(
CLEARCUT_URL,
() =>
new HttpResponse(
{ 'the system is down': true },
{
status: 500,
},
),
),
);
logger!.enqueueLogEvent({ event_id: 1 });
logger!.enqueueLogEvent({ event_id: 2 });
expect(getEventsSize(logger!)).toBe(2);
const flushPromise = logger!.flushToClearcut();
mockResponse.emit('end'); // End the response to trigger promise resolution
await flushPromise;
expect(getEvents(logger!).length).toBe(2);
await logger!.flushToClearcut();
expect(getEventsSize(logger!)).toBe(2);
expect(getEvents(logger!).length).toBe(2);
const events = getEvents(logger!);
expect(JSON.parse(events[0][0].source_extension_json).event_id).toBe(1);
});
@@ -200,7 +340,8 @@ describe('ClearcutLogger', () => {
describe('requeueFailedEvents logic', () => {
it('should limit the number of requeued events to max_retry_events', () => {
const maxRetryEvents = getMaxRetryEvents(logger!);
const { logger } = setup();
const maxRetryEvents = TEST_ONLY.MAX_RETRY_EVENTS;
const eventsToLogCount = maxRetryEvents + 5;
const eventsToSend: LogEventEntry[][] = [];
for (let i = 0; i < eventsToLogCount; i++) {
@@ -225,7 +366,8 @@ describe('ClearcutLogger', () => {
});
it('should not requeue more events than available space in the queue', () => {
const maxEvents = getMaxEvents(logger!);
const { logger } = setup();
const maxEvents = TEST_ONLY.MAX_EVENTS;
const spaceToLeave = 5;
const initialEventCount = maxEvents - spaceToLeave;
for (let i = 0; i < initialEventCount; i++) {

View File

@@ -4,10 +4,7 @@
* SPDX-License-Identifier: Apache-2.0
*/
import { Buffer } from 'buffer';
import * as https from 'https';
import { HttpsProxyAgent } from 'https-proxy-agent';
import {
StartSessionEvent,
EndSessionEvent,
@@ -22,6 +19,7 @@ import {
SlashCommandEvent,
MalformedJsonResponseEvent,
IdeConnectionEvent,
KittySequenceOverflowEvent,
} from '../types.js';
import { EventMetadataKey } from './event-metadata-key.js';
import { Config } from '../../config/config.js';
@@ -32,6 +30,7 @@ import {
} from '../../utils/user_account.js';
import { getInstallationId } from '../../utils/user_id.js';
import { FixedDeque } from 'mnemonist';
import { DetectedIde, detectIde } from '../../ide/detect-ide.js';
const start_session_event_name = 'start_session';
const new_prompt_event_name = 'new_prompt';
@@ -46,6 +45,7 @@ const next_speaker_check_event_name = 'next_speaker_check';
const slash_command_event_name = 'slash_command';
const malformed_json_response_event_name = 'malformed_json_response';
const ide_connection_event_name = 'ide_connection';
const kitty_sequence_overflow_event_name = 'kitty_sequence_overflow';
export interface LogResponse {
nextRequestWaitMs?: number;
@@ -56,19 +56,25 @@ export interface LogEventEntry {
source_extension_json: string;
}
export type EventValue = {
export interface EventValue {
gemini_cli_key: EventMetadataKey | string;
value: string;
};
}
export type LogEvent = {
console_type: string;
export interface LogEvent {
console_type: 'GEMINI_CLI';
application: number;
event_name: string;
event_metadata: EventValue[][];
client_email?: string;
client_install_id?: string;
};
}
export interface LogRequest {
log_source_name: 'CONCORD';
request_time_ms: number;
log_event: LogEventEntry[][];
}
/**
* Determine the surface that the user is currently using. Surface is effectively the
@@ -80,31 +86,70 @@ export type LogEvent = {
* methods might have in their runtimes.
*/
function determineSurface(): string {
if (process.env.CLOUD_SHELL === 'true') {
return 'CLOUD_SHELL';
} else if (process.env.MONOSPACE_ENV === 'true') {
return 'FIREBASE_STUDIO';
if (process.env.SURFACE) {
return process.env.SURFACE;
} else if (process.env.GITHUB_SHA) {
return 'GitHub';
} else if (process.env.TERM_PROGRAM === 'vscode') {
return detectIde() || DetectedIde.VSCode;
} else {
return process.env.SURFACE || 'SURFACE_NOT_SET';
return 'SURFACE_NOT_SET';
}
}
/**
* Clearcut URL to send logging events to.
*/
const CLEARCUT_URL = 'https://play.googleapis.com/log?format=json&hasfast=true';
/**
* Interval in which buffered events are sent to clearcut.
*/
const FLUSH_INTERVAL_MS = 1000 * 60;
/**
* Maximum amount of events to keep in memory. Events added after this amount
* are dropped until the next flush to clearcut, which happens periodically as
* defined by {@link FLUSH_INTERVAL_MS}.
*/
const MAX_EVENTS = 1000;
/**
* Maximum events to retry after a failed clearcut flush
*/
const MAX_RETRY_EVENTS = 100;
// Singleton class for batch posting log events to Clearcut. When a new event comes in, the elapsed time
// is checked and events are flushed to Clearcut if at least a minute has passed since the last flush.
export class ClearcutLogger {
private static instance: ClearcutLogger;
private config?: Config;
/**
* Queue of pending events that need to be flushed to the server. New events
* are added to this queue and then flushed on demand (via `flushToClearcut`)
*/
private readonly events: FixedDeque<LogEventEntry[]>;
private last_flush_time: number = Date.now();
private flush_interval_ms: number = 1000 * 60; // Wait at least a minute before flushing events.
private readonly max_events: number = 1000; // Maximum events to keep in memory
private readonly max_retry_events: number = 100; // Maximum failed events to retry
private flushing: boolean = false; // Prevent concurrent flush operations
private pendingFlush: boolean = false; // Track if a flush was requested during an ongoing flush
/**
* The last time that the events were successfully flushed to the server.
*/
private lastFlushTime: number = Date.now();
/**
* the value is true when there is a pending flush happening. This prevents
* concurrent flush operations.
*/
private flushing: boolean = false;
/**
* This value is true when a flush was requested during an ongoing flush.
*/
private pendingFlush: boolean = false;
private constructor(config?: Config) {
this.config = config;
this.events = new FixedDeque<LogEventEntry[]>(Array, this.max_events);
this.events = new FixedDeque<LogEventEntry[]>(Array, MAX_EVENTS);
}
static getInstance(config?: Config): ClearcutLogger | undefined {
@@ -125,7 +170,7 @@ export class ClearcutLogger {
enqueueLogEvent(event: object): void {
try {
// Manually handle overflow for FixedDeque, which throws when full.
const wasAtCapacity = this.events.size >= this.max_events;
const wasAtCapacity = this.events.size >= MAX_EVENTS;
if (wasAtCapacity) {
this.events.shift(); // Evict oldest element to make space.
@@ -150,31 +195,14 @@ export class ClearcutLogger {
}
}
addDefaultFields(data: EventValue[]): void {
const totalAccounts = getLifetimeGoogleAccounts();
const surface = determineSurface();
const defaultLogMetadata = [
{
gemini_cli_key: EventMetadataKey.GEMINI_CLI_GOOGLE_ACCOUNTS_COUNT,
value: totalAccounts.toString(),
},
{
gemini_cli_key: EventMetadataKey.GEMINI_CLI_SURFACE,
value: surface,
},
];
data.push(...defaultLogMetadata);
}
createLogEvent(name: string, data: EventValue[]): LogEvent {
const email = getCachedGoogleAccount();
// Add default fields that should exist for all logs
this.addDefaultFields(data);
data = addDefaultFields(data);
const logEvent: LogEvent = {
console_type: 'GEMINI_CLI',
application: 102,
application: 102, // GEMINI_CLI
event_name: name,
event_metadata: [data],
};
@@ -190,7 +218,7 @@ export class ClearcutLogger {
}
flushIfNeeded(): void {
if (Date.now() - this.last_flush_time < this.flush_interval_ms) {
if (Date.now() - this.lastFlushTime < FLUSH_INTERVAL_MS) {
return;
}
@@ -217,140 +245,67 @@ export class ClearcutLogger {
const eventsToSend = this.events.toArray() as LogEventEntry[][];
this.events.clear();
return new Promise<{ buffer: Buffer; statusCode?: number }>(
(resolve, reject) => {
const request = [
{
log_source_name: 'CONCORD',
request_time_ms: Date.now(),
log_event: eventsToSend,
},
];
const body = safeJsonStringify(request);
const options = {
hostname: 'play.googleapis.com',
path: '/log',
method: 'POST',
headers: { 'Content-Length': Buffer.byteLength(body) },
timeout: 30000, // 30-second timeout
};
const bufs: Buffer[] = [];
const req = https.request(
{
...options,
agent: this.getProxyAgent(),
},
(res) => {
res.on('error', reject); // Handle stream errors
res.on('data', (buf) => bufs.push(buf));
res.on('end', () => {
try {
const buffer = Buffer.concat(bufs);
// Check if we got a successful response
if (
res.statusCode &&
res.statusCode >= 200 &&
res.statusCode < 300
) {
resolve({ buffer, statusCode: res.statusCode });
} else {
// HTTP error - reject with status code for retry handling
reject(
new Error(`HTTP ${res.statusCode}: ${res.statusMessage}`),
);
}
} catch (e) {
reject(e);
}
});
},
);
req.on('error', (e) => {
// Network-level error
reject(e);
});
req.on('timeout', () => {
if (!req.destroyed) {
req.destroy(new Error('Request timeout after 30 seconds'));
}
});
req.end(body);
const request: LogRequest[] = [
{
log_source_name: 'CONCORD',
request_time_ms: Date.now(),
log_event: eventsToSend,
},
)
.then(({ buffer }) => {
try {
this.last_flush_time = Date.now();
return this.decodeLogResponse(buffer) || {};
} catch (error: unknown) {
console.error('Error decoding log response:', error);
return {};
}
})
.catch((error: unknown) => {
// Handle both network-level and HTTP-level errors
];
let result: LogResponse = {};
try {
const response = await fetch(CLEARCUT_URL, {
method: 'POST',
body: safeJsonStringify(request),
headers: {
'Content-Type': 'application/json',
},
});
const responseBody = await response.text();
if (response.status >= 200 && response.status < 300) {
this.lastFlushTime = Date.now();
const nextRequestWaitMs = Number(JSON.parse(responseBody)[0]);
result = {
...result,
nextRequestWaitMs,
};
} else {
if (this.config?.getDebugMode()) {
console.error('Error flushing log events:', error);
console.error(
`Error flushing log events: HTTP ${response.status}: ${response.statusText}`,
);
}
// Re-queue failed events for retry
this.requeueFailedEvents(eventsToSend);
}
} catch (e: unknown) {
if (this.config?.getDebugMode()) {
console.error('Error flushing log events:', e as Error);
}
// Return empty response to maintain the Promise<LogResponse> contract
return {};
})
.finally(() => {
this.flushing = false;
// Re-queue failed events for retry
this.requeueFailedEvents(eventsToSend);
}
// If a flush was requested while we were flushing, flush again
if (this.pendingFlush) {
this.pendingFlush = false;
// Fire and forget the pending flush
this.flushToClearcut().catch((error) => {
if (this.config?.getDebugMode()) {
console.debug('Error in pending flush to Clearcut:', error);
}
});
this.flushing = false;
// If a flush was requested while we were flushing, flush again
if (this.pendingFlush) {
this.pendingFlush = false;
// Fire and forget the pending flush
this.flushToClearcut().catch((error) => {
if (this.config?.getDebugMode()) {
console.debug('Error in pending flush to Clearcut:', error);
}
});
}
// Visible for testing. Decodes protobuf-encoded response from Clearcut server.
decodeLogResponse(buf: Buffer): LogResponse | undefined {
// TODO(obrienowen): return specific errors to facilitate debugging.
if (buf.length < 1) {
return undefined;
}
// The first byte of the buffer is `field<<3 | type`. We're looking for field
// 1, with type varint, represented by type=0. If the first byte isn't 8, that
// means field 1 is missing or the message is corrupted. Either way, we return
// undefined.
if (buf.readUInt8(0) !== 8) {
return undefined;
}
let ms = BigInt(0);
let cont = true;
// In each byte, the most significant bit is the continuation bit. If it's
// set, we keep going. The lowest 7 bits, are data bits. They are concatenated
// in reverse order to form the final number.
for (let i = 1; cont && i < buf.length; i++) {
const byte = buf.readUInt8(i);
ms |= BigInt(byte & 0x7f) << BigInt(7 * (i - 1));
cont = (byte & 0x80) !== 0;
}
if (cont) {
// We have fallen off the buffer without seeing a terminating byte. The
// message is corrupted.
return undefined;
}
const returnVal = {
nextRequestWaitMs: Number(ms),
};
return returnVal;
return result;
}
logStartSessionEvent(event: StartSessionEvent): void {
@@ -687,6 +642,13 @@ export class ClearcutLogger {
});
}
if (event.status) {
data.push({
gemini_cli_key: EventMetadataKey.GEMINI_CLI_SLASH_COMMAND_STATUS,
value: JSON.stringify(event.status),
});
}
this.enqueueLogEvent(this.createLogEvent(slash_command_event_name, data));
this.flushIfNeeded();
}
@@ -718,6 +680,24 @@ export class ClearcutLogger {
this.flushIfNeeded();
}
logKittySequenceOverflowEvent(event: KittySequenceOverflowEvent): void {
const data: EventValue[] = [
{
gemini_cli_key: EventMetadataKey.GEMINI_CLI_KITTY_SEQUENCE_LENGTH,
value: event.sequence_length.toString(),
},
{
gemini_cli_key: EventMetadataKey.GEMINI_CLI_KITTY_TRUNCATED_SEQUENCE,
value: event.truncated_sequence,
},
];
this.enqueueLogEvent(
this.createLogEvent(kitty_sequence_overflow_event_name, data),
);
this.flushIfNeeded();
}
logEndSessionEvent(event: EndSessionEvent): void {
const data: EventValue[] = [
{
@@ -752,24 +732,21 @@ export class ClearcutLogger {
private requeueFailedEvents(eventsToSend: LogEventEntry[][]): void {
// Add the events back to the front of the queue to be retried, but limit retry queue size
const eventsToRetry = eventsToSend.slice(-this.max_retry_events); // Keep only the most recent events
const eventsToRetry = eventsToSend.slice(-MAX_RETRY_EVENTS); // Keep only the most recent events
// Log a warning if we're dropping events
if (
eventsToSend.length > this.max_retry_events &&
this.config?.getDebugMode()
) {
if (eventsToSend.length > MAX_RETRY_EVENTS && this.config?.getDebugMode()) {
console.warn(
`ClearcutLogger: Dropping ${
eventsToSend.length - this.max_retry_events
eventsToSend.length - MAX_RETRY_EVENTS
} events due to retry queue limit. Total events: ${
eventsToSend.length
}, keeping: ${this.max_retry_events}`,
}, keeping: ${MAX_RETRY_EVENTS}`,
);
}
// Determine how many events can be re-queued
const availableSpace = this.max_events - this.events.size;
const availableSpace = MAX_EVENTS - this.events.size;
const numEventsToRequeue = Math.min(eventsToRetry.length, availableSpace);
if (numEventsToRequeue === 0) {
@@ -792,7 +769,7 @@ export class ClearcutLogger {
this.events.unshift(eventsToRequeue[i]);
}
// Clear any potential overflow
while (this.events.size > this.max_events) {
while (this.events.size > MAX_EVENTS) {
this.events.pop();
}
@@ -803,3 +780,28 @@ export class ClearcutLogger {
}
}
}
/**
* Adds default fields to data, and returns a new data array. This fields
* should exist on all log events.
*/
function addDefaultFields(data: EventValue[]): EventValue[] {
const totalAccounts = getLifetimeGoogleAccounts();
const surface = determineSurface();
const defaultLogMetadata: EventValue[] = [
{
gemini_cli_key: EventMetadataKey.GEMINI_CLI_GOOGLE_ACCOUNTS_COUNT,
value: `${totalAccounts}`,
},
{
gemini_cli_key: EventMetadataKey.GEMINI_CLI_SURFACE,
value: surface,
},
];
return [...data, ...defaultLogMetadata];
}
export const TEST_ONLY = {
MAX_RETRY_EVENTS,
MAX_EVENTS,
};

View File

@@ -174,6 +174,9 @@ export enum EventMetadataKey {
// Logs the subcommand of the slash command.
GEMINI_CLI_SLASH_COMMAND_SUBCOMMAND = 42,
// Logs the status of the slash command (e.g. 'success', 'error')
GEMINI_CLI_SLASH_COMMAND_STATUS = 51,
// ==========================================================================
// Next Speaker Check Event Keys
// ===========================================================================
@@ -209,6 +212,16 @@ export enum EventMetadataKey {
// Logs user removed lines in edit/write tool response.
GEMINI_CLI_USER_REMOVED_LINES = 50,
// ==========================================================================
// Kitty Sequence Overflow Event Keys
// ===========================================================================
// Logs the length of the kitty sequence that overflowed.
GEMINI_CLI_KITTY_SEQUENCE_LENGTH = 53,
// Logs the truncated kitty sequence.
GEMINI_CLI_KITTY_TRUNCATED_SEQUENCE = 52,
}
export function getEventMetadataKey(

View File

@@ -28,6 +28,7 @@ export {
logApiResponse,
logFlashFallback,
logSlashCommand,
logKittySequenceOverflow,
} from './loggers.js';
export {
StartSessionEvent,
@@ -39,7 +40,10 @@ export {
ApiResponseEvent,
TelemetryEvent,
FlashFallbackEvent,
KittySequenceOverflowEvent,
SlashCommandEvent,
makeSlashCommandEvent,
SlashCommandStatus,
} from './types.js';
export { SpanStatusCode, ValueType } from '@opentelemetry/api';
export { SemanticAttributes } from '@opentelemetry/semantic-conventions';

View File

@@ -4,45 +4,47 @@
* SPDX-License-Identifier: Apache-2.0
*/
import { logs, LogRecord, LogAttributes } from '@opentelemetry/api-logs';
import { LogAttributes, LogRecord, logs } from '@opentelemetry/api-logs';
import { SemanticAttributes } from '@opentelemetry/semantic-conventions';
import { Config } from '../config/config.js';
import { safeJsonStringify } from '../utils/safeJsonStringify.js';
import { ClearcutLogger } from './clearcut-logger/clearcut-logger.js';
import {
EVENT_API_ERROR,
EVENT_API_REQUEST,
EVENT_API_RESPONSE,
EVENT_CLI_CONFIG,
EVENT_FLASH_FALLBACK,
EVENT_IDE_CONNECTION,
EVENT_NEXT_SPEAKER_CHECK,
EVENT_SLASH_COMMAND,
EVENT_TOOL_CALL,
EVENT_USER_PROMPT,
EVENT_FLASH_FALLBACK,
EVENT_NEXT_SPEAKER_CHECK,
SERVICE_NAME,
EVENT_SLASH_COMMAND,
} from './constants.js';
import {
recordApiErrorMetrics,
recordApiResponseMetrics,
recordTokenUsageMetrics,
recordToolCallMetrics,
} from './metrics.js';
import { QwenLogger } from './qwen-logger/qwen-logger.js';
import { isTelemetrySdkInitialized } from './sdk.js';
import {
ApiErrorEvent,
ApiRequestEvent,
ApiResponseEvent,
FlashFallbackEvent,
IdeConnectionEvent,
KittySequenceOverflowEvent,
LoopDetectedEvent,
NextSpeakerCheckEvent,
SlashCommandEvent,
StartSessionEvent,
ToolCallEvent,
UserPromptEvent,
FlashFallbackEvent,
NextSpeakerCheckEvent,
LoopDetectedEvent,
SlashCommandEvent,
} from './types.js';
import {
recordApiErrorMetrics,
recordTokenUsageMetrics,
recordApiResponseMetrics,
recordToolCallMetrics,
} from './metrics.js';
import { isTelemetrySdkInitialized } from './sdk.js';
import { uiTelemetryService, UiEvent } from './uiTelemetry.js';
import { QwenLogger } from './qwen-logger/qwen-logger.js';
import { safeJsonStringify } from '../utils/safeJsonStringify.js';
import { UiEvent, uiTelemetryService } from './uiTelemetry.js';
const shouldLogUserPrompts = (config: Config): boolean =>
config.getTelemetryLogPromptsEnabled();
@@ -377,3 +379,21 @@ export function logIdeConnection(
};
logger.emit(logRecord);
}
export function logKittySequenceOverflow(
config: Config,
event: KittySequenceOverflowEvent,
): void {
ClearcutLogger.getInstance(config)?.logKittySequenceOverflowEvent(event);
if (!isTelemetrySdkInitialized()) return;
const attributes: LogAttributes = {
...getCommonAttributes(config),
...event,
};
const logger = logs.getLogger(SERVICE_NAME);
const logRecord: LogRecord = {
body: `Kitty sequence buffer overflow: ${event.sequence_length} bytes`,
attributes,
};
logger.emit(logRecord);
}

View File

@@ -124,24 +124,32 @@ export function initializeTelemetry(config: Config): void {
try {
sdk.start();
console.log('OpenTelemetry SDK started successfully.');
if (config.getDebugMode()) {
console.log('OpenTelemetry SDK started successfully.');
}
telemetryInitialized = true;
initializeMetrics(config);
} catch (error) {
console.error('Error starting OpenTelemetry SDK:', error);
}
process.on('SIGTERM', shutdownTelemetry);
process.on('SIGINT', shutdownTelemetry);
process.on('SIGTERM', () => {
shutdownTelemetry(config);
});
process.on('SIGINT', () => {
shutdownTelemetry(config);
});
}
export async function shutdownTelemetry(): Promise<void> {
export async function shutdownTelemetry(config: Config): Promise<void> {
if (!telemetryInitialized || !sdk) {
return;
}
try {
await sdk.shutdown();
console.log('OpenTelemetry SDK shut down successfully.');
if (config.getDebugMode()) {
console.log('OpenTelemetry SDK shut down successfully.');
}
} catch (error) {
console.error('Error shutting down SDK:', error);
} finally {

View File

@@ -45,7 +45,7 @@ describe('telemetry', () => {
afterEach(async () => {
// Ensure we shut down telemetry even if a test fails.
if (isTelemetrySdkInitialized()) {
await shutdownTelemetry();
await shutdownTelemetry(mockConfig);
}
});
@@ -57,7 +57,7 @@ describe('telemetry', () => {
it('should shutdown the telemetry service', async () => {
initializeTelemetry(mockConfig);
await shutdownTelemetry();
await shutdownTelemetry(mockConfig);
expect(mockNodeSdk.shutdown).toHaveBeenCalled();
});

View File

@@ -14,9 +14,17 @@ import {
ToolCallDecision,
} from './tool-call-decision.js';
export class StartSessionEvent {
interface BaseTelemetryEvent {
'event.name': string;
/** Current timestamp in ISO 8601 format */
'event.timestamp': string;
}
type CommonFields = keyof BaseTelemetryEvent;
export class StartSessionEvent implements BaseTelemetryEvent {
'event.name': 'cli_config';
'event.timestamp': string; // ISO 8601
'event.timestamp': string;
model: string;
embedding_model: string;
sandbox_enabled: boolean;
@@ -60,9 +68,9 @@ export class StartSessionEvent {
}
}
export class EndSessionEvent {
export class EndSessionEvent implements BaseTelemetryEvent {
'event.name': 'end_session';
'event.timestamp': string; // ISO 8601
'event.timestamp': string;
session_id?: string;
constructor(config?: Config) {
@@ -72,9 +80,9 @@ export class EndSessionEvent {
}
}
export class UserPromptEvent {
export class UserPromptEvent implements BaseTelemetryEvent {
'event.name': 'user_prompt';
'event.timestamp': string; // ISO 8601
'event.timestamp': string;
prompt_length: number;
prompt_id: string;
auth_type?: string;
@@ -95,9 +103,9 @@ export class UserPromptEvent {
}
}
export class ToolCallEvent {
export class ToolCallEvent implements BaseTelemetryEvent {
'event.name': 'tool_call';
'event.timestamp': string; // ISO 8601
'event.timestamp': string;
function_name: string;
function_args: Record<string, unknown>;
duration_ms: number;
@@ -142,9 +150,9 @@ export class ToolCallEvent {
}
}
export class ApiRequestEvent {
export class ApiRequestEvent implements BaseTelemetryEvent {
'event.name': 'api_request';
'event.timestamp': string; // ISO 8601
'event.timestamp': string;
model: string;
prompt_id: string;
request_text?: string;
@@ -158,7 +166,7 @@ export class ApiRequestEvent {
}
}
export class ApiErrorEvent {
export class ApiErrorEvent implements BaseTelemetryEvent {
'event.name': 'api_error';
'event.timestamp': string; // ISO 8601
response_id?: string;
@@ -193,7 +201,7 @@ export class ApiErrorEvent {
}
}
export class ApiResponseEvent {
export class ApiResponseEvent implements BaseTelemetryEvent {
'event.name': 'api_response';
'event.timestamp': string; // ISO 8601
response_id: string;
@@ -240,9 +248,9 @@ export class ApiResponseEvent {
}
}
export class FlashFallbackEvent {
export class FlashFallbackEvent implements BaseTelemetryEvent {
'event.name': 'flash_fallback';
'event.timestamp': string; // ISO 8601
'event.timestamp': string;
auth_type: string;
constructor(auth_type: string) {
@@ -258,9 +266,9 @@ export enum LoopType {
LLM_DETECTED_LOOP = 'llm_detected_loop',
}
export class LoopDetectedEvent {
export class LoopDetectedEvent implements BaseTelemetryEvent {
'event.name': 'loop_detected';
'event.timestamp': string; // ISO 8601
'event.timestamp': string;
loop_type: LoopType;
prompt_id: string;
@@ -272,9 +280,9 @@ export class LoopDetectedEvent {
}
}
export class NextSpeakerCheckEvent {
export class NextSpeakerCheckEvent implements BaseTelemetryEvent {
'event.name': 'next_speaker_check';
'event.timestamp': string; // ISO 8601
'event.timestamp': string;
prompt_id: string;
finish_reason: string;
result: string;
@@ -288,23 +296,36 @@ export class NextSpeakerCheckEvent {
}
}
export class SlashCommandEvent {
export interface SlashCommandEvent extends BaseTelemetryEvent {
'event.name': 'slash_command';
'event.timestamp': string; // ISO 8106
command: string;
subcommand?: string;
constructor(command: string, subcommand?: string) {
this['event.name'] = 'slash_command';
this['event.timestamp'] = new Date().toISOString();
this.command = command;
this.subcommand = subcommand;
}
status?: SlashCommandStatus;
}
export class MalformedJsonResponseEvent {
export function makeSlashCommandEvent({
command,
subcommand,
status,
}: Omit<SlashCommandEvent, CommonFields>): SlashCommandEvent {
return {
'event.name': 'slash_command',
'event.timestamp': new Date().toISOString(),
command,
subcommand,
status,
};
}
export enum SlashCommandStatus {
SUCCESS = 'success',
ERROR = 'error',
}
export class MalformedJsonResponseEvent implements BaseTelemetryEvent {
'event.name': 'malformed_json_response';
'event.timestamp': string; // ISO 8601
'event.timestamp': string;
model: string;
constructor(model: string) {
@@ -321,7 +342,7 @@ export enum IdeConnectionType {
export class IdeConnectionEvent {
'event.name': 'ide_connection';
'event.timestamp': string; // ISO 8601
'event.timestamp': string;
connection_type: IdeConnectionType;
constructor(connection_type: IdeConnectionType) {
@@ -331,6 +352,20 @@ export class IdeConnectionEvent {
}
}
export class KittySequenceOverflowEvent {
'event.name': 'kitty_sequence_overflow';
'event.timestamp': string; // ISO 8601
sequence_length: number;
truncated_sequence: string;
constructor(sequence_length: number, truncated_sequence: string) {
this['event.name'] = 'kitty_sequence_overflow';
this['event.timestamp'] = new Date().toISOString();
this.sequence_length = sequence_length;
// Truncate to first 20 chars for logging (avoid logging sensitive data)
this.truncated_sequence = truncated_sequence.substring(0, 20);
}
}
export type TelemetryEvent =
| StartSessionEvent
| EndSessionEvent
@@ -342,6 +377,7 @@ export type TelemetryEvent =
| FlashFallbackEvent
| LoopDetectedEvent
| NextSpeakerCheckEvent
| SlashCommandEvent
| KittySequenceOverflowEvent
| MalformedJsonResponseEvent
| IdeConnectionEvent;
| IdeConnectionEvent
| SlashCommandEvent;

View File

@@ -0,0 +1,36 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { Config, ConfigParameters } from '../config/config.js';
/**
* Default parameters used for {@link FAKE_CONFIG}
*/
export const DEFAULT_CONFIG_PARAMETERS: ConfigParameters = {
usageStatisticsEnabled: true,
debugMode: false,
sessionId: 'test-session-id',
proxy: undefined,
model: 'gemini-9001-super-duper',
targetDir: '/',
cwd: '/',
};
/**
* Produces a config. Default paramters are set to
* {@link DEFAULT_CONFIG_PARAMETERS}, optionally, fields can be specified to
* override those defaults.
*/
export function makeFakeConfig(
config: Partial<ConfigParameters> = {
...DEFAULT_CONFIG_PARAMETERS,
},
): Config {
return new Config({
...DEFAULT_CONFIG_PARAMETERS,
...config,
});
}

View File

@@ -7,9 +7,9 @@
import { vi } from 'vitest';
import {
BaseTool,
Icon,
ToolCallConfirmationDetails,
ToolResult,
Kind,
} from '../tools/tools.js';
import { Schema, Type } from '@google/genai';
@@ -29,7 +29,7 @@ export class MockTool extends BaseTool<{ [key: string]: unknown }, ToolResult> {
properties: { param: { type: Type.STRING } },
},
) {
super(name, displayName ?? name, description, Icon.Hammer, params);
super(name, displayName ?? name, description, Kind.Other, params);
}
async execute(
@@ -45,7 +45,7 @@ export class MockTool extends BaseTool<{ [key: string]: unknown }, ToolResult> {
);
}
async shouldConfirmExecute(
override async shouldConfirmExecute(
_params: { [key: string]: unknown },
_abortSignal: AbortSignal,
): Promise<ToolCallConfirmationDetails | false> {

View File

@@ -62,7 +62,6 @@ describe('EditTool', () => {
getWorkspaceContext: () => createMockWorkspaceContext(rootDir),
getIdeClient: () => undefined,
getIdeMode: () => false,
getIdeModeFeature: () => false,
// getGeminiConfig: () => ({ apiKey: 'test-api-key' }), // This was not a real Config method
// Add other properties/methods of Config if EditTool uses them
// Minimal other methods to satisfy Config type if needed by EditTool constructor or other direct uses:
@@ -810,7 +809,6 @@ describe('EditTool', () => {
}),
};
(mockConfig as any).getIdeMode = () => true;
(mockConfig as any).getIdeModeFeature = () => true;
(mockConfig as any).getIdeClient = () => ideClient;
});

View File

@@ -9,7 +9,7 @@ import * as path from 'path';
import * as Diff from 'diff';
import {
BaseDeclarativeTool,
Icon,
Kind,
ToolCallConfirmationDetails,
ToolConfirmationOutcome,
ToolEditConfirmationDetails,
@@ -250,7 +250,6 @@ class EditToolInvocation implements ToolInvocation<EditToolParams, ToolResult> {
);
const ideClient = this.config.getIdeClient();
const ideConfirmation =
this.config.getIdeModeFeature() &&
this.config.getIdeMode() &&
ideClient?.getConnectionStatus().status === IDEConnectionStatus.Connected
? ideClient.openDiff(this.params.file_path, editData.newContent)
@@ -436,7 +435,7 @@ Expectation for required parameters:
4. NEVER escape \`old_string\` or \`new_string\`, that would break the exact literal text requirement.
**Important:** If ANY of the above are not satisfied, the tool will fail. CRITICAL for \`old_string\`: Must uniquely identify the single instance to change. Include at least 3 lines of context BEFORE and AFTER the target text, matching whitespace and indentation precisely. If this string matches multiple locations, or does not match exactly, the tool will fail.
**Multiple replacements:** Set \`expected_replacements\` to the number of occurrences you want to replace. The tool will replace ALL occurrences that match \`old_string\` exactly. Ensure the number of replacements matches your expectation.`,
Icon.Pencil,
Kind.Edit,
{
properties: {
file_path: {
@@ -472,7 +471,7 @@ Expectation for required parameters:
* @param params Parameters to validate
* @returns Error message string or null if valid
*/
validateToolParams(params: EditToolParams): string | null {
override validateToolParams(params: EditToolParams): string | null {
const errors = SchemaValidator.validate(
this.schema.parametersJsonSchema,
params,

View File

@@ -11,7 +11,7 @@ import { SchemaValidator } from '../utils/schemaValidator.js';
import {
BaseDeclarativeTool,
BaseToolInvocation,
Icon,
Kind,
ToolInvocation,
ToolResult,
} from './tools.js';
@@ -248,7 +248,7 @@ export class GlobTool extends BaseDeclarativeTool<GlobToolParams, ToolResult> {
GlobTool.Name,
'FindFiles',
'Efficiently finds files matching specific glob patterns (e.g., `src/**/*.ts`, `**/*.md`), returning absolute paths sorted by modification time (newest first). Ideal for quickly locating files based on their name or path structure, especially in large codebases.',
Icon.FileSearch,
Kind.Search,
{
properties: {
pattern: {
@@ -281,7 +281,7 @@ export class GlobTool extends BaseDeclarativeTool<GlobToolParams, ToolResult> {
/**
* Validates the parameters for the tool.
*/
validateToolParams(params: GlobToolParams): string | null {
override validateToolParams(params: GlobToolParams): string | null {
const errors = SchemaValidator.validate(
this.schema.parametersJsonSchema,
params,

View File

@@ -13,7 +13,7 @@ import { globStream } from 'glob';
import {
BaseDeclarativeTool,
BaseToolInvocation,
Icon,
Kind,
ToolInvocation,
ToolResult,
} from './tools.js';
@@ -594,7 +594,7 @@ export class GrepTool extends BaseDeclarativeTool<GrepToolParams, ToolResult> {
GrepTool.Name,
'SearchText',
'Searches for a regular expression pattern within the content of files in a specified directory (or current working directory). Can filter files by a glob pattern. Returns the lines containing matches, along with their file paths and line numbers.',
Icon.Regex,
Kind.Search,
{
properties: {
pattern: {
@@ -672,7 +672,7 @@ export class GrepTool extends BaseDeclarativeTool<GrepToolParams, ToolResult> {
* @param params Parameters to validate
* @returns An error message string if invalid, null otherwise
*/
validateToolParams(params: GrepToolParams): string | null {
override validateToolParams(params: GrepToolParams): string | null {
const errors = SchemaValidator.validate(
this.schema.parametersJsonSchema,
params,

View File

@@ -74,9 +74,11 @@ describe('LSTool', () => {
const params = {
path: '/home/user/project/src',
};
const error = lsTool.validateToolParams(params);
expect(error).toBeNull();
vi.mocked(fs.statSync).mockReturnValue({
isDirectory: () => true,
} as fs.Stats);
const invocation = lsTool.build(params);
expect(invocation).toBeDefined();
});
it('should reject relative paths', () => {
@@ -84,8 +86,9 @@ describe('LSTool', () => {
path: './src',
};
const error = lsTool.validateToolParams(params);
expect(error).toBe('Path must be absolute: ./src');
expect(() => lsTool.build(params)).toThrow(
'Path must be absolute: ./src',
);
});
it('should reject paths outside workspace with clear error message', () => {
@@ -93,8 +96,7 @@ describe('LSTool', () => {
path: '/etc/passwd',
};
const error = lsTool.validateToolParams(params);
expect(error).toBe(
expect(() => lsTool.build(params)).toThrow(
'Path must be within one of the workspace directories: /home/user/project, /home/user/other-project',
);
});
@@ -103,9 +105,11 @@ describe('LSTool', () => {
const params = {
path: '/home/user/other-project/lib',
};
const error = lsTool.validateToolParams(params);
expect(error).toBeNull();
vi.mocked(fs.statSync).mockReturnValue({
isDirectory: () => true,
} as fs.Stats);
const invocation = lsTool.build(params);
expect(invocation).toBeDefined();
});
});
@@ -133,10 +137,8 @@ describe('LSTool', () => {
vi.mocked(fs.readdirSync).mockReturnValue(mockFiles as any);
const result = await lsTool.execute(
{ path: testPath },
new AbortController().signal,
);
const invocation = lsTool.build({ path: testPath });
const result = await invocation.execute(new AbortController().signal);
expect(result.llmContent).toContain('[DIR] subdir');
expect(result.llmContent).toContain('file1.ts');
@@ -161,10 +163,8 @@ describe('LSTool', () => {
vi.mocked(fs.readdirSync).mockReturnValue(mockFiles as any);
const result = await lsTool.execute(
{ path: testPath },
new AbortController().signal,
);
const invocation = lsTool.build({ path: testPath });
const result = await invocation.execute(new AbortController().signal);
expect(result.llmContent).toContain('module1.js');
expect(result.llmContent).toContain('module2.js');
@@ -179,10 +179,8 @@ describe('LSTool', () => {
} as fs.Stats);
vi.mocked(fs.readdirSync).mockReturnValue([]);
const result = await lsTool.execute(
{ path: testPath },
new AbortController().signal,
);
const invocation = lsTool.build({ path: testPath });
const result = await invocation.execute(new AbortController().signal);
expect(result.llmContent).toBe(
'Directory /home/user/project/empty is empty.',
@@ -207,10 +205,11 @@ describe('LSTool', () => {
});
vi.mocked(fs.readdirSync).mockReturnValue(mockFiles as any);
const result = await lsTool.execute(
{ path: testPath, ignore: ['*.spec.js'] },
new AbortController().signal,
);
const invocation = lsTool.build({
path: testPath,
ignore: ['*.spec.js'],
});
const result = await invocation.execute(new AbortController().signal);
expect(result.llmContent).toContain('test.js');
expect(result.llmContent).toContain('index.js');
@@ -238,10 +237,8 @@ describe('LSTool', () => {
(path: string) => path.includes('ignored.js'),
);
const result = await lsTool.execute(
{ path: testPath },
new AbortController().signal,
);
const invocation = lsTool.build({ path: testPath });
const result = await invocation.execute(new AbortController().signal);
expect(result.llmContent).toContain('file1.js');
expect(result.llmContent).toContain('file2.js');
@@ -269,10 +266,8 @@ describe('LSTool', () => {
(path: string) => path.includes('private.js'),
);
const result = await lsTool.execute(
{ path: testPath },
new AbortController().signal,
);
const invocation = lsTool.build({ path: testPath });
const result = await invocation.execute(new AbortController().signal);
expect(result.llmContent).toContain('file1.js');
expect(result.llmContent).toContain('file2.js');
@@ -287,10 +282,8 @@ describe('LSTool', () => {
isDirectory: () => false,
} as fs.Stats);
const result = await lsTool.execute(
{ path: testPath },
new AbortController().signal,
);
const invocation = lsTool.build({ path: testPath });
const result = await invocation.execute(new AbortController().signal);
expect(result.llmContent).toContain('Path is not a directory');
expect(result.returnDisplay).toBe('Error: Path is not a directory.');
@@ -303,10 +296,8 @@ describe('LSTool', () => {
throw new Error('ENOENT: no such file or directory');
});
const result = await lsTool.execute(
{ path: testPath },
new AbortController().signal,
);
const invocation = lsTool.build({ path: testPath });
const result = await invocation.execute(new AbortController().signal);
expect(result.llmContent).toContain('Error listing directory');
expect(result.returnDisplay).toBe('Error: Failed to list directory.');
@@ -336,10 +327,8 @@ describe('LSTool', () => {
vi.mocked(fs.readdirSync).mockReturnValue(mockFiles as any);
const result = await lsTool.execute(
{ path: testPath },
new AbortController().signal,
);
const invocation = lsTool.build({ path: testPath });
const result = await invocation.execute(new AbortController().signal);
const lines = (
typeof result.llmContent === 'string' ? result.llmContent : ''
@@ -361,24 +350,18 @@ describe('LSTool', () => {
throw new Error('EACCES: permission denied');
});
const result = await lsTool.execute(
{ path: testPath },
new AbortController().signal,
);
const invocation = lsTool.build({ path: testPath });
const result = await invocation.execute(new AbortController().signal);
expect(result.llmContent).toContain('Error listing directory');
expect(result.llmContent).toContain('permission denied');
expect(result.returnDisplay).toBe('Error: Failed to list directory.');
});
it('should validate parameters and return error for invalid params', async () => {
const result = await lsTool.execute(
{ path: '../outside' },
new AbortController().signal,
it('should throw for invalid params at build time', async () => {
expect(() => lsTool.build({ path: '../outside' })).toThrow(
'Path must be absolute: ../outside',
);
expect(result.llmContent).toContain('Invalid parameters provided');
expect(result.returnDisplay).toBe('Error: Failed to execute tool.');
});
it('should handle errors accessing individual files during listing', async () => {
@@ -406,10 +389,8 @@ describe('LSTool', () => {
.spyOn(console, 'error')
.mockImplementation(() => {});
const result = await lsTool.execute(
{ path: testPath },
new AbortController().signal,
);
const invocation = lsTool.build({ path: testPath });
const result = await invocation.execute(new AbortController().signal);
// Should still list the accessible file
expect(result.llmContent).toContain('accessible.ts');
@@ -428,19 +409,25 @@ describe('LSTool', () => {
describe('getDescription', () => {
it('should return shortened relative path', () => {
const params = {
path: path.join(mockPrimaryDir, 'deeply', 'nested', 'directory'),
path: `${mockPrimaryDir}/deeply/nested/directory`,
};
const description = lsTool.getDescription(params);
vi.mocked(fs.statSync).mockReturnValue({
isDirectory: () => true,
} as fs.Stats);
const invocation = lsTool.build(params);
const description = invocation.getDescription();
expect(description).toBe(path.join('deeply', 'nested', 'directory'));
});
it('should handle paths in secondary workspace', () => {
const params = {
path: path.join(mockSecondaryDir, 'lib'),
path: `${mockSecondaryDir}/lib`,
};
const description = lsTool.getDescription(params);
vi.mocked(fs.statSync).mockReturnValue({
isDirectory: () => true,
} as fs.Stats);
const invocation = lsTool.build(params);
const description = invocation.getDescription();
expect(description).toBe(path.join('..', 'other-project', 'lib'));
});
});
@@ -448,22 +435,25 @@ describe('LSTool', () => {
describe('workspace boundary validation', () => {
it('should accept paths in primary workspace directory', () => {
const params = { path: `${mockPrimaryDir}/src` };
expect(lsTool.validateToolParams(params)).toBeNull();
vi.mocked(fs.statSync).mockReturnValue({
isDirectory: () => true,
} as fs.Stats);
expect(lsTool.build(params)).toBeDefined();
});
it('should accept paths in secondary workspace directory', () => {
const params = { path: `${mockSecondaryDir}/lib` };
expect(lsTool.validateToolParams(params)).toBeNull();
vi.mocked(fs.statSync).mockReturnValue({
isDirectory: () => true,
} as fs.Stats);
expect(lsTool.build(params)).toBeDefined();
});
it('should reject paths outside all workspace directories', () => {
const params = { path: '/etc/passwd' };
const error = lsTool.validateToolParams(params);
expect(error).toContain(
expect(() => lsTool.build(params)).toThrow(
'Path must be within one of the workspace directories',
);
expect(error).toContain(mockPrimaryDir);
expect(error).toContain(mockSecondaryDir);
});
it('should list files from secondary workspace directory', async () => {
@@ -483,10 +473,8 @@ describe('LSTool', () => {
vi.mocked(fs.readdirSync).mockReturnValue(mockFiles as any);
const result = await lsTool.execute(
{ path: testPath },
new AbortController().signal,
);
const invocation = lsTool.build({ path: testPath });
const result = await invocation.execute(new AbortController().signal);
expect(result.llmContent).toContain('test1.spec.ts');
expect(result.llmContent).toContain('test2.spec.ts');

View File

@@ -6,7 +6,13 @@
import fs from 'fs';
import path from 'path';
import { BaseTool, Icon, ToolResult } from './tools.js';
import {
BaseDeclarativeTool,
BaseToolInvocation,
Kind,
ToolInvocation,
ToolResult,
} from './tools.js';
import { SchemaValidator } from '../utils/schemaValidator.js';
import { makeRelative, shortenPath } from '../utils/paths.js';
import { Config, DEFAULT_FILE_FILTERING_OPTIONS } from '../config/config.js';
@@ -64,79 +70,12 @@ export interface FileEntry {
modifiedTime: Date;
}
/**
* Implementation of the LS tool logic
*/
export class LSTool extends BaseTool<LSToolParams, ToolResult> {
static readonly Name = 'list_directory';
constructor(private config: Config) {
super(
LSTool.Name,
'ReadFolder',
'Lists the names of files and subdirectories directly within a specified directory path. Can optionally ignore entries matching provided glob patterns.',
Icon.Folder,
{
properties: {
path: {
description:
'The absolute path to the directory to list (must be absolute, not relative)',
type: 'string',
},
ignore: {
description: 'List of glob patterns to ignore',
items: {
type: 'string',
},
type: 'array',
},
file_filtering_options: {
description:
'Optional: Whether to respect ignore patterns from .gitignore or .geminiignore',
type: 'object',
properties: {
respect_git_ignore: {
description:
'Optional: Whether to respect .gitignore patterns when listing files. Only available in git repositories. Defaults to true.',
type: 'boolean',
},
respect_gemini_ignore: {
description:
'Optional: Whether to respect .geminiignore patterns when listing files. Defaults to true.',
type: 'boolean',
},
},
},
},
required: ['path'],
type: 'object',
},
);
}
/**
* Validates the parameters for the tool
* @param params Parameters to validate
* @returns An error message string if invalid, null otherwise
*/
validateToolParams(params: LSToolParams): string | null {
const errors = SchemaValidator.validate(
this.schema.parametersJsonSchema,
params,
);
if (errors) {
return errors;
}
if (!path.isAbsolute(params.path)) {
return `Path must be absolute: ${params.path}`;
}
const workspaceContext = this.config.getWorkspaceContext();
if (!workspaceContext.isPathWithinWorkspace(params.path)) {
const directories = workspaceContext.getDirectories();
return `Path must be within one of the workspace directories: ${directories.join(', ')}`;
}
return null;
class LSToolInvocation extends BaseToolInvocation<LSToolParams, ToolResult> {
constructor(
private readonly config: Config,
params: LSToolParams,
) {
super(params);
}
/**
@@ -165,11 +104,13 @@ export class LSTool extends BaseTool<LSToolParams, ToolResult> {
/**
* Gets a description of the file reading operation
* @param params Parameters for the file reading
* @returns A string describing the file being read
*/
getDescription(params: LSToolParams): string {
const relativePath = makeRelative(params.path, this.config.getTargetDir());
getDescription(): string {
const relativePath = makeRelative(
this.params.path,
this.config.getTargetDir(),
);
return shortenPath(relativePath);
}
@@ -184,49 +125,37 @@ export class LSTool extends BaseTool<LSToolParams, ToolResult> {
/**
* Executes the LS operation with the given parameters
* @param params Parameters for the LS operation
* @returns Result of the LS operation
*/
async execute(
params: LSToolParams,
_signal: AbortSignal,
): Promise<ToolResult> {
const validationError = this.validateToolParams(params);
if (validationError) {
return this.errorResult(
`Error: Invalid parameters provided. Reason: ${validationError}`,
`Failed to execute tool.`,
);
}
async execute(_signal: AbortSignal): Promise<ToolResult> {
try {
const stats = fs.statSync(params.path);
const stats = fs.statSync(this.params.path);
if (!stats) {
// fs.statSync throws on non-existence, so this check might be redundant
// but keeping for clarity. Error message adjusted.
return this.errorResult(
`Error: Directory not found or inaccessible: ${params.path}`,
`Error: Directory not found or inaccessible: ${this.params.path}`,
`Directory not found or inaccessible.`,
);
}
if (!stats.isDirectory()) {
return this.errorResult(
`Error: Path is not a directory: ${params.path}`,
`Error: Path is not a directory: ${this.params.path}`,
`Path is not a directory.`,
);
}
const files = fs.readdirSync(params.path);
const files = fs.readdirSync(this.params.path);
const defaultFileIgnores =
this.config.getFileFilteringOptions() ?? DEFAULT_FILE_FILTERING_OPTIONS;
const fileFilteringOptions = {
respectGitIgnore:
params.file_filtering_options?.respect_git_ignore ??
this.params.file_filtering_options?.respect_git_ignore ??
defaultFileIgnores.respectGitIgnore,
respectGeminiIgnore:
params.file_filtering_options?.respect_gemini_ignore ??
this.params.file_filtering_options?.respect_gemini_ignore ??
defaultFileIgnores.respectGeminiIgnore,
};
@@ -241,17 +170,17 @@ export class LSTool extends BaseTool<LSToolParams, ToolResult> {
if (files.length === 0) {
// Changed error message to be more neutral for LLM
return {
llmContent: `Directory ${params.path} is empty.`,
llmContent: `Directory ${this.params.path} is empty.`,
returnDisplay: `Directory is empty.`,
};
}
for (const file of files) {
if (this.shouldIgnore(file, params.ignore)) {
if (this.shouldIgnore(file, this.params.ignore)) {
continue;
}
const fullPath = path.join(params.path, file);
const fullPath = path.join(this.params.path, file);
const relativePath = path.relative(
this.config.getTargetDir(),
fullPath,
@@ -301,7 +230,7 @@ export class LSTool extends BaseTool<LSToolParams, ToolResult> {
.map((entry) => `${entry.isDirectory ? '[DIR] ' : ''}${entry.name}`)
.join('\n');
let resultMessage = `Directory listing for ${params.path}:\n${directoryContent}`;
let resultMessage = `Directory listing for ${this.params.path}:\n${directoryContent}`;
const ignoredMessages = [];
if (gitIgnoredCount > 0) {
ignoredMessages.push(`${gitIgnoredCount} git-ignored`);
@@ -329,3 +258,87 @@ export class LSTool extends BaseTool<LSToolParams, ToolResult> {
}
}
}
/**
* Implementation of the LS tool logic
*/
export class LSTool extends BaseDeclarativeTool<LSToolParams, ToolResult> {
static readonly Name = 'list_directory';
constructor(private config: Config) {
super(
LSTool.Name,
'ReadFolder',
'Lists the names of files and subdirectories directly within a specified directory path. Can optionally ignore entries matching provided glob patterns.',
Kind.Search,
{
properties: {
path: {
description:
'The absolute path to the directory to list (must be absolute, not relative)',
type: 'string',
},
ignore: {
description: 'List of glob patterns to ignore',
items: {
type: 'string',
},
type: 'array',
},
file_filtering_options: {
description:
'Optional: Whether to respect ignore patterns from .gitignore or .geminiignore',
type: 'object',
properties: {
respect_git_ignore: {
description:
'Optional: Whether to respect .gitignore patterns when listing files. Only available in git repositories. Defaults to true.',
type: 'boolean',
},
respect_gemini_ignore: {
description:
'Optional: Whether to respect .geminiignore patterns when listing files. Defaults to true.',
type: 'boolean',
},
},
},
},
required: ['path'],
type: 'object',
},
);
}
/**
* Validates the parameters for the tool
* @param params Parameters to validate
* @returns An error message string if invalid, null otherwise
*/
override validateToolParams(params: LSToolParams): string | null {
const errors = SchemaValidator.validate(
this.schema.parametersJsonSchema,
params,
);
if (errors) {
return errors;
}
if (!path.isAbsolute(params.path)) {
return `Path must be absolute: ${params.path}`;
}
const workspaceContext = this.config.getWorkspaceContext();
if (!workspaceContext.isPathWithinWorkspace(params.path)) {
const directories = workspaceContext.getDirectories();
return `Path must be within one of the workspace directories: ${directories.join(
', ',
)}`;
}
return null;
}
protected createInvocation(
params: LSToolParams,
): ToolInvocation<LSToolParams, ToolResult> {
return new LSToolInvocation(this.config, params);
}
}

View File

@@ -73,11 +73,21 @@ describe('DiscoveredMCPTool', () => {
required: ['param'],
};
let tool: DiscoveredMCPTool;
beforeEach(() => {
mockCallTool.mockClear();
mockToolMethod.mockClear();
tool = new DiscoveredMCPTool(
mockCallableToolInstance,
serverName,
serverToolName,
baseDescription,
inputSchema,
);
// Clear allowlist before each relevant test, especially for shouldConfirmExecute
(DiscoveredMCPTool as any).allowlist.clear();
const invocation = tool.build({}) as any;
invocation.constructor.allowlist.clear();
});
afterEach(() => {
@@ -86,14 +96,6 @@ describe('DiscoveredMCPTool', () => {
describe('constructor', () => {
it('should set properties correctly', () => {
const tool = new DiscoveredMCPTool(
mockCallableToolInstance,
serverName,
serverToolName,
baseDescription,
inputSchema,
);
expect(tool.name).toBe(serverToolName);
expect(tool.schema.name).toBe(serverToolName);
expect(tool.schema.description).toBe(baseDescription);
@@ -105,7 +107,7 @@ describe('DiscoveredMCPTool', () => {
it('should accept and store a custom timeout', () => {
const customTimeout = 5000;
const tool = new DiscoveredMCPTool(
const toolWithTimeout = new DiscoveredMCPTool(
mockCallableToolInstance,
serverName,
serverToolName,
@@ -113,19 +115,12 @@ describe('DiscoveredMCPTool', () => {
inputSchema,
customTimeout,
);
expect(tool.timeout).toBe(customTimeout);
expect(toolWithTimeout.timeout).toBe(customTimeout);
});
});
describe('execute', () => {
it('should call mcpTool.callTool with correct parameters and format display output', async () => {
const tool = new DiscoveredMCPTool(
mockCallableToolInstance,
serverName,
serverToolName,
baseDescription,
inputSchema,
);
const params = { param: 'testValue' };
const mockToolSuccessResultObject = {
success: true,
@@ -147,7 +142,10 @@ describe('DiscoveredMCPTool', () => {
];
mockCallTool.mockResolvedValue(mockMcpToolResponseParts);
const toolResult: ToolResult = await tool.execute(params);
const invocation = tool.build(params);
const toolResult: ToolResult = await invocation.execute(
new AbortController().signal,
);
expect(mockCallTool).toHaveBeenCalledWith([
{ name: serverToolName, args: params },
@@ -163,17 +161,13 @@ describe('DiscoveredMCPTool', () => {
});
it('should handle empty result from getStringifiedResultForDisplay', async () => {
const tool = new DiscoveredMCPTool(
mockCallableToolInstance,
serverName,
serverToolName,
baseDescription,
inputSchema,
);
const params = { param: 'testValue' };
const mockMcpToolResponsePartsEmpty: Part[] = [];
mockCallTool.mockResolvedValue(mockMcpToolResponsePartsEmpty);
const toolResult: ToolResult = await tool.execute(params);
const invocation = tool.build(params);
const toolResult: ToolResult = await invocation.execute(
new AbortController().signal,
);
expect(toolResult.returnDisplay).toBe('```json\n[]\n```');
expect(toolResult.llmContent).toEqual([
{ text: '[Error: Could not parse tool response]' },
@@ -181,28 +175,17 @@ describe('DiscoveredMCPTool', () => {
});
it('should propagate rejection if mcpTool.callTool rejects', async () => {
const tool = new DiscoveredMCPTool(
mockCallableToolInstance,
serverName,
serverToolName,
baseDescription,
inputSchema,
);
const params = { param: 'failCase' };
const expectedError = new Error('MCP call failed');
mockCallTool.mockRejectedValue(expectedError);
await expect(tool.execute(params)).rejects.toThrow(expectedError);
const invocation = tool.build(params);
await expect(
invocation.execute(new AbortController().signal),
).rejects.toThrow(expectedError);
});
it('should handle a simple text response correctly', async () => {
const tool = new DiscoveredMCPTool(
mockCallableToolInstance,
serverName,
serverToolName,
baseDescription,
inputSchema,
);
const params = { query: 'test' };
const successMessage = 'This is a success message.';
@@ -221,7 +204,8 @@ describe('DiscoveredMCPTool', () => {
];
mockCallTool.mockResolvedValue(sdkResponse);
const toolResult = await tool.execute(params);
const invocation = tool.build(params);
const toolResult = await invocation.execute(new AbortController().signal);
// 1. Assert that the llmContent sent to the scheduler is a clean Part array.
expect(toolResult.llmContent).toEqual([{ text: successMessage }]);
@@ -236,13 +220,6 @@ describe('DiscoveredMCPTool', () => {
});
it('should handle an AudioBlock response', async () => {
const tool = new DiscoveredMCPTool(
mockCallableToolInstance,
serverName,
serverToolName,
baseDescription,
inputSchema,
);
const params = { action: 'play' };
const sdkResponse: Part[] = [
{
@@ -262,7 +239,8 @@ describe('DiscoveredMCPTool', () => {
];
mockCallTool.mockResolvedValue(sdkResponse);
const toolResult = await tool.execute(params);
const invocation = tool.build(params);
const toolResult = await invocation.execute(new AbortController().signal);
expect(toolResult.llmContent).toEqual([
{
@@ -279,13 +257,6 @@ describe('DiscoveredMCPTool', () => {
});
it('should handle a ResourceLinkBlock response', async () => {
const tool = new DiscoveredMCPTool(
mockCallableToolInstance,
serverName,
serverToolName,
baseDescription,
inputSchema,
);
const params = { resource: 'get' };
const sdkResponse: Part[] = [
{
@@ -306,7 +277,8 @@ describe('DiscoveredMCPTool', () => {
];
mockCallTool.mockResolvedValue(sdkResponse);
const toolResult = await tool.execute(params);
const invocation = tool.build(params);
const toolResult = await invocation.execute(new AbortController().signal);
expect(toolResult.llmContent).toEqual([
{
@@ -319,13 +291,6 @@ describe('DiscoveredMCPTool', () => {
});
it('should handle an embedded text ResourceBlock response', async () => {
const tool = new DiscoveredMCPTool(
mockCallableToolInstance,
serverName,
serverToolName,
baseDescription,
inputSchema,
);
const params = { resource: 'get' };
const sdkResponse: Part[] = [
{
@@ -348,7 +313,8 @@ describe('DiscoveredMCPTool', () => {
];
mockCallTool.mockResolvedValue(sdkResponse);
const toolResult = await tool.execute(params);
const invocation = tool.build(params);
const toolResult = await invocation.execute(new AbortController().signal);
expect(toolResult.llmContent).toEqual([
{ text: 'This is the text content.' },
@@ -357,13 +323,6 @@ describe('DiscoveredMCPTool', () => {
});
it('should handle an embedded binary ResourceBlock response', async () => {
const tool = new DiscoveredMCPTool(
mockCallableToolInstance,
serverName,
serverToolName,
baseDescription,
inputSchema,
);
const params = { resource: 'get' };
const sdkResponse: Part[] = [
{
@@ -386,7 +345,8 @@ describe('DiscoveredMCPTool', () => {
];
mockCallTool.mockResolvedValue(sdkResponse);
const toolResult = await tool.execute(params);
const invocation = tool.build(params);
const toolResult = await invocation.execute(new AbortController().signal);
expect(toolResult.llmContent).toEqual([
{
@@ -405,13 +365,6 @@ describe('DiscoveredMCPTool', () => {
});
it('should handle a mix of content block types', async () => {
const tool = new DiscoveredMCPTool(
mockCallableToolInstance,
serverName,
serverToolName,
baseDescription,
inputSchema,
);
const params = { action: 'complex' };
const sdkResponse: Part[] = [
{
@@ -433,7 +386,8 @@ describe('DiscoveredMCPTool', () => {
];
mockCallTool.mockResolvedValue(sdkResponse);
const toolResult = await tool.execute(params);
const invocation = tool.build(params);
const toolResult = await invocation.execute(new AbortController().signal);
expect(toolResult.llmContent).toEqual([
{ text: 'First part.' },
@@ -454,13 +408,6 @@ describe('DiscoveredMCPTool', () => {
});
it('should ignore unknown content block types', async () => {
const tool = new DiscoveredMCPTool(
mockCallableToolInstance,
serverName,
serverToolName,
baseDescription,
inputSchema,
);
const params = { action: 'test' };
const sdkResponse: Part[] = [
{
@@ -477,7 +424,8 @@ describe('DiscoveredMCPTool', () => {
];
mockCallTool.mockResolvedValue(sdkResponse);
const toolResult = await tool.execute(params);
const invocation = tool.build(params);
const toolResult = await invocation.execute(new AbortController().signal);
expect(toolResult.llmContent).toEqual([{ text: 'Valid part.' }]);
expect(toolResult.returnDisplay).toBe(
@@ -486,13 +434,6 @@ describe('DiscoveredMCPTool', () => {
});
it('should handle a complex mix of content block types', async () => {
const tool = new DiscoveredMCPTool(
mockCallableToolInstance,
serverName,
serverToolName,
baseDescription,
inputSchema,
);
const params = { action: 'super-complex' };
const sdkResponse: Part[] = [
{
@@ -527,7 +468,8 @@ describe('DiscoveredMCPTool', () => {
];
mockCallTool.mockResolvedValue(sdkResponse);
const toolResult = await tool.execute(params);
const invocation = tool.build(params);
const toolResult = await invocation.execute(new AbortController().signal);
expect(toolResult.llmContent).toEqual([
{ text: 'Here is a resource.' },
@@ -552,10 +494,8 @@ describe('DiscoveredMCPTool', () => {
});
describe('shouldConfirmExecute', () => {
// beforeEach is already clearing allowlist
it('should return false if trust is true', async () => {
const tool = new DiscoveredMCPTool(
const trustedTool = new DiscoveredMCPTool(
mockCallableToolInstance,
serverName,
serverToolName,
@@ -564,50 +504,32 @@ describe('DiscoveredMCPTool', () => {
undefined,
true,
);
const invocation = trustedTool.build({});
expect(
await tool.shouldConfirmExecute({}, new AbortController().signal),
await invocation.shouldConfirmExecute(new AbortController().signal),
).toBe(false);
});
it('should return false if server is allowlisted', async () => {
(DiscoveredMCPTool as any).allowlist.add(serverName);
const tool = new DiscoveredMCPTool(
mockCallableToolInstance,
serverName,
serverToolName,
baseDescription,
inputSchema,
);
const invocation = tool.build({}) as any;
invocation.constructor.allowlist.add(serverName);
expect(
await tool.shouldConfirmExecute({}, new AbortController().signal),
await invocation.shouldConfirmExecute(new AbortController().signal),
).toBe(false);
});
it('should return false if tool is allowlisted', async () => {
const toolAllowlistKey = `${serverName}.${serverToolName}`;
(DiscoveredMCPTool as any).allowlist.add(toolAllowlistKey);
const tool = new DiscoveredMCPTool(
mockCallableToolInstance,
serverName,
serverToolName,
baseDescription,
inputSchema,
);
const invocation = tool.build({}) as any;
invocation.constructor.allowlist.add(toolAllowlistKey);
expect(
await tool.shouldConfirmExecute({}, new AbortController().signal),
await invocation.shouldConfirmExecute(new AbortController().signal),
).toBe(false);
});
it('should return confirmation details if not trusted and not allowlisted', async () => {
const tool = new DiscoveredMCPTool(
mockCallableToolInstance,
serverName,
serverToolName,
baseDescription,
inputSchema,
);
const confirmation = await tool.shouldConfirmExecute(
{},
const invocation = tool.build({});
const confirmation = await invocation.shouldConfirmExecute(
new AbortController().signal,
);
expect(confirmation).not.toBe(false);
@@ -629,15 +551,8 @@ describe('DiscoveredMCPTool', () => {
});
it('should add server to allowlist on ProceedAlwaysServer', async () => {
const tool = new DiscoveredMCPTool(
mockCallableToolInstance,
serverName,
serverToolName,
baseDescription,
inputSchema,
);
const confirmation = await tool.shouldConfirmExecute(
{},
const invocation = tool.build({}) as any;
const confirmation = await invocation.shouldConfirmExecute(
new AbortController().signal,
);
expect(confirmation).not.toBe(false);
@@ -650,7 +565,7 @@ describe('DiscoveredMCPTool', () => {
await confirmation.onConfirm(
ToolConfirmationOutcome.ProceedAlwaysServer,
);
expect((DiscoveredMCPTool as any).allowlist.has(serverName)).toBe(true);
expect(invocation.constructor.allowlist.has(serverName)).toBe(true);
} else {
throw new Error(
'Confirmation details or onConfirm not in expected format',
@@ -659,16 +574,9 @@ describe('DiscoveredMCPTool', () => {
});
it('should add tool to allowlist on ProceedAlwaysTool', async () => {
const tool = new DiscoveredMCPTool(
mockCallableToolInstance,
serverName,
serverToolName,
baseDescription,
inputSchema,
);
const toolAllowlistKey = `${serverName}.${serverToolName}`;
const confirmation = await tool.shouldConfirmExecute(
{},
const invocation = tool.build({}) as any;
const confirmation = await invocation.shouldConfirmExecute(
new AbortController().signal,
);
expect(confirmation).not.toBe(false);
@@ -679,7 +587,7 @@ describe('DiscoveredMCPTool', () => {
typeof confirmation.onConfirm === 'function'
) {
await confirmation.onConfirm(ToolConfirmationOutcome.ProceedAlwaysTool);
expect((DiscoveredMCPTool as any).allowlist.has(toolAllowlistKey)).toBe(
expect(invocation.constructor.allowlist.has(toolAllowlistKey)).toBe(
true,
);
} else {
@@ -690,15 +598,8 @@ describe('DiscoveredMCPTool', () => {
});
it('should handle Cancel confirmation outcome', async () => {
const tool = new DiscoveredMCPTool(
mockCallableToolInstance,
serverName,
serverToolName,
baseDescription,
inputSchema,
);
const confirmation = await tool.shouldConfirmExecute(
{},
const invocation = tool.build({}) as any;
const confirmation = await invocation.shouldConfirmExecute(
new AbortController().signal,
);
expect(confirmation).not.toBe(false);
@@ -710,11 +611,9 @@ describe('DiscoveredMCPTool', () => {
) {
// Cancel should not add anything to allowlist
await confirmation.onConfirm(ToolConfirmationOutcome.Cancel);
expect((DiscoveredMCPTool as any).allowlist.has(serverName)).toBe(
false,
);
expect(invocation.constructor.allowlist.has(serverName)).toBe(false);
expect(
(DiscoveredMCPTool as any).allowlist.has(
invocation.constructor.allowlist.has(
`${serverName}.${serverToolName}`,
),
).toBe(false);
@@ -726,15 +625,8 @@ describe('DiscoveredMCPTool', () => {
});
it('should handle ProceedOnce confirmation outcome', async () => {
const tool = new DiscoveredMCPTool(
mockCallableToolInstance,
serverName,
serverToolName,
baseDescription,
inputSchema,
);
const confirmation = await tool.shouldConfirmExecute(
{},
const invocation = tool.build({}) as any;
const confirmation = await invocation.shouldConfirmExecute(
new AbortController().signal,
);
expect(confirmation).not.toBe(false);
@@ -746,11 +638,9 @@ describe('DiscoveredMCPTool', () => {
) {
// ProceedOnce should not add anything to allowlist
await confirmation.onConfirm(ToolConfirmationOutcome.ProceedOnce);
expect((DiscoveredMCPTool as any).allowlist.has(serverName)).toBe(
false,
);
expect(invocation.constructor.allowlist.has(serverName)).toBe(false);
expect(
(DiscoveredMCPTool as any).allowlist.has(
invocation.constructor.allowlist.has(
`${serverName}.${serverToolName}`,
),
).toBe(false);

View File

@@ -5,14 +5,16 @@
*/
import {
BaseTool,
ToolResult,
BaseDeclarativeTool,
BaseToolInvocation,
Kind,
ToolCallConfirmationDetails,
ToolConfirmationOutcome,
ToolInvocation,
ToolMcpConfirmationDetails,
Icon,
ToolResult,
} from './tools.js';
import { CallableTool, Part, FunctionCall } from '@google/genai';
import { CallableTool, FunctionCall, Part } from '@google/genai';
type ToolParams = Record<string, unknown>;
@@ -50,15 +52,90 @@ type McpContentBlock =
| McpResourceBlock
| McpResourceLinkBlock;
export class DiscoveredMCPTool extends BaseTool<ToolParams, ToolResult> {
class DiscoveredMCPToolInvocation extends BaseToolInvocation<
ToolParams,
ToolResult
> {
private static readonly allowlist: Set<string> = new Set();
constructor(
private readonly mcpTool: CallableTool,
readonly serverName: string,
readonly serverToolName: string,
readonly displayName: string,
readonly timeout?: number,
readonly trust?: boolean,
params: ToolParams = {},
) {
super(params);
}
override async shouldConfirmExecute(
_abortSignal: AbortSignal,
): Promise<ToolCallConfirmationDetails | false> {
const serverAllowListKey = this.serverName;
const toolAllowListKey = `${this.serverName}.${this.serverToolName}`;
if (this.trust) {
return false; // server is trusted, no confirmation needed
}
if (
DiscoveredMCPToolInvocation.allowlist.has(serverAllowListKey) ||
DiscoveredMCPToolInvocation.allowlist.has(toolAllowListKey)
) {
return false; // server and/or tool already allowlisted
}
const confirmationDetails: ToolMcpConfirmationDetails = {
type: 'mcp',
title: 'Confirm MCP Tool Execution',
serverName: this.serverName,
toolName: this.serverToolName, // Display original tool name in confirmation
toolDisplayName: this.displayName, // Display global registry name exposed to model and user
onConfirm: async (outcome: ToolConfirmationOutcome) => {
if (outcome === ToolConfirmationOutcome.ProceedAlwaysServer) {
DiscoveredMCPToolInvocation.allowlist.add(serverAllowListKey);
} else if (outcome === ToolConfirmationOutcome.ProceedAlwaysTool) {
DiscoveredMCPToolInvocation.allowlist.add(toolAllowListKey);
}
},
};
return confirmationDetails;
}
async execute(): Promise<ToolResult> {
const functionCalls: FunctionCall[] = [
{
name: this.serverToolName,
args: this.params,
},
];
const rawResponseParts = await this.mcpTool.callTool(functionCalls);
const transformedParts = transformMcpContentToParts(rawResponseParts);
return {
llmContent: transformedParts,
returnDisplay: getStringifiedResultForDisplay(rawResponseParts),
};
}
getDescription(): string {
return this.displayName;
}
}
export class DiscoveredMCPTool extends BaseDeclarativeTool<
ToolParams,
ToolResult
> {
constructor(
private readonly mcpTool: CallableTool,
readonly serverName: string,
readonly serverToolName: string,
description: string,
readonly parameterSchema: unknown,
override readonly parameterSchema: unknown,
readonly timeout?: number,
readonly trust?: boolean,
nameOverride?: string,
@@ -67,7 +144,7 @@ export class DiscoveredMCPTool extends BaseTool<ToolParams, ToolResult> {
nameOverride ?? generateValidName(serverToolName),
`${serverToolName} (${serverName} MCP Server)`,
description,
Icon.Hammer,
Kind.Other,
parameterSchema,
true, // isOutputMarkdown
false, // canUpdateOutput
@@ -87,56 +164,18 @@ export class DiscoveredMCPTool extends BaseTool<ToolParams, ToolResult> {
);
}
async shouldConfirmExecute(
_params: ToolParams,
_abortSignal: AbortSignal,
): Promise<ToolCallConfirmationDetails | false> {
const serverAllowListKey = this.serverName;
const toolAllowListKey = `${this.serverName}.${this.serverToolName}`;
if (this.trust) {
return false; // server is trusted, no confirmation needed
}
if (
DiscoveredMCPTool.allowlist.has(serverAllowListKey) ||
DiscoveredMCPTool.allowlist.has(toolAllowListKey)
) {
return false; // server and/or tool already allowlisted
}
const confirmationDetails: ToolMcpConfirmationDetails = {
type: 'mcp',
title: 'Confirm MCP Tool Execution',
serverName: this.serverName,
toolName: this.serverToolName, // Display original tool name in confirmation
toolDisplayName: this.name, // Display global registry name exposed to model and user
onConfirm: async (outcome: ToolConfirmationOutcome) => {
if (outcome === ToolConfirmationOutcome.ProceedAlwaysServer) {
DiscoveredMCPTool.allowlist.add(serverAllowListKey);
} else if (outcome === ToolConfirmationOutcome.ProceedAlwaysTool) {
DiscoveredMCPTool.allowlist.add(toolAllowListKey);
}
},
};
return confirmationDetails;
}
async execute(params: ToolParams): Promise<ToolResult> {
const functionCalls: FunctionCall[] = [
{
name: this.serverToolName,
args: params,
},
];
const rawResponseParts = await this.mcpTool.callTool(functionCalls);
const transformedParts = transformMcpContentToParts(rawResponseParts);
return {
llmContent: transformedParts,
returnDisplay: getStringifiedResultForDisplay(rawResponseParts),
};
protected createInvocation(
params: ToolParams,
): ToolInvocation<ToolParams, ToolResult> {
return new DiscoveredMCPToolInvocation(
this.mcpTool,
this.serverName,
this.serverToolName,
this.displayName,
this.timeout,
this.trust,
params,
);
}
}

View File

@@ -202,9 +202,11 @@ describe('MemoryTool', () => {
expect(memoryTool.schema.parametersJsonSchema).toBeDefined();
});
it('should call performAddMemoryEntry with correct parameters and return success', async () => {
it('should call performAddMemoryEntry with correct parameters and return success for global scope', async () => {
const params = { fact: 'The sky is blue', scope: 'global' as const };
const result = await memoryTool.execute(params, mockAbortSignal);
const invocation = memoryTool.build(params);
const result = await invocation.execute(mockAbortSignal);
// Use getCurrentGeminiMdFilename for the default expectation before any setGeminiMdFilename calls in a test
const expectedFilePath = path.join(
os.homedir(),
@@ -231,16 +233,44 @@ describe('MemoryTool', () => {
expect(result.returnDisplay).toBe(successMessage);
});
it('should call performAddMemoryEntry with correct parameters and return success for project scope', async () => {
const params = { fact: 'The sky is blue', scope: 'project' as const };
const invocation = memoryTool.build(params);
const result = await invocation.execute(mockAbortSignal);
// For project scope, expect the file to be in current working directory
const expectedFilePath = path.join(
process.cwd(),
getCurrentGeminiMdFilename(),
);
// For this test, we expect the actual fs methods to be passed
const expectedFsArgument = {
readFile: fs.readFile,
writeFile: fs.writeFile,
mkdir: fs.mkdir,
};
expect(performAddMemoryEntrySpy).toHaveBeenCalledWith(
params.fact,
expectedFilePath,
expectedFsArgument,
);
const successMessage = `Okay, I've remembered that in project memory: "${params.fact}"`;
expect(result.llmContent).toBe(
JSON.stringify({ success: true, message: successMessage }),
);
expect(result.returnDisplay).toBe(successMessage);
});
it('should return an error if fact is empty', async () => {
const params = { fact: ' ' }; // Empty fact
const result = await memoryTool.execute(params, mockAbortSignal);
const errorMessage = 'Parameter "fact" must be a non-empty string.';
expect(performAddMemoryEntrySpy).not.toHaveBeenCalled();
expect(result.llmContent).toBe(
JSON.stringify({ success: false, error: errorMessage }),
expect(memoryTool.validateToolParams(params)).toBe(
'Parameter "fact" must be a non-empty string.',
);
expect(() => memoryTool.build(params)).toThrow(
'Parameter "fact" must be a non-empty string.',
);
expect(result.returnDisplay).toBe(`Error: ${errorMessage}`);
});
it('should handle errors from performAddMemoryEntry', async () => {
@@ -250,7 +280,8 @@ describe('MemoryTool', () => {
);
performAddMemoryEntrySpy.mockRejectedValue(underlyingError);
const result = await memoryTool.execute(params, mockAbortSignal);
const invocation = memoryTool.build(params);
const result = await invocation.execute(mockAbortSignal);
expect(result.llmContent).toBe(
JSON.stringify({
@@ -262,6 +293,18 @@ describe('MemoryTool', () => {
`Error saving memory: ${underlyingError.message}`,
);
});
it('should return error when executing without scope parameter', async () => {
const params = { fact: 'Test fact' };
const invocation = memoryTool.build(params);
const result = await invocation.execute(mockAbortSignal);
expect(result.llmContent).toContain(
'Please specify where to save this memory',
);
expect(result.returnDisplay).toContain('Global:');
expect(result.returnDisplay).toContain('Project:');
});
});
describe('shouldConfirmExecute', () => {
@@ -269,18 +312,19 @@ describe('MemoryTool', () => {
beforeEach(() => {
memoryTool = new MemoryTool();
// Clear the allowlist before each test
(MemoryTool as unknown as { allowlist: Set<string> }).allowlist.clear();
// Mock fs.readFile to return empty string (file doesn't exist)
vi.mocked(fs.readFile).mockResolvedValue('');
// Clear allowlist before each test to ensure clean state
const invocation = memoryTool.build({ fact: 'test', scope: 'global' });
// eslint-disable-next-line @typescript-eslint/no-explicit-any
(invocation.constructor as any).allowlist.clear();
});
it('should return confirmation details when memory file is not allowlisted', async () => {
it('should return confirmation details when memory file is not allowlisted for global scope', async () => {
const params = { fact: 'Test fact', scope: 'global' as const };
const result = await memoryTool.shouldConfirmExecute(
params,
mockAbortSignal,
);
const invocation = memoryTool.build(params);
const result = await invocation.shouldConfirmExecute(mockAbortSignal);
expect(result).toBeDefined();
expect(result).not.toBe(false);
@@ -301,7 +345,30 @@ describe('MemoryTool', () => {
}
});
it('should return false when memory file is already allowlisted', async () => {
it('should return confirmation details when memory file is not allowlisted for project scope', async () => {
const params = { fact: 'Test fact', scope: 'project' as const };
const invocation = memoryTool.build(params);
const result = await invocation.shouldConfirmExecute(mockAbortSignal);
expect(result).toBeDefined();
expect(result).not.toBe(false);
if (result && result.type === 'edit') {
const expectedPath = path.join(process.cwd(), 'QWEN.md');
expect(result.title).toBe(
`Confirm Memory Save: ${expectedPath} (project)`,
);
expect(result.fileName).toBe(expectedPath);
expect(result.fileDiff).toContain('Index: QWEN.md');
expect(result.fileDiff).toContain('+## Qwen Added Memories');
expect(result.fileDiff).toContain('+- Test fact');
expect(result.originalContent).toBe('');
expect(result.newContent).toContain('## Qwen Added Memories');
expect(result.newContent).toContain('- Test fact');
}
});
it('should return false when memory file is already allowlisted for global scope', async () => {
const params = { fact: 'Test fact', scope: 'global' as const };
const memoryFilePath = path.join(
os.homedir(),
@@ -309,20 +376,36 @@ describe('MemoryTool', () => {
getCurrentGeminiMdFilename(),
);
// Add the memory file to the allowlist with the new key format
(MemoryTool as unknown as { allowlist: Set<string> }).allowlist.add(
`${memoryFilePath}_global`,
);
const invocation = memoryTool.build(params);
// Add the memory file to the allowlist with the scope-specific key format
// eslint-disable-next-line @typescript-eslint/no-explicit-any
(invocation.constructor as any).allowlist.add(`${memoryFilePath}_global`);
const result = await memoryTool.shouldConfirmExecute(
params,
mockAbortSignal,
);
const result = await invocation.shouldConfirmExecute(mockAbortSignal);
expect(result).toBe(false);
});
it('should add memory file to allowlist when ProceedAlways is confirmed', async () => {
it('should return false when memory file is already allowlisted for project scope', async () => {
const params = { fact: 'Test fact', scope: 'project' as const };
const memoryFilePath = path.join(
process.cwd(),
getCurrentGeminiMdFilename(),
);
const invocation = memoryTool.build(params);
// Add the memory file to the allowlist with the scope-specific key format
// eslint-disable-next-line @typescript-eslint/no-explicit-any
(invocation.constructor as any).allowlist.add(
`${memoryFilePath}_project`,
);
const result = await invocation.shouldConfirmExecute(mockAbortSignal);
expect(result).toBe(false);
});
it('should add memory file to allowlist when ProceedAlways is confirmed for global scope', async () => {
const params = { fact: 'Test fact', scope: 'global' as const };
const memoryFilePath = path.join(
os.homedir(),
@@ -330,10 +413,8 @@ describe('MemoryTool', () => {
getCurrentGeminiMdFilename(),
);
const result = await memoryTool.shouldConfirmExecute(
params,
mockAbortSignal,
);
const invocation = memoryTool.build(params);
const result = await invocation.shouldConfirmExecute(mockAbortSignal);
expect(result).toBeDefined();
expect(result).not.toBe(false);
@@ -342,27 +423,53 @@ describe('MemoryTool', () => {
// Simulate the onConfirm callback
await result.onConfirm(ToolConfirmationOutcome.ProceedAlways);
// Check that the memory file was added to the allowlist with the new key format
// Check that the memory file was added to the allowlist with the scope-specific key format
expect(
(MemoryTool as unknown as { allowlist: Set<string> }).allowlist.has(
// eslint-disable-next-line @typescript-eslint/no-explicit-any
(invocation.constructor as any).allowlist.has(
`${memoryFilePath}_global`,
),
).toBe(true);
}
});
it('should add memory file to allowlist when ProceedAlways is confirmed for project scope', async () => {
const params = { fact: 'Test fact', scope: 'project' as const };
const memoryFilePath = path.join(
process.cwd(),
getCurrentGeminiMdFilename(),
);
const invocation = memoryTool.build(params);
const result = await invocation.shouldConfirmExecute(mockAbortSignal);
expect(result).toBeDefined();
expect(result).not.toBe(false);
if (result && result.type === 'edit') {
// Simulate the onConfirm callback
await result.onConfirm(ToolConfirmationOutcome.ProceedAlways);
// Check that the memory file was added to the allowlist with the scope-specific key format
expect(
// eslint-disable-next-line @typescript-eslint/no-explicit-any
(invocation.constructor as any).allowlist.has(
`${memoryFilePath}_project`,
),
).toBe(true);
}
});
it('should not add memory file to allowlist when other outcomes are confirmed', async () => {
const params = { fact: 'Test fact' };
const params = { fact: 'Test fact', scope: 'global' as const };
const memoryFilePath = path.join(
os.homedir(),
'.qwen',
getCurrentGeminiMdFilename(),
);
const result = await memoryTool.shouldConfirmExecute(
params,
mockAbortSignal,
);
const invocation = memoryTool.build(params);
const result = await invocation.shouldConfirmExecute(mockAbortSignal);
expect(result).toBeDefined();
expect(result).not.toBe(false);
@@ -370,22 +477,16 @@ describe('MemoryTool', () => {
if (result && result.type === 'edit') {
// Simulate the onConfirm callback with different outcomes
await result.onConfirm(ToolConfirmationOutcome.ProceedOnce);
expect(
(MemoryTool as unknown as { allowlist: Set<string> }).allowlist.has(
memoryFilePath,
),
).toBe(false);
// eslint-disable-next-line @typescript-eslint/no-explicit-any
const allowlist = (invocation.constructor as any).allowlist;
expect(allowlist.has(`${memoryFilePath}_global`)).toBe(false);
await result.onConfirm(ToolConfirmationOutcome.Cancel);
expect(
(MemoryTool as unknown as { allowlist: Set<string> }).allowlist.has(
memoryFilePath,
),
).toBe(false);
expect(allowlist.has(`${memoryFilePath}_global`)).toBe(false);
}
});
it('should handle existing memory file with content', async () => {
it('should handle existing memory file with content for global scope', async () => {
const params = { fact: 'New fact', scope: 'global' as const };
const existingContent =
'Some existing content.\n\n## Qwen Added Memories\n- Old fact\n';
@@ -393,10 +494,8 @@ describe('MemoryTool', () => {
// Mock fs.readFile to return existing content
vi.mocked(fs.readFile).mockResolvedValue(existingContent);
const result = await memoryTool.shouldConfirmExecute(
params,
mockAbortSignal,
);
const invocation = memoryTool.build(params);
const result = await invocation.shouldConfirmExecute(mockAbortSignal);
expect(result).toBeDefined();
expect(result).not.toBe(false);
@@ -416,10 +515,8 @@ describe('MemoryTool', () => {
it('should prompt for scope selection when scope is not specified', async () => {
const params = { fact: 'Test fact' };
const result = await memoryTool.shouldConfirmExecute(
params,
mockAbortSignal,
);
const invocation = memoryTool.build(params);
const result = await invocation.shouldConfirmExecute(mockAbortSignal);
expect(result).toBeDefined();
expect(result).not.toBe(false);
@@ -438,15 +535,61 @@ describe('MemoryTool', () => {
}
});
it('should return error when executing without scope parameter', async () => {
it('should show correct file paths in scope selection prompt', async () => {
const params = { fact: 'Test fact' };
const result = await memoryTool.execute(params, mockAbortSignal);
const invocation = memoryTool.build(params);
const result = await invocation.shouldConfirmExecute(mockAbortSignal);
expect(result.llmContent).toContain(
'Please specify where to save this memory',
expect(result).toBeDefined();
expect(result).not.toBe(false);
if (result && result.type === 'edit') {
const globalPath = path.join('~', '.qwen', 'QWEN.md');
const projectPath = path.join(process.cwd(), 'QWEN.md');
expect(result.fileDiff).toContain(`Global: ${globalPath}`);
expect(result.fileDiff).toContain(`Project: ${projectPath}`);
expect(result.fileDiff).toContain('(shared across all projects)');
expect(result.fileDiff).toContain('(current project only)');
}
});
});
describe('getDescription', () => {
let memoryTool: MemoryTool;
beforeEach(() => {
memoryTool = new MemoryTool();
});
it('should return correct description for global scope', () => {
const params = { fact: 'Test fact', scope: 'global' as const };
const invocation = memoryTool.build(params);
const description = invocation.getDescription();
const expectedPath = path.join('~', '.qwen', 'QWEN.md');
expect(description).toBe(`${expectedPath} (global)`);
});
it('should return correct description for project scope', () => {
const params = { fact: 'Test fact', scope: 'project' as const };
const invocation = memoryTool.build(params);
const description = invocation.getDescription();
const expectedPath = path.join(process.cwd(), 'QWEN.md');
expect(description).toBe(`${expectedPath} (project)`);
});
it('should show choice prompt when scope is not specified', () => {
const params = { fact: 'Test fact' };
const invocation = memoryTool.build(params);
const description = invocation.getDescription();
const globalPath = path.join('~', '.qwen', 'QWEN.md');
const projectPath = path.join(process.cwd(), 'QWEN.md');
expect(description).toBe(
`CHOOSE: ${globalPath} (global) OR ${projectPath} (project)`,
);
expect(result.returnDisplay).toContain('Global:');
expect(result.returnDisplay).toContain('Project:');
});
});
});

View File

@@ -5,11 +5,12 @@
*/
import {
BaseTool,
BaseDeclarativeTool,
BaseToolInvocation,
Kind,
ToolResult,
ToolEditConfirmationDetails,
ToolConfirmationOutcome,
Icon,
} from './tools.js';
import { FunctionDeclaration } from '@google/genai';
import * as fs from 'fs/promises';
@@ -19,6 +20,7 @@ import * as Diff from 'diff';
import { DEFAULT_DIFF_OPTIONS } from './diffOptions.js';
import { tildeifyPath } from '../utils/paths.js';
import { ModifiableDeclarativeTool, ModifyContext } from './modifiable-tool.js';
import { SchemaValidator } from '../utils/schemaValidator.js';
const memoryToolSchemaData: FunctionDeclaration = {
name: 'save_memory',
@@ -131,117 +133,117 @@ function ensureNewlineSeparation(currentContent: string): string {
return '\n\n';
}
export class MemoryTool
extends BaseTool<SaveMemoryParams, ToolResult>
implements ModifiableDeclarativeTool<SaveMemoryParams>
{
private static readonly allowlist: Set<string> = new Set();
/**
* Reads the current content of the memory file
*/
async function readMemoryFileContent(
scope: 'global' | 'project' = 'global',
): Promise<string> {
try {
return await fs.readFile(getMemoryFilePath(scope), 'utf-8');
} catch (err) {
const error = err as Error & { code?: string };
if (!(error instanceof Error) || error.code !== 'ENOENT') throw err;
return '';
}
}
static readonly Name: string = memoryToolSchemaData.name!;
constructor() {
super(
MemoryTool.Name,
'Save Memory',
memoryToolDescription,
Icon.LightBulb,
memoryToolSchemaData.parametersJsonSchema as Record<string, unknown>,
/**
* Computes the new content that would result from adding a memory entry
*/
function computeNewContent(currentContent: string, fact: string): string {
let processedText = fact.trim();
processedText = processedText.replace(/^(-+\s*)+/, '').trim();
const newMemoryItem = `- ${processedText}`;
const headerIndex = currentContent.indexOf(MEMORY_SECTION_HEADER);
if (headerIndex === -1) {
// Header not found, append header and then the entry
const separator = ensureNewlineSeparation(currentContent);
return (
currentContent +
`${separator}${MEMORY_SECTION_HEADER}\n${newMemoryItem}\n`
);
} else {
// Header found, find where to insert the new memory entry
const startOfSectionContent = headerIndex + MEMORY_SECTION_HEADER.length;
let endOfSectionIndex = currentContent.indexOf(
'\n## ',
startOfSectionContent,
);
if (endOfSectionIndex === -1) {
endOfSectionIndex = currentContent.length; // End of file
}
const beforeSectionMarker = currentContent
.substring(0, startOfSectionContent)
.trimEnd();
let sectionContent = currentContent
.substring(startOfSectionContent, endOfSectionIndex)
.trimEnd();
const afterSectionMarker = currentContent.substring(endOfSectionIndex);
sectionContent += `\n${newMemoryItem}`;
return (
`${beforeSectionMarker}\n${sectionContent.trimStart()}\n${afterSectionMarker}`.trimEnd() +
'\n'
);
}
}
getDescription(params: SaveMemoryParams): string {
if (!params.scope) {
class MemoryToolInvocation extends BaseToolInvocation<
SaveMemoryParams,
ToolResult
> {
private static readonly allowlist: Set<string> = new Set();
getDescription(): string {
if (!this.params.scope) {
const globalPath = tildeifyPath(getMemoryFilePath('global'));
const projectPath = tildeifyPath(getMemoryFilePath('project'));
return `CHOOSE: ${globalPath} (global) OR ${projectPath} (project)`;
}
const scope = params.scope;
const scope = this.params.scope;
const memoryFilePath = getMemoryFilePath(scope);
return `in ${tildeifyPath(memoryFilePath)} (${scope})`;
return `${tildeifyPath(memoryFilePath)} (${scope})`;
}
/**
* Reads the current content of the memory file
*/
private async readMemoryFileContent(
scope: 'global' | 'project' = 'global',
): Promise<string> {
try {
return await fs.readFile(getMemoryFilePath(scope), 'utf-8');
} catch (err) {
const error = err as Error & { code?: string };
if (!(error instanceof Error) || error.code !== 'ENOENT') throw err;
return '';
}
}
/**
* Computes the new content that would result from adding a memory entry
*/
private computeNewContent(currentContent: string, fact: string): string {
let processedText = fact.trim();
processedText = processedText.replace(/^(-+\s*)+/, '').trim();
const newMemoryItem = `- ${processedText}`;
const headerIndex = currentContent.indexOf(MEMORY_SECTION_HEADER);
if (headerIndex === -1) {
// Header not found, append header and then the entry
const separator = ensureNewlineSeparation(currentContent);
return (
currentContent +
`${separator}${MEMORY_SECTION_HEADER}\n${newMemoryItem}\n`
);
} else {
// Header found, find where to insert the new memory entry
const startOfSectionContent = headerIndex + MEMORY_SECTION_HEADER.length;
let endOfSectionIndex = currentContent.indexOf(
'\n## ',
startOfSectionContent,
);
if (endOfSectionIndex === -1) {
endOfSectionIndex = currentContent.length; // End of file
}
const beforeSectionMarker = currentContent
.substring(0, startOfSectionContent)
.trimEnd();
let sectionContent = currentContent
.substring(startOfSectionContent, endOfSectionIndex)
.trimEnd();
const afterSectionMarker = currentContent.substring(endOfSectionIndex);
sectionContent += `\n${newMemoryItem}`;
return (
`${beforeSectionMarker}\n${sectionContent.trimStart()}\n${afterSectionMarker}`.trimEnd() +
'\n'
);
}
}
async shouldConfirmExecute(
params: SaveMemoryParams,
override async shouldConfirmExecute(
_abortSignal: AbortSignal,
): Promise<ToolEditConfirmationDetails | false> {
// When scope is not specified, show a choice dialog defaulting to global
if (!params.scope) {
if (!this.params.scope) {
// Show preview of what would be added to global by default
const defaultScope = 'global';
const currentContent = await this.readMemoryFileContent(defaultScope);
const newContent = this.computeNewContent(currentContent, params.fact);
const fileName = path.basename(getMemoryFilePath(defaultScope));
const fileDiff = Diff.createPatch(
fileName,
currentContent,
newContent,
'Current',
'Proposed (Global)',
DEFAULT_DIFF_OPTIONS,
);
const currentContent = await readMemoryFileContent(defaultScope);
const newContent = computeNewContent(currentContent, this.params.fact);
const globalPath = tildeifyPath(getMemoryFilePath('global'));
const projectPath = tildeifyPath(getMemoryFilePath('project'));
const fileName = path.basename(getMemoryFilePath(defaultScope));
const choiceText = `Choose where to save this memory:
"${this.params.fact}"
Options:
- Global: ${globalPath} (shared across all projects)
- Project: ${projectPath} (current project only)
Preview of changes to be made to GLOBAL memory:
`;
const fileDiff =
choiceText +
Diff.createPatch(
fileName,
currentContent,
newContent,
'Current',
'Proposed (Global)',
DEFAULT_DIFF_OPTIONS,
);
const confirmationDetails: ToolEditConfirmationDetails = {
type: 'edit',
title: `Choose Memory Location: GLOBAL (${globalPath}) or PROJECT (${projectPath})`,
@@ -258,19 +260,19 @@ export class MemoryTool
}
// Only check allowlist when scope is specified
const scope = params.scope!; // We know scope is specified at this point
const scope = this.params.scope;
const memoryFilePath = getMemoryFilePath(scope);
const allowlistKey = `${memoryFilePath}_${scope}`;
if (MemoryTool.allowlist.has(allowlistKey)) {
if (MemoryToolInvocation.allowlist.has(allowlistKey)) {
return false;
}
// Read current content of the memory file
const currentContent = await this.readMemoryFileContent(scope);
const currentContent = await readMemoryFileContent(scope);
// Calculate the new content that will be written to the memory file
const newContent = this.computeNewContent(currentContent, params.fact);
const newContent = computeNewContent(currentContent, this.params.fact);
const fileName = path.basename(memoryFilePath);
const fileDiff = Diff.createPatch(
@@ -292,13 +294,128 @@ export class MemoryTool
newContent,
onConfirm: async (outcome: ToolConfirmationOutcome) => {
if (outcome === ToolConfirmationOutcome.ProceedAlways) {
MemoryTool.allowlist.add(allowlistKey);
MemoryToolInvocation.allowlist.add(allowlistKey);
}
},
};
return confirmationDetails;
}
async execute(_signal: AbortSignal): Promise<ToolResult> {
const { fact, modified_by_user, modified_content } = this.params;
if (!fact || typeof fact !== 'string' || fact.trim() === '') {
const errorMessage = 'Parameter "fact" must be a non-empty string.';
return {
llmContent: JSON.stringify({ success: false, error: errorMessage }),
returnDisplay: `Error: ${errorMessage}`,
};
}
// If scope is not specified and user didn't modify content, return error prompting for choice
if (!this.params.scope && !modified_by_user) {
const globalPath = tildeifyPath(getMemoryFilePath('global'));
const projectPath = tildeifyPath(getMemoryFilePath('project'));
const errorMessage = `Please specify where to save this memory:
Global: ${globalPath} (shared across all projects)
Project: ${projectPath} (current project only)`;
return {
llmContent: JSON.stringify({
success: false,
error: 'Please specify where to save this memory',
}),
returnDisplay: errorMessage,
};
}
const scope = this.params.scope || 'global';
const memoryFilePath = getMemoryFilePath(scope);
try {
if (modified_by_user && modified_content !== undefined) {
// User modified the content in external editor, write it directly
await fs.mkdir(path.dirname(memoryFilePath), {
recursive: true,
});
await fs.writeFile(memoryFilePath, modified_content, 'utf-8');
const successMessage = `Okay, I've updated the ${scope} memory file with your modifications.`;
return {
llmContent: JSON.stringify({
success: true,
message: successMessage,
}),
returnDisplay: successMessage,
};
} else {
// Use the normal memory entry logic
await MemoryTool.performAddMemoryEntry(fact, memoryFilePath, {
readFile: fs.readFile,
writeFile: fs.writeFile,
mkdir: fs.mkdir,
});
const successMessage = `Okay, I've remembered that in ${scope} memory: "${fact}"`;
return {
llmContent: JSON.stringify({
success: true,
message: successMessage,
}),
returnDisplay: successMessage,
};
}
} catch (error) {
const errorMessage =
error instanceof Error ? error.message : String(error);
console.error(
`[MemoryTool] Error executing save_memory for fact "${fact}" in ${scope}: ${errorMessage}`,
);
return {
llmContent: JSON.stringify({
success: false,
error: `Failed to save memory. Detail: ${errorMessage}`,
}),
returnDisplay: `Error saving memory: ${errorMessage}`,
};
}
}
}
export class MemoryTool
extends BaseDeclarativeTool<SaveMemoryParams, ToolResult>
implements ModifiableDeclarativeTool<SaveMemoryParams>
{
static readonly Name: string = memoryToolSchemaData.name!;
constructor() {
super(
MemoryTool.Name,
'Save Memory',
memoryToolDescription,
Kind.Think,
memoryToolSchemaData.parametersJsonSchema as Record<string, unknown>,
);
}
override validateToolParams(params: SaveMemoryParams): string | null {
const errors = SchemaValidator.validate(
this.schema.parametersJsonSchema,
params,
);
if (errors) {
return errors;
}
if (params.fact.trim() === '') {
return 'Parameter "fact" must be a non-empty string.';
}
return null;
}
protected createInvocation(params: SaveMemoryParams) {
return new MemoryToolInvocation(params);
}
static async performAddMemoryEntry(
text: string,
memoryFilePath: string,
@@ -369,88 +486,6 @@ export class MemoryTool
}
}
async execute(
params: SaveMemoryParams,
_signal: AbortSignal,
): Promise<ToolResult> {
const { fact, modified_by_user, modified_content } = params;
if (!fact || typeof fact !== 'string' || fact.trim() === '') {
const errorMessage = 'Parameter "fact" must be a non-empty string.';
return {
llmContent: JSON.stringify({ success: false, error: errorMessage }),
returnDisplay: `Error: ${errorMessage}`,
};
}
// If scope is not specified and user didn't modify content, return error prompting for choice
if (!params.scope && !params.modified_by_user) {
const globalPath = tildeifyPath(getMemoryFilePath('global'));
const projectPath = tildeifyPath(getMemoryFilePath('project'));
const errorMessage = `Please specify where to save this memory:
Global: ${globalPath} (shared across all projects)
Project: ${projectPath} (current project only)`;
return {
llmContent: JSON.stringify({
success: false,
error: 'Please specify where to save this memory',
}),
returnDisplay: errorMessage,
};
}
const scope = params.scope || 'global';
const memoryFilePath = getMemoryFilePath(scope);
try {
if (modified_by_user && modified_content !== undefined) {
// User modified the content in external editor, write it directly
await fs.mkdir(path.dirname(memoryFilePath), {
recursive: true,
});
await fs.writeFile(memoryFilePath, modified_content, 'utf-8');
const successMessage = `Okay, I've updated the ${scope} memory file with your modifications.`;
return {
llmContent: JSON.stringify({
success: true,
message: successMessage,
}),
returnDisplay: successMessage,
};
} else {
// Use the normal memory entry logic
await MemoryTool.performAddMemoryEntry(fact, memoryFilePath, {
readFile: fs.readFile,
writeFile: fs.writeFile,
mkdir: fs.mkdir,
});
const successMessage = `Okay, I've remembered that in ${scope} memory: "${fact}"`;
return {
llmContent: JSON.stringify({
success: true,
message: successMessage,
}),
returnDisplay: successMessage,
};
}
} catch (error) {
const errorMessage =
error instanceof Error ? error.message : String(error);
console.error(
`[MemoryTool] Error executing save_memory for fact "${fact}" in ${scope}: ${errorMessage}`,
);
return {
llmContent: JSON.stringify({
success: false,
error: `Failed to save memory. Detail: ${errorMessage}`,
}),
returnDisplay: `Error saving memory: ${errorMessage}`,
};
}
}
getModifyContext(_abortSignal: AbortSignal): ModifyContext<SaveMemoryParams> {
return {
getFilePath: (params: SaveMemoryParams) => {
@@ -474,14 +509,14 @@ Project: ${projectPath} (current project only)`;
);
if (scopeMatch) {
const scope = scopeMatch[1].toLowerCase() as 'global' | 'project';
const content = await this.readMemoryFileContent(scope);
const content = await readMemoryFileContent(scope);
const globalPath = tildeifyPath(getMemoryFilePath('global'));
const projectPath = tildeifyPath(getMemoryFilePath('project'));
return `scope: ${scope}\n\n# INSTRUCTIONS:\n# - Save as "global" for GLOBAL memory: ${globalPath}\n# - Save as "project" for PROJECT memory: ${projectPath}\n\n${content}`;
}
}
const scope = params.scope || 'global';
const content = await this.readMemoryFileContent(scope);
const content = await readMemoryFileContent(scope);
const globalPath = tildeifyPath(getMemoryFilePath('global'));
const projectPath = tildeifyPath(getMemoryFilePath('project'));
return `scope: ${scope}\n\n# INSTRUCTIONS:\n# - Save as "global" for GLOBAL memory: ${globalPath}\n# - Save as "project" for PROJECT memory: ${projectPath}\n\n${content}`;
@@ -499,8 +534,8 @@ Project: ${projectPath} (current project only)`;
}
}
const currentContent = await this.readMemoryFileContent(scope);
const newContent = this.computeNewContent(currentContent, params.fact);
const currentContent = await readMemoryFileContent(scope);
const newContent = computeNewContent(currentContent, params.fact);
const globalPath = tildeifyPath(getMemoryFilePath('global'));
const projectPath = tildeifyPath(getMemoryFilePath('project'));
return `scope: ${scope}\n\n# INSTRUCTIONS:\n# - Save as "global" for GLOBAL memory: ${globalPath}\n# - Save as "project" for PROJECT memory: ${projectPath}\n\n${newContent}`;

View File

@@ -10,7 +10,7 @@ import { makeRelative, shortenPath } from '../utils/paths.js';
import {
BaseDeclarativeTool,
BaseToolInvocation,
Icon,
Kind,
ToolInvocation,
ToolLocation,
ToolResult,
@@ -173,7 +173,7 @@ export class ReadFileTool extends BaseDeclarativeTool<
ReadFileTool.Name,
'ReadFile',
`Reads and returns the content of a specified file. If the file is large, the content will be truncated. The tool's response will clearly indicate if truncation has occurred and will provide details on how to read more of the file using the 'offset' and 'limit' parameters. Handles text, images (PNG, JPG, GIF, WEBP, SVG, BMP), and PDF files. For text files, it can read specific line ranges.`,
Icon.FileSearch,
Kind.Read,
{
properties: {
absolute_path: {
@@ -198,7 +198,9 @@ export class ReadFileTool extends BaseDeclarativeTool<
);
}
protected validateToolParams(params: ReadFileToolParams): string | null {
protected override validateToolParams(
params: ReadFileToolParams,
): string | null {
const errors = SchemaValidator.validate(
this.schema.parametersJsonSchema,
params,

View File

@@ -121,66 +121,71 @@ describe('ReadManyFilesTool', () => {
}
});
describe('validateParams', () => {
it('should return null for valid relative paths within root', () => {
describe('build', () => {
it('should return an invocation for valid relative paths within root', () => {
const params = { paths: ['file1.txt', 'subdir/file2.txt'] };
expect(tool.validateParams(params)).toBeNull();
const invocation = tool.build(params);
expect(invocation).toBeDefined();
});
it('should return null for valid glob patterns within root', () => {
it('should return an invocation for valid glob patterns within root', () => {
const params = { paths: ['*.txt', 'subdir/**/*.js'] };
expect(tool.validateParams(params)).toBeNull();
const invocation = tool.build(params);
expect(invocation).toBeDefined();
});
it('should return null for paths trying to escape the root (e.g., ../) as execute handles this', () => {
it('should return an invocation for paths trying to escape the root (e.g., ../) as execute handles this', () => {
const params = { paths: ['../outside.txt'] };
expect(tool.validateParams(params)).toBeNull();
const invocation = tool.build(params);
expect(invocation).toBeDefined();
});
it('should return null for absolute paths as execute handles this', () => {
it('should return an invocation for absolute paths as execute handles this', () => {
const params = { paths: [path.join(tempDirOutsideRoot, 'absolute.txt')] };
expect(tool.validateParams(params)).toBeNull();
const invocation = tool.build(params);
expect(invocation).toBeDefined();
});
it('should return error if paths array is empty', () => {
it('should throw error if paths array is empty', () => {
const params = { paths: [] };
expect(tool.validateParams(params)).toBe(
expect(() => tool.build(params)).toThrow(
'params/paths must NOT have fewer than 1 items',
);
});
it('should return null for valid exclude and include patterns', () => {
it('should return an invocation for valid exclude and include patterns', () => {
const params = {
paths: ['src/**/*.ts'],
exclude: ['**/*.test.ts'],
include: ['src/utils/*.ts'],
};
expect(tool.validateParams(params)).toBeNull();
const invocation = tool.build(params);
expect(invocation).toBeDefined();
});
it('should return error if paths array contains an empty string', () => {
it('should throw error if paths array contains an empty string', () => {
const params = { paths: ['file1.txt', ''] };
expect(tool.validateParams(params)).toBe(
expect(() => tool.build(params)).toThrow(
'params/paths/1 must NOT have fewer than 1 characters',
);
});
it('should return error if include array contains non-string elements', () => {
it('should throw error if include array contains non-string elements', () => {
const params = {
paths: ['file1.txt'],
include: ['*.ts', 123] as string[],
};
expect(tool.validateParams(params)).toBe(
expect(() => tool.build(params)).toThrow(
'params/include/1 must be string',
);
});
it('should return error if exclude array contains non-string elements', () => {
it('should throw error if exclude array contains non-string elements', () => {
const params = {
paths: ['file1.txt'],
exclude: ['*.log', {}] as string[],
};
expect(tool.validateParams(params)).toBe(
expect(() => tool.build(params)).toThrow(
'params/exclude/1 must be string',
);
});
@@ -201,7 +206,8 @@ describe('ReadManyFilesTool', () => {
it('should read a single specified file', async () => {
createFile('file1.txt', 'Content of file1');
const params = { paths: ['file1.txt'] };
const result = await tool.execute(params, new AbortController().signal);
const invocation = tool.build(params);
const result = await invocation.execute(new AbortController().signal);
const expectedPath = path.join(tempRootDir, 'file1.txt');
expect(result.llmContent).toEqual([
`--- ${expectedPath} ---\n\nContent of file1\n\n`,
@@ -215,7 +221,8 @@ describe('ReadManyFilesTool', () => {
createFile('file1.txt', 'Content1');
createFile('subdir/file2.js', 'Content2');
const params = { paths: ['file1.txt', 'subdir/file2.js'] };
const result = await tool.execute(params, new AbortController().signal);
const invocation = tool.build(params);
const result = await invocation.execute(new AbortController().signal);
const content = result.llmContent as string[];
const expectedPath1 = path.join(tempRootDir, 'file1.txt');
const expectedPath2 = path.join(tempRootDir, 'subdir/file2.js');
@@ -239,7 +246,8 @@ describe('ReadManyFilesTool', () => {
createFile('another.txt', 'Another text');
createFile('sub/data.json', '{}');
const params = { paths: ['*.txt'] };
const result = await tool.execute(params, new AbortController().signal);
const invocation = tool.build(params);
const result = await invocation.execute(new AbortController().signal);
const content = result.llmContent as string[];
const expectedPath1 = path.join(tempRootDir, 'file.txt');
const expectedPath2 = path.join(tempRootDir, 'another.txt');
@@ -263,7 +271,8 @@ describe('ReadManyFilesTool', () => {
createFile('src/main.ts', 'Main content');
createFile('src/main.test.ts', 'Test content');
const params = { paths: ['src/**/*.ts'], exclude: ['**/*.test.ts'] };
const result = await tool.execute(params, new AbortController().signal);
const invocation = tool.build(params);
const result = await invocation.execute(new AbortController().signal);
const content = result.llmContent as string[];
const expectedPath = path.join(tempRootDir, 'src/main.ts');
expect(content).toEqual([`--- ${expectedPath} ---\n\nMain content\n\n`]);
@@ -277,7 +286,8 @@ describe('ReadManyFilesTool', () => {
it('should handle nonexistent specific files gracefully', async () => {
const params = { paths: ['nonexistent-file.txt'] };
const result = await tool.execute(params, new AbortController().signal);
const invocation = tool.build(params);
const result = await invocation.execute(new AbortController().signal);
expect(result.llmContent).toEqual([
'No files matching the criteria were found or all were skipped.',
]);
@@ -290,7 +300,8 @@ describe('ReadManyFilesTool', () => {
createFile('node_modules/some-lib/index.js', 'lib code');
createFile('src/app.js', 'app code');
const params = { paths: ['**/*.js'] };
const result = await tool.execute(params, new AbortController().signal);
const invocation = tool.build(params);
const result = await invocation.execute(new AbortController().signal);
const content = result.llmContent as string[];
const expectedPath = path.join(tempRootDir, 'src/app.js');
expect(content).toEqual([`--- ${expectedPath} ---\n\napp code\n\n`]);
@@ -306,7 +317,8 @@ describe('ReadManyFilesTool', () => {
createFile('node_modules/some-lib/index.js', 'lib code');
createFile('src/app.js', 'app code');
const params = { paths: ['**/*.js'], useDefaultExcludes: false };
const result = await tool.execute(params, new AbortController().signal);
const invocation = tool.build(params);
const result = await invocation.execute(new AbortController().signal);
const content = result.llmContent as string[];
const expectedPath1 = path.join(
tempRootDir,
@@ -334,7 +346,8 @@ describe('ReadManyFilesTool', () => {
Buffer.from([0x89, 0x50, 0x4e, 0x47, 0x0d, 0x0a, 0x1a, 0x0a]),
);
const params = { paths: ['*.png'] }; // Explicitly requesting .png
const result = await tool.execute(params, new AbortController().signal);
const invocation = tool.build(params);
const result = await invocation.execute(new AbortController().signal);
expect(result.llmContent).toEqual([
{
inlineData: {
@@ -356,7 +369,8 @@ describe('ReadManyFilesTool', () => {
Buffer.from([0x89, 0x50, 0x4e, 0x47, 0x0d, 0x0a, 0x1a, 0x0a]),
);
const params = { paths: ['myExactImage.png'] }; // Explicitly requesting by full name
const result = await tool.execute(params, new AbortController().signal);
const invocation = tool.build(params);
const result = await invocation.execute(new AbortController().signal);
expect(result.llmContent).toEqual([
{
inlineData: {
@@ -373,7 +387,8 @@ describe('ReadManyFilesTool', () => {
createBinaryFile('document.pdf', Buffer.from('%PDF-1.4...'));
createFile('notes.txt', 'text notes');
const params = { paths: ['*'] }; // Generic glob, not specific to .pdf
const result = await tool.execute(params, new AbortController().signal);
const invocation = tool.build(params);
const result = await invocation.execute(new AbortController().signal);
const content = result.llmContent as string[];
const expectedPath = path.join(tempRootDir, 'notes.txt');
expect(
@@ -392,7 +407,8 @@ describe('ReadManyFilesTool', () => {
it('should include PDF files as inlineData parts if explicitly requested by extension', async () => {
createBinaryFile('important.pdf', Buffer.from('%PDF-1.4...'));
const params = { paths: ['*.pdf'] }; // Explicitly requesting .pdf files
const result = await tool.execute(params, new AbortController().signal);
const invocation = tool.build(params);
const result = await invocation.execute(new AbortController().signal);
expect(result.llmContent).toEqual([
{
inlineData: {
@@ -406,7 +422,8 @@ describe('ReadManyFilesTool', () => {
it('should include PDF files as inlineData parts if explicitly requested by name', async () => {
createBinaryFile('report-final.pdf', Buffer.from('%PDF-1.4...'));
const params = { paths: ['report-final.pdf'] };
const result = await tool.execute(params, new AbortController().signal);
const invocation = tool.build(params);
const result = await invocation.execute(new AbortController().signal);
expect(result.llmContent).toEqual([
{
inlineData: {
@@ -422,7 +439,8 @@ describe('ReadManyFilesTool', () => {
createFile('bar.ts', '');
createFile('foo.quux', '');
const params = { paths: ['foo.bar', 'bar.ts', 'foo.quux'] };
const result = await tool.execute(params, new AbortController().signal);
const invocation = tool.build(params);
const result = await invocation.execute(new AbortController().signal);
expect(result.returnDisplay).not.toContain('foo.bar');
expect(result.returnDisplay).not.toContain('foo.quux');
expect(result.returnDisplay).toContain('bar.ts');
@@ -451,7 +469,8 @@ describe('ReadManyFilesTool', () => {
fs.writeFileSync(path.join(tempDir2, 'file2.txt'), 'Content2');
const params = { paths: ['*.txt'] };
const result = await tool.execute(params, new AbortController().signal);
const invocation = tool.build(params);
const result = await invocation.execute(new AbortController().signal);
const content = result.llmContent as string[];
if (!Array.isArray(content)) {
throw new Error(`llmContent is not an array: ${content}`);
@@ -486,7 +505,8 @@ describe('ReadManyFilesTool', () => {
createFile('large-file.txt', longContent);
const params = { paths: ['*.txt'] };
const result = await tool.execute(params, new AbortController().signal);
const invocation = tool.build(params);
const result = await invocation.execute(new AbortController().signal);
const content = result.llmContent as string[];
const normalFileContent = content.find((c) => c.includes('file1.txt'));
@@ -541,7 +561,8 @@ describe('ReadManyFilesTool', () => {
});
const params = { paths: files };
const result = await tool.execute(params, new AbortController().signal);
const invocation = tool.build(params);
const result = await invocation.execute(new AbortController().signal);
// Verify all files were processed
const content = result.llmContent as string[];
@@ -569,7 +590,8 @@ describe('ReadManyFilesTool', () => {
],
};
const result = await tool.execute(params, new AbortController().signal);
const invocation = tool.build(params);
const result = await invocation.execute(new AbortController().signal);
const content = result.llmContent as string[];
// Should successfully process valid files despite one failure
@@ -606,7 +628,8 @@ describe('ReadManyFilesTool', () => {
return 'text';
});
await tool.execute({ paths: files }, new AbortController().signal);
const invocation = tool.build({ paths: files });
await invocation.execute(new AbortController().signal);
console.log('Execution order:', executionOrder);

View File

@@ -4,7 +4,13 @@
* SPDX-License-Identifier: Apache-2.0
*/
import { BaseTool, Icon, ToolResult } from './tools.js';
import {
BaseDeclarativeTool,
BaseToolInvocation,
Kind,
ToolInvocation,
ToolResult,
} from './tools.js';
import { SchemaValidator } from '../utils/schemaValidator.js';
import { getErrorMessage } from '../utils/errors.js';
import * as path from 'path';
@@ -138,120 +144,28 @@ const DEFAULT_EXCLUDES: string[] = [
const DEFAULT_OUTPUT_SEPARATOR_FORMAT = '--- {filePath} ---';
/**
* Tool implementation for finding and reading multiple text files from the local filesystem
* within a specified target directory. The content is concatenated.
* It is intended to run in an environment with access to the local file system (e.g., a Node.js backend).
*/
export class ReadManyFilesTool extends BaseTool<
class ReadManyFilesToolInvocation extends BaseToolInvocation<
ReadManyFilesParams,
ToolResult
> {
static readonly Name: string = 'read_many_files';
constructor(private config: Config) {
const parameterSchema = {
type: 'object',
properties: {
paths: {
type: 'array',
items: {
type: 'string',
minLength: 1,
},
minItems: 1,
description:
"Required. An array of glob patterns or paths relative to the tool's target directory. Examples: ['src/**/*.ts'], ['README.md', 'docs/']",
},
include: {
type: 'array',
items: {
type: 'string',
minLength: 1,
},
description:
'Optional. Additional glob patterns to include. These are merged with `paths`. Example: ["*.test.ts"] to specifically add test files if they were broadly excluded.',
default: [],
},
exclude: {
type: 'array',
items: {
type: 'string',
minLength: 1,
},
description:
'Optional. Glob patterns for files/directories to exclude. Added to default excludes if useDefaultExcludes is true. Example: ["**/*.log", "temp/"]',
default: [],
},
recursive: {
type: 'boolean',
description:
'Optional. Whether to search recursively (primarily controlled by `**` in glob patterns). Defaults to true.',
default: true,
},
useDefaultExcludes: {
type: 'boolean',
description:
'Optional. Whether to apply a list of default exclusion patterns (e.g., node_modules, .git, binary files). Defaults to true.',
default: true,
},
file_filtering_options: {
description:
'Whether to respect ignore patterns from .gitignore or .geminiignore',
type: 'object',
properties: {
respect_git_ignore: {
description:
'Optional: Whether to respect .gitignore patterns when listing files. Only available in git repositories. Defaults to true.',
type: 'boolean',
},
respect_gemini_ignore: {
description:
'Optional: Whether to respect .geminiignore patterns when listing files. Defaults to true.',
type: 'boolean',
},
},
},
},
required: ['paths'],
};
super(
ReadManyFilesTool.Name,
'ReadManyFiles',
`Reads content from multiple files specified by paths or glob patterns within a configured target directory. For text files, it concatenates their content into a single string. It is primarily designed for text-based files. However, it can also process image (e.g., .png, .jpg) and PDF (.pdf) files if their file names or extensions are explicitly included in the 'paths' argument. For these explicitly requested non-text files, their data is read and included in a format suitable for model consumption (e.g., base64 encoded).
This tool is useful when you need to understand or analyze a collection of files, such as:
- Getting an overview of a codebase or parts of it (e.g., all TypeScript files in the 'src' directory).
- Finding where specific functionality is implemented if the user asks broad questions about code.
- Reviewing documentation files (e.g., all Markdown files in the 'docs' directory).
- Gathering context from multiple configuration files.
- When the user asks to "read all files in X directory" or "show me the content of all Y files".
Use this tool when the user's query implies needing the content of several files simultaneously for context, analysis, or summarization. For text files, it uses default UTF-8 encoding and a '--- {filePath} ---' separator between file contents. Ensure paths are relative to the target directory. Glob patterns like 'src/**/*.js' are supported. Avoid using for single files if a more specific single-file reading tool is available, unless the user specifically requests to process a list containing just one file via this tool. Other binary files (not explicitly requested as image/PDF) are generally skipped. Default excludes apply to common non-text files (except for explicitly requested images/PDFs) and large dependency directories unless 'useDefaultExcludes' is false.`,
Icon.FileSearch,
parameterSchema,
);
constructor(
private readonly config: Config,
params: ReadManyFilesParams,
) {
super(params);
}
validateParams(params: ReadManyFilesParams): string | null {
const errors = SchemaValidator.validate(
this.schema.parametersJsonSchema,
params,
);
if (errors) {
return errors;
}
return null;
}
getDescription(params: ReadManyFilesParams): string {
const allPatterns = [...params.paths, ...(params.include || [])];
const pathDesc = `using patterns: \`${allPatterns.join('`, `')}\` (within target directory: \`${this.config.getTargetDir()}\`)`;
getDescription(): string {
const allPatterns = [...this.params.paths, ...(this.params.include || [])];
const pathDesc = `using patterns:
${allPatterns.join('`, `')}
(within target directory:
${this.config.getTargetDir()}
) `;
// Determine the final list of exclusion patterns exactly as in execute method
const paramExcludes = params.exclude || [];
const paramUseDefaultExcludes = params.useDefaultExcludes !== false;
const paramExcludes = this.params.exclude || [];
const paramUseDefaultExcludes = this.params.useDefaultExcludes !== false;
const geminiIgnorePatterns = this.config
.getFileService()
.getGeminiIgnorePatterns();
@@ -260,7 +174,16 @@ Use this tool when the user's query implies needing the content of several files
? [...DEFAULT_EXCLUDES, ...paramExcludes, ...geminiIgnorePatterns]
: [...paramExcludes, ...geminiIgnorePatterns];
let excludeDesc = `Excluding: ${finalExclusionPatternsForDescription.length > 0 ? `patterns like \`${finalExclusionPatternsForDescription.slice(0, 2).join('`, `')}${finalExclusionPatternsForDescription.length > 2 ? '...`' : '`'}` : 'none specified'}`;
let excludeDesc = `Excluding: ${
finalExclusionPatternsForDescription.length > 0
? `patterns like
${finalExclusionPatternsForDescription
.slice(0, 2)
.join(
'`, `',
)}${finalExclusionPatternsForDescription.length > 2 ? '...`' : '`'}`
: 'none specified'
}`;
// Add a note if .geminiignore patterns contributed to the final list of exclusions
if (geminiIgnorePatterns.length > 0) {
@@ -272,37 +195,29 @@ Use this tool when the user's query implies needing the content of several files
}
}
return `Will attempt to read and concatenate files ${pathDesc}. ${excludeDesc}. File encoding: ${DEFAULT_ENCODING}. Separator: "${DEFAULT_OUTPUT_SEPARATOR_FORMAT.replace('{filePath}', 'path/to/file.ext')}".`;
return `Will attempt to read and concatenate files ${pathDesc}. ${excludeDesc}. File encoding: ${DEFAULT_ENCODING}. Separator: "${DEFAULT_OUTPUT_SEPARATOR_FORMAT.replace(
'{filePath}',
'path/to/file.ext',
)}".`;
}
async execute(
params: ReadManyFilesParams,
signal: AbortSignal,
): Promise<ToolResult> {
const validationError = this.validateParams(params);
if (validationError) {
return {
llmContent: `Error: Invalid parameters for ${this.displayName}. Reason: ${validationError}`,
returnDisplay: `## Parameter Error\n\n${validationError}`,
};
}
async execute(signal: AbortSignal): Promise<ToolResult> {
const {
paths: inputPatterns,
include = [],
exclude = [],
useDefaultExcludes = true,
} = params;
} = this.params;
const defaultFileIgnores =
this.config.getFileFilteringOptions() ?? DEFAULT_FILE_FILTERING_OPTIONS;
const fileFilteringOptions = {
respectGitIgnore:
params.file_filtering_options?.respect_git_ignore ??
this.params.file_filtering_options?.respect_git_ignore ??
defaultFileIgnores.respectGitIgnore, // Use the property from the returned object
respectGeminiIgnore:
params.file_filtering_options?.respect_gemini_ignore ??
this.params.file_filtering_options?.respect_gemini_ignore ??
defaultFileIgnores.respectGeminiIgnore, // Use the property from the returned object
};
// Get centralized file discovery service
@@ -614,3 +529,119 @@ Use this tool when the user's query implies needing the content of several files
};
}
}
/**
* Tool implementation for finding and reading multiple text files from the local filesystem
* within a specified target directory. The content is concatenated.
* It is intended to run in an environment with access to the local file system (e.g., a Node.js backend).
*/
export class ReadManyFilesTool extends BaseDeclarativeTool<
ReadManyFilesParams,
ToolResult
> {
static readonly Name: string = 'read_many_files';
constructor(private config: Config) {
const parameterSchema = {
type: 'object',
properties: {
paths: {
type: 'array',
items: {
type: 'string',
minLength: 1,
},
minItems: 1,
description:
"Required. An array of glob patterns or paths relative to the tool's target directory. Examples: ['src/**/*.ts'], ['README.md', 'docs/']",
},
include: {
type: 'array',
items: {
type: 'string',
minLength: 1,
},
description:
'Optional. Additional glob patterns to include. These are merged with `paths`. Example: "*.test.ts" to specifically add test files if they were broadly excluded.',
default: [],
},
exclude: {
type: 'array',
items: {
type: 'string',
minLength: 1,
},
description:
'Optional. Glob patterns for files/directories to exclude. Added to default excludes if useDefaultExcludes is true. Example: "**/*.log", "temp/"',
default: [],
},
recursive: {
type: 'boolean',
description:
'Optional. Whether to search recursively (primarily controlled by `**` in glob patterns). Defaults to true.',
default: true,
},
useDefaultExcludes: {
type: 'boolean',
description:
'Optional. Whether to apply a list of default exclusion patterns (e.g., node_modules, .git, binary files). Defaults to true.',
default: true,
},
file_filtering_options: {
description:
'Whether to respect ignore patterns from .gitignore or .geminiignore',
type: 'object',
properties: {
respect_git_ignore: {
description:
'Optional: Whether to respect .gitignore patterns when listing files. Only available in git repositories. Defaults to true.',
type: 'boolean',
},
respect_gemini_ignore: {
description:
'Optional: Whether to respect .geminiignore patterns when listing files. Defaults to true.',
type: 'boolean',
},
},
},
},
required: ['paths'],
};
super(
ReadManyFilesTool.Name,
'ReadManyFiles',
`Reads content from multiple files specified by paths or glob patterns within a configured target directory. For text files, it concatenates their content into a single string. It is primarily designed for text-based files. However, it can also process image (e.g., .png, .jpg) and PDF (.pdf) files if their file names or extensions are explicitly included in the 'paths' argument. For these explicitly requested non-text files, their data is read and included in a format suitable for model consumption (e.g., base64 encoded).
This tool is useful when you need to understand or analyze a collection of files, such as:
- Getting an overview of a codebase or parts of it (e.g., all TypeScript files in the 'src' directory).
- Finding where specific functionality is implemented if the user asks broad questions about code.
- Reviewing documentation files (e.g., all Markdown files in the 'docs' directory).
- Gathering context from multiple configuration files.
- When the user asks to "read all files in X directory" or "show me the content of all Y files".
Use this tool when the user's query implies needing the content of several files simultaneously for context, analysis, or summarization. For text files, it uses default UTF-8 encoding and a '--- {filePath} ---' separator between file contents. Ensure paths are relative to the target directory. Glob patterns like 'src/**/*.js' are supported. Avoid using for single files if a more specific single-file reading tool is available, unless the user specifically requests to process a list containing just one file via this tool. Other binary files (not explicitly requested as image/PDF) are generally skipped. Default excludes apply to common non-text files (except for explicitly requested images/PDFs) and large dependency directories unless 'useDefaultExcludes' is false.`,
Kind.Read,
parameterSchema,
);
}
protected override validateToolParams(
params: ReadManyFilesParams,
): string | null {
const errors = SchemaValidator.validate(
this.schema.parametersJsonSchema,
params,
);
if (errors) {
return errors;
}
return null;
}
protected createInvocation(
params: ReadManyFilesParams,
): ToolInvocation<ReadManyFilesParams, ToolResult> {
return new ReadManyFilesToolInvocation(this.config, params);
}
}

View File

@@ -25,7 +25,6 @@ vi.mock('../utils/summarizer.js');
import { isCommandAllowed } from '../utils/shell-utils.js';
import { ShellTool } from './shell.js';
import { ToolErrorType } from './tool-error.js';
import { type Config } from '../config/config.js';
import {
type ShellExecutionResult,
@@ -98,22 +97,25 @@ describe('ShellTool', () => {
});
});
describe('validateToolParams', () => {
it('should return null for a valid command', () => {
expect(shellTool.validateToolParams({ command: 'ls -l' })).toBeNull();
describe('build', () => {
it('should return an invocation for a valid command', () => {
const invocation = shellTool.build({ command: 'ls -l' });
expect(invocation).toBeDefined();
});
it('should return an error for an empty command', () => {
expect(shellTool.validateToolParams({ command: ' ' })).toBe(
it('should throw an error for an empty command', () => {
expect(() => shellTool.build({ command: ' ' })).toThrow(
'Command cannot be empty.',
);
});
it('should return an error for a non-existent directory', () => {
it('should throw an error for a non-existent directory', () => {
vi.mocked(fs.existsSync).mockReturnValue(false);
expect(
shellTool.validateToolParams({ command: 'ls', directory: 'rel/path' }),
).toBe("Directory 'rel/path' is not a registered workspace directory.");
expect(() =>
shellTool.build({ command: 'ls', directory: 'rel/path' }),
).toThrow(
"Directory 'rel/path' is not a registered workspace directory.",
);
});
});
@@ -139,10 +141,8 @@ describe('ShellTool', () => {
};
it('should wrap command on linux and parse pgrep output', async () => {
const promise = shellTool.execute(
{ command: 'my-command &' },
mockAbortSignal,
);
const invocation = shellTool.build({ command: 'my-command &' });
const promise = invocation.execute(mockAbortSignal);
resolveShellExecution({ pid: 54321 });
vi.mocked(fs.existsSync).mockReturnValue(true);
@@ -164,8 +164,9 @@ describe('ShellTool', () => {
it('should not wrap command on windows', async () => {
vi.mocked(os.platform).mockReturnValue('win32');
const promise = shellTool.execute({ command: 'dir' }, mockAbortSignal);
resolveExecutionPromise({
const invocation = shellTool.build({ command: 'dir' });
const promise = invocation.execute(mockAbortSignal);
resolveShellExecution({
rawOutput: Buffer.from(''),
output: '',
stdout: '',
@@ -187,10 +188,8 @@ describe('ShellTool', () => {
it('should format error messages correctly', async () => {
const error = new Error('wrapped command failed');
const promise = shellTool.execute(
{ command: 'user-command' },
mockAbortSignal,
);
const invocation = shellTool.build({ command: 'user-command' });
const promise = invocation.execute(mockAbortSignal);
resolveShellExecution({
error,
exitCode: 1,
@@ -209,40 +208,19 @@ describe('ShellTool', () => {
expect(result.llmContent).not.toContain('pgrep');
});
it('should return error with error property for invalid parameters', async () => {
const result = await shellTool.execute(
{ command: '' }, // Empty command is invalid
mockAbortSignal,
it('should throw an error for invalid parameters', () => {
expect(() => shellTool.build({ command: '' })).toThrow(
'Command cannot be empty.',
);
expect(result.llmContent).toContain(
'Could not execute command due to invalid parameters:',
);
expect(result.returnDisplay).toBe('Command cannot be empty.');
expect(result.error).toEqual({
message: 'Command cannot be empty.',
type: ToolErrorType.INVALID_TOOL_PARAMS,
});
});
it('should return error with error property for invalid directory', async () => {
it('should throw an error for invalid directory', () => {
vi.mocked(fs.existsSync).mockReturnValue(false);
const result = await shellTool.execute(
{ command: 'ls', directory: 'nonexistent' },
mockAbortSignal,
expect(() =>
shellTool.build({ command: 'ls', directory: 'nonexistent' }),
).toThrow(
`Directory 'nonexistent' is not a registered workspace directory.`,
);
expect(result.llmContent).toContain(
'Could not execute command due to invalid parameters:',
);
expect(result.returnDisplay).toBe(
"Directory 'nonexistent' is not a registered workspace directory.",
);
expect(result.error).toEqual({
message:
"Directory 'nonexistent' is not a registered workspace directory.",
type: ToolErrorType.INVALID_TOOL_PARAMS,
});
});
it('should summarize output when configured', async () => {
@@ -253,7 +231,8 @@ describe('ShellTool', () => {
'summarized output',
);
const promise = shellTool.execute({ command: 'ls' }, mockAbortSignal);
const invocation = shellTool.build({ command: 'ls' });
const promise = invocation.execute(mockAbortSignal);
resolveExecutionPromise({
output: 'long output',
rawOutput: Buffer.from('long output'),
@@ -285,9 +264,8 @@ describe('ShellTool', () => {
});
vi.mocked(fs.existsSync).mockReturnValue(true); // Pretend the file exists
await expect(
shellTool.execute({ command: 'a-command' }, mockAbortSignal),
).rejects.toThrow(error);
const invocation = shellTool.build({ command: 'a-command' });
await expect(invocation.execute(mockAbortSignal)).rejects.toThrow(error);
const tmpFile = path.join(os.tmpdir(), 'shell_pgrep_abcdef.tmp');
expect(vi.mocked(fs.unlinkSync)).toHaveBeenCalledWith(tmpFile);
@@ -304,11 +282,8 @@ describe('ShellTool', () => {
});
it('should throttle text output updates', async () => {
const promise = shellTool.execute(
{ command: 'stream' },
mockAbortSignal,
updateOutputMock,
);
const invocation = shellTool.build({ command: 'stream' });
const promise = invocation.execute(mockAbortSignal, updateOutputMock);
// First chunk, should be throttled.
mockShellOutputCallback({
@@ -347,11 +322,8 @@ describe('ShellTool', () => {
});
it('should immediately show binary detection message and throttle progress', async () => {
const promise = shellTool.execute(
{ command: 'cat img' },
mockAbortSignal,
updateOutputMock,
);
const invocation = shellTool.build({ command: 'cat img' });
const promise = invocation.execute(mockAbortSignal, updateOutputMock);
mockShellOutputCallback({ type: 'binary_detected' });
expect(updateOutputMock).toHaveBeenCalledOnce();
@@ -394,13 +366,260 @@ describe('ShellTool', () => {
await promise;
});
});
describe('addCoAuthorToGitCommit', () => {
it('should add co-author to git commit with double quotes', async () => {
const command = 'git commit -m "Initial commit"';
const invocation = shellTool.build({ command });
const promise = invocation.execute(mockAbortSignal);
// Mock the shell execution to return success
resolveExecutionPromise({
rawOutput: Buffer.from(''),
output: '',
stdout: '',
stderr: '',
exitCode: 0,
signal: null,
error: null,
aborted: false,
pid: 12345,
});
await promise;
// Verify that the command was executed with co-author added
expect(mockShellExecutionService).toHaveBeenCalledWith(
expect.stringContaining(
'Co-authored-by: Qwen-Coder <qwen-coder@alibabacloud.com>',
),
expect.any(String),
expect.any(Function),
mockAbortSignal,
);
});
it('should add co-author to git commit with single quotes', async () => {
const command = "git commit -m 'Fix bug'";
const invocation = shellTool.build({ command });
const promise = invocation.execute(mockAbortSignal);
resolveExecutionPromise({
rawOutput: Buffer.from(''),
output: '',
stdout: '',
stderr: '',
exitCode: 0,
signal: null,
error: null,
aborted: false,
pid: 12345,
});
await promise;
expect(mockShellExecutionService).toHaveBeenCalledWith(
expect.stringContaining(
'Co-authored-by: Qwen-Coder <qwen-coder@alibabacloud.com>',
),
expect.any(String),
expect.any(Function),
mockAbortSignal,
);
});
it('should handle git commit with additional flags', async () => {
const command = 'git commit -a -m "Add feature"';
const invocation = shellTool.build({ command });
const promise = invocation.execute(mockAbortSignal);
resolveExecutionPromise({
rawOutput: Buffer.from(''),
output: '',
stdout: '',
stderr: '',
exitCode: 0,
signal: null,
error: null,
aborted: false,
pid: 12345,
});
await promise;
expect(mockShellExecutionService).toHaveBeenCalledWith(
expect.stringContaining(
'Co-authored-by: Qwen-Coder <qwen-coder@alibabacloud.com>',
),
expect.any(String),
expect.any(Function),
mockAbortSignal,
);
});
it('should not modify non-git commands', async () => {
const command = 'npm install';
const invocation = shellTool.build({ command });
const promise = invocation.execute(mockAbortSignal);
resolveExecutionPromise({
rawOutput: Buffer.from(''),
output: '',
stdout: '',
stderr: '',
exitCode: 0,
signal: null,
error: null,
aborted: false,
pid: 12345,
});
await promise;
// On Linux, commands are wrapped with pgrep functionality
expect(mockShellExecutionService).toHaveBeenCalledWith(
expect.stringContaining('npm install'),
expect.any(String),
expect.any(Function),
mockAbortSignal,
);
});
it('should not modify git commands without -m flag', async () => {
const command = 'git commit';
const invocation = shellTool.build({ command });
const promise = invocation.execute(mockAbortSignal);
resolveExecutionPromise({
rawOutput: Buffer.from(''),
output: '',
stdout: '',
stderr: '',
exitCode: 0,
signal: null,
error: null,
aborted: false,
pid: 12345,
});
await promise;
// On Linux, commands are wrapped with pgrep functionality
expect(mockShellExecutionService).toHaveBeenCalledWith(
expect.stringContaining('git commit'),
expect.any(String),
expect.any(Function),
mockAbortSignal,
);
});
it('should handle git commit with escaped quotes in message', async () => {
const command = 'git commit -m "Fix \\"quoted\\" text"';
const invocation = shellTool.build({ command });
const promise = invocation.execute(mockAbortSignal);
resolveExecutionPromise({
rawOutput: Buffer.from(''),
output: '',
stdout: '',
stderr: '',
exitCode: 0,
signal: null,
error: null,
aborted: false,
pid: 12345,
});
await promise;
expect(mockShellExecutionService).toHaveBeenCalledWith(
expect.stringContaining(
'Co-authored-by: Qwen-Coder <qwen-coder@alibabacloud.com>',
),
expect.any(String),
expect.any(Function),
mockAbortSignal,
);
});
it('should not add co-author when disabled in config', async () => {
// Mock config with disabled co-author
(mockConfig.getGitCoAuthor as Mock).mockReturnValue({
enabled: false,
name: 'Qwen-Coder',
email: 'qwen-coder@alibabacloud.com',
});
const command = 'git commit -m "Initial commit"';
const invocation = shellTool.build({ command });
const promise = invocation.execute(mockAbortSignal);
resolveExecutionPromise({
rawOutput: Buffer.from(''),
output: '',
stdout: '',
stderr: '',
exitCode: 0,
signal: null,
error: null,
aborted: false,
pid: 12345,
});
await promise;
// On Linux, commands are wrapped with pgrep functionality
expect(mockShellExecutionService).toHaveBeenCalledWith(
expect.stringContaining('git commit -m "Initial commit"'),
expect.any(String),
expect.any(Function),
mockAbortSignal,
);
});
it('should use custom name and email from config', async () => {
// Mock config with custom co-author details
(mockConfig.getGitCoAuthor as Mock).mockReturnValue({
enabled: true,
name: 'Custom Bot',
email: 'custom@example.com',
});
const command = 'git commit -m "Test commit"';
const invocation = shellTool.build({ command });
const promise = invocation.execute(mockAbortSignal);
resolveExecutionPromise({
rawOutput: Buffer.from(''),
output: '',
stdout: '',
stderr: '',
exitCode: 0,
signal: null,
error: null,
aborted: false,
pid: 12345,
});
await promise;
expect(mockShellExecutionService).toHaveBeenCalledWith(
expect.stringContaining(
'Co-authored-by: Custom Bot <custom@example.com>',
),
expect.any(String),
expect.any(Function),
mockAbortSignal,
);
});
});
});
describe('shouldConfirmExecute', () => {
it('should request confirmation for a new command and whitelist it on "Always"', async () => {
const params = { command: 'npm install' };
const confirmation = await shellTool.shouldConfirmExecute(
params,
const invocation = shellTool.build(params);
const confirmation = await invocation.shouldConfirmExecute(
new AbortController().signal,
);
@@ -413,136 +632,15 @@ describe('ShellTool', () => {
);
// Should now be whitelisted
const secondConfirmation = await shellTool.shouldConfirmExecute(
{ command: 'npm test' },
const secondInvocation = shellTool.build({ command: 'npm test' });
const secondConfirmation = await secondInvocation.shouldConfirmExecute(
new AbortController().signal,
);
expect(secondConfirmation).toBe(false);
});
it('should skip confirmation if validation fails', async () => {
const confirmation = await shellTool.shouldConfirmExecute(
{ command: '' },
new AbortController().signal,
);
expect(confirmation).toBe(false);
});
});
describe('addCoAuthorToGitCommit', () => {
it('should add co-author to git commit with double quotes', () => {
const command = 'git commit -m "Initial commit"';
// Use public test method
const result = (
shellTool as unknown as {
addCoAuthorToGitCommit: (command: string) => string;
}
).addCoAuthorToGitCommit(command);
expect(result).toBe(
`git commit -m "Initial commit
Co-authored-by: Qwen-Coder <qwen-coder@alibabacloud.com>"`,
);
});
it('should add co-author to git commit with single quotes', () => {
const command = "git commit -m 'Fix bug'";
const result = (
shellTool as unknown as {
addCoAuthorToGitCommit: (command: string) => string;
}
).addCoAuthorToGitCommit(command);
expect(result).toBe(
`git commit -m 'Fix bug
Co-authored-by: Qwen-Coder <qwen-coder@alibabacloud.com>'`,
);
});
it('should handle git commit with additional flags', () => {
const command = 'git commit -a -m "Add feature"';
const result = (
shellTool as unknown as {
addCoAuthorToGitCommit: (command: string) => string;
}
).addCoAuthorToGitCommit(command);
expect(result).toBe(
`git commit -a -m "Add feature
Co-authored-by: Qwen-Coder <qwen-coder@alibabacloud.com>"`,
);
});
it('should not modify non-git commands', () => {
const command = 'npm install';
const result = (
shellTool as unknown as {
addCoAuthorToGitCommit: (command: string) => string;
}
).addCoAuthorToGitCommit(command);
expect(result).toBe('npm install');
});
it('should not modify git commands without -m flag', () => {
const command = 'git commit';
const result = (
shellTool as unknown as {
addCoAuthorToGitCommit: (command: string) => string;
}
).addCoAuthorToGitCommit(command);
expect(result).toBe('git commit');
});
it('should handle git commit with escaped quotes in message', () => {
const command = 'git commit -m "Fix \\"quoted\\" text"';
const result = (
shellTool as unknown as {
addCoAuthorToGitCommit: (command: string) => string;
}
).addCoAuthorToGitCommit(command);
expect(result).toBe(
`git commit -m "Fix \\"quoted\\" text
Co-authored-by: Qwen-Coder <qwen-coder@alibabacloud.com>"`,
);
});
it('should not add co-author when disabled in config', () => {
// Mock config with disabled co-author
(mockConfig.getGitCoAuthor as Mock).mockReturnValue({
enabled: false,
name: 'Qwen-Coder',
email: 'qwen-coder@alibabacloud.com',
});
const command = 'git commit -m "Initial commit"';
const result = (
shellTool as unknown as {
addCoAuthorToGitCommit: (command: string) => string;
}
).addCoAuthorToGitCommit(command);
expect(result).toBe('git commit -m "Initial commit"');
});
it('should use custom name and email from config', () => {
// Mock config with custom co-author details
(mockConfig.getGitCoAuthor as Mock).mockReturnValue({
enabled: true,
name: 'Custom Bot',
email: 'custom@example.com',
});
const command = 'git commit -m "Test commit"';
const result = (
shellTool as unknown as {
addCoAuthorToGitCommit: (command: string) => string;
}
).addCoAuthorToGitCommit(command);
expect(result).toBe(
`git commit -m "Test commit
Co-authored-by: Custom Bot <custom@example.com>"`,
);
it('should throw an error if validation fails', () => {
expect(() => shellTool.build({ command: '' })).toThrow();
});
});
});
@@ -581,8 +679,8 @@ describe('validateToolParams', () => {
});
});
describe('validateToolParams', () => {
it('should return null for valid directory', () => {
describe('build', () => {
it('should return an invocation for valid directory', () => {
const config = {
getCoreTools: () => undefined,
getExcludeTools: () => undefined,
@@ -591,14 +689,14 @@ describe('validateToolParams', () => {
createMockWorkspaceContext('/root', ['/users/test']),
} as unknown as Config;
const shellTool = new ShellTool(config);
const result = shellTool.validateToolParams({
const invocation = shellTool.build({
command: 'ls',
directory: 'test',
});
expect(result).toBeNull();
expect(invocation).toBeDefined();
});
it('should return error for directory outside workspace', () => {
it('should throw an error for directory outside workspace', () => {
const config = {
getCoreTools: () => undefined,
getExcludeTools: () => undefined,
@@ -607,10 +705,11 @@ describe('validateToolParams', () => {
createMockWorkspaceContext('/root', ['/users/test']),
} as unknown as Config;
const shellTool = new ShellTool(config);
const result = shellTool.validateToolParams({
command: 'ls',
directory: 'test2',
});
expect(result).toContain('is not a registered workspace directory');
expect(() =>
shellTool.build({
command: 'ls',
directory: 'test2',
}),
).toThrow('is not a registered workspace directory');
});
});

View File

@@ -10,14 +10,15 @@ import os from 'os';
import crypto from 'crypto';
import { Config } from '../config/config.js';
import {
BaseTool,
BaseDeclarativeTool,
BaseToolInvocation,
ToolInvocation,
ToolResult,
ToolCallConfirmationDetails,
ToolExecuteConfirmationDetails,
ToolConfirmationOutcome,
Icon,
Kind,
} from './tools.js';
import { ToolErrorType } from './tool-error.js';
import { SchemaValidator } from '../utils/schemaValidator.js';
import { getErrorMessage } from '../utils/errors.js';
import { summarizeToolOutput } from '../utils/summarizer.js';
@@ -40,120 +41,36 @@ export interface ShellToolParams {
directory?: string;
}
export class ShellTool extends BaseTool<ShellToolParams, ToolResult> {
static Name: string = 'run_shell_command';
private allowlist: Set<string> = new Set();
constructor(private readonly config: Config) {
super(
ShellTool.Name,
'Shell',
`This tool executes a given shell command as \`bash -c <command>\`. Command can start background processes using \`&\`. Command is executed as a subprocess that leads its own process group. Command process group can be terminated as \`kill -- -PGID\` or signaled as \`kill -s SIGNAL -- -PGID\`.
The following information is returned:
Command: Executed command.
Directory: Directory (relative to project root) where command was executed, or \`(root)\`.
Stdout: Output on stdout stream. Can be \`(empty)\` or partial on error and for any unwaited background processes.
Stderr: Output on stderr stream. Can be \`(empty)\` or partial on error and for any unwaited background processes.
Error: Error or \`(none)\` if no error was reported for the subprocess.
Exit Code: Exit code or \`(none)\` if terminated by signal.
Signal: Signal number or \`(none)\` if no signal was received.
Background PIDs: List of background processes started or \`(none)\`.
Process Group PGID: Process group started or \`(none)\``,
Icon.Terminal,
{
type: 'object',
properties: {
command: {
type: 'string',
description: 'Exact bash command to execute as `bash -c <command>`',
},
description: {
type: 'string',
description:
'Brief description of the command for the user. Be specific and concise. Ideally a single sentence. Can be up to 3 sentences for clarity. No line breaks.',
},
directory: {
type: 'string',
description:
'(OPTIONAL) Directory to run the command in, if not the project root directory. Must be relative to the project root directory and must already exist.',
},
},
required: ['command'],
},
false, // output is not markdown
true, // output can be updated
);
class ShellToolInvocation extends BaseToolInvocation<
ShellToolParams,
ToolResult
> {
constructor(
private readonly config: Config,
params: ShellToolParams,
private readonly allowlist: Set<string>,
) {
super(params);
}
getDescription(params: ShellToolParams): string {
let description = `${params.command}`;
getDescription(): string {
let description = `${this.params.command}`;
// append optional [in directory]
// note description is needed even if validation fails due to absolute path
if (params.directory) {
description += ` [in ${params.directory}]`;
if (this.params.directory) {
description += ` [in ${this.params.directory}]`;
}
// append optional (description), replacing any line breaks with spaces
if (params.description) {
description += ` (${params.description.replace(/\n/g, ' ')})`;
if (this.params.description) {
description += ` (${this.params.description.replace(/\n/g, ' ')})`;
}
return description;
}
validateToolParams(params: ShellToolParams): string | null {
const commandCheck = isCommandAllowed(params.command, this.config);
if (!commandCheck.allowed) {
if (!commandCheck.reason) {
console.error(
'Unexpected: isCommandAllowed returned false without a reason',
);
return `Command is not allowed: ${params.command}`;
}
return commandCheck.reason;
}
const errors = SchemaValidator.validate(
this.schema.parametersJsonSchema,
params,
);
if (errors) {
return errors;
}
if (!params.command.trim()) {
return 'Command cannot be empty.';
}
if (getCommandRoots(params.command).length === 0) {
return 'Could not identify command root to obtain permission from user.';
}
if (params.directory) {
if (path.isAbsolute(params.directory)) {
return 'Directory cannot be absolute. Please refer to workspace directories by their name.';
}
const workspaceDirs = this.config.getWorkspaceContext().getDirectories();
const matchingDirs = workspaceDirs.filter(
(dir) => path.basename(dir) === params.directory,
);
if (matchingDirs.length === 0) {
return `Directory '${params.directory}' is not a registered workspace directory.`;
}
if (matchingDirs.length > 1) {
return `Directory name '${params.directory}' is ambiguous as it matches multiple workspace directories.`;
}
}
return null;
}
async shouldConfirmExecute(
params: ShellToolParams,
override async shouldConfirmExecute(
_abortSignal: AbortSignal,
): Promise<ToolCallConfirmationDetails | false> {
if (this.validateToolParams(params)) {
return false; // skip confirmation, execute call will fail immediately
}
const command = stripShellWrapper(params.command);
const command = stripShellWrapper(this.params.command);
const rootCommands = [...new Set(getCommandRoots(command))];
const commandsToConfirm = rootCommands.filter(
(command) => !this.allowlist.has(command),
@@ -166,7 +83,7 @@ export class ShellTool extends BaseTool<ShellToolParams, ToolResult> {
const confirmationDetails: ToolExecuteConfirmationDetails = {
type: 'exec',
title: 'Confirm Shell Command',
command: params.command,
command: this.params.command,
rootCommand: commandsToConfirm.join(', '),
onConfirm: async (outcome: ToolConfirmationOutcome) => {
if (outcome === ToolConfirmationOutcome.ProceedAlways) {
@@ -178,25 +95,10 @@ export class ShellTool extends BaseTool<ShellToolParams, ToolResult> {
}
async execute(
params: ShellToolParams,
signal: AbortSignal,
updateOutput?: (output: string) => void,
): Promise<ToolResult> {
const strippedCommand = stripShellWrapper(params.command);
const validationError = this.validateToolParams({
...params,
command: strippedCommand,
});
if (validationError) {
return {
llmContent: `Could not execute command due to invalid parameters: ${validationError}`,
returnDisplay: validationError,
error: {
message: validationError,
type: ToolErrorType.INVALID_TOOL_PARAMS,
},
};
}
const strippedCommand = stripShellWrapper(this.params.command);
if (signal.aborted) {
return {
@@ -227,7 +129,7 @@ export class ShellTool extends BaseTool<ShellToolParams, ToolResult> {
const cwd = path.resolve(
this.config.getTargetDir(),
params.directory || '',
this.params.directory || '',
);
let cumulativeStdout = '';
@@ -327,12 +229,12 @@ export class ShellTool extends BaseTool<ShellToolParams, ToolResult> {
// Create a formatted error string for display, replacing the wrapper command
// with the user-facing command.
const finalError = result.error
? result.error.message.replace(commandToExecute, params.command)
? result.error.message.replace(commandToExecute, this.params.command)
: '(none)';
llmContent = [
`Command: ${params.command}`,
`Directory: ${params.directory || '(root)'}`,
`Command: ${this.params.command}`,
`Directory: ${this.params.directory || '(root)'}`,
`Stdout: ${result.stdout || '(empty)'}`,
`Stderr: ${result.stderr || '(empty)'}`,
`Error: ${finalError}`, // Use the cleaned error string.
@@ -369,12 +271,12 @@ export class ShellTool extends BaseTool<ShellToolParams, ToolResult> {
}
const summarizeConfig = this.config.getSummarizeToolOutputConfig();
if (summarizeConfig && summarizeConfig[this.name]) {
if (summarizeConfig && summarizeConfig[ShellTool.Name]) {
const summary = await summarizeToolOutput(
llmContent,
this.config.getGeminiClient(),
signal,
summarizeConfig[this.name].tokenBudget,
summarizeConfig[ShellTool.Name].tokenBudget,
);
return {
llmContent: summary,
@@ -429,3 +331,104 @@ Co-authored-by: ${gitCoAuthorSettings.name} <${gitCoAuthorSettings.email}>`;
return command;
}
}
export class ShellTool extends BaseDeclarativeTool<
ShellToolParams,
ToolResult
> {
static Name: string = 'run_shell_command';
private allowlist: Set<string> = new Set();
constructor(private readonly config: Config) {
super(
ShellTool.Name,
'Shell',
`This tool executes a given shell command as \`bash -c <command>\`. Command can start background processes using \`&\`. Command is executed as a subprocess that leads its own process group. Command process group can be terminated as \`kill -- -PGID\` or signaled as \`kill -s SIGNAL -- -PGID\`.
The following information is returned:
Command: Executed command.
Directory: Directory (relative to project root) where command was executed, or \`(root)\`.
Stdout: Output on stdout stream. Can be \`(empty)\` or partial on error and for any unwaited background processes.
Stderr: Output on stderr stream. Can be \`(empty)\` or partial on error and for any unwaited background processes.
Error: Error or \`(none)\` if no error was reported for the subprocess.
Exit Code: Exit code or \`(none)\` if terminated by signal.
Signal: Signal number or \`(none)\` if no signal was received.
Background PIDs: List of background processes started or \`(none)\`.
Process Group PGID: Process group started or \`(none)\``,
Kind.Execute,
{
type: 'object',
properties: {
command: {
type: 'string',
description: 'Exact bash command to execute as `bash -c <command>`',
},
description: {
type: 'string',
description:
'Brief description of the command for the user. Be specific and concise. Ideally a single sentence. Can be up to 3 sentences for clarity. No line breaks.',
},
directory: {
type: 'string',
description:
'(OPTIONAL) Directory to run the command in, if not the project root directory. Must be relative to the project root directory and must already exist.',
},
},
required: ['command'],
},
false, // output is not markdown
true, // output can be updated
);
}
override validateToolParams(params: ShellToolParams): string | null {
const commandCheck = isCommandAllowed(params.command, this.config);
if (!commandCheck.allowed) {
if (!commandCheck.reason) {
console.error(
'Unexpected: isCommandAllowed returned false without a reason',
);
return `Command is not allowed: ${params.command}`;
}
return commandCheck.reason;
}
const errors = SchemaValidator.validate(
this.schema.parametersJsonSchema,
params,
);
if (errors) {
return errors;
}
if (!params.command.trim()) {
return 'Command cannot be empty.';
}
if (getCommandRoots(params.command).length === 0) {
return 'Could not identify command root to obtain permission from user.';
}
if (params.directory) {
if (path.isAbsolute(params.directory)) {
return 'Directory cannot be absolute. Please refer to workspace directories by their name.';
}
const workspaceDirs = this.config.getWorkspaceContext().getDirectories();
const matchingDirs = workspaceDirs.filter(
(dir) => path.basename(dir) === params.directory,
);
if (matchingDirs.length === 0) {
return `Directory '${params.directory}' is not a registered workspace directory.`;
}
if (matchingDirs.length > 1) {
return `Directory name '${params.directory}' is ambiguous as it matches multiple workspace directories.`;
}
}
return null;
}
protected createInvocation(
params: ShellToolParams,
): ToolInvocation<ShellToolParams, ToolResult> {
return new ShellToolInvocation(this.config, params, this.allowlist);
}
}

View File

@@ -5,7 +5,7 @@
*/
import { FunctionDeclaration } from '@google/genai';
import { AnyDeclarativeTool, Icon, ToolResult, BaseTool } from './tools.js';
import { AnyDeclarativeTool, Kind, ToolResult, BaseTool } from './tools.js';
import { Config } from '../config/config.js';
import { spawn } from 'node:child_process';
import { StringDecoder } from 'node:string_decoder';
@@ -19,8 +19,8 @@ export class DiscoveredTool extends BaseTool<ToolParams, ToolResult> {
constructor(
private readonly config: Config,
name: string,
readonly description: string,
readonly parameterSchema: Record<string, unknown>,
override readonly description: string,
override readonly parameterSchema: Record<string, unknown>,
) {
const discoveryCmd = config.getToolDiscoveryCommand()!;
const callCommand = config.getToolCallCommand()!;
@@ -44,7 +44,7 @@ Signal: Signal number or \`(none)\` if no signal was received.
name,
name,
description,
Icon.Hammer,
Kind.Other,
parameterSchema,
false, // isOutputMarkdown
false, // canUpdateOutput
@@ -158,6 +158,18 @@ export class ToolRegistry {
}
}
/**
* Removes all tools from a specific MCP server.
* @param serverName The name of the server to remove tools from.
*/
removeMcpToolsByServer(serverName: string): void {
for (const [name, tool] of this.tools.entries()) {
if (tool instanceof DiscoveredMCPTool && tool.serverName === serverName) {
this.tools.delete(name);
}
}
}
/**
* Discovers tools from project (if available and configured).
* Can be called multiple times to update discovered tools.

View File

@@ -145,9 +145,9 @@ export interface ToolBuilder<
description: string;
/**
* The icon to display when interacting via ACP.
* The kind of tool for categorization and permissions
*/
icon: Icon;
kind: Kind;
/**
* Function declaration schema from @google/genai.
@@ -185,7 +185,7 @@ export abstract class DeclarativeTool<
readonly name: string,
readonly displayName: string,
readonly description: string,
readonly icon: Icon,
readonly kind: Kind,
readonly parameterSchema: unknown,
readonly isOutputMarkdown: boolean = true,
readonly canUpdateOutput: boolean = false,
@@ -284,19 +284,19 @@ export abstract class BaseTool<
* @param parameterSchema JSON Schema defining the parameters
*/
constructor(
readonly name: string,
readonly displayName: string,
readonly description: string,
readonly icon: Icon,
readonly parameterSchema: unknown,
readonly isOutputMarkdown: boolean = true,
readonly canUpdateOutput: boolean = false,
override readonly name: string,
override readonly displayName: string,
override readonly description: string,
override readonly kind: Kind,
override readonly parameterSchema: unknown,
override readonly isOutputMarkdown: boolean = true,
override readonly canUpdateOutput: boolean = false,
) {
super(
name,
displayName,
description,
icon,
kind,
parameterSchema,
isOutputMarkdown,
canUpdateOutput,
@@ -320,7 +320,7 @@ export abstract class BaseTool<
* @returns An error message string if invalid, null otherwise
*/
// eslint-disable-next-line @typescript-eslint/no-unused-vars
validateToolParams(params: TParams): string | null {
override validateToolParams(params: TParams): string | null {
// Implementation would typically use a JSON Schema validator
// This is a placeholder that should be implemented by derived classes
return null;
@@ -570,15 +570,16 @@ export enum ToolConfirmationOutcome {
Cancel = 'cancel',
}
export enum Icon {
FileSearch = 'fileSearch',
Folder = 'folder',
Globe = 'globe',
Hammer = 'hammer',
LightBulb = 'lightBulb',
Pencil = 'pencil',
Regex = 'regex',
Terminal = 'terminal',
export enum Kind {
Read = 'read',
Edit = 'edit',
Delete = 'delete',
Move = 'move',
Search = 'search',
Execute = 'execute',
Think = 'think',
Fetch = 'fetch',
Other = 'other',
}
export interface ToolLocation {

View File

@@ -23,7 +23,10 @@ describe('WebFetchTool', () => {
url: 'https://example.com',
prompt: 'summarize this page',
};
const confirmationDetails = await tool.shouldConfirmExecute(params);
const invocation = tool.build(params);
const confirmationDetails = await invocation.shouldConfirmExecute(
new AbortController().signal,
);
expect(confirmationDetails).toEqual({
type: 'info',
@@ -41,7 +44,10 @@ describe('WebFetchTool', () => {
url: 'https://github.com/google/gemini-react/blob/main/README.md',
prompt: 'summarize the README',
};
const confirmationDetails = await tool.shouldConfirmExecute(params);
const invocation = tool.build(params);
const confirmationDetails = await invocation.shouldConfirmExecute(
new AbortController().signal,
);
expect(confirmationDetails).toEqual({
type: 'info',
@@ -62,7 +68,10 @@ describe('WebFetchTool', () => {
url: 'https://example.com',
prompt: 'summarize this page',
};
const confirmationDetails = await tool.shouldConfirmExecute(params);
const invocation = tool.build(params);
const confirmationDetails = await invocation.shouldConfirmExecute(
new AbortController().signal,
);
expect(confirmationDetails).toBe(false);
});
@@ -77,7 +86,10 @@ describe('WebFetchTool', () => {
url: 'https://example.com',
prompt: 'summarize this page',
};
const confirmationDetails = await tool.shouldConfirmExecute(params);
const invocation = tool.build(params);
const confirmationDetails = await invocation.shouldConfirmExecute(
new AbortController().signal,
);
if (
confirmationDetails &&

View File

@@ -6,15 +6,18 @@
import { SchemaValidator } from '../utils/schemaValidator.js';
import {
BaseTool,
ToolResult,
BaseDeclarativeTool,
BaseToolInvocation,
Kind,
ToolCallConfirmationDetails,
ToolConfirmationOutcome,
Icon,
ToolInvocation,
ToolResult,
} from './tools.js';
import { Config, ApprovalMode } from '../config/config.js';
import { getResponseText } from '../utils/generateContentResponseUtilities.js';
import { fetchWithTimeout } from '../utils/fetch.js';
import { fetchWithTimeout, isPrivateIp } from '../utils/fetch.js';
import { convert } from 'html-to-text';
import { ProxyAgent, setGlobalDispatcher } from 'undici';
@@ -35,18 +38,158 @@ export interface WebFetchToolParams {
prompt: string;
}
/**
* Implementation of the WebFetch tool invocation logic
*/
class WebFetchToolInvocation extends BaseToolInvocation<
WebFetchToolParams,
ToolResult
> {
constructor(
private readonly config: Config,
params: WebFetchToolParams,
) {
super(params);
}
private async executeDirectFetch(signal: AbortSignal): Promise<ToolResult> {
let url = this.params.url;
// Convert GitHub blob URL to raw URL
if (url.includes('github.com') && url.includes('/blob/')) {
url = url
.replace('github.com', 'raw.githubusercontent.com')
.replace('/blob/', '/');
console.debug(
`[WebFetchTool] Converted GitHub blob URL to raw URL: ${url}`,
);
}
try {
console.debug(`[WebFetchTool] Fetching content from: ${url}`);
const response = await fetchWithTimeout(url, URL_FETCH_TIMEOUT_MS);
if (!response.ok) {
const errorMessage = `Request failed with status code ${response.status} ${response.statusText}`;
console.error(`[WebFetchTool] ${errorMessage}`);
throw new Error(errorMessage);
}
console.debug(`[WebFetchTool] Successfully fetched content from ${url}`);
const html = await response.text();
const textContent = convert(html, {
wordwrap: false,
selectors: [
{ selector: 'a', options: { ignoreHref: true } },
{ selector: 'img', format: 'skip' },
],
}).substring(0, MAX_CONTENT_LENGTH);
console.debug(
`[WebFetchTool] Converted HTML to text (${textContent.length} characters)`,
);
const geminiClient = this.config.getGeminiClient();
const fallbackPrompt = `The user requested the following: "${this.params.prompt}".
I have fetched the content from ${this.params.url}. Please use the following content to answer the user's request.
---
${textContent}
---`;
console.debug(
`[WebFetchTool] Processing content with prompt: "${this.params.prompt}"`,
);
const result = await geminiClient.generateContent(
[{ role: 'user', parts: [{ text: fallbackPrompt }] }],
{},
signal,
);
const resultText = getResponseText(result) || '';
console.debug(
`[WebFetchTool] Successfully processed content from ${this.params.url}`,
);
return {
llmContent: resultText,
returnDisplay: `Content from ${this.params.url} processed successfully.`,
};
} catch (e) {
const error = e as Error;
const errorMessage = `Error during fetch for ${url}: ${error.message}`;
console.error(`[WebFetchTool] ${errorMessage}`, error);
return {
llmContent: `Error: ${errorMessage}`,
returnDisplay: `Error: ${errorMessage}`,
};
}
}
override getDescription(): string {
const displayPrompt =
this.params.prompt.length > 100
? this.params.prompt.substring(0, 97) + '...'
: this.params.prompt;
return `Fetching content from ${this.params.url} and processing with prompt: "${displayPrompt}"`;
}
override async shouldConfirmExecute(): Promise<
ToolCallConfirmationDetails | false
> {
if (this.config.getApprovalMode() === ApprovalMode.AUTO_EDIT) {
return false;
}
const confirmationDetails: ToolCallConfirmationDetails = {
type: 'info',
title: `Confirm Web Fetch`,
prompt: `Fetch content from ${this.params.url} and process with: ${this.params.prompt}`,
urls: [this.params.url],
onConfirm: async (outcome: ToolConfirmationOutcome) => {
if (outcome === ToolConfirmationOutcome.ProceedAlways) {
this.config.setApprovalMode(ApprovalMode.AUTO_EDIT);
}
},
};
return confirmationDetails;
}
async execute(signal: AbortSignal): Promise<ToolResult> {
// Check if URL is private/localhost
const isPrivate = isPrivateIp(this.params.url);
if (isPrivate) {
console.debug(
`[WebFetchTool] Private IP detected for ${this.params.url}, using direct fetch`,
);
} else {
console.debug(
`[WebFetchTool] Public URL detected for ${this.params.url}, using direct fetch`,
);
}
return this.executeDirectFetch(signal);
}
}
/**
* Implementation of the WebFetch tool logic
*/
export class WebFetchTool extends BaseTool<WebFetchToolParams, ToolResult> {
export class WebFetchTool extends BaseDeclarativeTool<
WebFetchToolParams,
ToolResult
> {
static readonly Name: string = 'web_fetch';
constructor(private readonly config: Config) {
super(
WebFetchTool.Name,
'WebFetch',
'Fetches content from a specified URL and processes it using an AI model\n- Takes a URL and a prompt as input\n- Fetches the URL content, converts HTML to markdown\n- Processes the content with the prompt using a small, fast model\n- Returns the model\'s response about the content\n- Use this tool when you need to retrieve and analyze web content\n\nUsage notes:\n - IMPORTANT: If an MCP-provided web fetch tool is available, prefer using that tool instead of this one, as it may have fewer restrictions. All MCP-provided tools start with "mcp__".\n - The URL must be a fully-formed valid URL\n - The prompt should describe what information you want to extract from the page\n - This tool is read-only and does not modify any files\n - Results may be summarized if the content is very large',
Icon.Globe,
'Fetches content from a specified URL and processes it using an AI model\n- Takes a URL and a prompt as input\n- Fetches the URL content, converts HTML to markdown\n- Processes the content with the prompt using a small, fast model\n- Returns the model\'s response about the content\n- Use this tool when you need to retrieve and analyze web content\n\nUsage notes:\n - IMPORTANT: If an MCP-provided web fetch tool is available, prefer using that tool instead of this one, as it may have fewer restrictions. All MCP-provided tools start with "mcp__".\n - The URL must be a fully-formed valid URL\n - The prompt should describe what information you want to extract from the page\n - This tool is read-only and does not modify any files\n - Results may be summarized if the content is very large\n - Supports both public and private/localhost URLs using direct fetch',
Kind.Fetch,
{
properties: {
url: {
@@ -68,64 +211,9 @@ export class WebFetchTool extends BaseTool<WebFetchToolParams, ToolResult> {
}
}
private async executeFetch(
protected override validateToolParams(
params: WebFetchToolParams,
signal: AbortSignal,
): Promise<ToolResult> {
let url = params.url;
// Convert GitHub blob URL to raw URL
if (url.includes('github.com') && url.includes('/blob/')) {
url = url
.replace('github.com', 'raw.githubusercontent.com')
.replace('/blob/', '/');
}
try {
const response = await fetchWithTimeout(url, URL_FETCH_TIMEOUT_MS);
if (!response.ok) {
throw new Error(
`Request failed with status code ${response.status} ${response.statusText}`,
);
}
const html = await response.text();
const textContent = convert(html, {
wordwrap: false,
selectors: [
{ selector: 'a', options: { ignoreHref: true } },
{ selector: 'img', format: 'skip' },
],
}).substring(0, MAX_CONTENT_LENGTH);
const geminiClient = this.config.getGeminiClient();
const fallbackPrompt = `The user requested the following: "${params.prompt}".
I have fetched the content from ${params.url}. Please use the following content to answer the user's request.
---
${textContent}
---`;
const result = await geminiClient.generateContent(
[{ role: 'user', parts: [{ text: fallbackPrompt }] }],
{},
signal,
);
const resultText = getResponseText(result) || '';
return {
llmContent: resultText,
returnDisplay: `Content from ${params.url} processed successfully.`,
};
} catch (e) {
const error = e as Error;
const errorMessage = `Error during fetch for ${url}: ${error.message}`;
return {
llmContent: `Error: ${errorMessage}`,
returnDisplay: `Error: ${errorMessage}`,
};
}
}
validateParams(params: WebFetchToolParams): string | null {
): string | null {
const errors = SchemaValidator.validate(
this.schema.parametersJsonSchema,
params,
@@ -148,52 +236,9 @@ ${textContent}
return null;
}
getDescription(params: WebFetchToolParams): string {
const displayPrompt =
params.prompt.length > 100
? params.prompt.substring(0, 97) + '...'
: params.prompt;
return `Fetching content from ${params.url} and processing with prompt: "${displayPrompt}"`;
}
async shouldConfirmExecute(
protected createInvocation(
params: WebFetchToolParams,
): Promise<ToolCallConfirmationDetails | false> {
if (this.config.getApprovalMode() === ApprovalMode.AUTO_EDIT) {
return false;
}
const validationError = this.validateParams(params);
if (validationError) {
return false;
}
const confirmationDetails: ToolCallConfirmationDetails = {
type: 'info',
title: `Confirm Web Fetch`,
prompt: `Fetch content from ${params.url} and process with: ${params.prompt}`,
urls: [params.url],
onConfirm: async (outcome: ToolConfirmationOutcome) => {
if (outcome === ToolConfirmationOutcome.ProceedAlways) {
this.config.setApprovalMode(ApprovalMode.AUTO_EDIT);
}
},
};
return confirmationDetails;
}
async execute(
params: WebFetchToolParams,
signal: AbortSignal,
): Promise<ToolResult> {
const validationError = this.validateParams(params);
if (validationError) {
return {
llmContent: `Error: Invalid parameters provided. Reason: ${validationError}`,
returnDisplay: validationError,
};
}
return this.executeFetch(params, signal);
): ToolInvocation<WebFetchToolParams, ToolResult> {
return new WebFetchToolInvocation(this.config, params);
}
}

View File

@@ -4,7 +4,7 @@
* SPDX-License-Identifier: Apache-2.0
*/
import { BaseTool, Icon, ToolResult } from './tools.js';
import { BaseTool, Kind, ToolResult } from './tools.js';
import { SchemaValidator } from '../utils/schemaValidator.js';
import { getErrorMessage } from '../utils/errors.js';
import { Config } from '../config/config.js';
@@ -54,7 +54,7 @@ export class WebSearchTool extends BaseTool<
WebSearchTool.Name,
'TavilySearch',
'Performs a web search using the Tavily API and returns a concise answer with sources. Requires the TAVILY_API_KEY environment variable.',
Icon.Globe,
Kind.Search,
{
type: 'object',
properties: {
@@ -88,7 +88,7 @@ export class WebSearchTool extends BaseTool<
return null;
}
getDescription(params: WebSearchToolParams): string {
override getDescription(params: WebSearchToolParams): string {
return `Searching the web for: "${params.query}"`;
}

View File

@@ -58,7 +58,6 @@ const mockConfigInternal = {
getGeminiClient: vi.fn(), // Initialize as a plain mock function
getIdeClient: vi.fn(),
getIdeMode: vi.fn(() => false),
getIdeModeFeature: vi.fn(() => false),
getWorkspaceContext: () => createMockWorkspaceContext(rootDir),
getApiKey: () => 'test-key',
getModel: () => 'test-model',

View File

@@ -15,7 +15,8 @@ import {
ToolEditConfirmationDetails,
ToolConfirmationOutcome,
ToolCallConfirmationDetails,
Icon,
Kind,
ToolLocation,
} from './tools.js';
import { ToolErrorType } from './tool-error.js';
import { SchemaValidator } from '../utils/schemaValidator.js';
@@ -82,7 +83,7 @@ export class WriteFileTool
`Writes content to a specified file in the local filesystem.
The user has the ability to modify \`content\`. If modified, this will be stated in the response.`,
Icon.Pencil,
Kind.Edit,
{
properties: {
file_path: {
@@ -101,7 +102,11 @@ export class WriteFileTool
);
}
validateToolParams(params: WriteFileToolParams): string | null {
override toolLocations(params: WriteFileToolParams): ToolLocation[] {
return [{ path: params.file_path }];
}
override validateToolParams(params: WriteFileToolParams): string | null {
const errors = SchemaValidator.validate(
this.schema.parametersJsonSchema,
params,
@@ -139,7 +144,7 @@ export class WriteFileTool
return null;
}
getDescription(params: WriteFileToolParams): string {
override getDescription(params: WriteFileToolParams): string {
if (!params.file_path) {
return `Model did not provide valid parameters for write file tool, missing or empty "file_path"`;
}
@@ -153,7 +158,7 @@ export class WriteFileTool
/**
* Handles the confirmation prompt for the WriteFile tool.
*/
async shouldConfirmExecute(
override async shouldConfirmExecute(
params: WriteFileToolParams,
abortSignal: AbortSignal,
): Promise<ToolCallConfirmationDetails | false> {
@@ -195,7 +200,6 @@ export class WriteFileTool
const ideClient = this.config.getIdeClient();
const ideConfirmation =
this.config.getIdeModeFeature() &&
this.config.getIdeMode() &&
ideClient.getConnectionStatus().status === IDEConnectionStatus.Connected
? ideClient.openDiff(params.file_path, correctedContent)

View File

@@ -79,6 +79,14 @@ describe('getEnvironmentContext', () => {
vi.useFakeTimers();
vi.setSystemTime(new Date('2025-08-05T12:00:00Z'));
// Mock the locale to ensure consistent English date formatting
vi.stubGlobal('Intl', {
...global.Intl,
DateTimeFormat: vi.fn().mockImplementation(() => ({
format: vi.fn().mockReturnValue('Tuesday, August 5, 2025'),
})),
});
mockToolRegistry = {
getTool: vi.fn(),
};
@@ -97,6 +105,7 @@ describe('getEnvironmentContext', () => {
afterEach(() => {
vi.useRealTimers();
vi.unstubAllGlobals();
vi.resetAllMocks();
});
@@ -106,7 +115,8 @@ describe('getEnvironmentContext', () => {
expect(parts.length).toBe(1);
const context = parts[0].text;
expect(context).toContain("Today's date is Tuesday, August 5, 2025");
// Use a more flexible date assertion that works with different locales
expect(context).toMatch(/Today's date is .*2025.*/);
expect(context).toContain(`My operating system is: ${process.platform}`);
expect(context).toContain(
"I'm currently working in the directory: /test/dir",

View File

@@ -0,0 +1,375 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { describe, it, expect } from 'vitest';
import { parseAndFormatApiError } from './errorParsing.js';
import { isProQuotaExceededError } from './quotaErrorDetection.js';
import { DEFAULT_GEMINI_FLASH_MODEL } from '../config/models.js';
import { UserTierId } from '../code_assist/types.js';
import { AuthType } from '../core/contentGenerator.js';
import { StructuredError } from '../core/turn.js';
describe('parseAndFormatApiError', () => {
const vertexMessage = 'request a quota increase through Vertex';
const geminiMessage = 'request a quota increase through AI Studio';
it('should format a valid API error JSON', () => {
const errorMessage =
'got status: 400 Bad Request. {"error":{"code":400,"message":"API key not valid. Please pass a valid API key.","status":"INVALID_ARGUMENT"}}';
const expected =
'[API Error: API key not valid. Please pass a valid API key. (Status: INVALID_ARGUMENT)]';
expect(parseAndFormatApiError(errorMessage)).toBe(expected);
});
it('should format a 429 API error with the default message', () => {
const errorMessage =
'got status: 429 Too Many Requests. {"error":{"code":429,"message":"Rate limit exceeded","status":"RESOURCE_EXHAUSTED"}}';
const result = parseAndFormatApiError(
errorMessage,
undefined,
undefined,
'gemini-2.5-pro',
DEFAULT_GEMINI_FLASH_MODEL,
);
expect(result).toContain('[API Error: Rate limit exceeded');
expect(result).toContain(
'Possible quota limitations in place or slow response times detected. Switching to the gemini-2.5-flash model',
);
});
it('should format a 429 API error with the personal message', () => {
const errorMessage =
'got status: 429 Too Many Requests. {"error":{"code":429,"message":"Rate limit exceeded","status":"RESOURCE_EXHAUSTED"}}';
const result = parseAndFormatApiError(
errorMessage,
AuthType.LOGIN_WITH_GOOGLE,
undefined,
'gemini-2.5-pro',
DEFAULT_GEMINI_FLASH_MODEL,
);
expect(result).toContain('[API Error: Rate limit exceeded');
expect(result).toContain(
'Possible quota limitations in place or slow response times detected. Switching to the gemini-2.5-flash model',
);
});
it('should format a 429 API error with the vertex message', () => {
const errorMessage =
'got status: 429 Too Many Requests. {"error":{"code":429,"message":"Rate limit exceeded","status":"RESOURCE_EXHAUSTED"}}';
const result = parseAndFormatApiError(errorMessage, AuthType.USE_VERTEX_AI);
expect(result).toContain('[API Error: Rate limit exceeded');
expect(result).toContain(vertexMessage);
});
it('should return the original message if it is not a JSON error', () => {
const errorMessage = 'This is a plain old error message';
expect(parseAndFormatApiError(errorMessage)).toBe(
`[API Error: ${errorMessage}]`,
);
});
it('should return the original message for malformed JSON', () => {
const errorMessage = '[Stream Error: {"error": "malformed}';
expect(parseAndFormatApiError(errorMessage)).toBe(
`[API Error: ${errorMessage}]`,
);
});
it('should handle JSON that does not match the ApiError structure', () => {
const errorMessage = '[Stream Error: {"not_an_error": "some other json"}]';
expect(parseAndFormatApiError(errorMessage)).toBe(
`[API Error: ${errorMessage}]`,
);
});
it('should format a nested API error', () => {
const nestedErrorMessage = JSON.stringify({
error: {
code: 429,
message:
"Gemini 2.5 Pro Preview doesn't have a free quota tier. For more information on this error, head to: https://ai.google.dev/gemini-api/docs/rate-limits.",
status: 'RESOURCE_EXHAUSTED',
},
});
const errorMessage = JSON.stringify({
error: {
code: 429,
message: nestedErrorMessage,
status: 'Too Many Requests',
},
});
const result = parseAndFormatApiError(errorMessage, AuthType.USE_GEMINI);
expect(result).toContain('Gemini 2.5 Pro Preview');
expect(result).toContain(geminiMessage);
});
it('should format a StructuredError', () => {
const error: StructuredError = {
message: 'A structured error occurred',
status: 500,
};
const expected = '[API Error: A structured error occurred]';
expect(parseAndFormatApiError(error)).toBe(expected);
});
it('should format a 429 StructuredError with the vertex message', () => {
const error: StructuredError = {
message: 'Rate limit exceeded',
status: 429,
};
const result = parseAndFormatApiError(error, AuthType.USE_VERTEX_AI);
expect(result).toContain('[API Error: Rate limit exceeded]');
expect(result).toContain(vertexMessage);
});
it('should handle an unknown error type', () => {
const error = 12345;
const expected = '[API Error: An unknown error occurred.]';
expect(parseAndFormatApiError(error)).toBe(expected);
});
it('should format a 429 API error with Pro quota exceeded message for Google auth (Free tier)', () => {
const errorMessage =
'got status: 429 Too Many Requests. {"error":{"code":429,"message":"Quota exceeded for quota metric \'Gemini 2.5 Pro Requests\' and limit \'RequestsPerDay\' of service \'generativelanguage.googleapis.com\' for consumer \'project_number:123456789\'.","status":"RESOURCE_EXHAUSTED"}}';
const result = parseAndFormatApiError(
errorMessage,
AuthType.LOGIN_WITH_GOOGLE,
undefined,
'gemini-2.5-pro',
DEFAULT_GEMINI_FLASH_MODEL,
);
expect(result).toContain(
"[API Error: Quota exceeded for quota metric 'Gemini 2.5 Pro Requests'",
);
expect(result).toContain(
'You have reached your daily gemini-2.5-pro quota limit',
);
expect(result).toContain(
'upgrade to a Gemini Code Assist Standard or Enterprise plan',
);
});
it('should format a regular 429 API error with standard message for Google auth', () => {
const errorMessage =
'got status: 429 Too Many Requests. {"error":{"code":429,"message":"Rate limit exceeded","status":"RESOURCE_EXHAUSTED"}}';
const result = parseAndFormatApiError(
errorMessage,
AuthType.LOGIN_WITH_GOOGLE,
undefined,
'gemini-2.5-pro',
DEFAULT_GEMINI_FLASH_MODEL,
);
expect(result).toContain('[API Error: Rate limit exceeded');
expect(result).toContain(
'Possible quota limitations in place or slow response times detected. Switching to the gemini-2.5-flash model',
);
expect(result).not.toContain(
'You have reached your daily gemini-2.5-pro quota limit',
);
});
it('should format a 429 API error with generic quota exceeded message for Google auth', () => {
const errorMessage =
'got status: 429 Too Many Requests. {"error":{"code":429,"message":"Quota exceeded for quota metric \'GenerationRequests\' and limit \'RequestsPerDay\' of service \'generativelanguage.googleapis.com\' for consumer \'project_number:123456789\'.","status":"RESOURCE_EXHAUSTED"}}';
const result = parseAndFormatApiError(
errorMessage,
AuthType.LOGIN_WITH_GOOGLE,
undefined,
'gemini-2.5-pro',
DEFAULT_GEMINI_FLASH_MODEL,
);
expect(result).toContain(
"[API Error: Quota exceeded for quota metric 'GenerationRequests'",
);
expect(result).toContain('You have reached your daily quota limit');
expect(result).not.toContain(
'You have reached your daily Gemini 2.5 Pro quota limit',
);
});
it('should prioritize Pro quota message over generic quota message for Google auth', () => {
const errorMessage =
'got status: 429 Too Many Requests. {"error":{"code":429,"message":"Quota exceeded for quota metric \'Gemini 2.5 Pro Requests\' and limit \'RequestsPerDay\' of service \'generativelanguage.googleapis.com\' for consumer \'project_number:123456789\'.","status":"RESOURCE_EXHAUSTED"}}';
const result = parseAndFormatApiError(
errorMessage,
AuthType.LOGIN_WITH_GOOGLE,
undefined,
'gemini-2.5-pro',
DEFAULT_GEMINI_FLASH_MODEL,
);
expect(result).toContain(
"[API Error: Quota exceeded for quota metric 'Gemini 2.5 Pro Requests'",
);
expect(result).toContain(
'You have reached your daily gemini-2.5-pro quota limit',
);
expect(result).not.toContain('You have reached your daily quota limit');
});
it('should format a 429 API error with Pro quota exceeded message for Google auth (Standard tier)', () => {
const errorMessage =
'got status: 429 Too Many Requests. {"error":{"code":429,"message":"Quota exceeded for quota metric \'Gemini 2.5 Pro Requests\' and limit \'RequestsPerDay\' of service \'generativelanguage.googleapis.com\' for consumer \'project_number:123456789\'.","status":"RESOURCE_EXHAUSTED"}}';
const result = parseAndFormatApiError(
errorMessage,
AuthType.LOGIN_WITH_GOOGLE,
UserTierId.STANDARD,
'gemini-2.5-pro',
DEFAULT_GEMINI_FLASH_MODEL,
);
expect(result).toContain(
"[API Error: Quota exceeded for quota metric 'Gemini 2.5 Pro Requests'",
);
expect(result).toContain(
'You have reached your daily gemini-2.5-pro quota limit',
);
expect(result).toContain(
'We appreciate you for choosing Gemini Code Assist and the Gemini CLI',
);
expect(result).not.toContain(
'upgrade to a Gemini Code Assist Standard or Enterprise plan',
);
});
it('should format a 429 API error with Pro quota exceeded message for Google auth (Legacy tier)', () => {
const errorMessage =
'got status: 429 Too Many Requests. {"error":{"code":429,"message":"Quota exceeded for quota metric \'Gemini 2.5 Pro Requests\' and limit \'RequestsPerDay\' of service \'generativelanguage.googleapis.com\' for consumer \'project_number:123456789\'.","status":"RESOURCE_EXHAUSTED"}}';
const result = parseAndFormatApiError(
errorMessage,
AuthType.LOGIN_WITH_GOOGLE,
UserTierId.LEGACY,
'gemini-2.5-pro',
DEFAULT_GEMINI_FLASH_MODEL,
);
expect(result).toContain(
"[API Error: Quota exceeded for quota metric 'Gemini 2.5 Pro Requests'",
);
expect(result).toContain(
'You have reached your daily gemini-2.5-pro quota limit',
);
expect(result).toContain(
'We appreciate you for choosing Gemini Code Assist and the Gemini CLI',
);
expect(result).not.toContain(
'upgrade to a Gemini Code Assist Standard or Enterprise plan',
);
});
it('should handle different Gemini 2.5 version strings in Pro quota exceeded errors', () => {
const errorMessage25 =
'got status: 429 Too Many Requests. {"error":{"code":429,"message":"Quota exceeded for quota metric \'Gemini 2.5 Pro Requests\' and limit \'RequestsPerDay\' of service \'generativelanguage.googleapis.com\' for consumer \'project_number:123456789\'.","status":"RESOURCE_EXHAUSTED"}}';
const errorMessagePreview =
'got status: 429 Too Many Requests. {"error":{"code":429,"message":"Quota exceeded for quota metric \'Gemini 2.5-preview Pro Requests\' and limit \'RequestsPerDay\' of service \'generativelanguage.googleapis.com\' for consumer \'project_number:123456789\'.","status":"RESOURCE_EXHAUSTED"}}';
const result25 = parseAndFormatApiError(
errorMessage25,
AuthType.LOGIN_WITH_GOOGLE,
undefined,
'gemini-2.5-pro',
DEFAULT_GEMINI_FLASH_MODEL,
);
const resultPreview = parseAndFormatApiError(
errorMessagePreview,
AuthType.LOGIN_WITH_GOOGLE,
undefined,
'gemini-2.5-preview-pro',
DEFAULT_GEMINI_FLASH_MODEL,
);
expect(result25).toContain(
'You have reached your daily gemini-2.5-pro quota limit',
);
expect(resultPreview).toContain(
'You have reached your daily gemini-2.5-preview-pro quota limit',
);
expect(result25).toContain(
'upgrade to a Gemini Code Assist Standard or Enterprise plan',
);
expect(resultPreview).toContain(
'upgrade to a Gemini Code Assist Standard or Enterprise plan',
);
});
it('should not match non-Pro models with similar version strings', () => {
// Test that Flash models with similar version strings don't match
expect(
isProQuotaExceededError(
"Quota exceeded for quota metric 'Gemini 2.5 Flash Requests' and limit",
),
).toBe(false);
expect(
isProQuotaExceededError(
"Quota exceeded for quota metric 'Gemini 2.5-preview Flash Requests' and limit",
),
).toBe(false);
// Test other model types
expect(
isProQuotaExceededError(
"Quota exceeded for quota metric 'Gemini 2.5 Ultra Requests' and limit",
),
).toBe(false);
expect(
isProQuotaExceededError(
"Quota exceeded for quota metric 'Gemini 2.5 Standard Requests' and limit",
),
).toBe(false);
// Test generic quota messages
expect(
isProQuotaExceededError(
"Quota exceeded for quota metric 'GenerationRequests' and limit",
),
).toBe(false);
expect(
isProQuotaExceededError(
"Quota exceeded for quota metric 'EmbeddingRequests' and limit",
),
).toBe(false);
});
it('should format a generic quota exceeded message for Google auth (Standard tier)', () => {
const errorMessage =
'got status: 429 Too Many Requests. {"error":{"code":429,"message":"Quota exceeded for quota metric \'GenerationRequests\' and limit \'RequestsPerDay\' of service \'generativelanguage.googleapis.com\' for consumer \'project_number:123456789\'.","status":"RESOURCE_EXHAUSTED"}}';
const result = parseAndFormatApiError(
errorMessage,
AuthType.LOGIN_WITH_GOOGLE,
UserTierId.STANDARD,
'gemini-2.5-pro',
DEFAULT_GEMINI_FLASH_MODEL,
);
expect(result).toContain(
"[API Error: Quota exceeded for quota metric 'GenerationRequests'",
);
expect(result).toContain('You have reached your daily quota limit');
expect(result).toContain(
'We appreciate you for choosing Gemini Code Assist and the Gemini CLI',
);
expect(result).not.toContain(
'upgrade to a Gemini Code Assist Standard or Enterprise plan',
);
});
it('should format a regular 429 API error with standard message for Google auth (Standard tier)', () => {
const errorMessage =
'got status: 429 Too Many Requests. {"error":{"code":429,"message":"Rate limit exceeded","status":"RESOURCE_EXHAUSTED"}}';
const result = parseAndFormatApiError(
errorMessage,
AuthType.LOGIN_WITH_GOOGLE,
UserTierId.STANDARD,
'gemini-2.5-pro',
DEFAULT_GEMINI_FLASH_MODEL,
);
expect(result).toContain('[API Error: Rate limit exceeded');
expect(result).toContain(
'We appreciate you for choosing Gemini Code Assist and the Gemini CLI',
);
expect(result).not.toContain(
'upgrade to a Gemini Code Assist Standard or Enterprise plan',
);
});
});

View File

@@ -0,0 +1,166 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import {
isProQuotaExceededError,
isGenericQuotaExceededError,
isApiError,
isStructuredError,
} from './quotaErrorDetection.js';
import {
DEFAULT_GEMINI_MODEL,
DEFAULT_GEMINI_FLASH_MODEL,
} from '../config/models.js';
import { UserTierId } from '../code_assist/types.js';
import { AuthType } from '../core/contentGenerator.js';
// Free Tier message functions
const getRateLimitErrorMessageGoogleFree = (
fallbackModel: string = DEFAULT_GEMINI_FLASH_MODEL,
) =>
`\nPossible quota limitations in place or slow response times detected. Switching to the ${fallbackModel} model for the rest of this session.`;
const getRateLimitErrorMessageGoogleProQuotaFree = (
currentModel: string = DEFAULT_GEMINI_MODEL,
fallbackModel: string = DEFAULT_GEMINI_FLASH_MODEL,
) =>
`\nYou have reached your daily ${currentModel} quota limit. You will be switched to the ${fallbackModel} model for the rest of this session. To increase your limits, upgrade to a Gemini Code Assist Standard or Enterprise plan with higher limits at https://goo.gle/set-up-gemini-code-assist, or use /auth to switch to using a paid API key from AI Studio at https://aistudio.google.com/apikey`;
const getRateLimitErrorMessageGoogleGenericQuotaFree = () =>
`\nYou have reached your daily quota limit. To increase your limits, upgrade to a Gemini Code Assist Standard or Enterprise plan with higher limits at https://goo.gle/set-up-gemini-code-assist, or use /auth to switch to using a paid API key from AI Studio at https://aistudio.google.com/apikey`;
// Legacy/Standard Tier message functions
const getRateLimitErrorMessageGooglePaid = (
fallbackModel: string = DEFAULT_GEMINI_FLASH_MODEL,
) =>
`\nPossible quota limitations in place or slow response times detected. Switching to the ${fallbackModel} model for the rest of this session. We appreciate you for choosing Gemini Code Assist and the Gemini CLI.`;
const getRateLimitErrorMessageGoogleProQuotaPaid = (
currentModel: string = DEFAULT_GEMINI_MODEL,
fallbackModel: string = DEFAULT_GEMINI_FLASH_MODEL,
) =>
`\nYou have reached your daily ${currentModel} quota limit. You will be switched to the ${fallbackModel} model for the rest of this session. We appreciate you for choosing Gemini Code Assist and the Gemini CLI. To continue accessing the ${currentModel} model today, consider using /auth to switch to using a paid API key from AI Studio at https://aistudio.google.com/apikey`;
const getRateLimitErrorMessageGoogleGenericQuotaPaid = (
currentModel: string = DEFAULT_GEMINI_MODEL,
) =>
`\nYou have reached your daily quota limit. We appreciate you for choosing Gemini Code Assist and the Gemini CLI. To continue accessing the ${currentModel} model today, consider using /auth to switch to using a paid API key from AI Studio at https://aistudio.google.com/apikey`;
const RATE_LIMIT_ERROR_MESSAGE_USE_GEMINI =
'\nPlease wait and try again later. To increase your limits, request a quota increase through AI Studio, or switch to another /auth method';
const RATE_LIMIT_ERROR_MESSAGE_VERTEX =
'\nPlease wait and try again later. To increase your limits, request a quota increase through Vertex, or switch to another /auth method';
const getRateLimitErrorMessageDefault = (
fallbackModel: string = DEFAULT_GEMINI_FLASH_MODEL,
) =>
`\nPossible quota limitations in place or slow response times detected. Switching to the ${fallbackModel} model for the rest of this session.`;
function getRateLimitMessage(
authType?: AuthType,
error?: unknown,
userTier?: UserTierId,
currentModel?: string,
fallbackModel?: string,
): string {
switch (authType) {
case AuthType.LOGIN_WITH_GOOGLE: {
// Determine if user is on a paid tier (Legacy or Standard) - default to FREE if not specified
const isPaidTier =
userTier === UserTierId.LEGACY || userTier === UserTierId.STANDARD;
if (isProQuotaExceededError(error)) {
return isPaidTier
? getRateLimitErrorMessageGoogleProQuotaPaid(
currentModel || DEFAULT_GEMINI_MODEL,
fallbackModel,
)
: getRateLimitErrorMessageGoogleProQuotaFree(
currentModel || DEFAULT_GEMINI_MODEL,
fallbackModel,
);
} else if (isGenericQuotaExceededError(error)) {
return isPaidTier
? getRateLimitErrorMessageGoogleGenericQuotaPaid(
currentModel || DEFAULT_GEMINI_MODEL,
)
: getRateLimitErrorMessageGoogleGenericQuotaFree();
} else {
return isPaidTier
? getRateLimitErrorMessageGooglePaid(fallbackModel)
: getRateLimitErrorMessageGoogleFree(fallbackModel);
}
}
case AuthType.USE_GEMINI:
return RATE_LIMIT_ERROR_MESSAGE_USE_GEMINI;
case AuthType.USE_VERTEX_AI:
return RATE_LIMIT_ERROR_MESSAGE_VERTEX;
default:
return getRateLimitErrorMessageDefault(fallbackModel);
}
}
export function parseAndFormatApiError(
error: unknown,
authType?: AuthType,
userTier?: UserTierId,
currentModel?: string,
fallbackModel?: string,
): string {
if (isStructuredError(error)) {
let text = `[API Error: ${error.message}]`;
if (error.status === 429) {
text += getRateLimitMessage(
authType,
error,
userTier,
currentModel,
fallbackModel,
);
}
return text;
}
// The error message might be a string containing a JSON object.
if (typeof error === 'string') {
const jsonStart = error.indexOf('{');
if (jsonStart === -1) {
return `[API Error: ${error}]`; // Not a JSON error, return as is.
}
const jsonString = error.substring(jsonStart);
try {
const parsedError = JSON.parse(jsonString) as unknown;
if (isApiError(parsedError)) {
let finalMessage = parsedError.error.message;
try {
// See if the message is a stringified JSON with another error
const nestedError = JSON.parse(finalMessage) as unknown;
if (isApiError(nestedError)) {
finalMessage = nestedError.error.message;
}
} catch (_e) {
// It's not a nested JSON error, so we just use the message as is.
}
let text = `[API Error: ${finalMessage} (Status: ${parsedError.error.status})]`;
if (parsedError.error.code === 429) {
text += getRateLimitMessage(
authType,
parsedError,
userTier,
currentModel,
fallbackModel,
);
}
return text;
}
} catch (_e) {
// Not a valid JSON, fall through and return the original message.
}
return `[API Error: ${error}]`;
}
return '[API Error: An unknown error occurred.]';
}

View File

@@ -289,7 +289,7 @@ export class FileSearch {
* Builds the in-memory cache for fast pattern matching.
*/
private buildResultCache(): void {
this.resultCache = new ResultCache(this.allFiles, this.absoluteDir);
this.resultCache = new ResultCache(this.allFiles);
// The v1 algorithm is much faster since it only looks at the first
// occurence of the pattern. We use it for search spaces that have >20k
// files, because the v2 algorithm is just too slow in those cases.

View File

@@ -4,7 +4,6 @@
* SPDX-License-Identifier: Apache-2.0
*/
import path from 'node:path';
import { test, expect } from 'vitest';
import { ResultCache } from './result-cache.js';
@@ -17,7 +16,7 @@ test('ResultCache basic usage', async () => {
'subdir/other.js',
'subdir/nested/file.md',
];
const cache = new ResultCache(files, path.resolve('.'));
const cache = new ResultCache(files);
const { files: resultFiles, isExactMatch } = await cache.get('*.js');
expect(resultFiles).toEqual(files);
expect(isExactMatch).toBe(false);
@@ -25,7 +24,7 @@ test('ResultCache basic usage', async () => {
test('ResultCache cache hit/miss', async () => {
const files = ['foo.txt', 'bar.js', 'baz.md'];
const cache = new ResultCache(files, path.resolve('.'));
const cache = new ResultCache(files);
// First call: miss
const { files: result1Files, isExactMatch: isExactMatch1 } =
await cache.get('*.js');
@@ -44,7 +43,7 @@ test('ResultCache cache hit/miss', async () => {
test('ResultCache best base query', async () => {
const files = ['foo.txt', 'foobar.js', 'baz.md'];
const cache = new ResultCache(files, path.resolve('.'));
const cache = new ResultCache(files);
// Cache a broader query
cache.set('foo', ['foo.txt', 'foobar.js']);

View File

@@ -13,10 +13,7 @@ export class ResultCache {
private hits = 0;
private misses = 0;
constructor(
private readonly allFiles: string[],
private readonly absoluteDir: string,
) {
constructor(private readonly allFiles: string[]) {
this.cache = new Map();
}

View File

@@ -11,7 +11,7 @@ import { marked } from 'marked';
import { processImports, validateImportPath } from './memoryImportProcessor.js';
// Helper function to create platform-agnostic test paths
const testPath = (...segments: string[]) => {
function testPath(...segments: string[]): string {
// Start with the first segment as is (might be an absolute path on Windows)
let result = segments[0];
@@ -27,9 +27,8 @@ const testPath = (...segments: string[]) => {
}
return path.normalize(result);
};
}
// Mock fs/promises
vi.mock('fs/promises');
const mockedFs = vi.mocked(fs);
@@ -509,21 +508,21 @@ describe('memoryImportProcessor', () => {
expect(result.importTree.imports).toHaveLength(2);
// First import: nested.md
// Prefix with underscore to indicate they're intentionally unused
const _expectedNestedPath = testPath(projectRoot, 'src', 'nested.md');
const _expectedInnerPath = testPath(projectRoot, 'src', 'inner.md');
const _expectedSimplePath = testPath(projectRoot, 'src', 'simple.md');
// Check that the paths match using includes to handle potential absolute/relative differences
expect(result.importTree.imports![0].path).toContain('nested.md');
const expectedNestedPath = testPath(projectRoot, 'src', 'nested.md');
expect(result.importTree.imports![0].path).toContain(expectedNestedPath);
expect(result.importTree.imports![0].imports).toHaveLength(1);
const expectedInnerPath = testPath(projectRoot, 'src', 'inner.md');
expect(result.importTree.imports![0].imports![0].path).toContain(
'inner.md',
expectedInnerPath,
);
expect(result.importTree.imports![0].imports![0].imports).toBeUndefined();
// Second import: simple.md
expect(result.importTree.imports![1].path).toContain('simple.md');
const expectedSimplePath = testPath(projectRoot, 'src', 'simple.md');
expect(result.importTree.imports![1].path).toContain(expectedSimplePath);
expect(result.importTree.imports![1].imports).toBeUndefined();
});
@@ -724,21 +723,20 @@ describe('memoryImportProcessor', () => {
expect(result.importTree.imports).toHaveLength(2);
// First import: nested.md
// Prefix with underscore to indicate they're intentionally unused
const _expectedNestedPath = testPath(projectRoot, 'src', 'nested.md');
const _expectedInnerPath = testPath(projectRoot, 'src', 'inner.md');
const _expectedSimplePath = testPath(projectRoot, 'src', 'simple.md');
const expectedNestedPath = testPath(projectRoot, 'src', 'nested.md');
const expectedInnerPath = testPath(projectRoot, 'src', 'inner.md');
const expectedSimplePath = testPath(projectRoot, 'src', 'simple.md');
// Check that the paths match using includes to handle potential absolute/relative differences
expect(result.importTree.imports![0].path).toContain('nested.md');
expect(result.importTree.imports![0].path).toContain(expectedNestedPath);
expect(result.importTree.imports![0].imports).toHaveLength(1);
expect(result.importTree.imports![0].imports![0].path).toContain(
'inner.md',
expectedInnerPath,
);
expect(result.importTree.imports![0].imports![0].imports).toBeUndefined();
// Second import: simple.md
expect(result.importTree.imports![1].path).toContain('simple.md');
expect(result.importTree.imports![1].path).toContain(expectedSimplePath);
expect(result.importTree.imports![1].imports).toBeUndefined();
});
@@ -899,7 +897,7 @@ describe('memoryImportProcessor', () => {
// Test relative paths - resolve them against basePath
const relativePath = './file.md';
const _resolvedRelativePath = path.resolve(basePath, relativePath);
path.resolve(basePath, relativePath);
expect(validateImportPath(relativePath, basePath, [basePath])).toBe(true);
// Test parent directory access (should be allowed if parent is in allowed paths)
@@ -907,12 +905,12 @@ describe('memoryImportProcessor', () => {
if (parentPath !== basePath) {
// Only test if parent is different
const parentRelativePath = '../file.md';
const _resolvedParentPath = path.resolve(basePath, parentRelativePath);
path.resolve(basePath, parentRelativePath);
expect(
validateImportPath(parentRelativePath, basePath, [parentPath]),
).toBe(true);
const _resolvedSubPath = path.resolve(basePath, 'sub');
path.resolve(basePath, 'sub');
const resultSub = validateImportPath('sub', basePath, [basePath]);
expect(resultSub).toBe(true);
}

View File

@@ -261,7 +261,7 @@ export async function processImports(
// Process imports in reverse order to handle indices correctly
for (let i = imports.length - 1; i >= 0; i--) {
const { start, _end, path: importPath } = imports[i];
const { start, path: importPath } = imports[i];
// Skip if inside a code region
if (

View File

@@ -4,6 +4,8 @@
* SPDX-License-Identifier: Apache-2.0
*/
import { StructuredError } from '../core/turn.js';
export interface ApiError {
error: {
code: number;
@@ -13,11 +15,6 @@ export interface ApiError {
};
}
interface StructuredError {
message: string;
status?: number;
}
export function isApiError(error: unknown): error is ApiError {
return (
typeof error === 'object' &&