From 93fbc54f88566af7e54a7d715777fe6864a05d88 Mon Sep 17 00:00:00 2001 From: "mingholy.lmh" Date: Wed, 17 Sep 2025 18:18:59 +0800 Subject: [PATCH] feat: add image tokenizer to fit vlm context window --- .../ui/components/ModelSelectionDialog.tsx | 2 +- .../src/ui/components/ModelSwitchDialog.tsx | 2 +- .../cli/src/ui/components/StatsDisplay.tsx | 2 +- .../SessionSummaryDisplay.test.tsx.snap | 2 +- .../__snapshots__/StatsDisplay.test.tsx.snap | 24 +- .../__tests__/openaiTimeoutHandling.test.ts | 114 +++--- packages/core/src/core/geminiChat.ts | 2 +- .../core/src/core/openaiContentGenerator.ts | 4 +- .../openaiContentGenerator.test.ts | 39 +- .../openaiContentGenerator.ts | 34 +- .../core/openaiContentGenerator/pipeline.ts | 72 ++-- .../provider/dashscope.ts | 10 + packages/core/src/core/tokenLimits.ts | 3 + packages/core/src/core/turn.test.ts | 2 +- packages/core/src/core/turn.ts | 2 +- .../src/qwen/qwenContentGenerator.test.ts | 23 +- .../core/src/qwen/qwenContentGenerator.ts | 4 +- .../request-tokenizer/imageTokenizer.test.ts | 157 ++++++++ .../utils/request-tokenizer/imageTokenizer.ts | 309 ++++++++++++++++ .../core/src/utils/request-tokenizer/index.ts | 40 ++ .../requestTokenizer.test.ts | 293 +++++++++++++++ .../request-tokenizer/requestTokenizer.ts | 336 +++++++++++++++++ .../request-tokenizer/textTokenizer.test.ts | 347 ++++++++++++++++++ .../utils/request-tokenizer/textTokenizer.ts | 97 +++++ .../core/src/utils/request-tokenizer/types.ts | 64 ++++ 25 files changed, 1860 insertions(+), 124 deletions(-) create mode 100644 packages/core/src/utils/request-tokenizer/imageTokenizer.test.ts create mode 100644 packages/core/src/utils/request-tokenizer/imageTokenizer.ts create mode 100644 packages/core/src/utils/request-tokenizer/index.ts create mode 100644 packages/core/src/utils/request-tokenizer/requestTokenizer.test.ts create mode 100644 packages/core/src/utils/request-tokenizer/requestTokenizer.ts create mode 100644 packages/core/src/utils/request-tokenizer/textTokenizer.test.ts create mode 100644 packages/core/src/utils/request-tokenizer/textTokenizer.ts create mode 100644 packages/core/src/utils/request-tokenizer/types.ts diff --git a/packages/cli/src/ui/components/ModelSelectionDialog.tsx b/packages/cli/src/ui/components/ModelSelectionDialog.tsx index 56f01e5c..7dbc570a 100644 --- a/packages/cli/src/ui/components/ModelSelectionDialog.tsx +++ b/packages/cli/src/ui/components/ModelSelectionDialog.tsx @@ -9,7 +9,7 @@ import { Box, Text } from 'ink'; import { Colors } from '../colors.js'; import { RadioButtonSelect, - RadioSelectItem, + type RadioSelectItem, } from './shared/RadioButtonSelect.js'; import { useKeypress } from '../hooks/useKeypress.js'; import { AvailableModel } from '../models/availableModels.js'; diff --git a/packages/cli/src/ui/components/ModelSwitchDialog.tsx b/packages/cli/src/ui/components/ModelSwitchDialog.tsx index 7cc37f9d..62b9bbce 100644 --- a/packages/cli/src/ui/components/ModelSwitchDialog.tsx +++ b/packages/cli/src/ui/components/ModelSwitchDialog.tsx @@ -9,7 +9,7 @@ import { Box, Text } from 'ink'; import { Colors } from '../colors.js'; import { RadioButtonSelect, - RadioSelectItem, + type RadioSelectItem, } from './shared/RadioButtonSelect.js'; import { useKeypress } from '../hooks/useKeypress.js'; diff --git a/packages/cli/src/ui/components/StatsDisplay.tsx b/packages/cli/src/ui/components/StatsDisplay.tsx index b5a14d55..f3b326a5 100644 --- a/packages/cli/src/ui/components/StatsDisplay.tsx +++ b/packages/cli/src/ui/components/StatsDisplay.tsx @@ -205,7 +205,7 @@ export const StatsDisplay: React.FC = ({ {tools.totalCalls} ({' '} ✓ {tools.totalSuccess}{' '} - ✖ {tools.totalFail} ) + x {tools.totalFail} ) diff --git a/packages/cli/src/ui/components/__snapshots__/SessionSummaryDisplay.test.tsx.snap b/packages/cli/src/ui/components/__snapshots__/SessionSummaryDisplay.test.tsx.snap index f07cc7ec..7c925f72 100644 --- a/packages/cli/src/ui/components/__snapshots__/SessionSummaryDisplay.test.tsx.snap +++ b/packages/cli/src/ui/components/__snapshots__/SessionSummaryDisplay.test.tsx.snap @@ -7,7 +7,7 @@ exports[` > renders the summary display with a title 1` │ │ │ Interaction Summary │ │ Session ID: │ -│ Tool Calls: 0 ( ✓ 0 ✖ 0 ) │ +│ Tool Calls: 0 ( ✓ 0 x 0 ) │ │ Success Rate: 0.0% │ │ Code Changes: +42 -15 │ │ │ diff --git a/packages/cli/src/ui/components/__snapshots__/StatsDisplay.test.tsx.snap b/packages/cli/src/ui/components/__snapshots__/StatsDisplay.test.tsx.snap index 1fbd42c8..8106d1f5 100644 --- a/packages/cli/src/ui/components/__snapshots__/StatsDisplay.test.tsx.snap +++ b/packages/cli/src/ui/components/__snapshots__/StatsDisplay.test.tsx.snap @@ -7,7 +7,7 @@ exports[` > Code Changes Display > displays Code Changes when li │ │ │ Interaction Summary │ │ Session ID: test-session-id │ -│ Tool Calls: 1 ( ✓ 1 ✖ 0 ) │ +│ Tool Calls: 1 ( ✓ 1 x 0 ) │ │ Success Rate: 100.0% │ │ Code Changes: +42 -18 │ │ │ @@ -28,7 +28,7 @@ exports[` > Code Changes Display > hides Code Changes when no li │ │ │ Interaction Summary │ │ Session ID: test-session-id │ -│ Tool Calls: 1 ( ✓ 1 ✖ 0 ) │ +│ Tool Calls: 1 ( ✓ 1 x 0 ) │ │ Success Rate: 100.0% │ │ │ │ Performance │ @@ -48,7 +48,7 @@ exports[` > Conditional Color Tests > renders success rate in gr │ │ │ Interaction Summary │ │ Session ID: test-session-id │ -│ Tool Calls: 10 ( ✓ 10 ✖ 0 ) │ +│ Tool Calls: 10 ( ✓ 10 x 0 ) │ │ Success Rate: 100.0% │ │ │ │ Performance │ @@ -68,7 +68,7 @@ exports[` > Conditional Color Tests > renders success rate in re │ │ │ Interaction Summary │ │ Session ID: test-session-id │ -│ Tool Calls: 10 ( ✓ 5 ✖ 5 ) │ +│ Tool Calls: 10 ( ✓ 5 x 5 ) │ │ Success Rate: 50.0% │ │ │ │ Performance │ @@ -88,7 +88,7 @@ exports[` > Conditional Color Tests > renders success rate in ye │ │ │ Interaction Summary │ │ Session ID: test-session-id │ -│ Tool Calls: 10 ( ✓ 9 ✖ 1 ) │ +│ Tool Calls: 10 ( ✓ 9 x 1 ) │ │ Success Rate: 90.0% │ │ │ │ Performance │ @@ -108,7 +108,7 @@ exports[` > Conditional Rendering Tests > hides Efficiency secti │ │ │ Interaction Summary │ │ Session ID: test-session-id │ -│ Tool Calls: 0 ( ✓ 0 ✖ 0 ) │ +│ Tool Calls: 0 ( ✓ 0 x 0 ) │ │ Success Rate: 0.0% │ │ │ │ Performance │ @@ -132,7 +132,7 @@ exports[` > Conditional Rendering Tests > hides User Agreement w │ │ │ Interaction Summary │ │ Session ID: test-session-id │ -│ Tool Calls: 2 ( ✓ 1 ✖ 1 ) │ +│ Tool Calls: 2 ( ✓ 1 x 1 ) │ │ Success Rate: 50.0% │ │ │ │ Performance │ @@ -152,7 +152,7 @@ exports[` > Title Rendering > renders the custom title when a ti │ │ │ Interaction Summary │ │ Session ID: test-session-id │ -│ Tool Calls: 0 ( ✓ 0 ✖ 0 ) │ +│ Tool Calls: 0 ( ✓ 0 x 0 ) │ │ Success Rate: 0.0% │ │ │ │ Performance │ @@ -172,7 +172,7 @@ exports[` > Title Rendering > renders the default title when no │ │ │ Interaction Summary │ │ Session ID: test-session-id │ -│ Tool Calls: 0 ( ✓ 0 ✖ 0 ) │ +│ Tool Calls: 0 ( ✓ 0 x 0 ) │ │ Success Rate: 0.0% │ │ │ │ Performance │ @@ -192,7 +192,7 @@ exports[` > renders a table with two models correctly 1`] = ` │ │ │ Interaction Summary │ │ Session ID: test-session-id │ -│ Tool Calls: 0 ( ✓ 0 ✖ 0 ) │ +│ Tool Calls: 0 ( ✓ 0 x 0 ) │ │ Success Rate: 0.0% │ │ │ │ Performance │ @@ -221,7 +221,7 @@ exports[` > renders all sections when all data is present 1`] = │ │ │ Interaction Summary │ │ Session ID: test-session-id │ -│ Tool Calls: 2 ( ✓ 1 ✖ 1 ) │ +│ Tool Calls: 2 ( ✓ 1 x 1 ) │ │ Success Rate: 50.0% │ │ User Agreement: 100.0% (1 reviewed) │ │ │ @@ -250,7 +250,7 @@ exports[` > renders only the Performance section in its zero sta │ │ │ Interaction Summary │ │ Session ID: test-session-id │ -│ Tool Calls: 0 ( ✓ 0 ✖ 0 ) │ +│ Tool Calls: 0 ( ✓ 0 x 0 ) │ │ Success Rate: 0.0% │ │ │ │ Performance │ diff --git a/packages/core/src/core/__tests__/openaiTimeoutHandling.test.ts b/packages/core/src/core/__tests__/openaiTimeoutHandling.test.ts index bb46b09b..91a97874 100644 --- a/packages/core/src/core/__tests__/openaiTimeoutHandling.test.ts +++ b/packages/core/src/core/__tests__/openaiTimeoutHandling.test.ts @@ -5,9 +5,10 @@ */ import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest'; -import { OpenAIContentGenerator } from '../openaiContentGenerator.js'; +import { OpenAIContentGenerator } from '../openaiContentGenerator/openaiContentGenerator.js'; import { Config } from '../../config/config.js'; import { AuthType } from '../contentGenerator.js'; +import type { OpenAICompatibleProvider } from '../openaiContentGenerator/provider/index.js'; import OpenAI from 'openai'; // Mock OpenAI @@ -30,6 +31,7 @@ describe('OpenAIContentGenerator Timeout Handling', () => { let mockConfig: Config; // eslint-disable-next-line @typescript-eslint/no-explicit-any let mockOpenAIClient: any; + let mockProvider: OpenAICompatibleProvider; beforeEach(() => { // Reset mocks @@ -42,6 +44,7 @@ describe('OpenAIContentGenerator Timeout Handling', () => { mockConfig = { getContentGeneratorConfig: vi.fn().mockReturnValue({ authType: 'openai', + enableOpenAILogging: false, }), getCliVersion: vi.fn().mockReturnValue('1.0.0'), } as unknown as Config; @@ -53,17 +56,34 @@ describe('OpenAIContentGenerator Timeout Handling', () => { create: vi.fn(), }, }, + embeddings: { + create: vi.fn(), + }, }; vi.mocked(OpenAI).mockImplementation(() => mockOpenAIClient); + // Create mock provider + mockProvider = { + buildHeaders: vi.fn().mockReturnValue({ + 'User-Agent': 'QwenCode/1.0.0 (test; test)', + }), + buildClient: vi.fn().mockReturnValue(mockOpenAIClient), + buildRequest: vi.fn().mockImplementation((req) => req), + }; + // Create generator instance const contentGeneratorConfig = { model: 'gpt-4', apiKey: 'test-key', authType: AuthType.USE_OPENAI, + enableOpenAILogging: false, }; - generator = new OpenAIContentGenerator(contentGeneratorConfig, mockConfig); + generator = new OpenAIContentGenerator( + contentGeneratorConfig, + mockConfig, + mockProvider, + ); }); afterEach(() => { @@ -209,7 +229,7 @@ describe('OpenAIContentGenerator Timeout Handling', () => { await expect( generator.generateContentStream(request, 'test-prompt-id'), ).rejects.toThrow( - /Streaming setup timeout after \d+s\. Try reducing input length or increasing timeout in config\./, + /Streaming request timeout after \d+s\. Try reducing input length or increasing timeout in config\./, ); }); @@ -227,12 +247,8 @@ describe('OpenAIContentGenerator Timeout Handling', () => { } catch (error: unknown) { const errorMessage = error instanceof Error ? error.message : String(error); - expect(errorMessage).toContain( - 'Streaming setup timeout troubleshooting:', - ); - expect(errorMessage).toContain( - 'Check network connectivity and firewall settings', - ); + expect(errorMessage).toContain('Streaming timeout troubleshooting:'); + expect(errorMessage).toContain('Check network connectivity'); expect(errorMessage).toContain('Consider using non-streaming mode'); } }); @@ -246,23 +262,21 @@ describe('OpenAIContentGenerator Timeout Handling', () => { authType: AuthType.USE_OPENAI, baseUrl: 'http://localhost:8080', }; - new OpenAIContentGenerator(contentGeneratorConfig, mockConfig); + new OpenAIContentGenerator( + contentGeneratorConfig, + mockConfig, + mockProvider, + ); - // Verify OpenAI client was created with timeout config - expect(OpenAI).toHaveBeenCalledWith({ - apiKey: 'test-key', - baseURL: 'http://localhost:8080', - timeout: 120000, - maxRetries: 3, - defaultHeaders: { - 'User-Agent': expect.stringMatching(/^QwenCode/), - }, - }); + // Verify provider buildClient was called + expect(mockProvider.buildClient).toHaveBeenCalled(); }); it('should use custom timeout from config', () => { const customConfig = { - getContentGeneratorConfig: vi.fn().mockReturnValue({}), + getContentGeneratorConfig: vi.fn().mockReturnValue({ + enableOpenAILogging: false, + }), getCliVersion: vi.fn().mockReturnValue('1.0.0'), } as unknown as Config; @@ -274,22 +288,31 @@ describe('OpenAIContentGenerator Timeout Handling', () => { timeout: 300000, maxRetries: 5, }; - new OpenAIContentGenerator(contentGeneratorConfig, customConfig); - expect(OpenAI).toHaveBeenCalledWith({ - apiKey: 'test-key', - baseURL: 'http://localhost:8080', - timeout: 300000, - maxRetries: 5, - defaultHeaders: { - 'User-Agent': expect.stringMatching(/^QwenCode/), - }, - }); + // Create a custom mock provider for this test + const customMockProvider: OpenAICompatibleProvider = { + buildHeaders: vi.fn().mockReturnValue({ + 'User-Agent': 'QwenCode/1.0.0 (test; test)', + }), + buildClient: vi.fn().mockReturnValue(mockOpenAIClient), + buildRequest: vi.fn().mockImplementation((req) => req), + }; + + new OpenAIContentGenerator( + contentGeneratorConfig, + customConfig, + customMockProvider, + ); + + // Verify provider buildClient was called + expect(customMockProvider.buildClient).toHaveBeenCalled(); }); it('should handle missing timeout config gracefully', () => { const noTimeoutConfig = { - getContentGeneratorConfig: vi.fn().mockReturnValue({}), + getContentGeneratorConfig: vi.fn().mockReturnValue({ + enableOpenAILogging: false, + }), getCliVersion: vi.fn().mockReturnValue('1.0.0'), } as unknown as Config; @@ -299,17 +322,24 @@ describe('OpenAIContentGenerator Timeout Handling', () => { authType: AuthType.USE_OPENAI, baseUrl: 'http://localhost:8080', }; - new OpenAIContentGenerator(contentGeneratorConfig, noTimeoutConfig); - expect(OpenAI).toHaveBeenCalledWith({ - apiKey: 'test-key', - baseURL: 'http://localhost:8080', - timeout: 120000, // default - maxRetries: 3, // default - defaultHeaders: { - 'User-Agent': expect.stringMatching(/^QwenCode/), - }, - }); + // Create a custom mock provider for this test + const noTimeoutMockProvider: OpenAICompatibleProvider = { + buildHeaders: vi.fn().mockReturnValue({ + 'User-Agent': 'QwenCode/1.0.0 (test; test)', + }), + buildClient: vi.fn().mockReturnValue(mockOpenAIClient), + buildRequest: vi.fn().mockImplementation((req) => req), + }; + + new OpenAIContentGenerator( + contentGeneratorConfig, + noTimeoutConfig, + noTimeoutMockProvider, + ); + + // Verify provider buildClient was called + expect(noTimeoutMockProvider.buildClient).toHaveBeenCalled(); }); }); diff --git a/packages/core/src/core/geminiChat.ts b/packages/core/src/core/geminiChat.ts index 5db2a286..c2407474 100644 --- a/packages/core/src/core/geminiChat.ts +++ b/packages/core/src/core/geminiChat.ts @@ -480,7 +480,7 @@ export class GeminiChat { if (error instanceof Error && error.message) { if (isSchemaDepthError(error.message)) return false; if (error.message.includes('429')) return true; - if (error.message.match(/5\d{2}/)) return true; + if (error.message.match(/^5\d{2}/)) return true; } return false; }, diff --git a/packages/core/src/core/openaiContentGenerator.ts b/packages/core/src/core/openaiContentGenerator.ts index b0f22de8..bc8cfa77 100644 --- a/packages/core/src/core/openaiContentGenerator.ts +++ b/packages/core/src/core/openaiContentGenerator.ts @@ -644,8 +644,8 @@ export class OpenAIContentGenerator implements ContentGenerator { let totalTokens = 0; try { - const { get_encoding } = await import('tiktoken'); - const encoding = get_encoding('cl100k_base'); // GPT-4 encoding, but estimate for qwen + const tikToken = await import('tiktoken'); + const encoding = tikToken.get_encoding('cl100k_base'); // GPT-4 encoding, but estimate for qwen totalTokens = encoding.encode(content).length; encoding.free(); } catch (error) { diff --git a/packages/core/src/core/openaiContentGenerator/openaiContentGenerator.test.ts b/packages/core/src/core/openaiContentGenerator/openaiContentGenerator.test.ts index 0569c120..3ee57dd6 100644 --- a/packages/core/src/core/openaiContentGenerator/openaiContentGenerator.test.ts +++ b/packages/core/src/core/openaiContentGenerator/openaiContentGenerator.test.ts @@ -5,6 +5,37 @@ */ import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest'; + +// Mock the request tokenizer module BEFORE importing the class that uses it +const mockTokenizer = { + calculateTokens: vi.fn().mockResolvedValue({ + totalTokens: 50, + breakdown: { + textTokens: 50, + imageTokens: 0, + audioTokens: 0, + otherTokens: 0, + }, + processingTime: 1, + }), + dispose: vi.fn(), +}; + +vi.mock('../../../utils/request-tokenizer/index.js', () => ({ + getDefaultTokenizer: vi.fn(() => mockTokenizer), + DefaultRequestTokenizer: vi.fn(() => mockTokenizer), + disposeDefaultTokenizer: vi.fn(), +})); + +// Mock tiktoken as well for completeness +vi.mock('tiktoken', () => ({ + get_encoding: vi.fn(() => ({ + encode: vi.fn(() => new Array(50)), // Mock 50 tokens + free: vi.fn(), + })), +})); + +// Now import the modules that depend on the mocked modules import { OpenAIContentGenerator } from './openaiContentGenerator.js'; import { Config } from '../../config/config.js'; import { AuthType } from '../contentGenerator.js'; @@ -15,14 +46,6 @@ import type { import type { OpenAICompatibleProvider } from './provider/index.js'; import OpenAI from 'openai'; -// Mock tiktoken -vi.mock('tiktoken', () => ({ - get_encoding: vi.fn().mockReturnValue({ - encode: vi.fn().mockReturnValue(new Array(50)), // Mock 50 tokens - free: vi.fn(), - }), -})); - describe('OpenAIContentGenerator (Refactored)', () => { let generator: OpenAIContentGenerator; let mockConfig: Config; diff --git a/packages/core/src/core/openaiContentGenerator/openaiContentGenerator.ts b/packages/core/src/core/openaiContentGenerator/openaiContentGenerator.ts index c251d87d..90f784a3 100644 --- a/packages/core/src/core/openaiContentGenerator/openaiContentGenerator.ts +++ b/packages/core/src/core/openaiContentGenerator/openaiContentGenerator.ts @@ -13,6 +13,7 @@ import { ContentGenerationPipeline, PipelineConfig } from './pipeline.js'; import { DefaultTelemetryService } from './telemetryService.js'; import { EnhancedErrorHandler } from './errorHandler.js'; import { ContentGeneratorConfig } from '../contentGenerator.js'; +import { getDefaultTokenizer } from '../../utils/request-tokenizer/index.js'; export class OpenAIContentGenerator implements ContentGenerator { protected pipeline: ContentGenerationPipeline; @@ -70,27 +71,30 @@ export class OpenAIContentGenerator implements ContentGenerator { async countTokens( request: CountTokensParameters, ): Promise { - // Use tiktoken for accurate token counting - const content = JSON.stringify(request.contents); - let totalTokens = 0; - try { - const { get_encoding } = await import('tiktoken'); - const encoding = get_encoding('cl100k_base'); // GPT-4 encoding, but estimate for qwen - totalTokens = encoding.encode(content).length; - encoding.free(); + // Use the new high-performance request tokenizer + const tokenizer = getDefaultTokenizer(); + const result = await tokenizer.calculateTokens(request, { + textEncoding: 'cl100k_base', // Use GPT-4 encoding for consistency + }); + + return { + totalTokens: result.totalTokens, + }; } catch (error) { console.warn( - 'Failed to load tiktoken, falling back to character approximation:', + 'Failed to calculate tokens with new tokenizer, falling back to simple method:', error, ); - // Fallback: rough approximation using character count - totalTokens = Math.ceil(content.length / 4); // Rough estimate: 1 token ≈ 4 characters - } - return { - totalTokens, - }; + // Fallback to original simple method + const content = JSON.stringify(request.contents); + const totalTokens = Math.ceil(content.length / 4); // Rough estimate: 1 token ≈ 4 characters + + return { + totalTokens, + }; + } } async embedContent( diff --git a/packages/core/src/core/openaiContentGenerator/pipeline.ts b/packages/core/src/core/openaiContentGenerator/pipeline.ts index f3697441..53cda85a 100644 --- a/packages/core/src/core/openaiContentGenerator/pipeline.ts +++ b/packages/core/src/core/openaiContentGenerator/pipeline.ts @@ -98,7 +98,7 @@ export class ContentGenerationPipeline { * 2. Filter empty responses * 3. Handle chunk merging for providers that send finishReason and usageMetadata separately * 4. Collect both formats for logging - * 5. Handle success/error logging with original OpenAI format + * 5. Handle success/error logging */ private async *processStreamWithLogging( stream: AsyncIterable, @@ -166,19 +166,11 @@ export class ContentGenerationPipeline { collectedOpenAIChunks, ); } catch (error) { - // Stage 2e: Stream failed - handle error and logging - context.duration = Date.now() - context.startTime; - // Clear streaming tool calls on error to prevent data pollution this.converter.resetStreamingToolCalls(); - await this.config.telemetryService.logError( - context, - error, - openaiRequest, - ); - - this.config.errorHandler.handle(error, context, request); + // Use shared error handling logic + await this.handleError(error, context, request); } } @@ -362,25 +354,59 @@ export class ContentGenerationPipeline { context.duration = Date.now() - context.startTime; return result; } catch (error) { - context.duration = Date.now() - context.startTime; - - // Log error - const openaiRequest = await this.buildRequest( + // Use shared error handling logic + return await this.handleError( + error, + context, request, userPromptId, isStreaming, ); - await this.config.telemetryService.logError( - context, - error, - openaiRequest, - ); - - // Handle and throw enhanced error - this.config.errorHandler.handle(error, context, request); } } + /** + * Shared error handling logic for both executeWithErrorHandling and processStreamWithLogging + * This centralizes the common error processing steps to avoid duplication + */ + private async handleError( + error: unknown, + context: RequestContext, + request: GenerateContentParameters, + userPromptId?: string, + isStreaming?: boolean, + ): Promise { + context.duration = Date.now() - context.startTime; + + // Build request for logging (may fail, but we still want to log the error) + let openaiRequest: OpenAI.Chat.ChatCompletionCreateParams; + try { + if (userPromptId !== undefined && isStreaming !== undefined) { + openaiRequest = await this.buildRequest( + request, + userPromptId, + isStreaming, + ); + } else { + // For processStreamWithLogging, we don't have userPromptId/isStreaming, + // so create a minimal request + openaiRequest = { + model: this.contentGeneratorConfig.model, + messages: [], + }; + } + } catch (_buildError) { + // If we can't build the request, create a minimal one for logging + openaiRequest = { + model: this.contentGeneratorConfig.model, + messages: [], + }; + } + + await this.config.telemetryService.logError(context, error, openaiRequest); + this.config.errorHandler.handle(error, context, request); + } + /** * Create request context with common properties */ diff --git a/packages/core/src/core/openaiContentGenerator/provider/dashscope.ts b/packages/core/src/core/openaiContentGenerator/provider/dashscope.ts index 9a82d475..1dd5c2da 100644 --- a/packages/core/src/core/openaiContentGenerator/provider/dashscope.ts +++ b/packages/core/src/core/openaiContentGenerator/provider/dashscope.ts @@ -78,6 +78,16 @@ export class DashScopeOpenAICompatibleProvider messages = this.addDashScopeCacheControl(messages, cacheTarget); } + if (request.model.startsWith('qwen-vl')) { + return { + ...request, + messages, + ...(this.buildMetadata(userPromptId) || {}), + /* @ts-expect-error dashscope exclusive */ + vl_high_resolution_images: true, + }; + } + return { ...request, // Preserve all original parameters including sampling params messages, diff --git a/packages/core/src/core/tokenLimits.ts b/packages/core/src/core/tokenLimits.ts index e51becab..2e502037 100644 --- a/packages/core/src/core/tokenLimits.ts +++ b/packages/core/src/core/tokenLimits.ts @@ -116,6 +116,9 @@ const PATTERNS: Array<[RegExp, TokenCount]> = [ [/^qwen-flash-latest$/, LIMITS['1m']], [/^qwen-turbo.*$/, LIMITS['128k']], + // Qwen Vision Models + [/^qwen-vl-max.*$/, LIMITS['128k']], + // ------------------- // ByteDance Seed-OSS (512K) // ------------------- diff --git a/packages/core/src/core/turn.test.ts b/packages/core/src/core/turn.test.ts index 7144d16b..d6cc195d 100644 --- a/packages/core/src/core/turn.test.ts +++ b/packages/core/src/core/turn.test.ts @@ -222,7 +222,7 @@ describe('Turn', () => { expect(turn.getDebugResponses().length).toBe(0); expect(reportError).toHaveBeenCalledWith( error, - 'Error when talking to Gemini API', + 'Error when talking to API', [...historyContent, reqParts], 'Turn.run-sendMessageStream', ); diff --git a/packages/core/src/core/turn.ts b/packages/core/src/core/turn.ts index 24f718e9..80f286c7 100644 --- a/packages/core/src/core/turn.ts +++ b/packages/core/src/core/turn.ts @@ -273,7 +273,7 @@ export class Turn { const contextForReport = [...this.chat.getHistory(/*curated*/ true), req]; await reportError( error, - 'Error when talking to Gemini API', + 'Error when talking to API', contextForReport, 'Turn.run-sendMessageStream', ); diff --git a/packages/core/src/qwen/qwenContentGenerator.test.ts b/packages/core/src/qwen/qwenContentGenerator.test.ts index 90f40fd8..40a835bd 100644 --- a/packages/core/src/qwen/qwenContentGenerator.test.ts +++ b/packages/core/src/qwen/qwenContentGenerator.test.ts @@ -404,11 +404,9 @@ describe('QwenContentGenerator', () => { expect(mockQwenClient.getAccessToken).toHaveBeenCalled(); }); - it('should count tokens with valid token', async () => { - vi.mocked(mockQwenClient.getAccessToken).mockResolvedValue({ - token: 'valid-token', - }); - vi.mocked(mockQwenClient.getCredentials).mockReturnValue(mockCredentials); + it('should count tokens without requiring authentication', async () => { + // Clear any previous mock calls + vi.clearAllMocks(); const request: CountTokensParameters = { model: 'qwen-turbo', @@ -418,7 +416,8 @@ describe('QwenContentGenerator', () => { const result = await qwenContentGenerator.countTokens(request); expect(result.totalTokens).toBe(15); - expect(mockQwenClient.getAccessToken).toHaveBeenCalled(); + // countTokens is a local operation and should not require OAuth credentials + expect(mockQwenClient.getAccessToken).not.toHaveBeenCalled(); }); it('should embed content with valid token', async () => { @@ -1655,7 +1654,7 @@ describe('QwenContentGenerator', () => { SharedTokenManager.getInstance = originalGetInstance; }); - it('should handle all method types with token failure', async () => { + it('should handle method types with token failure (except countTokens)', async () => { const mockTokenManager = { getValidCredentials: vi .fn() @@ -1688,7 +1687,7 @@ describe('QwenContentGenerator', () => { contents: [{ parts: [{ text: 'Embed' }] }], }; - // All methods should fail with the same error + // Methods requiring authentication should fail await expect( newGenerator.generateContent(generateRequest, 'test-id'), ).rejects.toThrow('Failed to obtain valid Qwen access token'); @@ -1697,14 +1696,14 @@ describe('QwenContentGenerator', () => { newGenerator.generateContentStream(generateRequest, 'test-id'), ).rejects.toThrow('Failed to obtain valid Qwen access token'); - await expect(newGenerator.countTokens(countRequest)).rejects.toThrow( - 'Failed to obtain valid Qwen access token', - ); - await expect(newGenerator.embedContent(embedRequest)).rejects.toThrow( 'Failed to obtain valid Qwen access token', ); + // countTokens should succeed as it's a local operation + const countResult = await newGenerator.countTokens(countRequest); + expect(countResult.totalTokens).toBe(15); + SharedTokenManager.getInstance = originalGetInstance; }); }); diff --git a/packages/core/src/qwen/qwenContentGenerator.ts b/packages/core/src/qwen/qwenContentGenerator.ts index b70a1bd1..5f2bdaea 100644 --- a/packages/core/src/qwen/qwenContentGenerator.ts +++ b/packages/core/src/qwen/qwenContentGenerator.ts @@ -180,9 +180,7 @@ export class QwenContentGenerator extends OpenAIContentGenerator { override async countTokens( request: CountTokensParameters, ): Promise { - return this.executeWithCredentialManagement(() => - super.countTokens(request), - ); + return super.countTokens(request); } /** diff --git a/packages/core/src/utils/request-tokenizer/imageTokenizer.test.ts b/packages/core/src/utils/request-tokenizer/imageTokenizer.test.ts new file mode 100644 index 00000000..cdb5f35f --- /dev/null +++ b/packages/core/src/utils/request-tokenizer/imageTokenizer.test.ts @@ -0,0 +1,157 @@ +/** + * @license + * Copyright 2025 Qwen + * SPDX-License-Identifier: Apache-2.0 + */ + +import { describe, it, expect } from 'vitest'; +import { ImageTokenizer } from './imageTokenizer.js'; + +describe('ImageTokenizer', () => { + const tokenizer = new ImageTokenizer(); + + describe('token calculation', () => { + it('should calculate tokens based on image dimensions with reference logic', () => { + const metadata = { + width: 28, + height: 28, + mimeType: 'image/png', + dataSize: 1000, + }; + + const tokens = tokenizer.calculateTokens(metadata); + + // 28x28 = 784 pixels = 1 image token + 2 special tokens = 3 total + // But minimum scaling may apply for small images + expect(tokens).toBeGreaterThanOrEqual(6); // Minimum after scaling + special tokens + }); + + it('should calculate tokens for larger images', () => { + const metadata = { + width: 512, + height: 512, + mimeType: 'image/png', + dataSize: 10000, + }; + + const tokens = tokenizer.calculateTokens(metadata); + + // 512x512 with reference logic: rounded dimensions + scaling + special tokens + expect(tokens).toBeGreaterThan(300); + expect(tokens).toBeLessThan(400); // Should be reasonable for 512x512 + }); + + it('should enforce minimum tokens per image with scaling', () => { + const metadata = { + width: 1, + height: 1, + mimeType: 'image/png', + dataSize: 100, + }; + + const tokens = tokenizer.calculateTokens(metadata); + + // Tiny images get scaled up to minimum pixels + special tokens + expect(tokens).toBeGreaterThanOrEqual(6); // 4 image tokens + 2 special tokens + }); + + it('should handle very large images with scaling', () => { + const metadata = { + width: 8192, + height: 8192, + mimeType: 'image/png', + dataSize: 100000, + }; + + const tokens = tokenizer.calculateTokens(metadata); + + // Very large images should be scaled down to max limit + special tokens + expect(tokens).toBeLessThanOrEqual(16386); // 16384 max + 2 special tokens + expect(tokens).toBeGreaterThan(16000); // Should be close to the limit + }); + }); + + describe('PNG dimension extraction', () => { + it('should extract dimensions from valid PNG', async () => { + // 1x1 PNG image in base64 + const pngBase64 = + 'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNkYPhfDwAChAI9jU77yQAAAABJRU5ErkJggg=='; + + const metadata = await tokenizer.extractImageMetadata( + pngBase64, + 'image/png', + ); + + expect(metadata.width).toBe(1); + expect(metadata.height).toBe(1); + expect(metadata.mimeType).toBe('image/png'); + }); + + it('should handle invalid PNG gracefully', async () => { + const invalidBase64 = 'invalid-png-data'; + + const metadata = await tokenizer.extractImageMetadata( + invalidBase64, + 'image/png', + ); + + // Should return default dimensions + expect(metadata.width).toBe(512); + expect(metadata.height).toBe(512); + expect(metadata.mimeType).toBe('image/png'); + }); + }); + + describe('batch processing', () => { + it('should process multiple images serially', async () => { + const pngBase64 = + 'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNkYPhfDwAChAI9jU77yQAAAABJRU5ErkJggg=='; + + const images = [ + { data: pngBase64, mimeType: 'image/png' }, + { data: pngBase64, mimeType: 'image/png' }, + { data: pngBase64, mimeType: 'image/png' }, + ]; + + const tokens = await tokenizer.calculateTokensBatch(images); + + expect(tokens).toHaveLength(3); + expect(tokens.every((t) => t >= 4)).toBe(true); // All should have at least 4 tokens + }); + + it('should handle mixed valid and invalid images', async () => { + const validPng = + 'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNkYPhfDwAChAI9jU77yQAAAABJRU5ErkJggg=='; + const invalidPng = 'invalid-data'; + + const images = [ + { data: validPng, mimeType: 'image/png' }, + { data: invalidPng, mimeType: 'image/png' }, + ]; + + const tokens = await tokenizer.calculateTokensBatch(images); + + expect(tokens).toHaveLength(2); + expect(tokens.every((t) => t >= 4)).toBe(true); // All should have at least minimum tokens + }); + }); + + describe('different image formats', () => { + it('should handle different MIME types', async () => { + const pngBase64 = + 'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNkYPhfDwAChAI9jU77yQAAAABJRU5ErkJggg=='; + + const formats = ['image/png', 'image/jpeg', 'image/webp', 'image/gif']; + + for (const mimeType of formats) { + const metadata = await tokenizer.extractImageMetadata( + pngBase64, + mimeType, + ); + expect(metadata.mimeType).toBe(mimeType); + expect(metadata.width).toBeGreaterThan(0); + expect(metadata.height).toBeGreaterThan(0); + } + }); + }); +}); diff --git a/packages/core/src/utils/request-tokenizer/imageTokenizer.ts b/packages/core/src/utils/request-tokenizer/imageTokenizer.ts new file mode 100644 index 00000000..5bd4f2cf --- /dev/null +++ b/packages/core/src/utils/request-tokenizer/imageTokenizer.ts @@ -0,0 +1,309 @@ +/** + * @license + * Copyright 2025 Qwen + * SPDX-License-Identifier: Apache-2.0 + */ + +import { ImageMetadata } from './types.js'; + +/** + * Image tokenizer for calculating image tokens based on dimensions + * + * Key rules: + * - 28x28 pixels = 1 token + * - Minimum: 4 tokens per image + * - Maximum: 16384 tokens per image + * - Additional: 2 special tokens (vision_bos + vision_eos) + * - Supports: PNG, JPEG, WebP, GIF formats + */ +export class ImageTokenizer { + /** 28x28 pixels = 1 token */ + private static readonly PIXELS_PER_TOKEN = 28 * 28; + + /** Minimum tokens per image */ + private static readonly MIN_TOKENS_PER_IMAGE = 4; + + /** Maximum tokens per image */ + private static readonly MAX_TOKENS_PER_IMAGE = 16384; + + /** Special tokens for vision markers */ + private static readonly VISION_SPECIAL_TOKENS = 2; + + /** + * Extract image metadata from base64 data + * + * @param base64Data Base64-encoded image data (with or without data URL prefix) + * @param mimeType MIME type of the image + * @returns Promise resolving to ImageMetadata with dimensions and format info + */ + async extractImageMetadata( + base64Data: string, + mimeType: string, + ): Promise { + try { + const cleanBase64 = base64Data.replace(/^data:[^;]+;base64,/, ''); + const buffer = Buffer.from(cleanBase64, 'base64'); + const dimensions = await this.extractDimensions(buffer, mimeType); + + return { + width: dimensions.width, + height: dimensions.height, + mimeType, + dataSize: buffer.length, + }; + } catch (error) { + console.warn('Failed to extract image metadata:', error); + // Return default metadata for fallback + return { + width: 512, + height: 512, + mimeType, + dataSize: Math.floor(base64Data.length * 0.75), + }; + } + } + + /** + * Extract image dimensions from buffer based on format + * + * @param buffer Binary image data buffer + * @param mimeType MIME type to determine parsing strategy + * @returns Promise resolving to width and height dimensions + */ + private async extractDimensions( + buffer: Buffer, + mimeType: string, + ): Promise<{ width: number; height: number }> { + if (mimeType.includes('png')) { + return this.extractPngDimensions(buffer); + } + + if (mimeType.includes('jpeg') || mimeType.includes('jpg')) { + return this.extractJpegDimensions(buffer); + } + + if (mimeType.includes('webp')) { + return this.extractWebpDimensions(buffer); + } + + if (mimeType.includes('gif')) { + return this.extractGifDimensions(buffer); + } + + return { width: 512, height: 512 }; + } + + /** + * Extract PNG dimensions from IHDR chunk + * PNG signature: 89 50 4E 47 0D 0A 1A 0A + * Width/height at bytes 16-19 and 20-23 (big-endian) + */ + private extractPngDimensions(buffer: Buffer): { + width: number; + height: number; + } { + if (buffer.length < 24) { + throw new Error('Invalid PNG: buffer too short'); + } + + // Verify PNG signature + const signature = buffer.subarray(0, 8); + const expectedSignature = Buffer.from([ + 0x89, 0x50, 0x4e, 0x47, 0x0d, 0x0a, 0x1a, 0x0a, + ]); + if (!signature.equals(expectedSignature)) { + throw new Error('Invalid PNG signature'); + } + + const width = buffer.readUInt32BE(16); + const height = buffer.readUInt32BE(20); + + return { width, height }; + } + + /** + * Extract JPEG dimensions from SOF (Start of Frame) markers + * JPEG starts with FF D8, SOF markers: 0xC0-0xC3, 0xC5-0xC7, 0xC9-0xCB, 0xCD-0xCF + * Dimensions at offset +5 (height) and +7 (width) from SOF marker + */ + private extractJpegDimensions(buffer: Buffer): { + width: number; + height: number; + } { + if (buffer.length < 4 || buffer[0] !== 0xff || buffer[1] !== 0xd8) { + throw new Error('Invalid JPEG signature'); + } + + let offset = 2; + + while (offset < buffer.length - 8) { + if (buffer[offset] !== 0xff) { + offset++; + continue; + } + + const marker = buffer[offset + 1]; + + // SOF markers + if ( + (marker >= 0xc0 && marker <= 0xc3) || + (marker >= 0xc5 && marker <= 0xc7) || + (marker >= 0xc9 && marker <= 0xcb) || + (marker >= 0xcd && marker <= 0xcf) + ) { + const height = buffer.readUInt16BE(offset + 5); + const width = buffer.readUInt16BE(offset + 7); + return { width, height }; + } + + const segmentLength = buffer.readUInt16BE(offset + 2); + offset += 2 + segmentLength; + } + + throw new Error('Could not find JPEG dimensions'); + } + + /** + * Extract WebP dimensions from RIFF container + * Supports VP8, VP8L, and VP8X formats + */ + private extractWebpDimensions(buffer: Buffer): { + width: number; + height: number; + } { + if (buffer.length < 30) { + throw new Error('Invalid WebP: too short'); + } + + const riffSignature = buffer.subarray(0, 4).toString('ascii'); + const webpSignature = buffer.subarray(8, 12).toString('ascii'); + + if (riffSignature !== 'RIFF' || webpSignature !== 'WEBP') { + throw new Error('Invalid WebP signature'); + } + + const format = buffer.subarray(12, 16).toString('ascii'); + + if (format === 'VP8 ') { + const width = buffer.readUInt16LE(26) & 0x3fff; + const height = buffer.readUInt16LE(28) & 0x3fff; + return { width, height }; + } else if (format === 'VP8L') { + const bits = buffer.readUInt32LE(21); + const width = (bits & 0x3fff) + 1; + const height = ((bits >> 14) & 0x3fff) + 1; + return { width, height }; + } else if (format === 'VP8X') { + const width = (buffer.readUInt32LE(24) & 0xffffff) + 1; + const height = (buffer.readUInt32LE(26) & 0xffffff) + 1; + return { width, height }; + } + + throw new Error('Unsupported WebP format'); + } + + /** + * Extract GIF dimensions from header + * Supports GIF87a and GIF89a formats + */ + private extractGifDimensions(buffer: Buffer): { + width: number; + height: number; + } { + if (buffer.length < 10) { + throw new Error('Invalid GIF: too short'); + } + + const signature = buffer.subarray(0, 6).toString('ascii'); + if (signature !== 'GIF87a' && signature !== 'GIF89a') { + throw new Error('Invalid GIF signature'); + } + + const width = buffer.readUInt16LE(6); + const height = buffer.readUInt16LE(8); + + return { width, height }; + } + + /** + * Calculate tokens for an image based on its metadata + * + * @param metadata Image metadata containing width, height, and format info + * @returns Total token count including base image tokens and special tokens + */ + calculateTokens(metadata: ImageMetadata): number { + return this.calculateTokensWithScaling(metadata.width, metadata.height); + } + + /** + * Calculate tokens with scaling logic + * + * Steps: + * 1. Normalize to 28-pixel multiples + * 2. Scale large images down, small images up + * 3. Calculate tokens: pixels / 784 + 2 special tokens + * + * @param width Original image width in pixels + * @param height Original image height in pixels + * @returns Total token count for the image + */ + private calculateTokensWithScaling(width: number, height: number): number { + // Normalize to 28-pixel multiples + let hBar = Math.round(height / 28) * 28; + let wBar = Math.round(width / 28) * 28; + + // Define pixel boundaries + const minPixels = + ImageTokenizer.MIN_TOKENS_PER_IMAGE * ImageTokenizer.PIXELS_PER_TOKEN; + const maxPixels = + ImageTokenizer.MAX_TOKENS_PER_IMAGE * ImageTokenizer.PIXELS_PER_TOKEN; + + // Apply scaling + if (hBar * wBar > maxPixels) { + // Scale down large images + const beta = Math.sqrt((height * width) / maxPixels); + hBar = Math.floor(height / beta / 28) * 28; + wBar = Math.floor(width / beta / 28) * 28; + } else if (hBar * wBar < minPixels) { + // Scale up small images + const beta = Math.sqrt(minPixels / (height * width)); + hBar = Math.ceil((height * beta) / 28) * 28; + wBar = Math.ceil((width * beta) / 28) * 28; + } + + // Calculate tokens + const imageTokens = Math.floor( + (hBar * wBar) / ImageTokenizer.PIXELS_PER_TOKEN, + ); + + return imageTokens + ImageTokenizer.VISION_SPECIAL_TOKENS; + } + + /** + * Calculate tokens for multiple images serially + * + * @param base64DataArray Array of image data with MIME type information + * @returns Promise resolving to array of token counts in same order as input + */ + async calculateTokensBatch( + base64DataArray: Array<{ data: string; mimeType: string }>, + ): Promise { + const results: number[] = []; + + for (const { data, mimeType } of base64DataArray) { + try { + const metadata = await this.extractImageMetadata(data, mimeType); + results.push(this.calculateTokens(metadata)); + } catch (error) { + console.warn('Error calculating tokens for image:', error); + // Return minimum tokens as fallback + results.push( + ImageTokenizer.MIN_TOKENS_PER_IMAGE + + ImageTokenizer.VISION_SPECIAL_TOKENS, + ); + } + } + + return results; + } +} diff --git a/packages/core/src/utils/request-tokenizer/index.ts b/packages/core/src/utils/request-tokenizer/index.ts new file mode 100644 index 00000000..064b93c1 --- /dev/null +++ b/packages/core/src/utils/request-tokenizer/index.ts @@ -0,0 +1,40 @@ +/** + * @license + * Copyright 2025 Qwen + * SPDX-License-Identifier: Apache-2.0 + */ + +export { DefaultRequestTokenizer } from './requestTokenizer.js'; +import { DefaultRequestTokenizer } from './requestTokenizer.js'; +export { TextTokenizer } from './textTokenizer.js'; +export { ImageTokenizer } from './imageTokenizer.js'; + +export type { + RequestTokenizer, + TokenizerConfig, + TokenCalculationResult, + ImageMetadata, +} from './types.js'; + +// Singleton instance for convenient usage +let defaultTokenizer: DefaultRequestTokenizer | null = null; + +/** + * Get the default request tokenizer instance + */ +export function getDefaultTokenizer(): DefaultRequestTokenizer { + if (!defaultTokenizer) { + defaultTokenizer = new DefaultRequestTokenizer(); + } + return defaultTokenizer; +} + +/** + * Dispose of the default tokenizer instance + */ +export async function disposeDefaultTokenizer(): Promise { + if (defaultTokenizer) { + await defaultTokenizer.dispose(); + defaultTokenizer = null; + } +} diff --git a/packages/core/src/utils/request-tokenizer/requestTokenizer.test.ts b/packages/core/src/utils/request-tokenizer/requestTokenizer.test.ts new file mode 100644 index 00000000..7e4493d4 --- /dev/null +++ b/packages/core/src/utils/request-tokenizer/requestTokenizer.test.ts @@ -0,0 +1,293 @@ +/** + * @license + * Copyright 2025 Qwen + * SPDX-License-Identifier: Apache-2.0 + */ + +import { describe, it, expect, beforeEach, afterEach } from 'vitest'; +import { DefaultRequestTokenizer } from './requestTokenizer.js'; +import { CountTokensParameters } from '@google/genai'; + +describe('DefaultRequestTokenizer', () => { + let tokenizer: DefaultRequestTokenizer; + + beforeEach(() => { + tokenizer = new DefaultRequestTokenizer(); + }); + + afterEach(async () => { + await tokenizer.dispose(); + }); + + describe('text token calculation', () => { + it('should calculate tokens for simple text content', async () => { + const request: CountTokensParameters = { + model: 'test-model', + contents: [ + { + role: 'user', + parts: [{ text: 'Hello, world!' }], + }, + ], + }; + + const result = await tokenizer.calculateTokens(request); + + expect(result.totalTokens).toBeGreaterThan(0); + expect(result.breakdown.textTokens).toBeGreaterThan(0); + expect(result.breakdown.imageTokens).toBe(0); + expect(result.processingTime).toBeGreaterThan(0); + }); + + it('should handle multiple text parts', async () => { + const request: CountTokensParameters = { + model: 'test-model', + contents: [ + { + role: 'user', + parts: [ + { text: 'First part' }, + { text: 'Second part' }, + { text: 'Third part' }, + ], + }, + ], + }; + + const result = await tokenizer.calculateTokens(request); + + expect(result.totalTokens).toBeGreaterThan(0); + expect(result.breakdown.textTokens).toBeGreaterThan(0); + }); + + it('should handle string content', async () => { + const request: CountTokensParameters = { + model: 'test-model', + contents: ['Simple string content'], + }; + + const result = await tokenizer.calculateTokens(request); + + expect(result.totalTokens).toBeGreaterThan(0); + expect(result.breakdown.textTokens).toBeGreaterThan(0); + }); + }); + + describe('image token calculation', () => { + it('should calculate tokens for image content', async () => { + // Create a simple 1x1 PNG image in base64 + const pngBase64 = + 'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNkYPhfDwAChAI9jU77yQAAAABJRU5ErkJggg=='; + + const request: CountTokensParameters = { + model: 'test-model', + contents: [ + { + role: 'user', + parts: [ + { + inlineData: { + mimeType: 'image/png', + data: pngBase64, + }, + }, + ], + }, + ], + }; + + const result = await tokenizer.calculateTokens(request); + + expect(result.totalTokens).toBeGreaterThanOrEqual(4); // Minimum 4 tokens per image + expect(result.breakdown.imageTokens).toBeGreaterThanOrEqual(4); + expect(result.breakdown.textTokens).toBe(0); + }); + + it('should handle multiple images', async () => { + const pngBase64 = + 'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNkYPhfDwAChAI9jU77yQAAAABJRU5ErkJggg=='; + + const request: CountTokensParameters = { + model: 'test-model', + contents: [ + { + role: 'user', + parts: [ + { + inlineData: { + mimeType: 'image/png', + data: pngBase64, + }, + }, + { + inlineData: { + mimeType: 'image/png', + data: pngBase64, + }, + }, + ], + }, + ], + }; + + const result = await tokenizer.calculateTokens(request); + + expect(result.totalTokens).toBeGreaterThanOrEqual(8); // At least 4 tokens per image + expect(result.breakdown.imageTokens).toBeGreaterThanOrEqual(8); + }); + }); + + describe('mixed content', () => { + it('should handle text and image content together', async () => { + const pngBase64 = + 'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNkYPhfDwAChAI9jU77yQAAAABJRU5ErkJggg=='; + + const request: CountTokensParameters = { + model: 'test-model', + contents: [ + { + role: 'user', + parts: [ + { text: 'Here is an image:' }, + { + inlineData: { + mimeType: 'image/png', + data: pngBase64, + }, + }, + { text: 'What do you see?' }, + ], + }, + ], + }; + + const result = await tokenizer.calculateTokens(request); + + expect(result.totalTokens).toBeGreaterThan(4); + expect(result.breakdown.textTokens).toBeGreaterThan(0); + expect(result.breakdown.imageTokens).toBeGreaterThanOrEqual(4); + }); + }); + + describe('function content', () => { + it('should handle function calls', async () => { + const request: CountTokensParameters = { + model: 'test-model', + contents: [ + { + role: 'user', + parts: [ + { + functionCall: { + name: 'test_function', + args: { param1: 'value1', param2: 42 }, + }, + }, + ], + }, + ], + }; + + const result = await tokenizer.calculateTokens(request); + + expect(result.totalTokens).toBeGreaterThan(0); + expect(result.breakdown.otherTokens).toBeGreaterThan(0); + }); + }); + + describe('empty content', () => { + it('should handle empty request', async () => { + const request: CountTokensParameters = { + model: 'test-model', + contents: [], + }; + + const result = await tokenizer.calculateTokens(request); + + expect(result.totalTokens).toBe(0); + expect(result.breakdown.textTokens).toBe(0); + expect(result.breakdown.imageTokens).toBe(0); + }); + + it('should handle undefined contents', async () => { + const request: CountTokensParameters = { + model: 'test-model', + contents: [], + }; + + const result = await tokenizer.calculateTokens(request); + + expect(result.totalTokens).toBe(0); + }); + }); + + describe('configuration', () => { + it('should use custom text encoding', async () => { + const request: CountTokensParameters = { + model: 'test-model', + contents: [ + { + role: 'user', + parts: [{ text: 'Test text for encoding' }], + }, + ], + }; + + const result = await tokenizer.calculateTokens(request, { + textEncoding: 'cl100k_base', + }); + + expect(result.totalTokens).toBeGreaterThan(0); + }); + + it('should process multiple images serially', async () => { + const pngBase64 = + 'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNkYPhfDwAChAI9jU77yQAAAABJRU5ErkJggg=='; + + const request: CountTokensParameters = { + model: 'test-model', + contents: [ + { + role: 'user', + parts: Array(10).fill({ + inlineData: { + mimeType: 'image/png', + data: pngBase64, + }, + }), + }, + ], + }; + + const result = await tokenizer.calculateTokens(request); + + expect(result.totalTokens).toBeGreaterThanOrEqual(60); // At least 6 tokens per image * 10 images + }); + }); + + describe('error handling', () => { + it('should handle malformed image data gracefully', async () => { + const request: CountTokensParameters = { + model: 'test-model', + contents: [ + { + role: 'user', + parts: [ + { + inlineData: { + mimeType: 'image/png', + data: 'invalid-base64-data', + }, + }, + ], + }, + ], + }; + + const result = await tokenizer.calculateTokens(request); + + // Should still return some tokens (fallback to minimum) + expect(result.totalTokens).toBeGreaterThanOrEqual(4); + }); + }); +}); diff --git a/packages/core/src/utils/request-tokenizer/requestTokenizer.ts b/packages/core/src/utils/request-tokenizer/requestTokenizer.ts new file mode 100644 index 00000000..5f1bd16c --- /dev/null +++ b/packages/core/src/utils/request-tokenizer/requestTokenizer.ts @@ -0,0 +1,336 @@ +/** + * @license + * Copyright 2025 Qwen + * SPDX-License-Identifier: Apache-2.0 + */ + +import { CountTokensParameters, Content, Part, PartUnion } from '@google/genai'; +import { + RequestTokenizer, + TokenizerConfig, + TokenCalculationResult, +} from './types.js'; +import { TextTokenizer } from './textTokenizer.js'; +import { ImageTokenizer } from './imageTokenizer.js'; + +/** + * Simple request tokenizer that handles text and image content serially + */ +export class DefaultRequestTokenizer implements RequestTokenizer { + private textTokenizer: TextTokenizer; + private imageTokenizer: ImageTokenizer; + + constructor() { + this.textTokenizer = new TextTokenizer(); + this.imageTokenizer = new ImageTokenizer(); + } + + /** + * Calculate tokens for a request using serial processing + */ + async calculateTokens( + request: CountTokensParameters, + config: TokenizerConfig = {}, + ): Promise { + const startTime = performance.now(); + + // Apply configuration + if (config.textEncoding) { + this.textTokenizer = new TextTokenizer(config.textEncoding); + } + + try { + // Process request content and group by type + const { textContents, imageContents, audioContents, otherContents } = + this.processAndGroupContents(request); + + if ( + textContents.length === 0 && + imageContents.length === 0 && + audioContents.length === 0 && + otherContents.length === 0 + ) { + return { + totalTokens: 0, + breakdown: { + textTokens: 0, + imageTokens: 0, + audioTokens: 0, + otherTokens: 0, + }, + processingTime: performance.now() - startTime, + }; + } + + // Calculate tokens for each content type serially + const textTokens = await this.calculateTextTokens(textContents); + const imageTokens = await this.calculateImageTokens(imageContents); + const audioTokens = await this.calculateAudioTokens(audioContents); + const otherTokens = await this.calculateOtherTokens(otherContents); + + const totalTokens = textTokens + imageTokens + audioTokens + otherTokens; + const processingTime = performance.now() - startTime; + + return { + totalTokens, + breakdown: { + textTokens, + imageTokens, + audioTokens, + otherTokens, + }, + processingTime, + }; + } catch (error) { + console.error('Error calculating tokens:', error); + + // Fallback calculation + const fallbackTokens = this.calculateFallbackTokens(request); + + return { + totalTokens: fallbackTokens, + breakdown: { + textTokens: fallbackTokens, + imageTokens: 0, + audioTokens: 0, + otherTokens: 0, + }, + processingTime: performance.now() - startTime, + }; + } + } + + /** + * Calculate tokens for text contents + */ + private async calculateTextTokens(textContents: string[]): Promise { + if (textContents.length === 0) return 0; + + try { + const tokenCounts = + await this.textTokenizer.calculateTokensBatch(textContents); + return tokenCounts.reduce((sum, count) => sum + count, 0); + } catch (error) { + console.warn('Error calculating text tokens:', error); + // Fallback: character-based estimation + const totalChars = textContents.join('').length; + return Math.ceil(totalChars / 4); + } + } + + /** + * Calculate tokens for image contents using serial processing + */ + private async calculateImageTokens( + imageContents: Array<{ data: string; mimeType: string }>, + ): Promise { + if (imageContents.length === 0) return 0; + + try { + const tokenCounts = + await this.imageTokenizer.calculateTokensBatch(imageContents); + return tokenCounts.reduce((sum, count) => sum + count, 0); + } catch (error) { + console.warn('Error calculating image tokens:', error); + // Fallback: minimum tokens per image + return imageContents.length * 6; // 4 image tokens + 2 special tokens as minimum + } + } + + /** + * Calculate tokens for audio contents + * TODO: Implement proper audio token calculation + */ + private async calculateAudioTokens( + audioContents: Array<{ data: string; mimeType: string }>, + ): Promise { + if (audioContents.length === 0) return 0; + + // Placeholder implementation - audio token calculation would depend on + // the specific model's audio processing capabilities + // For now, estimate based on data size + let totalTokens = 0; + + for (const audioContent of audioContents) { + try { + const dataSize = Math.floor(audioContent.data.length * 0.75); // Approximate binary size + // Rough estimate: 1 token per 100 bytes of audio data + totalTokens += Math.max(Math.ceil(dataSize / 100), 10); // Minimum 10 tokens per audio + } catch (error) { + console.warn('Error calculating audio tokens:', error); + totalTokens += 10; // Fallback minimum + } + } + + return totalTokens; + } + + /** + * Calculate tokens for other content types (functions, files, etc.) + */ + private async calculateOtherTokens(otherContents: string[]): Promise { + if (otherContents.length === 0) return 0; + + try { + // Treat other content as text for token calculation + const tokenCounts = + await this.textTokenizer.calculateTokensBatch(otherContents); + return tokenCounts.reduce((sum, count) => sum + count, 0); + } catch (error) { + console.warn('Error calculating other content tokens:', error); + // Fallback: character-based estimation + const totalChars = otherContents.join('').length; + return Math.ceil(totalChars / 4); + } + } + + /** + * Fallback token calculation using simple string serialization + */ + private calculateFallbackTokens(request: CountTokensParameters): number { + try { + const content = JSON.stringify(request.contents); + return Math.ceil(content.length / 4); // Rough estimate: 1 token ≈ 4 characters + } catch (error) { + console.warn('Error in fallback token calculation:', error); + return 100; // Conservative fallback + } + } + + /** + * Process request contents and group by type + */ + private processAndGroupContents(request: CountTokensParameters): { + textContents: string[]; + imageContents: Array<{ data: string; mimeType: string }>; + audioContents: Array<{ data: string; mimeType: string }>; + otherContents: string[]; + } { + const textContents: string[] = []; + const imageContents: Array<{ data: string; mimeType: string }> = []; + const audioContents: Array<{ data: string; mimeType: string }> = []; + const otherContents: string[] = []; + + if (!request.contents) { + return { textContents, imageContents, audioContents, otherContents }; + } + + const contents = Array.isArray(request.contents) + ? request.contents + : [request.contents]; + + for (const content of contents) { + this.processContent( + content, + textContents, + imageContents, + audioContents, + otherContents, + ); + } + + return { textContents, imageContents, audioContents, otherContents }; + } + + /** + * Process a single content item and add to appropriate arrays + */ + private processContent( + content: Content | string | PartUnion, + textContents: string[], + imageContents: Array<{ data: string; mimeType: string }>, + audioContents: Array<{ data: string; mimeType: string }>, + otherContents: string[], + ): void { + if (typeof content === 'string') { + if (content.trim()) { + textContents.push(content); + } + return; + } + + if ('parts' in content && content.parts) { + for (const part of content.parts) { + this.processPart( + part, + textContents, + imageContents, + audioContents, + otherContents, + ); + } + } + } + + /** + * Process a single part and add to appropriate arrays + */ + private processPart( + part: Part | string, + textContents: string[], + imageContents: Array<{ data: string; mimeType: string }>, + audioContents: Array<{ data: string; mimeType: string }>, + otherContents: string[], + ): void { + if (typeof part === 'string') { + if (part.trim()) { + textContents.push(part); + } + return; + } + + if ('text' in part && part.text) { + textContents.push(part.text); + return; + } + + if ('inlineData' in part && part.inlineData) { + const { data, mimeType } = part.inlineData; + if (mimeType && mimeType.startsWith('image/')) { + imageContents.push({ data: data || '', mimeType }); + return; + } + if (mimeType && mimeType.startsWith('audio/')) { + audioContents.push({ data: data || '', mimeType }); + return; + } + } + + if ('fileData' in part && part.fileData) { + otherContents.push(JSON.stringify(part.fileData)); + return; + } + + if ('functionCall' in part && part.functionCall) { + otherContents.push(JSON.stringify(part.functionCall)); + return; + } + + if ('functionResponse' in part && part.functionResponse) { + otherContents.push(JSON.stringify(part.functionResponse)); + return; + } + + // Unknown part type - try to serialize + try { + const serialized = JSON.stringify(part); + if (serialized && serialized !== '{}') { + otherContents.push(serialized); + } + } catch (error) { + console.warn('Failed to serialize unknown part type:', error); + } + } + + /** + * Dispose of resources + */ + async dispose(): Promise { + try { + // Dispose of tokenizers + this.textTokenizer.dispose(); + } catch (error) { + console.warn('Error disposing request tokenizer:', error); + } + } +} diff --git a/packages/core/src/utils/request-tokenizer/textTokenizer.test.ts b/packages/core/src/utils/request-tokenizer/textTokenizer.test.ts new file mode 100644 index 00000000..f29155a8 --- /dev/null +++ b/packages/core/src/utils/request-tokenizer/textTokenizer.test.ts @@ -0,0 +1,347 @@ +/** + * @license + * Copyright 2025 Qwen + * SPDX-License-Identifier: Apache-2.0 + */ + +import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest'; +import { TextTokenizer } from './textTokenizer.js'; + +// Mock tiktoken at the top level with hoisted functions +const mockEncode = vi.hoisted(() => vi.fn()); +const mockFree = vi.hoisted(() => vi.fn()); +const mockGetEncoding = vi.hoisted(() => vi.fn()); + +vi.mock('tiktoken', () => ({ + get_encoding: mockGetEncoding, +})); + +describe('TextTokenizer', () => { + let tokenizer: TextTokenizer; + let consoleWarnSpy: ReturnType; + + beforeEach(() => { + vi.resetAllMocks(); + consoleWarnSpy = vi.spyOn(console, 'warn').mockImplementation(() => {}); + + // Default mock implementation + mockGetEncoding.mockReturnValue({ + encode: mockEncode, + free: mockFree, + }); + }); + + afterEach(() => { + vi.restoreAllMocks(); + tokenizer?.dispose(); + }); + + describe('constructor', () => { + it('should create tokenizer with default encoding', () => { + tokenizer = new TextTokenizer(); + expect(tokenizer).toBeInstanceOf(TextTokenizer); + }); + + it('should create tokenizer with custom encoding', () => { + tokenizer = new TextTokenizer('gpt2'); + expect(tokenizer).toBeInstanceOf(TextTokenizer); + }); + }); + + describe('calculateTokens', () => { + beforeEach(() => { + tokenizer = new TextTokenizer(); + }); + + it('should return 0 for empty text', async () => { + const result = await tokenizer.calculateTokens(''); + expect(result).toBe(0); + }); + + it('should return 0 for null/undefined text', async () => { + const result1 = await tokenizer.calculateTokens( + null as unknown as string, + ); + const result2 = await tokenizer.calculateTokens( + undefined as unknown as string, + ); + expect(result1).toBe(0); + expect(result2).toBe(0); + }); + + it('should calculate tokens using tiktoken when available', async () => { + const testText = 'Hello, world!'; + const mockTokens = [1, 2, 3, 4, 5]; // 5 tokens + mockEncode.mockReturnValue(mockTokens); + + const result = await tokenizer.calculateTokens(testText); + + expect(mockGetEncoding).toHaveBeenCalledWith('cl100k_base'); + expect(mockEncode).toHaveBeenCalledWith(testText); + expect(result).toBe(5); + }); + + it('should use fallback calculation when tiktoken fails to load', async () => { + mockGetEncoding.mockImplementation(() => { + throw new Error('Failed to load tiktoken'); + }); + + const testText = 'Hello, world!'; // 13 characters + const result = await tokenizer.calculateTokens(testText); + + expect(consoleWarnSpy).toHaveBeenCalledWith( + 'Failed to load tiktoken with encoding cl100k_base:', + expect.any(Error), + ); + // Fallback: Math.ceil(13 / 4) = 4 + expect(result).toBe(4); + }); + + it('should use fallback calculation when encoding fails', async () => { + mockEncode.mockImplementation(() => { + throw new Error('Encoding failed'); + }); + + const testText = 'Hello, world!'; // 13 characters + const result = await tokenizer.calculateTokens(testText); + + expect(consoleWarnSpy).toHaveBeenCalledWith( + 'Error encoding text with tiktoken:', + expect.any(Error), + ); + // Fallback: Math.ceil(13 / 4) = 4 + expect(result).toBe(4); + }); + + it('should handle very long text', async () => { + const longText = 'a'.repeat(10000); + const mockTokens = new Array(2500); // 2500 tokens + mockEncode.mockReturnValue(mockTokens); + + const result = await tokenizer.calculateTokens(longText); + + expect(result).toBe(2500); + }); + + it('should handle unicode characters', async () => { + const unicodeText = '你好世界 🌍'; + const mockTokens = [1, 2, 3, 4, 5, 6]; + mockEncode.mockReturnValue(mockTokens); + + const result = await tokenizer.calculateTokens(unicodeText); + + expect(result).toBe(6); + }); + + it('should use custom encoding when specified', async () => { + tokenizer = new TextTokenizer('gpt2'); + const testText = 'Hello, world!'; + const mockTokens = [1, 2, 3]; + mockEncode.mockReturnValue(mockTokens); + + const result = await tokenizer.calculateTokens(testText); + + expect(mockGetEncoding).toHaveBeenCalledWith('gpt2'); + expect(result).toBe(3); + }); + }); + + describe('calculateTokensBatch', () => { + beforeEach(() => { + tokenizer = new TextTokenizer(); + }); + + it('should process multiple texts and return token counts', async () => { + const texts = ['Hello', 'world', 'test']; + mockEncode + .mockReturnValueOnce([1, 2]) // 2 tokens for 'Hello' + .mockReturnValueOnce([3, 4, 5]) // 3 tokens for 'world' + .mockReturnValueOnce([6]); // 1 token for 'test' + + const result = await tokenizer.calculateTokensBatch(texts); + + expect(result).toEqual([2, 3, 1]); + expect(mockEncode).toHaveBeenCalledTimes(3); + }); + + it('should handle empty array', async () => { + const result = await tokenizer.calculateTokensBatch([]); + expect(result).toEqual([]); + }); + + it('should handle array with empty strings', async () => { + const texts = ['', 'hello', '']; + mockEncode.mockReturnValue([1, 2, 3]); // Only called for 'hello' + + const result = await tokenizer.calculateTokensBatch(texts); + + expect(result).toEqual([0, 3, 0]); + expect(mockEncode).toHaveBeenCalledTimes(1); + expect(mockEncode).toHaveBeenCalledWith('hello'); + }); + + it('should use fallback calculation when tiktoken fails to load', async () => { + mockGetEncoding.mockImplementation(() => { + throw new Error('Failed to load tiktoken'); + }); + + const texts = ['Hello', 'world']; // 5 and 5 characters + const result = await tokenizer.calculateTokensBatch(texts); + + expect(consoleWarnSpy).toHaveBeenCalledWith( + 'Failed to load tiktoken with encoding cl100k_base:', + expect.any(Error), + ); + // Fallback: Math.ceil(5/4) = 2 for both + expect(result).toEqual([2, 2]); + }); + + it('should use fallback calculation when encoding fails during batch processing', async () => { + mockEncode.mockImplementation(() => { + throw new Error('Encoding failed'); + }); + + const texts = ['Hello', 'world']; // 5 and 5 characters + const result = await tokenizer.calculateTokensBatch(texts); + + expect(consoleWarnSpy).toHaveBeenCalledWith( + 'Error encoding texts with tiktoken:', + expect.any(Error), + ); + // Fallback: Math.ceil(5/4) = 2 for both + expect(result).toEqual([2, 2]); + }); + + it('should handle null and undefined values in batch', async () => { + const texts = [null, 'hello', undefined, 'world'] as unknown as string[]; + mockEncode + .mockReturnValueOnce([1, 2, 3]) // 3 tokens for 'hello' + .mockReturnValueOnce([4, 5]); // 2 tokens for 'world' + + const result = await tokenizer.calculateTokensBatch(texts); + + expect(result).toEqual([0, 3, 0, 2]); + }); + }); + + describe('dispose', () => { + beforeEach(() => { + tokenizer = new TextTokenizer(); + }); + + it('should free tiktoken encoding when disposing', async () => { + // Initialize the encoding by calling calculateTokens + await tokenizer.calculateTokens('test'); + + tokenizer.dispose(); + + expect(mockFree).toHaveBeenCalled(); + }); + + it('should handle disposal when encoding is not initialized', () => { + expect(() => tokenizer.dispose()).not.toThrow(); + expect(mockFree).not.toHaveBeenCalled(); + }); + + it('should handle disposal when encoding is null', async () => { + // Force encoding to be null by making tiktoken fail + mockGetEncoding.mockImplementation(() => { + throw new Error('Failed to load'); + }); + + await tokenizer.calculateTokens('test'); + + expect(() => tokenizer.dispose()).not.toThrow(); + expect(mockFree).not.toHaveBeenCalled(); + }); + + it('should handle errors during disposal gracefully', async () => { + await tokenizer.calculateTokens('test'); + + mockFree.mockImplementation(() => { + throw new Error('Free failed'); + }); + + tokenizer.dispose(); + + expect(consoleWarnSpy).toHaveBeenCalledWith( + 'Error freeing tiktoken encoding:', + expect.any(Error), + ); + }); + + it('should allow multiple calls to dispose', async () => { + await tokenizer.calculateTokens('test'); + + tokenizer.dispose(); + tokenizer.dispose(); // Second call should not throw + + expect(mockFree).toHaveBeenCalledTimes(1); + }); + }); + + describe('lazy initialization', () => { + beforeEach(() => { + tokenizer = new TextTokenizer(); + }); + + it('should not initialize tiktoken until first use', () => { + expect(mockGetEncoding).not.toHaveBeenCalled(); + }); + + it('should initialize tiktoken on first calculateTokens call', async () => { + await tokenizer.calculateTokens('test'); + expect(mockGetEncoding).toHaveBeenCalledTimes(1); + }); + + it('should not reinitialize tiktoken on subsequent calls', async () => { + await tokenizer.calculateTokens('test1'); + await tokenizer.calculateTokens('test2'); + + expect(mockGetEncoding).toHaveBeenCalledTimes(1); + }); + + it('should initialize tiktoken on first calculateTokensBatch call', async () => { + await tokenizer.calculateTokensBatch(['test']); + expect(mockGetEncoding).toHaveBeenCalledTimes(1); + }); + }); + + describe('edge cases', () => { + beforeEach(() => { + tokenizer = new TextTokenizer(); + }); + + it('should handle very short text', async () => { + const result = await tokenizer.calculateTokens('a'); + + if (mockGetEncoding.mock.calls.length > 0) { + // If tiktoken was called, use its result + expect(mockEncode).toHaveBeenCalledWith('a'); + } else { + // If tiktoken failed, should use fallback: Math.ceil(1/4) = 1 + expect(result).toBe(1); + } + }); + + it('should handle text with only whitespace', async () => { + const whitespaceText = ' \n\t '; + const mockTokens = [1]; + mockEncode.mockReturnValue(mockTokens); + + const result = await tokenizer.calculateTokens(whitespaceText); + + expect(result).toBe(1); + }); + + it('should handle special characters and symbols', async () => { + const specialText = '!@#$%^&*()_+-=[]{}|;:,.<>?'; + const mockTokens = new Array(10); + mockEncode.mockReturnValue(mockTokens); + + const result = await tokenizer.calculateTokens(specialText); + + expect(result).toBe(10); + }); + }); +}); diff --git a/packages/core/src/utils/request-tokenizer/textTokenizer.ts b/packages/core/src/utils/request-tokenizer/textTokenizer.ts new file mode 100644 index 00000000..86c71d4c --- /dev/null +++ b/packages/core/src/utils/request-tokenizer/textTokenizer.ts @@ -0,0 +1,97 @@ +/** + * @license + * Copyright 2025 Qwen + * SPDX-License-Identifier: Apache-2.0 + */ + +import type { TiktokenEncoding, Tiktoken } from 'tiktoken'; +import { get_encoding } from 'tiktoken'; + +/** + * Text tokenizer for calculating text tokens using tiktoken + */ +export class TextTokenizer { + private encoding: Tiktoken | null = null; + private encodingName: string; + + constructor(encodingName: string = 'cl100k_base') { + this.encodingName = encodingName; + } + + /** + * Initialize the tokenizer (lazy loading) + */ + private async ensureEncoding(): Promise { + if (this.encoding) return; + + try { + // Use type assertion since we know the encoding name is valid + this.encoding = get_encoding(this.encodingName as TiktokenEncoding); + } catch (error) { + console.warn( + `Failed to load tiktoken with encoding ${this.encodingName}:`, + error, + ); + this.encoding = null; + } + } + + /** + * Calculate tokens for text content + */ + async calculateTokens(text: string): Promise { + if (!text) return 0; + + await this.ensureEncoding(); + + if (this.encoding) { + try { + return this.encoding.encode(text).length; + } catch (error) { + console.warn('Error encoding text with tiktoken:', error); + } + } + + // Fallback: rough approximation using character count + // This is a conservative estimate: 1 token ≈ 4 characters for most languages + return Math.ceil(text.length / 4); + } + + /** + * Calculate tokens for multiple text strings in parallel + */ + async calculateTokensBatch(texts: string[]): Promise { + await this.ensureEncoding(); + + if (this.encoding) { + try { + return texts.map((text) => { + if (!text) return 0; + // this.encoding may be null, add a null check to satisfy lint + return this.encoding ? this.encoding.encode(text).length : 0; + }); + } catch (error) { + console.warn('Error encoding texts with tiktoken:', error); + // In case of error, return fallback estimation for all texts + return texts.map((text) => Math.ceil((text || '').length / 4)); + } + } + + // Fallback for batch processing + return texts.map((text) => Math.ceil((text || '').length / 4)); + } + + /** + * Dispose of resources + */ + dispose(): void { + if (this.encoding) { + try { + this.encoding.free(); + } catch (error) { + console.warn('Error freeing tiktoken encoding:', error); + } + this.encoding = null; + } + } +} diff --git a/packages/core/src/utils/request-tokenizer/types.ts b/packages/core/src/utils/request-tokenizer/types.ts new file mode 100644 index 00000000..b7c09f21 --- /dev/null +++ b/packages/core/src/utils/request-tokenizer/types.ts @@ -0,0 +1,64 @@ +/** + * @license + * Copyright 2025 Qwen + * SPDX-License-Identifier: Apache-2.0 + */ + +import { CountTokensParameters } from '@google/genai'; + +/** + * Token calculation result for different content types + */ +export interface TokenCalculationResult { + /** Total tokens calculated */ + totalTokens: number; + /** Breakdown by content type */ + breakdown: { + textTokens: number; + imageTokens: number; + audioTokens: number; + otherTokens: number; + }; + /** Processing time in milliseconds */ + processingTime: number; +} + +/** + * Configuration for token calculation + */ +export interface TokenizerConfig { + /** Custom text tokenizer encoding (defaults to cl100k_base) */ + textEncoding?: string; +} + +/** + * Image metadata extracted from base64 data + */ +export interface ImageMetadata { + /** Image width in pixels */ + width: number; + /** Image height in pixels */ + height: number; + /** MIME type of the image */ + mimeType: string; + /** Size of the base64 data in bytes */ + dataSize: number; +} + +/** + * Request tokenizer interface + */ +export interface RequestTokenizer { + /** + * Calculate tokens for a request + */ + calculateTokens( + request: CountTokensParameters, + config?: TokenizerConfig, + ): Promise; + + /** + * Dispose of resources (worker threads, etc.) + */ + dispose(): Promise; +}