Merge tag 'v0.1.18' of https://github.com/google-gemini/gemini-cli into chore/sync-gemini-cli-v0.1.18

This commit is contained in:
tanzhenxin
2025-08-13 15:11:10 +08:00
94 changed files with 5258 additions and 4724 deletions

View File

@@ -33,6 +33,7 @@
"chardet": "^2.1.0",
"diff": "^7.0.0",
"dotenv": "^17.1.0",
"fdir": "^6.4.6",
"glob": "^10.4.5",
"google-auth-library": "^9.11.0",
"html-to-text": "^9.0.5",
@@ -42,6 +43,7 @@
"micromatch": "^4.0.8",
"open": "^10.1.2",
"openai": "^5.7.0",
"picomatch": "^4.0.1",
"shell-quote": "^1.8.3",
"simple-git": "^3.28.0",
"strip-ansi": "^7.1.0",
@@ -50,10 +52,12 @@
"ws": "^8.18.0"
},
"devDependencies": {
"@google/gemini-cli-test-utils": "file:../test-utils",
"@types/diff": "^7.0.2",
"@types/dotenv": "^6.1.1",
"@types/micromatch": "^4.0.8",
"@types/minimatch": "^5.1.2",
"@types/picomatch": "^4.0.1",
"@types/ws": "^8.5.10",
"typescript": "^5.3.3",
"vitest": "^3.1.1"

View File

@@ -18,7 +18,18 @@ import {
} from '../core/contentGenerator.js';
import { GeminiClient } from '../core/client.js';
import { GitService } from '../services/gitService.js';
import { IdeClient } from '../ide/ide-client.js';
vi.mock('fs', async (importOriginal) => {
const actual = await importOriginal<typeof import('fs')>();
return {
...actual,
existsSync: vi.fn().mockReturnValue(true),
statSync: vi.fn().mockReturnValue({
isDirectory: vi.fn().mockReturnValue(true),
}),
realpathSync: vi.fn((path) => path),
};
});
vi.mock('fs', async (importOriginal) => {
const actual = await importOriginal<typeof import('fs')>();
@@ -120,7 +131,6 @@ describe('Server Config (config.ts)', () => {
telemetry: TELEMETRY_SETTINGS,
sessionId: SESSION_ID,
model: MODEL,
ideClient: IdeClient.getInstance(false),
};
beforeEach(() => {

View File

@@ -48,6 +48,8 @@ import { shouldAttemptBrowserLaunch } from '../utils/browser.js';
import { MCPOAuthConfig } from '../mcp/oauth-provider.js';
import { IdeClient } from '../ide/ide-client.js';
import type { Content } from '@google/genai';
import { logIdeConnection } from '../telemetry/loggers.js';
import { IdeConnectionEvent, IdeConnectionType } from '../telemetry/types.js';
// Re-export OAuth config type
export type { MCPOAuthConfig };
@@ -196,7 +198,6 @@ export interface ConfigParameters {
summarizeToolOutput?: Record<string, SummarizeToolOutputSettings>;
ideModeFeature?: boolean;
ideMode?: boolean;
ideClient?: IdeClient;
enableOpenAILogging?: boolean;
sampling_params?: Record<string, unknown>;
systemPromptMappings?: Array<{
@@ -209,6 +210,7 @@ export interface ConfigParameters {
maxRetries?: number;
};
cliVersion?: string;
loadMemoryFromIncludeDirectories?: boolean;
}
export class Config {
@@ -283,6 +285,8 @@ export class Config {
maxRetries?: number;
};
private readonly cliVersion?: string;
private readonly loadMemoryFromIncludeDirectories: boolean = false;
constructor(params: ConfigParameters) {
this.sessionId = params.sessionId;
this.embeddingModel =
@@ -345,15 +349,20 @@ export class Config {
this.summarizeToolOutput = params.summarizeToolOutput;
this.ideModeFeature = params.ideModeFeature ?? false;
this.ideMode = params.ideMode ?? false;
this.ideClient =
params.ideClient ??
IdeClient.getInstance(this.ideMode && this.ideModeFeature);
this.ideClient = IdeClient.getInstance();
if (this.ideMode && this.ideModeFeature) {
this.ideClient.connect();
logIdeConnection(this, new IdeConnectionEvent(IdeConnectionType.START));
}
this.systemPromptMappings = params.systemPromptMappings;
this.enableOpenAILogging = params.enableOpenAILogging ?? false;
this.sampling_params = params.sampling_params;
this.contentGenerator = params.contentGenerator;
this.cliVersion = params.cliVersion;
this.loadMemoryFromIncludeDirectories =
params.loadMemoryFromIncludeDirectories ?? false;
if (params.contextFileName) {
setGeminiMdFilename(params.contextFileName);
}
@@ -415,6 +424,10 @@ export class Config {
return this.sessionId;
}
shouldLoadMemoryFromIncludeDirectories(): boolean {
return this.loadMemoryFromIncludeDirectories;
}
getContentGeneratorConfig(): ContentGeneratorConfig {
return this.contentGeneratorConfig;
}
@@ -698,12 +711,14 @@ export class Config {
this.ideMode = value;
}
setIdeClientDisconnected(): void {
this.ideClient.setDisconnected();
}
setIdeClientConnected(): void {
this.ideClient.reconnect(this.ideMode && this.ideModeFeature);
async setIdeModeAndSyncConnection(value: boolean): Promise<void> {
this.ideMode = value;
if (value) {
await this.ideClient.connect();
logIdeConnection(this, new IdeConnectionEvent(IdeConnectionType.SESSION));
} else {
this.ideClient.disconnect();
}
}
getEnableOpenAILogging(): boolean {

View File

@@ -7,7 +7,6 @@
import { describe, it, expect, beforeEach, vi } from 'vitest';
import { Config } from './config.js';
import { DEFAULT_GEMINI_MODEL, DEFAULT_GEMINI_FLASH_MODEL } from './models.js';
import { IdeClient } from '../ide/ide-client.js';
import fs from 'node:fs';
vi.mock('node:fs');
@@ -26,7 +25,6 @@ describe('Flash Model Fallback Configuration', () => {
debugMode: false,
cwd: '/test',
model: DEFAULT_GEMINI_MODEL,
ideClient: IdeClient.getInstance(false),
});
// Initialize contentGeneratorConfig for testing
@@ -51,7 +49,6 @@ describe('Flash Model Fallback Configuration', () => {
debugMode: false,
cwd: '/test',
model: DEFAULT_GEMINI_MODEL,
ideClient: IdeClient.getInstance(false),
});
// Should not crash when contentGeneratorConfig is undefined
@@ -75,7 +72,6 @@ describe('Flash Model Fallback Configuration', () => {
debugMode: false,
cwd: '/test',
model: 'custom-model',
ideClient: IdeClient.getInstance(false),
});
expect(newConfig.getModel()).toBe('custom-model');

View File

@@ -136,6 +136,7 @@ describe('CoreToolScheduler', () => {
onAllToolCallsComplete,
onToolCallsUpdate,
getPreferredEditor: () => 'vscode',
onEditorClose: vi.fn(),
});
const abortController = new AbortController();
@@ -205,6 +206,7 @@ describe('CoreToolScheduler with payload', () => {
onAllToolCallsComplete,
onToolCallsUpdate,
getPreferredEditor: () => 'vscode',
onEditorClose: vi.fn(),
});
const abortController = new AbortController();
@@ -482,6 +484,7 @@ describe('CoreToolScheduler edit cancellation', () => {
onAllToolCallsComplete,
onToolCallsUpdate,
getPreferredEditor: () => 'vscode',
onEditorClose: vi.fn(),
});
const abortController = new AbortController();
@@ -571,6 +574,7 @@ describe('CoreToolScheduler YOLO mode', () => {
onAllToolCallsComplete,
onToolCallsUpdate,
getPreferredEditor: () => 'vscode',
onEditorClose: vi.fn(),
});
const abortController = new AbortController();

View File

@@ -224,6 +224,7 @@ interface CoreToolSchedulerOptions {
onToolCallsUpdate?: ToolCallsUpdateHandler;
getPreferredEditor: () => EditorType | undefined;
config: Config;
onEditorClose: () => void;
}
export class CoreToolScheduler {
@@ -234,6 +235,7 @@ export class CoreToolScheduler {
private onToolCallsUpdate?: ToolCallsUpdateHandler;
private getPreferredEditor: () => EditorType | undefined;
private config: Config;
private onEditorClose: () => void;
constructor(options: CoreToolSchedulerOptions) {
this.config = options.config;
@@ -242,6 +244,7 @@ export class CoreToolScheduler {
this.onAllToolCallsComplete = options.onAllToolCallsComplete;
this.onToolCallsUpdate = options.onToolCallsUpdate;
this.getPreferredEditor = options.getPreferredEditor;
this.onEditorClose = options.onEditorClose;
}
private setStatusInternal(
@@ -563,6 +566,7 @@ export class CoreToolScheduler {
modifyContext as ModifyContext<typeof waitingToolCall.request.args>,
editorType,
signal,
this.onEditorClose,
);
this.setArgsInternal(callId, updatedParams);
this.setStatusInternal(callId, 'awaiting_approval', {

View File

@@ -33,34 +33,58 @@ export enum IDEConnectionStatus {
* Manages the connection to and interaction with the IDE server.
*/
export class IdeClient {
client: Client | undefined = undefined;
private static instance: IdeClient;
private client: Client | undefined = undefined;
private state: IDEConnectionState = {
status: IDEConnectionStatus.Disconnected,
details:
'IDE integration is currently disabled. To enable it, run /ide enable.',
};
private static instance: IdeClient;
private readonly currentIde: DetectedIde | undefined;
private readonly currentIdeDisplayName: string | undefined;
constructor(ideMode: boolean) {
private constructor() {
this.currentIde = detectIde();
if (this.currentIde) {
this.currentIdeDisplayName = getIdeDisplayName(this.currentIde);
}
if (!ideMode) {
return;
}
this.init().catch((err) => {
logger.debug('Failed to initialize IdeClient:', err);
});
}
static getInstance(ideMode: boolean): IdeClient {
static getInstance(): IdeClient {
if (!IdeClient.instance) {
IdeClient.instance = new IdeClient(ideMode);
IdeClient.instance = new IdeClient();
}
return IdeClient.instance;
}
async connect(): Promise<void> {
this.setState(IDEConnectionStatus.Connecting);
if (!this.currentIde || !this.currentIdeDisplayName) {
this.setState(IDEConnectionStatus.Disconnected);
return;
}
if (!this.validateWorkspacePath()) {
return;
}
const port = this.getPortFromEnv();
if (!port) {
return;
}
await this.establishConnection(port);
}
disconnect() {
this.setState(
IDEConnectionStatus.Disconnected,
'IDE integration disabled. To enable it again, run /ide enable.',
);
this.client?.close();
}
getCurrentIde(): DetectedIde | undefined {
return this.currentIde;
}
@@ -70,45 +94,60 @@ export class IdeClient {
}
private setState(status: IDEConnectionStatus, details?: string) {
this.state = { status, details };
const isAlreadyDisconnected =
this.state.status === IDEConnectionStatus.Disconnected &&
status === IDEConnectionStatus.Disconnected;
// Only update details if the state wasn't already disconnected, so that
// the first detail message is preserved.
if (!isAlreadyDisconnected) {
this.state = { status, details };
}
if (status === IDEConnectionStatus.Disconnected) {
logger.debug('IDE integration is disconnected. ', details);
logger.debug('IDE integration disconnected:', details);
ideContext.clearIdeContext();
}
}
private validateWorkspacePath(): boolean {
const ideWorkspacePath = process.env['GEMINI_CLI_IDE_WORKSPACE_PATH'];
if (ideWorkspacePath === undefined) {
this.setState(
IDEConnectionStatus.Disconnected,
`Failed to connect to IDE companion extension for ${this.currentIdeDisplayName}. Please ensure the extension is running and try refreshing your terminal. To install the extension, run /ide install.`,
);
return false;
}
if (ideWorkspacePath === '') {
this.setState(
IDEConnectionStatus.Disconnected,
`To use this feature, please open a single workspace folder in ${this.currentIdeDisplayName} and try again.`,
);
return false;
}
if (ideWorkspacePath !== process.cwd()) {
this.setState(
IDEConnectionStatus.Disconnected,
`Directory mismatch. Gemini CLI is running in a different location than the open workspace in ${this.currentIdeDisplayName}. Please run the CLI from the same directory as your project's root folder.`,
);
return false;
}
return true;
}
private getPortFromEnv(): string | undefined {
const port = process.env['GEMINI_CLI_IDE_SERVER_PORT'];
if (!port) {
this.setState(
IDEConnectionStatus.Disconnected,
'Gemini CLI Companion extension not found. Install via /ide install and restart the CLI in a fresh terminal window.',
`Failed to connect to IDE companion extension for ${this.currentIdeDisplayName}. Please ensure the extension is running and try refreshing your terminal. To install the extension, run /ide install.`,
);
return undefined;
}
return port;
}
private validateWorkspacePath(): boolean {
const ideWorkspacePath = process.env['GEMINI_CLI_IDE_WORKSPACE_PATH'];
if (!ideWorkspacePath) {
this.setState(
IDEConnectionStatus.Disconnected,
'IDE integration requires a single workspace folder to be open in the IDE. Please ensure one folder is open and try again.',
);
return false;
}
if (ideWorkspacePath !== process.cwd()) {
this.setState(
IDEConnectionStatus.Disconnected,
`Gemini CLI is running in a different directory (${process.cwd()}) from the IDE's open workspace (${ideWorkspacePath}). Please run Gemini CLI in the same directory.`,
);
return false;
}
return true;
}
private registerClientHandlers() {
if (!this.client) {
return;
@@ -120,20 +159,20 @@ export class IdeClient {
ideContext.setIdeContext(notification.params);
},
);
this.client.onerror = (_error) => {
this.setState(IDEConnectionStatus.Disconnected, 'Client error.');
this.setState(
IDEConnectionStatus.Disconnected,
`IDE connection error. The connection was lost unexpectedly. Please try reconnecting by running /ide enable`,
);
};
this.client.onclose = () => {
this.setState(IDEConnectionStatus.Disconnected, 'Connection closed.');
this.setState(
IDEConnectionStatus.Disconnected,
`IDE connection error. The connection was lost unexpectedly. Please try reconnecting by running /ide enable`,
);
};
}
async reconnect(ideMode: boolean) {
IdeClient.instance = new IdeClient(ideMode);
}
private async establishConnection(port: string) {
let transport: StreamableHTTPClientTransport | undefined;
try {
@@ -150,12 +189,12 @@ export class IdeClient {
this.registerClientHandlers();
await this.client.connect(transport);
this.registerClientHandlers();
this.setState(IDEConnectionStatus.Connected);
} catch (error) {
} catch (_error) {
this.setState(
IDEConnectionStatus.Disconnected,
`Failed to connect to IDE server: ${error}`,
`Failed to connect to IDE companion extension for ${this.currentIdeDisplayName}. Please ensure the extension is running and try refreshing your terminal. To install the extension, run /ide install.`,
);
if (transport) {
try {

View File

@@ -41,6 +41,7 @@ export * from './utils/shell-utils.js';
export * from './utils/systemEncoding.js';
export * from './utils/textUtils.js';
export * from './utils/formatters.js';
export * from './utils/filesearch/fileSearch.js';
// Export services
export * from './services/fileDiscoveryService.js';

View File

@@ -4,7 +4,17 @@
* SPDX-License-Identifier: Apache-2.0
*/
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
import { vi } from 'vitest';
// Mock dependencies AT THE TOP
const mockOpenBrowserSecurely = vi.hoisted(() => vi.fn());
vi.mock('../utils/secure-browser-launcher.js', () => ({
openBrowserSecurely: mockOpenBrowserSecurely,
}));
vi.mock('node:crypto');
vi.mock('./oauth-token-storage.js');
import { describe, it, expect, beforeEach, afterEach } from 'vitest';
import * as http from 'node:http';
import * as crypto from 'node:crypto';
import {
@@ -15,14 +25,6 @@ import {
} from './oauth-provider.js';
import { MCPOAuthTokenStorage, MCPOAuthToken } from './oauth-token-storage.js';
// Mock dependencies
const mockOpenBrowserSecurely = vi.hoisted(() => vi.fn());
vi.mock('../utils/secure-browser-launcher.js', () => ({
openBrowserSecurely: mockOpenBrowserSecurely,
}));
vi.mock('node:crypto');
vi.mock('./oauth-token-storage.js');
// Mock fetch globally
const mockFetch = vi.fn();
global.fetch = mockFetch;
@@ -46,6 +48,7 @@ describe('MCPOAuthProvider', () => {
tokenUrl: 'https://auth.example.com/token',
scopes: ['read', 'write'],
redirectUri: 'http://localhost:7777/oauth/callback',
audiences: ['https://api.example.com'],
};
const mockToken: MCPOAuthToken = {
@@ -720,6 +723,105 @@ describe('MCPOAuthProvider', () => {
expect(capturedUrl!).toContain('code_challenge_method=S256');
expect(capturedUrl!).toContain('scope=read+write');
expect(capturedUrl!).toContain('resource=https%3A%2F%2Fauth.example.com');
expect(capturedUrl!).toContain('audience=https%3A%2F%2Fapi.example.com');
});
it('should correctly append parameters to an authorization URL that already has query params', async () => {
// Mock to capture the URL that would be opened
let capturedUrl: string;
mockOpenBrowserSecurely.mockImplementation((url: string) => {
capturedUrl = url;
return Promise.resolve();
});
let callbackHandler: unknown;
vi.mocked(http.createServer).mockImplementation((handler) => {
callbackHandler = handler;
return mockHttpServer as unknown as http.Server;
});
mockHttpServer.listen.mockImplementation((port, callback) => {
callback?.();
setTimeout(() => {
const mockReq = {
url: '/oauth/callback?code=auth_code_123&state=bW9ja19zdGF0ZV8xNl9ieXRlcw',
};
const mockRes = {
writeHead: vi.fn(),
end: vi.fn(),
};
(callbackHandler as (req: unknown, res: unknown) => void)(
mockReq,
mockRes,
);
}, 10);
});
mockFetch.mockResolvedValueOnce({
ok: true,
json: () => Promise.resolve(mockTokenResponse),
});
const configWithParamsInUrl = {
...mockConfig,
authorizationUrl: 'https://auth.example.com/authorize?audience=1234',
};
await MCPOAuthProvider.authenticate('test-server', configWithParamsInUrl);
const url = new URL(capturedUrl!);
expect(url.searchParams.get('audience')).toBe('1234');
expect(url.searchParams.get('client_id')).toBe('test-client-id');
expect(url.search.startsWith('?audience=1234&')).toBe(true);
});
it('should correctly append parameters to a URL with a fragment', async () => {
// Mock to capture the URL that would be opened
let capturedUrl: string;
mockOpenBrowserSecurely.mockImplementation((url: string) => {
capturedUrl = url;
return Promise.resolve();
});
let callbackHandler: unknown;
vi.mocked(http.createServer).mockImplementation((handler) => {
callbackHandler = handler;
return mockHttpServer as unknown as http.Server;
});
mockHttpServer.listen.mockImplementation((port, callback) => {
callback?.();
setTimeout(() => {
const mockReq = {
url: '/oauth/callback?code=auth_code_123&state=bW9ja19zdGF0ZV8xNl9ieXRlcw',
};
const mockRes = {
writeHead: vi.fn(),
end: vi.fn(),
};
(callbackHandler as (req: unknown, res: unknown) => void)(
mockReq,
mockRes,
);
}, 10);
});
mockFetch.mockResolvedValueOnce({
ok: true,
json: () => Promise.resolve(mockTokenResponse),
});
const configWithFragment = {
...mockConfig,
authorizationUrl: 'https://auth.example.com/authorize#login',
};
await MCPOAuthProvider.authenticate('test-server', configWithFragment);
const url = new URL(capturedUrl!);
expect(url.searchParams.get('client_id')).toBe('test-client-id');
expect(url.hash).toBe('#login');
expect(url.pathname).toBe('/authorize');
});
});
});

View File

@@ -22,6 +22,7 @@ export interface MCPOAuthConfig {
authorizationUrl?: string;
tokenUrl?: string;
scopes?: string[];
audiences?: string[];
redirectUri?: string;
tokenParamName?: string; // For SSE connections, specifies the query parameter name for the token
}
@@ -297,6 +298,10 @@ export class MCPOAuthProvider {
params.append('scope', config.scopes.join(' '));
}
if (config.audiences && config.audiences.length > 0) {
params.append('audience', config.audiences.join(' '));
}
// Add resource parameter for MCP OAuth spec compliance
// Use the MCP server URL if provided, otherwise fall back to authorization URL
const resourceUrl = mcpServerUrl || config.authorizationUrl!;
@@ -308,7 +313,11 @@ export class MCPOAuthProvider {
);
}
return `${config.authorizationUrl}?${params.toString()}`;
const url = new URL(config.authorizationUrl!);
params.forEach((value, key) => {
url.searchParams.append(key, value);
});
return url.toString();
}
/**
@@ -342,6 +351,10 @@ export class MCPOAuthProvider {
params.append('client_secret', config.clientSecret);
}
if (config.audiences && config.audiences.length > 0) {
params.append('audience', config.audiences.join(' '));
}
// Add resource parameter for MCP OAuth spec compliance
// Use the MCP server URL if provided, otherwise fall back to token URL
const resourceUrl = mcpServerUrl || config.tokenUrl!;
@@ -400,6 +413,10 @@ export class MCPOAuthProvider {
params.append('scope', config.scopes.join(' '));
}
if (config.audiences && config.audiences.length > 0) {
params.append('audience', config.audiences.join(' '));
}
// Add resource parameter for MCP OAuth spec compliance
// Use the MCP server URL if provided, otherwise fall back to token URL
const resourceUrl = mcpServerUrl || tokenUrl;

View File

@@ -53,4 +53,22 @@ export class PromptRegistry {
}
return serverPrompts.sort((a, b) => a.name.localeCompare(b.name));
}
/**
* Clears all the prompts from the registry.
*/
clear(): void {
this.prompts.clear();
}
/**
* Removes all prompts from a specific server.
*/
removePromptsByServer(serverName: string): void {
for (const [name, prompt] of this.prompts.entries()) {
if (prompt.serverName === serverName) {
this.prompts.delete(name);
}
}
}
}

View File

@@ -21,6 +21,7 @@ import {
NextSpeakerCheckEvent,
SlashCommandEvent,
MalformedJsonResponseEvent,
IdeConnectionEvent,
} from '../types.js';
import { EventMetadataKey } from './event-metadata-key.js';
import { Config } from '../../config/config.js';
@@ -44,6 +45,7 @@ const loop_detected_event_name = 'loop_detected';
const next_speaker_check_event_name = 'next_speaker_check';
const slash_command_event_name = 'slash_command';
const malformed_json_response_event_name = 'malformed_json_response';
const ide_connection_event_name = 'ide_connection';
export interface LogResponse {
nextRequestWaitMs?: number;
@@ -578,6 +580,18 @@ export class ClearcutLogger {
this.flushIfNeeded();
}
logIdeConnectionEvent(event: IdeConnectionEvent): void {
const data = [
{
gemini_cli_key: EventMetadataKey.GEMINI_CLI_IDE_CONNECTION_TYPE,
value: JSON.stringify(event.connection_type),
},
];
this.enqueueLogEvent(this.createLogEvent(ide_connection_event_name, data));
this.flushIfNeeded();
}
logEndSessionEvent(event: EndSessionEvent): void {
const data = [
{

View File

@@ -190,6 +190,13 @@ export enum EventMetadataKey {
// Logs the model that produced the malformed JSON response.
GEMINI_CLI_MALFORMED_JSON_RESPONSE_MODEL = 45,
// ==========================================================================
// IDE Connection Event Keys
// ===========================================================================
// Logs the type of the IDE connection.
GEMINI_CLI_IDE_CONNECTION_TYPE = 46,
}
export function getEventMetadataKey(

View File

@@ -15,6 +15,7 @@ export const EVENT_CLI_CONFIG = 'qwen-code.config';
export const EVENT_FLASH_FALLBACK = 'qwen-code.flash_fallback';
export const EVENT_NEXT_SPEAKER_CHECK = 'qwen-code.next_speaker_check';
export const EVENT_SLASH_COMMAND = 'qwen-code.slash_command';
export const EVENT_IDE_CONNECTION = 'qwen-code.ide_connection';
export const METRIC_TOOL_CALL_COUNT = 'qwen-code.tool.call.count';
export const METRIC_TOOL_CALL_LATENCY = 'qwen-code.tool.call.latency';

View File

@@ -12,6 +12,7 @@ import {
EVENT_API_REQUEST,
EVENT_API_RESPONSE,
EVENT_CLI_CONFIG,
EVENT_IDE_CONNECTION,
EVENT_TOOL_CALL,
EVENT_USER_PROMPT,
EVENT_FLASH_FALLBACK,
@@ -23,6 +24,7 @@ import {
ApiErrorEvent,
ApiRequestEvent,
ApiResponseEvent,
IdeConnectionEvent,
StartSessionEvent,
ToolCallEvent,
UserPromptEvent,
@@ -355,3 +357,23 @@ export function logSlashCommand(
};
logger.emit(logRecord);
}
export function logIdeConnection(
config: Config,
event: IdeConnectionEvent,
): void {
if (!isTelemetrySdkInitialized()) return;
const attributes: LogAttributes = {
...getCommonAttributes(config),
...event,
'event.name': EVENT_IDE_CONNECTION,
};
const logger = logs.getLogger(SERVICE_NAME);
const logRecord: LogRecord = {
body: `Ide connection. Type: ${event.connection_type}.`,
attributes,
};
logger.emit(logRecord);
}

View File

@@ -12,7 +12,6 @@ import {
} from './sdk.js';
import { Config } from '../config/config.js';
import { NodeSDK } from '@opentelemetry/sdk-node';
import { IdeClient } from '../ide/ide-client.js';
vi.mock('@opentelemetry/sdk-node');
vi.mock('../config/config.js');
@@ -30,7 +29,6 @@ describe('telemetry', () => {
targetDir: '/test/dir',
debugMode: false,
cwd: '/test/dir',
ideClient: IdeClient.getInstance(false),
});
vi.spyOn(mockConfig, 'getTelemetryEnabled').mockReturnValue(true);
vi.spyOn(mockConfig, 'getTelemetryOtlpEndpoint').mockReturnValue(

View File

@@ -314,6 +314,23 @@ export class MalformedJsonResponseEvent {
}
}
export enum IdeConnectionType {
START = 'start',
SESSION = 'session',
}
export class IdeConnectionEvent {
'event.name': 'ide_connection';
'event.timestamp': string; // ISO 8601
connection_type: IdeConnectionType;
constructor(connection_type: IdeConnectionType) {
this['event.name'] = 'ide_connection';
this['event.timestamp'] = new Date().toISOString();
this.connection_type = connection_type;
}
}
export type TelemetryEvent =
| StartSessionEvent
| EndSessionEvent
@@ -326,4 +343,5 @@ export type TelemetryEvent =
| LoopDetectedEvent
| NextSpeakerCheckEvent
| SlashCommandEvent
| MalformedJsonResponseEvent;
| MalformedJsonResponseEvent
| IdeConnectionEvent;

View File

@@ -58,9 +58,7 @@ describe('mcp-client', () => {
const mockedClient = {} as unknown as ClientLib.Client;
const consoleErrorSpy = vi
.spyOn(console, 'error')
.mockImplementation(() => {
// no-op
});
.mockImplementation(() => {});
const testError = new Error('Invalid tool name');
vi.mocked(DiscoveredMCPTool).mockImplementation(
@@ -113,12 +111,17 @@ describe('mcp-client', () => {
{ name: 'prompt2' },
],
});
const mockGetServerCapabilities = vi.fn().mockReturnValue({
prompts: {},
});
const mockedClient = {
getServerCapabilities: mockGetServerCapabilities,
request: mockRequest,
} as unknown as ClientLib.Client;
await discoverPrompts('test-server', mockedClient, mockedPromptRegistry);
expect(mockGetServerCapabilities).toHaveBeenCalledOnce();
expect(mockRequest).toHaveBeenCalledWith(
{ method: 'prompts/list', params: {} },
expect.anything(),
@@ -129,37 +132,67 @@ describe('mcp-client', () => {
const mockRequest = vi.fn().mockResolvedValue({
prompts: [],
});
const mockGetServerCapabilities = vi.fn().mockReturnValue({
prompts: {},
});
const mockedClient = {
getServerCapabilities: mockGetServerCapabilities,
request: mockRequest,
} as unknown as ClientLib.Client;
const consoleLogSpy = vi
.spyOn(console, 'debug')
.mockImplementation(() => {
// no-op
});
.mockImplementation(() => {});
await discoverPrompts('test-server', mockedClient, mockedPromptRegistry);
expect(mockGetServerCapabilities).toHaveBeenCalledOnce();
expect(mockRequest).toHaveBeenCalledOnce();
expect(consoleLogSpy).not.toHaveBeenCalled();
consoleLogSpy.mockRestore();
});
it('should do nothing if the server has no prompt support', async () => {
const mockRequest = vi.fn().mockResolvedValue({
prompts: [],
});
const mockGetServerCapabilities = vi.fn().mockReturnValue({});
const mockedClient = {
getServerCapabilities: mockGetServerCapabilities,
request: mockRequest,
} as unknown as ClientLib.Client;
const consoleLogSpy = vi
.spyOn(console, 'debug')
.mockImplementation(() => {});
await discoverPrompts('test-server', mockedClient, mockedPromptRegistry);
expect(mockGetServerCapabilities).toHaveBeenCalledOnce();
expect(mockRequest).not.toHaveBeenCalled();
expect(consoleLogSpy).not.toHaveBeenCalled();
consoleLogSpy.mockRestore();
});
it('should log an error if discovery fails', async () => {
const testError = new Error('test error');
testError.message = 'test error';
const mockRequest = vi.fn().mockRejectedValue(testError);
const mockGetServerCapabilities = vi.fn().mockReturnValue({
prompts: {},
});
const mockedClient = {
getServerCapabilities: mockGetServerCapabilities,
request: mockRequest,
} as unknown as ClientLib.Client;
const consoleErrorSpy = vi
.spyOn(console, 'error')
.mockImplementation(() => {
// no-op
});
.mockImplementation(() => {});
await discoverPrompts('test-server', mockedClient, mockedPromptRegistry);

View File

@@ -496,6 +496,9 @@ export async function discoverPrompts(
promptRegistry: PromptRegistry,
): Promise<Prompt[]> {
try {
// Only request prompts if the server supports them.
if (mcpClient.getServerCapabilities()?.prompts == null) return [];
const response = await mcpClient.request(
{ method: 'prompts/list', params: {} },
ListPromptsResultSchema,

View File

@@ -131,8 +131,11 @@ describe('DiscoveredMCPTool', () => {
success: true,
details: 'executed',
};
const mockFunctionResponseContent: Part[] = [
{ text: JSON.stringify(mockToolSuccessResultObject) },
const mockFunctionResponseContent = [
{
type: 'text',
text: JSON.stringify(mockToolSuccessResultObject),
},
];
const mockMcpToolResponseParts: Part[] = [
{
@@ -149,11 +152,13 @@ describe('DiscoveredMCPTool', () => {
expect(mockCallTool).toHaveBeenCalledWith([
{ name: serverToolName, args: params },
]);
expect(toolResult.llmContent).toEqual(mockMcpToolResponseParts);
const stringifiedResponseContent = JSON.stringify(
mockToolSuccessResultObject,
);
expect(toolResult.llmContent).toEqual([
{ text: stringifiedResponseContent },
]);
expect(toolResult.returnDisplay).toBe(stringifiedResponseContent);
});
@@ -170,6 +175,9 @@ describe('DiscoveredMCPTool', () => {
mockCallTool.mockResolvedValue(mockMcpToolResponsePartsEmpty);
const toolResult: ToolResult = await tool.execute(params);
expect(toolResult.returnDisplay).toBe('```json\n[]\n```');
expect(toolResult.llmContent).toEqual([
{ text: '[Error: Could not parse tool response]' },
]);
});
it('should propagate rejection if mcpTool.callTool rejects', async () => {
@@ -186,6 +194,361 @@ describe('DiscoveredMCPTool', () => {
await expect(tool.execute(params)).rejects.toThrow(expectedError);
});
it('should handle a simple text response correctly', async () => {
const tool = new DiscoveredMCPTool(
mockCallableToolInstance,
serverName,
serverToolName,
baseDescription,
inputSchema,
);
const params = { query: 'test' };
const successMessage = 'This is a success message.';
// Simulate the response from the GenAI SDK, which wraps the MCP
// response in a functionResponse Part.
const sdkResponse: Part[] = [
{
functionResponse: {
name: serverToolName,
response: {
// The `content` array contains MCP ContentBlocks.
content: [{ type: 'text', text: successMessage }],
},
},
},
];
mockCallTool.mockResolvedValue(sdkResponse);
const toolResult = await tool.execute(params);
// 1. Assert that the llmContent sent to the scheduler is a clean Part array.
expect(toolResult.llmContent).toEqual([{ text: successMessage }]);
// 2. Assert that the display output is the simple text message.
expect(toolResult.returnDisplay).toBe(successMessage);
// 3. Verify that the underlying callTool was made correctly.
expect(mockCallTool).toHaveBeenCalledWith([
{ name: serverToolName, args: params },
]);
});
it('should handle an AudioBlock response', async () => {
const tool = new DiscoveredMCPTool(
mockCallableToolInstance,
serverName,
serverToolName,
baseDescription,
inputSchema,
);
const params = { action: 'play' };
const sdkResponse: Part[] = [
{
functionResponse: {
name: serverToolName,
response: {
content: [
{
type: 'audio',
data: 'BASE64_AUDIO_DATA',
mimeType: 'audio/mp3',
},
],
},
},
},
];
mockCallTool.mockResolvedValue(sdkResponse);
const toolResult = await tool.execute(params);
expect(toolResult.llmContent).toEqual([
{
text: `[Tool '${serverToolName}' provided the following audio data with mime-type: audio/mp3]`,
},
{
inlineData: {
mimeType: 'audio/mp3',
data: 'BASE64_AUDIO_DATA',
},
},
]);
expect(toolResult.returnDisplay).toBe('[Audio: audio/mp3]');
});
it('should handle a ResourceLinkBlock response', async () => {
const tool = new DiscoveredMCPTool(
mockCallableToolInstance,
serverName,
serverToolName,
baseDescription,
inputSchema,
);
const params = { resource: 'get' };
const sdkResponse: Part[] = [
{
functionResponse: {
name: serverToolName,
response: {
content: [
{
type: 'resource_link',
uri: 'file:///path/to/thing',
name: 'resource-name',
title: 'My Resource',
},
],
},
},
},
];
mockCallTool.mockResolvedValue(sdkResponse);
const toolResult = await tool.execute(params);
expect(toolResult.llmContent).toEqual([
{
text: 'Resource Link: My Resource at file:///path/to/thing',
},
]);
expect(toolResult.returnDisplay).toBe(
'[Link to My Resource: file:///path/to/thing]',
);
});
it('should handle an embedded text ResourceBlock response', async () => {
const tool = new DiscoveredMCPTool(
mockCallableToolInstance,
serverName,
serverToolName,
baseDescription,
inputSchema,
);
const params = { resource: 'get' };
const sdkResponse: Part[] = [
{
functionResponse: {
name: serverToolName,
response: {
content: [
{
type: 'resource',
resource: {
uri: 'file:///path/to/text.txt',
text: 'This is the text content.',
mimeType: 'text/plain',
},
},
],
},
},
},
];
mockCallTool.mockResolvedValue(sdkResponse);
const toolResult = await tool.execute(params);
expect(toolResult.llmContent).toEqual([
{ text: 'This is the text content.' },
]);
expect(toolResult.returnDisplay).toBe('This is the text content.');
});
it('should handle an embedded binary ResourceBlock response', async () => {
const tool = new DiscoveredMCPTool(
mockCallableToolInstance,
serverName,
serverToolName,
baseDescription,
inputSchema,
);
const params = { resource: 'get' };
const sdkResponse: Part[] = [
{
functionResponse: {
name: serverToolName,
response: {
content: [
{
type: 'resource',
resource: {
uri: 'file:///path/to/data.bin',
blob: 'BASE64_BINARY_DATA',
mimeType: 'application/octet-stream',
},
},
],
},
},
},
];
mockCallTool.mockResolvedValue(sdkResponse);
const toolResult = await tool.execute(params);
expect(toolResult.llmContent).toEqual([
{
text: `[Tool '${serverToolName}' provided the following embedded resource with mime-type: application/octet-stream]`,
},
{
inlineData: {
mimeType: 'application/octet-stream',
data: 'BASE64_BINARY_DATA',
},
},
]);
expect(toolResult.returnDisplay).toBe(
'[Embedded Resource: application/octet-stream]',
);
});
it('should handle a mix of content block types', async () => {
const tool = new DiscoveredMCPTool(
mockCallableToolInstance,
serverName,
serverToolName,
baseDescription,
inputSchema,
);
const params = { action: 'complex' };
const sdkResponse: Part[] = [
{
functionResponse: {
name: serverToolName,
response: {
content: [
{ type: 'text', text: 'First part.' },
{
type: 'image',
data: 'BASE64_IMAGE_DATA',
mimeType: 'image/jpeg',
},
{ type: 'text', text: 'Second part.' },
],
},
},
},
];
mockCallTool.mockResolvedValue(sdkResponse);
const toolResult = await tool.execute(params);
expect(toolResult.llmContent).toEqual([
{ text: 'First part.' },
{
text: `[Tool '${serverToolName}' provided the following image data with mime-type: image/jpeg]`,
},
{
inlineData: {
mimeType: 'image/jpeg',
data: 'BASE64_IMAGE_DATA',
},
},
{ text: 'Second part.' },
]);
expect(toolResult.returnDisplay).toBe(
'First part.\n[Image: image/jpeg]\nSecond part.',
);
});
it('should ignore unknown content block types', async () => {
const tool = new DiscoveredMCPTool(
mockCallableToolInstance,
serverName,
serverToolName,
baseDescription,
inputSchema,
);
const params = { action: 'test' };
const sdkResponse: Part[] = [
{
functionResponse: {
name: serverToolName,
response: {
content: [
{ type: 'text', text: 'Valid part.' },
{ type: 'future_block', data: 'some-data' },
],
},
},
},
];
mockCallTool.mockResolvedValue(sdkResponse);
const toolResult = await tool.execute(params);
expect(toolResult.llmContent).toEqual([{ text: 'Valid part.' }]);
expect(toolResult.returnDisplay).toBe(
'Valid part.\n[Unknown content type: future_block]',
);
});
it('should handle a complex mix of content block types', async () => {
const tool = new DiscoveredMCPTool(
mockCallableToolInstance,
serverName,
serverToolName,
baseDescription,
inputSchema,
);
const params = { action: 'super-complex' };
const sdkResponse: Part[] = [
{
functionResponse: {
name: serverToolName,
response: {
content: [
{ type: 'text', text: 'Here is a resource.' },
{
type: 'resource_link',
uri: 'file:///path/to/resource',
name: 'resource-name',
title: 'My Resource',
},
{
type: 'resource',
resource: {
uri: 'file:///path/to/text.txt',
text: 'Embedded text content.',
mimeType: 'text/plain',
},
},
{
type: 'image',
data: 'BASE64_IMAGE_DATA',
mimeType: 'image/jpeg',
},
],
},
},
},
];
mockCallTool.mockResolvedValue(sdkResponse);
const toolResult = await tool.execute(params);
expect(toolResult.llmContent).toEqual([
{ text: 'Here is a resource.' },
{
text: 'Resource Link: My Resource at file:///path/to/resource',
},
{ text: 'Embedded text content.' },
{
text: `[Tool '${serverToolName}' provided the following image data with mime-type: image/jpeg]`,
},
{
inlineData: {
mimeType: 'image/jpeg',
data: 'BASE64_IMAGE_DATA',
},
},
]);
expect(toolResult.returnDisplay).toBe(
'Here is a resource.\n[Link to My Resource: file:///path/to/resource]\nEmbedded text content.\n[Image: image/jpeg]',
);
});
});
describe('shouldConfirmExecute', () => {

View File

@@ -22,6 +22,40 @@ import {
type ToolParams = Record<string, unknown>;
// Discriminated union for MCP Content Blocks to ensure type safety.
type McpTextBlock = {
type: 'text';
text: string;
};
type McpMediaBlock = {
type: 'image' | 'audio';
mimeType: string;
data: string;
};
type McpResourceBlock = {
type: 'resource';
resource: {
text?: string;
blob?: string;
mimeType?: string;
};
};
type McpResourceLinkBlock = {
type: 'resource_link';
uri: string;
title?: string;
name?: string;
};
type McpContentBlock =
| McpTextBlock
| McpMediaBlock
| McpResourceBlock
| McpResourceLinkBlock;
export class DiscoveredMCPTool extends BaseTool<ToolParams, ToolResult> {
private static readonly allowlist: Set<string> = new Set();
@@ -114,70 +148,145 @@ export class DiscoveredMCPTool extends BaseTool<ToolParams, ToolResult> {
},
];
const responseParts: Part[] = await this.mcpTool.callTool(functionCalls);
const rawResponseParts = await this.mcpTool.callTool(functionCalls);
const transformedParts = transformMcpContentToParts(rawResponseParts);
return {
llmContent: responseParts,
returnDisplay: getStringifiedResultForDisplay(responseParts),
llmContent: transformedParts,
returnDisplay: getStringifiedResultForDisplay(rawResponseParts),
};
}
}
/**
* Processes an array of `Part` objects, primarily from a tool's execution result,
* to generate a user-friendly string representation, typically for display in a CLI.
*
* The `result` array can contain various types of `Part` objects:
* 1. `FunctionResponse` parts:
* - If the `response.content` of a `FunctionResponse` is an array consisting solely
* of `TextPart` objects, their text content is concatenated into a single string.
* This is to present simple textual outputs directly.
* - If `response.content` is an array but contains other types of `Part` objects (or a mix),
* the `content` array itself is preserved. This handles structured data like JSON objects or arrays
* returned by a tool.
* - If `response.content` is not an array or is missing, the entire `functionResponse`
* object is preserved.
* 2. Other `Part` types (e.g., `TextPart` directly in the `result` array):
* - These are preserved as is.
*
* All processed parts are then collected into an array, which is JSON.stringify-ed
* with indentation and wrapped in a markdown JSON code block.
*/
function getStringifiedResultForDisplay(result: Part[]) {
if (!result || result.length === 0) {
return '```json\n[]\n```';
function transformTextBlock(block: McpTextBlock): Part {
return { text: block.text };
}
function transformImageAudioBlock(
block: McpMediaBlock,
toolName: string,
): Part[] {
return [
{
text: `[Tool '${toolName}' provided the following ${
block.type
} data with mime-type: ${block.mimeType}]`,
},
{
inlineData: {
mimeType: block.mimeType,
data: block.data,
},
},
];
}
function transformResourceBlock(
block: McpResourceBlock,
toolName: string,
): Part | Part[] | null {
const resource = block.resource;
if (resource?.text) {
return { text: resource.text };
}
if (resource?.blob) {
const mimeType = resource.mimeType || 'application/octet-stream';
return [
{
text: `[Tool '${toolName}' provided the following embedded resource with mime-type: ${mimeType}]`,
},
{
inlineData: {
mimeType,
data: resource.blob,
},
},
];
}
return null;
}
const processFunctionResponse = (part: Part) => {
if (part.functionResponse) {
const responseContent = part.functionResponse.response?.content;
if (responseContent && Array.isArray(responseContent)) {
// Check if all parts in responseContent are simple TextParts
const allTextParts = responseContent.every(
(p: Part) => p.text !== undefined,
);
if (allTextParts) {
return responseContent.map((p: Part) => p.text).join('');
}
// If not all simple text parts, return the array of these content parts for JSON stringification
return responseContent;
}
// If no content, or not an array, or not a functionResponse, stringify the whole functionResponse part for inspection
return part.functionResponse;
}
return part; // Fallback for unexpected structure or non-FunctionResponsePart
function transformResourceLinkBlock(block: McpResourceLinkBlock): Part {
return {
text: `Resource Link: ${block.title || block.name} at ${block.uri}`,
};
}
const processedResults =
result.length === 1
? processFunctionResponse(result[0])
: result.map(processFunctionResponse);
if (typeof processedResults === 'string') {
return processedResults;
/**
* Transforms the raw MCP content blocks from the SDK response into a
* standard GenAI Part array.
* @param sdkResponse The raw Part[] array from `mcpTool.callTool()`.
* @returns A clean Part[] array ready for the scheduler.
*/
function transformMcpContentToParts(sdkResponse: Part[]): Part[] {
const funcResponse = sdkResponse?.[0]?.functionResponse;
const mcpContent = funcResponse?.response?.content as McpContentBlock[];
const toolName = funcResponse?.name || 'unknown tool';
if (!Array.isArray(mcpContent)) {
return [{ text: '[Error: Could not parse tool response]' }];
}
return '```json\n' + JSON.stringify(processedResults, null, 2) + '\n```';
const transformed = mcpContent.flatMap(
(block: McpContentBlock): Part | Part[] | null => {
switch (block.type) {
case 'text':
return transformTextBlock(block);
case 'image':
case 'audio':
return transformImageAudioBlock(block, toolName);
case 'resource':
return transformResourceBlock(block, toolName);
case 'resource_link':
return transformResourceLinkBlock(block);
default:
return null;
}
},
);
return transformed.filter((part): part is Part => part !== null);
}
/**
* Processes the raw response from the MCP tool to generate a clean,
* human-readable string for display in the CLI. It summarizes non-text
* content and presents text directly.
*
* @param rawResponse The raw Part[] array from the GenAI SDK.
* @returns A formatted string representing the tool's output.
*/
function getStringifiedResultForDisplay(rawResponse: Part[]): string {
const mcpContent = rawResponse?.[0]?.functionResponse?.response
?.content as McpContentBlock[];
if (!Array.isArray(mcpContent)) {
return '```json\n' + JSON.stringify(rawResponse, null, 2) + '\n```';
}
const displayParts = mcpContent.map((block: McpContentBlock): string => {
switch (block.type) {
case 'text':
return block.text;
case 'image':
return `[Image: ${block.mimeType}]`;
case 'audio':
return `[Audio: ${block.mimeType}]`;
case 'resource_link':
return `[Link to ${block.title || block.name}: ${block.uri}]`;
case 'resource':
if (block.resource?.text) {
return block.resource.text;
}
return `[Embedded Resource: ${
block.resource?.mimeType || 'unknown type'
}]`;
default:
return `[Unknown content type: ${(block as { type: string }).type}]`;
}
});
return displayParts.join('\n');
}
/** Visible for testing */

View File

@@ -94,6 +94,7 @@ describe('modifyWithEditor', () => {
mockModifyContext,
'vscode' as EditorType,
abortSignal,
vi.fn(),
);
expect(mockModifyContext.getCurrentContent).toHaveBeenCalledWith(
@@ -148,6 +149,7 @@ describe('modifyWithEditor', () => {
mockModifyContext,
'vscode' as EditorType,
abortSignal,
vi.fn(),
);
const stats = await fsp.stat(diffDir);
@@ -165,6 +167,7 @@ describe('modifyWithEditor', () => {
mockModifyContext,
'vscode' as EditorType,
abortSignal,
vi.fn(),
);
expect(mkdirSpy).not.toHaveBeenCalled();
@@ -183,6 +186,7 @@ describe('modifyWithEditor', () => {
mockModifyContext,
'vscode' as EditorType,
abortSignal,
vi.fn(),
);
expect(mockCreatePatch).toHaveBeenCalledWith(
@@ -211,6 +215,7 @@ describe('modifyWithEditor', () => {
mockModifyContext,
'vscode' as EditorType,
abortSignal,
vi.fn(),
);
expect(mockCreatePatch).toHaveBeenCalledWith(
@@ -241,6 +246,7 @@ describe('modifyWithEditor', () => {
mockModifyContext,
'vscode' as EditorType,
abortSignal,
vi.fn(),
),
).rejects.toThrow('Editor failed to open');
@@ -267,6 +273,7 @@ describe('modifyWithEditor', () => {
mockModifyContext,
'vscode' as EditorType,
abortSignal,
vi.fn(),
);
expect(consoleErrorSpy).toHaveBeenCalledTimes(2);
@@ -290,6 +297,7 @@ describe('modifyWithEditor', () => {
mockModifyContext,
'vscode' as EditorType,
abortSignal,
vi.fn(),
);
expect(mockOpenDiff).toHaveBeenCalledOnce();
@@ -311,6 +319,7 @@ describe('modifyWithEditor', () => {
mockModifyContext,
'vscode' as EditorType,
abortSignal,
vi.fn(),
);
expect(mockOpenDiff).toHaveBeenCalledOnce();

View File

@@ -138,6 +138,7 @@ export async function modifyWithEditor<ToolParams>(
modifyContext: ModifyContext<ToolParams>,
editorType: EditorType,
_abortSignal: AbortSignal,
onEditorClose: () => void,
): Promise<ModifyResult<ToolParams>> {
const currentContent = await modifyContext.getCurrentContent(originalParams);
const proposedContent =
@@ -150,7 +151,7 @@ export async function modifyWithEditor<ToolParams>(
);
try {
await openDiff(oldPath, newPath, editorType);
await openDiff(oldPath, newPath, editorType, onEditorClose);
const result = getUpdatedParams(
oldPath,
newPath,

View File

@@ -477,4 +477,139 @@ describe('ReadManyFilesTool', () => {
fs.rmSync(tempDir2, { recursive: true, force: true });
});
});
describe('Batch Processing', () => {
const createMultipleFiles = (count: number, contentPrefix = 'Content') => {
const files: string[] = [];
for (let i = 0; i < count; i++) {
const fileName = `file${i}.txt`;
createFile(fileName, `${contentPrefix} ${i}`);
files.push(fileName);
}
return files;
};
const createFile = (filePath: string, content = '') => {
const fullPath = path.join(tempRootDir, filePath);
fs.mkdirSync(path.dirname(fullPath), { recursive: true });
fs.writeFileSync(fullPath, content);
};
it('should process files in parallel for performance', async () => {
// Mock detectFileType to add artificial delay to simulate I/O
const detectFileTypeSpy = vi.spyOn(
await import('../utils/fileUtils.js'),
'detectFileType',
);
// Create files
const fileCount = 4;
const files = createMultipleFiles(fileCount, 'Batch test');
// Mock with 100ms delay per file to simulate I/O operations
detectFileTypeSpy.mockImplementation(async (_filePath: string) => {
await new Promise((resolve) => setTimeout(resolve, 100));
return 'text';
});
const startTime = Date.now();
const params = { paths: files };
const result = await tool.execute(params, new AbortController().signal);
const endTime = Date.now();
const processingTime = endTime - startTime;
console.log(
`Processing time: ${processingTime}ms for ${fileCount} files`,
);
// Verify parallel processing performance improvement
// Parallel processing should complete in ~100ms (single file time)
// Sequential would take ~400ms (4 files × 100ms each)
expect(processingTime).toBeLessThan(200); // Should PASS with parallel implementation
// Verify all files were processed
const content = result.llmContent as string[];
expect(content).toHaveLength(fileCount);
// Cleanup mock
detectFileTypeSpy.mockRestore();
});
it('should handle batch processing errors gracefully', async () => {
// Create mix of valid and problematic files
createFile('valid1.txt', 'Valid content 1');
createFile('valid2.txt', 'Valid content 2');
createFile('valid3.txt', 'Valid content 3');
const params = {
paths: [
'valid1.txt',
'valid2.txt',
'nonexistent-file.txt', // This will fail
'valid3.txt',
],
};
const result = await tool.execute(params, new AbortController().signal);
const content = result.llmContent as string[];
// Should successfully process valid files despite one failure
expect(content.length).toBeGreaterThanOrEqual(3);
expect(result.returnDisplay).toContain('Successfully read');
// Verify valid files were processed
const expectedPath1 = path.join(tempRootDir, 'valid1.txt');
const expectedPath3 = path.join(tempRootDir, 'valid3.txt');
expect(content.some((c) => c.includes(expectedPath1))).toBe(true);
expect(content.some((c) => c.includes(expectedPath3))).toBe(true);
});
it('should execute file operations concurrently', async () => {
// Track execution order to verify concurrency
const executionOrder: string[] = [];
const detectFileTypeSpy = vi.spyOn(
await import('../utils/fileUtils.js'),
'detectFileType',
);
const files = ['file1.txt', 'file2.txt', 'file3.txt'];
files.forEach((file) => createFile(file, 'test content'));
// Mock to track concurrent vs sequential execution
detectFileTypeSpy.mockImplementation(async (filePath: string) => {
const fileName = filePath.split('/').pop() || '';
executionOrder.push(`start:${fileName}`);
// Add delay to make timing differences visible
await new Promise((resolve) => setTimeout(resolve, 50));
executionOrder.push(`end:${fileName}`);
return 'text';
});
await tool.execute({ paths: files }, new AbortController().signal);
console.log('Execution order:', executionOrder);
// Verify concurrent execution pattern
// In parallel execution: all "start:" events should come before all "end:" events
// In sequential execution: "start:file1", "end:file1", "start:file2", "end:file2", etc.
const startEvents = executionOrder.filter((e) =>
e.startsWith('start:'),
).length;
const firstEndIndex = executionOrder.findIndex((e) =>
e.startsWith('end:'),
);
const startsBeforeFirstEnd = executionOrder
.slice(0, firstEndIndex)
.filter((e) => e.startsWith('start:')).length;
// For parallel processing, ALL start events should happen before the first end event
expect(startsBeforeFirstEnd).toBe(startEvents); // Should PASS with parallel implementation
detectFileTypeSpy.mockRestore();
});
});
});

View File

@@ -70,6 +70,27 @@ export interface ReadManyFilesParams {
};
}
/**
* Result type for file processing operations
*/
type FileProcessingResult =
| {
success: true;
filePath: string;
relativePathForDisplay: string;
fileReadResult: NonNullable<
Awaited<ReturnType<typeof processSingleFileContent>>
>;
reason?: undefined;
}
| {
success: false;
filePath: string;
relativePathForDisplay: string;
fileReadResult?: undefined;
reason: string;
};
/**
* Default exclusion patterns for commonly ignored directories and binary file types.
* These are compatible with glob ignore patterns.
@@ -413,66 +434,124 @@ Use this tool when the user's query implies needing the content of several files
const sortedFiles = Array.from(filesToConsider).sort();
for (const filePath of sortedFiles) {
const relativePathForDisplay = path
.relative(this.config.getTargetDir(), filePath)
.replace(/\\/g, '/');
const fileProcessingPromises = sortedFiles.map(
async (filePath): Promise<FileProcessingResult> => {
try {
const relativePathForDisplay = path
.relative(this.config.getTargetDir(), filePath)
.replace(/\\/g, '/');
const fileType = await detectFileType(filePath);
const fileType = await detectFileType(filePath);
if (fileType === 'image' || fileType === 'pdf') {
const fileExtension = path.extname(filePath).toLowerCase();
const fileNameWithoutExtension = path.basename(filePath, fileExtension);
const requestedExplicitly = inputPatterns.some(
(pattern: string) =>
pattern.toLowerCase().includes(fileExtension) ||
pattern.includes(fileNameWithoutExtension),
);
if (fileType === 'image' || fileType === 'pdf') {
const fileExtension = path.extname(filePath).toLowerCase();
const fileNameWithoutExtension = path.basename(
filePath,
fileExtension,
);
const requestedExplicitly = inputPatterns.some(
(pattern: string) =>
pattern.toLowerCase().includes(fileExtension) ||
pattern.includes(fileNameWithoutExtension),
);
if (!requestedExplicitly) {
skippedFiles.push({
path: relativePathForDisplay,
reason:
'asset file (image/pdf) was not explicitly requested by name or extension',
});
continue;
}
}
if (!requestedExplicitly) {
return {
success: false,
filePath,
relativePathForDisplay,
reason:
'asset file (image/pdf) was not explicitly requested by name or extension',
};
}
}
// Use processSingleFileContent for all file types now
const fileReadResult = await processSingleFileContent(
filePath,
this.config.getTargetDir(),
);
if (fileReadResult.error) {
skippedFiles.push({
path: relativePathForDisplay,
reason: `Read error: ${fileReadResult.error}`,
});
} else {
if (typeof fileReadResult.llmContent === 'string') {
const separator = DEFAULT_OUTPUT_SEPARATOR_FORMAT.replace(
'{filePath}',
// Use processSingleFileContent for all file types now
const fileReadResult = await processSingleFileContent(
filePath,
this.config.getTargetDir(),
);
contentParts.push(`${separator}\n\n${fileReadResult.llmContent}\n\n`);
} else {
contentParts.push(fileReadResult.llmContent); // This is a Part for image/pdf
if (fileReadResult.error) {
return {
success: false,
filePath,
relativePathForDisplay,
reason: `Read error: ${fileReadResult.error}`,
};
}
return {
success: true,
filePath,
relativePathForDisplay,
fileReadResult,
};
} catch (error) {
const relativePathForDisplay = path
.relative(this.config.getTargetDir(), filePath)
.replace(/\\/g, '/');
return {
success: false,
filePath,
relativePathForDisplay,
reason: `Unexpected error: ${error instanceof Error ? error.message : String(error)}`,
};
}
processedFilesRelativePaths.push(relativePathForDisplay);
const lines =
typeof fileReadResult.llmContent === 'string'
? fileReadResult.llmContent.split('\n').length
: undefined;
const mimetype = getSpecificMimeType(filePath);
recordFileOperationMetric(
this.config,
FileOperation.READ,
lines,
mimetype,
path.extname(filePath),
);
},
);
const results = await Promise.allSettled(fileProcessingPromises);
for (const result of results) {
if (result.status === 'fulfilled') {
const fileResult = result.value;
if (!fileResult.success) {
// Handle skipped files (images/PDFs not requested or read errors)
skippedFiles.push({
path: fileResult.relativePathForDisplay,
reason: fileResult.reason,
});
} else {
// Handle successfully processed files
const { filePath, relativePathForDisplay, fileReadResult } =
fileResult;
if (typeof fileReadResult.llmContent === 'string') {
const separator = DEFAULT_OUTPUT_SEPARATOR_FORMAT.replace(
'{filePath}',
filePath,
);
contentParts.push(
`${separator}\n\n${fileReadResult.llmContent}\n\n`,
);
} else {
contentParts.push(fileReadResult.llmContent); // This is a Part for image/pdf
}
processedFilesRelativePaths.push(relativePathForDisplay);
const lines =
typeof fileReadResult.llmContent === 'string'
? fileReadResult.llmContent.split('\n').length
: undefined;
const mimetype = getSpecificMimeType(filePath);
recordFileOperationMetric(
this.config,
FileOperation.READ,
lines,
mimetype,
path.extname(filePath),
);
}
} else {
// Handle Promise rejection (unexpected errors)
skippedFiles.push({
path: 'unknown',
reason: `Unexpected error: ${result.reason}`,
});
}
}

View File

@@ -543,3 +543,37 @@ describe('validateToolParams', () => {
expect(result).toContain('is not a registered workspace directory');
});
});
describe('validateToolParams', () => {
it('should return null for valid directory', () => {
const config = {
getCoreTools: () => undefined,
getExcludeTools: () => undefined,
getTargetDir: () => '/root',
getWorkspaceContext: () =>
createMockWorkspaceContext('/root', ['/users/test']),
} as unknown as Config;
const shellTool = new ShellTool(config);
const result = shellTool.validateToolParams({
command: 'ls',
directory: 'test',
});
expect(result).toBeNull();
});
it('should return error for directory outside workspace', () => {
const config = {
getCoreTools: () => undefined,
getExcludeTools: () => undefined,
getTargetDir: () => '/root',
getWorkspaceContext: () =>
createMockWorkspaceContext('/root', ['/users/test']),
} as unknown as Config;
const shellTool = new ShellTool(config);
const result = shellTool.validateToolParams({
command: 'ls',
directory: 'test2',
});
expect(result).toContain('is not a registered workspace directory');
});
});

View File

@@ -30,7 +30,6 @@ import {
Schema,
} from '@google/genai';
import { spawn } from 'node:child_process';
import { IdeClient } from '../ide/ide-client.js';
import fs from 'node:fs';
vi.mock('node:fs');
@@ -140,7 +139,6 @@ const baseConfigParams: ConfigParameters = {
geminiMdFileCount: 0,
approvalMode: ApprovalMode.DEFAULT,
sessionId: 'test-session-id',
ideClient: IdeClient.getInstance(false),
};
describe('ToolRegistry', () => {
@@ -172,6 +170,10 @@ describe('ToolRegistry', () => {
);
vi.spyOn(config, 'getMcpServers');
vi.spyOn(config, 'getMcpServerCommand');
vi.spyOn(config, 'getPromptRegistry').mockReturnValue({
clear: vi.fn(),
removePromptsByServer: vi.fn(),
} as any);
mockDiscoverMcpTools.mockReset().mockResolvedValue(undefined);
});
@@ -353,7 +355,7 @@ describe('ToolRegistry', () => {
mcpServerConfigVal,
undefined,
toolRegistry,
undefined,
config.getPromptRegistry(),
false,
);
});
@@ -376,7 +378,7 @@ describe('ToolRegistry', () => {
mcpServerConfigVal,
undefined,
toolRegistry,
undefined,
config.getPromptRegistry(),
false,
);
});

View File

@@ -150,6 +150,14 @@ export class ToolRegistry {
this.tools.set(tool.name, tool);
}
private removeDiscoveredTools(): void {
for (const tool of this.tools.values()) {
if (tool instanceof DiscoveredTool || tool instanceof DiscoveredMCPTool) {
this.tools.delete(tool.name);
}
}
}
/**
* Discovers tools from project (if available and configured).
* Can be called multiple times to update discovered tools.
@@ -157,11 +165,9 @@ export class ToolRegistry {
*/
async discoverAllTools(): Promise<void> {
// remove any previously discovered tools
for (const tool of this.tools.values()) {
if (tool instanceof DiscoveredTool || tool instanceof DiscoveredMCPTool) {
this.tools.delete(tool.name);
}
}
this.removeDiscoveredTools();
this.config.getPromptRegistry().clear();
await this.discoverAndRegisterToolsFromCommand();
@@ -182,11 +188,9 @@ export class ToolRegistry {
*/
async discoverMcpTools(): Promise<void> {
// remove any previously discovered tools
for (const tool of this.tools.values()) {
if (tool instanceof DiscoveredMCPTool) {
this.tools.delete(tool.name);
}
}
this.removeDiscoveredTools();
this.config.getPromptRegistry().clear();
// discover tools using MCP servers, if configured
await discoverMcpTools(
@@ -210,6 +214,8 @@ export class ToolRegistry {
}
}
this.config.getPromptRegistry().removePromptsByServer(serverName);
const mcpServers = this.config.getMcpServers() ?? {};
const serverConfig = mcpServers[serverName];
if (serverConfig) {

View File

@@ -331,7 +331,7 @@ describe('editor utils', () => {
}),
};
(spawn as Mock).mockReturnValue(mockSpawn);
await openDiff('old.txt', 'new.txt', editor);
await openDiff('old.txt', 'new.txt', editor, () => {});
const diffCommand = getDiffCommand('old.txt', 'new.txt', editor)!;
expect(spawn).toHaveBeenCalledWith(
diffCommand.command,
@@ -361,9 +361,9 @@ describe('editor utils', () => {
}),
};
(spawn as Mock).mockReturnValue(mockSpawn);
await expect(openDiff('old.txt', 'new.txt', editor)).rejects.toThrow(
'spawn error',
);
await expect(
openDiff('old.txt', 'new.txt', editor, () => {}),
).rejects.toThrow('spawn error');
});
it(`should reject if ${editor} exits with non-zero code`, async () => {
@@ -375,9 +375,9 @@ describe('editor utils', () => {
}),
};
(spawn as Mock).mockReturnValue(mockSpawn);
await expect(openDiff('old.txt', 'new.txt', editor)).rejects.toThrow(
`${editor} exited with code 1`,
);
await expect(
openDiff('old.txt', 'new.txt', editor, () => {}),
).rejects.toThrow(`${editor} exited with code 1`);
});
}
@@ -385,7 +385,7 @@ describe('editor utils', () => {
for (const editor of execSyncEditors) {
it(`should call execSync for ${editor} on non-windows`, async () => {
Object.defineProperty(process, 'platform', { value: 'linux' });
await openDiff('old.txt', 'new.txt', editor);
await openDiff('old.txt', 'new.txt', editor, () => {});
expect(execSync).toHaveBeenCalledTimes(1);
const diffCommand = getDiffCommand('old.txt', 'new.txt', editor)!;
const expectedCommand = `${
@@ -399,7 +399,7 @@ describe('editor utils', () => {
it(`should call execSync for ${editor} on windows`, async () => {
Object.defineProperty(process, 'platform', { value: 'win32' });
await openDiff('old.txt', 'new.txt', editor);
await openDiff('old.txt', 'new.txt', editor, () => {});
expect(execSync).toHaveBeenCalledTimes(1);
const diffCommand = getDiffCommand('old.txt', 'new.txt', editor)!;
const expectedCommand = `${diffCommand.command} ${diffCommand.args.join(
@@ -417,11 +417,46 @@ describe('editor utils', () => {
.spyOn(console, 'error')
.mockImplementation(() => {});
// @ts-expect-error Testing unsupported editor
await openDiff('old.txt', 'new.txt', 'foobar');
await openDiff('old.txt', 'new.txt', 'foobar', () => {});
expect(consoleErrorSpy).toHaveBeenCalledWith(
'No diff tool available. Install a supported editor.',
);
});
describe('onEditorClose callback', () => {
it('should call onEditorClose for execSync editors', async () => {
(execSync as Mock).mockReturnValue(Buffer.from(`/usr/bin/`));
const onEditorClose = vi.fn();
await openDiff('old.txt', 'new.txt', 'vim', onEditorClose);
expect(execSync).toHaveBeenCalledTimes(1);
expect(onEditorClose).toHaveBeenCalledTimes(1);
});
it('should call onEditorClose for execSync editors when an error is thrown', async () => {
(execSync as Mock).mockImplementation(() => {
throw new Error('test error');
});
const onEditorClose = vi.fn();
openDiff('old.txt', 'new.txt', 'vim', onEditorClose);
expect(execSync).toHaveBeenCalledTimes(1);
expect(onEditorClose).toHaveBeenCalledTimes(1);
});
it('should not call onEditorClose for spawn editors', async () => {
const onEditorClose = vi.fn();
const mockSpawn = {
on: vi.fn((event, cb) => {
if (event === 'close') {
cb(0);
}
}),
};
(spawn as Mock).mockReturnValue(mockSpawn);
await openDiff('old.txt', 'new.txt', 'vscode', onEditorClose);
expect(spawn).toHaveBeenCalledTimes(1);
expect(onEditorClose).not.toHaveBeenCalled();
});
});
});
describe('allowEditorTypeInSandbox', () => {

View File

@@ -164,6 +164,7 @@ export async function openDiff(
oldPath: string,
newPath: string,
editor: EditorType,
onEditorClose: () => void,
): Promise<void> {
const diffCommand = getDiffCommand(oldPath, newPath, editor);
if (!diffCommand) {
@@ -206,10 +207,16 @@ export async function openDiff(
process.platform === 'win32'
? `${diffCommand.command} ${diffCommand.args.join(' ')}`
: `${diffCommand.command} ${diffCommand.args.map((arg) => `"${arg}"`).join(' ')}`;
execSync(command, {
stdio: 'inherit',
encoding: 'utf8',
});
try {
execSync(command, {
stdio: 'inherit',
encoding: 'utf8',
});
} catch (e) {
console.error('Error in onEditorClose callback:', e);
} finally {
onEditorClose();
}
break;
}

View File

@@ -426,6 +426,29 @@ describe('fileUtils', () => {
expect(result.linesShown).toEqual([6, 10]);
});
it('should identify truncation when reading the end of a file', async () => {
const lines = Array.from({ length: 20 }, (_, i) => `Line ${i + 1}`);
actualNodeFs.writeFileSync(testTextFilePath, lines.join('\n'));
// Read from line 11 to 20. The start is not 0, so it's truncated.
const result = await processSingleFileContent(
testTextFilePath,
tempRootDir,
10,
10,
);
const expectedContent = lines.slice(10, 20).join('\n');
expect(result.llmContent).toContain(expectedContent);
expect(result.llmContent).toContain(
'[File content truncated: showing lines 11-20 of 20 total lines. Use offset/limit parameters to view more.]',
);
expect(result.returnDisplay).toBe('Read lines 11-20 of 20 from test.txt');
expect(result.isTruncated).toBe(true); // This is the key check for the bug
expect(result.originalLineCount).toBe(20);
expect(result.linesShown).toEqual([11, 20]);
});
it('should handle limit exceeding file length', async () => {
const lines = ['Line 1', 'Line 2'];
actualNodeFs.writeFileSync(testTextFilePath, lines.join('\n'));

View File

@@ -299,7 +299,8 @@ export async function processSingleFileContent(
return line;
});
const contentRangeTruncated = endLine < originalLineCount;
const contentRangeTruncated =
startLine > 0 || endLine < originalLineCount;
const isTruncated = contentRangeTruncated || linesWereTruncatedInLength;
let llmTextContent = '';

View File

@@ -0,0 +1,112 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { describe, it, expect, vi, afterEach, beforeEach } from 'vitest';
import { getCacheKey, read, write, clear } from './crawlCache.js';
describe('CrawlCache', () => {
describe('getCacheKey', () => {
it('should generate a consistent hash', () => {
const key1 = getCacheKey('/foo', 'bar');
const key2 = getCacheKey('/foo', 'bar');
expect(key1).toBe(key2);
});
it('should generate a different hash for different directories', () => {
const key1 = getCacheKey('/foo', 'bar');
const key2 = getCacheKey('/bar', 'bar');
expect(key1).not.toBe(key2);
});
it('should generate a different hash for different ignore content', () => {
const key1 = getCacheKey('/foo', 'bar');
const key2 = getCacheKey('/foo', 'baz');
expect(key1).not.toBe(key2);
});
});
describe('in-memory cache operations', () => {
beforeEach(() => {
// Ensure a clean slate before each test
clear();
});
afterEach(() => {
// Restore real timers after each test that uses fake ones
vi.useRealTimers();
});
it('should write and read data from the cache', () => {
const key = 'test-key';
const data = ['foo', 'bar'];
write(key, data, 10000); // 10 second TTL
const cachedData = read(key);
expect(cachedData).toEqual(data);
});
it('should return undefined for a nonexistent key', () => {
const cachedData = read('nonexistent-key');
expect(cachedData).toBeUndefined();
});
it('should clear the cache', () => {
const key = 'test-key';
const data = ['foo', 'bar'];
write(key, data, 10000);
clear();
const cachedData = read(key);
expect(cachedData).toBeUndefined();
});
it('should automatically evict a cache entry after its TTL expires', async () => {
vi.useFakeTimers();
const key = 'ttl-key';
const data = ['foo'];
const ttl = 5000; // 5 seconds
write(key, data, ttl);
// Should exist immediately after writing
expect(read(key)).toEqual(data);
// Advance time just before expiration
await vi.advanceTimersByTimeAsync(ttl - 1);
expect(read(key)).toEqual(data);
// Advance time past expiration
await vi.advanceTimersByTimeAsync(1);
expect(read(key)).toBeUndefined();
});
it('should reset the timer when an entry is updated', async () => {
vi.useFakeTimers();
const key = 'update-key';
const initialData = ['initial'];
const updatedData = ['updated'];
const ttl = 5000; // 5 seconds
// Write initial data
write(key, initialData, ttl);
// Advance time, but not enough to expire
await vi.advanceTimersByTimeAsync(3000);
expect(read(key)).toEqual(initialData);
// Update the data, which should reset the timer
write(key, updatedData, ttl);
expect(read(key)).toEqual(updatedData);
// Advance time again. If the timer wasn't reset, the total elapsed
// time (3000 + 3000 = 6000) would cause an eviction.
await vi.advanceTimersByTimeAsync(3000);
expect(read(key)).toEqual(updatedData);
// Advance past the new expiration time
await vi.advanceTimersByTimeAsync(2001);
expect(read(key)).toBeUndefined();
});
});
});

View File

@@ -0,0 +1,65 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import crypto from 'node:crypto';
const crawlCache = new Map<string, string[]>();
const cacheTimers = new Map<string, NodeJS.Timeout>();
/**
* Generates a unique cache key based on the project directory and the content
* of ignore files. This ensures that the cache is invalidated if the project
* or ignore rules change.
*/
export const getCacheKey = (
directory: string,
ignoreContent: string,
): string => {
const hash = crypto.createHash('sha256');
hash.update(directory);
hash.update(ignoreContent);
return hash.digest('hex');
};
/**
* Reads cached data from the in-memory cache.
* Returns undefined if the key is not found.
*/
export const read = (key: string): string[] | undefined => crawlCache.get(key);
/**
* Writes data to the in-memory cache and sets a timer to evict it after the TTL.
*/
export const write = (key: string, results: string[], ttlMs: number): void => {
// Clear any existing timer for this key to prevent premature deletion
if (cacheTimers.has(key)) {
clearTimeout(cacheTimers.get(key)!);
}
// Store the new data
crawlCache.set(key, results);
// Set a timer to automatically delete the cache entry after the TTL
const timerId = setTimeout(() => {
crawlCache.delete(key);
cacheTimers.delete(key);
}, ttlMs);
// Store the timer handle so we can clear it if the entry is updated
cacheTimers.set(key, timerId);
};
/**
* Clears the entire cache and all active timers.
* Primarily used for testing.
*/
export const clear = (): void => {
for (const timerId of cacheTimers.values()) {
clearTimeout(timerId);
}
crawlCache.clear();
cacheTimers.clear();
};

View File

@@ -0,0 +1,642 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { describe, it, expect, beforeEach, afterEach, vi } from 'vitest';
import * as fs from 'fs/promises';
import * as path from 'path';
import * as cache from './crawlCache.js';
import { FileSearch, AbortError, filter } from './fileSearch.js';
import { createTmpDir, cleanupTmpDir } from '@google/gemini-cli-test-utils';
type FileSearchWithPrivateMethods = FileSearch & {
performCrawl: () => Promise<void>;
};
describe('FileSearch', () => {
let tmpDir: string;
afterEach(async () => {
if (tmpDir) {
await cleanupTmpDir(tmpDir);
}
vi.restoreAllMocks();
});
it('should use .geminiignore rules', async () => {
tmpDir = await createTmpDir({
'.geminiignore': 'dist/',
dist: ['ignored.js'],
src: ['not-ignored.js'],
});
const fileSearch = new FileSearch({
projectRoot: tmpDir,
useGitignore: false,
useGeminiignore: true,
ignoreDirs: [],
cache: false,
cacheTtl: 0,
});
await fileSearch.initialize();
const results = await fileSearch.search('');
expect(results).toEqual(['src/', '.geminiignore', 'src/not-ignored.js']);
});
it('should combine .gitignore and .geminiignore rules', async () => {
tmpDir = await createTmpDir({
'.gitignore': 'dist/',
'.geminiignore': 'build/',
dist: ['ignored-by-git.js'],
build: ['ignored-by-gemini.js'],
src: ['not-ignored.js'],
});
const fileSearch = new FileSearch({
projectRoot: tmpDir,
useGitignore: true,
useGeminiignore: true,
ignoreDirs: [],
cache: false,
cacheTtl: 0,
});
await fileSearch.initialize();
const results = await fileSearch.search('');
expect(results).toEqual([
'src/',
'.geminiignore',
'.gitignore',
'src/not-ignored.js',
]);
});
it('should use ignoreDirs option', async () => {
tmpDir = await createTmpDir({
logs: ['some.log'],
src: ['main.js'],
});
const fileSearch = new FileSearch({
projectRoot: tmpDir,
useGitignore: false,
useGeminiignore: false,
ignoreDirs: ['logs'],
cache: false,
cacheTtl: 0,
});
await fileSearch.initialize();
const results = await fileSearch.search('');
expect(results).toEqual(['src/', 'src/main.js']);
});
it('should handle negated directories', async () => {
tmpDir = await createTmpDir({
'.gitignore': ['build/**', '!build/public', '!build/public/**'].join(
'\n',
),
build: {
'private.js': '',
public: ['index.html'],
},
src: ['main.js'],
});
const fileSearch = new FileSearch({
projectRoot: tmpDir,
useGitignore: true,
useGeminiignore: false,
ignoreDirs: [],
cache: false,
cacheTtl: 0,
});
await fileSearch.initialize();
const results = await fileSearch.search('');
expect(results).toEqual([
'build/',
'build/public/',
'src/',
'.gitignore',
'build/public/index.html',
'src/main.js',
]);
});
it('should filter results with a search pattern', async () => {
tmpDir = await createTmpDir({
src: {
'main.js': '',
'util.ts': '',
'style.css': '',
},
});
const fileSearch = new FileSearch({
projectRoot: tmpDir,
useGitignore: false,
useGeminiignore: false,
ignoreDirs: [],
cache: false,
cacheTtl: 0,
});
await fileSearch.initialize();
const results = await fileSearch.search('**/*.js');
expect(results).toEqual(['src/main.js']);
});
it('should handle root-level file negation', async () => {
tmpDir = await createTmpDir({
'.gitignore': ['*.mk', '!Foo.mk'].join('\n'),
'bar.mk': '',
'Foo.mk': '',
});
const fileSearch = new FileSearch({
projectRoot: tmpDir,
useGitignore: true,
useGeminiignore: false,
ignoreDirs: [],
cache: false,
cacheTtl: 0,
});
await fileSearch.initialize();
const results = await fileSearch.search('');
expect(results).toEqual(['.gitignore', 'Foo.mk']);
});
it('should handle directory negation with glob', async () => {
tmpDir = await createTmpDir({
'.gitignore': [
'third_party/**',
'!third_party/foo',
'!third_party/foo/bar',
'!third_party/foo/bar/baz_buffer',
].join('\n'),
third_party: {
foo: {
bar: {
baz_buffer: '',
},
},
ignore_this: '',
},
});
const fileSearch = new FileSearch({
projectRoot: tmpDir,
useGitignore: true,
useGeminiignore: false,
ignoreDirs: [],
cache: false,
cacheTtl: 0,
});
await fileSearch.initialize();
const results = await fileSearch.search('');
expect(results).toEqual([
'third_party/',
'third_party/foo/',
'third_party/foo/bar/',
'.gitignore',
'third_party/foo/bar/baz_buffer',
]);
});
it('should correctly handle negated patterns in .gitignore', async () => {
tmpDir = await createTmpDir({
'.gitignore': ['dist/**', '!dist/keep.js'].join('\n'),
dist: ['ignore.js', 'keep.js'],
src: ['main.js'],
});
const fileSearch = new FileSearch({
projectRoot: tmpDir,
useGitignore: true,
useGeminiignore: false,
ignoreDirs: [],
cache: false,
cacheTtl: 0,
});
await fileSearch.initialize();
const results = await fileSearch.search('');
expect(results).toEqual([
'dist/',
'src/',
'.gitignore',
'dist/keep.js',
'src/main.js',
]);
});
// New test cases start here
it('should initialize correctly when ignore files are missing', async () => {
tmpDir = await createTmpDir({
src: ['file1.js'],
});
const fileSearch = new FileSearch({
projectRoot: tmpDir,
useGitignore: true,
useGeminiignore: true,
ignoreDirs: [],
cache: false,
cacheTtl: 0,
});
// Expect no errors to be thrown during initialization
await expect(fileSearch.initialize()).resolves.toBeUndefined();
const results = await fileSearch.search('');
expect(results).toEqual(['src/', 'src/file1.js']);
});
it('should respect maxResults option in search', async () => {
tmpDir = await createTmpDir({
src: {
'file1.js': '',
'file2.js': '',
'file3.js': '',
'file4.js': '',
},
});
const fileSearch = new FileSearch({
projectRoot: tmpDir,
useGitignore: false,
useGeminiignore: false,
ignoreDirs: [],
cache: false,
cacheTtl: 0,
});
await fileSearch.initialize();
const results = await fileSearch.search('**/*.js', { maxResults: 2 });
expect(results).toEqual(['src/file1.js', 'src/file2.js']); // Assuming alphabetical sort
});
it('should return empty array when no matches are found', async () => {
tmpDir = await createTmpDir({
src: ['file1.js'],
});
const fileSearch = new FileSearch({
projectRoot: tmpDir,
useGitignore: false,
useGeminiignore: false,
ignoreDirs: [],
cache: false,
cacheTtl: 0,
});
await fileSearch.initialize();
const results = await fileSearch.search('nonexistent-file.xyz');
expect(results).toEqual([]);
});
it('should throw AbortError when filter is aborted', async () => {
const controller = new AbortController();
const dummyPaths = Array.from({ length: 5000 }, (_, i) => `file${i}.js`); // Large array to ensure yielding
const filterPromise = filter(dummyPaths, '*.js', controller.signal);
// Abort after a short delay to ensure filter has started
setTimeout(() => controller.abort(), 1);
await expect(filterPromise).rejects.toThrow(AbortError);
});
describe('with in-memory cache', () => {
beforeEach(() => {
cache.clear();
});
afterEach(() => {
vi.useRealTimers();
});
it('should throw an error if search is called before initialization', async () => {
tmpDir = await createTmpDir({});
const fileSearch = new FileSearch({
projectRoot: tmpDir,
useGitignore: false,
useGeminiignore: false,
ignoreDirs: [],
cache: false,
cacheTtl: 0,
});
await expect(fileSearch.search('')).rejects.toThrow(
'Engine not initialized. Call initialize() first.',
);
});
it('should hit the cache for subsequent searches', async () => {
tmpDir = await createTmpDir({ 'file1.js': '' });
const getOptions = () => ({
projectRoot: tmpDir,
useGitignore: false,
useGeminiignore: false,
ignoreDirs: [],
cache: true,
cacheTtl: 10,
});
const fs1 = new FileSearch(getOptions());
const crawlSpy1 = vi.spyOn(
fs1 as FileSearchWithPrivateMethods,
'performCrawl',
);
await fs1.initialize();
expect(crawlSpy1).toHaveBeenCalledTimes(1);
// Second search should hit the cache because the options are identical
const fs2 = new FileSearch(getOptions());
const crawlSpy2 = vi.spyOn(
fs2 as FileSearchWithPrivateMethods,
'performCrawl',
);
await fs2.initialize();
expect(crawlSpy2).not.toHaveBeenCalled();
});
it('should miss the cache when ignore rules change', async () => {
tmpDir = await createTmpDir({
'.gitignore': 'a.txt',
'a.txt': '',
'b.txt': '',
});
const options = {
projectRoot: tmpDir,
useGitignore: true,
useGeminiignore: false,
ignoreDirs: [],
cache: true,
cacheTtl: 10000,
};
// Initial search to populate the cache
const fs1 = new FileSearch(options);
const crawlSpy1 = vi.spyOn(
fs1 as FileSearchWithPrivateMethods,
'performCrawl',
);
await fs1.initialize();
const results1 = await fs1.search('');
expect(crawlSpy1).toHaveBeenCalledTimes(1);
expect(results1).toEqual(['.gitignore', 'b.txt']);
// Modify the ignore file
await fs.writeFile(path.join(tmpDir, '.gitignore'), 'b.txt');
// Second search should miss the cache and trigger a recrawl
const fs2 = new FileSearch(options);
const crawlSpy2 = vi.spyOn(
fs2 as FileSearchWithPrivateMethods,
'performCrawl',
);
await fs2.initialize();
const results2 = await fs2.search('');
expect(crawlSpy2).toHaveBeenCalledTimes(1);
expect(results2).toEqual(['.gitignore', 'a.txt']);
});
it('should miss the cache after TTL expires', async () => {
vi.useFakeTimers();
tmpDir = await createTmpDir({ 'file1.js': '' });
const options = {
projectRoot: tmpDir,
useGitignore: false,
useGeminiignore: false,
ignoreDirs: [],
cache: true,
cacheTtl: 10, // 10 seconds
};
// Initial search to populate the cache
const fs1 = new FileSearch(options);
await fs1.initialize();
// Advance time past the TTL
await vi.advanceTimersByTimeAsync(11000);
// Second search should miss the cache and trigger a recrawl
const fs2 = new FileSearch(options);
const crawlSpy = vi.spyOn(
fs2 as FileSearchWithPrivateMethods,
'performCrawl',
);
await fs2.initialize();
expect(crawlSpy).toHaveBeenCalledTimes(1);
});
});
it('should handle empty or commented-only ignore files', async () => {
tmpDir = await createTmpDir({
'.gitignore': '# This is a comment\n\n \n',
src: ['main.js'],
});
const fileSearch = new FileSearch({
projectRoot: tmpDir,
useGitignore: true,
useGeminiignore: false,
ignoreDirs: [],
cache: false,
cacheTtl: 0,
});
await fileSearch.initialize();
const results = await fileSearch.search('');
expect(results).toEqual(['src/', '.gitignore', 'src/main.js']);
});
it('should always ignore the .git directory', async () => {
tmpDir = await createTmpDir({
'.git': ['config', 'HEAD'],
src: ['main.js'],
});
const fileSearch = new FileSearch({
projectRoot: tmpDir,
useGitignore: false, // Explicitly disable .gitignore to isolate this rule
useGeminiignore: false,
ignoreDirs: [],
cache: false,
cacheTtl: 0,
});
await fileSearch.initialize();
const results = await fileSearch.search('');
expect(results).toEqual(['src/', 'src/main.js']);
});
it('should be cancellable via AbortSignal', async () => {
const largeDir: Record<string, string> = {};
for (let i = 0; i < 100; i++) {
largeDir[`file${i}.js`] = '';
}
tmpDir = await createTmpDir(largeDir);
const fileSearch = new FileSearch({
projectRoot: tmpDir,
useGitignore: false,
useGeminiignore: false,
ignoreDirs: [],
cache: false,
cacheTtl: 0,
});
await fileSearch.initialize();
const controller = new AbortController();
const searchPromise = fileSearch.search('**/*.js', {
signal: controller.signal,
});
// Yield to allow the search to start before aborting.
await new Promise((resolve) => setImmediate(resolve));
controller.abort();
await expect(searchPromise).rejects.toThrow(AbortError);
});
it('should leverage ResultCache for bestBaseQuery optimization', async () => {
tmpDir = await createTmpDir({
src: {
'foo.js': '',
'bar.ts': '',
nested: {
'baz.js': '',
},
},
});
const fileSearch = new FileSearch({
projectRoot: tmpDir,
useGitignore: false,
useGeminiignore: false,
ignoreDirs: [],
cache: true, // Enable caching for this test
cacheTtl: 0,
});
await fileSearch.initialize();
// Perform a broad search to prime the cache
const broadResults = await fileSearch.search('src/**');
expect(broadResults).toEqual([
'src/',
'src/nested/',
'src/bar.ts',
'src/foo.js',
'src/nested/baz.js',
]);
// Perform a more specific search that should leverage the broad search's cached results
const specificResults = await fileSearch.search('src/**/*.js');
expect(specificResults).toEqual(['src/foo.js', 'src/nested/baz.js']);
// Although we can't directly inspect ResultCache.hits/misses from here,
// the correctness of specificResults after a broad search implicitly
// verifies that the caching mechanism, including bestBaseQuery, is working.
});
it('should be case-insensitive by default', async () => {
tmpDir = await createTmpDir({
'File1.Js': '',
'file2.js': '',
'FILE3.JS': '',
'other.txt': '',
});
const fileSearch = new FileSearch({
projectRoot: tmpDir,
useGitignore: false,
useGeminiignore: false,
ignoreDirs: [],
cache: false,
cacheTtl: 0,
});
await fileSearch.initialize();
// Search with a lowercase pattern
let results = await fileSearch.search('file*.js');
expect(results).toHaveLength(3);
expect(results).toEqual(
expect.arrayContaining(['File1.Js', 'file2.js', 'FILE3.JS']),
);
// Search with an uppercase pattern
results = await fileSearch.search('FILE*.JS');
expect(results).toHaveLength(3);
expect(results).toEqual(
expect.arrayContaining(['File1.Js', 'file2.js', 'FILE3.JS']),
);
// Search with a mixed-case pattern
results = await fileSearch.search('FiLe*.Js');
expect(results).toHaveLength(3);
expect(results).toEqual(
expect.arrayContaining(['File1.Js', 'file2.js', 'FILE3.JS']),
);
});
it('should respect maxResults even when the cache returns an exact match', async () => {
tmpDir = await createTmpDir({
'file1.js': '',
'file2.js': '',
'file3.js': '',
'file4.js': '',
'file5.js': '',
});
const fileSearch = new FileSearch({
projectRoot: tmpDir,
useGitignore: false,
useGeminiignore: false,
ignoreDirs: [],
cache: true, // Ensure caching is enabled
cacheTtl: 10000,
});
await fileSearch.initialize();
// 1. Perform a broad search to populate the cache with an exact match.
const initialResults = await fileSearch.search('*.js');
expect(initialResults).toEqual([
'file1.js',
'file2.js',
'file3.js',
'file4.js',
'file5.js',
]);
// 2. Perform the same search again, but this time with a maxResults limit.
const limitedResults = await fileSearch.search('*.js', { maxResults: 2 });
// 3. Assert that the maxResults limit was respected, even with a cache hit.
expect(limitedResults).toEqual(['file1.js', 'file2.js']);
});
});

View File

@@ -0,0 +1,269 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import path from 'node:path';
import fs from 'node:fs';
import { fdir } from 'fdir';
import picomatch from 'picomatch';
import { Ignore } from './ignore.js';
import { ResultCache } from './result-cache.js';
import * as cache from './crawlCache.js';
export type FileSearchOptions = {
projectRoot: string;
ignoreDirs: string[];
useGitignore: boolean;
useGeminiignore: boolean;
cache: boolean;
cacheTtl: number;
};
export class AbortError extends Error {
constructor(message = 'Search aborted') {
super(message);
this.name = 'AbortError';
}
}
/**
* Filters a list of paths based on a given pattern.
* @param allPaths The list of all paths to filter.
* @param pattern The picomatch pattern to filter by.
* @param signal An AbortSignal to cancel the operation.
* @returns A promise that resolves to the filtered and sorted list of paths.
*/
export async function filter(
allPaths: string[],
pattern: string,
signal: AbortSignal | undefined,
): Promise<string[]> {
const patternFilter = picomatch(pattern, {
dot: true,
contains: true,
nocase: true,
});
const results: string[] = [];
for (const [i, p] of allPaths.entries()) {
// Yield control to the event loop periodically to prevent blocking.
if (i % 1000 === 0) {
await new Promise((resolve) => setImmediate(resolve));
if (signal?.aborted) {
throw new AbortError();
}
}
if (patternFilter(p)) {
results.push(p);
}
}
results.sort((a, b) => {
const aIsDir = a.endsWith('/');
const bIsDir = b.endsWith('/');
if (aIsDir && !bIsDir) return -1;
if (!aIsDir && bIsDir) return 1;
// This is 40% faster than localeCompare and the only thing we would really
// gain from localeCompare is case-sensitive sort
return a < b ? -1 : a > b ? 1 : 0;
});
return results;
}
export type SearchOptions = {
signal?: AbortSignal;
maxResults?: number;
};
/**
* Provides a fast and efficient way to search for files within a project,
* respecting .gitignore and .geminiignore rules, and utilizing caching
* for improved performance.
*/
export class FileSearch {
private readonly absoluteDir: string;
private readonly ignore: Ignore = new Ignore();
private resultCache: ResultCache | undefined;
private allFiles: string[] = [];
/**
* Constructs a new `FileSearch` instance.
* @param options Configuration options for the file search.
*/
constructor(private readonly options: FileSearchOptions) {
this.absoluteDir = path.resolve(options.projectRoot);
}
/**
* Initializes the file search engine by loading ignore rules, crawling the
* file system, and building the in-memory cache. This method must be called
* before performing any searches.
*/
async initialize(): Promise<void> {
this.loadIgnoreRules();
await this.crawlFiles();
this.buildResultCache();
}
/**
* Searches for files matching a given pattern.
* @param pattern The picomatch pattern to search for (e.g., '*.js', 'src/**').
* @param options Search options, including an AbortSignal and maxResults.
* @returns A promise that resolves to a list of matching file paths, relative
* to the project root.
*/
async search(
pattern: string,
options: SearchOptions = {},
): Promise<string[]> {
if (!this.resultCache) {
throw new Error('Engine not initialized. Call initialize() first.');
}
pattern = pattern || '*';
const { files: candidates, isExactMatch } =
await this.resultCache!.get(pattern);
let filteredCandidates;
if (isExactMatch) {
filteredCandidates = candidates;
} else {
// Apply the user's picomatch pattern filter
filteredCandidates = await filter(candidates, pattern, options.signal);
this.resultCache!.set(pattern, filteredCandidates);
}
// Trade-off: We apply a two-stage filtering process.
// 1. During the file system crawl (`performCrawl`), we only apply directory-level
// ignore rules (e.g., `node_modules/`, `dist/`). This is because applying
// a full ignore filter (which includes file-specific patterns like `*.log`)
// during the crawl can significantly slow down `fdir`.
// 2. Here, in the `search` method, we apply the full ignore filter
// (including file patterns) to the `filteredCandidates` (which have already
// been filtered by the user's search pattern and sorted). For autocomplete,
// the number of displayed results is small (MAX_SUGGESTIONS_TO_SHOW),
// so applying the full filter to this truncated list is much more efficient
// than applying it to every file during the initial crawl.
const fileFilter = this.ignore.getFileFilter();
const results: string[] = [];
for (const [i, candidate] of filteredCandidates.entries()) {
// Yield to the event loop to avoid blocking on large result sets.
if (i % 1000 === 0) {
await new Promise((resolve) => setImmediate(resolve));
if (options.signal?.aborted) {
throw new AbortError();
}
}
if (results.length >= (options.maxResults ?? Infinity)) {
break;
}
// The `ignore` library throws an error if the path is '.', so we skip it.
if (candidate === '.') {
continue;
}
if (!fileFilter(candidate)) {
results.push(candidate);
}
}
return results;
}
/**
* Loads ignore rules from .gitignore and .geminiignore files, and applies
* any additional ignore directories specified in the options.
*/
private loadIgnoreRules(): void {
if (this.options.useGitignore) {
const gitignorePath = path.join(this.absoluteDir, '.gitignore');
if (fs.existsSync(gitignorePath)) {
this.ignore.add(fs.readFileSync(gitignorePath, 'utf8'));
}
}
if (this.options.useGeminiignore) {
const geminiignorePath = path.join(this.absoluteDir, '.geminiignore');
if (fs.existsSync(geminiignorePath)) {
this.ignore.add(fs.readFileSync(geminiignorePath, 'utf8'));
}
}
const ignoreDirs = ['.git', ...this.options.ignoreDirs];
this.ignore.add(
ignoreDirs.map((dir) => {
if (dir.endsWith('/')) {
return dir;
}
return `${dir}/`;
}),
);
}
/**
* Crawls the file system to get a list of all files and directories,
* optionally using a cache for faster initialization.
*/
private async crawlFiles(): Promise<void> {
if (this.options.cache) {
const cacheKey = cache.getCacheKey(
this.absoluteDir,
this.ignore.getFingerprint(),
);
const cachedResults = cache.read(cacheKey);
if (cachedResults) {
this.allFiles = cachedResults;
return;
}
}
this.allFiles = await this.performCrawl();
if (this.options.cache) {
const cacheKey = cache.getCacheKey(
this.absoluteDir,
this.ignore.getFingerprint(),
);
cache.write(cacheKey, this.allFiles, this.options.cacheTtl * 1000);
}
}
/**
* Performs the actual file system crawl using `fdir`, applying directory
* ignore rules.
* @returns A promise that resolves to a list of all files and directories.
*/
private async performCrawl(): Promise<string[]> {
const dirFilter = this.ignore.getDirectoryFilter();
// We use `fdir` for fast file system traversal. A key performance
// optimization for large workspaces is to exclude entire directories
// early in the traversal process. This is why we apply directory-specific
// ignore rules (e.g., `node_modules/`, `dist/`) directly to `fdir`'s
// exclude filter.
const api = new fdir()
.withRelativePaths()
.withDirs()
.withPathSeparator('/') // Always use unix style paths
.exclude((_, dirPath) => {
const relativePath = path.relative(this.absoluteDir, dirPath);
return dirFilter(`${relativePath}/`);
});
return api.crawl(this.absoluteDir).withPromise();
}
/**
* Builds the in-memory cache for fast pattern matching.
*/
private buildResultCache(): void {
this.resultCache = new ResultCache(this.allFiles, this.absoluteDir);
}
}

View File

@@ -0,0 +1,65 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { describe, it, expect } from 'vitest';
import { Ignore } from './ignore.js';
describe('Ignore', () => {
describe('getDirectoryFilter', () => {
it('should ignore directories matching directory patterns', () => {
const ig = new Ignore().add(['foo/', 'bar/']);
const dirFilter = ig.getDirectoryFilter();
expect(dirFilter('foo/')).toBe(true);
expect(dirFilter('bar/')).toBe(true);
expect(dirFilter('baz/')).toBe(false);
});
it('should not ignore directories with file patterns', () => {
const ig = new Ignore().add(['foo.js', '*.log']);
const dirFilter = ig.getDirectoryFilter();
expect(dirFilter('foo.js')).toBe(false);
expect(dirFilter('foo.log')).toBe(false);
});
});
describe('getFileFilter', () => {
it('should not ignore files with directory patterns', () => {
const ig = new Ignore().add(['foo/', 'bar/']);
const fileFilter = ig.getFileFilter();
expect(fileFilter('foo')).toBe(false);
expect(fileFilter('foo/file.txt')).toBe(false);
});
it('should ignore files matching file patterns', () => {
const ig = new Ignore().add(['*.log', 'foo.js']);
const fileFilter = ig.getFileFilter();
expect(fileFilter('foo.log')).toBe(true);
expect(fileFilter('foo.js')).toBe(true);
expect(fileFilter('bar.txt')).toBe(false);
});
});
it('should accumulate patterns across multiple add() calls', () => {
const ig = new Ignore().add('foo.js');
ig.add('bar.js');
const fileFilter = ig.getFileFilter();
expect(fileFilter('foo.js')).toBe(true);
expect(fileFilter('bar.js')).toBe(true);
expect(fileFilter('baz.js')).toBe(false);
});
it('should return a stable and consistent fingerprint', () => {
const ig1 = new Ignore().add(['foo', '!bar']);
const ig2 = new Ignore().add('foo\n!bar');
// Fingerprints should be identical for the same rules.
expect(ig1.getFingerprint()).toBe(ig2.getFingerprint());
// Adding a new rule should change the fingerprint.
ig2.add('baz');
expect(ig1.getFingerprint()).not.toBe(ig2.getFingerprint());
});
});

View File

@@ -0,0 +1,93 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import ignore from 'ignore';
import picomatch from 'picomatch';
const hasFileExtension = picomatch('**/*[*.]*');
export class Ignore {
private readonly allPatterns: string[] = [];
private dirIgnorer = ignore();
private fileIgnorer = ignore();
/**
* Adds one or more ignore patterns.
* @param patterns A single pattern string or an array of pattern strings.
* Each pattern can be a glob-like string similar to .gitignore rules.
* @returns The `Ignore` instance for chaining.
*/
add(patterns: string | string[]): this {
if (typeof patterns === 'string') {
patterns = patterns.split(/\r?\n/);
}
for (const p of patterns) {
const pattern = p.trim();
if (pattern === '' || pattern.startsWith('#')) {
continue;
}
this.allPatterns.push(pattern);
const isPositiveDirPattern =
pattern.endsWith('/') && !pattern.startsWith('!');
if (isPositiveDirPattern) {
this.dirIgnorer.add(pattern);
} else {
// An ambiguous pattern (e.g., "build") could match a file or a
// directory. To optimize the file system crawl, we use a heuristic:
// patterns without a dot in the last segment are included in the
// directory exclusion check.
//
// This heuristic can fail. For example, an ignore pattern of "my.assets"
// intended to exclude a directory will not be treated as a directory
// pattern because it contains a ".". This results in crawling a
// directory that should have been excluded, reducing efficiency.
// Correctness is still maintained. The incorrectly crawled directory
// will be filtered out by the final ignore check.
//
// For maximum crawl efficiency, users should explicitly mark directory
// patterns with a trailing slash (e.g., "my.assets/").
this.fileIgnorer.add(pattern);
if (!hasFileExtension(pattern)) {
this.dirIgnorer.add(pattern);
}
}
}
return this;
}
/**
* Returns a predicate that matches explicit directory ignore patterns (patterns ending with '/').
* @returns {(dirPath: string) => boolean}
*/
getDirectoryFilter(): (dirPath: string) => boolean {
return (dirPath: string) => this.dirIgnorer.ignores(dirPath);
}
/**
* Returns a predicate that matches file ignore patterns (all patterns not ending with '/').
* Note: This may also match directories if a file pattern matches a directory name, but all explicit directory patterns are handled by getDirectoryFilter.
* @returns {(filePath: string) => boolean}
*/
getFileFilter(): (filePath: string) => boolean {
return (filePath: string) => this.fileIgnorer.ignores(filePath);
}
/**
* Returns a string representing the current set of ignore patterns.
* This can be used to generate a unique identifier for the ignore configuration,
* useful for caching purposes.
* @returns A string fingerprint of the ignore patterns.
*/
getFingerprint(): string {
return this.allPatterns.join('\n');
}
}

View File

@@ -0,0 +1,56 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import path from 'node:path';
import { test, expect } from 'vitest';
import { ResultCache } from './result-cache.js';
test('ResultCache basic usage', async () => {
const files = [
'foo.txt',
'bar.js',
'baz.md',
'subdir/file.txt',
'subdir/other.js',
'subdir/nested/file.md',
];
const cache = new ResultCache(files, path.resolve('.'));
const { files: resultFiles, isExactMatch } = await cache.get('*.js');
expect(resultFiles).toEqual(files);
expect(isExactMatch).toBe(false);
});
test('ResultCache cache hit/miss', async () => {
const files = ['foo.txt', 'bar.js', 'baz.md'];
const cache = new ResultCache(files, path.resolve('.'));
// First call: miss
const { files: result1Files, isExactMatch: isExactMatch1 } =
await cache.get('*.js');
expect(result1Files).toEqual(files);
expect(isExactMatch1).toBe(false);
// Simulate FileSearch applying the filter and setting the result
cache.set('*.js', ['bar.js']);
// Second call: hit
const { files: result2Files, isExactMatch: isExactMatch2 } =
await cache.get('*.js');
expect(result2Files).toEqual(['bar.js']);
expect(isExactMatch2).toBe(true);
});
test('ResultCache best base query', async () => {
const files = ['foo.txt', 'foobar.js', 'baz.md'];
const cache = new ResultCache(files, path.resolve('.'));
// Cache a broader query
cache.set('foo', ['foo.txt', 'foobar.js']);
// Search for a more specific query that starts with the broader one
const { files: resultFiles, isExactMatch } = await cache.get('foobar');
expect(resultFiles).toEqual(['foo.txt', 'foobar.js']);
expect(isExactMatch).toBe(false);
});

View File

@@ -0,0 +1,70 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
/**
* Implements an in-memory cache for file search results.
* This cache optimizes subsequent searches by leveraging previously computed results.
*/
export class ResultCache {
private readonly cache: Map<string, string[]>;
private hits = 0;
private misses = 0;
constructor(
private readonly allFiles: string[],
private readonly absoluteDir: string,
) {
this.cache = new Map();
}
/**
* Retrieves cached search results for a given query, or provides a base set
* of files to search from.
* @param query The search query pattern.
* @returns An object containing the files to search and a boolean indicating
* if the result is an exact cache hit.
*/
async get(
query: string,
): Promise<{ files: string[]; isExactMatch: boolean }> {
const isCacheHit = this.cache.has(query);
if (isCacheHit) {
this.hits++;
return { files: this.cache.get(query)!, isExactMatch: true };
}
this.misses++;
// This is the core optimization of the memory cache.
// If a user first searches for "foo", and then for "foobar",
// we don't need to search through all files again. We can start
// from the results of the "foo" search.
// This finds the most specific, already-cached query that is a prefix
// of the current query.
let bestBaseQuery = '';
for (const key of this.cache?.keys?.() ?? []) {
if (query.startsWith(key) && key.length > bestBaseQuery.length) {
bestBaseQuery = key;
}
}
const filesToSearch = bestBaseQuery
? this.cache.get(bestBaseQuery)!
: this.allFiles;
return { files: filesToSearch, isExactMatch: false };
}
/**
* Stores search results in the cache.
* @param query The search query pattern.
* @param results The matching file paths to cache.
*/
set(query: string, results: string[]): void {
this.cache.set(query, results);
}
}

View File

@@ -17,7 +17,8 @@ import {
import { DEFAULT_GEMINI_FLASH_MODEL } from '../config/models.js';
import { retryWithBackoff } from './retry.js';
import { AuthType } from '../core/contentGenerator.js';
import { IdeClient } from '../ide/ide-client.js';
vi.mock('node:fs');
vi.mock('node:fs');
@@ -35,7 +36,6 @@ describe('Flash Fallback Integration', () => {
debugMode: false,
cwd: '/test',
model: 'gemini-2.5-pro',
ideClient: IdeClient.getInstance(false),
});
// Reset simulation state for each test

View File

@@ -67,6 +67,7 @@ describe('loadServerHierarchicalMemory', () => {
it('should return empty memory and count if no context files are found', async () => {
const result = await loadServerHierarchicalMemory(
cwd,
[],
false,
new FileDiscoveryService(projectRoot),
);
@@ -85,14 +86,13 @@ describe('loadServerHierarchicalMemory', () => {
const result = await loadServerHierarchicalMemory(
cwd,
[],
false,
new FileDiscoveryService(projectRoot),
);
expect(result).toEqual({
memoryContent: `--- Context from: ${path.relative(cwd, defaultContextFile)} ---
default context content
--- End of Context from: ${path.relative(cwd, defaultContextFile)} ---`,
memoryContent: `--- Context from: ${path.relative(cwd, defaultContextFile)} ---\ndefault context content\n--- End of Context from: ${path.relative(cwd, defaultContextFile)} ---`,
fileCount: 1,
});
});
@@ -108,14 +108,13 @@ default context content
const result = await loadServerHierarchicalMemory(
cwd,
[],
false,
new FileDiscoveryService(projectRoot),
);
expect(result).toEqual({
memoryContent: `--- Context from: ${path.relative(cwd, customContextFile)} ---
custom context content
--- End of Context from: ${path.relative(cwd, customContextFile)} ---`,
memoryContent: `--- Context from: ${path.relative(cwd, customContextFile)} ---\ncustom context content\n--- End of Context from: ${path.relative(cwd, customContextFile)} ---`,
fileCount: 1,
});
});
@@ -135,18 +134,13 @@ custom context content
const result = await loadServerHierarchicalMemory(
cwd,
[],
false,
new FileDiscoveryService(projectRoot),
);
expect(result).toEqual({
memoryContent: `--- Context from: ${path.relative(cwd, projectContextFile)} ---
project context content
--- End of Context from: ${path.relative(cwd, projectContextFile)} ---
--- Context from: ${path.relative(cwd, cwdContextFile)} ---
cwd context content
--- End of Context from: ${path.relative(cwd, cwdContextFile)} ---`,
memoryContent: `--- Context from: ${path.relative(cwd, projectContextFile)} ---\nproject context content\n--- End of Context from: ${path.relative(cwd, projectContextFile)} ---\n\n--- Context from: ${path.relative(cwd, cwdContextFile)} ---\ncwd context content\n--- End of Context from: ${path.relative(cwd, cwdContextFile)} ---`,
fileCount: 2,
});
});
@@ -163,18 +157,13 @@ cwd context content
const result = await loadServerHierarchicalMemory(
cwd,
[],
false,
new FileDiscoveryService(projectRoot),
);
expect(result).toEqual({
memoryContent: `--- Context from: ${customFilename} ---
CWD custom memory
--- End of Context from: ${customFilename} ---
--- Context from: ${path.join('subdir', customFilename)} ---
Subdir custom memory
--- End of Context from: ${path.join('subdir', customFilename)} ---`,
memoryContent: `--- Context from: ${customFilename} ---\nCWD custom memory\n--- End of Context from: ${customFilename} ---\n\n--- Context from: ${path.join('subdir', customFilename)} ---\nSubdir custom memory\n--- End of Context from: ${path.join('subdir', customFilename)} ---`,
fileCount: 2,
});
});
@@ -191,18 +180,13 @@ Subdir custom memory
const result = await loadServerHierarchicalMemory(
cwd,
[],
false,
new FileDiscoveryService(projectRoot),
);
expect(result).toEqual({
memoryContent: `--- Context from: ${path.relative(cwd, projectRootGeminiFile)} ---
Project root memory
--- End of Context from: ${path.relative(cwd, projectRootGeminiFile)} ---
--- Context from: ${path.relative(cwd, srcGeminiFile)} ---
Src directory memory
--- End of Context from: ${path.relative(cwd, srcGeminiFile)} ---`,
memoryContent: `--- Context from: ${path.relative(cwd, projectRootGeminiFile)} ---\nProject root memory\n--- End of Context from: ${path.relative(cwd, projectRootGeminiFile)} ---\n\n--- Context from: ${path.relative(cwd, srcGeminiFile)} ---\nSrc directory memory\n--- End of Context from: ${path.relative(cwd, srcGeminiFile)} ---`,
fileCount: 2,
});
});
@@ -219,18 +203,13 @@ Src directory memory
const result = await loadServerHierarchicalMemory(
cwd,
[],
false,
new FileDiscoveryService(projectRoot),
);
expect(result).toEqual({
memoryContent: `--- Context from: ${DEFAULT_CONTEXT_FILENAME} ---
CWD memory
--- End of Context from: ${DEFAULT_CONTEXT_FILENAME} ---
--- Context from: ${path.join('subdir', DEFAULT_CONTEXT_FILENAME)} ---
Subdir memory
--- End of Context from: ${path.join('subdir', DEFAULT_CONTEXT_FILENAME)} ---`,
memoryContent: `--- Context from: ${DEFAULT_CONTEXT_FILENAME} ---\nCWD memory\n--- End of Context from: ${DEFAULT_CONTEXT_FILENAME} ---\n\n--- Context from: ${path.join('subdir', DEFAULT_CONTEXT_FILENAME)} ---\nSubdir memory\n--- End of Context from: ${path.join('subdir', DEFAULT_CONTEXT_FILENAME)} ---`,
fileCount: 2,
});
});
@@ -259,30 +238,13 @@ Subdir memory
const result = await loadServerHierarchicalMemory(
cwd,
[],
false,
new FileDiscoveryService(projectRoot),
);
expect(result).toEqual({
memoryContent: `--- Context from: ${path.relative(cwd, defaultContextFile)} ---
default context content
--- End of Context from: ${path.relative(cwd, defaultContextFile)} ---
--- Context from: ${path.relative(cwd, rootGeminiFile)} ---
Project parent memory
--- End of Context from: ${path.relative(cwd, rootGeminiFile)} ---
--- Context from: ${path.relative(cwd, projectRootGeminiFile)} ---
Project root memory
--- End of Context from: ${path.relative(cwd, projectRootGeminiFile)} ---
--- Context from: ${path.relative(cwd, cwdGeminiFile)} ---
CWD memory
--- End of Context from: ${path.relative(cwd, cwdGeminiFile)} ---
--- Context from: ${path.relative(cwd, subDirGeminiFile)} ---
Subdir memory
--- End of Context from: ${path.relative(cwd, subDirGeminiFile)} ---`,
memoryContent: `--- Context from: ${path.relative(cwd, defaultContextFile)} ---\ndefault context content\n--- End of Context from: ${path.relative(cwd, defaultContextFile)} ---\n\n--- Context from: ${path.relative(cwd, rootGeminiFile)} ---\nProject parent memory\n--- End of Context from: ${path.relative(cwd, rootGeminiFile)} ---\n\n--- Context from: ${path.relative(cwd, projectRootGeminiFile)} ---\nProject root memory\n--- End of Context from: ${path.relative(cwd, projectRootGeminiFile)} ---\n\n--- Context from: ${path.relative(cwd, cwdGeminiFile)} ---\nCWD memory\n--- End of Context from: ${path.relative(cwd, cwdGeminiFile)} ---\n\n--- Context from: ${path.relative(cwd, subDirGeminiFile)} ---\nSubdir memory\n--- End of Context from: ${path.relative(cwd, subDirGeminiFile)} ---`,
fileCount: 5,
});
});
@@ -302,6 +264,7 @@ Subdir memory
const result = await loadServerHierarchicalMemory(
cwd,
[],
false,
new FileDiscoveryService(projectRoot),
[],
@@ -314,9 +277,7 @@ Subdir memory
);
expect(result).toEqual({
memoryContent: `--- Context from: ${path.relative(cwd, regularSubDirGeminiFile)} ---
My code memory
--- End of Context from: ${path.relative(cwd, regularSubDirGeminiFile)} ---`,
memoryContent: `--- Context from: ${path.relative(cwd, regularSubDirGeminiFile)} ---\nMy code memory\n--- End of Context from: ${path.relative(cwd, regularSubDirGeminiFile)} ---`,
fileCount: 1,
});
});
@@ -333,6 +294,7 @@ My code memory
// Pass the custom limit directly to the function
await loadServerHierarchicalMemory(
cwd,
[],
true,
new FileDiscoveryService(projectRoot),
[],
@@ -353,6 +315,7 @@ My code memory
const result = await loadServerHierarchicalMemory(
cwd,
[],
false,
new FileDiscoveryService(projectRoot),
);
@@ -371,15 +334,36 @@ My code memory
const result = await loadServerHierarchicalMemory(
cwd,
[],
false,
new FileDiscoveryService(projectRoot),
[extensionFilePath],
);
expect(result).toEqual({
memoryContent: `--- Context from: ${path.relative(cwd, extensionFilePath)} ---
Extension memory content
--- End of Context from: ${path.relative(cwd, extensionFilePath)} ---`,
memoryContent: `--- Context from: ${path.relative(cwd, extensionFilePath)} ---\nExtension memory content\n--- End of Context from: ${path.relative(cwd, extensionFilePath)} ---`,
fileCount: 1,
});
});
it('should load memory from included directories', async () => {
const includedDir = await createEmptyDir(
path.join(testRootDir, 'included'),
);
const includedFile = await createTestFile(
path.join(includedDir, DEFAULT_CONTEXT_FILENAME),
'included directory memory',
);
const result = await loadServerHierarchicalMemory(
cwd,
[includedDir],
false,
new FileDiscoveryService(projectRoot),
);
expect(result).toEqual({
memoryContent: `--- Context from: ${path.relative(cwd, includedFile)} ---\nincluded directory memory\n--- End of Context from: ${path.relative(cwd, includedFile)} ---`,
fileCount: 1,
});
});

View File

@@ -83,6 +83,36 @@ async function findProjectRoot(startDir: string): Promise<string | null> {
async function getGeminiMdFilePathsInternal(
currentWorkingDirectory: string,
includeDirectoriesToReadGemini: readonly string[],
userHomePath: string,
debugMode: boolean,
fileService: FileDiscoveryService,
extensionContextFilePaths: string[] = [],
fileFilteringOptions: FileFilteringOptions,
maxDirs: number,
): Promise<string[]> {
const dirs = new Set<string>([
...includeDirectoriesToReadGemini,
currentWorkingDirectory,
]);
const paths = [];
for (const dir of dirs) {
const pathsByDir = await getGeminiMdFilePathsInternalForEachDir(
dir,
userHomePath,
debugMode,
fileService,
extensionContextFilePaths,
fileFilteringOptions,
maxDirs,
);
paths.push(...pathsByDir);
}
return Array.from(new Set<string>(paths));
}
async function getGeminiMdFilePathsInternalForEachDir(
dir: string,
userHomePath: string,
debugMode: boolean,
fileService: FileDiscoveryService,
@@ -115,8 +145,8 @@ async function getGeminiMdFilePathsInternal(
// FIX: Only perform the workspace search (upward and downward scans)
// if a valid currentWorkingDirectory is provided.
if (currentWorkingDirectory) {
const resolvedCwd = path.resolve(currentWorkingDirectory);
if (dir) {
const resolvedCwd = path.resolve(dir);
if (debugMode)
logger.debug(
`Searching for ${geminiMdFilename} starting from CWD: ${resolvedCwd}`,
@@ -257,6 +287,7 @@ function concatenateInstructions(
*/
export async function loadServerHierarchicalMemory(
currentWorkingDirectory: string,
includeDirectoriesToReadGemini: readonly string[],
debugMode: boolean,
fileService: FileDiscoveryService,
extensionContextFilePaths: string[] = [],
@@ -274,6 +305,7 @@ export async function loadServerHierarchicalMemory(
const userHomePath = homedir();
const filePaths = await getGeminiMdFilePathsInternal(
currentWorkingDirectory,
includeDirectoriesToReadGemini,
userHomePath,
debugMode,
fileService,

View File

@@ -15,6 +15,8 @@ import * as path from 'path';
export class WorkspaceContext {
private directories: Set<string>;
private initialDirectories: Set<string>;
/**
* Creates a new WorkspaceContext with the given initial directory and optional additional directories.
* @param initialDirectory The initial working directory (usually cwd)
@@ -22,11 +24,14 @@ export class WorkspaceContext {
*/
constructor(initialDirectory: string, additionalDirectories: string[] = []) {
this.directories = new Set<string>();
this.initialDirectories = new Set<string>();
this.addDirectoryInternal(initialDirectory);
this.addInitialDirectoryInternal(initialDirectory);
for (const dir of additionalDirectories) {
this.addDirectoryInternal(dir);
this.addInitialDirectoryInternal(dir);
}
}
@@ -69,6 +74,33 @@ export class WorkspaceContext {
this.directories.add(realPath);
}
private addInitialDirectoryInternal(
directory: string,
basePath: string = process.cwd(),
): void {
const absolutePath = path.isAbsolute(directory)
? directory
: path.resolve(basePath, directory);
if (!fs.existsSync(absolutePath)) {
throw new Error(`Directory does not exist: ${absolutePath}`);
}
const stats = fs.statSync(absolutePath);
if (!stats.isDirectory()) {
throw new Error(`Path is not a directory: ${absolutePath}`);
}
let realPath: string;
try {
realPath = fs.realpathSync(absolutePath);
} catch (_error) {
throw new Error(`Failed to resolve path: ${absolutePath}`);
}
this.initialDirectories.add(realPath);
}
/**
* Gets a copy of all workspace directories.
* @returns Array of absolute directory paths
@@ -77,6 +109,17 @@ export class WorkspaceContext {
return Array.from(this.directories);
}
getInitialDirectories(): readonly string[] {
return Array.from(this.initialDirectories);
}
setDirectories(directories: readonly string[]): void {
this.directories.clear();
for (const dir of directories) {
this.addDirectoryInternal(dir);
}
}
/**
* Checks if a given path is within any of the workspace directories.
* @param pathToCheck The path to validate