From 761833c915c22aedd6274952d38895a4aef61101 Mon Sep 17 00:00:00 2001 From: Mingholy Date: Thu, 18 Sep 2025 13:32:00 +0800 Subject: [PATCH] Vision model support for Qwen-OAuth (#525) * refactor: openaiContentGenerator * refactor: optimize stream handling * refactor: re-organize refactored files * fix: unit test cases * feat: `/model` command for switching to vision model * fix: lint error * feat: add image tokenizer to fit vlm context window * fix: lint and type errors * feat: add `visionModelPreview` to control default visibility of vision models * fix: remove deprecated files * fix: align supported image formats with bailian doc --- packages/cli/src/config/settingsSchema.ts | 10 + .../src/services/BuiltinCommandLoader.test.ts | 10 + .../cli/src/services/BuiltinCommandLoader.ts | 2 + packages/cli/src/ui/App.tsx | 110 + .../cli/src/ui/commands/modelCommand.test.ts | 179 + packages/cli/src/ui/commands/modelCommand.ts | 88 + packages/cli/src/ui/commands/types.ts | 1 + .../components/ModelSelectionDialog.test.tsx | 246 ++ .../ui/components/ModelSelectionDialog.tsx | 87 + .../ui/components/ModelSwitchDialog.test.tsx | 185 + .../src/ui/components/ModelSwitchDialog.tsx | 89 + .../ui/hooks/slashCommandProcessor.test.ts | 21 + .../cli/src/ui/hooks/slashCommandProcessor.ts | 5 + .../cli/src/ui/hooks/useGeminiStream.test.tsx | 177 + packages/cli/src/ui/hooks/useGeminiStream.ts | 40 +- .../src/ui/hooks/useVisionAutoSwitch.test.ts | 374 ++ .../cli/src/ui/hooks/useVisionAutoSwitch.ts | 304 ++ packages/cli/src/ui/models/availableModels.ts | 52 + packages/core/index.ts | 1 + .../__tests__/openaiTimeoutHandling.test.ts | 114 +- packages/core/src/core/geminiChat.ts | 2 +- .../src/core/openaiContentGenerator.test.ts | 3511 ----------------- .../core/src/core/openaiContentGenerator.ts | 1711 -------- .../openaiContentGenerator.test.ts | 39 +- .../openaiContentGenerator.ts | 34 +- .../core/openaiContentGenerator/pipeline.ts | 83 +- .../provider/dashscope.ts | 10 + packages/core/src/core/tokenLimits.ts | 3 + packages/core/src/core/turn.test.ts | 2 +- packages/core/src/core/turn.ts | 2 +- .../src/qwen/qwenContentGenerator.test.ts | 23 +- .../core/src/qwen/qwenContentGenerator.ts | 4 +- .../request-tokenizer/imageTokenizer.test.ts | 157 + .../utils/request-tokenizer/imageTokenizer.ts | 505 +++ .../core/src/utils/request-tokenizer/index.ts | 40 + .../requestTokenizer.test.ts | 293 ++ .../request-tokenizer/requestTokenizer.ts | 341 ++ .../supportedImageFormats.ts | 56 + .../request-tokenizer/textTokenizer.test.ts | 347 ++ .../utils/request-tokenizer/textTokenizer.ts | 97 + .../core/src/utils/request-tokenizer/types.ts | 64 + 41 files changed, 4083 insertions(+), 5336 deletions(-) create mode 100644 packages/cli/src/ui/commands/modelCommand.test.ts create mode 100644 packages/cli/src/ui/commands/modelCommand.ts create mode 100644 packages/cli/src/ui/components/ModelSelectionDialog.test.tsx create mode 100644 packages/cli/src/ui/components/ModelSelectionDialog.tsx create mode 100644 packages/cli/src/ui/components/ModelSwitchDialog.test.tsx create mode 100644 packages/cli/src/ui/components/ModelSwitchDialog.tsx create mode 100644 packages/cli/src/ui/hooks/useVisionAutoSwitch.test.ts create mode 100644 packages/cli/src/ui/hooks/useVisionAutoSwitch.ts create mode 100644 packages/cli/src/ui/models/availableModels.ts delete mode 100644 packages/core/src/core/openaiContentGenerator.test.ts delete mode 100644 packages/core/src/core/openaiContentGenerator.ts create mode 100644 packages/core/src/utils/request-tokenizer/imageTokenizer.test.ts create mode 100644 packages/core/src/utils/request-tokenizer/imageTokenizer.ts create mode 100644 packages/core/src/utils/request-tokenizer/index.ts create mode 100644 packages/core/src/utils/request-tokenizer/requestTokenizer.test.ts create mode 100644 packages/core/src/utils/request-tokenizer/requestTokenizer.ts create mode 100644 packages/core/src/utils/request-tokenizer/supportedImageFormats.ts create mode 100644 packages/core/src/utils/request-tokenizer/textTokenizer.test.ts create mode 100644 packages/core/src/utils/request-tokenizer/textTokenizer.ts create mode 100644 packages/core/src/utils/request-tokenizer/types.ts diff --git a/packages/cli/src/config/settingsSchema.ts b/packages/cli/src/config/settingsSchema.ts index 1cbd63e0..c7f1e94e 100644 --- a/packages/cli/src/config/settingsSchema.ts +++ b/packages/cli/src/config/settingsSchema.ts @@ -741,6 +741,16 @@ export const SETTINGS_SCHEMA = { description: 'Enable extension management features.', showInDialog: false, }, + visionModelPreview: { + type: 'boolean', + label: 'Vision Model Preview', + category: 'Experimental', + requiresRestart: false, + default: false, + description: + 'Enable vision model support and auto-switching functionality. When disabled, vision models like qwen-vl-max-latest will be hidden and auto-switching will not occur.', + showInDialog: true, + }, }, }, diff --git a/packages/cli/src/services/BuiltinCommandLoader.test.ts b/packages/cli/src/services/BuiltinCommandLoader.test.ts index dcede5a3..38deb425 100644 --- a/packages/cli/src/services/BuiltinCommandLoader.test.ts +++ b/packages/cli/src/services/BuiltinCommandLoader.test.ts @@ -56,6 +56,13 @@ vi.mock('../ui/commands/mcpCommand.js', () => ({ kind: 'BUILT_IN', }, })); +vi.mock('../ui/commands/modelCommand.js', () => ({ + modelCommand: { + name: 'model', + description: 'Model command', + kind: 'BUILT_IN', + }, +})); describe('BuiltinCommandLoader', () => { let mockConfig: Config; @@ -126,5 +133,8 @@ describe('BuiltinCommandLoader', () => { const mcpCmd = commands.find((c) => c.name === 'mcp'); expect(mcpCmd).toBeDefined(); + + const modelCmd = commands.find((c) => c.name === 'model'); + expect(modelCmd).toBeDefined(); }); }); diff --git a/packages/cli/src/services/BuiltinCommandLoader.ts b/packages/cli/src/services/BuiltinCommandLoader.ts index 74de2a3c..12c3cfc9 100644 --- a/packages/cli/src/services/BuiltinCommandLoader.ts +++ b/packages/cli/src/services/BuiltinCommandLoader.ts @@ -35,6 +35,7 @@ import { settingsCommand } from '../ui/commands/settingsCommand.js'; import { vimCommand } from '../ui/commands/vimCommand.js'; import { setupGithubCommand } from '../ui/commands/setupGithubCommand.js'; import { terminalSetupCommand } from '../ui/commands/terminalSetupCommand.js'; +import { modelCommand } from '../ui/commands/modelCommand.js'; import { agentsCommand } from '../ui/commands/agentsCommand.js'; /** @@ -71,6 +72,7 @@ export class BuiltinCommandLoader implements ICommandLoader { initCommand, mcpCommand, memoryCommand, + modelCommand, privacyCommand, quitCommand, quitConfirmCommand, diff --git a/packages/cli/src/ui/App.tsx b/packages/cli/src/ui/App.tsx index 9ec15650..85691182 100644 --- a/packages/cli/src/ui/App.tsx +++ b/packages/cli/src/ui/App.tsx @@ -53,6 +53,17 @@ import { FolderTrustDialog } from './components/FolderTrustDialog.js'; import { ShellConfirmationDialog } from './components/ShellConfirmationDialog.js'; import { QuitConfirmationDialog } from './components/QuitConfirmationDialog.js'; import { RadioButtonSelect } from './components/shared/RadioButtonSelect.js'; +import { ModelSelectionDialog } from './components/ModelSelectionDialog.js'; +import { + ModelSwitchDialog, + type VisionSwitchOutcome, +} from './components/ModelSwitchDialog.js'; +import { + getOpenAIAvailableModelFromEnv, + getFilteredQwenModels, + type AvailableModel, +} from './models/availableModels.js'; +import { processVisionSwitchOutcome } from './hooks/useVisionAutoSwitch.js'; import { AgentCreationWizard, AgentsManagerDialog, @@ -248,6 +259,20 @@ const App = ({ config, settings, startupWarnings = [], version }: AppProps) => { onWorkspaceMigrationDialogClose, } = useWorkspaceMigration(settings); + // Model selection dialog states + const [isModelSelectionDialogOpen, setIsModelSelectionDialogOpen] = + useState(false); + const [isVisionSwitchDialogOpen, setIsVisionSwitchDialogOpen] = + useState(false); + const [visionSwitchResolver, setVisionSwitchResolver] = useState<{ + resolve: (result: { + modelOverride?: string; + persistSessionModel?: string; + showGuidance?: boolean; + }) => void; + reject: () => void; + } | null>(null); + useEffect(() => { const unsubscribe = ideContext.subscribeToIdeContext(setIdeContextState); // Set the initial value @@ -590,6 +615,75 @@ const App = ({ config, settings, startupWarnings = [], version }: AppProps) => { openAuthDialog(); }, [openAuthDialog, setAuthError]); + // Vision switch handler for auto-switch functionality + const handleVisionSwitchRequired = useCallback( + async (_query: unknown) => + new Promise<{ + modelOverride?: string; + persistSessionModel?: string; + showGuidance?: boolean; + }>((resolve, reject) => { + setVisionSwitchResolver({ resolve, reject }); + setIsVisionSwitchDialogOpen(true); + }), + [], + ); + + const handleVisionSwitchSelect = useCallback( + (outcome: VisionSwitchOutcome) => { + setIsVisionSwitchDialogOpen(false); + if (visionSwitchResolver) { + const result = processVisionSwitchOutcome(outcome); + visionSwitchResolver.resolve(result); + setVisionSwitchResolver(null); + } + }, + [visionSwitchResolver], + ); + + const handleModelSelectionOpen = useCallback(() => { + setIsModelSelectionDialogOpen(true); + }, []); + + const handleModelSelectionClose = useCallback(() => { + setIsModelSelectionDialogOpen(false); + }, []); + + const handleModelSelect = useCallback( + (modelId: string) => { + config.setModel(modelId); + setCurrentModel(modelId); + setIsModelSelectionDialogOpen(false); + addItem( + { + type: MessageType.INFO, + text: `Switched model to \`${modelId}\` for this session.`, + }, + Date.now(), + ); + }, + [config, setCurrentModel, addItem], + ); + + const getAvailableModelsForCurrentAuth = useCallback((): AvailableModel[] => { + const contentGeneratorConfig = config.getContentGeneratorConfig(); + if (!contentGeneratorConfig) return []; + + const visionModelPreviewEnabled = + settings.merged.experimental?.visionModelPreview ?? false; + + switch (contentGeneratorConfig.authType) { + case AuthType.QWEN_OAUTH: + return getFilteredQwenModels(visionModelPreviewEnabled); + case AuthType.USE_OPENAI: { + const openAIModel = getOpenAIAvailableModelFromEnv(); + return openAIModel ? [openAIModel] : []; + } + default: + return []; + } + }, [config, settings.merged.experimental?.visionModelPreview]); + // Core hooks and processors const { vimEnabled: vimModeEnabled, @@ -620,6 +714,7 @@ const App = ({ config, settings, startupWarnings = [], version }: AppProps) => { setQuittingMessages, openPrivacyNotice, openSettingsDialog, + handleModelSelectionOpen, openSubagentCreateDialog, openAgentsManagerDialog, toggleVimEnabled, @@ -664,6 +759,8 @@ const App = ({ config, settings, startupWarnings = [], version }: AppProps) => { setModelSwitchedFromQuotaError, refreshStatic, () => cancelHandlerRef.current(), + settings.merged.experimental?.visionModelPreview ?? false, + handleVisionSwitchRequired, ); const pendingHistoryItems = useMemo( @@ -1034,6 +1131,8 @@ const App = ({ config, settings, startupWarnings = [], version }: AppProps) => { !isAuthDialogOpen && !isThemeDialogOpen && !isEditorDialogOpen && + !isModelSelectionDialogOpen && + !isVisionSwitchDialogOpen && !isSubagentCreateDialogOpen && !showPrivacyNotice && !showWelcomeBackDialog && @@ -1055,6 +1154,8 @@ const App = ({ config, settings, startupWarnings = [], version }: AppProps) => { showWelcomeBackDialog, welcomeBackChoice, geminiClient, + isModelSelectionDialogOpen, + isVisionSwitchDialogOpen, ]); if (quittingMessages) { @@ -1322,6 +1423,15 @@ const App = ({ config, settings, startupWarnings = [], version }: AppProps) => { onExit={exitEditorDialog} /> + ) : isModelSelectionDialogOpen ? ( + + ) : isVisionSwitchDialogOpen ? ( + ) : showPrivacyNotice ? ( setShowPrivacyNotice(false)} diff --git a/packages/cli/src/ui/commands/modelCommand.test.ts b/packages/cli/src/ui/commands/modelCommand.test.ts new file mode 100644 index 00000000..f3aaad52 --- /dev/null +++ b/packages/cli/src/ui/commands/modelCommand.test.ts @@ -0,0 +1,179 @@ +/** + * @license + * Copyright 2025 Qwen + * SPDX-License-Identifier: Apache-2.0 + */ + +import { describe, it, expect, beforeEach, vi } from 'vitest'; +import { modelCommand } from './modelCommand.js'; +import { type CommandContext } from './types.js'; +import { createMockCommandContext } from '../../test-utils/mockCommandContext.js'; +import { + AuthType, + type ContentGeneratorConfig, + type Config, +} from '@qwen-code/qwen-code-core'; +import * as availableModelsModule from '../models/availableModels.js'; + +// Mock the availableModels module +vi.mock('../models/availableModels.js', () => ({ + AVAILABLE_MODELS_QWEN: [ + { id: 'qwen3-coder-plus', label: 'qwen3-coder-plus' }, + { id: 'qwen-vl-max-latest', label: 'qwen-vl-max', isVision: true }, + ], + getOpenAIAvailableModelFromEnv: vi.fn(), +})); + +// Helper function to create a mock config +function createMockConfig( + contentGeneratorConfig: ContentGeneratorConfig | null, +): Partial { + return { + getContentGeneratorConfig: vi.fn().mockReturnValue(contentGeneratorConfig), + }; +} + +describe('modelCommand', () => { + let mockContext: CommandContext; + const mockGetOpenAIAvailableModelFromEnv = vi.mocked( + availableModelsModule.getOpenAIAvailableModelFromEnv, + ); + + beforeEach(() => { + mockContext = createMockCommandContext(); + vi.clearAllMocks(); + }); + + it('should have the correct name and description', () => { + expect(modelCommand.name).toBe('model'); + expect(modelCommand.description).toBe('Switch the model for this session'); + }); + + it('should return error when config is not available', async () => { + mockContext.services.config = null; + + const result = await modelCommand.action!(mockContext, ''); + + expect(result).toEqual({ + type: 'message', + messageType: 'error', + content: 'Configuration not available.', + }); + }); + + it('should return error when content generator config is not available', async () => { + const mockConfig = createMockConfig(null); + mockContext.services.config = mockConfig as Config; + + const result = await modelCommand.action!(mockContext, ''); + + expect(result).toEqual({ + type: 'message', + messageType: 'error', + content: 'Content generator configuration not available.', + }); + }); + + it('should return error when auth type is not available', async () => { + const mockConfig = createMockConfig({ + model: 'test-model', + authType: undefined, + }); + mockContext.services.config = mockConfig as Config; + + const result = await modelCommand.action!(mockContext, ''); + + expect(result).toEqual({ + type: 'message', + messageType: 'error', + content: 'Authentication type not available.', + }); + }); + + it('should return dialog action for QWEN_OAUTH auth type', async () => { + const mockConfig = createMockConfig({ + model: 'test-model', + authType: AuthType.QWEN_OAUTH, + }); + mockContext.services.config = mockConfig as Config; + + const result = await modelCommand.action!(mockContext, ''); + + expect(result).toEqual({ + type: 'dialog', + dialog: 'model', + }); + }); + + it('should return dialog action for USE_OPENAI auth type when model is available', async () => { + mockGetOpenAIAvailableModelFromEnv.mockReturnValue({ + id: 'gpt-4', + label: 'gpt-4', + }); + + const mockConfig = createMockConfig({ + model: 'test-model', + authType: AuthType.USE_OPENAI, + }); + mockContext.services.config = mockConfig as Config; + + const result = await modelCommand.action!(mockContext, ''); + + expect(result).toEqual({ + type: 'dialog', + dialog: 'model', + }); + }); + + it('should return error for USE_OPENAI auth type when no model is available', async () => { + mockGetOpenAIAvailableModelFromEnv.mockReturnValue(null); + + const mockConfig = createMockConfig({ + model: 'test-model', + authType: AuthType.USE_OPENAI, + }); + mockContext.services.config = mockConfig as Config; + + const result = await modelCommand.action!(mockContext, ''); + + expect(result).toEqual({ + type: 'message', + messageType: 'error', + content: + 'No models available for the current authentication type (openai).', + }); + }); + + it('should return error for unsupported auth types', async () => { + const mockConfig = createMockConfig({ + model: 'test-model', + authType: 'UNSUPPORTED_AUTH_TYPE' as AuthType, + }); + mockContext.services.config = mockConfig as Config; + + const result = await modelCommand.action!(mockContext, ''); + + expect(result).toEqual({ + type: 'message', + messageType: 'error', + content: + 'No models available for the current authentication type (UNSUPPORTED_AUTH_TYPE).', + }); + }); + + it('should handle undefined auth type', async () => { + const mockConfig = createMockConfig({ + model: 'test-model', + authType: undefined, + }); + mockContext.services.config = mockConfig as Config; + + const result = await modelCommand.action!(mockContext, ''); + + expect(result).toEqual({ + type: 'message', + messageType: 'error', + content: 'Authentication type not available.', + }); + }); +}); diff --git a/packages/cli/src/ui/commands/modelCommand.ts b/packages/cli/src/ui/commands/modelCommand.ts new file mode 100644 index 00000000..9e4fdcb0 --- /dev/null +++ b/packages/cli/src/ui/commands/modelCommand.ts @@ -0,0 +1,88 @@ +/** + * @license + * Copyright 2025 Qwen + * SPDX-License-Identifier: Apache-2.0 + */ + +import { AuthType } from '@qwen-code/qwen-code-core'; +import type { + SlashCommand, + CommandContext, + OpenDialogActionReturn, + MessageActionReturn, +} from './types.js'; +import { CommandKind } from './types.js'; +import { + AVAILABLE_MODELS_QWEN, + getOpenAIAvailableModelFromEnv, + type AvailableModel, +} from '../models/availableModels.js'; + +function getAvailableModelsForAuthType(authType: AuthType): AvailableModel[] { + switch (authType) { + case AuthType.QWEN_OAUTH: + return AVAILABLE_MODELS_QWEN; + case AuthType.USE_OPENAI: { + const openAIModel = getOpenAIAvailableModelFromEnv(); + return openAIModel ? [openAIModel] : []; + } + default: + // For other auth types, return empty array for now + // This can be expanded later according to the design doc + return []; + } +} + +export const modelCommand: SlashCommand = { + name: 'model', + description: 'Switch the model for this session', + kind: CommandKind.BUILT_IN, + action: async ( + context: CommandContext, + ): Promise => { + const { services } = context; + const { config } = services; + + if (!config) { + return { + type: 'message', + messageType: 'error', + content: 'Configuration not available.', + }; + } + + const contentGeneratorConfig = config.getContentGeneratorConfig(); + if (!contentGeneratorConfig) { + return { + type: 'message', + messageType: 'error', + content: 'Content generator configuration not available.', + }; + } + + const authType = contentGeneratorConfig.authType; + if (!authType) { + return { + type: 'message', + messageType: 'error', + content: 'Authentication type not available.', + }; + } + + const availableModels = getAvailableModelsForAuthType(authType); + + if (availableModels.length === 0) { + return { + type: 'message', + messageType: 'error', + content: `No models available for the current authentication type (${authType}).`, + }; + } + + // Trigger model selection dialog + return { + type: 'dialog', + dialog: 'model', + }; + }, +}; diff --git a/packages/cli/src/ui/commands/types.ts b/packages/cli/src/ui/commands/types.ts index 971e8a02..18484d82 100644 --- a/packages/cli/src/ui/commands/types.ts +++ b/packages/cli/src/ui/commands/types.ts @@ -116,6 +116,7 @@ export interface OpenDialogActionReturn { | 'editor' | 'privacy' | 'settings' + | 'model' | 'subagent_create' | 'subagent_list'; } diff --git a/packages/cli/src/ui/components/ModelSelectionDialog.test.tsx b/packages/cli/src/ui/components/ModelSelectionDialog.test.tsx new file mode 100644 index 00000000..4a5b6bcf --- /dev/null +++ b/packages/cli/src/ui/components/ModelSelectionDialog.test.tsx @@ -0,0 +1,246 @@ +/** + * @license + * Copyright 2025 Qwen + * SPDX-License-Identifier: Apache-2.0 + */ + +import React from 'react'; +import { render } from 'ink-testing-library'; +import { describe, it, expect, vi, beforeEach } from 'vitest'; +import { ModelSelectionDialog } from './ModelSelectionDialog.js'; +import type { AvailableModel } from '../models/availableModels.js'; +import type { RadioSelectItem } from './shared/RadioButtonSelect.js'; + +// Mock the useKeypress hook +const mockUseKeypress = vi.hoisted(() => vi.fn()); +vi.mock('../hooks/useKeypress.js', () => ({ + useKeypress: mockUseKeypress, +})); + +// Mock the RadioButtonSelect component +const mockRadioButtonSelect = vi.hoisted(() => vi.fn()); +vi.mock('./shared/RadioButtonSelect.js', () => ({ + RadioButtonSelect: mockRadioButtonSelect, +})); + +describe('ModelSelectionDialog', () => { + const mockAvailableModels: AvailableModel[] = [ + { id: 'qwen3-coder-plus', label: 'qwen3-coder-plus' }, + { id: 'qwen-vl-max-latest', label: 'qwen-vl-max', isVision: true }, + { id: 'gpt-4', label: 'GPT-4' }, + ]; + + const mockOnSelect = vi.fn(); + const mockOnCancel = vi.fn(); + + beforeEach(() => { + vi.clearAllMocks(); + + // Mock RadioButtonSelect to return a simple div + mockRadioButtonSelect.mockReturnValue( + React.createElement('div', { 'data-testid': 'radio-select' }), + ); + }); + + it('should setup escape key handler to call onCancel', () => { + render( + , + ); + + expect(mockUseKeypress).toHaveBeenCalledWith(expect.any(Function), { + isActive: true, + }); + + // Simulate escape key press + const keypressHandler = mockUseKeypress.mock.calls[0][0]; + keypressHandler({ name: 'escape' }); + + expect(mockOnCancel).toHaveBeenCalled(); + }); + + it('should not call onCancel for non-escape keys', () => { + render( + , + ); + + const keypressHandler = mockUseKeypress.mock.calls[0][0]; + keypressHandler({ name: 'enter' }); + + expect(mockOnCancel).not.toHaveBeenCalled(); + }); + + it('should set correct initial index for current model', () => { + render( + , + ); + + const callArgs = mockRadioButtonSelect.mock.calls[0][0]; + expect(callArgs.initialIndex).toBe(1); // qwen-vl-max-latest is at index 1 + }); + + it('should set initial index to 0 when current model is not found', () => { + render( + , + ); + + const callArgs = mockRadioButtonSelect.mock.calls[0][0]; + expect(callArgs.initialIndex).toBe(0); + }); + + it('should call onSelect when a model is selected', () => { + render( + , + ); + + const callArgs = mockRadioButtonSelect.mock.calls[0][0]; + expect(typeof callArgs.onSelect).toBe('function'); + + // Simulate selection + const onSelectCallback = mockRadioButtonSelect.mock.calls[0][0].onSelect; + onSelectCallback('qwen-vl-max-latest'); + + expect(mockOnSelect).toHaveBeenCalledWith('qwen-vl-max-latest'); + }); + + it('should handle empty models array', () => { + render( + , + ); + + const callArgs = mockRadioButtonSelect.mock.calls[0][0]; + expect(callArgs.items).toEqual([]); + expect(callArgs.initialIndex).toBe(0); + }); + + it('should create correct option items with proper labels', () => { + render( + , + ); + + const expectedItems = [ + { + label: 'qwen3-coder-plus (current)', + value: 'qwen3-coder-plus', + }, + { + label: 'qwen-vl-max [Vision]', + value: 'qwen-vl-max-latest', + }, + { + label: 'GPT-4', + value: 'gpt-4', + }, + ]; + + const callArgs = mockRadioButtonSelect.mock.calls[0][0]; + expect(callArgs.items).toEqual(expectedItems); + }); + + it('should show vision indicator for vision models', () => { + render( + , + ); + + const callArgs = mockRadioButtonSelect.mock.calls[0][0]; + const visionModelItem = callArgs.items.find( + (item: RadioSelectItem) => item.value === 'qwen-vl-max-latest', + ); + + expect(visionModelItem?.label).toContain('[Vision]'); + }); + + it('should show current indicator for the current model', () => { + render( + , + ); + + const callArgs = mockRadioButtonSelect.mock.calls[0][0]; + const currentModelItem = callArgs.items.find( + (item: RadioSelectItem) => item.value === 'qwen-vl-max-latest', + ); + + expect(currentModelItem?.label).toContain('(current)'); + }); + + it('should pass isFocused prop to RadioButtonSelect', () => { + render( + , + ); + + const callArgs = mockRadioButtonSelect.mock.calls[0][0]; + expect(callArgs.isFocused).toBe(true); + }); + + it('should handle multiple onSelect calls correctly', () => { + render( + , + ); + + const onSelectCallback = mockRadioButtonSelect.mock.calls[0][0].onSelect; + + // Call multiple times + onSelectCallback('qwen3-coder-plus'); + onSelectCallback('qwen-vl-max-latest'); + onSelectCallback('gpt-4'); + + expect(mockOnSelect).toHaveBeenCalledTimes(3); + expect(mockOnSelect).toHaveBeenNthCalledWith(1, 'qwen3-coder-plus'); + expect(mockOnSelect).toHaveBeenNthCalledWith(2, 'qwen-vl-max-latest'); + expect(mockOnSelect).toHaveBeenNthCalledWith(3, 'gpt-4'); + }); +}); diff --git a/packages/cli/src/ui/components/ModelSelectionDialog.tsx b/packages/cli/src/ui/components/ModelSelectionDialog.tsx new file mode 100644 index 00000000..d43e69f3 --- /dev/null +++ b/packages/cli/src/ui/components/ModelSelectionDialog.tsx @@ -0,0 +1,87 @@ +/** + * @license + * Copyright 2025 Qwen + * SPDX-License-Identifier: Apache-2.0 + */ + +import type React from 'react'; +import { Box, Text } from 'ink'; +import { Colors } from '../colors.js'; +import { + RadioButtonSelect, + type RadioSelectItem, +} from './shared/RadioButtonSelect.js'; +import { useKeypress } from '../hooks/useKeypress.js'; +import type { AvailableModel } from '../models/availableModels.js'; + +export interface ModelSelectionDialogProps { + availableModels: AvailableModel[]; + currentModel: string; + onSelect: (modelId: string) => void; + onCancel: () => void; +} + +export const ModelSelectionDialog: React.FC = ({ + availableModels, + currentModel, + onSelect, + onCancel, +}) => { + useKeypress( + (key) => { + if (key.name === 'escape') { + onCancel(); + } + }, + { isActive: true }, + ); + + const options: Array> = availableModels.map( + (model) => { + const visionIndicator = model.isVision ? ' [Vision]' : ''; + const currentIndicator = model.id === currentModel ? ' (current)' : ''; + return { + label: `${model.label}${visionIndicator}${currentIndicator}`, + value: model.id, + }; + }, + ); + + const initialIndex = Math.max( + 0, + availableModels.findIndex((model) => model.id === currentModel), + ); + + const handleSelect = (modelId: string) => { + onSelect(modelId); + }; + + return ( + + + Select Model + Choose a model for this session: + + + + + + + + Press Enter to select, Esc to cancel + + + ); +}; diff --git a/packages/cli/src/ui/components/ModelSwitchDialog.test.tsx b/packages/cli/src/ui/components/ModelSwitchDialog.test.tsx new file mode 100644 index 00000000..f26dcc55 --- /dev/null +++ b/packages/cli/src/ui/components/ModelSwitchDialog.test.tsx @@ -0,0 +1,185 @@ +/** + * @license + * Copyright 2025 Qwen + * SPDX-License-Identifier: Apache-2.0 + */ + +import React from 'react'; +import { render } from 'ink-testing-library'; +import { describe, it, expect, vi, beforeEach } from 'vitest'; +import { ModelSwitchDialog, VisionSwitchOutcome } from './ModelSwitchDialog.js'; + +// Mock the useKeypress hook +const mockUseKeypress = vi.hoisted(() => vi.fn()); +vi.mock('../hooks/useKeypress.js', () => ({ + useKeypress: mockUseKeypress, +})); + +// Mock the RadioButtonSelect component +const mockRadioButtonSelect = vi.hoisted(() => vi.fn()); +vi.mock('./shared/RadioButtonSelect.js', () => ({ + RadioButtonSelect: mockRadioButtonSelect, +})); + +describe('ModelSwitchDialog', () => { + const mockOnSelect = vi.fn(); + + beforeEach(() => { + vi.clearAllMocks(); + + // Mock RadioButtonSelect to return a simple div + mockRadioButtonSelect.mockReturnValue( + React.createElement('div', { 'data-testid': 'radio-select' }), + ); + }); + + it('should setup RadioButtonSelect with correct options', () => { + render(); + + const expectedItems = [ + { + label: 'Switch for this request only', + value: VisionSwitchOutcome.SwitchOnce, + }, + { + label: 'Switch session to vision model', + value: VisionSwitchOutcome.SwitchSessionToVL, + }, + { + label: 'Do not switch, show guidance', + value: VisionSwitchOutcome.DisallowWithGuidance, + }, + ]; + + const callArgs = mockRadioButtonSelect.mock.calls[0][0]; + expect(callArgs.items).toEqual(expectedItems); + expect(callArgs.initialIndex).toBe(0); + expect(callArgs.isFocused).toBe(true); + }); + + it('should call onSelect when an option is selected', () => { + render(); + + const callArgs = mockRadioButtonSelect.mock.calls[0][0]; + expect(typeof callArgs.onSelect).toBe('function'); + + // Simulate selection of "Switch for this request only" + const onSelectCallback = mockRadioButtonSelect.mock.calls[0][0].onSelect; + onSelectCallback(VisionSwitchOutcome.SwitchOnce); + + expect(mockOnSelect).toHaveBeenCalledWith(VisionSwitchOutcome.SwitchOnce); + }); + + it('should call onSelect with SwitchSessionToVL when second option is selected', () => { + render(); + + const onSelectCallback = mockRadioButtonSelect.mock.calls[0][0].onSelect; + onSelectCallback(VisionSwitchOutcome.SwitchSessionToVL); + + expect(mockOnSelect).toHaveBeenCalledWith( + VisionSwitchOutcome.SwitchSessionToVL, + ); + }); + + it('should call onSelect with DisallowWithGuidance when third option is selected', () => { + render(); + + const onSelectCallback = mockRadioButtonSelect.mock.calls[0][0].onSelect; + onSelectCallback(VisionSwitchOutcome.DisallowWithGuidance); + + expect(mockOnSelect).toHaveBeenCalledWith( + VisionSwitchOutcome.DisallowWithGuidance, + ); + }); + + it('should setup escape key handler to call onSelect with DisallowWithGuidance', () => { + render(); + + expect(mockUseKeypress).toHaveBeenCalledWith(expect.any(Function), { + isActive: true, + }); + + // Simulate escape key press + const keypressHandler = mockUseKeypress.mock.calls[0][0]; + keypressHandler({ name: 'escape' }); + + expect(mockOnSelect).toHaveBeenCalledWith( + VisionSwitchOutcome.DisallowWithGuidance, + ); + }); + + it('should not call onSelect for non-escape keys', () => { + render(); + + const keypressHandler = mockUseKeypress.mock.calls[0][0]; + keypressHandler({ name: 'enter' }); + + expect(mockOnSelect).not.toHaveBeenCalled(); + }); + + it('should set initial index to 0 (first option)', () => { + render(); + + const callArgs = mockRadioButtonSelect.mock.calls[0][0]; + expect(callArgs.initialIndex).toBe(0); + }); + + describe('VisionSwitchOutcome enum', () => { + it('should have correct enum values', () => { + expect(VisionSwitchOutcome.SwitchOnce).toBe('switch_once'); + expect(VisionSwitchOutcome.SwitchSessionToVL).toBe( + 'switch_session_to_vl', + ); + expect(VisionSwitchOutcome.DisallowWithGuidance).toBe( + 'disallow_with_guidance', + ); + }); + }); + + it('should handle multiple onSelect calls correctly', () => { + render(); + + const onSelectCallback = mockRadioButtonSelect.mock.calls[0][0].onSelect; + + // Call multiple times + onSelectCallback(VisionSwitchOutcome.SwitchOnce); + onSelectCallback(VisionSwitchOutcome.SwitchSessionToVL); + onSelectCallback(VisionSwitchOutcome.DisallowWithGuidance); + + expect(mockOnSelect).toHaveBeenCalledTimes(3); + expect(mockOnSelect).toHaveBeenNthCalledWith( + 1, + VisionSwitchOutcome.SwitchOnce, + ); + expect(mockOnSelect).toHaveBeenNthCalledWith( + 2, + VisionSwitchOutcome.SwitchSessionToVL, + ); + expect(mockOnSelect).toHaveBeenNthCalledWith( + 3, + VisionSwitchOutcome.DisallowWithGuidance, + ); + }); + + it('should pass isFocused prop to RadioButtonSelect', () => { + render(); + + const callArgs = mockRadioButtonSelect.mock.calls[0][0]; + expect(callArgs.isFocused).toBe(true); + }); + + it('should handle escape key multiple times', () => { + render(); + + const keypressHandler = mockUseKeypress.mock.calls[0][0]; + + // Call escape multiple times + keypressHandler({ name: 'escape' }); + keypressHandler({ name: 'escape' }); + + expect(mockOnSelect).toHaveBeenCalledTimes(2); + expect(mockOnSelect).toHaveBeenCalledWith( + VisionSwitchOutcome.DisallowWithGuidance, + ); + }); +}); diff --git a/packages/cli/src/ui/components/ModelSwitchDialog.tsx b/packages/cli/src/ui/components/ModelSwitchDialog.tsx new file mode 100644 index 00000000..1a8c73d4 --- /dev/null +++ b/packages/cli/src/ui/components/ModelSwitchDialog.tsx @@ -0,0 +1,89 @@ +/** + * @license + * Copyright 2025 Qwen + * SPDX-License-Identifier: Apache-2.0 + */ + +import type React from 'react'; +import { Box, Text } from 'ink'; +import { Colors } from '../colors.js'; +import { + RadioButtonSelect, + type RadioSelectItem, +} from './shared/RadioButtonSelect.js'; +import { useKeypress } from '../hooks/useKeypress.js'; + +export enum VisionSwitchOutcome { + SwitchOnce = 'switch_once', + SwitchSessionToVL = 'switch_session_to_vl', + DisallowWithGuidance = 'disallow_with_guidance', +} + +export interface ModelSwitchDialogProps { + onSelect: (outcome: VisionSwitchOutcome) => void; +} + +export const ModelSwitchDialog: React.FC = ({ + onSelect, +}) => { + useKeypress( + (key) => { + if (key.name === 'escape') { + onSelect(VisionSwitchOutcome.DisallowWithGuidance); + } + }, + { isActive: true }, + ); + + const options: Array> = [ + { + label: 'Switch for this request only', + value: VisionSwitchOutcome.SwitchOnce, + }, + { + label: 'Switch session to vision model', + value: VisionSwitchOutcome.SwitchSessionToVL, + }, + { + label: 'Do not switch, show guidance', + value: VisionSwitchOutcome.DisallowWithGuidance, + }, + ]; + + const handleSelect = (outcome: VisionSwitchOutcome) => { + onSelect(outcome); + }; + + return ( + + + Vision Model Switch Required + + Your message contains an image, but the current model doesn't + support vision. + + How would you like to proceed? + + + + + + + + Press Enter to select, Esc to cancel + + + ); +}; diff --git a/packages/cli/src/ui/hooks/slashCommandProcessor.test.ts b/packages/cli/src/ui/hooks/slashCommandProcessor.test.ts index 5403bec8..44b99fe9 100644 --- a/packages/cli/src/ui/hooks/slashCommandProcessor.test.ts +++ b/packages/cli/src/ui/hooks/slashCommandProcessor.test.ts @@ -106,6 +106,7 @@ describe('useSlashCommandProcessor', () => { const mockLoadHistory = vi.fn(); const mockOpenThemeDialog = vi.fn(); const mockOpenAuthDialog = vi.fn(); + const mockOpenModelSelectionDialog = vi.fn(); const mockSetQuittingMessages = vi.fn(); const mockConfig = makeFakeConfig({}); @@ -122,6 +123,7 @@ describe('useSlashCommandProcessor', () => { mockBuiltinLoadCommands.mockResolvedValue([]); mockFileLoadCommands.mockResolvedValue([]); mockMcpLoadCommands.mockResolvedValue([]); + mockOpenModelSelectionDialog.mockClear(); }); const setupProcessorHook = ( @@ -150,11 +152,13 @@ describe('useSlashCommandProcessor', () => { mockSetQuittingMessages, vi.fn(), // openPrivacyNotice vi.fn(), // openSettingsDialog + mockOpenModelSelectionDialog, vi.fn(), // openSubagentCreateDialog vi.fn(), // openAgentsManagerDialog vi.fn(), // toggleVimEnabled setIsProcessing, vi.fn(), // setGeminiMdFileCount + vi.fn(), // _showQuitConfirmation ), ); @@ -395,6 +399,21 @@ describe('useSlashCommandProcessor', () => { expect(mockOpenThemeDialog).toHaveBeenCalled(); }); + it('should handle "dialog: model" action', async () => { + const command = createTestCommand({ + name: 'modelcmd', + action: vi.fn().mockResolvedValue({ type: 'dialog', dialog: 'model' }), + }); + const result = setupProcessorHook([command]); + await waitFor(() => expect(result.current.slashCommands).toHaveLength(1)); + + await act(async () => { + await result.current.handleSlashCommand('/modelcmd'); + }); + + expect(mockOpenModelSelectionDialog).toHaveBeenCalled(); + }); + it('should handle "load_history" action', async () => { const command = createTestCommand({ name: 'load', @@ -904,11 +923,13 @@ describe('useSlashCommandProcessor', () => { mockSetQuittingMessages, vi.fn(), // openPrivacyNotice vi.fn(), // openSettingsDialog + vi.fn(), // openModelSelectionDialog vi.fn(), // openSubagentCreateDialog vi.fn(), // openAgentsManagerDialog vi.fn(), // toggleVimEnabled vi.fn(), // setIsProcessing vi.fn(), // setGeminiMdFileCount + vi.fn(), // _showQuitConfirmation ), ); diff --git a/packages/cli/src/ui/hooks/slashCommandProcessor.ts b/packages/cli/src/ui/hooks/slashCommandProcessor.ts index 3e49a0eb..10c4573d 100644 --- a/packages/cli/src/ui/hooks/slashCommandProcessor.ts +++ b/packages/cli/src/ui/hooks/slashCommandProcessor.ts @@ -53,6 +53,7 @@ export const useSlashCommandProcessor = ( setQuittingMessages: (message: HistoryItem[]) => void, openPrivacyNotice: () => void, openSettingsDialog: () => void, + openModelSelectionDialog: () => void, openSubagentCreateDialog: () => void, openAgentsManagerDialog: () => void, toggleVimEnabled: () => Promise, @@ -404,6 +405,9 @@ export const useSlashCommandProcessor = ( case 'settings': openSettingsDialog(); return { type: 'handled' }; + case 'model': + openModelSelectionDialog(); + return { type: 'handled' }; case 'subagent_create': openSubagentCreateDialog(); return { type: 'handled' }; @@ -663,6 +667,7 @@ export const useSlashCommandProcessor = ( setSessionShellAllowlist, setIsProcessing, setConfirmationRequest, + openModelSelectionDialog, session.stats, ], ); diff --git a/packages/cli/src/ui/hooks/useGeminiStream.test.tsx b/packages/cli/src/ui/hooks/useGeminiStream.test.tsx index 9eab226c..125620cf 100644 --- a/packages/cli/src/ui/hooks/useGeminiStream.test.tsx +++ b/packages/cli/src/ui/hooks/useGeminiStream.test.tsx @@ -56,6 +56,12 @@ const MockedUserPromptEvent = vi.hoisted(() => ); const mockParseAndFormatApiError = vi.hoisted(() => vi.fn()); +// Vision auto-switch mocks (hoisted) +const mockHandleVisionSwitch = vi.hoisted(() => + vi.fn().mockResolvedValue({ shouldProceed: true }), +); +const mockRestoreOriginalModel = vi.hoisted(() => vi.fn()); + vi.mock('@qwen-code/qwen-code-core', async (importOriginal) => { const actualCoreModule = (await importOriginal()) as any; return { @@ -76,6 +82,13 @@ vi.mock('./useReactToolScheduler.js', async (importOriginal) => { }; }); +vi.mock('./useVisionAutoSwitch.js', () => ({ + useVisionAutoSwitch: vi.fn(() => ({ + handleVisionSwitch: mockHandleVisionSwitch, + restoreOriginalModel: mockRestoreOriginalModel, + })), +})); + vi.mock('./useKeypress.js', () => ({ useKeypress: vi.fn(), })); @@ -199,6 +212,7 @@ describe('useGeminiStream', () => { getContentGeneratorConfig: vi .fn() .mockReturnValue(contentGeneratorConfig), + getMaxSessionTurns: vi.fn(() => 50), } as unknown as Config; mockOnDebugMessage = vi.fn(); mockHandleSlashCommand = vi.fn().mockResolvedValue(false); @@ -1551,6 +1565,7 @@ describe('useGeminiStream', () => { expect.any(String), // Argument 3: The prompt_id string ); }); + describe('Thought Reset', () => { it('should reset thought to null when starting a new prompt', async () => { // First, simulate a response with a thought @@ -1900,4 +1915,166 @@ describe('useGeminiStream', () => { ); }); }); + + // --- New tests focused on recent modifications --- + describe('Vision Auto Switch Integration', () => { + it('should call handleVisionSwitch and proceed to send when allowed', async () => { + mockHandleVisionSwitch.mockResolvedValueOnce({ shouldProceed: true }); + mockSendMessageStream.mockReturnValue( + (async function* () { + yield { type: ServerGeminiEventType.Content, value: 'ok' }; + yield { type: ServerGeminiEventType.Finished, value: 'STOP' }; + })(), + ); + + const { result } = renderHook(() => + useGeminiStream( + new MockedGeminiClientClass(mockConfig), + [], + mockAddItem, + mockConfig, + mockOnDebugMessage, + mockHandleSlashCommand, + false, + () => 'vscode' as EditorType, + () => {}, + () => Promise.resolve(), + false, + () => {}, + () => {}, + () => {}, + ), + ); + + await act(async () => { + await result.current.submitQuery('image prompt'); + }); + + await waitFor(() => { + expect(mockHandleVisionSwitch).toHaveBeenCalled(); + expect(mockSendMessageStream).toHaveBeenCalled(); + }); + }); + + it('should gate submission when handleVisionSwitch returns shouldProceed=false', async () => { + mockHandleVisionSwitch.mockResolvedValueOnce({ shouldProceed: false }); + + const { result } = renderHook(() => + useGeminiStream( + new MockedGeminiClientClass(mockConfig), + [], + mockAddItem, + mockConfig, + mockOnDebugMessage, + mockHandleSlashCommand, + false, + () => 'vscode' as EditorType, + () => {}, + () => Promise.resolve(), + false, + () => {}, + () => {}, + () => {}, + ), + ); + + await act(async () => { + await result.current.submitQuery('vision-gated'); + }); + + // No call to API, no restoreOriginalModel needed since no override occurred + expect(mockSendMessageStream).not.toHaveBeenCalled(); + expect(mockRestoreOriginalModel).not.toHaveBeenCalled(); + + // Next call allowed (flag reset path) + mockHandleVisionSwitch.mockResolvedValueOnce({ shouldProceed: true }); + mockSendMessageStream.mockReturnValue( + (async function* () { + yield { type: ServerGeminiEventType.Content, value: 'ok' }; + yield { type: ServerGeminiEventType.Finished, value: 'STOP' }; + })(), + ); + await act(async () => { + await result.current.submitQuery('after-gate'); + }); + await waitFor(() => { + expect(mockSendMessageStream).toHaveBeenCalled(); + }); + }); + }); + + describe('Model restore on completion and errors', () => { + it('should restore model after successful stream completion', async () => { + mockSendMessageStream.mockReturnValue( + (async function* () { + yield { type: ServerGeminiEventType.Content, value: 'content' }; + yield { type: ServerGeminiEventType.Finished, value: 'STOP' }; + })(), + ); + + const { result } = renderHook(() => + useGeminiStream( + new MockedGeminiClientClass(mockConfig), + [], + mockAddItem, + mockConfig, + mockOnDebugMessage, + mockHandleSlashCommand, + false, + () => 'vscode' as EditorType, + () => {}, + () => Promise.resolve(), + false, + () => {}, + () => {}, + () => {}, + ), + ); + + await act(async () => { + await result.current.submitQuery('restore-success'); + }); + + await waitFor(() => { + expect(mockRestoreOriginalModel).toHaveBeenCalledTimes(1); + }); + }); + + it('should restore model when an error occurs during streaming', async () => { + const testError = new Error('stream failure'); + mockSendMessageStream.mockReturnValue( + (async function* () { + yield { type: ServerGeminiEventType.Content, value: 'content' }; + throw testError; + })(), + ); + + const { result } = renderHook(() => + useGeminiStream( + new MockedGeminiClientClass(mockConfig), + [], + mockAddItem, + mockConfig, + mockOnDebugMessage, + mockHandleSlashCommand, + false, + () => 'vscode' as EditorType, + () => {}, + () => Promise.resolve(), + false, + () => {}, + () => {}, + () => {}, + ), + ); + + await act(async () => { + await result.current.submitQuery('restore-error'); + }); + + await waitFor(() => { + expect(mockRestoreOriginalModel).toHaveBeenCalledTimes(1); + }); + }); + }); }); diff --git a/packages/cli/src/ui/hooks/useGeminiStream.ts b/packages/cli/src/ui/hooks/useGeminiStream.ts index 07f4d7b9..7f34eaa2 100644 --- a/packages/cli/src/ui/hooks/useGeminiStream.ts +++ b/packages/cli/src/ui/hooks/useGeminiStream.ts @@ -42,6 +42,7 @@ import type { import { StreamingState, MessageType, ToolCallStatus } from '../types.js'; import { isAtCommand, isSlashCommand } from '../utils/commandUtils.js'; import { useShellCommandProcessor } from './shellCommandProcessor.js'; +import { useVisionAutoSwitch } from './useVisionAutoSwitch.js'; import { handleAtCommand } from './atCommandProcessor.js'; import { findLastSafeSplitPoint } from '../utils/markdownUtilities.js'; import { useStateAndRef } from './useStateAndRef.js'; @@ -88,6 +89,12 @@ export const useGeminiStream = ( setModelSwitchedFromQuotaError: React.Dispatch>, onEditorClose: () => void, onCancelSubmit: () => void, + visionModelPreviewEnabled: boolean = false, + onVisionSwitchRequired?: (query: PartListUnion) => Promise<{ + modelOverride?: string; + persistSessionModel?: string; + showGuidance?: boolean; + }>, ) => { const [initError, setInitError] = useState(null); const abortControllerRef = useRef(null); @@ -155,6 +162,13 @@ export const useGeminiStream = ( geminiClient, ); + const { handleVisionSwitch, restoreOriginalModel } = useVisionAutoSwitch( + config, + addItem, + visionModelPreviewEnabled, + onVisionSwitchRequired, + ); + const streamingState = useMemo(() => { if (toolCalls.some((tc) => tc.status === 'awaiting_approval')) { return StreamingState.WaitingForConfirmation; @@ -715,6 +729,20 @@ export const useGeminiStream = ( return; } + // Handle vision switch requirement + const visionSwitchResult = await handleVisionSwitch( + queryToSend, + userMessageTimestamp, + options?.isContinuation || false, + ); + + if (!visionSwitchResult.shouldProceed) { + isSubmittingQueryRef.current = false; + return; + } + + const finalQueryToSend = queryToSend; + if (!options?.isContinuation) { startNewPrompt(); setThought(null); // Reset thought when starting a new prompt @@ -725,7 +753,7 @@ export const useGeminiStream = ( try { const stream = geminiClient.sendMessageStream( - queryToSend, + finalQueryToSend, abortSignal, prompt_id!, ); @@ -736,6 +764,8 @@ export const useGeminiStream = ( ); if (processingStatus === StreamProcessingStatus.UserCancelled) { + // Restore original model if it was temporarily overridden + restoreOriginalModel(); isSubmittingQueryRef.current = false; return; } @@ -748,7 +778,13 @@ export const useGeminiStream = ( loopDetectedRef.current = false; handleLoopDetectedEvent(); } + + // Restore original model if it was temporarily overridden + restoreOriginalModel(); } catch (error: unknown) { + // Restore original model if it was temporarily overridden + restoreOriginalModel(); + if (error instanceof UnauthorizedError) { onAuthError(); } else if (!isNodeError(error) || error.name !== 'AbortError') { @@ -786,6 +822,8 @@ export const useGeminiStream = ( startNewPrompt, getPromptCount, handleLoopDetectedEvent, + handleVisionSwitch, + restoreOriginalModel, ], ); diff --git a/packages/cli/src/ui/hooks/useVisionAutoSwitch.test.ts b/packages/cli/src/ui/hooks/useVisionAutoSwitch.test.ts new file mode 100644 index 00000000..dd8c6a06 --- /dev/null +++ b/packages/cli/src/ui/hooks/useVisionAutoSwitch.test.ts @@ -0,0 +1,374 @@ +/** + * @license + * Copyright 2025 Qwen + * SPDX-License-Identifier: Apache-2.0 + */ + +/* eslint-disable @typescript-eslint/no-explicit-any */ +import { describe, it, expect, vi, beforeEach } from 'vitest'; +import { renderHook, act } from '@testing-library/react'; +import type { Part, PartListUnion } from '@google/genai'; +import { AuthType, type Config } from '@qwen-code/qwen-code-core'; +import { + shouldOfferVisionSwitch, + processVisionSwitchOutcome, + getVisionSwitchGuidanceMessage, + useVisionAutoSwitch, +} from './useVisionAutoSwitch.js'; +import { VisionSwitchOutcome } from '../components/ModelSwitchDialog.js'; +import { MessageType } from '../types.js'; +import { getDefaultVisionModel } from '../models/availableModels.js'; + +describe('useVisionAutoSwitch helpers', () => { + describe('shouldOfferVisionSwitch', () => { + it('returns false when authType is not QWEN_OAUTH', () => { + const parts: PartListUnion = [ + { inlineData: { mimeType: 'image/png', data: '...' } }, + ]; + const result = shouldOfferVisionSwitch( + parts, + AuthType.USE_GEMINI, + 'qwen3-coder-plus', + true, + ); + expect(result).toBe(false); + }); + + it('returns false when current model is already a vision model', () => { + const parts: PartListUnion = [ + { inlineData: { mimeType: 'image/png', data: '...' } }, + ]; + const result = shouldOfferVisionSwitch( + parts, + AuthType.QWEN_OAUTH, + 'qwen-vl-max-latest', + true, + ); + expect(result).toBe(false); + }); + + it('returns true when image parts exist, QWEN_OAUTH, and model is not vision', () => { + const parts: PartListUnion = [ + { text: 'hello' }, + { inlineData: { mimeType: 'image/jpeg', data: '...' } }, + ]; + const result = shouldOfferVisionSwitch( + parts, + AuthType.QWEN_OAUTH, + 'qwen3-coder-plus', + true, + ); + expect(result).toBe(true); + }); + + it('detects image when provided as a single Part object (non-array)', () => { + const singleImagePart: PartListUnion = { + fileData: { mimeType: 'image/gif', fileUri: 'file://image.gif' }, + } as Part; + const result = shouldOfferVisionSwitch( + singleImagePart, + AuthType.QWEN_OAUTH, + 'qwen3-coder-plus', + true, + ); + expect(result).toBe(true); + }); + + it('returns false when parts contain no images', () => { + const parts: PartListUnion = [{ text: 'just text' }]; + const result = shouldOfferVisionSwitch( + parts, + AuthType.QWEN_OAUTH, + 'qwen3-coder-plus', + true, + ); + expect(result).toBe(false); + }); + + it('returns false when parts is a plain string', () => { + const parts: PartListUnion = 'plain text'; + const result = shouldOfferVisionSwitch( + parts, + AuthType.QWEN_OAUTH, + 'qwen3-coder-plus', + true, + ); + expect(result).toBe(false); + }); + + it('returns false when visionModelPreviewEnabled is false', () => { + const parts: PartListUnion = [ + { inlineData: { mimeType: 'image/png', data: '...' } }, + ]; + const result = shouldOfferVisionSwitch( + parts, + AuthType.QWEN_OAUTH, + 'qwen3-coder-plus', + false, + ); + expect(result).toBe(false); + }); + }); + + describe('processVisionSwitchOutcome', () => { + it('maps SwitchOnce to a one-time model override', () => { + const vl = getDefaultVisionModel(); + const result = processVisionSwitchOutcome(VisionSwitchOutcome.SwitchOnce); + expect(result).toEqual({ modelOverride: vl }); + }); + + it('maps SwitchSessionToVL to a persistent session model', () => { + const vl = getDefaultVisionModel(); + const result = processVisionSwitchOutcome( + VisionSwitchOutcome.SwitchSessionToVL, + ); + expect(result).toEqual({ persistSessionModel: vl }); + }); + + it('maps DisallowWithGuidance to showGuidance', () => { + const result = processVisionSwitchOutcome( + VisionSwitchOutcome.DisallowWithGuidance, + ); + expect(result).toEqual({ showGuidance: true }); + }); + }); + + describe('getVisionSwitchGuidanceMessage', () => { + it('returns the expected guidance message', () => { + const vl = getDefaultVisionModel(); + const expected = + 'To use images with your query, you can:\n' + + `• Use /model set ${vl} to switch to a vision-capable model\n` + + '• Or remove the image and provide a text description instead'; + expect(getVisionSwitchGuidanceMessage()).toBe(expected); + }); + }); +}); + +describe('useVisionAutoSwitch hook', () => { + type AddItemFn = ( + item: { type: MessageType; text: string }, + ts: number, + ) => any; + + const createMockConfig = (authType: AuthType, initialModel: string) => { + let currentModel = initialModel; + const mockConfig: Partial = { + getModel: vi.fn(() => currentModel), + setModel: vi.fn((m: string) => { + currentModel = m; + }), + getContentGeneratorConfig: vi.fn(() => ({ + authType, + model: currentModel, + apiKey: 'test-key', + vertexai: false, + })), + }; + return mockConfig as Config; + }; + + let addItem: AddItemFn; + + beforeEach(() => { + vi.clearAllMocks(); + addItem = vi.fn(); + }); + + it('returns shouldProceed=true immediately for continuations', async () => { + const config = createMockConfig(AuthType.QWEN_OAUTH, 'qwen3-coder-plus'); + const { result } = renderHook(() => + useVisionAutoSwitch(config, addItem as any, true, vi.fn()), + ); + + const parts: PartListUnion = [ + { inlineData: { mimeType: 'image/png', data: '...' } }, + ]; + let res: any; + await act(async () => { + res = await result.current.handleVisionSwitch(parts, Date.now(), true); + }); + expect(res).toEqual({ shouldProceed: true }); + expect(addItem).not.toHaveBeenCalled(); + }); + + it('does nothing when authType is not QWEN_OAUTH', async () => { + const config = createMockConfig(AuthType.USE_GEMINI, 'qwen3-coder-plus'); + const onVisionSwitchRequired = vi.fn(); + const { result } = renderHook(() => + useVisionAutoSwitch(config, addItem as any, true, onVisionSwitchRequired), + ); + + const parts: PartListUnion = [ + { inlineData: { mimeType: 'image/png', data: '...' } }, + ]; + let res: any; + await act(async () => { + res = await result.current.handleVisionSwitch(parts, 123, false); + }); + expect(res).toEqual({ shouldProceed: true }); + expect(onVisionSwitchRequired).not.toHaveBeenCalled(); + }); + + it('does nothing when there are no image parts', async () => { + const config = createMockConfig(AuthType.QWEN_OAUTH, 'qwen3-coder-plus'); + const onVisionSwitchRequired = vi.fn(); + const { result } = renderHook(() => + useVisionAutoSwitch(config, addItem as any, true, onVisionSwitchRequired), + ); + + const parts: PartListUnion = [{ text: 'no images here' }]; + let res: any; + await act(async () => { + res = await result.current.handleVisionSwitch(parts, 456, false); + }); + expect(res).toEqual({ shouldProceed: true }); + expect(onVisionSwitchRequired).not.toHaveBeenCalled(); + }); + + it('shows guidance and blocks when dialog returns showGuidance', async () => { + const config = createMockConfig(AuthType.QWEN_OAUTH, 'qwen3-coder-plus'); + const onVisionSwitchRequired = vi + .fn() + .mockResolvedValue({ showGuidance: true }); + const { result } = renderHook(() => + useVisionAutoSwitch(config, addItem as any, true, onVisionSwitchRequired), + ); + + const parts: PartListUnion = [ + { inlineData: { mimeType: 'image/png', data: '...' } }, + ]; + + const userTs = 1010; + let res: any; + await act(async () => { + res = await result.current.handleVisionSwitch(parts, userTs, false); + }); + + expect(addItem).toHaveBeenCalledWith( + { type: MessageType.INFO, text: getVisionSwitchGuidanceMessage() }, + userTs, + ); + expect(res).toEqual({ shouldProceed: false }); + expect(config.setModel).not.toHaveBeenCalled(); + }); + + it('applies a one-time override and returns originalModel, then restores', async () => { + const initialModel = 'qwen3-coder-plus'; + const config = createMockConfig(AuthType.QWEN_OAUTH, initialModel); + const onVisionSwitchRequired = vi + .fn() + .mockResolvedValue({ modelOverride: 'qwen-vl-max-latest' }); + const { result } = renderHook(() => + useVisionAutoSwitch(config, addItem as any, true, onVisionSwitchRequired), + ); + + const parts: PartListUnion = [ + { inlineData: { mimeType: 'image/png', data: '...' } }, + ]; + + let res: any; + await act(async () => { + res = await result.current.handleVisionSwitch(parts, 2020, false); + }); + + expect(res).toEqual({ shouldProceed: true, originalModel: initialModel }); + expect(config.setModel).toHaveBeenCalledWith('qwen-vl-max-latest'); + + // Now restore + act(() => { + result.current.restoreOriginalModel(); + }); + expect(config.setModel).toHaveBeenLastCalledWith(initialModel); + }); + + it('persists session model when dialog requests persistence', async () => { + const config = createMockConfig(AuthType.QWEN_OAUTH, 'qwen3-coder-plus'); + const onVisionSwitchRequired = vi + .fn() + .mockResolvedValue({ persistSessionModel: 'qwen-vl-max-latest' }); + const { result } = renderHook(() => + useVisionAutoSwitch(config, addItem as any, true, onVisionSwitchRequired), + ); + + const parts: PartListUnion = [ + { inlineData: { mimeType: 'image/png', data: '...' } }, + ]; + + let res: any; + await act(async () => { + res = await result.current.handleVisionSwitch(parts, 3030, false); + }); + + expect(res).toEqual({ shouldProceed: true }); + expect(config.setModel).toHaveBeenCalledWith('qwen-vl-max-latest'); + + // Restore should be a no-op since no one-time override was used + act(() => { + result.current.restoreOriginalModel(); + }); + // Last call should still be the persisted model set + expect((config.setModel as any).mock.calls.pop()?.[0]).toBe( + 'qwen-vl-max-latest', + ); + }); + + it('returns shouldProceed=true when dialog returns no special flags', async () => { + const config = createMockConfig(AuthType.QWEN_OAUTH, 'qwen3-coder-plus'); + const onVisionSwitchRequired = vi.fn().mockResolvedValue({}); + const { result } = renderHook(() => + useVisionAutoSwitch(config, addItem as any, true, onVisionSwitchRequired), + ); + + const parts: PartListUnion = [ + { inlineData: { mimeType: 'image/png', data: '...' } }, + ]; + let res: any; + await act(async () => { + res = await result.current.handleVisionSwitch(parts, 4040, false); + }); + expect(res).toEqual({ shouldProceed: true }); + expect(config.setModel).not.toHaveBeenCalled(); + }); + + it('blocks when dialog throws or is cancelled', async () => { + const config = createMockConfig(AuthType.QWEN_OAUTH, 'qwen3-coder-plus'); + const onVisionSwitchRequired = vi.fn().mockRejectedValue(new Error('x')); + const { result } = renderHook(() => + useVisionAutoSwitch(config, addItem as any, true, onVisionSwitchRequired), + ); + + const parts: PartListUnion = [ + { inlineData: { mimeType: 'image/png', data: '...' } }, + ]; + let res: any; + await act(async () => { + res = await result.current.handleVisionSwitch(parts, 5050, false); + }); + expect(res).toEqual({ shouldProceed: false }); + expect(config.setModel).not.toHaveBeenCalled(); + }); + + it('does nothing when visionModelPreviewEnabled is false', async () => { + const config = createMockConfig(AuthType.QWEN_OAUTH, 'qwen3-coder-plus'); + const onVisionSwitchRequired = vi.fn(); + const { result } = renderHook(() => + useVisionAutoSwitch( + config, + addItem as any, + false, + onVisionSwitchRequired, + ), + ); + + const parts: PartListUnion = [ + { inlineData: { mimeType: 'image/png', data: '...' } }, + ]; + let res: any; + await act(async () => { + res = await result.current.handleVisionSwitch(parts, 6060, false); + }); + expect(res).toEqual({ shouldProceed: true }); + expect(onVisionSwitchRequired).not.toHaveBeenCalled(); + }); +}); diff --git a/packages/cli/src/ui/hooks/useVisionAutoSwitch.ts b/packages/cli/src/ui/hooks/useVisionAutoSwitch.ts new file mode 100644 index 00000000..d4b9629c --- /dev/null +++ b/packages/cli/src/ui/hooks/useVisionAutoSwitch.ts @@ -0,0 +1,304 @@ +/** + * @license + * Copyright 2025 Qwen + * SPDX-License-Identifier: Apache-2.0 + */ + +import { type PartListUnion, type Part } from '@google/genai'; +import { AuthType, type Config } from '@qwen-code/qwen-code-core'; +import { useCallback, useRef } from 'react'; +import { VisionSwitchOutcome } from '../components/ModelSwitchDialog.js'; +import { + getDefaultVisionModel, + isVisionModel, +} from '../models/availableModels.js'; +import { MessageType } from '../types.js'; +import type { UseHistoryManagerReturn } from './useHistoryManager.js'; +import { + isSupportedImageMimeType, + getUnsupportedImageFormatWarning, +} from '@qwen-code/qwen-code-core'; + +/** + * Checks if a PartListUnion contains image parts + */ +function hasImageParts(parts: PartListUnion): boolean { + if (typeof parts === 'string') { + return false; + } + + if (Array.isArray(parts)) { + return parts.some((part) => { + // Skip string parts + if (typeof part === 'string') return false; + return isImagePart(part); + }); + } + + // If it's a single Part (not a string), check if it's an image + if (typeof parts === 'object') { + return isImagePart(parts); + } + + return false; +} + +/** + * Checks if a single Part is an image part + */ +function isImagePart(part: Part): boolean { + // Check for inlineData with image mime type + if ('inlineData' in part && part.inlineData?.mimeType?.startsWith('image/')) { + return true; + } + + // Check for fileData with image mime type + if ('fileData' in part && part.fileData?.mimeType?.startsWith('image/')) { + return true; + } + + return false; +} + +/** + * Checks if image parts have supported formats and returns unsupported ones + */ +function checkImageFormatsSupport(parts: PartListUnion): { + hasImages: boolean; + hasUnsupportedFormats: boolean; + unsupportedMimeTypes: string[]; +} { + const unsupportedMimeTypes: string[] = []; + let hasImages = false; + + if (typeof parts === 'string') { + return { + hasImages: false, + hasUnsupportedFormats: false, + unsupportedMimeTypes: [], + }; + } + + const partsArray = Array.isArray(parts) ? parts : [parts]; + + for (const part of partsArray) { + if (typeof part === 'string') continue; + + let mimeType: string | undefined; + + // Check inlineData + if ( + 'inlineData' in part && + part.inlineData?.mimeType?.startsWith('image/') + ) { + hasImages = true; + mimeType = part.inlineData.mimeType; + } + + // Check fileData + if ('fileData' in part && part.fileData?.mimeType?.startsWith('image/')) { + hasImages = true; + mimeType = part.fileData.mimeType; + } + + // Check if the mime type is supported + if (mimeType && !isSupportedImageMimeType(mimeType)) { + unsupportedMimeTypes.push(mimeType); + } + } + + return { + hasImages, + hasUnsupportedFormats: unsupportedMimeTypes.length > 0, + unsupportedMimeTypes, + }; +} + +/** + * Determines if we should offer vision switch for the given parts, auth type, and current model + */ +export function shouldOfferVisionSwitch( + parts: PartListUnion, + authType: AuthType, + currentModel: string, + visionModelPreviewEnabled: boolean = false, +): boolean { + // Only trigger for qwen-oauth + if (authType !== AuthType.QWEN_OAUTH) { + return false; + } + + // If vision model preview is disabled, never offer vision switch + if (!visionModelPreviewEnabled) { + return false; + } + + // If current model is already a vision model, no need to switch + if (isVisionModel(currentModel)) { + return false; + } + + // Check if the current message contains image parts + return hasImageParts(parts); +} + +/** + * Interface for vision switch result + */ +export interface VisionSwitchResult { + modelOverride?: string; + persistSessionModel?: string; + showGuidance?: boolean; +} + +/** + * Processes the vision switch outcome and returns the appropriate result + */ +export function processVisionSwitchOutcome( + outcome: VisionSwitchOutcome, +): VisionSwitchResult { + const vlModelId = getDefaultVisionModel(); + + switch (outcome) { + case VisionSwitchOutcome.SwitchOnce: + return { modelOverride: vlModelId }; + + case VisionSwitchOutcome.SwitchSessionToVL: + return { persistSessionModel: vlModelId }; + + case VisionSwitchOutcome.DisallowWithGuidance: + return { showGuidance: true }; + + default: + return { showGuidance: true }; + } +} + +/** + * Gets the guidance message for when vision switch is disallowed + */ +export function getVisionSwitchGuidanceMessage(): string { + const vlModelId = getDefaultVisionModel(); + return `To use images with your query, you can: +• Use /model set ${vlModelId} to switch to a vision-capable model +• Or remove the image and provide a text description instead`; +} + +/** + * Interface for vision switch handling result + */ +export interface VisionSwitchHandlingResult { + shouldProceed: boolean; + originalModel?: string; +} + +/** + * Custom hook for handling vision model auto-switching + */ +export function useVisionAutoSwitch( + config: Config, + addItem: UseHistoryManagerReturn['addItem'], + visionModelPreviewEnabled: boolean = false, + onVisionSwitchRequired?: (query: PartListUnion) => Promise<{ + modelOverride?: string; + persistSessionModel?: string; + showGuidance?: boolean; + }>, +) { + const originalModelRef = useRef(null); + + const handleVisionSwitch = useCallback( + async ( + query: PartListUnion, + userMessageTimestamp: number, + isContinuation: boolean, + ): Promise => { + // Skip vision switch handling for continuations or if no handler provided + if (isContinuation || !onVisionSwitchRequired) { + return { shouldProceed: true }; + } + + const contentGeneratorConfig = config.getContentGeneratorConfig(); + + // Only handle qwen-oauth auth type + if (contentGeneratorConfig?.authType !== AuthType.QWEN_OAUTH) { + return { shouldProceed: true }; + } + + // Check image format support first + const formatCheck = checkImageFormatsSupport(query); + + // If there are unsupported image formats, show warning + if (formatCheck.hasUnsupportedFormats) { + addItem( + { + type: MessageType.INFO, + text: getUnsupportedImageFormatWarning(), + }, + userMessageTimestamp, + ); + // Continue processing but with warning shown + } + + // Check if vision switch is needed + if ( + !shouldOfferVisionSwitch( + query, + contentGeneratorConfig.authType, + config.getModel(), + visionModelPreviewEnabled, + ) + ) { + return { shouldProceed: true }; + } + + try { + const visionSwitchResult = await onVisionSwitchRequired(query); + + if (visionSwitchResult.showGuidance) { + // Show guidance and don't proceed with the request + addItem( + { + type: MessageType.INFO, + text: getVisionSwitchGuidanceMessage(), + }, + userMessageTimestamp, + ); + return { shouldProceed: false }; + } + + if (visionSwitchResult.modelOverride) { + // One-time model override + originalModelRef.current = config.getModel(); + config.setModel(visionSwitchResult.modelOverride); + return { + shouldProceed: true, + originalModel: originalModelRef.current, + }; + } else if (visionSwitchResult.persistSessionModel) { + // Persistent session model change + config.setModel(visionSwitchResult.persistSessionModel); + return { shouldProceed: true }; + } + + return { shouldProceed: true }; + } catch (_error) { + // If vision switch dialog was cancelled or errored, don't proceed + return { shouldProceed: false }; + } + }, + [config, addItem, visionModelPreviewEnabled, onVisionSwitchRequired], + ); + + const restoreOriginalModel = useCallback(() => { + if (originalModelRef.current) { + config.setModel(originalModelRef.current); + originalModelRef.current = null; + } + }, [config]); + + return { + handleVisionSwitch, + restoreOriginalModel, + }; +} diff --git a/packages/cli/src/ui/models/availableModels.ts b/packages/cli/src/ui/models/availableModels.ts new file mode 100644 index 00000000..7c3a1cf5 --- /dev/null +++ b/packages/cli/src/ui/models/availableModels.ts @@ -0,0 +1,52 @@ +/** + * @license + * Copyright 2025 Qwen + * SPDX-License-Identifier: Apache-2.0 + */ + +export type AvailableModel = { + id: string; + label: string; + isVision?: boolean; +}; + +export const AVAILABLE_MODELS_QWEN: AvailableModel[] = [ + { id: 'qwen3-coder-plus', label: 'qwen3-coder-plus' }, + { id: 'qwen-vl-max-latest', label: 'qwen-vl-max', isVision: true }, +]; + +/** + * Get available Qwen models filtered by vision model preview setting + */ +export function getFilteredQwenModels( + visionModelPreviewEnabled: boolean, +): AvailableModel[] { + if (visionModelPreviewEnabled) { + return AVAILABLE_MODELS_QWEN; + } + return AVAILABLE_MODELS_QWEN.filter((model) => !model.isVision); +} + +/** + * Currently we use the single model of `OPENAI_MODEL` in the env. + * In the future, after settings.json is updated, we will allow users to configure this themselves. + */ +export function getOpenAIAvailableModelFromEnv(): AvailableModel | null { + const id = process.env['OPENAI_MODEL']?.trim(); + return id ? { id, label: id } : null; +} + +/** +/** + * Hard code the default vision model as a string literal, + * until our coding model supports multimodal. + */ +export function getDefaultVisionModel(): string { + return 'qwen-vl-max-latest'; +} + +export function isVisionModel(modelId: string): boolean { + return AVAILABLE_MODELS_QWEN.some( + (model) => model.id === modelId && model.isVision, + ); +} diff --git a/packages/core/index.ts b/packages/core/index.ts index 447560d4..3cc271d0 100644 --- a/packages/core/index.ts +++ b/packages/core/index.ts @@ -19,3 +19,4 @@ export { } from './src/telemetry/types.js'; export { makeFakeConfig } from './src/test-utils/config.js'; export * from './src/utils/pathReader.js'; +export * from './src/utils/request-tokenizer/supportedImageFormats.js'; diff --git a/packages/core/src/core/__tests__/openaiTimeoutHandling.test.ts b/packages/core/src/core/__tests__/openaiTimeoutHandling.test.ts index f28c786c..7f4eec69 100644 --- a/packages/core/src/core/__tests__/openaiTimeoutHandling.test.ts +++ b/packages/core/src/core/__tests__/openaiTimeoutHandling.test.ts @@ -5,9 +5,10 @@ */ import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest'; -import { OpenAIContentGenerator } from '../openaiContentGenerator.js'; +import { OpenAIContentGenerator } from '../openaiContentGenerator/openaiContentGenerator.js'; import type { Config } from '../../config/config.js'; import { AuthType } from '../contentGenerator.js'; +import type { OpenAICompatibleProvider } from '../openaiContentGenerator/provider/index.js'; import OpenAI from 'openai'; // Mock OpenAI @@ -30,6 +31,7 @@ describe('OpenAIContentGenerator Timeout Handling', () => { let mockConfig: Config; // eslint-disable-next-line @typescript-eslint/no-explicit-any let mockOpenAIClient: any; + let mockProvider: OpenAICompatibleProvider; beforeEach(() => { // Reset mocks @@ -42,6 +44,7 @@ describe('OpenAIContentGenerator Timeout Handling', () => { mockConfig = { getContentGeneratorConfig: vi.fn().mockReturnValue({ authType: 'openai', + enableOpenAILogging: false, }), getCliVersion: vi.fn().mockReturnValue('1.0.0'), } as unknown as Config; @@ -53,17 +56,34 @@ describe('OpenAIContentGenerator Timeout Handling', () => { create: vi.fn(), }, }, + embeddings: { + create: vi.fn(), + }, }; vi.mocked(OpenAI).mockImplementation(() => mockOpenAIClient); + // Create mock provider + mockProvider = { + buildHeaders: vi.fn().mockReturnValue({ + 'User-Agent': 'QwenCode/1.0.0 (test; test)', + }), + buildClient: vi.fn().mockReturnValue(mockOpenAIClient), + buildRequest: vi.fn().mockImplementation((req) => req), + }; + // Create generator instance const contentGeneratorConfig = { model: 'gpt-4', apiKey: 'test-key', authType: AuthType.USE_OPENAI, + enableOpenAILogging: false, }; - generator = new OpenAIContentGenerator(contentGeneratorConfig, mockConfig); + generator = new OpenAIContentGenerator( + contentGeneratorConfig, + mockConfig, + mockProvider, + ); }); afterEach(() => { @@ -209,7 +229,7 @@ describe('OpenAIContentGenerator Timeout Handling', () => { await expect( generator.generateContentStream(request, 'test-prompt-id'), ).rejects.toThrow( - /Streaming setup timeout after \d+s\. Try reducing input length or increasing timeout in config\./, + /Streaming request timeout after \d+s\. Try reducing input length or increasing timeout in config\./, ); }); @@ -227,12 +247,8 @@ describe('OpenAIContentGenerator Timeout Handling', () => { } catch (error: unknown) { const errorMessage = error instanceof Error ? error.message : String(error); - expect(errorMessage).toContain( - 'Streaming setup timeout troubleshooting:', - ); - expect(errorMessage).toContain( - 'Check network connectivity and firewall settings', - ); + expect(errorMessage).toContain('Streaming timeout troubleshooting:'); + expect(errorMessage).toContain('Check network connectivity'); expect(errorMessage).toContain('Consider using non-streaming mode'); } }); @@ -246,23 +262,21 @@ describe('OpenAIContentGenerator Timeout Handling', () => { authType: AuthType.USE_OPENAI, baseUrl: 'http://localhost:8080', }; - new OpenAIContentGenerator(contentGeneratorConfig, mockConfig); + new OpenAIContentGenerator( + contentGeneratorConfig, + mockConfig, + mockProvider, + ); - // Verify OpenAI client was created with timeout config - expect(OpenAI).toHaveBeenCalledWith({ - apiKey: 'test-key', - baseURL: 'http://localhost:8080', - timeout: 120000, - maxRetries: 3, - defaultHeaders: { - 'User-Agent': expect.stringMatching(/^QwenCode/), - }, - }); + // Verify provider buildClient was called + expect(mockProvider.buildClient).toHaveBeenCalled(); }); it('should use custom timeout from config', () => { const customConfig = { - getContentGeneratorConfig: vi.fn().mockReturnValue({}), + getContentGeneratorConfig: vi.fn().mockReturnValue({ + enableOpenAILogging: false, + }), getCliVersion: vi.fn().mockReturnValue('1.0.0'), } as unknown as Config; @@ -274,22 +288,31 @@ describe('OpenAIContentGenerator Timeout Handling', () => { timeout: 300000, maxRetries: 5, }; - new OpenAIContentGenerator(contentGeneratorConfig, customConfig); - expect(OpenAI).toHaveBeenCalledWith({ - apiKey: 'test-key', - baseURL: 'http://localhost:8080', - timeout: 300000, - maxRetries: 5, - defaultHeaders: { - 'User-Agent': expect.stringMatching(/^QwenCode/), - }, - }); + // Create a custom mock provider for this test + const customMockProvider: OpenAICompatibleProvider = { + buildHeaders: vi.fn().mockReturnValue({ + 'User-Agent': 'QwenCode/1.0.0 (test; test)', + }), + buildClient: vi.fn().mockReturnValue(mockOpenAIClient), + buildRequest: vi.fn().mockImplementation((req) => req), + }; + + new OpenAIContentGenerator( + contentGeneratorConfig, + customConfig, + customMockProvider, + ); + + // Verify provider buildClient was called + expect(customMockProvider.buildClient).toHaveBeenCalled(); }); it('should handle missing timeout config gracefully', () => { const noTimeoutConfig = { - getContentGeneratorConfig: vi.fn().mockReturnValue({}), + getContentGeneratorConfig: vi.fn().mockReturnValue({ + enableOpenAILogging: false, + }), getCliVersion: vi.fn().mockReturnValue('1.0.0'), } as unknown as Config; @@ -299,17 +322,24 @@ describe('OpenAIContentGenerator Timeout Handling', () => { authType: AuthType.USE_OPENAI, baseUrl: 'http://localhost:8080', }; - new OpenAIContentGenerator(contentGeneratorConfig, noTimeoutConfig); - expect(OpenAI).toHaveBeenCalledWith({ - apiKey: 'test-key', - baseURL: 'http://localhost:8080', - timeout: 120000, // default - maxRetries: 3, // default - defaultHeaders: { - 'User-Agent': expect.stringMatching(/^QwenCode/), - }, - }); + // Create a custom mock provider for this test + const noTimeoutMockProvider: OpenAICompatibleProvider = { + buildHeaders: vi.fn().mockReturnValue({ + 'User-Agent': 'QwenCode/1.0.0 (test; test)', + }), + buildClient: vi.fn().mockReturnValue(mockOpenAIClient), + buildRequest: vi.fn().mockImplementation((req) => req), + }; + + new OpenAIContentGenerator( + contentGeneratorConfig, + noTimeoutConfig, + noTimeoutMockProvider, + ); + + // Verify provider buildClient was called + expect(noTimeoutMockProvider.buildClient).toHaveBeenCalled(); }); }); diff --git a/packages/core/src/core/geminiChat.ts b/packages/core/src/core/geminiChat.ts index ca726bd3..bf8aa804 100644 --- a/packages/core/src/core/geminiChat.ts +++ b/packages/core/src/core/geminiChat.ts @@ -500,7 +500,7 @@ export class GeminiChat { if (error instanceof Error && error.message) { if (isSchemaDepthError(error.message)) return false; if (error.message.includes('429')) return true; - if (error.message.match(/5\d{2}/)) return true; + if (error.message.match(/^5\d{2}/)) return true; } return false; }, diff --git a/packages/core/src/core/openaiContentGenerator.test.ts b/packages/core/src/core/openaiContentGenerator.test.ts deleted file mode 100644 index d2b28842..00000000 --- a/packages/core/src/core/openaiContentGenerator.test.ts +++ /dev/null @@ -1,3511 +0,0 @@ -/** - * @license - * Copyright 2025 Qwen - * SPDX-License-Identifier: Apache-2.0 - */ - -import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest'; -import { OpenAIContentGenerator } from './openaiContentGenerator.js'; -import type { Config } from '../config/config.js'; -import { AuthType } from './contentGenerator.js'; -import OpenAI from 'openai'; -import type { - GenerateContentParameters, - CountTokensParameters, - EmbedContentParameters, - CallableTool, - Content, -} from '@google/genai'; -import { Type, FinishReason } from '@google/genai'; - -// Mock OpenAI -vi.mock('openai'); - -// Mock logger modules -vi.mock('../telemetry/loggers.js', () => ({ - logApiResponse: vi.fn(), - logApiError: vi.fn(), -})); - -vi.mock('../utils/openaiLogger.js', () => ({ - openaiLogger: { - logInteraction: vi.fn(), - }, -})); - -// Mock tiktoken -vi.mock('tiktoken', () => ({ - get_encoding: vi.fn().mockReturnValue({ - encode: vi.fn().mockReturnValue(new Array(50)), // Mock 50 tokens - free: vi.fn(), - }), -})); - -describe('OpenAIContentGenerator', () => { - let generator: OpenAIContentGenerator; - let mockConfig: Config; - // eslint-disable-next-line @typescript-eslint/no-explicit-any - let mockOpenAIClient: any; - - beforeEach(() => { - // Reset mocks - vi.clearAllMocks(); - - // Mock environment variables - vi.stubEnv('OPENAI_BASE_URL', ''); - - // Mock config - mockConfig = { - getContentGeneratorConfig: vi.fn().mockReturnValue({ - authType: 'openai', - enableOpenAILogging: false, - timeout: 120000, - maxRetries: 3, - samplingParams: { - temperature: 0.7, - max_tokens: 1000, - top_p: 0.9, - }, - }), - getCliVersion: vi.fn().mockReturnValue('1.0.0'), - } as unknown as Config; - - // Mock OpenAI client - mockOpenAIClient = { - chat: { - completions: { - create: vi.fn(), - }, - }, - embeddings: { - create: vi.fn(), - }, - }; - - vi.mocked(OpenAI).mockImplementation(() => mockOpenAIClient); - - // Create generator instance - const contentGeneratorConfig = { - model: 'gpt-4', - apiKey: 'test-key', - authType: AuthType.USE_OPENAI, - enableOpenAILogging: false, - timeout: 120000, - maxRetries: 3, - samplingParams: { - temperature: 0.7, - max_tokens: 1000, - top_p: 0.9, - }, - }; - generator = new OpenAIContentGenerator(contentGeneratorConfig, mockConfig); - }); - - afterEach(() => { - vi.restoreAllMocks(); - }); - - describe('constructor', () => { - it('should initialize with basic configuration', () => { - expect(OpenAI).toHaveBeenCalledWith({ - apiKey: 'test-key', - baseURL: undefined, - timeout: 120000, - maxRetries: 3, - defaultHeaders: { - 'User-Agent': expect.stringMatching(/^QwenCode/), - }, - }); - }); - - it('should handle custom base URL', () => { - const contentGeneratorConfig = { - model: 'gpt-4', - apiKey: 'test-key', - baseUrl: 'https://api.custom.com', - authType: AuthType.USE_OPENAI, - enableOpenAILogging: false, - timeout: 120000, - maxRetries: 3, - }; - new OpenAIContentGenerator(contentGeneratorConfig, mockConfig); - - expect(OpenAI).toHaveBeenCalledWith({ - apiKey: 'test-key', - baseURL: 'https://api.custom.com', - timeout: 120000, - maxRetries: 3, - defaultHeaders: { - 'User-Agent': expect.stringMatching(/^QwenCode/), - }, - }); - }); - - it('should configure OpenRouter headers when using OpenRouter', () => { - const contentGeneratorConfig = { - model: 'gpt-4', - apiKey: 'test-key', - baseUrl: 'https://openrouter.ai/api/v1', - authType: AuthType.USE_OPENAI, - enableOpenAILogging: false, - timeout: 120000, - maxRetries: 3, - }; - new OpenAIContentGenerator(contentGeneratorConfig, mockConfig); - - expect(OpenAI).toHaveBeenCalledWith({ - apiKey: 'test-key', - baseURL: 'https://openrouter.ai/api/v1', - timeout: 120000, - maxRetries: 3, - defaultHeaders: { - 'User-Agent': expect.stringMatching(/^QwenCode/), - 'HTTP-Referer': 'https://github.com/QwenLM/qwen-code.git', - 'X-Title': 'Qwen Code', - }, - }); - }); - - it('should override timeout settings from config', () => { - const customConfig = { - getContentGeneratorConfig: vi.fn().mockReturnValue({ - timeout: 300000, - maxRetries: 5, - }), - getCliVersion: vi.fn().mockReturnValue('1.0.0'), - } as unknown as Config; - - const contentGeneratorConfig = { - model: 'gpt-4', - apiKey: 'test-key', - authType: AuthType.USE_OPENAI, - timeout: 300000, - maxRetries: 5, - }; - new OpenAIContentGenerator(contentGeneratorConfig, customConfig); - - expect(OpenAI).toHaveBeenCalledWith({ - apiKey: 'test-key', - baseURL: undefined, - timeout: 300000, - maxRetries: 5, - defaultHeaders: { - 'User-Agent': expect.stringMatching(/^QwenCode/), - }, - }); - }); - }); - - describe('generateContent', () => { - it('should generate content successfully', async () => { - const mockResponse = { - id: 'chatcmpl-123', - object: 'chat.completion', - created: 1677652288, - model: 'gpt-4', - choices: [ - { - index: 0, - message: { - role: 'assistant', - content: 'Hello! How can I help you?', - }, - finish_reason: 'stop', - }, - ], - usage: { - prompt_tokens: 10, - completion_tokens: 15, - total_tokens: 25, - }, - }; - - mockOpenAIClient.chat.completions.create.mockResolvedValue(mockResponse); - - const request: GenerateContentParameters = { - contents: [{ role: 'user', parts: [{ text: 'Hello' }] }], - model: 'gpt-4', - }; - - const result = await generator.generateContent(request, 'test-prompt-id'); - - expect(result.candidates).toHaveLength(1); - if ( - result.candidates && - result.candidates.length > 0 && - result.candidates[0] - ) { - const firstCandidate = result.candidates[0]; - if (firstCandidate.content) { - expect(firstCandidate.content.parts).toEqual([ - { text: 'Hello! How can I help you?' }, - ]); - } - } - expect(result.usageMetadata).toEqual({ - promptTokenCount: 10, - candidatesTokenCount: 15, - totalTokenCount: 25, - cachedContentTokenCount: 0, - }); - }); - - it('should handle system instructions', async () => { - const mockResponse = { - id: 'chatcmpl-123', - choices: [ - { - index: 0, - message: { role: 'assistant', content: 'Response' }, - finish_reason: 'stop', - }, - ], - created: 1677652288, - model: 'gpt-4', - }; - - mockOpenAIClient.chat.completions.create.mockResolvedValue(mockResponse); - - const request: GenerateContentParameters = { - contents: [{ role: 'user', parts: [{ text: 'Hello' }] }], - model: 'gpt-4', - config: { - systemInstruction: 'You are a helpful assistant.', - }, - }; - - await generator.generateContent(request, 'test-prompt-id'); - - expect(mockOpenAIClient.chat.completions.create).toHaveBeenCalledWith( - expect.objectContaining({ - messages: [ - { role: 'system', content: 'You are a helpful assistant.' }, - { role: 'user', content: 'Hello' }, - ], - }), - ); - }); - - it('should handle function calls', async () => { - const mockResponse = { - id: 'chatcmpl-123', - choices: [ - { - index: 0, - message: { - role: 'assistant', - content: null, - tool_calls: [ - { - id: 'call_123', - type: 'function', - function: { - name: 'get_weather', - arguments: '{"location": "New York"}', - }, - }, - ], - }, - finish_reason: 'tool_calls', - }, - ], - created: 1677652288, - model: 'gpt-4', - }; - - mockOpenAIClient.chat.completions.create.mockResolvedValue(mockResponse); - - const request: GenerateContentParameters = { - contents: [{ role: 'user', parts: [{ text: 'What is the weather?' }] }], - model: 'gpt-4', - config: { - tools: [ - { - callTool: vi.fn(), - tool: () => - Promise.resolve({ - functionDeclarations: [ - { - name: 'get_weather', - description: 'Get weather information', - parameters: { - type: Type.OBJECT, - properties: { location: { type: Type.STRING } }, - }, - }, - ], - }), - } as unknown as CallableTool, - ], - }, - }; - - const result = await generator.generateContent(request, 'test-prompt-id'); - - if ( - result.candidates && - result.candidates.length > 0 && - result.candidates[0] - ) { - const firstCandidate = result.candidates[0]; - if (firstCandidate.content) { - expect(firstCandidate.content.parts).toEqual([ - { - functionCall: { - id: 'call_123', - name: 'get_weather', - args: { location: 'New York' }, - }, - }, - ]); - } - } - }); - - it('should apply sampling parameters from config', async () => { - const mockResponse = { - id: 'chatcmpl-123', - choices: [ - { - index: 0, - message: { role: 'assistant', content: 'Response' }, - finish_reason: 'stop', - }, - ], - created: 1677652288, - model: 'gpt-4', - }; - - mockOpenAIClient.chat.completions.create.mockResolvedValue(mockResponse); - - const request: GenerateContentParameters = { - contents: [{ role: 'user', parts: [{ text: 'Hello' }] }], - model: 'gpt-4', - }; - - await generator.generateContent(request, 'test-prompt-id'); - - expect(mockOpenAIClient.chat.completions.create).toHaveBeenCalledWith( - expect.objectContaining({ - temperature: 0.7, - max_tokens: 1000, - top_p: 0.9, - }), - ); - }); - - it('should prioritize request-level parameters over config', async () => { - const mockResponse = { - id: 'chatcmpl-123', - choices: [ - { - index: 0, - message: { role: 'assistant', content: 'Response' }, - finish_reason: 'stop', - }, - ], - created: 1677652288, - model: 'gpt-4', - }; - - mockOpenAIClient.chat.completions.create.mockResolvedValue(mockResponse); - - const request: GenerateContentParameters = { - contents: [{ role: 'user', parts: [{ text: 'Hello' }] }], - model: 'gpt-4', - config: { - temperature: 0.5, - maxOutputTokens: 500, - }, - }; - - await generator.generateContent(request, 'test-prompt-id'); - - expect(mockOpenAIClient.chat.completions.create).toHaveBeenCalledWith( - expect.objectContaining({ - temperature: 0.7, // From config sampling params (higher priority) - max_tokens: 1000, // From config sampling params (higher priority) - top_p: 0.9, - }), - ); - }); - }); - - describe('generateContentStream', () => { - it('should handle streaming responses', async () => { - const mockStream = [ - { - id: 'chatcmpl-123', - choices: [ - { - index: 0, - delta: { content: 'Hello' }, - finish_reason: null, - }, - ], - created: 1677652288, - }, - { - id: 'chatcmpl-123', - choices: [ - { - index: 0, - delta: { content: ' there!' }, - finish_reason: 'stop', - }, - ], - created: 1677652288, - usage: { - prompt_tokens: 10, - completion_tokens: 5, - total_tokens: 15, - }, - }, - ]; - - // Mock async iterable - mockOpenAIClient.chat.completions.create.mockResolvedValue({ - async *[Symbol.asyncIterator]() { - for (const chunk of mockStream) { - yield chunk; - } - }, - }); - - const request: GenerateContentParameters = { - contents: [{ role: 'user', parts: [{ text: 'Hello' }] }], - model: 'gpt-4', - }; - - const stream = await generator.generateContentStream( - request, - 'test-prompt-id', - ); - const responses = []; - for await (const response of stream) { - responses.push(response); - } - - expect(responses).toHaveLength(2); - if ( - responses[0]?.candidates && - responses[0].candidates.length > 0 && - responses[0].candidates[0] - ) { - const firstCandidate = responses[0].candidates[0]; - if (firstCandidate.content) { - expect(firstCandidate.content.parts).toEqual([{ text: 'Hello' }]); - } - } - if ( - responses[1]?.candidates && - responses[1].candidates.length > 0 && - responses[1].candidates[0] - ) { - const secondCandidate = responses[1].candidates[0]; - if (secondCandidate.content) { - expect(secondCandidate.content.parts).toEqual([{ text: ' there!' }]); - } - } - expect(responses[1].usageMetadata).toEqual({ - promptTokenCount: 10, - candidatesTokenCount: 5, - totalTokenCount: 15, - cachedContentTokenCount: 0, - }); - }); - - it('should handle streaming tool calls', async () => { - const mockStream = [ - { - id: 'chatcmpl-123', - choices: [ - { - index: 0, - delta: { - tool_calls: [ - { - index: 0, - id: 'call_123', - function: { name: 'get_weather' }, - }, - ], - }, - finish_reason: null, - }, - ], - created: 1677652288, - }, - { - id: 'chatcmpl-123', - choices: [ - { - index: 0, - delta: { - tool_calls: [ - { - index: 0, - function: { arguments: '{"location": "NYC"}' }, - }, - ], - }, - finish_reason: 'tool_calls', - }, - ], - created: 1677652288, - }, - ]; - - mockOpenAIClient.chat.completions.create.mockResolvedValue({ - async *[Symbol.asyncIterator]() { - for (const chunk of mockStream) { - yield chunk; - } - }, - }); - - const request: GenerateContentParameters = { - contents: [{ role: 'user', parts: [{ text: 'Weather?' }] }], - model: 'gpt-4', - }; - - const stream = await generator.generateContentStream( - request, - 'test-prompt-id', - ); - const responses = []; - for await (const response of stream) { - responses.push(response); - } - - // First response should contain the complete tool call (accumulated from streaming) - if ( - responses[0]?.candidates && - responses[0].candidates.length > 0 && - responses[0].candidates[0] - ) { - const firstCandidate = responses[0].candidates[0]; - if (firstCandidate.content) { - expect(firstCandidate.content.parts).toEqual([ - { - functionCall: { - id: 'call_123', - name: 'get_weather', - args: { location: 'NYC' }, - }, - }, - ]); - } - } - if ( - responses[1]?.candidates && - responses[1].candidates.length > 0 && - responses[1].candidates[0] - ) { - const secondCandidate = responses[1].candidates[0]; - if (secondCandidate.content) { - expect(secondCandidate.content.parts).toEqual([ - { - functionCall: { - id: 'call_123', - name: 'get_weather', - args: { location: 'NYC' }, - }, - }, - ]); - } - } - }); - }); - - describe('countTokens', () => { - it('should count tokens using tiktoken', async () => { - const request: CountTokensParameters = { - contents: [{ role: 'user', parts: [{ text: 'Hello world' }] }], - model: 'gpt-4', - }; - - const result = await generator.countTokens(request); - - expect(result.totalTokens).toBe(50); // Mocked value - }); - - it('should fall back to character approximation if tiktoken fails', async () => { - // Mock tiktoken to throw error - vi.doMock('tiktoken', () => ({ - get_encoding: vi.fn().mockImplementation(() => { - throw new Error('Tiktoken failed'); - }), - })); - - const request: CountTokensParameters = { - contents: [{ role: 'user', parts: [{ text: 'Hello world' }] }], - model: 'gpt-4', - }; - - const result = await generator.countTokens(request); - - // Should use character approximation (content length / 4) - expect(result.totalTokens).toBeGreaterThan(0); - }); - }); - - describe('embedContent', () => { - it('should generate embeddings for text content', async () => { - const mockEmbedding = { - data: [{ embedding: [0.1, 0.2, 0.3, 0.4] }], - model: 'text-embedding-ada-002', - usage: { prompt_tokens: 5, total_tokens: 5 }, - }; - - mockOpenAIClient.embeddings.create.mockResolvedValue(mockEmbedding); - - const request: EmbedContentParameters = { - contents: [{ role: 'user', parts: [{ text: 'Hello world' }] }], - model: 'text-embedding-ada-002', - }; - - const result = await generator.embedContent(request); - - expect(result.embeddings).toHaveLength(1); - expect(result.embeddings?.[0]?.values).toEqual([0.1, 0.2, 0.3, 0.4]); - expect(mockOpenAIClient.embeddings.create).toHaveBeenCalledWith({ - model: 'text-embedding-ada-002', - input: 'Hello world', - }); - }); - - it('should handle string content', async () => { - const mockEmbedding = { - data: [{ embedding: [0.1, 0.2] }], - }; - - mockOpenAIClient.embeddings.create.mockResolvedValue(mockEmbedding); - - const request: EmbedContentParameters = { - contents: 'Simple text', - model: 'text-embedding-ada-002', - }; - - await generator.embedContent(request); - - expect(mockOpenAIClient.embeddings.create).toHaveBeenCalledWith({ - model: 'text-embedding-ada-002', - input: 'Simple text', - }); - }); - - it('should handle embedding errors', async () => { - const error = new Error('Embedding failed'); - mockOpenAIClient.embeddings.create.mockRejectedValue(error); - - const request: EmbedContentParameters = { - contents: 'Test text', - model: 'text-embedding-ada-002', - }; - - await expect(generator.embedContent(request)).rejects.toThrow( - 'Embedding failed', - ); - }); - }); - - describe('error handling', () => { - it('should handle API errors with proper error message', async () => { - const apiError = new Error('Invalid API key'); - mockOpenAIClient.chat.completions.create.mockRejectedValue(apiError); - - const request: GenerateContentParameters = { - contents: [{ role: 'user', parts: [{ text: 'Hello' }] }], - model: 'gpt-4', - }; - - await expect( - generator.generateContent(request, 'test-prompt-id'), - ).rejects.toThrow('Invalid API key'); - }); - - it('should estimate tokens on error for telemetry', async () => { - const apiError = new Error('API error'); - mockOpenAIClient.chat.completions.create.mockRejectedValue(apiError); - - const request: GenerateContentParameters = { - contents: [{ role: 'user', parts: [{ text: 'Hello' }] }], - model: 'gpt-4', - }; - - try { - await generator.generateContent(request, 'test-prompt-id'); - } catch (error) { - // Error should be thrown but token estimation should have been attempted - expect(error).toBeInstanceOf(Error); - } - }); - - it('should preserve error status codes like 429', async () => { - // Create an error object with status property like OpenAI SDK would - const apiError = Object.assign(new Error('Rate limit exceeded'), { - status: 429, - }); - mockOpenAIClient.chat.completions.create.mockRejectedValue(apiError); - - const request: GenerateContentParameters = { - contents: [{ role: 'user', parts: [{ text: 'Hello' }] }], - model: 'gpt-4', - }; - - try { - await generator.generateContent(request, 'test-prompt-id'); - expect.fail('Expected error to be thrown'); - } catch (error: unknown) { - // Should throw the original error object with status preserved - expect((error as Error & { status: number }).message).toBe( - 'Rate limit exceeded', - ); - expect((error as Error & { status: number }).status).toBe(429); - } - }); - }); - - describe('message conversion', () => { - it('should convert function responses to tool messages', async () => { - const mockResponse = { - id: 'chatcmpl-123', - choices: [ - { - index: 0, - message: { role: 'assistant', content: 'Response' }, - finish_reason: 'stop', - }, - ], - created: 1677652288, - model: 'gpt-4', - }; - - mockOpenAIClient.chat.completions.create.mockResolvedValue(mockResponse); - - const request: GenerateContentParameters = { - contents: [ - { role: 'user', parts: [{ text: 'What is the weather?' }] }, - { - role: 'model', - parts: [ - { - functionCall: { - id: 'call_123', - name: 'get_weather', - args: { location: 'NYC' }, - }, - }, - ], - }, - { - role: 'user', - parts: [ - { - functionResponse: { - id: 'call_123', - name: 'get_weather', - response: { temperature: '72F', condition: 'sunny' }, - }, - }, - ], - }, - ], - model: 'gpt-4', - }; - - await generator.generateContent(request, 'test-prompt-id'); - - expect(mockOpenAIClient.chat.completions.create).toHaveBeenCalledWith( - expect.objectContaining({ - messages: expect.arrayContaining([ - { role: 'user', content: 'What is the weather?' }, - { - role: 'assistant', - content: null, - tool_calls: [ - { - id: 'call_123', - type: 'function', - function: { - name: 'get_weather', - arguments: '{"location":"NYC"}', - }, - }, - ], - }, - { - role: 'tool', - tool_call_id: 'call_123', - content: '{"temperature":"72F","condition":"sunny"}', - }, - ]), - }), - ); - }); - - it('should clean up orphaned tool calls', async () => { - const mockResponse = { - id: 'chatcmpl-123', - choices: [ - { - index: 0, - message: { role: 'assistant', content: 'Response' }, - finish_reason: 'stop', - }, - ], - created: 1677652288, - model: 'gpt-4', - }; - - mockOpenAIClient.chat.completions.create.mockResolvedValue(mockResponse); - - const request: GenerateContentParameters = { - contents: [ - { - role: 'model', - parts: [ - { - functionCall: { - id: 'call_orphaned', - name: 'orphaned_function', - args: {}, - }, - }, - ], - }, - // No corresponding function response - ], - model: 'gpt-4', - }; - - await generator.generateContent(request, 'test-prompt-id'); - - // Should not include the orphaned tool call - expect(mockOpenAIClient.chat.completions.create).toHaveBeenCalledWith( - expect.objectContaining({ - messages: [], // Empty because orphaned tool call was cleaned up - }), - ); - }); - }); - - describe('finish reason mapping', () => { - it('should map OpenAI finish reasons to Gemini format', async () => { - const testCases = [ - { openai: 'stop', expected: FinishReason.STOP }, - { openai: 'length', expected: FinishReason.MAX_TOKENS }, - { openai: 'content_filter', expected: FinishReason.SAFETY }, - { openai: 'function_call', expected: FinishReason.STOP }, - { openai: 'tool_calls', expected: FinishReason.STOP }, - ]; - - for (const testCase of testCases) { - const mockResponse = { - id: 'chatcmpl-123', - choices: [ - { - index: 0, - message: { role: 'assistant', content: 'Response' }, - finish_reason: testCase.openai, - }, - ], - created: 1677652288, - model: 'gpt-4', - }; - - mockOpenAIClient.chat.completions.create.mockResolvedValue( - mockResponse, - ); - - const request: GenerateContentParameters = { - contents: [{ role: 'user', parts: [{ text: 'Hello' }] }], - model: 'gpt-4', - }; - - const result = await generator.generateContent( - request, - 'test-prompt-id', - ); - if ( - result.candidates && - result.candidates.length > 0 && - result.candidates[0] - ) { - const firstCandidate = result.candidates[0]; - expect(firstCandidate.finishReason).toBe(testCase.expected); - } - } - }); - }); - - describe('logging integration', () => { - it('should log interactions when enabled', async () => { - const loggingConfig = { - getContentGeneratorConfig: vi.fn().mockReturnValue({ - enableOpenAILogging: true, - }), - getCliVersion: vi.fn().mockReturnValue('1.0.0'), - } as unknown as Config; - - const contentGeneratorConfig = { - model: 'gpt-4', - apiKey: 'test-key', - authType: AuthType.USE_OPENAI, - enableOpenAILogging: true, - }; - const loggingGenerator = new OpenAIContentGenerator( - contentGeneratorConfig, - loggingConfig, - ); - - const mockResponse = { - id: 'chatcmpl-123', - choices: [ - { - index: 0, - message: { role: 'assistant', content: 'Response' }, - finish_reason: 'stop', - }, - ], - created: 1677652288, - model: 'gpt-4', - }; - - mockOpenAIClient.chat.completions.create.mockResolvedValue(mockResponse); - - const request: GenerateContentParameters = { - contents: [{ role: 'user', parts: [{ text: 'Hello' }] }], - model: 'gpt-4', - }; - - await loggingGenerator.generateContent(request, 'test-prompt-id'); - - // Verify logging was called - const { openaiLogger } = await import('../utils/openaiLogger.js'); - expect(openaiLogger.logInteraction).toHaveBeenCalled(); - }); - }); - - describe('timeout error detection', () => { - it('should detect various timeout error patterns', async () => { - const timeoutErrors = [ - new Error('timeout'), - new Error('Request timed out'), - new Error('Connection timeout occurred'), - new Error('ETIMEDOUT'), - new Error('ESOCKETTIMEDOUT'), - { code: 'ETIMEDOUT', message: 'Connection timed out' }, - { type: 'timeout', message: 'Request timeout' }, - new Error('deadline exceeded'), - ]; - - for (const error of timeoutErrors) { - mockOpenAIClient.chat.completions.create.mockRejectedValueOnce(error); - - const request: GenerateContentParameters = { - contents: [{ role: 'user', parts: [{ text: 'Hello' }] }], - model: 'gpt-4', - }; - - try { - await generator.generateContent(request, 'test-prompt-id'); - // Should not reach here - expect(true).toBe(false); - } catch (error) { - const errorMessage = - error instanceof Error ? error.message : String(error); - expect(errorMessage).toMatch(/timeout|Troubleshooting tips/); - } - } - }); - - it('should provide timeout-specific error messages', async () => { - const timeoutError = new Error('Request timeout'); - mockOpenAIClient.chat.completions.create.mockRejectedValue(timeoutError); - - const request: GenerateContentParameters = { - contents: [{ role: 'user', parts: [{ text: 'Hello' }] }], - model: 'gpt-4', - }; - - await expect( - generator.generateContent(request, 'test-prompt-id'), - ).rejects.toThrow( - /Troubleshooting tips.*Reduce input length.*Increase timeout.*Check network/s, - ); - }); - }); - - describe('streaming error handling', () => { - it('should handle errors during streaming setup', async () => { - const setupError = new Error('Streaming setup failed'); - mockOpenAIClient.chat.completions.create.mockRejectedValue(setupError); - - const request: GenerateContentParameters = { - contents: [{ role: 'user', parts: [{ text: 'Hello' }] }], - model: 'gpt-4', - }; - - await expect( - generator.generateContent(request, 'test-prompt-id'), - ).rejects.toThrow('Streaming setup failed'); - }); - - it('should handle timeout errors during streaming setup', async () => { - const timeoutError = new Error('Streaming setup timeout'); - mockOpenAIClient.chat.completions.create.mockRejectedValue(timeoutError); - - const request: GenerateContentParameters = { - contents: [{ role: 'user', parts: [{ text: 'Hello' }] }], - model: 'gpt-4', - }; - - await expect( - generator.generateContentStream(request, 'test-prompt-id'), - ).rejects.toThrow( - /Streaming setup timeout troubleshooting.*Reduce input length/s, - ); - }); - - it('should handle errors during streaming with logging', async () => { - const loggingConfig = { - getContentGeneratorConfig: vi.fn().mockReturnValue({ - enableOpenAILogging: true, - }), - getCliVersion: vi.fn().mockReturnValue('1.0.0'), - } as unknown as Config; - - const contentGeneratorConfig = { - model: 'gpt-4', - apiKey: 'test-key', - authType: AuthType.USE_OPENAI, - enableOpenAILogging: true, - }; - const loggingGenerator = new OpenAIContentGenerator( - contentGeneratorConfig, - loggingConfig, - ); - - // Mock stream that throws an error - const mockStream = { - async *[Symbol.asyncIterator]() { - yield { - id: 'chatcmpl-123', - choices: [ - { - index: 0, - delta: { content: 'Hello' }, - finish_reason: null, - }, - ], - created: 1677652288, - }; - throw new Error('Stream error'); - }, - }; - - mockOpenAIClient.chat.completions.create.mockResolvedValue(mockStream); - - const request: GenerateContentParameters = { - contents: [{ role: 'user', parts: [{ text: 'Hello' }] }], - model: 'gpt-4', - }; - - const stream = await loggingGenerator.generateContentStream( - request, - 'test-prompt-id', - ); - - // Consume the stream and expect error - await expect(async () => { - for await (const chunk of stream) { - // Stream will throw during iteration - console.log('Processing chunk:', chunk); // Use chunk to avoid warning - } - }).rejects.toThrow('Stream error'); - }); - }); - - describe('tool parameter conversion', () => { - it('should convert Gemini types to OpenAI JSON Schema types', async () => { - const mockResponse = { - id: 'chatcmpl-123', - choices: [ - { - index: 0, - message: { role: 'assistant', content: 'Response' }, - finish_reason: 'stop', - }, - ], - created: 1677652288, - model: 'gpt-4', - }; - - mockOpenAIClient.chat.completions.create.mockResolvedValue(mockResponse); - - const request: GenerateContentParameters = { - contents: [{ role: 'user', parts: [{ text: 'Test' }] }], - model: 'gpt-4', - config: { - tools: [ - { - callTool: vi.fn(), - tool: () => - Promise.resolve({ - functionDeclarations: [ - { - name: 'test_function', - description: 'Test function', - parameters: { - type: 'OBJECT', - properties: { - count: { - type: 'INTEGER', - minimum: '1', - maximum: '100', - }, - name: { - type: 'STRING', - minLength: '1', - maxLength: '50', - }, - score: { type: 'NUMBER', multipleOf: '0.1' }, - items: { - type: 'ARRAY', - minItems: '1', - maxItems: '10', - }, - }, - }, - }, - ], - }), - } as unknown as CallableTool, - ], - }, - }; - - await generator.generateContent(request, 'test-prompt-id'); - - expect(mockOpenAIClient.chat.completions.create).toHaveBeenCalledWith( - expect.objectContaining({ - tools: [ - { - type: 'function', - function: { - name: 'test_function', - description: 'Test function', - parameters: { - type: 'object', - properties: { - count: { type: 'integer', minimum: 1, maximum: 100 }, - name: { type: 'string', minLength: 1, maxLength: 50 }, - score: { type: 'number', multipleOf: 0.1 }, - items: { type: 'array', minItems: 1, maxItems: 10 }, - }, - }, - }, - }, - ], - }), - ); - }); - - it('should handle MCP tools with parametersJsonSchema', async () => { - const mockResponse = { - id: 'chatcmpl-123', - choices: [ - { - index: 0, - message: { role: 'assistant', content: 'Response' }, - finish_reason: 'stop', - }, - ], - created: 1677652288, - model: 'gpt-4', - }; - - mockOpenAIClient.chat.completions.create.mockResolvedValue(mockResponse); - - const request: GenerateContentParameters = { - contents: [{ role: 'user', parts: [{ text: 'Test' }] }], - model: 'gpt-4', - config: { - tools: [ - { - callTool: vi.fn(), - tool: () => - Promise.resolve({ - functionDeclarations: [ - { - name: 'list-items', - description: 'Get a list of items', - parametersJsonSchema: { - type: 'object', - properties: { - page_number: { - type: 'number', - description: 'Page number', - }, - page_size: { - type: 'number', - description: 'Number of items per page', - }, - }, - additionalProperties: false, - $schema: 'http://json-schema.org/draft-07/schema#', - }, - }, - ], - }), - } as unknown as CallableTool, - ], - }, - }; - - await generator.generateContent(request, 'test-prompt-id'); - - expect(mockOpenAIClient.chat.completions.create).toHaveBeenCalledWith( - expect.objectContaining({ - tools: [ - { - type: 'function', - function: { - name: 'list-items', - description: 'Get a list of items', - parameters: { - type: 'object', - properties: { - page_number: { - type: 'number', - description: 'Page number', - }, - page_size: { - type: 'number', - description: 'Number of items per page', - }, - }, - additionalProperties: false, - $schema: 'http://json-schema.org/draft-07/schema#', - }, - }, - }, - ], - }), - ); - }); - - it('should handle nested parameter objects', async () => { - const mockResponse = { - id: 'chatcmpl-123', - choices: [ - { - index: 0, - message: { role: 'assistant', content: 'Response' }, - finish_reason: 'stop', - }, - ], - created: 1677652288, - model: 'gpt-4', - }; - - mockOpenAIClient.chat.completions.create.mockResolvedValue(mockResponse); - - const request: GenerateContentParameters = { - contents: [{ role: 'user', parts: [{ text: 'Test' }] }], - model: 'gpt-4', - config: { - tools: [ - { - callTool: vi.fn(), - tool: () => - Promise.resolve({ - functionDeclarations: [ - { - name: 'nested_function', - description: 'Function with nested parameters', - parameters: { - type: 'OBJECT', - properties: { - config: { - type: 'OBJECT', - properties: { - nested_count: { type: 'INTEGER' }, - nested_array: { - type: 'ARRAY', - items: { type: 'STRING' }, - }, - }, - }, - }, - }, - }, - ], - }), - } as unknown as CallableTool, - ], - }, - }; - - await generator.generateContent(request, 'test-prompt-id'); - - expect(mockOpenAIClient.chat.completions.create).toHaveBeenCalledWith( - expect.objectContaining({ - tools: [ - { - type: 'function', - function: { - name: 'nested_function', - description: 'Function with nested parameters', - parameters: { - type: 'object', - properties: { - config: { - type: 'object', - properties: { - nested_count: { type: 'integer' }, - nested_array: { - type: 'array', - items: { type: 'string' }, - }, - }, - }, - }, - }, - }, - }, - ], - }), - ); - }); - }); - - describe('message cleanup and conversion', () => { - it('should handle complex conversation with multiple tool calls', async () => { - const mockResponse = { - id: 'chatcmpl-123', - choices: [ - { - index: 0, - message: { role: 'assistant', content: 'Response' }, - finish_reason: 'stop', - }, - ], - created: 1677652288, - model: 'gpt-4', - }; - - mockOpenAIClient.chat.completions.create.mockResolvedValue(mockResponse); - - const request: GenerateContentParameters = { - contents: [ - { role: 'user', parts: [{ text: 'What tools are available?' }] }, - { - role: 'model', - parts: [ - { - functionCall: { - id: 'call_1', - name: 'list_tools', - args: { category: 'all' }, - }, - }, - ], - }, - { - role: 'user', - parts: [ - { - functionResponse: { - id: 'call_1', - name: 'list_tools', - response: { tools: ['calculator', 'weather'] }, - }, - }, - ], - }, - { - role: 'model', - parts: [ - { - functionCall: { - id: 'call_2', - name: 'get_weather', - args: { location: 'NYC' }, - }, - }, - ], - }, - { - role: 'user', - parts: [ - { - functionResponse: { - id: 'call_2', - name: 'get_weather', - response: { temperature: '22°C', condition: 'sunny' }, - }, - }, - ], - }, - ], - model: 'gpt-4', - }; - - await generator.generateContent(request, 'test-prompt-id'); - - expect(mockOpenAIClient.chat.completions.create).toHaveBeenCalledWith( - expect.objectContaining({ - messages: [ - { role: 'user', content: 'What tools are available?' }, - { - role: 'assistant', - content: null, - tool_calls: [ - { - id: 'call_1', - type: 'function', - function: { - name: 'list_tools', - arguments: '{"category":"all"}', - }, - }, - ], - }, - { - role: 'tool', - tool_call_id: 'call_1', - content: '{"tools":["calculator","weather"]}', - }, - { - role: 'assistant', - content: null, - tool_calls: [ - { - id: 'call_2', - type: 'function', - function: { - name: 'get_weather', - arguments: '{"location":"NYC"}', - }, - }, - ], - }, - { - role: 'tool', - tool_call_id: 'call_2', - content: '{"temperature":"22°C","condition":"sunny"}', - }, - ], - }), - ); - }); - - it('should clean up orphaned tool calls without corresponding responses', async () => { - const mockResponse = { - id: 'chatcmpl-123', - choices: [ - { - index: 0, - message: { role: 'assistant', content: 'Response' }, - finish_reason: 'stop', - }, - ], - created: 1677652288, - model: 'gpt-4', - }; - - mockOpenAIClient.chat.completions.create.mockResolvedValue(mockResponse); - - const request: GenerateContentParameters = { - contents: [ - { role: 'user', parts: [{ text: 'Test' }] }, - { - role: 'model', - parts: [ - { - functionCall: { - id: 'call_orphaned', - name: 'orphaned_function', - args: {}, - }, - }, - ], - }, - { - role: 'model', - parts: [ - { - functionCall: { - id: 'call_valid', - name: 'valid_function', - args: {}, - }, - }, - ], - }, - { - role: 'user', - parts: [ - { - functionResponse: { - id: 'call_valid', - name: 'valid_function', - response: { result: 'success' }, - }, - }, - ], - }, - ], - model: 'gpt-4', - }; - - await generator.generateContent(request, 'test-prompt-id'); - - expect(mockOpenAIClient.chat.completions.create).toHaveBeenCalledWith( - expect.objectContaining({ - messages: [ - { role: 'user', content: 'Test' }, - { - role: 'assistant', - content: null, - tool_calls: [ - { - id: 'call_valid', - type: 'function', - function: { - name: 'valid_function', - arguments: '{}', - }, - }, - ], - }, - { - role: 'tool', - tool_call_id: 'call_valid', - content: '{"result":"success"}', - }, - ], - }), - ); - }); - - it('should merge consecutive assistant messages', async () => { - const mockResponse = { - id: 'chatcmpl-123', - choices: [ - { - index: 0, - message: { role: 'assistant', content: 'Response' }, - finish_reason: 'stop', - }, - ], - created: 1677652288, - model: 'gpt-4', - }; - - mockOpenAIClient.chat.completions.create.mockResolvedValue(mockResponse); - - const request: GenerateContentParameters = { - contents: [ - { role: 'user', parts: [{ text: 'Hello' }] }, - { role: 'model', parts: [{ text: 'Part 1' }] }, - { role: 'model', parts: [{ text: 'Part 2' }] }, - { role: 'user', parts: [{ text: 'Continue' }] }, - ], - model: 'gpt-4', - }; - - await generator.generateContent(request, 'test-prompt-id'); - - expect(mockOpenAIClient.chat.completions.create).toHaveBeenCalledWith( - expect.objectContaining({ - messages: [ - { role: 'user', content: 'Hello' }, - { role: 'assistant', content: 'Part 1Part 2' }, - { role: 'user', content: 'Continue' }, - ], - }), - ); - }); - }); - - describe('error suppression functionality', () => { - it('should allow subclasses to suppress error logging', async () => { - class TestGenerator extends OpenAIContentGenerator { - protected override shouldSuppressErrorLogging(): boolean { - return true; // Always suppress for this test - } - } - - const contentGeneratorConfig = { - model: 'gpt-4', - apiKey: 'test-key', - authType: AuthType.USE_OPENAI, - enableOpenAILogging: false, - timeout: 120000, - maxRetries: 3, - samplingParams: { - temperature: 0.7, - max_tokens: 1000, - top_p: 0.9, - }, - }; - const testGenerator = new TestGenerator( - contentGeneratorConfig, - mockConfig, - ); - const consoleSpy = vi - .spyOn(console, 'error') - .mockImplementation(() => {}); - - const apiError = new Error('Test error'); - mockOpenAIClient.chat.completions.create.mockRejectedValue(apiError); - - const request: GenerateContentParameters = { - contents: [{ role: 'user', parts: [{ text: 'Hello' }] }], - model: 'gpt-4', - }; - - await expect( - testGenerator.generateContent(request, 'test-prompt-id'), - ).rejects.toThrow(); - - // Error logging should be suppressed - expect(consoleSpy).not.toHaveBeenCalledWith( - 'OpenAI API Error:', - expect.any(String), - ); - - consoleSpy.mockRestore(); - }); - - it('should log errors when not suppressed', async () => { - const consoleSpy = vi - .spyOn(console, 'error') - .mockImplementation(() => {}); - - const apiError = new Error('Test error'); - mockOpenAIClient.chat.completions.create.mockRejectedValue(apiError); - - const request: GenerateContentParameters = { - contents: [{ role: 'user', parts: [{ text: 'Hello' }] }], - model: 'gpt-4', - }; - - await expect( - generator.generateContent(request, 'test-prompt-id'), - ).rejects.toThrow(); - - // Error logging should occur by default - expect(consoleSpy).toHaveBeenCalledWith( - 'OpenAI API Error:', - 'Test error', - ); - - consoleSpy.mockRestore(); - }); - }); - - describe('edge cases and error scenarios', () => { - it('should handle malformed tool call arguments', async () => { - const mockResponse = { - id: 'chatcmpl-123', - choices: [ - { - index: 0, - message: { - role: 'assistant', - content: null, - tool_calls: [ - { - id: 'call_123', - type: 'function', - function: { - name: 'test_function', - arguments: 'invalid json{', - }, - }, - ], - }, - finish_reason: 'tool_calls', - }, - ], - created: 1677652288, - model: 'gpt-4', - }; - - mockOpenAIClient.chat.completions.create.mockResolvedValue(mockResponse); - - const request: GenerateContentParameters = { - contents: [{ role: 'user', parts: [{ text: 'Test' }] }], - model: 'gpt-4', - }; - - const result = await generator.generateContent(request, 'test-prompt-id'); - - // Should handle malformed JSON gracefully - if ( - result.candidates && - result.candidates.length > 0 && - result.candidates[0] - ) { - const firstCandidate = result.candidates[0]; - if (firstCandidate.content) { - expect(firstCandidate.content.parts).toEqual([ - { - functionCall: { - id: 'call_123', - name: 'test_function', - args: {}, // Should default to empty object - }, - }, - ]); - } - } - }); - - it('should handle streaming with malformed tool call arguments', async () => { - const mockStream = [ - { - id: 'chatcmpl-123', - choices: [ - { - index: 0, - delta: { - tool_calls: [ - { - index: 0, - id: 'call_123', - function: { name: 'test_function' }, - }, - ], - }, - finish_reason: null, - }, - ], - created: 1677652288, - }, - { - id: 'chatcmpl-123', - choices: [ - { - index: 0, - delta: { - tool_calls: [ - { - index: 0, - function: { arguments: 'invalid json{' }, - }, - ], - }, - finish_reason: 'tool_calls', - }, - ], - created: 1677652288, - }, - ]; - - mockOpenAIClient.chat.completions.create.mockResolvedValue({ - async *[Symbol.asyncIterator]() { - for (const chunk of mockStream) { - yield chunk; - } - }, - }); - - const request: GenerateContentParameters = { - contents: [{ role: 'user', parts: [{ text: 'Test' }] }], - model: 'gpt-4', - }; - - const stream = await generator.generateContentStream( - request, - 'test-prompt-id', - ); - const responses = []; - for await (const response of stream) { - responses.push(response); - } - - // Should handle malformed JSON in streaming gracefully - const finalResponse = responses[responses.length - 1]; - if ( - finalResponse.candidates && - finalResponse.candidates.length > 0 && - finalResponse.candidates[0] - ) { - const firstCandidate = finalResponse.candidates[0]; - if (firstCandidate.content) { - expect(firstCandidate.content.parts).toEqual([ - { - functionCall: { - id: 'call_123', - name: 'test_function', - args: {}, // Should default to empty object - }, - }, - ]); - } - } - }); - - it('should handle empty or null content gracefully', async () => { - const mockResponse = { - id: 'chatcmpl-123', - choices: [ - { - index: 0, - message: { role: 'assistant', content: null }, - finish_reason: 'stop', - }, - ], - created: 1677652288, - model: 'gpt-4', - }; - - mockOpenAIClient.chat.completions.create.mockResolvedValue(mockResponse); - - const request: GenerateContentParameters = { - contents: [], - model: 'gpt-4', - }; - - const result = await generator.generateContent(request, 'test-prompt-id'); - - expect(result.candidates).toHaveLength(1); - if ( - result.candidates && - result.candidates.length > 0 && - result.candidates[0] - ) { - const firstCandidate = result.candidates[0]; - if (firstCandidate.content) { - expect(firstCandidate.content.parts).toEqual([]); - } - } - }); - - it('should handle usage metadata estimation when breakdown is missing', async () => { - const mockResponse = { - id: 'chatcmpl-123', - choices: [ - { - index: 0, - message: { role: 'assistant', content: 'Test response' }, - finish_reason: 'stop', - }, - ], - created: 1677652288, - model: 'gpt-4', - usage: { - prompt_tokens: 0, - completion_tokens: 0, - total_tokens: 100, - }, - }; - - mockOpenAIClient.chat.completions.create.mockResolvedValue(mockResponse); - - const request: GenerateContentParameters = { - contents: [{ role: 'user', parts: [{ text: 'Hello' }] }], - model: 'gpt-4', - }; - - const result = await generator.generateContent(request, 'test-prompt-id'); - - expect(result.usageMetadata).toEqual({ - promptTokenCount: 70, // 70% of 100 - candidatesTokenCount: 30, // 30% of 100 - totalTokenCount: 100, - cachedContentTokenCount: 0, - }); - }); - - it('should handle cached token metadata', async () => { - const mockResponse = { - id: 'chatcmpl-123', - choices: [ - { - index: 0, - message: { role: 'assistant', content: 'Test response' }, - finish_reason: 'stop', - }, - ], - created: 1677652288, - model: 'gpt-4', - usage: { - prompt_tokens: 50, - completion_tokens: 25, - total_tokens: 75, - prompt_tokens_details: { - cached_tokens: 10, - }, - }, - }; - - mockOpenAIClient.chat.completions.create.mockResolvedValue(mockResponse); - - const request: GenerateContentParameters = { - contents: [{ role: 'user', parts: [{ text: 'Hello' }] }], - model: 'gpt-4', - }; - - const result = await generator.generateContent(request, 'test-prompt-id'); - - expect(result.usageMetadata).toEqual({ - promptTokenCount: 50, - candidatesTokenCount: 25, - totalTokenCount: 75, - cachedContentTokenCount: 10, - }); - }); - }); - - describe('request/response logging conversion', () => { - it('should convert complex Gemini request to OpenAI format for logging', async () => { - const loggingConfig = { - getContentGeneratorConfig: vi.fn().mockReturnValue({ - enableOpenAILogging: true, - samplingParams: { - temperature: 0.8, - max_tokens: 500, - }, - }), - getCliVersion: vi.fn().mockReturnValue('1.0.0'), - } as unknown as Config; - - const contentGeneratorConfig = { - model: 'gpt-4', - apiKey: 'test-key', - authType: AuthType.USE_OPENAI, - enableOpenAILogging: true, - samplingParams: { - temperature: 0.8, - max_tokens: 500, - }, - }; - const loggingGenerator = new OpenAIContentGenerator( - contentGeneratorConfig, - loggingConfig, - ); - - const mockResponse = { - id: 'chatcmpl-123', - choices: [ - { - index: 0, - message: { - role: 'assistant', - content: null, - tool_calls: [ - { - id: 'call_123', - type: 'function', - function: { - name: 'test_function', - arguments: '{"param":"value"}', - }, - }, - ], - }, - finish_reason: 'tool_calls', - }, - ], - created: 1677652288, - model: 'gpt-4', - usage: { - prompt_tokens: 100, - completion_tokens: 50, - total_tokens: 150, - }, - }; - - mockOpenAIClient.chat.completions.create.mockResolvedValue(mockResponse); - - const request: GenerateContentParameters = { - contents: [ - { role: 'user', parts: [{ text: 'Test complex request' }] }, - { - role: 'model', - parts: [ - { - functionCall: { - id: 'prev_call', - name: 'previous_function', - args: { data: 'test' }, - }, - }, - ], - }, - { - role: 'user', - parts: [ - { - functionResponse: { - id: 'prev_call', - name: 'previous_function', - response: { result: 'success' }, - }, - }, - ], - }, - ], - model: 'gpt-4', - config: { - systemInstruction: 'You are a helpful assistant', - temperature: 0.9, - tools: [ - { - callTool: vi.fn(), - tool: () => - Promise.resolve({ - functionDeclarations: [ - { - name: 'test_function', - description: 'Test function', - parameters: { type: 'object' }, - }, - ], - }), - } as unknown as CallableTool, - ], - }, - }; - - await loggingGenerator.generateContent(request, 'test-prompt-id'); - - // Verify that logging was called with properly converted request/response - const { openaiLogger } = await import('../utils/openaiLogger.js'); - expect(openaiLogger.logInteraction).toHaveBeenCalledWith( - expect.objectContaining({ - model: 'gpt-4', - messages: [ - { - role: 'system', - content: 'You are a helpful assistant', - }, - { - role: 'user', - content: 'Test complex request', - }, - { - role: 'assistant', - content: null, - tool_calls: [ - { - id: 'prev_call', - type: 'function', - function: { - name: 'previous_function', - arguments: '{"data":"test"}', - }, - }, - ], - }, - { - role: 'tool', - tool_call_id: 'prev_call', - content: '{"result":"success"}', - }, - ], - temperature: 0.8, // Config override - max_tokens: 500, // Config override - top_p: 1, // Default value - tools: [ - { - type: 'function', - function: { - name: 'test_function', - description: 'Test function', - parameters: { - type: 'object', - }, - }, - }, - ], - }), - expect.objectContaining({ - id: 'chatcmpl-123', - object: 'chat.completion', - created: 1677652288, - model: 'gpt-4', - choices: [ - { - index: 0, - message: { - role: 'assistant', - content: '', - tool_calls: [ - { - id: 'call_123', - type: 'function', - function: { - name: 'test_function', - arguments: '{"param":"value"}', - }, - }, - ], - }, - finish_reason: 'stop', - }, - ], - usage: { - prompt_tokens: 100, - completion_tokens: 50, - total_tokens: 150, - }, - }), - ); - }); - }); - - describe('advanced streaming scenarios', () => { - it('should combine streaming responses correctly for logging', async () => { - const loggingConfig = { - getContentGeneratorConfig: vi.fn().mockReturnValue({ - enableOpenAILogging: true, - }), - getCliVersion: vi.fn().mockReturnValue('1.0.0'), - } as unknown as Config; - - const contentGeneratorConfig = { - model: 'gpt-4', - apiKey: 'test-key', - authType: AuthType.USE_OPENAI, - enableOpenAILogging: true, - }; - const loggingGenerator = new OpenAIContentGenerator( - contentGeneratorConfig, - loggingConfig, - ); - - const mockStream = [ - { - id: 'chatcmpl-123', - choices: [ - { - index: 0, - delta: { content: 'Hello' }, - finish_reason: null, - }, - ], - created: 1677652288, - }, - { - id: 'chatcmpl-123', - choices: [ - { - index: 0, - delta: { content: ' world' }, - finish_reason: null, - }, - ], - created: 1677652288, - }, - { - id: 'chatcmpl-123', - choices: [ - { - index: 0, - delta: {}, - finish_reason: 'stop', - }, - ], - created: 1677652288, - usage: { - prompt_tokens: 10, - completion_tokens: 5, - total_tokens: 15, - }, - }, - ]; - - mockOpenAIClient.chat.completions.create.mockResolvedValue({ - async *[Symbol.asyncIterator]() { - for (const chunk of mockStream) { - yield chunk; - } - }, - }); - - const request: GenerateContentParameters = { - contents: [{ role: 'user', parts: [{ text: 'Hello' }] }], - model: 'gpt-4', - }; - - const stream = await loggingGenerator.generateContentStream( - request, - 'test-prompt-id', - ); - const responses = []; - for await (const response of stream) { - responses.push(response); - } - - // Verify logging was called with combined content - const { openaiLogger } = await import('../utils/openaiLogger.js'); - expect(openaiLogger.logInteraction).toHaveBeenCalledWith( - expect.any(Object), - expect.objectContaining({ - choices: [ - expect.objectContaining({ - message: expect.objectContaining({ - content: 'Hello world', // Combined text - }), - }), - ], - }), - ); - }); - - it('should handle streaming without choices', async () => { - const mockStream = [ - { - id: 'chatcmpl-123', - choices: [], - created: 1677652288, - }, - ]; - - mockOpenAIClient.chat.completions.create.mockResolvedValue({ - async *[Symbol.asyncIterator]() { - for (const chunk of mockStream) { - yield chunk; - } - }, - }); - - const request: GenerateContentParameters = { - contents: [{ role: 'user', parts: [{ text: 'Hello' }] }], - model: 'gpt-4', - }; - - const stream = await generator.generateContentStream( - request, - 'test-prompt-id', - ); - const responses = []; - for await (const response of stream) { - responses.push(response); - } - - expect(responses).toHaveLength(1); - expect(responses[0].candidates).toEqual([]); - }); - }); - - describe('embed content edge cases', () => { - it('should handle mixed content types in embed request', async () => { - const mockEmbedding = { - data: [{ embedding: [0.1, 0.2, 0.3] }], - model: 'text-embedding-ada-002', - usage: { prompt_tokens: 5, total_tokens: 5 }, - }; - - mockOpenAIClient.embeddings.create.mockResolvedValue(mockEmbedding); - - const request: EmbedContentParameters = { - contents: 'Hello world Direct string Another part', - model: 'text-embedding-ada-002', - }; - - const result = await generator.embedContent(request); - - expect(mockOpenAIClient.embeddings.create).toHaveBeenCalledWith({ - model: 'text-embedding-ada-002', - input: 'Hello world Direct string Another part', - }); - - expect(result.embeddings).toHaveLength(1); - expect(result.embeddings?.[0]?.values).toEqual([0.1, 0.2, 0.3]); - }); - - it('should handle empty content in embed request', async () => { - const mockEmbedding = { - data: [{ embedding: [] }], - }; - - mockOpenAIClient.embeddings.create.mockResolvedValue(mockEmbedding); - - const request: EmbedContentParameters = { - contents: [], - model: 'text-embedding-ada-002', - }; - - await generator.embedContent(request); - - expect(mockOpenAIClient.embeddings.create).toHaveBeenCalledWith({ - model: 'text-embedding-ada-002', - input: '', - }); - }); - }); - - describe('system instruction edge cases', () => { - it('should handle array system instructions', async () => { - const mockResponse = { - id: 'chatcmpl-123', - choices: [ - { - index: 0, - message: { role: 'assistant', content: 'Response' }, - finish_reason: 'stop', - }, - ], - created: 1677652288, - model: 'gpt-4', - }; - - mockOpenAIClient.chat.completions.create.mockResolvedValue(mockResponse); - - const request: GenerateContentParameters = { - contents: [{ role: 'user', parts: [{ text: 'Hello' }] }], - model: 'gpt-4', - config: { - systemInstruction: 'You are helpful\nBe concise', - }, - }; - - await generator.generateContent(request, 'test-prompt-id'); - - expect(mockOpenAIClient.chat.completions.create).toHaveBeenCalledWith( - expect.objectContaining({ - messages: [ - { role: 'system', content: 'You are helpful\nBe concise' }, - { role: 'user', content: 'Hello' }, - ], - }), - ); - }); - - it('should handle object system instruction', async () => { - const mockResponse = { - id: 'chatcmpl-123', - choices: [ - { - index: 0, - message: { role: 'assistant', content: 'Response' }, - finish_reason: 'stop', - }, - ], - created: 1677652288, - model: 'gpt-4', - }; - - mockOpenAIClient.chat.completions.create.mockResolvedValue(mockResponse); - - const request: GenerateContentParameters = { - contents: [{ role: 'user', parts: [{ text: 'Hello' }] }], - model: 'gpt-4', - config: { - systemInstruction: { - parts: [{ text: 'System message' }, { text: 'Additional text' }], - } as Content, - }, - }; - - await generator.generateContent(request, 'test-prompt-id'); - - expect(mockOpenAIClient.chat.completions.create).toHaveBeenCalledWith( - expect.objectContaining({ - messages: [ - { role: 'system', content: 'System message\nAdditional text' }, - { role: 'user', content: 'Hello' }, - ], - }), - ); - }); - }); - - describe('sampling parameters edge cases', () => { - it('should handle undefined sampling parameters gracefully', async () => { - const configWithUndefined = { - getContentGeneratorConfig: vi.fn().mockReturnValue({ - samplingParams: { - temperature: undefined, - max_tokens: undefined, - top_p: undefined, - }, - }), - getCliVersion: vi.fn().mockReturnValue('1.0.0'), - } as unknown as Config; - - const contentGeneratorConfig = { - model: 'gpt-4', - apiKey: 'test-key', - authType: AuthType.USE_OPENAI, - samplingParams: { - temperature: undefined, - max_tokens: undefined, - top_p: undefined, - }, - }; - const testGenerator = new OpenAIContentGenerator( - contentGeneratorConfig, - configWithUndefined, - ); - - const mockResponse = { - id: 'chatcmpl-123', - choices: [ - { - index: 0, - message: { role: 'assistant', content: 'Response' }, - finish_reason: 'stop', - }, - ], - created: 1677652288, - model: 'gpt-4', - }; - - mockOpenAIClient.chat.completions.create.mockResolvedValue(mockResponse); - - const request: GenerateContentParameters = { - contents: [{ role: 'user', parts: [{ text: 'Hello' }] }], - model: 'gpt-4', - config: { - temperature: undefined, - maxOutputTokens: undefined, - topP: undefined, - }, - }; - - await testGenerator.generateContent(request, 'test-prompt-id'); - - expect(mockOpenAIClient.chat.completions.create).toHaveBeenCalledWith( - expect.objectContaining({ - temperature: 0.0, // Default value - top_p: 1.0, // Default value - // max_tokens should not be present when undefined - }), - ); - }); - - it('should handle all config-level sampling parameters', async () => { - const fullSamplingConfig = { - getContentGeneratorConfig: vi.fn().mockReturnValue({ - samplingParams: { - temperature: 0.8, - max_tokens: 1500, - top_p: 0.95, - top_k: 40, - repetition_penalty: 1.1, - presence_penalty: 0.5, - frequency_penalty: 0.3, - }, - }), - getCliVersion: vi.fn().mockReturnValue('1.0.0'), - } as unknown as Config; - - const contentGeneratorConfig = { - model: 'gpt-4', - apiKey: 'test-key', - authType: AuthType.USE_OPENAI, - samplingParams: { - temperature: 0.8, - max_tokens: 1500, - top_p: 0.95, - top_k: 40, - repetition_penalty: 1.1, - presence_penalty: 0.5, - frequency_penalty: 0.3, - }, - }; - const testGenerator = new OpenAIContentGenerator( - contentGeneratorConfig, - fullSamplingConfig, - ); - - const mockResponse = { - id: 'chatcmpl-123', - choices: [ - { - index: 0, - message: { role: 'assistant', content: 'Response' }, - finish_reason: 'stop', - }, - ], - created: 1677652288, - model: 'gpt-4', - }; - - mockOpenAIClient.chat.completions.create.mockResolvedValue(mockResponse); - - const request: GenerateContentParameters = { - contents: [{ role: 'user', parts: [{ text: 'Hello' }] }], - model: 'gpt-4', - }; - - await testGenerator.generateContent(request, 'test-prompt-id'); - - expect(mockOpenAIClient.chat.completions.create).toHaveBeenCalledWith( - expect.objectContaining({ - temperature: 0.8, - max_tokens: 1500, - top_p: 0.95, - top_k: 40, - repetition_penalty: 1.1, - presence_penalty: 0.5, - frequency_penalty: 0.3, - }), - ); - }); - }); - - describe('token counting edge cases', () => { - it('should handle tiktoken import failure with console warning', async () => { - // Mock tiktoken to fail on import - vi.doMock('tiktoken', () => { - throw new Error('Failed to import tiktoken'); - }); - - const consoleSpy = vi.spyOn(console, 'warn').mockImplementation(() => {}); - - const request: CountTokensParameters = { - contents: [{ role: 'user', parts: [{ text: 'Test content' }] }], - model: 'gpt-4', - }; - - const result = await generator.countTokens(request); - - expect(consoleSpy).toHaveBeenCalledWith( - expect.stringMatching(/Failed to load tiktoken.*falling back/), - expect.any(Error), - ); - - // Should use character approximation - expect(result.totalTokens).toBeGreaterThan(0); - - consoleSpy.mockRestore(); - }); - }); - - describe('metadata control', () => { - it('should include metadata when authType is QWEN_OAUTH', async () => { - const qwenConfig = { - getContentGeneratorConfig: vi.fn().mockReturnValue({ - authType: 'qwen-oauth', - enableOpenAILogging: false, - }), - getSessionId: vi.fn().mockReturnValue('test-session-id'), - getCliVersion: vi.fn().mockReturnValue('1.0.0'), - } as unknown as Config; - - const contentGeneratorConfig = { - model: 'qwen-turbo', - apiKey: 'test-key', - authType: AuthType.QWEN_OAUTH, - enableOpenAILogging: false, - }; - const qwenGenerator = new OpenAIContentGenerator( - contentGeneratorConfig, - qwenConfig, - ); - - const mockResponse = { - id: 'chatcmpl-123', - choices: [ - { - index: 0, - message: { role: 'assistant', content: 'Response' }, - finish_reason: 'stop', - }, - ], - created: 1677652288, - model: 'qwen-turbo', - }; - - mockOpenAIClient.chat.completions.create.mockResolvedValue(mockResponse); - - const request: GenerateContentParameters = { - contents: [{ role: 'user', parts: [{ text: 'Hello' }] }], - model: 'qwen-turbo', - }; - - await qwenGenerator.generateContent(request, 'test-prompt-id'); - - expect(mockOpenAIClient.chat.completions.create).toHaveBeenCalledWith( - expect.objectContaining({ - metadata: { - sessionId: 'test-session-id', - promptId: 'test-prompt-id', - }, - }), - ); - }); - - it('should include metadata when baseURL is dashscope openai compatible mode', async () => { - const dashscopeConfig = { - getContentGeneratorConfig: vi.fn().mockReturnValue({ - authType: 'openai', // Not QWEN_OAUTH - enableOpenAILogging: false, - }), - getSessionId: vi.fn().mockReturnValue('dashscope-session-id'), - getCliVersion: vi.fn().mockReturnValue('1.0.0'), - } as unknown as Config; - - const contentGeneratorConfig = { - model: 'qwen-turbo', - apiKey: 'test-key', - baseUrl: 'https://dashscope.aliyuncs.com/compatible-mode/v1', - authType: AuthType.USE_OPENAI, - enableOpenAILogging: false, - }; - const dashscopeGenerator = new OpenAIContentGenerator( - contentGeneratorConfig, - dashscopeConfig, - ); - - // Debug: Check if the client was created with the correct baseURL - expect(vi.mocked(OpenAI)).toHaveBeenCalledWith( - expect.objectContaining({ - baseURL: 'https://dashscope.aliyuncs.com/compatible-mode/v1', - }), - ); - - // Mock the client's baseURL property to return the expected value - Object.defineProperty(dashscopeGenerator['client'], 'baseURL', { - value: 'https://dashscope.aliyuncs.com/compatible-mode/v1', - writable: true, - }); - - const mockResponse = { - id: 'chatcmpl-123', - choices: [ - { - index: 0, - message: { role: 'assistant', content: 'Response' }, - finish_reason: 'stop', - }, - ], - created: 1677652288, - model: 'qwen-turbo', - }; - - mockOpenAIClient.chat.completions.create.mockResolvedValue(mockResponse); - - const request: GenerateContentParameters = { - contents: [{ role: 'user', parts: [{ text: 'Hello' }] }], - model: 'qwen-turbo', - }; - - await dashscopeGenerator.generateContent(request, 'dashscope-prompt-id'); - - expect(mockOpenAIClient.chat.completions.create).toHaveBeenCalledWith( - expect.objectContaining({ - metadata: { - sessionId: 'dashscope-session-id', - promptId: 'dashscope-prompt-id', - }, - }), - ); - }); - - it('should NOT include metadata for regular OpenAI providers', async () => { - const regularConfig = { - getContentGeneratorConfig: vi.fn().mockReturnValue({ - authType: 'openai', - enableOpenAILogging: false, - }), - getSessionId: vi.fn().mockReturnValue('regular-session-id'), - getCliVersion: vi.fn().mockReturnValue('1.0.0'), - } as unknown as Config; - - const contentGeneratorConfig = { - model: 'gpt-4', - - apiKey: 'test-key', - - authType: AuthType.USE_OPENAI, - - enableOpenAILogging: false, - }; - - const regularGenerator = new OpenAIContentGenerator( - contentGeneratorConfig, - regularConfig, - ); - - const mockResponse = { - id: 'chatcmpl-123', - choices: [ - { - index: 0, - message: { role: 'assistant', content: 'Response' }, - finish_reason: 'stop', - }, - ], - created: 1677652288, - model: 'gpt-4', - }; - - mockOpenAIClient.chat.completions.create.mockResolvedValue(mockResponse); - - const request: GenerateContentParameters = { - contents: [{ role: 'user', parts: [{ text: 'Hello' }] }], - model: 'gpt-4', - }; - - await regularGenerator.generateContent(request, 'regular-prompt-id'); - - // Should NOT include metadata - expect(mockOpenAIClient.chat.completions.create).toHaveBeenCalledWith( - expect.not.objectContaining({ - metadata: expect.any(Object), - }), - ); - }); - - it('should NOT include metadata for other auth types', async () => { - const otherAuthConfig = { - getContentGeneratorConfig: vi.fn().mockReturnValue({ - authType: 'gemini-api-key', - enableOpenAILogging: false, - }), - getSessionId: vi.fn().mockReturnValue('other-session-id'), - getCliVersion: vi.fn().mockReturnValue('1.0.0'), - } as unknown as Config; - - const contentGeneratorConfig = { - model: 'gpt-4', - - apiKey: 'test-key', - - authType: AuthType.USE_OPENAI, - - enableOpenAILogging: false, - }; - - const otherGenerator = new OpenAIContentGenerator( - contentGeneratorConfig, - otherAuthConfig, - ); - - const mockResponse = { - id: 'chatcmpl-123', - choices: [ - { - index: 0, - message: { role: 'assistant', content: 'Response' }, - finish_reason: 'stop', - }, - ], - created: 1677652288, - model: 'gpt-4', - }; - - mockOpenAIClient.chat.completions.create.mockResolvedValue(mockResponse); - - const request: GenerateContentParameters = { - contents: [{ role: 'user', parts: [{ text: 'Hello' }] }], - model: 'gpt-4', - }; - - await otherGenerator.generateContent(request, 'other-prompt-id'); - - // Should NOT include metadata - expect(mockOpenAIClient.chat.completions.create).toHaveBeenCalledWith( - expect.not.objectContaining({ - metadata: expect.any(Object), - }), - ); - }); - - it('should NOT include metadata for other base URLs', async () => { - // Mock environment to set a different base URL - vi.stubEnv('OPENAI_BASE_URL', 'https://api.openai.com/v1'); - - const otherBaseUrlConfig = { - getContentGeneratorConfig: vi.fn().mockReturnValue({ - authType: 'openai', - enableOpenAILogging: false, - }), - getSessionId: vi.fn().mockReturnValue('other-base-url-session-id'), - getCliVersion: vi.fn().mockReturnValue('1.0.0'), - } as unknown as Config; - - const contentGeneratorConfig = { - model: 'gpt-4', - - apiKey: 'test-key', - - authType: AuthType.USE_OPENAI, - - enableOpenAILogging: false, - }; - - const otherBaseUrlGenerator = new OpenAIContentGenerator( - contentGeneratorConfig, - otherBaseUrlConfig, - ); - - const mockResponse = { - id: 'chatcmpl-123', - choices: [ - { - index: 0, - message: { role: 'assistant', content: 'Response' }, - finish_reason: 'stop', - }, - ], - created: 1677652288, - model: 'gpt-4', - }; - - mockOpenAIClient.chat.completions.create.mockResolvedValue(mockResponse); - - const request: GenerateContentParameters = { - contents: [{ role: 'user', parts: [{ text: 'Hello' }] }], - model: 'gpt-4', - }; - - await otherBaseUrlGenerator.generateContent( - request, - 'other-base-url-prompt-id', - ); - - // Should NOT include metadata - expect(mockOpenAIClient.chat.completions.create).toHaveBeenCalledWith( - expect.not.objectContaining({ - metadata: expect.any(Object), - }), - ); - }); - - it('should include metadata in streaming requests when conditions are met', async () => { - const qwenConfig = { - getContentGeneratorConfig: vi.fn().mockReturnValue({ - authType: 'qwen-oauth', - enableOpenAILogging: false, - }), - getSessionId: vi.fn().mockReturnValue('streaming-session-id'), - getCliVersion: vi.fn().mockReturnValue('1.0.0'), - } as unknown as Config; - - const contentGeneratorConfig = { - model: 'qwen-turbo', - - apiKey: 'test-key', - - authType: AuthType.QWEN_OAUTH, - - enableOpenAILogging: false, - }; - - const qwenGenerator = new OpenAIContentGenerator( - contentGeneratorConfig, - qwenConfig, - ); - - const mockStream = [ - { - id: 'chatcmpl-123', - choices: [ - { - index: 0, - delta: { content: 'Hello' }, - finish_reason: null, - }, - ], - created: 1677652288, - }, - { - id: 'chatcmpl-123', - choices: [ - { - index: 0, - delta: { content: ' there!' }, - finish_reason: 'stop', - }, - ], - created: 1677652288, - }, - ]; - - mockOpenAIClient.chat.completions.create.mockResolvedValue({ - async *[Symbol.asyncIterator]() { - for (const chunk of mockStream) { - yield chunk; - } - }, - }); - - const request: GenerateContentParameters = { - contents: [{ role: 'user', parts: [{ text: 'Hello' }] }], - model: 'qwen-turbo', - }; - - const stream = await qwenGenerator.generateContentStream( - request, - 'streaming-prompt-id', - ); - - // Verify metadata was included in the streaming request - expect(mockOpenAIClient.chat.completions.create).toHaveBeenCalledWith( - expect.objectContaining({ - metadata: { - sessionId: 'streaming-session-id', - promptId: 'streaming-prompt-id', - }, - }), - ); - - // Consume the stream to complete the test - const responses = []; - for await (const response of stream) { - responses.push(response); - } - expect(responses).toHaveLength(2); - }); - - it('should NOT include metadata in streaming requests when conditions are not met', async () => { - const regularConfig = { - getContentGeneratorConfig: vi.fn().mockReturnValue({ - authType: 'openai', - enableOpenAILogging: false, - }), - getSessionId: vi.fn().mockReturnValue('regular-streaming-session-id'), - getCliVersion: vi.fn().mockReturnValue('1.0.0'), - } as unknown as Config; - - const contentGeneratorConfig = { - model: 'gpt-4', - - apiKey: 'test-key', - - authType: AuthType.USE_OPENAI, - - enableOpenAILogging: false, - }; - - const regularGenerator = new OpenAIContentGenerator( - contentGeneratorConfig, - regularConfig, - ); - - const mockStream = [ - { - id: 'chatcmpl-123', - choices: [ - { - index: 0, - delta: { content: 'Hello' }, - finish_reason: null, - }, - ], - created: 1677652288, - }, - { - id: 'chatcmpl-123', - choices: [ - { - index: 0, - delta: { content: ' there!' }, - finish_reason: 'stop', - }, - ], - created: 1677652288, - }, - ]; - - mockOpenAIClient.chat.completions.create.mockResolvedValue({ - async *[Symbol.asyncIterator]() { - for (const chunk of mockStream) { - yield chunk; - } - }, - }); - - const request: GenerateContentParameters = { - contents: [{ role: 'user', parts: [{ text: 'Hello' }] }], - model: 'gpt-4', - }; - - const stream = await regularGenerator.generateContentStream( - request, - 'regular-streaming-prompt-id', - ); - - // Verify metadata was NOT included in the streaming request - expect(mockOpenAIClient.chat.completions.create).toHaveBeenCalledWith( - expect.not.objectContaining({ - metadata: expect.any(Object), - }), - ); - - // Consume the stream to complete the test - const responses = []; - for await (const response of stream) { - responses.push(response); - } - expect(responses).toHaveLength(2); - }); - - it('should handle undefined sessionId gracefully', async () => { - const qwenConfig = { - getContentGeneratorConfig: vi.fn().mockReturnValue({ - authType: 'qwen-oauth', - enableOpenAILogging: false, - }), - getSessionId: vi.fn().mockReturnValue(undefined), // Undefined session ID - getCliVersion: vi.fn().mockReturnValue('1.0.0'), - } as unknown as Config; - - const contentGeneratorConfig = { - model: 'qwen-turbo', - - apiKey: 'test-key', - - authType: AuthType.QWEN_OAUTH, - - enableOpenAILogging: false, - }; - - const qwenGenerator = new OpenAIContentGenerator( - contentGeneratorConfig, - qwenConfig, - ); - - const mockResponse = { - id: 'chatcmpl-123', - choices: [ - { - index: 0, - message: { role: 'assistant', content: 'Response' }, - finish_reason: 'stop', - }, - ], - created: 1677652288, - model: 'qwen-turbo', - }; - - mockOpenAIClient.chat.completions.create.mockResolvedValue(mockResponse); - - const request: GenerateContentParameters = { - contents: [{ role: 'user', parts: [{ text: 'Hello' }] }], - model: 'qwen-turbo', - }; - - await qwenGenerator.generateContent( - request, - 'undefined-session-prompt-id', - ); - - expect(mockOpenAIClient.chat.completions.create).toHaveBeenCalledWith( - expect.objectContaining({ - metadata: { - sessionId: undefined, - promptId: 'undefined-session-prompt-id', - }, - }), - ); - }); - - it('should handle undefined baseURL gracefully', async () => { - // Ensure no base URL is set - vi.stubEnv('OPENAI_BASE_URL', ''); - - const noBaseUrlConfig = { - getContentGeneratorConfig: vi.fn().mockReturnValue({ - authType: 'openai', - enableOpenAILogging: false, - }), - getSessionId: vi.fn().mockReturnValue('no-base-url-session-id'), - getCliVersion: vi.fn().mockReturnValue('1.0.0'), - } as unknown as Config; - - const contentGeneratorConfig = { - model: 'gpt-4', - - apiKey: 'test-key', - - authType: AuthType.USE_OPENAI, - - enableOpenAILogging: false, - }; - - const noBaseUrlGenerator = new OpenAIContentGenerator( - contentGeneratorConfig, - noBaseUrlConfig, - ); - - const mockResponse = { - id: 'chatcmpl-123', - choices: [ - { - index: 0, - message: { role: 'assistant', content: 'Response' }, - finish_reason: 'stop', - }, - ], - created: 1677652288, - model: 'gpt-4', - }; - - mockOpenAIClient.chat.completions.create.mockResolvedValue(mockResponse); - - const request: GenerateContentParameters = { - contents: [{ role: 'user', parts: [{ text: 'Hello' }] }], - model: 'gpt-4', - }; - - await noBaseUrlGenerator.generateContent( - request, - 'no-base-url-prompt-id', - ); - - // Should NOT include metadata when baseURL is empty - expect(mockOpenAIClient.chat.completions.create).toHaveBeenCalledWith( - expect.not.objectContaining({ - metadata: expect.any(Object), - }), - ); - }); - - it('should handle undefined authType gracefully', async () => { - const undefinedAuthConfig = { - getContentGeneratorConfig: vi.fn().mockReturnValue({ - authType: undefined, // Undefined auth type - enableOpenAILogging: false, - }), - getSessionId: vi.fn().mockReturnValue('undefined-auth-session-id'), - getCliVersion: vi.fn().mockReturnValue('1.0.0'), - } as unknown as Config; - - const contentGeneratorConfig = { - model: 'gpt-4', - - apiKey: 'test-key', - - authType: AuthType.USE_OPENAI, - - enableOpenAILogging: false, - }; - - const undefinedAuthGenerator = new OpenAIContentGenerator( - contentGeneratorConfig, - undefinedAuthConfig, - ); - - const mockResponse = { - id: 'chatcmpl-123', - choices: [ - { - index: 0, - message: { role: 'assistant', content: 'Response' }, - finish_reason: 'stop', - }, - ], - created: 1677652288, - model: 'gpt-4', - }; - - mockOpenAIClient.chat.completions.create.mockResolvedValue(mockResponse); - - const request: GenerateContentParameters = { - contents: [{ role: 'user', parts: [{ text: 'Hello' }] }], - model: 'gpt-4', - }; - - await undefinedAuthGenerator.generateContent( - request, - 'undefined-auth-prompt-id', - ); - - // Should NOT include metadata when authType is undefined - expect(mockOpenAIClient.chat.completions.create).toHaveBeenCalledWith( - expect.not.objectContaining({ - metadata: expect.any(Object), - }), - ); - }); - - it('should handle undefined config gracefully', async () => { - const undefinedConfig = { - getContentGeneratorConfig: vi.fn().mockReturnValue(undefined), // Undefined config - getSessionId: vi.fn().mockReturnValue('undefined-config-session-id'), - getCliVersion: vi.fn().mockReturnValue('1.0.0'), - } as unknown as Config; - - const contentGeneratorConfig = { - model: 'gpt-4', - - apiKey: 'test-key', - - authType: AuthType.USE_OPENAI, - - enableOpenAILogging: false, - }; - - const undefinedConfigGenerator = new OpenAIContentGenerator( - contentGeneratorConfig, - undefinedConfig, - ); - - const mockResponse = { - id: 'chatcmpl-123', - choices: [ - { - index: 0, - message: { role: 'assistant', content: 'Response' }, - finish_reason: 'stop', - }, - ], - created: 1677652288, - model: 'gpt-4', - }; - - mockOpenAIClient.chat.completions.create.mockResolvedValue(mockResponse); - - const request: GenerateContentParameters = { - contents: [{ role: 'user', parts: [{ text: 'Hello' }] }], - model: 'gpt-4', - }; - - await undefinedConfigGenerator.generateContent( - request, - 'undefined-config-prompt-id', - ); - - // Should NOT include metadata when config is undefined - expect(mockOpenAIClient.chat.completions.create).toHaveBeenCalledWith( - expect.not.objectContaining({ - metadata: expect.any(Object), - }), - ); - }); - }); - - describe('cache control for DashScope', () => { - it('should add cache control to system message for DashScope providers', async () => { - // Mock environment to set dashscope base URL - vi.stubEnv( - 'OPENAI_BASE_URL', - 'https://dashscope.aliyuncs.com/compatible-mode/v1', - ); - - const dashscopeConfig = { - getContentGeneratorConfig: vi.fn().mockReturnValue({ - authType: 'openai', - enableOpenAILogging: false, - }), - getSessionId: vi.fn().mockReturnValue('dashscope-session-id'), - getCliVersion: vi.fn().mockReturnValue('1.0.0'), - } as unknown as Config; - - const contentGeneratorConfig = { - model: 'qwen-turbo', - - apiKey: 'test-key', - - authType: AuthType.QWEN_OAUTH, - - enableOpenAILogging: false, - }; - - const dashscopeGenerator = new OpenAIContentGenerator( - contentGeneratorConfig, - dashscopeConfig, - ); - - // Mock the client's baseURL property to return the expected value - Object.defineProperty(dashscopeGenerator['client'], 'baseURL', { - value: 'https://dashscope.aliyuncs.com/compatible-mode/v1', - writable: true, - }); - - const mockResponse = { - id: 'chatcmpl-123', - choices: [ - { - index: 0, - message: { role: 'assistant', content: 'Response' }, - finish_reason: 'stop', - }, - ], - created: 1677652288, - model: 'qwen-turbo', - }; - - mockOpenAIClient.chat.completions.create.mockResolvedValue(mockResponse); - - const request: GenerateContentParameters = { - contents: [{ role: 'user', parts: [{ text: 'Hello' }] }], - config: { - systemInstruction: 'You are a helpful assistant.', - }, - model: 'qwen-turbo', - }; - - await dashscopeGenerator.generateContent(request, 'dashscope-prompt-id'); - - // Should include cache control in system message - expect(mockOpenAIClient.chat.completions.create).toHaveBeenCalledWith( - expect.objectContaining({ - messages: expect.arrayContaining([ - expect.objectContaining({ - role: 'system', - content: expect.arrayContaining([ - expect.objectContaining({ - type: 'text', - text: 'You are a helpful assistant.', - cache_control: { type: 'ephemeral' }, - }), - ]), - }), - ]), - }), - ); - }); - - it('should add cache control to last message for DashScope providers', async () => { - // Mock environment to set dashscope base URL - vi.stubEnv( - 'OPENAI_BASE_URL', - 'https://dashscope.aliyuncs.com/compatible-mode/v1', - ); - - const dashscopeConfig = { - getContentGeneratorConfig: vi.fn().mockReturnValue({ - authType: 'openai', - enableOpenAILogging: false, - }), - getSessionId: vi.fn().mockReturnValue('dashscope-session-id'), - getCliVersion: vi.fn().mockReturnValue('1.0.0'), - } as unknown as Config; - - const contentGeneratorConfig = { - model: 'qwen-turbo', - - apiKey: 'test-key', - - authType: AuthType.QWEN_OAUTH, - - enableOpenAILogging: false, - }; - - const dashscopeGenerator = new OpenAIContentGenerator( - contentGeneratorConfig, - dashscopeConfig, - ); - - // Mock the client's baseURL property to return the expected value - Object.defineProperty(dashscopeGenerator['client'], 'baseURL', { - value: 'https://dashscope.aliyuncs.com/compatible-mode/v1', - writable: true, - }); - - const mockResponse = { - id: 'chatcmpl-123', - choices: [ - { - index: 0, - message: { role: 'assistant', content: 'Response' }, - finish_reason: 'stop', - }, - ], - created: 1677652288, - model: 'qwen-turbo', - }; - - mockOpenAIClient.chat.completions.create.mockResolvedValue(mockResponse); - - const request: GenerateContentParameters = { - contents: [{ role: 'user', parts: [{ text: 'Hello, how are you?' }] }], - model: 'qwen-turbo', - }; - - await dashscopeGenerator.generateContentStream( - request, - 'dashscope-prompt-id', - ); - - // Should include cache control in last message - expect(mockOpenAIClient.chat.completions.create).toHaveBeenCalledWith( - expect.objectContaining({ - messages: expect.arrayContaining([ - expect.objectContaining({ - role: 'user', - content: expect.arrayContaining([ - expect.objectContaining({ - type: 'text', - text: 'Hello, how are you?', - }), - ]), - }), - ]), - }), - ); - }); - - it('should NOT add cache control for non-DashScope providers', async () => { - const regularConfig = { - getContentGeneratorConfig: vi.fn().mockReturnValue({ - authType: 'openai', - enableOpenAILogging: false, - }), - getSessionId: vi.fn().mockReturnValue('regular-session-id'), - getCliVersion: vi.fn().mockReturnValue('1.0.0'), - } as unknown as Config; - - const contentGeneratorConfig = { - model: 'gpt-4', - - apiKey: 'test-key', - - authType: AuthType.USE_OPENAI, - - enableOpenAILogging: false, - }; - - const regularGenerator = new OpenAIContentGenerator( - contentGeneratorConfig, - regularConfig, - ); - - const mockResponse = { - id: 'chatcmpl-123', - choices: [ - { - index: 0, - message: { role: 'assistant', content: 'Response' }, - finish_reason: 'stop', - }, - ], - created: 1677652288, - model: 'gpt-4', - }; - - mockOpenAIClient.chat.completions.create.mockResolvedValue(mockResponse); - - const request: GenerateContentParameters = { - contents: [{ role: 'user', parts: [{ text: 'Hello' }] }], - config: { - systemInstruction: 'You are a helpful assistant.', - }, - model: 'gpt-4', - }; - - await regularGenerator.generateContent(request, 'regular-prompt-id'); - - // Should NOT include cache control (messages should be strings, not arrays) - expect(mockOpenAIClient.chat.completions.create).toHaveBeenCalledWith( - expect.objectContaining({ - messages: expect.arrayContaining([ - expect.objectContaining({ - role: 'system', - content: 'You are a helpful assistant.', - }), - expect.objectContaining({ - role: 'user', - content: 'Hello', - }), - ]), - }), - ); - }); - }); -}); diff --git a/packages/core/src/core/openaiContentGenerator.ts b/packages/core/src/core/openaiContentGenerator.ts deleted file mode 100644 index e3e19533..00000000 --- a/packages/core/src/core/openaiContentGenerator.ts +++ /dev/null @@ -1,1711 +0,0 @@ -/** - * @license - * Copyright 2025 Qwen - * SPDX-License-Identifier: Apache-2.0 - */ - -import type { - CountTokensResponse, - GenerateContentParameters, - CountTokensParameters, - EmbedContentResponse, - EmbedContentParameters, - Part, - Content, - Tool, - ToolListUnion, - CallableTool, - FunctionCall, - FunctionResponse, -} from '@google/genai'; -import { GenerateContentResponse, FinishReason } from '@google/genai'; -import type { - ContentGenerator, - ContentGeneratorConfig, -} from './contentGenerator.js'; -import { AuthType } from './contentGenerator.js'; -import OpenAI from 'openai'; -import { logApiError, logApiResponse } from '../telemetry/loggers.js'; -import { ApiErrorEvent, ApiResponseEvent } from '../telemetry/types.js'; -import type { Config } from '../config/config.js'; -import { openaiLogger } from '../utils/openaiLogger.js'; -import { safeJsonParse } from '../utils/safeJsonParse.js'; - -// Extended types to support cache_control -interface ChatCompletionContentPartTextWithCache - extends OpenAI.Chat.ChatCompletionContentPartText { - cache_control?: { type: 'ephemeral' }; -} - -type ChatCompletionContentPartWithCache = - | ChatCompletionContentPartTextWithCache - | OpenAI.Chat.ChatCompletionContentPartImage - | OpenAI.Chat.ChatCompletionContentPartRefusal; - -// OpenAI API type definitions for logging -interface OpenAIToolCall { - id: string; - type: 'function'; - function: { - name: string; - arguments: string; - }; -} - -interface OpenAIContentItem { - type: 'text'; - text: string; - cache_control?: { type: 'ephemeral' }; -} - -interface OpenAIMessage { - role: 'system' | 'user' | 'assistant' | 'tool'; - content: string | null | OpenAIContentItem[]; - tool_calls?: OpenAIToolCall[]; - tool_call_id?: string; -} - -interface OpenAIUsage { - prompt_tokens: number; - completion_tokens: number; - total_tokens: number; - prompt_tokens_details?: { - cached_tokens?: number; - }; -} - -interface OpenAIChoice { - index: number; - message: OpenAIMessage; - finish_reason: string; -} - -interface OpenAIResponseFormat { - id: string; - object: string; - created: number; - model: string; - choices: OpenAIChoice[]; - usage?: OpenAIUsage; -} - -/** - * @deprecated refactored to ./openaiContentGenerator - * use `createOpenAIContentGenerator` instead - * or extend `OpenAIContentGenerator` to add customized behavior - */ -export class OpenAIContentGenerator implements ContentGenerator { - protected client: OpenAI; - private model: string; - private contentGeneratorConfig: ContentGeneratorConfig; - private config: Config; - private streamingToolCalls: Map< - number, - { - id?: string; - name?: string; - arguments: string; - } - > = new Map(); - - constructor( - contentGeneratorConfig: ContentGeneratorConfig, - gcConfig: Config, - ) { - this.model = contentGeneratorConfig.model; - this.contentGeneratorConfig = contentGeneratorConfig; - this.config = gcConfig; - - const version = gcConfig.getCliVersion() || 'unknown'; - const userAgent = `QwenCode/${version} (${process.platform}; ${process.arch})`; - - // Check if using OpenRouter and add required headers - const isOpenRouterProvider = this.isOpenRouterProvider(); - const isDashScopeProvider = this.isDashScopeProvider(); - - const defaultHeaders = { - 'User-Agent': userAgent, - ...(isOpenRouterProvider - ? { - 'HTTP-Referer': 'https://github.com/QwenLM/qwen-code.git', - 'X-Title': 'Qwen Code', - } - : isDashScopeProvider - ? { - 'X-DashScope-CacheControl': 'enable', - 'X-DashScope-UserAgent': userAgent, - 'X-DashScope-AuthType': contentGeneratorConfig.authType, - } - : {}), - }; - - this.client = new OpenAI({ - apiKey: contentGeneratorConfig.apiKey, - baseURL: contentGeneratorConfig.baseUrl, - timeout: contentGeneratorConfig.timeout ?? 120000, - maxRetries: contentGeneratorConfig.maxRetries ?? 3, - defaultHeaders, - }); - } - - /** - * Hook for subclasses to customize error handling behavior - * @param error The error that occurred - * @param request The original request - * @returns true if error logging should be suppressed, false otherwise - */ - protected shouldSuppressErrorLogging( - _error: unknown, - _request: GenerateContentParameters, - ): boolean { - return false; // Default behavior: never suppress error logging - } - - /** - * Check if an error is a timeout error - */ - private isTimeoutError(error: unknown): boolean { - if (!error) return false; - - const errorMessage = - error instanceof Error - ? error.message.toLowerCase() - : String(error).toLowerCase(); - // eslint-disable-next-line @typescript-eslint/no-explicit-any - const errorCode = (error as any)?.code; - // eslint-disable-next-line @typescript-eslint/no-explicit-any - const errorType = (error as any)?.type; - - // Check for common timeout indicators - return ( - errorMessage.includes('timeout') || - errorMessage.includes('timed out') || - errorMessage.includes('connection timeout') || - errorMessage.includes('request timeout') || - errorMessage.includes('read timeout') || - errorMessage.includes('etimedout') || // Include ETIMEDOUT in message check - errorMessage.includes('esockettimedout') || // Include ESOCKETTIMEDOUT in message check - errorCode === 'ETIMEDOUT' || - errorCode === 'ESOCKETTIMEDOUT' || - errorType === 'timeout' || - // OpenAI specific timeout indicators - errorMessage.includes('request timed out') || - errorMessage.includes('deadline exceeded') - ); - } - - private isOpenRouterProvider(): boolean { - const baseURL = this.contentGeneratorConfig.baseUrl || ''; - return baseURL.includes('openrouter.ai'); - } - - /** - * Determine if this is a DashScope provider. - * DashScope providers include QWEN_OAUTH auth type or specific DashScope base URLs. - * - * @returns true if this is a DashScope provider, false otherwise - */ - private isDashScopeProvider(): boolean { - const authType = this.contentGeneratorConfig.authType; - const baseUrl = this.contentGeneratorConfig.baseUrl; - - return ( - authType === AuthType.QWEN_OAUTH || - baseUrl === 'https://dashscope.aliyuncs.com/compatible-mode/v1' || - baseUrl === 'https://dashscope-intl.aliyuncs.com/compatible-mode/v1' - ); - } - - /** - * Check if cache control should be disabled based on configuration. - * - * @returns true if cache control should be disabled, false otherwise - */ - private shouldDisableCacheControl(): boolean { - return ( - this.config.getContentGeneratorConfig()?.disableCacheControl === true - ); - } - - /** - * Build metadata object for OpenAI API requests. - * - * @param userPromptId The user prompt ID to include in metadata - * @returns metadata object if shouldIncludeMetadata() returns true, undefined otherwise - */ - private buildMetadata( - userPromptId: string, - ): { metadata: { sessionId?: string; promptId: string } } | undefined { - if (!this.isDashScopeProvider()) { - return undefined; - } - - return { - metadata: { - sessionId: this.config.getSessionId?.(), - promptId: userPromptId, - }, - }; - } - - private async buildCreateParams( - request: GenerateContentParameters, - userPromptId: string, - streaming: boolean = false, - ): Promise[0]> { - let messages = this.convertToOpenAIFormat(request); - - // Add cache control to system and last messages for DashScope providers - // Only add cache control to system message for non-streaming requests - if (this.isDashScopeProvider() && !this.shouldDisableCacheControl()) { - messages = this.addDashScopeCacheControl( - messages, - streaming ? 'both' : 'system', - ); - } - - // Build sampling parameters with clear priority: - // 1. Request-level parameters (highest priority) - // 2. Config-level sampling parameters (medium priority) - // 3. Default values (lowest priority) - const samplingParams = this.buildSamplingParameters(request); - - const createParams: Parameters< - typeof this.client.chat.completions.create - >[0] = { - model: this.model, - messages, - ...samplingParams, - ...(this.buildMetadata(userPromptId) || {}), - }; - - if (request.config?.tools) { - createParams.tools = await this.convertGeminiToolsToOpenAI( - request.config.tools, - ); - } - - if (streaming) { - createParams.stream = true; - createParams.stream_options = { include_usage: true }; - } - - return createParams; - } - - async generateContent( - request: GenerateContentParameters, - userPromptId: string, - ): Promise { - const startTime = Date.now(); - const createParams = await this.buildCreateParams( - request, - userPromptId, - false, - ); - - try { - const completion = (await this.client.chat.completions.create( - createParams, - )) as OpenAI.Chat.ChatCompletion; - - const response = this.convertToGeminiFormat(completion); - const durationMs = Date.now() - startTime; - - // Log API response event for UI telemetry - const responseEvent = new ApiResponseEvent( - response.responseId || 'unknown', - this.model, - durationMs, - userPromptId, - this.contentGeneratorConfig.authType, - response.usageMetadata, - ); - - logApiResponse(this.config, responseEvent); - - // Log interaction if enabled - if (this.contentGeneratorConfig.enableOpenAILogging) { - const openaiRequest = createParams; - const openaiResponse = this.convertGeminiResponseToOpenAI(response); - await openaiLogger.logInteraction(openaiRequest, openaiResponse); - } - - return response; - } catch (error) { - const durationMs = Date.now() - startTime; - - // Identify timeout errors specifically - const isTimeoutError = this.isTimeoutError(error); - const errorMessage = isTimeoutError - ? `Request timeout after ${Math.round(durationMs / 1000)}s. Try reducing input length or increasing timeout in config.` - : error instanceof Error - ? error.message - : String(error); - - // Log API error event for UI telemetry - const errorEvent = new ApiErrorEvent( - // eslint-disable-next-line @typescript-eslint/no-explicit-any - (error as any).requestID || 'unknown', - this.model, - errorMessage, - durationMs, - userPromptId, - this.contentGeneratorConfig.authType, - // eslint-disable-next-line @typescript-eslint/no-explicit-any - (error as any).type, - // eslint-disable-next-line @typescript-eslint/no-explicit-any - (error as any).code, - ); - logApiError(this.config, errorEvent); - - // Log error interaction if enabled - if (this.contentGeneratorConfig.enableOpenAILogging) { - await openaiLogger.logInteraction( - createParams, - undefined, - error as Error, - ); - } - - // Allow subclasses to suppress error logging for specific scenarios - if (!this.shouldSuppressErrorLogging(error, request)) { - console.error('OpenAI API Error:', errorMessage); - } - - // Provide helpful timeout-specific error message - if (isTimeoutError) { - throw new Error( - `${errorMessage}\n\nTroubleshooting tips:\n` + - `- Reduce input length or complexity\n` + - `- Increase timeout in config: contentGenerator.timeout\n` + - `- Check network connectivity\n` + - `- Consider using streaming mode for long responses`, - ); - } - - throw error; - } - } - - async generateContentStream( - request: GenerateContentParameters, - userPromptId: string, - ): Promise> { - const startTime = Date.now(); - const createParams = await this.buildCreateParams( - request, - userPromptId, - true, - ); - - try { - const stream = (await this.client.chat.completions.create( - createParams, - )) as AsyncIterable; - - const originalStream = this.streamGenerator(stream); - - // Collect all responses for final logging (don't log during streaming) - const responses: GenerateContentResponse[] = []; - - // Return a new generator that both yields responses and collects them - const wrappedGenerator = async function* (this: OpenAIContentGenerator) { - try { - for await (const response of originalStream) { - responses.push(response); - yield response; - } - - const durationMs = Date.now() - startTime; - - // Get final usage metadata from the last response that has it - const finalUsageMetadata = responses - .slice() - .reverse() - .find((r) => r.usageMetadata)?.usageMetadata; - - // Log API response event for UI telemetry - const responseEvent = new ApiResponseEvent( - responses[responses.length - 1]?.responseId || 'unknown', - this.model, - durationMs, - userPromptId, - this.contentGeneratorConfig.authType, - finalUsageMetadata, - ); - - logApiResponse(this.config, responseEvent); - - // Log interaction if enabled (same as generateContent method) - if (this.contentGeneratorConfig.enableOpenAILogging) { - const openaiRequest = createParams; - // For streaming, we combine all responses into a single response for logging - const combinedResponse = - this.combineStreamResponsesForLogging(responses); - const openaiResponse = - this.convertGeminiResponseToOpenAI(combinedResponse); - await openaiLogger.logInteraction(openaiRequest, openaiResponse); - } - } catch (error) { - const durationMs = Date.now() - startTime; - - // Identify timeout errors specifically for streaming - const isTimeoutError = this.isTimeoutError(error); - const errorMessage = isTimeoutError - ? `Streaming request timeout after ${Math.round(durationMs / 1000)}s. Try reducing input length or increasing timeout in config.` - : error instanceof Error - ? error.message - : String(error); - - // Log API error event for UI telemetry - const errorEvent = new ApiErrorEvent( - // eslint-disable-next-line @typescript-eslint/no-explicit-any - (error as any).requestID || 'unknown', - this.model, - errorMessage, - durationMs, - userPromptId, - this.contentGeneratorConfig.authType, - // eslint-disable-next-line @typescript-eslint/no-explicit-any - (error as any).type, - // eslint-disable-next-line @typescript-eslint/no-explicit-any - (error as any).code, - ); - logApiError(this.config, errorEvent); - - // Log error interaction if enabled - if (this.contentGeneratorConfig.enableOpenAILogging) { - await openaiLogger.logInteraction( - createParams, - undefined, - error as Error, - ); - } - - // Provide helpful timeout-specific error message for streaming - if (isTimeoutError) { - throw new Error( - `${errorMessage}\n\nStreaming timeout troubleshooting:\n` + - `- Reduce input length or complexity\n` + - `- Increase timeout in config: contentGenerator.timeout\n` + - `- Check network stability for streaming connections\n` + - `- Consider using non-streaming mode for very long inputs`, - ); - } - - throw error; - } - }.bind(this); - - return wrappedGenerator(); - } catch (error) { - const durationMs = Date.now() - startTime; - - // Identify timeout errors specifically for streaming setup - const isTimeoutError = this.isTimeoutError(error); - const errorMessage = isTimeoutError - ? `Streaming setup timeout after ${Math.round(durationMs / 1000)}s. Try reducing input length or increasing timeout in config.` - : error instanceof Error - ? error.message - : String(error); - - // Log API error event for UI telemetry - const errorEvent = new ApiErrorEvent( - // eslint-disable-next-line @typescript-eslint/no-explicit-any - (error as any).requestID || 'unknown', - this.model, - errorMessage, - durationMs, - userPromptId, - this.contentGeneratorConfig.authType, - // eslint-disable-next-line @typescript-eslint/no-explicit-any - (error as any).type, - // eslint-disable-next-line @typescript-eslint/no-explicit-any - (error as any).code, - ); - logApiError(this.config, errorEvent); - - // Allow subclasses to suppress error logging for specific scenarios - if (!this.shouldSuppressErrorLogging(error, request)) { - console.error('OpenAI API Streaming Error:', errorMessage); - } - - // Provide helpful timeout-specific error message for streaming setup - if (isTimeoutError) { - throw new Error( - `${errorMessage}\n\nStreaming setup timeout troubleshooting:\n` + - `- Reduce input length or complexity\n` + - `- Increase timeout in config: contentGenerator.timeout\n` + - `- Check network connectivity and firewall settings\n` + - `- Consider using non-streaming mode for very long inputs`, - ); - } - - throw error; - } - } - - private async *streamGenerator( - stream: AsyncIterable, - ): AsyncGenerator { - // Reset the accumulator for each new stream - this.streamingToolCalls.clear(); - - for await (const chunk of stream) { - const response = this.convertStreamChunkToGeminiFormat(chunk); - - // Ignore empty responses, which would cause problems with downstream code - // that expects a valid response. - if ( - response.candidates?.[0]?.content?.parts?.length === 0 && - !response.usageMetadata - ) { - continue; - } - - yield response; - } - } - - /** - * Combine streaming responses for logging purposes - */ - private combineStreamResponsesForLogging( - responses: GenerateContentResponse[], - ): GenerateContentResponse { - if (responses.length === 0) { - return new GenerateContentResponse(); - } - - const lastResponse = responses[responses.length - 1]; - - // Find the last response with usage metadata - const finalUsageMetadata = responses - .slice() - .reverse() - .find((r) => r.usageMetadata)?.usageMetadata; - - // Combine all text content from the stream - const combinedParts: Part[] = []; - let combinedText = ''; - const functionCalls: Part[] = []; - - for (const response of responses) { - if (response.candidates?.[0]?.content?.parts) { - for (const part of response.candidates[0].content.parts) { - if ('text' in part && part.text) { - combinedText += part.text; - } else if ('functionCall' in part && part.functionCall) { - functionCalls.push(part); - } - } - } - } - - // Add combined text if any - if (combinedText) { - combinedParts.push({ text: combinedText }); - } - - // Add function calls - combinedParts.push(...functionCalls); - - // Create combined response - const combinedResponse = new GenerateContentResponse(); - combinedResponse.candidates = [ - { - content: { - parts: combinedParts, - role: 'model' as const, - }, - finishReason: - responses[responses.length - 1]?.candidates?.[0]?.finishReason || - FinishReason.FINISH_REASON_UNSPECIFIED, - index: 0, - safetyRatings: [], - }, - ]; - combinedResponse.responseId = lastResponse?.responseId; - combinedResponse.createTime = lastResponse?.createTime; - combinedResponse.modelVersion = this.model; - combinedResponse.promptFeedback = { safetyRatings: [] }; - combinedResponse.usageMetadata = finalUsageMetadata; - - return combinedResponse; - } - - async countTokens( - request: CountTokensParameters, - ): Promise { - // Use tiktoken for accurate token counting - const content = JSON.stringify(request.contents); - let totalTokens = 0; - - try { - const { get_encoding } = await import('tiktoken'); - const encoding = get_encoding('cl100k_base'); // GPT-4 encoding, but estimate for qwen - totalTokens = encoding.encode(content).length; - encoding.free(); - } catch (error) { - console.warn( - 'Failed to load tiktoken, falling back to character approximation:', - error, - ); - // Fallback: rough approximation using character count - totalTokens = Math.ceil(content.length / 4); // Rough estimate: 1 token ≈ 4 characters - } - - return { - totalTokens, - }; - } - - async embedContent( - request: EmbedContentParameters, - ): Promise { - // Extract text from contents - let text = ''; - if (Array.isArray(request.contents)) { - text = request.contents - .map((content) => { - if (typeof content === 'string') return content; - if ('parts' in content && content.parts) { - return content.parts - .map((part) => - typeof part === 'string' - ? part - : 'text' in part - ? (part as { text?: string }).text || '' - : '', - ) - .join(' '); - } - return ''; - }) - .join(' '); - } else if (request.contents) { - if (typeof request.contents === 'string') { - text = request.contents; - } else if ('parts' in request.contents && request.contents.parts) { - text = request.contents.parts - .map((part: Part) => - typeof part === 'string' ? part : 'text' in part ? part.text : '', - ) - .join(' '); - } - } - - try { - const embedding = await this.client.embeddings.create({ - model: 'text-embedding-ada-002', // Default embedding model - input: text, - }); - - return { - embeddings: [ - { - values: embedding.data[0].embedding, - }, - ], - }; - } catch (error) { - console.error('OpenAI API Embedding Error:', error); - throw new Error( - `OpenAI API error: ${error instanceof Error ? error.message : String(error)}`, - ); - } - } - - private convertGeminiParametersToOpenAI( - parameters: Record, - ): Record | undefined { - if (!parameters || typeof parameters !== 'object') { - return parameters; - } - - const converted = JSON.parse(JSON.stringify(parameters)); - - const convertTypes = (obj: unknown): unknown => { - if (typeof obj !== 'object' || obj === null) { - return obj; - } - - if (Array.isArray(obj)) { - return obj.map(convertTypes); - } - - const result: Record = {}; - for (const [key, value] of Object.entries(obj)) { - if (key === 'type' && typeof value === 'string') { - // Convert Gemini types to OpenAI JSON Schema types - const lowerValue = value.toLowerCase(); - if (lowerValue === 'integer') { - result[key] = 'integer'; - } else if (lowerValue === 'number') { - result[key] = 'number'; - } else { - result[key] = lowerValue; - } - } else if ( - key === 'minimum' || - key === 'maximum' || - key === 'multipleOf' - ) { - // Ensure numeric constraints are actual numbers, not strings - if (typeof value === 'string' && !isNaN(Number(value))) { - result[key] = Number(value); - } else { - result[key] = value; - } - } else if ( - key === 'minLength' || - key === 'maxLength' || - key === 'minItems' || - key === 'maxItems' - ) { - // Ensure length constraints are integers, not strings - if (typeof value === 'string' && !isNaN(Number(value))) { - result[key] = parseInt(value, 10); - } else { - result[key] = value; - } - } else if (typeof value === 'object') { - result[key] = convertTypes(value); - } else { - result[key] = value; - } - } - return result; - }; - - return convertTypes(converted) as Record | undefined; - } - - /** - * Converts Gemini tools to OpenAI format for API compatibility. - * Handles both Gemini tools (using 'parameters' field) and MCP tools (using 'parametersJsonSchema' field). - * - * Gemini tools use a custom parameter format that needs conversion to OpenAI JSON Schema format. - * MCP tools already use JSON Schema format in the parametersJsonSchema field and can be used directly. - * - * @param geminiTools - Array of Gemini tools to convert - * @returns Promise resolving to array of OpenAI-compatible tools - */ - private async convertGeminiToolsToOpenAI( - geminiTools: ToolListUnion, - ): Promise { - const openAITools: OpenAI.Chat.ChatCompletionTool[] = []; - - for (const tool of geminiTools) { - let actualTool: Tool; - - // Handle CallableTool vs Tool - if ('tool' in tool) { - // This is a CallableTool - actualTool = await (tool as CallableTool).tool(); - } else { - // This is already a Tool - actualTool = tool as Tool; - } - - if (actualTool.functionDeclarations) { - for (const func of actualTool.functionDeclarations) { - if (func.name && func.description) { - let parameters: Record | undefined; - - // Handle both Gemini tools (parameters) and MCP tools (parametersJsonSchema) - if (func.parametersJsonSchema) { - // MCP tool format - use parametersJsonSchema directly - if (func.parametersJsonSchema) { - // Create a shallow copy to avoid mutating the original object - const paramsCopy = { - ...(func.parametersJsonSchema as Record), - }; - parameters = paramsCopy; - } - } else if (func.parameters) { - // Gemini tool format - convert parameters to OpenAI format - parameters = this.convertGeminiParametersToOpenAI( - func.parameters as Record, - ); - } - - openAITools.push({ - type: 'function', - function: { - name: func.name, - description: func.description, - parameters, - }, - }); - } - } - } - } - - // console.log( - // 'OpenAI Tools Parameters:', - // JSON.stringify(openAITools, null, 2), - // ); - return openAITools; - } - - private convertToOpenAIFormat( - request: GenerateContentParameters, - ): OpenAI.Chat.ChatCompletionMessageParam[] { - const messages: OpenAI.Chat.ChatCompletionMessageParam[] = []; - - // Handle system instruction from config - if (request.config?.systemInstruction) { - const systemInstruction = request.config.systemInstruction; - let systemText = ''; - - if (Array.isArray(systemInstruction)) { - systemText = systemInstruction - .map((content) => { - if (typeof content === 'string') return content; - if ('parts' in content) { - const contentObj = content as Content; - return ( - contentObj.parts - ?.map((p: Part) => - typeof p === 'string' ? p : 'text' in p ? p.text : '', - ) - .join('\n') || '' - ); - } - return ''; - }) - .join('\n'); - } else if (typeof systemInstruction === 'string') { - systemText = systemInstruction; - } else if ( - typeof systemInstruction === 'object' && - 'parts' in systemInstruction - ) { - const systemContent = systemInstruction as Content; - systemText = - systemContent.parts - ?.map((p: Part) => - typeof p === 'string' ? p : 'text' in p ? p.text : '', - ) - .join('\n') || ''; - } - - if (systemText) { - messages.push({ - role: 'system' as const, - content: systemText, - }); - } - } - - // Handle contents - if (Array.isArray(request.contents)) { - for (const content of request.contents) { - if (typeof content === 'string') { - messages.push({ role: 'user' as const, content }); - } else if ('role' in content && 'parts' in content) { - // Check if this content has function calls or responses - const functionCalls: FunctionCall[] = []; - const functionResponses: FunctionResponse[] = []; - const textParts: string[] = []; - - for (const part of content.parts || []) { - if (typeof part === 'string') { - textParts.push(part); - } else if ('text' in part && part.text) { - textParts.push(part.text); - } else if ('functionCall' in part && part.functionCall) { - functionCalls.push(part.functionCall); - } else if ('functionResponse' in part && part.functionResponse) { - functionResponses.push(part.functionResponse); - } - } - - // Handle function responses (tool results) - if (functionResponses.length > 0) { - for (const funcResponse of functionResponses) { - messages.push({ - role: 'tool' as const, - tool_call_id: funcResponse.id || '', - content: - typeof funcResponse.response === 'string' - ? funcResponse.response - : JSON.stringify(funcResponse.response), - }); - } - } - // Handle model messages with function calls - else if (content.role === 'model' && functionCalls.length > 0) { - const toolCalls = functionCalls.map((fc, index) => ({ - id: fc.id || `call_${index}`, - type: 'function' as const, - function: { - name: fc.name || '', - arguments: JSON.stringify(fc.args || {}), - }, - })); - - messages.push({ - role: 'assistant' as const, - content: textParts.join('') || null, - tool_calls: toolCalls, - }); - } - // Handle regular text messages - else { - const role = - content.role === 'model' - ? ('assistant' as const) - : ('user' as const); - const text = textParts.join(''); - if (text) { - messages.push({ role, content: text }); - } - } - } - } - } else if (request.contents) { - if (typeof request.contents === 'string') { - messages.push({ role: 'user' as const, content: request.contents }); - } else if ('role' in request.contents && 'parts' in request.contents) { - const content = request.contents; - const role = - content.role === 'model' ? ('assistant' as const) : ('user' as const); - const text = - content.parts - ?.map((p: Part) => - typeof p === 'string' ? p : 'text' in p ? p.text : '', - ) - .join('\n') || ''; - messages.push({ role, content: text }); - } - } - - // Clean up orphaned tool calls and merge consecutive assistant messages - const cleanedMessages = this.cleanOrphanedToolCalls(messages); - const mergedMessages = - this.mergeConsecutiveAssistantMessages(cleanedMessages); - - return mergedMessages; - } - - /** - * Add cache control flag to specified message(s) for DashScope providers - */ - private addDashScopeCacheControl( - messages: OpenAI.Chat.ChatCompletionMessageParam[], - target: 'system' | 'last' | 'both' = 'both', - ): OpenAI.Chat.ChatCompletionMessageParam[] { - if (!this.isDashScopeProvider() || messages.length === 0) { - return messages; - } - - let updatedMessages = [...messages]; - - // Add cache control to system message if requested - if (target === 'system' || target === 'both') { - updatedMessages = this.addCacheControlToMessage( - updatedMessages, - 'system', - ); - } - - // Add cache control to last message if requested - if (target === 'last' || target === 'both') { - updatedMessages = this.addCacheControlToMessage(updatedMessages, 'last'); - } - - return updatedMessages; - } - - /** - * Helper method to add cache control to a specific message - */ - private addCacheControlToMessage( - messages: OpenAI.Chat.ChatCompletionMessageParam[], - target: 'system' | 'last', - ): OpenAI.Chat.ChatCompletionMessageParam[] { - const updatedMessages = [...messages]; - let messageIndex: number; - - if (target === 'system') { - // Find the first system message - messageIndex = messages.findIndex((msg) => msg.role === 'system'); - if (messageIndex === -1) { - return updatedMessages; - } - } else { - // Get the last message - messageIndex = messages.length - 1; - } - - const message = updatedMessages[messageIndex]; - - // Only process messages that have content - if ('content' in message && message.content !== null) { - if (typeof message.content === 'string') { - // Convert string content to array format with cache control - const messageWithArrayContent = { - ...message, - content: [ - { - type: 'text', - text: message.content, - cache_control: { type: 'ephemeral' }, - } as ChatCompletionContentPartTextWithCache, - ], - }; - updatedMessages[messageIndex] = - messageWithArrayContent as OpenAI.Chat.ChatCompletionMessageParam; - } else if (Array.isArray(message.content)) { - // If content is already an array, add cache_control to the last item - const contentArray = [ - ...message.content, - ] as ChatCompletionContentPartWithCache[]; - if (contentArray.length > 0) { - const lastItem = contentArray[contentArray.length - 1]; - if (lastItem.type === 'text') { - // Add cache_control to the last text item - contentArray[contentArray.length - 1] = { - ...lastItem, - cache_control: { type: 'ephemeral' }, - } as ChatCompletionContentPartTextWithCache; - } else { - // If the last item is not text, add a new text item with cache_control - contentArray.push({ - type: 'text', - text: '', - cache_control: { type: 'ephemeral' }, - } as ChatCompletionContentPartTextWithCache); - } - - const messageWithCache = { - ...message, - content: contentArray, - }; - updatedMessages[messageIndex] = - messageWithCache as OpenAI.Chat.ChatCompletionMessageParam; - } - } - } - - return updatedMessages; - } - - /** - * Clean up orphaned tool calls from message history to prevent OpenAI API errors - */ - private cleanOrphanedToolCalls( - messages: OpenAI.Chat.ChatCompletionMessageParam[], - ): OpenAI.Chat.ChatCompletionMessageParam[] { - const cleaned: OpenAI.Chat.ChatCompletionMessageParam[] = []; - const toolCallIds = new Set(); - const toolResponseIds = new Set(); - - // First pass: collect all tool call IDs and tool response IDs - for (const message of messages) { - if ( - message.role === 'assistant' && - 'tool_calls' in message && - message.tool_calls - ) { - for (const toolCall of message.tool_calls) { - if (toolCall.id) { - toolCallIds.add(toolCall.id); - } - } - } else if ( - message.role === 'tool' && - 'tool_call_id' in message && - message.tool_call_id - ) { - toolResponseIds.add(message.tool_call_id); - } - } - - // Second pass: filter out orphaned messages - for (const message of messages) { - if ( - message.role === 'assistant' && - 'tool_calls' in message && - message.tool_calls - ) { - // Filter out tool calls that don't have corresponding responses - const validToolCalls = message.tool_calls.filter( - (toolCall) => toolCall.id && toolResponseIds.has(toolCall.id), - ); - - if (validToolCalls.length > 0) { - // Keep the message but only with valid tool calls - const cleanedMessage = { ...message }; - ( - cleanedMessage as OpenAI.Chat.ChatCompletionMessageParam & { - tool_calls?: OpenAI.Chat.ChatCompletionMessageToolCall[]; - } - ).tool_calls = validToolCalls; - cleaned.push(cleanedMessage); - } else if ( - typeof message.content === 'string' && - message.content.trim() - ) { - // Keep the message if it has text content, but remove tool calls - const cleanedMessage = { ...message }; - delete ( - cleanedMessage as OpenAI.Chat.ChatCompletionMessageParam & { - tool_calls?: OpenAI.Chat.ChatCompletionMessageToolCall[]; - } - ).tool_calls; - cleaned.push(cleanedMessage); - } - // If no valid tool calls and no content, skip the message entirely - } else if ( - message.role === 'tool' && - 'tool_call_id' in message && - message.tool_call_id - ) { - // Only keep tool responses that have corresponding tool calls - if (toolCallIds.has(message.tool_call_id)) { - cleaned.push(message); - } - } else { - // Keep all other messages as-is - cleaned.push(message); - } - } - - // Final validation: ensure every assistant message with tool_calls has corresponding tool responses - const finalCleaned: OpenAI.Chat.ChatCompletionMessageParam[] = []; - const finalToolCallIds = new Set(); - - // Collect all remaining tool call IDs - for (const message of cleaned) { - if ( - message.role === 'assistant' && - 'tool_calls' in message && - message.tool_calls - ) { - for (const toolCall of message.tool_calls) { - if (toolCall.id) { - finalToolCallIds.add(toolCall.id); - } - } - } - } - - // Verify all tool calls have responses - const finalToolResponseIds = new Set(); - for (const message of cleaned) { - if ( - message.role === 'tool' && - 'tool_call_id' in message && - message.tool_call_id - ) { - finalToolResponseIds.add(message.tool_call_id); - } - } - - // Remove any remaining orphaned tool calls - for (const message of cleaned) { - if ( - message.role === 'assistant' && - 'tool_calls' in message && - message.tool_calls - ) { - const finalValidToolCalls = message.tool_calls.filter( - (toolCall) => toolCall.id && finalToolResponseIds.has(toolCall.id), - ); - - if (finalValidToolCalls.length > 0) { - const cleanedMessage = { ...message }; - ( - cleanedMessage as OpenAI.Chat.ChatCompletionMessageParam & { - tool_calls?: OpenAI.Chat.ChatCompletionMessageToolCall[]; - } - ).tool_calls = finalValidToolCalls; - finalCleaned.push(cleanedMessage); - } else if ( - typeof message.content === 'string' && - message.content.trim() - ) { - const cleanedMessage = { ...message }; - delete ( - cleanedMessage as OpenAI.Chat.ChatCompletionMessageParam & { - tool_calls?: OpenAI.Chat.ChatCompletionMessageToolCall[]; - } - ).tool_calls; - finalCleaned.push(cleanedMessage); - } - } else { - finalCleaned.push(message); - } - } - - return finalCleaned; - } - - /** - * Merge consecutive assistant messages to combine split text and tool calls - */ - private mergeConsecutiveAssistantMessages( - messages: OpenAI.Chat.ChatCompletionMessageParam[], - ): OpenAI.Chat.ChatCompletionMessageParam[] { - const merged: OpenAI.Chat.ChatCompletionMessageParam[] = []; - - for (const message of messages) { - if (message.role === 'assistant' && merged.length > 0) { - const lastMessage = merged[merged.length - 1]; - - // If the last message is also an assistant message, merge them - if (lastMessage.role === 'assistant') { - // Combine content - const combinedContent = [ - typeof lastMessage.content === 'string' ? lastMessage.content : '', - typeof message.content === 'string' ? message.content : '', - ] - .filter(Boolean) - .join(''); - - // Combine tool calls - const lastToolCalls = - 'tool_calls' in lastMessage ? lastMessage.tool_calls || [] : []; - const currentToolCalls = - 'tool_calls' in message ? message.tool_calls || [] : []; - const combinedToolCalls = [...lastToolCalls, ...currentToolCalls]; - - // Update the last message with combined data - ( - lastMessage as OpenAI.Chat.ChatCompletionMessageParam & { - content: string | null; - tool_calls?: OpenAI.Chat.ChatCompletionMessageToolCall[]; - } - ).content = combinedContent || null; - if (combinedToolCalls.length > 0) { - ( - lastMessage as OpenAI.Chat.ChatCompletionMessageParam & { - content: string | null; - tool_calls?: OpenAI.Chat.ChatCompletionMessageToolCall[]; - } - ).tool_calls = combinedToolCalls; - } - - continue; // Skip adding the current message since it's been merged - } - } - - // Add the message as-is if no merging is needed - merged.push(message); - } - - return merged; - } - - private convertToGeminiFormat( - openaiResponse: OpenAI.Chat.ChatCompletion, - ): GenerateContentResponse { - const choice = openaiResponse.choices[0]; - const response = new GenerateContentResponse(); - - const parts: Part[] = []; - - // Handle text content - if (choice.message.content) { - parts.push({ text: choice.message.content }); - } - - // Handle tool calls - if (choice.message.tool_calls) { - for (const toolCall of choice.message.tool_calls) { - if (toolCall.function) { - let args: Record = {}; - if (toolCall.function.arguments) { - args = safeJsonParse(toolCall.function.arguments, {}); - } - - parts.push({ - functionCall: { - id: toolCall.id, - name: toolCall.function.name, - args, - }, - }); - } - } - } - - response.responseId = openaiResponse.id; - response.createTime = openaiResponse.created - ? openaiResponse.created.toString() - : new Date().getTime().toString(); - - response.candidates = [ - { - content: { - parts, - role: 'model' as const, - }, - finishReason: this.mapFinishReason(choice.finish_reason || 'stop'), - index: 0, - safetyRatings: [], - }, - ]; - - response.modelVersion = this.model; - response.promptFeedback = { safetyRatings: [] }; - - // Add usage metadata if available - if (openaiResponse.usage) { - const usage = openaiResponse.usage as OpenAIUsage; - - const promptTokens = usage.prompt_tokens || 0; - const completionTokens = usage.completion_tokens || 0; - const totalTokens = usage.total_tokens || 0; - const cachedTokens = usage.prompt_tokens_details?.cached_tokens || 0; - - // If we only have total tokens but no breakdown, estimate the split - // Typically input is ~70% and output is ~30% for most conversations - let finalPromptTokens = promptTokens; - let finalCompletionTokens = completionTokens; - - if (totalTokens > 0 && promptTokens === 0 && completionTokens === 0) { - // Estimate: assume 70% input, 30% output - finalPromptTokens = Math.round(totalTokens * 0.7); - finalCompletionTokens = Math.round(totalTokens * 0.3); - } - - response.usageMetadata = { - promptTokenCount: finalPromptTokens, - candidatesTokenCount: finalCompletionTokens, - totalTokenCount: totalTokens, - cachedContentTokenCount: cachedTokens, - }; - } - - return response; - } - - private convertStreamChunkToGeminiFormat( - chunk: OpenAI.Chat.ChatCompletionChunk, - ): GenerateContentResponse { - const choice = chunk.choices?.[0]; - const response = new GenerateContentResponse(); - - if (choice) { - const parts: Part[] = []; - - // Handle text content - if (choice.delta?.content) { - if (typeof choice.delta.content === 'string') { - parts.push({ text: choice.delta.content }); - } - } - - // Handle tool calls - only accumulate during streaming, emit when complete - if (choice.delta?.tool_calls) { - for (const toolCall of choice.delta.tool_calls) { - const index = toolCall.index ?? 0; - - // Get or create the tool call accumulator for this index - let accumulatedCall = this.streamingToolCalls.get(index); - if (!accumulatedCall) { - accumulatedCall = { arguments: '' }; - this.streamingToolCalls.set(index, accumulatedCall); - } - - // Update accumulated data - if (toolCall.id) { - accumulatedCall.id = toolCall.id; - } - if (toolCall.function?.name) { - // If this is a new function name, reset the arguments - if (accumulatedCall.name !== toolCall.function.name) { - accumulatedCall.arguments = ''; - } - accumulatedCall.name = toolCall.function.name; - } - if (toolCall.function?.arguments) { - // Check if we already have a complete JSON object - const currentArgs = accumulatedCall.arguments; - const newArgs = toolCall.function.arguments; - - // If current arguments already form a complete JSON and new arguments start a new object, - // this indicates a new tool call with the same name - let shouldReset = false; - if (currentArgs && newArgs.trim().startsWith('{')) { - try { - JSON.parse(currentArgs); - // If we can parse current arguments as complete JSON and new args start with {, - // this is likely a new tool call - shouldReset = true; - } catch { - // Current arguments are not complete JSON, continue accumulating - } - } - - if (shouldReset) { - accumulatedCall.arguments = newArgs; - } else { - accumulatedCall.arguments += newArgs; - } - } - } - } - - // Only emit function calls when streaming is complete (finish_reason is present) - if (choice.finish_reason) { - for (const [, accumulatedCall] of this.streamingToolCalls) { - // TODO: Add back id once we have a way to generate tool_call_id from the VLLM parser. - // if (accumulatedCall.id && accumulatedCall.name) { - if (accumulatedCall.name) { - let args: Record = {}; - if (accumulatedCall.arguments) { - args = safeJsonParse(accumulatedCall.arguments, {}); - } - - parts.push({ - functionCall: { - id: - accumulatedCall.id || - `call_${Date.now()}_${Math.random().toString(36).substring(2, 9)}`, - name: accumulatedCall.name, - args, - }, - }); - } - } - // Clear all accumulated tool calls - this.streamingToolCalls.clear(); - } - - response.candidates = [ - { - content: { - parts, - role: 'model' as const, - }, - finishReason: choice.finish_reason - ? this.mapFinishReason(choice.finish_reason) - : FinishReason.FINISH_REASON_UNSPECIFIED, - index: 0, - safetyRatings: [], - }, - ]; - } else { - response.candidates = []; - } - - response.responseId = chunk.id; - response.createTime = chunk.created - ? chunk.created.toString() - : new Date().getTime().toString(); - - response.modelVersion = this.model; - response.promptFeedback = { safetyRatings: [] }; - - // Add usage metadata if available in the chunk - if (chunk.usage) { - const usage = chunk.usage as OpenAIUsage; - - const promptTokens = usage.prompt_tokens || 0; - const completionTokens = usage.completion_tokens || 0; - const totalTokens = usage.total_tokens || 0; - const cachedTokens = usage.prompt_tokens_details?.cached_tokens || 0; - - // If we only have total tokens but no breakdown, estimate the split - // Typically input is ~70% and output is ~30% for most conversations - let finalPromptTokens = promptTokens; - let finalCompletionTokens = completionTokens; - - if (totalTokens > 0 && promptTokens === 0 && completionTokens === 0) { - // Estimate: assume 70% input, 30% output - finalPromptTokens = Math.round(totalTokens * 0.7); - finalCompletionTokens = Math.round(totalTokens * 0.3); - } - - response.usageMetadata = { - promptTokenCount: finalPromptTokens, - candidatesTokenCount: finalCompletionTokens, - totalTokenCount: totalTokens, - cachedContentTokenCount: cachedTokens, - }; - } - - return response; - } - - /** - * Build sampling parameters with clear priority: - * 1. Config-level sampling parameters (highest priority) - * 2. Request-level parameters (medium priority) - * 3. Default values (lowest priority) - */ - private buildSamplingParameters( - request: GenerateContentParameters, - ): Record { - const configSamplingParams = this.contentGeneratorConfig.samplingParams; - - const params = { - // Temperature: config > request > default - temperature: - configSamplingParams?.temperature !== undefined - ? configSamplingParams.temperature - : request.config?.temperature !== undefined - ? request.config.temperature - : 0.0, - - // Max tokens: config > request > undefined - ...(configSamplingParams?.max_tokens !== undefined - ? { max_tokens: configSamplingParams.max_tokens } - : request.config?.maxOutputTokens !== undefined - ? { max_tokens: request.config.maxOutputTokens } - : {}), - - // Top-p: config > request > default - top_p: - configSamplingParams?.top_p !== undefined - ? configSamplingParams.top_p - : request.config?.topP !== undefined - ? request.config.topP - : 1.0, - - // Top-k: config only (not available in request) - ...(configSamplingParams?.top_k !== undefined - ? { top_k: configSamplingParams.top_k } - : {}), - - // Repetition penalty: config only - ...(configSamplingParams?.repetition_penalty !== undefined - ? { repetition_penalty: configSamplingParams.repetition_penalty } - : {}), - - // Presence penalty: config only - ...(configSamplingParams?.presence_penalty !== undefined - ? { presence_penalty: configSamplingParams.presence_penalty } - : {}), - - // Frequency penalty: config only - ...(configSamplingParams?.frequency_penalty !== undefined - ? { frequency_penalty: configSamplingParams.frequency_penalty } - : {}), - }; - - return params; - } - - private mapFinishReason(openaiReason: string | null): FinishReason { - if (!openaiReason) return FinishReason.FINISH_REASON_UNSPECIFIED; - const mapping: Record = { - stop: FinishReason.STOP, - length: FinishReason.MAX_TOKENS, - content_filter: FinishReason.SAFETY, - function_call: FinishReason.STOP, - tool_calls: FinishReason.STOP, - }; - return mapping[openaiReason] || FinishReason.FINISH_REASON_UNSPECIFIED; - } - - /** - * Convert Gemini response format to OpenAI chat completion format for logging - */ - private convertGeminiResponseToOpenAI( - response: GenerateContentResponse, - ): OpenAIResponseFormat { - const candidate = response.candidates?.[0]; - const content = candidate?.content; - - let messageContent: string | null = null; - const toolCalls: OpenAIToolCall[] = []; - - if (content?.parts) { - const textParts: string[] = []; - - for (const part of content.parts) { - if ('text' in part && part.text) { - textParts.push(part.text); - } else if ('functionCall' in part && part.functionCall) { - toolCalls.push({ - id: part.functionCall.id || `call_${toolCalls.length}`, - type: 'function' as const, - function: { - name: part.functionCall.name || '', - arguments: JSON.stringify(part.functionCall.args || {}), - }, - }); - } - } - - messageContent = textParts.join('').trimEnd(); - } - - const choice: OpenAIChoice = { - index: 0, - message: { - role: 'assistant', - content: messageContent, - }, - finish_reason: this.mapGeminiFinishReasonToOpenAI( - candidate?.finishReason, - ), - }; - - if (toolCalls.length > 0) { - choice.message.tool_calls = toolCalls; - } - - const openaiResponse: OpenAIResponseFormat = { - id: response.responseId || `chatcmpl-${Date.now()}`, - object: 'chat.completion', - created: response.createTime - ? Number(response.createTime) - : Math.floor(Date.now() / 1000), - model: this.model, - choices: [choice], - }; - - // Add usage metadata if available - if (response.usageMetadata) { - openaiResponse.usage = { - prompt_tokens: response.usageMetadata.promptTokenCount || 0, - completion_tokens: response.usageMetadata.candidatesTokenCount || 0, - total_tokens: response.usageMetadata.totalTokenCount || 0, - }; - - if (response.usageMetadata.cachedContentTokenCount) { - openaiResponse.usage.prompt_tokens_details = { - cached_tokens: response.usageMetadata.cachedContentTokenCount, - }; - } - } - - return openaiResponse; - } - - /** - * Map Gemini finish reasons to OpenAI finish reasons - */ - private mapGeminiFinishReasonToOpenAI(geminiReason?: unknown): string { - if (!geminiReason) return 'stop'; - - switch (geminiReason) { - case 'STOP': - case 1: // FinishReason.STOP - return 'stop'; - case 'MAX_TOKENS': - case 2: // FinishReason.MAX_TOKENS - return 'length'; - case 'SAFETY': - case 3: // FinishReason.SAFETY - return 'content_filter'; - case 'RECITATION': - case 4: // FinishReason.RECITATION - return 'content_filter'; - case 'OTHER': - case 5: // FinishReason.OTHER - return 'stop'; - default: - return 'stop'; - } - } -} diff --git a/packages/core/src/core/openaiContentGenerator/openaiContentGenerator.test.ts b/packages/core/src/core/openaiContentGenerator/openaiContentGenerator.test.ts index 60c85d23..3d1a516c 100644 --- a/packages/core/src/core/openaiContentGenerator/openaiContentGenerator.test.ts +++ b/packages/core/src/core/openaiContentGenerator/openaiContentGenerator.test.ts @@ -5,6 +5,37 @@ */ import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest'; + +// Mock the request tokenizer module BEFORE importing the class that uses it +const mockTokenizer = { + calculateTokens: vi.fn().mockResolvedValue({ + totalTokens: 50, + breakdown: { + textTokens: 50, + imageTokens: 0, + audioTokens: 0, + otherTokens: 0, + }, + processingTime: 1, + }), + dispose: vi.fn(), +}; + +vi.mock('../../../utils/request-tokenizer/index.js', () => ({ + getDefaultTokenizer: vi.fn(() => mockTokenizer), + DefaultRequestTokenizer: vi.fn(() => mockTokenizer), + disposeDefaultTokenizer: vi.fn(), +})); + +// Mock tiktoken as well for completeness +vi.mock('tiktoken', () => ({ + get_encoding: vi.fn(() => ({ + encode: vi.fn(() => new Array(50)), // Mock 50 tokens + free: vi.fn(), + })), +})); + +// Now import the modules that depend on the mocked modules import { OpenAIContentGenerator } from './openaiContentGenerator.js'; import type { Config } from '../../config/config.js'; import { AuthType } from '../contentGenerator.js'; @@ -15,14 +46,6 @@ import type { import type { OpenAICompatibleProvider } from './provider/index.js'; import type OpenAI from 'openai'; -// Mock tiktoken -vi.mock('tiktoken', () => ({ - get_encoding: vi.fn().mockReturnValue({ - encode: vi.fn().mockReturnValue(new Array(50)), // Mock 50 tokens - free: vi.fn(), - }), -})); - describe('OpenAIContentGenerator (Refactored)', () => { let generator: OpenAIContentGenerator; let mockConfig: Config; diff --git a/packages/core/src/core/openaiContentGenerator/openaiContentGenerator.ts b/packages/core/src/core/openaiContentGenerator/openaiContentGenerator.ts index fa738af3..91e69527 100644 --- a/packages/core/src/core/openaiContentGenerator/openaiContentGenerator.ts +++ b/packages/core/src/core/openaiContentGenerator/openaiContentGenerator.ts @@ -13,6 +13,7 @@ import type { PipelineConfig } from './pipeline.js'; import { ContentGenerationPipeline } from './pipeline.js'; import { DefaultTelemetryService } from './telemetryService.js'; import { EnhancedErrorHandler } from './errorHandler.js'; +import { getDefaultTokenizer } from '../../utils/request-tokenizer/index.js'; import type { ContentGeneratorConfig } from '../contentGenerator.js'; export class OpenAIContentGenerator implements ContentGenerator { @@ -71,27 +72,30 @@ export class OpenAIContentGenerator implements ContentGenerator { async countTokens( request: CountTokensParameters, ): Promise { - // Use tiktoken for accurate token counting - const content = JSON.stringify(request.contents); - let totalTokens = 0; - try { - const { get_encoding } = await import('tiktoken'); - const encoding = get_encoding('cl100k_base'); // GPT-4 encoding, but estimate for qwen - totalTokens = encoding.encode(content).length; - encoding.free(); + // Use the new high-performance request tokenizer + const tokenizer = getDefaultTokenizer(); + const result = await tokenizer.calculateTokens(request, { + textEncoding: 'cl100k_base', // Use GPT-4 encoding for consistency + }); + + return { + totalTokens: result.totalTokens, + }; } catch (error) { console.warn( - 'Failed to load tiktoken, falling back to character approximation:', + 'Failed to calculate tokens with new tokenizer, falling back to simple method:', error, ); - // Fallback: rough approximation using character count - totalTokens = Math.ceil(content.length / 4); // Rough estimate: 1 token ≈ 4 characters - } - return { - totalTokens, - }; + // Fallback to original simple method + const content = JSON.stringify(request.contents); + const totalTokens = Math.ceil(content.length / 4); // Rough estimate: 1 token ≈ 4 characters + + return { + totalTokens, + }; + } } async embedContent( diff --git a/packages/core/src/core/openaiContentGenerator/pipeline.ts b/packages/core/src/core/openaiContentGenerator/pipeline.ts index bf4b3892..85d279e6 100644 --- a/packages/core/src/core/openaiContentGenerator/pipeline.ts +++ b/packages/core/src/core/openaiContentGenerator/pipeline.ts @@ -10,14 +10,11 @@ import { GenerateContentResponse, } from '@google/genai'; import type { Config } from '../../config/config.js'; -import { type ContentGeneratorConfig } from '../contentGenerator.js'; -import { type OpenAICompatibleProvider } from './provider/index.js'; +import type { ContentGeneratorConfig } from '../contentGenerator.js'; +import type { OpenAICompatibleProvider } from './provider/index.js'; import { OpenAIContentConverter } from './converter.js'; -import { - type TelemetryService, - type RequestContext, -} from './telemetryService.js'; -import { type ErrorHandler } from './errorHandler.js'; +import type { TelemetryService, RequestContext } from './telemetryService.js'; +import type { ErrorHandler } from './errorHandler.js'; export interface PipelineConfig { cliConfig: Config; @@ -101,7 +98,7 @@ export class ContentGenerationPipeline { * 2. Filter empty responses * 3. Handle chunk merging for providers that send finishReason and usageMetadata separately * 4. Collect both formats for logging - * 5. Handle success/error logging with original OpenAI format + * 5. Handle success/error logging */ private async *processStreamWithLogging( stream: AsyncIterable, @@ -169,19 +166,11 @@ export class ContentGenerationPipeline { collectedOpenAIChunks, ); } catch (error) { - // Stage 2e: Stream failed - handle error and logging - context.duration = Date.now() - context.startTime; - // Clear streaming tool calls on error to prevent data pollution this.converter.resetStreamingToolCalls(); - await this.config.telemetryService.logError( - context, - error, - openaiRequest, - ); - - this.config.errorHandler.handle(error, context, request); + // Use shared error handling logic + await this.handleError(error, context, request); } } @@ -365,25 +354,59 @@ export class ContentGenerationPipeline { context.duration = Date.now() - context.startTime; return result; } catch (error) { - context.duration = Date.now() - context.startTime; - - // Log error - const openaiRequest = await this.buildRequest( + // Use shared error handling logic + return await this.handleError( + error, + context, request, userPromptId, isStreaming, ); - await this.config.telemetryService.logError( - context, - error, - openaiRequest, - ); - - // Handle and throw enhanced error - this.config.errorHandler.handle(error, context, request); } } + /** + * Shared error handling logic for both executeWithErrorHandling and processStreamWithLogging + * This centralizes the common error processing steps to avoid duplication + */ + private async handleError( + error: unknown, + context: RequestContext, + request: GenerateContentParameters, + userPromptId?: string, + isStreaming?: boolean, + ): Promise { + context.duration = Date.now() - context.startTime; + + // Build request for logging (may fail, but we still want to log the error) + let openaiRequest: OpenAI.Chat.ChatCompletionCreateParams; + try { + if (userPromptId !== undefined && isStreaming !== undefined) { + openaiRequest = await this.buildRequest( + request, + userPromptId, + isStreaming, + ); + } else { + // For processStreamWithLogging, we don't have userPromptId/isStreaming, + // so create a minimal request + openaiRequest = { + model: this.contentGeneratorConfig.model, + messages: [], + }; + } + } catch (_buildError) { + // If we can't build the request, create a minimal one for logging + openaiRequest = { + model: this.contentGeneratorConfig.model, + messages: [], + }; + } + + await this.config.telemetryService.logError(context, error, openaiRequest); + this.config.errorHandler.handle(error, context, request); + } + /** * Create request context with common properties */ diff --git a/packages/core/src/core/openaiContentGenerator/provider/dashscope.ts b/packages/core/src/core/openaiContentGenerator/provider/dashscope.ts index 0999052f..86cb54c0 100644 --- a/packages/core/src/core/openaiContentGenerator/provider/dashscope.ts +++ b/packages/core/src/core/openaiContentGenerator/provider/dashscope.ts @@ -79,6 +79,16 @@ export class DashScopeOpenAICompatibleProvider messages = this.addDashScopeCacheControl(messages, cacheTarget); } + if (request.model.startsWith('qwen-vl')) { + return { + ...request, + messages, + ...(this.buildMetadata(userPromptId) || {}), + /* @ts-expect-error dashscope exclusive */ + vl_high_resolution_images: true, + }; + } + return { ...request, // Preserve all original parameters including sampling params messages, diff --git a/packages/core/src/core/tokenLimits.ts b/packages/core/src/core/tokenLimits.ts index e51becab..2e502037 100644 --- a/packages/core/src/core/tokenLimits.ts +++ b/packages/core/src/core/tokenLimits.ts @@ -116,6 +116,9 @@ const PATTERNS: Array<[RegExp, TokenCount]> = [ [/^qwen-flash-latest$/, LIMITS['1m']], [/^qwen-turbo.*$/, LIMITS['128k']], + // Qwen Vision Models + [/^qwen-vl-max.*$/, LIMITS['128k']], + // ------------------- // ByteDance Seed-OSS (512K) // ------------------- diff --git a/packages/core/src/core/turn.test.ts b/packages/core/src/core/turn.test.ts index 87be432b..55bcfa0d 100644 --- a/packages/core/src/core/turn.test.ts +++ b/packages/core/src/core/turn.test.ts @@ -242,7 +242,7 @@ describe('Turn', () => { expect(turn.getDebugResponses().length).toBe(0); expect(reportError).toHaveBeenCalledWith( error, - 'Error when talking to Gemini API', + 'Error when talking to API', [...historyContent, reqParts], 'Turn.run-sendMessageStream', ); diff --git a/packages/core/src/core/turn.ts b/packages/core/src/core/turn.ts index 8dd8377c..ad6f8319 100644 --- a/packages/core/src/core/turn.ts +++ b/packages/core/src/core/turn.ts @@ -310,7 +310,7 @@ export class Turn { const contextForReport = [...this.chat.getHistory(/*curated*/ true), req]; await reportError( error, - 'Error when talking to Gemini API', + 'Error when talking to API', contextForReport, 'Turn.run-sendMessageStream', ); diff --git a/packages/core/src/qwen/qwenContentGenerator.test.ts b/packages/core/src/qwen/qwenContentGenerator.test.ts index 8992e829..8efdd530 100644 --- a/packages/core/src/qwen/qwenContentGenerator.test.ts +++ b/packages/core/src/qwen/qwenContentGenerator.test.ts @@ -401,11 +401,9 @@ describe('QwenContentGenerator', () => { expect(mockQwenClient.getAccessToken).toHaveBeenCalled(); }); - it('should count tokens with valid token', async () => { - vi.mocked(mockQwenClient.getAccessToken).mockResolvedValue({ - token: 'valid-token', - }); - vi.mocked(mockQwenClient.getCredentials).mockReturnValue(mockCredentials); + it('should count tokens without requiring authentication', async () => { + // Clear any previous mock calls + vi.clearAllMocks(); const request: CountTokensParameters = { model: 'qwen-turbo', @@ -415,7 +413,8 @@ describe('QwenContentGenerator', () => { const result = await qwenContentGenerator.countTokens(request); expect(result.totalTokens).toBe(15); - expect(mockQwenClient.getAccessToken).toHaveBeenCalled(); + // countTokens is a local operation and should not require OAuth credentials + expect(mockQwenClient.getAccessToken).not.toHaveBeenCalled(); }); it('should embed content with valid token', async () => { @@ -1652,7 +1651,7 @@ describe('QwenContentGenerator', () => { SharedTokenManager.getInstance = originalGetInstance; }); - it('should handle all method types with token failure', async () => { + it('should handle method types with token failure (except countTokens)', async () => { const mockTokenManager = { getValidCredentials: vi .fn() @@ -1685,7 +1684,7 @@ describe('QwenContentGenerator', () => { contents: [{ parts: [{ text: 'Embed' }] }], }; - // All methods should fail with the same error + // Methods requiring authentication should fail await expect( newGenerator.generateContent(generateRequest, 'test-id'), ).rejects.toThrow('Failed to obtain valid Qwen access token'); @@ -1694,14 +1693,14 @@ describe('QwenContentGenerator', () => { newGenerator.generateContentStream(generateRequest, 'test-id'), ).rejects.toThrow('Failed to obtain valid Qwen access token'); - await expect(newGenerator.countTokens(countRequest)).rejects.toThrow( - 'Failed to obtain valid Qwen access token', - ); - await expect(newGenerator.embedContent(embedRequest)).rejects.toThrow( 'Failed to obtain valid Qwen access token', ); + // countTokens should succeed as it's a local operation + const countResult = await newGenerator.countTokens(countRequest); + expect(countResult.totalTokens).toBe(15); + SharedTokenManager.getInstance = originalGetInstance; }); }); diff --git a/packages/core/src/qwen/qwenContentGenerator.ts b/packages/core/src/qwen/qwenContentGenerator.ts index 1795903e..0e3ca12e 100644 --- a/packages/core/src/qwen/qwenContentGenerator.ts +++ b/packages/core/src/qwen/qwenContentGenerator.ts @@ -180,9 +180,7 @@ export class QwenContentGenerator extends OpenAIContentGenerator { override async countTokens( request: CountTokensParameters, ): Promise { - return this.executeWithCredentialManagement(() => - super.countTokens(request), - ); + return super.countTokens(request); } /** diff --git a/packages/core/src/utils/request-tokenizer/imageTokenizer.test.ts b/packages/core/src/utils/request-tokenizer/imageTokenizer.test.ts new file mode 100644 index 00000000..cdb5f35f --- /dev/null +++ b/packages/core/src/utils/request-tokenizer/imageTokenizer.test.ts @@ -0,0 +1,157 @@ +/** + * @license + * Copyright 2025 Qwen + * SPDX-License-Identifier: Apache-2.0 + */ + +import { describe, it, expect } from 'vitest'; +import { ImageTokenizer } from './imageTokenizer.js'; + +describe('ImageTokenizer', () => { + const tokenizer = new ImageTokenizer(); + + describe('token calculation', () => { + it('should calculate tokens based on image dimensions with reference logic', () => { + const metadata = { + width: 28, + height: 28, + mimeType: 'image/png', + dataSize: 1000, + }; + + const tokens = tokenizer.calculateTokens(metadata); + + // 28x28 = 784 pixels = 1 image token + 2 special tokens = 3 total + // But minimum scaling may apply for small images + expect(tokens).toBeGreaterThanOrEqual(6); // Minimum after scaling + special tokens + }); + + it('should calculate tokens for larger images', () => { + const metadata = { + width: 512, + height: 512, + mimeType: 'image/png', + dataSize: 10000, + }; + + const tokens = tokenizer.calculateTokens(metadata); + + // 512x512 with reference logic: rounded dimensions + scaling + special tokens + expect(tokens).toBeGreaterThan(300); + expect(tokens).toBeLessThan(400); // Should be reasonable for 512x512 + }); + + it('should enforce minimum tokens per image with scaling', () => { + const metadata = { + width: 1, + height: 1, + mimeType: 'image/png', + dataSize: 100, + }; + + const tokens = tokenizer.calculateTokens(metadata); + + // Tiny images get scaled up to minimum pixels + special tokens + expect(tokens).toBeGreaterThanOrEqual(6); // 4 image tokens + 2 special tokens + }); + + it('should handle very large images with scaling', () => { + const metadata = { + width: 8192, + height: 8192, + mimeType: 'image/png', + dataSize: 100000, + }; + + const tokens = tokenizer.calculateTokens(metadata); + + // Very large images should be scaled down to max limit + special tokens + expect(tokens).toBeLessThanOrEqual(16386); // 16384 max + 2 special tokens + expect(tokens).toBeGreaterThan(16000); // Should be close to the limit + }); + }); + + describe('PNG dimension extraction', () => { + it('should extract dimensions from valid PNG', async () => { + // 1x1 PNG image in base64 + const pngBase64 = + 'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNkYPhfDwAChAI9jU77yQAAAABJRU5ErkJggg=='; + + const metadata = await tokenizer.extractImageMetadata( + pngBase64, + 'image/png', + ); + + expect(metadata.width).toBe(1); + expect(metadata.height).toBe(1); + expect(metadata.mimeType).toBe('image/png'); + }); + + it('should handle invalid PNG gracefully', async () => { + const invalidBase64 = 'invalid-png-data'; + + const metadata = await tokenizer.extractImageMetadata( + invalidBase64, + 'image/png', + ); + + // Should return default dimensions + expect(metadata.width).toBe(512); + expect(metadata.height).toBe(512); + expect(metadata.mimeType).toBe('image/png'); + }); + }); + + describe('batch processing', () => { + it('should process multiple images serially', async () => { + const pngBase64 = + 'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNkYPhfDwAChAI9jU77yQAAAABJRU5ErkJggg=='; + + const images = [ + { data: pngBase64, mimeType: 'image/png' }, + { data: pngBase64, mimeType: 'image/png' }, + { data: pngBase64, mimeType: 'image/png' }, + ]; + + const tokens = await tokenizer.calculateTokensBatch(images); + + expect(tokens).toHaveLength(3); + expect(tokens.every((t) => t >= 4)).toBe(true); // All should have at least 4 tokens + }); + + it('should handle mixed valid and invalid images', async () => { + const validPng = + 'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNkYPhfDwAChAI9jU77yQAAAABJRU5ErkJggg=='; + const invalidPng = 'invalid-data'; + + const images = [ + { data: validPng, mimeType: 'image/png' }, + { data: invalidPng, mimeType: 'image/png' }, + ]; + + const tokens = await tokenizer.calculateTokensBatch(images); + + expect(tokens).toHaveLength(2); + expect(tokens.every((t) => t >= 4)).toBe(true); // All should have at least minimum tokens + }); + }); + + describe('different image formats', () => { + it('should handle different MIME types', async () => { + const pngBase64 = + 'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNkYPhfDwAChAI9jU77yQAAAABJRU5ErkJggg=='; + + const formats = ['image/png', 'image/jpeg', 'image/webp', 'image/gif']; + + for (const mimeType of formats) { + const metadata = await tokenizer.extractImageMetadata( + pngBase64, + mimeType, + ); + expect(metadata.mimeType).toBe(mimeType); + expect(metadata.width).toBeGreaterThan(0); + expect(metadata.height).toBeGreaterThan(0); + } + }); + }); +}); diff --git a/packages/core/src/utils/request-tokenizer/imageTokenizer.ts b/packages/core/src/utils/request-tokenizer/imageTokenizer.ts new file mode 100644 index 00000000..b55c6b9e --- /dev/null +++ b/packages/core/src/utils/request-tokenizer/imageTokenizer.ts @@ -0,0 +1,505 @@ +/** + * @license + * Copyright 2025 Qwen + * SPDX-License-Identifier: Apache-2.0 + */ + +import type { ImageMetadata } from './types.js'; +import { isSupportedImageMimeType } from './supportedImageFormats.js'; + +/** + * Image tokenizer for calculating image tokens based on dimensions + * + * Key rules: + * - 28x28 pixels = 1 token + * - Minimum: 4 tokens per image + * - Maximum: 16384 tokens per image + * - Additional: 2 special tokens (vision_bos + vision_eos) + * - Supports: PNG, JPEG, WebP, GIF, BMP, TIFF, HEIC formats + */ +export class ImageTokenizer { + /** 28x28 pixels = 1 token */ + private static readonly PIXELS_PER_TOKEN = 28 * 28; + + /** Minimum tokens per image */ + private static readonly MIN_TOKENS_PER_IMAGE = 4; + + /** Maximum tokens per image */ + private static readonly MAX_TOKENS_PER_IMAGE = 16384; + + /** Special tokens for vision markers */ + private static readonly VISION_SPECIAL_TOKENS = 2; + + /** + * Extract image metadata from base64 data + * + * @param base64Data Base64-encoded image data (with or without data URL prefix) + * @param mimeType MIME type of the image + * @returns Promise resolving to ImageMetadata with dimensions and format info + */ + async extractImageMetadata( + base64Data: string, + mimeType: string, + ): Promise { + try { + // Check if the MIME type is supported + if (!isSupportedImageMimeType(mimeType)) { + console.warn(`Unsupported image format: ${mimeType}`); + // Return default metadata for unsupported formats + return { + width: 512, + height: 512, + mimeType, + dataSize: Math.floor(base64Data.length * 0.75), + }; + } + + const cleanBase64 = base64Data.replace(/^data:[^;]+;base64,/, ''); + const buffer = Buffer.from(cleanBase64, 'base64'); + const dimensions = await this.extractDimensions(buffer, mimeType); + + return { + width: dimensions.width, + height: dimensions.height, + mimeType, + dataSize: buffer.length, + }; + } catch (error) { + console.warn('Failed to extract image metadata:', error); + // Return default metadata for fallback + return { + width: 512, + height: 512, + mimeType, + dataSize: Math.floor(base64Data.length * 0.75), + }; + } + } + + /** + * Extract image dimensions from buffer based on format + * + * @param buffer Binary image data buffer + * @param mimeType MIME type to determine parsing strategy + * @returns Promise resolving to width and height dimensions + */ + private async extractDimensions( + buffer: Buffer, + mimeType: string, + ): Promise<{ width: number; height: number }> { + if (mimeType.includes('png')) { + return this.extractPngDimensions(buffer); + } + + if (mimeType.includes('jpeg') || mimeType.includes('jpg')) { + return this.extractJpegDimensions(buffer); + } + + if (mimeType.includes('webp')) { + return this.extractWebpDimensions(buffer); + } + + if (mimeType.includes('gif')) { + return this.extractGifDimensions(buffer); + } + + if (mimeType.includes('bmp')) { + return this.extractBmpDimensions(buffer); + } + + if (mimeType.includes('tiff')) { + return this.extractTiffDimensions(buffer); + } + + if (mimeType.includes('heic')) { + return this.extractHeicDimensions(buffer); + } + + return { width: 512, height: 512 }; + } + + /** + * Extract PNG dimensions from IHDR chunk + * PNG signature: 89 50 4E 47 0D 0A 1A 0A + * Width/height at bytes 16-19 and 20-23 (big-endian) + */ + private extractPngDimensions(buffer: Buffer): { + width: number; + height: number; + } { + if (buffer.length < 24) { + throw new Error('Invalid PNG: buffer too short'); + } + + // Verify PNG signature + const signature = buffer.subarray(0, 8); + const expectedSignature = Buffer.from([ + 0x89, 0x50, 0x4e, 0x47, 0x0d, 0x0a, 0x1a, 0x0a, + ]); + if (!signature.equals(expectedSignature)) { + throw new Error('Invalid PNG signature'); + } + + const width = buffer.readUInt32BE(16); + const height = buffer.readUInt32BE(20); + + return { width, height }; + } + + /** + * Extract JPEG dimensions from SOF (Start of Frame) markers + * JPEG starts with FF D8, SOF markers: 0xC0-0xC3, 0xC5-0xC7, 0xC9-0xCB, 0xCD-0xCF + * Dimensions at offset +5 (height) and +7 (width) from SOF marker + */ + private extractJpegDimensions(buffer: Buffer): { + width: number; + height: number; + } { + if (buffer.length < 4 || buffer[0] !== 0xff || buffer[1] !== 0xd8) { + throw new Error('Invalid JPEG signature'); + } + + let offset = 2; + + while (offset < buffer.length - 8) { + if (buffer[offset] !== 0xff) { + offset++; + continue; + } + + const marker = buffer[offset + 1]; + + // SOF markers + if ( + (marker >= 0xc0 && marker <= 0xc3) || + (marker >= 0xc5 && marker <= 0xc7) || + (marker >= 0xc9 && marker <= 0xcb) || + (marker >= 0xcd && marker <= 0xcf) + ) { + const height = buffer.readUInt16BE(offset + 5); + const width = buffer.readUInt16BE(offset + 7); + return { width, height }; + } + + const segmentLength = buffer.readUInt16BE(offset + 2); + offset += 2 + segmentLength; + } + + throw new Error('Could not find JPEG dimensions'); + } + + /** + * Extract WebP dimensions from RIFF container + * Supports VP8, VP8L, and VP8X formats + */ + private extractWebpDimensions(buffer: Buffer): { + width: number; + height: number; + } { + if (buffer.length < 30) { + throw new Error('Invalid WebP: too short'); + } + + const riffSignature = buffer.subarray(0, 4).toString('ascii'); + const webpSignature = buffer.subarray(8, 12).toString('ascii'); + + if (riffSignature !== 'RIFF' || webpSignature !== 'WEBP') { + throw new Error('Invalid WebP signature'); + } + + const format = buffer.subarray(12, 16).toString('ascii'); + + if (format === 'VP8 ') { + const width = buffer.readUInt16LE(26) & 0x3fff; + const height = buffer.readUInt16LE(28) & 0x3fff; + return { width, height }; + } else if (format === 'VP8L') { + const bits = buffer.readUInt32LE(21); + const width = (bits & 0x3fff) + 1; + const height = ((bits >> 14) & 0x3fff) + 1; + return { width, height }; + } else if (format === 'VP8X') { + const width = (buffer.readUInt32LE(24) & 0xffffff) + 1; + const height = (buffer.readUInt32LE(26) & 0xffffff) + 1; + return { width, height }; + } + + throw new Error('Unsupported WebP format'); + } + + /** + * Extract GIF dimensions from header + * Supports GIF87a and GIF89a formats + */ + private extractGifDimensions(buffer: Buffer): { + width: number; + height: number; + } { + if (buffer.length < 10) { + throw new Error('Invalid GIF: too short'); + } + + const signature = buffer.subarray(0, 6).toString('ascii'); + if (signature !== 'GIF87a' && signature !== 'GIF89a') { + throw new Error('Invalid GIF signature'); + } + + const width = buffer.readUInt16LE(6); + const height = buffer.readUInt16LE(8); + + return { width, height }; + } + + /** + * Calculate tokens for an image based on its metadata + * + * @param metadata Image metadata containing width, height, and format info + * @returns Total token count including base image tokens and special tokens + */ + calculateTokens(metadata: ImageMetadata): number { + return this.calculateTokensWithScaling(metadata.width, metadata.height); + } + + /** + * Calculate tokens with scaling logic + * + * Steps: + * 1. Normalize to 28-pixel multiples + * 2. Scale large images down, small images up + * 3. Calculate tokens: pixels / 784 + 2 special tokens + * + * @param width Original image width in pixels + * @param height Original image height in pixels + * @returns Total token count for the image + */ + private calculateTokensWithScaling(width: number, height: number): number { + // Normalize to 28-pixel multiples + let hBar = Math.round(height / 28) * 28; + let wBar = Math.round(width / 28) * 28; + + // Define pixel boundaries + const minPixels = + ImageTokenizer.MIN_TOKENS_PER_IMAGE * ImageTokenizer.PIXELS_PER_TOKEN; + const maxPixels = + ImageTokenizer.MAX_TOKENS_PER_IMAGE * ImageTokenizer.PIXELS_PER_TOKEN; + + // Apply scaling + if (hBar * wBar > maxPixels) { + // Scale down large images + const beta = Math.sqrt((height * width) / maxPixels); + hBar = Math.floor(height / beta / 28) * 28; + wBar = Math.floor(width / beta / 28) * 28; + } else if (hBar * wBar < minPixels) { + // Scale up small images + const beta = Math.sqrt(minPixels / (height * width)); + hBar = Math.ceil((height * beta) / 28) * 28; + wBar = Math.ceil((width * beta) / 28) * 28; + } + + // Calculate tokens + const imageTokens = Math.floor( + (hBar * wBar) / ImageTokenizer.PIXELS_PER_TOKEN, + ); + + return imageTokens + ImageTokenizer.VISION_SPECIAL_TOKENS; + } + + /** + * Calculate tokens for multiple images serially + * + * @param base64DataArray Array of image data with MIME type information + * @returns Promise resolving to array of token counts in same order as input + */ + async calculateTokensBatch( + base64DataArray: Array<{ data: string; mimeType: string }>, + ): Promise { + const results: number[] = []; + + for (const { data, mimeType } of base64DataArray) { + try { + const metadata = await this.extractImageMetadata(data, mimeType); + results.push(this.calculateTokens(metadata)); + } catch (error) { + console.warn('Error calculating tokens for image:', error); + // Return minimum tokens as fallback + results.push( + ImageTokenizer.MIN_TOKENS_PER_IMAGE + + ImageTokenizer.VISION_SPECIAL_TOKENS, + ); + } + } + + return results; + } + + /** + * Extract BMP dimensions from header + * BMP signature: 42 4D (BM) + * Width/height at bytes 18-21 and 22-25 (little-endian) + */ + private extractBmpDimensions(buffer: Buffer): { + width: number; + height: number; + } { + if (buffer.length < 26) { + throw new Error('Invalid BMP: buffer too short'); + } + + // Verify BMP signature + if (buffer[0] !== 0x42 || buffer[1] !== 0x4d) { + throw new Error('Invalid BMP signature'); + } + + const width = buffer.readUInt32LE(18); + const height = buffer.readUInt32LE(22); + + return { width, height: Math.abs(height) }; // Height can be negative for top-down BMPs + } + + /** + * Extract TIFF dimensions from IFD (Image File Directory) + * TIFF can be little-endian (II) or big-endian (MM) + * Width/height are stored in IFD entries with tags 0x0100 and 0x0101 + */ + private extractTiffDimensions(buffer: Buffer): { + width: number; + height: number; + } { + if (buffer.length < 8) { + throw new Error('Invalid TIFF: buffer too short'); + } + + // Check byte order + const byteOrder = buffer.subarray(0, 2).toString('ascii'); + const isLittleEndian = byteOrder === 'II'; + const isBigEndian = byteOrder === 'MM'; + + if (!isLittleEndian && !isBigEndian) { + throw new Error('Invalid TIFF byte order'); + } + + // Read magic number (should be 42) + const magic = isLittleEndian + ? buffer.readUInt16LE(2) + : buffer.readUInt16BE(2); + if (magic !== 42) { + throw new Error('Invalid TIFF magic number'); + } + + // Read IFD offset + const ifdOffset = isLittleEndian + ? buffer.readUInt32LE(4) + : buffer.readUInt32BE(4); + + if (ifdOffset >= buffer.length) { + throw new Error('Invalid TIFF IFD offset'); + } + + // Read number of directory entries + const numEntries = isLittleEndian + ? buffer.readUInt16LE(ifdOffset) + : buffer.readUInt16BE(ifdOffset); + + let width = 0; + let height = 0; + + // Parse IFD entries + for (let i = 0; i < numEntries; i++) { + const entryOffset = ifdOffset + 2 + i * 12; + + if (entryOffset + 12 > buffer.length) break; + + const tag = isLittleEndian + ? buffer.readUInt16LE(entryOffset) + : buffer.readUInt16BE(entryOffset); + + const type = isLittleEndian + ? buffer.readUInt16LE(entryOffset + 2) + : buffer.readUInt16BE(entryOffset + 2); + + const value = isLittleEndian + ? buffer.readUInt32LE(entryOffset + 8) + : buffer.readUInt32BE(entryOffset + 8); + + if (tag === 0x0100) { + // ImageWidth + width = type === 3 ? value : value; // SHORT or LONG + } else if (tag === 0x0101) { + // ImageLength (height) + height = type === 3 ? value : value; // SHORT or LONG + } + + if (width > 0 && height > 0) break; + } + + if (width === 0 || height === 0) { + throw new Error('Could not find TIFF dimensions'); + } + + return { width, height }; + } + + /** + * Extract HEIC dimensions from meta box + * HEIC is based on ISO Base Media File Format + * This is a simplified implementation that looks for 'ispe' (Image Spatial Extents) box + */ + private extractHeicDimensions(buffer: Buffer): { + width: number; + height: number; + } { + if (buffer.length < 12) { + throw new Error('Invalid HEIC: buffer too short'); + } + + // Check for ftyp box with HEIC brand + const ftypBox = buffer.subarray(4, 8).toString('ascii'); + if (ftypBox !== 'ftyp') { + throw new Error('Invalid HEIC: missing ftyp box'); + } + + const brand = buffer.subarray(8, 12).toString('ascii'); + if (!['heic', 'heix', 'hevc', 'hevx'].includes(brand)) { + throw new Error('Invalid HEIC brand'); + } + + // Look for meta box and then ispe box + let offset = 0; + while (offset < buffer.length - 8) { + const boxSize = buffer.readUInt32BE(offset); + const boxType = buffer.subarray(offset + 4, offset + 8).toString('ascii'); + + if (boxType === 'meta') { + // Look for ispe box inside meta box + const metaOffset = offset + 8; + let innerOffset = metaOffset + 4; // Skip version and flags + + while (innerOffset < offset + boxSize - 8) { + const innerBoxSize = buffer.readUInt32BE(innerOffset); + const innerBoxType = buffer + .subarray(innerOffset + 4, innerOffset + 8) + .toString('ascii'); + + if (innerBoxType === 'ispe') { + // Found Image Spatial Extents box + if (innerOffset + 20 <= buffer.length) { + const width = buffer.readUInt32BE(innerOffset + 12); + const height = buffer.readUInt32BE(innerOffset + 16); + return { width, height }; + } + } + + if (innerBoxSize === 0) break; + innerOffset += innerBoxSize; + } + } + + if (boxSize === 0) break; + offset += boxSize; + } + + // Fallback: return default dimensions if we can't parse the structure + console.warn('Could not extract HEIC dimensions, using default'); + return { width: 512, height: 512 }; + } +} diff --git a/packages/core/src/utils/request-tokenizer/index.ts b/packages/core/src/utils/request-tokenizer/index.ts new file mode 100644 index 00000000..064b93c1 --- /dev/null +++ b/packages/core/src/utils/request-tokenizer/index.ts @@ -0,0 +1,40 @@ +/** + * @license + * Copyright 2025 Qwen + * SPDX-License-Identifier: Apache-2.0 + */ + +export { DefaultRequestTokenizer } from './requestTokenizer.js'; +import { DefaultRequestTokenizer } from './requestTokenizer.js'; +export { TextTokenizer } from './textTokenizer.js'; +export { ImageTokenizer } from './imageTokenizer.js'; + +export type { + RequestTokenizer, + TokenizerConfig, + TokenCalculationResult, + ImageMetadata, +} from './types.js'; + +// Singleton instance for convenient usage +let defaultTokenizer: DefaultRequestTokenizer | null = null; + +/** + * Get the default request tokenizer instance + */ +export function getDefaultTokenizer(): DefaultRequestTokenizer { + if (!defaultTokenizer) { + defaultTokenizer = new DefaultRequestTokenizer(); + } + return defaultTokenizer; +} + +/** + * Dispose of the default tokenizer instance + */ +export async function disposeDefaultTokenizer(): Promise { + if (defaultTokenizer) { + await defaultTokenizer.dispose(); + defaultTokenizer = null; + } +} diff --git a/packages/core/src/utils/request-tokenizer/requestTokenizer.test.ts b/packages/core/src/utils/request-tokenizer/requestTokenizer.test.ts new file mode 100644 index 00000000..cb69163b --- /dev/null +++ b/packages/core/src/utils/request-tokenizer/requestTokenizer.test.ts @@ -0,0 +1,293 @@ +/** + * @license + * Copyright 2025 Qwen + * SPDX-License-Identifier: Apache-2.0 + */ + +import { describe, it, expect, beforeEach, afterEach } from 'vitest'; +import { DefaultRequestTokenizer } from './requestTokenizer.js'; +import type { CountTokensParameters } from '@google/genai'; + +describe('DefaultRequestTokenizer', () => { + let tokenizer: DefaultRequestTokenizer; + + beforeEach(() => { + tokenizer = new DefaultRequestTokenizer(); + }); + + afterEach(async () => { + await tokenizer.dispose(); + }); + + describe('text token calculation', () => { + it('should calculate tokens for simple text content', async () => { + const request: CountTokensParameters = { + model: 'test-model', + contents: [ + { + role: 'user', + parts: [{ text: 'Hello, world!' }], + }, + ], + }; + + const result = await tokenizer.calculateTokens(request); + + expect(result.totalTokens).toBeGreaterThan(0); + expect(result.breakdown.textTokens).toBeGreaterThan(0); + expect(result.breakdown.imageTokens).toBe(0); + expect(result.processingTime).toBeGreaterThan(0); + }); + + it('should handle multiple text parts', async () => { + const request: CountTokensParameters = { + model: 'test-model', + contents: [ + { + role: 'user', + parts: [ + { text: 'First part' }, + { text: 'Second part' }, + { text: 'Third part' }, + ], + }, + ], + }; + + const result = await tokenizer.calculateTokens(request); + + expect(result.totalTokens).toBeGreaterThan(0); + expect(result.breakdown.textTokens).toBeGreaterThan(0); + }); + + it('should handle string content', async () => { + const request: CountTokensParameters = { + model: 'test-model', + contents: ['Simple string content'], + }; + + const result = await tokenizer.calculateTokens(request); + + expect(result.totalTokens).toBeGreaterThan(0); + expect(result.breakdown.textTokens).toBeGreaterThan(0); + }); + }); + + describe('image token calculation', () => { + it('should calculate tokens for image content', async () => { + // Create a simple 1x1 PNG image in base64 + const pngBase64 = + 'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNkYPhfDwAChAI9jU77yQAAAABJRU5ErkJggg=='; + + const request: CountTokensParameters = { + model: 'test-model', + contents: [ + { + role: 'user', + parts: [ + { + inlineData: { + mimeType: 'image/png', + data: pngBase64, + }, + }, + ], + }, + ], + }; + + const result = await tokenizer.calculateTokens(request); + + expect(result.totalTokens).toBeGreaterThanOrEqual(4); // Minimum 4 tokens per image + expect(result.breakdown.imageTokens).toBeGreaterThanOrEqual(4); + expect(result.breakdown.textTokens).toBe(0); + }); + + it('should handle multiple images', async () => { + const pngBase64 = + 'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNkYPhfDwAChAI9jU77yQAAAABJRU5ErkJggg=='; + + const request: CountTokensParameters = { + model: 'test-model', + contents: [ + { + role: 'user', + parts: [ + { + inlineData: { + mimeType: 'image/png', + data: pngBase64, + }, + }, + { + inlineData: { + mimeType: 'image/png', + data: pngBase64, + }, + }, + ], + }, + ], + }; + + const result = await tokenizer.calculateTokens(request); + + expect(result.totalTokens).toBeGreaterThanOrEqual(8); // At least 4 tokens per image + expect(result.breakdown.imageTokens).toBeGreaterThanOrEqual(8); + }); + }); + + describe('mixed content', () => { + it('should handle text and image content together', async () => { + const pngBase64 = + 'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNkYPhfDwAChAI9jU77yQAAAABJRU5ErkJggg=='; + + const request: CountTokensParameters = { + model: 'test-model', + contents: [ + { + role: 'user', + parts: [ + { text: 'Here is an image:' }, + { + inlineData: { + mimeType: 'image/png', + data: pngBase64, + }, + }, + { text: 'What do you see?' }, + ], + }, + ], + }; + + const result = await tokenizer.calculateTokens(request); + + expect(result.totalTokens).toBeGreaterThan(4); + expect(result.breakdown.textTokens).toBeGreaterThan(0); + expect(result.breakdown.imageTokens).toBeGreaterThanOrEqual(4); + }); + }); + + describe('function content', () => { + it('should handle function calls', async () => { + const request: CountTokensParameters = { + model: 'test-model', + contents: [ + { + role: 'user', + parts: [ + { + functionCall: { + name: 'test_function', + args: { param1: 'value1', param2: 42 }, + }, + }, + ], + }, + ], + }; + + const result = await tokenizer.calculateTokens(request); + + expect(result.totalTokens).toBeGreaterThan(0); + expect(result.breakdown.otherTokens).toBeGreaterThan(0); + }); + }); + + describe('empty content', () => { + it('should handle empty request', async () => { + const request: CountTokensParameters = { + model: 'test-model', + contents: [], + }; + + const result = await tokenizer.calculateTokens(request); + + expect(result.totalTokens).toBe(0); + expect(result.breakdown.textTokens).toBe(0); + expect(result.breakdown.imageTokens).toBe(0); + }); + + it('should handle undefined contents', async () => { + const request: CountTokensParameters = { + model: 'test-model', + contents: [], + }; + + const result = await tokenizer.calculateTokens(request); + + expect(result.totalTokens).toBe(0); + }); + }); + + describe('configuration', () => { + it('should use custom text encoding', async () => { + const request: CountTokensParameters = { + model: 'test-model', + contents: [ + { + role: 'user', + parts: [{ text: 'Test text for encoding' }], + }, + ], + }; + + const result = await tokenizer.calculateTokens(request, { + textEncoding: 'cl100k_base', + }); + + expect(result.totalTokens).toBeGreaterThan(0); + }); + + it('should process multiple images serially', async () => { + const pngBase64 = + 'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNkYPhfDwAChAI9jU77yQAAAABJRU5ErkJggg=='; + + const request: CountTokensParameters = { + model: 'test-model', + contents: [ + { + role: 'user', + parts: Array(10).fill({ + inlineData: { + mimeType: 'image/png', + data: pngBase64, + }, + }), + }, + ], + }; + + const result = await tokenizer.calculateTokens(request); + + expect(result.totalTokens).toBeGreaterThanOrEqual(60); // At least 6 tokens per image * 10 images + }); + }); + + describe('error handling', () => { + it('should handle malformed image data gracefully', async () => { + const request: CountTokensParameters = { + model: 'test-model', + contents: [ + { + role: 'user', + parts: [ + { + inlineData: { + mimeType: 'image/png', + data: 'invalid-base64-data', + }, + }, + ], + }, + ], + }; + + const result = await tokenizer.calculateTokens(request); + + // Should still return some tokens (fallback to minimum) + expect(result.totalTokens).toBeGreaterThanOrEqual(4); + }); + }); +}); diff --git a/packages/core/src/utils/request-tokenizer/requestTokenizer.ts b/packages/core/src/utils/request-tokenizer/requestTokenizer.ts new file mode 100644 index 00000000..173bb261 --- /dev/null +++ b/packages/core/src/utils/request-tokenizer/requestTokenizer.ts @@ -0,0 +1,341 @@ +/** + * @license + * Copyright 2025 Qwen + * SPDX-License-Identifier: Apache-2.0 + */ + +import type { + CountTokensParameters, + Content, + Part, + PartUnion, +} from '@google/genai'; +import type { + RequestTokenizer, + TokenizerConfig, + TokenCalculationResult, +} from './types.js'; +import { TextTokenizer } from './textTokenizer.js'; +import { ImageTokenizer } from './imageTokenizer.js'; + +/** + * Simple request tokenizer that handles text and image content serially + */ +export class DefaultRequestTokenizer implements RequestTokenizer { + private textTokenizer: TextTokenizer; + private imageTokenizer: ImageTokenizer; + + constructor() { + this.textTokenizer = new TextTokenizer(); + this.imageTokenizer = new ImageTokenizer(); + } + + /** + * Calculate tokens for a request using serial processing + */ + async calculateTokens( + request: CountTokensParameters, + config: TokenizerConfig = {}, + ): Promise { + const startTime = performance.now(); + + // Apply configuration + if (config.textEncoding) { + this.textTokenizer = new TextTokenizer(config.textEncoding); + } + + try { + // Process request content and group by type + const { textContents, imageContents, audioContents, otherContents } = + this.processAndGroupContents(request); + + if ( + textContents.length === 0 && + imageContents.length === 0 && + audioContents.length === 0 && + otherContents.length === 0 + ) { + return { + totalTokens: 0, + breakdown: { + textTokens: 0, + imageTokens: 0, + audioTokens: 0, + otherTokens: 0, + }, + processingTime: performance.now() - startTime, + }; + } + + // Calculate tokens for each content type serially + const textTokens = await this.calculateTextTokens(textContents); + const imageTokens = await this.calculateImageTokens(imageContents); + const audioTokens = await this.calculateAudioTokens(audioContents); + const otherTokens = await this.calculateOtherTokens(otherContents); + + const totalTokens = textTokens + imageTokens + audioTokens + otherTokens; + const processingTime = performance.now() - startTime; + + return { + totalTokens, + breakdown: { + textTokens, + imageTokens, + audioTokens, + otherTokens, + }, + processingTime, + }; + } catch (error) { + console.error('Error calculating tokens:', error); + + // Fallback calculation + const fallbackTokens = this.calculateFallbackTokens(request); + + return { + totalTokens: fallbackTokens, + breakdown: { + textTokens: fallbackTokens, + imageTokens: 0, + audioTokens: 0, + otherTokens: 0, + }, + processingTime: performance.now() - startTime, + }; + } + } + + /** + * Calculate tokens for text contents + */ + private async calculateTextTokens(textContents: string[]): Promise { + if (textContents.length === 0) return 0; + + try { + const tokenCounts = + await this.textTokenizer.calculateTokensBatch(textContents); + return tokenCounts.reduce((sum, count) => sum + count, 0); + } catch (error) { + console.warn('Error calculating text tokens:', error); + // Fallback: character-based estimation + const totalChars = textContents.join('').length; + return Math.ceil(totalChars / 4); + } + } + + /** + * Calculate tokens for image contents using serial processing + */ + private async calculateImageTokens( + imageContents: Array<{ data: string; mimeType: string }>, + ): Promise { + if (imageContents.length === 0) return 0; + + try { + const tokenCounts = + await this.imageTokenizer.calculateTokensBatch(imageContents); + return tokenCounts.reduce((sum, count) => sum + count, 0); + } catch (error) { + console.warn('Error calculating image tokens:', error); + // Fallback: minimum tokens per image + return imageContents.length * 6; // 4 image tokens + 2 special tokens as minimum + } + } + + /** + * Calculate tokens for audio contents + * TODO: Implement proper audio token calculation + */ + private async calculateAudioTokens( + audioContents: Array<{ data: string; mimeType: string }>, + ): Promise { + if (audioContents.length === 0) return 0; + + // Placeholder implementation - audio token calculation would depend on + // the specific model's audio processing capabilities + // For now, estimate based on data size + let totalTokens = 0; + + for (const audioContent of audioContents) { + try { + const dataSize = Math.floor(audioContent.data.length * 0.75); // Approximate binary size + // Rough estimate: 1 token per 100 bytes of audio data + totalTokens += Math.max(Math.ceil(dataSize / 100), 10); // Minimum 10 tokens per audio + } catch (error) { + console.warn('Error calculating audio tokens:', error); + totalTokens += 10; // Fallback minimum + } + } + + return totalTokens; + } + + /** + * Calculate tokens for other content types (functions, files, etc.) + */ + private async calculateOtherTokens(otherContents: string[]): Promise { + if (otherContents.length === 0) return 0; + + try { + // Treat other content as text for token calculation + const tokenCounts = + await this.textTokenizer.calculateTokensBatch(otherContents); + return tokenCounts.reduce((sum, count) => sum + count, 0); + } catch (error) { + console.warn('Error calculating other content tokens:', error); + // Fallback: character-based estimation + const totalChars = otherContents.join('').length; + return Math.ceil(totalChars / 4); + } + } + + /** + * Fallback token calculation using simple string serialization + */ + private calculateFallbackTokens(request: CountTokensParameters): number { + try { + const content = JSON.stringify(request.contents); + return Math.ceil(content.length / 4); // Rough estimate: 1 token ≈ 4 characters + } catch (error) { + console.warn('Error in fallback token calculation:', error); + return 100; // Conservative fallback + } + } + + /** + * Process request contents and group by type + */ + private processAndGroupContents(request: CountTokensParameters): { + textContents: string[]; + imageContents: Array<{ data: string; mimeType: string }>; + audioContents: Array<{ data: string; mimeType: string }>; + otherContents: string[]; + } { + const textContents: string[] = []; + const imageContents: Array<{ data: string; mimeType: string }> = []; + const audioContents: Array<{ data: string; mimeType: string }> = []; + const otherContents: string[] = []; + + if (!request.contents) { + return { textContents, imageContents, audioContents, otherContents }; + } + + const contents = Array.isArray(request.contents) + ? request.contents + : [request.contents]; + + for (const content of contents) { + this.processContent( + content, + textContents, + imageContents, + audioContents, + otherContents, + ); + } + + return { textContents, imageContents, audioContents, otherContents }; + } + + /** + * Process a single content item and add to appropriate arrays + */ + private processContent( + content: Content | string | PartUnion, + textContents: string[], + imageContents: Array<{ data: string; mimeType: string }>, + audioContents: Array<{ data: string; mimeType: string }>, + otherContents: string[], + ): void { + if (typeof content === 'string') { + if (content.trim()) { + textContents.push(content); + } + return; + } + + if ('parts' in content && content.parts) { + for (const part of content.parts) { + this.processPart( + part, + textContents, + imageContents, + audioContents, + otherContents, + ); + } + } + } + + /** + * Process a single part and add to appropriate arrays + */ + private processPart( + part: Part | string, + textContents: string[], + imageContents: Array<{ data: string; mimeType: string }>, + audioContents: Array<{ data: string; mimeType: string }>, + otherContents: string[], + ): void { + if (typeof part === 'string') { + if (part.trim()) { + textContents.push(part); + } + return; + } + + if ('text' in part && part.text) { + textContents.push(part.text); + return; + } + + if ('inlineData' in part && part.inlineData) { + const { data, mimeType } = part.inlineData; + if (mimeType && mimeType.startsWith('image/')) { + imageContents.push({ data: data || '', mimeType }); + return; + } + if (mimeType && mimeType.startsWith('audio/')) { + audioContents.push({ data: data || '', mimeType }); + return; + } + } + + if ('fileData' in part && part.fileData) { + otherContents.push(JSON.stringify(part.fileData)); + return; + } + + if ('functionCall' in part && part.functionCall) { + otherContents.push(JSON.stringify(part.functionCall)); + return; + } + + if ('functionResponse' in part && part.functionResponse) { + otherContents.push(JSON.stringify(part.functionResponse)); + return; + } + + // Unknown part type - try to serialize + try { + const serialized = JSON.stringify(part); + if (serialized && serialized !== '{}') { + otherContents.push(serialized); + } + } catch (error) { + console.warn('Failed to serialize unknown part type:', error); + } + } + + /** + * Dispose of resources + */ + async dispose(): Promise { + try { + // Dispose of tokenizers + this.textTokenizer.dispose(); + } catch (error) { + console.warn('Error disposing request tokenizer:', error); + } + } +} diff --git a/packages/core/src/utils/request-tokenizer/supportedImageFormats.ts b/packages/core/src/utils/request-tokenizer/supportedImageFormats.ts new file mode 100644 index 00000000..fce679d7 --- /dev/null +++ b/packages/core/src/utils/request-tokenizer/supportedImageFormats.ts @@ -0,0 +1,56 @@ +/** + * @license + * Copyright 2025 Qwen + * SPDX-License-Identifier: Apache-2.0 + */ + +/** + * Supported image MIME types for vision models + * These formats are supported by the vision model and can be processed by the image tokenizer + */ +export const SUPPORTED_IMAGE_MIME_TYPES = [ + 'image/bmp', + 'image/jpeg', + 'image/jpg', // Alternative MIME type for JPEG + 'image/png', + 'image/tiff', + 'image/webp', + 'image/heic', +] as const; + +/** + * Type for supported image MIME types + */ +export type SupportedImageMimeType = + (typeof SUPPORTED_IMAGE_MIME_TYPES)[number]; + +/** + * Check if a MIME type is supported for vision processing + * @param mimeType The MIME type to check + * @returns True if the MIME type is supported + */ +export function isSupportedImageMimeType( + mimeType: string, +): mimeType is SupportedImageMimeType { + return SUPPORTED_IMAGE_MIME_TYPES.includes( + mimeType as SupportedImageMimeType, + ); +} + +/** + * Get a human-readable list of supported image formats + * @returns Comma-separated string of supported formats + */ +export function getSupportedImageFormatsString(): string { + return SUPPORTED_IMAGE_MIME_TYPES.map((type) => + type.replace('image/', '').toUpperCase(), + ).join(', '); +} + +/** + * Get warning message for unsupported image formats + * @returns Warning message string + */ +export function getUnsupportedImageFormatWarning(): string { + return `Only the following image formats are supported: ${getSupportedImageFormatsString()}. Other formats may not work as expected.`; +} diff --git a/packages/core/src/utils/request-tokenizer/textTokenizer.test.ts b/packages/core/src/utils/request-tokenizer/textTokenizer.test.ts new file mode 100644 index 00000000..f29155a8 --- /dev/null +++ b/packages/core/src/utils/request-tokenizer/textTokenizer.test.ts @@ -0,0 +1,347 @@ +/** + * @license + * Copyright 2025 Qwen + * SPDX-License-Identifier: Apache-2.0 + */ + +import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest'; +import { TextTokenizer } from './textTokenizer.js'; + +// Mock tiktoken at the top level with hoisted functions +const mockEncode = vi.hoisted(() => vi.fn()); +const mockFree = vi.hoisted(() => vi.fn()); +const mockGetEncoding = vi.hoisted(() => vi.fn()); + +vi.mock('tiktoken', () => ({ + get_encoding: mockGetEncoding, +})); + +describe('TextTokenizer', () => { + let tokenizer: TextTokenizer; + let consoleWarnSpy: ReturnType; + + beforeEach(() => { + vi.resetAllMocks(); + consoleWarnSpy = vi.spyOn(console, 'warn').mockImplementation(() => {}); + + // Default mock implementation + mockGetEncoding.mockReturnValue({ + encode: mockEncode, + free: mockFree, + }); + }); + + afterEach(() => { + vi.restoreAllMocks(); + tokenizer?.dispose(); + }); + + describe('constructor', () => { + it('should create tokenizer with default encoding', () => { + tokenizer = new TextTokenizer(); + expect(tokenizer).toBeInstanceOf(TextTokenizer); + }); + + it('should create tokenizer with custom encoding', () => { + tokenizer = new TextTokenizer('gpt2'); + expect(tokenizer).toBeInstanceOf(TextTokenizer); + }); + }); + + describe('calculateTokens', () => { + beforeEach(() => { + tokenizer = new TextTokenizer(); + }); + + it('should return 0 for empty text', async () => { + const result = await tokenizer.calculateTokens(''); + expect(result).toBe(0); + }); + + it('should return 0 for null/undefined text', async () => { + const result1 = await tokenizer.calculateTokens( + null as unknown as string, + ); + const result2 = await tokenizer.calculateTokens( + undefined as unknown as string, + ); + expect(result1).toBe(0); + expect(result2).toBe(0); + }); + + it('should calculate tokens using tiktoken when available', async () => { + const testText = 'Hello, world!'; + const mockTokens = [1, 2, 3, 4, 5]; // 5 tokens + mockEncode.mockReturnValue(mockTokens); + + const result = await tokenizer.calculateTokens(testText); + + expect(mockGetEncoding).toHaveBeenCalledWith('cl100k_base'); + expect(mockEncode).toHaveBeenCalledWith(testText); + expect(result).toBe(5); + }); + + it('should use fallback calculation when tiktoken fails to load', async () => { + mockGetEncoding.mockImplementation(() => { + throw new Error('Failed to load tiktoken'); + }); + + const testText = 'Hello, world!'; // 13 characters + const result = await tokenizer.calculateTokens(testText); + + expect(consoleWarnSpy).toHaveBeenCalledWith( + 'Failed to load tiktoken with encoding cl100k_base:', + expect.any(Error), + ); + // Fallback: Math.ceil(13 / 4) = 4 + expect(result).toBe(4); + }); + + it('should use fallback calculation when encoding fails', async () => { + mockEncode.mockImplementation(() => { + throw new Error('Encoding failed'); + }); + + const testText = 'Hello, world!'; // 13 characters + const result = await tokenizer.calculateTokens(testText); + + expect(consoleWarnSpy).toHaveBeenCalledWith( + 'Error encoding text with tiktoken:', + expect.any(Error), + ); + // Fallback: Math.ceil(13 / 4) = 4 + expect(result).toBe(4); + }); + + it('should handle very long text', async () => { + const longText = 'a'.repeat(10000); + const mockTokens = new Array(2500); // 2500 tokens + mockEncode.mockReturnValue(mockTokens); + + const result = await tokenizer.calculateTokens(longText); + + expect(result).toBe(2500); + }); + + it('should handle unicode characters', async () => { + const unicodeText = '你好世界 🌍'; + const mockTokens = [1, 2, 3, 4, 5, 6]; + mockEncode.mockReturnValue(mockTokens); + + const result = await tokenizer.calculateTokens(unicodeText); + + expect(result).toBe(6); + }); + + it('should use custom encoding when specified', async () => { + tokenizer = new TextTokenizer('gpt2'); + const testText = 'Hello, world!'; + const mockTokens = [1, 2, 3]; + mockEncode.mockReturnValue(mockTokens); + + const result = await tokenizer.calculateTokens(testText); + + expect(mockGetEncoding).toHaveBeenCalledWith('gpt2'); + expect(result).toBe(3); + }); + }); + + describe('calculateTokensBatch', () => { + beforeEach(() => { + tokenizer = new TextTokenizer(); + }); + + it('should process multiple texts and return token counts', async () => { + const texts = ['Hello', 'world', 'test']; + mockEncode + .mockReturnValueOnce([1, 2]) // 2 tokens for 'Hello' + .mockReturnValueOnce([3, 4, 5]) // 3 tokens for 'world' + .mockReturnValueOnce([6]); // 1 token for 'test' + + const result = await tokenizer.calculateTokensBatch(texts); + + expect(result).toEqual([2, 3, 1]); + expect(mockEncode).toHaveBeenCalledTimes(3); + }); + + it('should handle empty array', async () => { + const result = await tokenizer.calculateTokensBatch([]); + expect(result).toEqual([]); + }); + + it('should handle array with empty strings', async () => { + const texts = ['', 'hello', '']; + mockEncode.mockReturnValue([1, 2, 3]); // Only called for 'hello' + + const result = await tokenizer.calculateTokensBatch(texts); + + expect(result).toEqual([0, 3, 0]); + expect(mockEncode).toHaveBeenCalledTimes(1); + expect(mockEncode).toHaveBeenCalledWith('hello'); + }); + + it('should use fallback calculation when tiktoken fails to load', async () => { + mockGetEncoding.mockImplementation(() => { + throw new Error('Failed to load tiktoken'); + }); + + const texts = ['Hello', 'world']; // 5 and 5 characters + const result = await tokenizer.calculateTokensBatch(texts); + + expect(consoleWarnSpy).toHaveBeenCalledWith( + 'Failed to load tiktoken with encoding cl100k_base:', + expect.any(Error), + ); + // Fallback: Math.ceil(5/4) = 2 for both + expect(result).toEqual([2, 2]); + }); + + it('should use fallback calculation when encoding fails during batch processing', async () => { + mockEncode.mockImplementation(() => { + throw new Error('Encoding failed'); + }); + + const texts = ['Hello', 'world']; // 5 and 5 characters + const result = await tokenizer.calculateTokensBatch(texts); + + expect(consoleWarnSpy).toHaveBeenCalledWith( + 'Error encoding texts with tiktoken:', + expect.any(Error), + ); + // Fallback: Math.ceil(5/4) = 2 for both + expect(result).toEqual([2, 2]); + }); + + it('should handle null and undefined values in batch', async () => { + const texts = [null, 'hello', undefined, 'world'] as unknown as string[]; + mockEncode + .mockReturnValueOnce([1, 2, 3]) // 3 tokens for 'hello' + .mockReturnValueOnce([4, 5]); // 2 tokens for 'world' + + const result = await tokenizer.calculateTokensBatch(texts); + + expect(result).toEqual([0, 3, 0, 2]); + }); + }); + + describe('dispose', () => { + beforeEach(() => { + tokenizer = new TextTokenizer(); + }); + + it('should free tiktoken encoding when disposing', async () => { + // Initialize the encoding by calling calculateTokens + await tokenizer.calculateTokens('test'); + + tokenizer.dispose(); + + expect(mockFree).toHaveBeenCalled(); + }); + + it('should handle disposal when encoding is not initialized', () => { + expect(() => tokenizer.dispose()).not.toThrow(); + expect(mockFree).not.toHaveBeenCalled(); + }); + + it('should handle disposal when encoding is null', async () => { + // Force encoding to be null by making tiktoken fail + mockGetEncoding.mockImplementation(() => { + throw new Error('Failed to load'); + }); + + await tokenizer.calculateTokens('test'); + + expect(() => tokenizer.dispose()).not.toThrow(); + expect(mockFree).not.toHaveBeenCalled(); + }); + + it('should handle errors during disposal gracefully', async () => { + await tokenizer.calculateTokens('test'); + + mockFree.mockImplementation(() => { + throw new Error('Free failed'); + }); + + tokenizer.dispose(); + + expect(consoleWarnSpy).toHaveBeenCalledWith( + 'Error freeing tiktoken encoding:', + expect.any(Error), + ); + }); + + it('should allow multiple calls to dispose', async () => { + await tokenizer.calculateTokens('test'); + + tokenizer.dispose(); + tokenizer.dispose(); // Second call should not throw + + expect(mockFree).toHaveBeenCalledTimes(1); + }); + }); + + describe('lazy initialization', () => { + beforeEach(() => { + tokenizer = new TextTokenizer(); + }); + + it('should not initialize tiktoken until first use', () => { + expect(mockGetEncoding).not.toHaveBeenCalled(); + }); + + it('should initialize tiktoken on first calculateTokens call', async () => { + await tokenizer.calculateTokens('test'); + expect(mockGetEncoding).toHaveBeenCalledTimes(1); + }); + + it('should not reinitialize tiktoken on subsequent calls', async () => { + await tokenizer.calculateTokens('test1'); + await tokenizer.calculateTokens('test2'); + + expect(mockGetEncoding).toHaveBeenCalledTimes(1); + }); + + it('should initialize tiktoken on first calculateTokensBatch call', async () => { + await tokenizer.calculateTokensBatch(['test']); + expect(mockGetEncoding).toHaveBeenCalledTimes(1); + }); + }); + + describe('edge cases', () => { + beforeEach(() => { + tokenizer = new TextTokenizer(); + }); + + it('should handle very short text', async () => { + const result = await tokenizer.calculateTokens('a'); + + if (mockGetEncoding.mock.calls.length > 0) { + // If tiktoken was called, use its result + expect(mockEncode).toHaveBeenCalledWith('a'); + } else { + // If tiktoken failed, should use fallback: Math.ceil(1/4) = 1 + expect(result).toBe(1); + } + }); + + it('should handle text with only whitespace', async () => { + const whitespaceText = ' \n\t '; + const mockTokens = [1]; + mockEncode.mockReturnValue(mockTokens); + + const result = await tokenizer.calculateTokens(whitespaceText); + + expect(result).toBe(1); + }); + + it('should handle special characters and symbols', async () => { + const specialText = '!@#$%^&*()_+-=[]{}|;:,.<>?'; + const mockTokens = new Array(10); + mockEncode.mockReturnValue(mockTokens); + + const result = await tokenizer.calculateTokens(specialText); + + expect(result).toBe(10); + }); + }); +}); diff --git a/packages/core/src/utils/request-tokenizer/textTokenizer.ts b/packages/core/src/utils/request-tokenizer/textTokenizer.ts new file mode 100644 index 00000000..86c71d4c --- /dev/null +++ b/packages/core/src/utils/request-tokenizer/textTokenizer.ts @@ -0,0 +1,97 @@ +/** + * @license + * Copyright 2025 Qwen + * SPDX-License-Identifier: Apache-2.0 + */ + +import type { TiktokenEncoding, Tiktoken } from 'tiktoken'; +import { get_encoding } from 'tiktoken'; + +/** + * Text tokenizer for calculating text tokens using tiktoken + */ +export class TextTokenizer { + private encoding: Tiktoken | null = null; + private encodingName: string; + + constructor(encodingName: string = 'cl100k_base') { + this.encodingName = encodingName; + } + + /** + * Initialize the tokenizer (lazy loading) + */ + private async ensureEncoding(): Promise { + if (this.encoding) return; + + try { + // Use type assertion since we know the encoding name is valid + this.encoding = get_encoding(this.encodingName as TiktokenEncoding); + } catch (error) { + console.warn( + `Failed to load tiktoken with encoding ${this.encodingName}:`, + error, + ); + this.encoding = null; + } + } + + /** + * Calculate tokens for text content + */ + async calculateTokens(text: string): Promise { + if (!text) return 0; + + await this.ensureEncoding(); + + if (this.encoding) { + try { + return this.encoding.encode(text).length; + } catch (error) { + console.warn('Error encoding text with tiktoken:', error); + } + } + + // Fallback: rough approximation using character count + // This is a conservative estimate: 1 token ≈ 4 characters for most languages + return Math.ceil(text.length / 4); + } + + /** + * Calculate tokens for multiple text strings in parallel + */ + async calculateTokensBatch(texts: string[]): Promise { + await this.ensureEncoding(); + + if (this.encoding) { + try { + return texts.map((text) => { + if (!text) return 0; + // this.encoding may be null, add a null check to satisfy lint + return this.encoding ? this.encoding.encode(text).length : 0; + }); + } catch (error) { + console.warn('Error encoding texts with tiktoken:', error); + // In case of error, return fallback estimation for all texts + return texts.map((text) => Math.ceil((text || '').length / 4)); + } + } + + // Fallback for batch processing + return texts.map((text) => Math.ceil((text || '').length / 4)); + } + + /** + * Dispose of resources + */ + dispose(): void { + if (this.encoding) { + try { + this.encoding.free(); + } catch (error) { + console.warn('Error freeing tiktoken encoding:', error); + } + this.encoding = null; + } + } +} diff --git a/packages/core/src/utils/request-tokenizer/types.ts b/packages/core/src/utils/request-tokenizer/types.ts new file mode 100644 index 00000000..38c47699 --- /dev/null +++ b/packages/core/src/utils/request-tokenizer/types.ts @@ -0,0 +1,64 @@ +/** + * @license + * Copyright 2025 Qwen + * SPDX-License-Identifier: Apache-2.0 + */ + +import type { CountTokensParameters } from '@google/genai'; + +/** + * Token calculation result for different content types + */ +export interface TokenCalculationResult { + /** Total tokens calculated */ + totalTokens: number; + /** Breakdown by content type */ + breakdown: { + textTokens: number; + imageTokens: number; + audioTokens: number; + otherTokens: number; + }; + /** Processing time in milliseconds */ + processingTime: number; +} + +/** + * Configuration for token calculation + */ +export interface TokenizerConfig { + /** Custom text tokenizer encoding (defaults to cl100k_base) */ + textEncoding?: string; +} + +/** + * Image metadata extracted from base64 data + */ +export interface ImageMetadata { + /** Image width in pixels */ + width: number; + /** Image height in pixels */ + height: number; + /** MIME type of the image */ + mimeType: string; + /** Size of the base64 data in bytes */ + dataSize: number; +} + +/** + * Request tokenizer interface + */ +export interface RequestTokenizer { + /** + * Calculate tokens for a request + */ + calculateTokens( + request: CountTokensParameters, + config?: TokenizerConfig, + ): Promise; + + /** + * Dispose of resources (worker threads, etc.) + */ + dispose(): Promise; +}