diff --git a/packages/cli/src/config/settingsSchema.ts b/packages/cli/src/config/settingsSchema.ts
index 1cbd63e0..c7f1e94e 100644
--- a/packages/cli/src/config/settingsSchema.ts
+++ b/packages/cli/src/config/settingsSchema.ts
@@ -741,6 +741,16 @@ export const SETTINGS_SCHEMA = {
description: 'Enable extension management features.',
showInDialog: false,
},
+ visionModelPreview: {
+ type: 'boolean',
+ label: 'Vision Model Preview',
+ category: 'Experimental',
+ requiresRestart: false,
+ default: false,
+ description:
+ 'Enable vision model support and auto-switching functionality. When disabled, vision models like qwen-vl-max-latest will be hidden and auto-switching will not occur.',
+ showInDialog: true,
+ },
},
},
diff --git a/packages/cli/src/services/BuiltinCommandLoader.test.ts b/packages/cli/src/services/BuiltinCommandLoader.test.ts
index dcede5a3..38deb425 100644
--- a/packages/cli/src/services/BuiltinCommandLoader.test.ts
+++ b/packages/cli/src/services/BuiltinCommandLoader.test.ts
@@ -56,6 +56,13 @@ vi.mock('../ui/commands/mcpCommand.js', () => ({
kind: 'BUILT_IN',
},
}));
+vi.mock('../ui/commands/modelCommand.js', () => ({
+ modelCommand: {
+ name: 'model',
+ description: 'Model command',
+ kind: 'BUILT_IN',
+ },
+}));
describe('BuiltinCommandLoader', () => {
let mockConfig: Config;
@@ -126,5 +133,8 @@ describe('BuiltinCommandLoader', () => {
const mcpCmd = commands.find((c) => c.name === 'mcp');
expect(mcpCmd).toBeDefined();
+
+ const modelCmd = commands.find((c) => c.name === 'model');
+ expect(modelCmd).toBeDefined();
});
});
diff --git a/packages/cli/src/services/BuiltinCommandLoader.ts b/packages/cli/src/services/BuiltinCommandLoader.ts
index 74de2a3c..12c3cfc9 100644
--- a/packages/cli/src/services/BuiltinCommandLoader.ts
+++ b/packages/cli/src/services/BuiltinCommandLoader.ts
@@ -35,6 +35,7 @@ import { settingsCommand } from '../ui/commands/settingsCommand.js';
import { vimCommand } from '../ui/commands/vimCommand.js';
import { setupGithubCommand } from '../ui/commands/setupGithubCommand.js';
import { terminalSetupCommand } from '../ui/commands/terminalSetupCommand.js';
+import { modelCommand } from '../ui/commands/modelCommand.js';
import { agentsCommand } from '../ui/commands/agentsCommand.js';
/**
@@ -71,6 +72,7 @@ export class BuiltinCommandLoader implements ICommandLoader {
initCommand,
mcpCommand,
memoryCommand,
+ modelCommand,
privacyCommand,
quitCommand,
quitConfirmCommand,
diff --git a/packages/cli/src/ui/App.tsx b/packages/cli/src/ui/App.tsx
index 9ec15650..85691182 100644
--- a/packages/cli/src/ui/App.tsx
+++ b/packages/cli/src/ui/App.tsx
@@ -53,6 +53,17 @@ import { FolderTrustDialog } from './components/FolderTrustDialog.js';
import { ShellConfirmationDialog } from './components/ShellConfirmationDialog.js';
import { QuitConfirmationDialog } from './components/QuitConfirmationDialog.js';
import { RadioButtonSelect } from './components/shared/RadioButtonSelect.js';
+import { ModelSelectionDialog } from './components/ModelSelectionDialog.js';
+import {
+ ModelSwitchDialog,
+ type VisionSwitchOutcome,
+} from './components/ModelSwitchDialog.js';
+import {
+ getOpenAIAvailableModelFromEnv,
+ getFilteredQwenModels,
+ type AvailableModel,
+} from './models/availableModels.js';
+import { processVisionSwitchOutcome } from './hooks/useVisionAutoSwitch.js';
import {
AgentCreationWizard,
AgentsManagerDialog,
@@ -248,6 +259,20 @@ const App = ({ config, settings, startupWarnings = [], version }: AppProps) => {
onWorkspaceMigrationDialogClose,
} = useWorkspaceMigration(settings);
+ // Model selection dialog states
+ const [isModelSelectionDialogOpen, setIsModelSelectionDialogOpen] =
+ useState(false);
+ const [isVisionSwitchDialogOpen, setIsVisionSwitchDialogOpen] =
+ useState(false);
+ const [visionSwitchResolver, setVisionSwitchResolver] = useState<{
+ resolve: (result: {
+ modelOverride?: string;
+ persistSessionModel?: string;
+ showGuidance?: boolean;
+ }) => void;
+ reject: () => void;
+ } | null>(null);
+
useEffect(() => {
const unsubscribe = ideContext.subscribeToIdeContext(setIdeContextState);
// Set the initial value
@@ -590,6 +615,75 @@ const App = ({ config, settings, startupWarnings = [], version }: AppProps) => {
openAuthDialog();
}, [openAuthDialog, setAuthError]);
+ // Vision switch handler for auto-switch functionality
+ const handleVisionSwitchRequired = useCallback(
+ async (_query: unknown) =>
+ new Promise<{
+ modelOverride?: string;
+ persistSessionModel?: string;
+ showGuidance?: boolean;
+ }>((resolve, reject) => {
+ setVisionSwitchResolver({ resolve, reject });
+ setIsVisionSwitchDialogOpen(true);
+ }),
+ [],
+ );
+
+ const handleVisionSwitchSelect = useCallback(
+ (outcome: VisionSwitchOutcome) => {
+ setIsVisionSwitchDialogOpen(false);
+ if (visionSwitchResolver) {
+ const result = processVisionSwitchOutcome(outcome);
+ visionSwitchResolver.resolve(result);
+ setVisionSwitchResolver(null);
+ }
+ },
+ [visionSwitchResolver],
+ );
+
+ const handleModelSelectionOpen = useCallback(() => {
+ setIsModelSelectionDialogOpen(true);
+ }, []);
+
+ const handleModelSelectionClose = useCallback(() => {
+ setIsModelSelectionDialogOpen(false);
+ }, []);
+
+ const handleModelSelect = useCallback(
+ (modelId: string) => {
+ config.setModel(modelId);
+ setCurrentModel(modelId);
+ setIsModelSelectionDialogOpen(false);
+ addItem(
+ {
+ type: MessageType.INFO,
+ text: `Switched model to \`${modelId}\` for this session.`,
+ },
+ Date.now(),
+ );
+ },
+ [config, setCurrentModel, addItem],
+ );
+
+ const getAvailableModelsForCurrentAuth = useCallback((): AvailableModel[] => {
+ const contentGeneratorConfig = config.getContentGeneratorConfig();
+ if (!contentGeneratorConfig) return [];
+
+ const visionModelPreviewEnabled =
+ settings.merged.experimental?.visionModelPreview ?? false;
+
+ switch (contentGeneratorConfig.authType) {
+ case AuthType.QWEN_OAUTH:
+ return getFilteredQwenModels(visionModelPreviewEnabled);
+ case AuthType.USE_OPENAI: {
+ const openAIModel = getOpenAIAvailableModelFromEnv();
+ return openAIModel ? [openAIModel] : [];
+ }
+ default:
+ return [];
+ }
+ }, [config, settings.merged.experimental?.visionModelPreview]);
+
// Core hooks and processors
const {
vimEnabled: vimModeEnabled,
@@ -620,6 +714,7 @@ const App = ({ config, settings, startupWarnings = [], version }: AppProps) => {
setQuittingMessages,
openPrivacyNotice,
openSettingsDialog,
+ handleModelSelectionOpen,
openSubagentCreateDialog,
openAgentsManagerDialog,
toggleVimEnabled,
@@ -664,6 +759,8 @@ const App = ({ config, settings, startupWarnings = [], version }: AppProps) => {
setModelSwitchedFromQuotaError,
refreshStatic,
() => cancelHandlerRef.current(),
+ settings.merged.experimental?.visionModelPreview ?? false,
+ handleVisionSwitchRequired,
);
const pendingHistoryItems = useMemo(
@@ -1034,6 +1131,8 @@ const App = ({ config, settings, startupWarnings = [], version }: AppProps) => {
!isAuthDialogOpen &&
!isThemeDialogOpen &&
!isEditorDialogOpen &&
+ !isModelSelectionDialogOpen &&
+ !isVisionSwitchDialogOpen &&
!isSubagentCreateDialogOpen &&
!showPrivacyNotice &&
!showWelcomeBackDialog &&
@@ -1055,6 +1154,8 @@ const App = ({ config, settings, startupWarnings = [], version }: AppProps) => {
showWelcomeBackDialog,
welcomeBackChoice,
geminiClient,
+ isModelSelectionDialogOpen,
+ isVisionSwitchDialogOpen,
]);
if (quittingMessages) {
@@ -1322,6 +1423,15 @@ const App = ({ config, settings, startupWarnings = [], version }: AppProps) => {
onExit={exitEditorDialog}
/>
+ ) : isModelSelectionDialogOpen ? (
+
+ ) : isVisionSwitchDialogOpen ? (
+
) : showPrivacyNotice ? (
setShowPrivacyNotice(false)}
diff --git a/packages/cli/src/ui/commands/modelCommand.test.ts b/packages/cli/src/ui/commands/modelCommand.test.ts
new file mode 100644
index 00000000..f3aaad52
--- /dev/null
+++ b/packages/cli/src/ui/commands/modelCommand.test.ts
@@ -0,0 +1,179 @@
+/**
+ * @license
+ * Copyright 2025 Qwen
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+import { describe, it, expect, beforeEach, vi } from 'vitest';
+import { modelCommand } from './modelCommand.js';
+import { type CommandContext } from './types.js';
+import { createMockCommandContext } from '../../test-utils/mockCommandContext.js';
+import {
+ AuthType,
+ type ContentGeneratorConfig,
+ type Config,
+} from '@qwen-code/qwen-code-core';
+import * as availableModelsModule from '../models/availableModels.js';
+
+// Mock the availableModels module
+vi.mock('../models/availableModels.js', () => ({
+ AVAILABLE_MODELS_QWEN: [
+ { id: 'qwen3-coder-plus', label: 'qwen3-coder-plus' },
+ { id: 'qwen-vl-max-latest', label: 'qwen-vl-max', isVision: true },
+ ],
+ getOpenAIAvailableModelFromEnv: vi.fn(),
+}));
+
+// Helper function to create a mock config
+function createMockConfig(
+ contentGeneratorConfig: ContentGeneratorConfig | null,
+): Partial {
+ return {
+ getContentGeneratorConfig: vi.fn().mockReturnValue(contentGeneratorConfig),
+ };
+}
+
+describe('modelCommand', () => {
+ let mockContext: CommandContext;
+ const mockGetOpenAIAvailableModelFromEnv = vi.mocked(
+ availableModelsModule.getOpenAIAvailableModelFromEnv,
+ );
+
+ beforeEach(() => {
+ mockContext = createMockCommandContext();
+ vi.clearAllMocks();
+ });
+
+ it('should have the correct name and description', () => {
+ expect(modelCommand.name).toBe('model');
+ expect(modelCommand.description).toBe('Switch the model for this session');
+ });
+
+ it('should return error when config is not available', async () => {
+ mockContext.services.config = null;
+
+ const result = await modelCommand.action!(mockContext, '');
+
+ expect(result).toEqual({
+ type: 'message',
+ messageType: 'error',
+ content: 'Configuration not available.',
+ });
+ });
+
+ it('should return error when content generator config is not available', async () => {
+ const mockConfig = createMockConfig(null);
+ mockContext.services.config = mockConfig as Config;
+
+ const result = await modelCommand.action!(mockContext, '');
+
+ expect(result).toEqual({
+ type: 'message',
+ messageType: 'error',
+ content: 'Content generator configuration not available.',
+ });
+ });
+
+ it('should return error when auth type is not available', async () => {
+ const mockConfig = createMockConfig({
+ model: 'test-model',
+ authType: undefined,
+ });
+ mockContext.services.config = mockConfig as Config;
+
+ const result = await modelCommand.action!(mockContext, '');
+
+ expect(result).toEqual({
+ type: 'message',
+ messageType: 'error',
+ content: 'Authentication type not available.',
+ });
+ });
+
+ it('should return dialog action for QWEN_OAUTH auth type', async () => {
+ const mockConfig = createMockConfig({
+ model: 'test-model',
+ authType: AuthType.QWEN_OAUTH,
+ });
+ mockContext.services.config = mockConfig as Config;
+
+ const result = await modelCommand.action!(mockContext, '');
+
+ expect(result).toEqual({
+ type: 'dialog',
+ dialog: 'model',
+ });
+ });
+
+ it('should return dialog action for USE_OPENAI auth type when model is available', async () => {
+ mockGetOpenAIAvailableModelFromEnv.mockReturnValue({
+ id: 'gpt-4',
+ label: 'gpt-4',
+ });
+
+ const mockConfig = createMockConfig({
+ model: 'test-model',
+ authType: AuthType.USE_OPENAI,
+ });
+ mockContext.services.config = mockConfig as Config;
+
+ const result = await modelCommand.action!(mockContext, '');
+
+ expect(result).toEqual({
+ type: 'dialog',
+ dialog: 'model',
+ });
+ });
+
+ it('should return error for USE_OPENAI auth type when no model is available', async () => {
+ mockGetOpenAIAvailableModelFromEnv.mockReturnValue(null);
+
+ const mockConfig = createMockConfig({
+ model: 'test-model',
+ authType: AuthType.USE_OPENAI,
+ });
+ mockContext.services.config = mockConfig as Config;
+
+ const result = await modelCommand.action!(mockContext, '');
+
+ expect(result).toEqual({
+ type: 'message',
+ messageType: 'error',
+ content:
+ 'No models available for the current authentication type (openai).',
+ });
+ });
+
+ it('should return error for unsupported auth types', async () => {
+ const mockConfig = createMockConfig({
+ model: 'test-model',
+ authType: 'UNSUPPORTED_AUTH_TYPE' as AuthType,
+ });
+ mockContext.services.config = mockConfig as Config;
+
+ const result = await modelCommand.action!(mockContext, '');
+
+ expect(result).toEqual({
+ type: 'message',
+ messageType: 'error',
+ content:
+ 'No models available for the current authentication type (UNSUPPORTED_AUTH_TYPE).',
+ });
+ });
+
+ it('should handle undefined auth type', async () => {
+ const mockConfig = createMockConfig({
+ model: 'test-model',
+ authType: undefined,
+ });
+ mockContext.services.config = mockConfig as Config;
+
+ const result = await modelCommand.action!(mockContext, '');
+
+ expect(result).toEqual({
+ type: 'message',
+ messageType: 'error',
+ content: 'Authentication type not available.',
+ });
+ });
+});
diff --git a/packages/cli/src/ui/commands/modelCommand.ts b/packages/cli/src/ui/commands/modelCommand.ts
new file mode 100644
index 00000000..9e4fdcb0
--- /dev/null
+++ b/packages/cli/src/ui/commands/modelCommand.ts
@@ -0,0 +1,88 @@
+/**
+ * @license
+ * Copyright 2025 Qwen
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+import { AuthType } from '@qwen-code/qwen-code-core';
+import type {
+ SlashCommand,
+ CommandContext,
+ OpenDialogActionReturn,
+ MessageActionReturn,
+} from './types.js';
+import { CommandKind } from './types.js';
+import {
+ AVAILABLE_MODELS_QWEN,
+ getOpenAIAvailableModelFromEnv,
+ type AvailableModel,
+} from '../models/availableModels.js';
+
+function getAvailableModelsForAuthType(authType: AuthType): AvailableModel[] {
+ switch (authType) {
+ case AuthType.QWEN_OAUTH:
+ return AVAILABLE_MODELS_QWEN;
+ case AuthType.USE_OPENAI: {
+ const openAIModel = getOpenAIAvailableModelFromEnv();
+ return openAIModel ? [openAIModel] : [];
+ }
+ default:
+ // For other auth types, return empty array for now
+ // This can be expanded later according to the design doc
+ return [];
+ }
+}
+
+export const modelCommand: SlashCommand = {
+ name: 'model',
+ description: 'Switch the model for this session',
+ kind: CommandKind.BUILT_IN,
+ action: async (
+ context: CommandContext,
+ ): Promise => {
+ const { services } = context;
+ const { config } = services;
+
+ if (!config) {
+ return {
+ type: 'message',
+ messageType: 'error',
+ content: 'Configuration not available.',
+ };
+ }
+
+ const contentGeneratorConfig = config.getContentGeneratorConfig();
+ if (!contentGeneratorConfig) {
+ return {
+ type: 'message',
+ messageType: 'error',
+ content: 'Content generator configuration not available.',
+ };
+ }
+
+ const authType = contentGeneratorConfig.authType;
+ if (!authType) {
+ return {
+ type: 'message',
+ messageType: 'error',
+ content: 'Authentication type not available.',
+ };
+ }
+
+ const availableModels = getAvailableModelsForAuthType(authType);
+
+ if (availableModels.length === 0) {
+ return {
+ type: 'message',
+ messageType: 'error',
+ content: `No models available for the current authentication type (${authType}).`,
+ };
+ }
+
+ // Trigger model selection dialog
+ return {
+ type: 'dialog',
+ dialog: 'model',
+ };
+ },
+};
diff --git a/packages/cli/src/ui/commands/types.ts b/packages/cli/src/ui/commands/types.ts
index 971e8a02..18484d82 100644
--- a/packages/cli/src/ui/commands/types.ts
+++ b/packages/cli/src/ui/commands/types.ts
@@ -116,6 +116,7 @@ export interface OpenDialogActionReturn {
| 'editor'
| 'privacy'
| 'settings'
+ | 'model'
| 'subagent_create'
| 'subagent_list';
}
diff --git a/packages/cli/src/ui/components/ModelSelectionDialog.test.tsx b/packages/cli/src/ui/components/ModelSelectionDialog.test.tsx
new file mode 100644
index 00000000..4a5b6bcf
--- /dev/null
+++ b/packages/cli/src/ui/components/ModelSelectionDialog.test.tsx
@@ -0,0 +1,246 @@
+/**
+ * @license
+ * Copyright 2025 Qwen
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+import React from 'react';
+import { render } from 'ink-testing-library';
+import { describe, it, expect, vi, beforeEach } from 'vitest';
+import { ModelSelectionDialog } from './ModelSelectionDialog.js';
+import type { AvailableModel } from '../models/availableModels.js';
+import type { RadioSelectItem } from './shared/RadioButtonSelect.js';
+
+// Mock the useKeypress hook
+const mockUseKeypress = vi.hoisted(() => vi.fn());
+vi.mock('../hooks/useKeypress.js', () => ({
+ useKeypress: mockUseKeypress,
+}));
+
+// Mock the RadioButtonSelect component
+const mockRadioButtonSelect = vi.hoisted(() => vi.fn());
+vi.mock('./shared/RadioButtonSelect.js', () => ({
+ RadioButtonSelect: mockRadioButtonSelect,
+}));
+
+describe('ModelSelectionDialog', () => {
+ const mockAvailableModels: AvailableModel[] = [
+ { id: 'qwen3-coder-plus', label: 'qwen3-coder-plus' },
+ { id: 'qwen-vl-max-latest', label: 'qwen-vl-max', isVision: true },
+ { id: 'gpt-4', label: 'GPT-4' },
+ ];
+
+ const mockOnSelect = vi.fn();
+ const mockOnCancel = vi.fn();
+
+ beforeEach(() => {
+ vi.clearAllMocks();
+
+ // Mock RadioButtonSelect to return a simple div
+ mockRadioButtonSelect.mockReturnValue(
+ React.createElement('div', { 'data-testid': 'radio-select' }),
+ );
+ });
+
+ it('should setup escape key handler to call onCancel', () => {
+ render(
+ ,
+ );
+
+ expect(mockUseKeypress).toHaveBeenCalledWith(expect.any(Function), {
+ isActive: true,
+ });
+
+ // Simulate escape key press
+ const keypressHandler = mockUseKeypress.mock.calls[0][0];
+ keypressHandler({ name: 'escape' });
+
+ expect(mockOnCancel).toHaveBeenCalled();
+ });
+
+ it('should not call onCancel for non-escape keys', () => {
+ render(
+ ,
+ );
+
+ const keypressHandler = mockUseKeypress.mock.calls[0][0];
+ keypressHandler({ name: 'enter' });
+
+ expect(mockOnCancel).not.toHaveBeenCalled();
+ });
+
+ it('should set correct initial index for current model', () => {
+ render(
+ ,
+ );
+
+ const callArgs = mockRadioButtonSelect.mock.calls[0][0];
+ expect(callArgs.initialIndex).toBe(1); // qwen-vl-max-latest is at index 1
+ });
+
+ it('should set initial index to 0 when current model is not found', () => {
+ render(
+ ,
+ );
+
+ const callArgs = mockRadioButtonSelect.mock.calls[0][0];
+ expect(callArgs.initialIndex).toBe(0);
+ });
+
+ it('should call onSelect when a model is selected', () => {
+ render(
+ ,
+ );
+
+ const callArgs = mockRadioButtonSelect.mock.calls[0][0];
+ expect(typeof callArgs.onSelect).toBe('function');
+
+ // Simulate selection
+ const onSelectCallback = mockRadioButtonSelect.mock.calls[0][0].onSelect;
+ onSelectCallback('qwen-vl-max-latest');
+
+ expect(mockOnSelect).toHaveBeenCalledWith('qwen-vl-max-latest');
+ });
+
+ it('should handle empty models array', () => {
+ render(
+ ,
+ );
+
+ const callArgs = mockRadioButtonSelect.mock.calls[0][0];
+ expect(callArgs.items).toEqual([]);
+ expect(callArgs.initialIndex).toBe(0);
+ });
+
+ it('should create correct option items with proper labels', () => {
+ render(
+ ,
+ );
+
+ const expectedItems = [
+ {
+ label: 'qwen3-coder-plus (current)',
+ value: 'qwen3-coder-plus',
+ },
+ {
+ label: 'qwen-vl-max [Vision]',
+ value: 'qwen-vl-max-latest',
+ },
+ {
+ label: 'GPT-4',
+ value: 'gpt-4',
+ },
+ ];
+
+ const callArgs = mockRadioButtonSelect.mock.calls[0][0];
+ expect(callArgs.items).toEqual(expectedItems);
+ });
+
+ it('should show vision indicator for vision models', () => {
+ render(
+ ,
+ );
+
+ const callArgs = mockRadioButtonSelect.mock.calls[0][0];
+ const visionModelItem = callArgs.items.find(
+ (item: RadioSelectItem) => item.value === 'qwen-vl-max-latest',
+ );
+
+ expect(visionModelItem?.label).toContain('[Vision]');
+ });
+
+ it('should show current indicator for the current model', () => {
+ render(
+ ,
+ );
+
+ const callArgs = mockRadioButtonSelect.mock.calls[0][0];
+ const currentModelItem = callArgs.items.find(
+ (item: RadioSelectItem) => item.value === 'qwen-vl-max-latest',
+ );
+
+ expect(currentModelItem?.label).toContain('(current)');
+ });
+
+ it('should pass isFocused prop to RadioButtonSelect', () => {
+ render(
+ ,
+ );
+
+ const callArgs = mockRadioButtonSelect.mock.calls[0][0];
+ expect(callArgs.isFocused).toBe(true);
+ });
+
+ it('should handle multiple onSelect calls correctly', () => {
+ render(
+ ,
+ );
+
+ const onSelectCallback = mockRadioButtonSelect.mock.calls[0][0].onSelect;
+
+ // Call multiple times
+ onSelectCallback('qwen3-coder-plus');
+ onSelectCallback('qwen-vl-max-latest');
+ onSelectCallback('gpt-4');
+
+ expect(mockOnSelect).toHaveBeenCalledTimes(3);
+ expect(mockOnSelect).toHaveBeenNthCalledWith(1, 'qwen3-coder-plus');
+ expect(mockOnSelect).toHaveBeenNthCalledWith(2, 'qwen-vl-max-latest');
+ expect(mockOnSelect).toHaveBeenNthCalledWith(3, 'gpt-4');
+ });
+});
diff --git a/packages/cli/src/ui/components/ModelSelectionDialog.tsx b/packages/cli/src/ui/components/ModelSelectionDialog.tsx
new file mode 100644
index 00000000..d43e69f3
--- /dev/null
+++ b/packages/cli/src/ui/components/ModelSelectionDialog.tsx
@@ -0,0 +1,87 @@
+/**
+ * @license
+ * Copyright 2025 Qwen
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+import type React from 'react';
+import { Box, Text } from 'ink';
+import { Colors } from '../colors.js';
+import {
+ RadioButtonSelect,
+ type RadioSelectItem,
+} from './shared/RadioButtonSelect.js';
+import { useKeypress } from '../hooks/useKeypress.js';
+import type { AvailableModel } from '../models/availableModels.js';
+
+export interface ModelSelectionDialogProps {
+ availableModels: AvailableModel[];
+ currentModel: string;
+ onSelect: (modelId: string) => void;
+ onCancel: () => void;
+}
+
+export const ModelSelectionDialog: React.FC = ({
+ availableModels,
+ currentModel,
+ onSelect,
+ onCancel,
+}) => {
+ useKeypress(
+ (key) => {
+ if (key.name === 'escape') {
+ onCancel();
+ }
+ },
+ { isActive: true },
+ );
+
+ const options: Array> = availableModels.map(
+ (model) => {
+ const visionIndicator = model.isVision ? ' [Vision]' : '';
+ const currentIndicator = model.id === currentModel ? ' (current)' : '';
+ return {
+ label: `${model.label}${visionIndicator}${currentIndicator}`,
+ value: model.id,
+ };
+ },
+ );
+
+ const initialIndex = Math.max(
+ 0,
+ availableModels.findIndex((model) => model.id === currentModel),
+ );
+
+ const handleSelect = (modelId: string) => {
+ onSelect(modelId);
+ };
+
+ return (
+
+
+ Select Model
+ Choose a model for this session:
+
+
+
+
+
+
+
+ Press Enter to select, Esc to cancel
+
+
+ );
+};
diff --git a/packages/cli/src/ui/components/ModelSwitchDialog.test.tsx b/packages/cli/src/ui/components/ModelSwitchDialog.test.tsx
new file mode 100644
index 00000000..f26dcc55
--- /dev/null
+++ b/packages/cli/src/ui/components/ModelSwitchDialog.test.tsx
@@ -0,0 +1,185 @@
+/**
+ * @license
+ * Copyright 2025 Qwen
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+import React from 'react';
+import { render } from 'ink-testing-library';
+import { describe, it, expect, vi, beforeEach } from 'vitest';
+import { ModelSwitchDialog, VisionSwitchOutcome } from './ModelSwitchDialog.js';
+
+// Mock the useKeypress hook
+const mockUseKeypress = vi.hoisted(() => vi.fn());
+vi.mock('../hooks/useKeypress.js', () => ({
+ useKeypress: mockUseKeypress,
+}));
+
+// Mock the RadioButtonSelect component
+const mockRadioButtonSelect = vi.hoisted(() => vi.fn());
+vi.mock('./shared/RadioButtonSelect.js', () => ({
+ RadioButtonSelect: mockRadioButtonSelect,
+}));
+
+describe('ModelSwitchDialog', () => {
+ const mockOnSelect = vi.fn();
+
+ beforeEach(() => {
+ vi.clearAllMocks();
+
+ // Mock RadioButtonSelect to return a simple div
+ mockRadioButtonSelect.mockReturnValue(
+ React.createElement('div', { 'data-testid': 'radio-select' }),
+ );
+ });
+
+ it('should setup RadioButtonSelect with correct options', () => {
+ render();
+
+ const expectedItems = [
+ {
+ label: 'Switch for this request only',
+ value: VisionSwitchOutcome.SwitchOnce,
+ },
+ {
+ label: 'Switch session to vision model',
+ value: VisionSwitchOutcome.SwitchSessionToVL,
+ },
+ {
+ label: 'Do not switch, show guidance',
+ value: VisionSwitchOutcome.DisallowWithGuidance,
+ },
+ ];
+
+ const callArgs = mockRadioButtonSelect.mock.calls[0][0];
+ expect(callArgs.items).toEqual(expectedItems);
+ expect(callArgs.initialIndex).toBe(0);
+ expect(callArgs.isFocused).toBe(true);
+ });
+
+ it('should call onSelect when an option is selected', () => {
+ render();
+
+ const callArgs = mockRadioButtonSelect.mock.calls[0][0];
+ expect(typeof callArgs.onSelect).toBe('function');
+
+ // Simulate selection of "Switch for this request only"
+ const onSelectCallback = mockRadioButtonSelect.mock.calls[0][0].onSelect;
+ onSelectCallback(VisionSwitchOutcome.SwitchOnce);
+
+ expect(mockOnSelect).toHaveBeenCalledWith(VisionSwitchOutcome.SwitchOnce);
+ });
+
+ it('should call onSelect with SwitchSessionToVL when second option is selected', () => {
+ render();
+
+ const onSelectCallback = mockRadioButtonSelect.mock.calls[0][0].onSelect;
+ onSelectCallback(VisionSwitchOutcome.SwitchSessionToVL);
+
+ expect(mockOnSelect).toHaveBeenCalledWith(
+ VisionSwitchOutcome.SwitchSessionToVL,
+ );
+ });
+
+ it('should call onSelect with DisallowWithGuidance when third option is selected', () => {
+ render();
+
+ const onSelectCallback = mockRadioButtonSelect.mock.calls[0][0].onSelect;
+ onSelectCallback(VisionSwitchOutcome.DisallowWithGuidance);
+
+ expect(mockOnSelect).toHaveBeenCalledWith(
+ VisionSwitchOutcome.DisallowWithGuidance,
+ );
+ });
+
+ it('should setup escape key handler to call onSelect with DisallowWithGuidance', () => {
+ render();
+
+ expect(mockUseKeypress).toHaveBeenCalledWith(expect.any(Function), {
+ isActive: true,
+ });
+
+ // Simulate escape key press
+ const keypressHandler = mockUseKeypress.mock.calls[0][0];
+ keypressHandler({ name: 'escape' });
+
+ expect(mockOnSelect).toHaveBeenCalledWith(
+ VisionSwitchOutcome.DisallowWithGuidance,
+ );
+ });
+
+ it('should not call onSelect for non-escape keys', () => {
+ render();
+
+ const keypressHandler = mockUseKeypress.mock.calls[0][0];
+ keypressHandler({ name: 'enter' });
+
+ expect(mockOnSelect).not.toHaveBeenCalled();
+ });
+
+ it('should set initial index to 0 (first option)', () => {
+ render();
+
+ const callArgs = mockRadioButtonSelect.mock.calls[0][0];
+ expect(callArgs.initialIndex).toBe(0);
+ });
+
+ describe('VisionSwitchOutcome enum', () => {
+ it('should have correct enum values', () => {
+ expect(VisionSwitchOutcome.SwitchOnce).toBe('switch_once');
+ expect(VisionSwitchOutcome.SwitchSessionToVL).toBe(
+ 'switch_session_to_vl',
+ );
+ expect(VisionSwitchOutcome.DisallowWithGuidance).toBe(
+ 'disallow_with_guidance',
+ );
+ });
+ });
+
+ it('should handle multiple onSelect calls correctly', () => {
+ render();
+
+ const onSelectCallback = mockRadioButtonSelect.mock.calls[0][0].onSelect;
+
+ // Call multiple times
+ onSelectCallback(VisionSwitchOutcome.SwitchOnce);
+ onSelectCallback(VisionSwitchOutcome.SwitchSessionToVL);
+ onSelectCallback(VisionSwitchOutcome.DisallowWithGuidance);
+
+ expect(mockOnSelect).toHaveBeenCalledTimes(3);
+ expect(mockOnSelect).toHaveBeenNthCalledWith(
+ 1,
+ VisionSwitchOutcome.SwitchOnce,
+ );
+ expect(mockOnSelect).toHaveBeenNthCalledWith(
+ 2,
+ VisionSwitchOutcome.SwitchSessionToVL,
+ );
+ expect(mockOnSelect).toHaveBeenNthCalledWith(
+ 3,
+ VisionSwitchOutcome.DisallowWithGuidance,
+ );
+ });
+
+ it('should pass isFocused prop to RadioButtonSelect', () => {
+ render();
+
+ const callArgs = mockRadioButtonSelect.mock.calls[0][0];
+ expect(callArgs.isFocused).toBe(true);
+ });
+
+ it('should handle escape key multiple times', () => {
+ render();
+
+ const keypressHandler = mockUseKeypress.mock.calls[0][0];
+
+ // Call escape multiple times
+ keypressHandler({ name: 'escape' });
+ keypressHandler({ name: 'escape' });
+
+ expect(mockOnSelect).toHaveBeenCalledTimes(2);
+ expect(mockOnSelect).toHaveBeenCalledWith(
+ VisionSwitchOutcome.DisallowWithGuidance,
+ );
+ });
+});
diff --git a/packages/cli/src/ui/components/ModelSwitchDialog.tsx b/packages/cli/src/ui/components/ModelSwitchDialog.tsx
new file mode 100644
index 00000000..1a8c73d4
--- /dev/null
+++ b/packages/cli/src/ui/components/ModelSwitchDialog.tsx
@@ -0,0 +1,89 @@
+/**
+ * @license
+ * Copyright 2025 Qwen
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+import type React from 'react';
+import { Box, Text } from 'ink';
+import { Colors } from '../colors.js';
+import {
+ RadioButtonSelect,
+ type RadioSelectItem,
+} from './shared/RadioButtonSelect.js';
+import { useKeypress } from '../hooks/useKeypress.js';
+
+export enum VisionSwitchOutcome {
+ SwitchOnce = 'switch_once',
+ SwitchSessionToVL = 'switch_session_to_vl',
+ DisallowWithGuidance = 'disallow_with_guidance',
+}
+
+export interface ModelSwitchDialogProps {
+ onSelect: (outcome: VisionSwitchOutcome) => void;
+}
+
+export const ModelSwitchDialog: React.FC = ({
+ onSelect,
+}) => {
+ useKeypress(
+ (key) => {
+ if (key.name === 'escape') {
+ onSelect(VisionSwitchOutcome.DisallowWithGuidance);
+ }
+ },
+ { isActive: true },
+ );
+
+ const options: Array> = [
+ {
+ label: 'Switch for this request only',
+ value: VisionSwitchOutcome.SwitchOnce,
+ },
+ {
+ label: 'Switch session to vision model',
+ value: VisionSwitchOutcome.SwitchSessionToVL,
+ },
+ {
+ label: 'Do not switch, show guidance',
+ value: VisionSwitchOutcome.DisallowWithGuidance,
+ },
+ ];
+
+ const handleSelect = (outcome: VisionSwitchOutcome) => {
+ onSelect(outcome);
+ };
+
+ return (
+
+
+ Vision Model Switch Required
+
+ Your message contains an image, but the current model doesn't
+ support vision.
+
+ How would you like to proceed?
+
+
+
+
+
+
+
+ Press Enter to select, Esc to cancel
+
+
+ );
+};
diff --git a/packages/cli/src/ui/hooks/slashCommandProcessor.test.ts b/packages/cli/src/ui/hooks/slashCommandProcessor.test.ts
index 5403bec8..44b99fe9 100644
--- a/packages/cli/src/ui/hooks/slashCommandProcessor.test.ts
+++ b/packages/cli/src/ui/hooks/slashCommandProcessor.test.ts
@@ -106,6 +106,7 @@ describe('useSlashCommandProcessor', () => {
const mockLoadHistory = vi.fn();
const mockOpenThemeDialog = vi.fn();
const mockOpenAuthDialog = vi.fn();
+ const mockOpenModelSelectionDialog = vi.fn();
const mockSetQuittingMessages = vi.fn();
const mockConfig = makeFakeConfig({});
@@ -122,6 +123,7 @@ describe('useSlashCommandProcessor', () => {
mockBuiltinLoadCommands.mockResolvedValue([]);
mockFileLoadCommands.mockResolvedValue([]);
mockMcpLoadCommands.mockResolvedValue([]);
+ mockOpenModelSelectionDialog.mockClear();
});
const setupProcessorHook = (
@@ -150,11 +152,13 @@ describe('useSlashCommandProcessor', () => {
mockSetQuittingMessages,
vi.fn(), // openPrivacyNotice
vi.fn(), // openSettingsDialog
+ mockOpenModelSelectionDialog,
vi.fn(), // openSubagentCreateDialog
vi.fn(), // openAgentsManagerDialog
vi.fn(), // toggleVimEnabled
setIsProcessing,
vi.fn(), // setGeminiMdFileCount
+ vi.fn(), // _showQuitConfirmation
),
);
@@ -395,6 +399,21 @@ describe('useSlashCommandProcessor', () => {
expect(mockOpenThemeDialog).toHaveBeenCalled();
});
+ it('should handle "dialog: model" action', async () => {
+ const command = createTestCommand({
+ name: 'modelcmd',
+ action: vi.fn().mockResolvedValue({ type: 'dialog', dialog: 'model' }),
+ });
+ const result = setupProcessorHook([command]);
+ await waitFor(() => expect(result.current.slashCommands).toHaveLength(1));
+
+ await act(async () => {
+ await result.current.handleSlashCommand('/modelcmd');
+ });
+
+ expect(mockOpenModelSelectionDialog).toHaveBeenCalled();
+ });
+
it('should handle "load_history" action', async () => {
const command = createTestCommand({
name: 'load',
@@ -904,11 +923,13 @@ describe('useSlashCommandProcessor', () => {
mockSetQuittingMessages,
vi.fn(), // openPrivacyNotice
vi.fn(), // openSettingsDialog
+ vi.fn(), // openModelSelectionDialog
vi.fn(), // openSubagentCreateDialog
vi.fn(), // openAgentsManagerDialog
vi.fn(), // toggleVimEnabled
vi.fn(), // setIsProcessing
vi.fn(), // setGeminiMdFileCount
+ vi.fn(), // _showQuitConfirmation
),
);
diff --git a/packages/cli/src/ui/hooks/slashCommandProcessor.ts b/packages/cli/src/ui/hooks/slashCommandProcessor.ts
index 3e49a0eb..10c4573d 100644
--- a/packages/cli/src/ui/hooks/slashCommandProcessor.ts
+++ b/packages/cli/src/ui/hooks/slashCommandProcessor.ts
@@ -53,6 +53,7 @@ export const useSlashCommandProcessor = (
setQuittingMessages: (message: HistoryItem[]) => void,
openPrivacyNotice: () => void,
openSettingsDialog: () => void,
+ openModelSelectionDialog: () => void,
openSubagentCreateDialog: () => void,
openAgentsManagerDialog: () => void,
toggleVimEnabled: () => Promise,
@@ -404,6 +405,9 @@ export const useSlashCommandProcessor = (
case 'settings':
openSettingsDialog();
return { type: 'handled' };
+ case 'model':
+ openModelSelectionDialog();
+ return { type: 'handled' };
case 'subagent_create':
openSubagentCreateDialog();
return { type: 'handled' };
@@ -663,6 +667,7 @@ export const useSlashCommandProcessor = (
setSessionShellAllowlist,
setIsProcessing,
setConfirmationRequest,
+ openModelSelectionDialog,
session.stats,
],
);
diff --git a/packages/cli/src/ui/hooks/useGeminiStream.test.tsx b/packages/cli/src/ui/hooks/useGeminiStream.test.tsx
index 9eab226c..125620cf 100644
--- a/packages/cli/src/ui/hooks/useGeminiStream.test.tsx
+++ b/packages/cli/src/ui/hooks/useGeminiStream.test.tsx
@@ -56,6 +56,12 @@ const MockedUserPromptEvent = vi.hoisted(() =>
);
const mockParseAndFormatApiError = vi.hoisted(() => vi.fn());
+// Vision auto-switch mocks (hoisted)
+const mockHandleVisionSwitch = vi.hoisted(() =>
+ vi.fn().mockResolvedValue({ shouldProceed: true }),
+);
+const mockRestoreOriginalModel = vi.hoisted(() => vi.fn());
+
vi.mock('@qwen-code/qwen-code-core', async (importOriginal) => {
const actualCoreModule = (await importOriginal()) as any;
return {
@@ -76,6 +82,13 @@ vi.mock('./useReactToolScheduler.js', async (importOriginal) => {
};
});
+vi.mock('./useVisionAutoSwitch.js', () => ({
+ useVisionAutoSwitch: vi.fn(() => ({
+ handleVisionSwitch: mockHandleVisionSwitch,
+ restoreOriginalModel: mockRestoreOriginalModel,
+ })),
+}));
+
vi.mock('./useKeypress.js', () => ({
useKeypress: vi.fn(),
}));
@@ -199,6 +212,7 @@ describe('useGeminiStream', () => {
getContentGeneratorConfig: vi
.fn()
.mockReturnValue(contentGeneratorConfig),
+ getMaxSessionTurns: vi.fn(() => 50),
} as unknown as Config;
mockOnDebugMessage = vi.fn();
mockHandleSlashCommand = vi.fn().mockResolvedValue(false);
@@ -1551,6 +1565,7 @@ describe('useGeminiStream', () => {
expect.any(String), // Argument 3: The prompt_id string
);
});
+
describe('Thought Reset', () => {
it('should reset thought to null when starting a new prompt', async () => {
// First, simulate a response with a thought
@@ -1900,4 +1915,166 @@ describe('useGeminiStream', () => {
);
});
});
+
+ // --- New tests focused on recent modifications ---
+ describe('Vision Auto Switch Integration', () => {
+ it('should call handleVisionSwitch and proceed to send when allowed', async () => {
+ mockHandleVisionSwitch.mockResolvedValueOnce({ shouldProceed: true });
+ mockSendMessageStream.mockReturnValue(
+ (async function* () {
+ yield { type: ServerGeminiEventType.Content, value: 'ok' };
+ yield { type: ServerGeminiEventType.Finished, value: 'STOP' };
+ })(),
+ );
+
+ const { result } = renderHook(() =>
+ useGeminiStream(
+ new MockedGeminiClientClass(mockConfig),
+ [],
+ mockAddItem,
+ mockConfig,
+ mockOnDebugMessage,
+ mockHandleSlashCommand,
+ false,
+ () => 'vscode' as EditorType,
+ () => {},
+ () => Promise.resolve(),
+ false,
+ () => {},
+ () => {},
+ () => {},
+ ),
+ );
+
+ await act(async () => {
+ await result.current.submitQuery('image prompt');
+ });
+
+ await waitFor(() => {
+ expect(mockHandleVisionSwitch).toHaveBeenCalled();
+ expect(mockSendMessageStream).toHaveBeenCalled();
+ });
+ });
+
+ it('should gate submission when handleVisionSwitch returns shouldProceed=false', async () => {
+ mockHandleVisionSwitch.mockResolvedValueOnce({ shouldProceed: false });
+
+ const { result } = renderHook(() =>
+ useGeminiStream(
+ new MockedGeminiClientClass(mockConfig),
+ [],
+ mockAddItem,
+ mockConfig,
+ mockOnDebugMessage,
+ mockHandleSlashCommand,
+ false,
+ () => 'vscode' as EditorType,
+ () => {},
+ () => Promise.resolve(),
+ false,
+ () => {},
+ () => {},
+ () => {},
+ ),
+ );
+
+ await act(async () => {
+ await result.current.submitQuery('vision-gated');
+ });
+
+ // No call to API, no restoreOriginalModel needed since no override occurred
+ expect(mockSendMessageStream).not.toHaveBeenCalled();
+ expect(mockRestoreOriginalModel).not.toHaveBeenCalled();
+
+ // Next call allowed (flag reset path)
+ mockHandleVisionSwitch.mockResolvedValueOnce({ shouldProceed: true });
+ mockSendMessageStream.mockReturnValue(
+ (async function* () {
+ yield { type: ServerGeminiEventType.Content, value: 'ok' };
+ yield { type: ServerGeminiEventType.Finished, value: 'STOP' };
+ })(),
+ );
+ await act(async () => {
+ await result.current.submitQuery('after-gate');
+ });
+ await waitFor(() => {
+ expect(mockSendMessageStream).toHaveBeenCalled();
+ });
+ });
+ });
+
+ describe('Model restore on completion and errors', () => {
+ it('should restore model after successful stream completion', async () => {
+ mockSendMessageStream.mockReturnValue(
+ (async function* () {
+ yield { type: ServerGeminiEventType.Content, value: 'content' };
+ yield { type: ServerGeminiEventType.Finished, value: 'STOP' };
+ })(),
+ );
+
+ const { result } = renderHook(() =>
+ useGeminiStream(
+ new MockedGeminiClientClass(mockConfig),
+ [],
+ mockAddItem,
+ mockConfig,
+ mockOnDebugMessage,
+ mockHandleSlashCommand,
+ false,
+ () => 'vscode' as EditorType,
+ () => {},
+ () => Promise.resolve(),
+ false,
+ () => {},
+ () => {},
+ () => {},
+ ),
+ );
+
+ await act(async () => {
+ await result.current.submitQuery('restore-success');
+ });
+
+ await waitFor(() => {
+ expect(mockRestoreOriginalModel).toHaveBeenCalledTimes(1);
+ });
+ });
+
+ it('should restore model when an error occurs during streaming', async () => {
+ const testError = new Error('stream failure');
+ mockSendMessageStream.mockReturnValue(
+ (async function* () {
+ yield { type: ServerGeminiEventType.Content, value: 'content' };
+ throw testError;
+ })(),
+ );
+
+ const { result } = renderHook(() =>
+ useGeminiStream(
+ new MockedGeminiClientClass(mockConfig),
+ [],
+ mockAddItem,
+ mockConfig,
+ mockOnDebugMessage,
+ mockHandleSlashCommand,
+ false,
+ () => 'vscode' as EditorType,
+ () => {},
+ () => Promise.resolve(),
+ false,
+ () => {},
+ () => {},
+ () => {},
+ ),
+ );
+
+ await act(async () => {
+ await result.current.submitQuery('restore-error');
+ });
+
+ await waitFor(() => {
+ expect(mockRestoreOriginalModel).toHaveBeenCalledTimes(1);
+ });
+ });
+ });
});
diff --git a/packages/cli/src/ui/hooks/useGeminiStream.ts b/packages/cli/src/ui/hooks/useGeminiStream.ts
index 07f4d7b9..7f34eaa2 100644
--- a/packages/cli/src/ui/hooks/useGeminiStream.ts
+++ b/packages/cli/src/ui/hooks/useGeminiStream.ts
@@ -42,6 +42,7 @@ import type {
import { StreamingState, MessageType, ToolCallStatus } from '../types.js';
import { isAtCommand, isSlashCommand } from '../utils/commandUtils.js';
import { useShellCommandProcessor } from './shellCommandProcessor.js';
+import { useVisionAutoSwitch } from './useVisionAutoSwitch.js';
import { handleAtCommand } from './atCommandProcessor.js';
import { findLastSafeSplitPoint } from '../utils/markdownUtilities.js';
import { useStateAndRef } from './useStateAndRef.js';
@@ -88,6 +89,12 @@ export const useGeminiStream = (
setModelSwitchedFromQuotaError: React.Dispatch>,
onEditorClose: () => void,
onCancelSubmit: () => void,
+ visionModelPreviewEnabled: boolean = false,
+ onVisionSwitchRequired?: (query: PartListUnion) => Promise<{
+ modelOverride?: string;
+ persistSessionModel?: string;
+ showGuidance?: boolean;
+ }>,
) => {
const [initError, setInitError] = useState(null);
const abortControllerRef = useRef(null);
@@ -155,6 +162,13 @@ export const useGeminiStream = (
geminiClient,
);
+ const { handleVisionSwitch, restoreOriginalModel } = useVisionAutoSwitch(
+ config,
+ addItem,
+ visionModelPreviewEnabled,
+ onVisionSwitchRequired,
+ );
+
const streamingState = useMemo(() => {
if (toolCalls.some((tc) => tc.status === 'awaiting_approval')) {
return StreamingState.WaitingForConfirmation;
@@ -715,6 +729,20 @@ export const useGeminiStream = (
return;
}
+ // Handle vision switch requirement
+ const visionSwitchResult = await handleVisionSwitch(
+ queryToSend,
+ userMessageTimestamp,
+ options?.isContinuation || false,
+ );
+
+ if (!visionSwitchResult.shouldProceed) {
+ isSubmittingQueryRef.current = false;
+ return;
+ }
+
+ const finalQueryToSend = queryToSend;
+
if (!options?.isContinuation) {
startNewPrompt();
setThought(null); // Reset thought when starting a new prompt
@@ -725,7 +753,7 @@ export const useGeminiStream = (
try {
const stream = geminiClient.sendMessageStream(
- queryToSend,
+ finalQueryToSend,
abortSignal,
prompt_id!,
);
@@ -736,6 +764,8 @@ export const useGeminiStream = (
);
if (processingStatus === StreamProcessingStatus.UserCancelled) {
+ // Restore original model if it was temporarily overridden
+ restoreOriginalModel();
isSubmittingQueryRef.current = false;
return;
}
@@ -748,7 +778,13 @@ export const useGeminiStream = (
loopDetectedRef.current = false;
handleLoopDetectedEvent();
}
+
+ // Restore original model if it was temporarily overridden
+ restoreOriginalModel();
} catch (error: unknown) {
+ // Restore original model if it was temporarily overridden
+ restoreOriginalModel();
+
if (error instanceof UnauthorizedError) {
onAuthError();
} else if (!isNodeError(error) || error.name !== 'AbortError') {
@@ -786,6 +822,8 @@ export const useGeminiStream = (
startNewPrompt,
getPromptCount,
handleLoopDetectedEvent,
+ handleVisionSwitch,
+ restoreOriginalModel,
],
);
diff --git a/packages/cli/src/ui/hooks/useVisionAutoSwitch.test.ts b/packages/cli/src/ui/hooks/useVisionAutoSwitch.test.ts
new file mode 100644
index 00000000..dd8c6a06
--- /dev/null
+++ b/packages/cli/src/ui/hooks/useVisionAutoSwitch.test.ts
@@ -0,0 +1,374 @@
+/**
+ * @license
+ * Copyright 2025 Qwen
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+/* eslint-disable @typescript-eslint/no-explicit-any */
+import { describe, it, expect, vi, beforeEach } from 'vitest';
+import { renderHook, act } from '@testing-library/react';
+import type { Part, PartListUnion } from '@google/genai';
+import { AuthType, type Config } from '@qwen-code/qwen-code-core';
+import {
+ shouldOfferVisionSwitch,
+ processVisionSwitchOutcome,
+ getVisionSwitchGuidanceMessage,
+ useVisionAutoSwitch,
+} from './useVisionAutoSwitch.js';
+import { VisionSwitchOutcome } from '../components/ModelSwitchDialog.js';
+import { MessageType } from '../types.js';
+import { getDefaultVisionModel } from '../models/availableModels.js';
+
+describe('useVisionAutoSwitch helpers', () => {
+ describe('shouldOfferVisionSwitch', () => {
+ it('returns false when authType is not QWEN_OAUTH', () => {
+ const parts: PartListUnion = [
+ { inlineData: { mimeType: 'image/png', data: '...' } },
+ ];
+ const result = shouldOfferVisionSwitch(
+ parts,
+ AuthType.USE_GEMINI,
+ 'qwen3-coder-plus',
+ true,
+ );
+ expect(result).toBe(false);
+ });
+
+ it('returns false when current model is already a vision model', () => {
+ const parts: PartListUnion = [
+ { inlineData: { mimeType: 'image/png', data: '...' } },
+ ];
+ const result = shouldOfferVisionSwitch(
+ parts,
+ AuthType.QWEN_OAUTH,
+ 'qwen-vl-max-latest',
+ true,
+ );
+ expect(result).toBe(false);
+ });
+
+ it('returns true when image parts exist, QWEN_OAUTH, and model is not vision', () => {
+ const parts: PartListUnion = [
+ { text: 'hello' },
+ { inlineData: { mimeType: 'image/jpeg', data: '...' } },
+ ];
+ const result = shouldOfferVisionSwitch(
+ parts,
+ AuthType.QWEN_OAUTH,
+ 'qwen3-coder-plus',
+ true,
+ );
+ expect(result).toBe(true);
+ });
+
+ it('detects image when provided as a single Part object (non-array)', () => {
+ const singleImagePart: PartListUnion = {
+ fileData: { mimeType: 'image/gif', fileUri: 'file://image.gif' },
+ } as Part;
+ const result = shouldOfferVisionSwitch(
+ singleImagePart,
+ AuthType.QWEN_OAUTH,
+ 'qwen3-coder-plus',
+ true,
+ );
+ expect(result).toBe(true);
+ });
+
+ it('returns false when parts contain no images', () => {
+ const parts: PartListUnion = [{ text: 'just text' }];
+ const result = shouldOfferVisionSwitch(
+ parts,
+ AuthType.QWEN_OAUTH,
+ 'qwen3-coder-plus',
+ true,
+ );
+ expect(result).toBe(false);
+ });
+
+ it('returns false when parts is a plain string', () => {
+ const parts: PartListUnion = 'plain text';
+ const result = shouldOfferVisionSwitch(
+ parts,
+ AuthType.QWEN_OAUTH,
+ 'qwen3-coder-plus',
+ true,
+ );
+ expect(result).toBe(false);
+ });
+
+ it('returns false when visionModelPreviewEnabled is false', () => {
+ const parts: PartListUnion = [
+ { inlineData: { mimeType: 'image/png', data: '...' } },
+ ];
+ const result = shouldOfferVisionSwitch(
+ parts,
+ AuthType.QWEN_OAUTH,
+ 'qwen3-coder-plus',
+ false,
+ );
+ expect(result).toBe(false);
+ });
+ });
+
+ describe('processVisionSwitchOutcome', () => {
+ it('maps SwitchOnce to a one-time model override', () => {
+ const vl = getDefaultVisionModel();
+ const result = processVisionSwitchOutcome(VisionSwitchOutcome.SwitchOnce);
+ expect(result).toEqual({ modelOverride: vl });
+ });
+
+ it('maps SwitchSessionToVL to a persistent session model', () => {
+ const vl = getDefaultVisionModel();
+ const result = processVisionSwitchOutcome(
+ VisionSwitchOutcome.SwitchSessionToVL,
+ );
+ expect(result).toEqual({ persistSessionModel: vl });
+ });
+
+ it('maps DisallowWithGuidance to showGuidance', () => {
+ const result = processVisionSwitchOutcome(
+ VisionSwitchOutcome.DisallowWithGuidance,
+ );
+ expect(result).toEqual({ showGuidance: true });
+ });
+ });
+
+ describe('getVisionSwitchGuidanceMessage', () => {
+ it('returns the expected guidance message', () => {
+ const vl = getDefaultVisionModel();
+ const expected =
+ 'To use images with your query, you can:\n' +
+ `• Use /model set ${vl} to switch to a vision-capable model\n` +
+ '• Or remove the image and provide a text description instead';
+ expect(getVisionSwitchGuidanceMessage()).toBe(expected);
+ });
+ });
+});
+
+describe('useVisionAutoSwitch hook', () => {
+ type AddItemFn = (
+ item: { type: MessageType; text: string },
+ ts: number,
+ ) => any;
+
+ const createMockConfig = (authType: AuthType, initialModel: string) => {
+ let currentModel = initialModel;
+ const mockConfig: Partial = {
+ getModel: vi.fn(() => currentModel),
+ setModel: vi.fn((m: string) => {
+ currentModel = m;
+ }),
+ getContentGeneratorConfig: vi.fn(() => ({
+ authType,
+ model: currentModel,
+ apiKey: 'test-key',
+ vertexai: false,
+ })),
+ };
+ return mockConfig as Config;
+ };
+
+ let addItem: AddItemFn;
+
+ beforeEach(() => {
+ vi.clearAllMocks();
+ addItem = vi.fn();
+ });
+
+ it('returns shouldProceed=true immediately for continuations', async () => {
+ const config = createMockConfig(AuthType.QWEN_OAUTH, 'qwen3-coder-plus');
+ const { result } = renderHook(() =>
+ useVisionAutoSwitch(config, addItem as any, true, vi.fn()),
+ );
+
+ const parts: PartListUnion = [
+ { inlineData: { mimeType: 'image/png', data: '...' } },
+ ];
+ let res: any;
+ await act(async () => {
+ res = await result.current.handleVisionSwitch(parts, Date.now(), true);
+ });
+ expect(res).toEqual({ shouldProceed: true });
+ expect(addItem).not.toHaveBeenCalled();
+ });
+
+ it('does nothing when authType is not QWEN_OAUTH', async () => {
+ const config = createMockConfig(AuthType.USE_GEMINI, 'qwen3-coder-plus');
+ const onVisionSwitchRequired = vi.fn();
+ const { result } = renderHook(() =>
+ useVisionAutoSwitch(config, addItem as any, true, onVisionSwitchRequired),
+ );
+
+ const parts: PartListUnion = [
+ { inlineData: { mimeType: 'image/png', data: '...' } },
+ ];
+ let res: any;
+ await act(async () => {
+ res = await result.current.handleVisionSwitch(parts, 123, false);
+ });
+ expect(res).toEqual({ shouldProceed: true });
+ expect(onVisionSwitchRequired).not.toHaveBeenCalled();
+ });
+
+ it('does nothing when there are no image parts', async () => {
+ const config = createMockConfig(AuthType.QWEN_OAUTH, 'qwen3-coder-plus');
+ const onVisionSwitchRequired = vi.fn();
+ const { result } = renderHook(() =>
+ useVisionAutoSwitch(config, addItem as any, true, onVisionSwitchRequired),
+ );
+
+ const parts: PartListUnion = [{ text: 'no images here' }];
+ let res: any;
+ await act(async () => {
+ res = await result.current.handleVisionSwitch(parts, 456, false);
+ });
+ expect(res).toEqual({ shouldProceed: true });
+ expect(onVisionSwitchRequired).not.toHaveBeenCalled();
+ });
+
+ it('shows guidance and blocks when dialog returns showGuidance', async () => {
+ const config = createMockConfig(AuthType.QWEN_OAUTH, 'qwen3-coder-plus');
+ const onVisionSwitchRequired = vi
+ .fn()
+ .mockResolvedValue({ showGuidance: true });
+ const { result } = renderHook(() =>
+ useVisionAutoSwitch(config, addItem as any, true, onVisionSwitchRequired),
+ );
+
+ const parts: PartListUnion = [
+ { inlineData: { mimeType: 'image/png', data: '...' } },
+ ];
+
+ const userTs = 1010;
+ let res: any;
+ await act(async () => {
+ res = await result.current.handleVisionSwitch(parts, userTs, false);
+ });
+
+ expect(addItem).toHaveBeenCalledWith(
+ { type: MessageType.INFO, text: getVisionSwitchGuidanceMessage() },
+ userTs,
+ );
+ expect(res).toEqual({ shouldProceed: false });
+ expect(config.setModel).not.toHaveBeenCalled();
+ });
+
+ it('applies a one-time override and returns originalModel, then restores', async () => {
+ const initialModel = 'qwen3-coder-plus';
+ const config = createMockConfig(AuthType.QWEN_OAUTH, initialModel);
+ const onVisionSwitchRequired = vi
+ .fn()
+ .mockResolvedValue({ modelOverride: 'qwen-vl-max-latest' });
+ const { result } = renderHook(() =>
+ useVisionAutoSwitch(config, addItem as any, true, onVisionSwitchRequired),
+ );
+
+ const parts: PartListUnion = [
+ { inlineData: { mimeType: 'image/png', data: '...' } },
+ ];
+
+ let res: any;
+ await act(async () => {
+ res = await result.current.handleVisionSwitch(parts, 2020, false);
+ });
+
+ expect(res).toEqual({ shouldProceed: true, originalModel: initialModel });
+ expect(config.setModel).toHaveBeenCalledWith('qwen-vl-max-latest');
+
+ // Now restore
+ act(() => {
+ result.current.restoreOriginalModel();
+ });
+ expect(config.setModel).toHaveBeenLastCalledWith(initialModel);
+ });
+
+ it('persists session model when dialog requests persistence', async () => {
+ const config = createMockConfig(AuthType.QWEN_OAUTH, 'qwen3-coder-plus');
+ const onVisionSwitchRequired = vi
+ .fn()
+ .mockResolvedValue({ persistSessionModel: 'qwen-vl-max-latest' });
+ const { result } = renderHook(() =>
+ useVisionAutoSwitch(config, addItem as any, true, onVisionSwitchRequired),
+ );
+
+ const parts: PartListUnion = [
+ { inlineData: { mimeType: 'image/png', data: '...' } },
+ ];
+
+ let res: any;
+ await act(async () => {
+ res = await result.current.handleVisionSwitch(parts, 3030, false);
+ });
+
+ expect(res).toEqual({ shouldProceed: true });
+ expect(config.setModel).toHaveBeenCalledWith('qwen-vl-max-latest');
+
+ // Restore should be a no-op since no one-time override was used
+ act(() => {
+ result.current.restoreOriginalModel();
+ });
+ // Last call should still be the persisted model set
+ expect((config.setModel as any).mock.calls.pop()?.[0]).toBe(
+ 'qwen-vl-max-latest',
+ );
+ });
+
+ it('returns shouldProceed=true when dialog returns no special flags', async () => {
+ const config = createMockConfig(AuthType.QWEN_OAUTH, 'qwen3-coder-plus');
+ const onVisionSwitchRequired = vi.fn().mockResolvedValue({});
+ const { result } = renderHook(() =>
+ useVisionAutoSwitch(config, addItem as any, true, onVisionSwitchRequired),
+ );
+
+ const parts: PartListUnion = [
+ { inlineData: { mimeType: 'image/png', data: '...' } },
+ ];
+ let res: any;
+ await act(async () => {
+ res = await result.current.handleVisionSwitch(parts, 4040, false);
+ });
+ expect(res).toEqual({ shouldProceed: true });
+ expect(config.setModel).not.toHaveBeenCalled();
+ });
+
+ it('blocks when dialog throws or is cancelled', async () => {
+ const config = createMockConfig(AuthType.QWEN_OAUTH, 'qwen3-coder-plus');
+ const onVisionSwitchRequired = vi.fn().mockRejectedValue(new Error('x'));
+ const { result } = renderHook(() =>
+ useVisionAutoSwitch(config, addItem as any, true, onVisionSwitchRequired),
+ );
+
+ const parts: PartListUnion = [
+ { inlineData: { mimeType: 'image/png', data: '...' } },
+ ];
+ let res: any;
+ await act(async () => {
+ res = await result.current.handleVisionSwitch(parts, 5050, false);
+ });
+ expect(res).toEqual({ shouldProceed: false });
+ expect(config.setModel).not.toHaveBeenCalled();
+ });
+
+ it('does nothing when visionModelPreviewEnabled is false', async () => {
+ const config = createMockConfig(AuthType.QWEN_OAUTH, 'qwen3-coder-plus');
+ const onVisionSwitchRequired = vi.fn();
+ const { result } = renderHook(() =>
+ useVisionAutoSwitch(
+ config,
+ addItem as any,
+ false,
+ onVisionSwitchRequired,
+ ),
+ );
+
+ const parts: PartListUnion = [
+ { inlineData: { mimeType: 'image/png', data: '...' } },
+ ];
+ let res: any;
+ await act(async () => {
+ res = await result.current.handleVisionSwitch(parts, 6060, false);
+ });
+ expect(res).toEqual({ shouldProceed: true });
+ expect(onVisionSwitchRequired).not.toHaveBeenCalled();
+ });
+});
diff --git a/packages/cli/src/ui/hooks/useVisionAutoSwitch.ts b/packages/cli/src/ui/hooks/useVisionAutoSwitch.ts
new file mode 100644
index 00000000..d4b9629c
--- /dev/null
+++ b/packages/cli/src/ui/hooks/useVisionAutoSwitch.ts
@@ -0,0 +1,304 @@
+/**
+ * @license
+ * Copyright 2025 Qwen
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+import { type PartListUnion, type Part } from '@google/genai';
+import { AuthType, type Config } from '@qwen-code/qwen-code-core';
+import { useCallback, useRef } from 'react';
+import { VisionSwitchOutcome } from '../components/ModelSwitchDialog.js';
+import {
+ getDefaultVisionModel,
+ isVisionModel,
+} from '../models/availableModels.js';
+import { MessageType } from '../types.js';
+import type { UseHistoryManagerReturn } from './useHistoryManager.js';
+import {
+ isSupportedImageMimeType,
+ getUnsupportedImageFormatWarning,
+} from '@qwen-code/qwen-code-core';
+
+/**
+ * Checks if a PartListUnion contains image parts
+ */
+function hasImageParts(parts: PartListUnion): boolean {
+ if (typeof parts === 'string') {
+ return false;
+ }
+
+ if (Array.isArray(parts)) {
+ return parts.some((part) => {
+ // Skip string parts
+ if (typeof part === 'string') return false;
+ return isImagePart(part);
+ });
+ }
+
+ // If it's a single Part (not a string), check if it's an image
+ if (typeof parts === 'object') {
+ return isImagePart(parts);
+ }
+
+ return false;
+}
+
+/**
+ * Checks if a single Part is an image part
+ */
+function isImagePart(part: Part): boolean {
+ // Check for inlineData with image mime type
+ if ('inlineData' in part && part.inlineData?.mimeType?.startsWith('image/')) {
+ return true;
+ }
+
+ // Check for fileData with image mime type
+ if ('fileData' in part && part.fileData?.mimeType?.startsWith('image/')) {
+ return true;
+ }
+
+ return false;
+}
+
+/**
+ * Checks if image parts have supported formats and returns unsupported ones
+ */
+function checkImageFormatsSupport(parts: PartListUnion): {
+ hasImages: boolean;
+ hasUnsupportedFormats: boolean;
+ unsupportedMimeTypes: string[];
+} {
+ const unsupportedMimeTypes: string[] = [];
+ let hasImages = false;
+
+ if (typeof parts === 'string') {
+ return {
+ hasImages: false,
+ hasUnsupportedFormats: false,
+ unsupportedMimeTypes: [],
+ };
+ }
+
+ const partsArray = Array.isArray(parts) ? parts : [parts];
+
+ for (const part of partsArray) {
+ if (typeof part === 'string') continue;
+
+ let mimeType: string | undefined;
+
+ // Check inlineData
+ if (
+ 'inlineData' in part &&
+ part.inlineData?.mimeType?.startsWith('image/')
+ ) {
+ hasImages = true;
+ mimeType = part.inlineData.mimeType;
+ }
+
+ // Check fileData
+ if ('fileData' in part && part.fileData?.mimeType?.startsWith('image/')) {
+ hasImages = true;
+ mimeType = part.fileData.mimeType;
+ }
+
+ // Check if the mime type is supported
+ if (mimeType && !isSupportedImageMimeType(mimeType)) {
+ unsupportedMimeTypes.push(mimeType);
+ }
+ }
+
+ return {
+ hasImages,
+ hasUnsupportedFormats: unsupportedMimeTypes.length > 0,
+ unsupportedMimeTypes,
+ };
+}
+
+/**
+ * Determines if we should offer vision switch for the given parts, auth type, and current model
+ */
+export function shouldOfferVisionSwitch(
+ parts: PartListUnion,
+ authType: AuthType,
+ currentModel: string,
+ visionModelPreviewEnabled: boolean = false,
+): boolean {
+ // Only trigger for qwen-oauth
+ if (authType !== AuthType.QWEN_OAUTH) {
+ return false;
+ }
+
+ // If vision model preview is disabled, never offer vision switch
+ if (!visionModelPreviewEnabled) {
+ return false;
+ }
+
+ // If current model is already a vision model, no need to switch
+ if (isVisionModel(currentModel)) {
+ return false;
+ }
+
+ // Check if the current message contains image parts
+ return hasImageParts(parts);
+}
+
+/**
+ * Interface for vision switch result
+ */
+export interface VisionSwitchResult {
+ modelOverride?: string;
+ persistSessionModel?: string;
+ showGuidance?: boolean;
+}
+
+/**
+ * Processes the vision switch outcome and returns the appropriate result
+ */
+export function processVisionSwitchOutcome(
+ outcome: VisionSwitchOutcome,
+): VisionSwitchResult {
+ const vlModelId = getDefaultVisionModel();
+
+ switch (outcome) {
+ case VisionSwitchOutcome.SwitchOnce:
+ return { modelOverride: vlModelId };
+
+ case VisionSwitchOutcome.SwitchSessionToVL:
+ return { persistSessionModel: vlModelId };
+
+ case VisionSwitchOutcome.DisallowWithGuidance:
+ return { showGuidance: true };
+
+ default:
+ return { showGuidance: true };
+ }
+}
+
+/**
+ * Gets the guidance message for when vision switch is disallowed
+ */
+export function getVisionSwitchGuidanceMessage(): string {
+ const vlModelId = getDefaultVisionModel();
+ return `To use images with your query, you can:
+• Use /model set ${vlModelId} to switch to a vision-capable model
+• Or remove the image and provide a text description instead`;
+}
+
+/**
+ * Interface for vision switch handling result
+ */
+export interface VisionSwitchHandlingResult {
+ shouldProceed: boolean;
+ originalModel?: string;
+}
+
+/**
+ * Custom hook for handling vision model auto-switching
+ */
+export function useVisionAutoSwitch(
+ config: Config,
+ addItem: UseHistoryManagerReturn['addItem'],
+ visionModelPreviewEnabled: boolean = false,
+ onVisionSwitchRequired?: (query: PartListUnion) => Promise<{
+ modelOverride?: string;
+ persistSessionModel?: string;
+ showGuidance?: boolean;
+ }>,
+) {
+ const originalModelRef = useRef(null);
+
+ const handleVisionSwitch = useCallback(
+ async (
+ query: PartListUnion,
+ userMessageTimestamp: number,
+ isContinuation: boolean,
+ ): Promise => {
+ // Skip vision switch handling for continuations or if no handler provided
+ if (isContinuation || !onVisionSwitchRequired) {
+ return { shouldProceed: true };
+ }
+
+ const contentGeneratorConfig = config.getContentGeneratorConfig();
+
+ // Only handle qwen-oauth auth type
+ if (contentGeneratorConfig?.authType !== AuthType.QWEN_OAUTH) {
+ return { shouldProceed: true };
+ }
+
+ // Check image format support first
+ const formatCheck = checkImageFormatsSupport(query);
+
+ // If there are unsupported image formats, show warning
+ if (formatCheck.hasUnsupportedFormats) {
+ addItem(
+ {
+ type: MessageType.INFO,
+ text: getUnsupportedImageFormatWarning(),
+ },
+ userMessageTimestamp,
+ );
+ // Continue processing but with warning shown
+ }
+
+ // Check if vision switch is needed
+ if (
+ !shouldOfferVisionSwitch(
+ query,
+ contentGeneratorConfig.authType,
+ config.getModel(),
+ visionModelPreviewEnabled,
+ )
+ ) {
+ return { shouldProceed: true };
+ }
+
+ try {
+ const visionSwitchResult = await onVisionSwitchRequired(query);
+
+ if (visionSwitchResult.showGuidance) {
+ // Show guidance and don't proceed with the request
+ addItem(
+ {
+ type: MessageType.INFO,
+ text: getVisionSwitchGuidanceMessage(),
+ },
+ userMessageTimestamp,
+ );
+ return { shouldProceed: false };
+ }
+
+ if (visionSwitchResult.modelOverride) {
+ // One-time model override
+ originalModelRef.current = config.getModel();
+ config.setModel(visionSwitchResult.modelOverride);
+ return {
+ shouldProceed: true,
+ originalModel: originalModelRef.current,
+ };
+ } else if (visionSwitchResult.persistSessionModel) {
+ // Persistent session model change
+ config.setModel(visionSwitchResult.persistSessionModel);
+ return { shouldProceed: true };
+ }
+
+ return { shouldProceed: true };
+ } catch (_error) {
+ // If vision switch dialog was cancelled or errored, don't proceed
+ return { shouldProceed: false };
+ }
+ },
+ [config, addItem, visionModelPreviewEnabled, onVisionSwitchRequired],
+ );
+
+ const restoreOriginalModel = useCallback(() => {
+ if (originalModelRef.current) {
+ config.setModel(originalModelRef.current);
+ originalModelRef.current = null;
+ }
+ }, [config]);
+
+ return {
+ handleVisionSwitch,
+ restoreOriginalModel,
+ };
+}
diff --git a/packages/cli/src/ui/models/availableModels.ts b/packages/cli/src/ui/models/availableModels.ts
new file mode 100644
index 00000000..7c3a1cf5
--- /dev/null
+++ b/packages/cli/src/ui/models/availableModels.ts
@@ -0,0 +1,52 @@
+/**
+ * @license
+ * Copyright 2025 Qwen
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+export type AvailableModel = {
+ id: string;
+ label: string;
+ isVision?: boolean;
+};
+
+export const AVAILABLE_MODELS_QWEN: AvailableModel[] = [
+ { id: 'qwen3-coder-plus', label: 'qwen3-coder-plus' },
+ { id: 'qwen-vl-max-latest', label: 'qwen-vl-max', isVision: true },
+];
+
+/**
+ * Get available Qwen models filtered by vision model preview setting
+ */
+export function getFilteredQwenModels(
+ visionModelPreviewEnabled: boolean,
+): AvailableModel[] {
+ if (visionModelPreviewEnabled) {
+ return AVAILABLE_MODELS_QWEN;
+ }
+ return AVAILABLE_MODELS_QWEN.filter((model) => !model.isVision);
+}
+
+/**
+ * Currently we use the single model of `OPENAI_MODEL` in the env.
+ * In the future, after settings.json is updated, we will allow users to configure this themselves.
+ */
+export function getOpenAIAvailableModelFromEnv(): AvailableModel | null {
+ const id = process.env['OPENAI_MODEL']?.trim();
+ return id ? { id, label: id } : null;
+}
+
+/**
+/**
+ * Hard code the default vision model as a string literal,
+ * until our coding model supports multimodal.
+ */
+export function getDefaultVisionModel(): string {
+ return 'qwen-vl-max-latest';
+}
+
+export function isVisionModel(modelId: string): boolean {
+ return AVAILABLE_MODELS_QWEN.some(
+ (model) => model.id === modelId && model.isVision,
+ );
+}
diff --git a/packages/core/index.ts b/packages/core/index.ts
index 447560d4..3cc271d0 100644
--- a/packages/core/index.ts
+++ b/packages/core/index.ts
@@ -19,3 +19,4 @@ export {
} from './src/telemetry/types.js';
export { makeFakeConfig } from './src/test-utils/config.js';
export * from './src/utils/pathReader.js';
+export * from './src/utils/request-tokenizer/supportedImageFormats.js';
diff --git a/packages/core/src/core/__tests__/openaiTimeoutHandling.test.ts b/packages/core/src/core/__tests__/openaiTimeoutHandling.test.ts
index f28c786c..7f4eec69 100644
--- a/packages/core/src/core/__tests__/openaiTimeoutHandling.test.ts
+++ b/packages/core/src/core/__tests__/openaiTimeoutHandling.test.ts
@@ -5,9 +5,10 @@
*/
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
-import { OpenAIContentGenerator } from '../openaiContentGenerator.js';
+import { OpenAIContentGenerator } from '../openaiContentGenerator/openaiContentGenerator.js';
import type { Config } from '../../config/config.js';
import { AuthType } from '../contentGenerator.js';
+import type { OpenAICompatibleProvider } from '../openaiContentGenerator/provider/index.js';
import OpenAI from 'openai';
// Mock OpenAI
@@ -30,6 +31,7 @@ describe('OpenAIContentGenerator Timeout Handling', () => {
let mockConfig: Config;
// eslint-disable-next-line @typescript-eslint/no-explicit-any
let mockOpenAIClient: any;
+ let mockProvider: OpenAICompatibleProvider;
beforeEach(() => {
// Reset mocks
@@ -42,6 +44,7 @@ describe('OpenAIContentGenerator Timeout Handling', () => {
mockConfig = {
getContentGeneratorConfig: vi.fn().mockReturnValue({
authType: 'openai',
+ enableOpenAILogging: false,
}),
getCliVersion: vi.fn().mockReturnValue('1.0.0'),
} as unknown as Config;
@@ -53,17 +56,34 @@ describe('OpenAIContentGenerator Timeout Handling', () => {
create: vi.fn(),
},
},
+ embeddings: {
+ create: vi.fn(),
+ },
};
vi.mocked(OpenAI).mockImplementation(() => mockOpenAIClient);
+ // Create mock provider
+ mockProvider = {
+ buildHeaders: vi.fn().mockReturnValue({
+ 'User-Agent': 'QwenCode/1.0.0 (test; test)',
+ }),
+ buildClient: vi.fn().mockReturnValue(mockOpenAIClient),
+ buildRequest: vi.fn().mockImplementation((req) => req),
+ };
+
// Create generator instance
const contentGeneratorConfig = {
model: 'gpt-4',
apiKey: 'test-key',
authType: AuthType.USE_OPENAI,
+ enableOpenAILogging: false,
};
- generator = new OpenAIContentGenerator(contentGeneratorConfig, mockConfig);
+ generator = new OpenAIContentGenerator(
+ contentGeneratorConfig,
+ mockConfig,
+ mockProvider,
+ );
});
afterEach(() => {
@@ -209,7 +229,7 @@ describe('OpenAIContentGenerator Timeout Handling', () => {
await expect(
generator.generateContentStream(request, 'test-prompt-id'),
).rejects.toThrow(
- /Streaming setup timeout after \d+s\. Try reducing input length or increasing timeout in config\./,
+ /Streaming request timeout after \d+s\. Try reducing input length or increasing timeout in config\./,
);
});
@@ -227,12 +247,8 @@ describe('OpenAIContentGenerator Timeout Handling', () => {
} catch (error: unknown) {
const errorMessage =
error instanceof Error ? error.message : String(error);
- expect(errorMessage).toContain(
- 'Streaming setup timeout troubleshooting:',
- );
- expect(errorMessage).toContain(
- 'Check network connectivity and firewall settings',
- );
+ expect(errorMessage).toContain('Streaming timeout troubleshooting:');
+ expect(errorMessage).toContain('Check network connectivity');
expect(errorMessage).toContain('Consider using non-streaming mode');
}
});
@@ -246,23 +262,21 @@ describe('OpenAIContentGenerator Timeout Handling', () => {
authType: AuthType.USE_OPENAI,
baseUrl: 'http://localhost:8080',
};
- new OpenAIContentGenerator(contentGeneratorConfig, mockConfig);
+ new OpenAIContentGenerator(
+ contentGeneratorConfig,
+ mockConfig,
+ mockProvider,
+ );
- // Verify OpenAI client was created with timeout config
- expect(OpenAI).toHaveBeenCalledWith({
- apiKey: 'test-key',
- baseURL: 'http://localhost:8080',
- timeout: 120000,
- maxRetries: 3,
- defaultHeaders: {
- 'User-Agent': expect.stringMatching(/^QwenCode/),
- },
- });
+ // Verify provider buildClient was called
+ expect(mockProvider.buildClient).toHaveBeenCalled();
});
it('should use custom timeout from config', () => {
const customConfig = {
- getContentGeneratorConfig: vi.fn().mockReturnValue({}),
+ getContentGeneratorConfig: vi.fn().mockReturnValue({
+ enableOpenAILogging: false,
+ }),
getCliVersion: vi.fn().mockReturnValue('1.0.0'),
} as unknown as Config;
@@ -274,22 +288,31 @@ describe('OpenAIContentGenerator Timeout Handling', () => {
timeout: 300000,
maxRetries: 5,
};
- new OpenAIContentGenerator(contentGeneratorConfig, customConfig);
- expect(OpenAI).toHaveBeenCalledWith({
- apiKey: 'test-key',
- baseURL: 'http://localhost:8080',
- timeout: 300000,
- maxRetries: 5,
- defaultHeaders: {
- 'User-Agent': expect.stringMatching(/^QwenCode/),
- },
- });
+ // Create a custom mock provider for this test
+ const customMockProvider: OpenAICompatibleProvider = {
+ buildHeaders: vi.fn().mockReturnValue({
+ 'User-Agent': 'QwenCode/1.0.0 (test; test)',
+ }),
+ buildClient: vi.fn().mockReturnValue(mockOpenAIClient),
+ buildRequest: vi.fn().mockImplementation((req) => req),
+ };
+
+ new OpenAIContentGenerator(
+ contentGeneratorConfig,
+ customConfig,
+ customMockProvider,
+ );
+
+ // Verify provider buildClient was called
+ expect(customMockProvider.buildClient).toHaveBeenCalled();
});
it('should handle missing timeout config gracefully', () => {
const noTimeoutConfig = {
- getContentGeneratorConfig: vi.fn().mockReturnValue({}),
+ getContentGeneratorConfig: vi.fn().mockReturnValue({
+ enableOpenAILogging: false,
+ }),
getCliVersion: vi.fn().mockReturnValue('1.0.0'),
} as unknown as Config;
@@ -299,17 +322,24 @@ describe('OpenAIContentGenerator Timeout Handling', () => {
authType: AuthType.USE_OPENAI,
baseUrl: 'http://localhost:8080',
};
- new OpenAIContentGenerator(contentGeneratorConfig, noTimeoutConfig);
- expect(OpenAI).toHaveBeenCalledWith({
- apiKey: 'test-key',
- baseURL: 'http://localhost:8080',
- timeout: 120000, // default
- maxRetries: 3, // default
- defaultHeaders: {
- 'User-Agent': expect.stringMatching(/^QwenCode/),
- },
- });
+ // Create a custom mock provider for this test
+ const noTimeoutMockProvider: OpenAICompatibleProvider = {
+ buildHeaders: vi.fn().mockReturnValue({
+ 'User-Agent': 'QwenCode/1.0.0 (test; test)',
+ }),
+ buildClient: vi.fn().mockReturnValue(mockOpenAIClient),
+ buildRequest: vi.fn().mockImplementation((req) => req),
+ };
+
+ new OpenAIContentGenerator(
+ contentGeneratorConfig,
+ noTimeoutConfig,
+ noTimeoutMockProvider,
+ );
+
+ // Verify provider buildClient was called
+ expect(noTimeoutMockProvider.buildClient).toHaveBeenCalled();
});
});
diff --git a/packages/core/src/core/geminiChat.ts b/packages/core/src/core/geminiChat.ts
index ca726bd3..bf8aa804 100644
--- a/packages/core/src/core/geminiChat.ts
+++ b/packages/core/src/core/geminiChat.ts
@@ -500,7 +500,7 @@ export class GeminiChat {
if (error instanceof Error && error.message) {
if (isSchemaDepthError(error.message)) return false;
if (error.message.includes('429')) return true;
- if (error.message.match(/5\d{2}/)) return true;
+ if (error.message.match(/^5\d{2}/)) return true;
}
return false;
},
diff --git a/packages/core/src/core/openaiContentGenerator.test.ts b/packages/core/src/core/openaiContentGenerator.test.ts
deleted file mode 100644
index d2b28842..00000000
--- a/packages/core/src/core/openaiContentGenerator.test.ts
+++ /dev/null
@@ -1,3511 +0,0 @@
-/**
- * @license
- * Copyright 2025 Qwen
- * SPDX-License-Identifier: Apache-2.0
- */
-
-import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
-import { OpenAIContentGenerator } from './openaiContentGenerator.js';
-import type { Config } from '../config/config.js';
-import { AuthType } from './contentGenerator.js';
-import OpenAI from 'openai';
-import type {
- GenerateContentParameters,
- CountTokensParameters,
- EmbedContentParameters,
- CallableTool,
- Content,
-} from '@google/genai';
-import { Type, FinishReason } from '@google/genai';
-
-// Mock OpenAI
-vi.mock('openai');
-
-// Mock logger modules
-vi.mock('../telemetry/loggers.js', () => ({
- logApiResponse: vi.fn(),
- logApiError: vi.fn(),
-}));
-
-vi.mock('../utils/openaiLogger.js', () => ({
- openaiLogger: {
- logInteraction: vi.fn(),
- },
-}));
-
-// Mock tiktoken
-vi.mock('tiktoken', () => ({
- get_encoding: vi.fn().mockReturnValue({
- encode: vi.fn().mockReturnValue(new Array(50)), // Mock 50 tokens
- free: vi.fn(),
- }),
-}));
-
-describe('OpenAIContentGenerator', () => {
- let generator: OpenAIContentGenerator;
- let mockConfig: Config;
- // eslint-disable-next-line @typescript-eslint/no-explicit-any
- let mockOpenAIClient: any;
-
- beforeEach(() => {
- // Reset mocks
- vi.clearAllMocks();
-
- // Mock environment variables
- vi.stubEnv('OPENAI_BASE_URL', '');
-
- // Mock config
- mockConfig = {
- getContentGeneratorConfig: vi.fn().mockReturnValue({
- authType: 'openai',
- enableOpenAILogging: false,
- timeout: 120000,
- maxRetries: 3,
- samplingParams: {
- temperature: 0.7,
- max_tokens: 1000,
- top_p: 0.9,
- },
- }),
- getCliVersion: vi.fn().mockReturnValue('1.0.0'),
- } as unknown as Config;
-
- // Mock OpenAI client
- mockOpenAIClient = {
- chat: {
- completions: {
- create: vi.fn(),
- },
- },
- embeddings: {
- create: vi.fn(),
- },
- };
-
- vi.mocked(OpenAI).mockImplementation(() => mockOpenAIClient);
-
- // Create generator instance
- const contentGeneratorConfig = {
- model: 'gpt-4',
- apiKey: 'test-key',
- authType: AuthType.USE_OPENAI,
- enableOpenAILogging: false,
- timeout: 120000,
- maxRetries: 3,
- samplingParams: {
- temperature: 0.7,
- max_tokens: 1000,
- top_p: 0.9,
- },
- };
- generator = new OpenAIContentGenerator(contentGeneratorConfig, mockConfig);
- });
-
- afterEach(() => {
- vi.restoreAllMocks();
- });
-
- describe('constructor', () => {
- it('should initialize with basic configuration', () => {
- expect(OpenAI).toHaveBeenCalledWith({
- apiKey: 'test-key',
- baseURL: undefined,
- timeout: 120000,
- maxRetries: 3,
- defaultHeaders: {
- 'User-Agent': expect.stringMatching(/^QwenCode/),
- },
- });
- });
-
- it('should handle custom base URL', () => {
- const contentGeneratorConfig = {
- model: 'gpt-4',
- apiKey: 'test-key',
- baseUrl: 'https://api.custom.com',
- authType: AuthType.USE_OPENAI,
- enableOpenAILogging: false,
- timeout: 120000,
- maxRetries: 3,
- };
- new OpenAIContentGenerator(contentGeneratorConfig, mockConfig);
-
- expect(OpenAI).toHaveBeenCalledWith({
- apiKey: 'test-key',
- baseURL: 'https://api.custom.com',
- timeout: 120000,
- maxRetries: 3,
- defaultHeaders: {
- 'User-Agent': expect.stringMatching(/^QwenCode/),
- },
- });
- });
-
- it('should configure OpenRouter headers when using OpenRouter', () => {
- const contentGeneratorConfig = {
- model: 'gpt-4',
- apiKey: 'test-key',
- baseUrl: 'https://openrouter.ai/api/v1',
- authType: AuthType.USE_OPENAI,
- enableOpenAILogging: false,
- timeout: 120000,
- maxRetries: 3,
- };
- new OpenAIContentGenerator(contentGeneratorConfig, mockConfig);
-
- expect(OpenAI).toHaveBeenCalledWith({
- apiKey: 'test-key',
- baseURL: 'https://openrouter.ai/api/v1',
- timeout: 120000,
- maxRetries: 3,
- defaultHeaders: {
- 'User-Agent': expect.stringMatching(/^QwenCode/),
- 'HTTP-Referer': 'https://github.com/QwenLM/qwen-code.git',
- 'X-Title': 'Qwen Code',
- },
- });
- });
-
- it('should override timeout settings from config', () => {
- const customConfig = {
- getContentGeneratorConfig: vi.fn().mockReturnValue({
- timeout: 300000,
- maxRetries: 5,
- }),
- getCliVersion: vi.fn().mockReturnValue('1.0.0'),
- } as unknown as Config;
-
- const contentGeneratorConfig = {
- model: 'gpt-4',
- apiKey: 'test-key',
- authType: AuthType.USE_OPENAI,
- timeout: 300000,
- maxRetries: 5,
- };
- new OpenAIContentGenerator(contentGeneratorConfig, customConfig);
-
- expect(OpenAI).toHaveBeenCalledWith({
- apiKey: 'test-key',
- baseURL: undefined,
- timeout: 300000,
- maxRetries: 5,
- defaultHeaders: {
- 'User-Agent': expect.stringMatching(/^QwenCode/),
- },
- });
- });
- });
-
- describe('generateContent', () => {
- it('should generate content successfully', async () => {
- const mockResponse = {
- id: 'chatcmpl-123',
- object: 'chat.completion',
- created: 1677652288,
- model: 'gpt-4',
- choices: [
- {
- index: 0,
- message: {
- role: 'assistant',
- content: 'Hello! How can I help you?',
- },
- finish_reason: 'stop',
- },
- ],
- usage: {
- prompt_tokens: 10,
- completion_tokens: 15,
- total_tokens: 25,
- },
- };
-
- mockOpenAIClient.chat.completions.create.mockResolvedValue(mockResponse);
-
- const request: GenerateContentParameters = {
- contents: [{ role: 'user', parts: [{ text: 'Hello' }] }],
- model: 'gpt-4',
- };
-
- const result = await generator.generateContent(request, 'test-prompt-id');
-
- expect(result.candidates).toHaveLength(1);
- if (
- result.candidates &&
- result.candidates.length > 0 &&
- result.candidates[0]
- ) {
- const firstCandidate = result.candidates[0];
- if (firstCandidate.content) {
- expect(firstCandidate.content.parts).toEqual([
- { text: 'Hello! How can I help you?' },
- ]);
- }
- }
- expect(result.usageMetadata).toEqual({
- promptTokenCount: 10,
- candidatesTokenCount: 15,
- totalTokenCount: 25,
- cachedContentTokenCount: 0,
- });
- });
-
- it('should handle system instructions', async () => {
- const mockResponse = {
- id: 'chatcmpl-123',
- choices: [
- {
- index: 0,
- message: { role: 'assistant', content: 'Response' },
- finish_reason: 'stop',
- },
- ],
- created: 1677652288,
- model: 'gpt-4',
- };
-
- mockOpenAIClient.chat.completions.create.mockResolvedValue(mockResponse);
-
- const request: GenerateContentParameters = {
- contents: [{ role: 'user', parts: [{ text: 'Hello' }] }],
- model: 'gpt-4',
- config: {
- systemInstruction: 'You are a helpful assistant.',
- },
- };
-
- await generator.generateContent(request, 'test-prompt-id');
-
- expect(mockOpenAIClient.chat.completions.create).toHaveBeenCalledWith(
- expect.objectContaining({
- messages: [
- { role: 'system', content: 'You are a helpful assistant.' },
- { role: 'user', content: 'Hello' },
- ],
- }),
- );
- });
-
- it('should handle function calls', async () => {
- const mockResponse = {
- id: 'chatcmpl-123',
- choices: [
- {
- index: 0,
- message: {
- role: 'assistant',
- content: null,
- tool_calls: [
- {
- id: 'call_123',
- type: 'function',
- function: {
- name: 'get_weather',
- arguments: '{"location": "New York"}',
- },
- },
- ],
- },
- finish_reason: 'tool_calls',
- },
- ],
- created: 1677652288,
- model: 'gpt-4',
- };
-
- mockOpenAIClient.chat.completions.create.mockResolvedValue(mockResponse);
-
- const request: GenerateContentParameters = {
- contents: [{ role: 'user', parts: [{ text: 'What is the weather?' }] }],
- model: 'gpt-4',
- config: {
- tools: [
- {
- callTool: vi.fn(),
- tool: () =>
- Promise.resolve({
- functionDeclarations: [
- {
- name: 'get_weather',
- description: 'Get weather information',
- parameters: {
- type: Type.OBJECT,
- properties: { location: { type: Type.STRING } },
- },
- },
- ],
- }),
- } as unknown as CallableTool,
- ],
- },
- };
-
- const result = await generator.generateContent(request, 'test-prompt-id');
-
- if (
- result.candidates &&
- result.candidates.length > 0 &&
- result.candidates[0]
- ) {
- const firstCandidate = result.candidates[0];
- if (firstCandidate.content) {
- expect(firstCandidate.content.parts).toEqual([
- {
- functionCall: {
- id: 'call_123',
- name: 'get_weather',
- args: { location: 'New York' },
- },
- },
- ]);
- }
- }
- });
-
- it('should apply sampling parameters from config', async () => {
- const mockResponse = {
- id: 'chatcmpl-123',
- choices: [
- {
- index: 0,
- message: { role: 'assistant', content: 'Response' },
- finish_reason: 'stop',
- },
- ],
- created: 1677652288,
- model: 'gpt-4',
- };
-
- mockOpenAIClient.chat.completions.create.mockResolvedValue(mockResponse);
-
- const request: GenerateContentParameters = {
- contents: [{ role: 'user', parts: [{ text: 'Hello' }] }],
- model: 'gpt-4',
- };
-
- await generator.generateContent(request, 'test-prompt-id');
-
- expect(mockOpenAIClient.chat.completions.create).toHaveBeenCalledWith(
- expect.objectContaining({
- temperature: 0.7,
- max_tokens: 1000,
- top_p: 0.9,
- }),
- );
- });
-
- it('should prioritize request-level parameters over config', async () => {
- const mockResponse = {
- id: 'chatcmpl-123',
- choices: [
- {
- index: 0,
- message: { role: 'assistant', content: 'Response' },
- finish_reason: 'stop',
- },
- ],
- created: 1677652288,
- model: 'gpt-4',
- };
-
- mockOpenAIClient.chat.completions.create.mockResolvedValue(mockResponse);
-
- const request: GenerateContentParameters = {
- contents: [{ role: 'user', parts: [{ text: 'Hello' }] }],
- model: 'gpt-4',
- config: {
- temperature: 0.5,
- maxOutputTokens: 500,
- },
- };
-
- await generator.generateContent(request, 'test-prompt-id');
-
- expect(mockOpenAIClient.chat.completions.create).toHaveBeenCalledWith(
- expect.objectContaining({
- temperature: 0.7, // From config sampling params (higher priority)
- max_tokens: 1000, // From config sampling params (higher priority)
- top_p: 0.9,
- }),
- );
- });
- });
-
- describe('generateContentStream', () => {
- it('should handle streaming responses', async () => {
- const mockStream = [
- {
- id: 'chatcmpl-123',
- choices: [
- {
- index: 0,
- delta: { content: 'Hello' },
- finish_reason: null,
- },
- ],
- created: 1677652288,
- },
- {
- id: 'chatcmpl-123',
- choices: [
- {
- index: 0,
- delta: { content: ' there!' },
- finish_reason: 'stop',
- },
- ],
- created: 1677652288,
- usage: {
- prompt_tokens: 10,
- completion_tokens: 5,
- total_tokens: 15,
- },
- },
- ];
-
- // Mock async iterable
- mockOpenAIClient.chat.completions.create.mockResolvedValue({
- async *[Symbol.asyncIterator]() {
- for (const chunk of mockStream) {
- yield chunk;
- }
- },
- });
-
- const request: GenerateContentParameters = {
- contents: [{ role: 'user', parts: [{ text: 'Hello' }] }],
- model: 'gpt-4',
- };
-
- const stream = await generator.generateContentStream(
- request,
- 'test-prompt-id',
- );
- const responses = [];
- for await (const response of stream) {
- responses.push(response);
- }
-
- expect(responses).toHaveLength(2);
- if (
- responses[0]?.candidates &&
- responses[0].candidates.length > 0 &&
- responses[0].candidates[0]
- ) {
- const firstCandidate = responses[0].candidates[0];
- if (firstCandidate.content) {
- expect(firstCandidate.content.parts).toEqual([{ text: 'Hello' }]);
- }
- }
- if (
- responses[1]?.candidates &&
- responses[1].candidates.length > 0 &&
- responses[1].candidates[0]
- ) {
- const secondCandidate = responses[1].candidates[0];
- if (secondCandidate.content) {
- expect(secondCandidate.content.parts).toEqual([{ text: ' there!' }]);
- }
- }
- expect(responses[1].usageMetadata).toEqual({
- promptTokenCount: 10,
- candidatesTokenCount: 5,
- totalTokenCount: 15,
- cachedContentTokenCount: 0,
- });
- });
-
- it('should handle streaming tool calls', async () => {
- const mockStream = [
- {
- id: 'chatcmpl-123',
- choices: [
- {
- index: 0,
- delta: {
- tool_calls: [
- {
- index: 0,
- id: 'call_123',
- function: { name: 'get_weather' },
- },
- ],
- },
- finish_reason: null,
- },
- ],
- created: 1677652288,
- },
- {
- id: 'chatcmpl-123',
- choices: [
- {
- index: 0,
- delta: {
- tool_calls: [
- {
- index: 0,
- function: { arguments: '{"location": "NYC"}' },
- },
- ],
- },
- finish_reason: 'tool_calls',
- },
- ],
- created: 1677652288,
- },
- ];
-
- mockOpenAIClient.chat.completions.create.mockResolvedValue({
- async *[Symbol.asyncIterator]() {
- for (const chunk of mockStream) {
- yield chunk;
- }
- },
- });
-
- const request: GenerateContentParameters = {
- contents: [{ role: 'user', parts: [{ text: 'Weather?' }] }],
- model: 'gpt-4',
- };
-
- const stream = await generator.generateContentStream(
- request,
- 'test-prompt-id',
- );
- const responses = [];
- for await (const response of stream) {
- responses.push(response);
- }
-
- // First response should contain the complete tool call (accumulated from streaming)
- if (
- responses[0]?.candidates &&
- responses[0].candidates.length > 0 &&
- responses[0].candidates[0]
- ) {
- const firstCandidate = responses[0].candidates[0];
- if (firstCandidate.content) {
- expect(firstCandidate.content.parts).toEqual([
- {
- functionCall: {
- id: 'call_123',
- name: 'get_weather',
- args: { location: 'NYC' },
- },
- },
- ]);
- }
- }
- if (
- responses[1]?.candidates &&
- responses[1].candidates.length > 0 &&
- responses[1].candidates[0]
- ) {
- const secondCandidate = responses[1].candidates[0];
- if (secondCandidate.content) {
- expect(secondCandidate.content.parts).toEqual([
- {
- functionCall: {
- id: 'call_123',
- name: 'get_weather',
- args: { location: 'NYC' },
- },
- },
- ]);
- }
- }
- });
- });
-
- describe('countTokens', () => {
- it('should count tokens using tiktoken', async () => {
- const request: CountTokensParameters = {
- contents: [{ role: 'user', parts: [{ text: 'Hello world' }] }],
- model: 'gpt-4',
- };
-
- const result = await generator.countTokens(request);
-
- expect(result.totalTokens).toBe(50); // Mocked value
- });
-
- it('should fall back to character approximation if tiktoken fails', async () => {
- // Mock tiktoken to throw error
- vi.doMock('tiktoken', () => ({
- get_encoding: vi.fn().mockImplementation(() => {
- throw new Error('Tiktoken failed');
- }),
- }));
-
- const request: CountTokensParameters = {
- contents: [{ role: 'user', parts: [{ text: 'Hello world' }] }],
- model: 'gpt-4',
- };
-
- const result = await generator.countTokens(request);
-
- // Should use character approximation (content length / 4)
- expect(result.totalTokens).toBeGreaterThan(0);
- });
- });
-
- describe('embedContent', () => {
- it('should generate embeddings for text content', async () => {
- const mockEmbedding = {
- data: [{ embedding: [0.1, 0.2, 0.3, 0.4] }],
- model: 'text-embedding-ada-002',
- usage: { prompt_tokens: 5, total_tokens: 5 },
- };
-
- mockOpenAIClient.embeddings.create.mockResolvedValue(mockEmbedding);
-
- const request: EmbedContentParameters = {
- contents: [{ role: 'user', parts: [{ text: 'Hello world' }] }],
- model: 'text-embedding-ada-002',
- };
-
- const result = await generator.embedContent(request);
-
- expect(result.embeddings).toHaveLength(1);
- expect(result.embeddings?.[0]?.values).toEqual([0.1, 0.2, 0.3, 0.4]);
- expect(mockOpenAIClient.embeddings.create).toHaveBeenCalledWith({
- model: 'text-embedding-ada-002',
- input: 'Hello world',
- });
- });
-
- it('should handle string content', async () => {
- const mockEmbedding = {
- data: [{ embedding: [0.1, 0.2] }],
- };
-
- mockOpenAIClient.embeddings.create.mockResolvedValue(mockEmbedding);
-
- const request: EmbedContentParameters = {
- contents: 'Simple text',
- model: 'text-embedding-ada-002',
- };
-
- await generator.embedContent(request);
-
- expect(mockOpenAIClient.embeddings.create).toHaveBeenCalledWith({
- model: 'text-embedding-ada-002',
- input: 'Simple text',
- });
- });
-
- it('should handle embedding errors', async () => {
- const error = new Error('Embedding failed');
- mockOpenAIClient.embeddings.create.mockRejectedValue(error);
-
- const request: EmbedContentParameters = {
- contents: 'Test text',
- model: 'text-embedding-ada-002',
- };
-
- await expect(generator.embedContent(request)).rejects.toThrow(
- 'Embedding failed',
- );
- });
- });
-
- describe('error handling', () => {
- it('should handle API errors with proper error message', async () => {
- const apiError = new Error('Invalid API key');
- mockOpenAIClient.chat.completions.create.mockRejectedValue(apiError);
-
- const request: GenerateContentParameters = {
- contents: [{ role: 'user', parts: [{ text: 'Hello' }] }],
- model: 'gpt-4',
- };
-
- await expect(
- generator.generateContent(request, 'test-prompt-id'),
- ).rejects.toThrow('Invalid API key');
- });
-
- it('should estimate tokens on error for telemetry', async () => {
- const apiError = new Error('API error');
- mockOpenAIClient.chat.completions.create.mockRejectedValue(apiError);
-
- const request: GenerateContentParameters = {
- contents: [{ role: 'user', parts: [{ text: 'Hello' }] }],
- model: 'gpt-4',
- };
-
- try {
- await generator.generateContent(request, 'test-prompt-id');
- } catch (error) {
- // Error should be thrown but token estimation should have been attempted
- expect(error).toBeInstanceOf(Error);
- }
- });
-
- it('should preserve error status codes like 429', async () => {
- // Create an error object with status property like OpenAI SDK would
- const apiError = Object.assign(new Error('Rate limit exceeded'), {
- status: 429,
- });
- mockOpenAIClient.chat.completions.create.mockRejectedValue(apiError);
-
- const request: GenerateContentParameters = {
- contents: [{ role: 'user', parts: [{ text: 'Hello' }] }],
- model: 'gpt-4',
- };
-
- try {
- await generator.generateContent(request, 'test-prompt-id');
- expect.fail('Expected error to be thrown');
- } catch (error: unknown) {
- // Should throw the original error object with status preserved
- expect((error as Error & { status: number }).message).toBe(
- 'Rate limit exceeded',
- );
- expect((error as Error & { status: number }).status).toBe(429);
- }
- });
- });
-
- describe('message conversion', () => {
- it('should convert function responses to tool messages', async () => {
- const mockResponse = {
- id: 'chatcmpl-123',
- choices: [
- {
- index: 0,
- message: { role: 'assistant', content: 'Response' },
- finish_reason: 'stop',
- },
- ],
- created: 1677652288,
- model: 'gpt-4',
- };
-
- mockOpenAIClient.chat.completions.create.mockResolvedValue(mockResponse);
-
- const request: GenerateContentParameters = {
- contents: [
- { role: 'user', parts: [{ text: 'What is the weather?' }] },
- {
- role: 'model',
- parts: [
- {
- functionCall: {
- id: 'call_123',
- name: 'get_weather',
- args: { location: 'NYC' },
- },
- },
- ],
- },
- {
- role: 'user',
- parts: [
- {
- functionResponse: {
- id: 'call_123',
- name: 'get_weather',
- response: { temperature: '72F', condition: 'sunny' },
- },
- },
- ],
- },
- ],
- model: 'gpt-4',
- };
-
- await generator.generateContent(request, 'test-prompt-id');
-
- expect(mockOpenAIClient.chat.completions.create).toHaveBeenCalledWith(
- expect.objectContaining({
- messages: expect.arrayContaining([
- { role: 'user', content: 'What is the weather?' },
- {
- role: 'assistant',
- content: null,
- tool_calls: [
- {
- id: 'call_123',
- type: 'function',
- function: {
- name: 'get_weather',
- arguments: '{"location":"NYC"}',
- },
- },
- ],
- },
- {
- role: 'tool',
- tool_call_id: 'call_123',
- content: '{"temperature":"72F","condition":"sunny"}',
- },
- ]),
- }),
- );
- });
-
- it('should clean up orphaned tool calls', async () => {
- const mockResponse = {
- id: 'chatcmpl-123',
- choices: [
- {
- index: 0,
- message: { role: 'assistant', content: 'Response' },
- finish_reason: 'stop',
- },
- ],
- created: 1677652288,
- model: 'gpt-4',
- };
-
- mockOpenAIClient.chat.completions.create.mockResolvedValue(mockResponse);
-
- const request: GenerateContentParameters = {
- contents: [
- {
- role: 'model',
- parts: [
- {
- functionCall: {
- id: 'call_orphaned',
- name: 'orphaned_function',
- args: {},
- },
- },
- ],
- },
- // No corresponding function response
- ],
- model: 'gpt-4',
- };
-
- await generator.generateContent(request, 'test-prompt-id');
-
- // Should not include the orphaned tool call
- expect(mockOpenAIClient.chat.completions.create).toHaveBeenCalledWith(
- expect.objectContaining({
- messages: [], // Empty because orphaned tool call was cleaned up
- }),
- );
- });
- });
-
- describe('finish reason mapping', () => {
- it('should map OpenAI finish reasons to Gemini format', async () => {
- const testCases = [
- { openai: 'stop', expected: FinishReason.STOP },
- { openai: 'length', expected: FinishReason.MAX_TOKENS },
- { openai: 'content_filter', expected: FinishReason.SAFETY },
- { openai: 'function_call', expected: FinishReason.STOP },
- { openai: 'tool_calls', expected: FinishReason.STOP },
- ];
-
- for (const testCase of testCases) {
- const mockResponse = {
- id: 'chatcmpl-123',
- choices: [
- {
- index: 0,
- message: { role: 'assistant', content: 'Response' },
- finish_reason: testCase.openai,
- },
- ],
- created: 1677652288,
- model: 'gpt-4',
- };
-
- mockOpenAIClient.chat.completions.create.mockResolvedValue(
- mockResponse,
- );
-
- const request: GenerateContentParameters = {
- contents: [{ role: 'user', parts: [{ text: 'Hello' }] }],
- model: 'gpt-4',
- };
-
- const result = await generator.generateContent(
- request,
- 'test-prompt-id',
- );
- if (
- result.candidates &&
- result.candidates.length > 0 &&
- result.candidates[0]
- ) {
- const firstCandidate = result.candidates[0];
- expect(firstCandidate.finishReason).toBe(testCase.expected);
- }
- }
- });
- });
-
- describe('logging integration', () => {
- it('should log interactions when enabled', async () => {
- const loggingConfig = {
- getContentGeneratorConfig: vi.fn().mockReturnValue({
- enableOpenAILogging: true,
- }),
- getCliVersion: vi.fn().mockReturnValue('1.0.0'),
- } as unknown as Config;
-
- const contentGeneratorConfig = {
- model: 'gpt-4',
- apiKey: 'test-key',
- authType: AuthType.USE_OPENAI,
- enableOpenAILogging: true,
- };
- const loggingGenerator = new OpenAIContentGenerator(
- contentGeneratorConfig,
- loggingConfig,
- );
-
- const mockResponse = {
- id: 'chatcmpl-123',
- choices: [
- {
- index: 0,
- message: { role: 'assistant', content: 'Response' },
- finish_reason: 'stop',
- },
- ],
- created: 1677652288,
- model: 'gpt-4',
- };
-
- mockOpenAIClient.chat.completions.create.mockResolvedValue(mockResponse);
-
- const request: GenerateContentParameters = {
- contents: [{ role: 'user', parts: [{ text: 'Hello' }] }],
- model: 'gpt-4',
- };
-
- await loggingGenerator.generateContent(request, 'test-prompt-id');
-
- // Verify logging was called
- const { openaiLogger } = await import('../utils/openaiLogger.js');
- expect(openaiLogger.logInteraction).toHaveBeenCalled();
- });
- });
-
- describe('timeout error detection', () => {
- it('should detect various timeout error patterns', async () => {
- const timeoutErrors = [
- new Error('timeout'),
- new Error('Request timed out'),
- new Error('Connection timeout occurred'),
- new Error('ETIMEDOUT'),
- new Error('ESOCKETTIMEDOUT'),
- { code: 'ETIMEDOUT', message: 'Connection timed out' },
- { type: 'timeout', message: 'Request timeout' },
- new Error('deadline exceeded'),
- ];
-
- for (const error of timeoutErrors) {
- mockOpenAIClient.chat.completions.create.mockRejectedValueOnce(error);
-
- const request: GenerateContentParameters = {
- contents: [{ role: 'user', parts: [{ text: 'Hello' }] }],
- model: 'gpt-4',
- };
-
- try {
- await generator.generateContent(request, 'test-prompt-id');
- // Should not reach here
- expect(true).toBe(false);
- } catch (error) {
- const errorMessage =
- error instanceof Error ? error.message : String(error);
- expect(errorMessage).toMatch(/timeout|Troubleshooting tips/);
- }
- }
- });
-
- it('should provide timeout-specific error messages', async () => {
- const timeoutError = new Error('Request timeout');
- mockOpenAIClient.chat.completions.create.mockRejectedValue(timeoutError);
-
- const request: GenerateContentParameters = {
- contents: [{ role: 'user', parts: [{ text: 'Hello' }] }],
- model: 'gpt-4',
- };
-
- await expect(
- generator.generateContent(request, 'test-prompt-id'),
- ).rejects.toThrow(
- /Troubleshooting tips.*Reduce input length.*Increase timeout.*Check network/s,
- );
- });
- });
-
- describe('streaming error handling', () => {
- it('should handle errors during streaming setup', async () => {
- const setupError = new Error('Streaming setup failed');
- mockOpenAIClient.chat.completions.create.mockRejectedValue(setupError);
-
- const request: GenerateContentParameters = {
- contents: [{ role: 'user', parts: [{ text: 'Hello' }] }],
- model: 'gpt-4',
- };
-
- await expect(
- generator.generateContent(request, 'test-prompt-id'),
- ).rejects.toThrow('Streaming setup failed');
- });
-
- it('should handle timeout errors during streaming setup', async () => {
- const timeoutError = new Error('Streaming setup timeout');
- mockOpenAIClient.chat.completions.create.mockRejectedValue(timeoutError);
-
- const request: GenerateContentParameters = {
- contents: [{ role: 'user', parts: [{ text: 'Hello' }] }],
- model: 'gpt-4',
- };
-
- await expect(
- generator.generateContentStream(request, 'test-prompt-id'),
- ).rejects.toThrow(
- /Streaming setup timeout troubleshooting.*Reduce input length/s,
- );
- });
-
- it('should handle errors during streaming with logging', async () => {
- const loggingConfig = {
- getContentGeneratorConfig: vi.fn().mockReturnValue({
- enableOpenAILogging: true,
- }),
- getCliVersion: vi.fn().mockReturnValue('1.0.0'),
- } as unknown as Config;
-
- const contentGeneratorConfig = {
- model: 'gpt-4',
- apiKey: 'test-key',
- authType: AuthType.USE_OPENAI,
- enableOpenAILogging: true,
- };
- const loggingGenerator = new OpenAIContentGenerator(
- contentGeneratorConfig,
- loggingConfig,
- );
-
- // Mock stream that throws an error
- const mockStream = {
- async *[Symbol.asyncIterator]() {
- yield {
- id: 'chatcmpl-123',
- choices: [
- {
- index: 0,
- delta: { content: 'Hello' },
- finish_reason: null,
- },
- ],
- created: 1677652288,
- };
- throw new Error('Stream error');
- },
- };
-
- mockOpenAIClient.chat.completions.create.mockResolvedValue(mockStream);
-
- const request: GenerateContentParameters = {
- contents: [{ role: 'user', parts: [{ text: 'Hello' }] }],
- model: 'gpt-4',
- };
-
- const stream = await loggingGenerator.generateContentStream(
- request,
- 'test-prompt-id',
- );
-
- // Consume the stream and expect error
- await expect(async () => {
- for await (const chunk of stream) {
- // Stream will throw during iteration
- console.log('Processing chunk:', chunk); // Use chunk to avoid warning
- }
- }).rejects.toThrow('Stream error');
- });
- });
-
- describe('tool parameter conversion', () => {
- it('should convert Gemini types to OpenAI JSON Schema types', async () => {
- const mockResponse = {
- id: 'chatcmpl-123',
- choices: [
- {
- index: 0,
- message: { role: 'assistant', content: 'Response' },
- finish_reason: 'stop',
- },
- ],
- created: 1677652288,
- model: 'gpt-4',
- };
-
- mockOpenAIClient.chat.completions.create.mockResolvedValue(mockResponse);
-
- const request: GenerateContentParameters = {
- contents: [{ role: 'user', parts: [{ text: 'Test' }] }],
- model: 'gpt-4',
- config: {
- tools: [
- {
- callTool: vi.fn(),
- tool: () =>
- Promise.resolve({
- functionDeclarations: [
- {
- name: 'test_function',
- description: 'Test function',
- parameters: {
- type: 'OBJECT',
- properties: {
- count: {
- type: 'INTEGER',
- minimum: '1',
- maximum: '100',
- },
- name: {
- type: 'STRING',
- minLength: '1',
- maxLength: '50',
- },
- score: { type: 'NUMBER', multipleOf: '0.1' },
- items: {
- type: 'ARRAY',
- minItems: '1',
- maxItems: '10',
- },
- },
- },
- },
- ],
- }),
- } as unknown as CallableTool,
- ],
- },
- };
-
- await generator.generateContent(request, 'test-prompt-id');
-
- expect(mockOpenAIClient.chat.completions.create).toHaveBeenCalledWith(
- expect.objectContaining({
- tools: [
- {
- type: 'function',
- function: {
- name: 'test_function',
- description: 'Test function',
- parameters: {
- type: 'object',
- properties: {
- count: { type: 'integer', minimum: 1, maximum: 100 },
- name: { type: 'string', minLength: 1, maxLength: 50 },
- score: { type: 'number', multipleOf: 0.1 },
- items: { type: 'array', minItems: 1, maxItems: 10 },
- },
- },
- },
- },
- ],
- }),
- );
- });
-
- it('should handle MCP tools with parametersJsonSchema', async () => {
- const mockResponse = {
- id: 'chatcmpl-123',
- choices: [
- {
- index: 0,
- message: { role: 'assistant', content: 'Response' },
- finish_reason: 'stop',
- },
- ],
- created: 1677652288,
- model: 'gpt-4',
- };
-
- mockOpenAIClient.chat.completions.create.mockResolvedValue(mockResponse);
-
- const request: GenerateContentParameters = {
- contents: [{ role: 'user', parts: [{ text: 'Test' }] }],
- model: 'gpt-4',
- config: {
- tools: [
- {
- callTool: vi.fn(),
- tool: () =>
- Promise.resolve({
- functionDeclarations: [
- {
- name: 'list-items',
- description: 'Get a list of items',
- parametersJsonSchema: {
- type: 'object',
- properties: {
- page_number: {
- type: 'number',
- description: 'Page number',
- },
- page_size: {
- type: 'number',
- description: 'Number of items per page',
- },
- },
- additionalProperties: false,
- $schema: 'http://json-schema.org/draft-07/schema#',
- },
- },
- ],
- }),
- } as unknown as CallableTool,
- ],
- },
- };
-
- await generator.generateContent(request, 'test-prompt-id');
-
- expect(mockOpenAIClient.chat.completions.create).toHaveBeenCalledWith(
- expect.objectContaining({
- tools: [
- {
- type: 'function',
- function: {
- name: 'list-items',
- description: 'Get a list of items',
- parameters: {
- type: 'object',
- properties: {
- page_number: {
- type: 'number',
- description: 'Page number',
- },
- page_size: {
- type: 'number',
- description: 'Number of items per page',
- },
- },
- additionalProperties: false,
- $schema: 'http://json-schema.org/draft-07/schema#',
- },
- },
- },
- ],
- }),
- );
- });
-
- it('should handle nested parameter objects', async () => {
- const mockResponse = {
- id: 'chatcmpl-123',
- choices: [
- {
- index: 0,
- message: { role: 'assistant', content: 'Response' },
- finish_reason: 'stop',
- },
- ],
- created: 1677652288,
- model: 'gpt-4',
- };
-
- mockOpenAIClient.chat.completions.create.mockResolvedValue(mockResponse);
-
- const request: GenerateContentParameters = {
- contents: [{ role: 'user', parts: [{ text: 'Test' }] }],
- model: 'gpt-4',
- config: {
- tools: [
- {
- callTool: vi.fn(),
- tool: () =>
- Promise.resolve({
- functionDeclarations: [
- {
- name: 'nested_function',
- description: 'Function with nested parameters',
- parameters: {
- type: 'OBJECT',
- properties: {
- config: {
- type: 'OBJECT',
- properties: {
- nested_count: { type: 'INTEGER' },
- nested_array: {
- type: 'ARRAY',
- items: { type: 'STRING' },
- },
- },
- },
- },
- },
- },
- ],
- }),
- } as unknown as CallableTool,
- ],
- },
- };
-
- await generator.generateContent(request, 'test-prompt-id');
-
- expect(mockOpenAIClient.chat.completions.create).toHaveBeenCalledWith(
- expect.objectContaining({
- tools: [
- {
- type: 'function',
- function: {
- name: 'nested_function',
- description: 'Function with nested parameters',
- parameters: {
- type: 'object',
- properties: {
- config: {
- type: 'object',
- properties: {
- nested_count: { type: 'integer' },
- nested_array: {
- type: 'array',
- items: { type: 'string' },
- },
- },
- },
- },
- },
- },
- },
- ],
- }),
- );
- });
- });
-
- describe('message cleanup and conversion', () => {
- it('should handle complex conversation with multiple tool calls', async () => {
- const mockResponse = {
- id: 'chatcmpl-123',
- choices: [
- {
- index: 0,
- message: { role: 'assistant', content: 'Response' },
- finish_reason: 'stop',
- },
- ],
- created: 1677652288,
- model: 'gpt-4',
- };
-
- mockOpenAIClient.chat.completions.create.mockResolvedValue(mockResponse);
-
- const request: GenerateContentParameters = {
- contents: [
- { role: 'user', parts: [{ text: 'What tools are available?' }] },
- {
- role: 'model',
- parts: [
- {
- functionCall: {
- id: 'call_1',
- name: 'list_tools',
- args: { category: 'all' },
- },
- },
- ],
- },
- {
- role: 'user',
- parts: [
- {
- functionResponse: {
- id: 'call_1',
- name: 'list_tools',
- response: { tools: ['calculator', 'weather'] },
- },
- },
- ],
- },
- {
- role: 'model',
- parts: [
- {
- functionCall: {
- id: 'call_2',
- name: 'get_weather',
- args: { location: 'NYC' },
- },
- },
- ],
- },
- {
- role: 'user',
- parts: [
- {
- functionResponse: {
- id: 'call_2',
- name: 'get_weather',
- response: { temperature: '22°C', condition: 'sunny' },
- },
- },
- ],
- },
- ],
- model: 'gpt-4',
- };
-
- await generator.generateContent(request, 'test-prompt-id');
-
- expect(mockOpenAIClient.chat.completions.create).toHaveBeenCalledWith(
- expect.objectContaining({
- messages: [
- { role: 'user', content: 'What tools are available?' },
- {
- role: 'assistant',
- content: null,
- tool_calls: [
- {
- id: 'call_1',
- type: 'function',
- function: {
- name: 'list_tools',
- arguments: '{"category":"all"}',
- },
- },
- ],
- },
- {
- role: 'tool',
- tool_call_id: 'call_1',
- content: '{"tools":["calculator","weather"]}',
- },
- {
- role: 'assistant',
- content: null,
- tool_calls: [
- {
- id: 'call_2',
- type: 'function',
- function: {
- name: 'get_weather',
- arguments: '{"location":"NYC"}',
- },
- },
- ],
- },
- {
- role: 'tool',
- tool_call_id: 'call_2',
- content: '{"temperature":"22°C","condition":"sunny"}',
- },
- ],
- }),
- );
- });
-
- it('should clean up orphaned tool calls without corresponding responses', async () => {
- const mockResponse = {
- id: 'chatcmpl-123',
- choices: [
- {
- index: 0,
- message: { role: 'assistant', content: 'Response' },
- finish_reason: 'stop',
- },
- ],
- created: 1677652288,
- model: 'gpt-4',
- };
-
- mockOpenAIClient.chat.completions.create.mockResolvedValue(mockResponse);
-
- const request: GenerateContentParameters = {
- contents: [
- { role: 'user', parts: [{ text: 'Test' }] },
- {
- role: 'model',
- parts: [
- {
- functionCall: {
- id: 'call_orphaned',
- name: 'orphaned_function',
- args: {},
- },
- },
- ],
- },
- {
- role: 'model',
- parts: [
- {
- functionCall: {
- id: 'call_valid',
- name: 'valid_function',
- args: {},
- },
- },
- ],
- },
- {
- role: 'user',
- parts: [
- {
- functionResponse: {
- id: 'call_valid',
- name: 'valid_function',
- response: { result: 'success' },
- },
- },
- ],
- },
- ],
- model: 'gpt-4',
- };
-
- await generator.generateContent(request, 'test-prompt-id');
-
- expect(mockOpenAIClient.chat.completions.create).toHaveBeenCalledWith(
- expect.objectContaining({
- messages: [
- { role: 'user', content: 'Test' },
- {
- role: 'assistant',
- content: null,
- tool_calls: [
- {
- id: 'call_valid',
- type: 'function',
- function: {
- name: 'valid_function',
- arguments: '{}',
- },
- },
- ],
- },
- {
- role: 'tool',
- tool_call_id: 'call_valid',
- content: '{"result":"success"}',
- },
- ],
- }),
- );
- });
-
- it('should merge consecutive assistant messages', async () => {
- const mockResponse = {
- id: 'chatcmpl-123',
- choices: [
- {
- index: 0,
- message: { role: 'assistant', content: 'Response' },
- finish_reason: 'stop',
- },
- ],
- created: 1677652288,
- model: 'gpt-4',
- };
-
- mockOpenAIClient.chat.completions.create.mockResolvedValue(mockResponse);
-
- const request: GenerateContentParameters = {
- contents: [
- { role: 'user', parts: [{ text: 'Hello' }] },
- { role: 'model', parts: [{ text: 'Part 1' }] },
- { role: 'model', parts: [{ text: 'Part 2' }] },
- { role: 'user', parts: [{ text: 'Continue' }] },
- ],
- model: 'gpt-4',
- };
-
- await generator.generateContent(request, 'test-prompt-id');
-
- expect(mockOpenAIClient.chat.completions.create).toHaveBeenCalledWith(
- expect.objectContaining({
- messages: [
- { role: 'user', content: 'Hello' },
- { role: 'assistant', content: 'Part 1Part 2' },
- { role: 'user', content: 'Continue' },
- ],
- }),
- );
- });
- });
-
- describe('error suppression functionality', () => {
- it('should allow subclasses to suppress error logging', async () => {
- class TestGenerator extends OpenAIContentGenerator {
- protected override shouldSuppressErrorLogging(): boolean {
- return true; // Always suppress for this test
- }
- }
-
- const contentGeneratorConfig = {
- model: 'gpt-4',
- apiKey: 'test-key',
- authType: AuthType.USE_OPENAI,
- enableOpenAILogging: false,
- timeout: 120000,
- maxRetries: 3,
- samplingParams: {
- temperature: 0.7,
- max_tokens: 1000,
- top_p: 0.9,
- },
- };
- const testGenerator = new TestGenerator(
- contentGeneratorConfig,
- mockConfig,
- );
- const consoleSpy = vi
- .spyOn(console, 'error')
- .mockImplementation(() => {});
-
- const apiError = new Error('Test error');
- mockOpenAIClient.chat.completions.create.mockRejectedValue(apiError);
-
- const request: GenerateContentParameters = {
- contents: [{ role: 'user', parts: [{ text: 'Hello' }] }],
- model: 'gpt-4',
- };
-
- await expect(
- testGenerator.generateContent(request, 'test-prompt-id'),
- ).rejects.toThrow();
-
- // Error logging should be suppressed
- expect(consoleSpy).not.toHaveBeenCalledWith(
- 'OpenAI API Error:',
- expect.any(String),
- );
-
- consoleSpy.mockRestore();
- });
-
- it('should log errors when not suppressed', async () => {
- const consoleSpy = vi
- .spyOn(console, 'error')
- .mockImplementation(() => {});
-
- const apiError = new Error('Test error');
- mockOpenAIClient.chat.completions.create.mockRejectedValue(apiError);
-
- const request: GenerateContentParameters = {
- contents: [{ role: 'user', parts: [{ text: 'Hello' }] }],
- model: 'gpt-4',
- };
-
- await expect(
- generator.generateContent(request, 'test-prompt-id'),
- ).rejects.toThrow();
-
- // Error logging should occur by default
- expect(consoleSpy).toHaveBeenCalledWith(
- 'OpenAI API Error:',
- 'Test error',
- );
-
- consoleSpy.mockRestore();
- });
- });
-
- describe('edge cases and error scenarios', () => {
- it('should handle malformed tool call arguments', async () => {
- const mockResponse = {
- id: 'chatcmpl-123',
- choices: [
- {
- index: 0,
- message: {
- role: 'assistant',
- content: null,
- tool_calls: [
- {
- id: 'call_123',
- type: 'function',
- function: {
- name: 'test_function',
- arguments: 'invalid json{',
- },
- },
- ],
- },
- finish_reason: 'tool_calls',
- },
- ],
- created: 1677652288,
- model: 'gpt-4',
- };
-
- mockOpenAIClient.chat.completions.create.mockResolvedValue(mockResponse);
-
- const request: GenerateContentParameters = {
- contents: [{ role: 'user', parts: [{ text: 'Test' }] }],
- model: 'gpt-4',
- };
-
- const result = await generator.generateContent(request, 'test-prompt-id');
-
- // Should handle malformed JSON gracefully
- if (
- result.candidates &&
- result.candidates.length > 0 &&
- result.candidates[0]
- ) {
- const firstCandidate = result.candidates[0];
- if (firstCandidate.content) {
- expect(firstCandidate.content.parts).toEqual([
- {
- functionCall: {
- id: 'call_123',
- name: 'test_function',
- args: {}, // Should default to empty object
- },
- },
- ]);
- }
- }
- });
-
- it('should handle streaming with malformed tool call arguments', async () => {
- const mockStream = [
- {
- id: 'chatcmpl-123',
- choices: [
- {
- index: 0,
- delta: {
- tool_calls: [
- {
- index: 0,
- id: 'call_123',
- function: { name: 'test_function' },
- },
- ],
- },
- finish_reason: null,
- },
- ],
- created: 1677652288,
- },
- {
- id: 'chatcmpl-123',
- choices: [
- {
- index: 0,
- delta: {
- tool_calls: [
- {
- index: 0,
- function: { arguments: 'invalid json{' },
- },
- ],
- },
- finish_reason: 'tool_calls',
- },
- ],
- created: 1677652288,
- },
- ];
-
- mockOpenAIClient.chat.completions.create.mockResolvedValue({
- async *[Symbol.asyncIterator]() {
- for (const chunk of mockStream) {
- yield chunk;
- }
- },
- });
-
- const request: GenerateContentParameters = {
- contents: [{ role: 'user', parts: [{ text: 'Test' }] }],
- model: 'gpt-4',
- };
-
- const stream = await generator.generateContentStream(
- request,
- 'test-prompt-id',
- );
- const responses = [];
- for await (const response of stream) {
- responses.push(response);
- }
-
- // Should handle malformed JSON in streaming gracefully
- const finalResponse = responses[responses.length - 1];
- if (
- finalResponse.candidates &&
- finalResponse.candidates.length > 0 &&
- finalResponse.candidates[0]
- ) {
- const firstCandidate = finalResponse.candidates[0];
- if (firstCandidate.content) {
- expect(firstCandidate.content.parts).toEqual([
- {
- functionCall: {
- id: 'call_123',
- name: 'test_function',
- args: {}, // Should default to empty object
- },
- },
- ]);
- }
- }
- });
-
- it('should handle empty or null content gracefully', async () => {
- const mockResponse = {
- id: 'chatcmpl-123',
- choices: [
- {
- index: 0,
- message: { role: 'assistant', content: null },
- finish_reason: 'stop',
- },
- ],
- created: 1677652288,
- model: 'gpt-4',
- };
-
- mockOpenAIClient.chat.completions.create.mockResolvedValue(mockResponse);
-
- const request: GenerateContentParameters = {
- contents: [],
- model: 'gpt-4',
- };
-
- const result = await generator.generateContent(request, 'test-prompt-id');
-
- expect(result.candidates).toHaveLength(1);
- if (
- result.candidates &&
- result.candidates.length > 0 &&
- result.candidates[0]
- ) {
- const firstCandidate = result.candidates[0];
- if (firstCandidate.content) {
- expect(firstCandidate.content.parts).toEqual([]);
- }
- }
- });
-
- it('should handle usage metadata estimation when breakdown is missing', async () => {
- const mockResponse = {
- id: 'chatcmpl-123',
- choices: [
- {
- index: 0,
- message: { role: 'assistant', content: 'Test response' },
- finish_reason: 'stop',
- },
- ],
- created: 1677652288,
- model: 'gpt-4',
- usage: {
- prompt_tokens: 0,
- completion_tokens: 0,
- total_tokens: 100,
- },
- };
-
- mockOpenAIClient.chat.completions.create.mockResolvedValue(mockResponse);
-
- const request: GenerateContentParameters = {
- contents: [{ role: 'user', parts: [{ text: 'Hello' }] }],
- model: 'gpt-4',
- };
-
- const result = await generator.generateContent(request, 'test-prompt-id');
-
- expect(result.usageMetadata).toEqual({
- promptTokenCount: 70, // 70% of 100
- candidatesTokenCount: 30, // 30% of 100
- totalTokenCount: 100,
- cachedContentTokenCount: 0,
- });
- });
-
- it('should handle cached token metadata', async () => {
- const mockResponse = {
- id: 'chatcmpl-123',
- choices: [
- {
- index: 0,
- message: { role: 'assistant', content: 'Test response' },
- finish_reason: 'stop',
- },
- ],
- created: 1677652288,
- model: 'gpt-4',
- usage: {
- prompt_tokens: 50,
- completion_tokens: 25,
- total_tokens: 75,
- prompt_tokens_details: {
- cached_tokens: 10,
- },
- },
- };
-
- mockOpenAIClient.chat.completions.create.mockResolvedValue(mockResponse);
-
- const request: GenerateContentParameters = {
- contents: [{ role: 'user', parts: [{ text: 'Hello' }] }],
- model: 'gpt-4',
- };
-
- const result = await generator.generateContent(request, 'test-prompt-id');
-
- expect(result.usageMetadata).toEqual({
- promptTokenCount: 50,
- candidatesTokenCount: 25,
- totalTokenCount: 75,
- cachedContentTokenCount: 10,
- });
- });
- });
-
- describe('request/response logging conversion', () => {
- it('should convert complex Gemini request to OpenAI format for logging', async () => {
- const loggingConfig = {
- getContentGeneratorConfig: vi.fn().mockReturnValue({
- enableOpenAILogging: true,
- samplingParams: {
- temperature: 0.8,
- max_tokens: 500,
- },
- }),
- getCliVersion: vi.fn().mockReturnValue('1.0.0'),
- } as unknown as Config;
-
- const contentGeneratorConfig = {
- model: 'gpt-4',
- apiKey: 'test-key',
- authType: AuthType.USE_OPENAI,
- enableOpenAILogging: true,
- samplingParams: {
- temperature: 0.8,
- max_tokens: 500,
- },
- };
- const loggingGenerator = new OpenAIContentGenerator(
- contentGeneratorConfig,
- loggingConfig,
- );
-
- const mockResponse = {
- id: 'chatcmpl-123',
- choices: [
- {
- index: 0,
- message: {
- role: 'assistant',
- content: null,
- tool_calls: [
- {
- id: 'call_123',
- type: 'function',
- function: {
- name: 'test_function',
- arguments: '{"param":"value"}',
- },
- },
- ],
- },
- finish_reason: 'tool_calls',
- },
- ],
- created: 1677652288,
- model: 'gpt-4',
- usage: {
- prompt_tokens: 100,
- completion_tokens: 50,
- total_tokens: 150,
- },
- };
-
- mockOpenAIClient.chat.completions.create.mockResolvedValue(mockResponse);
-
- const request: GenerateContentParameters = {
- contents: [
- { role: 'user', parts: [{ text: 'Test complex request' }] },
- {
- role: 'model',
- parts: [
- {
- functionCall: {
- id: 'prev_call',
- name: 'previous_function',
- args: { data: 'test' },
- },
- },
- ],
- },
- {
- role: 'user',
- parts: [
- {
- functionResponse: {
- id: 'prev_call',
- name: 'previous_function',
- response: { result: 'success' },
- },
- },
- ],
- },
- ],
- model: 'gpt-4',
- config: {
- systemInstruction: 'You are a helpful assistant',
- temperature: 0.9,
- tools: [
- {
- callTool: vi.fn(),
- tool: () =>
- Promise.resolve({
- functionDeclarations: [
- {
- name: 'test_function',
- description: 'Test function',
- parameters: { type: 'object' },
- },
- ],
- }),
- } as unknown as CallableTool,
- ],
- },
- };
-
- await loggingGenerator.generateContent(request, 'test-prompt-id');
-
- // Verify that logging was called with properly converted request/response
- const { openaiLogger } = await import('../utils/openaiLogger.js');
- expect(openaiLogger.logInteraction).toHaveBeenCalledWith(
- expect.objectContaining({
- model: 'gpt-4',
- messages: [
- {
- role: 'system',
- content: 'You are a helpful assistant',
- },
- {
- role: 'user',
- content: 'Test complex request',
- },
- {
- role: 'assistant',
- content: null,
- tool_calls: [
- {
- id: 'prev_call',
- type: 'function',
- function: {
- name: 'previous_function',
- arguments: '{"data":"test"}',
- },
- },
- ],
- },
- {
- role: 'tool',
- tool_call_id: 'prev_call',
- content: '{"result":"success"}',
- },
- ],
- temperature: 0.8, // Config override
- max_tokens: 500, // Config override
- top_p: 1, // Default value
- tools: [
- {
- type: 'function',
- function: {
- name: 'test_function',
- description: 'Test function',
- parameters: {
- type: 'object',
- },
- },
- },
- ],
- }),
- expect.objectContaining({
- id: 'chatcmpl-123',
- object: 'chat.completion',
- created: 1677652288,
- model: 'gpt-4',
- choices: [
- {
- index: 0,
- message: {
- role: 'assistant',
- content: '',
- tool_calls: [
- {
- id: 'call_123',
- type: 'function',
- function: {
- name: 'test_function',
- arguments: '{"param":"value"}',
- },
- },
- ],
- },
- finish_reason: 'stop',
- },
- ],
- usage: {
- prompt_tokens: 100,
- completion_tokens: 50,
- total_tokens: 150,
- },
- }),
- );
- });
- });
-
- describe('advanced streaming scenarios', () => {
- it('should combine streaming responses correctly for logging', async () => {
- const loggingConfig = {
- getContentGeneratorConfig: vi.fn().mockReturnValue({
- enableOpenAILogging: true,
- }),
- getCliVersion: vi.fn().mockReturnValue('1.0.0'),
- } as unknown as Config;
-
- const contentGeneratorConfig = {
- model: 'gpt-4',
- apiKey: 'test-key',
- authType: AuthType.USE_OPENAI,
- enableOpenAILogging: true,
- };
- const loggingGenerator = new OpenAIContentGenerator(
- contentGeneratorConfig,
- loggingConfig,
- );
-
- const mockStream = [
- {
- id: 'chatcmpl-123',
- choices: [
- {
- index: 0,
- delta: { content: 'Hello' },
- finish_reason: null,
- },
- ],
- created: 1677652288,
- },
- {
- id: 'chatcmpl-123',
- choices: [
- {
- index: 0,
- delta: { content: ' world' },
- finish_reason: null,
- },
- ],
- created: 1677652288,
- },
- {
- id: 'chatcmpl-123',
- choices: [
- {
- index: 0,
- delta: {},
- finish_reason: 'stop',
- },
- ],
- created: 1677652288,
- usage: {
- prompt_tokens: 10,
- completion_tokens: 5,
- total_tokens: 15,
- },
- },
- ];
-
- mockOpenAIClient.chat.completions.create.mockResolvedValue({
- async *[Symbol.asyncIterator]() {
- for (const chunk of mockStream) {
- yield chunk;
- }
- },
- });
-
- const request: GenerateContentParameters = {
- contents: [{ role: 'user', parts: [{ text: 'Hello' }] }],
- model: 'gpt-4',
- };
-
- const stream = await loggingGenerator.generateContentStream(
- request,
- 'test-prompt-id',
- );
- const responses = [];
- for await (const response of stream) {
- responses.push(response);
- }
-
- // Verify logging was called with combined content
- const { openaiLogger } = await import('../utils/openaiLogger.js');
- expect(openaiLogger.logInteraction).toHaveBeenCalledWith(
- expect.any(Object),
- expect.objectContaining({
- choices: [
- expect.objectContaining({
- message: expect.objectContaining({
- content: 'Hello world', // Combined text
- }),
- }),
- ],
- }),
- );
- });
-
- it('should handle streaming without choices', async () => {
- const mockStream = [
- {
- id: 'chatcmpl-123',
- choices: [],
- created: 1677652288,
- },
- ];
-
- mockOpenAIClient.chat.completions.create.mockResolvedValue({
- async *[Symbol.asyncIterator]() {
- for (const chunk of mockStream) {
- yield chunk;
- }
- },
- });
-
- const request: GenerateContentParameters = {
- contents: [{ role: 'user', parts: [{ text: 'Hello' }] }],
- model: 'gpt-4',
- };
-
- const stream = await generator.generateContentStream(
- request,
- 'test-prompt-id',
- );
- const responses = [];
- for await (const response of stream) {
- responses.push(response);
- }
-
- expect(responses).toHaveLength(1);
- expect(responses[0].candidates).toEqual([]);
- });
- });
-
- describe('embed content edge cases', () => {
- it('should handle mixed content types in embed request', async () => {
- const mockEmbedding = {
- data: [{ embedding: [0.1, 0.2, 0.3] }],
- model: 'text-embedding-ada-002',
- usage: { prompt_tokens: 5, total_tokens: 5 },
- };
-
- mockOpenAIClient.embeddings.create.mockResolvedValue(mockEmbedding);
-
- const request: EmbedContentParameters = {
- contents: 'Hello world Direct string Another part',
- model: 'text-embedding-ada-002',
- };
-
- const result = await generator.embedContent(request);
-
- expect(mockOpenAIClient.embeddings.create).toHaveBeenCalledWith({
- model: 'text-embedding-ada-002',
- input: 'Hello world Direct string Another part',
- });
-
- expect(result.embeddings).toHaveLength(1);
- expect(result.embeddings?.[0]?.values).toEqual([0.1, 0.2, 0.3]);
- });
-
- it('should handle empty content in embed request', async () => {
- const mockEmbedding = {
- data: [{ embedding: [] }],
- };
-
- mockOpenAIClient.embeddings.create.mockResolvedValue(mockEmbedding);
-
- const request: EmbedContentParameters = {
- contents: [],
- model: 'text-embedding-ada-002',
- };
-
- await generator.embedContent(request);
-
- expect(mockOpenAIClient.embeddings.create).toHaveBeenCalledWith({
- model: 'text-embedding-ada-002',
- input: '',
- });
- });
- });
-
- describe('system instruction edge cases', () => {
- it('should handle array system instructions', async () => {
- const mockResponse = {
- id: 'chatcmpl-123',
- choices: [
- {
- index: 0,
- message: { role: 'assistant', content: 'Response' },
- finish_reason: 'stop',
- },
- ],
- created: 1677652288,
- model: 'gpt-4',
- };
-
- mockOpenAIClient.chat.completions.create.mockResolvedValue(mockResponse);
-
- const request: GenerateContentParameters = {
- contents: [{ role: 'user', parts: [{ text: 'Hello' }] }],
- model: 'gpt-4',
- config: {
- systemInstruction: 'You are helpful\nBe concise',
- },
- };
-
- await generator.generateContent(request, 'test-prompt-id');
-
- expect(mockOpenAIClient.chat.completions.create).toHaveBeenCalledWith(
- expect.objectContaining({
- messages: [
- { role: 'system', content: 'You are helpful\nBe concise' },
- { role: 'user', content: 'Hello' },
- ],
- }),
- );
- });
-
- it('should handle object system instruction', async () => {
- const mockResponse = {
- id: 'chatcmpl-123',
- choices: [
- {
- index: 0,
- message: { role: 'assistant', content: 'Response' },
- finish_reason: 'stop',
- },
- ],
- created: 1677652288,
- model: 'gpt-4',
- };
-
- mockOpenAIClient.chat.completions.create.mockResolvedValue(mockResponse);
-
- const request: GenerateContentParameters = {
- contents: [{ role: 'user', parts: [{ text: 'Hello' }] }],
- model: 'gpt-4',
- config: {
- systemInstruction: {
- parts: [{ text: 'System message' }, { text: 'Additional text' }],
- } as Content,
- },
- };
-
- await generator.generateContent(request, 'test-prompt-id');
-
- expect(mockOpenAIClient.chat.completions.create).toHaveBeenCalledWith(
- expect.objectContaining({
- messages: [
- { role: 'system', content: 'System message\nAdditional text' },
- { role: 'user', content: 'Hello' },
- ],
- }),
- );
- });
- });
-
- describe('sampling parameters edge cases', () => {
- it('should handle undefined sampling parameters gracefully', async () => {
- const configWithUndefined = {
- getContentGeneratorConfig: vi.fn().mockReturnValue({
- samplingParams: {
- temperature: undefined,
- max_tokens: undefined,
- top_p: undefined,
- },
- }),
- getCliVersion: vi.fn().mockReturnValue('1.0.0'),
- } as unknown as Config;
-
- const contentGeneratorConfig = {
- model: 'gpt-4',
- apiKey: 'test-key',
- authType: AuthType.USE_OPENAI,
- samplingParams: {
- temperature: undefined,
- max_tokens: undefined,
- top_p: undefined,
- },
- };
- const testGenerator = new OpenAIContentGenerator(
- contentGeneratorConfig,
- configWithUndefined,
- );
-
- const mockResponse = {
- id: 'chatcmpl-123',
- choices: [
- {
- index: 0,
- message: { role: 'assistant', content: 'Response' },
- finish_reason: 'stop',
- },
- ],
- created: 1677652288,
- model: 'gpt-4',
- };
-
- mockOpenAIClient.chat.completions.create.mockResolvedValue(mockResponse);
-
- const request: GenerateContentParameters = {
- contents: [{ role: 'user', parts: [{ text: 'Hello' }] }],
- model: 'gpt-4',
- config: {
- temperature: undefined,
- maxOutputTokens: undefined,
- topP: undefined,
- },
- };
-
- await testGenerator.generateContent(request, 'test-prompt-id');
-
- expect(mockOpenAIClient.chat.completions.create).toHaveBeenCalledWith(
- expect.objectContaining({
- temperature: 0.0, // Default value
- top_p: 1.0, // Default value
- // max_tokens should not be present when undefined
- }),
- );
- });
-
- it('should handle all config-level sampling parameters', async () => {
- const fullSamplingConfig = {
- getContentGeneratorConfig: vi.fn().mockReturnValue({
- samplingParams: {
- temperature: 0.8,
- max_tokens: 1500,
- top_p: 0.95,
- top_k: 40,
- repetition_penalty: 1.1,
- presence_penalty: 0.5,
- frequency_penalty: 0.3,
- },
- }),
- getCliVersion: vi.fn().mockReturnValue('1.0.0'),
- } as unknown as Config;
-
- const contentGeneratorConfig = {
- model: 'gpt-4',
- apiKey: 'test-key',
- authType: AuthType.USE_OPENAI,
- samplingParams: {
- temperature: 0.8,
- max_tokens: 1500,
- top_p: 0.95,
- top_k: 40,
- repetition_penalty: 1.1,
- presence_penalty: 0.5,
- frequency_penalty: 0.3,
- },
- };
- const testGenerator = new OpenAIContentGenerator(
- contentGeneratorConfig,
- fullSamplingConfig,
- );
-
- const mockResponse = {
- id: 'chatcmpl-123',
- choices: [
- {
- index: 0,
- message: { role: 'assistant', content: 'Response' },
- finish_reason: 'stop',
- },
- ],
- created: 1677652288,
- model: 'gpt-4',
- };
-
- mockOpenAIClient.chat.completions.create.mockResolvedValue(mockResponse);
-
- const request: GenerateContentParameters = {
- contents: [{ role: 'user', parts: [{ text: 'Hello' }] }],
- model: 'gpt-4',
- };
-
- await testGenerator.generateContent(request, 'test-prompt-id');
-
- expect(mockOpenAIClient.chat.completions.create).toHaveBeenCalledWith(
- expect.objectContaining({
- temperature: 0.8,
- max_tokens: 1500,
- top_p: 0.95,
- top_k: 40,
- repetition_penalty: 1.1,
- presence_penalty: 0.5,
- frequency_penalty: 0.3,
- }),
- );
- });
- });
-
- describe('token counting edge cases', () => {
- it('should handle tiktoken import failure with console warning', async () => {
- // Mock tiktoken to fail on import
- vi.doMock('tiktoken', () => {
- throw new Error('Failed to import tiktoken');
- });
-
- const consoleSpy = vi.spyOn(console, 'warn').mockImplementation(() => {});
-
- const request: CountTokensParameters = {
- contents: [{ role: 'user', parts: [{ text: 'Test content' }] }],
- model: 'gpt-4',
- };
-
- const result = await generator.countTokens(request);
-
- expect(consoleSpy).toHaveBeenCalledWith(
- expect.stringMatching(/Failed to load tiktoken.*falling back/),
- expect.any(Error),
- );
-
- // Should use character approximation
- expect(result.totalTokens).toBeGreaterThan(0);
-
- consoleSpy.mockRestore();
- });
- });
-
- describe('metadata control', () => {
- it('should include metadata when authType is QWEN_OAUTH', async () => {
- const qwenConfig = {
- getContentGeneratorConfig: vi.fn().mockReturnValue({
- authType: 'qwen-oauth',
- enableOpenAILogging: false,
- }),
- getSessionId: vi.fn().mockReturnValue('test-session-id'),
- getCliVersion: vi.fn().mockReturnValue('1.0.0'),
- } as unknown as Config;
-
- const contentGeneratorConfig = {
- model: 'qwen-turbo',
- apiKey: 'test-key',
- authType: AuthType.QWEN_OAUTH,
- enableOpenAILogging: false,
- };
- const qwenGenerator = new OpenAIContentGenerator(
- contentGeneratorConfig,
- qwenConfig,
- );
-
- const mockResponse = {
- id: 'chatcmpl-123',
- choices: [
- {
- index: 0,
- message: { role: 'assistant', content: 'Response' },
- finish_reason: 'stop',
- },
- ],
- created: 1677652288,
- model: 'qwen-turbo',
- };
-
- mockOpenAIClient.chat.completions.create.mockResolvedValue(mockResponse);
-
- const request: GenerateContentParameters = {
- contents: [{ role: 'user', parts: [{ text: 'Hello' }] }],
- model: 'qwen-turbo',
- };
-
- await qwenGenerator.generateContent(request, 'test-prompt-id');
-
- expect(mockOpenAIClient.chat.completions.create).toHaveBeenCalledWith(
- expect.objectContaining({
- metadata: {
- sessionId: 'test-session-id',
- promptId: 'test-prompt-id',
- },
- }),
- );
- });
-
- it('should include metadata when baseURL is dashscope openai compatible mode', async () => {
- const dashscopeConfig = {
- getContentGeneratorConfig: vi.fn().mockReturnValue({
- authType: 'openai', // Not QWEN_OAUTH
- enableOpenAILogging: false,
- }),
- getSessionId: vi.fn().mockReturnValue('dashscope-session-id'),
- getCliVersion: vi.fn().mockReturnValue('1.0.0'),
- } as unknown as Config;
-
- const contentGeneratorConfig = {
- model: 'qwen-turbo',
- apiKey: 'test-key',
- baseUrl: 'https://dashscope.aliyuncs.com/compatible-mode/v1',
- authType: AuthType.USE_OPENAI,
- enableOpenAILogging: false,
- };
- const dashscopeGenerator = new OpenAIContentGenerator(
- contentGeneratorConfig,
- dashscopeConfig,
- );
-
- // Debug: Check if the client was created with the correct baseURL
- expect(vi.mocked(OpenAI)).toHaveBeenCalledWith(
- expect.objectContaining({
- baseURL: 'https://dashscope.aliyuncs.com/compatible-mode/v1',
- }),
- );
-
- // Mock the client's baseURL property to return the expected value
- Object.defineProperty(dashscopeGenerator['client'], 'baseURL', {
- value: 'https://dashscope.aliyuncs.com/compatible-mode/v1',
- writable: true,
- });
-
- const mockResponse = {
- id: 'chatcmpl-123',
- choices: [
- {
- index: 0,
- message: { role: 'assistant', content: 'Response' },
- finish_reason: 'stop',
- },
- ],
- created: 1677652288,
- model: 'qwen-turbo',
- };
-
- mockOpenAIClient.chat.completions.create.mockResolvedValue(mockResponse);
-
- const request: GenerateContentParameters = {
- contents: [{ role: 'user', parts: [{ text: 'Hello' }] }],
- model: 'qwen-turbo',
- };
-
- await dashscopeGenerator.generateContent(request, 'dashscope-prompt-id');
-
- expect(mockOpenAIClient.chat.completions.create).toHaveBeenCalledWith(
- expect.objectContaining({
- metadata: {
- sessionId: 'dashscope-session-id',
- promptId: 'dashscope-prompt-id',
- },
- }),
- );
- });
-
- it('should NOT include metadata for regular OpenAI providers', async () => {
- const regularConfig = {
- getContentGeneratorConfig: vi.fn().mockReturnValue({
- authType: 'openai',
- enableOpenAILogging: false,
- }),
- getSessionId: vi.fn().mockReturnValue('regular-session-id'),
- getCliVersion: vi.fn().mockReturnValue('1.0.0'),
- } as unknown as Config;
-
- const contentGeneratorConfig = {
- model: 'gpt-4',
-
- apiKey: 'test-key',
-
- authType: AuthType.USE_OPENAI,
-
- enableOpenAILogging: false,
- };
-
- const regularGenerator = new OpenAIContentGenerator(
- contentGeneratorConfig,
- regularConfig,
- );
-
- const mockResponse = {
- id: 'chatcmpl-123',
- choices: [
- {
- index: 0,
- message: { role: 'assistant', content: 'Response' },
- finish_reason: 'stop',
- },
- ],
- created: 1677652288,
- model: 'gpt-4',
- };
-
- mockOpenAIClient.chat.completions.create.mockResolvedValue(mockResponse);
-
- const request: GenerateContentParameters = {
- contents: [{ role: 'user', parts: [{ text: 'Hello' }] }],
- model: 'gpt-4',
- };
-
- await regularGenerator.generateContent(request, 'regular-prompt-id');
-
- // Should NOT include metadata
- expect(mockOpenAIClient.chat.completions.create).toHaveBeenCalledWith(
- expect.not.objectContaining({
- metadata: expect.any(Object),
- }),
- );
- });
-
- it('should NOT include metadata for other auth types', async () => {
- const otherAuthConfig = {
- getContentGeneratorConfig: vi.fn().mockReturnValue({
- authType: 'gemini-api-key',
- enableOpenAILogging: false,
- }),
- getSessionId: vi.fn().mockReturnValue('other-session-id'),
- getCliVersion: vi.fn().mockReturnValue('1.0.0'),
- } as unknown as Config;
-
- const contentGeneratorConfig = {
- model: 'gpt-4',
-
- apiKey: 'test-key',
-
- authType: AuthType.USE_OPENAI,
-
- enableOpenAILogging: false,
- };
-
- const otherGenerator = new OpenAIContentGenerator(
- contentGeneratorConfig,
- otherAuthConfig,
- );
-
- const mockResponse = {
- id: 'chatcmpl-123',
- choices: [
- {
- index: 0,
- message: { role: 'assistant', content: 'Response' },
- finish_reason: 'stop',
- },
- ],
- created: 1677652288,
- model: 'gpt-4',
- };
-
- mockOpenAIClient.chat.completions.create.mockResolvedValue(mockResponse);
-
- const request: GenerateContentParameters = {
- contents: [{ role: 'user', parts: [{ text: 'Hello' }] }],
- model: 'gpt-4',
- };
-
- await otherGenerator.generateContent(request, 'other-prompt-id');
-
- // Should NOT include metadata
- expect(mockOpenAIClient.chat.completions.create).toHaveBeenCalledWith(
- expect.not.objectContaining({
- metadata: expect.any(Object),
- }),
- );
- });
-
- it('should NOT include metadata for other base URLs', async () => {
- // Mock environment to set a different base URL
- vi.stubEnv('OPENAI_BASE_URL', 'https://api.openai.com/v1');
-
- const otherBaseUrlConfig = {
- getContentGeneratorConfig: vi.fn().mockReturnValue({
- authType: 'openai',
- enableOpenAILogging: false,
- }),
- getSessionId: vi.fn().mockReturnValue('other-base-url-session-id'),
- getCliVersion: vi.fn().mockReturnValue('1.0.0'),
- } as unknown as Config;
-
- const contentGeneratorConfig = {
- model: 'gpt-4',
-
- apiKey: 'test-key',
-
- authType: AuthType.USE_OPENAI,
-
- enableOpenAILogging: false,
- };
-
- const otherBaseUrlGenerator = new OpenAIContentGenerator(
- contentGeneratorConfig,
- otherBaseUrlConfig,
- );
-
- const mockResponse = {
- id: 'chatcmpl-123',
- choices: [
- {
- index: 0,
- message: { role: 'assistant', content: 'Response' },
- finish_reason: 'stop',
- },
- ],
- created: 1677652288,
- model: 'gpt-4',
- };
-
- mockOpenAIClient.chat.completions.create.mockResolvedValue(mockResponse);
-
- const request: GenerateContentParameters = {
- contents: [{ role: 'user', parts: [{ text: 'Hello' }] }],
- model: 'gpt-4',
- };
-
- await otherBaseUrlGenerator.generateContent(
- request,
- 'other-base-url-prompt-id',
- );
-
- // Should NOT include metadata
- expect(mockOpenAIClient.chat.completions.create).toHaveBeenCalledWith(
- expect.not.objectContaining({
- metadata: expect.any(Object),
- }),
- );
- });
-
- it('should include metadata in streaming requests when conditions are met', async () => {
- const qwenConfig = {
- getContentGeneratorConfig: vi.fn().mockReturnValue({
- authType: 'qwen-oauth',
- enableOpenAILogging: false,
- }),
- getSessionId: vi.fn().mockReturnValue('streaming-session-id'),
- getCliVersion: vi.fn().mockReturnValue('1.0.0'),
- } as unknown as Config;
-
- const contentGeneratorConfig = {
- model: 'qwen-turbo',
-
- apiKey: 'test-key',
-
- authType: AuthType.QWEN_OAUTH,
-
- enableOpenAILogging: false,
- };
-
- const qwenGenerator = new OpenAIContentGenerator(
- contentGeneratorConfig,
- qwenConfig,
- );
-
- const mockStream = [
- {
- id: 'chatcmpl-123',
- choices: [
- {
- index: 0,
- delta: { content: 'Hello' },
- finish_reason: null,
- },
- ],
- created: 1677652288,
- },
- {
- id: 'chatcmpl-123',
- choices: [
- {
- index: 0,
- delta: { content: ' there!' },
- finish_reason: 'stop',
- },
- ],
- created: 1677652288,
- },
- ];
-
- mockOpenAIClient.chat.completions.create.mockResolvedValue({
- async *[Symbol.asyncIterator]() {
- for (const chunk of mockStream) {
- yield chunk;
- }
- },
- });
-
- const request: GenerateContentParameters = {
- contents: [{ role: 'user', parts: [{ text: 'Hello' }] }],
- model: 'qwen-turbo',
- };
-
- const stream = await qwenGenerator.generateContentStream(
- request,
- 'streaming-prompt-id',
- );
-
- // Verify metadata was included in the streaming request
- expect(mockOpenAIClient.chat.completions.create).toHaveBeenCalledWith(
- expect.objectContaining({
- metadata: {
- sessionId: 'streaming-session-id',
- promptId: 'streaming-prompt-id',
- },
- }),
- );
-
- // Consume the stream to complete the test
- const responses = [];
- for await (const response of stream) {
- responses.push(response);
- }
- expect(responses).toHaveLength(2);
- });
-
- it('should NOT include metadata in streaming requests when conditions are not met', async () => {
- const regularConfig = {
- getContentGeneratorConfig: vi.fn().mockReturnValue({
- authType: 'openai',
- enableOpenAILogging: false,
- }),
- getSessionId: vi.fn().mockReturnValue('regular-streaming-session-id'),
- getCliVersion: vi.fn().mockReturnValue('1.0.0'),
- } as unknown as Config;
-
- const contentGeneratorConfig = {
- model: 'gpt-4',
-
- apiKey: 'test-key',
-
- authType: AuthType.USE_OPENAI,
-
- enableOpenAILogging: false,
- };
-
- const regularGenerator = new OpenAIContentGenerator(
- contentGeneratorConfig,
- regularConfig,
- );
-
- const mockStream = [
- {
- id: 'chatcmpl-123',
- choices: [
- {
- index: 0,
- delta: { content: 'Hello' },
- finish_reason: null,
- },
- ],
- created: 1677652288,
- },
- {
- id: 'chatcmpl-123',
- choices: [
- {
- index: 0,
- delta: { content: ' there!' },
- finish_reason: 'stop',
- },
- ],
- created: 1677652288,
- },
- ];
-
- mockOpenAIClient.chat.completions.create.mockResolvedValue({
- async *[Symbol.asyncIterator]() {
- for (const chunk of mockStream) {
- yield chunk;
- }
- },
- });
-
- const request: GenerateContentParameters = {
- contents: [{ role: 'user', parts: [{ text: 'Hello' }] }],
- model: 'gpt-4',
- };
-
- const stream = await regularGenerator.generateContentStream(
- request,
- 'regular-streaming-prompt-id',
- );
-
- // Verify metadata was NOT included in the streaming request
- expect(mockOpenAIClient.chat.completions.create).toHaveBeenCalledWith(
- expect.not.objectContaining({
- metadata: expect.any(Object),
- }),
- );
-
- // Consume the stream to complete the test
- const responses = [];
- for await (const response of stream) {
- responses.push(response);
- }
- expect(responses).toHaveLength(2);
- });
-
- it('should handle undefined sessionId gracefully', async () => {
- const qwenConfig = {
- getContentGeneratorConfig: vi.fn().mockReturnValue({
- authType: 'qwen-oauth',
- enableOpenAILogging: false,
- }),
- getSessionId: vi.fn().mockReturnValue(undefined), // Undefined session ID
- getCliVersion: vi.fn().mockReturnValue('1.0.0'),
- } as unknown as Config;
-
- const contentGeneratorConfig = {
- model: 'qwen-turbo',
-
- apiKey: 'test-key',
-
- authType: AuthType.QWEN_OAUTH,
-
- enableOpenAILogging: false,
- };
-
- const qwenGenerator = new OpenAIContentGenerator(
- contentGeneratorConfig,
- qwenConfig,
- );
-
- const mockResponse = {
- id: 'chatcmpl-123',
- choices: [
- {
- index: 0,
- message: { role: 'assistant', content: 'Response' },
- finish_reason: 'stop',
- },
- ],
- created: 1677652288,
- model: 'qwen-turbo',
- };
-
- mockOpenAIClient.chat.completions.create.mockResolvedValue(mockResponse);
-
- const request: GenerateContentParameters = {
- contents: [{ role: 'user', parts: [{ text: 'Hello' }] }],
- model: 'qwen-turbo',
- };
-
- await qwenGenerator.generateContent(
- request,
- 'undefined-session-prompt-id',
- );
-
- expect(mockOpenAIClient.chat.completions.create).toHaveBeenCalledWith(
- expect.objectContaining({
- metadata: {
- sessionId: undefined,
- promptId: 'undefined-session-prompt-id',
- },
- }),
- );
- });
-
- it('should handle undefined baseURL gracefully', async () => {
- // Ensure no base URL is set
- vi.stubEnv('OPENAI_BASE_URL', '');
-
- const noBaseUrlConfig = {
- getContentGeneratorConfig: vi.fn().mockReturnValue({
- authType: 'openai',
- enableOpenAILogging: false,
- }),
- getSessionId: vi.fn().mockReturnValue('no-base-url-session-id'),
- getCliVersion: vi.fn().mockReturnValue('1.0.0'),
- } as unknown as Config;
-
- const contentGeneratorConfig = {
- model: 'gpt-4',
-
- apiKey: 'test-key',
-
- authType: AuthType.USE_OPENAI,
-
- enableOpenAILogging: false,
- };
-
- const noBaseUrlGenerator = new OpenAIContentGenerator(
- contentGeneratorConfig,
- noBaseUrlConfig,
- );
-
- const mockResponse = {
- id: 'chatcmpl-123',
- choices: [
- {
- index: 0,
- message: { role: 'assistant', content: 'Response' },
- finish_reason: 'stop',
- },
- ],
- created: 1677652288,
- model: 'gpt-4',
- };
-
- mockOpenAIClient.chat.completions.create.mockResolvedValue(mockResponse);
-
- const request: GenerateContentParameters = {
- contents: [{ role: 'user', parts: [{ text: 'Hello' }] }],
- model: 'gpt-4',
- };
-
- await noBaseUrlGenerator.generateContent(
- request,
- 'no-base-url-prompt-id',
- );
-
- // Should NOT include metadata when baseURL is empty
- expect(mockOpenAIClient.chat.completions.create).toHaveBeenCalledWith(
- expect.not.objectContaining({
- metadata: expect.any(Object),
- }),
- );
- });
-
- it('should handle undefined authType gracefully', async () => {
- const undefinedAuthConfig = {
- getContentGeneratorConfig: vi.fn().mockReturnValue({
- authType: undefined, // Undefined auth type
- enableOpenAILogging: false,
- }),
- getSessionId: vi.fn().mockReturnValue('undefined-auth-session-id'),
- getCliVersion: vi.fn().mockReturnValue('1.0.0'),
- } as unknown as Config;
-
- const contentGeneratorConfig = {
- model: 'gpt-4',
-
- apiKey: 'test-key',
-
- authType: AuthType.USE_OPENAI,
-
- enableOpenAILogging: false,
- };
-
- const undefinedAuthGenerator = new OpenAIContentGenerator(
- contentGeneratorConfig,
- undefinedAuthConfig,
- );
-
- const mockResponse = {
- id: 'chatcmpl-123',
- choices: [
- {
- index: 0,
- message: { role: 'assistant', content: 'Response' },
- finish_reason: 'stop',
- },
- ],
- created: 1677652288,
- model: 'gpt-4',
- };
-
- mockOpenAIClient.chat.completions.create.mockResolvedValue(mockResponse);
-
- const request: GenerateContentParameters = {
- contents: [{ role: 'user', parts: [{ text: 'Hello' }] }],
- model: 'gpt-4',
- };
-
- await undefinedAuthGenerator.generateContent(
- request,
- 'undefined-auth-prompt-id',
- );
-
- // Should NOT include metadata when authType is undefined
- expect(mockOpenAIClient.chat.completions.create).toHaveBeenCalledWith(
- expect.not.objectContaining({
- metadata: expect.any(Object),
- }),
- );
- });
-
- it('should handle undefined config gracefully', async () => {
- const undefinedConfig = {
- getContentGeneratorConfig: vi.fn().mockReturnValue(undefined), // Undefined config
- getSessionId: vi.fn().mockReturnValue('undefined-config-session-id'),
- getCliVersion: vi.fn().mockReturnValue('1.0.0'),
- } as unknown as Config;
-
- const contentGeneratorConfig = {
- model: 'gpt-4',
-
- apiKey: 'test-key',
-
- authType: AuthType.USE_OPENAI,
-
- enableOpenAILogging: false,
- };
-
- const undefinedConfigGenerator = new OpenAIContentGenerator(
- contentGeneratorConfig,
- undefinedConfig,
- );
-
- const mockResponse = {
- id: 'chatcmpl-123',
- choices: [
- {
- index: 0,
- message: { role: 'assistant', content: 'Response' },
- finish_reason: 'stop',
- },
- ],
- created: 1677652288,
- model: 'gpt-4',
- };
-
- mockOpenAIClient.chat.completions.create.mockResolvedValue(mockResponse);
-
- const request: GenerateContentParameters = {
- contents: [{ role: 'user', parts: [{ text: 'Hello' }] }],
- model: 'gpt-4',
- };
-
- await undefinedConfigGenerator.generateContent(
- request,
- 'undefined-config-prompt-id',
- );
-
- // Should NOT include metadata when config is undefined
- expect(mockOpenAIClient.chat.completions.create).toHaveBeenCalledWith(
- expect.not.objectContaining({
- metadata: expect.any(Object),
- }),
- );
- });
- });
-
- describe('cache control for DashScope', () => {
- it('should add cache control to system message for DashScope providers', async () => {
- // Mock environment to set dashscope base URL
- vi.stubEnv(
- 'OPENAI_BASE_URL',
- 'https://dashscope.aliyuncs.com/compatible-mode/v1',
- );
-
- const dashscopeConfig = {
- getContentGeneratorConfig: vi.fn().mockReturnValue({
- authType: 'openai',
- enableOpenAILogging: false,
- }),
- getSessionId: vi.fn().mockReturnValue('dashscope-session-id'),
- getCliVersion: vi.fn().mockReturnValue('1.0.0'),
- } as unknown as Config;
-
- const contentGeneratorConfig = {
- model: 'qwen-turbo',
-
- apiKey: 'test-key',
-
- authType: AuthType.QWEN_OAUTH,
-
- enableOpenAILogging: false,
- };
-
- const dashscopeGenerator = new OpenAIContentGenerator(
- contentGeneratorConfig,
- dashscopeConfig,
- );
-
- // Mock the client's baseURL property to return the expected value
- Object.defineProperty(dashscopeGenerator['client'], 'baseURL', {
- value: 'https://dashscope.aliyuncs.com/compatible-mode/v1',
- writable: true,
- });
-
- const mockResponse = {
- id: 'chatcmpl-123',
- choices: [
- {
- index: 0,
- message: { role: 'assistant', content: 'Response' },
- finish_reason: 'stop',
- },
- ],
- created: 1677652288,
- model: 'qwen-turbo',
- };
-
- mockOpenAIClient.chat.completions.create.mockResolvedValue(mockResponse);
-
- const request: GenerateContentParameters = {
- contents: [{ role: 'user', parts: [{ text: 'Hello' }] }],
- config: {
- systemInstruction: 'You are a helpful assistant.',
- },
- model: 'qwen-turbo',
- };
-
- await dashscopeGenerator.generateContent(request, 'dashscope-prompt-id');
-
- // Should include cache control in system message
- expect(mockOpenAIClient.chat.completions.create).toHaveBeenCalledWith(
- expect.objectContaining({
- messages: expect.arrayContaining([
- expect.objectContaining({
- role: 'system',
- content: expect.arrayContaining([
- expect.objectContaining({
- type: 'text',
- text: 'You are a helpful assistant.',
- cache_control: { type: 'ephemeral' },
- }),
- ]),
- }),
- ]),
- }),
- );
- });
-
- it('should add cache control to last message for DashScope providers', async () => {
- // Mock environment to set dashscope base URL
- vi.stubEnv(
- 'OPENAI_BASE_URL',
- 'https://dashscope.aliyuncs.com/compatible-mode/v1',
- );
-
- const dashscopeConfig = {
- getContentGeneratorConfig: vi.fn().mockReturnValue({
- authType: 'openai',
- enableOpenAILogging: false,
- }),
- getSessionId: vi.fn().mockReturnValue('dashscope-session-id'),
- getCliVersion: vi.fn().mockReturnValue('1.0.0'),
- } as unknown as Config;
-
- const contentGeneratorConfig = {
- model: 'qwen-turbo',
-
- apiKey: 'test-key',
-
- authType: AuthType.QWEN_OAUTH,
-
- enableOpenAILogging: false,
- };
-
- const dashscopeGenerator = new OpenAIContentGenerator(
- contentGeneratorConfig,
- dashscopeConfig,
- );
-
- // Mock the client's baseURL property to return the expected value
- Object.defineProperty(dashscopeGenerator['client'], 'baseURL', {
- value: 'https://dashscope.aliyuncs.com/compatible-mode/v1',
- writable: true,
- });
-
- const mockResponse = {
- id: 'chatcmpl-123',
- choices: [
- {
- index: 0,
- message: { role: 'assistant', content: 'Response' },
- finish_reason: 'stop',
- },
- ],
- created: 1677652288,
- model: 'qwen-turbo',
- };
-
- mockOpenAIClient.chat.completions.create.mockResolvedValue(mockResponse);
-
- const request: GenerateContentParameters = {
- contents: [{ role: 'user', parts: [{ text: 'Hello, how are you?' }] }],
- model: 'qwen-turbo',
- };
-
- await dashscopeGenerator.generateContentStream(
- request,
- 'dashscope-prompt-id',
- );
-
- // Should include cache control in last message
- expect(mockOpenAIClient.chat.completions.create).toHaveBeenCalledWith(
- expect.objectContaining({
- messages: expect.arrayContaining([
- expect.objectContaining({
- role: 'user',
- content: expect.arrayContaining([
- expect.objectContaining({
- type: 'text',
- text: 'Hello, how are you?',
- }),
- ]),
- }),
- ]),
- }),
- );
- });
-
- it('should NOT add cache control for non-DashScope providers', async () => {
- const regularConfig = {
- getContentGeneratorConfig: vi.fn().mockReturnValue({
- authType: 'openai',
- enableOpenAILogging: false,
- }),
- getSessionId: vi.fn().mockReturnValue('regular-session-id'),
- getCliVersion: vi.fn().mockReturnValue('1.0.0'),
- } as unknown as Config;
-
- const contentGeneratorConfig = {
- model: 'gpt-4',
-
- apiKey: 'test-key',
-
- authType: AuthType.USE_OPENAI,
-
- enableOpenAILogging: false,
- };
-
- const regularGenerator = new OpenAIContentGenerator(
- contentGeneratorConfig,
- regularConfig,
- );
-
- const mockResponse = {
- id: 'chatcmpl-123',
- choices: [
- {
- index: 0,
- message: { role: 'assistant', content: 'Response' },
- finish_reason: 'stop',
- },
- ],
- created: 1677652288,
- model: 'gpt-4',
- };
-
- mockOpenAIClient.chat.completions.create.mockResolvedValue(mockResponse);
-
- const request: GenerateContentParameters = {
- contents: [{ role: 'user', parts: [{ text: 'Hello' }] }],
- config: {
- systemInstruction: 'You are a helpful assistant.',
- },
- model: 'gpt-4',
- };
-
- await regularGenerator.generateContent(request, 'regular-prompt-id');
-
- // Should NOT include cache control (messages should be strings, not arrays)
- expect(mockOpenAIClient.chat.completions.create).toHaveBeenCalledWith(
- expect.objectContaining({
- messages: expect.arrayContaining([
- expect.objectContaining({
- role: 'system',
- content: 'You are a helpful assistant.',
- }),
- expect.objectContaining({
- role: 'user',
- content: 'Hello',
- }),
- ]),
- }),
- );
- });
- });
-});
diff --git a/packages/core/src/core/openaiContentGenerator.ts b/packages/core/src/core/openaiContentGenerator.ts
deleted file mode 100644
index e3e19533..00000000
--- a/packages/core/src/core/openaiContentGenerator.ts
+++ /dev/null
@@ -1,1711 +0,0 @@
-/**
- * @license
- * Copyright 2025 Qwen
- * SPDX-License-Identifier: Apache-2.0
- */
-
-import type {
- CountTokensResponse,
- GenerateContentParameters,
- CountTokensParameters,
- EmbedContentResponse,
- EmbedContentParameters,
- Part,
- Content,
- Tool,
- ToolListUnion,
- CallableTool,
- FunctionCall,
- FunctionResponse,
-} from '@google/genai';
-import { GenerateContentResponse, FinishReason } from '@google/genai';
-import type {
- ContentGenerator,
- ContentGeneratorConfig,
-} from './contentGenerator.js';
-import { AuthType } from './contentGenerator.js';
-import OpenAI from 'openai';
-import { logApiError, logApiResponse } from '../telemetry/loggers.js';
-import { ApiErrorEvent, ApiResponseEvent } from '../telemetry/types.js';
-import type { Config } from '../config/config.js';
-import { openaiLogger } from '../utils/openaiLogger.js';
-import { safeJsonParse } from '../utils/safeJsonParse.js';
-
-// Extended types to support cache_control
-interface ChatCompletionContentPartTextWithCache
- extends OpenAI.Chat.ChatCompletionContentPartText {
- cache_control?: { type: 'ephemeral' };
-}
-
-type ChatCompletionContentPartWithCache =
- | ChatCompletionContentPartTextWithCache
- | OpenAI.Chat.ChatCompletionContentPartImage
- | OpenAI.Chat.ChatCompletionContentPartRefusal;
-
-// OpenAI API type definitions for logging
-interface OpenAIToolCall {
- id: string;
- type: 'function';
- function: {
- name: string;
- arguments: string;
- };
-}
-
-interface OpenAIContentItem {
- type: 'text';
- text: string;
- cache_control?: { type: 'ephemeral' };
-}
-
-interface OpenAIMessage {
- role: 'system' | 'user' | 'assistant' | 'tool';
- content: string | null | OpenAIContentItem[];
- tool_calls?: OpenAIToolCall[];
- tool_call_id?: string;
-}
-
-interface OpenAIUsage {
- prompt_tokens: number;
- completion_tokens: number;
- total_tokens: number;
- prompt_tokens_details?: {
- cached_tokens?: number;
- };
-}
-
-interface OpenAIChoice {
- index: number;
- message: OpenAIMessage;
- finish_reason: string;
-}
-
-interface OpenAIResponseFormat {
- id: string;
- object: string;
- created: number;
- model: string;
- choices: OpenAIChoice[];
- usage?: OpenAIUsage;
-}
-
-/**
- * @deprecated refactored to ./openaiContentGenerator
- * use `createOpenAIContentGenerator` instead
- * or extend `OpenAIContentGenerator` to add customized behavior
- */
-export class OpenAIContentGenerator implements ContentGenerator {
- protected client: OpenAI;
- private model: string;
- private contentGeneratorConfig: ContentGeneratorConfig;
- private config: Config;
- private streamingToolCalls: Map<
- number,
- {
- id?: string;
- name?: string;
- arguments: string;
- }
- > = new Map();
-
- constructor(
- contentGeneratorConfig: ContentGeneratorConfig,
- gcConfig: Config,
- ) {
- this.model = contentGeneratorConfig.model;
- this.contentGeneratorConfig = contentGeneratorConfig;
- this.config = gcConfig;
-
- const version = gcConfig.getCliVersion() || 'unknown';
- const userAgent = `QwenCode/${version} (${process.platform}; ${process.arch})`;
-
- // Check if using OpenRouter and add required headers
- const isOpenRouterProvider = this.isOpenRouterProvider();
- const isDashScopeProvider = this.isDashScopeProvider();
-
- const defaultHeaders = {
- 'User-Agent': userAgent,
- ...(isOpenRouterProvider
- ? {
- 'HTTP-Referer': 'https://github.com/QwenLM/qwen-code.git',
- 'X-Title': 'Qwen Code',
- }
- : isDashScopeProvider
- ? {
- 'X-DashScope-CacheControl': 'enable',
- 'X-DashScope-UserAgent': userAgent,
- 'X-DashScope-AuthType': contentGeneratorConfig.authType,
- }
- : {}),
- };
-
- this.client = new OpenAI({
- apiKey: contentGeneratorConfig.apiKey,
- baseURL: contentGeneratorConfig.baseUrl,
- timeout: contentGeneratorConfig.timeout ?? 120000,
- maxRetries: contentGeneratorConfig.maxRetries ?? 3,
- defaultHeaders,
- });
- }
-
- /**
- * Hook for subclasses to customize error handling behavior
- * @param error The error that occurred
- * @param request The original request
- * @returns true if error logging should be suppressed, false otherwise
- */
- protected shouldSuppressErrorLogging(
- _error: unknown,
- _request: GenerateContentParameters,
- ): boolean {
- return false; // Default behavior: never suppress error logging
- }
-
- /**
- * Check if an error is a timeout error
- */
- private isTimeoutError(error: unknown): boolean {
- if (!error) return false;
-
- const errorMessage =
- error instanceof Error
- ? error.message.toLowerCase()
- : String(error).toLowerCase();
- // eslint-disable-next-line @typescript-eslint/no-explicit-any
- const errorCode = (error as any)?.code;
- // eslint-disable-next-line @typescript-eslint/no-explicit-any
- const errorType = (error as any)?.type;
-
- // Check for common timeout indicators
- return (
- errorMessage.includes('timeout') ||
- errorMessage.includes('timed out') ||
- errorMessage.includes('connection timeout') ||
- errorMessage.includes('request timeout') ||
- errorMessage.includes('read timeout') ||
- errorMessage.includes('etimedout') || // Include ETIMEDOUT in message check
- errorMessage.includes('esockettimedout') || // Include ESOCKETTIMEDOUT in message check
- errorCode === 'ETIMEDOUT' ||
- errorCode === 'ESOCKETTIMEDOUT' ||
- errorType === 'timeout' ||
- // OpenAI specific timeout indicators
- errorMessage.includes('request timed out') ||
- errorMessage.includes('deadline exceeded')
- );
- }
-
- private isOpenRouterProvider(): boolean {
- const baseURL = this.contentGeneratorConfig.baseUrl || '';
- return baseURL.includes('openrouter.ai');
- }
-
- /**
- * Determine if this is a DashScope provider.
- * DashScope providers include QWEN_OAUTH auth type or specific DashScope base URLs.
- *
- * @returns true if this is a DashScope provider, false otherwise
- */
- private isDashScopeProvider(): boolean {
- const authType = this.contentGeneratorConfig.authType;
- const baseUrl = this.contentGeneratorConfig.baseUrl;
-
- return (
- authType === AuthType.QWEN_OAUTH ||
- baseUrl === 'https://dashscope.aliyuncs.com/compatible-mode/v1' ||
- baseUrl === 'https://dashscope-intl.aliyuncs.com/compatible-mode/v1'
- );
- }
-
- /**
- * Check if cache control should be disabled based on configuration.
- *
- * @returns true if cache control should be disabled, false otherwise
- */
- private shouldDisableCacheControl(): boolean {
- return (
- this.config.getContentGeneratorConfig()?.disableCacheControl === true
- );
- }
-
- /**
- * Build metadata object for OpenAI API requests.
- *
- * @param userPromptId The user prompt ID to include in metadata
- * @returns metadata object if shouldIncludeMetadata() returns true, undefined otherwise
- */
- private buildMetadata(
- userPromptId: string,
- ): { metadata: { sessionId?: string; promptId: string } } | undefined {
- if (!this.isDashScopeProvider()) {
- return undefined;
- }
-
- return {
- metadata: {
- sessionId: this.config.getSessionId?.(),
- promptId: userPromptId,
- },
- };
- }
-
- private async buildCreateParams(
- request: GenerateContentParameters,
- userPromptId: string,
- streaming: boolean = false,
- ): Promise[0]> {
- let messages = this.convertToOpenAIFormat(request);
-
- // Add cache control to system and last messages for DashScope providers
- // Only add cache control to system message for non-streaming requests
- if (this.isDashScopeProvider() && !this.shouldDisableCacheControl()) {
- messages = this.addDashScopeCacheControl(
- messages,
- streaming ? 'both' : 'system',
- );
- }
-
- // Build sampling parameters with clear priority:
- // 1. Request-level parameters (highest priority)
- // 2. Config-level sampling parameters (medium priority)
- // 3. Default values (lowest priority)
- const samplingParams = this.buildSamplingParameters(request);
-
- const createParams: Parameters<
- typeof this.client.chat.completions.create
- >[0] = {
- model: this.model,
- messages,
- ...samplingParams,
- ...(this.buildMetadata(userPromptId) || {}),
- };
-
- if (request.config?.tools) {
- createParams.tools = await this.convertGeminiToolsToOpenAI(
- request.config.tools,
- );
- }
-
- if (streaming) {
- createParams.stream = true;
- createParams.stream_options = { include_usage: true };
- }
-
- return createParams;
- }
-
- async generateContent(
- request: GenerateContentParameters,
- userPromptId: string,
- ): Promise {
- const startTime = Date.now();
- const createParams = await this.buildCreateParams(
- request,
- userPromptId,
- false,
- );
-
- try {
- const completion = (await this.client.chat.completions.create(
- createParams,
- )) as OpenAI.Chat.ChatCompletion;
-
- const response = this.convertToGeminiFormat(completion);
- const durationMs = Date.now() - startTime;
-
- // Log API response event for UI telemetry
- const responseEvent = new ApiResponseEvent(
- response.responseId || 'unknown',
- this.model,
- durationMs,
- userPromptId,
- this.contentGeneratorConfig.authType,
- response.usageMetadata,
- );
-
- logApiResponse(this.config, responseEvent);
-
- // Log interaction if enabled
- if (this.contentGeneratorConfig.enableOpenAILogging) {
- const openaiRequest = createParams;
- const openaiResponse = this.convertGeminiResponseToOpenAI(response);
- await openaiLogger.logInteraction(openaiRequest, openaiResponse);
- }
-
- return response;
- } catch (error) {
- const durationMs = Date.now() - startTime;
-
- // Identify timeout errors specifically
- const isTimeoutError = this.isTimeoutError(error);
- const errorMessage = isTimeoutError
- ? `Request timeout after ${Math.round(durationMs / 1000)}s. Try reducing input length or increasing timeout in config.`
- : error instanceof Error
- ? error.message
- : String(error);
-
- // Log API error event for UI telemetry
- const errorEvent = new ApiErrorEvent(
- // eslint-disable-next-line @typescript-eslint/no-explicit-any
- (error as any).requestID || 'unknown',
- this.model,
- errorMessage,
- durationMs,
- userPromptId,
- this.contentGeneratorConfig.authType,
- // eslint-disable-next-line @typescript-eslint/no-explicit-any
- (error as any).type,
- // eslint-disable-next-line @typescript-eslint/no-explicit-any
- (error as any).code,
- );
- logApiError(this.config, errorEvent);
-
- // Log error interaction if enabled
- if (this.contentGeneratorConfig.enableOpenAILogging) {
- await openaiLogger.logInteraction(
- createParams,
- undefined,
- error as Error,
- );
- }
-
- // Allow subclasses to suppress error logging for specific scenarios
- if (!this.shouldSuppressErrorLogging(error, request)) {
- console.error('OpenAI API Error:', errorMessage);
- }
-
- // Provide helpful timeout-specific error message
- if (isTimeoutError) {
- throw new Error(
- `${errorMessage}\n\nTroubleshooting tips:\n` +
- `- Reduce input length or complexity\n` +
- `- Increase timeout in config: contentGenerator.timeout\n` +
- `- Check network connectivity\n` +
- `- Consider using streaming mode for long responses`,
- );
- }
-
- throw error;
- }
- }
-
- async generateContentStream(
- request: GenerateContentParameters,
- userPromptId: string,
- ): Promise> {
- const startTime = Date.now();
- const createParams = await this.buildCreateParams(
- request,
- userPromptId,
- true,
- );
-
- try {
- const stream = (await this.client.chat.completions.create(
- createParams,
- )) as AsyncIterable;
-
- const originalStream = this.streamGenerator(stream);
-
- // Collect all responses for final logging (don't log during streaming)
- const responses: GenerateContentResponse[] = [];
-
- // Return a new generator that both yields responses and collects them
- const wrappedGenerator = async function* (this: OpenAIContentGenerator) {
- try {
- for await (const response of originalStream) {
- responses.push(response);
- yield response;
- }
-
- const durationMs = Date.now() - startTime;
-
- // Get final usage metadata from the last response that has it
- const finalUsageMetadata = responses
- .slice()
- .reverse()
- .find((r) => r.usageMetadata)?.usageMetadata;
-
- // Log API response event for UI telemetry
- const responseEvent = new ApiResponseEvent(
- responses[responses.length - 1]?.responseId || 'unknown',
- this.model,
- durationMs,
- userPromptId,
- this.contentGeneratorConfig.authType,
- finalUsageMetadata,
- );
-
- logApiResponse(this.config, responseEvent);
-
- // Log interaction if enabled (same as generateContent method)
- if (this.contentGeneratorConfig.enableOpenAILogging) {
- const openaiRequest = createParams;
- // For streaming, we combine all responses into a single response for logging
- const combinedResponse =
- this.combineStreamResponsesForLogging(responses);
- const openaiResponse =
- this.convertGeminiResponseToOpenAI(combinedResponse);
- await openaiLogger.logInteraction(openaiRequest, openaiResponse);
- }
- } catch (error) {
- const durationMs = Date.now() - startTime;
-
- // Identify timeout errors specifically for streaming
- const isTimeoutError = this.isTimeoutError(error);
- const errorMessage = isTimeoutError
- ? `Streaming request timeout after ${Math.round(durationMs / 1000)}s. Try reducing input length or increasing timeout in config.`
- : error instanceof Error
- ? error.message
- : String(error);
-
- // Log API error event for UI telemetry
- const errorEvent = new ApiErrorEvent(
- // eslint-disable-next-line @typescript-eslint/no-explicit-any
- (error as any).requestID || 'unknown',
- this.model,
- errorMessage,
- durationMs,
- userPromptId,
- this.contentGeneratorConfig.authType,
- // eslint-disable-next-line @typescript-eslint/no-explicit-any
- (error as any).type,
- // eslint-disable-next-line @typescript-eslint/no-explicit-any
- (error as any).code,
- );
- logApiError(this.config, errorEvent);
-
- // Log error interaction if enabled
- if (this.contentGeneratorConfig.enableOpenAILogging) {
- await openaiLogger.logInteraction(
- createParams,
- undefined,
- error as Error,
- );
- }
-
- // Provide helpful timeout-specific error message for streaming
- if (isTimeoutError) {
- throw new Error(
- `${errorMessage}\n\nStreaming timeout troubleshooting:\n` +
- `- Reduce input length or complexity\n` +
- `- Increase timeout in config: contentGenerator.timeout\n` +
- `- Check network stability for streaming connections\n` +
- `- Consider using non-streaming mode for very long inputs`,
- );
- }
-
- throw error;
- }
- }.bind(this);
-
- return wrappedGenerator();
- } catch (error) {
- const durationMs = Date.now() - startTime;
-
- // Identify timeout errors specifically for streaming setup
- const isTimeoutError = this.isTimeoutError(error);
- const errorMessage = isTimeoutError
- ? `Streaming setup timeout after ${Math.round(durationMs / 1000)}s. Try reducing input length or increasing timeout in config.`
- : error instanceof Error
- ? error.message
- : String(error);
-
- // Log API error event for UI telemetry
- const errorEvent = new ApiErrorEvent(
- // eslint-disable-next-line @typescript-eslint/no-explicit-any
- (error as any).requestID || 'unknown',
- this.model,
- errorMessage,
- durationMs,
- userPromptId,
- this.contentGeneratorConfig.authType,
- // eslint-disable-next-line @typescript-eslint/no-explicit-any
- (error as any).type,
- // eslint-disable-next-line @typescript-eslint/no-explicit-any
- (error as any).code,
- );
- logApiError(this.config, errorEvent);
-
- // Allow subclasses to suppress error logging for specific scenarios
- if (!this.shouldSuppressErrorLogging(error, request)) {
- console.error('OpenAI API Streaming Error:', errorMessage);
- }
-
- // Provide helpful timeout-specific error message for streaming setup
- if (isTimeoutError) {
- throw new Error(
- `${errorMessage}\n\nStreaming setup timeout troubleshooting:\n` +
- `- Reduce input length or complexity\n` +
- `- Increase timeout in config: contentGenerator.timeout\n` +
- `- Check network connectivity and firewall settings\n` +
- `- Consider using non-streaming mode for very long inputs`,
- );
- }
-
- throw error;
- }
- }
-
- private async *streamGenerator(
- stream: AsyncIterable,
- ): AsyncGenerator {
- // Reset the accumulator for each new stream
- this.streamingToolCalls.clear();
-
- for await (const chunk of stream) {
- const response = this.convertStreamChunkToGeminiFormat(chunk);
-
- // Ignore empty responses, which would cause problems with downstream code
- // that expects a valid response.
- if (
- response.candidates?.[0]?.content?.parts?.length === 0 &&
- !response.usageMetadata
- ) {
- continue;
- }
-
- yield response;
- }
- }
-
- /**
- * Combine streaming responses for logging purposes
- */
- private combineStreamResponsesForLogging(
- responses: GenerateContentResponse[],
- ): GenerateContentResponse {
- if (responses.length === 0) {
- return new GenerateContentResponse();
- }
-
- const lastResponse = responses[responses.length - 1];
-
- // Find the last response with usage metadata
- const finalUsageMetadata = responses
- .slice()
- .reverse()
- .find((r) => r.usageMetadata)?.usageMetadata;
-
- // Combine all text content from the stream
- const combinedParts: Part[] = [];
- let combinedText = '';
- const functionCalls: Part[] = [];
-
- for (const response of responses) {
- if (response.candidates?.[0]?.content?.parts) {
- for (const part of response.candidates[0].content.parts) {
- if ('text' in part && part.text) {
- combinedText += part.text;
- } else if ('functionCall' in part && part.functionCall) {
- functionCalls.push(part);
- }
- }
- }
- }
-
- // Add combined text if any
- if (combinedText) {
- combinedParts.push({ text: combinedText });
- }
-
- // Add function calls
- combinedParts.push(...functionCalls);
-
- // Create combined response
- const combinedResponse = new GenerateContentResponse();
- combinedResponse.candidates = [
- {
- content: {
- parts: combinedParts,
- role: 'model' as const,
- },
- finishReason:
- responses[responses.length - 1]?.candidates?.[0]?.finishReason ||
- FinishReason.FINISH_REASON_UNSPECIFIED,
- index: 0,
- safetyRatings: [],
- },
- ];
- combinedResponse.responseId = lastResponse?.responseId;
- combinedResponse.createTime = lastResponse?.createTime;
- combinedResponse.modelVersion = this.model;
- combinedResponse.promptFeedback = { safetyRatings: [] };
- combinedResponse.usageMetadata = finalUsageMetadata;
-
- return combinedResponse;
- }
-
- async countTokens(
- request: CountTokensParameters,
- ): Promise {
- // Use tiktoken for accurate token counting
- const content = JSON.stringify(request.contents);
- let totalTokens = 0;
-
- try {
- const { get_encoding } = await import('tiktoken');
- const encoding = get_encoding('cl100k_base'); // GPT-4 encoding, but estimate for qwen
- totalTokens = encoding.encode(content).length;
- encoding.free();
- } catch (error) {
- console.warn(
- 'Failed to load tiktoken, falling back to character approximation:',
- error,
- );
- // Fallback: rough approximation using character count
- totalTokens = Math.ceil(content.length / 4); // Rough estimate: 1 token ≈ 4 characters
- }
-
- return {
- totalTokens,
- };
- }
-
- async embedContent(
- request: EmbedContentParameters,
- ): Promise {
- // Extract text from contents
- let text = '';
- if (Array.isArray(request.contents)) {
- text = request.contents
- .map((content) => {
- if (typeof content === 'string') return content;
- if ('parts' in content && content.parts) {
- return content.parts
- .map((part) =>
- typeof part === 'string'
- ? part
- : 'text' in part
- ? (part as { text?: string }).text || ''
- : '',
- )
- .join(' ');
- }
- return '';
- })
- .join(' ');
- } else if (request.contents) {
- if (typeof request.contents === 'string') {
- text = request.contents;
- } else if ('parts' in request.contents && request.contents.parts) {
- text = request.contents.parts
- .map((part: Part) =>
- typeof part === 'string' ? part : 'text' in part ? part.text : '',
- )
- .join(' ');
- }
- }
-
- try {
- const embedding = await this.client.embeddings.create({
- model: 'text-embedding-ada-002', // Default embedding model
- input: text,
- });
-
- return {
- embeddings: [
- {
- values: embedding.data[0].embedding,
- },
- ],
- };
- } catch (error) {
- console.error('OpenAI API Embedding Error:', error);
- throw new Error(
- `OpenAI API error: ${error instanceof Error ? error.message : String(error)}`,
- );
- }
- }
-
- private convertGeminiParametersToOpenAI(
- parameters: Record,
- ): Record | undefined {
- if (!parameters || typeof parameters !== 'object') {
- return parameters;
- }
-
- const converted = JSON.parse(JSON.stringify(parameters));
-
- const convertTypes = (obj: unknown): unknown => {
- if (typeof obj !== 'object' || obj === null) {
- return obj;
- }
-
- if (Array.isArray(obj)) {
- return obj.map(convertTypes);
- }
-
- const result: Record = {};
- for (const [key, value] of Object.entries(obj)) {
- if (key === 'type' && typeof value === 'string') {
- // Convert Gemini types to OpenAI JSON Schema types
- const lowerValue = value.toLowerCase();
- if (lowerValue === 'integer') {
- result[key] = 'integer';
- } else if (lowerValue === 'number') {
- result[key] = 'number';
- } else {
- result[key] = lowerValue;
- }
- } else if (
- key === 'minimum' ||
- key === 'maximum' ||
- key === 'multipleOf'
- ) {
- // Ensure numeric constraints are actual numbers, not strings
- if (typeof value === 'string' && !isNaN(Number(value))) {
- result[key] = Number(value);
- } else {
- result[key] = value;
- }
- } else if (
- key === 'minLength' ||
- key === 'maxLength' ||
- key === 'minItems' ||
- key === 'maxItems'
- ) {
- // Ensure length constraints are integers, not strings
- if (typeof value === 'string' && !isNaN(Number(value))) {
- result[key] = parseInt(value, 10);
- } else {
- result[key] = value;
- }
- } else if (typeof value === 'object') {
- result[key] = convertTypes(value);
- } else {
- result[key] = value;
- }
- }
- return result;
- };
-
- return convertTypes(converted) as Record | undefined;
- }
-
- /**
- * Converts Gemini tools to OpenAI format for API compatibility.
- * Handles both Gemini tools (using 'parameters' field) and MCP tools (using 'parametersJsonSchema' field).
- *
- * Gemini tools use a custom parameter format that needs conversion to OpenAI JSON Schema format.
- * MCP tools already use JSON Schema format in the parametersJsonSchema field and can be used directly.
- *
- * @param geminiTools - Array of Gemini tools to convert
- * @returns Promise resolving to array of OpenAI-compatible tools
- */
- private async convertGeminiToolsToOpenAI(
- geminiTools: ToolListUnion,
- ): Promise {
- const openAITools: OpenAI.Chat.ChatCompletionTool[] = [];
-
- for (const tool of geminiTools) {
- let actualTool: Tool;
-
- // Handle CallableTool vs Tool
- if ('tool' in tool) {
- // This is a CallableTool
- actualTool = await (tool as CallableTool).tool();
- } else {
- // This is already a Tool
- actualTool = tool as Tool;
- }
-
- if (actualTool.functionDeclarations) {
- for (const func of actualTool.functionDeclarations) {
- if (func.name && func.description) {
- let parameters: Record | undefined;
-
- // Handle both Gemini tools (parameters) and MCP tools (parametersJsonSchema)
- if (func.parametersJsonSchema) {
- // MCP tool format - use parametersJsonSchema directly
- if (func.parametersJsonSchema) {
- // Create a shallow copy to avoid mutating the original object
- const paramsCopy = {
- ...(func.parametersJsonSchema as Record),
- };
- parameters = paramsCopy;
- }
- } else if (func.parameters) {
- // Gemini tool format - convert parameters to OpenAI format
- parameters = this.convertGeminiParametersToOpenAI(
- func.parameters as Record,
- );
- }
-
- openAITools.push({
- type: 'function',
- function: {
- name: func.name,
- description: func.description,
- parameters,
- },
- });
- }
- }
- }
- }
-
- // console.log(
- // 'OpenAI Tools Parameters:',
- // JSON.stringify(openAITools, null, 2),
- // );
- return openAITools;
- }
-
- private convertToOpenAIFormat(
- request: GenerateContentParameters,
- ): OpenAI.Chat.ChatCompletionMessageParam[] {
- const messages: OpenAI.Chat.ChatCompletionMessageParam[] = [];
-
- // Handle system instruction from config
- if (request.config?.systemInstruction) {
- const systemInstruction = request.config.systemInstruction;
- let systemText = '';
-
- if (Array.isArray(systemInstruction)) {
- systemText = systemInstruction
- .map((content) => {
- if (typeof content === 'string') return content;
- if ('parts' in content) {
- const contentObj = content as Content;
- return (
- contentObj.parts
- ?.map((p: Part) =>
- typeof p === 'string' ? p : 'text' in p ? p.text : '',
- )
- .join('\n') || ''
- );
- }
- return '';
- })
- .join('\n');
- } else if (typeof systemInstruction === 'string') {
- systemText = systemInstruction;
- } else if (
- typeof systemInstruction === 'object' &&
- 'parts' in systemInstruction
- ) {
- const systemContent = systemInstruction as Content;
- systemText =
- systemContent.parts
- ?.map((p: Part) =>
- typeof p === 'string' ? p : 'text' in p ? p.text : '',
- )
- .join('\n') || '';
- }
-
- if (systemText) {
- messages.push({
- role: 'system' as const,
- content: systemText,
- });
- }
- }
-
- // Handle contents
- if (Array.isArray(request.contents)) {
- for (const content of request.contents) {
- if (typeof content === 'string') {
- messages.push({ role: 'user' as const, content });
- } else if ('role' in content && 'parts' in content) {
- // Check if this content has function calls or responses
- const functionCalls: FunctionCall[] = [];
- const functionResponses: FunctionResponse[] = [];
- const textParts: string[] = [];
-
- for (const part of content.parts || []) {
- if (typeof part === 'string') {
- textParts.push(part);
- } else if ('text' in part && part.text) {
- textParts.push(part.text);
- } else if ('functionCall' in part && part.functionCall) {
- functionCalls.push(part.functionCall);
- } else if ('functionResponse' in part && part.functionResponse) {
- functionResponses.push(part.functionResponse);
- }
- }
-
- // Handle function responses (tool results)
- if (functionResponses.length > 0) {
- for (const funcResponse of functionResponses) {
- messages.push({
- role: 'tool' as const,
- tool_call_id: funcResponse.id || '',
- content:
- typeof funcResponse.response === 'string'
- ? funcResponse.response
- : JSON.stringify(funcResponse.response),
- });
- }
- }
- // Handle model messages with function calls
- else if (content.role === 'model' && functionCalls.length > 0) {
- const toolCalls = functionCalls.map((fc, index) => ({
- id: fc.id || `call_${index}`,
- type: 'function' as const,
- function: {
- name: fc.name || '',
- arguments: JSON.stringify(fc.args || {}),
- },
- }));
-
- messages.push({
- role: 'assistant' as const,
- content: textParts.join('') || null,
- tool_calls: toolCalls,
- });
- }
- // Handle regular text messages
- else {
- const role =
- content.role === 'model'
- ? ('assistant' as const)
- : ('user' as const);
- const text = textParts.join('');
- if (text) {
- messages.push({ role, content: text });
- }
- }
- }
- }
- } else if (request.contents) {
- if (typeof request.contents === 'string') {
- messages.push({ role: 'user' as const, content: request.contents });
- } else if ('role' in request.contents && 'parts' in request.contents) {
- const content = request.contents;
- const role =
- content.role === 'model' ? ('assistant' as const) : ('user' as const);
- const text =
- content.parts
- ?.map((p: Part) =>
- typeof p === 'string' ? p : 'text' in p ? p.text : '',
- )
- .join('\n') || '';
- messages.push({ role, content: text });
- }
- }
-
- // Clean up orphaned tool calls and merge consecutive assistant messages
- const cleanedMessages = this.cleanOrphanedToolCalls(messages);
- const mergedMessages =
- this.mergeConsecutiveAssistantMessages(cleanedMessages);
-
- return mergedMessages;
- }
-
- /**
- * Add cache control flag to specified message(s) for DashScope providers
- */
- private addDashScopeCacheControl(
- messages: OpenAI.Chat.ChatCompletionMessageParam[],
- target: 'system' | 'last' | 'both' = 'both',
- ): OpenAI.Chat.ChatCompletionMessageParam[] {
- if (!this.isDashScopeProvider() || messages.length === 0) {
- return messages;
- }
-
- let updatedMessages = [...messages];
-
- // Add cache control to system message if requested
- if (target === 'system' || target === 'both') {
- updatedMessages = this.addCacheControlToMessage(
- updatedMessages,
- 'system',
- );
- }
-
- // Add cache control to last message if requested
- if (target === 'last' || target === 'both') {
- updatedMessages = this.addCacheControlToMessage(updatedMessages, 'last');
- }
-
- return updatedMessages;
- }
-
- /**
- * Helper method to add cache control to a specific message
- */
- private addCacheControlToMessage(
- messages: OpenAI.Chat.ChatCompletionMessageParam[],
- target: 'system' | 'last',
- ): OpenAI.Chat.ChatCompletionMessageParam[] {
- const updatedMessages = [...messages];
- let messageIndex: number;
-
- if (target === 'system') {
- // Find the first system message
- messageIndex = messages.findIndex((msg) => msg.role === 'system');
- if (messageIndex === -1) {
- return updatedMessages;
- }
- } else {
- // Get the last message
- messageIndex = messages.length - 1;
- }
-
- const message = updatedMessages[messageIndex];
-
- // Only process messages that have content
- if ('content' in message && message.content !== null) {
- if (typeof message.content === 'string') {
- // Convert string content to array format with cache control
- const messageWithArrayContent = {
- ...message,
- content: [
- {
- type: 'text',
- text: message.content,
- cache_control: { type: 'ephemeral' },
- } as ChatCompletionContentPartTextWithCache,
- ],
- };
- updatedMessages[messageIndex] =
- messageWithArrayContent as OpenAI.Chat.ChatCompletionMessageParam;
- } else if (Array.isArray(message.content)) {
- // If content is already an array, add cache_control to the last item
- const contentArray = [
- ...message.content,
- ] as ChatCompletionContentPartWithCache[];
- if (contentArray.length > 0) {
- const lastItem = contentArray[contentArray.length - 1];
- if (lastItem.type === 'text') {
- // Add cache_control to the last text item
- contentArray[contentArray.length - 1] = {
- ...lastItem,
- cache_control: { type: 'ephemeral' },
- } as ChatCompletionContentPartTextWithCache;
- } else {
- // If the last item is not text, add a new text item with cache_control
- contentArray.push({
- type: 'text',
- text: '',
- cache_control: { type: 'ephemeral' },
- } as ChatCompletionContentPartTextWithCache);
- }
-
- const messageWithCache = {
- ...message,
- content: contentArray,
- };
- updatedMessages[messageIndex] =
- messageWithCache as OpenAI.Chat.ChatCompletionMessageParam;
- }
- }
- }
-
- return updatedMessages;
- }
-
- /**
- * Clean up orphaned tool calls from message history to prevent OpenAI API errors
- */
- private cleanOrphanedToolCalls(
- messages: OpenAI.Chat.ChatCompletionMessageParam[],
- ): OpenAI.Chat.ChatCompletionMessageParam[] {
- const cleaned: OpenAI.Chat.ChatCompletionMessageParam[] = [];
- const toolCallIds = new Set();
- const toolResponseIds = new Set();
-
- // First pass: collect all tool call IDs and tool response IDs
- for (const message of messages) {
- if (
- message.role === 'assistant' &&
- 'tool_calls' in message &&
- message.tool_calls
- ) {
- for (const toolCall of message.tool_calls) {
- if (toolCall.id) {
- toolCallIds.add(toolCall.id);
- }
- }
- } else if (
- message.role === 'tool' &&
- 'tool_call_id' in message &&
- message.tool_call_id
- ) {
- toolResponseIds.add(message.tool_call_id);
- }
- }
-
- // Second pass: filter out orphaned messages
- for (const message of messages) {
- if (
- message.role === 'assistant' &&
- 'tool_calls' in message &&
- message.tool_calls
- ) {
- // Filter out tool calls that don't have corresponding responses
- const validToolCalls = message.tool_calls.filter(
- (toolCall) => toolCall.id && toolResponseIds.has(toolCall.id),
- );
-
- if (validToolCalls.length > 0) {
- // Keep the message but only with valid tool calls
- const cleanedMessage = { ...message };
- (
- cleanedMessage as OpenAI.Chat.ChatCompletionMessageParam & {
- tool_calls?: OpenAI.Chat.ChatCompletionMessageToolCall[];
- }
- ).tool_calls = validToolCalls;
- cleaned.push(cleanedMessage);
- } else if (
- typeof message.content === 'string' &&
- message.content.trim()
- ) {
- // Keep the message if it has text content, but remove tool calls
- const cleanedMessage = { ...message };
- delete (
- cleanedMessage as OpenAI.Chat.ChatCompletionMessageParam & {
- tool_calls?: OpenAI.Chat.ChatCompletionMessageToolCall[];
- }
- ).tool_calls;
- cleaned.push(cleanedMessage);
- }
- // If no valid tool calls and no content, skip the message entirely
- } else if (
- message.role === 'tool' &&
- 'tool_call_id' in message &&
- message.tool_call_id
- ) {
- // Only keep tool responses that have corresponding tool calls
- if (toolCallIds.has(message.tool_call_id)) {
- cleaned.push(message);
- }
- } else {
- // Keep all other messages as-is
- cleaned.push(message);
- }
- }
-
- // Final validation: ensure every assistant message with tool_calls has corresponding tool responses
- const finalCleaned: OpenAI.Chat.ChatCompletionMessageParam[] = [];
- const finalToolCallIds = new Set();
-
- // Collect all remaining tool call IDs
- for (const message of cleaned) {
- if (
- message.role === 'assistant' &&
- 'tool_calls' in message &&
- message.tool_calls
- ) {
- for (const toolCall of message.tool_calls) {
- if (toolCall.id) {
- finalToolCallIds.add(toolCall.id);
- }
- }
- }
- }
-
- // Verify all tool calls have responses
- const finalToolResponseIds = new Set();
- for (const message of cleaned) {
- if (
- message.role === 'tool' &&
- 'tool_call_id' in message &&
- message.tool_call_id
- ) {
- finalToolResponseIds.add(message.tool_call_id);
- }
- }
-
- // Remove any remaining orphaned tool calls
- for (const message of cleaned) {
- if (
- message.role === 'assistant' &&
- 'tool_calls' in message &&
- message.tool_calls
- ) {
- const finalValidToolCalls = message.tool_calls.filter(
- (toolCall) => toolCall.id && finalToolResponseIds.has(toolCall.id),
- );
-
- if (finalValidToolCalls.length > 0) {
- const cleanedMessage = { ...message };
- (
- cleanedMessage as OpenAI.Chat.ChatCompletionMessageParam & {
- tool_calls?: OpenAI.Chat.ChatCompletionMessageToolCall[];
- }
- ).tool_calls = finalValidToolCalls;
- finalCleaned.push(cleanedMessage);
- } else if (
- typeof message.content === 'string' &&
- message.content.trim()
- ) {
- const cleanedMessage = { ...message };
- delete (
- cleanedMessage as OpenAI.Chat.ChatCompletionMessageParam & {
- tool_calls?: OpenAI.Chat.ChatCompletionMessageToolCall[];
- }
- ).tool_calls;
- finalCleaned.push(cleanedMessage);
- }
- } else {
- finalCleaned.push(message);
- }
- }
-
- return finalCleaned;
- }
-
- /**
- * Merge consecutive assistant messages to combine split text and tool calls
- */
- private mergeConsecutiveAssistantMessages(
- messages: OpenAI.Chat.ChatCompletionMessageParam[],
- ): OpenAI.Chat.ChatCompletionMessageParam[] {
- const merged: OpenAI.Chat.ChatCompletionMessageParam[] = [];
-
- for (const message of messages) {
- if (message.role === 'assistant' && merged.length > 0) {
- const lastMessage = merged[merged.length - 1];
-
- // If the last message is also an assistant message, merge them
- if (lastMessage.role === 'assistant') {
- // Combine content
- const combinedContent = [
- typeof lastMessage.content === 'string' ? lastMessage.content : '',
- typeof message.content === 'string' ? message.content : '',
- ]
- .filter(Boolean)
- .join('');
-
- // Combine tool calls
- const lastToolCalls =
- 'tool_calls' in lastMessage ? lastMessage.tool_calls || [] : [];
- const currentToolCalls =
- 'tool_calls' in message ? message.tool_calls || [] : [];
- const combinedToolCalls = [...lastToolCalls, ...currentToolCalls];
-
- // Update the last message with combined data
- (
- lastMessage as OpenAI.Chat.ChatCompletionMessageParam & {
- content: string | null;
- tool_calls?: OpenAI.Chat.ChatCompletionMessageToolCall[];
- }
- ).content = combinedContent || null;
- if (combinedToolCalls.length > 0) {
- (
- lastMessage as OpenAI.Chat.ChatCompletionMessageParam & {
- content: string | null;
- tool_calls?: OpenAI.Chat.ChatCompletionMessageToolCall[];
- }
- ).tool_calls = combinedToolCalls;
- }
-
- continue; // Skip adding the current message since it's been merged
- }
- }
-
- // Add the message as-is if no merging is needed
- merged.push(message);
- }
-
- return merged;
- }
-
- private convertToGeminiFormat(
- openaiResponse: OpenAI.Chat.ChatCompletion,
- ): GenerateContentResponse {
- const choice = openaiResponse.choices[0];
- const response = new GenerateContentResponse();
-
- const parts: Part[] = [];
-
- // Handle text content
- if (choice.message.content) {
- parts.push({ text: choice.message.content });
- }
-
- // Handle tool calls
- if (choice.message.tool_calls) {
- for (const toolCall of choice.message.tool_calls) {
- if (toolCall.function) {
- let args: Record = {};
- if (toolCall.function.arguments) {
- args = safeJsonParse(toolCall.function.arguments, {});
- }
-
- parts.push({
- functionCall: {
- id: toolCall.id,
- name: toolCall.function.name,
- args,
- },
- });
- }
- }
- }
-
- response.responseId = openaiResponse.id;
- response.createTime = openaiResponse.created
- ? openaiResponse.created.toString()
- : new Date().getTime().toString();
-
- response.candidates = [
- {
- content: {
- parts,
- role: 'model' as const,
- },
- finishReason: this.mapFinishReason(choice.finish_reason || 'stop'),
- index: 0,
- safetyRatings: [],
- },
- ];
-
- response.modelVersion = this.model;
- response.promptFeedback = { safetyRatings: [] };
-
- // Add usage metadata if available
- if (openaiResponse.usage) {
- const usage = openaiResponse.usage as OpenAIUsage;
-
- const promptTokens = usage.prompt_tokens || 0;
- const completionTokens = usage.completion_tokens || 0;
- const totalTokens = usage.total_tokens || 0;
- const cachedTokens = usage.prompt_tokens_details?.cached_tokens || 0;
-
- // If we only have total tokens but no breakdown, estimate the split
- // Typically input is ~70% and output is ~30% for most conversations
- let finalPromptTokens = promptTokens;
- let finalCompletionTokens = completionTokens;
-
- if (totalTokens > 0 && promptTokens === 0 && completionTokens === 0) {
- // Estimate: assume 70% input, 30% output
- finalPromptTokens = Math.round(totalTokens * 0.7);
- finalCompletionTokens = Math.round(totalTokens * 0.3);
- }
-
- response.usageMetadata = {
- promptTokenCount: finalPromptTokens,
- candidatesTokenCount: finalCompletionTokens,
- totalTokenCount: totalTokens,
- cachedContentTokenCount: cachedTokens,
- };
- }
-
- return response;
- }
-
- private convertStreamChunkToGeminiFormat(
- chunk: OpenAI.Chat.ChatCompletionChunk,
- ): GenerateContentResponse {
- const choice = chunk.choices?.[0];
- const response = new GenerateContentResponse();
-
- if (choice) {
- const parts: Part[] = [];
-
- // Handle text content
- if (choice.delta?.content) {
- if (typeof choice.delta.content === 'string') {
- parts.push({ text: choice.delta.content });
- }
- }
-
- // Handle tool calls - only accumulate during streaming, emit when complete
- if (choice.delta?.tool_calls) {
- for (const toolCall of choice.delta.tool_calls) {
- const index = toolCall.index ?? 0;
-
- // Get or create the tool call accumulator for this index
- let accumulatedCall = this.streamingToolCalls.get(index);
- if (!accumulatedCall) {
- accumulatedCall = { arguments: '' };
- this.streamingToolCalls.set(index, accumulatedCall);
- }
-
- // Update accumulated data
- if (toolCall.id) {
- accumulatedCall.id = toolCall.id;
- }
- if (toolCall.function?.name) {
- // If this is a new function name, reset the arguments
- if (accumulatedCall.name !== toolCall.function.name) {
- accumulatedCall.arguments = '';
- }
- accumulatedCall.name = toolCall.function.name;
- }
- if (toolCall.function?.arguments) {
- // Check if we already have a complete JSON object
- const currentArgs = accumulatedCall.arguments;
- const newArgs = toolCall.function.arguments;
-
- // If current arguments already form a complete JSON and new arguments start a new object,
- // this indicates a new tool call with the same name
- let shouldReset = false;
- if (currentArgs && newArgs.trim().startsWith('{')) {
- try {
- JSON.parse(currentArgs);
- // If we can parse current arguments as complete JSON and new args start with {,
- // this is likely a new tool call
- shouldReset = true;
- } catch {
- // Current arguments are not complete JSON, continue accumulating
- }
- }
-
- if (shouldReset) {
- accumulatedCall.arguments = newArgs;
- } else {
- accumulatedCall.arguments += newArgs;
- }
- }
- }
- }
-
- // Only emit function calls when streaming is complete (finish_reason is present)
- if (choice.finish_reason) {
- for (const [, accumulatedCall] of this.streamingToolCalls) {
- // TODO: Add back id once we have a way to generate tool_call_id from the VLLM parser.
- // if (accumulatedCall.id && accumulatedCall.name) {
- if (accumulatedCall.name) {
- let args: Record = {};
- if (accumulatedCall.arguments) {
- args = safeJsonParse(accumulatedCall.arguments, {});
- }
-
- parts.push({
- functionCall: {
- id:
- accumulatedCall.id ||
- `call_${Date.now()}_${Math.random().toString(36).substring(2, 9)}`,
- name: accumulatedCall.name,
- args,
- },
- });
- }
- }
- // Clear all accumulated tool calls
- this.streamingToolCalls.clear();
- }
-
- response.candidates = [
- {
- content: {
- parts,
- role: 'model' as const,
- },
- finishReason: choice.finish_reason
- ? this.mapFinishReason(choice.finish_reason)
- : FinishReason.FINISH_REASON_UNSPECIFIED,
- index: 0,
- safetyRatings: [],
- },
- ];
- } else {
- response.candidates = [];
- }
-
- response.responseId = chunk.id;
- response.createTime = chunk.created
- ? chunk.created.toString()
- : new Date().getTime().toString();
-
- response.modelVersion = this.model;
- response.promptFeedback = { safetyRatings: [] };
-
- // Add usage metadata if available in the chunk
- if (chunk.usage) {
- const usage = chunk.usage as OpenAIUsage;
-
- const promptTokens = usage.prompt_tokens || 0;
- const completionTokens = usage.completion_tokens || 0;
- const totalTokens = usage.total_tokens || 0;
- const cachedTokens = usage.prompt_tokens_details?.cached_tokens || 0;
-
- // If we only have total tokens but no breakdown, estimate the split
- // Typically input is ~70% and output is ~30% for most conversations
- let finalPromptTokens = promptTokens;
- let finalCompletionTokens = completionTokens;
-
- if (totalTokens > 0 && promptTokens === 0 && completionTokens === 0) {
- // Estimate: assume 70% input, 30% output
- finalPromptTokens = Math.round(totalTokens * 0.7);
- finalCompletionTokens = Math.round(totalTokens * 0.3);
- }
-
- response.usageMetadata = {
- promptTokenCount: finalPromptTokens,
- candidatesTokenCount: finalCompletionTokens,
- totalTokenCount: totalTokens,
- cachedContentTokenCount: cachedTokens,
- };
- }
-
- return response;
- }
-
- /**
- * Build sampling parameters with clear priority:
- * 1. Config-level sampling parameters (highest priority)
- * 2. Request-level parameters (medium priority)
- * 3. Default values (lowest priority)
- */
- private buildSamplingParameters(
- request: GenerateContentParameters,
- ): Record {
- const configSamplingParams = this.contentGeneratorConfig.samplingParams;
-
- const params = {
- // Temperature: config > request > default
- temperature:
- configSamplingParams?.temperature !== undefined
- ? configSamplingParams.temperature
- : request.config?.temperature !== undefined
- ? request.config.temperature
- : 0.0,
-
- // Max tokens: config > request > undefined
- ...(configSamplingParams?.max_tokens !== undefined
- ? { max_tokens: configSamplingParams.max_tokens }
- : request.config?.maxOutputTokens !== undefined
- ? { max_tokens: request.config.maxOutputTokens }
- : {}),
-
- // Top-p: config > request > default
- top_p:
- configSamplingParams?.top_p !== undefined
- ? configSamplingParams.top_p
- : request.config?.topP !== undefined
- ? request.config.topP
- : 1.0,
-
- // Top-k: config only (not available in request)
- ...(configSamplingParams?.top_k !== undefined
- ? { top_k: configSamplingParams.top_k }
- : {}),
-
- // Repetition penalty: config only
- ...(configSamplingParams?.repetition_penalty !== undefined
- ? { repetition_penalty: configSamplingParams.repetition_penalty }
- : {}),
-
- // Presence penalty: config only
- ...(configSamplingParams?.presence_penalty !== undefined
- ? { presence_penalty: configSamplingParams.presence_penalty }
- : {}),
-
- // Frequency penalty: config only
- ...(configSamplingParams?.frequency_penalty !== undefined
- ? { frequency_penalty: configSamplingParams.frequency_penalty }
- : {}),
- };
-
- return params;
- }
-
- private mapFinishReason(openaiReason: string | null): FinishReason {
- if (!openaiReason) return FinishReason.FINISH_REASON_UNSPECIFIED;
- const mapping: Record = {
- stop: FinishReason.STOP,
- length: FinishReason.MAX_TOKENS,
- content_filter: FinishReason.SAFETY,
- function_call: FinishReason.STOP,
- tool_calls: FinishReason.STOP,
- };
- return mapping[openaiReason] || FinishReason.FINISH_REASON_UNSPECIFIED;
- }
-
- /**
- * Convert Gemini response format to OpenAI chat completion format for logging
- */
- private convertGeminiResponseToOpenAI(
- response: GenerateContentResponse,
- ): OpenAIResponseFormat {
- const candidate = response.candidates?.[0];
- const content = candidate?.content;
-
- let messageContent: string | null = null;
- const toolCalls: OpenAIToolCall[] = [];
-
- if (content?.parts) {
- const textParts: string[] = [];
-
- for (const part of content.parts) {
- if ('text' in part && part.text) {
- textParts.push(part.text);
- } else if ('functionCall' in part && part.functionCall) {
- toolCalls.push({
- id: part.functionCall.id || `call_${toolCalls.length}`,
- type: 'function' as const,
- function: {
- name: part.functionCall.name || '',
- arguments: JSON.stringify(part.functionCall.args || {}),
- },
- });
- }
- }
-
- messageContent = textParts.join('').trimEnd();
- }
-
- const choice: OpenAIChoice = {
- index: 0,
- message: {
- role: 'assistant',
- content: messageContent,
- },
- finish_reason: this.mapGeminiFinishReasonToOpenAI(
- candidate?.finishReason,
- ),
- };
-
- if (toolCalls.length > 0) {
- choice.message.tool_calls = toolCalls;
- }
-
- const openaiResponse: OpenAIResponseFormat = {
- id: response.responseId || `chatcmpl-${Date.now()}`,
- object: 'chat.completion',
- created: response.createTime
- ? Number(response.createTime)
- : Math.floor(Date.now() / 1000),
- model: this.model,
- choices: [choice],
- };
-
- // Add usage metadata if available
- if (response.usageMetadata) {
- openaiResponse.usage = {
- prompt_tokens: response.usageMetadata.promptTokenCount || 0,
- completion_tokens: response.usageMetadata.candidatesTokenCount || 0,
- total_tokens: response.usageMetadata.totalTokenCount || 0,
- };
-
- if (response.usageMetadata.cachedContentTokenCount) {
- openaiResponse.usage.prompt_tokens_details = {
- cached_tokens: response.usageMetadata.cachedContentTokenCount,
- };
- }
- }
-
- return openaiResponse;
- }
-
- /**
- * Map Gemini finish reasons to OpenAI finish reasons
- */
- private mapGeminiFinishReasonToOpenAI(geminiReason?: unknown): string {
- if (!geminiReason) return 'stop';
-
- switch (geminiReason) {
- case 'STOP':
- case 1: // FinishReason.STOP
- return 'stop';
- case 'MAX_TOKENS':
- case 2: // FinishReason.MAX_TOKENS
- return 'length';
- case 'SAFETY':
- case 3: // FinishReason.SAFETY
- return 'content_filter';
- case 'RECITATION':
- case 4: // FinishReason.RECITATION
- return 'content_filter';
- case 'OTHER':
- case 5: // FinishReason.OTHER
- return 'stop';
- default:
- return 'stop';
- }
- }
-}
diff --git a/packages/core/src/core/openaiContentGenerator/openaiContentGenerator.test.ts b/packages/core/src/core/openaiContentGenerator/openaiContentGenerator.test.ts
index 60c85d23..3d1a516c 100644
--- a/packages/core/src/core/openaiContentGenerator/openaiContentGenerator.test.ts
+++ b/packages/core/src/core/openaiContentGenerator/openaiContentGenerator.test.ts
@@ -5,6 +5,37 @@
*/
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
+
+// Mock the request tokenizer module BEFORE importing the class that uses it
+const mockTokenizer = {
+ calculateTokens: vi.fn().mockResolvedValue({
+ totalTokens: 50,
+ breakdown: {
+ textTokens: 50,
+ imageTokens: 0,
+ audioTokens: 0,
+ otherTokens: 0,
+ },
+ processingTime: 1,
+ }),
+ dispose: vi.fn(),
+};
+
+vi.mock('../../../utils/request-tokenizer/index.js', () => ({
+ getDefaultTokenizer: vi.fn(() => mockTokenizer),
+ DefaultRequestTokenizer: vi.fn(() => mockTokenizer),
+ disposeDefaultTokenizer: vi.fn(),
+}));
+
+// Mock tiktoken as well for completeness
+vi.mock('tiktoken', () => ({
+ get_encoding: vi.fn(() => ({
+ encode: vi.fn(() => new Array(50)), // Mock 50 tokens
+ free: vi.fn(),
+ })),
+}));
+
+// Now import the modules that depend on the mocked modules
import { OpenAIContentGenerator } from './openaiContentGenerator.js';
import type { Config } from '../../config/config.js';
import { AuthType } from '../contentGenerator.js';
@@ -15,14 +46,6 @@ import type {
import type { OpenAICompatibleProvider } from './provider/index.js';
import type OpenAI from 'openai';
-// Mock tiktoken
-vi.mock('tiktoken', () => ({
- get_encoding: vi.fn().mockReturnValue({
- encode: vi.fn().mockReturnValue(new Array(50)), // Mock 50 tokens
- free: vi.fn(),
- }),
-}));
-
describe('OpenAIContentGenerator (Refactored)', () => {
let generator: OpenAIContentGenerator;
let mockConfig: Config;
diff --git a/packages/core/src/core/openaiContentGenerator/openaiContentGenerator.ts b/packages/core/src/core/openaiContentGenerator/openaiContentGenerator.ts
index fa738af3..91e69527 100644
--- a/packages/core/src/core/openaiContentGenerator/openaiContentGenerator.ts
+++ b/packages/core/src/core/openaiContentGenerator/openaiContentGenerator.ts
@@ -13,6 +13,7 @@ import type { PipelineConfig } from './pipeline.js';
import { ContentGenerationPipeline } from './pipeline.js';
import { DefaultTelemetryService } from './telemetryService.js';
import { EnhancedErrorHandler } from './errorHandler.js';
+import { getDefaultTokenizer } from '../../utils/request-tokenizer/index.js';
import type { ContentGeneratorConfig } from '../contentGenerator.js';
export class OpenAIContentGenerator implements ContentGenerator {
@@ -71,27 +72,30 @@ export class OpenAIContentGenerator implements ContentGenerator {
async countTokens(
request: CountTokensParameters,
): Promise {
- // Use tiktoken for accurate token counting
- const content = JSON.stringify(request.contents);
- let totalTokens = 0;
-
try {
- const { get_encoding } = await import('tiktoken');
- const encoding = get_encoding('cl100k_base'); // GPT-4 encoding, but estimate for qwen
- totalTokens = encoding.encode(content).length;
- encoding.free();
+ // Use the new high-performance request tokenizer
+ const tokenizer = getDefaultTokenizer();
+ const result = await tokenizer.calculateTokens(request, {
+ textEncoding: 'cl100k_base', // Use GPT-4 encoding for consistency
+ });
+
+ return {
+ totalTokens: result.totalTokens,
+ };
} catch (error) {
console.warn(
- 'Failed to load tiktoken, falling back to character approximation:',
+ 'Failed to calculate tokens with new tokenizer, falling back to simple method:',
error,
);
- // Fallback: rough approximation using character count
- totalTokens = Math.ceil(content.length / 4); // Rough estimate: 1 token ≈ 4 characters
- }
- return {
- totalTokens,
- };
+ // Fallback to original simple method
+ const content = JSON.stringify(request.contents);
+ const totalTokens = Math.ceil(content.length / 4); // Rough estimate: 1 token ≈ 4 characters
+
+ return {
+ totalTokens,
+ };
+ }
}
async embedContent(
diff --git a/packages/core/src/core/openaiContentGenerator/pipeline.ts b/packages/core/src/core/openaiContentGenerator/pipeline.ts
index bf4b3892..85d279e6 100644
--- a/packages/core/src/core/openaiContentGenerator/pipeline.ts
+++ b/packages/core/src/core/openaiContentGenerator/pipeline.ts
@@ -10,14 +10,11 @@ import {
GenerateContentResponse,
} from '@google/genai';
import type { Config } from '../../config/config.js';
-import { type ContentGeneratorConfig } from '../contentGenerator.js';
-import { type OpenAICompatibleProvider } from './provider/index.js';
+import type { ContentGeneratorConfig } from '../contentGenerator.js';
+import type { OpenAICompatibleProvider } from './provider/index.js';
import { OpenAIContentConverter } from './converter.js';
-import {
- type TelemetryService,
- type RequestContext,
-} from './telemetryService.js';
-import { type ErrorHandler } from './errorHandler.js';
+import type { TelemetryService, RequestContext } from './telemetryService.js';
+import type { ErrorHandler } from './errorHandler.js';
export interface PipelineConfig {
cliConfig: Config;
@@ -101,7 +98,7 @@ export class ContentGenerationPipeline {
* 2. Filter empty responses
* 3. Handle chunk merging for providers that send finishReason and usageMetadata separately
* 4. Collect both formats for logging
- * 5. Handle success/error logging with original OpenAI format
+ * 5. Handle success/error logging
*/
private async *processStreamWithLogging(
stream: AsyncIterable,
@@ -169,19 +166,11 @@ export class ContentGenerationPipeline {
collectedOpenAIChunks,
);
} catch (error) {
- // Stage 2e: Stream failed - handle error and logging
- context.duration = Date.now() - context.startTime;
-
// Clear streaming tool calls on error to prevent data pollution
this.converter.resetStreamingToolCalls();
- await this.config.telemetryService.logError(
- context,
- error,
- openaiRequest,
- );
-
- this.config.errorHandler.handle(error, context, request);
+ // Use shared error handling logic
+ await this.handleError(error, context, request);
}
}
@@ -365,25 +354,59 @@ export class ContentGenerationPipeline {
context.duration = Date.now() - context.startTime;
return result;
} catch (error) {
- context.duration = Date.now() - context.startTime;
-
- // Log error
- const openaiRequest = await this.buildRequest(
+ // Use shared error handling logic
+ return await this.handleError(
+ error,
+ context,
request,
userPromptId,
isStreaming,
);
- await this.config.telemetryService.logError(
- context,
- error,
- openaiRequest,
- );
-
- // Handle and throw enhanced error
- this.config.errorHandler.handle(error, context, request);
}
}
+ /**
+ * Shared error handling logic for both executeWithErrorHandling and processStreamWithLogging
+ * This centralizes the common error processing steps to avoid duplication
+ */
+ private async handleError(
+ error: unknown,
+ context: RequestContext,
+ request: GenerateContentParameters,
+ userPromptId?: string,
+ isStreaming?: boolean,
+ ): Promise {
+ context.duration = Date.now() - context.startTime;
+
+ // Build request for logging (may fail, but we still want to log the error)
+ let openaiRequest: OpenAI.Chat.ChatCompletionCreateParams;
+ try {
+ if (userPromptId !== undefined && isStreaming !== undefined) {
+ openaiRequest = await this.buildRequest(
+ request,
+ userPromptId,
+ isStreaming,
+ );
+ } else {
+ // For processStreamWithLogging, we don't have userPromptId/isStreaming,
+ // so create a minimal request
+ openaiRequest = {
+ model: this.contentGeneratorConfig.model,
+ messages: [],
+ };
+ }
+ } catch (_buildError) {
+ // If we can't build the request, create a minimal one for logging
+ openaiRequest = {
+ model: this.contentGeneratorConfig.model,
+ messages: [],
+ };
+ }
+
+ await this.config.telemetryService.logError(context, error, openaiRequest);
+ this.config.errorHandler.handle(error, context, request);
+ }
+
/**
* Create request context with common properties
*/
diff --git a/packages/core/src/core/openaiContentGenerator/provider/dashscope.ts b/packages/core/src/core/openaiContentGenerator/provider/dashscope.ts
index 0999052f..86cb54c0 100644
--- a/packages/core/src/core/openaiContentGenerator/provider/dashscope.ts
+++ b/packages/core/src/core/openaiContentGenerator/provider/dashscope.ts
@@ -79,6 +79,16 @@ export class DashScopeOpenAICompatibleProvider
messages = this.addDashScopeCacheControl(messages, cacheTarget);
}
+ if (request.model.startsWith('qwen-vl')) {
+ return {
+ ...request,
+ messages,
+ ...(this.buildMetadata(userPromptId) || {}),
+ /* @ts-expect-error dashscope exclusive */
+ vl_high_resolution_images: true,
+ };
+ }
+
return {
...request, // Preserve all original parameters including sampling params
messages,
diff --git a/packages/core/src/core/tokenLimits.ts b/packages/core/src/core/tokenLimits.ts
index e51becab..2e502037 100644
--- a/packages/core/src/core/tokenLimits.ts
+++ b/packages/core/src/core/tokenLimits.ts
@@ -116,6 +116,9 @@ const PATTERNS: Array<[RegExp, TokenCount]> = [
[/^qwen-flash-latest$/, LIMITS['1m']],
[/^qwen-turbo.*$/, LIMITS['128k']],
+ // Qwen Vision Models
+ [/^qwen-vl-max.*$/, LIMITS['128k']],
+
// -------------------
// ByteDance Seed-OSS (512K)
// -------------------
diff --git a/packages/core/src/core/turn.test.ts b/packages/core/src/core/turn.test.ts
index 87be432b..55bcfa0d 100644
--- a/packages/core/src/core/turn.test.ts
+++ b/packages/core/src/core/turn.test.ts
@@ -242,7 +242,7 @@ describe('Turn', () => {
expect(turn.getDebugResponses().length).toBe(0);
expect(reportError).toHaveBeenCalledWith(
error,
- 'Error when talking to Gemini API',
+ 'Error when talking to API',
[...historyContent, reqParts],
'Turn.run-sendMessageStream',
);
diff --git a/packages/core/src/core/turn.ts b/packages/core/src/core/turn.ts
index 8dd8377c..ad6f8319 100644
--- a/packages/core/src/core/turn.ts
+++ b/packages/core/src/core/turn.ts
@@ -310,7 +310,7 @@ export class Turn {
const contextForReport = [...this.chat.getHistory(/*curated*/ true), req];
await reportError(
error,
- 'Error when talking to Gemini API',
+ 'Error when talking to API',
contextForReport,
'Turn.run-sendMessageStream',
);
diff --git a/packages/core/src/qwen/qwenContentGenerator.test.ts b/packages/core/src/qwen/qwenContentGenerator.test.ts
index 8992e829..8efdd530 100644
--- a/packages/core/src/qwen/qwenContentGenerator.test.ts
+++ b/packages/core/src/qwen/qwenContentGenerator.test.ts
@@ -401,11 +401,9 @@ describe('QwenContentGenerator', () => {
expect(mockQwenClient.getAccessToken).toHaveBeenCalled();
});
- it('should count tokens with valid token', async () => {
- vi.mocked(mockQwenClient.getAccessToken).mockResolvedValue({
- token: 'valid-token',
- });
- vi.mocked(mockQwenClient.getCredentials).mockReturnValue(mockCredentials);
+ it('should count tokens without requiring authentication', async () => {
+ // Clear any previous mock calls
+ vi.clearAllMocks();
const request: CountTokensParameters = {
model: 'qwen-turbo',
@@ -415,7 +413,8 @@ describe('QwenContentGenerator', () => {
const result = await qwenContentGenerator.countTokens(request);
expect(result.totalTokens).toBe(15);
- expect(mockQwenClient.getAccessToken).toHaveBeenCalled();
+ // countTokens is a local operation and should not require OAuth credentials
+ expect(mockQwenClient.getAccessToken).not.toHaveBeenCalled();
});
it('should embed content with valid token', async () => {
@@ -1652,7 +1651,7 @@ describe('QwenContentGenerator', () => {
SharedTokenManager.getInstance = originalGetInstance;
});
- it('should handle all method types with token failure', async () => {
+ it('should handle method types with token failure (except countTokens)', async () => {
const mockTokenManager = {
getValidCredentials: vi
.fn()
@@ -1685,7 +1684,7 @@ describe('QwenContentGenerator', () => {
contents: [{ parts: [{ text: 'Embed' }] }],
};
- // All methods should fail with the same error
+ // Methods requiring authentication should fail
await expect(
newGenerator.generateContent(generateRequest, 'test-id'),
).rejects.toThrow('Failed to obtain valid Qwen access token');
@@ -1694,14 +1693,14 @@ describe('QwenContentGenerator', () => {
newGenerator.generateContentStream(generateRequest, 'test-id'),
).rejects.toThrow('Failed to obtain valid Qwen access token');
- await expect(newGenerator.countTokens(countRequest)).rejects.toThrow(
- 'Failed to obtain valid Qwen access token',
- );
-
await expect(newGenerator.embedContent(embedRequest)).rejects.toThrow(
'Failed to obtain valid Qwen access token',
);
+ // countTokens should succeed as it's a local operation
+ const countResult = await newGenerator.countTokens(countRequest);
+ expect(countResult.totalTokens).toBe(15);
+
SharedTokenManager.getInstance = originalGetInstance;
});
});
diff --git a/packages/core/src/qwen/qwenContentGenerator.ts b/packages/core/src/qwen/qwenContentGenerator.ts
index 1795903e..0e3ca12e 100644
--- a/packages/core/src/qwen/qwenContentGenerator.ts
+++ b/packages/core/src/qwen/qwenContentGenerator.ts
@@ -180,9 +180,7 @@ export class QwenContentGenerator extends OpenAIContentGenerator {
override async countTokens(
request: CountTokensParameters,
): Promise {
- return this.executeWithCredentialManagement(() =>
- super.countTokens(request),
- );
+ return super.countTokens(request);
}
/**
diff --git a/packages/core/src/utils/request-tokenizer/imageTokenizer.test.ts b/packages/core/src/utils/request-tokenizer/imageTokenizer.test.ts
new file mode 100644
index 00000000..cdb5f35f
--- /dev/null
+++ b/packages/core/src/utils/request-tokenizer/imageTokenizer.test.ts
@@ -0,0 +1,157 @@
+/**
+ * @license
+ * Copyright 2025 Qwen
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+import { describe, it, expect } from 'vitest';
+import { ImageTokenizer } from './imageTokenizer.js';
+
+describe('ImageTokenizer', () => {
+ const tokenizer = new ImageTokenizer();
+
+ describe('token calculation', () => {
+ it('should calculate tokens based on image dimensions with reference logic', () => {
+ const metadata = {
+ width: 28,
+ height: 28,
+ mimeType: 'image/png',
+ dataSize: 1000,
+ };
+
+ const tokens = tokenizer.calculateTokens(metadata);
+
+ // 28x28 = 784 pixels = 1 image token + 2 special tokens = 3 total
+ // But minimum scaling may apply for small images
+ expect(tokens).toBeGreaterThanOrEqual(6); // Minimum after scaling + special tokens
+ });
+
+ it('should calculate tokens for larger images', () => {
+ const metadata = {
+ width: 512,
+ height: 512,
+ mimeType: 'image/png',
+ dataSize: 10000,
+ };
+
+ const tokens = tokenizer.calculateTokens(metadata);
+
+ // 512x512 with reference logic: rounded dimensions + scaling + special tokens
+ expect(tokens).toBeGreaterThan(300);
+ expect(tokens).toBeLessThan(400); // Should be reasonable for 512x512
+ });
+
+ it('should enforce minimum tokens per image with scaling', () => {
+ const metadata = {
+ width: 1,
+ height: 1,
+ mimeType: 'image/png',
+ dataSize: 100,
+ };
+
+ const tokens = tokenizer.calculateTokens(metadata);
+
+ // Tiny images get scaled up to minimum pixels + special tokens
+ expect(tokens).toBeGreaterThanOrEqual(6); // 4 image tokens + 2 special tokens
+ });
+
+ it('should handle very large images with scaling', () => {
+ const metadata = {
+ width: 8192,
+ height: 8192,
+ mimeType: 'image/png',
+ dataSize: 100000,
+ };
+
+ const tokens = tokenizer.calculateTokens(metadata);
+
+ // Very large images should be scaled down to max limit + special tokens
+ expect(tokens).toBeLessThanOrEqual(16386); // 16384 max + 2 special tokens
+ expect(tokens).toBeGreaterThan(16000); // Should be close to the limit
+ });
+ });
+
+ describe('PNG dimension extraction', () => {
+ it('should extract dimensions from valid PNG', async () => {
+ // 1x1 PNG image in base64
+ const pngBase64 =
+ 'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNkYPhfDwAChAI9jU77yQAAAABJRU5ErkJggg==';
+
+ const metadata = await tokenizer.extractImageMetadata(
+ pngBase64,
+ 'image/png',
+ );
+
+ expect(metadata.width).toBe(1);
+ expect(metadata.height).toBe(1);
+ expect(metadata.mimeType).toBe('image/png');
+ });
+
+ it('should handle invalid PNG gracefully', async () => {
+ const invalidBase64 = 'invalid-png-data';
+
+ const metadata = await tokenizer.extractImageMetadata(
+ invalidBase64,
+ 'image/png',
+ );
+
+ // Should return default dimensions
+ expect(metadata.width).toBe(512);
+ expect(metadata.height).toBe(512);
+ expect(metadata.mimeType).toBe('image/png');
+ });
+ });
+
+ describe('batch processing', () => {
+ it('should process multiple images serially', async () => {
+ const pngBase64 =
+ 'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNkYPhfDwAChAI9jU77yQAAAABJRU5ErkJggg==';
+
+ const images = [
+ { data: pngBase64, mimeType: 'image/png' },
+ { data: pngBase64, mimeType: 'image/png' },
+ { data: pngBase64, mimeType: 'image/png' },
+ ];
+
+ const tokens = await tokenizer.calculateTokensBatch(images);
+
+ expect(tokens).toHaveLength(3);
+ expect(tokens.every((t) => t >= 4)).toBe(true); // All should have at least 4 tokens
+ });
+
+ it('should handle mixed valid and invalid images', async () => {
+ const validPng =
+ 'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNkYPhfDwAChAI9jU77yQAAAABJRU5ErkJggg==';
+ const invalidPng = 'invalid-data';
+
+ const images = [
+ { data: validPng, mimeType: 'image/png' },
+ { data: invalidPng, mimeType: 'image/png' },
+ ];
+
+ const tokens = await tokenizer.calculateTokensBatch(images);
+
+ expect(tokens).toHaveLength(2);
+ expect(tokens.every((t) => t >= 4)).toBe(true); // All should have at least minimum tokens
+ });
+ });
+
+ describe('different image formats', () => {
+ it('should handle different MIME types', async () => {
+ const pngBase64 =
+ 'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNkYPhfDwAChAI9jU77yQAAAABJRU5ErkJggg==';
+
+ const formats = ['image/png', 'image/jpeg', 'image/webp', 'image/gif'];
+
+ for (const mimeType of formats) {
+ const metadata = await tokenizer.extractImageMetadata(
+ pngBase64,
+ mimeType,
+ );
+ expect(metadata.mimeType).toBe(mimeType);
+ expect(metadata.width).toBeGreaterThan(0);
+ expect(metadata.height).toBeGreaterThan(0);
+ }
+ });
+ });
+});
diff --git a/packages/core/src/utils/request-tokenizer/imageTokenizer.ts b/packages/core/src/utils/request-tokenizer/imageTokenizer.ts
new file mode 100644
index 00000000..b55c6b9e
--- /dev/null
+++ b/packages/core/src/utils/request-tokenizer/imageTokenizer.ts
@@ -0,0 +1,505 @@
+/**
+ * @license
+ * Copyright 2025 Qwen
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+import type { ImageMetadata } from './types.js';
+import { isSupportedImageMimeType } from './supportedImageFormats.js';
+
+/**
+ * Image tokenizer for calculating image tokens based on dimensions
+ *
+ * Key rules:
+ * - 28x28 pixels = 1 token
+ * - Minimum: 4 tokens per image
+ * - Maximum: 16384 tokens per image
+ * - Additional: 2 special tokens (vision_bos + vision_eos)
+ * - Supports: PNG, JPEG, WebP, GIF, BMP, TIFF, HEIC formats
+ */
+export class ImageTokenizer {
+ /** 28x28 pixels = 1 token */
+ private static readonly PIXELS_PER_TOKEN = 28 * 28;
+
+ /** Minimum tokens per image */
+ private static readonly MIN_TOKENS_PER_IMAGE = 4;
+
+ /** Maximum tokens per image */
+ private static readonly MAX_TOKENS_PER_IMAGE = 16384;
+
+ /** Special tokens for vision markers */
+ private static readonly VISION_SPECIAL_TOKENS = 2;
+
+ /**
+ * Extract image metadata from base64 data
+ *
+ * @param base64Data Base64-encoded image data (with or without data URL prefix)
+ * @param mimeType MIME type of the image
+ * @returns Promise resolving to ImageMetadata with dimensions and format info
+ */
+ async extractImageMetadata(
+ base64Data: string,
+ mimeType: string,
+ ): Promise {
+ try {
+ // Check if the MIME type is supported
+ if (!isSupportedImageMimeType(mimeType)) {
+ console.warn(`Unsupported image format: ${mimeType}`);
+ // Return default metadata for unsupported formats
+ return {
+ width: 512,
+ height: 512,
+ mimeType,
+ dataSize: Math.floor(base64Data.length * 0.75),
+ };
+ }
+
+ const cleanBase64 = base64Data.replace(/^data:[^;]+;base64,/, '');
+ const buffer = Buffer.from(cleanBase64, 'base64');
+ const dimensions = await this.extractDimensions(buffer, mimeType);
+
+ return {
+ width: dimensions.width,
+ height: dimensions.height,
+ mimeType,
+ dataSize: buffer.length,
+ };
+ } catch (error) {
+ console.warn('Failed to extract image metadata:', error);
+ // Return default metadata for fallback
+ return {
+ width: 512,
+ height: 512,
+ mimeType,
+ dataSize: Math.floor(base64Data.length * 0.75),
+ };
+ }
+ }
+
+ /**
+ * Extract image dimensions from buffer based on format
+ *
+ * @param buffer Binary image data buffer
+ * @param mimeType MIME type to determine parsing strategy
+ * @returns Promise resolving to width and height dimensions
+ */
+ private async extractDimensions(
+ buffer: Buffer,
+ mimeType: string,
+ ): Promise<{ width: number; height: number }> {
+ if (mimeType.includes('png')) {
+ return this.extractPngDimensions(buffer);
+ }
+
+ if (mimeType.includes('jpeg') || mimeType.includes('jpg')) {
+ return this.extractJpegDimensions(buffer);
+ }
+
+ if (mimeType.includes('webp')) {
+ return this.extractWebpDimensions(buffer);
+ }
+
+ if (mimeType.includes('gif')) {
+ return this.extractGifDimensions(buffer);
+ }
+
+ if (mimeType.includes('bmp')) {
+ return this.extractBmpDimensions(buffer);
+ }
+
+ if (mimeType.includes('tiff')) {
+ return this.extractTiffDimensions(buffer);
+ }
+
+ if (mimeType.includes('heic')) {
+ return this.extractHeicDimensions(buffer);
+ }
+
+ return { width: 512, height: 512 };
+ }
+
+ /**
+ * Extract PNG dimensions from IHDR chunk
+ * PNG signature: 89 50 4E 47 0D 0A 1A 0A
+ * Width/height at bytes 16-19 and 20-23 (big-endian)
+ */
+ private extractPngDimensions(buffer: Buffer): {
+ width: number;
+ height: number;
+ } {
+ if (buffer.length < 24) {
+ throw new Error('Invalid PNG: buffer too short');
+ }
+
+ // Verify PNG signature
+ const signature = buffer.subarray(0, 8);
+ const expectedSignature = Buffer.from([
+ 0x89, 0x50, 0x4e, 0x47, 0x0d, 0x0a, 0x1a, 0x0a,
+ ]);
+ if (!signature.equals(expectedSignature)) {
+ throw new Error('Invalid PNG signature');
+ }
+
+ const width = buffer.readUInt32BE(16);
+ const height = buffer.readUInt32BE(20);
+
+ return { width, height };
+ }
+
+ /**
+ * Extract JPEG dimensions from SOF (Start of Frame) markers
+ * JPEG starts with FF D8, SOF markers: 0xC0-0xC3, 0xC5-0xC7, 0xC9-0xCB, 0xCD-0xCF
+ * Dimensions at offset +5 (height) and +7 (width) from SOF marker
+ */
+ private extractJpegDimensions(buffer: Buffer): {
+ width: number;
+ height: number;
+ } {
+ if (buffer.length < 4 || buffer[0] !== 0xff || buffer[1] !== 0xd8) {
+ throw new Error('Invalid JPEG signature');
+ }
+
+ let offset = 2;
+
+ while (offset < buffer.length - 8) {
+ if (buffer[offset] !== 0xff) {
+ offset++;
+ continue;
+ }
+
+ const marker = buffer[offset + 1];
+
+ // SOF markers
+ if (
+ (marker >= 0xc0 && marker <= 0xc3) ||
+ (marker >= 0xc5 && marker <= 0xc7) ||
+ (marker >= 0xc9 && marker <= 0xcb) ||
+ (marker >= 0xcd && marker <= 0xcf)
+ ) {
+ const height = buffer.readUInt16BE(offset + 5);
+ const width = buffer.readUInt16BE(offset + 7);
+ return { width, height };
+ }
+
+ const segmentLength = buffer.readUInt16BE(offset + 2);
+ offset += 2 + segmentLength;
+ }
+
+ throw new Error('Could not find JPEG dimensions');
+ }
+
+ /**
+ * Extract WebP dimensions from RIFF container
+ * Supports VP8, VP8L, and VP8X formats
+ */
+ private extractWebpDimensions(buffer: Buffer): {
+ width: number;
+ height: number;
+ } {
+ if (buffer.length < 30) {
+ throw new Error('Invalid WebP: too short');
+ }
+
+ const riffSignature = buffer.subarray(0, 4).toString('ascii');
+ const webpSignature = buffer.subarray(8, 12).toString('ascii');
+
+ if (riffSignature !== 'RIFF' || webpSignature !== 'WEBP') {
+ throw new Error('Invalid WebP signature');
+ }
+
+ const format = buffer.subarray(12, 16).toString('ascii');
+
+ if (format === 'VP8 ') {
+ const width = buffer.readUInt16LE(26) & 0x3fff;
+ const height = buffer.readUInt16LE(28) & 0x3fff;
+ return { width, height };
+ } else if (format === 'VP8L') {
+ const bits = buffer.readUInt32LE(21);
+ const width = (bits & 0x3fff) + 1;
+ const height = ((bits >> 14) & 0x3fff) + 1;
+ return { width, height };
+ } else if (format === 'VP8X') {
+ const width = (buffer.readUInt32LE(24) & 0xffffff) + 1;
+ const height = (buffer.readUInt32LE(26) & 0xffffff) + 1;
+ return { width, height };
+ }
+
+ throw new Error('Unsupported WebP format');
+ }
+
+ /**
+ * Extract GIF dimensions from header
+ * Supports GIF87a and GIF89a formats
+ */
+ private extractGifDimensions(buffer: Buffer): {
+ width: number;
+ height: number;
+ } {
+ if (buffer.length < 10) {
+ throw new Error('Invalid GIF: too short');
+ }
+
+ const signature = buffer.subarray(0, 6).toString('ascii');
+ if (signature !== 'GIF87a' && signature !== 'GIF89a') {
+ throw new Error('Invalid GIF signature');
+ }
+
+ const width = buffer.readUInt16LE(6);
+ const height = buffer.readUInt16LE(8);
+
+ return { width, height };
+ }
+
+ /**
+ * Calculate tokens for an image based on its metadata
+ *
+ * @param metadata Image metadata containing width, height, and format info
+ * @returns Total token count including base image tokens and special tokens
+ */
+ calculateTokens(metadata: ImageMetadata): number {
+ return this.calculateTokensWithScaling(metadata.width, metadata.height);
+ }
+
+ /**
+ * Calculate tokens with scaling logic
+ *
+ * Steps:
+ * 1. Normalize to 28-pixel multiples
+ * 2. Scale large images down, small images up
+ * 3. Calculate tokens: pixels / 784 + 2 special tokens
+ *
+ * @param width Original image width in pixels
+ * @param height Original image height in pixels
+ * @returns Total token count for the image
+ */
+ private calculateTokensWithScaling(width: number, height: number): number {
+ // Normalize to 28-pixel multiples
+ let hBar = Math.round(height / 28) * 28;
+ let wBar = Math.round(width / 28) * 28;
+
+ // Define pixel boundaries
+ const minPixels =
+ ImageTokenizer.MIN_TOKENS_PER_IMAGE * ImageTokenizer.PIXELS_PER_TOKEN;
+ const maxPixels =
+ ImageTokenizer.MAX_TOKENS_PER_IMAGE * ImageTokenizer.PIXELS_PER_TOKEN;
+
+ // Apply scaling
+ if (hBar * wBar > maxPixels) {
+ // Scale down large images
+ const beta = Math.sqrt((height * width) / maxPixels);
+ hBar = Math.floor(height / beta / 28) * 28;
+ wBar = Math.floor(width / beta / 28) * 28;
+ } else if (hBar * wBar < minPixels) {
+ // Scale up small images
+ const beta = Math.sqrt(minPixels / (height * width));
+ hBar = Math.ceil((height * beta) / 28) * 28;
+ wBar = Math.ceil((width * beta) / 28) * 28;
+ }
+
+ // Calculate tokens
+ const imageTokens = Math.floor(
+ (hBar * wBar) / ImageTokenizer.PIXELS_PER_TOKEN,
+ );
+
+ return imageTokens + ImageTokenizer.VISION_SPECIAL_TOKENS;
+ }
+
+ /**
+ * Calculate tokens for multiple images serially
+ *
+ * @param base64DataArray Array of image data with MIME type information
+ * @returns Promise resolving to array of token counts in same order as input
+ */
+ async calculateTokensBatch(
+ base64DataArray: Array<{ data: string; mimeType: string }>,
+ ): Promise {
+ const results: number[] = [];
+
+ for (const { data, mimeType } of base64DataArray) {
+ try {
+ const metadata = await this.extractImageMetadata(data, mimeType);
+ results.push(this.calculateTokens(metadata));
+ } catch (error) {
+ console.warn('Error calculating tokens for image:', error);
+ // Return minimum tokens as fallback
+ results.push(
+ ImageTokenizer.MIN_TOKENS_PER_IMAGE +
+ ImageTokenizer.VISION_SPECIAL_TOKENS,
+ );
+ }
+ }
+
+ return results;
+ }
+
+ /**
+ * Extract BMP dimensions from header
+ * BMP signature: 42 4D (BM)
+ * Width/height at bytes 18-21 and 22-25 (little-endian)
+ */
+ private extractBmpDimensions(buffer: Buffer): {
+ width: number;
+ height: number;
+ } {
+ if (buffer.length < 26) {
+ throw new Error('Invalid BMP: buffer too short');
+ }
+
+ // Verify BMP signature
+ if (buffer[0] !== 0x42 || buffer[1] !== 0x4d) {
+ throw new Error('Invalid BMP signature');
+ }
+
+ const width = buffer.readUInt32LE(18);
+ const height = buffer.readUInt32LE(22);
+
+ return { width, height: Math.abs(height) }; // Height can be negative for top-down BMPs
+ }
+
+ /**
+ * Extract TIFF dimensions from IFD (Image File Directory)
+ * TIFF can be little-endian (II) or big-endian (MM)
+ * Width/height are stored in IFD entries with tags 0x0100 and 0x0101
+ */
+ private extractTiffDimensions(buffer: Buffer): {
+ width: number;
+ height: number;
+ } {
+ if (buffer.length < 8) {
+ throw new Error('Invalid TIFF: buffer too short');
+ }
+
+ // Check byte order
+ const byteOrder = buffer.subarray(0, 2).toString('ascii');
+ const isLittleEndian = byteOrder === 'II';
+ const isBigEndian = byteOrder === 'MM';
+
+ if (!isLittleEndian && !isBigEndian) {
+ throw new Error('Invalid TIFF byte order');
+ }
+
+ // Read magic number (should be 42)
+ const magic = isLittleEndian
+ ? buffer.readUInt16LE(2)
+ : buffer.readUInt16BE(2);
+ if (magic !== 42) {
+ throw new Error('Invalid TIFF magic number');
+ }
+
+ // Read IFD offset
+ const ifdOffset = isLittleEndian
+ ? buffer.readUInt32LE(4)
+ : buffer.readUInt32BE(4);
+
+ if (ifdOffset >= buffer.length) {
+ throw new Error('Invalid TIFF IFD offset');
+ }
+
+ // Read number of directory entries
+ const numEntries = isLittleEndian
+ ? buffer.readUInt16LE(ifdOffset)
+ : buffer.readUInt16BE(ifdOffset);
+
+ let width = 0;
+ let height = 0;
+
+ // Parse IFD entries
+ for (let i = 0; i < numEntries; i++) {
+ const entryOffset = ifdOffset + 2 + i * 12;
+
+ if (entryOffset + 12 > buffer.length) break;
+
+ const tag = isLittleEndian
+ ? buffer.readUInt16LE(entryOffset)
+ : buffer.readUInt16BE(entryOffset);
+
+ const type = isLittleEndian
+ ? buffer.readUInt16LE(entryOffset + 2)
+ : buffer.readUInt16BE(entryOffset + 2);
+
+ const value = isLittleEndian
+ ? buffer.readUInt32LE(entryOffset + 8)
+ : buffer.readUInt32BE(entryOffset + 8);
+
+ if (tag === 0x0100) {
+ // ImageWidth
+ width = type === 3 ? value : value; // SHORT or LONG
+ } else if (tag === 0x0101) {
+ // ImageLength (height)
+ height = type === 3 ? value : value; // SHORT or LONG
+ }
+
+ if (width > 0 && height > 0) break;
+ }
+
+ if (width === 0 || height === 0) {
+ throw new Error('Could not find TIFF dimensions');
+ }
+
+ return { width, height };
+ }
+
+ /**
+ * Extract HEIC dimensions from meta box
+ * HEIC is based on ISO Base Media File Format
+ * This is a simplified implementation that looks for 'ispe' (Image Spatial Extents) box
+ */
+ private extractHeicDimensions(buffer: Buffer): {
+ width: number;
+ height: number;
+ } {
+ if (buffer.length < 12) {
+ throw new Error('Invalid HEIC: buffer too short');
+ }
+
+ // Check for ftyp box with HEIC brand
+ const ftypBox = buffer.subarray(4, 8).toString('ascii');
+ if (ftypBox !== 'ftyp') {
+ throw new Error('Invalid HEIC: missing ftyp box');
+ }
+
+ const brand = buffer.subarray(8, 12).toString('ascii');
+ if (!['heic', 'heix', 'hevc', 'hevx'].includes(brand)) {
+ throw new Error('Invalid HEIC brand');
+ }
+
+ // Look for meta box and then ispe box
+ let offset = 0;
+ while (offset < buffer.length - 8) {
+ const boxSize = buffer.readUInt32BE(offset);
+ const boxType = buffer.subarray(offset + 4, offset + 8).toString('ascii');
+
+ if (boxType === 'meta') {
+ // Look for ispe box inside meta box
+ const metaOffset = offset + 8;
+ let innerOffset = metaOffset + 4; // Skip version and flags
+
+ while (innerOffset < offset + boxSize - 8) {
+ const innerBoxSize = buffer.readUInt32BE(innerOffset);
+ const innerBoxType = buffer
+ .subarray(innerOffset + 4, innerOffset + 8)
+ .toString('ascii');
+
+ if (innerBoxType === 'ispe') {
+ // Found Image Spatial Extents box
+ if (innerOffset + 20 <= buffer.length) {
+ const width = buffer.readUInt32BE(innerOffset + 12);
+ const height = buffer.readUInt32BE(innerOffset + 16);
+ return { width, height };
+ }
+ }
+
+ if (innerBoxSize === 0) break;
+ innerOffset += innerBoxSize;
+ }
+ }
+
+ if (boxSize === 0) break;
+ offset += boxSize;
+ }
+
+ // Fallback: return default dimensions if we can't parse the structure
+ console.warn('Could not extract HEIC dimensions, using default');
+ return { width: 512, height: 512 };
+ }
+}
diff --git a/packages/core/src/utils/request-tokenizer/index.ts b/packages/core/src/utils/request-tokenizer/index.ts
new file mode 100644
index 00000000..064b93c1
--- /dev/null
+++ b/packages/core/src/utils/request-tokenizer/index.ts
@@ -0,0 +1,40 @@
+/**
+ * @license
+ * Copyright 2025 Qwen
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+export { DefaultRequestTokenizer } from './requestTokenizer.js';
+import { DefaultRequestTokenizer } from './requestTokenizer.js';
+export { TextTokenizer } from './textTokenizer.js';
+export { ImageTokenizer } from './imageTokenizer.js';
+
+export type {
+ RequestTokenizer,
+ TokenizerConfig,
+ TokenCalculationResult,
+ ImageMetadata,
+} from './types.js';
+
+// Singleton instance for convenient usage
+let defaultTokenizer: DefaultRequestTokenizer | null = null;
+
+/**
+ * Get the default request tokenizer instance
+ */
+export function getDefaultTokenizer(): DefaultRequestTokenizer {
+ if (!defaultTokenizer) {
+ defaultTokenizer = new DefaultRequestTokenizer();
+ }
+ return defaultTokenizer;
+}
+
+/**
+ * Dispose of the default tokenizer instance
+ */
+export async function disposeDefaultTokenizer(): Promise {
+ if (defaultTokenizer) {
+ await defaultTokenizer.dispose();
+ defaultTokenizer = null;
+ }
+}
diff --git a/packages/core/src/utils/request-tokenizer/requestTokenizer.test.ts b/packages/core/src/utils/request-tokenizer/requestTokenizer.test.ts
new file mode 100644
index 00000000..cb69163b
--- /dev/null
+++ b/packages/core/src/utils/request-tokenizer/requestTokenizer.test.ts
@@ -0,0 +1,293 @@
+/**
+ * @license
+ * Copyright 2025 Qwen
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+import { describe, it, expect, beforeEach, afterEach } from 'vitest';
+import { DefaultRequestTokenizer } from './requestTokenizer.js';
+import type { CountTokensParameters } from '@google/genai';
+
+describe('DefaultRequestTokenizer', () => {
+ let tokenizer: DefaultRequestTokenizer;
+
+ beforeEach(() => {
+ tokenizer = new DefaultRequestTokenizer();
+ });
+
+ afterEach(async () => {
+ await tokenizer.dispose();
+ });
+
+ describe('text token calculation', () => {
+ it('should calculate tokens for simple text content', async () => {
+ const request: CountTokensParameters = {
+ model: 'test-model',
+ contents: [
+ {
+ role: 'user',
+ parts: [{ text: 'Hello, world!' }],
+ },
+ ],
+ };
+
+ const result = await tokenizer.calculateTokens(request);
+
+ expect(result.totalTokens).toBeGreaterThan(0);
+ expect(result.breakdown.textTokens).toBeGreaterThan(0);
+ expect(result.breakdown.imageTokens).toBe(0);
+ expect(result.processingTime).toBeGreaterThan(0);
+ });
+
+ it('should handle multiple text parts', async () => {
+ const request: CountTokensParameters = {
+ model: 'test-model',
+ contents: [
+ {
+ role: 'user',
+ parts: [
+ { text: 'First part' },
+ { text: 'Second part' },
+ { text: 'Third part' },
+ ],
+ },
+ ],
+ };
+
+ const result = await tokenizer.calculateTokens(request);
+
+ expect(result.totalTokens).toBeGreaterThan(0);
+ expect(result.breakdown.textTokens).toBeGreaterThan(0);
+ });
+
+ it('should handle string content', async () => {
+ const request: CountTokensParameters = {
+ model: 'test-model',
+ contents: ['Simple string content'],
+ };
+
+ const result = await tokenizer.calculateTokens(request);
+
+ expect(result.totalTokens).toBeGreaterThan(0);
+ expect(result.breakdown.textTokens).toBeGreaterThan(0);
+ });
+ });
+
+ describe('image token calculation', () => {
+ it('should calculate tokens for image content', async () => {
+ // Create a simple 1x1 PNG image in base64
+ const pngBase64 =
+ 'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNkYPhfDwAChAI9jU77yQAAAABJRU5ErkJggg==';
+
+ const request: CountTokensParameters = {
+ model: 'test-model',
+ contents: [
+ {
+ role: 'user',
+ parts: [
+ {
+ inlineData: {
+ mimeType: 'image/png',
+ data: pngBase64,
+ },
+ },
+ ],
+ },
+ ],
+ };
+
+ const result = await tokenizer.calculateTokens(request);
+
+ expect(result.totalTokens).toBeGreaterThanOrEqual(4); // Minimum 4 tokens per image
+ expect(result.breakdown.imageTokens).toBeGreaterThanOrEqual(4);
+ expect(result.breakdown.textTokens).toBe(0);
+ });
+
+ it('should handle multiple images', async () => {
+ const pngBase64 =
+ 'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNkYPhfDwAChAI9jU77yQAAAABJRU5ErkJggg==';
+
+ const request: CountTokensParameters = {
+ model: 'test-model',
+ contents: [
+ {
+ role: 'user',
+ parts: [
+ {
+ inlineData: {
+ mimeType: 'image/png',
+ data: pngBase64,
+ },
+ },
+ {
+ inlineData: {
+ mimeType: 'image/png',
+ data: pngBase64,
+ },
+ },
+ ],
+ },
+ ],
+ };
+
+ const result = await tokenizer.calculateTokens(request);
+
+ expect(result.totalTokens).toBeGreaterThanOrEqual(8); // At least 4 tokens per image
+ expect(result.breakdown.imageTokens).toBeGreaterThanOrEqual(8);
+ });
+ });
+
+ describe('mixed content', () => {
+ it('should handle text and image content together', async () => {
+ const pngBase64 =
+ 'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNkYPhfDwAChAI9jU77yQAAAABJRU5ErkJggg==';
+
+ const request: CountTokensParameters = {
+ model: 'test-model',
+ contents: [
+ {
+ role: 'user',
+ parts: [
+ { text: 'Here is an image:' },
+ {
+ inlineData: {
+ mimeType: 'image/png',
+ data: pngBase64,
+ },
+ },
+ { text: 'What do you see?' },
+ ],
+ },
+ ],
+ };
+
+ const result = await tokenizer.calculateTokens(request);
+
+ expect(result.totalTokens).toBeGreaterThan(4);
+ expect(result.breakdown.textTokens).toBeGreaterThan(0);
+ expect(result.breakdown.imageTokens).toBeGreaterThanOrEqual(4);
+ });
+ });
+
+ describe('function content', () => {
+ it('should handle function calls', async () => {
+ const request: CountTokensParameters = {
+ model: 'test-model',
+ contents: [
+ {
+ role: 'user',
+ parts: [
+ {
+ functionCall: {
+ name: 'test_function',
+ args: { param1: 'value1', param2: 42 },
+ },
+ },
+ ],
+ },
+ ],
+ };
+
+ const result = await tokenizer.calculateTokens(request);
+
+ expect(result.totalTokens).toBeGreaterThan(0);
+ expect(result.breakdown.otherTokens).toBeGreaterThan(0);
+ });
+ });
+
+ describe('empty content', () => {
+ it('should handle empty request', async () => {
+ const request: CountTokensParameters = {
+ model: 'test-model',
+ contents: [],
+ };
+
+ const result = await tokenizer.calculateTokens(request);
+
+ expect(result.totalTokens).toBe(0);
+ expect(result.breakdown.textTokens).toBe(0);
+ expect(result.breakdown.imageTokens).toBe(0);
+ });
+
+ it('should handle undefined contents', async () => {
+ const request: CountTokensParameters = {
+ model: 'test-model',
+ contents: [],
+ };
+
+ const result = await tokenizer.calculateTokens(request);
+
+ expect(result.totalTokens).toBe(0);
+ });
+ });
+
+ describe('configuration', () => {
+ it('should use custom text encoding', async () => {
+ const request: CountTokensParameters = {
+ model: 'test-model',
+ contents: [
+ {
+ role: 'user',
+ parts: [{ text: 'Test text for encoding' }],
+ },
+ ],
+ };
+
+ const result = await tokenizer.calculateTokens(request, {
+ textEncoding: 'cl100k_base',
+ });
+
+ expect(result.totalTokens).toBeGreaterThan(0);
+ });
+
+ it('should process multiple images serially', async () => {
+ const pngBase64 =
+ 'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNkYPhfDwAChAI9jU77yQAAAABJRU5ErkJggg==';
+
+ const request: CountTokensParameters = {
+ model: 'test-model',
+ contents: [
+ {
+ role: 'user',
+ parts: Array(10).fill({
+ inlineData: {
+ mimeType: 'image/png',
+ data: pngBase64,
+ },
+ }),
+ },
+ ],
+ };
+
+ const result = await tokenizer.calculateTokens(request);
+
+ expect(result.totalTokens).toBeGreaterThanOrEqual(60); // At least 6 tokens per image * 10 images
+ });
+ });
+
+ describe('error handling', () => {
+ it('should handle malformed image data gracefully', async () => {
+ const request: CountTokensParameters = {
+ model: 'test-model',
+ contents: [
+ {
+ role: 'user',
+ parts: [
+ {
+ inlineData: {
+ mimeType: 'image/png',
+ data: 'invalid-base64-data',
+ },
+ },
+ ],
+ },
+ ],
+ };
+
+ const result = await tokenizer.calculateTokens(request);
+
+ // Should still return some tokens (fallback to minimum)
+ expect(result.totalTokens).toBeGreaterThanOrEqual(4);
+ });
+ });
+});
diff --git a/packages/core/src/utils/request-tokenizer/requestTokenizer.ts b/packages/core/src/utils/request-tokenizer/requestTokenizer.ts
new file mode 100644
index 00000000..173bb261
--- /dev/null
+++ b/packages/core/src/utils/request-tokenizer/requestTokenizer.ts
@@ -0,0 +1,341 @@
+/**
+ * @license
+ * Copyright 2025 Qwen
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+import type {
+ CountTokensParameters,
+ Content,
+ Part,
+ PartUnion,
+} from '@google/genai';
+import type {
+ RequestTokenizer,
+ TokenizerConfig,
+ TokenCalculationResult,
+} from './types.js';
+import { TextTokenizer } from './textTokenizer.js';
+import { ImageTokenizer } from './imageTokenizer.js';
+
+/**
+ * Simple request tokenizer that handles text and image content serially
+ */
+export class DefaultRequestTokenizer implements RequestTokenizer {
+ private textTokenizer: TextTokenizer;
+ private imageTokenizer: ImageTokenizer;
+
+ constructor() {
+ this.textTokenizer = new TextTokenizer();
+ this.imageTokenizer = new ImageTokenizer();
+ }
+
+ /**
+ * Calculate tokens for a request using serial processing
+ */
+ async calculateTokens(
+ request: CountTokensParameters,
+ config: TokenizerConfig = {},
+ ): Promise {
+ const startTime = performance.now();
+
+ // Apply configuration
+ if (config.textEncoding) {
+ this.textTokenizer = new TextTokenizer(config.textEncoding);
+ }
+
+ try {
+ // Process request content and group by type
+ const { textContents, imageContents, audioContents, otherContents } =
+ this.processAndGroupContents(request);
+
+ if (
+ textContents.length === 0 &&
+ imageContents.length === 0 &&
+ audioContents.length === 0 &&
+ otherContents.length === 0
+ ) {
+ return {
+ totalTokens: 0,
+ breakdown: {
+ textTokens: 0,
+ imageTokens: 0,
+ audioTokens: 0,
+ otherTokens: 0,
+ },
+ processingTime: performance.now() - startTime,
+ };
+ }
+
+ // Calculate tokens for each content type serially
+ const textTokens = await this.calculateTextTokens(textContents);
+ const imageTokens = await this.calculateImageTokens(imageContents);
+ const audioTokens = await this.calculateAudioTokens(audioContents);
+ const otherTokens = await this.calculateOtherTokens(otherContents);
+
+ const totalTokens = textTokens + imageTokens + audioTokens + otherTokens;
+ const processingTime = performance.now() - startTime;
+
+ return {
+ totalTokens,
+ breakdown: {
+ textTokens,
+ imageTokens,
+ audioTokens,
+ otherTokens,
+ },
+ processingTime,
+ };
+ } catch (error) {
+ console.error('Error calculating tokens:', error);
+
+ // Fallback calculation
+ const fallbackTokens = this.calculateFallbackTokens(request);
+
+ return {
+ totalTokens: fallbackTokens,
+ breakdown: {
+ textTokens: fallbackTokens,
+ imageTokens: 0,
+ audioTokens: 0,
+ otherTokens: 0,
+ },
+ processingTime: performance.now() - startTime,
+ };
+ }
+ }
+
+ /**
+ * Calculate tokens for text contents
+ */
+ private async calculateTextTokens(textContents: string[]): Promise {
+ if (textContents.length === 0) return 0;
+
+ try {
+ const tokenCounts =
+ await this.textTokenizer.calculateTokensBatch(textContents);
+ return tokenCounts.reduce((sum, count) => sum + count, 0);
+ } catch (error) {
+ console.warn('Error calculating text tokens:', error);
+ // Fallback: character-based estimation
+ const totalChars = textContents.join('').length;
+ return Math.ceil(totalChars / 4);
+ }
+ }
+
+ /**
+ * Calculate tokens for image contents using serial processing
+ */
+ private async calculateImageTokens(
+ imageContents: Array<{ data: string; mimeType: string }>,
+ ): Promise {
+ if (imageContents.length === 0) return 0;
+
+ try {
+ const tokenCounts =
+ await this.imageTokenizer.calculateTokensBatch(imageContents);
+ return tokenCounts.reduce((sum, count) => sum + count, 0);
+ } catch (error) {
+ console.warn('Error calculating image tokens:', error);
+ // Fallback: minimum tokens per image
+ return imageContents.length * 6; // 4 image tokens + 2 special tokens as minimum
+ }
+ }
+
+ /**
+ * Calculate tokens for audio contents
+ * TODO: Implement proper audio token calculation
+ */
+ private async calculateAudioTokens(
+ audioContents: Array<{ data: string; mimeType: string }>,
+ ): Promise {
+ if (audioContents.length === 0) return 0;
+
+ // Placeholder implementation - audio token calculation would depend on
+ // the specific model's audio processing capabilities
+ // For now, estimate based on data size
+ let totalTokens = 0;
+
+ for (const audioContent of audioContents) {
+ try {
+ const dataSize = Math.floor(audioContent.data.length * 0.75); // Approximate binary size
+ // Rough estimate: 1 token per 100 bytes of audio data
+ totalTokens += Math.max(Math.ceil(dataSize / 100), 10); // Minimum 10 tokens per audio
+ } catch (error) {
+ console.warn('Error calculating audio tokens:', error);
+ totalTokens += 10; // Fallback minimum
+ }
+ }
+
+ return totalTokens;
+ }
+
+ /**
+ * Calculate tokens for other content types (functions, files, etc.)
+ */
+ private async calculateOtherTokens(otherContents: string[]): Promise {
+ if (otherContents.length === 0) return 0;
+
+ try {
+ // Treat other content as text for token calculation
+ const tokenCounts =
+ await this.textTokenizer.calculateTokensBatch(otherContents);
+ return tokenCounts.reduce((sum, count) => sum + count, 0);
+ } catch (error) {
+ console.warn('Error calculating other content tokens:', error);
+ // Fallback: character-based estimation
+ const totalChars = otherContents.join('').length;
+ return Math.ceil(totalChars / 4);
+ }
+ }
+
+ /**
+ * Fallback token calculation using simple string serialization
+ */
+ private calculateFallbackTokens(request: CountTokensParameters): number {
+ try {
+ const content = JSON.stringify(request.contents);
+ return Math.ceil(content.length / 4); // Rough estimate: 1 token ≈ 4 characters
+ } catch (error) {
+ console.warn('Error in fallback token calculation:', error);
+ return 100; // Conservative fallback
+ }
+ }
+
+ /**
+ * Process request contents and group by type
+ */
+ private processAndGroupContents(request: CountTokensParameters): {
+ textContents: string[];
+ imageContents: Array<{ data: string; mimeType: string }>;
+ audioContents: Array<{ data: string; mimeType: string }>;
+ otherContents: string[];
+ } {
+ const textContents: string[] = [];
+ const imageContents: Array<{ data: string; mimeType: string }> = [];
+ const audioContents: Array<{ data: string; mimeType: string }> = [];
+ const otherContents: string[] = [];
+
+ if (!request.contents) {
+ return { textContents, imageContents, audioContents, otherContents };
+ }
+
+ const contents = Array.isArray(request.contents)
+ ? request.contents
+ : [request.contents];
+
+ for (const content of contents) {
+ this.processContent(
+ content,
+ textContents,
+ imageContents,
+ audioContents,
+ otherContents,
+ );
+ }
+
+ return { textContents, imageContents, audioContents, otherContents };
+ }
+
+ /**
+ * Process a single content item and add to appropriate arrays
+ */
+ private processContent(
+ content: Content | string | PartUnion,
+ textContents: string[],
+ imageContents: Array<{ data: string; mimeType: string }>,
+ audioContents: Array<{ data: string; mimeType: string }>,
+ otherContents: string[],
+ ): void {
+ if (typeof content === 'string') {
+ if (content.trim()) {
+ textContents.push(content);
+ }
+ return;
+ }
+
+ if ('parts' in content && content.parts) {
+ for (const part of content.parts) {
+ this.processPart(
+ part,
+ textContents,
+ imageContents,
+ audioContents,
+ otherContents,
+ );
+ }
+ }
+ }
+
+ /**
+ * Process a single part and add to appropriate arrays
+ */
+ private processPart(
+ part: Part | string,
+ textContents: string[],
+ imageContents: Array<{ data: string; mimeType: string }>,
+ audioContents: Array<{ data: string; mimeType: string }>,
+ otherContents: string[],
+ ): void {
+ if (typeof part === 'string') {
+ if (part.trim()) {
+ textContents.push(part);
+ }
+ return;
+ }
+
+ if ('text' in part && part.text) {
+ textContents.push(part.text);
+ return;
+ }
+
+ if ('inlineData' in part && part.inlineData) {
+ const { data, mimeType } = part.inlineData;
+ if (mimeType && mimeType.startsWith('image/')) {
+ imageContents.push({ data: data || '', mimeType });
+ return;
+ }
+ if (mimeType && mimeType.startsWith('audio/')) {
+ audioContents.push({ data: data || '', mimeType });
+ return;
+ }
+ }
+
+ if ('fileData' in part && part.fileData) {
+ otherContents.push(JSON.stringify(part.fileData));
+ return;
+ }
+
+ if ('functionCall' in part && part.functionCall) {
+ otherContents.push(JSON.stringify(part.functionCall));
+ return;
+ }
+
+ if ('functionResponse' in part && part.functionResponse) {
+ otherContents.push(JSON.stringify(part.functionResponse));
+ return;
+ }
+
+ // Unknown part type - try to serialize
+ try {
+ const serialized = JSON.stringify(part);
+ if (serialized && serialized !== '{}') {
+ otherContents.push(serialized);
+ }
+ } catch (error) {
+ console.warn('Failed to serialize unknown part type:', error);
+ }
+ }
+
+ /**
+ * Dispose of resources
+ */
+ async dispose(): Promise {
+ try {
+ // Dispose of tokenizers
+ this.textTokenizer.dispose();
+ } catch (error) {
+ console.warn('Error disposing request tokenizer:', error);
+ }
+ }
+}
diff --git a/packages/core/src/utils/request-tokenizer/supportedImageFormats.ts b/packages/core/src/utils/request-tokenizer/supportedImageFormats.ts
new file mode 100644
index 00000000..fce679d7
--- /dev/null
+++ b/packages/core/src/utils/request-tokenizer/supportedImageFormats.ts
@@ -0,0 +1,56 @@
+/**
+ * @license
+ * Copyright 2025 Qwen
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+/**
+ * Supported image MIME types for vision models
+ * These formats are supported by the vision model and can be processed by the image tokenizer
+ */
+export const SUPPORTED_IMAGE_MIME_TYPES = [
+ 'image/bmp',
+ 'image/jpeg',
+ 'image/jpg', // Alternative MIME type for JPEG
+ 'image/png',
+ 'image/tiff',
+ 'image/webp',
+ 'image/heic',
+] as const;
+
+/**
+ * Type for supported image MIME types
+ */
+export type SupportedImageMimeType =
+ (typeof SUPPORTED_IMAGE_MIME_TYPES)[number];
+
+/**
+ * Check if a MIME type is supported for vision processing
+ * @param mimeType The MIME type to check
+ * @returns True if the MIME type is supported
+ */
+export function isSupportedImageMimeType(
+ mimeType: string,
+): mimeType is SupportedImageMimeType {
+ return SUPPORTED_IMAGE_MIME_TYPES.includes(
+ mimeType as SupportedImageMimeType,
+ );
+}
+
+/**
+ * Get a human-readable list of supported image formats
+ * @returns Comma-separated string of supported formats
+ */
+export function getSupportedImageFormatsString(): string {
+ return SUPPORTED_IMAGE_MIME_TYPES.map((type) =>
+ type.replace('image/', '').toUpperCase(),
+ ).join(', ');
+}
+
+/**
+ * Get warning message for unsupported image formats
+ * @returns Warning message string
+ */
+export function getUnsupportedImageFormatWarning(): string {
+ return `Only the following image formats are supported: ${getSupportedImageFormatsString()}. Other formats may not work as expected.`;
+}
diff --git a/packages/core/src/utils/request-tokenizer/textTokenizer.test.ts b/packages/core/src/utils/request-tokenizer/textTokenizer.test.ts
new file mode 100644
index 00000000..f29155a8
--- /dev/null
+++ b/packages/core/src/utils/request-tokenizer/textTokenizer.test.ts
@@ -0,0 +1,347 @@
+/**
+ * @license
+ * Copyright 2025 Qwen
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
+import { TextTokenizer } from './textTokenizer.js';
+
+// Mock tiktoken at the top level with hoisted functions
+const mockEncode = vi.hoisted(() => vi.fn());
+const mockFree = vi.hoisted(() => vi.fn());
+const mockGetEncoding = vi.hoisted(() => vi.fn());
+
+vi.mock('tiktoken', () => ({
+ get_encoding: mockGetEncoding,
+}));
+
+describe('TextTokenizer', () => {
+ let tokenizer: TextTokenizer;
+ let consoleWarnSpy: ReturnType;
+
+ beforeEach(() => {
+ vi.resetAllMocks();
+ consoleWarnSpy = vi.spyOn(console, 'warn').mockImplementation(() => {});
+
+ // Default mock implementation
+ mockGetEncoding.mockReturnValue({
+ encode: mockEncode,
+ free: mockFree,
+ });
+ });
+
+ afterEach(() => {
+ vi.restoreAllMocks();
+ tokenizer?.dispose();
+ });
+
+ describe('constructor', () => {
+ it('should create tokenizer with default encoding', () => {
+ tokenizer = new TextTokenizer();
+ expect(tokenizer).toBeInstanceOf(TextTokenizer);
+ });
+
+ it('should create tokenizer with custom encoding', () => {
+ tokenizer = new TextTokenizer('gpt2');
+ expect(tokenizer).toBeInstanceOf(TextTokenizer);
+ });
+ });
+
+ describe('calculateTokens', () => {
+ beforeEach(() => {
+ tokenizer = new TextTokenizer();
+ });
+
+ it('should return 0 for empty text', async () => {
+ const result = await tokenizer.calculateTokens('');
+ expect(result).toBe(0);
+ });
+
+ it('should return 0 for null/undefined text', async () => {
+ const result1 = await tokenizer.calculateTokens(
+ null as unknown as string,
+ );
+ const result2 = await tokenizer.calculateTokens(
+ undefined as unknown as string,
+ );
+ expect(result1).toBe(0);
+ expect(result2).toBe(0);
+ });
+
+ it('should calculate tokens using tiktoken when available', async () => {
+ const testText = 'Hello, world!';
+ const mockTokens = [1, 2, 3, 4, 5]; // 5 tokens
+ mockEncode.mockReturnValue(mockTokens);
+
+ const result = await tokenizer.calculateTokens(testText);
+
+ expect(mockGetEncoding).toHaveBeenCalledWith('cl100k_base');
+ expect(mockEncode).toHaveBeenCalledWith(testText);
+ expect(result).toBe(5);
+ });
+
+ it('should use fallback calculation when tiktoken fails to load', async () => {
+ mockGetEncoding.mockImplementation(() => {
+ throw new Error('Failed to load tiktoken');
+ });
+
+ const testText = 'Hello, world!'; // 13 characters
+ const result = await tokenizer.calculateTokens(testText);
+
+ expect(consoleWarnSpy).toHaveBeenCalledWith(
+ 'Failed to load tiktoken with encoding cl100k_base:',
+ expect.any(Error),
+ );
+ // Fallback: Math.ceil(13 / 4) = 4
+ expect(result).toBe(4);
+ });
+
+ it('should use fallback calculation when encoding fails', async () => {
+ mockEncode.mockImplementation(() => {
+ throw new Error('Encoding failed');
+ });
+
+ const testText = 'Hello, world!'; // 13 characters
+ const result = await tokenizer.calculateTokens(testText);
+
+ expect(consoleWarnSpy).toHaveBeenCalledWith(
+ 'Error encoding text with tiktoken:',
+ expect.any(Error),
+ );
+ // Fallback: Math.ceil(13 / 4) = 4
+ expect(result).toBe(4);
+ });
+
+ it('should handle very long text', async () => {
+ const longText = 'a'.repeat(10000);
+ const mockTokens = new Array(2500); // 2500 tokens
+ mockEncode.mockReturnValue(mockTokens);
+
+ const result = await tokenizer.calculateTokens(longText);
+
+ expect(result).toBe(2500);
+ });
+
+ it('should handle unicode characters', async () => {
+ const unicodeText = '你好世界 🌍';
+ const mockTokens = [1, 2, 3, 4, 5, 6];
+ mockEncode.mockReturnValue(mockTokens);
+
+ const result = await tokenizer.calculateTokens(unicodeText);
+
+ expect(result).toBe(6);
+ });
+
+ it('should use custom encoding when specified', async () => {
+ tokenizer = new TextTokenizer('gpt2');
+ const testText = 'Hello, world!';
+ const mockTokens = [1, 2, 3];
+ mockEncode.mockReturnValue(mockTokens);
+
+ const result = await tokenizer.calculateTokens(testText);
+
+ expect(mockGetEncoding).toHaveBeenCalledWith('gpt2');
+ expect(result).toBe(3);
+ });
+ });
+
+ describe('calculateTokensBatch', () => {
+ beforeEach(() => {
+ tokenizer = new TextTokenizer();
+ });
+
+ it('should process multiple texts and return token counts', async () => {
+ const texts = ['Hello', 'world', 'test'];
+ mockEncode
+ .mockReturnValueOnce([1, 2]) // 2 tokens for 'Hello'
+ .mockReturnValueOnce([3, 4, 5]) // 3 tokens for 'world'
+ .mockReturnValueOnce([6]); // 1 token for 'test'
+
+ const result = await tokenizer.calculateTokensBatch(texts);
+
+ expect(result).toEqual([2, 3, 1]);
+ expect(mockEncode).toHaveBeenCalledTimes(3);
+ });
+
+ it('should handle empty array', async () => {
+ const result = await tokenizer.calculateTokensBatch([]);
+ expect(result).toEqual([]);
+ });
+
+ it('should handle array with empty strings', async () => {
+ const texts = ['', 'hello', ''];
+ mockEncode.mockReturnValue([1, 2, 3]); // Only called for 'hello'
+
+ const result = await tokenizer.calculateTokensBatch(texts);
+
+ expect(result).toEqual([0, 3, 0]);
+ expect(mockEncode).toHaveBeenCalledTimes(1);
+ expect(mockEncode).toHaveBeenCalledWith('hello');
+ });
+
+ it('should use fallback calculation when tiktoken fails to load', async () => {
+ mockGetEncoding.mockImplementation(() => {
+ throw new Error('Failed to load tiktoken');
+ });
+
+ const texts = ['Hello', 'world']; // 5 and 5 characters
+ const result = await tokenizer.calculateTokensBatch(texts);
+
+ expect(consoleWarnSpy).toHaveBeenCalledWith(
+ 'Failed to load tiktoken with encoding cl100k_base:',
+ expect.any(Error),
+ );
+ // Fallback: Math.ceil(5/4) = 2 for both
+ expect(result).toEqual([2, 2]);
+ });
+
+ it('should use fallback calculation when encoding fails during batch processing', async () => {
+ mockEncode.mockImplementation(() => {
+ throw new Error('Encoding failed');
+ });
+
+ const texts = ['Hello', 'world']; // 5 and 5 characters
+ const result = await tokenizer.calculateTokensBatch(texts);
+
+ expect(consoleWarnSpy).toHaveBeenCalledWith(
+ 'Error encoding texts with tiktoken:',
+ expect.any(Error),
+ );
+ // Fallback: Math.ceil(5/4) = 2 for both
+ expect(result).toEqual([2, 2]);
+ });
+
+ it('should handle null and undefined values in batch', async () => {
+ const texts = [null, 'hello', undefined, 'world'] as unknown as string[];
+ mockEncode
+ .mockReturnValueOnce([1, 2, 3]) // 3 tokens for 'hello'
+ .mockReturnValueOnce([4, 5]); // 2 tokens for 'world'
+
+ const result = await tokenizer.calculateTokensBatch(texts);
+
+ expect(result).toEqual([0, 3, 0, 2]);
+ });
+ });
+
+ describe('dispose', () => {
+ beforeEach(() => {
+ tokenizer = new TextTokenizer();
+ });
+
+ it('should free tiktoken encoding when disposing', async () => {
+ // Initialize the encoding by calling calculateTokens
+ await tokenizer.calculateTokens('test');
+
+ tokenizer.dispose();
+
+ expect(mockFree).toHaveBeenCalled();
+ });
+
+ it('should handle disposal when encoding is not initialized', () => {
+ expect(() => tokenizer.dispose()).not.toThrow();
+ expect(mockFree).not.toHaveBeenCalled();
+ });
+
+ it('should handle disposal when encoding is null', async () => {
+ // Force encoding to be null by making tiktoken fail
+ mockGetEncoding.mockImplementation(() => {
+ throw new Error('Failed to load');
+ });
+
+ await tokenizer.calculateTokens('test');
+
+ expect(() => tokenizer.dispose()).not.toThrow();
+ expect(mockFree).not.toHaveBeenCalled();
+ });
+
+ it('should handle errors during disposal gracefully', async () => {
+ await tokenizer.calculateTokens('test');
+
+ mockFree.mockImplementation(() => {
+ throw new Error('Free failed');
+ });
+
+ tokenizer.dispose();
+
+ expect(consoleWarnSpy).toHaveBeenCalledWith(
+ 'Error freeing tiktoken encoding:',
+ expect.any(Error),
+ );
+ });
+
+ it('should allow multiple calls to dispose', async () => {
+ await tokenizer.calculateTokens('test');
+
+ tokenizer.dispose();
+ tokenizer.dispose(); // Second call should not throw
+
+ expect(mockFree).toHaveBeenCalledTimes(1);
+ });
+ });
+
+ describe('lazy initialization', () => {
+ beforeEach(() => {
+ tokenizer = new TextTokenizer();
+ });
+
+ it('should not initialize tiktoken until first use', () => {
+ expect(mockGetEncoding).not.toHaveBeenCalled();
+ });
+
+ it('should initialize tiktoken on first calculateTokens call', async () => {
+ await tokenizer.calculateTokens('test');
+ expect(mockGetEncoding).toHaveBeenCalledTimes(1);
+ });
+
+ it('should not reinitialize tiktoken on subsequent calls', async () => {
+ await tokenizer.calculateTokens('test1');
+ await tokenizer.calculateTokens('test2');
+
+ expect(mockGetEncoding).toHaveBeenCalledTimes(1);
+ });
+
+ it('should initialize tiktoken on first calculateTokensBatch call', async () => {
+ await tokenizer.calculateTokensBatch(['test']);
+ expect(mockGetEncoding).toHaveBeenCalledTimes(1);
+ });
+ });
+
+ describe('edge cases', () => {
+ beforeEach(() => {
+ tokenizer = new TextTokenizer();
+ });
+
+ it('should handle very short text', async () => {
+ const result = await tokenizer.calculateTokens('a');
+
+ if (mockGetEncoding.mock.calls.length > 0) {
+ // If tiktoken was called, use its result
+ expect(mockEncode).toHaveBeenCalledWith('a');
+ } else {
+ // If tiktoken failed, should use fallback: Math.ceil(1/4) = 1
+ expect(result).toBe(1);
+ }
+ });
+
+ it('should handle text with only whitespace', async () => {
+ const whitespaceText = ' \n\t ';
+ const mockTokens = [1];
+ mockEncode.mockReturnValue(mockTokens);
+
+ const result = await tokenizer.calculateTokens(whitespaceText);
+
+ expect(result).toBe(1);
+ });
+
+ it('should handle special characters and symbols', async () => {
+ const specialText = '!@#$%^&*()_+-=[]{}|;:,.<>?';
+ const mockTokens = new Array(10);
+ mockEncode.mockReturnValue(mockTokens);
+
+ const result = await tokenizer.calculateTokens(specialText);
+
+ expect(result).toBe(10);
+ });
+ });
+});
diff --git a/packages/core/src/utils/request-tokenizer/textTokenizer.ts b/packages/core/src/utils/request-tokenizer/textTokenizer.ts
new file mode 100644
index 00000000..86c71d4c
--- /dev/null
+++ b/packages/core/src/utils/request-tokenizer/textTokenizer.ts
@@ -0,0 +1,97 @@
+/**
+ * @license
+ * Copyright 2025 Qwen
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+import type { TiktokenEncoding, Tiktoken } from 'tiktoken';
+import { get_encoding } from 'tiktoken';
+
+/**
+ * Text tokenizer for calculating text tokens using tiktoken
+ */
+export class TextTokenizer {
+ private encoding: Tiktoken | null = null;
+ private encodingName: string;
+
+ constructor(encodingName: string = 'cl100k_base') {
+ this.encodingName = encodingName;
+ }
+
+ /**
+ * Initialize the tokenizer (lazy loading)
+ */
+ private async ensureEncoding(): Promise {
+ if (this.encoding) return;
+
+ try {
+ // Use type assertion since we know the encoding name is valid
+ this.encoding = get_encoding(this.encodingName as TiktokenEncoding);
+ } catch (error) {
+ console.warn(
+ `Failed to load tiktoken with encoding ${this.encodingName}:`,
+ error,
+ );
+ this.encoding = null;
+ }
+ }
+
+ /**
+ * Calculate tokens for text content
+ */
+ async calculateTokens(text: string): Promise {
+ if (!text) return 0;
+
+ await this.ensureEncoding();
+
+ if (this.encoding) {
+ try {
+ return this.encoding.encode(text).length;
+ } catch (error) {
+ console.warn('Error encoding text with tiktoken:', error);
+ }
+ }
+
+ // Fallback: rough approximation using character count
+ // This is a conservative estimate: 1 token ≈ 4 characters for most languages
+ return Math.ceil(text.length / 4);
+ }
+
+ /**
+ * Calculate tokens for multiple text strings in parallel
+ */
+ async calculateTokensBatch(texts: string[]): Promise {
+ await this.ensureEncoding();
+
+ if (this.encoding) {
+ try {
+ return texts.map((text) => {
+ if (!text) return 0;
+ // this.encoding may be null, add a null check to satisfy lint
+ return this.encoding ? this.encoding.encode(text).length : 0;
+ });
+ } catch (error) {
+ console.warn('Error encoding texts with tiktoken:', error);
+ // In case of error, return fallback estimation for all texts
+ return texts.map((text) => Math.ceil((text || '').length / 4));
+ }
+ }
+
+ // Fallback for batch processing
+ return texts.map((text) => Math.ceil((text || '').length / 4));
+ }
+
+ /**
+ * Dispose of resources
+ */
+ dispose(): void {
+ if (this.encoding) {
+ try {
+ this.encoding.free();
+ } catch (error) {
+ console.warn('Error freeing tiktoken encoding:', error);
+ }
+ this.encoding = null;
+ }
+ }
+}
diff --git a/packages/core/src/utils/request-tokenizer/types.ts b/packages/core/src/utils/request-tokenizer/types.ts
new file mode 100644
index 00000000..38c47699
--- /dev/null
+++ b/packages/core/src/utils/request-tokenizer/types.ts
@@ -0,0 +1,64 @@
+/**
+ * @license
+ * Copyright 2025 Qwen
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+import type { CountTokensParameters } from '@google/genai';
+
+/**
+ * Token calculation result for different content types
+ */
+export interface TokenCalculationResult {
+ /** Total tokens calculated */
+ totalTokens: number;
+ /** Breakdown by content type */
+ breakdown: {
+ textTokens: number;
+ imageTokens: number;
+ audioTokens: number;
+ otherTokens: number;
+ };
+ /** Processing time in milliseconds */
+ processingTime: number;
+}
+
+/**
+ * Configuration for token calculation
+ */
+export interface TokenizerConfig {
+ /** Custom text tokenizer encoding (defaults to cl100k_base) */
+ textEncoding?: string;
+}
+
+/**
+ * Image metadata extracted from base64 data
+ */
+export interface ImageMetadata {
+ /** Image width in pixels */
+ width: number;
+ /** Image height in pixels */
+ height: number;
+ /** MIME type of the image */
+ mimeType: string;
+ /** Size of the base64 data in bytes */
+ dataSize: number;
+}
+
+/**
+ * Request tokenizer interface
+ */
+export interface RequestTokenizer {
+ /**
+ * Calculate tokens for a request
+ */
+ calculateTokens(
+ request: CountTokensParameters,
+ config?: TokenizerConfig,
+ ): Promise;
+
+ /**
+ * Dispose of resources (worker threads, etc.)
+ */
+ dispose(): Promise;
+}