mirror of
https://github.com/QwenLM/qwen-code.git
synced 2025-12-21 17:27:54 +00:00
refactor: openaiContentGenerator (#501)
* refactor: openaiContentGenerator * refactor: optimize stream handling * refactor: re-organize refactored files * fix: unit test cases * fix: try to fix parallel tool calls with irregular chunk content
This commit is contained in:
@@ -22,7 +22,96 @@ import {
|
||||
import { QwenContentGenerator } from './qwenContentGenerator.js';
|
||||
import { SharedTokenManager } from './sharedTokenManager.js';
|
||||
import { Config } from '../config/config.js';
|
||||
import { AuthType, ContentGeneratorConfig } from '../core/contentGenerator.js';
|
||||
import { AuthType } from '../core/contentGenerator.js';
|
||||
|
||||
// Mock OpenAI client to avoid real network calls
|
||||
vi.mock('openai', () => ({
|
||||
default: class MockOpenAI {
|
||||
chat = {
|
||||
completions: {
|
||||
create: vi.fn(),
|
||||
},
|
||||
};
|
||||
embeddings = {
|
||||
create: vi.fn(),
|
||||
};
|
||||
apiKey = '';
|
||||
baseURL = '';
|
||||
constructor(config: { apiKey: string; baseURL: string }) {
|
||||
this.apiKey = config.apiKey;
|
||||
this.baseURL = config.baseURL;
|
||||
}
|
||||
},
|
||||
}));
|
||||
|
||||
// Mock DashScope provider
|
||||
vi.mock('../core/openaiContentGenerator/provider/dashscope.js', () => ({
|
||||
DashScopeOpenAICompatibleProvider: class {
|
||||
constructor(_config: unknown, _cliConfig: unknown) {}
|
||||
},
|
||||
}));
|
||||
|
||||
// Mock ContentGenerationPipeline
|
||||
vi.mock('../core/openaiContentGenerator/pipeline.js', () => ({
|
||||
ContentGenerationPipeline: class {
|
||||
client: {
|
||||
apiKey: string;
|
||||
baseURL: string;
|
||||
chat: {
|
||||
completions: {
|
||||
create: ReturnType<typeof vi.fn>;
|
||||
};
|
||||
};
|
||||
embeddings: {
|
||||
create: ReturnType<typeof vi.fn>;
|
||||
};
|
||||
};
|
||||
|
||||
constructor(_config: unknown) {
|
||||
this.client = {
|
||||
apiKey: '',
|
||||
baseURL: '',
|
||||
chat: {
|
||||
completions: {
|
||||
create: vi.fn(),
|
||||
},
|
||||
},
|
||||
embeddings: {
|
||||
create: vi.fn(),
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
async execute(
|
||||
_request: GenerateContentParameters,
|
||||
_userPromptId: string,
|
||||
): Promise<GenerateContentResponse> {
|
||||
return createMockResponse('Test response');
|
||||
}
|
||||
|
||||
async executeStream(
|
||||
_request: GenerateContentParameters,
|
||||
_userPromptId: string,
|
||||
): Promise<AsyncGenerator<GenerateContentResponse>> {
|
||||
return (async function* () {
|
||||
yield createMockResponse('Stream chunk 1');
|
||||
yield createMockResponse('Stream chunk 2');
|
||||
})();
|
||||
}
|
||||
|
||||
async countTokens(
|
||||
_request: CountTokensParameters,
|
||||
): Promise<CountTokensResponse> {
|
||||
return { totalTokens: 15 };
|
||||
}
|
||||
|
||||
async embedContent(
|
||||
_request: EmbedContentParameters,
|
||||
): Promise<EmbedContentResponse> {
|
||||
return { embeddings: [{ values: [0.1, 0.2, 0.3] }] };
|
||||
}
|
||||
},
|
||||
}));
|
||||
|
||||
// Mock SharedTokenManager
|
||||
vi.mock('./sharedTokenManager.js', () => ({
|
||||
@@ -132,20 +221,21 @@ vi.mock('./sharedTokenManager.js', () => ({
|
||||
}));
|
||||
|
||||
// Mock the OpenAIContentGenerator parent class
|
||||
vi.mock('../core/openaiContentGenerator.js', () => ({
|
||||
vi.mock('../core/openaiContentGenerator/index.js', () => ({
|
||||
OpenAIContentGenerator: class {
|
||||
client: {
|
||||
apiKey: string;
|
||||
baseURL: string;
|
||||
pipeline: {
|
||||
client: {
|
||||
apiKey: string;
|
||||
baseURL: string;
|
||||
};
|
||||
};
|
||||
|
||||
constructor(
|
||||
contentGeneratorConfig: ContentGeneratorConfig,
|
||||
_config: Config,
|
||||
) {
|
||||
this.client = {
|
||||
apiKey: contentGeneratorConfig.apiKey || 'test-key',
|
||||
baseURL: contentGeneratorConfig.baseUrl || 'https://api.openai.com/v1',
|
||||
constructor(_config: Config, _provider: unknown) {
|
||||
this.pipeline = {
|
||||
client: {
|
||||
apiKey: 'test-key',
|
||||
baseURL: 'https://api.openai.com/v1',
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
@@ -167,7 +257,7 @@ vi.mock('../core/openaiContentGenerator.js', () => ({
|
||||
async countTokens(
|
||||
_request: CountTokensParameters,
|
||||
): Promise<CountTokensResponse> {
|
||||
return { totalTokens: 10 };
|
||||
return { totalTokens: 15 };
|
||||
}
|
||||
|
||||
async embedContent(
|
||||
@@ -220,7 +310,10 @@ describe('QwenContentGenerator', () => {
|
||||
// Mock Config
|
||||
mockConfig = {
|
||||
getContentGeneratorConfig: vi.fn().mockReturnValue({
|
||||
model: 'qwen-turbo',
|
||||
apiKey: 'test-api-key',
|
||||
authType: 'qwen',
|
||||
baseUrl: 'https://dashscope.aliyuncs.com/compatible-mode/v1',
|
||||
enableOpenAILogging: false,
|
||||
timeout: 120000,
|
||||
maxRetries: 3,
|
||||
@@ -230,6 +323,9 @@ describe('QwenContentGenerator', () => {
|
||||
top_p: 0.9,
|
||||
},
|
||||
}),
|
||||
getCliVersion: vi.fn().mockReturnValue('1.0.0'),
|
||||
getSessionId: vi.fn().mockReturnValue('test-session-id'),
|
||||
getUsageStatisticsEnabled: vi.fn().mockReturnValue(false),
|
||||
} as unknown as Config;
|
||||
|
||||
// Mock QwenOAuth2Client
|
||||
@@ -245,7 +341,11 @@ describe('QwenContentGenerator', () => {
|
||||
// Create QwenContentGenerator instance
|
||||
const contentGeneratorConfig = {
|
||||
model: 'qwen-turbo',
|
||||
apiKey: 'test-api-key',
|
||||
authType: AuthType.QWEN_OAUTH,
|
||||
baseUrl: 'https://dashscope.aliyuncs.com/compatible-mode/v1',
|
||||
timeout: 120000,
|
||||
maxRetries: 3,
|
||||
};
|
||||
qwenContentGenerator = new QwenContentGenerator(
|
||||
mockQwenClient,
|
||||
@@ -317,7 +417,7 @@ describe('QwenContentGenerator', () => {
|
||||
|
||||
const result = await qwenContentGenerator.countTokens(request);
|
||||
|
||||
expect(result.totalTokens).toBe(10);
|
||||
expect(result.totalTokens).toBe(15);
|
||||
expect(mockQwenClient.getAccessToken).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
@@ -494,8 +594,9 @@ describe('QwenContentGenerator', () => {
|
||||
parentPrototype.generateContent = vi.fn().mockImplementation(function (
|
||||
this: QwenContentGenerator,
|
||||
) {
|
||||
capturedBaseURL = (this as unknown as { client: { baseURL: string } })
|
||||
.client.baseURL;
|
||||
capturedBaseURL = (
|
||||
this as unknown as { pipeline: { client: { baseURL: string } } }
|
||||
).pipeline.client.baseURL;
|
||||
return createMockResponse('Generated content');
|
||||
});
|
||||
|
||||
@@ -534,8 +635,9 @@ describe('QwenContentGenerator', () => {
|
||||
parentPrototype.generateContent = vi.fn().mockImplementation(function (
|
||||
this: QwenContentGenerator,
|
||||
) {
|
||||
capturedBaseURL = (this as unknown as { client: { baseURL: string } })
|
||||
.client.baseURL;
|
||||
capturedBaseURL = (
|
||||
this as unknown as { pipeline: { client: { baseURL: string } } }
|
||||
).pipeline.client.baseURL;
|
||||
return createMockResponse('Generated content');
|
||||
});
|
||||
|
||||
@@ -572,8 +674,9 @@ describe('QwenContentGenerator', () => {
|
||||
parentPrototype.generateContent = vi.fn().mockImplementation(function (
|
||||
this: QwenContentGenerator,
|
||||
) {
|
||||
capturedBaseURL = (this as unknown as { client: { baseURL: string } })
|
||||
.client.baseURL;
|
||||
capturedBaseURL = (
|
||||
this as unknown as { pipeline: { client: { baseURL: string } } }
|
||||
).pipeline.client.baseURL;
|
||||
return createMockResponse('Generated content');
|
||||
});
|
||||
|
||||
@@ -610,8 +713,9 @@ describe('QwenContentGenerator', () => {
|
||||
parentPrototype.generateContent = vi.fn().mockImplementation(function (
|
||||
this: QwenContentGenerator,
|
||||
) {
|
||||
capturedBaseURL = (this as unknown as { client: { baseURL: string } })
|
||||
.client.baseURL;
|
||||
capturedBaseURL = (
|
||||
this as unknown as { pipeline: { client: { baseURL: string } } }
|
||||
).pipeline.client.baseURL;
|
||||
return createMockResponse('Generated content');
|
||||
});
|
||||
|
||||
@@ -631,20 +735,19 @@ describe('QwenContentGenerator', () => {
|
||||
});
|
||||
|
||||
describe('Client State Management', () => {
|
||||
it('should restore original client credentials after operations', async () => {
|
||||
it('should set dynamic credentials during operations', async () => {
|
||||
const client = (
|
||||
qwenContentGenerator as unknown as {
|
||||
client: { apiKey: string; baseURL: string };
|
||||
pipeline: { client: { apiKey: string; baseURL: string } };
|
||||
}
|
||||
).client;
|
||||
const originalApiKey = client.apiKey;
|
||||
const originalBaseURL = client.baseURL;
|
||||
).pipeline.client;
|
||||
|
||||
vi.mocked(mockQwenClient.getAccessToken).mockResolvedValue({
|
||||
token: 'temp-token',
|
||||
});
|
||||
vi.mocked(mockQwenClient.getCredentials).mockReturnValue({
|
||||
...mockCredentials,
|
||||
access_token: 'temp-token',
|
||||
resource_url: 'https://temp-endpoint.com',
|
||||
});
|
||||
|
||||
@@ -655,24 +758,25 @@ describe('QwenContentGenerator', () => {
|
||||
|
||||
await qwenContentGenerator.generateContent(request, 'test-prompt-id');
|
||||
|
||||
// Should restore original values after operation
|
||||
expect(client.apiKey).toBe(originalApiKey);
|
||||
expect(client.baseURL).toBe(originalBaseURL);
|
||||
// Should have dynamic credentials set
|
||||
expect(client.apiKey).toBe('temp-token');
|
||||
expect(client.baseURL).toBe('https://temp-endpoint.com/v1');
|
||||
});
|
||||
|
||||
it('should restore credentials even when operation throws', async () => {
|
||||
it('should set credentials even when operation throws', async () => {
|
||||
const client = (
|
||||
qwenContentGenerator as unknown as {
|
||||
client: { apiKey: string; baseURL: string };
|
||||
pipeline: { client: { apiKey: string; baseURL: string } };
|
||||
}
|
||||
).client;
|
||||
const originalApiKey = client.apiKey;
|
||||
const originalBaseURL = client.baseURL;
|
||||
).pipeline.client;
|
||||
|
||||
vi.mocked(mockQwenClient.getAccessToken).mockResolvedValue({
|
||||
token: 'temp-token',
|
||||
});
|
||||
vi.mocked(mockQwenClient.getCredentials).mockReturnValue(mockCredentials);
|
||||
vi.mocked(mockQwenClient.getCredentials).mockReturnValue({
|
||||
...mockCredentials,
|
||||
access_token: 'temp-token',
|
||||
});
|
||||
|
||||
// Mock the parent method to throw an error
|
||||
const mockError = new Error('Network error');
|
||||
@@ -693,9 +797,9 @@ describe('QwenContentGenerator', () => {
|
||||
expect(error).toBe(mockError);
|
||||
}
|
||||
|
||||
// Credentials should still be restored
|
||||
expect(client.apiKey).toBe(originalApiKey);
|
||||
expect(client.baseURL).toBe(originalBaseURL);
|
||||
// Credentials should still be set before the error occurred
|
||||
expect(client.apiKey).toBe('temp-token');
|
||||
expect(client.baseURL).toBe('https://test-endpoint.com/v1');
|
||||
|
||||
// Restore original method
|
||||
parentPrototype.generateContent = originalGenerateContent;
|
||||
@@ -1281,20 +1385,19 @@ describe('QwenContentGenerator', () => {
|
||||
});
|
||||
|
||||
describe('Stream Error Handling', () => {
|
||||
it('should restore credentials when stream generation fails', async () => {
|
||||
it('should set credentials when stream generation fails', async () => {
|
||||
const client = (
|
||||
qwenContentGenerator as unknown as {
|
||||
client: { apiKey: string; baseURL: string };
|
||||
pipeline: { client: { apiKey: string; baseURL: string } };
|
||||
}
|
||||
).client;
|
||||
const originalApiKey = client.apiKey;
|
||||
const originalBaseURL = client.baseURL;
|
||||
).pipeline.client;
|
||||
|
||||
vi.mocked(mockQwenClient.getAccessToken).mockResolvedValue({
|
||||
token: 'stream-token',
|
||||
});
|
||||
vi.mocked(mockQwenClient.getCredentials).mockReturnValue({
|
||||
...mockCredentials,
|
||||
access_token: 'stream-token',
|
||||
resource_url: 'https://stream-endpoint.com',
|
||||
});
|
||||
|
||||
@@ -1322,20 +1425,20 @@ describe('QwenContentGenerator', () => {
|
||||
expect(error).toBeInstanceOf(Error);
|
||||
}
|
||||
|
||||
// Credentials should be restored even on error
|
||||
expect(client.apiKey).toBe(originalApiKey);
|
||||
expect(client.baseURL).toBe(originalBaseURL);
|
||||
// Credentials should be set before the error occurred
|
||||
expect(client.apiKey).toBe('stream-token');
|
||||
expect(client.baseURL).toBe('https://stream-endpoint.com/v1');
|
||||
|
||||
// Restore original method
|
||||
parentPrototype.generateContentStream = originalGenerateContentStream;
|
||||
});
|
||||
|
||||
it('should not restore credentials in finally block for successful streams', async () => {
|
||||
it('should set credentials for successful streams', async () => {
|
||||
const client = (
|
||||
qwenContentGenerator as unknown as {
|
||||
client: { apiKey: string; baseURL: string };
|
||||
pipeline: { client: { apiKey: string; baseURL: string } };
|
||||
}
|
||||
).client;
|
||||
).pipeline.client;
|
||||
|
||||
// Set up the mock to return stream credentials
|
||||
const streamCredentials = {
|
||||
@@ -1368,11 +1471,12 @@ describe('QwenContentGenerator', () => {
|
||||
'test-prompt-id',
|
||||
);
|
||||
|
||||
// After successful stream creation, credentials should still be set for the stream
|
||||
// After successful stream creation, credentials should be set for the stream
|
||||
expect(client.apiKey).toBe('stream-token');
|
||||
expect(client.baseURL).toBe('https://stream-endpoint.com/v1');
|
||||
|
||||
// Consume the stream
|
||||
// Verify stream is iterable and consume it
|
||||
expect(stream).toBeDefined();
|
||||
const chunks = [];
|
||||
for await (const chunk of stream) {
|
||||
chunks.push(chunk);
|
||||
@@ -1478,15 +1582,21 @@ describe('QwenContentGenerator', () => {
|
||||
});
|
||||
|
||||
describe('Constructor and Initialization', () => {
|
||||
it('should initialize with default base URL', () => {
|
||||
it('should initialize with configured base URL when provided', () => {
|
||||
const generator = new QwenContentGenerator(
|
||||
mockQwenClient,
|
||||
{ model: 'qwen-turbo', authType: AuthType.QWEN_OAUTH },
|
||||
{
|
||||
model: 'qwen-turbo',
|
||||
authType: AuthType.QWEN_OAUTH,
|
||||
baseUrl: 'https://dashscope.aliyuncs.com/compatible-mode/v1',
|
||||
apiKey: 'test-key',
|
||||
},
|
||||
mockConfig,
|
||||
);
|
||||
|
||||
const client = (generator as unknown as { client: { baseURL: string } })
|
||||
.client;
|
||||
const client = (
|
||||
generator as unknown as { pipeline: { client: { baseURL: string } } }
|
||||
).pipeline.client;
|
||||
expect(client.baseURL).toBe(
|
||||
'https://dashscope.aliyuncs.com/compatible-mode/v1',
|
||||
);
|
||||
|
||||
@@ -4,7 +4,8 @@
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { OpenAIContentGenerator } from '../core/openaiContentGenerator.js';
|
||||
import { OpenAIContentGenerator } from '../core/openaiContentGenerator/index.js';
|
||||
import { DashScopeOpenAICompatibleProvider } from '../core/openaiContentGenerator/provider/dashscope.js';
|
||||
import { IQwenOAuth2Client } from './qwenOAuth2.js';
|
||||
import { SharedTokenManager } from './sharedTokenManager.js';
|
||||
import { Config } from '../config/config.js';
|
||||
@@ -33,15 +34,24 @@ export class QwenContentGenerator extends OpenAIContentGenerator {
|
||||
constructor(
|
||||
qwenClient: IQwenOAuth2Client,
|
||||
contentGeneratorConfig: ContentGeneratorConfig,
|
||||
config: Config,
|
||||
cliConfig: Config,
|
||||
) {
|
||||
// Initialize with empty API key, we'll override it dynamically
|
||||
super(contentGeneratorConfig, config);
|
||||
// Create DashScope provider for Qwen
|
||||
const dashscopeProvider = new DashScopeOpenAICompatibleProvider(
|
||||
contentGeneratorConfig,
|
||||
cliConfig,
|
||||
);
|
||||
|
||||
// Initialize with DashScope provider
|
||||
super(contentGeneratorConfig, cliConfig, dashscopeProvider);
|
||||
this.qwenClient = qwenClient;
|
||||
this.sharedManager = SharedTokenManager.getInstance();
|
||||
|
||||
// Set default base URL, will be updated dynamically
|
||||
this.client.baseURL = DEFAULT_QWEN_BASE_URL;
|
||||
if (contentGeneratorConfig?.baseUrl && contentGeneratorConfig?.apiKey) {
|
||||
this.pipeline.client.baseURL = contentGeneratorConfig?.baseUrl;
|
||||
this.pipeline.client.apiKey = contentGeneratorConfig?.apiKey;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -106,46 +116,24 @@ export class QwenContentGenerator extends OpenAIContentGenerator {
|
||||
* Execute an operation with automatic credential management and retry logic.
|
||||
* This method handles:
|
||||
* - Dynamic token and endpoint retrieval
|
||||
* - Temporary client configuration updates
|
||||
* - Automatic restoration of original configuration
|
||||
* - Client configuration updates
|
||||
* - Retry logic on authentication errors with token refresh
|
||||
*
|
||||
* @param operation - The operation to execute with updated client configuration
|
||||
* @param restoreOnCompletion - Whether to restore original config after operation completes
|
||||
* @returns The result of the operation
|
||||
*/
|
||||
private async executeWithCredentialManagement<T>(
|
||||
operation: () => Promise<T>,
|
||||
restoreOnCompletion: boolean = true,
|
||||
): Promise<T> {
|
||||
// Attempt the operation with credential management and retry logic
|
||||
const attemptOperation = async (): Promise<T> => {
|
||||
const { token, endpoint } = await this.getValidToken();
|
||||
|
||||
// Store original configuration
|
||||
const originalApiKey = this.client.apiKey;
|
||||
const originalBaseURL = this.client.baseURL;
|
||||
|
||||
// Apply dynamic configuration
|
||||
this.client.apiKey = token;
|
||||
this.client.baseURL = endpoint;
|
||||
this.pipeline.client.apiKey = token;
|
||||
this.pipeline.client.baseURL = endpoint;
|
||||
|
||||
try {
|
||||
const result = await operation();
|
||||
|
||||
// For streaming operations, we may need to keep the configuration active
|
||||
if (restoreOnCompletion) {
|
||||
this.client.apiKey = originalApiKey;
|
||||
this.client.baseURL = originalBaseURL;
|
||||
}
|
||||
|
||||
return result;
|
||||
} catch (error) {
|
||||
// Always restore on error
|
||||
this.client.apiKey = originalApiKey;
|
||||
this.client.baseURL = originalBaseURL;
|
||||
throw error;
|
||||
}
|
||||
return await operation();
|
||||
};
|
||||
|
||||
// Execute with retry logic for auth errors
|
||||
@@ -175,17 +163,14 @@ export class QwenContentGenerator extends OpenAIContentGenerator {
|
||||
}
|
||||
|
||||
/**
|
||||
* Override to use dynamic token and endpoint with automatic retry.
|
||||
* Note: For streaming, the client configuration is not restored immediately
|
||||
* since the generator may continue to be used after this method returns.
|
||||
* Override to use dynamic token and endpoint with automatic retry
|
||||
*/
|
||||
override async generateContentStream(
|
||||
request: GenerateContentParameters,
|
||||
userPromptId: string,
|
||||
): Promise<AsyncGenerator<GenerateContentResponse>> {
|
||||
return this.executeWithCredentialManagement(
|
||||
() => super.generateContentStream(request, userPromptId),
|
||||
false, // Don't restore immediately for streaming
|
||||
return this.executeWithCredentialManagement(() =>
|
||||
super.generateContentStream(request, userPromptId),
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user