mirror of
https://github.com/QwenLM/qwen-code.git
synced 2025-12-21 01:07:46 +00:00
feat: add image tokenizer to fit vlm context window
This commit is contained in:
@@ -9,7 +9,7 @@ import { Box, Text } from 'ink';
|
||||
import { Colors } from '../colors.js';
|
||||
import {
|
||||
RadioButtonSelect,
|
||||
RadioSelectItem,
|
||||
type RadioSelectItem,
|
||||
} from './shared/RadioButtonSelect.js';
|
||||
import { useKeypress } from '../hooks/useKeypress.js';
|
||||
import { AvailableModel } from '../models/availableModels.js';
|
||||
|
||||
@@ -9,7 +9,7 @@ import { Box, Text } from 'ink';
|
||||
import { Colors } from '../colors.js';
|
||||
import {
|
||||
RadioButtonSelect,
|
||||
RadioSelectItem,
|
||||
type RadioSelectItem,
|
||||
} from './shared/RadioButtonSelect.js';
|
||||
import { useKeypress } from '../hooks/useKeypress.js';
|
||||
|
||||
|
||||
@@ -205,7 +205,7 @@ export const StatsDisplay: React.FC<StatsDisplayProps> = ({
|
||||
<Text>
|
||||
{tools.totalCalls} ({' '}
|
||||
<Text color={theme.status.success}>✓ {tools.totalSuccess}</Text>{' '}
|
||||
<Text color={theme.status.error}>✖ {tools.totalFail}</Text> )
|
||||
<Text color={theme.status.error}>x {tools.totalFail}</Text> )
|
||||
</Text>
|
||||
</StatRow>
|
||||
<StatRow title="Success Rate:">
|
||||
|
||||
@@ -7,7 +7,7 @@ exports[`<SessionSummaryDisplay /> > renders the summary display with a title 1`
|
||||
│ │
|
||||
│ Interaction Summary │
|
||||
│ Session ID: │
|
||||
│ Tool Calls: 0 ( ✓ 0 ✖ 0 ) │
|
||||
│ Tool Calls: 0 ( ✓ 0 x 0 ) │
|
||||
│ Success Rate: 0.0% │
|
||||
│ Code Changes: +42 -15 │
|
||||
│ │
|
||||
|
||||
@@ -7,7 +7,7 @@ exports[`<StatsDisplay /> > Code Changes Display > displays Code Changes when li
|
||||
│ │
|
||||
│ Interaction Summary │
|
||||
│ Session ID: test-session-id │
|
||||
│ Tool Calls: 1 ( ✓ 1 ✖ 0 ) │
|
||||
│ Tool Calls: 1 ( ✓ 1 x 0 ) │
|
||||
│ Success Rate: 100.0% │
|
||||
│ Code Changes: +42 -18 │
|
||||
│ │
|
||||
@@ -28,7 +28,7 @@ exports[`<StatsDisplay /> > Code Changes Display > hides Code Changes when no li
|
||||
│ │
|
||||
│ Interaction Summary │
|
||||
│ Session ID: test-session-id │
|
||||
│ Tool Calls: 1 ( ✓ 1 ✖ 0 ) │
|
||||
│ Tool Calls: 1 ( ✓ 1 x 0 ) │
|
||||
│ Success Rate: 100.0% │
|
||||
│ │
|
||||
│ Performance │
|
||||
@@ -48,7 +48,7 @@ exports[`<StatsDisplay /> > Conditional Color Tests > renders success rate in gr
|
||||
│ │
|
||||
│ Interaction Summary │
|
||||
│ Session ID: test-session-id │
|
||||
│ Tool Calls: 10 ( ✓ 10 ✖ 0 ) │
|
||||
│ Tool Calls: 10 ( ✓ 10 x 0 ) │
|
||||
│ Success Rate: 100.0% │
|
||||
│ │
|
||||
│ Performance │
|
||||
@@ -68,7 +68,7 @@ exports[`<StatsDisplay /> > Conditional Color Tests > renders success rate in re
|
||||
│ │
|
||||
│ Interaction Summary │
|
||||
│ Session ID: test-session-id │
|
||||
│ Tool Calls: 10 ( ✓ 5 ✖ 5 ) │
|
||||
│ Tool Calls: 10 ( ✓ 5 x 5 ) │
|
||||
│ Success Rate: 50.0% │
|
||||
│ │
|
||||
│ Performance │
|
||||
@@ -88,7 +88,7 @@ exports[`<StatsDisplay /> > Conditional Color Tests > renders success rate in ye
|
||||
│ │
|
||||
│ Interaction Summary │
|
||||
│ Session ID: test-session-id │
|
||||
│ Tool Calls: 10 ( ✓ 9 ✖ 1 ) │
|
||||
│ Tool Calls: 10 ( ✓ 9 x 1 ) │
|
||||
│ Success Rate: 90.0% │
|
||||
│ │
|
||||
│ Performance │
|
||||
@@ -108,7 +108,7 @@ exports[`<StatsDisplay /> > Conditional Rendering Tests > hides Efficiency secti
|
||||
│ │
|
||||
│ Interaction Summary │
|
||||
│ Session ID: test-session-id │
|
||||
│ Tool Calls: 0 ( ✓ 0 ✖ 0 ) │
|
||||
│ Tool Calls: 0 ( ✓ 0 x 0 ) │
|
||||
│ Success Rate: 0.0% │
|
||||
│ │
|
||||
│ Performance │
|
||||
@@ -132,7 +132,7 @@ exports[`<StatsDisplay /> > Conditional Rendering Tests > hides User Agreement w
|
||||
│ │
|
||||
│ Interaction Summary │
|
||||
│ Session ID: test-session-id │
|
||||
│ Tool Calls: 2 ( ✓ 1 ✖ 1 ) │
|
||||
│ Tool Calls: 2 ( ✓ 1 x 1 ) │
|
||||
│ Success Rate: 50.0% │
|
||||
│ │
|
||||
│ Performance │
|
||||
@@ -152,7 +152,7 @@ exports[`<StatsDisplay /> > Title Rendering > renders the custom title when a ti
|
||||
│ │
|
||||
│ Interaction Summary │
|
||||
│ Session ID: test-session-id │
|
||||
│ Tool Calls: 0 ( ✓ 0 ✖ 0 ) │
|
||||
│ Tool Calls: 0 ( ✓ 0 x 0 ) │
|
||||
│ Success Rate: 0.0% │
|
||||
│ │
|
||||
│ Performance │
|
||||
@@ -172,7 +172,7 @@ exports[`<StatsDisplay /> > Title Rendering > renders the default title when no
|
||||
│ │
|
||||
│ Interaction Summary │
|
||||
│ Session ID: test-session-id │
|
||||
│ Tool Calls: 0 ( ✓ 0 ✖ 0 ) │
|
||||
│ Tool Calls: 0 ( ✓ 0 x 0 ) │
|
||||
│ Success Rate: 0.0% │
|
||||
│ │
|
||||
│ Performance │
|
||||
@@ -192,7 +192,7 @@ exports[`<StatsDisplay /> > renders a table with two models correctly 1`] = `
|
||||
│ │
|
||||
│ Interaction Summary │
|
||||
│ Session ID: test-session-id │
|
||||
│ Tool Calls: 0 ( ✓ 0 ✖ 0 ) │
|
||||
│ Tool Calls: 0 ( ✓ 0 x 0 ) │
|
||||
│ Success Rate: 0.0% │
|
||||
│ │
|
||||
│ Performance │
|
||||
@@ -221,7 +221,7 @@ exports[`<StatsDisplay /> > renders all sections when all data is present 1`] =
|
||||
│ │
|
||||
│ Interaction Summary │
|
||||
│ Session ID: test-session-id │
|
||||
│ Tool Calls: 2 ( ✓ 1 ✖ 1 ) │
|
||||
│ Tool Calls: 2 ( ✓ 1 x 1 ) │
|
||||
│ Success Rate: 50.0% │
|
||||
│ User Agreement: 100.0% (1 reviewed) │
|
||||
│ │
|
||||
@@ -250,7 +250,7 @@ exports[`<StatsDisplay /> > renders only the Performance section in its zero sta
|
||||
│ │
|
||||
│ Interaction Summary │
|
||||
│ Session ID: test-session-id │
|
||||
│ Tool Calls: 0 ( ✓ 0 ✖ 0 ) │
|
||||
│ Tool Calls: 0 ( ✓ 0 x 0 ) │
|
||||
│ Success Rate: 0.0% │
|
||||
│ │
|
||||
│ Performance │
|
||||
|
||||
@@ -5,9 +5,10 @@
|
||||
*/
|
||||
|
||||
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
|
||||
import { OpenAIContentGenerator } from '../openaiContentGenerator.js';
|
||||
import { OpenAIContentGenerator } from '../openaiContentGenerator/openaiContentGenerator.js';
|
||||
import { Config } from '../../config/config.js';
|
||||
import { AuthType } from '../contentGenerator.js';
|
||||
import type { OpenAICompatibleProvider } from '../openaiContentGenerator/provider/index.js';
|
||||
import OpenAI from 'openai';
|
||||
|
||||
// Mock OpenAI
|
||||
@@ -30,6 +31,7 @@ describe('OpenAIContentGenerator Timeout Handling', () => {
|
||||
let mockConfig: Config;
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
let mockOpenAIClient: any;
|
||||
let mockProvider: OpenAICompatibleProvider;
|
||||
|
||||
beforeEach(() => {
|
||||
// Reset mocks
|
||||
@@ -42,6 +44,7 @@ describe('OpenAIContentGenerator Timeout Handling', () => {
|
||||
mockConfig = {
|
||||
getContentGeneratorConfig: vi.fn().mockReturnValue({
|
||||
authType: 'openai',
|
||||
enableOpenAILogging: false,
|
||||
}),
|
||||
getCliVersion: vi.fn().mockReturnValue('1.0.0'),
|
||||
} as unknown as Config;
|
||||
@@ -53,17 +56,34 @@ describe('OpenAIContentGenerator Timeout Handling', () => {
|
||||
create: vi.fn(),
|
||||
},
|
||||
},
|
||||
embeddings: {
|
||||
create: vi.fn(),
|
||||
},
|
||||
};
|
||||
|
||||
vi.mocked(OpenAI).mockImplementation(() => mockOpenAIClient);
|
||||
|
||||
// Create mock provider
|
||||
mockProvider = {
|
||||
buildHeaders: vi.fn().mockReturnValue({
|
||||
'User-Agent': 'QwenCode/1.0.0 (test; test)',
|
||||
}),
|
||||
buildClient: vi.fn().mockReturnValue(mockOpenAIClient),
|
||||
buildRequest: vi.fn().mockImplementation((req) => req),
|
||||
};
|
||||
|
||||
// Create generator instance
|
||||
const contentGeneratorConfig = {
|
||||
model: 'gpt-4',
|
||||
apiKey: 'test-key',
|
||||
authType: AuthType.USE_OPENAI,
|
||||
enableOpenAILogging: false,
|
||||
};
|
||||
generator = new OpenAIContentGenerator(contentGeneratorConfig, mockConfig);
|
||||
generator = new OpenAIContentGenerator(
|
||||
contentGeneratorConfig,
|
||||
mockConfig,
|
||||
mockProvider,
|
||||
);
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
@@ -209,7 +229,7 @@ describe('OpenAIContentGenerator Timeout Handling', () => {
|
||||
await expect(
|
||||
generator.generateContentStream(request, 'test-prompt-id'),
|
||||
).rejects.toThrow(
|
||||
/Streaming setup timeout after \d+s\. Try reducing input length or increasing timeout in config\./,
|
||||
/Streaming request timeout after \d+s\. Try reducing input length or increasing timeout in config\./,
|
||||
);
|
||||
});
|
||||
|
||||
@@ -227,12 +247,8 @@ describe('OpenAIContentGenerator Timeout Handling', () => {
|
||||
} catch (error: unknown) {
|
||||
const errorMessage =
|
||||
error instanceof Error ? error.message : String(error);
|
||||
expect(errorMessage).toContain(
|
||||
'Streaming setup timeout troubleshooting:',
|
||||
);
|
||||
expect(errorMessage).toContain(
|
||||
'Check network connectivity and firewall settings',
|
||||
);
|
||||
expect(errorMessage).toContain('Streaming timeout troubleshooting:');
|
||||
expect(errorMessage).toContain('Check network connectivity');
|
||||
expect(errorMessage).toContain('Consider using non-streaming mode');
|
||||
}
|
||||
});
|
||||
@@ -246,23 +262,21 @@ describe('OpenAIContentGenerator Timeout Handling', () => {
|
||||
authType: AuthType.USE_OPENAI,
|
||||
baseUrl: 'http://localhost:8080',
|
||||
};
|
||||
new OpenAIContentGenerator(contentGeneratorConfig, mockConfig);
|
||||
new OpenAIContentGenerator(
|
||||
contentGeneratorConfig,
|
||||
mockConfig,
|
||||
mockProvider,
|
||||
);
|
||||
|
||||
// Verify OpenAI client was created with timeout config
|
||||
expect(OpenAI).toHaveBeenCalledWith({
|
||||
apiKey: 'test-key',
|
||||
baseURL: 'http://localhost:8080',
|
||||
timeout: 120000,
|
||||
maxRetries: 3,
|
||||
defaultHeaders: {
|
||||
'User-Agent': expect.stringMatching(/^QwenCode/),
|
||||
},
|
||||
});
|
||||
// Verify provider buildClient was called
|
||||
expect(mockProvider.buildClient).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should use custom timeout from config', () => {
|
||||
const customConfig = {
|
||||
getContentGeneratorConfig: vi.fn().mockReturnValue({}),
|
||||
getContentGeneratorConfig: vi.fn().mockReturnValue({
|
||||
enableOpenAILogging: false,
|
||||
}),
|
||||
getCliVersion: vi.fn().mockReturnValue('1.0.0'),
|
||||
} as unknown as Config;
|
||||
|
||||
@@ -274,22 +288,31 @@ describe('OpenAIContentGenerator Timeout Handling', () => {
|
||||
timeout: 300000,
|
||||
maxRetries: 5,
|
||||
};
|
||||
new OpenAIContentGenerator(contentGeneratorConfig, customConfig);
|
||||
|
||||
expect(OpenAI).toHaveBeenCalledWith({
|
||||
apiKey: 'test-key',
|
||||
baseURL: 'http://localhost:8080',
|
||||
timeout: 300000,
|
||||
maxRetries: 5,
|
||||
defaultHeaders: {
|
||||
'User-Agent': expect.stringMatching(/^QwenCode/),
|
||||
},
|
||||
});
|
||||
// Create a custom mock provider for this test
|
||||
const customMockProvider: OpenAICompatibleProvider = {
|
||||
buildHeaders: vi.fn().mockReturnValue({
|
||||
'User-Agent': 'QwenCode/1.0.0 (test; test)',
|
||||
}),
|
||||
buildClient: vi.fn().mockReturnValue(mockOpenAIClient),
|
||||
buildRequest: vi.fn().mockImplementation((req) => req),
|
||||
};
|
||||
|
||||
new OpenAIContentGenerator(
|
||||
contentGeneratorConfig,
|
||||
customConfig,
|
||||
customMockProvider,
|
||||
);
|
||||
|
||||
// Verify provider buildClient was called
|
||||
expect(customMockProvider.buildClient).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should handle missing timeout config gracefully', () => {
|
||||
const noTimeoutConfig = {
|
||||
getContentGeneratorConfig: vi.fn().mockReturnValue({}),
|
||||
getContentGeneratorConfig: vi.fn().mockReturnValue({
|
||||
enableOpenAILogging: false,
|
||||
}),
|
||||
getCliVersion: vi.fn().mockReturnValue('1.0.0'),
|
||||
} as unknown as Config;
|
||||
|
||||
@@ -299,17 +322,24 @@ describe('OpenAIContentGenerator Timeout Handling', () => {
|
||||
authType: AuthType.USE_OPENAI,
|
||||
baseUrl: 'http://localhost:8080',
|
||||
};
|
||||
new OpenAIContentGenerator(contentGeneratorConfig, noTimeoutConfig);
|
||||
|
||||
expect(OpenAI).toHaveBeenCalledWith({
|
||||
apiKey: 'test-key',
|
||||
baseURL: 'http://localhost:8080',
|
||||
timeout: 120000, // default
|
||||
maxRetries: 3, // default
|
||||
defaultHeaders: {
|
||||
'User-Agent': expect.stringMatching(/^QwenCode/),
|
||||
},
|
||||
});
|
||||
// Create a custom mock provider for this test
|
||||
const noTimeoutMockProvider: OpenAICompatibleProvider = {
|
||||
buildHeaders: vi.fn().mockReturnValue({
|
||||
'User-Agent': 'QwenCode/1.0.0 (test; test)',
|
||||
}),
|
||||
buildClient: vi.fn().mockReturnValue(mockOpenAIClient),
|
||||
buildRequest: vi.fn().mockImplementation((req) => req),
|
||||
};
|
||||
|
||||
new OpenAIContentGenerator(
|
||||
contentGeneratorConfig,
|
||||
noTimeoutConfig,
|
||||
noTimeoutMockProvider,
|
||||
);
|
||||
|
||||
// Verify provider buildClient was called
|
||||
expect(noTimeoutMockProvider.buildClient).toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
|
||||
@@ -480,7 +480,7 @@ export class GeminiChat {
|
||||
if (error instanceof Error && error.message) {
|
||||
if (isSchemaDepthError(error.message)) return false;
|
||||
if (error.message.includes('429')) return true;
|
||||
if (error.message.match(/5\d{2}/)) return true;
|
||||
if (error.message.match(/^5\d{2}/)) return true;
|
||||
}
|
||||
return false;
|
||||
},
|
||||
|
||||
@@ -644,8 +644,8 @@ export class OpenAIContentGenerator implements ContentGenerator {
|
||||
let totalTokens = 0;
|
||||
|
||||
try {
|
||||
const { get_encoding } = await import('tiktoken');
|
||||
const encoding = get_encoding('cl100k_base'); // GPT-4 encoding, but estimate for qwen
|
||||
const tikToken = await import('tiktoken');
|
||||
const encoding = tikToken.get_encoding('cl100k_base'); // GPT-4 encoding, but estimate for qwen
|
||||
totalTokens = encoding.encode(content).length;
|
||||
encoding.free();
|
||||
} catch (error) {
|
||||
|
||||
@@ -5,6 +5,37 @@
|
||||
*/
|
||||
|
||||
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
|
||||
|
||||
// Mock the request tokenizer module BEFORE importing the class that uses it
|
||||
const mockTokenizer = {
|
||||
calculateTokens: vi.fn().mockResolvedValue({
|
||||
totalTokens: 50,
|
||||
breakdown: {
|
||||
textTokens: 50,
|
||||
imageTokens: 0,
|
||||
audioTokens: 0,
|
||||
otherTokens: 0,
|
||||
},
|
||||
processingTime: 1,
|
||||
}),
|
||||
dispose: vi.fn(),
|
||||
};
|
||||
|
||||
vi.mock('../../../utils/request-tokenizer/index.js', () => ({
|
||||
getDefaultTokenizer: vi.fn(() => mockTokenizer),
|
||||
DefaultRequestTokenizer: vi.fn(() => mockTokenizer),
|
||||
disposeDefaultTokenizer: vi.fn(),
|
||||
}));
|
||||
|
||||
// Mock tiktoken as well for completeness
|
||||
vi.mock('tiktoken', () => ({
|
||||
get_encoding: vi.fn(() => ({
|
||||
encode: vi.fn(() => new Array(50)), // Mock 50 tokens
|
||||
free: vi.fn(),
|
||||
})),
|
||||
}));
|
||||
|
||||
// Now import the modules that depend on the mocked modules
|
||||
import { OpenAIContentGenerator } from './openaiContentGenerator.js';
|
||||
import { Config } from '../../config/config.js';
|
||||
import { AuthType } from '../contentGenerator.js';
|
||||
@@ -15,14 +46,6 @@ import type {
|
||||
import type { OpenAICompatibleProvider } from './provider/index.js';
|
||||
import OpenAI from 'openai';
|
||||
|
||||
// Mock tiktoken
|
||||
vi.mock('tiktoken', () => ({
|
||||
get_encoding: vi.fn().mockReturnValue({
|
||||
encode: vi.fn().mockReturnValue(new Array(50)), // Mock 50 tokens
|
||||
free: vi.fn(),
|
||||
}),
|
||||
}));
|
||||
|
||||
describe('OpenAIContentGenerator (Refactored)', () => {
|
||||
let generator: OpenAIContentGenerator;
|
||||
let mockConfig: Config;
|
||||
|
||||
@@ -13,6 +13,7 @@ import { ContentGenerationPipeline, PipelineConfig } from './pipeline.js';
|
||||
import { DefaultTelemetryService } from './telemetryService.js';
|
||||
import { EnhancedErrorHandler } from './errorHandler.js';
|
||||
import { ContentGeneratorConfig } from '../contentGenerator.js';
|
||||
import { getDefaultTokenizer } from '../../utils/request-tokenizer/index.js';
|
||||
|
||||
export class OpenAIContentGenerator implements ContentGenerator {
|
||||
protected pipeline: ContentGenerationPipeline;
|
||||
@@ -70,27 +71,30 @@ export class OpenAIContentGenerator implements ContentGenerator {
|
||||
async countTokens(
|
||||
request: CountTokensParameters,
|
||||
): Promise<CountTokensResponse> {
|
||||
// Use tiktoken for accurate token counting
|
||||
const content = JSON.stringify(request.contents);
|
||||
let totalTokens = 0;
|
||||
|
||||
try {
|
||||
const { get_encoding } = await import('tiktoken');
|
||||
const encoding = get_encoding('cl100k_base'); // GPT-4 encoding, but estimate for qwen
|
||||
totalTokens = encoding.encode(content).length;
|
||||
encoding.free();
|
||||
// Use the new high-performance request tokenizer
|
||||
const tokenizer = getDefaultTokenizer();
|
||||
const result = await tokenizer.calculateTokens(request, {
|
||||
textEncoding: 'cl100k_base', // Use GPT-4 encoding for consistency
|
||||
});
|
||||
|
||||
return {
|
||||
totalTokens: result.totalTokens,
|
||||
};
|
||||
} catch (error) {
|
||||
console.warn(
|
||||
'Failed to load tiktoken, falling back to character approximation:',
|
||||
'Failed to calculate tokens with new tokenizer, falling back to simple method:',
|
||||
error,
|
||||
);
|
||||
// Fallback: rough approximation using character count
|
||||
totalTokens = Math.ceil(content.length / 4); // Rough estimate: 1 token ≈ 4 characters
|
||||
}
|
||||
|
||||
return {
|
||||
totalTokens,
|
||||
};
|
||||
// Fallback to original simple method
|
||||
const content = JSON.stringify(request.contents);
|
||||
const totalTokens = Math.ceil(content.length / 4); // Rough estimate: 1 token ≈ 4 characters
|
||||
|
||||
return {
|
||||
totalTokens,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
async embedContent(
|
||||
|
||||
@@ -98,7 +98,7 @@ export class ContentGenerationPipeline {
|
||||
* 2. Filter empty responses
|
||||
* 3. Handle chunk merging for providers that send finishReason and usageMetadata separately
|
||||
* 4. Collect both formats for logging
|
||||
* 5. Handle success/error logging with original OpenAI format
|
||||
* 5. Handle success/error logging
|
||||
*/
|
||||
private async *processStreamWithLogging(
|
||||
stream: AsyncIterable<OpenAI.Chat.ChatCompletionChunk>,
|
||||
@@ -166,19 +166,11 @@ export class ContentGenerationPipeline {
|
||||
collectedOpenAIChunks,
|
||||
);
|
||||
} catch (error) {
|
||||
// Stage 2e: Stream failed - handle error and logging
|
||||
context.duration = Date.now() - context.startTime;
|
||||
|
||||
// Clear streaming tool calls on error to prevent data pollution
|
||||
this.converter.resetStreamingToolCalls();
|
||||
|
||||
await this.config.telemetryService.logError(
|
||||
context,
|
||||
error,
|
||||
openaiRequest,
|
||||
);
|
||||
|
||||
this.config.errorHandler.handle(error, context, request);
|
||||
// Use shared error handling logic
|
||||
await this.handleError(error, context, request);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -362,25 +354,59 @@ export class ContentGenerationPipeline {
|
||||
context.duration = Date.now() - context.startTime;
|
||||
return result;
|
||||
} catch (error) {
|
||||
context.duration = Date.now() - context.startTime;
|
||||
|
||||
// Log error
|
||||
const openaiRequest = await this.buildRequest(
|
||||
// Use shared error handling logic
|
||||
return await this.handleError(
|
||||
error,
|
||||
context,
|
||||
request,
|
||||
userPromptId,
|
||||
isStreaming,
|
||||
);
|
||||
await this.config.telemetryService.logError(
|
||||
context,
|
||||
error,
|
||||
openaiRequest,
|
||||
);
|
||||
|
||||
// Handle and throw enhanced error
|
||||
this.config.errorHandler.handle(error, context, request);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Shared error handling logic for both executeWithErrorHandling and processStreamWithLogging
|
||||
* This centralizes the common error processing steps to avoid duplication
|
||||
*/
|
||||
private async handleError(
|
||||
error: unknown,
|
||||
context: RequestContext,
|
||||
request: GenerateContentParameters,
|
||||
userPromptId?: string,
|
||||
isStreaming?: boolean,
|
||||
): Promise<never> {
|
||||
context.duration = Date.now() - context.startTime;
|
||||
|
||||
// Build request for logging (may fail, but we still want to log the error)
|
||||
let openaiRequest: OpenAI.Chat.ChatCompletionCreateParams;
|
||||
try {
|
||||
if (userPromptId !== undefined && isStreaming !== undefined) {
|
||||
openaiRequest = await this.buildRequest(
|
||||
request,
|
||||
userPromptId,
|
||||
isStreaming,
|
||||
);
|
||||
} else {
|
||||
// For processStreamWithLogging, we don't have userPromptId/isStreaming,
|
||||
// so create a minimal request
|
||||
openaiRequest = {
|
||||
model: this.contentGeneratorConfig.model,
|
||||
messages: [],
|
||||
};
|
||||
}
|
||||
} catch (_buildError) {
|
||||
// If we can't build the request, create a minimal one for logging
|
||||
openaiRequest = {
|
||||
model: this.contentGeneratorConfig.model,
|
||||
messages: [],
|
||||
};
|
||||
}
|
||||
|
||||
await this.config.telemetryService.logError(context, error, openaiRequest);
|
||||
this.config.errorHandler.handle(error, context, request);
|
||||
}
|
||||
|
||||
/**
|
||||
* Create request context with common properties
|
||||
*/
|
||||
|
||||
@@ -78,6 +78,16 @@ export class DashScopeOpenAICompatibleProvider
|
||||
messages = this.addDashScopeCacheControl(messages, cacheTarget);
|
||||
}
|
||||
|
||||
if (request.model.startsWith('qwen-vl')) {
|
||||
return {
|
||||
...request,
|
||||
messages,
|
||||
...(this.buildMetadata(userPromptId) || {}),
|
||||
/* @ts-expect-error dashscope exclusive */
|
||||
vl_high_resolution_images: true,
|
||||
};
|
||||
}
|
||||
|
||||
return {
|
||||
...request, // Preserve all original parameters including sampling params
|
||||
messages,
|
||||
|
||||
@@ -116,6 +116,9 @@ const PATTERNS: Array<[RegExp, TokenCount]> = [
|
||||
[/^qwen-flash-latest$/, LIMITS['1m']],
|
||||
[/^qwen-turbo.*$/, LIMITS['128k']],
|
||||
|
||||
// Qwen Vision Models
|
||||
[/^qwen-vl-max.*$/, LIMITS['128k']],
|
||||
|
||||
// -------------------
|
||||
// ByteDance Seed-OSS (512K)
|
||||
// -------------------
|
||||
|
||||
@@ -222,7 +222,7 @@ describe('Turn', () => {
|
||||
expect(turn.getDebugResponses().length).toBe(0);
|
||||
expect(reportError).toHaveBeenCalledWith(
|
||||
error,
|
||||
'Error when talking to Gemini API',
|
||||
'Error when talking to API',
|
||||
[...historyContent, reqParts],
|
||||
'Turn.run-sendMessageStream',
|
||||
);
|
||||
|
||||
@@ -273,7 +273,7 @@ export class Turn {
|
||||
const contextForReport = [...this.chat.getHistory(/*curated*/ true), req];
|
||||
await reportError(
|
||||
error,
|
||||
'Error when talking to Gemini API',
|
||||
'Error when talking to API',
|
||||
contextForReport,
|
||||
'Turn.run-sendMessageStream',
|
||||
);
|
||||
|
||||
@@ -404,11 +404,9 @@ describe('QwenContentGenerator', () => {
|
||||
expect(mockQwenClient.getAccessToken).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should count tokens with valid token', async () => {
|
||||
vi.mocked(mockQwenClient.getAccessToken).mockResolvedValue({
|
||||
token: 'valid-token',
|
||||
});
|
||||
vi.mocked(mockQwenClient.getCredentials).mockReturnValue(mockCredentials);
|
||||
it('should count tokens without requiring authentication', async () => {
|
||||
// Clear any previous mock calls
|
||||
vi.clearAllMocks();
|
||||
|
||||
const request: CountTokensParameters = {
|
||||
model: 'qwen-turbo',
|
||||
@@ -418,7 +416,8 @@ describe('QwenContentGenerator', () => {
|
||||
const result = await qwenContentGenerator.countTokens(request);
|
||||
|
||||
expect(result.totalTokens).toBe(15);
|
||||
expect(mockQwenClient.getAccessToken).toHaveBeenCalled();
|
||||
// countTokens is a local operation and should not require OAuth credentials
|
||||
expect(mockQwenClient.getAccessToken).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should embed content with valid token', async () => {
|
||||
@@ -1655,7 +1654,7 @@ describe('QwenContentGenerator', () => {
|
||||
SharedTokenManager.getInstance = originalGetInstance;
|
||||
});
|
||||
|
||||
it('should handle all method types with token failure', async () => {
|
||||
it('should handle method types with token failure (except countTokens)', async () => {
|
||||
const mockTokenManager = {
|
||||
getValidCredentials: vi
|
||||
.fn()
|
||||
@@ -1688,7 +1687,7 @@ describe('QwenContentGenerator', () => {
|
||||
contents: [{ parts: [{ text: 'Embed' }] }],
|
||||
};
|
||||
|
||||
// All methods should fail with the same error
|
||||
// Methods requiring authentication should fail
|
||||
await expect(
|
||||
newGenerator.generateContent(generateRequest, 'test-id'),
|
||||
).rejects.toThrow('Failed to obtain valid Qwen access token');
|
||||
@@ -1697,14 +1696,14 @@ describe('QwenContentGenerator', () => {
|
||||
newGenerator.generateContentStream(generateRequest, 'test-id'),
|
||||
).rejects.toThrow('Failed to obtain valid Qwen access token');
|
||||
|
||||
await expect(newGenerator.countTokens(countRequest)).rejects.toThrow(
|
||||
'Failed to obtain valid Qwen access token',
|
||||
);
|
||||
|
||||
await expect(newGenerator.embedContent(embedRequest)).rejects.toThrow(
|
||||
'Failed to obtain valid Qwen access token',
|
||||
);
|
||||
|
||||
// countTokens should succeed as it's a local operation
|
||||
const countResult = await newGenerator.countTokens(countRequest);
|
||||
expect(countResult.totalTokens).toBe(15);
|
||||
|
||||
SharedTokenManager.getInstance = originalGetInstance;
|
||||
});
|
||||
});
|
||||
|
||||
@@ -180,9 +180,7 @@ export class QwenContentGenerator extends OpenAIContentGenerator {
|
||||
override async countTokens(
|
||||
request: CountTokensParameters,
|
||||
): Promise<CountTokensResponse> {
|
||||
return this.executeWithCredentialManagement(() =>
|
||||
super.countTokens(request),
|
||||
);
|
||||
return super.countTokens(request);
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
157
packages/core/src/utils/request-tokenizer/imageTokenizer.test.ts
Normal file
157
packages/core/src/utils/request-tokenizer/imageTokenizer.test.ts
Normal file
@@ -0,0 +1,157 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Qwen
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { describe, it, expect } from 'vitest';
|
||||
import { ImageTokenizer } from './imageTokenizer.js';
|
||||
|
||||
describe('ImageTokenizer', () => {
|
||||
const tokenizer = new ImageTokenizer();
|
||||
|
||||
describe('token calculation', () => {
|
||||
it('should calculate tokens based on image dimensions with reference logic', () => {
|
||||
const metadata = {
|
||||
width: 28,
|
||||
height: 28,
|
||||
mimeType: 'image/png',
|
||||
dataSize: 1000,
|
||||
};
|
||||
|
||||
const tokens = tokenizer.calculateTokens(metadata);
|
||||
|
||||
// 28x28 = 784 pixels = 1 image token + 2 special tokens = 3 total
|
||||
// But minimum scaling may apply for small images
|
||||
expect(tokens).toBeGreaterThanOrEqual(6); // Minimum after scaling + special tokens
|
||||
});
|
||||
|
||||
it('should calculate tokens for larger images', () => {
|
||||
const metadata = {
|
||||
width: 512,
|
||||
height: 512,
|
||||
mimeType: 'image/png',
|
||||
dataSize: 10000,
|
||||
};
|
||||
|
||||
const tokens = tokenizer.calculateTokens(metadata);
|
||||
|
||||
// 512x512 with reference logic: rounded dimensions + scaling + special tokens
|
||||
expect(tokens).toBeGreaterThan(300);
|
||||
expect(tokens).toBeLessThan(400); // Should be reasonable for 512x512
|
||||
});
|
||||
|
||||
it('should enforce minimum tokens per image with scaling', () => {
|
||||
const metadata = {
|
||||
width: 1,
|
||||
height: 1,
|
||||
mimeType: 'image/png',
|
||||
dataSize: 100,
|
||||
};
|
||||
|
||||
const tokens = tokenizer.calculateTokens(metadata);
|
||||
|
||||
// Tiny images get scaled up to minimum pixels + special tokens
|
||||
expect(tokens).toBeGreaterThanOrEqual(6); // 4 image tokens + 2 special tokens
|
||||
});
|
||||
|
||||
it('should handle very large images with scaling', () => {
|
||||
const metadata = {
|
||||
width: 8192,
|
||||
height: 8192,
|
||||
mimeType: 'image/png',
|
||||
dataSize: 100000,
|
||||
};
|
||||
|
||||
const tokens = tokenizer.calculateTokens(metadata);
|
||||
|
||||
// Very large images should be scaled down to max limit + special tokens
|
||||
expect(tokens).toBeLessThanOrEqual(16386); // 16384 max + 2 special tokens
|
||||
expect(tokens).toBeGreaterThan(16000); // Should be close to the limit
|
||||
});
|
||||
});
|
||||
|
||||
describe('PNG dimension extraction', () => {
|
||||
it('should extract dimensions from valid PNG', async () => {
|
||||
// 1x1 PNG image in base64
|
||||
const pngBase64 =
|
||||
'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNkYPhfDwAChAI9jU77yQAAAABJRU5ErkJggg==';
|
||||
|
||||
const metadata = await tokenizer.extractImageMetadata(
|
||||
pngBase64,
|
||||
'image/png',
|
||||
);
|
||||
|
||||
expect(metadata.width).toBe(1);
|
||||
expect(metadata.height).toBe(1);
|
||||
expect(metadata.mimeType).toBe('image/png');
|
||||
});
|
||||
|
||||
it('should handle invalid PNG gracefully', async () => {
|
||||
const invalidBase64 = 'invalid-png-data';
|
||||
|
||||
const metadata = await tokenizer.extractImageMetadata(
|
||||
invalidBase64,
|
||||
'image/png',
|
||||
);
|
||||
|
||||
// Should return default dimensions
|
||||
expect(metadata.width).toBe(512);
|
||||
expect(metadata.height).toBe(512);
|
||||
expect(metadata.mimeType).toBe('image/png');
|
||||
});
|
||||
});
|
||||
|
||||
describe('batch processing', () => {
|
||||
it('should process multiple images serially', async () => {
|
||||
const pngBase64 =
|
||||
'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNkYPhfDwAChAI9jU77yQAAAABJRU5ErkJggg==';
|
||||
|
||||
const images = [
|
||||
{ data: pngBase64, mimeType: 'image/png' },
|
||||
{ data: pngBase64, mimeType: 'image/png' },
|
||||
{ data: pngBase64, mimeType: 'image/png' },
|
||||
];
|
||||
|
||||
const tokens = await tokenizer.calculateTokensBatch(images);
|
||||
|
||||
expect(tokens).toHaveLength(3);
|
||||
expect(tokens.every((t) => t >= 4)).toBe(true); // All should have at least 4 tokens
|
||||
});
|
||||
|
||||
it('should handle mixed valid and invalid images', async () => {
|
||||
const validPng =
|
||||
'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNkYPhfDwAChAI9jU77yQAAAABJRU5ErkJggg==';
|
||||
const invalidPng = 'invalid-data';
|
||||
|
||||
const images = [
|
||||
{ data: validPng, mimeType: 'image/png' },
|
||||
{ data: invalidPng, mimeType: 'image/png' },
|
||||
];
|
||||
|
||||
const tokens = await tokenizer.calculateTokensBatch(images);
|
||||
|
||||
expect(tokens).toHaveLength(2);
|
||||
expect(tokens.every((t) => t >= 4)).toBe(true); // All should have at least minimum tokens
|
||||
});
|
||||
});
|
||||
|
||||
describe('different image formats', () => {
|
||||
it('should handle different MIME types', async () => {
|
||||
const pngBase64 =
|
||||
'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNkYPhfDwAChAI9jU77yQAAAABJRU5ErkJggg==';
|
||||
|
||||
const formats = ['image/png', 'image/jpeg', 'image/webp', 'image/gif'];
|
||||
|
||||
for (const mimeType of formats) {
|
||||
const metadata = await tokenizer.extractImageMetadata(
|
||||
pngBase64,
|
||||
mimeType,
|
||||
);
|
||||
expect(metadata.mimeType).toBe(mimeType);
|
||||
expect(metadata.width).toBeGreaterThan(0);
|
||||
expect(metadata.height).toBeGreaterThan(0);
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
309
packages/core/src/utils/request-tokenizer/imageTokenizer.ts
Normal file
309
packages/core/src/utils/request-tokenizer/imageTokenizer.ts
Normal file
@@ -0,0 +1,309 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Qwen
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { ImageMetadata } from './types.js';
|
||||
|
||||
/**
|
||||
* Image tokenizer for calculating image tokens based on dimensions
|
||||
*
|
||||
* Key rules:
|
||||
* - 28x28 pixels = 1 token
|
||||
* - Minimum: 4 tokens per image
|
||||
* - Maximum: 16384 tokens per image
|
||||
* - Additional: 2 special tokens (vision_bos + vision_eos)
|
||||
* - Supports: PNG, JPEG, WebP, GIF formats
|
||||
*/
|
||||
export class ImageTokenizer {
|
||||
/** 28x28 pixels = 1 token */
|
||||
private static readonly PIXELS_PER_TOKEN = 28 * 28;
|
||||
|
||||
/** Minimum tokens per image */
|
||||
private static readonly MIN_TOKENS_PER_IMAGE = 4;
|
||||
|
||||
/** Maximum tokens per image */
|
||||
private static readonly MAX_TOKENS_PER_IMAGE = 16384;
|
||||
|
||||
/** Special tokens for vision markers */
|
||||
private static readonly VISION_SPECIAL_TOKENS = 2;
|
||||
|
||||
/**
|
||||
* Extract image metadata from base64 data
|
||||
*
|
||||
* @param base64Data Base64-encoded image data (with or without data URL prefix)
|
||||
* @param mimeType MIME type of the image
|
||||
* @returns Promise resolving to ImageMetadata with dimensions and format info
|
||||
*/
|
||||
async extractImageMetadata(
|
||||
base64Data: string,
|
||||
mimeType: string,
|
||||
): Promise<ImageMetadata> {
|
||||
try {
|
||||
const cleanBase64 = base64Data.replace(/^data:[^;]+;base64,/, '');
|
||||
const buffer = Buffer.from(cleanBase64, 'base64');
|
||||
const dimensions = await this.extractDimensions(buffer, mimeType);
|
||||
|
||||
return {
|
||||
width: dimensions.width,
|
||||
height: dimensions.height,
|
||||
mimeType,
|
||||
dataSize: buffer.length,
|
||||
};
|
||||
} catch (error) {
|
||||
console.warn('Failed to extract image metadata:', error);
|
||||
// Return default metadata for fallback
|
||||
return {
|
||||
width: 512,
|
||||
height: 512,
|
||||
mimeType,
|
||||
dataSize: Math.floor(base64Data.length * 0.75),
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Extract image dimensions from buffer based on format
|
||||
*
|
||||
* @param buffer Binary image data buffer
|
||||
* @param mimeType MIME type to determine parsing strategy
|
||||
* @returns Promise resolving to width and height dimensions
|
||||
*/
|
||||
private async extractDimensions(
|
||||
buffer: Buffer,
|
||||
mimeType: string,
|
||||
): Promise<{ width: number; height: number }> {
|
||||
if (mimeType.includes('png')) {
|
||||
return this.extractPngDimensions(buffer);
|
||||
}
|
||||
|
||||
if (mimeType.includes('jpeg') || mimeType.includes('jpg')) {
|
||||
return this.extractJpegDimensions(buffer);
|
||||
}
|
||||
|
||||
if (mimeType.includes('webp')) {
|
||||
return this.extractWebpDimensions(buffer);
|
||||
}
|
||||
|
||||
if (mimeType.includes('gif')) {
|
||||
return this.extractGifDimensions(buffer);
|
||||
}
|
||||
|
||||
return { width: 512, height: 512 };
|
||||
}
|
||||
|
||||
/**
|
||||
* Extract PNG dimensions from IHDR chunk
|
||||
* PNG signature: 89 50 4E 47 0D 0A 1A 0A
|
||||
* Width/height at bytes 16-19 and 20-23 (big-endian)
|
||||
*/
|
||||
private extractPngDimensions(buffer: Buffer): {
|
||||
width: number;
|
||||
height: number;
|
||||
} {
|
||||
if (buffer.length < 24) {
|
||||
throw new Error('Invalid PNG: buffer too short');
|
||||
}
|
||||
|
||||
// Verify PNG signature
|
||||
const signature = buffer.subarray(0, 8);
|
||||
const expectedSignature = Buffer.from([
|
||||
0x89, 0x50, 0x4e, 0x47, 0x0d, 0x0a, 0x1a, 0x0a,
|
||||
]);
|
||||
if (!signature.equals(expectedSignature)) {
|
||||
throw new Error('Invalid PNG signature');
|
||||
}
|
||||
|
||||
const width = buffer.readUInt32BE(16);
|
||||
const height = buffer.readUInt32BE(20);
|
||||
|
||||
return { width, height };
|
||||
}
|
||||
|
||||
/**
|
||||
* Extract JPEG dimensions from SOF (Start of Frame) markers
|
||||
* JPEG starts with FF D8, SOF markers: 0xC0-0xC3, 0xC5-0xC7, 0xC9-0xCB, 0xCD-0xCF
|
||||
* Dimensions at offset +5 (height) and +7 (width) from SOF marker
|
||||
*/
|
||||
private extractJpegDimensions(buffer: Buffer): {
|
||||
width: number;
|
||||
height: number;
|
||||
} {
|
||||
if (buffer.length < 4 || buffer[0] !== 0xff || buffer[1] !== 0xd8) {
|
||||
throw new Error('Invalid JPEG signature');
|
||||
}
|
||||
|
||||
let offset = 2;
|
||||
|
||||
while (offset < buffer.length - 8) {
|
||||
if (buffer[offset] !== 0xff) {
|
||||
offset++;
|
||||
continue;
|
||||
}
|
||||
|
||||
const marker = buffer[offset + 1];
|
||||
|
||||
// SOF markers
|
||||
if (
|
||||
(marker >= 0xc0 && marker <= 0xc3) ||
|
||||
(marker >= 0xc5 && marker <= 0xc7) ||
|
||||
(marker >= 0xc9 && marker <= 0xcb) ||
|
||||
(marker >= 0xcd && marker <= 0xcf)
|
||||
) {
|
||||
const height = buffer.readUInt16BE(offset + 5);
|
||||
const width = buffer.readUInt16BE(offset + 7);
|
||||
return { width, height };
|
||||
}
|
||||
|
||||
const segmentLength = buffer.readUInt16BE(offset + 2);
|
||||
offset += 2 + segmentLength;
|
||||
}
|
||||
|
||||
throw new Error('Could not find JPEG dimensions');
|
||||
}
|
||||
|
||||
/**
|
||||
* Extract WebP dimensions from RIFF container
|
||||
* Supports VP8, VP8L, and VP8X formats
|
||||
*/
|
||||
private extractWebpDimensions(buffer: Buffer): {
|
||||
width: number;
|
||||
height: number;
|
||||
} {
|
||||
if (buffer.length < 30) {
|
||||
throw new Error('Invalid WebP: too short');
|
||||
}
|
||||
|
||||
const riffSignature = buffer.subarray(0, 4).toString('ascii');
|
||||
const webpSignature = buffer.subarray(8, 12).toString('ascii');
|
||||
|
||||
if (riffSignature !== 'RIFF' || webpSignature !== 'WEBP') {
|
||||
throw new Error('Invalid WebP signature');
|
||||
}
|
||||
|
||||
const format = buffer.subarray(12, 16).toString('ascii');
|
||||
|
||||
if (format === 'VP8 ') {
|
||||
const width = buffer.readUInt16LE(26) & 0x3fff;
|
||||
const height = buffer.readUInt16LE(28) & 0x3fff;
|
||||
return { width, height };
|
||||
} else if (format === 'VP8L') {
|
||||
const bits = buffer.readUInt32LE(21);
|
||||
const width = (bits & 0x3fff) + 1;
|
||||
const height = ((bits >> 14) & 0x3fff) + 1;
|
||||
return { width, height };
|
||||
} else if (format === 'VP8X') {
|
||||
const width = (buffer.readUInt32LE(24) & 0xffffff) + 1;
|
||||
const height = (buffer.readUInt32LE(26) & 0xffffff) + 1;
|
||||
return { width, height };
|
||||
}
|
||||
|
||||
throw new Error('Unsupported WebP format');
|
||||
}
|
||||
|
||||
/**
|
||||
* Extract GIF dimensions from header
|
||||
* Supports GIF87a and GIF89a formats
|
||||
*/
|
||||
private extractGifDimensions(buffer: Buffer): {
|
||||
width: number;
|
||||
height: number;
|
||||
} {
|
||||
if (buffer.length < 10) {
|
||||
throw new Error('Invalid GIF: too short');
|
||||
}
|
||||
|
||||
const signature = buffer.subarray(0, 6).toString('ascii');
|
||||
if (signature !== 'GIF87a' && signature !== 'GIF89a') {
|
||||
throw new Error('Invalid GIF signature');
|
||||
}
|
||||
|
||||
const width = buffer.readUInt16LE(6);
|
||||
const height = buffer.readUInt16LE(8);
|
||||
|
||||
return { width, height };
|
||||
}
|
||||
|
||||
/**
|
||||
* Calculate tokens for an image based on its metadata
|
||||
*
|
||||
* @param metadata Image metadata containing width, height, and format info
|
||||
* @returns Total token count including base image tokens and special tokens
|
||||
*/
|
||||
calculateTokens(metadata: ImageMetadata): number {
|
||||
return this.calculateTokensWithScaling(metadata.width, metadata.height);
|
||||
}
|
||||
|
||||
/**
|
||||
* Calculate tokens with scaling logic
|
||||
*
|
||||
* Steps:
|
||||
* 1. Normalize to 28-pixel multiples
|
||||
* 2. Scale large images down, small images up
|
||||
* 3. Calculate tokens: pixels / 784 + 2 special tokens
|
||||
*
|
||||
* @param width Original image width in pixels
|
||||
* @param height Original image height in pixels
|
||||
* @returns Total token count for the image
|
||||
*/
|
||||
private calculateTokensWithScaling(width: number, height: number): number {
|
||||
// Normalize to 28-pixel multiples
|
||||
let hBar = Math.round(height / 28) * 28;
|
||||
let wBar = Math.round(width / 28) * 28;
|
||||
|
||||
// Define pixel boundaries
|
||||
const minPixels =
|
||||
ImageTokenizer.MIN_TOKENS_PER_IMAGE * ImageTokenizer.PIXELS_PER_TOKEN;
|
||||
const maxPixels =
|
||||
ImageTokenizer.MAX_TOKENS_PER_IMAGE * ImageTokenizer.PIXELS_PER_TOKEN;
|
||||
|
||||
// Apply scaling
|
||||
if (hBar * wBar > maxPixels) {
|
||||
// Scale down large images
|
||||
const beta = Math.sqrt((height * width) / maxPixels);
|
||||
hBar = Math.floor(height / beta / 28) * 28;
|
||||
wBar = Math.floor(width / beta / 28) * 28;
|
||||
} else if (hBar * wBar < minPixels) {
|
||||
// Scale up small images
|
||||
const beta = Math.sqrt(minPixels / (height * width));
|
||||
hBar = Math.ceil((height * beta) / 28) * 28;
|
||||
wBar = Math.ceil((width * beta) / 28) * 28;
|
||||
}
|
||||
|
||||
// Calculate tokens
|
||||
const imageTokens = Math.floor(
|
||||
(hBar * wBar) / ImageTokenizer.PIXELS_PER_TOKEN,
|
||||
);
|
||||
|
||||
return imageTokens + ImageTokenizer.VISION_SPECIAL_TOKENS;
|
||||
}
|
||||
|
||||
/**
|
||||
* Calculate tokens for multiple images serially
|
||||
*
|
||||
* @param base64DataArray Array of image data with MIME type information
|
||||
* @returns Promise resolving to array of token counts in same order as input
|
||||
*/
|
||||
async calculateTokensBatch(
|
||||
base64DataArray: Array<{ data: string; mimeType: string }>,
|
||||
): Promise<number[]> {
|
||||
const results: number[] = [];
|
||||
|
||||
for (const { data, mimeType } of base64DataArray) {
|
||||
try {
|
||||
const metadata = await this.extractImageMetadata(data, mimeType);
|
||||
results.push(this.calculateTokens(metadata));
|
||||
} catch (error) {
|
||||
console.warn('Error calculating tokens for image:', error);
|
||||
// Return minimum tokens as fallback
|
||||
results.push(
|
||||
ImageTokenizer.MIN_TOKENS_PER_IMAGE +
|
||||
ImageTokenizer.VISION_SPECIAL_TOKENS,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
return results;
|
||||
}
|
||||
}
|
||||
40
packages/core/src/utils/request-tokenizer/index.ts
Normal file
40
packages/core/src/utils/request-tokenizer/index.ts
Normal file
@@ -0,0 +1,40 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Qwen
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
export { DefaultRequestTokenizer } from './requestTokenizer.js';
|
||||
import { DefaultRequestTokenizer } from './requestTokenizer.js';
|
||||
export { TextTokenizer } from './textTokenizer.js';
|
||||
export { ImageTokenizer } from './imageTokenizer.js';
|
||||
|
||||
export type {
|
||||
RequestTokenizer,
|
||||
TokenizerConfig,
|
||||
TokenCalculationResult,
|
||||
ImageMetadata,
|
||||
} from './types.js';
|
||||
|
||||
// Singleton instance for convenient usage
|
||||
let defaultTokenizer: DefaultRequestTokenizer | null = null;
|
||||
|
||||
/**
|
||||
* Get the default request tokenizer instance
|
||||
*/
|
||||
export function getDefaultTokenizer(): DefaultRequestTokenizer {
|
||||
if (!defaultTokenizer) {
|
||||
defaultTokenizer = new DefaultRequestTokenizer();
|
||||
}
|
||||
return defaultTokenizer;
|
||||
}
|
||||
|
||||
/**
|
||||
* Dispose of the default tokenizer instance
|
||||
*/
|
||||
export async function disposeDefaultTokenizer(): Promise<void> {
|
||||
if (defaultTokenizer) {
|
||||
await defaultTokenizer.dispose();
|
||||
defaultTokenizer = null;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,293 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Qwen
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { describe, it, expect, beforeEach, afterEach } from 'vitest';
|
||||
import { DefaultRequestTokenizer } from './requestTokenizer.js';
|
||||
import { CountTokensParameters } from '@google/genai';
|
||||
|
||||
describe('DefaultRequestTokenizer', () => {
|
||||
let tokenizer: DefaultRequestTokenizer;
|
||||
|
||||
beforeEach(() => {
|
||||
tokenizer = new DefaultRequestTokenizer();
|
||||
});
|
||||
|
||||
afterEach(async () => {
|
||||
await tokenizer.dispose();
|
||||
});
|
||||
|
||||
describe('text token calculation', () => {
|
||||
it('should calculate tokens for simple text content', async () => {
|
||||
const request: CountTokensParameters = {
|
||||
model: 'test-model',
|
||||
contents: [
|
||||
{
|
||||
role: 'user',
|
||||
parts: [{ text: 'Hello, world!' }],
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
const result = await tokenizer.calculateTokens(request);
|
||||
|
||||
expect(result.totalTokens).toBeGreaterThan(0);
|
||||
expect(result.breakdown.textTokens).toBeGreaterThan(0);
|
||||
expect(result.breakdown.imageTokens).toBe(0);
|
||||
expect(result.processingTime).toBeGreaterThan(0);
|
||||
});
|
||||
|
||||
it('should handle multiple text parts', async () => {
|
||||
const request: CountTokensParameters = {
|
||||
model: 'test-model',
|
||||
contents: [
|
||||
{
|
||||
role: 'user',
|
||||
parts: [
|
||||
{ text: 'First part' },
|
||||
{ text: 'Second part' },
|
||||
{ text: 'Third part' },
|
||||
],
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
const result = await tokenizer.calculateTokens(request);
|
||||
|
||||
expect(result.totalTokens).toBeGreaterThan(0);
|
||||
expect(result.breakdown.textTokens).toBeGreaterThan(0);
|
||||
});
|
||||
|
||||
it('should handle string content', async () => {
|
||||
const request: CountTokensParameters = {
|
||||
model: 'test-model',
|
||||
contents: ['Simple string content'],
|
||||
};
|
||||
|
||||
const result = await tokenizer.calculateTokens(request);
|
||||
|
||||
expect(result.totalTokens).toBeGreaterThan(0);
|
||||
expect(result.breakdown.textTokens).toBeGreaterThan(0);
|
||||
});
|
||||
});
|
||||
|
||||
describe('image token calculation', () => {
|
||||
it('should calculate tokens for image content', async () => {
|
||||
// Create a simple 1x1 PNG image in base64
|
||||
const pngBase64 =
|
||||
'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNkYPhfDwAChAI9jU77yQAAAABJRU5ErkJggg==';
|
||||
|
||||
const request: CountTokensParameters = {
|
||||
model: 'test-model',
|
||||
contents: [
|
||||
{
|
||||
role: 'user',
|
||||
parts: [
|
||||
{
|
||||
inlineData: {
|
||||
mimeType: 'image/png',
|
||||
data: pngBase64,
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
const result = await tokenizer.calculateTokens(request);
|
||||
|
||||
expect(result.totalTokens).toBeGreaterThanOrEqual(4); // Minimum 4 tokens per image
|
||||
expect(result.breakdown.imageTokens).toBeGreaterThanOrEqual(4);
|
||||
expect(result.breakdown.textTokens).toBe(0);
|
||||
});
|
||||
|
||||
it('should handle multiple images', async () => {
|
||||
const pngBase64 =
|
||||
'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNkYPhfDwAChAI9jU77yQAAAABJRU5ErkJggg==';
|
||||
|
||||
const request: CountTokensParameters = {
|
||||
model: 'test-model',
|
||||
contents: [
|
||||
{
|
||||
role: 'user',
|
||||
parts: [
|
||||
{
|
||||
inlineData: {
|
||||
mimeType: 'image/png',
|
||||
data: pngBase64,
|
||||
},
|
||||
},
|
||||
{
|
||||
inlineData: {
|
||||
mimeType: 'image/png',
|
||||
data: pngBase64,
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
const result = await tokenizer.calculateTokens(request);
|
||||
|
||||
expect(result.totalTokens).toBeGreaterThanOrEqual(8); // At least 4 tokens per image
|
||||
expect(result.breakdown.imageTokens).toBeGreaterThanOrEqual(8);
|
||||
});
|
||||
});
|
||||
|
||||
describe('mixed content', () => {
|
||||
it('should handle text and image content together', async () => {
|
||||
const pngBase64 =
|
||||
'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNkYPhfDwAChAI9jU77yQAAAABJRU5ErkJggg==';
|
||||
|
||||
const request: CountTokensParameters = {
|
||||
model: 'test-model',
|
||||
contents: [
|
||||
{
|
||||
role: 'user',
|
||||
parts: [
|
||||
{ text: 'Here is an image:' },
|
||||
{
|
||||
inlineData: {
|
||||
mimeType: 'image/png',
|
||||
data: pngBase64,
|
||||
},
|
||||
},
|
||||
{ text: 'What do you see?' },
|
||||
],
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
const result = await tokenizer.calculateTokens(request);
|
||||
|
||||
expect(result.totalTokens).toBeGreaterThan(4);
|
||||
expect(result.breakdown.textTokens).toBeGreaterThan(0);
|
||||
expect(result.breakdown.imageTokens).toBeGreaterThanOrEqual(4);
|
||||
});
|
||||
});
|
||||
|
||||
describe('function content', () => {
|
||||
it('should handle function calls', async () => {
|
||||
const request: CountTokensParameters = {
|
||||
model: 'test-model',
|
||||
contents: [
|
||||
{
|
||||
role: 'user',
|
||||
parts: [
|
||||
{
|
||||
functionCall: {
|
||||
name: 'test_function',
|
||||
args: { param1: 'value1', param2: 42 },
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
const result = await tokenizer.calculateTokens(request);
|
||||
|
||||
expect(result.totalTokens).toBeGreaterThan(0);
|
||||
expect(result.breakdown.otherTokens).toBeGreaterThan(0);
|
||||
});
|
||||
});
|
||||
|
||||
describe('empty content', () => {
|
||||
it('should handle empty request', async () => {
|
||||
const request: CountTokensParameters = {
|
||||
model: 'test-model',
|
||||
contents: [],
|
||||
};
|
||||
|
||||
const result = await tokenizer.calculateTokens(request);
|
||||
|
||||
expect(result.totalTokens).toBe(0);
|
||||
expect(result.breakdown.textTokens).toBe(0);
|
||||
expect(result.breakdown.imageTokens).toBe(0);
|
||||
});
|
||||
|
||||
it('should handle undefined contents', async () => {
|
||||
const request: CountTokensParameters = {
|
||||
model: 'test-model',
|
||||
contents: [],
|
||||
};
|
||||
|
||||
const result = await tokenizer.calculateTokens(request);
|
||||
|
||||
expect(result.totalTokens).toBe(0);
|
||||
});
|
||||
});
|
||||
|
||||
describe('configuration', () => {
|
||||
it('should use custom text encoding', async () => {
|
||||
const request: CountTokensParameters = {
|
||||
model: 'test-model',
|
||||
contents: [
|
||||
{
|
||||
role: 'user',
|
||||
parts: [{ text: 'Test text for encoding' }],
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
const result = await tokenizer.calculateTokens(request, {
|
||||
textEncoding: 'cl100k_base',
|
||||
});
|
||||
|
||||
expect(result.totalTokens).toBeGreaterThan(0);
|
||||
});
|
||||
|
||||
it('should process multiple images serially', async () => {
|
||||
const pngBase64 =
|
||||
'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNkYPhfDwAChAI9jU77yQAAAABJRU5ErkJggg==';
|
||||
|
||||
const request: CountTokensParameters = {
|
||||
model: 'test-model',
|
||||
contents: [
|
||||
{
|
||||
role: 'user',
|
||||
parts: Array(10).fill({
|
||||
inlineData: {
|
||||
mimeType: 'image/png',
|
||||
data: pngBase64,
|
||||
},
|
||||
}),
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
const result = await tokenizer.calculateTokens(request);
|
||||
|
||||
expect(result.totalTokens).toBeGreaterThanOrEqual(60); // At least 6 tokens per image * 10 images
|
||||
});
|
||||
});
|
||||
|
||||
describe('error handling', () => {
|
||||
it('should handle malformed image data gracefully', async () => {
|
||||
const request: CountTokensParameters = {
|
||||
model: 'test-model',
|
||||
contents: [
|
||||
{
|
||||
role: 'user',
|
||||
parts: [
|
||||
{
|
||||
inlineData: {
|
||||
mimeType: 'image/png',
|
||||
data: 'invalid-base64-data',
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
const result = await tokenizer.calculateTokens(request);
|
||||
|
||||
// Should still return some tokens (fallback to minimum)
|
||||
expect(result.totalTokens).toBeGreaterThanOrEqual(4);
|
||||
});
|
||||
});
|
||||
});
|
||||
336
packages/core/src/utils/request-tokenizer/requestTokenizer.ts
Normal file
336
packages/core/src/utils/request-tokenizer/requestTokenizer.ts
Normal file
@@ -0,0 +1,336 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Qwen
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { CountTokensParameters, Content, Part, PartUnion } from '@google/genai';
|
||||
import {
|
||||
RequestTokenizer,
|
||||
TokenizerConfig,
|
||||
TokenCalculationResult,
|
||||
} from './types.js';
|
||||
import { TextTokenizer } from './textTokenizer.js';
|
||||
import { ImageTokenizer } from './imageTokenizer.js';
|
||||
|
||||
/**
|
||||
* Simple request tokenizer that handles text and image content serially
|
||||
*/
|
||||
export class DefaultRequestTokenizer implements RequestTokenizer {
|
||||
private textTokenizer: TextTokenizer;
|
||||
private imageTokenizer: ImageTokenizer;
|
||||
|
||||
constructor() {
|
||||
this.textTokenizer = new TextTokenizer();
|
||||
this.imageTokenizer = new ImageTokenizer();
|
||||
}
|
||||
|
||||
/**
|
||||
* Calculate tokens for a request using serial processing
|
||||
*/
|
||||
async calculateTokens(
|
||||
request: CountTokensParameters,
|
||||
config: TokenizerConfig = {},
|
||||
): Promise<TokenCalculationResult> {
|
||||
const startTime = performance.now();
|
||||
|
||||
// Apply configuration
|
||||
if (config.textEncoding) {
|
||||
this.textTokenizer = new TextTokenizer(config.textEncoding);
|
||||
}
|
||||
|
||||
try {
|
||||
// Process request content and group by type
|
||||
const { textContents, imageContents, audioContents, otherContents } =
|
||||
this.processAndGroupContents(request);
|
||||
|
||||
if (
|
||||
textContents.length === 0 &&
|
||||
imageContents.length === 0 &&
|
||||
audioContents.length === 0 &&
|
||||
otherContents.length === 0
|
||||
) {
|
||||
return {
|
||||
totalTokens: 0,
|
||||
breakdown: {
|
||||
textTokens: 0,
|
||||
imageTokens: 0,
|
||||
audioTokens: 0,
|
||||
otherTokens: 0,
|
||||
},
|
||||
processingTime: performance.now() - startTime,
|
||||
};
|
||||
}
|
||||
|
||||
// Calculate tokens for each content type serially
|
||||
const textTokens = await this.calculateTextTokens(textContents);
|
||||
const imageTokens = await this.calculateImageTokens(imageContents);
|
||||
const audioTokens = await this.calculateAudioTokens(audioContents);
|
||||
const otherTokens = await this.calculateOtherTokens(otherContents);
|
||||
|
||||
const totalTokens = textTokens + imageTokens + audioTokens + otherTokens;
|
||||
const processingTime = performance.now() - startTime;
|
||||
|
||||
return {
|
||||
totalTokens,
|
||||
breakdown: {
|
||||
textTokens,
|
||||
imageTokens,
|
||||
audioTokens,
|
||||
otherTokens,
|
||||
},
|
||||
processingTime,
|
||||
};
|
||||
} catch (error) {
|
||||
console.error('Error calculating tokens:', error);
|
||||
|
||||
// Fallback calculation
|
||||
const fallbackTokens = this.calculateFallbackTokens(request);
|
||||
|
||||
return {
|
||||
totalTokens: fallbackTokens,
|
||||
breakdown: {
|
||||
textTokens: fallbackTokens,
|
||||
imageTokens: 0,
|
||||
audioTokens: 0,
|
||||
otherTokens: 0,
|
||||
},
|
||||
processingTime: performance.now() - startTime,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Calculate tokens for text contents
|
||||
*/
|
||||
private async calculateTextTokens(textContents: string[]): Promise<number> {
|
||||
if (textContents.length === 0) return 0;
|
||||
|
||||
try {
|
||||
const tokenCounts =
|
||||
await this.textTokenizer.calculateTokensBatch(textContents);
|
||||
return tokenCounts.reduce((sum, count) => sum + count, 0);
|
||||
} catch (error) {
|
||||
console.warn('Error calculating text tokens:', error);
|
||||
// Fallback: character-based estimation
|
||||
const totalChars = textContents.join('').length;
|
||||
return Math.ceil(totalChars / 4);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Calculate tokens for image contents using serial processing
|
||||
*/
|
||||
private async calculateImageTokens(
|
||||
imageContents: Array<{ data: string; mimeType: string }>,
|
||||
): Promise<number> {
|
||||
if (imageContents.length === 0) return 0;
|
||||
|
||||
try {
|
||||
const tokenCounts =
|
||||
await this.imageTokenizer.calculateTokensBatch(imageContents);
|
||||
return tokenCounts.reduce((sum, count) => sum + count, 0);
|
||||
} catch (error) {
|
||||
console.warn('Error calculating image tokens:', error);
|
||||
// Fallback: minimum tokens per image
|
||||
return imageContents.length * 6; // 4 image tokens + 2 special tokens as minimum
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Calculate tokens for audio contents
|
||||
* TODO: Implement proper audio token calculation
|
||||
*/
|
||||
private async calculateAudioTokens(
|
||||
audioContents: Array<{ data: string; mimeType: string }>,
|
||||
): Promise<number> {
|
||||
if (audioContents.length === 0) return 0;
|
||||
|
||||
// Placeholder implementation - audio token calculation would depend on
|
||||
// the specific model's audio processing capabilities
|
||||
// For now, estimate based on data size
|
||||
let totalTokens = 0;
|
||||
|
||||
for (const audioContent of audioContents) {
|
||||
try {
|
||||
const dataSize = Math.floor(audioContent.data.length * 0.75); // Approximate binary size
|
||||
// Rough estimate: 1 token per 100 bytes of audio data
|
||||
totalTokens += Math.max(Math.ceil(dataSize / 100), 10); // Minimum 10 tokens per audio
|
||||
} catch (error) {
|
||||
console.warn('Error calculating audio tokens:', error);
|
||||
totalTokens += 10; // Fallback minimum
|
||||
}
|
||||
}
|
||||
|
||||
return totalTokens;
|
||||
}
|
||||
|
||||
/**
|
||||
* Calculate tokens for other content types (functions, files, etc.)
|
||||
*/
|
||||
private async calculateOtherTokens(otherContents: string[]): Promise<number> {
|
||||
if (otherContents.length === 0) return 0;
|
||||
|
||||
try {
|
||||
// Treat other content as text for token calculation
|
||||
const tokenCounts =
|
||||
await this.textTokenizer.calculateTokensBatch(otherContents);
|
||||
return tokenCounts.reduce((sum, count) => sum + count, 0);
|
||||
} catch (error) {
|
||||
console.warn('Error calculating other content tokens:', error);
|
||||
// Fallback: character-based estimation
|
||||
const totalChars = otherContents.join('').length;
|
||||
return Math.ceil(totalChars / 4);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Fallback token calculation using simple string serialization
|
||||
*/
|
||||
private calculateFallbackTokens(request: CountTokensParameters): number {
|
||||
try {
|
||||
const content = JSON.stringify(request.contents);
|
||||
return Math.ceil(content.length / 4); // Rough estimate: 1 token ≈ 4 characters
|
||||
} catch (error) {
|
||||
console.warn('Error in fallback token calculation:', error);
|
||||
return 100; // Conservative fallback
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Process request contents and group by type
|
||||
*/
|
||||
private processAndGroupContents(request: CountTokensParameters): {
|
||||
textContents: string[];
|
||||
imageContents: Array<{ data: string; mimeType: string }>;
|
||||
audioContents: Array<{ data: string; mimeType: string }>;
|
||||
otherContents: string[];
|
||||
} {
|
||||
const textContents: string[] = [];
|
||||
const imageContents: Array<{ data: string; mimeType: string }> = [];
|
||||
const audioContents: Array<{ data: string; mimeType: string }> = [];
|
||||
const otherContents: string[] = [];
|
||||
|
||||
if (!request.contents) {
|
||||
return { textContents, imageContents, audioContents, otherContents };
|
||||
}
|
||||
|
||||
const contents = Array.isArray(request.contents)
|
||||
? request.contents
|
||||
: [request.contents];
|
||||
|
||||
for (const content of contents) {
|
||||
this.processContent(
|
||||
content,
|
||||
textContents,
|
||||
imageContents,
|
||||
audioContents,
|
||||
otherContents,
|
||||
);
|
||||
}
|
||||
|
||||
return { textContents, imageContents, audioContents, otherContents };
|
||||
}
|
||||
|
||||
/**
|
||||
* Process a single content item and add to appropriate arrays
|
||||
*/
|
||||
private processContent(
|
||||
content: Content | string | PartUnion,
|
||||
textContents: string[],
|
||||
imageContents: Array<{ data: string; mimeType: string }>,
|
||||
audioContents: Array<{ data: string; mimeType: string }>,
|
||||
otherContents: string[],
|
||||
): void {
|
||||
if (typeof content === 'string') {
|
||||
if (content.trim()) {
|
||||
textContents.push(content);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if ('parts' in content && content.parts) {
|
||||
for (const part of content.parts) {
|
||||
this.processPart(
|
||||
part,
|
||||
textContents,
|
||||
imageContents,
|
||||
audioContents,
|
||||
otherContents,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Process a single part and add to appropriate arrays
|
||||
*/
|
||||
private processPart(
|
||||
part: Part | string,
|
||||
textContents: string[],
|
||||
imageContents: Array<{ data: string; mimeType: string }>,
|
||||
audioContents: Array<{ data: string; mimeType: string }>,
|
||||
otherContents: string[],
|
||||
): void {
|
||||
if (typeof part === 'string') {
|
||||
if (part.trim()) {
|
||||
textContents.push(part);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if ('text' in part && part.text) {
|
||||
textContents.push(part.text);
|
||||
return;
|
||||
}
|
||||
|
||||
if ('inlineData' in part && part.inlineData) {
|
||||
const { data, mimeType } = part.inlineData;
|
||||
if (mimeType && mimeType.startsWith('image/')) {
|
||||
imageContents.push({ data: data || '', mimeType });
|
||||
return;
|
||||
}
|
||||
if (mimeType && mimeType.startsWith('audio/')) {
|
||||
audioContents.push({ data: data || '', mimeType });
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
if ('fileData' in part && part.fileData) {
|
||||
otherContents.push(JSON.stringify(part.fileData));
|
||||
return;
|
||||
}
|
||||
|
||||
if ('functionCall' in part && part.functionCall) {
|
||||
otherContents.push(JSON.stringify(part.functionCall));
|
||||
return;
|
||||
}
|
||||
|
||||
if ('functionResponse' in part && part.functionResponse) {
|
||||
otherContents.push(JSON.stringify(part.functionResponse));
|
||||
return;
|
||||
}
|
||||
|
||||
// Unknown part type - try to serialize
|
||||
try {
|
||||
const serialized = JSON.stringify(part);
|
||||
if (serialized && serialized !== '{}') {
|
||||
otherContents.push(serialized);
|
||||
}
|
||||
} catch (error) {
|
||||
console.warn('Failed to serialize unknown part type:', error);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Dispose of resources
|
||||
*/
|
||||
async dispose(): Promise<void> {
|
||||
try {
|
||||
// Dispose of tokenizers
|
||||
this.textTokenizer.dispose();
|
||||
} catch (error) {
|
||||
console.warn('Error disposing request tokenizer:', error);
|
||||
}
|
||||
}
|
||||
}
|
||||
347
packages/core/src/utils/request-tokenizer/textTokenizer.test.ts
Normal file
347
packages/core/src/utils/request-tokenizer/textTokenizer.test.ts
Normal file
@@ -0,0 +1,347 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Qwen
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
|
||||
import { TextTokenizer } from './textTokenizer.js';
|
||||
|
||||
// Mock tiktoken at the top level with hoisted functions
|
||||
const mockEncode = vi.hoisted(() => vi.fn());
|
||||
const mockFree = vi.hoisted(() => vi.fn());
|
||||
const mockGetEncoding = vi.hoisted(() => vi.fn());
|
||||
|
||||
vi.mock('tiktoken', () => ({
|
||||
get_encoding: mockGetEncoding,
|
||||
}));
|
||||
|
||||
describe('TextTokenizer', () => {
|
||||
let tokenizer: TextTokenizer;
|
||||
let consoleWarnSpy: ReturnType<typeof vi.spyOn>;
|
||||
|
||||
beforeEach(() => {
|
||||
vi.resetAllMocks();
|
||||
consoleWarnSpy = vi.spyOn(console, 'warn').mockImplementation(() => {});
|
||||
|
||||
// Default mock implementation
|
||||
mockGetEncoding.mockReturnValue({
|
||||
encode: mockEncode,
|
||||
free: mockFree,
|
||||
});
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.restoreAllMocks();
|
||||
tokenizer?.dispose();
|
||||
});
|
||||
|
||||
describe('constructor', () => {
|
||||
it('should create tokenizer with default encoding', () => {
|
||||
tokenizer = new TextTokenizer();
|
||||
expect(tokenizer).toBeInstanceOf(TextTokenizer);
|
||||
});
|
||||
|
||||
it('should create tokenizer with custom encoding', () => {
|
||||
tokenizer = new TextTokenizer('gpt2');
|
||||
expect(tokenizer).toBeInstanceOf(TextTokenizer);
|
||||
});
|
||||
});
|
||||
|
||||
describe('calculateTokens', () => {
|
||||
beforeEach(() => {
|
||||
tokenizer = new TextTokenizer();
|
||||
});
|
||||
|
||||
it('should return 0 for empty text', async () => {
|
||||
const result = await tokenizer.calculateTokens('');
|
||||
expect(result).toBe(0);
|
||||
});
|
||||
|
||||
it('should return 0 for null/undefined text', async () => {
|
||||
const result1 = await tokenizer.calculateTokens(
|
||||
null as unknown as string,
|
||||
);
|
||||
const result2 = await tokenizer.calculateTokens(
|
||||
undefined as unknown as string,
|
||||
);
|
||||
expect(result1).toBe(0);
|
||||
expect(result2).toBe(0);
|
||||
});
|
||||
|
||||
it('should calculate tokens using tiktoken when available', async () => {
|
||||
const testText = 'Hello, world!';
|
||||
const mockTokens = [1, 2, 3, 4, 5]; // 5 tokens
|
||||
mockEncode.mockReturnValue(mockTokens);
|
||||
|
||||
const result = await tokenizer.calculateTokens(testText);
|
||||
|
||||
expect(mockGetEncoding).toHaveBeenCalledWith('cl100k_base');
|
||||
expect(mockEncode).toHaveBeenCalledWith(testText);
|
||||
expect(result).toBe(5);
|
||||
});
|
||||
|
||||
it('should use fallback calculation when tiktoken fails to load', async () => {
|
||||
mockGetEncoding.mockImplementation(() => {
|
||||
throw new Error('Failed to load tiktoken');
|
||||
});
|
||||
|
||||
const testText = 'Hello, world!'; // 13 characters
|
||||
const result = await tokenizer.calculateTokens(testText);
|
||||
|
||||
expect(consoleWarnSpy).toHaveBeenCalledWith(
|
||||
'Failed to load tiktoken with encoding cl100k_base:',
|
||||
expect.any(Error),
|
||||
);
|
||||
// Fallback: Math.ceil(13 / 4) = 4
|
||||
expect(result).toBe(4);
|
||||
});
|
||||
|
||||
it('should use fallback calculation when encoding fails', async () => {
|
||||
mockEncode.mockImplementation(() => {
|
||||
throw new Error('Encoding failed');
|
||||
});
|
||||
|
||||
const testText = 'Hello, world!'; // 13 characters
|
||||
const result = await tokenizer.calculateTokens(testText);
|
||||
|
||||
expect(consoleWarnSpy).toHaveBeenCalledWith(
|
||||
'Error encoding text with tiktoken:',
|
||||
expect.any(Error),
|
||||
);
|
||||
// Fallback: Math.ceil(13 / 4) = 4
|
||||
expect(result).toBe(4);
|
||||
});
|
||||
|
||||
it('should handle very long text', async () => {
|
||||
const longText = 'a'.repeat(10000);
|
||||
const mockTokens = new Array(2500); // 2500 tokens
|
||||
mockEncode.mockReturnValue(mockTokens);
|
||||
|
||||
const result = await tokenizer.calculateTokens(longText);
|
||||
|
||||
expect(result).toBe(2500);
|
||||
});
|
||||
|
||||
it('should handle unicode characters', async () => {
|
||||
const unicodeText = '你好世界 🌍';
|
||||
const mockTokens = [1, 2, 3, 4, 5, 6];
|
||||
mockEncode.mockReturnValue(mockTokens);
|
||||
|
||||
const result = await tokenizer.calculateTokens(unicodeText);
|
||||
|
||||
expect(result).toBe(6);
|
||||
});
|
||||
|
||||
it('should use custom encoding when specified', async () => {
|
||||
tokenizer = new TextTokenizer('gpt2');
|
||||
const testText = 'Hello, world!';
|
||||
const mockTokens = [1, 2, 3];
|
||||
mockEncode.mockReturnValue(mockTokens);
|
||||
|
||||
const result = await tokenizer.calculateTokens(testText);
|
||||
|
||||
expect(mockGetEncoding).toHaveBeenCalledWith('gpt2');
|
||||
expect(result).toBe(3);
|
||||
});
|
||||
});
|
||||
|
||||
describe('calculateTokensBatch', () => {
|
||||
beforeEach(() => {
|
||||
tokenizer = new TextTokenizer();
|
||||
});
|
||||
|
||||
it('should process multiple texts and return token counts', async () => {
|
||||
const texts = ['Hello', 'world', 'test'];
|
||||
mockEncode
|
||||
.mockReturnValueOnce([1, 2]) // 2 tokens for 'Hello'
|
||||
.mockReturnValueOnce([3, 4, 5]) // 3 tokens for 'world'
|
||||
.mockReturnValueOnce([6]); // 1 token for 'test'
|
||||
|
||||
const result = await tokenizer.calculateTokensBatch(texts);
|
||||
|
||||
expect(result).toEqual([2, 3, 1]);
|
||||
expect(mockEncode).toHaveBeenCalledTimes(3);
|
||||
});
|
||||
|
||||
it('should handle empty array', async () => {
|
||||
const result = await tokenizer.calculateTokensBatch([]);
|
||||
expect(result).toEqual([]);
|
||||
});
|
||||
|
||||
it('should handle array with empty strings', async () => {
|
||||
const texts = ['', 'hello', ''];
|
||||
mockEncode.mockReturnValue([1, 2, 3]); // Only called for 'hello'
|
||||
|
||||
const result = await tokenizer.calculateTokensBatch(texts);
|
||||
|
||||
expect(result).toEqual([0, 3, 0]);
|
||||
expect(mockEncode).toHaveBeenCalledTimes(1);
|
||||
expect(mockEncode).toHaveBeenCalledWith('hello');
|
||||
});
|
||||
|
||||
it('should use fallback calculation when tiktoken fails to load', async () => {
|
||||
mockGetEncoding.mockImplementation(() => {
|
||||
throw new Error('Failed to load tiktoken');
|
||||
});
|
||||
|
||||
const texts = ['Hello', 'world']; // 5 and 5 characters
|
||||
const result = await tokenizer.calculateTokensBatch(texts);
|
||||
|
||||
expect(consoleWarnSpy).toHaveBeenCalledWith(
|
||||
'Failed to load tiktoken with encoding cl100k_base:',
|
||||
expect.any(Error),
|
||||
);
|
||||
// Fallback: Math.ceil(5/4) = 2 for both
|
||||
expect(result).toEqual([2, 2]);
|
||||
});
|
||||
|
||||
it('should use fallback calculation when encoding fails during batch processing', async () => {
|
||||
mockEncode.mockImplementation(() => {
|
||||
throw new Error('Encoding failed');
|
||||
});
|
||||
|
||||
const texts = ['Hello', 'world']; // 5 and 5 characters
|
||||
const result = await tokenizer.calculateTokensBatch(texts);
|
||||
|
||||
expect(consoleWarnSpy).toHaveBeenCalledWith(
|
||||
'Error encoding texts with tiktoken:',
|
||||
expect.any(Error),
|
||||
);
|
||||
// Fallback: Math.ceil(5/4) = 2 for both
|
||||
expect(result).toEqual([2, 2]);
|
||||
});
|
||||
|
||||
it('should handle null and undefined values in batch', async () => {
|
||||
const texts = [null, 'hello', undefined, 'world'] as unknown as string[];
|
||||
mockEncode
|
||||
.mockReturnValueOnce([1, 2, 3]) // 3 tokens for 'hello'
|
||||
.mockReturnValueOnce([4, 5]); // 2 tokens for 'world'
|
||||
|
||||
const result = await tokenizer.calculateTokensBatch(texts);
|
||||
|
||||
expect(result).toEqual([0, 3, 0, 2]);
|
||||
});
|
||||
});
|
||||
|
||||
describe('dispose', () => {
|
||||
beforeEach(() => {
|
||||
tokenizer = new TextTokenizer();
|
||||
});
|
||||
|
||||
it('should free tiktoken encoding when disposing', async () => {
|
||||
// Initialize the encoding by calling calculateTokens
|
||||
await tokenizer.calculateTokens('test');
|
||||
|
||||
tokenizer.dispose();
|
||||
|
||||
expect(mockFree).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should handle disposal when encoding is not initialized', () => {
|
||||
expect(() => tokenizer.dispose()).not.toThrow();
|
||||
expect(mockFree).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should handle disposal when encoding is null', async () => {
|
||||
// Force encoding to be null by making tiktoken fail
|
||||
mockGetEncoding.mockImplementation(() => {
|
||||
throw new Error('Failed to load');
|
||||
});
|
||||
|
||||
await tokenizer.calculateTokens('test');
|
||||
|
||||
expect(() => tokenizer.dispose()).not.toThrow();
|
||||
expect(mockFree).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should handle errors during disposal gracefully', async () => {
|
||||
await tokenizer.calculateTokens('test');
|
||||
|
||||
mockFree.mockImplementation(() => {
|
||||
throw new Error('Free failed');
|
||||
});
|
||||
|
||||
tokenizer.dispose();
|
||||
|
||||
expect(consoleWarnSpy).toHaveBeenCalledWith(
|
||||
'Error freeing tiktoken encoding:',
|
||||
expect.any(Error),
|
||||
);
|
||||
});
|
||||
|
||||
it('should allow multiple calls to dispose', async () => {
|
||||
await tokenizer.calculateTokens('test');
|
||||
|
||||
tokenizer.dispose();
|
||||
tokenizer.dispose(); // Second call should not throw
|
||||
|
||||
expect(mockFree).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
});
|
||||
|
||||
describe('lazy initialization', () => {
|
||||
beforeEach(() => {
|
||||
tokenizer = new TextTokenizer();
|
||||
});
|
||||
|
||||
it('should not initialize tiktoken until first use', () => {
|
||||
expect(mockGetEncoding).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should initialize tiktoken on first calculateTokens call', async () => {
|
||||
await tokenizer.calculateTokens('test');
|
||||
expect(mockGetEncoding).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
|
||||
it('should not reinitialize tiktoken on subsequent calls', async () => {
|
||||
await tokenizer.calculateTokens('test1');
|
||||
await tokenizer.calculateTokens('test2');
|
||||
|
||||
expect(mockGetEncoding).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
|
||||
it('should initialize tiktoken on first calculateTokensBatch call', async () => {
|
||||
await tokenizer.calculateTokensBatch(['test']);
|
||||
expect(mockGetEncoding).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
});
|
||||
|
||||
describe('edge cases', () => {
|
||||
beforeEach(() => {
|
||||
tokenizer = new TextTokenizer();
|
||||
});
|
||||
|
||||
it('should handle very short text', async () => {
|
||||
const result = await tokenizer.calculateTokens('a');
|
||||
|
||||
if (mockGetEncoding.mock.calls.length > 0) {
|
||||
// If tiktoken was called, use its result
|
||||
expect(mockEncode).toHaveBeenCalledWith('a');
|
||||
} else {
|
||||
// If tiktoken failed, should use fallback: Math.ceil(1/4) = 1
|
||||
expect(result).toBe(1);
|
||||
}
|
||||
});
|
||||
|
||||
it('should handle text with only whitespace', async () => {
|
||||
const whitespaceText = ' \n\t ';
|
||||
const mockTokens = [1];
|
||||
mockEncode.mockReturnValue(mockTokens);
|
||||
|
||||
const result = await tokenizer.calculateTokens(whitespaceText);
|
||||
|
||||
expect(result).toBe(1);
|
||||
});
|
||||
|
||||
it('should handle special characters and symbols', async () => {
|
||||
const specialText = '!@#$%^&*()_+-=[]{}|;:,.<>?';
|
||||
const mockTokens = new Array(10);
|
||||
mockEncode.mockReturnValue(mockTokens);
|
||||
|
||||
const result = await tokenizer.calculateTokens(specialText);
|
||||
|
||||
expect(result).toBe(10);
|
||||
});
|
||||
});
|
||||
});
|
||||
97
packages/core/src/utils/request-tokenizer/textTokenizer.ts
Normal file
97
packages/core/src/utils/request-tokenizer/textTokenizer.ts
Normal file
@@ -0,0 +1,97 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Qwen
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import type { TiktokenEncoding, Tiktoken } from 'tiktoken';
|
||||
import { get_encoding } from 'tiktoken';
|
||||
|
||||
/**
|
||||
* Text tokenizer for calculating text tokens using tiktoken
|
||||
*/
|
||||
export class TextTokenizer {
|
||||
private encoding: Tiktoken | null = null;
|
||||
private encodingName: string;
|
||||
|
||||
constructor(encodingName: string = 'cl100k_base') {
|
||||
this.encodingName = encodingName;
|
||||
}
|
||||
|
||||
/**
|
||||
* Initialize the tokenizer (lazy loading)
|
||||
*/
|
||||
private async ensureEncoding(): Promise<void> {
|
||||
if (this.encoding) return;
|
||||
|
||||
try {
|
||||
// Use type assertion since we know the encoding name is valid
|
||||
this.encoding = get_encoding(this.encodingName as TiktokenEncoding);
|
||||
} catch (error) {
|
||||
console.warn(
|
||||
`Failed to load tiktoken with encoding ${this.encodingName}:`,
|
||||
error,
|
||||
);
|
||||
this.encoding = null;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Calculate tokens for text content
|
||||
*/
|
||||
async calculateTokens(text: string): Promise<number> {
|
||||
if (!text) return 0;
|
||||
|
||||
await this.ensureEncoding();
|
||||
|
||||
if (this.encoding) {
|
||||
try {
|
||||
return this.encoding.encode(text).length;
|
||||
} catch (error) {
|
||||
console.warn('Error encoding text with tiktoken:', error);
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback: rough approximation using character count
|
||||
// This is a conservative estimate: 1 token ≈ 4 characters for most languages
|
||||
return Math.ceil(text.length / 4);
|
||||
}
|
||||
|
||||
/**
|
||||
* Calculate tokens for multiple text strings in parallel
|
||||
*/
|
||||
async calculateTokensBatch(texts: string[]): Promise<number[]> {
|
||||
await this.ensureEncoding();
|
||||
|
||||
if (this.encoding) {
|
||||
try {
|
||||
return texts.map((text) => {
|
||||
if (!text) return 0;
|
||||
// this.encoding may be null, add a null check to satisfy lint
|
||||
return this.encoding ? this.encoding.encode(text).length : 0;
|
||||
});
|
||||
} catch (error) {
|
||||
console.warn('Error encoding texts with tiktoken:', error);
|
||||
// In case of error, return fallback estimation for all texts
|
||||
return texts.map((text) => Math.ceil((text || '').length / 4));
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback for batch processing
|
||||
return texts.map((text) => Math.ceil((text || '').length / 4));
|
||||
}
|
||||
|
||||
/**
|
||||
* Dispose of resources
|
||||
*/
|
||||
dispose(): void {
|
||||
if (this.encoding) {
|
||||
try {
|
||||
this.encoding.free();
|
||||
} catch (error) {
|
||||
console.warn('Error freeing tiktoken encoding:', error);
|
||||
}
|
||||
this.encoding = null;
|
||||
}
|
||||
}
|
||||
}
|
||||
64
packages/core/src/utils/request-tokenizer/types.ts
Normal file
64
packages/core/src/utils/request-tokenizer/types.ts
Normal file
@@ -0,0 +1,64 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Qwen
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { CountTokensParameters } from '@google/genai';
|
||||
|
||||
/**
|
||||
* Token calculation result for different content types
|
||||
*/
|
||||
export interface TokenCalculationResult {
|
||||
/** Total tokens calculated */
|
||||
totalTokens: number;
|
||||
/** Breakdown by content type */
|
||||
breakdown: {
|
||||
textTokens: number;
|
||||
imageTokens: number;
|
||||
audioTokens: number;
|
||||
otherTokens: number;
|
||||
};
|
||||
/** Processing time in milliseconds */
|
||||
processingTime: number;
|
||||
}
|
||||
|
||||
/**
|
||||
* Configuration for token calculation
|
||||
*/
|
||||
export interface TokenizerConfig {
|
||||
/** Custom text tokenizer encoding (defaults to cl100k_base) */
|
||||
textEncoding?: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* Image metadata extracted from base64 data
|
||||
*/
|
||||
export interface ImageMetadata {
|
||||
/** Image width in pixels */
|
||||
width: number;
|
||||
/** Image height in pixels */
|
||||
height: number;
|
||||
/** MIME type of the image */
|
||||
mimeType: string;
|
||||
/** Size of the base64 data in bytes */
|
||||
dataSize: number;
|
||||
}
|
||||
|
||||
/**
|
||||
* Request tokenizer interface
|
||||
*/
|
||||
export interface RequestTokenizer {
|
||||
/**
|
||||
* Calculate tokens for a request
|
||||
*/
|
||||
calculateTokens(
|
||||
request: CountTokensParameters,
|
||||
config?: TokenizerConfig,
|
||||
): Promise<TokenCalculationResult>;
|
||||
|
||||
/**
|
||||
* Dispose of resources (worker threads, etc.)
|
||||
*/
|
||||
dispose(): Promise<void>;
|
||||
}
|
||||
Reference in New Issue
Block a user