Merge branch 'main' into chore/sync-gemini-cli-v0.1.21

This commit is contained in:
tanzhenxin
2025-08-22 10:20:01 +08:00
22 changed files with 984 additions and 701 deletions

View File

@@ -52,6 +52,7 @@
"strip-ansi": "^7.1.0",
"tiktoken": "^1.0.21",
"undici": "^7.10.0",
"uuid": "^9.0.1",
"ws": "^8.18.0"
},
"devDependencies": {

View File

@@ -203,7 +203,6 @@ export interface ConfigParameters {
folderTrust?: boolean;
ideMode?: boolean;
enableOpenAILogging?: boolean;
sampling_params?: Record<string, unknown>;
systemPromptMappings?: Array<{
baseUrls: string[];
modelNames: string[];
@@ -212,6 +211,9 @@ export interface ConfigParameters {
contentGenerator?: {
timeout?: number;
maxRetries?: number;
samplingParams?: {
[key: string]: unknown;
};
};
cliVersion?: string;
loadMemoryFromIncludeDirectories?: boolean;
@@ -287,10 +289,10 @@ export class Config {
| Record<string, SummarizeToolOutputSettings>
| undefined;
private readonly enableOpenAILogging: boolean;
private readonly sampling_params?: Record<string, unknown>;
private readonly contentGenerator?: {
timeout?: number;
maxRetries?: number;
samplingParams?: Record<string, unknown>;
};
private readonly cliVersion?: string;
private readonly experimentalZedIntegration: boolean = false;
@@ -367,7 +369,6 @@ export class Config {
this.ideClient = IdeClient.getInstance();
this.systemPromptMappings = params.systemPromptMappings;
this.enableOpenAILogging = params.enableOpenAILogging ?? false;
this.sampling_params = params.sampling_params;
this.contentGenerator = params.contentGenerator;
this.cliVersion = params.cliVersion;
@@ -766,10 +767,6 @@ export class Config {
return this.enableOpenAILogging;
}
getSamplingParams(): Record<string, unknown> | undefined {
return this.sampling_params;
}
getContentGeneratorTimeout(): number | undefined {
return this.contentGenerator?.timeout;
}
@@ -778,6 +775,12 @@ export class Config {
return this.contentGenerator?.maxRetries;
}
getContentGeneratorSamplingParams(): ContentGeneratorConfig['samplingParams'] {
return this.contentGenerator?.samplingParams as
| ContentGeneratorConfig['samplingParams']
| undefined;
}
getCliVersion(): string | undefined {
return this.cliVersion;
}

View File

@@ -7,6 +7,7 @@
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
import { OpenAIContentGenerator } from '../openaiContentGenerator.js';
import { Config } from '../../config/config.js';
import { AuthType } from '../contentGenerator.js';
import OpenAI from 'openai';
// Mock OpenAI
@@ -41,9 +42,6 @@ describe('OpenAIContentGenerator Timeout Handling', () => {
mockConfig = {
getContentGeneratorConfig: vi.fn().mockReturnValue({
authType: 'openai',
enableOpenAILogging: false,
timeout: 120000,
maxRetries: 3,
}),
getCliVersion: vi.fn().mockReturnValue('1.0.0'),
} as unknown as Config;
@@ -60,7 +58,12 @@ describe('OpenAIContentGenerator Timeout Handling', () => {
vi.mocked(OpenAI).mockImplementation(() => mockOpenAIClient);
// Create generator instance
generator = new OpenAIContentGenerator('test-key', 'gpt-4', mockConfig);
const contentGeneratorConfig = {
model: 'gpt-4',
apiKey: 'test-key',
authType: AuthType.USE_OPENAI,
};
generator = new OpenAIContentGenerator(contentGeneratorConfig, mockConfig);
});
afterEach(() => {
@@ -237,12 +240,18 @@ describe('OpenAIContentGenerator Timeout Handling', () => {
describe('timeout configuration', () => {
it('should use default timeout configuration', () => {
new OpenAIContentGenerator('test-key', 'gpt-4', mockConfig);
const contentGeneratorConfig = {
model: 'gpt-4',
apiKey: 'test-key',
authType: AuthType.USE_OPENAI,
baseUrl: 'http://localhost:8080',
};
new OpenAIContentGenerator(contentGeneratorConfig, mockConfig);
// Verify OpenAI client was created with timeout config
expect(OpenAI).toHaveBeenCalledWith({
apiKey: 'test-key',
baseURL: '',
baseURL: 'http://localhost:8080',
timeout: 120000,
maxRetries: 3,
defaultHeaders: {
@@ -253,18 +262,23 @@ describe('OpenAIContentGenerator Timeout Handling', () => {
it('should use custom timeout from config', () => {
const customConfig = {
getContentGeneratorConfig: vi.fn().mockReturnValue({
timeout: 300000, // 5 minutes
maxRetries: 5,
}),
getContentGeneratorConfig: vi.fn().mockReturnValue({}),
getCliVersion: vi.fn().mockReturnValue('1.0.0'),
} as unknown as Config;
new OpenAIContentGenerator('test-key', 'gpt-4', customConfig);
const contentGeneratorConfig = {
model: 'gpt-4',
apiKey: 'test-key',
baseUrl: 'http://localhost:8080',
authType: AuthType.USE_OPENAI,
timeout: 300000,
maxRetries: 5,
};
new OpenAIContentGenerator(contentGeneratorConfig, customConfig);
expect(OpenAI).toHaveBeenCalledWith({
apiKey: 'test-key',
baseURL: '',
baseURL: 'http://localhost:8080',
timeout: 300000,
maxRetries: 5,
defaultHeaders: {
@@ -279,11 +293,17 @@ describe('OpenAIContentGenerator Timeout Handling', () => {
getCliVersion: vi.fn().mockReturnValue('1.0.0'),
} as unknown as Config;
new OpenAIContentGenerator('test-key', 'gpt-4', noTimeoutConfig);
const contentGeneratorConfig = {
model: 'gpt-4',
apiKey: 'test-key',
authType: AuthType.USE_OPENAI,
baseUrl: 'http://localhost:8080',
};
new OpenAIContentGenerator(contentGeneratorConfig, noTimeoutConfig);
expect(OpenAI).toHaveBeenCalledWith({
apiKey: 'test-key',
baseURL: '',
baseURL: 'http://localhost:8080',
timeout: 120000, // default
maxRetries: 3, // default
defaultHeaders: {

View File

@@ -589,10 +589,7 @@ export class GeminiClient {
model || this.config.getModel() || DEFAULT_GEMINI_FLASH_MODEL;
try {
const userMemory = this.config.getUserMemory();
const systemPromptMappings = this.config.getSystemPromptMappings();
const systemInstruction = getCoreSystemPrompt(userMemory, {
systemPromptMappings,
});
const systemInstruction = getCoreSystemPrompt(userMemory);
const requestConfig = {
abortSignal,
...this.generateContentConfig,
@@ -680,10 +677,7 @@ export class GeminiClient {
try {
const userMemory = this.config.getUserMemory();
const systemPromptMappings = this.config.getSystemPromptMappings();
const systemInstruction = getCoreSystemPrompt(userMemory, {
systemPromptMappings,
});
const systemInstruction = getCoreSystemPrompt(userMemory);
const requestConfig = {
abortSignal,

View File

@@ -84,6 +84,7 @@ describe('createContentGeneratorConfig', () => {
getSamplingParams: vi.fn().mockReturnValue(undefined),
getContentGeneratorTimeout: vi.fn().mockReturnValue(undefined),
getContentGeneratorMaxRetries: vi.fn().mockReturnValue(undefined),
getContentGeneratorSamplingParams: vi.fn().mockReturnValue(undefined),
getCliVersion: vi.fn().mockReturnValue('1.0.0'),
} as unknown as Config;

View File

@@ -53,6 +53,7 @@ export enum AuthType {
export type ContentGeneratorConfig = {
model: string;
apiKey?: string;
baseUrl?: string;
vertexai?: boolean;
authType?: AuthType | undefined;
enableOpenAILogging?: boolean;
@@ -77,11 +78,16 @@ export function createContentGeneratorConfig(
config: Config,
authType: AuthType | undefined,
): ContentGeneratorConfig {
// google auth
const geminiApiKey = process.env.GEMINI_API_KEY || undefined;
const googleApiKey = process.env.GOOGLE_API_KEY || undefined;
const googleCloudProject = process.env.GOOGLE_CLOUD_PROJECT || undefined;
const googleCloudLocation = process.env.GOOGLE_CLOUD_LOCATION || undefined;
// openai auth
const openaiApiKey = process.env.OPENAI_API_KEY;
const openaiBaseUrl = process.env.OPENAI_BASE_URL || undefined;
const openaiModel = process.env.OPENAI_MODEL || undefined;
// Use runtime model from config if available; otherwise, fall back to parameter or default
const effectiveModel = config.getModel() || DEFAULT_GEMINI_MODEL;
@@ -93,7 +99,7 @@ export function createContentGeneratorConfig(
enableOpenAILogging: config.getEnableOpenAILogging(),
timeout: config.getContentGeneratorTimeout(),
maxRetries: config.getContentGeneratorMaxRetries(),
samplingParams: config.getSamplingParams(),
samplingParams: config.getContentGeneratorSamplingParams(),
};
// If we are using Google auth or we are in Cloud Shell, there is nothing else to validate for now
@@ -123,8 +129,8 @@ export function createContentGeneratorConfig(
if (authType === AuthType.USE_OPENAI && openaiApiKey) {
contentGeneratorConfig.apiKey = openaiApiKey;
contentGeneratorConfig.model =
process.env.OPENAI_MODEL || DEFAULT_GEMINI_MODEL;
contentGeneratorConfig.baseUrl = openaiBaseUrl;
contentGeneratorConfig.model = openaiModel || DEFAULT_QWEN_MODEL;
return contentGeneratorConfig;
}
@@ -192,7 +198,7 @@ export async function createContentGenerator(
);
// Always use OpenAIContentGenerator, logging is controlled by enableOpenAILogging flag
return new OpenAIContentGenerator(config.apiKey, config.model, gcConfig);
return new OpenAIContentGenerator(config, gcConfig);
}
if (config.authType === AuthType.QWEN_OAUTH) {
@@ -213,7 +219,7 @@ export async function createContentGenerator(
const qwenClient = await getQwenOauthClient(gcConfig);
// Create the content generator with dynamic token management
return new QwenContentGenerator(qwenClient, config.model, gcConfig);
return new QwenContentGenerator(qwenClient, config, gcConfig);
} catch (error) {
throw new Error(
`Failed to initialize Qwen: ${error instanceof Error ? error.message : String(error)}`,

View File

@@ -7,6 +7,7 @@
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
import { OpenAIContentGenerator } from './openaiContentGenerator.js';
import { Config } from '../config/config.js';
import { AuthType } from './contentGenerator.js';
import OpenAI from 'openai';
import type {
GenerateContentParameters,
@@ -84,7 +85,20 @@ describe('OpenAIContentGenerator', () => {
vi.mocked(OpenAI).mockImplementation(() => mockOpenAIClient);
// Create generator instance
generator = new OpenAIContentGenerator('test-key', 'gpt-4', mockConfig);
const contentGeneratorConfig = {
model: 'gpt-4',
apiKey: 'test-key',
authType: AuthType.USE_OPENAI,
enableOpenAILogging: false,
timeout: 120000,
maxRetries: 3,
samplingParams: {
temperature: 0.7,
max_tokens: 1000,
top_p: 0.9,
},
};
generator = new OpenAIContentGenerator(contentGeneratorConfig, mockConfig);
});
afterEach(() => {
@@ -95,7 +109,7 @@ describe('OpenAIContentGenerator', () => {
it('should initialize with basic configuration', () => {
expect(OpenAI).toHaveBeenCalledWith({
apiKey: 'test-key',
baseURL: '',
baseURL: undefined,
timeout: 120000,
maxRetries: 3,
defaultHeaders: {
@@ -105,9 +119,16 @@ describe('OpenAIContentGenerator', () => {
});
it('should handle custom base URL', () => {
vi.stubEnv('OPENAI_BASE_URL', 'https://api.custom.com');
new OpenAIContentGenerator('test-key', 'gpt-4', mockConfig);
const contentGeneratorConfig = {
model: 'gpt-4',
apiKey: 'test-key',
baseUrl: 'https://api.custom.com',
authType: AuthType.USE_OPENAI,
enableOpenAILogging: false,
timeout: 120000,
maxRetries: 3,
};
new OpenAIContentGenerator(contentGeneratorConfig, mockConfig);
expect(OpenAI).toHaveBeenCalledWith({
apiKey: 'test-key',
@@ -121,9 +142,16 @@ describe('OpenAIContentGenerator', () => {
});
it('should configure OpenRouter headers when using OpenRouter', () => {
vi.stubEnv('OPENAI_BASE_URL', 'https://openrouter.ai/api/v1');
new OpenAIContentGenerator('test-key', 'gpt-4', mockConfig);
const contentGeneratorConfig = {
model: 'gpt-4',
apiKey: 'test-key',
baseUrl: 'https://openrouter.ai/api/v1',
authType: AuthType.USE_OPENAI,
enableOpenAILogging: false,
timeout: 120000,
maxRetries: 3,
};
new OpenAIContentGenerator(contentGeneratorConfig, mockConfig);
expect(OpenAI).toHaveBeenCalledWith({
apiKey: 'test-key',
@@ -147,11 +175,18 @@ describe('OpenAIContentGenerator', () => {
getCliVersion: vi.fn().mockReturnValue('1.0.0'),
} as unknown as Config;
new OpenAIContentGenerator('test-key', 'gpt-4', customConfig);
const contentGeneratorConfig = {
model: 'gpt-4',
apiKey: 'test-key',
authType: AuthType.USE_OPENAI,
timeout: 300000,
maxRetries: 5,
};
new OpenAIContentGenerator(contentGeneratorConfig, customConfig);
expect(OpenAI).toHaveBeenCalledWith({
apiKey: 'test-key',
baseURL: '',
baseURL: undefined,
timeout: 300000,
maxRetries: 5,
defaultHeaders: {
@@ -906,9 +941,14 @@ describe('OpenAIContentGenerator', () => {
getCliVersion: vi.fn().mockReturnValue('1.0.0'),
} as unknown as Config;
const contentGeneratorConfig = {
model: 'gpt-4',
apiKey: 'test-key',
authType: AuthType.USE_OPENAI,
enableOpenAILogging: true,
};
const loggingGenerator = new OpenAIContentGenerator(
'test-key',
'gpt-4',
contentGeneratorConfig,
loggingConfig,
);
@@ -1029,9 +1069,14 @@ describe('OpenAIContentGenerator', () => {
getCliVersion: vi.fn().mockReturnValue('1.0.0'),
} as unknown as Config;
const contentGeneratorConfig = {
model: 'gpt-4',
apiKey: 'test-key',
authType: AuthType.USE_OPENAI,
enableOpenAILogging: true,
};
const loggingGenerator = new OpenAIContentGenerator(
'test-key',
'gpt-4',
contentGeneratorConfig,
loggingConfig,
);
@@ -1587,7 +1632,23 @@ describe('OpenAIContentGenerator', () => {
}
}
const testGenerator = new TestGenerator('test-key', 'gpt-4', mockConfig);
const contentGeneratorConfig = {
model: 'gpt-4',
apiKey: 'test-key',
authType: AuthType.USE_OPENAI,
enableOpenAILogging: false,
timeout: 120000,
maxRetries: 3,
samplingParams: {
temperature: 0.7,
max_tokens: 1000,
top_p: 0.9,
},
};
const testGenerator = new TestGenerator(
contentGeneratorConfig,
mockConfig,
);
const consoleSpy = vi
.spyOn(console, 'error')
.mockImplementation(() => {});
@@ -1908,9 +1969,18 @@ describe('OpenAIContentGenerator', () => {
getCliVersion: vi.fn().mockReturnValue('1.0.0'),
} as unknown as Config;
const contentGeneratorConfig = {
model: 'gpt-4',
apiKey: 'test-key',
authType: AuthType.USE_OPENAI,
enableOpenAILogging: true,
samplingParams: {
temperature: 0.8,
max_tokens: 500,
},
};
const loggingGenerator = new OpenAIContentGenerator(
'test-key',
'gpt-4',
contentGeneratorConfig,
loggingConfig,
);
@@ -2093,9 +2163,14 @@ describe('OpenAIContentGenerator', () => {
getCliVersion: vi.fn().mockReturnValue('1.0.0'),
} as unknown as Config;
const contentGeneratorConfig = {
model: 'gpt-4',
apiKey: 'test-key',
authType: AuthType.USE_OPENAI,
enableOpenAILogging: true,
};
const loggingGenerator = new OpenAIContentGenerator(
'test-key',
'gpt-4',
contentGeneratorConfig,
loggingConfig,
);
@@ -2350,9 +2425,18 @@ describe('OpenAIContentGenerator', () => {
getCliVersion: vi.fn().mockReturnValue('1.0.0'),
} as unknown as Config;
const contentGeneratorConfig = {
model: 'gpt-4',
apiKey: 'test-key',
authType: AuthType.USE_OPENAI,
samplingParams: {
temperature: undefined,
max_tokens: undefined,
top_p: undefined,
},
};
const testGenerator = new OpenAIContentGenerator(
'test-key',
'gpt-4',
contentGeneratorConfig,
configWithUndefined,
);
@@ -2408,9 +2492,22 @@ describe('OpenAIContentGenerator', () => {
getCliVersion: vi.fn().mockReturnValue('1.0.0'),
} as unknown as Config;
const contentGeneratorConfig = {
model: 'gpt-4',
apiKey: 'test-key',
authType: AuthType.USE_OPENAI,
samplingParams: {
temperature: 0.8,
max_tokens: 1500,
top_p: 0.95,
top_k: 40,
repetition_penalty: 1.1,
presence_penalty: 0.5,
frequency_penalty: 0.3,
},
};
const testGenerator = new OpenAIContentGenerator(
'test-key',
'gpt-4',
contentGeneratorConfig,
fullSamplingConfig,
);
@@ -2489,9 +2586,14 @@ describe('OpenAIContentGenerator', () => {
getCliVersion: vi.fn().mockReturnValue('1.0.0'),
} as unknown as Config;
const contentGeneratorConfig = {
model: 'qwen-turbo',
apiKey: 'test-key',
authType: AuthType.QWEN_OAUTH,
enableOpenAILogging: false,
};
const qwenGenerator = new OpenAIContentGenerator(
'test-key',
'qwen-turbo',
contentGeneratorConfig,
qwenConfig,
);
@@ -2528,12 +2630,6 @@ describe('OpenAIContentGenerator', () => {
});
it('should include metadata when baseURL is dashscope openai compatible mode', async () => {
// Mock environment to set dashscope base URL BEFORE creating the generator
vi.stubEnv(
'OPENAI_BASE_URL',
'https://dashscope.aliyuncs.com/compatible-mode/v1',
);
const dashscopeConfig = {
getContentGeneratorConfig: vi.fn().mockReturnValue({
authType: 'openai', // Not QWEN_OAUTH
@@ -2543,9 +2639,15 @@ describe('OpenAIContentGenerator', () => {
getCliVersion: vi.fn().mockReturnValue('1.0.0'),
} as unknown as Config;
const contentGeneratorConfig = {
model: 'qwen-turbo',
apiKey: 'test-key',
baseUrl: 'https://dashscope.aliyuncs.com/compatible-mode/v1',
authType: AuthType.USE_OPENAI,
enableOpenAILogging: false,
};
const dashscopeGenerator = new OpenAIContentGenerator(
'test-key',
'qwen-turbo',
contentGeneratorConfig,
dashscopeConfig,
);
@@ -2604,9 +2706,18 @@ describe('OpenAIContentGenerator', () => {
getCliVersion: vi.fn().mockReturnValue('1.0.0'),
} as unknown as Config;
const contentGeneratorConfig = {
model: 'gpt-4',
apiKey: 'test-key',
authType: AuthType.USE_OPENAI,
enableOpenAILogging: false,
};
const regularGenerator = new OpenAIContentGenerator(
'test-key',
'gpt-4',
contentGeneratorConfig,
regularConfig,
);
@@ -2650,9 +2761,18 @@ describe('OpenAIContentGenerator', () => {
getCliVersion: vi.fn().mockReturnValue('1.0.0'),
} as unknown as Config;
const contentGeneratorConfig = {
model: 'gpt-4',
apiKey: 'test-key',
authType: AuthType.USE_OPENAI,
enableOpenAILogging: false,
};
const otherGenerator = new OpenAIContentGenerator(
'test-key',
'gpt-4',
contentGeneratorConfig,
otherAuthConfig,
);
@@ -2699,9 +2819,18 @@ describe('OpenAIContentGenerator', () => {
getCliVersion: vi.fn().mockReturnValue('1.0.0'),
} as unknown as Config;
const contentGeneratorConfig = {
model: 'gpt-4',
apiKey: 'test-key',
authType: AuthType.USE_OPENAI,
enableOpenAILogging: false,
};
const otherBaseUrlGenerator = new OpenAIContentGenerator(
'test-key',
'gpt-4',
contentGeneratorConfig,
otherBaseUrlConfig,
);
@@ -2748,9 +2877,18 @@ describe('OpenAIContentGenerator', () => {
getCliVersion: vi.fn().mockReturnValue('1.0.0'),
} as unknown as Config;
const contentGeneratorConfig = {
model: 'qwen-turbo',
apiKey: 'test-key',
authType: AuthType.QWEN_OAUTH,
enableOpenAILogging: false,
};
const qwenGenerator = new OpenAIContentGenerator(
'test-key',
'qwen-turbo',
contentGeneratorConfig,
qwenConfig,
);
@@ -2804,8 +2942,6 @@ describe('OpenAIContentGenerator', () => {
sessionId: 'streaming-session-id',
promptId: 'streaming-prompt-id',
},
stream: true,
stream_options: { include_usage: true },
}),
);
@@ -2827,9 +2963,18 @@ describe('OpenAIContentGenerator', () => {
getCliVersion: vi.fn().mockReturnValue('1.0.0'),
} as unknown as Config;
const contentGeneratorConfig = {
model: 'gpt-4',
apiKey: 'test-key',
authType: AuthType.USE_OPENAI,
enableOpenAILogging: false,
};
const regularGenerator = new OpenAIContentGenerator(
'test-key',
'gpt-4',
contentGeneratorConfig,
regularConfig,
);
@@ -2901,9 +3046,18 @@ describe('OpenAIContentGenerator', () => {
getCliVersion: vi.fn().mockReturnValue('1.0.0'),
} as unknown as Config;
const contentGeneratorConfig = {
model: 'qwen-turbo',
apiKey: 'test-key',
authType: AuthType.QWEN_OAUTH,
enableOpenAILogging: false,
};
const qwenGenerator = new OpenAIContentGenerator(
'test-key',
'qwen-turbo',
contentGeneratorConfig,
qwenConfig,
);
@@ -2955,9 +3109,18 @@ describe('OpenAIContentGenerator', () => {
getCliVersion: vi.fn().mockReturnValue('1.0.0'),
} as unknown as Config;
const contentGeneratorConfig = {
model: 'gpt-4',
apiKey: 'test-key',
authType: AuthType.USE_OPENAI,
enableOpenAILogging: false,
};
const noBaseUrlGenerator = new OpenAIContentGenerator(
'test-key',
'gpt-4',
contentGeneratorConfig,
noBaseUrlConfig,
);
@@ -3004,9 +3167,18 @@ describe('OpenAIContentGenerator', () => {
getCliVersion: vi.fn().mockReturnValue('1.0.0'),
} as unknown as Config;
const contentGeneratorConfig = {
model: 'gpt-4',
apiKey: 'test-key',
authType: AuthType.USE_OPENAI,
enableOpenAILogging: false,
};
const undefinedAuthGenerator = new OpenAIContentGenerator(
'test-key',
'gpt-4',
contentGeneratorConfig,
undefinedAuthConfig,
);
@@ -3050,9 +3222,18 @@ describe('OpenAIContentGenerator', () => {
getCliVersion: vi.fn().mockReturnValue('1.0.0'),
} as unknown as Config;
const contentGeneratorConfig = {
model: 'gpt-4',
apiKey: 'test-key',
authType: AuthType.USE_OPENAI,
enableOpenAILogging: false,
};
const undefinedConfigGenerator = new OpenAIContentGenerator(
'test-key',
'gpt-4',
contentGeneratorConfig,
undefinedConfig,
);
@@ -3089,4 +3270,232 @@ describe('OpenAIContentGenerator', () => {
);
});
});
describe('cache control for DashScope', () => {
it('should add cache control to system message for DashScope providers', async () => {
// Mock environment to set dashscope base URL
vi.stubEnv(
'OPENAI_BASE_URL',
'https://dashscope.aliyuncs.com/compatible-mode/v1',
);
const dashscopeConfig = {
getContentGeneratorConfig: vi.fn().mockReturnValue({
authType: 'openai',
enableOpenAILogging: false,
}),
getSessionId: vi.fn().mockReturnValue('dashscope-session-id'),
getCliVersion: vi.fn().mockReturnValue('1.0.0'),
} as unknown as Config;
const contentGeneratorConfig = {
model: 'qwen-turbo',
apiKey: 'test-key',
authType: AuthType.QWEN_OAUTH,
enableOpenAILogging: false,
};
const dashscopeGenerator = new OpenAIContentGenerator(
contentGeneratorConfig,
dashscopeConfig,
);
// Mock the client's baseURL property to return the expected value
Object.defineProperty(dashscopeGenerator['client'], 'baseURL', {
value: 'https://dashscope.aliyuncs.com/compatible-mode/v1',
writable: true,
});
const mockResponse = {
id: 'chatcmpl-123',
choices: [
{
index: 0,
message: { role: 'assistant', content: 'Response' },
finish_reason: 'stop',
},
],
created: 1677652288,
model: 'qwen-turbo',
};
mockOpenAIClient.chat.completions.create.mockResolvedValue(mockResponse);
const request: GenerateContentParameters = {
contents: [{ role: 'user', parts: [{ text: 'Hello' }] }],
config: {
systemInstruction: 'You are a helpful assistant.',
},
model: 'qwen-turbo',
};
await dashscopeGenerator.generateContent(request, 'dashscope-prompt-id');
// Should include cache control in system message
expect(mockOpenAIClient.chat.completions.create).toHaveBeenCalledWith(
expect.objectContaining({
messages: expect.arrayContaining([
expect.objectContaining({
role: 'system',
content: expect.arrayContaining([
expect.objectContaining({
type: 'text',
text: 'You are a helpful assistant.',
cache_control: { type: 'ephemeral' },
}),
]),
}),
]),
}),
);
});
it('should add cache control to last message for DashScope providers', async () => {
// Mock environment to set dashscope base URL
vi.stubEnv(
'OPENAI_BASE_URL',
'https://dashscope.aliyuncs.com/compatible-mode/v1',
);
const dashscopeConfig = {
getContentGeneratorConfig: vi.fn().mockReturnValue({
authType: 'openai',
enableOpenAILogging: false,
}),
getSessionId: vi.fn().mockReturnValue('dashscope-session-id'),
getCliVersion: vi.fn().mockReturnValue('1.0.0'),
} as unknown as Config;
const contentGeneratorConfig = {
model: 'qwen-turbo',
apiKey: 'test-key',
authType: AuthType.QWEN_OAUTH,
enableOpenAILogging: false,
};
const dashscopeGenerator = new OpenAIContentGenerator(
contentGeneratorConfig,
dashscopeConfig,
);
// Mock the client's baseURL property to return the expected value
Object.defineProperty(dashscopeGenerator['client'], 'baseURL', {
value: 'https://dashscope.aliyuncs.com/compatible-mode/v1',
writable: true,
});
const mockResponse = {
id: 'chatcmpl-123',
choices: [
{
index: 0,
message: { role: 'assistant', content: 'Response' },
finish_reason: 'stop',
},
],
created: 1677652288,
model: 'qwen-turbo',
};
mockOpenAIClient.chat.completions.create.mockResolvedValue(mockResponse);
const request: GenerateContentParameters = {
contents: [{ role: 'user', parts: [{ text: 'Hello, how are you?' }] }],
model: 'qwen-turbo',
};
await dashscopeGenerator.generateContent(request, 'dashscope-prompt-id');
// Should include cache control in last message
expect(mockOpenAIClient.chat.completions.create).toHaveBeenCalledWith(
expect.objectContaining({
messages: expect.arrayContaining([
expect.objectContaining({
role: 'user',
content: expect.arrayContaining([
expect.objectContaining({
type: 'text',
text: 'Hello, how are you?',
cache_control: { type: 'ephemeral' },
}),
]),
}),
]),
}),
);
});
it('should NOT add cache control for non-DashScope providers', async () => {
const regularConfig = {
getContentGeneratorConfig: vi.fn().mockReturnValue({
authType: 'openai',
enableOpenAILogging: false,
}),
getSessionId: vi.fn().mockReturnValue('regular-session-id'),
getCliVersion: vi.fn().mockReturnValue('1.0.0'),
} as unknown as Config;
const contentGeneratorConfig = {
model: 'gpt-4',
apiKey: 'test-key',
authType: AuthType.USE_OPENAI,
enableOpenAILogging: false,
};
const regularGenerator = new OpenAIContentGenerator(
contentGeneratorConfig,
regularConfig,
);
const mockResponse = {
id: 'chatcmpl-123',
choices: [
{
index: 0,
message: { role: 'assistant', content: 'Response' },
finish_reason: 'stop',
},
],
created: 1677652288,
model: 'gpt-4',
};
mockOpenAIClient.chat.completions.create.mockResolvedValue(mockResponse);
const request: GenerateContentParameters = {
contents: [{ role: 'user', parts: [{ text: 'Hello' }] }],
config: {
systemInstruction: 'You are a helpful assistant.',
},
model: 'gpt-4',
};
await regularGenerator.generateContent(request, 'regular-prompt-id');
// Should NOT include cache control (messages should be strings, not arrays)
expect(mockOpenAIClient.chat.completions.create).toHaveBeenCalledWith(
expect.objectContaining({
messages: expect.arrayContaining([
expect.objectContaining({
role: 'system',
content: 'You are a helpful assistant.',
}),
expect.objectContaining({
role: 'user',
content: 'Hello',
}),
]),
}),
);
});
});
});

View File

@@ -20,7 +20,11 @@ import {
FunctionCall,
FunctionResponse,
} from '@google/genai';
import { AuthType, ContentGenerator } from './contentGenerator.js';
import {
AuthType,
ContentGenerator,
ContentGeneratorConfig,
} from './contentGenerator.js';
import OpenAI from 'openai';
import { logApiError, logApiResponse } from '../telemetry/loggers.js';
import { ApiErrorEvent, ApiResponseEvent } from '../telemetry/types.js';
@@ -28,6 +32,17 @@ import { Config } from '../config/config.js';
import { openaiLogger } from '../utils/openaiLogger.js';
import { safeJsonParse } from '../utils/safeJsonParse.js';
// Extended types to support cache_control
interface ChatCompletionContentPartTextWithCache
extends OpenAI.Chat.ChatCompletionContentPartText {
cache_control?: { type: 'ephemeral' };
}
type ChatCompletionContentPartWithCache =
| ChatCompletionContentPartTextWithCache
| OpenAI.Chat.ChatCompletionContentPartImage
| OpenAI.Chat.ChatCompletionContentPartRefusal;
// OpenAI API type definitions for logging
interface OpenAIToolCall {
id: string;
@@ -38,9 +53,15 @@ interface OpenAIToolCall {
};
}
interface OpenAIContentItem {
type: 'text';
text: string;
cache_control?: { type: 'ephemeral' };
}
interface OpenAIMessage {
role: 'system' | 'user' | 'assistant' | 'tool';
content: string | null;
content: string | null | OpenAIContentItem[];
tool_calls?: OpenAIToolCall[];
tool_call_id?: string;
}
@@ -60,15 +81,6 @@ interface OpenAIChoice {
finish_reason: string;
}
interface OpenAIRequestFormat {
model: string;
messages: OpenAIMessage[];
temperature?: number;
max_tokens?: number;
top_p?: number;
tools?: unknown[];
}
interface OpenAIResponseFormat {
id: string;
object: string;
@@ -81,6 +93,7 @@ interface OpenAIResponseFormat {
export class OpenAIContentGenerator implements ContentGenerator {
protected client: OpenAI;
private model: string;
private contentGeneratorConfig: ContentGeneratorConfig;
private config: Config;
private streamingToolCalls: Map<
number,
@@ -91,50 +104,40 @@ export class OpenAIContentGenerator implements ContentGenerator {
}
> = new Map();
constructor(apiKey: string, model: string, config: Config) {
this.model = model;
this.config = config;
const baseURL = process.env.OPENAI_BASE_URL || '';
constructor(
contentGeneratorConfig: ContentGeneratorConfig,
gcConfig: Config,
) {
this.model = contentGeneratorConfig.model;
this.contentGeneratorConfig = contentGeneratorConfig;
this.config = gcConfig;
// Configure timeout settings - using progressive timeouts
const timeoutConfig = {
// Base timeout for most requests (2 minutes)
timeout: 120000,
// Maximum retries for failed requests
maxRetries: 3,
// HTTP client options
httpAgent: undefined, // Let the client use default agent
};
// Allow config to override timeout settings
const contentGeneratorConfig = this.config.getContentGeneratorConfig();
if (contentGeneratorConfig?.timeout) {
timeoutConfig.timeout = contentGeneratorConfig.timeout;
}
if (contentGeneratorConfig?.maxRetries !== undefined) {
timeoutConfig.maxRetries = contentGeneratorConfig.maxRetries;
}
const version = config.getCliVersion() || 'unknown';
const version = gcConfig.getCliVersion() || 'unknown';
const userAgent = `QwenCode/${version} (${process.platform}; ${process.arch})`;
// Check if using OpenRouter and add required headers
const isOpenRouter = baseURL.includes('openrouter.ai');
const isOpenRouterProvider = this.isOpenRouterProvider();
const isDashScopeProvider = this.isDashScopeProvider();
const defaultHeaders = {
'User-Agent': userAgent,
...(isOpenRouter
...(isOpenRouterProvider
? {
'HTTP-Referer': 'https://github.com/QwenLM/qwen-code.git',
'X-Title': 'Qwen Code',
}
: {}),
: isDashScopeProvider
? {
'X-DashScope-CacheControl': 'enable',
}
: {}),
};
this.client = new OpenAI({
apiKey,
baseURL,
timeout: timeoutConfig.timeout,
maxRetries: timeoutConfig.maxRetries,
apiKey: contentGeneratorConfig.apiKey,
baseURL: contentGeneratorConfig.baseUrl,
timeout: contentGeneratorConfig.timeout ?? 120000,
maxRetries: contentGeneratorConfig.maxRetries ?? 3,
defaultHeaders,
});
}
@@ -185,22 +188,25 @@ export class OpenAIContentGenerator implements ContentGenerator {
);
}
private isOpenRouterProvider(): boolean {
const baseURL = this.contentGeneratorConfig.baseUrl || '';
return baseURL.includes('openrouter.ai');
}
/**
* Determine if metadata should be included in the request.
* Only include the `metadata` field if the provider is QWEN_OAUTH
* or the baseUrl is 'https://dashscope.aliyuncs.com/compatible-mode/v1'.
* This is because some models/providers do not support metadata or need extra configuration.
* Determine if this is a DashScope provider.
* DashScope providers include QWEN_OAUTH auth type or specific DashScope base URLs.
*
* @returns true if metadata should be included, false otherwise
* @returns true if this is a DashScope provider, false otherwise
*/
private shouldIncludeMetadata(): boolean {
const authType = this.config.getContentGeneratorConfig?.()?.authType;
// baseUrl may be undefined; default to empty string if so
const baseUrl = this.client?.baseURL || '';
private isDashScopeProvider(): boolean {
const authType = this.contentGeneratorConfig.authType;
const baseUrl = this.contentGeneratorConfig.baseUrl;
return (
authType === AuthType.QWEN_OAUTH ||
baseUrl === 'https://dashscope.aliyuncs.com/compatible-mode/v1'
baseUrl === 'https://dashscope.aliyuncs.com/compatible-mode/v1' ||
baseUrl === 'https://dashscope-intl.aliyuncs.com/compatible-mode/v1'
);
}
@@ -213,7 +219,7 @@ export class OpenAIContentGenerator implements ContentGenerator {
private buildMetadata(
userPromptId: string,
): { metadata: { sessionId?: string; promptId: string } } | undefined {
if (!this.shouldIncludeMetadata()) {
if (!this.isDashScopeProvider()) {
return undefined;
}
@@ -225,35 +231,44 @@ export class OpenAIContentGenerator implements ContentGenerator {
};
}
private async buildCreateParams(
request: GenerateContentParameters,
userPromptId: string,
): Promise<Parameters<typeof this.client.chat.completions.create>[0]> {
const messages = this.convertToOpenAIFormat(request);
// Build sampling parameters with clear priority:
// 1. Request-level parameters (highest priority)
// 2. Config-level sampling parameters (medium priority)
// 3. Default values (lowest priority)
const samplingParams = this.buildSamplingParameters(request);
const createParams: Parameters<
typeof this.client.chat.completions.create
>[0] = {
model: this.model,
messages,
...samplingParams,
...(this.buildMetadata(userPromptId) || {}),
};
if (request.config?.tools) {
createParams.tools = await this.convertGeminiToolsToOpenAI(
request.config.tools,
);
}
return createParams;
}
async generateContent(
request: GenerateContentParameters,
userPromptId: string,
): Promise<GenerateContentResponse> {
const startTime = Date.now();
const messages = this.convertToOpenAIFormat(request);
const createParams = await this.buildCreateParams(request, userPromptId);
try {
// Build sampling parameters with clear priority:
// 1. Request-level parameters (highest priority)
// 2. Config-level sampling parameters (medium priority)
// 3. Default values (lowest priority)
const samplingParams = this.buildSamplingParameters(request);
const createParams: Parameters<
typeof this.client.chat.completions.create
>[0] = {
model: this.model,
messages,
...samplingParams,
...(this.buildMetadata(userPromptId) || {}),
};
if (request.config?.tools) {
createParams.tools = await this.convertGeminiToolsToOpenAI(
request.config.tools,
);
}
// console.log('createParams', createParams);
const completion = (await this.client.chat.completions.create(
createParams,
)) as OpenAI.Chat.ChatCompletion;
@@ -267,15 +282,15 @@ export class OpenAIContentGenerator implements ContentGenerator {
this.model,
durationMs,
userPromptId,
this.config.getContentGeneratorConfig()?.authType,
this.contentGeneratorConfig.authType,
response.usageMetadata,
);
logApiResponse(this.config, responseEvent);
// Log interaction if enabled
if (this.config.getContentGeneratorConfig()?.enableOpenAILogging) {
const openaiRequest = await this.convertGeminiRequestToOpenAI(request);
if (this.contentGeneratorConfig.enableOpenAILogging) {
const openaiRequest = createParams;
const openaiResponse = this.convertGeminiResponseToOpenAI(response);
await openaiLogger.logInteraction(openaiRequest, openaiResponse);
}
@@ -300,7 +315,7 @@ export class OpenAIContentGenerator implements ContentGenerator {
errorMessage,
durationMs,
userPromptId,
this.config.getContentGeneratorConfig()?.authType,
this.contentGeneratorConfig.authType,
// eslint-disable-next-line @typescript-eslint/no-explicit-any
(error as any).type,
// eslint-disable-next-line @typescript-eslint/no-explicit-any
@@ -309,10 +324,9 @@ export class OpenAIContentGenerator implements ContentGenerator {
logApiError(this.config, errorEvent);
// Log error interaction if enabled
if (this.config.getContentGeneratorConfig()?.enableOpenAILogging) {
const openaiRequest = await this.convertGeminiRequestToOpenAI(request);
if (this.contentGeneratorConfig.enableOpenAILogging) {
await openaiLogger.logInteraction(
openaiRequest,
createParams,
undefined,
error as Error,
);
@@ -343,29 +357,12 @@ export class OpenAIContentGenerator implements ContentGenerator {
userPromptId: string,
): Promise<AsyncGenerator<GenerateContentResponse>> {
const startTime = Date.now();
const messages = this.convertToOpenAIFormat(request);
const createParams = await this.buildCreateParams(request, userPromptId);
createParams.stream = true;
createParams.stream_options = { include_usage: true };
try {
// Build sampling parameters with clear priority
const samplingParams = this.buildSamplingParameters(request);
const createParams: Parameters<
typeof this.client.chat.completions.create
>[0] = {
model: this.model,
messages,
...samplingParams,
stream: true,
stream_options: { include_usage: true },
...(this.buildMetadata(userPromptId) || {}),
};
if (request.config?.tools) {
createParams.tools = await this.convertGeminiToolsToOpenAI(
request.config.tools,
);
}
const stream = (await this.client.chat.completions.create(
createParams,
)) as AsyncIterable<OpenAI.Chat.ChatCompletionChunk>;
@@ -397,16 +394,15 @@ export class OpenAIContentGenerator implements ContentGenerator {
this.model,
durationMs,
userPromptId,
this.config.getContentGeneratorConfig()?.authType,
this.contentGeneratorConfig.authType,
finalUsageMetadata,
);
logApiResponse(this.config, responseEvent);
// Log interaction if enabled (same as generateContent method)
if (this.config.getContentGeneratorConfig()?.enableOpenAILogging) {
const openaiRequest =
await this.convertGeminiRequestToOpenAI(request);
if (this.contentGeneratorConfig.enableOpenAILogging) {
const openaiRequest = createParams;
// For streaming, we combine all responses into a single response for logging
const combinedResponse =
this.combineStreamResponsesForLogging(responses);
@@ -433,7 +429,7 @@ export class OpenAIContentGenerator implements ContentGenerator {
errorMessage,
durationMs,
userPromptId,
this.config.getContentGeneratorConfig()?.authType,
this.contentGeneratorConfig.authType,
// eslint-disable-next-line @typescript-eslint/no-explicit-any
(error as any).type,
// eslint-disable-next-line @typescript-eslint/no-explicit-any
@@ -442,11 +438,9 @@ export class OpenAIContentGenerator implements ContentGenerator {
logApiError(this.config, errorEvent);
// Log error interaction if enabled
if (this.config.getContentGeneratorConfig()?.enableOpenAILogging) {
const openaiRequest =
await this.convertGeminiRequestToOpenAI(request);
if (this.contentGeneratorConfig.enableOpenAILogging) {
await openaiLogger.logInteraction(
openaiRequest,
createParams,
undefined,
error as Error,
);
@@ -487,7 +481,7 @@ export class OpenAIContentGenerator implements ContentGenerator {
errorMessage,
durationMs,
userPromptId,
this.config.getContentGeneratorConfig()?.authType,
this.contentGeneratorConfig.authType,
// eslint-disable-next-line @typescript-eslint/no-explicit-any
(error as any).type,
// eslint-disable-next-line @typescript-eslint/no-explicit-any
@@ -563,7 +557,7 @@ export class OpenAIContentGenerator implements ContentGenerator {
// Add combined text if any
if (combinedText) {
combinedParts.push({ text: combinedText.trimEnd() });
combinedParts.push({ text: combinedText });
}
// Add function calls
@@ -944,7 +938,114 @@ export class OpenAIContentGenerator implements ContentGenerator {
// Clean up orphaned tool calls and merge consecutive assistant messages
const cleanedMessages = this.cleanOrphanedToolCalls(messages);
return this.mergeConsecutiveAssistantMessages(cleanedMessages);
const mergedMessages =
this.mergeConsecutiveAssistantMessages(cleanedMessages);
// Add cache control to system and last messages for DashScope providers
return this.addCacheControlFlag(mergedMessages, 'both');
}
/**
* Add cache control flag to specified message(s) for DashScope providers
*/
private addCacheControlFlag(
messages: OpenAI.Chat.ChatCompletionMessageParam[],
target: 'system' | 'last' | 'both' = 'both',
): OpenAI.Chat.ChatCompletionMessageParam[] {
if (!this.isDashScopeProvider() || messages.length === 0) {
return messages;
}
let updatedMessages = [...messages];
// Add cache control to system message if requested
if (target === 'system' || target === 'both') {
updatedMessages = this.addCacheControlToMessage(
updatedMessages,
'system',
);
}
// Add cache control to last message if requested
if (target === 'last' || target === 'both') {
updatedMessages = this.addCacheControlToMessage(updatedMessages, 'last');
}
return updatedMessages;
}
/**
* Helper method to add cache control to a specific message
*/
private addCacheControlToMessage(
messages: OpenAI.Chat.ChatCompletionMessageParam[],
target: 'system' | 'last',
): OpenAI.Chat.ChatCompletionMessageParam[] {
const updatedMessages = [...messages];
let messageIndex: number;
if (target === 'system') {
// Find the first system message
messageIndex = messages.findIndex((msg) => msg.role === 'system');
if (messageIndex === -1) {
return updatedMessages;
}
} else {
// Get the last message
messageIndex = messages.length - 1;
}
const message = updatedMessages[messageIndex];
// Only process messages that have content
if ('content' in message && message.content !== null) {
if (typeof message.content === 'string') {
// Convert string content to array format with cache control
const messageWithArrayContent = {
...message,
content: [
{
type: 'text',
text: message.content,
cache_control: { type: 'ephemeral' },
} as ChatCompletionContentPartTextWithCache,
],
};
updatedMessages[messageIndex] =
messageWithArrayContent as OpenAI.Chat.ChatCompletionMessageParam;
} else if (Array.isArray(message.content)) {
// If content is already an array, add cache_control to the last item
const contentArray = [
...message.content,
] as ChatCompletionContentPartWithCache[];
if (contentArray.length > 0) {
const lastItem = contentArray[contentArray.length - 1];
if (lastItem.type === 'text') {
// Add cache_control to the last text item
contentArray[contentArray.length - 1] = {
...lastItem,
cache_control: { type: 'ephemeral' },
} as ChatCompletionContentPartTextWithCache;
} else {
// If the last item is not text, add a new text item with cache_control
contentArray.push({
type: 'text',
text: '',
cache_control: { type: 'ephemeral' },
} as ChatCompletionContentPartTextWithCache);
}
const messageWithCache = {
...message,
content: contentArray,
};
updatedMessages[messageIndex] =
messageWithCache as OpenAI.Chat.ChatCompletionMessageParam;
}
}
}
return updatedMessages;
}
/**
@@ -1164,11 +1265,7 @@ export class OpenAIContentGenerator implements ContentGenerator {
// Handle text content
if (choice.message.content) {
if (typeof choice.message.content === 'string') {
parts.push({ text: choice.message.content.trimEnd() });
} else {
parts.push({ text: choice.message.content });
}
parts.push({ text: choice.message.content });
}
// Handle tool calls
@@ -1253,11 +1350,7 @@ export class OpenAIContentGenerator implements ContentGenerator {
// Handle text content
if (choice.delta?.content) {
if (typeof choice.delta.content === 'string') {
parts.push({ text: choice.delta.content.trimEnd() });
} else {
parts.push({ text: choice.delta.content });
}
parts.push({ text: choice.delta.content });
}
// Handle tool calls - only accumulate during streaming, emit when complete
@@ -1376,8 +1469,7 @@ export class OpenAIContentGenerator implements ContentGenerator {
private buildSamplingParameters(
request: GenerateContentParameters,
): Record<string, unknown> {
const configSamplingParams =
this.config.getContentGeneratorConfig()?.samplingParams;
const configSamplingParams = this.contentGeneratorConfig.samplingParams;
const params = {
// Temperature: config > request > default
@@ -1439,313 +1531,6 @@ export class OpenAIContentGenerator implements ContentGenerator {
return mapping[openaiReason] || FinishReason.FINISH_REASON_UNSPECIFIED;
}
/**
* Convert Gemini request format to OpenAI chat completion format for logging
*/
private async convertGeminiRequestToOpenAI(
request: GenerateContentParameters,
): Promise<OpenAIRequestFormat> {
const messages: OpenAIMessage[] = [];
// Handle system instruction
if (request.config?.systemInstruction) {
const systemInstruction = request.config.systemInstruction;
let systemText = '';
if (Array.isArray(systemInstruction)) {
systemText = systemInstruction
.map((content) => {
if (typeof content === 'string') return content;
if ('parts' in content) {
const contentObj = content as Content;
return (
contentObj.parts
?.map((p: Part) =>
typeof p === 'string' ? p : 'text' in p ? p.text : '',
)
.join('\n') || ''
);
}
return '';
})
.join('\n');
} else if (typeof systemInstruction === 'string') {
systemText = systemInstruction;
} else if (
typeof systemInstruction === 'object' &&
'parts' in systemInstruction
) {
const systemContent = systemInstruction as Content;
systemText =
systemContent.parts
?.map((p: Part) =>
typeof p === 'string' ? p : 'text' in p ? p.text : '',
)
.join('\n') || '';
}
if (systemText) {
messages.push({
role: 'system',
content: systemText,
});
}
}
// Handle contents
if (Array.isArray(request.contents)) {
for (const content of request.contents) {
if (typeof content === 'string') {
messages.push({ role: 'user', content });
} else if ('role' in content && 'parts' in content) {
const functionCalls: FunctionCall[] = [];
const functionResponses: FunctionResponse[] = [];
const textParts: string[] = [];
for (const part of content.parts || []) {
if (typeof part === 'string') {
textParts.push(part);
} else if ('text' in part && part.text) {
textParts.push(part.text);
} else if ('functionCall' in part && part.functionCall) {
functionCalls.push(part.functionCall);
} else if ('functionResponse' in part && part.functionResponse) {
functionResponses.push(part.functionResponse);
}
}
// Handle function responses (tool results)
if (functionResponses.length > 0) {
for (const funcResponse of functionResponses) {
messages.push({
role: 'tool',
tool_call_id: funcResponse.id || '',
content:
typeof funcResponse.response === 'string'
? funcResponse.response
: JSON.stringify(funcResponse.response),
});
}
}
// Handle model messages with function calls
else if (content.role === 'model' && functionCalls.length > 0) {
const toolCalls = functionCalls.map((fc, index) => ({
id: fc.id || `call_${index}`,
type: 'function' as const,
function: {
name: fc.name || '',
arguments: JSON.stringify(fc.args || {}),
},
}));
messages.push({
role: 'assistant',
content: textParts.join('\n') || null,
tool_calls: toolCalls,
});
}
// Handle regular text messages
else {
const role = content.role === 'model' ? 'assistant' : 'user';
const text = textParts.join('\n');
if (text) {
messages.push({ role, content: text });
}
}
}
}
} else if (request.contents) {
if (typeof request.contents === 'string') {
messages.push({ role: 'user', content: request.contents });
} else if ('role' in request.contents && 'parts' in request.contents) {
const content = request.contents;
const role = content.role === 'model' ? 'assistant' : 'user';
const text =
content.parts
?.map((p: Part) =>
typeof p === 'string' ? p : 'text' in p ? p.text : '',
)
.join('\n') || '';
messages.push({ role, content: text });
}
}
// Clean up orphaned tool calls and merge consecutive assistant messages
const cleanedMessages = this.cleanOrphanedToolCallsForLogging(messages);
const mergedMessages =
this.mergeConsecutiveAssistantMessagesForLogging(cleanedMessages);
const openaiRequest: OpenAIRequestFormat = {
model: this.model,
messages: mergedMessages,
};
// Add sampling parameters using the same logic as actual API calls
const samplingParams = this.buildSamplingParameters(request);
Object.assign(openaiRequest, samplingParams);
// Convert tools if present
if (request.config?.tools) {
openaiRequest.tools = await this.convertGeminiToolsToOpenAI(
request.config.tools,
);
}
return openaiRequest;
}
/**
* Clean up orphaned tool calls for logging purposes
*/
private cleanOrphanedToolCallsForLogging(
messages: OpenAIMessage[],
): OpenAIMessage[] {
const cleaned: OpenAIMessage[] = [];
const toolCallIds = new Set<string>();
const toolResponseIds = new Set<string>();
// First pass: collect all tool call IDs and tool response IDs
for (const message of messages) {
if (message.role === 'assistant' && message.tool_calls) {
for (const toolCall of message.tool_calls) {
if (toolCall.id) {
toolCallIds.add(toolCall.id);
}
}
} else if (message.role === 'tool' && message.tool_call_id) {
toolResponseIds.add(message.tool_call_id);
}
}
// Second pass: filter out orphaned messages
for (const message of messages) {
if (message.role === 'assistant' && message.tool_calls) {
// Filter out tool calls that don't have corresponding responses
const validToolCalls = message.tool_calls.filter(
(toolCall) => toolCall.id && toolResponseIds.has(toolCall.id),
);
if (validToolCalls.length > 0) {
// Keep the message but only with valid tool calls
const cleanedMessage = { ...message };
cleanedMessage.tool_calls = validToolCalls;
cleaned.push(cleanedMessage);
} else if (
typeof message.content === 'string' &&
message.content.trim()
) {
// Keep the message if it has text content, but remove tool calls
const cleanedMessage = { ...message };
delete cleanedMessage.tool_calls;
cleaned.push(cleanedMessage);
}
// If no valid tool calls and no content, skip the message entirely
} else if (message.role === 'tool' && message.tool_call_id) {
// Only keep tool responses that have corresponding tool calls
if (toolCallIds.has(message.tool_call_id)) {
cleaned.push(message);
}
} else {
// Keep all other messages as-is
cleaned.push(message);
}
}
// Final validation: ensure every assistant message with tool_calls has corresponding tool responses
const finalCleaned: OpenAIMessage[] = [];
const finalToolCallIds = new Set<string>();
// Collect all remaining tool call IDs
for (const message of cleaned) {
if (message.role === 'assistant' && message.tool_calls) {
for (const toolCall of message.tool_calls) {
if (toolCall.id) {
finalToolCallIds.add(toolCall.id);
}
}
}
}
// Verify all tool calls have responses
const finalToolResponseIds = new Set<string>();
for (const message of cleaned) {
if (message.role === 'tool' && message.tool_call_id) {
finalToolResponseIds.add(message.tool_call_id);
}
}
// Remove any remaining orphaned tool calls
for (const message of cleaned) {
if (message.role === 'assistant' && message.tool_calls) {
const finalValidToolCalls = message.tool_calls.filter(
(toolCall) => toolCall.id && finalToolResponseIds.has(toolCall.id),
);
if (finalValidToolCalls.length > 0) {
const cleanedMessage = { ...message };
cleanedMessage.tool_calls = finalValidToolCalls;
finalCleaned.push(cleanedMessage);
} else if (
typeof message.content === 'string' &&
message.content.trim()
) {
const cleanedMessage = { ...message };
delete cleanedMessage.tool_calls;
finalCleaned.push(cleanedMessage);
}
} else {
finalCleaned.push(message);
}
}
return finalCleaned;
}
/**
* Merge consecutive assistant messages to combine split text and tool calls for logging
*/
private mergeConsecutiveAssistantMessagesForLogging(
messages: OpenAIMessage[],
): OpenAIMessage[] {
const merged: OpenAIMessage[] = [];
for (const message of messages) {
if (message.role === 'assistant' && merged.length > 0) {
const lastMessage = merged[merged.length - 1];
// If the last message is also an assistant message, merge them
if (lastMessage.role === 'assistant') {
// Combine content
const combinedContent = [
lastMessage.content || '',
message.content || '',
]
.filter(Boolean)
.join('');
// Combine tool calls
const combinedToolCalls = [
...(lastMessage.tool_calls || []),
...(message.tool_calls || []),
];
// Update the last message with combined data
lastMessage.content = combinedContent || null;
if (combinedToolCalls.length > 0) {
lastMessage.tool_calls = combinedToolCalls;
}
continue; // Skip adding the current message since it's been merged
}
}
// Add the message as-is if no merging is needed
merged.push(message);
}
return merged;
}
/**
* Convert Gemini response format to OpenAI chat completion format for logging
*/
@@ -1776,7 +1561,7 @@ export class OpenAIContentGenerator implements ContentGenerator {
}
}
messageContent = textParts.join('').trimEnd();
messageContent = textParts.join('');
}
const choice: OpenAIChoice = {

View File

@@ -54,15 +54,39 @@ describe('detectIde', () => {
expected: DetectedIde.FirebaseStudio,
},
])('detects the IDE for $expected', ({ env, expected }) => {
// Clear all environment variables first
vi.unstubAllEnvs();
// Set TERM_PROGRAM to vscode (required for all IDE detection)
vi.stubEnv('TERM_PROGRAM', 'vscode');
// Explicitly stub all environment variables that detectIde() checks to undefined
// This ensures no real environment variables interfere with the tests
vi.stubEnv('__COG_BASHRC_SOURCED', undefined);
vi.stubEnv('REPLIT_USER', undefined);
vi.stubEnv('CURSOR_TRACE_ID', undefined);
vi.stubEnv('CODESPACES', undefined);
vi.stubEnv('EDITOR_IN_CLOUD_SHELL', undefined);
vi.stubEnv('CLOUD_SHELL', undefined);
vi.stubEnv('TERM_PRODUCT', undefined);
vi.stubEnv('FIREBASE_DEPLOY_AGENT', undefined);
vi.stubEnv('MONOSPACE_ENV', undefined);
// Set only the specific environment variables for this test case
for (const [key, value] of Object.entries(env)) {
vi.stubEnv(key, value);
}
expect(detectIde()).toBe(expected);
});
it('returns undefined for non-vscode', () => {
// Clear all environment variables first
vi.unstubAllEnvs();
// Set TERM_PROGRAM to something other than vscode
vi.stubEnv('TERM_PROGRAM', 'definitely-not-vscode');
expect(detectIde()).toBeUndefined();
});
});

View File

@@ -21,6 +21,7 @@ import {
} from '@google/genai';
import { QwenContentGenerator } from './qwenContentGenerator.js';
import { Config } from '../config/config.js';
import { AuthType, ContentGeneratorConfig } from '../core/contentGenerator.js';
// Mock the OpenAIContentGenerator parent class
vi.mock('../core/openaiContentGenerator.js', () => ({
@@ -30,10 +31,13 @@ vi.mock('../core/openaiContentGenerator.js', () => ({
baseURL: string;
};
constructor(apiKey: string, _model: string, _config: Config) {
constructor(
contentGeneratorConfig: ContentGeneratorConfig,
_config: Config,
) {
this.client = {
apiKey,
baseURL: 'https://api.openai.com/v1',
apiKey: contentGeneratorConfig.apiKey || 'test-key',
baseURL: contentGeneratorConfig.baseUrl || 'https://api.openai.com/v1',
};
}
@@ -131,9 +135,13 @@ describe('QwenContentGenerator', () => {
};
// Create QwenContentGenerator instance
const contentGeneratorConfig = {
model: 'qwen-turbo',
authType: AuthType.QWEN_OAUTH,
};
qwenContentGenerator = new QwenContentGenerator(
mockQwenClient,
'qwen-turbo',
contentGeneratorConfig,
mockConfig,
);
});

View File

@@ -20,6 +20,7 @@ import {
EmbedContentParameters,
EmbedContentResponse,
} from '@google/genai';
import { ContentGeneratorConfig } from '../core/contentGenerator.js';
// Default fallback base URL if no endpoint is provided
const DEFAULT_QWEN_BASE_URL =
@@ -36,9 +37,13 @@ export class QwenContentGenerator extends OpenAIContentGenerator {
private currentEndpoint: string | null = null;
private refreshPromise: Promise<string> | null = null;
constructor(qwenClient: IQwenOAuth2Client, model: string, config: Config) {
constructor(
qwenClient: IQwenOAuth2Client,
contentGeneratorConfig: ContentGeneratorConfig,
config: Config,
) {
// Initialize with empty API key, we'll override it dynamically
super('', model, config);
super(contentGeneratorConfig, config);
this.qwenClient = qwenClient;
// Set default base URL, will be updated dynamically

View File

@@ -205,10 +205,26 @@ describe('ClearcutLogger', () => {
'logs the current surface for as $expectedValue, preempting vscode detection',
({ env, expectedValue }) => {
const { logger } = setup({});
// Clear all environment variables that could interfere with surface detection
vi.stubEnv('SURFACE', undefined);
vi.stubEnv('GITHUB_SHA', undefined);
vi.stubEnv('CURSOR_TRACE_ID', undefined);
vi.stubEnv('__COG_BASHRC_SOURCED', undefined);
vi.stubEnv('REPLIT_USER', undefined);
vi.stubEnv('CODESPACES', undefined);
vi.stubEnv('EDITOR_IN_CLOUD_SHELL', undefined);
vi.stubEnv('CLOUD_SHELL', undefined);
vi.stubEnv('TERM_PRODUCT', undefined);
vi.stubEnv('FIREBASE_DEPLOY_AGENT', undefined);
vi.stubEnv('MONOSPACE_ENV', undefined);
// Set the specific environment variables for this test case
for (const [key, value] of Object.entries(env)) {
vi.stubEnv(key, value);
}
vi.stubEnv('TERM_PROGRAM', 'vscode');
const event = logger?.createLogEvent('abc', []);
expect(event?.event_metadata[0][1]).toEqual({
gemini_cli_key: EventMetadataKey.GEMINI_CLI_SURFACE,

View File

@@ -439,4 +439,84 @@ describe('GrepTool', () => {
expect(invocation.getDescription()).toBe("'testPattern' within ./");
});
});
describe('Result limiting', () => {
beforeEach(async () => {
// Create many test files with matches to test limiting
for (let i = 1; i <= 30; i++) {
const fileName = `test${i}.txt`;
const content = `This is test file ${i} with the pattern testword in it.`;
await fs.writeFile(path.join(tempRootDir, fileName), content);
}
});
it('should limit results to default 20 matches', async () => {
const params: GrepToolParams = { pattern: 'testword' };
const invocation = grepTool.build(params);
const result = await invocation.execute(abortSignal);
expect(result.llmContent).toContain('Found 20 matches');
expect(result.llmContent).toContain(
'showing first 20 of 30+ total matches',
);
expect(result.llmContent).toContain('WARNING: Results truncated');
expect(result.returnDisplay).toContain(
'Found 20 matches (truncated from 30+)',
);
});
it('should respect custom maxResults parameter', async () => {
const params: GrepToolParams = { pattern: 'testword', maxResults: 5 };
const invocation = grepTool.build(params);
const result = await invocation.execute(abortSignal);
expect(result.llmContent).toContain('Found 5 matches');
expect(result.llmContent).toContain(
'showing first 5 of 30+ total matches',
);
expect(result.llmContent).toContain('current: 5');
expect(result.returnDisplay).toContain(
'Found 5 matches (truncated from 30+)',
);
});
it('should not show truncation warning when all results fit', async () => {
const params: GrepToolParams = { pattern: 'testword', maxResults: 50 };
const invocation = grepTool.build(params);
const result = await invocation.execute(abortSignal);
expect(result.llmContent).toContain('Found 30 matches');
expect(result.llmContent).not.toContain('WARNING: Results truncated');
expect(result.llmContent).not.toContain('showing first');
expect(result.returnDisplay).toBe('Found 30 matches');
});
it('should validate maxResults parameter', () => {
const invalidParams = [
{ pattern: 'test', maxResults: 0 },
{ pattern: 'test', maxResults: 101 },
{ pattern: 'test', maxResults: -1 },
{ pattern: 'test', maxResults: 1.5 },
];
invalidParams.forEach((params) => {
const error = grepTool.validateToolParams(params as GrepToolParams);
expect(error).toBeTruthy(); // Just check that validation fails
expect(error).toMatch(/maxResults|must be/); // Check it's about maxResults validation
});
});
it('should accept valid maxResults parameter', () => {
const validParams = [
{ pattern: 'test', maxResults: 1 },
{ pattern: 'test', maxResults: 50 },
{ pattern: 'test', maxResults: 100 },
];
validParams.forEach((params) => {
const error = grepTool.validateToolParams(params);
expect(error).toBeNull();
});
});
});
});

View File

@@ -43,6 +43,11 @@ export interface GrepToolParams {
* File pattern to include in the search (e.g. "*.js", "*.{ts,tsx}")
*/
include?: string;
/**
* Maximum number of matches to return (optional, defaults to 20)
*/
maxResults?: number;
}
/**
@@ -124,6 +129,10 @@ class GrepToolInvocation extends BaseToolInvocation<
// Collect matches from all search directories
let allMatches: GrepMatch[] = [];
const maxResults = this.params.maxResults ?? 20; // Default to 20 results
let totalMatchesFound = 0;
let searchTruncated = false;
for (const searchDir of searchDirectories) {
const matches = await this.performGrepSearch({
pattern: this.params.pattern,
@@ -132,6 +141,8 @@ class GrepToolInvocation extends BaseToolInvocation<
signal,
});
totalMatchesFound += matches.length;
// Add directory prefix if searching multiple directories
if (searchDirectories.length > 1) {
const dirName = path.basename(searchDir);
@@ -140,7 +151,20 @@ class GrepToolInvocation extends BaseToolInvocation<
});
}
allMatches = allMatches.concat(matches);
// Apply result limiting
const remainingSlots = maxResults - allMatches.length;
if (remainingSlots <= 0) {
searchTruncated = true;
break;
}
if (matches.length > remainingSlots) {
allMatches = allMatches.concat(matches.slice(0, remainingSlots));
searchTruncated = true;
break;
} else {
allMatches = allMatches.concat(matches);
}
}
let searchLocationDescription: string;
@@ -176,7 +200,14 @@ class GrepToolInvocation extends BaseToolInvocation<
const matchCount = allMatches.length;
const matchTerm = matchCount === 1 ? 'match' : 'matches';
let llmContent = `Found ${matchCount} ${matchTerm} for pattern "${this.params.pattern}" ${searchLocationDescription}${this.params.include ? ` (filter: "${this.params.include}")` : ''}:
// Build the header with truncation info if needed
let headerText = `Found ${matchCount} ${matchTerm} for pattern "${this.params.pattern}" ${searchLocationDescription}${this.params.include ? ` (filter: "${this.params.include}")` : ''}`;
if (searchTruncated) {
headerText += ` (showing first ${matchCount} of ${totalMatchesFound}+ total matches)`;
}
let llmContent = `${headerText}:
---
`;
@@ -189,9 +220,23 @@ class GrepToolInvocation extends BaseToolInvocation<
llmContent += '---\n';
}
// Add truncation guidance if results were limited
if (searchTruncated) {
llmContent += `\nWARNING: Results truncated to prevent context overflow. To see more results:
- Use a more specific pattern to reduce matches
- Add file filters with the 'include' parameter (e.g., "*.js", "src/**")
- Specify a narrower 'path' to search in a subdirectory
- Increase 'maxResults' parameter if you need more matches (current: ${maxResults})`;
}
let displayText = `Found ${matchCount} ${matchTerm}`;
if (searchTruncated) {
displayText += ` (truncated from ${totalMatchesFound}+)`;
}
return {
llmContent: llmContent.trim(),
returnDisplay: `Found ${matchCount} ${matchTerm}`,
returnDisplay: displayText,
};
} catch (error) {
console.error(`Error during GrepLogic execution: ${error}`);
@@ -567,6 +612,13 @@ export class GrepTool extends BaseDeclarativeTool<GrepToolParams, ToolResult> {
"Optional: A glob pattern to filter which files are searched (e.g., '*.js', '*.{ts,tsx}', 'src/**'). If omitted, searches all files (respecting potential global ignores).",
type: 'string',
},
maxResults: {
description:
'Optional: Maximum number of matches to return to prevent context overflow (default: 20, max: 100). Use lower values for broad searches, higher for specific searches.',
type: 'number',
minimum: 1,
maximum: 100,
},
},
required: ['pattern'],
type: 'object',
@@ -635,6 +687,17 @@ export class GrepTool extends BaseDeclarativeTool<GrepToolParams, ToolResult> {
return `Invalid regular expression pattern provided: ${params.pattern}. Error: ${getErrorMessage(error)}`;
}
// Validate maxResults if provided
if (params.maxResults !== undefined) {
if (
!Number.isInteger(params.maxResults) ||
params.maxResults < 1 ||
params.maxResults > 100
) {
return `maxResults must be an integer between 1 and 100, got: ${params.maxResults}`;
}
}
// Only validate path if one is provided
if (params.path) {
try {