feat: Add deterministic cache control (#411)

* feat: add deterministic cache control
This commit is contained in:
tanzhenxin
2025-08-21 18:33:13 +08:00
committed by GitHub
parent cd5e592b6a
commit 742337c390
13 changed files with 757 additions and 527 deletions

1
package-lock.json generated
View File

@@ -11997,6 +11997,7 @@
"strip-ansi": "^7.1.0",
"tiktoken": "^1.0.21",
"undici": "^7.10.0",
"uuid": "^9.0.1",
"ws": "^8.18.0"
},
"devDependencies": {

View File

@@ -517,7 +517,6 @@ export async function loadCliConfig(
(typeof argv.openaiLogging === 'undefined'
? settings.enableOpenAILogging
: argv.openaiLogging) ?? false,
sampling_params: settings.sampling_params,
systemPromptMappings: (settings.systemPromptMappings ?? [
{
baseUrls: [

View File

@@ -503,7 +503,6 @@ export const SETTINGS_SCHEMA = {
description: 'Show line numbers in the chat.',
showInDialog: true,
},
contentGenerator: {
type: 'object',
label: 'Content Generator',
@@ -513,15 +512,6 @@ export const SETTINGS_SCHEMA = {
description: 'Content generator settings.',
showInDialog: false,
},
sampling_params: {
type: 'object',
label: 'Sampling Params',
category: 'General',
requiresRestart: false,
default: undefined as Record<string, unknown> | undefined,
description: 'Sampling parameters for the model.',
showInDialog: false,
},
enableOpenAILogging: {
type: 'boolean',
label: 'Enable OpenAI Logging',

View File

@@ -52,6 +52,7 @@
"strip-ansi": "^7.1.0",
"tiktoken": "^1.0.21",
"undici": "^7.10.0",
"uuid": "^9.0.1",
"ws": "^8.18.0"
},
"devDependencies": {

View File

@@ -204,7 +204,6 @@ export interface ConfigParameters {
folderTrust?: boolean;
ideMode?: boolean;
enableOpenAILogging?: boolean;
sampling_params?: Record<string, unknown>;
systemPromptMappings?: Array<{
baseUrls: string[];
modelNames: string[];
@@ -213,6 +212,9 @@ export interface ConfigParameters {
contentGenerator?: {
timeout?: number;
maxRetries?: number;
samplingParams?: {
[key: string]: unknown;
};
};
cliVersion?: string;
loadMemoryFromIncludeDirectories?: boolean;
@@ -289,10 +291,10 @@ export class Config {
| undefined;
private readonly experimentalAcp: boolean = false;
private readonly enableOpenAILogging: boolean;
private readonly sampling_params?: Record<string, unknown>;
private readonly contentGenerator?: {
timeout?: number;
maxRetries?: number;
samplingParams?: Record<string, unknown>;
};
private readonly cliVersion?: string;
private readonly loadMemoryFromIncludeDirectories: boolean = false;
@@ -367,7 +369,6 @@ export class Config {
this.ideClient = IdeClient.getInstance();
this.systemPromptMappings = params.systemPromptMappings;
this.enableOpenAILogging = params.enableOpenAILogging ?? false;
this.sampling_params = params.sampling_params;
this.contentGenerator = params.contentGenerator;
this.cliVersion = params.cliVersion;
@@ -757,10 +758,6 @@ export class Config {
return this.enableOpenAILogging;
}
getSamplingParams(): Record<string, unknown> | undefined {
return this.sampling_params;
}
getContentGeneratorTimeout(): number | undefined {
return this.contentGenerator?.timeout;
}
@@ -769,6 +766,12 @@ export class Config {
return this.contentGenerator?.maxRetries;
}
getContentGeneratorSamplingParams(): ContentGeneratorConfig['samplingParams'] {
return this.contentGenerator?.samplingParams as
| ContentGeneratorConfig['samplingParams']
| undefined;
}
getCliVersion(): string | undefined {
return this.cliVersion;
}

View File

@@ -7,6 +7,7 @@
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
import { OpenAIContentGenerator } from '../openaiContentGenerator.js';
import { Config } from '../../config/config.js';
import { AuthType } from '../contentGenerator.js';
import OpenAI from 'openai';
// Mock OpenAI
@@ -41,9 +42,6 @@ describe('OpenAIContentGenerator Timeout Handling', () => {
mockConfig = {
getContentGeneratorConfig: vi.fn().mockReturnValue({
authType: 'openai',
enableOpenAILogging: false,
timeout: 120000,
maxRetries: 3,
}),
getCliVersion: vi.fn().mockReturnValue('1.0.0'),
} as unknown as Config;
@@ -60,7 +58,12 @@ describe('OpenAIContentGenerator Timeout Handling', () => {
vi.mocked(OpenAI).mockImplementation(() => mockOpenAIClient);
// Create generator instance
generator = new OpenAIContentGenerator('test-key', 'gpt-4', mockConfig);
const contentGeneratorConfig = {
model: 'gpt-4',
apiKey: 'test-key',
authType: AuthType.USE_OPENAI,
};
generator = new OpenAIContentGenerator(contentGeneratorConfig, mockConfig);
});
afterEach(() => {
@@ -237,12 +240,18 @@ describe('OpenAIContentGenerator Timeout Handling', () => {
describe('timeout configuration', () => {
it('should use default timeout configuration', () => {
new OpenAIContentGenerator('test-key', 'gpt-4', mockConfig);
const contentGeneratorConfig = {
model: 'gpt-4',
apiKey: 'test-key',
authType: AuthType.USE_OPENAI,
baseUrl: 'http://localhost:8080',
};
new OpenAIContentGenerator(contentGeneratorConfig, mockConfig);
// Verify OpenAI client was created with timeout config
expect(OpenAI).toHaveBeenCalledWith({
apiKey: 'test-key',
baseURL: '',
baseURL: 'http://localhost:8080',
timeout: 120000,
maxRetries: 3,
defaultHeaders: {
@@ -253,18 +262,23 @@ describe('OpenAIContentGenerator Timeout Handling', () => {
it('should use custom timeout from config', () => {
const customConfig = {
getContentGeneratorConfig: vi.fn().mockReturnValue({
timeout: 300000, // 5 minutes
maxRetries: 5,
}),
getContentGeneratorConfig: vi.fn().mockReturnValue({}),
getCliVersion: vi.fn().mockReturnValue('1.0.0'),
} as unknown as Config;
new OpenAIContentGenerator('test-key', 'gpt-4', customConfig);
const contentGeneratorConfig = {
model: 'gpt-4',
apiKey: 'test-key',
baseUrl: 'http://localhost:8080',
authType: AuthType.USE_OPENAI,
timeout: 300000,
maxRetries: 5,
};
new OpenAIContentGenerator(contentGeneratorConfig, customConfig);
expect(OpenAI).toHaveBeenCalledWith({
apiKey: 'test-key',
baseURL: '',
baseURL: 'http://localhost:8080',
timeout: 300000,
maxRetries: 5,
defaultHeaders: {
@@ -279,11 +293,17 @@ describe('OpenAIContentGenerator Timeout Handling', () => {
getCliVersion: vi.fn().mockReturnValue('1.0.0'),
} as unknown as Config;
new OpenAIContentGenerator('test-key', 'gpt-4', noTimeoutConfig);
const contentGeneratorConfig = {
model: 'gpt-4',
apiKey: 'test-key',
authType: AuthType.USE_OPENAI,
baseUrl: 'http://localhost:8080',
};
new OpenAIContentGenerator(contentGeneratorConfig, noTimeoutConfig);
expect(OpenAI).toHaveBeenCalledWith({
apiKey: 'test-key',
baseURL: '',
baseURL: 'http://localhost:8080',
timeout: 120000, // default
maxRetries: 3, // default
defaultHeaders: {

View File

@@ -565,10 +565,7 @@ export class GeminiClient {
model || this.config.getModel() || DEFAULT_GEMINI_FLASH_MODEL;
try {
const userMemory = this.config.getUserMemory();
const systemPromptMappings = this.config.getSystemPromptMappings();
const systemInstruction = getCoreSystemPrompt(userMemory, {
systemPromptMappings,
});
const systemInstruction = getCoreSystemPrompt(userMemory);
const requestConfig = {
abortSignal,
...this.generateContentConfig,
@@ -656,10 +653,7 @@ export class GeminiClient {
try {
const userMemory = this.config.getUserMemory();
const systemPromptMappings = this.config.getSystemPromptMappings();
const systemInstruction = getCoreSystemPrompt(userMemory, {
systemPromptMappings,
});
const systemInstruction = getCoreSystemPrompt(userMemory);
const requestConfig = {
abortSignal,

View File

@@ -84,6 +84,7 @@ describe('createContentGeneratorConfig', () => {
getSamplingParams: vi.fn().mockReturnValue(undefined),
getContentGeneratorTimeout: vi.fn().mockReturnValue(undefined),
getContentGeneratorMaxRetries: vi.fn().mockReturnValue(undefined),
getContentGeneratorSamplingParams: vi.fn().mockReturnValue(undefined),
getCliVersion: vi.fn().mockReturnValue('1.0.0'),
} as unknown as Config;

View File

@@ -53,6 +53,7 @@ export enum AuthType {
export type ContentGeneratorConfig = {
model: string;
apiKey?: string;
baseUrl?: string;
vertexai?: boolean;
authType?: AuthType | undefined;
enableOpenAILogging?: boolean;
@@ -76,11 +77,16 @@ export function createContentGeneratorConfig(
config: Config,
authType: AuthType | undefined,
): ContentGeneratorConfig {
// google auth
const geminiApiKey = process.env.GEMINI_API_KEY || undefined;
const googleApiKey = process.env.GOOGLE_API_KEY || undefined;
const googleCloudProject = process.env.GOOGLE_CLOUD_PROJECT || undefined;
const googleCloudLocation = process.env.GOOGLE_CLOUD_LOCATION || undefined;
// openai auth
const openaiApiKey = process.env.OPENAI_API_KEY;
const openaiBaseUrl = process.env.OPENAI_BASE_URL || undefined;
const openaiModel = process.env.OPENAI_MODEL || undefined;
// Use runtime model from config if available; otherwise, fall back to parameter or default
const effectiveModel = config.getModel() || DEFAULT_GEMINI_MODEL;
@@ -92,7 +98,7 @@ export function createContentGeneratorConfig(
enableOpenAILogging: config.getEnableOpenAILogging(),
timeout: config.getContentGeneratorTimeout(),
maxRetries: config.getContentGeneratorMaxRetries(),
samplingParams: config.getSamplingParams(),
samplingParams: config.getContentGeneratorSamplingParams(),
};
// If we are using Google auth or we are in Cloud Shell, there is nothing else to validate for now
@@ -127,8 +133,8 @@ export function createContentGeneratorConfig(
if (authType === AuthType.USE_OPENAI && openaiApiKey) {
contentGeneratorConfig.apiKey = openaiApiKey;
contentGeneratorConfig.model =
process.env.OPENAI_MODEL || DEFAULT_GEMINI_MODEL;
contentGeneratorConfig.baseUrl = openaiBaseUrl;
contentGeneratorConfig.model = openaiModel || DEFAULT_QWEN_MODEL;
return contentGeneratorConfig;
}
@@ -196,7 +202,7 @@ export async function createContentGenerator(
);
// Always use OpenAIContentGenerator, logging is controlled by enableOpenAILogging flag
return new OpenAIContentGenerator(config.apiKey, config.model, gcConfig);
return new OpenAIContentGenerator(config, gcConfig);
}
if (config.authType === AuthType.QWEN_OAUTH) {
@@ -217,7 +223,7 @@ export async function createContentGenerator(
const qwenClient = await getQwenOauthClient(gcConfig);
// Create the content generator with dynamic token management
return new QwenContentGenerator(qwenClient, config.model, gcConfig);
return new QwenContentGenerator(qwenClient, config, gcConfig);
} catch (error) {
throw new Error(
`Failed to initialize Qwen: ${error instanceof Error ? error.message : String(error)}`,

View File

@@ -7,6 +7,7 @@
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
import { OpenAIContentGenerator } from './openaiContentGenerator.js';
import { Config } from '../config/config.js';
import { AuthType } from './contentGenerator.js';
import OpenAI from 'openai';
import type {
GenerateContentParameters,
@@ -84,7 +85,20 @@ describe('OpenAIContentGenerator', () => {
vi.mocked(OpenAI).mockImplementation(() => mockOpenAIClient);
// Create generator instance
generator = new OpenAIContentGenerator('test-key', 'gpt-4', mockConfig);
const contentGeneratorConfig = {
model: 'gpt-4',
apiKey: 'test-key',
authType: AuthType.USE_OPENAI,
enableOpenAILogging: false,
timeout: 120000,
maxRetries: 3,
samplingParams: {
temperature: 0.7,
max_tokens: 1000,
top_p: 0.9,
},
};
generator = new OpenAIContentGenerator(contentGeneratorConfig, mockConfig);
});
afterEach(() => {
@@ -95,7 +109,7 @@ describe('OpenAIContentGenerator', () => {
it('should initialize with basic configuration', () => {
expect(OpenAI).toHaveBeenCalledWith({
apiKey: 'test-key',
baseURL: '',
baseURL: undefined,
timeout: 120000,
maxRetries: 3,
defaultHeaders: {
@@ -105,9 +119,16 @@ describe('OpenAIContentGenerator', () => {
});
it('should handle custom base URL', () => {
vi.stubEnv('OPENAI_BASE_URL', 'https://api.custom.com');
new OpenAIContentGenerator('test-key', 'gpt-4', mockConfig);
const contentGeneratorConfig = {
model: 'gpt-4',
apiKey: 'test-key',
baseUrl: 'https://api.custom.com',
authType: AuthType.USE_OPENAI,
enableOpenAILogging: false,
timeout: 120000,
maxRetries: 3,
};
new OpenAIContentGenerator(contentGeneratorConfig, mockConfig);
expect(OpenAI).toHaveBeenCalledWith({
apiKey: 'test-key',
@@ -121,9 +142,16 @@ describe('OpenAIContentGenerator', () => {
});
it('should configure OpenRouter headers when using OpenRouter', () => {
vi.stubEnv('OPENAI_BASE_URL', 'https://openrouter.ai/api/v1');
new OpenAIContentGenerator('test-key', 'gpt-4', mockConfig);
const contentGeneratorConfig = {
model: 'gpt-4',
apiKey: 'test-key',
baseUrl: 'https://openrouter.ai/api/v1',
authType: AuthType.USE_OPENAI,
enableOpenAILogging: false,
timeout: 120000,
maxRetries: 3,
};
new OpenAIContentGenerator(contentGeneratorConfig, mockConfig);
expect(OpenAI).toHaveBeenCalledWith({
apiKey: 'test-key',
@@ -147,11 +175,18 @@ describe('OpenAIContentGenerator', () => {
getCliVersion: vi.fn().mockReturnValue('1.0.0'),
} as unknown as Config;
new OpenAIContentGenerator('test-key', 'gpt-4', customConfig);
const contentGeneratorConfig = {
model: 'gpt-4',
apiKey: 'test-key',
authType: AuthType.USE_OPENAI,
timeout: 300000,
maxRetries: 5,
};
new OpenAIContentGenerator(contentGeneratorConfig, customConfig);
expect(OpenAI).toHaveBeenCalledWith({
apiKey: 'test-key',
baseURL: '',
baseURL: undefined,
timeout: 300000,
maxRetries: 5,
defaultHeaders: {
@@ -906,9 +941,14 @@ describe('OpenAIContentGenerator', () => {
getCliVersion: vi.fn().mockReturnValue('1.0.0'),
} as unknown as Config;
const contentGeneratorConfig = {
model: 'gpt-4',
apiKey: 'test-key',
authType: AuthType.USE_OPENAI,
enableOpenAILogging: true,
};
const loggingGenerator = new OpenAIContentGenerator(
'test-key',
'gpt-4',
contentGeneratorConfig,
loggingConfig,
);
@@ -1029,9 +1069,14 @@ describe('OpenAIContentGenerator', () => {
getCliVersion: vi.fn().mockReturnValue('1.0.0'),
} as unknown as Config;
const contentGeneratorConfig = {
model: 'gpt-4',
apiKey: 'test-key',
authType: AuthType.USE_OPENAI,
enableOpenAILogging: true,
};
const loggingGenerator = new OpenAIContentGenerator(
'test-key',
'gpt-4',
contentGeneratorConfig,
loggingConfig,
);
@@ -1587,7 +1632,23 @@ describe('OpenAIContentGenerator', () => {
}
}
const testGenerator = new TestGenerator('test-key', 'gpt-4', mockConfig);
const contentGeneratorConfig = {
model: 'gpt-4',
apiKey: 'test-key',
authType: AuthType.USE_OPENAI,
enableOpenAILogging: false,
timeout: 120000,
maxRetries: 3,
samplingParams: {
temperature: 0.7,
max_tokens: 1000,
top_p: 0.9,
},
};
const testGenerator = new TestGenerator(
contentGeneratorConfig,
mockConfig,
);
const consoleSpy = vi
.spyOn(console, 'error')
.mockImplementation(() => {});
@@ -1908,9 +1969,18 @@ describe('OpenAIContentGenerator', () => {
getCliVersion: vi.fn().mockReturnValue('1.0.0'),
} as unknown as Config;
const contentGeneratorConfig = {
model: 'gpt-4',
apiKey: 'test-key',
authType: AuthType.USE_OPENAI,
enableOpenAILogging: true,
samplingParams: {
temperature: 0.8,
max_tokens: 500,
},
};
const loggingGenerator = new OpenAIContentGenerator(
'test-key',
'gpt-4',
contentGeneratorConfig,
loggingConfig,
);
@@ -2093,9 +2163,14 @@ describe('OpenAIContentGenerator', () => {
getCliVersion: vi.fn().mockReturnValue('1.0.0'),
} as unknown as Config;
const contentGeneratorConfig = {
model: 'gpt-4',
apiKey: 'test-key',
authType: AuthType.USE_OPENAI,
enableOpenAILogging: true,
};
const loggingGenerator = new OpenAIContentGenerator(
'test-key',
'gpt-4',
contentGeneratorConfig,
loggingConfig,
);
@@ -2350,9 +2425,18 @@ describe('OpenAIContentGenerator', () => {
getCliVersion: vi.fn().mockReturnValue('1.0.0'),
} as unknown as Config;
const contentGeneratorConfig = {
model: 'gpt-4',
apiKey: 'test-key',
authType: AuthType.USE_OPENAI,
samplingParams: {
temperature: undefined,
max_tokens: undefined,
top_p: undefined,
},
};
const testGenerator = new OpenAIContentGenerator(
'test-key',
'gpt-4',
contentGeneratorConfig,
configWithUndefined,
);
@@ -2408,9 +2492,22 @@ describe('OpenAIContentGenerator', () => {
getCliVersion: vi.fn().mockReturnValue('1.0.0'),
} as unknown as Config;
const contentGeneratorConfig = {
model: 'gpt-4',
apiKey: 'test-key',
authType: AuthType.USE_OPENAI,
samplingParams: {
temperature: 0.8,
max_tokens: 1500,
top_p: 0.95,
top_k: 40,
repetition_penalty: 1.1,
presence_penalty: 0.5,
frequency_penalty: 0.3,
},
};
const testGenerator = new OpenAIContentGenerator(
'test-key',
'gpt-4',
contentGeneratorConfig,
fullSamplingConfig,
);
@@ -2489,9 +2586,14 @@ describe('OpenAIContentGenerator', () => {
getCliVersion: vi.fn().mockReturnValue('1.0.0'),
} as unknown as Config;
const contentGeneratorConfig = {
model: 'qwen-turbo',
apiKey: 'test-key',
authType: AuthType.QWEN_OAUTH,
enableOpenAILogging: false,
};
const qwenGenerator = new OpenAIContentGenerator(
'test-key',
'qwen-turbo',
contentGeneratorConfig,
qwenConfig,
);
@@ -2528,12 +2630,6 @@ describe('OpenAIContentGenerator', () => {
});
it('should include metadata when baseURL is dashscope openai compatible mode', async () => {
// Mock environment to set dashscope base URL BEFORE creating the generator
vi.stubEnv(
'OPENAI_BASE_URL',
'https://dashscope.aliyuncs.com/compatible-mode/v1',
);
const dashscopeConfig = {
getContentGeneratorConfig: vi.fn().mockReturnValue({
authType: 'openai', // Not QWEN_OAUTH
@@ -2543,9 +2639,15 @@ describe('OpenAIContentGenerator', () => {
getCliVersion: vi.fn().mockReturnValue('1.0.0'),
} as unknown as Config;
const contentGeneratorConfig = {
model: 'qwen-turbo',
apiKey: 'test-key',
baseUrl: 'https://dashscope.aliyuncs.com/compatible-mode/v1',
authType: AuthType.USE_OPENAI,
enableOpenAILogging: false,
};
const dashscopeGenerator = new OpenAIContentGenerator(
'test-key',
'qwen-turbo',
contentGeneratorConfig,
dashscopeConfig,
);
@@ -2604,9 +2706,18 @@ describe('OpenAIContentGenerator', () => {
getCliVersion: vi.fn().mockReturnValue('1.0.0'),
} as unknown as Config;
const contentGeneratorConfig = {
model: 'gpt-4',
apiKey: 'test-key',
authType: AuthType.USE_OPENAI,
enableOpenAILogging: false,
};
const regularGenerator = new OpenAIContentGenerator(
'test-key',
'gpt-4',
contentGeneratorConfig,
regularConfig,
);
@@ -2650,9 +2761,18 @@ describe('OpenAIContentGenerator', () => {
getCliVersion: vi.fn().mockReturnValue('1.0.0'),
} as unknown as Config;
const contentGeneratorConfig = {
model: 'gpt-4',
apiKey: 'test-key',
authType: AuthType.USE_OPENAI,
enableOpenAILogging: false,
};
const otherGenerator = new OpenAIContentGenerator(
'test-key',
'gpt-4',
contentGeneratorConfig,
otherAuthConfig,
);
@@ -2699,9 +2819,18 @@ describe('OpenAIContentGenerator', () => {
getCliVersion: vi.fn().mockReturnValue('1.0.0'),
} as unknown as Config;
const contentGeneratorConfig = {
model: 'gpt-4',
apiKey: 'test-key',
authType: AuthType.USE_OPENAI,
enableOpenAILogging: false,
};
const otherBaseUrlGenerator = new OpenAIContentGenerator(
'test-key',
'gpt-4',
contentGeneratorConfig,
otherBaseUrlConfig,
);
@@ -2748,9 +2877,18 @@ describe('OpenAIContentGenerator', () => {
getCliVersion: vi.fn().mockReturnValue('1.0.0'),
} as unknown as Config;
const contentGeneratorConfig = {
model: 'qwen-turbo',
apiKey: 'test-key',
authType: AuthType.QWEN_OAUTH,
enableOpenAILogging: false,
};
const qwenGenerator = new OpenAIContentGenerator(
'test-key',
'qwen-turbo',
contentGeneratorConfig,
qwenConfig,
);
@@ -2804,8 +2942,6 @@ describe('OpenAIContentGenerator', () => {
sessionId: 'streaming-session-id',
promptId: 'streaming-prompt-id',
},
stream: true,
stream_options: { include_usage: true },
}),
);
@@ -2827,9 +2963,18 @@ describe('OpenAIContentGenerator', () => {
getCliVersion: vi.fn().mockReturnValue('1.0.0'),
} as unknown as Config;
const contentGeneratorConfig = {
model: 'gpt-4',
apiKey: 'test-key',
authType: AuthType.USE_OPENAI,
enableOpenAILogging: false,
};
const regularGenerator = new OpenAIContentGenerator(
'test-key',
'gpt-4',
contentGeneratorConfig,
regularConfig,
);
@@ -2901,9 +3046,18 @@ describe('OpenAIContentGenerator', () => {
getCliVersion: vi.fn().mockReturnValue('1.0.0'),
} as unknown as Config;
const contentGeneratorConfig = {
model: 'qwen-turbo',
apiKey: 'test-key',
authType: AuthType.QWEN_OAUTH,
enableOpenAILogging: false,
};
const qwenGenerator = new OpenAIContentGenerator(
'test-key',
'qwen-turbo',
contentGeneratorConfig,
qwenConfig,
);
@@ -2955,9 +3109,18 @@ describe('OpenAIContentGenerator', () => {
getCliVersion: vi.fn().mockReturnValue('1.0.0'),
} as unknown as Config;
const contentGeneratorConfig = {
model: 'gpt-4',
apiKey: 'test-key',
authType: AuthType.USE_OPENAI,
enableOpenAILogging: false,
};
const noBaseUrlGenerator = new OpenAIContentGenerator(
'test-key',
'gpt-4',
contentGeneratorConfig,
noBaseUrlConfig,
);
@@ -3004,9 +3167,18 @@ describe('OpenAIContentGenerator', () => {
getCliVersion: vi.fn().mockReturnValue('1.0.0'),
} as unknown as Config;
const contentGeneratorConfig = {
model: 'gpt-4',
apiKey: 'test-key',
authType: AuthType.USE_OPENAI,
enableOpenAILogging: false,
};
const undefinedAuthGenerator = new OpenAIContentGenerator(
'test-key',
'gpt-4',
contentGeneratorConfig,
undefinedAuthConfig,
);
@@ -3050,9 +3222,18 @@ describe('OpenAIContentGenerator', () => {
getCliVersion: vi.fn().mockReturnValue('1.0.0'),
} as unknown as Config;
const contentGeneratorConfig = {
model: 'gpt-4',
apiKey: 'test-key',
authType: AuthType.USE_OPENAI,
enableOpenAILogging: false,
};
const undefinedConfigGenerator = new OpenAIContentGenerator(
'test-key',
'gpt-4',
contentGeneratorConfig,
undefinedConfig,
);
@@ -3089,4 +3270,232 @@ describe('OpenAIContentGenerator', () => {
);
});
});
describe('cache control for DashScope', () => {
it('should add cache control to system message for DashScope providers', async () => {
// Mock environment to set dashscope base URL
vi.stubEnv(
'OPENAI_BASE_URL',
'https://dashscope.aliyuncs.com/compatible-mode/v1',
);
const dashscopeConfig = {
getContentGeneratorConfig: vi.fn().mockReturnValue({
authType: 'openai',
enableOpenAILogging: false,
}),
getSessionId: vi.fn().mockReturnValue('dashscope-session-id'),
getCliVersion: vi.fn().mockReturnValue('1.0.0'),
} as unknown as Config;
const contentGeneratorConfig = {
model: 'qwen-turbo',
apiKey: 'test-key',
authType: AuthType.QWEN_OAUTH,
enableOpenAILogging: false,
};
const dashscopeGenerator = new OpenAIContentGenerator(
contentGeneratorConfig,
dashscopeConfig,
);
// Mock the client's baseURL property to return the expected value
Object.defineProperty(dashscopeGenerator['client'], 'baseURL', {
value: 'https://dashscope.aliyuncs.com/compatible-mode/v1',
writable: true,
});
const mockResponse = {
id: 'chatcmpl-123',
choices: [
{
index: 0,
message: { role: 'assistant', content: 'Response' },
finish_reason: 'stop',
},
],
created: 1677652288,
model: 'qwen-turbo',
};
mockOpenAIClient.chat.completions.create.mockResolvedValue(mockResponse);
const request: GenerateContentParameters = {
contents: [{ role: 'user', parts: [{ text: 'Hello' }] }],
config: {
systemInstruction: 'You are a helpful assistant.',
},
model: 'qwen-turbo',
};
await dashscopeGenerator.generateContent(request, 'dashscope-prompt-id');
// Should include cache control in system message
expect(mockOpenAIClient.chat.completions.create).toHaveBeenCalledWith(
expect.objectContaining({
messages: expect.arrayContaining([
expect.objectContaining({
role: 'system',
content: expect.arrayContaining([
expect.objectContaining({
type: 'text',
text: 'You are a helpful assistant.',
cache_control: { type: 'ephemeral' },
}),
]),
}),
]),
}),
);
});
it('should add cache control to last message for DashScope providers', async () => {
// Mock environment to set dashscope base URL
vi.stubEnv(
'OPENAI_BASE_URL',
'https://dashscope.aliyuncs.com/compatible-mode/v1',
);
const dashscopeConfig = {
getContentGeneratorConfig: vi.fn().mockReturnValue({
authType: 'openai',
enableOpenAILogging: false,
}),
getSessionId: vi.fn().mockReturnValue('dashscope-session-id'),
getCliVersion: vi.fn().mockReturnValue('1.0.0'),
} as unknown as Config;
const contentGeneratorConfig = {
model: 'qwen-turbo',
apiKey: 'test-key',
authType: AuthType.QWEN_OAUTH,
enableOpenAILogging: false,
};
const dashscopeGenerator = new OpenAIContentGenerator(
contentGeneratorConfig,
dashscopeConfig,
);
// Mock the client's baseURL property to return the expected value
Object.defineProperty(dashscopeGenerator['client'], 'baseURL', {
value: 'https://dashscope.aliyuncs.com/compatible-mode/v1',
writable: true,
});
const mockResponse = {
id: 'chatcmpl-123',
choices: [
{
index: 0,
message: { role: 'assistant', content: 'Response' },
finish_reason: 'stop',
},
],
created: 1677652288,
model: 'qwen-turbo',
};
mockOpenAIClient.chat.completions.create.mockResolvedValue(mockResponse);
const request: GenerateContentParameters = {
contents: [{ role: 'user', parts: [{ text: 'Hello, how are you?' }] }],
model: 'qwen-turbo',
};
await dashscopeGenerator.generateContent(request, 'dashscope-prompt-id');
// Should include cache control in last message
expect(mockOpenAIClient.chat.completions.create).toHaveBeenCalledWith(
expect.objectContaining({
messages: expect.arrayContaining([
expect.objectContaining({
role: 'user',
content: expect.arrayContaining([
expect.objectContaining({
type: 'text',
text: 'Hello, how are you?',
cache_control: { type: 'ephemeral' },
}),
]),
}),
]),
}),
);
});
it('should NOT add cache control for non-DashScope providers', async () => {
const regularConfig = {
getContentGeneratorConfig: vi.fn().mockReturnValue({
authType: 'openai',
enableOpenAILogging: false,
}),
getSessionId: vi.fn().mockReturnValue('regular-session-id'),
getCliVersion: vi.fn().mockReturnValue('1.0.0'),
} as unknown as Config;
const contentGeneratorConfig = {
model: 'gpt-4',
apiKey: 'test-key',
authType: AuthType.USE_OPENAI,
enableOpenAILogging: false,
};
const regularGenerator = new OpenAIContentGenerator(
contentGeneratorConfig,
regularConfig,
);
const mockResponse = {
id: 'chatcmpl-123',
choices: [
{
index: 0,
message: { role: 'assistant', content: 'Response' },
finish_reason: 'stop',
},
],
created: 1677652288,
model: 'gpt-4',
};
mockOpenAIClient.chat.completions.create.mockResolvedValue(mockResponse);
const request: GenerateContentParameters = {
contents: [{ role: 'user', parts: [{ text: 'Hello' }] }],
config: {
systemInstruction: 'You are a helpful assistant.',
},
model: 'gpt-4',
};
await regularGenerator.generateContent(request, 'regular-prompt-id');
// Should NOT include cache control (messages should be strings, not arrays)
expect(mockOpenAIClient.chat.completions.create).toHaveBeenCalledWith(
expect.objectContaining({
messages: expect.arrayContaining([
expect.objectContaining({
role: 'system',
content: 'You are a helpful assistant.',
}),
expect.objectContaining({
role: 'user',
content: 'Hello',
}),
]),
}),
);
});
});
});

View File

@@ -20,7 +20,11 @@ import {
FunctionCall,
FunctionResponse,
} from '@google/genai';
import { AuthType, ContentGenerator } from './contentGenerator.js';
import {
AuthType,
ContentGenerator,
ContentGeneratorConfig,
} from './contentGenerator.js';
import OpenAI from 'openai';
import { logApiError, logApiResponse } from '../telemetry/loggers.js';
import { ApiErrorEvent, ApiResponseEvent } from '../telemetry/types.js';
@@ -28,6 +32,17 @@ import { Config } from '../config/config.js';
import { openaiLogger } from '../utils/openaiLogger.js';
import { safeJsonParse } from '../utils/safeJsonParse.js';
// Extended types to support cache_control
interface ChatCompletionContentPartTextWithCache
extends OpenAI.Chat.ChatCompletionContentPartText {
cache_control?: { type: 'ephemeral' };
}
type ChatCompletionContentPartWithCache =
| ChatCompletionContentPartTextWithCache
| OpenAI.Chat.ChatCompletionContentPartImage
| OpenAI.Chat.ChatCompletionContentPartRefusal;
// OpenAI API type definitions for logging
interface OpenAIToolCall {
id: string;
@@ -38,9 +53,15 @@ interface OpenAIToolCall {
};
}
interface OpenAIContentItem {
type: 'text';
text: string;
cache_control?: { type: 'ephemeral' };
}
interface OpenAIMessage {
role: 'system' | 'user' | 'assistant' | 'tool';
content: string | null;
content: string | null | OpenAIContentItem[];
tool_calls?: OpenAIToolCall[];
tool_call_id?: string;
}
@@ -60,15 +81,6 @@ interface OpenAIChoice {
finish_reason: string;
}
interface OpenAIRequestFormat {
model: string;
messages: OpenAIMessage[];
temperature?: number;
max_tokens?: number;
top_p?: number;
tools?: unknown[];
}
interface OpenAIResponseFormat {
id: string;
object: string;
@@ -81,6 +93,7 @@ interface OpenAIResponseFormat {
export class OpenAIContentGenerator implements ContentGenerator {
protected client: OpenAI;
private model: string;
private contentGeneratorConfig: ContentGeneratorConfig;
private config: Config;
private streamingToolCalls: Map<
number,
@@ -91,50 +104,40 @@ export class OpenAIContentGenerator implements ContentGenerator {
}
> = new Map();
constructor(apiKey: string, model: string, config: Config) {
this.model = model;
this.config = config;
const baseURL = process.env.OPENAI_BASE_URL || '';
constructor(
contentGeneratorConfig: ContentGeneratorConfig,
gcConfig: Config,
) {
this.model = contentGeneratorConfig.model;
this.contentGeneratorConfig = contentGeneratorConfig;
this.config = gcConfig;
// Configure timeout settings - using progressive timeouts
const timeoutConfig = {
// Base timeout for most requests (2 minutes)
timeout: 120000,
// Maximum retries for failed requests
maxRetries: 3,
// HTTP client options
httpAgent: undefined, // Let the client use default agent
};
// Allow config to override timeout settings
const contentGeneratorConfig = this.config.getContentGeneratorConfig();
if (contentGeneratorConfig?.timeout) {
timeoutConfig.timeout = contentGeneratorConfig.timeout;
}
if (contentGeneratorConfig?.maxRetries !== undefined) {
timeoutConfig.maxRetries = contentGeneratorConfig.maxRetries;
}
const version = config.getCliVersion() || 'unknown';
const version = gcConfig.getCliVersion() || 'unknown';
const userAgent = `QwenCode/${version} (${process.platform}; ${process.arch})`;
// Check if using OpenRouter and add required headers
const isOpenRouter = baseURL.includes('openrouter.ai');
const isOpenRouterProvider = this.isOpenRouterProvider();
const isDashScopeProvider = this.isDashScopeProvider();
const defaultHeaders = {
'User-Agent': userAgent,
...(isOpenRouter
...(isOpenRouterProvider
? {
'HTTP-Referer': 'https://github.com/QwenLM/qwen-code.git',
'X-Title': 'Qwen Code',
}
: {}),
: isDashScopeProvider
? {
'X-DashScope-CacheControl': 'enable',
}
: {}),
};
this.client = new OpenAI({
apiKey,
baseURL,
timeout: timeoutConfig.timeout,
maxRetries: timeoutConfig.maxRetries,
apiKey: contentGeneratorConfig.apiKey,
baseURL: contentGeneratorConfig.baseUrl,
timeout: contentGeneratorConfig.timeout ?? 120000,
maxRetries: contentGeneratorConfig.maxRetries ?? 3,
defaultHeaders,
});
}
@@ -185,22 +188,25 @@ export class OpenAIContentGenerator implements ContentGenerator {
);
}
private isOpenRouterProvider(): boolean {
const baseURL = this.contentGeneratorConfig.baseUrl || '';
return baseURL.includes('openrouter.ai');
}
/**
* Determine if metadata should be included in the request.
* Only include the `metadata` field if the provider is QWEN_OAUTH
* or the baseUrl is 'https://dashscope.aliyuncs.com/compatible-mode/v1'.
* This is because some models/providers do not support metadata or need extra configuration.
* Determine if this is a DashScope provider.
* DashScope providers include QWEN_OAUTH auth type or specific DashScope base URLs.
*
* @returns true if metadata should be included, false otherwise
* @returns true if this is a DashScope provider, false otherwise
*/
private shouldIncludeMetadata(): boolean {
const authType = this.config.getContentGeneratorConfig?.()?.authType;
// baseUrl may be undefined; default to empty string if so
const baseUrl = this.client?.baseURL || '';
private isDashScopeProvider(): boolean {
const authType = this.contentGeneratorConfig.authType;
const baseUrl = this.contentGeneratorConfig.baseUrl;
return (
authType === AuthType.QWEN_OAUTH ||
baseUrl === 'https://dashscope.aliyuncs.com/compatible-mode/v1'
baseUrl === 'https://dashscope.aliyuncs.com/compatible-mode/v1' ||
baseUrl === 'https://dashscope-intl.aliyuncs.com/compatible-mode/v1'
);
}
@@ -213,7 +219,7 @@ export class OpenAIContentGenerator implements ContentGenerator {
private buildMetadata(
userPromptId: string,
): { metadata: { sessionId?: string; promptId: string } } | undefined {
if (!this.shouldIncludeMetadata()) {
if (!this.isDashScopeProvider()) {
return undefined;
}
@@ -225,35 +231,44 @@ export class OpenAIContentGenerator implements ContentGenerator {
};
}
private async buildCreateParams(
request: GenerateContentParameters,
userPromptId: string,
): Promise<Parameters<typeof this.client.chat.completions.create>[0]> {
const messages = this.convertToOpenAIFormat(request);
// Build sampling parameters with clear priority:
// 1. Request-level parameters (highest priority)
// 2. Config-level sampling parameters (medium priority)
// 3. Default values (lowest priority)
const samplingParams = this.buildSamplingParameters(request);
const createParams: Parameters<
typeof this.client.chat.completions.create
>[0] = {
model: this.model,
messages,
...samplingParams,
...(this.buildMetadata(userPromptId) || {}),
};
if (request.config?.tools) {
createParams.tools = await this.convertGeminiToolsToOpenAI(
request.config.tools,
);
}
return createParams;
}
async generateContent(
request: GenerateContentParameters,
userPromptId: string,
): Promise<GenerateContentResponse> {
const startTime = Date.now();
const messages = this.convertToOpenAIFormat(request);
const createParams = await this.buildCreateParams(request, userPromptId);
try {
// Build sampling parameters with clear priority:
// 1. Request-level parameters (highest priority)
// 2. Config-level sampling parameters (medium priority)
// 3. Default values (lowest priority)
const samplingParams = this.buildSamplingParameters(request);
const createParams: Parameters<
typeof this.client.chat.completions.create
>[0] = {
model: this.model,
messages,
...samplingParams,
...(this.buildMetadata(userPromptId) || {}),
};
if (request.config?.tools) {
createParams.tools = await this.convertGeminiToolsToOpenAI(
request.config.tools,
);
}
// console.log('createParams', createParams);
const completion = (await this.client.chat.completions.create(
createParams,
)) as OpenAI.Chat.ChatCompletion;
@@ -267,15 +282,15 @@ export class OpenAIContentGenerator implements ContentGenerator {
this.model,
durationMs,
userPromptId,
this.config.getContentGeneratorConfig()?.authType,
this.contentGeneratorConfig.authType,
response.usageMetadata,
);
logApiResponse(this.config, responseEvent);
// Log interaction if enabled
if (this.config.getContentGeneratorConfig()?.enableOpenAILogging) {
const openaiRequest = await this.convertGeminiRequestToOpenAI(request);
if (this.contentGeneratorConfig.enableOpenAILogging) {
const openaiRequest = createParams;
const openaiResponse = this.convertGeminiResponseToOpenAI(response);
await openaiLogger.logInteraction(openaiRequest, openaiResponse);
}
@@ -300,7 +315,7 @@ export class OpenAIContentGenerator implements ContentGenerator {
errorMessage,
durationMs,
userPromptId,
this.config.getContentGeneratorConfig()?.authType,
this.contentGeneratorConfig.authType,
// eslint-disable-next-line @typescript-eslint/no-explicit-any
(error as any).type,
// eslint-disable-next-line @typescript-eslint/no-explicit-any
@@ -309,10 +324,9 @@ export class OpenAIContentGenerator implements ContentGenerator {
logApiError(this.config, errorEvent);
// Log error interaction if enabled
if (this.config.getContentGeneratorConfig()?.enableOpenAILogging) {
const openaiRequest = await this.convertGeminiRequestToOpenAI(request);
if (this.contentGeneratorConfig.enableOpenAILogging) {
await openaiLogger.logInteraction(
openaiRequest,
createParams,
undefined,
error as Error,
);
@@ -343,29 +357,12 @@ export class OpenAIContentGenerator implements ContentGenerator {
userPromptId: string,
): Promise<AsyncGenerator<GenerateContentResponse>> {
const startTime = Date.now();
const messages = this.convertToOpenAIFormat(request);
const createParams = await this.buildCreateParams(request, userPromptId);
createParams.stream = true;
createParams.stream_options = { include_usage: true };
try {
// Build sampling parameters with clear priority
const samplingParams = this.buildSamplingParameters(request);
const createParams: Parameters<
typeof this.client.chat.completions.create
>[0] = {
model: this.model,
messages,
...samplingParams,
stream: true,
stream_options: { include_usage: true },
...(this.buildMetadata(userPromptId) || {}),
};
if (request.config?.tools) {
createParams.tools = await this.convertGeminiToolsToOpenAI(
request.config.tools,
);
}
const stream = (await this.client.chat.completions.create(
createParams,
)) as AsyncIterable<OpenAI.Chat.ChatCompletionChunk>;
@@ -397,16 +394,15 @@ export class OpenAIContentGenerator implements ContentGenerator {
this.model,
durationMs,
userPromptId,
this.config.getContentGeneratorConfig()?.authType,
this.contentGeneratorConfig.authType,
finalUsageMetadata,
);
logApiResponse(this.config, responseEvent);
// Log interaction if enabled (same as generateContent method)
if (this.config.getContentGeneratorConfig()?.enableOpenAILogging) {
const openaiRequest =
await this.convertGeminiRequestToOpenAI(request);
if (this.contentGeneratorConfig.enableOpenAILogging) {
const openaiRequest = createParams;
// For streaming, we combine all responses into a single response for logging
const combinedResponse =
this.combineStreamResponsesForLogging(responses);
@@ -433,7 +429,7 @@ export class OpenAIContentGenerator implements ContentGenerator {
errorMessage,
durationMs,
userPromptId,
this.config.getContentGeneratorConfig()?.authType,
this.contentGeneratorConfig.authType,
// eslint-disable-next-line @typescript-eslint/no-explicit-any
(error as any).type,
// eslint-disable-next-line @typescript-eslint/no-explicit-any
@@ -442,11 +438,9 @@ export class OpenAIContentGenerator implements ContentGenerator {
logApiError(this.config, errorEvent);
// Log error interaction if enabled
if (this.config.getContentGeneratorConfig()?.enableOpenAILogging) {
const openaiRequest =
await this.convertGeminiRequestToOpenAI(request);
if (this.contentGeneratorConfig.enableOpenAILogging) {
await openaiLogger.logInteraction(
openaiRequest,
createParams,
undefined,
error as Error,
);
@@ -487,7 +481,7 @@ export class OpenAIContentGenerator implements ContentGenerator {
errorMessage,
durationMs,
userPromptId,
this.config.getContentGeneratorConfig()?.authType,
this.contentGeneratorConfig.authType,
// eslint-disable-next-line @typescript-eslint/no-explicit-any
(error as any).type,
// eslint-disable-next-line @typescript-eslint/no-explicit-any
@@ -944,7 +938,114 @@ export class OpenAIContentGenerator implements ContentGenerator {
// Clean up orphaned tool calls and merge consecutive assistant messages
const cleanedMessages = this.cleanOrphanedToolCalls(messages);
return this.mergeConsecutiveAssistantMessages(cleanedMessages);
const mergedMessages =
this.mergeConsecutiveAssistantMessages(cleanedMessages);
// Add cache control to system and last messages for DashScope providers
return this.addCacheControlFlag(mergedMessages, 'both');
}
/**
* Add cache control flag to specified message(s) for DashScope providers
*/
private addCacheControlFlag(
messages: OpenAI.Chat.ChatCompletionMessageParam[],
target: 'system' | 'last' | 'both' = 'both',
): OpenAI.Chat.ChatCompletionMessageParam[] {
if (!this.isDashScopeProvider() || messages.length === 0) {
return messages;
}
let updatedMessages = [...messages];
// Add cache control to system message if requested
if (target === 'system' || target === 'both') {
updatedMessages = this.addCacheControlToMessage(
updatedMessages,
'system',
);
}
// Add cache control to last message if requested
if (target === 'last' || target === 'both') {
updatedMessages = this.addCacheControlToMessage(updatedMessages, 'last');
}
return updatedMessages;
}
/**
* Helper method to add cache control to a specific message
*/
private addCacheControlToMessage(
messages: OpenAI.Chat.ChatCompletionMessageParam[],
target: 'system' | 'last',
): OpenAI.Chat.ChatCompletionMessageParam[] {
const updatedMessages = [...messages];
let messageIndex: number;
if (target === 'system') {
// Find the first system message
messageIndex = messages.findIndex((msg) => msg.role === 'system');
if (messageIndex === -1) {
return updatedMessages;
}
} else {
// Get the last message
messageIndex = messages.length - 1;
}
const message = updatedMessages[messageIndex];
// Only process messages that have content
if ('content' in message && message.content !== null) {
if (typeof message.content === 'string') {
// Convert string content to array format with cache control
const messageWithArrayContent = {
...message,
content: [
{
type: 'text',
text: message.content,
cache_control: { type: 'ephemeral' },
} as ChatCompletionContentPartTextWithCache,
],
};
updatedMessages[messageIndex] =
messageWithArrayContent as OpenAI.Chat.ChatCompletionMessageParam;
} else if (Array.isArray(message.content)) {
// If content is already an array, add cache_control to the last item
const contentArray = [
...message.content,
] as ChatCompletionContentPartWithCache[];
if (contentArray.length > 0) {
const lastItem = contentArray[contentArray.length - 1];
if (lastItem.type === 'text') {
// Add cache_control to the last text item
contentArray[contentArray.length - 1] = {
...lastItem,
cache_control: { type: 'ephemeral' },
} as ChatCompletionContentPartTextWithCache;
} else {
// If the last item is not text, add a new text item with cache_control
contentArray.push({
type: 'text',
text: '',
cache_control: { type: 'ephemeral' },
} as ChatCompletionContentPartTextWithCache);
}
const messageWithCache = {
...message,
content: contentArray,
};
updatedMessages[messageIndex] =
messageWithCache as OpenAI.Chat.ChatCompletionMessageParam;
}
}
}
return updatedMessages;
}
/**
@@ -1368,8 +1469,7 @@ export class OpenAIContentGenerator implements ContentGenerator {
private buildSamplingParameters(
request: GenerateContentParameters,
): Record<string, unknown> {
const configSamplingParams =
this.config.getContentGeneratorConfig()?.samplingParams;
const configSamplingParams = this.contentGeneratorConfig.samplingParams;
const params = {
// Temperature: config > request > default
@@ -1431,313 +1531,6 @@ export class OpenAIContentGenerator implements ContentGenerator {
return mapping[openaiReason] || FinishReason.FINISH_REASON_UNSPECIFIED;
}
/**
* Convert Gemini request format to OpenAI chat completion format for logging
*/
private async convertGeminiRequestToOpenAI(
request: GenerateContentParameters,
): Promise<OpenAIRequestFormat> {
const messages: OpenAIMessage[] = [];
// Handle system instruction
if (request.config?.systemInstruction) {
const systemInstruction = request.config.systemInstruction;
let systemText = '';
if (Array.isArray(systemInstruction)) {
systemText = systemInstruction
.map((content) => {
if (typeof content === 'string') return content;
if ('parts' in content) {
const contentObj = content as Content;
return (
contentObj.parts
?.map((p: Part) =>
typeof p === 'string' ? p : 'text' in p ? p.text : '',
)
.join('\n') || ''
);
}
return '';
})
.join('\n');
} else if (typeof systemInstruction === 'string') {
systemText = systemInstruction;
} else if (
typeof systemInstruction === 'object' &&
'parts' in systemInstruction
) {
const systemContent = systemInstruction as Content;
systemText =
systemContent.parts
?.map((p: Part) =>
typeof p === 'string' ? p : 'text' in p ? p.text : '',
)
.join('\n') || '';
}
if (systemText) {
messages.push({
role: 'system',
content: systemText,
});
}
}
// Handle contents
if (Array.isArray(request.contents)) {
for (const content of request.contents) {
if (typeof content === 'string') {
messages.push({ role: 'user', content });
} else if ('role' in content && 'parts' in content) {
const functionCalls: FunctionCall[] = [];
const functionResponses: FunctionResponse[] = [];
const textParts: string[] = [];
for (const part of content.parts || []) {
if (typeof part === 'string') {
textParts.push(part);
} else if ('text' in part && part.text) {
textParts.push(part.text);
} else if ('functionCall' in part && part.functionCall) {
functionCalls.push(part.functionCall);
} else if ('functionResponse' in part && part.functionResponse) {
functionResponses.push(part.functionResponse);
}
}
// Handle function responses (tool results)
if (functionResponses.length > 0) {
for (const funcResponse of functionResponses) {
messages.push({
role: 'tool',
tool_call_id: funcResponse.id || '',
content:
typeof funcResponse.response === 'string'
? funcResponse.response
: JSON.stringify(funcResponse.response),
});
}
}
// Handle model messages with function calls
else if (content.role === 'model' && functionCalls.length > 0) {
const toolCalls = functionCalls.map((fc, index) => ({
id: fc.id || `call_${index}`,
type: 'function' as const,
function: {
name: fc.name || '',
arguments: JSON.stringify(fc.args || {}),
},
}));
messages.push({
role: 'assistant',
content: textParts.join('\n') || null,
tool_calls: toolCalls,
});
}
// Handle regular text messages
else {
const role = content.role === 'model' ? 'assistant' : 'user';
const text = textParts.join('\n');
if (text) {
messages.push({ role, content: text });
}
}
}
}
} else if (request.contents) {
if (typeof request.contents === 'string') {
messages.push({ role: 'user', content: request.contents });
} else if ('role' in request.contents && 'parts' in request.contents) {
const content = request.contents;
const role = content.role === 'model' ? 'assistant' : 'user';
const text =
content.parts
?.map((p: Part) =>
typeof p === 'string' ? p : 'text' in p ? p.text : '',
)
.join('\n') || '';
messages.push({ role, content: text });
}
}
// Clean up orphaned tool calls and merge consecutive assistant messages
const cleanedMessages = this.cleanOrphanedToolCallsForLogging(messages);
const mergedMessages =
this.mergeConsecutiveAssistantMessagesForLogging(cleanedMessages);
const openaiRequest: OpenAIRequestFormat = {
model: this.model,
messages: mergedMessages,
};
// Add sampling parameters using the same logic as actual API calls
const samplingParams = this.buildSamplingParameters(request);
Object.assign(openaiRequest, samplingParams);
// Convert tools if present
if (request.config?.tools) {
openaiRequest.tools = await this.convertGeminiToolsToOpenAI(
request.config.tools,
);
}
return openaiRequest;
}
/**
* Clean up orphaned tool calls for logging purposes
*/
private cleanOrphanedToolCallsForLogging(
messages: OpenAIMessage[],
): OpenAIMessage[] {
const cleaned: OpenAIMessage[] = [];
const toolCallIds = new Set<string>();
const toolResponseIds = new Set<string>();
// First pass: collect all tool call IDs and tool response IDs
for (const message of messages) {
if (message.role === 'assistant' && message.tool_calls) {
for (const toolCall of message.tool_calls) {
if (toolCall.id) {
toolCallIds.add(toolCall.id);
}
}
} else if (message.role === 'tool' && message.tool_call_id) {
toolResponseIds.add(message.tool_call_id);
}
}
// Second pass: filter out orphaned messages
for (const message of messages) {
if (message.role === 'assistant' && message.tool_calls) {
// Filter out tool calls that don't have corresponding responses
const validToolCalls = message.tool_calls.filter(
(toolCall) => toolCall.id && toolResponseIds.has(toolCall.id),
);
if (validToolCalls.length > 0) {
// Keep the message but only with valid tool calls
const cleanedMessage = { ...message };
cleanedMessage.tool_calls = validToolCalls;
cleaned.push(cleanedMessage);
} else if (
typeof message.content === 'string' &&
message.content.trim()
) {
// Keep the message if it has text content, but remove tool calls
const cleanedMessage = { ...message };
delete cleanedMessage.tool_calls;
cleaned.push(cleanedMessage);
}
// If no valid tool calls and no content, skip the message entirely
} else if (message.role === 'tool' && message.tool_call_id) {
// Only keep tool responses that have corresponding tool calls
if (toolCallIds.has(message.tool_call_id)) {
cleaned.push(message);
}
} else {
// Keep all other messages as-is
cleaned.push(message);
}
}
// Final validation: ensure every assistant message with tool_calls has corresponding tool responses
const finalCleaned: OpenAIMessage[] = [];
const finalToolCallIds = new Set<string>();
// Collect all remaining tool call IDs
for (const message of cleaned) {
if (message.role === 'assistant' && message.tool_calls) {
for (const toolCall of message.tool_calls) {
if (toolCall.id) {
finalToolCallIds.add(toolCall.id);
}
}
}
}
// Verify all tool calls have responses
const finalToolResponseIds = new Set<string>();
for (const message of cleaned) {
if (message.role === 'tool' && message.tool_call_id) {
finalToolResponseIds.add(message.tool_call_id);
}
}
// Remove any remaining orphaned tool calls
for (const message of cleaned) {
if (message.role === 'assistant' && message.tool_calls) {
const finalValidToolCalls = message.tool_calls.filter(
(toolCall) => toolCall.id && finalToolResponseIds.has(toolCall.id),
);
if (finalValidToolCalls.length > 0) {
const cleanedMessage = { ...message };
cleanedMessage.tool_calls = finalValidToolCalls;
finalCleaned.push(cleanedMessage);
} else if (
typeof message.content === 'string' &&
message.content.trim()
) {
const cleanedMessage = { ...message };
delete cleanedMessage.tool_calls;
finalCleaned.push(cleanedMessage);
}
} else {
finalCleaned.push(message);
}
}
return finalCleaned;
}
/**
* Merge consecutive assistant messages to combine split text and tool calls for logging
*/
private mergeConsecutiveAssistantMessagesForLogging(
messages: OpenAIMessage[],
): OpenAIMessage[] {
const merged: OpenAIMessage[] = [];
for (const message of messages) {
if (message.role === 'assistant' && merged.length > 0) {
const lastMessage = merged[merged.length - 1];
// If the last message is also an assistant message, merge them
if (lastMessage.role === 'assistant') {
// Combine content
const combinedContent = [
lastMessage.content || '',
message.content || '',
]
.filter(Boolean)
.join('');
// Combine tool calls
const combinedToolCalls = [
...(lastMessage.tool_calls || []),
...(message.tool_calls || []),
];
// Update the last message with combined data
lastMessage.content = combinedContent || null;
if (combinedToolCalls.length > 0) {
lastMessage.tool_calls = combinedToolCalls;
}
continue; // Skip adding the current message since it's been merged
}
}
// Add the message as-is if no merging is needed
merged.push(message);
}
return merged;
}
/**
* Convert Gemini response format to OpenAI chat completion format for logging
*/

View File

@@ -21,6 +21,7 @@ import {
} from '@google/genai';
import { QwenContentGenerator } from './qwenContentGenerator.js';
import { Config } from '../config/config.js';
import { AuthType, ContentGeneratorConfig } from '../core/contentGenerator.js';
// Mock the OpenAIContentGenerator parent class
vi.mock('../core/openaiContentGenerator.js', () => ({
@@ -30,10 +31,13 @@ vi.mock('../core/openaiContentGenerator.js', () => ({
baseURL: string;
};
constructor(apiKey: string, _model: string, _config: Config) {
constructor(
contentGeneratorConfig: ContentGeneratorConfig,
_config: Config,
) {
this.client = {
apiKey,
baseURL: 'https://api.openai.com/v1',
apiKey: contentGeneratorConfig.apiKey || 'test-key',
baseURL: contentGeneratorConfig.baseUrl || 'https://api.openai.com/v1',
};
}
@@ -131,9 +135,13 @@ describe('QwenContentGenerator', () => {
};
// Create QwenContentGenerator instance
const contentGeneratorConfig = {
model: 'qwen-turbo',
authType: AuthType.QWEN_OAUTH,
};
qwenContentGenerator = new QwenContentGenerator(
mockQwenClient,
'qwen-turbo',
contentGeneratorConfig,
mockConfig,
);
});

View File

@@ -20,6 +20,7 @@ import {
EmbedContentParameters,
EmbedContentResponse,
} from '@google/genai';
import { ContentGeneratorConfig } from '../core/contentGenerator.js';
// Default fallback base URL if no endpoint is provided
const DEFAULT_QWEN_BASE_URL =
@@ -36,9 +37,13 @@ export class QwenContentGenerator extends OpenAIContentGenerator {
private currentEndpoint: string | null = null;
private refreshPromise: Promise<string> | null = null;
constructor(qwenClient: IQwenOAuth2Client, model: string, config: Config) {
constructor(
qwenClient: IQwenOAuth2Client,
contentGeneratorConfig: ContentGeneratorConfig,
config: Config,
) {
// Initialize with empty API key, we'll override it dynamically
super('', model, config);
super(contentGeneratorConfig, config);
this.qwenClient = qwenClient;
// Set default base URL, will be updated dynamically