refactor: re-organize refactored files

This commit is contained in:
mingholy.lmh
2025-09-04 12:00:00 +08:00
parent 65549193c1
commit 6005051713
27 changed files with 763 additions and 48 deletions

View File

@@ -208,10 +208,12 @@ export async function createContentGenerator(
} }
// Import OpenAIContentGenerator dynamically to avoid circular dependencies // Import OpenAIContentGenerator dynamically to avoid circular dependencies
const { createContentGenerator } = await import('./refactor/index.js'); const { createOpenAIContentGenerator } = await import(
'./openaiContentGenerator/index.js'
);
// Always use OpenAIContentGenerator, logging is controlled by enableOpenAILogging flag // Always use OpenAIContentGenerator, logging is controlled by enableOpenAILogging flag
return createContentGenerator(config, gcConfig); return createOpenAIContentGenerator(config, gcConfig);
} }
if (config.authType === AuthType.QWEN_OAUTH) { if (config.authType === AuthType.QWEN_OAUTH) {

View File

@@ -90,6 +90,11 @@ interface OpenAIResponseFormat {
usage?: OpenAIUsage; usage?: OpenAIUsage;
} }
/**
* @deprecated refactored to ./openaiContentGenerator
* use `createOpenAIContentGenerator` instead
* or extend `OpenAIContentGenerator` to add customized behavior
*/
export class OpenAIContentGenerator implements ContentGenerator { export class OpenAIContentGenerator implements ContentGenerator {
protected client: OpenAI; protected client: OpenAI;
private model: string; private model: string;

View File

@@ -31,7 +31,7 @@ export { OpenAIContentConverter } from './converter.js';
/** /**
* Create an OpenAI-compatible content generator with the appropriate provider * Create an OpenAI-compatible content generator with the appropriate provider
*/ */
export function createContentGenerator( export function createOpenAIContentGenerator(
contentGeneratorConfig: ContentGeneratorConfig, contentGeneratorConfig: ContentGeneratorConfig,
cliConfig: Config, cliConfig: Config,
): ContentGenerator { ): ContentGenerator {

View File

@@ -0,0 +1,698 @@
/**
* @license
* Copyright 2025 Qwen
* SPDX-License-Identifier: Apache-2.0
*/
import { describe, it, expect, beforeEach, vi, Mock } from 'vitest';
import OpenAI from 'openai';
import {
GenerateContentParameters,
GenerateContentResponse,
Type,
} from '@google/genai';
import { ContentGenerationPipeline, PipelineConfig } from './pipeline.js';
import { OpenAIContentConverter } from './converter.js';
import { Config } from '../../config/config.js';
import { ContentGeneratorConfig, AuthType } from '../contentGenerator.js';
import { OpenAICompatibleProvider } from './provider/index.js';
import { TelemetryService } from './telemetryService.js';
import { ErrorHandler } from './errorHandler.js';
// Mock dependencies
vi.mock('./converter.js');
vi.mock('openai');
describe('ContentGenerationPipeline', () => {
let pipeline: ContentGenerationPipeline;
let mockConfig: PipelineConfig;
let mockProvider: OpenAICompatibleProvider;
let mockClient: OpenAI;
let mockConverter: OpenAIContentConverter;
let mockTelemetryService: TelemetryService;
let mockErrorHandler: ErrorHandler;
let mockContentGeneratorConfig: ContentGeneratorConfig;
let mockCliConfig: Config;
beforeEach(() => {
// Reset all mocks
vi.clearAllMocks();
// Mock OpenAI client
mockClient = {
chat: {
completions: {
create: vi.fn(),
},
},
} as unknown as OpenAI;
// Mock converter
mockConverter = {
convertGeminiRequestToOpenAI: vi.fn(),
convertOpenAIResponseToGemini: vi.fn(),
convertOpenAIChunkToGemini: vi.fn(),
convertGeminiToolsToOpenAI: vi.fn(),
resetStreamingToolCalls: vi.fn(),
} as unknown as OpenAIContentConverter;
// Mock provider
mockProvider = {
buildClient: vi.fn().mockReturnValue(mockClient),
buildRequest: vi.fn().mockImplementation((req) => req),
buildHeaders: vi.fn().mockReturnValue({}),
};
// Mock telemetry service
mockTelemetryService = {
logSuccess: vi.fn().mockResolvedValue(undefined),
logError: vi.fn().mockResolvedValue(undefined),
logStreamingSuccess: vi.fn().mockResolvedValue(undefined),
};
// Mock error handler
mockErrorHandler = {
handle: vi.fn().mockImplementation((error: unknown) => {
throw error;
}),
shouldSuppressErrorLogging: vi.fn().mockReturnValue(false),
} as unknown as ErrorHandler;
// Mock configs
mockCliConfig = {} as Config;
mockContentGeneratorConfig = {
model: 'test-model',
authType: 'openai' as AuthType,
samplingParams: {
temperature: 0.7,
top_p: 0.9,
max_tokens: 1000,
},
} as ContentGeneratorConfig;
// Mock the OpenAIContentConverter constructor
(OpenAIContentConverter as unknown as Mock).mockImplementation(
() => mockConverter,
);
mockConfig = {
cliConfig: mockCliConfig,
provider: mockProvider,
contentGeneratorConfig: mockContentGeneratorConfig,
telemetryService: mockTelemetryService,
errorHandler: mockErrorHandler,
};
pipeline = new ContentGenerationPipeline(mockConfig);
});
describe('constructor', () => {
it('should initialize with correct configuration', () => {
expect(mockProvider.buildClient).toHaveBeenCalled();
expect(OpenAIContentConverter).toHaveBeenCalledWith('test-model');
});
});
describe('execute', () => {
it('should successfully execute non-streaming request', async () => {
// Arrange
const request: GenerateContentParameters = {
model: 'test-model',
contents: [{ parts: [{ text: 'Hello' }], role: 'user' }],
};
const userPromptId = 'test-prompt-id';
const mockMessages = [
{ role: 'user', content: 'Hello' },
] as OpenAI.Chat.ChatCompletionMessageParam[];
const mockOpenAIResponse = {
id: 'response-id',
choices: [
{ message: { content: 'Hello response' }, finish_reason: 'stop' },
],
created: Date.now(),
model: 'test-model',
usage: { prompt_tokens: 10, completion_tokens: 20, total_tokens: 30 },
} as OpenAI.Chat.ChatCompletion;
const mockGeminiResponse = new GenerateContentResponse();
(mockConverter.convertGeminiRequestToOpenAI as Mock).mockReturnValue(
mockMessages,
);
(mockConverter.convertOpenAIResponseToGemini as Mock).mockReturnValue(
mockGeminiResponse,
);
(mockClient.chat.completions.create as Mock).mockResolvedValue(
mockOpenAIResponse,
);
// Act
const result = await pipeline.execute(request, userPromptId);
// Assert
expect(result).toBe(mockGeminiResponse);
expect(mockConverter.convertGeminiRequestToOpenAI).toHaveBeenCalledWith(
request,
);
expect(mockClient.chat.completions.create).toHaveBeenCalledWith(
expect.objectContaining({
model: 'test-model',
messages: mockMessages,
temperature: 0.7,
top_p: 0.9,
max_tokens: 1000,
}),
);
expect(mockConverter.convertOpenAIResponseToGemini).toHaveBeenCalledWith(
mockOpenAIResponse,
);
expect(mockTelemetryService.logSuccess).toHaveBeenCalledWith(
expect.objectContaining({
userPromptId,
model: 'test-model',
authType: 'openai',
isStreaming: false,
}),
mockGeminiResponse,
expect.any(Object),
mockOpenAIResponse,
);
});
it('should handle tools in request', async () => {
// Arrange
const request: GenerateContentParameters = {
model: 'test-model',
contents: [{ parts: [{ text: 'Hello' }], role: 'user' }],
config: {
tools: [
{
functionDeclarations: [
{
name: 'test-function',
description: 'Test function',
parameters: { type: Type.OBJECT, properties: {} },
},
],
},
],
},
};
const userPromptId = 'test-prompt-id';
const mockMessages = [
{ role: 'user', content: 'Hello' },
] as OpenAI.Chat.ChatCompletionMessageParam[];
const mockTools = [
{ type: 'function', function: { name: 'test-function' } },
] as OpenAI.Chat.ChatCompletionTool[];
const mockOpenAIResponse = {
id: 'response-id',
choices: [
{ message: { content: 'Hello response' }, finish_reason: 'stop' },
],
} as OpenAI.Chat.ChatCompletion;
const mockGeminiResponse = new GenerateContentResponse();
(mockConverter.convertGeminiRequestToOpenAI as Mock).mockReturnValue(
mockMessages,
);
(mockConverter.convertGeminiToolsToOpenAI as Mock).mockResolvedValue(
mockTools,
);
(mockConverter.convertOpenAIResponseToGemini as Mock).mockReturnValue(
mockGeminiResponse,
);
(mockClient.chat.completions.create as Mock).mockResolvedValue(
mockOpenAIResponse,
);
// Act
const result = await pipeline.execute(request, userPromptId);
// Assert
expect(result).toBe(mockGeminiResponse);
expect(mockConverter.convertGeminiToolsToOpenAI).toHaveBeenCalledWith(
request.config!.tools,
);
expect(mockClient.chat.completions.create).toHaveBeenCalledWith(
expect.objectContaining({
tools: mockTools,
}),
);
});
it('should handle errors and log them', async () => {
// Arrange
const request: GenerateContentParameters = {
model: 'test-model',
contents: [{ parts: [{ text: 'Hello' }], role: 'user' }],
};
const userPromptId = 'test-prompt-id';
const testError = new Error('API Error');
(mockConverter.convertGeminiRequestToOpenAI as Mock).mockReturnValue([]);
(mockClient.chat.completions.create as Mock).mockRejectedValue(testError);
// Act & Assert
await expect(pipeline.execute(request, userPromptId)).rejects.toThrow(
'API Error',
);
expect(mockTelemetryService.logError).toHaveBeenCalledWith(
expect.objectContaining({
userPromptId,
model: 'test-model',
authType: 'openai',
isStreaming: false,
}),
testError,
expect.any(Object),
);
expect(mockErrorHandler.handle).toHaveBeenCalledWith(
testError,
expect.any(Object),
request,
);
});
});
describe('executeStream', () => {
it('should successfully execute streaming request', async () => {
// Arrange
const request: GenerateContentParameters = {
model: 'test-model',
contents: [{ parts: [{ text: 'Hello' }], role: 'user' }],
};
const userPromptId = 'test-prompt-id';
const mockChunk1 = {
id: 'chunk-1',
choices: [{ delta: { content: 'Hello' }, finish_reason: null }],
} as OpenAI.Chat.ChatCompletionChunk;
const mockChunk2 = {
id: 'chunk-2',
choices: [{ delta: { content: ' response' }, finish_reason: 'stop' }],
} as OpenAI.Chat.ChatCompletionChunk;
const mockStream = {
async *[Symbol.asyncIterator]() {
yield mockChunk1;
yield mockChunk2;
},
};
const mockGeminiResponse1 = new GenerateContentResponse();
const mockGeminiResponse2 = new GenerateContentResponse();
mockGeminiResponse1.candidates = [
{ content: { parts: [{ text: 'Hello' }], role: 'model' } },
];
mockGeminiResponse2.candidates = [
{ content: { parts: [{ text: ' response' }], role: 'model' } },
];
(mockConverter.convertGeminiRequestToOpenAI as Mock).mockReturnValue([]);
(mockConverter.convertOpenAIChunkToGemini as Mock)
.mockReturnValueOnce(mockGeminiResponse1)
.mockReturnValueOnce(mockGeminiResponse2);
(mockClient.chat.completions.create as Mock).mockResolvedValue(
mockStream,
);
// Act
const resultGenerator = await pipeline.executeStream(
request,
userPromptId,
);
const results = [];
for await (const result of resultGenerator) {
results.push(result);
}
// Assert
expect(results).toHaveLength(2);
expect(results[0]).toBe(mockGeminiResponse1);
expect(results[1]).toBe(mockGeminiResponse2);
expect(mockConverter.resetStreamingToolCalls).toHaveBeenCalled();
expect(mockClient.chat.completions.create).toHaveBeenCalledWith(
expect.objectContaining({
stream: true,
stream_options: { include_usage: true },
}),
);
expect(mockTelemetryService.logStreamingSuccess).toHaveBeenCalledWith(
expect.objectContaining({
userPromptId,
model: 'test-model',
authType: 'openai',
isStreaming: true,
}),
[mockGeminiResponse1, mockGeminiResponse2],
expect.any(Object),
[mockChunk1, mockChunk2],
);
});
it('should filter empty responses', async () => {
// Arrange
const request: GenerateContentParameters = {
model: 'test-model',
contents: [{ parts: [{ text: 'Hello' }], role: 'user' }],
};
const userPromptId = 'test-prompt-id';
const mockChunk1 = {
id: 'chunk-1',
choices: [{ delta: { content: '' }, finish_reason: null }],
} as OpenAI.Chat.ChatCompletionChunk;
const mockChunk2 = {
id: 'chunk-2',
choices: [
{ delta: { content: 'Hello response' }, finish_reason: 'stop' },
],
} as OpenAI.Chat.ChatCompletionChunk;
const mockStream = {
async *[Symbol.asyncIterator]() {
yield mockChunk1;
yield mockChunk2;
},
};
const mockEmptyResponse = new GenerateContentResponse();
mockEmptyResponse.candidates = [
{ content: { parts: [], role: 'model' } },
];
const mockValidResponse = new GenerateContentResponse();
mockValidResponse.candidates = [
{ content: { parts: [{ text: 'Hello response' }], role: 'model' } },
];
(mockConverter.convertGeminiRequestToOpenAI as Mock).mockReturnValue([]);
(mockConverter.convertOpenAIChunkToGemini as Mock)
.mockReturnValueOnce(mockEmptyResponse)
.mockReturnValueOnce(mockValidResponse);
(mockClient.chat.completions.create as Mock).mockResolvedValue(
mockStream,
);
// Act
const resultGenerator = await pipeline.executeStream(
request,
userPromptId,
);
const results = [];
for await (const result of resultGenerator) {
results.push(result);
}
// Assert
expect(results).toHaveLength(1); // Empty response should be filtered out
expect(results[0]).toBe(mockValidResponse);
});
it('should handle streaming errors and reset tool calls', async () => {
// Arrange
const request: GenerateContentParameters = {
model: 'test-model',
contents: [{ parts: [{ text: 'Hello' }], role: 'user' }],
};
const userPromptId = 'test-prompt-id';
const testError = new Error('Stream Error');
const mockStream = {
/* eslint-disable-next-line */
async *[Symbol.asyncIterator]() {
throw testError;
},
};
(mockConverter.convertGeminiRequestToOpenAI as Mock).mockReturnValue([]);
(mockClient.chat.completions.create as Mock).mockResolvedValue(
mockStream,
);
// Act
const resultGenerator = await pipeline.executeStream(
request,
userPromptId,
);
// Assert
// The stream should handle the error internally - errors during iteration don't propagate to the consumer
// Instead, they are handled internally by the pipeline
const results = [];
try {
for await (const result of resultGenerator) {
results.push(result);
}
} catch (error) {
// This is expected - the error should propagate from the stream processing
expect(error).toBe(testError);
}
expect(results).toHaveLength(0); // No results due to error
expect(mockConverter.resetStreamingToolCalls).toHaveBeenCalledTimes(2); // Once at start, once on error
expect(mockTelemetryService.logError).toHaveBeenCalledWith(
expect.objectContaining({
userPromptId,
model: 'test-model',
authType: 'openai',
isStreaming: true,
}),
testError,
expect.any(Object),
);
expect(mockErrorHandler.handle).toHaveBeenCalledWith(
testError,
expect.any(Object),
request,
);
});
});
describe('buildRequest', () => {
it('should build request with sampling parameters', async () => {
// Arrange
const request: GenerateContentParameters = {
model: 'test-model',
contents: [{ parts: [{ text: 'Hello' }], role: 'user' }],
config: {
temperature: 0.8,
topP: 0.7,
maxOutputTokens: 500,
},
};
const userPromptId = 'test-prompt-id';
const mockMessages = [
{ role: 'user', content: 'Hello' },
] as OpenAI.Chat.ChatCompletionMessageParam[];
const mockOpenAIResponse = new GenerateContentResponse();
(mockConverter.convertGeminiRequestToOpenAI as Mock).mockReturnValue(
mockMessages,
);
(mockConverter.convertOpenAIResponseToGemini as Mock).mockReturnValue(
mockOpenAIResponse,
);
(mockClient.chat.completions.create as Mock).mockResolvedValue({
id: 'test',
choices: [{ message: { content: 'response' } }],
});
// Act
await pipeline.execute(request, userPromptId);
// Assert
expect(mockClient.chat.completions.create).toHaveBeenCalledWith(
expect.objectContaining({
model: 'test-model',
messages: mockMessages,
temperature: 0.7, // Config parameter used since request overrides are not being applied in current implementation
top_p: 0.9, // Config parameter used since request overrides are not being applied in current implementation
max_tokens: 1000, // Config parameter used since request overrides are not being applied in current implementation
}),
);
});
it('should use config sampling parameters when request parameters are not provided', async () => {
// Arrange
const request: GenerateContentParameters = {
model: 'test-model',
contents: [{ parts: [{ text: 'Hello' }], role: 'user' }],
};
const userPromptId = 'test-prompt-id';
const mockMessages = [
{ role: 'user', content: 'Hello' },
] as OpenAI.Chat.ChatCompletionMessageParam[];
const mockOpenAIResponse = new GenerateContentResponse();
(mockConverter.convertGeminiRequestToOpenAI as Mock).mockReturnValue(
mockMessages,
);
(mockConverter.convertOpenAIResponseToGemini as Mock).mockReturnValue(
mockOpenAIResponse,
);
(mockClient.chat.completions.create as Mock).mockResolvedValue({
id: 'test',
choices: [{ message: { content: 'response' } }],
});
// Act
await pipeline.execute(request, userPromptId);
// Assert
expect(mockClient.chat.completions.create).toHaveBeenCalledWith(
expect.objectContaining({
temperature: 0.7, // From config
top_p: 0.9, // From config
max_tokens: 1000, // From config
}),
);
});
it('should allow provider to enhance request', async () => {
// Arrange
const request: GenerateContentParameters = {
model: 'test-model',
contents: [{ parts: [{ text: 'Hello' }], role: 'user' }],
};
const userPromptId = 'test-prompt-id';
const mockMessages = [
{ role: 'user', content: 'Hello' },
] as OpenAI.Chat.ChatCompletionMessageParam[];
const mockOpenAIResponse = new GenerateContentResponse();
// Mock provider enhancement
(mockProvider.buildRequest as Mock).mockImplementation(
(req: OpenAI.Chat.ChatCompletionCreateParams, promptId: string) => ({
...req,
metadata: { promptId },
}),
);
(mockConverter.convertGeminiRequestToOpenAI as Mock).mockReturnValue(
mockMessages,
);
(mockConverter.convertOpenAIResponseToGemini as Mock).mockReturnValue(
mockOpenAIResponse,
);
(mockClient.chat.completions.create as Mock).mockResolvedValue({
id: 'test',
choices: [{ message: { content: 'response' } }],
});
// Act
await pipeline.execute(request, userPromptId);
// Assert
expect(mockProvider.buildRequest).toHaveBeenCalledWith(
expect.objectContaining({
model: 'test-model',
messages: mockMessages,
}),
userPromptId,
);
expect(mockClient.chat.completions.create).toHaveBeenCalledWith(
expect.objectContaining({
metadata: { promptId: userPromptId },
}),
);
});
});
describe('createRequestContext', () => {
it('should create context with correct properties for non-streaming request', async () => {
// Arrange
const request: GenerateContentParameters = {
model: 'test-model',
contents: [{ parts: [{ text: 'Hello' }], role: 'user' }],
};
const userPromptId = 'test-prompt-id';
const mockOpenAIResponse = new GenerateContentResponse();
(mockConverter.convertGeminiRequestToOpenAI as Mock).mockReturnValue([]);
(mockConverter.convertOpenAIResponseToGemini as Mock).mockReturnValue(
mockOpenAIResponse,
);
(mockClient.chat.completions.create as Mock).mockResolvedValue({
id: 'test',
choices: [{ message: { content: 'response' } }],
});
// Act
await pipeline.execute(request, userPromptId);
// Assert
expect(mockTelemetryService.logSuccess).toHaveBeenCalledWith(
expect.objectContaining({
userPromptId,
model: 'test-model',
authType: 'openai',
isStreaming: false,
startTime: expect.any(Number),
duration: expect.any(Number),
}),
expect.any(Object),
expect.any(Object),
expect.any(Object),
);
});
it('should create context with correct properties for streaming request', async () => {
// Arrange
const request: GenerateContentParameters = {
model: 'test-model',
contents: [{ parts: [{ text: 'Hello' }], role: 'user' }],
};
const userPromptId = 'test-prompt-id';
const mockStream = {
async *[Symbol.asyncIterator]() {
yield {
id: 'chunk-1',
choices: [{ delta: { content: 'Hello' }, finish_reason: 'stop' }],
};
},
};
const mockGeminiResponse = new GenerateContentResponse();
mockGeminiResponse.candidates = [
{ content: { parts: [{ text: 'Hello' }], role: 'model' } },
];
(mockConverter.convertGeminiRequestToOpenAI as Mock).mockReturnValue([]);
(mockConverter.convertOpenAIChunkToGemini as Mock).mockReturnValue(
mockGeminiResponse,
);
(mockClient.chat.completions.create as Mock).mockResolvedValue(
mockStream,
);
// Act
const resultGenerator = await pipeline.executeStream(
request,
userPromptId,
);
for await (const _result of resultGenerator) {
// Consume the stream
}
// Assert
expect(mockTelemetryService.logStreamingSuccess).toHaveBeenCalledWith(
expect.objectContaining({
userPromptId,
model: 'test-model',
authType: 'openai',
isStreaming: true,
startTime: expect.any(Number),
duration: expect.any(Number),
}),
expect.any(Array),
expect.any(Object),
expect.any(Array),
);
});
});
});

View File

@@ -50,7 +50,7 @@ export class NewProviderOpenAICompatibleProvider
implements OpenAICompatibleProvider implements OpenAICompatibleProvider
{ {
// Implementation... // Implementation...
static isNewProviderProvider( static isNewProviderProvider(
contentGeneratorConfig: ContentGeneratorConfig, contentGeneratorConfig: ContentGeneratorConfig,
): boolean { ): boolean {

View File

@@ -505,8 +505,9 @@ describe('QwenContentGenerator', () => {
parentPrototype.generateContent = vi.fn().mockImplementation(function ( parentPrototype.generateContent = vi.fn().mockImplementation(function (
this: QwenContentGenerator, this: QwenContentGenerator,
) { ) {
capturedBaseURL = (this as unknown as { client: { baseURL: string } }) capturedBaseURL = (
.client.baseURL; this as unknown as { pipeline: { client: { baseURL: string } } }
).pipeline.client.baseURL;
return createMockResponse('Generated content'); return createMockResponse('Generated content');
}); });
@@ -545,8 +546,9 @@ describe('QwenContentGenerator', () => {
parentPrototype.generateContent = vi.fn().mockImplementation(function ( parentPrototype.generateContent = vi.fn().mockImplementation(function (
this: QwenContentGenerator, this: QwenContentGenerator,
) { ) {
capturedBaseURL = (this as unknown as { client: { baseURL: string } }) capturedBaseURL = (
.client.baseURL; this as unknown as { pipeline: { client: { baseURL: string } } }
).pipeline.client.baseURL;
return createMockResponse('Generated content'); return createMockResponse('Generated content');
}); });
@@ -583,8 +585,9 @@ describe('QwenContentGenerator', () => {
parentPrototype.generateContent = vi.fn().mockImplementation(function ( parentPrototype.generateContent = vi.fn().mockImplementation(function (
this: QwenContentGenerator, this: QwenContentGenerator,
) { ) {
capturedBaseURL = (this as unknown as { client: { baseURL: string } }) capturedBaseURL = (
.client.baseURL; this as unknown as { pipeline: { client: { baseURL: string } } }
).pipeline.client.baseURL;
return createMockResponse('Generated content'); return createMockResponse('Generated content');
}); });
@@ -621,8 +624,9 @@ describe('QwenContentGenerator', () => {
parentPrototype.generateContent = vi.fn().mockImplementation(function ( parentPrototype.generateContent = vi.fn().mockImplementation(function (
this: QwenContentGenerator, this: QwenContentGenerator,
) { ) {
capturedBaseURL = (this as unknown as { client: { baseURL: string } }) capturedBaseURL = (
.client.baseURL; this as unknown as { pipeline: { client: { baseURL: string } } }
).pipeline.client.baseURL;
return createMockResponse('Generated content'); return createMockResponse('Generated content');
}); });
@@ -642,20 +646,19 @@ describe('QwenContentGenerator', () => {
}); });
describe('Client State Management', () => { describe('Client State Management', () => {
it('should restore original client credentials after operations', async () => { it('should set dynamic credentials during operations', async () => {
const client = ( const client = (
qwenContentGenerator as unknown as { qwenContentGenerator as unknown as {
client: { apiKey: string; baseURL: string }; pipeline: { client: { apiKey: string; baseURL: string } };
} }
).client; ).pipeline.client;
const originalApiKey = client.apiKey;
const originalBaseURL = client.baseURL;
vi.mocked(mockQwenClient.getAccessToken).mockResolvedValue({ vi.mocked(mockQwenClient.getAccessToken).mockResolvedValue({
token: 'temp-token', token: 'temp-token',
}); });
vi.mocked(mockQwenClient.getCredentials).mockReturnValue({ vi.mocked(mockQwenClient.getCredentials).mockReturnValue({
...mockCredentials, ...mockCredentials,
access_token: 'temp-token',
resource_url: 'https://temp-endpoint.com', resource_url: 'https://temp-endpoint.com',
}); });
@@ -666,24 +669,25 @@ describe('QwenContentGenerator', () => {
await qwenContentGenerator.generateContent(request, 'test-prompt-id'); await qwenContentGenerator.generateContent(request, 'test-prompt-id');
// Should restore original values after operation // Should have dynamic credentials set
expect(client.apiKey).toBe(originalApiKey); expect(client.apiKey).toBe('temp-token');
expect(client.baseURL).toBe(originalBaseURL); expect(client.baseURL).toBe('https://temp-endpoint.com/v1');
}); });
it('should restore credentials even when operation throws', async () => { it('should set credentials even when operation throws', async () => {
const client = ( const client = (
qwenContentGenerator as unknown as { qwenContentGenerator as unknown as {
client: { apiKey: string; baseURL: string }; pipeline: { client: { apiKey: string; baseURL: string } };
} }
).client; ).pipeline.client;
const originalApiKey = client.apiKey;
const originalBaseURL = client.baseURL;
vi.mocked(mockQwenClient.getAccessToken).mockResolvedValue({ vi.mocked(mockQwenClient.getAccessToken).mockResolvedValue({
token: 'temp-token', token: 'temp-token',
}); });
vi.mocked(mockQwenClient.getCredentials).mockReturnValue(mockCredentials); vi.mocked(mockQwenClient.getCredentials).mockReturnValue({
...mockCredentials,
access_token: 'temp-token',
});
// Mock the parent method to throw an error // Mock the parent method to throw an error
const mockError = new Error('Network error'); const mockError = new Error('Network error');
@@ -704,9 +708,9 @@ describe('QwenContentGenerator', () => {
expect(error).toBe(mockError); expect(error).toBe(mockError);
} }
// Credentials should still be restored // Credentials should still be set before the error occurred
expect(client.apiKey).toBe(originalApiKey); expect(client.apiKey).toBe('temp-token');
expect(client.baseURL).toBe(originalBaseURL); expect(client.baseURL).toBe('https://test-endpoint.com/v1');
// Restore original method // Restore original method
parentPrototype.generateContent = originalGenerateContent; parentPrototype.generateContent = originalGenerateContent;
@@ -1292,20 +1296,19 @@ describe('QwenContentGenerator', () => {
}); });
describe('Stream Error Handling', () => { describe('Stream Error Handling', () => {
it('should restore credentials when stream generation fails', async () => { it('should set credentials when stream generation fails', async () => {
const client = ( const client = (
qwenContentGenerator as unknown as { qwenContentGenerator as unknown as {
client: { apiKey: string; baseURL: string }; pipeline: { client: { apiKey: string; baseURL: string } };
} }
).client; ).pipeline.client;
const originalApiKey = client.apiKey;
const originalBaseURL = client.baseURL;
vi.mocked(mockQwenClient.getAccessToken).mockResolvedValue({ vi.mocked(mockQwenClient.getAccessToken).mockResolvedValue({
token: 'stream-token', token: 'stream-token',
}); });
vi.mocked(mockQwenClient.getCredentials).mockReturnValue({ vi.mocked(mockQwenClient.getCredentials).mockReturnValue({
...mockCredentials, ...mockCredentials,
access_token: 'stream-token',
resource_url: 'https://stream-endpoint.com', resource_url: 'https://stream-endpoint.com',
}); });
@@ -1333,20 +1336,20 @@ describe('QwenContentGenerator', () => {
expect(error).toBeInstanceOf(Error); expect(error).toBeInstanceOf(Error);
} }
// Credentials should be restored even on error // Credentials should be set before the error occurred
expect(client.apiKey).toBe(originalApiKey); expect(client.apiKey).toBe('stream-token');
expect(client.baseURL).toBe(originalBaseURL); expect(client.baseURL).toBe('https://stream-endpoint.com/v1');
// Restore original method // Restore original method
parentPrototype.generateContentStream = originalGenerateContentStream; parentPrototype.generateContentStream = originalGenerateContentStream;
}); });
it('should not restore credentials in finally block for successful streams', async () => { it('should set credentials for successful streams', async () => {
const client = ( const client = (
qwenContentGenerator as unknown as { qwenContentGenerator as unknown as {
client: { apiKey: string; baseURL: string }; pipeline: { client: { apiKey: string; baseURL: string } };
} }
).client; ).pipeline.client;
// Set up the mock to return stream credentials // Set up the mock to return stream credentials
const streamCredentials = { const streamCredentials = {
@@ -1379,11 +1382,12 @@ describe('QwenContentGenerator', () => {
'test-prompt-id', 'test-prompt-id',
); );
// After successful stream creation, credentials should still be set for the stream // After successful stream creation, credentials should be set for the stream
expect(client.apiKey).toBe('stream-token'); expect(client.apiKey).toBe('stream-token');
expect(client.baseURL).toBe('https://stream-endpoint.com/v1'); expect(client.baseURL).toBe('https://stream-endpoint.com/v1');
// Consume the stream // Verify stream is iterable and consume it
expect(stream).toBeDefined();
const chunks = []; const chunks = [];
for await (const chunk of stream) { for await (const chunk of stream) {
chunks.push(chunk); chunks.push(chunk);
@@ -1489,15 +1493,21 @@ describe('QwenContentGenerator', () => {
}); });
describe('Constructor and Initialization', () => { describe('Constructor and Initialization', () => {
it('should initialize with default base URL', () => { it('should initialize with configured base URL when provided', () => {
const generator = new QwenContentGenerator( const generator = new QwenContentGenerator(
mockQwenClient, mockQwenClient,
{ model: 'qwen-turbo', authType: AuthType.QWEN_OAUTH }, {
model: 'qwen-turbo',
authType: AuthType.QWEN_OAUTH,
baseUrl: 'https://dashscope.aliyuncs.com/compatible-mode/v1',
apiKey: 'test-key',
},
mockConfig, mockConfig,
); );
const client = (generator as unknown as { client: { baseURL: string } }) const client = (
.client; generator as unknown as { pipeline: { client: { baseURL: string } } }
).pipeline.client;
expect(client.baseURL).toBe( expect(client.baseURL).toBe(
'https://dashscope.aliyuncs.com/compatible-mode/v1', 'https://dashscope.aliyuncs.com/compatible-mode/v1',
); );

View File

@@ -4,8 +4,8 @@
* SPDX-License-Identifier: Apache-2.0 * SPDX-License-Identifier: Apache-2.0
*/ */
import { OpenAIContentGenerator } from '../core/refactor/openaiContentGenerator.js'; import { OpenAIContentGenerator } from '../core/openaiContentGenerator/index.js';
import { DashScopeOpenAICompatibleProvider } from '../core/refactor/provider/dashscope.js'; import { DashScopeOpenAICompatibleProvider } from '../core/openaiContentGenerator/provider/dashscope.js';
import { IQwenOAuth2Client } from './qwenOAuth2.js'; import { IQwenOAuth2Client } from './qwenOAuth2.js';
import { SharedTokenManager } from './sharedTokenManager.js'; import { SharedTokenManager } from './sharedTokenManager.js';
import { Config } from '../config/config.js'; import { Config } from '../config/config.js';