mirror of
https://github.com/QwenLM/qwen-code.git
synced 2025-12-21 01:07:46 +00:00
refactor: optimize stream handling
This commit is contained in:
71
packages/core/src/core/refactor/converter.test.ts
Normal file
71
packages/core/src/core/refactor/converter.test.ts
Normal file
@@ -0,0 +1,71 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Qwen
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { describe, it, expect, beforeEach } from 'vitest';
|
||||
import { OpenAIContentConverter } from './converter.js';
|
||||
import { StreamingToolCallParser } from './streamingToolCallParser.js';
|
||||
|
||||
describe('OpenAIContentConverter', () => {
|
||||
let converter: OpenAIContentConverter;
|
||||
|
||||
beforeEach(() => {
|
||||
converter = new OpenAIContentConverter('test-model');
|
||||
});
|
||||
|
||||
describe('resetStreamingToolCalls', () => {
|
||||
it('should clear streaming tool calls accumulator', () => {
|
||||
// Access private field for testing
|
||||
const parser = (
|
||||
converter as unknown as {
|
||||
streamingToolCallParser: StreamingToolCallParser;
|
||||
}
|
||||
).streamingToolCallParser;
|
||||
|
||||
// Add some test data to the parser
|
||||
parser.addChunk(0, '{"arg": "value"}', 'test-id', 'test-function');
|
||||
parser.addChunk(1, '{"arg2": "value2"}', 'test-id-2', 'test-function-2');
|
||||
|
||||
// Verify data is present
|
||||
expect(parser.getBuffer(0)).toBe('{"arg": "value"}');
|
||||
expect(parser.getBuffer(1)).toBe('{"arg2": "value2"}');
|
||||
|
||||
// Call reset method
|
||||
converter.resetStreamingToolCalls();
|
||||
|
||||
// Verify data is cleared
|
||||
expect(parser.getBuffer(0)).toBe('');
|
||||
expect(parser.getBuffer(1)).toBe('');
|
||||
});
|
||||
|
||||
it('should be safe to call multiple times', () => {
|
||||
// Call reset multiple times
|
||||
converter.resetStreamingToolCalls();
|
||||
converter.resetStreamingToolCalls();
|
||||
converter.resetStreamingToolCalls();
|
||||
|
||||
// Should not throw any errors
|
||||
const parser = (
|
||||
converter as unknown as {
|
||||
streamingToolCallParser: StreamingToolCallParser;
|
||||
}
|
||||
).streamingToolCallParser;
|
||||
expect(parser.getBuffer(0)).toBe('');
|
||||
});
|
||||
|
||||
it('should be safe to call on empty accumulator', () => {
|
||||
// Call reset on empty accumulator
|
||||
converter.resetStreamingToolCalls();
|
||||
|
||||
// Should not throw any errors
|
||||
const parser = (
|
||||
converter as unknown as {
|
||||
streamingToolCallParser: StreamingToolCallParser;
|
||||
}
|
||||
).streamingToolCallParser;
|
||||
expect(parser.getBuffer(0)).toBe('');
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -15,28 +15,59 @@ import {
|
||||
CallableTool,
|
||||
FunctionCall,
|
||||
FunctionResponse,
|
||||
ContentListUnion,
|
||||
ContentUnion,
|
||||
PartUnion,
|
||||
} from '@google/genai';
|
||||
import OpenAI from 'openai';
|
||||
import { safeJsonParse } from '../../utils/safeJsonParse.js';
|
||||
import { StreamingToolCallParser } from './streamingToolCallParser.js';
|
||||
|
||||
/**
|
||||
* Tool call accumulator for streaming responses
|
||||
*/
|
||||
export interface ToolCallAccumulator {
|
||||
id?: string;
|
||||
name?: string;
|
||||
arguments: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* Parsed parts from Gemini content, categorized by type
|
||||
*/
|
||||
interface ParsedParts {
|
||||
textParts: string[];
|
||||
functionCalls: FunctionCall[];
|
||||
functionResponses: FunctionResponse[];
|
||||
mediaParts: Array<{
|
||||
type: 'image' | 'audio' | 'file';
|
||||
data: string;
|
||||
mimeType: string;
|
||||
fileUri?: string;
|
||||
}>;
|
||||
}
|
||||
|
||||
/**
|
||||
* Converter class for transforming data between Gemini and OpenAI formats
|
||||
*/
|
||||
export class Converter {
|
||||
export class OpenAIContentConverter {
|
||||
private model: string;
|
||||
private streamingToolCalls: Map<
|
||||
number,
|
||||
{
|
||||
id?: string;
|
||||
name?: string;
|
||||
arguments: string;
|
||||
}
|
||||
> = new Map();
|
||||
private streamingToolCallParser: StreamingToolCallParser =
|
||||
new StreamingToolCallParser();
|
||||
|
||||
constructor(model: string) {
|
||||
this.model = model;
|
||||
}
|
||||
|
||||
/**
|
||||
* Reset streaming tool calls parser for new stream processing
|
||||
* This should be called at the beginning of each stream to prevent
|
||||
* data pollution from previous incomplete streams
|
||||
*/
|
||||
resetStreamingToolCalls(): void {
|
||||
this.streamingToolCallParser.reset();
|
||||
}
|
||||
|
||||
/**
|
||||
* Convert Gemini tool parameters to OpenAI JSON Schema format
|
||||
*/
|
||||
@@ -173,132 +204,10 @@ export class Converter {
|
||||
const messages: OpenAI.Chat.ChatCompletionMessageParam[] = [];
|
||||
|
||||
// Handle system instruction from config
|
||||
if (request.config?.systemInstruction) {
|
||||
const systemInstruction = request.config.systemInstruction;
|
||||
let systemText = '';
|
||||
|
||||
if (Array.isArray(systemInstruction)) {
|
||||
systemText = systemInstruction
|
||||
.map((content) => {
|
||||
if (typeof content === 'string') return content;
|
||||
if ('parts' in content) {
|
||||
const contentObj = content as Content;
|
||||
return (
|
||||
contentObj.parts
|
||||
?.map((p: Part) =>
|
||||
typeof p === 'string' ? p : 'text' in p ? p.text : '',
|
||||
)
|
||||
.join('\n') || ''
|
||||
);
|
||||
}
|
||||
return '';
|
||||
})
|
||||
.join('\n');
|
||||
} else if (typeof systemInstruction === 'string') {
|
||||
systemText = systemInstruction;
|
||||
} else if (
|
||||
typeof systemInstruction === 'object' &&
|
||||
'parts' in systemInstruction
|
||||
) {
|
||||
const systemContent = systemInstruction as Content;
|
||||
systemText =
|
||||
systemContent.parts
|
||||
?.map((p: Part) =>
|
||||
typeof p === 'string' ? p : 'text' in p ? p.text : '',
|
||||
)
|
||||
.join('\n') || '';
|
||||
}
|
||||
|
||||
if (systemText) {
|
||||
messages.push({
|
||||
role: 'system' as const,
|
||||
content: systemText,
|
||||
});
|
||||
}
|
||||
}
|
||||
this.addSystemInstructionMessage(request, messages);
|
||||
|
||||
// Handle contents
|
||||
if (Array.isArray(request.contents)) {
|
||||
for (const content of request.contents) {
|
||||
if (typeof content === 'string') {
|
||||
messages.push({ role: 'user' as const, content });
|
||||
} else if ('role' in content && 'parts' in content) {
|
||||
// Check if this content has function calls or responses
|
||||
const functionCalls: FunctionCall[] = [];
|
||||
const functionResponses: FunctionResponse[] = [];
|
||||
const textParts: string[] = [];
|
||||
|
||||
for (const part of content.parts || []) {
|
||||
if (typeof part === 'string') {
|
||||
textParts.push(part);
|
||||
} else if ('text' in part && part.text) {
|
||||
textParts.push(part.text);
|
||||
} else if ('functionCall' in part && part.functionCall) {
|
||||
functionCalls.push(part.functionCall);
|
||||
} else if ('functionResponse' in part && part.functionResponse) {
|
||||
functionResponses.push(part.functionResponse);
|
||||
}
|
||||
}
|
||||
|
||||
// Handle function responses (tool results)
|
||||
if (functionResponses.length > 0) {
|
||||
for (const funcResponse of functionResponses) {
|
||||
messages.push({
|
||||
role: 'tool' as const,
|
||||
tool_call_id: funcResponse.id || '',
|
||||
content:
|
||||
typeof funcResponse.response === 'string'
|
||||
? funcResponse.response
|
||||
: JSON.stringify(funcResponse.response),
|
||||
});
|
||||
}
|
||||
}
|
||||
// Handle model messages with function calls
|
||||
else if (content.role === 'model' && functionCalls.length > 0) {
|
||||
const toolCalls = functionCalls.map((fc, index) => ({
|
||||
id: fc.id || `call_${index}`,
|
||||
type: 'function' as const,
|
||||
function: {
|
||||
name: fc.name || '',
|
||||
arguments: JSON.stringify(fc.args || {}),
|
||||
},
|
||||
}));
|
||||
|
||||
messages.push({
|
||||
role: 'assistant' as const,
|
||||
content: textParts.join('') || null,
|
||||
tool_calls: toolCalls,
|
||||
});
|
||||
}
|
||||
// Handle regular text messages
|
||||
else {
|
||||
const role =
|
||||
content.role === 'model'
|
||||
? ('assistant' as const)
|
||||
: ('user' as const);
|
||||
const text = textParts.join('');
|
||||
if (text) {
|
||||
messages.push({ role, content: text });
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if (request.contents) {
|
||||
if (typeof request.contents === 'string') {
|
||||
messages.push({ role: 'user' as const, content: request.contents });
|
||||
} else if ('role' in request.contents && 'parts' in request.contents) {
|
||||
const content = request.contents;
|
||||
const role =
|
||||
content.role === 'model' ? ('assistant' as const) : ('user' as const);
|
||||
const text =
|
||||
content.parts
|
||||
?.map((p: Part) =>
|
||||
typeof p === 'string' ? p : 'text' in p ? p.text : '',
|
||||
)
|
||||
.join('\n') || '';
|
||||
messages.push({ role, content: text });
|
||||
}
|
||||
}
|
||||
this.processContents(request.contents, messages);
|
||||
|
||||
// Clean up orphaned tool calls and merge consecutive assistant messages
|
||||
const cleanedMessages = this.cleanOrphanedToolCalls(messages);
|
||||
@@ -308,6 +217,285 @@ export class Converter {
|
||||
return mergedMessages;
|
||||
}
|
||||
|
||||
/**
|
||||
* Extract and add system instruction message from request config
|
||||
*/
|
||||
private addSystemInstructionMessage(
|
||||
request: GenerateContentParameters,
|
||||
messages: OpenAI.Chat.ChatCompletionMessageParam[],
|
||||
): void {
|
||||
if (!request.config?.systemInstruction) return;
|
||||
|
||||
const systemText = this.extractTextFromContentUnion(
|
||||
request.config.systemInstruction,
|
||||
);
|
||||
|
||||
if (systemText) {
|
||||
messages.push({
|
||||
role: 'system' as const,
|
||||
content: systemText,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Process contents and convert to OpenAI messages
|
||||
*/
|
||||
private processContents(
|
||||
contents: ContentListUnion,
|
||||
messages: OpenAI.Chat.ChatCompletionMessageParam[],
|
||||
): void {
|
||||
if (Array.isArray(contents)) {
|
||||
for (const content of contents) {
|
||||
this.processContent(content, messages);
|
||||
}
|
||||
} else if (contents) {
|
||||
this.processContent(contents, messages);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Process a single content item and convert to OpenAI message(s)
|
||||
*/
|
||||
private processContent(
|
||||
content: ContentUnion | PartUnion,
|
||||
messages: OpenAI.Chat.ChatCompletionMessageParam[],
|
||||
): void {
|
||||
if (typeof content === 'string') {
|
||||
messages.push({ role: 'user' as const, content });
|
||||
return;
|
||||
}
|
||||
|
||||
if (!this.isContentObject(content)) return;
|
||||
|
||||
const parsedParts = this.parseParts(content.parts || []);
|
||||
|
||||
// Handle function responses (tool results) first
|
||||
if (parsedParts.functionResponses.length > 0) {
|
||||
for (const funcResponse of parsedParts.functionResponses) {
|
||||
messages.push({
|
||||
role: 'tool' as const,
|
||||
tool_call_id: funcResponse.id || '',
|
||||
content:
|
||||
typeof funcResponse.response === 'string'
|
||||
? funcResponse.response
|
||||
: JSON.stringify(funcResponse.response),
|
||||
});
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
// Handle model messages with function calls
|
||||
if (content.role === 'model' && parsedParts.functionCalls.length > 0) {
|
||||
const toolCalls = parsedParts.functionCalls.map((fc, index) => ({
|
||||
id: fc.id || `call_${index}`,
|
||||
type: 'function' as const,
|
||||
function: {
|
||||
name: fc.name || '',
|
||||
arguments: JSON.stringify(fc.args || {}),
|
||||
},
|
||||
}));
|
||||
|
||||
messages.push({
|
||||
role: 'assistant' as const,
|
||||
content: parsedParts.textParts.join('') || null,
|
||||
tool_calls: toolCalls,
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
// Handle regular messages with multimodal content
|
||||
const role = content.role === 'model' ? 'assistant' : 'user';
|
||||
const openAIMessage = this.createMultimodalMessage(role, parsedParts);
|
||||
|
||||
if (openAIMessage) {
|
||||
messages.push(openAIMessage);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Parse Gemini parts into categorized components
|
||||
*/
|
||||
private parseParts(parts: Part[]): ParsedParts {
|
||||
const textParts: string[] = [];
|
||||
const functionCalls: FunctionCall[] = [];
|
||||
const functionResponses: FunctionResponse[] = [];
|
||||
const mediaParts: Array<{
|
||||
type: 'image' | 'audio' | 'file';
|
||||
data: string;
|
||||
mimeType: string;
|
||||
fileUri?: string;
|
||||
}> = [];
|
||||
|
||||
for (const part of parts) {
|
||||
if (typeof part === 'string') {
|
||||
textParts.push(part);
|
||||
} else if ('text' in part && part.text) {
|
||||
textParts.push(part.text);
|
||||
} else if ('functionCall' in part && part.functionCall) {
|
||||
functionCalls.push(part.functionCall);
|
||||
} else if ('functionResponse' in part && part.functionResponse) {
|
||||
functionResponses.push(part.functionResponse);
|
||||
} else if ('inlineData' in part && part.inlineData) {
|
||||
const { data, mimeType } = part.inlineData;
|
||||
if (data && mimeType) {
|
||||
const mediaType = this.getMediaType(mimeType);
|
||||
mediaParts.push({ type: mediaType, data, mimeType });
|
||||
}
|
||||
} else if ('fileData' in part && part.fileData) {
|
||||
const { fileUri, mimeType } = part.fileData;
|
||||
if (fileUri && mimeType) {
|
||||
const mediaType = this.getMediaType(mimeType);
|
||||
mediaParts.push({
|
||||
type: mediaType,
|
||||
data: '',
|
||||
mimeType,
|
||||
fileUri,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return { textParts, functionCalls, functionResponses, mediaParts };
|
||||
}
|
||||
|
||||
/**
|
||||
* Determine media type from MIME type
|
||||
*/
|
||||
private getMediaType(mimeType: string): 'image' | 'audio' | 'file' {
|
||||
if (mimeType.startsWith('image/')) return 'image';
|
||||
if (mimeType.startsWith('audio/')) return 'audio';
|
||||
return 'file';
|
||||
}
|
||||
|
||||
/**
|
||||
* Create multimodal OpenAI message from parsed parts
|
||||
*/
|
||||
private createMultimodalMessage(
|
||||
role: 'user' | 'assistant',
|
||||
parsedParts: Pick<ParsedParts, 'textParts' | 'mediaParts'>,
|
||||
): OpenAI.Chat.ChatCompletionMessageParam | null {
|
||||
const { textParts, mediaParts } = parsedParts;
|
||||
const combinedText = textParts.join('');
|
||||
|
||||
// If no media parts, return simple text message
|
||||
if (mediaParts.length === 0) {
|
||||
return combinedText ? { role, content: combinedText } : null;
|
||||
}
|
||||
|
||||
// For assistant messages with media, convert to text only
|
||||
// since OpenAI assistant messages don't support media content arrays
|
||||
if (role === 'assistant') {
|
||||
return combinedText
|
||||
? { role: 'assistant' as const, content: combinedText }
|
||||
: null;
|
||||
}
|
||||
|
||||
// Create multimodal content array for user messages
|
||||
const contentArray: OpenAI.Chat.ChatCompletionContentPart[] = [];
|
||||
|
||||
// Add text content
|
||||
if (combinedText) {
|
||||
contentArray.push({ type: 'text', text: combinedText });
|
||||
}
|
||||
|
||||
// Add media content
|
||||
for (const mediaPart of mediaParts) {
|
||||
if (mediaPart.type === 'image') {
|
||||
if (mediaPart.fileUri) {
|
||||
// For file URIs, use the URI directly
|
||||
contentArray.push({
|
||||
type: 'image_url',
|
||||
image_url: { url: mediaPart.fileUri },
|
||||
});
|
||||
} else if (mediaPart.data) {
|
||||
// For inline data, create data URL
|
||||
const dataUrl = `data:${mediaPart.mimeType};base64,${mediaPart.data}`;
|
||||
contentArray.push({
|
||||
type: 'image_url',
|
||||
image_url: { url: dataUrl },
|
||||
});
|
||||
}
|
||||
} else if (mediaPart.type === 'audio' && mediaPart.data) {
|
||||
// Convert audio format from MIME type
|
||||
const format = this.getAudioFormat(mediaPart.mimeType);
|
||||
if (format) {
|
||||
contentArray.push({
|
||||
type: 'input_audio',
|
||||
input_audio: {
|
||||
data: mediaPart.data,
|
||||
format: format as 'wav' | 'mp3',
|
||||
},
|
||||
});
|
||||
}
|
||||
}
|
||||
// Note: File type is not directly supported in OpenAI's current API
|
||||
// Could be extended in the future or handled as text description
|
||||
}
|
||||
|
||||
return contentArray.length > 0
|
||||
? { role: 'user' as const, content: contentArray }
|
||||
: null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Convert MIME type to OpenAI audio format
|
||||
*/
|
||||
private getAudioFormat(mimeType: string): 'wav' | 'mp3' | null {
|
||||
if (mimeType.includes('wav')) return 'wav';
|
||||
if (mimeType.includes('mp3') || mimeType.includes('mpeg')) return 'mp3';
|
||||
return null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Type guard to check if content is a valid Content object
|
||||
*/
|
||||
private isContentObject(
|
||||
content: unknown,
|
||||
): content is { role: string; parts: Part[] } {
|
||||
return (
|
||||
typeof content === 'object' &&
|
||||
content !== null &&
|
||||
'role' in content &&
|
||||
'parts' in content &&
|
||||
Array.isArray((content as Record<string, unknown>)['parts'])
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Extract text content from various Gemini content union types
|
||||
*/
|
||||
private extractTextFromContentUnion(contentUnion: unknown): string {
|
||||
if (typeof contentUnion === 'string') {
|
||||
return contentUnion;
|
||||
}
|
||||
|
||||
if (Array.isArray(contentUnion)) {
|
||||
return contentUnion
|
||||
.map((item) => this.extractTextFromContentUnion(item))
|
||||
.filter(Boolean)
|
||||
.join('\n');
|
||||
}
|
||||
|
||||
if (typeof contentUnion === 'object' && contentUnion !== null) {
|
||||
if ('parts' in contentUnion) {
|
||||
const content = contentUnion as Content;
|
||||
return (
|
||||
content.parts
|
||||
?.map((part: Part) => {
|
||||
if (typeof part === 'string') return part;
|
||||
if ('text' in part) return part.text || '';
|
||||
return '';
|
||||
})
|
||||
.filter(Boolean)
|
||||
.join('\n') || ''
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
return '';
|
||||
}
|
||||
|
||||
/**
|
||||
* Convert OpenAI response to Gemini format
|
||||
*/
|
||||
@@ -416,79 +604,52 @@ export class Converter {
|
||||
}
|
||||
}
|
||||
|
||||
// Handle tool calls - only accumulate during streaming, emit when complete
|
||||
// Handle tool calls using the streaming parser
|
||||
if (choice.delta?.tool_calls) {
|
||||
for (const toolCall of choice.delta.tool_calls) {
|
||||
const index = toolCall.index ?? 0;
|
||||
|
||||
// Get or create the tool call accumulator for this index
|
||||
let accumulatedCall = this.streamingToolCalls.get(index);
|
||||
if (!accumulatedCall) {
|
||||
accumulatedCall = { arguments: '' };
|
||||
this.streamingToolCalls.set(index, accumulatedCall);
|
||||
}
|
||||
|
||||
// Update accumulated data
|
||||
if (toolCall.id) {
|
||||
accumulatedCall.id = toolCall.id;
|
||||
}
|
||||
if (toolCall.function?.name) {
|
||||
// If this is a new function name, reset the arguments
|
||||
if (accumulatedCall.name !== toolCall.function.name) {
|
||||
accumulatedCall.arguments = '';
|
||||
}
|
||||
accumulatedCall.name = toolCall.function.name;
|
||||
}
|
||||
// Process the tool call chunk through the streaming parser
|
||||
if (toolCall.function?.arguments) {
|
||||
// Check if we already have a complete JSON object
|
||||
const currentArgs = accumulatedCall.arguments;
|
||||
const newArgs = toolCall.function.arguments;
|
||||
|
||||
// If current arguments already form a complete JSON and new arguments start a new object,
|
||||
// this indicates a new tool call with the same name
|
||||
let shouldReset = false;
|
||||
if (currentArgs && newArgs.trim().startsWith('{')) {
|
||||
try {
|
||||
JSON.parse(currentArgs);
|
||||
// If we can parse current arguments as complete JSON and new args start with {,
|
||||
// this is likely a new tool call
|
||||
shouldReset = true;
|
||||
} catch {
|
||||
// Current arguments are not complete JSON, continue accumulating
|
||||
}
|
||||
}
|
||||
|
||||
if (shouldReset) {
|
||||
accumulatedCall.arguments = newArgs;
|
||||
} else {
|
||||
accumulatedCall.arguments += newArgs;
|
||||
}
|
||||
this.streamingToolCallParser.addChunk(
|
||||
index,
|
||||
toolCall.function.arguments,
|
||||
toolCall.id,
|
||||
toolCall.function.name,
|
||||
);
|
||||
} else {
|
||||
// Handle metadata-only chunks (id and/or name without arguments)
|
||||
this.streamingToolCallParser.addChunk(
|
||||
index,
|
||||
'', // Empty chunk for metadata-only updates
|
||||
toolCall.id,
|
||||
toolCall.function?.name,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Only emit function calls when streaming is complete (finish_reason is present)
|
||||
if (choice.finish_reason) {
|
||||
for (const [, accumulatedCall] of this.streamingToolCalls) {
|
||||
if (accumulatedCall.name) {
|
||||
let args: Record<string, unknown> = {};
|
||||
if (accumulatedCall.arguments) {
|
||||
args = safeJsonParse(accumulatedCall.arguments, {});
|
||||
}
|
||||
const completedToolCalls =
|
||||
this.streamingToolCallParser.getCompletedToolCalls();
|
||||
|
||||
for (const toolCall of completedToolCalls) {
|
||||
if (toolCall.name) {
|
||||
parts.push({
|
||||
functionCall: {
|
||||
id:
|
||||
accumulatedCall.id ||
|
||||
toolCall.id ||
|
||||
`call_${Date.now()}_${Math.random().toString(36).substring(2, 9)}`,
|
||||
name: accumulatedCall.name,
|
||||
args,
|
||||
name: toolCall.name,
|
||||
args: toolCall.args,
|
||||
},
|
||||
});
|
||||
}
|
||||
}
|
||||
// Clear all accumulated tool calls
|
||||
this.streamingToolCalls.clear();
|
||||
|
||||
// Clear the parser for the next stream
|
||||
this.streamingToolCallParser.reset();
|
||||
}
|
||||
|
||||
response.candidates = [
|
||||
|
||||
393
packages/core/src/core/refactor/errorHandler.test.ts
Normal file
393
packages/core/src/core/refactor/errorHandler.test.ts
Normal file
@@ -0,0 +1,393 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Qwen
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
|
||||
import { GenerateContentParameters } from '@google/genai';
|
||||
import { EnhancedErrorHandler } from './errorHandler.js';
|
||||
import { RequestContext } from './telemetryService.js';
|
||||
|
||||
describe('EnhancedErrorHandler', () => {
|
||||
let errorHandler: EnhancedErrorHandler;
|
||||
let mockConsoleError: ReturnType<typeof vi.spyOn>;
|
||||
let mockContext: RequestContext;
|
||||
let mockRequest: GenerateContentParameters;
|
||||
|
||||
beforeEach(() => {
|
||||
mockConsoleError = vi.spyOn(console, 'error').mockImplementation(() => {});
|
||||
|
||||
mockContext = {
|
||||
userPromptId: 'test-prompt-id',
|
||||
model: 'test-model',
|
||||
authType: 'test-auth',
|
||||
startTime: Date.now() - 5000,
|
||||
duration: 5000,
|
||||
isStreaming: false,
|
||||
};
|
||||
|
||||
mockRequest = {
|
||||
model: 'test-model',
|
||||
contents: [{ parts: [{ text: 'test prompt' }] }],
|
||||
};
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.restoreAllMocks();
|
||||
});
|
||||
|
||||
describe('constructor', () => {
|
||||
it('should create instance with default shouldSuppressLogging function', () => {
|
||||
errorHandler = new EnhancedErrorHandler();
|
||||
expect(errorHandler).toBeInstanceOf(EnhancedErrorHandler);
|
||||
});
|
||||
|
||||
it('should create instance with custom shouldSuppressLogging function', () => {
|
||||
const customSuppressLogging = vi.fn(() => true);
|
||||
errorHandler = new EnhancedErrorHandler(customSuppressLogging);
|
||||
expect(errorHandler).toBeInstanceOf(EnhancedErrorHandler);
|
||||
});
|
||||
});
|
||||
|
||||
describe('handle method', () => {
|
||||
beforeEach(() => {
|
||||
errorHandler = new EnhancedErrorHandler();
|
||||
});
|
||||
|
||||
it('should throw the original error for non-timeout errors', () => {
|
||||
const originalError = new Error('Test error');
|
||||
|
||||
expect(() => {
|
||||
errorHandler.handle(originalError, mockContext, mockRequest);
|
||||
}).toThrow(originalError);
|
||||
});
|
||||
|
||||
it('should log error message for non-timeout errors', () => {
|
||||
const originalError = new Error('Test error');
|
||||
|
||||
expect(() => {
|
||||
errorHandler.handle(originalError, mockContext, mockRequest);
|
||||
}).toThrow();
|
||||
|
||||
expect(mockConsoleError).toHaveBeenCalledWith(
|
||||
'OpenAI API Error:',
|
||||
'Test error',
|
||||
);
|
||||
});
|
||||
|
||||
it('should log streaming error message for streaming requests', () => {
|
||||
const streamingContext = { ...mockContext, isStreaming: true };
|
||||
const originalError = new Error('Test streaming error');
|
||||
|
||||
expect(() => {
|
||||
errorHandler.handle(originalError, streamingContext, mockRequest);
|
||||
}).toThrow();
|
||||
|
||||
expect(mockConsoleError).toHaveBeenCalledWith(
|
||||
'OpenAI API Streaming Error:',
|
||||
'Test streaming error',
|
||||
);
|
||||
});
|
||||
|
||||
it('should throw enhanced error message for timeout errors', () => {
|
||||
const timeoutError = new Error('Request timeout');
|
||||
|
||||
expect(() => {
|
||||
errorHandler.handle(timeoutError, mockContext, mockRequest);
|
||||
}).toThrow(/Request timeout after 5s.*Troubleshooting tips:/s);
|
||||
});
|
||||
|
||||
it('should not log error when suppression is enabled', () => {
|
||||
const suppressLogging = vi.fn(() => true);
|
||||
errorHandler = new EnhancedErrorHandler(suppressLogging);
|
||||
const originalError = new Error('Test error');
|
||||
|
||||
expect(() => {
|
||||
errorHandler.handle(originalError, mockContext, mockRequest);
|
||||
}).toThrow();
|
||||
|
||||
expect(mockConsoleError).not.toHaveBeenCalled();
|
||||
expect(suppressLogging).toHaveBeenCalledWith(originalError, mockRequest);
|
||||
});
|
||||
|
||||
it('should handle string errors', () => {
|
||||
const stringError = 'String error message';
|
||||
|
||||
expect(() => {
|
||||
errorHandler.handle(stringError, mockContext, mockRequest);
|
||||
}).toThrow(stringError);
|
||||
|
||||
expect(mockConsoleError).toHaveBeenCalledWith(
|
||||
'OpenAI API Error:',
|
||||
'String error message',
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle null/undefined errors', () => {
|
||||
expect(() => {
|
||||
errorHandler.handle(null, mockContext, mockRequest);
|
||||
}).toThrow();
|
||||
|
||||
expect(() => {
|
||||
errorHandler.handle(undefined, mockContext, mockRequest);
|
||||
}).toThrow();
|
||||
});
|
||||
});
|
||||
|
||||
describe('shouldSuppressErrorLogging method', () => {
|
||||
it('should return false by default', () => {
|
||||
errorHandler = new EnhancedErrorHandler();
|
||||
const result = errorHandler.shouldSuppressErrorLogging(
|
||||
new Error('test'),
|
||||
mockRequest,
|
||||
);
|
||||
expect(result).toBe(false);
|
||||
});
|
||||
|
||||
it('should use custom suppression function', () => {
|
||||
const customSuppressLogging = vi.fn(() => true);
|
||||
errorHandler = new EnhancedErrorHandler(customSuppressLogging);
|
||||
|
||||
const testError = new Error('test');
|
||||
const result = errorHandler.shouldSuppressErrorLogging(
|
||||
testError,
|
||||
mockRequest,
|
||||
);
|
||||
|
||||
expect(result).toBe(true);
|
||||
expect(customSuppressLogging).toHaveBeenCalledWith(
|
||||
testError,
|
||||
mockRequest,
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('timeout error detection', () => {
|
||||
beforeEach(() => {
|
||||
errorHandler = new EnhancedErrorHandler();
|
||||
});
|
||||
|
||||
const timeoutErrorCases = [
|
||||
{ name: 'timeout in message', error: new Error('Connection timeout') },
|
||||
{ name: 'timed out in message', error: new Error('Request timed out') },
|
||||
{
|
||||
name: 'connection timeout',
|
||||
error: new Error('connection timeout occurred'),
|
||||
},
|
||||
{ name: 'request timeout', error: new Error('request timeout error') },
|
||||
{ name: 'read timeout', error: new Error('read timeout happened') },
|
||||
{ name: 'etimedout', error: new Error('ETIMEDOUT error') },
|
||||
{ name: 'esockettimedout', error: new Error('ESOCKETTIMEDOUT error') },
|
||||
{ name: 'deadline exceeded', error: new Error('deadline exceeded') },
|
||||
{
|
||||
name: 'ETIMEDOUT code',
|
||||
error: Object.assign(new Error('Network error'), { code: 'ETIMEDOUT' }),
|
||||
},
|
||||
{
|
||||
name: 'ESOCKETTIMEDOUT code',
|
||||
error: Object.assign(new Error('Socket error'), {
|
||||
code: 'ESOCKETTIMEDOUT',
|
||||
}),
|
||||
},
|
||||
{
|
||||
name: 'timeout type',
|
||||
error: Object.assign(new Error('Error'), { type: 'timeout' }),
|
||||
},
|
||||
];
|
||||
|
||||
timeoutErrorCases.forEach(({ name, error }) => {
|
||||
it(`should detect timeout error: ${name}`, () => {
|
||||
expect(() => {
|
||||
errorHandler.handle(error, mockContext, mockRequest);
|
||||
}).toThrow(/timeout.*Troubleshooting tips:/s);
|
||||
});
|
||||
});
|
||||
|
||||
it('should not detect non-timeout errors as timeout', () => {
|
||||
const regularError = new Error('Regular API error');
|
||||
|
||||
expect(() => {
|
||||
errorHandler.handle(regularError, mockContext, mockRequest);
|
||||
}).toThrow(regularError);
|
||||
|
||||
expect(() => {
|
||||
errorHandler.handle(regularError, mockContext, mockRequest);
|
||||
}).not.toThrow(/Troubleshooting tips:/);
|
||||
});
|
||||
|
||||
it('should handle case-insensitive timeout detection', () => {
|
||||
const uppercaseTimeoutError = new Error('REQUEST TIMEOUT');
|
||||
|
||||
expect(() => {
|
||||
errorHandler.handle(uppercaseTimeoutError, mockContext, mockRequest);
|
||||
}).toThrow(/timeout.*Troubleshooting tips:/s);
|
||||
});
|
||||
});
|
||||
|
||||
describe('error message building', () => {
|
||||
beforeEach(() => {
|
||||
errorHandler = new EnhancedErrorHandler();
|
||||
});
|
||||
|
||||
it('should build timeout error message for non-streaming requests', () => {
|
||||
const timeoutError = new Error('timeout');
|
||||
|
||||
expect(() => {
|
||||
errorHandler.handle(timeoutError, mockContext, mockRequest);
|
||||
}).toThrow(
|
||||
/Request timeout after 5s\. Try reducing input length or increasing timeout in config\./,
|
||||
);
|
||||
});
|
||||
|
||||
it('should build timeout error message for streaming requests', () => {
|
||||
const streamingContext = { ...mockContext, isStreaming: true };
|
||||
const timeoutError = new Error('timeout');
|
||||
|
||||
expect(() => {
|
||||
errorHandler.handle(timeoutError, streamingContext, mockRequest);
|
||||
}).toThrow(
|
||||
/Streaming request timeout after 5s\. Try reducing input length or increasing timeout in config\./,
|
||||
);
|
||||
});
|
||||
|
||||
it('should use original error message for non-timeout errors', () => {
|
||||
const originalError = new Error('Original error message');
|
||||
|
||||
expect(() => {
|
||||
errorHandler.handle(originalError, mockContext, mockRequest);
|
||||
}).toThrow('Original error message');
|
||||
});
|
||||
|
||||
it('should handle non-Error objects', () => {
|
||||
const objectError = { message: 'Object error', code: 500 };
|
||||
|
||||
expect(() => {
|
||||
errorHandler.handle(objectError, mockContext, mockRequest);
|
||||
}).toThrow(); // Non-timeout errors are thrown as-is
|
||||
});
|
||||
|
||||
it('should convert non-Error objects to strings for timeout errors', () => {
|
||||
// Create an object that will be detected as timeout error
|
||||
const objectTimeoutError = {
|
||||
toString: () => 'Connection timeout error',
|
||||
message: 'timeout occurred',
|
||||
code: 500,
|
||||
};
|
||||
|
||||
expect(() => {
|
||||
errorHandler.handle(objectTimeoutError, mockContext, mockRequest);
|
||||
}).toThrow(/Request timeout after 5s.*Troubleshooting tips:/s);
|
||||
});
|
||||
|
||||
it('should handle different duration values correctly', () => {
|
||||
const contextWithDifferentDuration = { ...mockContext, duration: 12345 };
|
||||
const timeoutError = new Error('timeout');
|
||||
|
||||
expect(() => {
|
||||
errorHandler.handle(
|
||||
timeoutError,
|
||||
contextWithDifferentDuration,
|
||||
mockRequest,
|
||||
);
|
||||
}).toThrow(/Request timeout after 12s\./);
|
||||
});
|
||||
});
|
||||
|
||||
describe('troubleshooting tips generation', () => {
|
||||
beforeEach(() => {
|
||||
errorHandler = new EnhancedErrorHandler();
|
||||
});
|
||||
|
||||
it('should provide general troubleshooting tips for non-streaming requests', () => {
|
||||
const timeoutError = new Error('timeout');
|
||||
|
||||
expect(() => {
|
||||
errorHandler.handle(timeoutError, mockContext, mockRequest);
|
||||
}).toThrow(
|
||||
/Troubleshooting tips:\n- Reduce input length or complexity\n- Increase timeout in config: contentGenerator\.timeout\n- Check network connectivity\n- Consider using streaming mode for long responses/,
|
||||
);
|
||||
});
|
||||
|
||||
it('should provide streaming-specific troubleshooting tips for streaming requests', () => {
|
||||
const streamingContext = { ...mockContext, isStreaming: true };
|
||||
const timeoutError = new Error('timeout');
|
||||
|
||||
expect(() => {
|
||||
errorHandler.handle(timeoutError, streamingContext, mockRequest);
|
||||
}).toThrow(
|
||||
/Streaming timeout troubleshooting:\n- Reduce input length or complexity\n- Increase timeout in config: contentGenerator\.timeout\n- Check network connectivity\n- Check network stability for streaming connections\n- Consider using non-streaming mode for very long inputs/,
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('ErrorHandler interface compliance', () => {
|
||||
it('should implement ErrorHandler interface correctly', () => {
|
||||
errorHandler = new EnhancedErrorHandler();
|
||||
|
||||
// Check that the class implements the interface methods
|
||||
expect(typeof errorHandler.handle).toBe('function');
|
||||
expect(typeof errorHandler.shouldSuppressErrorLogging).toBe('function');
|
||||
|
||||
// Check method signatures by calling them
|
||||
expect(() => {
|
||||
errorHandler.handle(new Error('test'), mockContext, mockRequest);
|
||||
}).toThrow();
|
||||
|
||||
expect(
|
||||
errorHandler.shouldSuppressErrorLogging(new Error('test'), mockRequest),
|
||||
).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe('edge cases', () => {
|
||||
beforeEach(() => {
|
||||
errorHandler = new EnhancedErrorHandler();
|
||||
});
|
||||
|
||||
it('should handle zero duration', () => {
|
||||
const zeroContext = { ...mockContext, duration: 0 };
|
||||
const timeoutError = new Error('timeout');
|
||||
|
||||
expect(() => {
|
||||
errorHandler.handle(timeoutError, zeroContext, mockRequest);
|
||||
}).toThrow(/Request timeout after 0s\./);
|
||||
});
|
||||
|
||||
it('should handle negative duration', () => {
|
||||
const negativeContext = { ...mockContext, duration: -1000 };
|
||||
const timeoutError = new Error('timeout');
|
||||
|
||||
expect(() => {
|
||||
errorHandler.handle(timeoutError, negativeContext, mockRequest);
|
||||
}).toThrow(/Request timeout after -1s\./);
|
||||
});
|
||||
|
||||
it('should handle very large duration', () => {
|
||||
const largeContext = { ...mockContext, duration: 999999 };
|
||||
const timeoutError = new Error('timeout');
|
||||
|
||||
expect(() => {
|
||||
errorHandler.handle(timeoutError, largeContext, mockRequest);
|
||||
}).toThrow(/Request timeout after 1000s\./);
|
||||
});
|
||||
|
||||
it('should handle empty error message', () => {
|
||||
const emptyError = new Error('');
|
||||
|
||||
expect(() => {
|
||||
errorHandler.handle(emptyError, mockContext, mockRequest);
|
||||
}).toThrow(emptyError);
|
||||
|
||||
expect(mockConsoleError).toHaveBeenCalledWith('OpenAI API Error:', '');
|
||||
});
|
||||
|
||||
it('should handle error with only whitespace message', () => {
|
||||
const whitespaceError = new Error(' \n\t ');
|
||||
|
||||
expect(() => {
|
||||
errorHandler.handle(whitespaceError, mockContext, mockRequest);
|
||||
}).toThrow(whitespaceError);
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -17,22 +17,17 @@ import {
|
||||
DefaultOpenAICompatibleProvider,
|
||||
} from './provider/index.js';
|
||||
|
||||
// Main classes
|
||||
export { OpenAIContentGenerator } from './openaiContentGenerator.js';
|
||||
export { ContentGenerationPipeline, type PipelineConfig } from './pipeline.js';
|
||||
|
||||
// Providers
|
||||
export {
|
||||
type OpenAICompatibleProvider,
|
||||
DashScopeOpenAICompatibleProvider,
|
||||
OpenRouterOpenAICompatibleProvider,
|
||||
} from './provider/index.js';
|
||||
|
||||
// Utilities
|
||||
export { Converter } from './converter.js';
|
||||
export { StreamingManager } from './streamingManager.js';
|
||||
export { OpenAIContentConverter } from './converter.js';
|
||||
|
||||
// Factory utility functions
|
||||
/**
|
||||
* Create an OpenAI-compatible content generator with the appropriate provider
|
||||
*/
|
||||
|
||||
278
packages/core/src/core/refactor/openaiContentGenerator.test.ts
Normal file
278
packages/core/src/core/refactor/openaiContentGenerator.test.ts
Normal file
@@ -0,0 +1,278 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Qwen
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
|
||||
import { OpenAIContentGenerator } from './openaiContentGenerator.js';
|
||||
import { Config } from '../../config/config.js';
|
||||
import { AuthType } from '../contentGenerator.js';
|
||||
import type {
|
||||
GenerateContentParameters,
|
||||
CountTokensParameters,
|
||||
} from '@google/genai';
|
||||
import type { OpenAICompatibleProvider } from './provider/index.js';
|
||||
import OpenAI from 'openai';
|
||||
|
||||
// Mock tiktoken
|
||||
vi.mock('tiktoken', () => ({
|
||||
get_encoding: vi.fn().mockReturnValue({
|
||||
encode: vi.fn().mockReturnValue(new Array(50)), // Mock 50 tokens
|
||||
free: vi.fn(),
|
||||
}),
|
||||
}));
|
||||
|
||||
describe('OpenAIContentGenerator (Refactored)', () => {
|
||||
let generator: OpenAIContentGenerator;
|
||||
let mockConfig: Config;
|
||||
|
||||
beforeEach(() => {
|
||||
// Reset mocks
|
||||
vi.clearAllMocks();
|
||||
|
||||
// Mock config
|
||||
mockConfig = {
|
||||
getContentGeneratorConfig: vi.fn().mockReturnValue({
|
||||
authType: 'openai',
|
||||
enableOpenAILogging: false,
|
||||
timeout: 120000,
|
||||
maxRetries: 3,
|
||||
samplingParams: {
|
||||
temperature: 0.7,
|
||||
max_tokens: 1000,
|
||||
top_p: 0.9,
|
||||
},
|
||||
}),
|
||||
getCliVersion: vi.fn().mockReturnValue('1.0.0'),
|
||||
} as unknown as Config;
|
||||
|
||||
// Create generator instance
|
||||
const contentGeneratorConfig = {
|
||||
model: 'gpt-4',
|
||||
apiKey: 'test-key',
|
||||
authType: AuthType.USE_OPENAI,
|
||||
enableOpenAILogging: false,
|
||||
timeout: 120000,
|
||||
maxRetries: 3,
|
||||
samplingParams: {
|
||||
temperature: 0.7,
|
||||
max_tokens: 1000,
|
||||
top_p: 0.9,
|
||||
},
|
||||
};
|
||||
|
||||
// Create a minimal mock provider
|
||||
const mockProvider: OpenAICompatibleProvider = {
|
||||
buildHeaders: vi.fn().mockReturnValue({}),
|
||||
buildClient: vi.fn().mockReturnValue({
|
||||
chat: {
|
||||
completions: {
|
||||
create: vi.fn(),
|
||||
},
|
||||
},
|
||||
embeddings: {
|
||||
create: vi.fn(),
|
||||
},
|
||||
} as unknown as OpenAI),
|
||||
buildRequest: vi.fn().mockImplementation((req) => req),
|
||||
};
|
||||
|
||||
generator = new OpenAIContentGenerator(
|
||||
contentGeneratorConfig,
|
||||
mockConfig,
|
||||
mockProvider,
|
||||
);
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.restoreAllMocks();
|
||||
});
|
||||
|
||||
describe('constructor', () => {
|
||||
it('should initialize with basic configuration', () => {
|
||||
expect(generator).toBeDefined();
|
||||
});
|
||||
});
|
||||
|
||||
describe('generateContent', () => {
|
||||
it('should delegate to pipeline.execute', async () => {
|
||||
// This test verifies the method exists and can be called
|
||||
expect(typeof generator.generateContent).toBe('function');
|
||||
});
|
||||
});
|
||||
|
||||
describe('generateContentStream', () => {
|
||||
it('should delegate to pipeline.executeStream', async () => {
|
||||
// This test verifies the method exists and can be called
|
||||
expect(typeof generator.generateContentStream).toBe('function');
|
||||
});
|
||||
});
|
||||
|
||||
describe('countTokens', () => {
|
||||
it('should count tokens using tiktoken', async () => {
|
||||
const request: CountTokensParameters = {
|
||||
contents: [{ role: 'user', parts: [{ text: 'Hello world' }] }],
|
||||
model: 'gpt-4',
|
||||
};
|
||||
|
||||
const result = await generator.countTokens(request);
|
||||
|
||||
expect(result.totalTokens).toBe(50); // Mocked value
|
||||
});
|
||||
|
||||
it('should fall back to character approximation if tiktoken fails', async () => {
|
||||
// Mock tiktoken to throw error
|
||||
vi.doMock('tiktoken', () => ({
|
||||
get_encoding: vi.fn().mockImplementation(() => {
|
||||
throw new Error('Tiktoken failed');
|
||||
}),
|
||||
}));
|
||||
|
||||
const request: CountTokensParameters = {
|
||||
contents: [{ role: 'user', parts: [{ text: 'Hello world' }] }],
|
||||
model: 'gpt-4',
|
||||
};
|
||||
|
||||
const result = await generator.countTokens(request);
|
||||
|
||||
// Should use character approximation (content length / 4)
|
||||
expect(result.totalTokens).toBeGreaterThan(0);
|
||||
});
|
||||
});
|
||||
|
||||
describe('embedContent', () => {
|
||||
it('should delegate to pipeline.client.embeddings.create', async () => {
|
||||
// This test verifies the method exists and can be called
|
||||
expect(typeof generator.embedContent).toBe('function');
|
||||
});
|
||||
});
|
||||
|
||||
describe('shouldSuppressErrorLogging', () => {
|
||||
it('should return false by default', () => {
|
||||
// Create a test subclass to access the protected method
|
||||
class TestGenerator extends OpenAIContentGenerator {
|
||||
testShouldSuppressErrorLogging(
|
||||
error: unknown,
|
||||
request: GenerateContentParameters,
|
||||
): boolean {
|
||||
return this.shouldSuppressErrorLogging(error, request);
|
||||
}
|
||||
}
|
||||
|
||||
const contentGeneratorConfig = {
|
||||
model: 'gpt-4',
|
||||
apiKey: 'test-key',
|
||||
authType: AuthType.USE_OPENAI,
|
||||
enableOpenAILogging: false,
|
||||
timeout: 120000,
|
||||
maxRetries: 3,
|
||||
samplingParams: {
|
||||
temperature: 0.7,
|
||||
max_tokens: 1000,
|
||||
top_p: 0.9,
|
||||
},
|
||||
};
|
||||
|
||||
// Create a minimal mock provider
|
||||
const mockProvider: OpenAICompatibleProvider = {
|
||||
buildHeaders: vi.fn().mockReturnValue({}),
|
||||
buildClient: vi.fn().mockReturnValue({
|
||||
chat: {
|
||||
completions: {
|
||||
create: vi.fn(),
|
||||
},
|
||||
},
|
||||
embeddings: {
|
||||
create: vi.fn(),
|
||||
},
|
||||
} as unknown as OpenAI),
|
||||
buildRequest: vi.fn().mockImplementation((req) => req),
|
||||
};
|
||||
|
||||
const testGenerator = new TestGenerator(
|
||||
contentGeneratorConfig,
|
||||
mockConfig,
|
||||
mockProvider,
|
||||
);
|
||||
|
||||
const request: GenerateContentParameters = {
|
||||
contents: [{ role: 'user', parts: [{ text: 'Hello' }] }],
|
||||
model: 'gpt-4',
|
||||
};
|
||||
|
||||
const result = testGenerator.testShouldSuppressErrorLogging(
|
||||
new Error('Test error'),
|
||||
request,
|
||||
);
|
||||
|
||||
expect(result).toBe(false);
|
||||
});
|
||||
|
||||
it('should allow subclasses to override error suppression behavior', async () => {
|
||||
class TestGenerator extends OpenAIContentGenerator {
|
||||
testShouldSuppressErrorLogging(
|
||||
error: unknown,
|
||||
request: GenerateContentParameters,
|
||||
): boolean {
|
||||
return this.shouldSuppressErrorLogging(error, request);
|
||||
}
|
||||
|
||||
protected override shouldSuppressErrorLogging(
|
||||
_error: unknown,
|
||||
_request: GenerateContentParameters,
|
||||
): boolean {
|
||||
return true; // Always suppress for this test
|
||||
}
|
||||
}
|
||||
|
||||
const contentGeneratorConfig = {
|
||||
model: 'gpt-4',
|
||||
apiKey: 'test-key',
|
||||
authType: AuthType.USE_OPENAI,
|
||||
enableOpenAILogging: false,
|
||||
timeout: 120000,
|
||||
maxRetries: 3,
|
||||
samplingParams: {
|
||||
temperature: 0.7,
|
||||
max_tokens: 1000,
|
||||
top_p: 0.9,
|
||||
},
|
||||
};
|
||||
|
||||
// Create a minimal mock provider
|
||||
const mockProvider: OpenAICompatibleProvider = {
|
||||
buildHeaders: vi.fn().mockReturnValue({}),
|
||||
buildClient: vi.fn().mockReturnValue({
|
||||
chat: {
|
||||
completions: {
|
||||
create: vi.fn(),
|
||||
},
|
||||
},
|
||||
embeddings: {
|
||||
create: vi.fn(),
|
||||
},
|
||||
} as unknown as OpenAI),
|
||||
buildRequest: vi.fn().mockImplementation((req) => req),
|
||||
};
|
||||
|
||||
const testGenerator = new TestGenerator(
|
||||
contentGeneratorConfig,
|
||||
mockConfig,
|
||||
mockProvider,
|
||||
);
|
||||
|
||||
const request: GenerateContentParameters = {
|
||||
contents: [{ role: 'user', parts: [{ text: 'Hello' }] }],
|
||||
model: 'gpt-4',
|
||||
};
|
||||
|
||||
const result = testGenerator.testShouldSuppressErrorLogging(
|
||||
new Error('Test error'),
|
||||
request,
|
||||
);
|
||||
|
||||
expect(result).toBe(true);
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -70,12 +70,82 @@ export class OpenAIContentGenerator implements ContentGenerator {
|
||||
async countTokens(
|
||||
request: CountTokensParameters,
|
||||
): Promise<CountTokensResponse> {
|
||||
return this.pipeline.countTokens(request);
|
||||
// Use tiktoken for accurate token counting
|
||||
const content = JSON.stringify(request.contents);
|
||||
let totalTokens = 0;
|
||||
|
||||
try {
|
||||
const { get_encoding } = await import('tiktoken');
|
||||
const encoding = get_encoding('cl100k_base'); // GPT-4 encoding, but estimate for qwen
|
||||
totalTokens = encoding.encode(content).length;
|
||||
encoding.free();
|
||||
} catch (error) {
|
||||
console.warn(
|
||||
'Failed to load tiktoken, falling back to character approximation:',
|
||||
error,
|
||||
);
|
||||
// Fallback: rough approximation using character count
|
||||
totalTokens = Math.ceil(content.length / 4); // Rough estimate: 1 token ≈ 4 characters
|
||||
}
|
||||
|
||||
return {
|
||||
totalTokens,
|
||||
};
|
||||
}
|
||||
|
||||
async embedContent(
|
||||
request: EmbedContentParameters,
|
||||
): Promise<EmbedContentResponse> {
|
||||
return this.pipeline.embedContent(request);
|
||||
// Extract text from contents
|
||||
let text = '';
|
||||
if (Array.isArray(request.contents)) {
|
||||
text = request.contents
|
||||
.map((content) => {
|
||||
if (typeof content === 'string') return content;
|
||||
if ('parts' in content && content.parts) {
|
||||
return content.parts
|
||||
.map((part) =>
|
||||
typeof part === 'string'
|
||||
? part
|
||||
: 'text' in part
|
||||
? (part as { text?: string }).text || ''
|
||||
: '',
|
||||
)
|
||||
.join(' ');
|
||||
}
|
||||
return '';
|
||||
})
|
||||
.join(' ');
|
||||
} else if (request.contents) {
|
||||
if (typeof request.contents === 'string') {
|
||||
text = request.contents;
|
||||
} else if ('parts' in request.contents && request.contents.parts) {
|
||||
text = request.contents.parts
|
||||
.map((part) =>
|
||||
typeof part === 'string' ? part : 'text' in part ? part.text : '',
|
||||
)
|
||||
.join(' ');
|
||||
}
|
||||
}
|
||||
|
||||
try {
|
||||
const embedding = await this.pipeline.client.embeddings.create({
|
||||
model: 'text-embedding-ada-002', // Default embedding model
|
||||
input: text,
|
||||
});
|
||||
|
||||
return {
|
||||
embeddings: [
|
||||
{
|
||||
values: embedding.data[0].embedding,
|
||||
},
|
||||
],
|
||||
};
|
||||
} catch (error) {
|
||||
console.error('OpenAI API Embedding Error:', error);
|
||||
throw new Error(
|
||||
`OpenAI API error: ${error instanceof Error ? error.message : String(error)}`,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,18 +8,13 @@ import OpenAI from 'openai';
|
||||
import {
|
||||
GenerateContentParameters,
|
||||
GenerateContentResponse,
|
||||
CountTokensParameters,
|
||||
CountTokensResponse,
|
||||
EmbedContentParameters,
|
||||
EmbedContentResponse,
|
||||
} from '@google/genai';
|
||||
import { Config } from '../../config/config.js';
|
||||
import { ContentGeneratorConfig } from '../contentGenerator.js';
|
||||
import { type OpenAICompatibleProvider } from './provider/index.js';
|
||||
import { Converter } from './converter.js';
|
||||
import { OpenAIContentConverter } from './converter.js';
|
||||
import { TelemetryService, RequestContext } from './telemetryService.js';
|
||||
import { ErrorHandler } from './errorHandler.js';
|
||||
import { StreamingManager } from './streamingManager.js';
|
||||
|
||||
export interface PipelineConfig {
|
||||
cliConfig: Config;
|
||||
@@ -31,15 +26,15 @@ export interface PipelineConfig {
|
||||
|
||||
export class ContentGenerationPipeline {
|
||||
client: OpenAI;
|
||||
private converter: Converter;
|
||||
private streamingManager: StreamingManager;
|
||||
private converter: OpenAIContentConverter;
|
||||
private contentGeneratorConfig: ContentGeneratorConfig;
|
||||
|
||||
constructor(private config: PipelineConfig) {
|
||||
this.contentGeneratorConfig = config.contentGeneratorConfig;
|
||||
this.client = this.config.provider.buildClient();
|
||||
this.converter = new Converter(this.contentGeneratorConfig.model);
|
||||
this.streamingManager = new StreamingManager(this.converter);
|
||||
this.converter = new OpenAIContentConverter(
|
||||
this.contentGeneratorConfig.model,
|
||||
);
|
||||
}
|
||||
|
||||
async execute(
|
||||
@@ -80,15 +75,14 @@ export class ContentGenerationPipeline {
|
||||
userPromptId,
|
||||
true,
|
||||
async (openaiRequest, context) => {
|
||||
// Stage 1: Create OpenAI stream
|
||||
const stream = (await this.client.chat.completions.create(
|
||||
openaiRequest,
|
||||
)) as AsyncIterable<OpenAI.Chat.ChatCompletionChunk>;
|
||||
|
||||
const originalStream = this.streamingManager.processStream(stream);
|
||||
|
||||
// Create a logging stream decorator that handles collection and logging
|
||||
return this.createLoggingStream(
|
||||
originalStream,
|
||||
// Stage 2: Process stream with conversion and logging
|
||||
return this.processStreamWithLogging(
|
||||
stream,
|
||||
context,
|
||||
openaiRequest,
|
||||
request,
|
||||
@@ -97,85 +91,68 @@ export class ContentGenerationPipeline {
|
||||
);
|
||||
}
|
||||
|
||||
async countTokens(
|
||||
request: CountTokensParameters,
|
||||
): Promise<CountTokensResponse> {
|
||||
// Use tiktoken for accurate token counting
|
||||
const content = JSON.stringify(request.contents);
|
||||
let totalTokens = 0;
|
||||
/**
|
||||
* Stage 2: Process OpenAI stream with conversion and logging
|
||||
* This method handles the complete stream processing pipeline:
|
||||
* 1. Convert OpenAI chunks to Gemini format while preserving original chunks
|
||||
* 2. Filter empty responses
|
||||
* 3. Collect both formats for logging
|
||||
* 4. Handle success/error logging with original OpenAI format
|
||||
*/
|
||||
private async *processStreamWithLogging(
|
||||
stream: AsyncIterable<OpenAI.Chat.ChatCompletionChunk>,
|
||||
context: RequestContext,
|
||||
openaiRequest: OpenAI.Chat.ChatCompletionCreateParams,
|
||||
request: GenerateContentParameters,
|
||||
): AsyncGenerator<GenerateContentResponse> {
|
||||
const collectedGeminiResponses: GenerateContentResponse[] = [];
|
||||
const collectedOpenAIChunks: OpenAI.Chat.ChatCompletionChunk[] = [];
|
||||
|
||||
// Reset streaming tool calls to prevent data pollution from previous streams
|
||||
this.converter.resetStreamingToolCalls();
|
||||
|
||||
try {
|
||||
const { get_encoding } = await import('tiktoken');
|
||||
const encoding = get_encoding('cl100k_base'); // GPT-4 encoding, but estimate for qwen
|
||||
totalTokens = encoding.encode(content).length;
|
||||
encoding.free();
|
||||
} catch (error) {
|
||||
console.warn(
|
||||
'Failed to load tiktoken, falling back to character approximation:',
|
||||
error,
|
||||
);
|
||||
// Fallback: rough approximation using character count
|
||||
totalTokens = Math.ceil(content.length / 4); // Rough estimate: 1 token ≈ 4 characters
|
||||
}
|
||||
// Stage 2a: Convert and yield each chunk while preserving original
|
||||
for await (const chunk of stream) {
|
||||
const response = this.converter.convertOpenAIChunkToGemini(chunk);
|
||||
|
||||
return {
|
||||
totalTokens,
|
||||
};
|
||||
}
|
||||
// Stage 2b: Filter empty responses to avoid downstream issues
|
||||
if (
|
||||
response.candidates?.[0]?.content?.parts?.length === 0 &&
|
||||
!response.usageMetadata
|
||||
) {
|
||||
continue;
|
||||
}
|
||||
|
||||
async embedContent(
|
||||
request: EmbedContentParameters,
|
||||
): Promise<EmbedContentResponse> {
|
||||
// Extract text from contents
|
||||
let text = '';
|
||||
if (Array.isArray(request.contents)) {
|
||||
text = request.contents
|
||||
.map((content) => {
|
||||
if (typeof content === 'string') return content;
|
||||
if ('parts' in content && content.parts) {
|
||||
return content.parts
|
||||
.map((part) =>
|
||||
typeof part === 'string'
|
||||
? part
|
||||
: 'text' in part
|
||||
? (part as { text?: string }).text || ''
|
||||
: '',
|
||||
)
|
||||
.join(' ');
|
||||
}
|
||||
return '';
|
||||
})
|
||||
.join(' ');
|
||||
} else if (request.contents) {
|
||||
if (typeof request.contents === 'string') {
|
||||
text = request.contents;
|
||||
} else if ('parts' in request.contents && request.contents.parts) {
|
||||
text = request.contents.parts
|
||||
.map((part) =>
|
||||
typeof part === 'string' ? part : 'text' in part ? part.text : '',
|
||||
)
|
||||
.join(' ');
|
||||
// Stage 2c: Collect both formats and yield Gemini format to consumer
|
||||
collectedGeminiResponses.push(response);
|
||||
collectedOpenAIChunks.push(chunk);
|
||||
yield response;
|
||||
}
|
||||
}
|
||||
|
||||
try {
|
||||
const embedding = await this.client.embeddings.create({
|
||||
model: 'text-embedding-ada-002', // Default embedding model
|
||||
input: text,
|
||||
});
|
||||
// Stage 2d: Stream completed successfully - perform logging with original OpenAI chunks
|
||||
context.duration = Date.now() - context.startTime;
|
||||
|
||||
return {
|
||||
embeddings: [
|
||||
{
|
||||
values: embedding.data[0].embedding,
|
||||
},
|
||||
],
|
||||
};
|
||||
} catch (error) {
|
||||
console.error('OpenAI API Embedding Error:', error);
|
||||
throw new Error(
|
||||
`OpenAI API error: ${error instanceof Error ? error.message : String(error)}`,
|
||||
await this.config.telemetryService.logStreamingSuccess(
|
||||
context,
|
||||
collectedGeminiResponses,
|
||||
openaiRequest,
|
||||
collectedOpenAIChunks,
|
||||
);
|
||||
} catch (error) {
|
||||
// Stage 2e: Stream failed - handle error and logging
|
||||
context.duration = Date.now() - context.startTime;
|
||||
|
||||
// Clear streaming tool calls on error to prevent data pollution
|
||||
this.converter.resetStreamingToolCalls();
|
||||
|
||||
await this.config.telemetryService.logError(
|
||||
context,
|
||||
error,
|
||||
openaiRequest,
|
||||
);
|
||||
|
||||
this.config.errorHandler.handle(error, context, request);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -267,55 +244,6 @@ export class ContentGenerationPipeline {
|
||||
return params;
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a stream decorator that collects responses and handles logging
|
||||
*/
|
||||
private async *createLoggingStream(
|
||||
originalStream: AsyncGenerator<GenerateContentResponse>,
|
||||
context: RequestContext,
|
||||
openaiRequest: OpenAI.Chat.ChatCompletionCreateParams,
|
||||
request: GenerateContentParameters,
|
||||
): AsyncGenerator<GenerateContentResponse> {
|
||||
const responses: GenerateContentResponse[] = [];
|
||||
|
||||
try {
|
||||
// Yield all responses while collecting them
|
||||
for await (const response of originalStream) {
|
||||
responses.push(response);
|
||||
yield response;
|
||||
}
|
||||
|
||||
// Stream completed successfully - perform logging
|
||||
context.duration = Date.now() - context.startTime;
|
||||
|
||||
const combinedResponse =
|
||||
this.streamingManager.combineStreamResponsesForLogging(
|
||||
responses,
|
||||
this.contentGeneratorConfig.model,
|
||||
);
|
||||
const openaiResponse =
|
||||
this.converter.convertGeminiResponseToOpenAI(combinedResponse);
|
||||
|
||||
await this.config.telemetryService.logStreamingSuccess(
|
||||
context,
|
||||
responses,
|
||||
openaiRequest,
|
||||
openaiResponse,
|
||||
);
|
||||
} catch (error) {
|
||||
// Stream failed - handle error and logging
|
||||
context.duration = Date.now() - context.startTime;
|
||||
|
||||
await this.config.telemetryService.logError(
|
||||
context,
|
||||
error,
|
||||
openaiRequest,
|
||||
);
|
||||
|
||||
this.config.errorHandler.handle(error, context, request);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Common error handling wrapper for execute methods
|
||||
*/
|
||||
|
||||
562
packages/core/src/core/refactor/provider/dashscope.test.ts
Normal file
562
packages/core/src/core/refactor/provider/dashscope.test.ts
Normal file
@@ -0,0 +1,562 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Qwen
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import {
|
||||
describe,
|
||||
it,
|
||||
expect,
|
||||
vi,
|
||||
beforeEach,
|
||||
type MockedFunction,
|
||||
} from 'vitest';
|
||||
import OpenAI from 'openai';
|
||||
import { DashScopeOpenAICompatibleProvider } from './dashscope.js';
|
||||
import { Config } from '../../../config/config.js';
|
||||
import { AuthType, ContentGeneratorConfig } from '../../contentGenerator.js';
|
||||
import { DEFAULT_TIMEOUT, DEFAULT_MAX_RETRIES } from '../constants.js';
|
||||
|
||||
// Mock OpenAI
|
||||
vi.mock('openai', () => ({
|
||||
default: vi.fn().mockImplementation((config) => ({
|
||||
config,
|
||||
chat: {
|
||||
completions: {
|
||||
create: vi.fn(),
|
||||
},
|
||||
},
|
||||
})),
|
||||
}));
|
||||
|
||||
describe('DashScopeOpenAICompatibleProvider', () => {
|
||||
let provider: DashScopeOpenAICompatibleProvider;
|
||||
let mockContentGeneratorConfig: ContentGeneratorConfig;
|
||||
let mockCliConfig: Config;
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
|
||||
// Mock ContentGeneratorConfig
|
||||
mockContentGeneratorConfig = {
|
||||
apiKey: 'test-api-key',
|
||||
baseUrl: 'https://dashscope.aliyuncs.com/compatible-mode/v1',
|
||||
timeout: 60000,
|
||||
maxRetries: 2,
|
||||
model: 'qwen-max',
|
||||
authType: AuthType.QWEN_OAUTH,
|
||||
} as ContentGeneratorConfig;
|
||||
|
||||
// Mock Config
|
||||
mockCliConfig = {
|
||||
getCliVersion: vi.fn().mockReturnValue('1.0.0'),
|
||||
getSessionId: vi.fn().mockReturnValue('test-session-id'),
|
||||
getContentGeneratorConfig: vi.fn().mockReturnValue({
|
||||
disableCacheControl: false,
|
||||
}),
|
||||
} as unknown as Config;
|
||||
|
||||
provider = new DashScopeOpenAICompatibleProvider(
|
||||
mockContentGeneratorConfig,
|
||||
mockCliConfig,
|
||||
);
|
||||
});
|
||||
|
||||
describe('constructor', () => {
|
||||
it('should initialize with provided configs', () => {
|
||||
expect(provider).toBeInstanceOf(DashScopeOpenAICompatibleProvider);
|
||||
});
|
||||
});
|
||||
|
||||
describe('isDashScopeProvider', () => {
|
||||
it('should return true for QWEN_OAUTH auth type', () => {
|
||||
const config = {
|
||||
authType: AuthType.QWEN_OAUTH,
|
||||
baseUrl: 'https://api.openai.com/v1',
|
||||
} as ContentGeneratorConfig;
|
||||
|
||||
const result =
|
||||
DashScopeOpenAICompatibleProvider.isDashScopeProvider(config);
|
||||
expect(result).toBe(true);
|
||||
});
|
||||
|
||||
it('should return true for DashScope domestic URL', () => {
|
||||
const config = {
|
||||
authType: AuthType.USE_OPENAI,
|
||||
baseUrl: 'https://dashscope.aliyuncs.com/compatible-mode/v1',
|
||||
} as ContentGeneratorConfig;
|
||||
|
||||
const result =
|
||||
DashScopeOpenAICompatibleProvider.isDashScopeProvider(config);
|
||||
expect(result).toBe(true);
|
||||
});
|
||||
|
||||
it('should return true for DashScope international URL', () => {
|
||||
const config = {
|
||||
authType: AuthType.USE_OPENAI,
|
||||
baseUrl: 'https://dashscope-intl.aliyuncs.com/compatible-mode/v1',
|
||||
} as ContentGeneratorConfig;
|
||||
|
||||
const result =
|
||||
DashScopeOpenAICompatibleProvider.isDashScopeProvider(config);
|
||||
expect(result).toBe(true);
|
||||
});
|
||||
|
||||
it('should return false for non-DashScope configurations', () => {
|
||||
const configs = [
|
||||
{
|
||||
authType: AuthType.USE_OPENAI,
|
||||
baseUrl: 'https://api.openai.com/v1',
|
||||
},
|
||||
{
|
||||
authType: AuthType.USE_OPENAI,
|
||||
baseUrl: 'https://api.anthropic.com/v1',
|
||||
},
|
||||
{
|
||||
authType: AuthType.USE_OPENAI,
|
||||
baseUrl: 'https://openrouter.ai/api/v1',
|
||||
},
|
||||
];
|
||||
|
||||
configs.forEach((config) => {
|
||||
const result = DashScopeOpenAICompatibleProvider.isDashScopeProvider(
|
||||
config as ContentGeneratorConfig,
|
||||
);
|
||||
expect(result).toBe(false);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('buildHeaders', () => {
|
||||
it('should build DashScope-specific headers', () => {
|
||||
const headers = provider.buildHeaders();
|
||||
|
||||
expect(headers).toEqual({
|
||||
'User-Agent': `QwenCode/1.0.0 (${process.platform}; ${process.arch})`,
|
||||
'X-DashScope-CacheControl': 'enable',
|
||||
'X-DashScope-UserAgent': `QwenCode/1.0.0 (${process.platform}; ${process.arch})`,
|
||||
'X-DashScope-AuthType': AuthType.QWEN_OAUTH,
|
||||
});
|
||||
});
|
||||
|
||||
it('should handle unknown CLI version', () => {
|
||||
(
|
||||
mockCliConfig.getCliVersion as MockedFunction<
|
||||
typeof mockCliConfig.getCliVersion
|
||||
>
|
||||
).mockReturnValue(undefined);
|
||||
|
||||
const headers = provider.buildHeaders();
|
||||
|
||||
expect(headers['User-Agent']).toBe(
|
||||
`QwenCode/unknown (${process.platform}; ${process.arch})`,
|
||||
);
|
||||
expect(headers['X-DashScope-UserAgent']).toBe(
|
||||
`QwenCode/unknown (${process.platform}; ${process.arch})`,
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('buildClient', () => {
|
||||
it('should create OpenAI client with DashScope configuration', () => {
|
||||
const client = provider.buildClient();
|
||||
|
||||
expect(OpenAI).toHaveBeenCalledWith({
|
||||
apiKey: 'test-api-key',
|
||||
baseURL: 'https://dashscope.aliyuncs.com/compatible-mode/v1',
|
||||
timeout: 60000,
|
||||
maxRetries: 2,
|
||||
defaultHeaders: {
|
||||
'User-Agent': `QwenCode/1.0.0 (${process.platform}; ${process.arch})`,
|
||||
'X-DashScope-CacheControl': 'enable',
|
||||
'X-DashScope-UserAgent': `QwenCode/1.0.0 (${process.platform}; ${process.arch})`,
|
||||
'X-DashScope-AuthType': AuthType.QWEN_OAUTH,
|
||||
},
|
||||
});
|
||||
|
||||
expect(client).toBeDefined();
|
||||
});
|
||||
|
||||
it('should use default timeout and maxRetries when not provided', () => {
|
||||
mockContentGeneratorConfig.timeout = undefined;
|
||||
mockContentGeneratorConfig.maxRetries = undefined;
|
||||
|
||||
provider.buildClient();
|
||||
|
||||
expect(OpenAI).toHaveBeenCalledWith({
|
||||
apiKey: 'test-api-key',
|
||||
baseURL: 'https://dashscope.aliyuncs.com/compatible-mode/v1',
|
||||
timeout: DEFAULT_TIMEOUT,
|
||||
maxRetries: DEFAULT_MAX_RETRIES,
|
||||
defaultHeaders: expect.any(Object),
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('buildMetadata', () => {
|
||||
it('should build metadata with session and prompt IDs', () => {
|
||||
const userPromptId = 'test-prompt-id';
|
||||
const metadata = provider.buildMetadata(userPromptId);
|
||||
|
||||
expect(metadata).toEqual({
|
||||
metadata: {
|
||||
sessionId: 'test-session-id',
|
||||
promptId: 'test-prompt-id',
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
it('should handle missing session ID', () => {
|
||||
// Mock the method to not exist (simulate optional chaining returning undefined)
|
||||
delete (mockCliConfig as unknown as Record<string, unknown>)[
|
||||
'getSessionId'
|
||||
];
|
||||
|
||||
const userPromptId = 'test-prompt-id';
|
||||
const metadata = provider.buildMetadata(userPromptId);
|
||||
|
||||
expect(metadata).toEqual({
|
||||
metadata: {
|
||||
sessionId: undefined,
|
||||
promptId: 'test-prompt-id',
|
||||
},
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('buildRequest', () => {
|
||||
const baseRequest: OpenAI.Chat.ChatCompletionCreateParams = {
|
||||
model: 'qwen-max',
|
||||
messages: [
|
||||
{ role: 'system', content: 'You are a helpful assistant.' },
|
||||
{ role: 'user', content: 'Hello!' },
|
||||
],
|
||||
temperature: 0.7,
|
||||
};
|
||||
|
||||
it('should add cache control to system message only for non-streaming requests', () => {
|
||||
const request = { ...baseRequest, stream: false };
|
||||
const result = provider.buildRequest(request, 'test-prompt-id');
|
||||
|
||||
expect(result.messages).toHaveLength(2);
|
||||
|
||||
// System message should have cache control
|
||||
const systemMessage = result.messages[0];
|
||||
expect(systemMessage.role).toBe('system');
|
||||
expect(systemMessage.content).toEqual([
|
||||
{
|
||||
type: 'text',
|
||||
text: 'You are a helpful assistant.',
|
||||
cache_control: { type: 'ephemeral' },
|
||||
},
|
||||
]);
|
||||
|
||||
// Last message should NOT have cache control for non-streaming
|
||||
const lastMessage = result.messages[1];
|
||||
expect(lastMessage.role).toBe('user');
|
||||
expect(lastMessage.content).toBe('Hello!');
|
||||
});
|
||||
|
||||
it('should add cache control to both system and last messages for streaming requests', () => {
|
||||
const request = { ...baseRequest, stream: true };
|
||||
const result = provider.buildRequest(request, 'test-prompt-id');
|
||||
|
||||
expect(result.messages).toHaveLength(2);
|
||||
|
||||
// System message should have cache control
|
||||
const systemMessage = result.messages[0];
|
||||
expect(systemMessage.content).toEqual([
|
||||
{
|
||||
type: 'text',
|
||||
text: 'You are a helpful assistant.',
|
||||
cache_control: { type: 'ephemeral' },
|
||||
},
|
||||
]);
|
||||
|
||||
// Last message should also have cache control for streaming
|
||||
const lastMessage = result.messages[1];
|
||||
expect(lastMessage.content).toEqual([
|
||||
{
|
||||
type: 'text',
|
||||
text: 'Hello!',
|
||||
cache_control: { type: 'ephemeral' },
|
||||
},
|
||||
]);
|
||||
});
|
||||
|
||||
it('should include metadata in the request', () => {
|
||||
const result = provider.buildRequest(baseRequest, 'test-prompt-id');
|
||||
|
||||
expect(result.metadata).toEqual({
|
||||
sessionId: 'test-session-id',
|
||||
promptId: 'test-prompt-id',
|
||||
});
|
||||
});
|
||||
|
||||
it('should preserve all original request parameters', () => {
|
||||
const complexRequest: OpenAI.Chat.ChatCompletionCreateParams = {
|
||||
...baseRequest,
|
||||
temperature: 0.8,
|
||||
max_tokens: 1000,
|
||||
top_p: 0.9,
|
||||
frequency_penalty: 0.1,
|
||||
presence_penalty: 0.2,
|
||||
stop: ['END'],
|
||||
user: 'test-user',
|
||||
};
|
||||
|
||||
const result = provider.buildRequest(complexRequest, 'test-prompt-id');
|
||||
|
||||
expect(result.model).toBe('qwen-max');
|
||||
expect(result.temperature).toBe(0.8);
|
||||
expect(result.max_tokens).toBe(1000);
|
||||
expect(result.top_p).toBe(0.9);
|
||||
expect(result.frequency_penalty).toBe(0.1);
|
||||
expect(result.presence_penalty).toBe(0.2);
|
||||
expect(result.stop).toEqual(['END']);
|
||||
expect(result.user).toBe('test-user');
|
||||
});
|
||||
|
||||
it('should skip cache control when disabled', () => {
|
||||
(
|
||||
mockCliConfig.getContentGeneratorConfig as MockedFunction<
|
||||
typeof mockCliConfig.getContentGeneratorConfig
|
||||
>
|
||||
).mockReturnValue({
|
||||
model: 'qwen-max',
|
||||
disableCacheControl: true,
|
||||
});
|
||||
|
||||
const result = provider.buildRequest(baseRequest, 'test-prompt-id');
|
||||
|
||||
// Messages should remain as strings (not converted to array format)
|
||||
expect(result.messages[0].content).toBe('You are a helpful assistant.');
|
||||
expect(result.messages[1].content).toBe('Hello!');
|
||||
});
|
||||
|
||||
it('should handle messages with array content for streaming requests', () => {
|
||||
const requestWithArrayContent: OpenAI.Chat.ChatCompletionCreateParams = {
|
||||
model: 'qwen-max',
|
||||
stream: true, // This will trigger cache control on last message
|
||||
messages: [
|
||||
{
|
||||
role: 'user',
|
||||
content: [
|
||||
{ type: 'text', text: 'Hello' },
|
||||
{ type: 'text', text: 'World' },
|
||||
],
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
const result = provider.buildRequest(
|
||||
requestWithArrayContent,
|
||||
'test-prompt-id',
|
||||
);
|
||||
|
||||
const message = result.messages[0];
|
||||
expect(Array.isArray(message.content)).toBe(true);
|
||||
const content =
|
||||
message.content as OpenAI.Chat.ChatCompletionContentPart[];
|
||||
expect(content).toHaveLength(2);
|
||||
expect(content[1]).toEqual({
|
||||
type: 'text',
|
||||
text: 'World',
|
||||
cache_control: { type: 'ephemeral' },
|
||||
});
|
||||
});
|
||||
|
||||
it('should handle empty messages array', () => {
|
||||
const emptyRequest: OpenAI.Chat.ChatCompletionCreateParams = {
|
||||
model: 'qwen-max',
|
||||
messages: [],
|
||||
};
|
||||
|
||||
const result = provider.buildRequest(emptyRequest, 'test-prompt-id');
|
||||
|
||||
expect(result.messages).toEqual([]);
|
||||
expect(result.metadata).toBeDefined();
|
||||
});
|
||||
|
||||
it('should handle messages without content for streaming requests', () => {
|
||||
const requestWithoutContent: OpenAI.Chat.ChatCompletionCreateParams = {
|
||||
model: 'qwen-max',
|
||||
stream: true, // This will trigger cache control on last message
|
||||
messages: [
|
||||
{ role: 'assistant', content: null },
|
||||
{ role: 'user', content: 'Hello' },
|
||||
],
|
||||
};
|
||||
|
||||
const result = provider.buildRequest(
|
||||
requestWithoutContent,
|
||||
'test-prompt-id',
|
||||
);
|
||||
|
||||
// First message should remain unchanged
|
||||
expect(result.messages[0].content).toBeNull();
|
||||
|
||||
// Second message should have cache control (it's the last message in streaming)
|
||||
expect(result.messages[1].content).toEqual([
|
||||
{
|
||||
type: 'text',
|
||||
text: 'Hello',
|
||||
cache_control: { type: 'ephemeral' },
|
||||
},
|
||||
]);
|
||||
});
|
||||
|
||||
it('should add cache control to last text item in mixed content for streaming requests', () => {
|
||||
const requestWithMixedContent: OpenAI.Chat.ChatCompletionCreateParams = {
|
||||
model: 'qwen-max',
|
||||
stream: true, // This will trigger cache control on last message
|
||||
messages: [
|
||||
{
|
||||
role: 'user',
|
||||
content: [
|
||||
{ type: 'text', text: 'Look at this image:' },
|
||||
{
|
||||
type: 'image_url',
|
||||
image_url: { url: 'https://example.com/image.jpg' },
|
||||
},
|
||||
{ type: 'text', text: 'What do you see?' },
|
||||
],
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
const result = provider.buildRequest(
|
||||
requestWithMixedContent,
|
||||
'test-prompt-id',
|
||||
);
|
||||
|
||||
const content = result.messages[0]
|
||||
.content as OpenAI.Chat.ChatCompletionContentPart[];
|
||||
expect(content).toHaveLength(3);
|
||||
|
||||
// Last text item should have cache control
|
||||
expect(content[2]).toEqual({
|
||||
type: 'text',
|
||||
text: 'What do you see?',
|
||||
cache_control: { type: 'ephemeral' },
|
||||
});
|
||||
|
||||
// Image item should remain unchanged
|
||||
expect(content[1]).toEqual({
|
||||
type: 'image_url',
|
||||
image_url: { url: 'https://example.com/image.jpg' },
|
||||
});
|
||||
});
|
||||
|
||||
it('should add empty text item with cache control if last item is not text for streaming requests', () => {
|
||||
const requestWithNonTextLast: OpenAI.Chat.ChatCompletionCreateParams = {
|
||||
model: 'qwen-max',
|
||||
stream: true, // This will trigger cache control on last message
|
||||
messages: [
|
||||
{
|
||||
role: 'user',
|
||||
content: [
|
||||
{ type: 'text', text: 'Look at this:' },
|
||||
{
|
||||
type: 'image_url',
|
||||
image_url: { url: 'https://example.com/image.jpg' },
|
||||
},
|
||||
],
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
const result = provider.buildRequest(
|
||||
requestWithNonTextLast,
|
||||
'test-prompt-id',
|
||||
);
|
||||
|
||||
const content = result.messages[0]
|
||||
.content as OpenAI.Chat.ChatCompletionContentPart[];
|
||||
expect(content).toHaveLength(3);
|
||||
|
||||
// Should add empty text item with cache control
|
||||
expect(content[2]).toEqual({
|
||||
type: 'text',
|
||||
text: '',
|
||||
cache_control: { type: 'ephemeral' },
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('cache control edge cases', () => {
|
||||
it('should handle request with only system message', () => {
|
||||
const systemOnlyRequest: OpenAI.Chat.ChatCompletionCreateParams = {
|
||||
model: 'qwen-max',
|
||||
messages: [{ role: 'system', content: 'System prompt' }],
|
||||
};
|
||||
|
||||
const result = provider.buildRequest(systemOnlyRequest, 'test-prompt-id');
|
||||
|
||||
expect(result.messages).toHaveLength(1);
|
||||
expect(result.messages[0].content).toEqual([
|
||||
{
|
||||
type: 'text',
|
||||
text: 'System prompt',
|
||||
cache_control: { type: 'ephemeral' },
|
||||
},
|
||||
]);
|
||||
});
|
||||
|
||||
it('should handle request without system message for streaming requests', () => {
|
||||
const noSystemRequest: OpenAI.Chat.ChatCompletionCreateParams = {
|
||||
model: 'qwen-max',
|
||||
stream: true, // This will trigger cache control on last message
|
||||
messages: [
|
||||
{ role: 'user', content: 'First message' },
|
||||
{ role: 'assistant', content: 'Response' },
|
||||
{ role: 'user', content: 'Second message' },
|
||||
],
|
||||
};
|
||||
|
||||
const result = provider.buildRequest(noSystemRequest, 'test-prompt-id');
|
||||
|
||||
expect(result.messages).toHaveLength(3);
|
||||
|
||||
// Only last message should have cache control (no system message to modify)
|
||||
expect(result.messages[0].content).toBe('First message');
|
||||
expect(result.messages[1].content).toBe('Response');
|
||||
expect(result.messages[2].content).toEqual([
|
||||
{
|
||||
type: 'text',
|
||||
text: 'Second message',
|
||||
cache_control: { type: 'ephemeral' },
|
||||
},
|
||||
]);
|
||||
});
|
||||
|
||||
it('should handle empty content array for streaming requests', () => {
|
||||
const emptyContentRequest: OpenAI.Chat.ChatCompletionCreateParams = {
|
||||
model: 'qwen-max',
|
||||
stream: true, // This will trigger cache control on last message
|
||||
messages: [
|
||||
{
|
||||
role: 'user',
|
||||
content: [],
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
const result = provider.buildRequest(
|
||||
emptyContentRequest,
|
||||
'test-prompt-id',
|
||||
);
|
||||
|
||||
const content = result.messages[0]
|
||||
.content as OpenAI.Chat.ChatCompletionContentPart[];
|
||||
expect(content).toEqual([
|
||||
{
|
||||
type: 'text',
|
||||
text: '',
|
||||
cache_control: { type: 'ephemeral' },
|
||||
},
|
||||
]);
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -131,72 +131,110 @@ export class DashScopeOpenAICompatibleProvider
|
||||
target: 'system' | 'last',
|
||||
): OpenAI.Chat.ChatCompletionMessageParam[] {
|
||||
const updatedMessages = [...messages];
|
||||
let messageIndex: number;
|
||||
const messageIndex = this.findTargetMessageIndex(messages, target);
|
||||
|
||||
if (target === 'system') {
|
||||
// Find the first system message
|
||||
messageIndex = messages.findIndex((msg) => msg.role === 'system');
|
||||
if (messageIndex === -1) {
|
||||
return updatedMessages;
|
||||
}
|
||||
} else {
|
||||
// Get the last message
|
||||
messageIndex = messages.length - 1;
|
||||
if (messageIndex === -1) {
|
||||
return updatedMessages;
|
||||
}
|
||||
|
||||
const message = updatedMessages[messageIndex];
|
||||
|
||||
// Only process messages that have content
|
||||
if ('content' in message && message.content !== null) {
|
||||
if (typeof message.content === 'string') {
|
||||
// Convert string content to array format with cache control
|
||||
const messageWithArrayContent = {
|
||||
...message,
|
||||
content: [
|
||||
{
|
||||
type: 'text',
|
||||
text: message.content,
|
||||
cache_control: { type: 'ephemeral' },
|
||||
} as ChatCompletionContentPartTextWithCache,
|
||||
],
|
||||
};
|
||||
updatedMessages[messageIndex] =
|
||||
messageWithArrayContent as OpenAI.Chat.ChatCompletionMessageParam;
|
||||
} else if (Array.isArray(message.content)) {
|
||||
// If content is already an array, add cache_control to the last item
|
||||
const contentArray = [
|
||||
...message.content,
|
||||
] as ChatCompletionContentPartWithCache[];
|
||||
if (contentArray.length > 0) {
|
||||
const lastItem = contentArray[contentArray.length - 1];
|
||||
if (lastItem.type === 'text') {
|
||||
// Add cache_control to the last text item
|
||||
contentArray[contentArray.length - 1] = {
|
||||
...lastItem,
|
||||
cache_control: { type: 'ephemeral' },
|
||||
} as ChatCompletionContentPartTextWithCache;
|
||||
} else {
|
||||
// If the last item is not text, add a new text item with cache_control
|
||||
contentArray.push({
|
||||
type: 'text',
|
||||
text: '',
|
||||
cache_control: { type: 'ephemeral' },
|
||||
} as ChatCompletionContentPartTextWithCache);
|
||||
}
|
||||
|
||||
const messageWithCache = {
|
||||
...message,
|
||||
content: contentArray,
|
||||
};
|
||||
updatedMessages[messageIndex] =
|
||||
messageWithCache as OpenAI.Chat.ChatCompletionMessageParam;
|
||||
}
|
||||
}
|
||||
if (
|
||||
'content' in message &&
|
||||
message.content !== null &&
|
||||
message.content !== undefined
|
||||
) {
|
||||
const updatedContent = this.addCacheControlToContent(message.content);
|
||||
updatedMessages[messageIndex] = {
|
||||
...message,
|
||||
content: updatedContent,
|
||||
} as OpenAI.Chat.ChatCompletionMessageParam;
|
||||
}
|
||||
|
||||
return updatedMessages;
|
||||
}
|
||||
|
||||
/**
|
||||
* Find the index of the target message (system or last)
|
||||
*/
|
||||
private findTargetMessageIndex(
|
||||
messages: OpenAI.Chat.ChatCompletionMessageParam[],
|
||||
target: 'system' | 'last',
|
||||
): number {
|
||||
if (target === 'system') {
|
||||
return messages.findIndex((msg) => msg.role === 'system');
|
||||
} else {
|
||||
return messages.length - 1;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Add cache control to message content, handling both string and array formats
|
||||
*/
|
||||
private addCacheControlToContent(
|
||||
content: NonNullable<OpenAI.Chat.ChatCompletionMessageParam['content']>,
|
||||
): ChatCompletionContentPartWithCache[] {
|
||||
// Convert content to array format if it's a string
|
||||
const contentArray = this.normalizeContentToArray(content);
|
||||
|
||||
// Add cache control to the last text item or create one if needed
|
||||
return this.addCacheControlToContentArray(contentArray);
|
||||
}
|
||||
|
||||
/**
|
||||
* Normalize content to array format
|
||||
*/
|
||||
private normalizeContentToArray(
|
||||
content: NonNullable<OpenAI.Chat.ChatCompletionMessageParam['content']>,
|
||||
): ChatCompletionContentPartWithCache[] {
|
||||
if (typeof content === 'string') {
|
||||
return [
|
||||
{
|
||||
type: 'text',
|
||||
text: content,
|
||||
} as ChatCompletionContentPartTextWithCache,
|
||||
];
|
||||
}
|
||||
return [...content] as ChatCompletionContentPartWithCache[];
|
||||
}
|
||||
|
||||
/**
|
||||
* Add cache control to the content array
|
||||
*/
|
||||
private addCacheControlToContentArray(
|
||||
contentArray: ChatCompletionContentPartWithCache[],
|
||||
): ChatCompletionContentPartWithCache[] {
|
||||
if (contentArray.length === 0) {
|
||||
return [
|
||||
{
|
||||
type: 'text',
|
||||
text: '',
|
||||
cache_control: { type: 'ephemeral' },
|
||||
} as ChatCompletionContentPartTextWithCache,
|
||||
];
|
||||
}
|
||||
|
||||
const lastItem = contentArray[contentArray.length - 1];
|
||||
|
||||
if (lastItem.type === 'text') {
|
||||
// Add cache_control to the last text item
|
||||
contentArray[contentArray.length - 1] = {
|
||||
...lastItem,
|
||||
cache_control: { type: 'ephemeral' },
|
||||
} as ChatCompletionContentPartTextWithCache;
|
||||
} else {
|
||||
// If the last item is not text, add a new text item with cache_control
|
||||
contentArray.push({
|
||||
type: 'text',
|
||||
text: '',
|
||||
cache_control: { type: 'ephemeral' },
|
||||
} as ChatCompletionContentPartTextWithCache);
|
||||
}
|
||||
|
||||
return contentArray;
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if cache control should be disabled based on configuration.
|
||||
*
|
||||
|
||||
229
packages/core/src/core/refactor/provider/default.test.ts
Normal file
229
packages/core/src/core/refactor/provider/default.test.ts
Normal file
@@ -0,0 +1,229 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Qwen
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import {
|
||||
describe,
|
||||
it,
|
||||
expect,
|
||||
vi,
|
||||
beforeEach,
|
||||
type MockedFunction,
|
||||
} from 'vitest';
|
||||
import OpenAI from 'openai';
|
||||
import { DefaultOpenAICompatibleProvider } from './default.js';
|
||||
import { Config } from '../../../config/config.js';
|
||||
import { ContentGeneratorConfig } from '../../contentGenerator.js';
|
||||
import { DEFAULT_TIMEOUT, DEFAULT_MAX_RETRIES } from '../constants.js';
|
||||
|
||||
// Mock OpenAI
|
||||
vi.mock('openai', () => ({
|
||||
default: vi.fn().mockImplementation((config) => ({
|
||||
config,
|
||||
chat: {
|
||||
completions: {
|
||||
create: vi.fn(),
|
||||
},
|
||||
},
|
||||
})),
|
||||
}));
|
||||
|
||||
describe('DefaultOpenAICompatibleProvider', () => {
|
||||
let provider: DefaultOpenAICompatibleProvider;
|
||||
let mockContentGeneratorConfig: ContentGeneratorConfig;
|
||||
let mockCliConfig: Config;
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
|
||||
// Mock ContentGeneratorConfig
|
||||
mockContentGeneratorConfig = {
|
||||
apiKey: 'test-api-key',
|
||||
baseUrl: 'https://api.openai.com/v1',
|
||||
timeout: 60000,
|
||||
maxRetries: 2,
|
||||
model: 'gpt-4',
|
||||
} as ContentGeneratorConfig;
|
||||
|
||||
// Mock Config
|
||||
mockCliConfig = {
|
||||
getCliVersion: vi.fn().mockReturnValue('1.0.0'),
|
||||
} as unknown as Config;
|
||||
|
||||
provider = new DefaultOpenAICompatibleProvider(
|
||||
mockContentGeneratorConfig,
|
||||
mockCliConfig,
|
||||
);
|
||||
});
|
||||
|
||||
describe('constructor', () => {
|
||||
it('should initialize with provided configs', () => {
|
||||
expect(provider).toBeInstanceOf(DefaultOpenAICompatibleProvider);
|
||||
});
|
||||
});
|
||||
|
||||
describe('buildHeaders', () => {
|
||||
it('should build headers with User-Agent', () => {
|
||||
const headers = provider.buildHeaders();
|
||||
|
||||
expect(headers).toEqual({
|
||||
'User-Agent': `QwenCode/1.0.0 (${process.platform}; ${process.arch})`,
|
||||
});
|
||||
});
|
||||
|
||||
it('should handle unknown CLI version', () => {
|
||||
(
|
||||
mockCliConfig.getCliVersion as MockedFunction<
|
||||
typeof mockCliConfig.getCliVersion
|
||||
>
|
||||
).mockReturnValue(undefined);
|
||||
|
||||
const headers = provider.buildHeaders();
|
||||
|
||||
expect(headers).toEqual({
|
||||
'User-Agent': `QwenCode/unknown (${process.platform}; ${process.arch})`,
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('buildClient', () => {
|
||||
it('should create OpenAI client with correct configuration', () => {
|
||||
const client = provider.buildClient();
|
||||
|
||||
expect(OpenAI).toHaveBeenCalledWith({
|
||||
apiKey: 'test-api-key',
|
||||
baseURL: 'https://api.openai.com/v1',
|
||||
timeout: 60000,
|
||||
maxRetries: 2,
|
||||
defaultHeaders: {
|
||||
'User-Agent': `QwenCode/1.0.0 (${process.platform}; ${process.arch})`,
|
||||
},
|
||||
});
|
||||
|
||||
expect(client).toBeDefined();
|
||||
});
|
||||
|
||||
it('should use default timeout and maxRetries when not provided', () => {
|
||||
mockContentGeneratorConfig.timeout = undefined;
|
||||
mockContentGeneratorConfig.maxRetries = undefined;
|
||||
|
||||
provider.buildClient();
|
||||
|
||||
expect(OpenAI).toHaveBeenCalledWith({
|
||||
apiKey: 'test-api-key',
|
||||
baseURL: 'https://api.openai.com/v1',
|
||||
timeout: DEFAULT_TIMEOUT,
|
||||
maxRetries: DEFAULT_MAX_RETRIES,
|
||||
defaultHeaders: {
|
||||
'User-Agent': `QwenCode/1.0.0 (${process.platform}; ${process.arch})`,
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
it('should include custom headers from buildHeaders', () => {
|
||||
provider.buildClient();
|
||||
|
||||
const expectedHeaders = provider.buildHeaders();
|
||||
expect(OpenAI).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
defaultHeaders: expectedHeaders,
|
||||
}),
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('buildRequest', () => {
|
||||
it('should pass through all request parameters unchanged', () => {
|
||||
const originalRequest: OpenAI.Chat.ChatCompletionCreateParams = {
|
||||
model: 'gpt-4',
|
||||
messages: [
|
||||
{ role: 'system', content: 'You are a helpful assistant.' },
|
||||
{ role: 'user', content: 'Hello!' },
|
||||
],
|
||||
temperature: 0.7,
|
||||
max_tokens: 1000,
|
||||
top_p: 0.9,
|
||||
frequency_penalty: 0.1,
|
||||
presence_penalty: 0.2,
|
||||
stream: false,
|
||||
};
|
||||
|
||||
const userPromptId = 'test-prompt-id';
|
||||
const result = provider.buildRequest(originalRequest, userPromptId);
|
||||
|
||||
expect(result).toEqual(originalRequest);
|
||||
expect(result).not.toBe(originalRequest); // Should be a new object
|
||||
});
|
||||
|
||||
it('should preserve all sampling parameters', () => {
|
||||
const originalRequest: OpenAI.Chat.ChatCompletionCreateParams = {
|
||||
model: 'gpt-3.5-turbo',
|
||||
messages: [{ role: 'user', content: 'Test message' }],
|
||||
temperature: 0.5,
|
||||
max_tokens: 500,
|
||||
top_p: 0.8,
|
||||
frequency_penalty: 0.3,
|
||||
presence_penalty: 0.4,
|
||||
stop: ['END'],
|
||||
logit_bias: { '123': 10 },
|
||||
user: 'test-user',
|
||||
seed: 42,
|
||||
};
|
||||
|
||||
const result = provider.buildRequest(originalRequest, 'prompt-id');
|
||||
|
||||
expect(result).toEqual(originalRequest);
|
||||
expect(result.temperature).toBe(0.5);
|
||||
expect(result.max_tokens).toBe(500);
|
||||
expect(result.top_p).toBe(0.8);
|
||||
expect(result.frequency_penalty).toBe(0.3);
|
||||
expect(result.presence_penalty).toBe(0.4);
|
||||
expect(result.stop).toEqual(['END']);
|
||||
expect(result.logit_bias).toEqual({ '123': 10 });
|
||||
expect(result.user).toBe('test-user');
|
||||
expect(result.seed).toBe(42);
|
||||
});
|
||||
|
||||
it('should handle minimal request parameters', () => {
|
||||
const minimalRequest: OpenAI.Chat.ChatCompletionCreateParams = {
|
||||
model: 'gpt-4',
|
||||
messages: [{ role: 'user', content: 'Hello' }],
|
||||
};
|
||||
|
||||
const result = provider.buildRequest(minimalRequest, 'prompt-id');
|
||||
|
||||
expect(result).toEqual(minimalRequest);
|
||||
});
|
||||
|
||||
it('should handle streaming requests', () => {
|
||||
const streamingRequest: OpenAI.Chat.ChatCompletionCreateParams = {
|
||||
model: 'gpt-4',
|
||||
messages: [{ role: 'user', content: 'Hello' }],
|
||||
stream: true,
|
||||
};
|
||||
|
||||
const result = provider.buildRequest(streamingRequest, 'prompt-id');
|
||||
|
||||
expect(result).toEqual(streamingRequest);
|
||||
expect(result.stream).toBe(true);
|
||||
});
|
||||
|
||||
it('should not modify the original request object', () => {
|
||||
const originalRequest: OpenAI.Chat.ChatCompletionCreateParams = {
|
||||
model: 'gpt-4',
|
||||
messages: [{ role: 'user', content: 'Hello' }],
|
||||
temperature: 0.7,
|
||||
};
|
||||
|
||||
const originalRequestCopy = { ...originalRequest };
|
||||
const result = provider.buildRequest(originalRequest, 'prompt-id');
|
||||
|
||||
// Original request should be unchanged
|
||||
expect(originalRequest).toEqual(originalRequestCopy);
|
||||
// Result should be a different object
|
||||
expect(result).not.toBe(originalRequest);
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -10,8 +10,8 @@ import { OpenAICompatibleProvider } from './types.js';
|
||||
export class DefaultOpenAICompatibleProvider
|
||||
implements OpenAICompatibleProvider
|
||||
{
|
||||
private contentGeneratorConfig: ContentGeneratorConfig;
|
||||
private cliConfig: Config;
|
||||
protected contentGeneratorConfig: ContentGeneratorConfig;
|
||||
protected cliConfig: Config;
|
||||
|
||||
constructor(
|
||||
contentGeneratorConfig: ContentGeneratorConfig,
|
||||
|
||||
221
packages/core/src/core/refactor/provider/openrouter.test.ts
Normal file
221
packages/core/src/core/refactor/provider/openrouter.test.ts
Normal file
@@ -0,0 +1,221 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Qwen
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { describe, it, expect, vi, beforeEach } from 'vitest';
|
||||
import OpenAI from 'openai';
|
||||
import { OpenRouterOpenAICompatibleProvider } from './openrouter.js';
|
||||
import { DefaultOpenAICompatibleProvider } from './default.js';
|
||||
import { Config } from '../../../config/config.js';
|
||||
import { ContentGeneratorConfig } from '../../contentGenerator.js';
|
||||
|
||||
describe('OpenRouterOpenAICompatibleProvider', () => {
|
||||
let provider: OpenRouterOpenAICompatibleProvider;
|
||||
let mockContentGeneratorConfig: ContentGeneratorConfig;
|
||||
let mockCliConfig: Config;
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
|
||||
// Mock ContentGeneratorConfig
|
||||
mockContentGeneratorConfig = {
|
||||
apiKey: 'test-api-key',
|
||||
baseUrl: 'https://openrouter.ai/api/v1',
|
||||
timeout: 60000,
|
||||
maxRetries: 2,
|
||||
model: 'openai/gpt-4',
|
||||
} as ContentGeneratorConfig;
|
||||
|
||||
// Mock Config
|
||||
mockCliConfig = {
|
||||
getCliVersion: vi.fn().mockReturnValue('1.0.0'),
|
||||
} as unknown as Config;
|
||||
|
||||
provider = new OpenRouterOpenAICompatibleProvider(
|
||||
mockContentGeneratorConfig,
|
||||
mockCliConfig,
|
||||
);
|
||||
});
|
||||
|
||||
describe('constructor', () => {
|
||||
it('should extend DefaultOpenAICompatibleProvider', () => {
|
||||
expect(provider).toBeInstanceOf(DefaultOpenAICompatibleProvider);
|
||||
expect(provider).toBeInstanceOf(OpenRouterOpenAICompatibleProvider);
|
||||
});
|
||||
});
|
||||
|
||||
describe('isOpenRouterProvider', () => {
|
||||
it('should return true for openrouter.ai URLs', () => {
|
||||
const configs = [
|
||||
{ baseUrl: 'https://openrouter.ai/api/v1' },
|
||||
{ baseUrl: 'https://api.openrouter.ai/v1' },
|
||||
{ baseUrl: 'https://openrouter.ai' },
|
||||
{ baseUrl: 'http://openrouter.ai/api/v1' },
|
||||
];
|
||||
|
||||
configs.forEach((config) => {
|
||||
const result = OpenRouterOpenAICompatibleProvider.isOpenRouterProvider(
|
||||
config as ContentGeneratorConfig,
|
||||
);
|
||||
expect(result).toBe(true);
|
||||
});
|
||||
});
|
||||
|
||||
it('should return false for non-openrouter URLs', () => {
|
||||
const configs = [
|
||||
{ baseUrl: 'https://api.openai.com/v1' },
|
||||
{ baseUrl: 'https://api.anthropic.com/v1' },
|
||||
{ baseUrl: 'https://dashscope.aliyuncs.com/compatible-mode/v1' },
|
||||
{ baseUrl: 'https://example.com/api/v1' }, // different domain
|
||||
{ baseUrl: '' },
|
||||
{ baseUrl: undefined },
|
||||
];
|
||||
|
||||
configs.forEach((config) => {
|
||||
const result = OpenRouterOpenAICompatibleProvider.isOpenRouterProvider(
|
||||
config as ContentGeneratorConfig,
|
||||
);
|
||||
expect(result).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
it('should handle missing baseUrl gracefully', () => {
|
||||
const config = {} as ContentGeneratorConfig;
|
||||
const result =
|
||||
OpenRouterOpenAICompatibleProvider.isOpenRouterProvider(config);
|
||||
expect(result).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe('buildHeaders', () => {
|
||||
it('should include base headers from parent class', () => {
|
||||
const headers = provider.buildHeaders();
|
||||
|
||||
// Should include User-Agent from parent
|
||||
expect(headers['User-Agent']).toBe(
|
||||
`QwenCode/1.0.0 (${process.platform}; ${process.arch})`,
|
||||
);
|
||||
});
|
||||
|
||||
it('should add OpenRouter-specific headers', () => {
|
||||
const headers = provider.buildHeaders();
|
||||
|
||||
expect(headers).toEqual({
|
||||
'User-Agent': `QwenCode/1.0.0 (${process.platform}; ${process.arch})`,
|
||||
'HTTP-Referer': 'https://github.com/QwenLM/qwen-code.git',
|
||||
'X-Title': 'Qwen Code',
|
||||
});
|
||||
});
|
||||
|
||||
it('should override parent headers if there are conflicts', () => {
|
||||
// Mock parent to return conflicting headers
|
||||
const parentBuildHeaders = vi.spyOn(
|
||||
DefaultOpenAICompatibleProvider.prototype,
|
||||
'buildHeaders',
|
||||
);
|
||||
parentBuildHeaders.mockReturnValue({
|
||||
'User-Agent': 'ParentAgent/1.0.0',
|
||||
'HTTP-Referer': 'https://parent.com',
|
||||
});
|
||||
|
||||
const headers = provider.buildHeaders();
|
||||
|
||||
expect(headers).toEqual({
|
||||
'User-Agent': 'ParentAgent/1.0.0',
|
||||
'HTTP-Referer': 'https://github.com/QwenLM/qwen-code.git', // OpenRouter-specific value should override
|
||||
'X-Title': 'Qwen Code',
|
||||
});
|
||||
|
||||
parentBuildHeaders.mockRestore();
|
||||
});
|
||||
|
||||
it('should handle unknown CLI version from parent', () => {
|
||||
vi.mocked(mockCliConfig.getCliVersion).mockReturnValue(undefined);
|
||||
|
||||
const headers = provider.buildHeaders();
|
||||
|
||||
expect(headers['User-Agent']).toBe(
|
||||
`QwenCode/unknown (${process.platform}; ${process.arch})`,
|
||||
);
|
||||
expect(headers['HTTP-Referer']).toBe(
|
||||
'https://github.com/QwenLM/qwen-code.git',
|
||||
);
|
||||
expect(headers['X-Title']).toBe('Qwen Code');
|
||||
});
|
||||
});
|
||||
|
||||
describe('buildClient', () => {
|
||||
it('should inherit buildClient behavior from parent', () => {
|
||||
// Mock the parent's buildClient method
|
||||
const mockClient = { test: 'client' };
|
||||
const parentBuildClient = vi.spyOn(
|
||||
DefaultOpenAICompatibleProvider.prototype,
|
||||
'buildClient',
|
||||
);
|
||||
parentBuildClient.mockReturnValue(mockClient as unknown as OpenAI);
|
||||
|
||||
const result = provider.buildClient();
|
||||
|
||||
expect(parentBuildClient).toHaveBeenCalled();
|
||||
expect(result).toBe(mockClient);
|
||||
|
||||
parentBuildClient.mockRestore();
|
||||
});
|
||||
});
|
||||
|
||||
describe('buildRequest', () => {
|
||||
it('should inherit buildRequest behavior from parent', () => {
|
||||
const mockRequest: OpenAI.Chat.ChatCompletionCreateParams = {
|
||||
model: 'openai/gpt-4',
|
||||
messages: [{ role: 'user', content: 'Hello' }],
|
||||
};
|
||||
const mockUserPromptId = 'test-prompt-id';
|
||||
const mockResult = { ...mockRequest, modified: true };
|
||||
|
||||
// Mock the parent's buildRequest method
|
||||
const parentBuildRequest = vi.spyOn(
|
||||
DefaultOpenAICompatibleProvider.prototype,
|
||||
'buildRequest',
|
||||
);
|
||||
parentBuildRequest.mockReturnValue(mockResult);
|
||||
|
||||
const result = provider.buildRequest(mockRequest, mockUserPromptId);
|
||||
|
||||
expect(parentBuildRequest).toHaveBeenCalledWith(
|
||||
mockRequest,
|
||||
mockUserPromptId,
|
||||
);
|
||||
expect(result).toBe(mockResult);
|
||||
|
||||
parentBuildRequest.mockRestore();
|
||||
});
|
||||
});
|
||||
|
||||
describe('integration with parent class', () => {
|
||||
it('should properly call parent constructor', () => {
|
||||
const newProvider = new OpenRouterOpenAICompatibleProvider(
|
||||
mockContentGeneratorConfig,
|
||||
mockCliConfig,
|
||||
);
|
||||
|
||||
// Verify that parent properties are accessible
|
||||
expect(newProvider).toHaveProperty('buildHeaders');
|
||||
expect(newProvider).toHaveProperty('buildClient');
|
||||
expect(newProvider).toHaveProperty('buildRequest');
|
||||
});
|
||||
|
||||
it('should maintain parent functionality while adding OpenRouter specifics', () => {
|
||||
// Test that the provider can perform all parent operations
|
||||
const headers = provider.buildHeaders();
|
||||
|
||||
// Should have both parent and OpenRouter-specific headers
|
||||
expect(headers['User-Agent']).toBeDefined(); // From parent
|
||||
expect(headers['HTTP-Referer']).toBe(
|
||||
'https://github.com/QwenLM/qwen-code.git',
|
||||
); // OpenRouter-specific
|
||||
expect(headers['X-Title']).toBe('Qwen Code'); // OpenRouter-specific
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -1,21 +1,13 @@
|
||||
import OpenAI from 'openai';
|
||||
import { Config } from '../../../config/config.js';
|
||||
import { ContentGeneratorConfig } from '../../contentGenerator.js';
|
||||
import { DEFAULT_TIMEOUT, DEFAULT_MAX_RETRIES } from '../constants.js';
|
||||
import { OpenAICompatibleProvider } from './types.js';
|
||||
|
||||
export class OpenRouterOpenAICompatibleProvider
|
||||
implements OpenAICompatibleProvider
|
||||
{
|
||||
private contentGeneratorConfig: ContentGeneratorConfig;
|
||||
private cliConfig: Config;
|
||||
import { DefaultOpenAICompatibleProvider } from './default.js';
|
||||
|
||||
export class OpenRouterOpenAICompatibleProvider extends DefaultOpenAICompatibleProvider {
|
||||
constructor(
|
||||
contentGeneratorConfig: ContentGeneratorConfig,
|
||||
cliConfig: Config,
|
||||
) {
|
||||
this.cliConfig = cliConfig;
|
||||
this.contentGeneratorConfig = contentGeneratorConfig;
|
||||
super(contentGeneratorConfig, cliConfig);
|
||||
}
|
||||
|
||||
static isOpenRouterProvider(
|
||||
@@ -25,40 +17,15 @@ export class OpenRouterOpenAICompatibleProvider
|
||||
return baseURL.includes('openrouter.ai');
|
||||
}
|
||||
|
||||
buildHeaders(): Record<string, string | undefined> {
|
||||
const version = this.cliConfig.getCliVersion() || 'unknown';
|
||||
const userAgent = `QwenCode/${version} (${process.platform}; ${process.arch})`;
|
||||
override buildHeaders(): Record<string, string | undefined> {
|
||||
// Get base headers from parent class
|
||||
const baseHeaders = super.buildHeaders();
|
||||
|
||||
// Add OpenRouter-specific headers
|
||||
return {
|
||||
'User-Agent': userAgent,
|
||||
...baseHeaders,
|
||||
'HTTP-Referer': 'https://github.com/QwenLM/qwen-code.git',
|
||||
'X-Title': 'Qwen Code',
|
||||
};
|
||||
}
|
||||
|
||||
buildClient(): OpenAI {
|
||||
const {
|
||||
apiKey,
|
||||
baseUrl,
|
||||
timeout = DEFAULT_TIMEOUT,
|
||||
maxRetries = DEFAULT_MAX_RETRIES,
|
||||
} = this.contentGeneratorConfig;
|
||||
const defaultHeaders = this.buildHeaders();
|
||||
return new OpenAI({
|
||||
apiKey,
|
||||
baseURL: baseUrl,
|
||||
timeout,
|
||||
maxRetries,
|
||||
defaultHeaders,
|
||||
});
|
||||
}
|
||||
|
||||
buildRequest(
|
||||
request: OpenAI.Chat.ChatCompletionCreateParams,
|
||||
_userPromptId: string,
|
||||
): OpenAI.Chat.ChatCompletionCreateParams {
|
||||
// OpenRouter doesn't need special enhancements, just pass through all parameters
|
||||
return {
|
||||
...request, // Preserve all original parameters including sampling params
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,111 +0,0 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Qwen
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import OpenAI from 'openai';
|
||||
import { GenerateContentResponse, Part, FinishReason } from '@google/genai';
|
||||
import { Converter } from './converter.js';
|
||||
|
||||
export interface ToolCallAccumulator {
|
||||
id?: string;
|
||||
name?: string;
|
||||
arguments: string;
|
||||
}
|
||||
|
||||
export class StreamingManager {
|
||||
private toolCallAccumulator = new Map<number, ToolCallAccumulator>();
|
||||
|
||||
constructor(private converter: Converter) {}
|
||||
|
||||
async *processStream(
|
||||
stream: AsyncIterable<OpenAI.Chat.ChatCompletionChunk>,
|
||||
): AsyncGenerator<GenerateContentResponse> {
|
||||
// Reset the accumulator for each new stream
|
||||
this.toolCallAccumulator.clear();
|
||||
|
||||
for await (const chunk of stream) {
|
||||
const response = this.converter.convertOpenAIChunkToGemini(chunk);
|
||||
|
||||
// Ignore empty responses, which would cause problems with downstream code
|
||||
// that expects a valid response.
|
||||
if (
|
||||
response.candidates?.[0]?.content?.parts?.length === 0 &&
|
||||
!response.usageMetadata
|
||||
) {
|
||||
continue;
|
||||
}
|
||||
|
||||
yield response;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Combine streaming responses for logging purposes
|
||||
*/
|
||||
combineStreamResponsesForLogging(
|
||||
responses: GenerateContentResponse[],
|
||||
model: string,
|
||||
): GenerateContentResponse {
|
||||
if (responses.length === 0) {
|
||||
return new GenerateContentResponse();
|
||||
}
|
||||
|
||||
const lastResponse = responses[responses.length - 1];
|
||||
|
||||
// Find the last response with usage metadata
|
||||
const finalUsageMetadata = responses
|
||||
.slice()
|
||||
.reverse()
|
||||
.find((r) => r.usageMetadata)?.usageMetadata;
|
||||
|
||||
// Combine all text content from the stream
|
||||
const combinedParts: Part[] = [];
|
||||
let combinedText = '';
|
||||
const functionCalls: Part[] = [];
|
||||
|
||||
for (const response of responses) {
|
||||
if (response.candidates?.[0]?.content?.parts) {
|
||||
for (const part of response.candidates[0].content.parts) {
|
||||
if ('text' in part && part.text) {
|
||||
combinedText += part.text;
|
||||
} else if ('functionCall' in part && part.functionCall) {
|
||||
functionCalls.push(part);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Add combined text if any
|
||||
if (combinedText) {
|
||||
combinedParts.push({ text: combinedText });
|
||||
}
|
||||
|
||||
// Add function calls
|
||||
combinedParts.push(...functionCalls);
|
||||
|
||||
// Create combined response
|
||||
const combinedResponse = new GenerateContentResponse();
|
||||
combinedResponse.candidates = [
|
||||
{
|
||||
content: {
|
||||
parts: combinedParts,
|
||||
role: 'model' as const,
|
||||
},
|
||||
finishReason:
|
||||
responses[responses.length - 1]?.candidates?.[0]?.finishReason ||
|
||||
FinishReason.FINISH_REASON_UNSPECIFIED,
|
||||
index: 0,
|
||||
safetyRatings: [],
|
||||
},
|
||||
];
|
||||
combinedResponse.responseId = lastResponse?.responseId;
|
||||
combinedResponse.createTime = lastResponse?.createTime;
|
||||
combinedResponse.modelVersion = model;
|
||||
combinedResponse.promptFeedback = { safetyRatings: [] };
|
||||
combinedResponse.usageMetadata = finalUsageMetadata;
|
||||
|
||||
return combinedResponse;
|
||||
}
|
||||
}
|
||||
570
packages/core/src/core/refactor/streamingToolCallParser.test.ts
Normal file
570
packages/core/src/core/refactor/streamingToolCallParser.test.ts
Normal file
@@ -0,0 +1,570 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Qwen
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { describe, it, expect, beforeEach } from 'vitest';
|
||||
import {
|
||||
StreamingToolCallParser,
|
||||
ToolCallParseResult,
|
||||
} from './streamingToolCallParser.js';
|
||||
|
||||
describe('StreamingToolCallParser', () => {
|
||||
let parser: StreamingToolCallParser;
|
||||
|
||||
beforeEach(() => {
|
||||
parser = new StreamingToolCallParser();
|
||||
});
|
||||
|
||||
describe('Basic functionality', () => {
|
||||
it('should initialize with empty state', () => {
|
||||
expect(parser.getBuffer(0)).toBe('');
|
||||
expect(parser.getState(0)).toEqual({
|
||||
depth: 0,
|
||||
inString: false,
|
||||
escape: false,
|
||||
});
|
||||
expect(parser.getToolCallMeta(0)).toEqual({});
|
||||
});
|
||||
|
||||
it('should handle simple complete JSON in single chunk', () => {
|
||||
const result = parser.addChunk(
|
||||
0,
|
||||
'{"key": "value"}',
|
||||
'call_1',
|
||||
'test_function',
|
||||
);
|
||||
|
||||
expect(result.complete).toBe(true);
|
||||
expect(result.value).toEqual({ key: 'value' });
|
||||
expect(result.error).toBeUndefined();
|
||||
expect(result.repaired).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should accumulate chunks until complete JSON', () => {
|
||||
let result = parser.addChunk(0, '{"key":', 'call_1', 'test_function');
|
||||
expect(result.complete).toBe(false);
|
||||
|
||||
result = parser.addChunk(0, ' "val');
|
||||
expect(result.complete).toBe(false);
|
||||
|
||||
result = parser.addChunk(0, 'ue"}');
|
||||
expect(result.complete).toBe(true);
|
||||
expect(result.value).toEqual({ key: 'value' });
|
||||
});
|
||||
|
||||
it('should handle empty chunks gracefully', () => {
|
||||
const result = parser.addChunk(0, '', 'call_1', 'test_function');
|
||||
expect(result.complete).toBe(false);
|
||||
expect(parser.getBuffer(0)).toBe('');
|
||||
});
|
||||
});
|
||||
|
||||
describe('JSON depth tracking', () => {
|
||||
it('should track nested objects correctly', () => {
|
||||
let result = parser.addChunk(
|
||||
0,
|
||||
'{"outer": {"inner":',
|
||||
'call_1',
|
||||
'test_function',
|
||||
);
|
||||
expect(result.complete).toBe(false);
|
||||
expect(parser.getState(0).depth).toBe(2);
|
||||
|
||||
result = parser.addChunk(0, ' "value"}}');
|
||||
expect(result.complete).toBe(true);
|
||||
expect(result.value).toEqual({ outer: { inner: 'value' } });
|
||||
});
|
||||
|
||||
it('should track nested arrays correctly', () => {
|
||||
let result = parser.addChunk(
|
||||
0,
|
||||
'{"arr": [1, [2,',
|
||||
'call_1',
|
||||
'test_function',
|
||||
);
|
||||
expect(result.complete).toBe(false);
|
||||
// Depth: { (1) + [ (2) + [ (3) = 3
|
||||
expect(parser.getState(0).depth).toBe(3);
|
||||
|
||||
result = parser.addChunk(0, ' 3]]}');
|
||||
expect(result.complete).toBe(true);
|
||||
expect(result.value).toEqual({ arr: [1, [2, 3]] });
|
||||
});
|
||||
|
||||
it('should handle mixed nested structures', () => {
|
||||
let result = parser.addChunk(
|
||||
0,
|
||||
'{"obj": {"arr": [{"nested":',
|
||||
'call_1',
|
||||
'test_function',
|
||||
);
|
||||
expect(result.complete).toBe(false);
|
||||
// Depth: { (1) + { (2) + [ (3) + { (4) = 4
|
||||
expect(parser.getState(0).depth).toBe(4);
|
||||
|
||||
result = parser.addChunk(0, ' true}]}}');
|
||||
expect(result.complete).toBe(true);
|
||||
expect(result.value).toEqual({ obj: { arr: [{ nested: true }] } });
|
||||
});
|
||||
});
|
||||
|
||||
describe('String handling', () => {
|
||||
it('should handle strings with special characters', () => {
|
||||
const result = parser.addChunk(
|
||||
0,
|
||||
'{"text": "Hello, \\"World\\"!"}',
|
||||
'call_1',
|
||||
'test_function',
|
||||
);
|
||||
expect(result.complete).toBe(true);
|
||||
expect(result.value).toEqual({ text: 'Hello, "World"!' });
|
||||
});
|
||||
|
||||
it('should handle strings with braces and brackets', () => {
|
||||
const result = parser.addChunk(
|
||||
0,
|
||||
'{"code": "if (x) { return [1, 2]; }"}',
|
||||
'call_1',
|
||||
'test_function',
|
||||
);
|
||||
expect(result.complete).toBe(true);
|
||||
expect(result.value).toEqual({ code: 'if (x) { return [1, 2]; }' });
|
||||
});
|
||||
|
||||
it('should track string boundaries correctly across chunks', () => {
|
||||
let result = parser.addChunk(
|
||||
0,
|
||||
'{"text": "Hello',
|
||||
'call_1',
|
||||
'test_function',
|
||||
);
|
||||
expect(result.complete).toBe(false);
|
||||
expect(parser.getState(0).inString).toBe(true);
|
||||
|
||||
result = parser.addChunk(0, ' World"}');
|
||||
expect(result.complete).toBe(true);
|
||||
expect(result.value).toEqual({ text: 'Hello World' });
|
||||
});
|
||||
|
||||
it('should handle escaped quotes in strings', () => {
|
||||
let result = parser.addChunk(
|
||||
0,
|
||||
'{"text": "Say \\"Hello',
|
||||
'call_1',
|
||||
'test_function',
|
||||
);
|
||||
expect(result.complete).toBe(false);
|
||||
expect(parser.getState(0).inString).toBe(true);
|
||||
|
||||
result = parser.addChunk(0, '\\" to me"}');
|
||||
expect(result.complete).toBe(true);
|
||||
expect(result.value).toEqual({ text: 'Say "Hello" to me' });
|
||||
});
|
||||
|
||||
it('should handle backslash escapes correctly', () => {
|
||||
const result = parser.addChunk(
|
||||
0,
|
||||
'{"path": "C:\\\\Users\\\\test"}',
|
||||
'call_1',
|
||||
'test_function',
|
||||
);
|
||||
expect(result.complete).toBe(true);
|
||||
expect(result.value).toEqual({ path: 'C:\\Users\\test' });
|
||||
});
|
||||
});
|
||||
|
||||
describe('Error handling and repair', () => {
|
||||
it('should return error for malformed JSON at depth 0', () => {
|
||||
const result = parser.addChunk(
|
||||
0,
|
||||
'{"key": invalid}',
|
||||
'call_1',
|
||||
'test_function',
|
||||
);
|
||||
expect(result.complete).toBe(false);
|
||||
expect(result.error).toBeInstanceOf(Error);
|
||||
});
|
||||
|
||||
it('should auto-repair unclosed strings', () => {
|
||||
// Test the repair functionality in getCompletedToolCalls instead
|
||||
// since that's where repair is actually used in practice
|
||||
parser.addChunk(0, '{"text": "unclosed', 'call_1', 'test_function');
|
||||
|
||||
const completed = parser.getCompletedToolCalls();
|
||||
expect(completed).toHaveLength(1);
|
||||
expect(completed[0].args).toEqual({ text: 'unclosed' });
|
||||
});
|
||||
|
||||
it('should not attempt repair when still in nested structure', () => {
|
||||
const result = parser.addChunk(
|
||||
0,
|
||||
'{"obj": {"text": "unclosed',
|
||||
'call_1',
|
||||
'test_function',
|
||||
);
|
||||
expect(result.complete).toBe(false);
|
||||
expect(result.repaired).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should handle repair failure gracefully', () => {
|
||||
// Create a case where even repair fails - malformed JSON at depth 0
|
||||
const result = parser.addChunk(
|
||||
0,
|
||||
'invalid json',
|
||||
'call_1',
|
||||
'test_function',
|
||||
);
|
||||
expect(result.complete).toBe(false);
|
||||
expect(result.error).toBeInstanceOf(Error);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Multiple tool calls', () => {
|
||||
it('should handle multiple tool calls with different indices', () => {
|
||||
const result1 = parser.addChunk(
|
||||
0,
|
||||
'{"param1": "value1"}',
|
||||
'call_1',
|
||||
'function1',
|
||||
);
|
||||
const result2 = parser.addChunk(
|
||||
1,
|
||||
'{"param2": "value2"}',
|
||||
'call_2',
|
||||
'function2',
|
||||
);
|
||||
|
||||
expect(result1.complete).toBe(true);
|
||||
expect(result1.value).toEqual({ param1: 'value1' });
|
||||
expect(result2.complete).toBe(true);
|
||||
expect(result2.value).toEqual({ param2: 'value2' });
|
||||
|
||||
expect(parser.getToolCallMeta(0)).toEqual({
|
||||
id: 'call_1',
|
||||
name: 'function1',
|
||||
});
|
||||
expect(parser.getToolCallMeta(1)).toEqual({
|
||||
id: 'call_2',
|
||||
name: 'function2',
|
||||
});
|
||||
});
|
||||
|
||||
it('should handle interleaved chunks from multiple tool calls', () => {
|
||||
let result1 = parser.addChunk(0, '{"param1":', 'call_1', 'function1');
|
||||
let result2 = parser.addChunk(1, '{"param2":', 'call_2', 'function2');
|
||||
|
||||
expect(result1.complete).toBe(false);
|
||||
expect(result2.complete).toBe(false);
|
||||
|
||||
result1 = parser.addChunk(0, ' "value1"}');
|
||||
result2 = parser.addChunk(1, ' "value2"}');
|
||||
|
||||
expect(result1.complete).toBe(true);
|
||||
expect(result1.value).toEqual({ param1: 'value1' });
|
||||
expect(result2.complete).toBe(true);
|
||||
expect(result2.value).toEqual({ param2: 'value2' });
|
||||
});
|
||||
|
||||
it('should maintain separate state for each index', () => {
|
||||
parser.addChunk(0, '{"nested": {"deep":', 'call_1', 'function1');
|
||||
parser.addChunk(1, '{"simple":', 'call_2', 'function2');
|
||||
|
||||
expect(parser.getState(0).depth).toBe(2);
|
||||
expect(parser.getState(1).depth).toBe(1);
|
||||
|
||||
const result1 = parser.addChunk(0, ' "value"}}');
|
||||
const result2 = parser.addChunk(1, ' "value"}');
|
||||
|
||||
expect(result1.complete).toBe(true);
|
||||
expect(result2.complete).toBe(true);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Tool call metadata handling', () => {
|
||||
it('should store and retrieve tool call metadata', () => {
|
||||
parser.addChunk(0, '{"param": "value"}', 'call_123', 'my_function');
|
||||
|
||||
const meta = parser.getToolCallMeta(0);
|
||||
expect(meta.id).toBe('call_123');
|
||||
expect(meta.name).toBe('my_function');
|
||||
});
|
||||
|
||||
it('should handle metadata-only chunks', () => {
|
||||
const result = parser.addChunk(0, '', 'call_123', 'my_function');
|
||||
expect(result.complete).toBe(false);
|
||||
|
||||
const meta = parser.getToolCallMeta(0);
|
||||
expect(meta.id).toBe('call_123');
|
||||
expect(meta.name).toBe('my_function');
|
||||
});
|
||||
|
||||
it('should update metadata incrementally', () => {
|
||||
parser.addChunk(0, '', 'call_123');
|
||||
expect(parser.getToolCallMeta(0).id).toBe('call_123');
|
||||
expect(parser.getToolCallMeta(0).name).toBeUndefined();
|
||||
|
||||
parser.addChunk(0, '{"param":', undefined, 'my_function');
|
||||
expect(parser.getToolCallMeta(0).id).toBe('call_123');
|
||||
expect(parser.getToolCallMeta(0).name).toBe('my_function');
|
||||
});
|
||||
|
||||
it('should detect new tool call with same index and reset state', () => {
|
||||
// First tool call
|
||||
const result1 = parser.addChunk(
|
||||
0,
|
||||
'{"param1": "value1"}',
|
||||
'call_1',
|
||||
'function1',
|
||||
);
|
||||
expect(result1.complete).toBe(true);
|
||||
|
||||
// New tool call with same index but different function name should reset
|
||||
const result2 = parser.addChunk(0, '{"param2":', 'call_2', 'function2');
|
||||
expect(result2.complete).toBe(false);
|
||||
expect(parser.getBuffer(0)).toBe('{"param2":');
|
||||
expect(parser.getToolCallMeta(0).name).toBe('function2');
|
||||
expect(parser.getToolCallMeta(0).id).toBe('call_2');
|
||||
});
|
||||
});
|
||||
|
||||
describe('Completed tool calls', () => {
|
||||
it('should return completed tool calls', () => {
|
||||
parser.addChunk(0, '{"param1": "value1"}', 'call_1', 'function1');
|
||||
parser.addChunk(1, '{"param2": "value2"}', 'call_2', 'function2');
|
||||
|
||||
const completed = parser.getCompletedToolCalls();
|
||||
expect(completed).toHaveLength(2);
|
||||
|
||||
expect(completed[0]).toEqual({
|
||||
id: 'call_1',
|
||||
name: 'function1',
|
||||
args: { param1: 'value1' },
|
||||
index: 0,
|
||||
});
|
||||
|
||||
expect(completed[1]).toEqual({
|
||||
id: 'call_2',
|
||||
name: 'function2',
|
||||
args: { param2: 'value2' },
|
||||
index: 1,
|
||||
});
|
||||
});
|
||||
|
||||
it('should handle completed tool calls with repair', () => {
|
||||
parser.addChunk(0, '{"text": "unclosed', 'call_1', 'function1');
|
||||
|
||||
const completed = parser.getCompletedToolCalls();
|
||||
expect(completed).toHaveLength(1);
|
||||
expect(completed[0].args).toEqual({ text: 'unclosed' });
|
||||
});
|
||||
|
||||
it('should use safeJsonParse as fallback for malformed JSON', () => {
|
||||
// Simulate a case where JSON.parse fails but jsonrepair can fix it
|
||||
parser.addChunk(
|
||||
0,
|
||||
'{"valid": "data", "invalid": }',
|
||||
'call_1',
|
||||
'function1',
|
||||
);
|
||||
|
||||
const completed = parser.getCompletedToolCalls();
|
||||
expect(completed).toHaveLength(1);
|
||||
// jsonrepair should fix the malformed JSON by setting invalid to null
|
||||
expect(completed[0].args).toEqual({ valid: 'data', invalid: null });
|
||||
});
|
||||
|
||||
it('should not return tool calls without function name', () => {
|
||||
parser.addChunk(0, '{"param": "value"}', 'call_1'); // No function name
|
||||
|
||||
const completed = parser.getCompletedToolCalls();
|
||||
expect(completed).toHaveLength(0);
|
||||
});
|
||||
|
||||
it('should not return tool calls with empty buffer', () => {
|
||||
parser.addChunk(0, '', 'call_1', 'function1'); // Empty buffer
|
||||
|
||||
const completed = parser.getCompletedToolCalls();
|
||||
expect(completed).toHaveLength(0);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Reset functionality', () => {
|
||||
it('should reset specific index', () => {
|
||||
parser.addChunk(0, '{"param":', 'call_1', 'function1');
|
||||
parser.addChunk(1, '{"other":', 'call_2', 'function2');
|
||||
|
||||
parser.resetIndex(0);
|
||||
|
||||
expect(parser.getBuffer(0)).toBe('');
|
||||
expect(parser.getState(0)).toEqual({
|
||||
depth: 0,
|
||||
inString: false,
|
||||
escape: false,
|
||||
});
|
||||
expect(parser.getToolCallMeta(0)).toEqual({});
|
||||
|
||||
// Index 1 should remain unchanged
|
||||
expect(parser.getBuffer(1)).toBe('{"other":');
|
||||
expect(parser.getToolCallMeta(1)).toEqual({
|
||||
id: 'call_2',
|
||||
name: 'function2',
|
||||
});
|
||||
});
|
||||
|
||||
it('should reset entire parser state', () => {
|
||||
parser.addChunk(0, '{"param1":', 'call_1', 'function1');
|
||||
parser.addChunk(1, '{"param2":', 'call_2', 'function2');
|
||||
|
||||
parser.reset();
|
||||
|
||||
expect(parser.getBuffer(0)).toBe('');
|
||||
expect(parser.getBuffer(1)).toBe('');
|
||||
expect(parser.getToolCallMeta(0)).toEqual({});
|
||||
expect(parser.getToolCallMeta(1)).toEqual({});
|
||||
expect(parser.getCompletedToolCalls()).toHaveLength(0);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Edge cases', () => {
|
||||
it('should handle very large JSON objects', () => {
|
||||
const largeObject = { data: 'x'.repeat(10000) };
|
||||
const jsonString = JSON.stringify(largeObject);
|
||||
|
||||
const result = parser.addChunk(0, jsonString, 'call_1', 'function1');
|
||||
expect(result.complete).toBe(true);
|
||||
expect(result.value).toEqual(largeObject);
|
||||
});
|
||||
|
||||
it('should handle deeply nested structures', () => {
|
||||
let nested: unknown = 'value';
|
||||
for (let i = 0; i < 100; i++) {
|
||||
nested = { level: nested };
|
||||
}
|
||||
|
||||
const jsonString = JSON.stringify(nested);
|
||||
const result = parser.addChunk(0, jsonString, 'call_1', 'function1');
|
||||
expect(result.complete).toBe(true);
|
||||
expect(result.value).toEqual(nested);
|
||||
});
|
||||
|
||||
it('should handle JSON with unicode characters', () => {
|
||||
const result = parser.addChunk(
|
||||
0,
|
||||
'{"emoji": "🚀", "chinese": "你好"}',
|
||||
'call_1',
|
||||
'function1',
|
||||
);
|
||||
expect(result.complete).toBe(true);
|
||||
expect(result.value).toEqual({ emoji: '🚀', chinese: '你好' });
|
||||
});
|
||||
|
||||
it('should handle JSON with null and boolean values', () => {
|
||||
const result = parser.addChunk(
|
||||
0,
|
||||
'{"null": null, "bool": true, "false": false}',
|
||||
'call_1',
|
||||
'function1',
|
||||
);
|
||||
expect(result.complete).toBe(true);
|
||||
expect(result.value).toEqual({ null: null, bool: true, false: false });
|
||||
});
|
||||
|
||||
it('should handle JSON with numbers', () => {
|
||||
const result = parser.addChunk(
|
||||
0,
|
||||
'{"int": 42, "float": 3.14, "negative": -1, "exp": 1e5}',
|
||||
'call_1',
|
||||
'function1',
|
||||
);
|
||||
expect(result.complete).toBe(true);
|
||||
expect(result.value).toEqual({
|
||||
int: 42,
|
||||
float: 3.14,
|
||||
negative: -1,
|
||||
exp: 1e5,
|
||||
});
|
||||
});
|
||||
|
||||
it('should handle whitespace-only chunks', () => {
|
||||
let result = parser.addChunk(0, ' \n\t ', 'call_1', 'function1');
|
||||
expect(result.complete).toBe(false);
|
||||
|
||||
result = parser.addChunk(0, '{"key": "value"}');
|
||||
expect(result.complete).toBe(true);
|
||||
expect(result.value).toEqual({ key: 'value' });
|
||||
});
|
||||
|
||||
it('should handle chunks with only structural characters', () => {
|
||||
let result = parser.addChunk(0, '{', 'call_1', 'function1');
|
||||
expect(result.complete).toBe(false);
|
||||
expect(parser.getState(0).depth).toBe(1);
|
||||
|
||||
result = parser.addChunk(0, '}');
|
||||
expect(result.complete).toBe(true);
|
||||
expect(result.value).toEqual({});
|
||||
});
|
||||
});
|
||||
|
||||
describe('Real-world streaming scenarios', () => {
|
||||
it('should handle typical OpenAI streaming pattern', () => {
|
||||
// Simulate how OpenAI typically streams tool call arguments
|
||||
const chunks = [
|
||||
'{"',
|
||||
'query',
|
||||
'": "',
|
||||
'What is',
|
||||
' the weather',
|
||||
' in Paris',
|
||||
'?"}',
|
||||
];
|
||||
|
||||
let result: ToolCallParseResult = { complete: false };
|
||||
for (let i = 0; i < chunks.length; i++) {
|
||||
result = parser.addChunk(
|
||||
0,
|
||||
chunks[i],
|
||||
i === 0 ? 'call_1' : undefined,
|
||||
i === 0 ? 'get_weather' : undefined,
|
||||
);
|
||||
if (i < chunks.length - 1) {
|
||||
expect(result.complete).toBe(false);
|
||||
}
|
||||
}
|
||||
|
||||
expect(result.complete).toBe(true);
|
||||
expect(result.value).toEqual({ query: 'What is the weather in Paris?' });
|
||||
});
|
||||
|
||||
it('should handle multiple concurrent tool calls streaming', () => {
|
||||
// Simulate multiple tool calls being streamed simultaneously
|
||||
parser.addChunk(0, '{"location":', 'call_1', 'get_weather');
|
||||
parser.addChunk(1, '{"query":', 'call_2', 'search_web');
|
||||
parser.addChunk(0, ' "New York"}');
|
||||
|
||||
const result1 = parser.addChunk(1, ' "OpenAI GPT"}');
|
||||
|
||||
expect(result1.complete).toBe(true);
|
||||
expect(result1.value).toEqual({ query: 'OpenAI GPT' });
|
||||
|
||||
const completed = parser.getCompletedToolCalls();
|
||||
expect(completed).toHaveLength(2);
|
||||
expect(completed.find((tc) => tc.name === 'get_weather')?.args).toEqual({
|
||||
location: 'New York',
|
||||
});
|
||||
expect(completed.find((tc) => tc.name === 'search_web')?.args).toEqual({
|
||||
query: 'OpenAI GPT',
|
||||
});
|
||||
});
|
||||
|
||||
it('should handle malformed streaming that gets repaired', () => {
|
||||
// Simulate a stream that gets cut off mid-string
|
||||
parser.addChunk(0, '{"message": "Hello world', 'call_1', 'send_message');
|
||||
|
||||
const completed = parser.getCompletedToolCalls();
|
||||
expect(completed).toHaveLength(1);
|
||||
expect(completed[0].args).toEqual({ message: 'Hello world' });
|
||||
});
|
||||
});
|
||||
});
|
||||
300
packages/core/src/core/refactor/streamingToolCallParser.ts
Normal file
300
packages/core/src/core/refactor/streamingToolCallParser.ts
Normal file
@@ -0,0 +1,300 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Qwen
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { safeJsonParse } from '../../utils/safeJsonParse.js';
|
||||
|
||||
/**
|
||||
* Type definition for the result of parsing a JSON chunk in tool calls
|
||||
*/
|
||||
export interface ToolCallParseResult {
|
||||
/** Whether the JSON parsing is complete */
|
||||
complete: boolean;
|
||||
/** The parsed JSON value (only present when complete is true) */
|
||||
value?: Record<string, unknown>;
|
||||
/** Error information if parsing failed */
|
||||
error?: Error;
|
||||
/** Whether the JSON was repaired (e.g., auto-closed unclosed strings) */
|
||||
repaired?: boolean;
|
||||
}
|
||||
|
||||
/**
|
||||
* Streaming Tool Call Parser Implementation
|
||||
*
|
||||
* This class implements a sophisticated streaming parser specifically designed for
|
||||
* handling tool call arguments that arrive as partial JSON data in chunks during
|
||||
* OpenAI streaming responses. It extends the principles from the streaming JSON parser
|
||||
* to handle the specific requirements of tool call processing.
|
||||
*
|
||||
* Key Features:
|
||||
* - Real-time depth tracking for nested JSON structures in tool arguments
|
||||
* - Proper handling of string literals and escape sequences
|
||||
* - Automatic repair of common JSON formatting issues
|
||||
* - Support for multiple consecutive tool calls with same or different function names
|
||||
* - Memory-efficient processing without storing complete JSON in memory
|
||||
* - State management for individual tool call indices
|
||||
*/
|
||||
export class StreamingToolCallParser {
|
||||
/** Accumulated buffer containing all received chunks for each tool call index */
|
||||
private buffers: Map<number, string> = new Map();
|
||||
/** Current nesting depth in JSON structure for each tool call index */
|
||||
private depths: Map<number, number> = new Map();
|
||||
/** Whether we're currently inside a string literal for each tool call index */
|
||||
private inStrings: Map<number, boolean> = new Map();
|
||||
/** Whether the next character should be treated as escaped for each tool call index */
|
||||
private escapes: Map<number, boolean> = new Map();
|
||||
/** Metadata for each tool call index */
|
||||
private toolCallMeta: Map<number, { id?: string; name?: string }> = new Map();
|
||||
|
||||
/**
|
||||
* Processes a new chunk of tool call data and attempts to parse complete JSON objects
|
||||
*
|
||||
* This method implements a state machine that tracks:
|
||||
* 1. JSON structure depth (brackets and braces) per tool call index
|
||||
* 2. String literal boundaries per tool call index
|
||||
* 3. Escape sequences within strings per tool call index
|
||||
* 4. Tool call metadata (id, function name) per tool call index
|
||||
*
|
||||
* The parser only attempts to parse when the depth returns to 0, indicating
|
||||
* a complete JSON structure has been received for that specific tool call index.
|
||||
*
|
||||
* @param index - The tool call index from OpenAI streaming response
|
||||
* @param chunk - A string chunk containing partial JSON data for arguments
|
||||
* @param id - Optional tool call ID
|
||||
* @param name - Optional function name
|
||||
* @returns ToolCallParseResult indicating whether parsing is complete and any parsed value
|
||||
*/
|
||||
addChunk(
|
||||
index: number,
|
||||
chunk: string,
|
||||
id?: string,
|
||||
name?: string,
|
||||
): ToolCallParseResult {
|
||||
// Initialize state for this index if not exists
|
||||
if (!this.buffers.has(index)) {
|
||||
this.buffers.set(index, '');
|
||||
this.depths.set(index, 0);
|
||||
this.inStrings.set(index, false);
|
||||
this.escapes.set(index, false);
|
||||
this.toolCallMeta.set(index, {});
|
||||
}
|
||||
|
||||
// Update metadata
|
||||
const meta = this.toolCallMeta.get(index)!;
|
||||
if (id) meta.id = id;
|
||||
if (name) {
|
||||
// If this is a new function name and we have existing arguments,
|
||||
// it might be a new tool call with the same index - reset the buffer
|
||||
if (meta.name && meta.name !== name && this.buffers.get(index)) {
|
||||
const currentBuffer = this.buffers.get(index)!;
|
||||
// Check if current buffer contains complete JSON
|
||||
if (currentBuffer.trim()) {
|
||||
try {
|
||||
JSON.parse(currentBuffer);
|
||||
// If we can parse it, this is likely a new tool call - reset state
|
||||
this.resetIndex(index);
|
||||
// Update metadata after reset
|
||||
const resetMeta = this.toolCallMeta.get(index)!;
|
||||
if (id) resetMeta.id = id;
|
||||
resetMeta.name = name;
|
||||
} catch {
|
||||
// Current buffer is incomplete, continue accumulating
|
||||
meta.name = name;
|
||||
}
|
||||
} else {
|
||||
meta.name = name;
|
||||
}
|
||||
} else {
|
||||
meta.name = name;
|
||||
}
|
||||
}
|
||||
|
||||
// Get current state for this index
|
||||
const currentBuffer = this.buffers.get(index)!;
|
||||
const currentDepth = this.depths.get(index)!;
|
||||
const currentInString = this.inStrings.get(index)!;
|
||||
const currentEscape = this.escapes.get(index)!;
|
||||
|
||||
// Add chunk to buffer
|
||||
const newBuffer = currentBuffer + chunk;
|
||||
this.buffers.set(index, newBuffer);
|
||||
|
||||
// Track JSON structure depth - only count brackets/braces outside of strings
|
||||
let depth = currentDepth;
|
||||
let inString = currentInString;
|
||||
let escape = currentEscape;
|
||||
|
||||
for (const char of chunk) {
|
||||
if (!inString) {
|
||||
if (char === '{' || char === '[') depth++;
|
||||
else if (char === '}' || char === ']') depth--;
|
||||
}
|
||||
|
||||
// Track string boundaries - toggle inString state on unescaped quotes
|
||||
if (char === '"' && !escape) {
|
||||
inString = !inString;
|
||||
}
|
||||
// Track escape sequences - backslash followed by any character is escaped
|
||||
escape = char === '\\' && !escape;
|
||||
}
|
||||
|
||||
// Update state
|
||||
this.depths.set(index, depth);
|
||||
this.inStrings.set(index, inString);
|
||||
this.escapes.set(index, escape);
|
||||
|
||||
// Attempt parse when we're back at root level (depth 0) and have data
|
||||
if (depth === 0 && newBuffer.trim().length > 0) {
|
||||
try {
|
||||
// Standard JSON parsing attempt
|
||||
const parsed = JSON.parse(newBuffer);
|
||||
return { complete: true, value: parsed };
|
||||
} catch (e) {
|
||||
// Intelligent repair: try auto-closing unclosed strings
|
||||
if (inString) {
|
||||
try {
|
||||
const repaired = JSON.parse(newBuffer + '"');
|
||||
return {
|
||||
complete: true,
|
||||
value: repaired,
|
||||
repaired: true,
|
||||
};
|
||||
} catch {
|
||||
// If repair fails, fall through to error case
|
||||
}
|
||||
}
|
||||
return {
|
||||
complete: false,
|
||||
error: e instanceof Error ? e : new Error(String(e)),
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
// JSON structure is incomplete, continue accumulating chunks
|
||||
return { complete: false };
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the current tool call metadata for a specific index
|
||||
*
|
||||
* @param index - The tool call index
|
||||
* @returns Object containing id and name if available
|
||||
*/
|
||||
getToolCallMeta(index: number): { id?: string; name?: string } {
|
||||
return this.toolCallMeta.get(index) || {};
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets all completed tool calls that are ready to be emitted
|
||||
* This method should be called when the streaming is complete (finish_reason is present)
|
||||
*
|
||||
* @returns Array of completed tool calls with their metadata and parsed arguments
|
||||
*/
|
||||
getCompletedToolCalls(): Array<{
|
||||
id?: string;
|
||||
name?: string;
|
||||
args: Record<string, unknown>;
|
||||
index: number;
|
||||
}> {
|
||||
const completed: Array<{
|
||||
id?: string;
|
||||
name?: string;
|
||||
args: Record<string, unknown>;
|
||||
index: number;
|
||||
}> = [];
|
||||
|
||||
for (const [index, buffer] of this.buffers.entries()) {
|
||||
const meta = this.toolCallMeta.get(index);
|
||||
if (meta?.name && buffer.trim()) {
|
||||
let args: Record<string, unknown> = {};
|
||||
|
||||
// Try to parse the final buffer
|
||||
try {
|
||||
args = JSON.parse(buffer);
|
||||
} catch {
|
||||
// Try with repair (auto-close strings)
|
||||
const inString = this.inStrings.get(index);
|
||||
if (inString) {
|
||||
try {
|
||||
args = JSON.parse(buffer + '"');
|
||||
} catch {
|
||||
// If all parsing fails, use safeJsonParse as fallback
|
||||
args = safeJsonParse(buffer, {});
|
||||
}
|
||||
} else {
|
||||
args = safeJsonParse(buffer, {});
|
||||
}
|
||||
}
|
||||
|
||||
completed.push({
|
||||
id: meta.id,
|
||||
name: meta.name,
|
||||
args,
|
||||
index,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
return completed;
|
||||
}
|
||||
|
||||
/**
|
||||
* Resets the parser state for a specific tool call index
|
||||
*
|
||||
* @param index - The tool call index to reset
|
||||
*/
|
||||
resetIndex(index: number): void {
|
||||
this.buffers.set(index, '');
|
||||
this.depths.set(index, 0);
|
||||
this.inStrings.set(index, false);
|
||||
this.escapes.set(index, false);
|
||||
this.toolCallMeta.set(index, {});
|
||||
}
|
||||
|
||||
/**
|
||||
* Resets the entire parser state for processing a new stream
|
||||
*
|
||||
* This method clears all internal state variables, allowing the parser
|
||||
* to be reused for multiple streams without interference.
|
||||
*/
|
||||
reset(): void {
|
||||
this.buffers.clear();
|
||||
this.depths.clear();
|
||||
this.inStrings.clear();
|
||||
this.escapes.clear();
|
||||
this.toolCallMeta.clear();
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the current accumulated buffer content for a specific index
|
||||
*
|
||||
* Useful for debugging or when you need to inspect the raw data
|
||||
* that has been accumulated so far.
|
||||
*
|
||||
* @param index - The tool call index
|
||||
* @returns The current buffer content for the specified index
|
||||
*/
|
||||
getBuffer(index: number): string {
|
||||
return this.buffers.get(index) || '';
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the current parsing state information for a specific index
|
||||
*
|
||||
* @param index - The tool call index
|
||||
* @returns Object containing current depth, string state, and escape state
|
||||
*/
|
||||
getState(index: number): {
|
||||
depth: number;
|
||||
inString: boolean;
|
||||
escape: boolean;
|
||||
} {
|
||||
return {
|
||||
depth: this.depths.get(index) || 0,
|
||||
inString: this.inStrings.get(index) || false,
|
||||
escape: this.escapes.get(index) || false,
|
||||
};
|
||||
}
|
||||
}
|
||||
1306
packages/core/src/core/refactor/telemetryService.test.ts
Normal file
1306
packages/core/src/core/refactor/telemetryService.test.ts
Normal file
File diff suppressed because it is too large
Load Diff
@@ -38,7 +38,7 @@ export interface TelemetryService {
|
||||
context: RequestContext,
|
||||
responses: GenerateContentResponse[],
|
||||
openaiRequest?: OpenAI.Chat.ChatCompletionCreateParams,
|
||||
openaiResponse?: OpenAI.Chat.ChatCompletion,
|
||||
openaiChunks?: OpenAI.Chat.ChatCompletionChunk[],
|
||||
): Promise<void>;
|
||||
}
|
||||
|
||||
@@ -82,16 +82,16 @@ export class DefaultTelemetryService implements TelemetryService {
|
||||
// Log API error event for UI telemetry
|
||||
const errorEvent = new ApiErrorEvent(
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
(error as any).requestID || 'unknown',
|
||||
(error as any)?.requestID || 'unknown',
|
||||
context.model,
|
||||
errorMessage,
|
||||
context.duration,
|
||||
context.userPromptId,
|
||||
context.authType,
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
(error as any).type,
|
||||
(error as any)?.type,
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
(error as any).code,
|
||||
(error as any)?.code,
|
||||
);
|
||||
logApiError(this.config, errorEvent);
|
||||
|
||||
@@ -109,7 +109,7 @@ export class DefaultTelemetryService implements TelemetryService {
|
||||
context: RequestContext,
|
||||
responses: GenerateContentResponse[],
|
||||
openaiRequest?: OpenAI.Chat.ChatCompletionCreateParams,
|
||||
openaiResponse?: OpenAI.Chat.ChatCompletion,
|
||||
openaiChunks?: OpenAI.Chat.ChatCompletionChunk[],
|
||||
): Promise<void> {
|
||||
// Get final usage metadata from the last response that has it
|
||||
const finalUsageMetadata = responses
|
||||
@@ -129,9 +129,127 @@ export class DefaultTelemetryService implements TelemetryService {
|
||||
|
||||
logApiResponse(this.config, responseEvent);
|
||||
|
||||
// Log interaction if enabled
|
||||
if (this.enableOpenAILogging && openaiRequest && openaiResponse) {
|
||||
await openaiLogger.logInteraction(openaiRequest, openaiResponse);
|
||||
// Log interaction if enabled - combine chunks only when needed
|
||||
if (
|
||||
this.enableOpenAILogging &&
|
||||
openaiRequest &&
|
||||
openaiChunks &&
|
||||
openaiChunks.length > 0
|
||||
) {
|
||||
const combinedResponse = this.combineOpenAIChunksForLogging(openaiChunks);
|
||||
await openaiLogger.logInteraction(openaiRequest, combinedResponse);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Combine OpenAI chunks for logging purposes
|
||||
* This method consolidates all OpenAI stream chunks into a single ChatCompletion response
|
||||
* for telemetry and logging purposes, avoiding unnecessary format conversions
|
||||
*/
|
||||
private combineOpenAIChunksForLogging(
|
||||
chunks: OpenAI.Chat.ChatCompletionChunk[],
|
||||
): OpenAI.Chat.ChatCompletion {
|
||||
if (chunks.length === 0) {
|
||||
throw new Error('No chunks to combine');
|
||||
}
|
||||
|
||||
const firstChunk = chunks[0];
|
||||
|
||||
// Combine all content from chunks
|
||||
let combinedContent = '';
|
||||
const toolCalls: OpenAI.Chat.ChatCompletionMessageToolCall[] = [];
|
||||
let finishReason:
|
||||
| 'stop'
|
||||
| 'length'
|
||||
| 'tool_calls'
|
||||
| 'content_filter'
|
||||
| 'function_call'
|
||||
| null = null;
|
||||
let usage:
|
||||
| {
|
||||
prompt_tokens: number;
|
||||
completion_tokens: number;
|
||||
total_tokens: number;
|
||||
}
|
||||
| undefined;
|
||||
|
||||
for (const chunk of chunks) {
|
||||
const choice = chunk.choices?.[0];
|
||||
if (choice) {
|
||||
// Combine text content
|
||||
if (choice.delta?.content) {
|
||||
combinedContent += choice.delta.content;
|
||||
}
|
||||
|
||||
// Collect tool calls
|
||||
if (choice.delta?.tool_calls) {
|
||||
for (const toolCall of choice.delta.tool_calls) {
|
||||
if (toolCall.index !== undefined) {
|
||||
if (!toolCalls[toolCall.index]) {
|
||||
toolCalls[toolCall.index] = {
|
||||
id: toolCall.id || '',
|
||||
type: toolCall.type || 'function',
|
||||
function: { name: '', arguments: '' },
|
||||
};
|
||||
}
|
||||
|
||||
if (toolCall.function?.name) {
|
||||
toolCalls[toolCall.index].function.name +=
|
||||
toolCall.function.name;
|
||||
}
|
||||
if (toolCall.function?.arguments) {
|
||||
toolCalls[toolCall.index].function.arguments +=
|
||||
toolCall.function.arguments;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Get finish reason from the last chunk
|
||||
if (choice.finish_reason) {
|
||||
finishReason = choice.finish_reason;
|
||||
}
|
||||
}
|
||||
|
||||
// Get usage from the last chunk that has it
|
||||
if (chunk.usage) {
|
||||
usage = chunk.usage;
|
||||
}
|
||||
}
|
||||
|
||||
// Create the combined ChatCompletion response
|
||||
const message: OpenAI.Chat.ChatCompletionMessage = {
|
||||
role: 'assistant',
|
||||
content: combinedContent || null,
|
||||
refusal: null,
|
||||
};
|
||||
|
||||
// Add tool calls if any
|
||||
if (toolCalls.length > 0) {
|
||||
message.tool_calls = toolCalls.filter((tc) => tc.id); // Filter out empty tool calls
|
||||
}
|
||||
|
||||
const combinedResponse: OpenAI.Chat.ChatCompletion = {
|
||||
id: firstChunk.id,
|
||||
object: 'chat.completion',
|
||||
created: firstChunk.created,
|
||||
model: firstChunk.model,
|
||||
choices: [
|
||||
{
|
||||
index: 0,
|
||||
message,
|
||||
finish_reason: finishReason || 'stop',
|
||||
logprobs: null,
|
||||
},
|
||||
],
|
||||
usage: usage || {
|
||||
prompt_tokens: 0,
|
||||
completion_tokens: 0,
|
||||
total_tokens: 0,
|
||||
},
|
||||
system_fingerprint: firstChunk.system_fingerprint,
|
||||
};
|
||||
|
||||
return combinedResponse;
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user