Vision model support for Qwen-OAuth (#525)

* refactor: openaiContentGenerator

* refactor: optimize stream handling

* refactor: re-organize refactored files

* fix: unit test cases

* feat: `/model` command for switching to vision model

* fix: lint error

* feat: add image tokenizer to fit vlm context window

* fix: lint and type errors

* feat: add `visionModelPreview` to control default visibility of vision models

* fix: remove deprecated files

* fix: align supported image formats with bailian doc
This commit is contained in:
Mingholy
2025-09-18 13:32:00 +08:00
committed by GitHub
parent 56808ac210
commit 761833c915
41 changed files with 4083 additions and 5336 deletions

View File

@@ -741,6 +741,16 @@ export const SETTINGS_SCHEMA = {
description: 'Enable extension management features.',
showInDialog: false,
},
visionModelPreview: {
type: 'boolean',
label: 'Vision Model Preview',
category: 'Experimental',
requiresRestart: false,
default: false,
description:
'Enable vision model support and auto-switching functionality. When disabled, vision models like qwen-vl-max-latest will be hidden and auto-switching will not occur.',
showInDialog: true,
},
},
},

View File

@@ -56,6 +56,13 @@ vi.mock('../ui/commands/mcpCommand.js', () => ({
kind: 'BUILT_IN',
},
}));
vi.mock('../ui/commands/modelCommand.js', () => ({
modelCommand: {
name: 'model',
description: 'Model command',
kind: 'BUILT_IN',
},
}));
describe('BuiltinCommandLoader', () => {
let mockConfig: Config;
@@ -126,5 +133,8 @@ describe('BuiltinCommandLoader', () => {
const mcpCmd = commands.find((c) => c.name === 'mcp');
expect(mcpCmd).toBeDefined();
const modelCmd = commands.find((c) => c.name === 'model');
expect(modelCmd).toBeDefined();
});
});

View File

@@ -35,6 +35,7 @@ import { settingsCommand } from '../ui/commands/settingsCommand.js';
import { vimCommand } from '../ui/commands/vimCommand.js';
import { setupGithubCommand } from '../ui/commands/setupGithubCommand.js';
import { terminalSetupCommand } from '../ui/commands/terminalSetupCommand.js';
import { modelCommand } from '../ui/commands/modelCommand.js';
import { agentsCommand } from '../ui/commands/agentsCommand.js';
/**
@@ -71,6 +72,7 @@ export class BuiltinCommandLoader implements ICommandLoader {
initCommand,
mcpCommand,
memoryCommand,
modelCommand,
privacyCommand,
quitCommand,
quitConfirmCommand,

View File

@@ -53,6 +53,17 @@ import { FolderTrustDialog } from './components/FolderTrustDialog.js';
import { ShellConfirmationDialog } from './components/ShellConfirmationDialog.js';
import { QuitConfirmationDialog } from './components/QuitConfirmationDialog.js';
import { RadioButtonSelect } from './components/shared/RadioButtonSelect.js';
import { ModelSelectionDialog } from './components/ModelSelectionDialog.js';
import {
ModelSwitchDialog,
type VisionSwitchOutcome,
} from './components/ModelSwitchDialog.js';
import {
getOpenAIAvailableModelFromEnv,
getFilteredQwenModels,
type AvailableModel,
} from './models/availableModels.js';
import { processVisionSwitchOutcome } from './hooks/useVisionAutoSwitch.js';
import {
AgentCreationWizard,
AgentsManagerDialog,
@@ -248,6 +259,20 @@ const App = ({ config, settings, startupWarnings = [], version }: AppProps) => {
onWorkspaceMigrationDialogClose,
} = useWorkspaceMigration(settings);
// Model selection dialog states
const [isModelSelectionDialogOpen, setIsModelSelectionDialogOpen] =
useState(false);
const [isVisionSwitchDialogOpen, setIsVisionSwitchDialogOpen] =
useState(false);
const [visionSwitchResolver, setVisionSwitchResolver] = useState<{
resolve: (result: {
modelOverride?: string;
persistSessionModel?: string;
showGuidance?: boolean;
}) => void;
reject: () => void;
} | null>(null);
useEffect(() => {
const unsubscribe = ideContext.subscribeToIdeContext(setIdeContextState);
// Set the initial value
@@ -590,6 +615,75 @@ const App = ({ config, settings, startupWarnings = [], version }: AppProps) => {
openAuthDialog();
}, [openAuthDialog, setAuthError]);
// Vision switch handler for auto-switch functionality
const handleVisionSwitchRequired = useCallback(
async (_query: unknown) =>
new Promise<{
modelOverride?: string;
persistSessionModel?: string;
showGuidance?: boolean;
}>((resolve, reject) => {
setVisionSwitchResolver({ resolve, reject });
setIsVisionSwitchDialogOpen(true);
}),
[],
);
const handleVisionSwitchSelect = useCallback(
(outcome: VisionSwitchOutcome) => {
setIsVisionSwitchDialogOpen(false);
if (visionSwitchResolver) {
const result = processVisionSwitchOutcome(outcome);
visionSwitchResolver.resolve(result);
setVisionSwitchResolver(null);
}
},
[visionSwitchResolver],
);
const handleModelSelectionOpen = useCallback(() => {
setIsModelSelectionDialogOpen(true);
}, []);
const handleModelSelectionClose = useCallback(() => {
setIsModelSelectionDialogOpen(false);
}, []);
const handleModelSelect = useCallback(
(modelId: string) => {
config.setModel(modelId);
setCurrentModel(modelId);
setIsModelSelectionDialogOpen(false);
addItem(
{
type: MessageType.INFO,
text: `Switched model to \`${modelId}\` for this session.`,
},
Date.now(),
);
},
[config, setCurrentModel, addItem],
);
const getAvailableModelsForCurrentAuth = useCallback((): AvailableModel[] => {
const contentGeneratorConfig = config.getContentGeneratorConfig();
if (!contentGeneratorConfig) return [];
const visionModelPreviewEnabled =
settings.merged.experimental?.visionModelPreview ?? false;
switch (contentGeneratorConfig.authType) {
case AuthType.QWEN_OAUTH:
return getFilteredQwenModels(visionModelPreviewEnabled);
case AuthType.USE_OPENAI: {
const openAIModel = getOpenAIAvailableModelFromEnv();
return openAIModel ? [openAIModel] : [];
}
default:
return [];
}
}, [config, settings.merged.experimental?.visionModelPreview]);
// Core hooks and processors
const {
vimEnabled: vimModeEnabled,
@@ -620,6 +714,7 @@ const App = ({ config, settings, startupWarnings = [], version }: AppProps) => {
setQuittingMessages,
openPrivacyNotice,
openSettingsDialog,
handleModelSelectionOpen,
openSubagentCreateDialog,
openAgentsManagerDialog,
toggleVimEnabled,
@@ -664,6 +759,8 @@ const App = ({ config, settings, startupWarnings = [], version }: AppProps) => {
setModelSwitchedFromQuotaError,
refreshStatic,
() => cancelHandlerRef.current(),
settings.merged.experimental?.visionModelPreview ?? false,
handleVisionSwitchRequired,
);
const pendingHistoryItems = useMemo(
@@ -1034,6 +1131,8 @@ const App = ({ config, settings, startupWarnings = [], version }: AppProps) => {
!isAuthDialogOpen &&
!isThemeDialogOpen &&
!isEditorDialogOpen &&
!isModelSelectionDialogOpen &&
!isVisionSwitchDialogOpen &&
!isSubagentCreateDialogOpen &&
!showPrivacyNotice &&
!showWelcomeBackDialog &&
@@ -1055,6 +1154,8 @@ const App = ({ config, settings, startupWarnings = [], version }: AppProps) => {
showWelcomeBackDialog,
welcomeBackChoice,
geminiClient,
isModelSelectionDialogOpen,
isVisionSwitchDialogOpen,
]);
if (quittingMessages) {
@@ -1322,6 +1423,15 @@ const App = ({ config, settings, startupWarnings = [], version }: AppProps) => {
onExit={exitEditorDialog}
/>
</Box>
) : isModelSelectionDialogOpen ? (
<ModelSelectionDialog
availableModels={getAvailableModelsForCurrentAuth()}
currentModel={currentModel}
onSelect={handleModelSelect}
onCancel={handleModelSelectionClose}
/>
) : isVisionSwitchDialogOpen ? (
<ModelSwitchDialog onSelect={handleVisionSwitchSelect} />
) : showPrivacyNotice ? (
<PrivacyNotice
onExit={() => setShowPrivacyNotice(false)}

View File

@@ -0,0 +1,179 @@
/**
* @license
* Copyright 2025 Qwen
* SPDX-License-Identifier: Apache-2.0
*/
import { describe, it, expect, beforeEach, vi } from 'vitest';
import { modelCommand } from './modelCommand.js';
import { type CommandContext } from './types.js';
import { createMockCommandContext } from '../../test-utils/mockCommandContext.js';
import {
AuthType,
type ContentGeneratorConfig,
type Config,
} from '@qwen-code/qwen-code-core';
import * as availableModelsModule from '../models/availableModels.js';
// Mock the availableModels module
vi.mock('../models/availableModels.js', () => ({
AVAILABLE_MODELS_QWEN: [
{ id: 'qwen3-coder-plus', label: 'qwen3-coder-plus' },
{ id: 'qwen-vl-max-latest', label: 'qwen-vl-max', isVision: true },
],
getOpenAIAvailableModelFromEnv: vi.fn(),
}));
// Helper function to create a mock config
function createMockConfig(
contentGeneratorConfig: ContentGeneratorConfig | null,
): Partial<Config> {
return {
getContentGeneratorConfig: vi.fn().mockReturnValue(contentGeneratorConfig),
};
}
describe('modelCommand', () => {
let mockContext: CommandContext;
const mockGetOpenAIAvailableModelFromEnv = vi.mocked(
availableModelsModule.getOpenAIAvailableModelFromEnv,
);
beforeEach(() => {
mockContext = createMockCommandContext();
vi.clearAllMocks();
});
it('should have the correct name and description', () => {
expect(modelCommand.name).toBe('model');
expect(modelCommand.description).toBe('Switch the model for this session');
});
it('should return error when config is not available', async () => {
mockContext.services.config = null;
const result = await modelCommand.action!(mockContext, '');
expect(result).toEqual({
type: 'message',
messageType: 'error',
content: 'Configuration not available.',
});
});
it('should return error when content generator config is not available', async () => {
const mockConfig = createMockConfig(null);
mockContext.services.config = mockConfig as Config;
const result = await modelCommand.action!(mockContext, '');
expect(result).toEqual({
type: 'message',
messageType: 'error',
content: 'Content generator configuration not available.',
});
});
it('should return error when auth type is not available', async () => {
const mockConfig = createMockConfig({
model: 'test-model',
authType: undefined,
});
mockContext.services.config = mockConfig as Config;
const result = await modelCommand.action!(mockContext, '');
expect(result).toEqual({
type: 'message',
messageType: 'error',
content: 'Authentication type not available.',
});
});
it('should return dialog action for QWEN_OAUTH auth type', async () => {
const mockConfig = createMockConfig({
model: 'test-model',
authType: AuthType.QWEN_OAUTH,
});
mockContext.services.config = mockConfig as Config;
const result = await modelCommand.action!(mockContext, '');
expect(result).toEqual({
type: 'dialog',
dialog: 'model',
});
});
it('should return dialog action for USE_OPENAI auth type when model is available', async () => {
mockGetOpenAIAvailableModelFromEnv.mockReturnValue({
id: 'gpt-4',
label: 'gpt-4',
});
const mockConfig = createMockConfig({
model: 'test-model',
authType: AuthType.USE_OPENAI,
});
mockContext.services.config = mockConfig as Config;
const result = await modelCommand.action!(mockContext, '');
expect(result).toEqual({
type: 'dialog',
dialog: 'model',
});
});
it('should return error for USE_OPENAI auth type when no model is available', async () => {
mockGetOpenAIAvailableModelFromEnv.mockReturnValue(null);
const mockConfig = createMockConfig({
model: 'test-model',
authType: AuthType.USE_OPENAI,
});
mockContext.services.config = mockConfig as Config;
const result = await modelCommand.action!(mockContext, '');
expect(result).toEqual({
type: 'message',
messageType: 'error',
content:
'No models available for the current authentication type (openai).',
});
});
it('should return error for unsupported auth types', async () => {
const mockConfig = createMockConfig({
model: 'test-model',
authType: 'UNSUPPORTED_AUTH_TYPE' as AuthType,
});
mockContext.services.config = mockConfig as Config;
const result = await modelCommand.action!(mockContext, '');
expect(result).toEqual({
type: 'message',
messageType: 'error',
content:
'No models available for the current authentication type (UNSUPPORTED_AUTH_TYPE).',
});
});
it('should handle undefined auth type', async () => {
const mockConfig = createMockConfig({
model: 'test-model',
authType: undefined,
});
mockContext.services.config = mockConfig as Config;
const result = await modelCommand.action!(mockContext, '');
expect(result).toEqual({
type: 'message',
messageType: 'error',
content: 'Authentication type not available.',
});
});
});

View File

@@ -0,0 +1,88 @@
/**
* @license
* Copyright 2025 Qwen
* SPDX-License-Identifier: Apache-2.0
*/
import { AuthType } from '@qwen-code/qwen-code-core';
import type {
SlashCommand,
CommandContext,
OpenDialogActionReturn,
MessageActionReturn,
} from './types.js';
import { CommandKind } from './types.js';
import {
AVAILABLE_MODELS_QWEN,
getOpenAIAvailableModelFromEnv,
type AvailableModel,
} from '../models/availableModels.js';
function getAvailableModelsForAuthType(authType: AuthType): AvailableModel[] {
switch (authType) {
case AuthType.QWEN_OAUTH:
return AVAILABLE_MODELS_QWEN;
case AuthType.USE_OPENAI: {
const openAIModel = getOpenAIAvailableModelFromEnv();
return openAIModel ? [openAIModel] : [];
}
default:
// For other auth types, return empty array for now
// This can be expanded later according to the design doc
return [];
}
}
export const modelCommand: SlashCommand = {
name: 'model',
description: 'Switch the model for this session',
kind: CommandKind.BUILT_IN,
action: async (
context: CommandContext,
): Promise<OpenDialogActionReturn | MessageActionReturn> => {
const { services } = context;
const { config } = services;
if (!config) {
return {
type: 'message',
messageType: 'error',
content: 'Configuration not available.',
};
}
const contentGeneratorConfig = config.getContentGeneratorConfig();
if (!contentGeneratorConfig) {
return {
type: 'message',
messageType: 'error',
content: 'Content generator configuration not available.',
};
}
const authType = contentGeneratorConfig.authType;
if (!authType) {
return {
type: 'message',
messageType: 'error',
content: 'Authentication type not available.',
};
}
const availableModels = getAvailableModelsForAuthType(authType);
if (availableModels.length === 0) {
return {
type: 'message',
messageType: 'error',
content: `No models available for the current authentication type (${authType}).`,
};
}
// Trigger model selection dialog
return {
type: 'dialog',
dialog: 'model',
};
},
};

View File

@@ -116,6 +116,7 @@ export interface OpenDialogActionReturn {
| 'editor'
| 'privacy'
| 'settings'
| 'model'
| 'subagent_create'
| 'subagent_list';
}

View File

@@ -0,0 +1,246 @@
/**
* @license
* Copyright 2025 Qwen
* SPDX-License-Identifier: Apache-2.0
*/
import React from 'react';
import { render } from 'ink-testing-library';
import { describe, it, expect, vi, beforeEach } from 'vitest';
import { ModelSelectionDialog } from './ModelSelectionDialog.js';
import type { AvailableModel } from '../models/availableModels.js';
import type { RadioSelectItem } from './shared/RadioButtonSelect.js';
// Mock the useKeypress hook
const mockUseKeypress = vi.hoisted(() => vi.fn());
vi.mock('../hooks/useKeypress.js', () => ({
useKeypress: mockUseKeypress,
}));
// Mock the RadioButtonSelect component
const mockRadioButtonSelect = vi.hoisted(() => vi.fn());
vi.mock('./shared/RadioButtonSelect.js', () => ({
RadioButtonSelect: mockRadioButtonSelect,
}));
describe('ModelSelectionDialog', () => {
const mockAvailableModels: AvailableModel[] = [
{ id: 'qwen3-coder-plus', label: 'qwen3-coder-plus' },
{ id: 'qwen-vl-max-latest', label: 'qwen-vl-max', isVision: true },
{ id: 'gpt-4', label: 'GPT-4' },
];
const mockOnSelect = vi.fn();
const mockOnCancel = vi.fn();
beforeEach(() => {
vi.clearAllMocks();
// Mock RadioButtonSelect to return a simple div
mockRadioButtonSelect.mockReturnValue(
React.createElement('div', { 'data-testid': 'radio-select' }),
);
});
it('should setup escape key handler to call onCancel', () => {
render(
<ModelSelectionDialog
availableModels={mockAvailableModels}
currentModel="qwen3-coder-plus"
onSelect={mockOnSelect}
onCancel={mockOnCancel}
/>,
);
expect(mockUseKeypress).toHaveBeenCalledWith(expect.any(Function), {
isActive: true,
});
// Simulate escape key press
const keypressHandler = mockUseKeypress.mock.calls[0][0];
keypressHandler({ name: 'escape' });
expect(mockOnCancel).toHaveBeenCalled();
});
it('should not call onCancel for non-escape keys', () => {
render(
<ModelSelectionDialog
availableModels={mockAvailableModels}
currentModel="qwen3-coder-plus"
onSelect={mockOnSelect}
onCancel={mockOnCancel}
/>,
);
const keypressHandler = mockUseKeypress.mock.calls[0][0];
keypressHandler({ name: 'enter' });
expect(mockOnCancel).not.toHaveBeenCalled();
});
it('should set correct initial index for current model', () => {
render(
<ModelSelectionDialog
availableModels={mockAvailableModels}
currentModel="qwen-vl-max-latest"
onSelect={mockOnSelect}
onCancel={mockOnCancel}
/>,
);
const callArgs = mockRadioButtonSelect.mock.calls[0][0];
expect(callArgs.initialIndex).toBe(1); // qwen-vl-max-latest is at index 1
});
it('should set initial index to 0 when current model is not found', () => {
render(
<ModelSelectionDialog
availableModels={mockAvailableModels}
currentModel="non-existent-model"
onSelect={mockOnSelect}
onCancel={mockOnCancel}
/>,
);
const callArgs = mockRadioButtonSelect.mock.calls[0][0];
expect(callArgs.initialIndex).toBe(0);
});
it('should call onSelect when a model is selected', () => {
render(
<ModelSelectionDialog
availableModels={mockAvailableModels}
currentModel="qwen3-coder-plus"
onSelect={mockOnSelect}
onCancel={mockOnCancel}
/>,
);
const callArgs = mockRadioButtonSelect.mock.calls[0][0];
expect(typeof callArgs.onSelect).toBe('function');
// Simulate selection
const onSelectCallback = mockRadioButtonSelect.mock.calls[0][0].onSelect;
onSelectCallback('qwen-vl-max-latest');
expect(mockOnSelect).toHaveBeenCalledWith('qwen-vl-max-latest');
});
it('should handle empty models array', () => {
render(
<ModelSelectionDialog
availableModels={[]}
currentModel=""
onSelect={mockOnSelect}
onCancel={mockOnCancel}
/>,
);
const callArgs = mockRadioButtonSelect.mock.calls[0][0];
expect(callArgs.items).toEqual([]);
expect(callArgs.initialIndex).toBe(0);
});
it('should create correct option items with proper labels', () => {
render(
<ModelSelectionDialog
availableModels={mockAvailableModels}
currentModel="qwen3-coder-plus"
onSelect={mockOnSelect}
onCancel={mockOnCancel}
/>,
);
const expectedItems = [
{
label: 'qwen3-coder-plus (current)',
value: 'qwen3-coder-plus',
},
{
label: 'qwen-vl-max [Vision]',
value: 'qwen-vl-max-latest',
},
{
label: 'GPT-4',
value: 'gpt-4',
},
];
const callArgs = mockRadioButtonSelect.mock.calls[0][0];
expect(callArgs.items).toEqual(expectedItems);
});
it('should show vision indicator for vision models', () => {
render(
<ModelSelectionDialog
availableModels={mockAvailableModels}
currentModel="gpt-4"
onSelect={mockOnSelect}
onCancel={mockOnCancel}
/>,
);
const callArgs = mockRadioButtonSelect.mock.calls[0][0];
const visionModelItem = callArgs.items.find(
(item: RadioSelectItem<string>) => item.value === 'qwen-vl-max-latest',
);
expect(visionModelItem?.label).toContain('[Vision]');
});
it('should show current indicator for the current model', () => {
render(
<ModelSelectionDialog
availableModels={mockAvailableModels}
currentModel="qwen-vl-max-latest"
onSelect={mockOnSelect}
onCancel={mockOnCancel}
/>,
);
const callArgs = mockRadioButtonSelect.mock.calls[0][0];
const currentModelItem = callArgs.items.find(
(item: RadioSelectItem<string>) => item.value === 'qwen-vl-max-latest',
);
expect(currentModelItem?.label).toContain('(current)');
});
it('should pass isFocused prop to RadioButtonSelect', () => {
render(
<ModelSelectionDialog
availableModels={mockAvailableModels}
currentModel="qwen3-coder-plus"
onSelect={mockOnSelect}
onCancel={mockOnCancel}
/>,
);
const callArgs = mockRadioButtonSelect.mock.calls[0][0];
expect(callArgs.isFocused).toBe(true);
});
it('should handle multiple onSelect calls correctly', () => {
render(
<ModelSelectionDialog
availableModels={mockAvailableModels}
currentModel="qwen3-coder-plus"
onSelect={mockOnSelect}
onCancel={mockOnCancel}
/>,
);
const onSelectCallback = mockRadioButtonSelect.mock.calls[0][0].onSelect;
// Call multiple times
onSelectCallback('qwen3-coder-plus');
onSelectCallback('qwen-vl-max-latest');
onSelectCallback('gpt-4');
expect(mockOnSelect).toHaveBeenCalledTimes(3);
expect(mockOnSelect).toHaveBeenNthCalledWith(1, 'qwen3-coder-plus');
expect(mockOnSelect).toHaveBeenNthCalledWith(2, 'qwen-vl-max-latest');
expect(mockOnSelect).toHaveBeenNthCalledWith(3, 'gpt-4');
});
});

View File

@@ -0,0 +1,87 @@
/**
* @license
* Copyright 2025 Qwen
* SPDX-License-Identifier: Apache-2.0
*/
import type React from 'react';
import { Box, Text } from 'ink';
import { Colors } from '../colors.js';
import {
RadioButtonSelect,
type RadioSelectItem,
} from './shared/RadioButtonSelect.js';
import { useKeypress } from '../hooks/useKeypress.js';
import type { AvailableModel } from '../models/availableModels.js';
export interface ModelSelectionDialogProps {
availableModels: AvailableModel[];
currentModel: string;
onSelect: (modelId: string) => void;
onCancel: () => void;
}
export const ModelSelectionDialog: React.FC<ModelSelectionDialogProps> = ({
availableModels,
currentModel,
onSelect,
onCancel,
}) => {
useKeypress(
(key) => {
if (key.name === 'escape') {
onCancel();
}
},
{ isActive: true },
);
const options: Array<RadioSelectItem<string>> = availableModels.map(
(model) => {
const visionIndicator = model.isVision ? ' [Vision]' : '';
const currentIndicator = model.id === currentModel ? ' (current)' : '';
return {
label: `${model.label}${visionIndicator}${currentIndicator}`,
value: model.id,
};
},
);
const initialIndex = Math.max(
0,
availableModels.findIndex((model) => model.id === currentModel),
);
const handleSelect = (modelId: string) => {
onSelect(modelId);
};
return (
<Box
flexDirection="column"
borderStyle="round"
borderColor={Colors.AccentBlue}
padding={1}
width="100%"
marginLeft={1}
>
<Box flexDirection="column" marginBottom={1}>
<Text bold>Select Model</Text>
<Text>Choose a model for this session:</Text>
</Box>
<Box marginBottom={1}>
<RadioButtonSelect
items={options}
initialIndex={initialIndex}
onSelect={handleSelect}
isFocused
/>
</Box>
<Box>
<Text color={Colors.Gray}>Press Enter to select, Esc to cancel</Text>
</Box>
</Box>
);
};

View File

@@ -0,0 +1,185 @@
/**
* @license
* Copyright 2025 Qwen
* SPDX-License-Identifier: Apache-2.0
*/
import React from 'react';
import { render } from 'ink-testing-library';
import { describe, it, expect, vi, beforeEach } from 'vitest';
import { ModelSwitchDialog, VisionSwitchOutcome } from './ModelSwitchDialog.js';
// Mock the useKeypress hook
const mockUseKeypress = vi.hoisted(() => vi.fn());
vi.mock('../hooks/useKeypress.js', () => ({
useKeypress: mockUseKeypress,
}));
// Mock the RadioButtonSelect component
const mockRadioButtonSelect = vi.hoisted(() => vi.fn());
vi.mock('./shared/RadioButtonSelect.js', () => ({
RadioButtonSelect: mockRadioButtonSelect,
}));
describe('ModelSwitchDialog', () => {
const mockOnSelect = vi.fn();
beforeEach(() => {
vi.clearAllMocks();
// Mock RadioButtonSelect to return a simple div
mockRadioButtonSelect.mockReturnValue(
React.createElement('div', { 'data-testid': 'radio-select' }),
);
});
it('should setup RadioButtonSelect with correct options', () => {
render(<ModelSwitchDialog onSelect={mockOnSelect} />);
const expectedItems = [
{
label: 'Switch for this request only',
value: VisionSwitchOutcome.SwitchOnce,
},
{
label: 'Switch session to vision model',
value: VisionSwitchOutcome.SwitchSessionToVL,
},
{
label: 'Do not switch, show guidance',
value: VisionSwitchOutcome.DisallowWithGuidance,
},
];
const callArgs = mockRadioButtonSelect.mock.calls[0][0];
expect(callArgs.items).toEqual(expectedItems);
expect(callArgs.initialIndex).toBe(0);
expect(callArgs.isFocused).toBe(true);
});
it('should call onSelect when an option is selected', () => {
render(<ModelSwitchDialog onSelect={mockOnSelect} />);
const callArgs = mockRadioButtonSelect.mock.calls[0][0];
expect(typeof callArgs.onSelect).toBe('function');
// Simulate selection of "Switch for this request only"
const onSelectCallback = mockRadioButtonSelect.mock.calls[0][0].onSelect;
onSelectCallback(VisionSwitchOutcome.SwitchOnce);
expect(mockOnSelect).toHaveBeenCalledWith(VisionSwitchOutcome.SwitchOnce);
});
it('should call onSelect with SwitchSessionToVL when second option is selected', () => {
render(<ModelSwitchDialog onSelect={mockOnSelect} />);
const onSelectCallback = mockRadioButtonSelect.mock.calls[0][0].onSelect;
onSelectCallback(VisionSwitchOutcome.SwitchSessionToVL);
expect(mockOnSelect).toHaveBeenCalledWith(
VisionSwitchOutcome.SwitchSessionToVL,
);
});
it('should call onSelect with DisallowWithGuidance when third option is selected', () => {
render(<ModelSwitchDialog onSelect={mockOnSelect} />);
const onSelectCallback = mockRadioButtonSelect.mock.calls[0][0].onSelect;
onSelectCallback(VisionSwitchOutcome.DisallowWithGuidance);
expect(mockOnSelect).toHaveBeenCalledWith(
VisionSwitchOutcome.DisallowWithGuidance,
);
});
it('should setup escape key handler to call onSelect with DisallowWithGuidance', () => {
render(<ModelSwitchDialog onSelect={mockOnSelect} />);
expect(mockUseKeypress).toHaveBeenCalledWith(expect.any(Function), {
isActive: true,
});
// Simulate escape key press
const keypressHandler = mockUseKeypress.mock.calls[0][0];
keypressHandler({ name: 'escape' });
expect(mockOnSelect).toHaveBeenCalledWith(
VisionSwitchOutcome.DisallowWithGuidance,
);
});
it('should not call onSelect for non-escape keys', () => {
render(<ModelSwitchDialog onSelect={mockOnSelect} />);
const keypressHandler = mockUseKeypress.mock.calls[0][0];
keypressHandler({ name: 'enter' });
expect(mockOnSelect).not.toHaveBeenCalled();
});
it('should set initial index to 0 (first option)', () => {
render(<ModelSwitchDialog onSelect={mockOnSelect} />);
const callArgs = mockRadioButtonSelect.mock.calls[0][0];
expect(callArgs.initialIndex).toBe(0);
});
describe('VisionSwitchOutcome enum', () => {
it('should have correct enum values', () => {
expect(VisionSwitchOutcome.SwitchOnce).toBe('switch_once');
expect(VisionSwitchOutcome.SwitchSessionToVL).toBe(
'switch_session_to_vl',
);
expect(VisionSwitchOutcome.DisallowWithGuidance).toBe(
'disallow_with_guidance',
);
});
});
it('should handle multiple onSelect calls correctly', () => {
render(<ModelSwitchDialog onSelect={mockOnSelect} />);
const onSelectCallback = mockRadioButtonSelect.mock.calls[0][0].onSelect;
// Call multiple times
onSelectCallback(VisionSwitchOutcome.SwitchOnce);
onSelectCallback(VisionSwitchOutcome.SwitchSessionToVL);
onSelectCallback(VisionSwitchOutcome.DisallowWithGuidance);
expect(mockOnSelect).toHaveBeenCalledTimes(3);
expect(mockOnSelect).toHaveBeenNthCalledWith(
1,
VisionSwitchOutcome.SwitchOnce,
);
expect(mockOnSelect).toHaveBeenNthCalledWith(
2,
VisionSwitchOutcome.SwitchSessionToVL,
);
expect(mockOnSelect).toHaveBeenNthCalledWith(
3,
VisionSwitchOutcome.DisallowWithGuidance,
);
});
it('should pass isFocused prop to RadioButtonSelect', () => {
render(<ModelSwitchDialog onSelect={mockOnSelect} />);
const callArgs = mockRadioButtonSelect.mock.calls[0][0];
expect(callArgs.isFocused).toBe(true);
});
it('should handle escape key multiple times', () => {
render(<ModelSwitchDialog onSelect={mockOnSelect} />);
const keypressHandler = mockUseKeypress.mock.calls[0][0];
// Call escape multiple times
keypressHandler({ name: 'escape' });
keypressHandler({ name: 'escape' });
expect(mockOnSelect).toHaveBeenCalledTimes(2);
expect(mockOnSelect).toHaveBeenCalledWith(
VisionSwitchOutcome.DisallowWithGuidance,
);
});
});

View File

@@ -0,0 +1,89 @@
/**
* @license
* Copyright 2025 Qwen
* SPDX-License-Identifier: Apache-2.0
*/
import type React from 'react';
import { Box, Text } from 'ink';
import { Colors } from '../colors.js';
import {
RadioButtonSelect,
type RadioSelectItem,
} from './shared/RadioButtonSelect.js';
import { useKeypress } from '../hooks/useKeypress.js';
export enum VisionSwitchOutcome {
SwitchOnce = 'switch_once',
SwitchSessionToVL = 'switch_session_to_vl',
DisallowWithGuidance = 'disallow_with_guidance',
}
export interface ModelSwitchDialogProps {
onSelect: (outcome: VisionSwitchOutcome) => void;
}
export const ModelSwitchDialog: React.FC<ModelSwitchDialogProps> = ({
onSelect,
}) => {
useKeypress(
(key) => {
if (key.name === 'escape') {
onSelect(VisionSwitchOutcome.DisallowWithGuidance);
}
},
{ isActive: true },
);
const options: Array<RadioSelectItem<VisionSwitchOutcome>> = [
{
label: 'Switch for this request only',
value: VisionSwitchOutcome.SwitchOnce,
},
{
label: 'Switch session to vision model',
value: VisionSwitchOutcome.SwitchSessionToVL,
},
{
label: 'Do not switch, show guidance',
value: VisionSwitchOutcome.DisallowWithGuidance,
},
];
const handleSelect = (outcome: VisionSwitchOutcome) => {
onSelect(outcome);
};
return (
<Box
flexDirection="column"
borderStyle="round"
borderColor={Colors.AccentYellow}
padding={1}
width="100%"
marginLeft={1}
>
<Box flexDirection="column" marginBottom={1}>
<Text bold>Vision Model Switch Required</Text>
<Text>
Your message contains an image, but the current model doesn&apos;t
support vision.
</Text>
<Text>How would you like to proceed?</Text>
</Box>
<Box marginBottom={1}>
<RadioButtonSelect
items={options}
initialIndex={0}
onSelect={handleSelect}
isFocused
/>
</Box>
<Box>
<Text color={Colors.Gray}>Press Enter to select, Esc to cancel</Text>
</Box>
</Box>
);
};

View File

@@ -106,6 +106,7 @@ describe('useSlashCommandProcessor', () => {
const mockLoadHistory = vi.fn();
const mockOpenThemeDialog = vi.fn();
const mockOpenAuthDialog = vi.fn();
const mockOpenModelSelectionDialog = vi.fn();
const mockSetQuittingMessages = vi.fn();
const mockConfig = makeFakeConfig({});
@@ -122,6 +123,7 @@ describe('useSlashCommandProcessor', () => {
mockBuiltinLoadCommands.mockResolvedValue([]);
mockFileLoadCommands.mockResolvedValue([]);
mockMcpLoadCommands.mockResolvedValue([]);
mockOpenModelSelectionDialog.mockClear();
});
const setupProcessorHook = (
@@ -150,11 +152,13 @@ describe('useSlashCommandProcessor', () => {
mockSetQuittingMessages,
vi.fn(), // openPrivacyNotice
vi.fn(), // openSettingsDialog
mockOpenModelSelectionDialog,
vi.fn(), // openSubagentCreateDialog
vi.fn(), // openAgentsManagerDialog
vi.fn(), // toggleVimEnabled
setIsProcessing,
vi.fn(), // setGeminiMdFileCount
vi.fn(), // _showQuitConfirmation
),
);
@@ -395,6 +399,21 @@ describe('useSlashCommandProcessor', () => {
expect(mockOpenThemeDialog).toHaveBeenCalled();
});
it('should handle "dialog: model" action', async () => {
const command = createTestCommand({
name: 'modelcmd',
action: vi.fn().mockResolvedValue({ type: 'dialog', dialog: 'model' }),
});
const result = setupProcessorHook([command]);
await waitFor(() => expect(result.current.slashCommands).toHaveLength(1));
await act(async () => {
await result.current.handleSlashCommand('/modelcmd');
});
expect(mockOpenModelSelectionDialog).toHaveBeenCalled();
});
it('should handle "load_history" action', async () => {
const command = createTestCommand({
name: 'load',
@@ -904,11 +923,13 @@ describe('useSlashCommandProcessor', () => {
mockSetQuittingMessages,
vi.fn(), // openPrivacyNotice
vi.fn(), // openSettingsDialog
vi.fn(), // openModelSelectionDialog
vi.fn(), // openSubagentCreateDialog
vi.fn(), // openAgentsManagerDialog
vi.fn(), // toggleVimEnabled
vi.fn(), // setIsProcessing
vi.fn(), // setGeminiMdFileCount
vi.fn(), // _showQuitConfirmation
),
);

View File

@@ -53,6 +53,7 @@ export const useSlashCommandProcessor = (
setQuittingMessages: (message: HistoryItem[]) => void,
openPrivacyNotice: () => void,
openSettingsDialog: () => void,
openModelSelectionDialog: () => void,
openSubagentCreateDialog: () => void,
openAgentsManagerDialog: () => void,
toggleVimEnabled: () => Promise<boolean>,
@@ -404,6 +405,9 @@ export const useSlashCommandProcessor = (
case 'settings':
openSettingsDialog();
return { type: 'handled' };
case 'model':
openModelSelectionDialog();
return { type: 'handled' };
case 'subagent_create':
openSubagentCreateDialog();
return { type: 'handled' };
@@ -663,6 +667,7 @@ export const useSlashCommandProcessor = (
setSessionShellAllowlist,
setIsProcessing,
setConfirmationRequest,
openModelSelectionDialog,
session.stats,
],
);

View File

@@ -56,6 +56,12 @@ const MockedUserPromptEvent = vi.hoisted(() =>
);
const mockParseAndFormatApiError = vi.hoisted(() => vi.fn());
// Vision auto-switch mocks (hoisted)
const mockHandleVisionSwitch = vi.hoisted(() =>
vi.fn().mockResolvedValue({ shouldProceed: true }),
);
const mockRestoreOriginalModel = vi.hoisted(() => vi.fn());
vi.mock('@qwen-code/qwen-code-core', async (importOriginal) => {
const actualCoreModule = (await importOriginal()) as any;
return {
@@ -76,6 +82,13 @@ vi.mock('./useReactToolScheduler.js', async (importOriginal) => {
};
});
vi.mock('./useVisionAutoSwitch.js', () => ({
useVisionAutoSwitch: vi.fn(() => ({
handleVisionSwitch: mockHandleVisionSwitch,
restoreOriginalModel: mockRestoreOriginalModel,
})),
}));
vi.mock('./useKeypress.js', () => ({
useKeypress: vi.fn(),
}));
@@ -199,6 +212,7 @@ describe('useGeminiStream', () => {
getContentGeneratorConfig: vi
.fn()
.mockReturnValue(contentGeneratorConfig),
getMaxSessionTurns: vi.fn(() => 50),
} as unknown as Config;
mockOnDebugMessage = vi.fn();
mockHandleSlashCommand = vi.fn().mockResolvedValue(false);
@@ -1551,6 +1565,7 @@ describe('useGeminiStream', () => {
expect.any(String), // Argument 3: The prompt_id string
);
});
describe('Thought Reset', () => {
it('should reset thought to null when starting a new prompt', async () => {
// First, simulate a response with a thought
@@ -1900,4 +1915,166 @@ describe('useGeminiStream', () => {
);
});
});
// --- New tests focused on recent modifications ---
describe('Vision Auto Switch Integration', () => {
it('should call handleVisionSwitch and proceed to send when allowed', async () => {
mockHandleVisionSwitch.mockResolvedValueOnce({ shouldProceed: true });
mockSendMessageStream.mockReturnValue(
(async function* () {
yield { type: ServerGeminiEventType.Content, value: 'ok' };
yield { type: ServerGeminiEventType.Finished, value: 'STOP' };
})(),
);
const { result } = renderHook(() =>
useGeminiStream(
new MockedGeminiClientClass(mockConfig),
[],
mockAddItem,
mockConfig,
mockOnDebugMessage,
mockHandleSlashCommand,
false,
() => 'vscode' as EditorType,
() => {},
() => Promise.resolve(),
false,
() => {},
() => {},
() => {},
),
);
await act(async () => {
await result.current.submitQuery('image prompt');
});
await waitFor(() => {
expect(mockHandleVisionSwitch).toHaveBeenCalled();
expect(mockSendMessageStream).toHaveBeenCalled();
});
});
it('should gate submission when handleVisionSwitch returns shouldProceed=false', async () => {
mockHandleVisionSwitch.mockResolvedValueOnce({ shouldProceed: false });
const { result } = renderHook(() =>
useGeminiStream(
new MockedGeminiClientClass(mockConfig),
[],
mockAddItem,
mockConfig,
mockOnDebugMessage,
mockHandleSlashCommand,
false,
() => 'vscode' as EditorType,
() => {},
() => Promise.resolve(),
false,
() => {},
() => {},
() => {},
),
);
await act(async () => {
await result.current.submitQuery('vision-gated');
});
// No call to API, no restoreOriginalModel needed since no override occurred
expect(mockSendMessageStream).not.toHaveBeenCalled();
expect(mockRestoreOriginalModel).not.toHaveBeenCalled();
// Next call allowed (flag reset path)
mockHandleVisionSwitch.mockResolvedValueOnce({ shouldProceed: true });
mockSendMessageStream.mockReturnValue(
(async function* () {
yield { type: ServerGeminiEventType.Content, value: 'ok' };
yield { type: ServerGeminiEventType.Finished, value: 'STOP' };
})(),
);
await act(async () => {
await result.current.submitQuery('after-gate');
});
await waitFor(() => {
expect(mockSendMessageStream).toHaveBeenCalled();
});
});
});
describe('Model restore on completion and errors', () => {
it('should restore model after successful stream completion', async () => {
mockSendMessageStream.mockReturnValue(
(async function* () {
yield { type: ServerGeminiEventType.Content, value: 'content' };
yield { type: ServerGeminiEventType.Finished, value: 'STOP' };
})(),
);
const { result } = renderHook(() =>
useGeminiStream(
new MockedGeminiClientClass(mockConfig),
[],
mockAddItem,
mockConfig,
mockOnDebugMessage,
mockHandleSlashCommand,
false,
() => 'vscode' as EditorType,
() => {},
() => Promise.resolve(),
false,
() => {},
() => {},
() => {},
),
);
await act(async () => {
await result.current.submitQuery('restore-success');
});
await waitFor(() => {
expect(mockRestoreOriginalModel).toHaveBeenCalledTimes(1);
});
});
it('should restore model when an error occurs during streaming', async () => {
const testError = new Error('stream failure');
mockSendMessageStream.mockReturnValue(
(async function* () {
yield { type: ServerGeminiEventType.Content, value: 'content' };
throw testError;
})(),
);
const { result } = renderHook(() =>
useGeminiStream(
new MockedGeminiClientClass(mockConfig),
[],
mockAddItem,
mockConfig,
mockOnDebugMessage,
mockHandleSlashCommand,
false,
() => 'vscode' as EditorType,
() => {},
() => Promise.resolve(),
false,
() => {},
() => {},
() => {},
),
);
await act(async () => {
await result.current.submitQuery('restore-error');
});
await waitFor(() => {
expect(mockRestoreOriginalModel).toHaveBeenCalledTimes(1);
});
});
});
});

View File

@@ -42,6 +42,7 @@ import type {
import { StreamingState, MessageType, ToolCallStatus } from '../types.js';
import { isAtCommand, isSlashCommand } from '../utils/commandUtils.js';
import { useShellCommandProcessor } from './shellCommandProcessor.js';
import { useVisionAutoSwitch } from './useVisionAutoSwitch.js';
import { handleAtCommand } from './atCommandProcessor.js';
import { findLastSafeSplitPoint } from '../utils/markdownUtilities.js';
import { useStateAndRef } from './useStateAndRef.js';
@@ -88,6 +89,12 @@ export const useGeminiStream = (
setModelSwitchedFromQuotaError: React.Dispatch<React.SetStateAction<boolean>>,
onEditorClose: () => void,
onCancelSubmit: () => void,
visionModelPreviewEnabled: boolean = false,
onVisionSwitchRequired?: (query: PartListUnion) => Promise<{
modelOverride?: string;
persistSessionModel?: string;
showGuidance?: boolean;
}>,
) => {
const [initError, setInitError] = useState<string | null>(null);
const abortControllerRef = useRef<AbortController | null>(null);
@@ -155,6 +162,13 @@ export const useGeminiStream = (
geminiClient,
);
const { handleVisionSwitch, restoreOriginalModel } = useVisionAutoSwitch(
config,
addItem,
visionModelPreviewEnabled,
onVisionSwitchRequired,
);
const streamingState = useMemo(() => {
if (toolCalls.some((tc) => tc.status === 'awaiting_approval')) {
return StreamingState.WaitingForConfirmation;
@@ -715,6 +729,20 @@ export const useGeminiStream = (
return;
}
// Handle vision switch requirement
const visionSwitchResult = await handleVisionSwitch(
queryToSend,
userMessageTimestamp,
options?.isContinuation || false,
);
if (!visionSwitchResult.shouldProceed) {
isSubmittingQueryRef.current = false;
return;
}
const finalQueryToSend = queryToSend;
if (!options?.isContinuation) {
startNewPrompt();
setThought(null); // Reset thought when starting a new prompt
@@ -725,7 +753,7 @@ export const useGeminiStream = (
try {
const stream = geminiClient.sendMessageStream(
queryToSend,
finalQueryToSend,
abortSignal,
prompt_id!,
);
@@ -736,6 +764,8 @@ export const useGeminiStream = (
);
if (processingStatus === StreamProcessingStatus.UserCancelled) {
// Restore original model if it was temporarily overridden
restoreOriginalModel();
isSubmittingQueryRef.current = false;
return;
}
@@ -748,7 +778,13 @@ export const useGeminiStream = (
loopDetectedRef.current = false;
handleLoopDetectedEvent();
}
// Restore original model if it was temporarily overridden
restoreOriginalModel();
} catch (error: unknown) {
// Restore original model if it was temporarily overridden
restoreOriginalModel();
if (error instanceof UnauthorizedError) {
onAuthError();
} else if (!isNodeError(error) || error.name !== 'AbortError') {
@@ -786,6 +822,8 @@ export const useGeminiStream = (
startNewPrompt,
getPromptCount,
handleLoopDetectedEvent,
handleVisionSwitch,
restoreOriginalModel,
],
);

View File

@@ -0,0 +1,374 @@
/**
* @license
* Copyright 2025 Qwen
* SPDX-License-Identifier: Apache-2.0
*/
/* eslint-disable @typescript-eslint/no-explicit-any */
import { describe, it, expect, vi, beforeEach } from 'vitest';
import { renderHook, act } from '@testing-library/react';
import type { Part, PartListUnion } from '@google/genai';
import { AuthType, type Config } from '@qwen-code/qwen-code-core';
import {
shouldOfferVisionSwitch,
processVisionSwitchOutcome,
getVisionSwitchGuidanceMessage,
useVisionAutoSwitch,
} from './useVisionAutoSwitch.js';
import { VisionSwitchOutcome } from '../components/ModelSwitchDialog.js';
import { MessageType } from '../types.js';
import { getDefaultVisionModel } from '../models/availableModels.js';
describe('useVisionAutoSwitch helpers', () => {
describe('shouldOfferVisionSwitch', () => {
it('returns false when authType is not QWEN_OAUTH', () => {
const parts: PartListUnion = [
{ inlineData: { mimeType: 'image/png', data: '...' } },
];
const result = shouldOfferVisionSwitch(
parts,
AuthType.USE_GEMINI,
'qwen3-coder-plus',
true,
);
expect(result).toBe(false);
});
it('returns false when current model is already a vision model', () => {
const parts: PartListUnion = [
{ inlineData: { mimeType: 'image/png', data: '...' } },
];
const result = shouldOfferVisionSwitch(
parts,
AuthType.QWEN_OAUTH,
'qwen-vl-max-latest',
true,
);
expect(result).toBe(false);
});
it('returns true when image parts exist, QWEN_OAUTH, and model is not vision', () => {
const parts: PartListUnion = [
{ text: 'hello' },
{ inlineData: { mimeType: 'image/jpeg', data: '...' } },
];
const result = shouldOfferVisionSwitch(
parts,
AuthType.QWEN_OAUTH,
'qwen3-coder-plus',
true,
);
expect(result).toBe(true);
});
it('detects image when provided as a single Part object (non-array)', () => {
const singleImagePart: PartListUnion = {
fileData: { mimeType: 'image/gif', fileUri: 'file://image.gif' },
} as Part;
const result = shouldOfferVisionSwitch(
singleImagePart,
AuthType.QWEN_OAUTH,
'qwen3-coder-plus',
true,
);
expect(result).toBe(true);
});
it('returns false when parts contain no images', () => {
const parts: PartListUnion = [{ text: 'just text' }];
const result = shouldOfferVisionSwitch(
parts,
AuthType.QWEN_OAUTH,
'qwen3-coder-plus',
true,
);
expect(result).toBe(false);
});
it('returns false when parts is a plain string', () => {
const parts: PartListUnion = 'plain text';
const result = shouldOfferVisionSwitch(
parts,
AuthType.QWEN_OAUTH,
'qwen3-coder-plus',
true,
);
expect(result).toBe(false);
});
it('returns false when visionModelPreviewEnabled is false', () => {
const parts: PartListUnion = [
{ inlineData: { mimeType: 'image/png', data: '...' } },
];
const result = shouldOfferVisionSwitch(
parts,
AuthType.QWEN_OAUTH,
'qwen3-coder-plus',
false,
);
expect(result).toBe(false);
});
});
describe('processVisionSwitchOutcome', () => {
it('maps SwitchOnce to a one-time model override', () => {
const vl = getDefaultVisionModel();
const result = processVisionSwitchOutcome(VisionSwitchOutcome.SwitchOnce);
expect(result).toEqual({ modelOverride: vl });
});
it('maps SwitchSessionToVL to a persistent session model', () => {
const vl = getDefaultVisionModel();
const result = processVisionSwitchOutcome(
VisionSwitchOutcome.SwitchSessionToVL,
);
expect(result).toEqual({ persistSessionModel: vl });
});
it('maps DisallowWithGuidance to showGuidance', () => {
const result = processVisionSwitchOutcome(
VisionSwitchOutcome.DisallowWithGuidance,
);
expect(result).toEqual({ showGuidance: true });
});
});
describe('getVisionSwitchGuidanceMessage', () => {
it('returns the expected guidance message', () => {
const vl = getDefaultVisionModel();
const expected =
'To use images with your query, you can:\n' +
`• Use /model set ${vl} to switch to a vision-capable model\n` +
'• Or remove the image and provide a text description instead';
expect(getVisionSwitchGuidanceMessage()).toBe(expected);
});
});
});
describe('useVisionAutoSwitch hook', () => {
type AddItemFn = (
item: { type: MessageType; text: string },
ts: number,
) => any;
const createMockConfig = (authType: AuthType, initialModel: string) => {
let currentModel = initialModel;
const mockConfig: Partial<Config> = {
getModel: vi.fn(() => currentModel),
setModel: vi.fn((m: string) => {
currentModel = m;
}),
getContentGeneratorConfig: vi.fn(() => ({
authType,
model: currentModel,
apiKey: 'test-key',
vertexai: false,
})),
};
return mockConfig as Config;
};
let addItem: AddItemFn;
beforeEach(() => {
vi.clearAllMocks();
addItem = vi.fn();
});
it('returns shouldProceed=true immediately for continuations', async () => {
const config = createMockConfig(AuthType.QWEN_OAUTH, 'qwen3-coder-plus');
const { result } = renderHook(() =>
useVisionAutoSwitch(config, addItem as any, true, vi.fn()),
);
const parts: PartListUnion = [
{ inlineData: { mimeType: 'image/png', data: '...' } },
];
let res: any;
await act(async () => {
res = await result.current.handleVisionSwitch(parts, Date.now(), true);
});
expect(res).toEqual({ shouldProceed: true });
expect(addItem).not.toHaveBeenCalled();
});
it('does nothing when authType is not QWEN_OAUTH', async () => {
const config = createMockConfig(AuthType.USE_GEMINI, 'qwen3-coder-plus');
const onVisionSwitchRequired = vi.fn();
const { result } = renderHook(() =>
useVisionAutoSwitch(config, addItem as any, true, onVisionSwitchRequired),
);
const parts: PartListUnion = [
{ inlineData: { mimeType: 'image/png', data: '...' } },
];
let res: any;
await act(async () => {
res = await result.current.handleVisionSwitch(parts, 123, false);
});
expect(res).toEqual({ shouldProceed: true });
expect(onVisionSwitchRequired).not.toHaveBeenCalled();
});
it('does nothing when there are no image parts', async () => {
const config = createMockConfig(AuthType.QWEN_OAUTH, 'qwen3-coder-plus');
const onVisionSwitchRequired = vi.fn();
const { result } = renderHook(() =>
useVisionAutoSwitch(config, addItem as any, true, onVisionSwitchRequired),
);
const parts: PartListUnion = [{ text: 'no images here' }];
let res: any;
await act(async () => {
res = await result.current.handleVisionSwitch(parts, 456, false);
});
expect(res).toEqual({ shouldProceed: true });
expect(onVisionSwitchRequired).not.toHaveBeenCalled();
});
it('shows guidance and blocks when dialog returns showGuidance', async () => {
const config = createMockConfig(AuthType.QWEN_OAUTH, 'qwen3-coder-plus');
const onVisionSwitchRequired = vi
.fn()
.mockResolvedValue({ showGuidance: true });
const { result } = renderHook(() =>
useVisionAutoSwitch(config, addItem as any, true, onVisionSwitchRequired),
);
const parts: PartListUnion = [
{ inlineData: { mimeType: 'image/png', data: '...' } },
];
const userTs = 1010;
let res: any;
await act(async () => {
res = await result.current.handleVisionSwitch(parts, userTs, false);
});
expect(addItem).toHaveBeenCalledWith(
{ type: MessageType.INFO, text: getVisionSwitchGuidanceMessage() },
userTs,
);
expect(res).toEqual({ shouldProceed: false });
expect(config.setModel).not.toHaveBeenCalled();
});
it('applies a one-time override and returns originalModel, then restores', async () => {
const initialModel = 'qwen3-coder-plus';
const config = createMockConfig(AuthType.QWEN_OAUTH, initialModel);
const onVisionSwitchRequired = vi
.fn()
.mockResolvedValue({ modelOverride: 'qwen-vl-max-latest' });
const { result } = renderHook(() =>
useVisionAutoSwitch(config, addItem as any, true, onVisionSwitchRequired),
);
const parts: PartListUnion = [
{ inlineData: { mimeType: 'image/png', data: '...' } },
];
let res: any;
await act(async () => {
res = await result.current.handleVisionSwitch(parts, 2020, false);
});
expect(res).toEqual({ shouldProceed: true, originalModel: initialModel });
expect(config.setModel).toHaveBeenCalledWith('qwen-vl-max-latest');
// Now restore
act(() => {
result.current.restoreOriginalModel();
});
expect(config.setModel).toHaveBeenLastCalledWith(initialModel);
});
it('persists session model when dialog requests persistence', async () => {
const config = createMockConfig(AuthType.QWEN_OAUTH, 'qwen3-coder-plus');
const onVisionSwitchRequired = vi
.fn()
.mockResolvedValue({ persistSessionModel: 'qwen-vl-max-latest' });
const { result } = renderHook(() =>
useVisionAutoSwitch(config, addItem as any, true, onVisionSwitchRequired),
);
const parts: PartListUnion = [
{ inlineData: { mimeType: 'image/png', data: '...' } },
];
let res: any;
await act(async () => {
res = await result.current.handleVisionSwitch(parts, 3030, false);
});
expect(res).toEqual({ shouldProceed: true });
expect(config.setModel).toHaveBeenCalledWith('qwen-vl-max-latest');
// Restore should be a no-op since no one-time override was used
act(() => {
result.current.restoreOriginalModel();
});
// Last call should still be the persisted model set
expect((config.setModel as any).mock.calls.pop()?.[0]).toBe(
'qwen-vl-max-latest',
);
});
it('returns shouldProceed=true when dialog returns no special flags', async () => {
const config = createMockConfig(AuthType.QWEN_OAUTH, 'qwen3-coder-plus');
const onVisionSwitchRequired = vi.fn().mockResolvedValue({});
const { result } = renderHook(() =>
useVisionAutoSwitch(config, addItem as any, true, onVisionSwitchRequired),
);
const parts: PartListUnion = [
{ inlineData: { mimeType: 'image/png', data: '...' } },
];
let res: any;
await act(async () => {
res = await result.current.handleVisionSwitch(parts, 4040, false);
});
expect(res).toEqual({ shouldProceed: true });
expect(config.setModel).not.toHaveBeenCalled();
});
it('blocks when dialog throws or is cancelled', async () => {
const config = createMockConfig(AuthType.QWEN_OAUTH, 'qwen3-coder-plus');
const onVisionSwitchRequired = vi.fn().mockRejectedValue(new Error('x'));
const { result } = renderHook(() =>
useVisionAutoSwitch(config, addItem as any, true, onVisionSwitchRequired),
);
const parts: PartListUnion = [
{ inlineData: { mimeType: 'image/png', data: '...' } },
];
let res: any;
await act(async () => {
res = await result.current.handleVisionSwitch(parts, 5050, false);
});
expect(res).toEqual({ shouldProceed: false });
expect(config.setModel).not.toHaveBeenCalled();
});
it('does nothing when visionModelPreviewEnabled is false', async () => {
const config = createMockConfig(AuthType.QWEN_OAUTH, 'qwen3-coder-plus');
const onVisionSwitchRequired = vi.fn();
const { result } = renderHook(() =>
useVisionAutoSwitch(
config,
addItem as any,
false,
onVisionSwitchRequired,
),
);
const parts: PartListUnion = [
{ inlineData: { mimeType: 'image/png', data: '...' } },
];
let res: any;
await act(async () => {
res = await result.current.handleVisionSwitch(parts, 6060, false);
});
expect(res).toEqual({ shouldProceed: true });
expect(onVisionSwitchRequired).not.toHaveBeenCalled();
});
});

View File

@@ -0,0 +1,304 @@
/**
* @license
* Copyright 2025 Qwen
* SPDX-License-Identifier: Apache-2.0
*/
import { type PartListUnion, type Part } from '@google/genai';
import { AuthType, type Config } from '@qwen-code/qwen-code-core';
import { useCallback, useRef } from 'react';
import { VisionSwitchOutcome } from '../components/ModelSwitchDialog.js';
import {
getDefaultVisionModel,
isVisionModel,
} from '../models/availableModels.js';
import { MessageType } from '../types.js';
import type { UseHistoryManagerReturn } from './useHistoryManager.js';
import {
isSupportedImageMimeType,
getUnsupportedImageFormatWarning,
} from '@qwen-code/qwen-code-core';
/**
* Checks if a PartListUnion contains image parts
*/
function hasImageParts(parts: PartListUnion): boolean {
if (typeof parts === 'string') {
return false;
}
if (Array.isArray(parts)) {
return parts.some((part) => {
// Skip string parts
if (typeof part === 'string') return false;
return isImagePart(part);
});
}
// If it's a single Part (not a string), check if it's an image
if (typeof parts === 'object') {
return isImagePart(parts);
}
return false;
}
/**
* Checks if a single Part is an image part
*/
function isImagePart(part: Part): boolean {
// Check for inlineData with image mime type
if ('inlineData' in part && part.inlineData?.mimeType?.startsWith('image/')) {
return true;
}
// Check for fileData with image mime type
if ('fileData' in part && part.fileData?.mimeType?.startsWith('image/')) {
return true;
}
return false;
}
/**
* Checks if image parts have supported formats and returns unsupported ones
*/
function checkImageFormatsSupport(parts: PartListUnion): {
hasImages: boolean;
hasUnsupportedFormats: boolean;
unsupportedMimeTypes: string[];
} {
const unsupportedMimeTypes: string[] = [];
let hasImages = false;
if (typeof parts === 'string') {
return {
hasImages: false,
hasUnsupportedFormats: false,
unsupportedMimeTypes: [],
};
}
const partsArray = Array.isArray(parts) ? parts : [parts];
for (const part of partsArray) {
if (typeof part === 'string') continue;
let mimeType: string | undefined;
// Check inlineData
if (
'inlineData' in part &&
part.inlineData?.mimeType?.startsWith('image/')
) {
hasImages = true;
mimeType = part.inlineData.mimeType;
}
// Check fileData
if ('fileData' in part && part.fileData?.mimeType?.startsWith('image/')) {
hasImages = true;
mimeType = part.fileData.mimeType;
}
// Check if the mime type is supported
if (mimeType && !isSupportedImageMimeType(mimeType)) {
unsupportedMimeTypes.push(mimeType);
}
}
return {
hasImages,
hasUnsupportedFormats: unsupportedMimeTypes.length > 0,
unsupportedMimeTypes,
};
}
/**
* Determines if we should offer vision switch for the given parts, auth type, and current model
*/
export function shouldOfferVisionSwitch(
parts: PartListUnion,
authType: AuthType,
currentModel: string,
visionModelPreviewEnabled: boolean = false,
): boolean {
// Only trigger for qwen-oauth
if (authType !== AuthType.QWEN_OAUTH) {
return false;
}
// If vision model preview is disabled, never offer vision switch
if (!visionModelPreviewEnabled) {
return false;
}
// If current model is already a vision model, no need to switch
if (isVisionModel(currentModel)) {
return false;
}
// Check if the current message contains image parts
return hasImageParts(parts);
}
/**
* Interface for vision switch result
*/
export interface VisionSwitchResult {
modelOverride?: string;
persistSessionModel?: string;
showGuidance?: boolean;
}
/**
* Processes the vision switch outcome and returns the appropriate result
*/
export function processVisionSwitchOutcome(
outcome: VisionSwitchOutcome,
): VisionSwitchResult {
const vlModelId = getDefaultVisionModel();
switch (outcome) {
case VisionSwitchOutcome.SwitchOnce:
return { modelOverride: vlModelId };
case VisionSwitchOutcome.SwitchSessionToVL:
return { persistSessionModel: vlModelId };
case VisionSwitchOutcome.DisallowWithGuidance:
return { showGuidance: true };
default:
return { showGuidance: true };
}
}
/**
* Gets the guidance message for when vision switch is disallowed
*/
export function getVisionSwitchGuidanceMessage(): string {
const vlModelId = getDefaultVisionModel();
return `To use images with your query, you can:
• Use /model set ${vlModelId} to switch to a vision-capable model
• Or remove the image and provide a text description instead`;
}
/**
* Interface for vision switch handling result
*/
export interface VisionSwitchHandlingResult {
shouldProceed: boolean;
originalModel?: string;
}
/**
* Custom hook for handling vision model auto-switching
*/
export function useVisionAutoSwitch(
config: Config,
addItem: UseHistoryManagerReturn['addItem'],
visionModelPreviewEnabled: boolean = false,
onVisionSwitchRequired?: (query: PartListUnion) => Promise<{
modelOverride?: string;
persistSessionModel?: string;
showGuidance?: boolean;
}>,
) {
const originalModelRef = useRef<string | null>(null);
const handleVisionSwitch = useCallback(
async (
query: PartListUnion,
userMessageTimestamp: number,
isContinuation: boolean,
): Promise<VisionSwitchHandlingResult> => {
// Skip vision switch handling for continuations or if no handler provided
if (isContinuation || !onVisionSwitchRequired) {
return { shouldProceed: true };
}
const contentGeneratorConfig = config.getContentGeneratorConfig();
// Only handle qwen-oauth auth type
if (contentGeneratorConfig?.authType !== AuthType.QWEN_OAUTH) {
return { shouldProceed: true };
}
// Check image format support first
const formatCheck = checkImageFormatsSupport(query);
// If there are unsupported image formats, show warning
if (formatCheck.hasUnsupportedFormats) {
addItem(
{
type: MessageType.INFO,
text: getUnsupportedImageFormatWarning(),
},
userMessageTimestamp,
);
// Continue processing but with warning shown
}
// Check if vision switch is needed
if (
!shouldOfferVisionSwitch(
query,
contentGeneratorConfig.authType,
config.getModel(),
visionModelPreviewEnabled,
)
) {
return { shouldProceed: true };
}
try {
const visionSwitchResult = await onVisionSwitchRequired(query);
if (visionSwitchResult.showGuidance) {
// Show guidance and don't proceed with the request
addItem(
{
type: MessageType.INFO,
text: getVisionSwitchGuidanceMessage(),
},
userMessageTimestamp,
);
return { shouldProceed: false };
}
if (visionSwitchResult.modelOverride) {
// One-time model override
originalModelRef.current = config.getModel();
config.setModel(visionSwitchResult.modelOverride);
return {
shouldProceed: true,
originalModel: originalModelRef.current,
};
} else if (visionSwitchResult.persistSessionModel) {
// Persistent session model change
config.setModel(visionSwitchResult.persistSessionModel);
return { shouldProceed: true };
}
return { shouldProceed: true };
} catch (_error) {
// If vision switch dialog was cancelled or errored, don't proceed
return { shouldProceed: false };
}
},
[config, addItem, visionModelPreviewEnabled, onVisionSwitchRequired],
);
const restoreOriginalModel = useCallback(() => {
if (originalModelRef.current) {
config.setModel(originalModelRef.current);
originalModelRef.current = null;
}
}, [config]);
return {
handleVisionSwitch,
restoreOriginalModel,
};
}

View File

@@ -0,0 +1,52 @@
/**
* @license
* Copyright 2025 Qwen
* SPDX-License-Identifier: Apache-2.0
*/
export type AvailableModel = {
id: string;
label: string;
isVision?: boolean;
};
export const AVAILABLE_MODELS_QWEN: AvailableModel[] = [
{ id: 'qwen3-coder-plus', label: 'qwen3-coder-plus' },
{ id: 'qwen-vl-max-latest', label: 'qwen-vl-max', isVision: true },
];
/**
* Get available Qwen models filtered by vision model preview setting
*/
export function getFilteredQwenModels(
visionModelPreviewEnabled: boolean,
): AvailableModel[] {
if (visionModelPreviewEnabled) {
return AVAILABLE_MODELS_QWEN;
}
return AVAILABLE_MODELS_QWEN.filter((model) => !model.isVision);
}
/**
* Currently we use the single model of `OPENAI_MODEL` in the env.
* In the future, after settings.json is updated, we will allow users to configure this themselves.
*/
export function getOpenAIAvailableModelFromEnv(): AvailableModel | null {
const id = process.env['OPENAI_MODEL']?.trim();
return id ? { id, label: id } : null;
}
/**
/**
* Hard code the default vision model as a string literal,
* until our coding model supports multimodal.
*/
export function getDefaultVisionModel(): string {
return 'qwen-vl-max-latest';
}
export function isVisionModel(modelId: string): boolean {
return AVAILABLE_MODELS_QWEN.some(
(model) => model.id === modelId && model.isVision,
);
}

View File

@@ -19,3 +19,4 @@ export {
} from './src/telemetry/types.js';
export { makeFakeConfig } from './src/test-utils/config.js';
export * from './src/utils/pathReader.js';
export * from './src/utils/request-tokenizer/supportedImageFormats.js';

View File

@@ -5,9 +5,10 @@
*/
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
import { OpenAIContentGenerator } from '../openaiContentGenerator.js';
import { OpenAIContentGenerator } from '../openaiContentGenerator/openaiContentGenerator.js';
import type { Config } from '../../config/config.js';
import { AuthType } from '../contentGenerator.js';
import type { OpenAICompatibleProvider } from '../openaiContentGenerator/provider/index.js';
import OpenAI from 'openai';
// Mock OpenAI
@@ -30,6 +31,7 @@ describe('OpenAIContentGenerator Timeout Handling', () => {
let mockConfig: Config;
// eslint-disable-next-line @typescript-eslint/no-explicit-any
let mockOpenAIClient: any;
let mockProvider: OpenAICompatibleProvider;
beforeEach(() => {
// Reset mocks
@@ -42,6 +44,7 @@ describe('OpenAIContentGenerator Timeout Handling', () => {
mockConfig = {
getContentGeneratorConfig: vi.fn().mockReturnValue({
authType: 'openai',
enableOpenAILogging: false,
}),
getCliVersion: vi.fn().mockReturnValue('1.0.0'),
} as unknown as Config;
@@ -53,17 +56,34 @@ describe('OpenAIContentGenerator Timeout Handling', () => {
create: vi.fn(),
},
},
embeddings: {
create: vi.fn(),
},
};
vi.mocked(OpenAI).mockImplementation(() => mockOpenAIClient);
// Create mock provider
mockProvider = {
buildHeaders: vi.fn().mockReturnValue({
'User-Agent': 'QwenCode/1.0.0 (test; test)',
}),
buildClient: vi.fn().mockReturnValue(mockOpenAIClient),
buildRequest: vi.fn().mockImplementation((req) => req),
};
// Create generator instance
const contentGeneratorConfig = {
model: 'gpt-4',
apiKey: 'test-key',
authType: AuthType.USE_OPENAI,
enableOpenAILogging: false,
};
generator = new OpenAIContentGenerator(contentGeneratorConfig, mockConfig);
generator = new OpenAIContentGenerator(
contentGeneratorConfig,
mockConfig,
mockProvider,
);
});
afterEach(() => {
@@ -209,7 +229,7 @@ describe('OpenAIContentGenerator Timeout Handling', () => {
await expect(
generator.generateContentStream(request, 'test-prompt-id'),
).rejects.toThrow(
/Streaming setup timeout after \d+s\. Try reducing input length or increasing timeout in config\./,
/Streaming request timeout after \d+s\. Try reducing input length or increasing timeout in config\./,
);
});
@@ -227,12 +247,8 @@ describe('OpenAIContentGenerator Timeout Handling', () => {
} catch (error: unknown) {
const errorMessage =
error instanceof Error ? error.message : String(error);
expect(errorMessage).toContain(
'Streaming setup timeout troubleshooting:',
);
expect(errorMessage).toContain(
'Check network connectivity and firewall settings',
);
expect(errorMessage).toContain('Streaming timeout troubleshooting:');
expect(errorMessage).toContain('Check network connectivity');
expect(errorMessage).toContain('Consider using non-streaming mode');
}
});
@@ -246,23 +262,21 @@ describe('OpenAIContentGenerator Timeout Handling', () => {
authType: AuthType.USE_OPENAI,
baseUrl: 'http://localhost:8080',
};
new OpenAIContentGenerator(contentGeneratorConfig, mockConfig);
new OpenAIContentGenerator(
contentGeneratorConfig,
mockConfig,
mockProvider,
);
// Verify OpenAI client was created with timeout config
expect(OpenAI).toHaveBeenCalledWith({
apiKey: 'test-key',
baseURL: 'http://localhost:8080',
timeout: 120000,
maxRetries: 3,
defaultHeaders: {
'User-Agent': expect.stringMatching(/^QwenCode/),
},
});
// Verify provider buildClient was called
expect(mockProvider.buildClient).toHaveBeenCalled();
});
it('should use custom timeout from config', () => {
const customConfig = {
getContentGeneratorConfig: vi.fn().mockReturnValue({}),
getContentGeneratorConfig: vi.fn().mockReturnValue({
enableOpenAILogging: false,
}),
getCliVersion: vi.fn().mockReturnValue('1.0.0'),
} as unknown as Config;
@@ -274,22 +288,31 @@ describe('OpenAIContentGenerator Timeout Handling', () => {
timeout: 300000,
maxRetries: 5,
};
new OpenAIContentGenerator(contentGeneratorConfig, customConfig);
expect(OpenAI).toHaveBeenCalledWith({
apiKey: 'test-key',
baseURL: 'http://localhost:8080',
timeout: 300000,
maxRetries: 5,
defaultHeaders: {
'User-Agent': expect.stringMatching(/^QwenCode/),
},
});
// Create a custom mock provider for this test
const customMockProvider: OpenAICompatibleProvider = {
buildHeaders: vi.fn().mockReturnValue({
'User-Agent': 'QwenCode/1.0.0 (test; test)',
}),
buildClient: vi.fn().mockReturnValue(mockOpenAIClient),
buildRequest: vi.fn().mockImplementation((req) => req),
};
new OpenAIContentGenerator(
contentGeneratorConfig,
customConfig,
customMockProvider,
);
// Verify provider buildClient was called
expect(customMockProvider.buildClient).toHaveBeenCalled();
});
it('should handle missing timeout config gracefully', () => {
const noTimeoutConfig = {
getContentGeneratorConfig: vi.fn().mockReturnValue({}),
getContentGeneratorConfig: vi.fn().mockReturnValue({
enableOpenAILogging: false,
}),
getCliVersion: vi.fn().mockReturnValue('1.0.0'),
} as unknown as Config;
@@ -299,17 +322,24 @@ describe('OpenAIContentGenerator Timeout Handling', () => {
authType: AuthType.USE_OPENAI,
baseUrl: 'http://localhost:8080',
};
new OpenAIContentGenerator(contentGeneratorConfig, noTimeoutConfig);
expect(OpenAI).toHaveBeenCalledWith({
apiKey: 'test-key',
baseURL: 'http://localhost:8080',
timeout: 120000, // default
maxRetries: 3, // default
defaultHeaders: {
'User-Agent': expect.stringMatching(/^QwenCode/),
},
});
// Create a custom mock provider for this test
const noTimeoutMockProvider: OpenAICompatibleProvider = {
buildHeaders: vi.fn().mockReturnValue({
'User-Agent': 'QwenCode/1.0.0 (test; test)',
}),
buildClient: vi.fn().mockReturnValue(mockOpenAIClient),
buildRequest: vi.fn().mockImplementation((req) => req),
};
new OpenAIContentGenerator(
contentGeneratorConfig,
noTimeoutConfig,
noTimeoutMockProvider,
);
// Verify provider buildClient was called
expect(noTimeoutMockProvider.buildClient).toHaveBeenCalled();
});
});

View File

@@ -500,7 +500,7 @@ export class GeminiChat {
if (error instanceof Error && error.message) {
if (isSchemaDepthError(error.message)) return false;
if (error.message.includes('429')) return true;
if (error.message.match(/5\d{2}/)) return true;
if (error.message.match(/^5\d{2}/)) return true;
}
return false;
},

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -5,6 +5,37 @@
*/
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
// Mock the request tokenizer module BEFORE importing the class that uses it
const mockTokenizer = {
calculateTokens: vi.fn().mockResolvedValue({
totalTokens: 50,
breakdown: {
textTokens: 50,
imageTokens: 0,
audioTokens: 0,
otherTokens: 0,
},
processingTime: 1,
}),
dispose: vi.fn(),
};
vi.mock('../../../utils/request-tokenizer/index.js', () => ({
getDefaultTokenizer: vi.fn(() => mockTokenizer),
DefaultRequestTokenizer: vi.fn(() => mockTokenizer),
disposeDefaultTokenizer: vi.fn(),
}));
// Mock tiktoken as well for completeness
vi.mock('tiktoken', () => ({
get_encoding: vi.fn(() => ({
encode: vi.fn(() => new Array(50)), // Mock 50 tokens
free: vi.fn(),
})),
}));
// Now import the modules that depend on the mocked modules
import { OpenAIContentGenerator } from './openaiContentGenerator.js';
import type { Config } from '../../config/config.js';
import { AuthType } from '../contentGenerator.js';
@@ -15,14 +46,6 @@ import type {
import type { OpenAICompatibleProvider } from './provider/index.js';
import type OpenAI from 'openai';
// Mock tiktoken
vi.mock('tiktoken', () => ({
get_encoding: vi.fn().mockReturnValue({
encode: vi.fn().mockReturnValue(new Array(50)), // Mock 50 tokens
free: vi.fn(),
}),
}));
describe('OpenAIContentGenerator (Refactored)', () => {
let generator: OpenAIContentGenerator;
let mockConfig: Config;

View File

@@ -13,6 +13,7 @@ import type { PipelineConfig } from './pipeline.js';
import { ContentGenerationPipeline } from './pipeline.js';
import { DefaultTelemetryService } from './telemetryService.js';
import { EnhancedErrorHandler } from './errorHandler.js';
import { getDefaultTokenizer } from '../../utils/request-tokenizer/index.js';
import type { ContentGeneratorConfig } from '../contentGenerator.js';
export class OpenAIContentGenerator implements ContentGenerator {
@@ -71,27 +72,30 @@ export class OpenAIContentGenerator implements ContentGenerator {
async countTokens(
request: CountTokensParameters,
): Promise<CountTokensResponse> {
// Use tiktoken for accurate token counting
const content = JSON.stringify(request.contents);
let totalTokens = 0;
try {
const { get_encoding } = await import('tiktoken');
const encoding = get_encoding('cl100k_base'); // GPT-4 encoding, but estimate for qwen
totalTokens = encoding.encode(content).length;
encoding.free();
// Use the new high-performance request tokenizer
const tokenizer = getDefaultTokenizer();
const result = await tokenizer.calculateTokens(request, {
textEncoding: 'cl100k_base', // Use GPT-4 encoding for consistency
});
return {
totalTokens: result.totalTokens,
};
} catch (error) {
console.warn(
'Failed to load tiktoken, falling back to character approximation:',
'Failed to calculate tokens with new tokenizer, falling back to simple method:',
error,
);
// Fallback: rough approximation using character count
totalTokens = Math.ceil(content.length / 4); // Rough estimate: 1 token ≈ 4 characters
}
return {
totalTokens,
};
// Fallback to original simple method
const content = JSON.stringify(request.contents);
const totalTokens = Math.ceil(content.length / 4); // Rough estimate: 1 token ≈ 4 characters
return {
totalTokens,
};
}
}
async embedContent(

View File

@@ -10,14 +10,11 @@ import {
GenerateContentResponse,
} from '@google/genai';
import type { Config } from '../../config/config.js';
import { type ContentGeneratorConfig } from '../contentGenerator.js';
import { type OpenAICompatibleProvider } from './provider/index.js';
import type { ContentGeneratorConfig } from '../contentGenerator.js';
import type { OpenAICompatibleProvider } from './provider/index.js';
import { OpenAIContentConverter } from './converter.js';
import {
type TelemetryService,
type RequestContext,
} from './telemetryService.js';
import { type ErrorHandler } from './errorHandler.js';
import type { TelemetryService, RequestContext } from './telemetryService.js';
import type { ErrorHandler } from './errorHandler.js';
export interface PipelineConfig {
cliConfig: Config;
@@ -101,7 +98,7 @@ export class ContentGenerationPipeline {
* 2. Filter empty responses
* 3. Handle chunk merging for providers that send finishReason and usageMetadata separately
* 4. Collect both formats for logging
* 5. Handle success/error logging with original OpenAI format
* 5. Handle success/error logging
*/
private async *processStreamWithLogging(
stream: AsyncIterable<OpenAI.Chat.ChatCompletionChunk>,
@@ -169,19 +166,11 @@ export class ContentGenerationPipeline {
collectedOpenAIChunks,
);
} catch (error) {
// Stage 2e: Stream failed - handle error and logging
context.duration = Date.now() - context.startTime;
// Clear streaming tool calls on error to prevent data pollution
this.converter.resetStreamingToolCalls();
await this.config.telemetryService.logError(
context,
error,
openaiRequest,
);
this.config.errorHandler.handle(error, context, request);
// Use shared error handling logic
await this.handleError(error, context, request);
}
}
@@ -365,25 +354,59 @@ export class ContentGenerationPipeline {
context.duration = Date.now() - context.startTime;
return result;
} catch (error) {
context.duration = Date.now() - context.startTime;
// Log error
const openaiRequest = await this.buildRequest(
// Use shared error handling logic
return await this.handleError(
error,
context,
request,
userPromptId,
isStreaming,
);
await this.config.telemetryService.logError(
context,
error,
openaiRequest,
);
// Handle and throw enhanced error
this.config.errorHandler.handle(error, context, request);
}
}
/**
* Shared error handling logic for both executeWithErrorHandling and processStreamWithLogging
* This centralizes the common error processing steps to avoid duplication
*/
private async handleError(
error: unknown,
context: RequestContext,
request: GenerateContentParameters,
userPromptId?: string,
isStreaming?: boolean,
): Promise<never> {
context.duration = Date.now() - context.startTime;
// Build request for logging (may fail, but we still want to log the error)
let openaiRequest: OpenAI.Chat.ChatCompletionCreateParams;
try {
if (userPromptId !== undefined && isStreaming !== undefined) {
openaiRequest = await this.buildRequest(
request,
userPromptId,
isStreaming,
);
} else {
// For processStreamWithLogging, we don't have userPromptId/isStreaming,
// so create a minimal request
openaiRequest = {
model: this.contentGeneratorConfig.model,
messages: [],
};
}
} catch (_buildError) {
// If we can't build the request, create a minimal one for logging
openaiRequest = {
model: this.contentGeneratorConfig.model,
messages: [],
};
}
await this.config.telemetryService.logError(context, error, openaiRequest);
this.config.errorHandler.handle(error, context, request);
}
/**
* Create request context with common properties
*/

View File

@@ -79,6 +79,16 @@ export class DashScopeOpenAICompatibleProvider
messages = this.addDashScopeCacheControl(messages, cacheTarget);
}
if (request.model.startsWith('qwen-vl')) {
return {
...request,
messages,
...(this.buildMetadata(userPromptId) || {}),
/* @ts-expect-error dashscope exclusive */
vl_high_resolution_images: true,
};
}
return {
...request, // Preserve all original parameters including sampling params
messages,

View File

@@ -116,6 +116,9 @@ const PATTERNS: Array<[RegExp, TokenCount]> = [
[/^qwen-flash-latest$/, LIMITS['1m']],
[/^qwen-turbo.*$/, LIMITS['128k']],
// Qwen Vision Models
[/^qwen-vl-max.*$/, LIMITS['128k']],
// -------------------
// ByteDance Seed-OSS (512K)
// -------------------

View File

@@ -242,7 +242,7 @@ describe('Turn', () => {
expect(turn.getDebugResponses().length).toBe(0);
expect(reportError).toHaveBeenCalledWith(
error,
'Error when talking to Gemini API',
'Error when talking to API',
[...historyContent, reqParts],
'Turn.run-sendMessageStream',
);

View File

@@ -310,7 +310,7 @@ export class Turn {
const contextForReport = [...this.chat.getHistory(/*curated*/ true), req];
await reportError(
error,
'Error when talking to Gemini API',
'Error when talking to API',
contextForReport,
'Turn.run-sendMessageStream',
);

View File

@@ -401,11 +401,9 @@ describe('QwenContentGenerator', () => {
expect(mockQwenClient.getAccessToken).toHaveBeenCalled();
});
it('should count tokens with valid token', async () => {
vi.mocked(mockQwenClient.getAccessToken).mockResolvedValue({
token: 'valid-token',
});
vi.mocked(mockQwenClient.getCredentials).mockReturnValue(mockCredentials);
it('should count tokens without requiring authentication', async () => {
// Clear any previous mock calls
vi.clearAllMocks();
const request: CountTokensParameters = {
model: 'qwen-turbo',
@@ -415,7 +413,8 @@ describe('QwenContentGenerator', () => {
const result = await qwenContentGenerator.countTokens(request);
expect(result.totalTokens).toBe(15);
expect(mockQwenClient.getAccessToken).toHaveBeenCalled();
// countTokens is a local operation and should not require OAuth credentials
expect(mockQwenClient.getAccessToken).not.toHaveBeenCalled();
});
it('should embed content with valid token', async () => {
@@ -1652,7 +1651,7 @@ describe('QwenContentGenerator', () => {
SharedTokenManager.getInstance = originalGetInstance;
});
it('should handle all method types with token failure', async () => {
it('should handle method types with token failure (except countTokens)', async () => {
const mockTokenManager = {
getValidCredentials: vi
.fn()
@@ -1685,7 +1684,7 @@ describe('QwenContentGenerator', () => {
contents: [{ parts: [{ text: 'Embed' }] }],
};
// All methods should fail with the same error
// Methods requiring authentication should fail
await expect(
newGenerator.generateContent(generateRequest, 'test-id'),
).rejects.toThrow('Failed to obtain valid Qwen access token');
@@ -1694,14 +1693,14 @@ describe('QwenContentGenerator', () => {
newGenerator.generateContentStream(generateRequest, 'test-id'),
).rejects.toThrow('Failed to obtain valid Qwen access token');
await expect(newGenerator.countTokens(countRequest)).rejects.toThrow(
'Failed to obtain valid Qwen access token',
);
await expect(newGenerator.embedContent(embedRequest)).rejects.toThrow(
'Failed to obtain valid Qwen access token',
);
// countTokens should succeed as it's a local operation
const countResult = await newGenerator.countTokens(countRequest);
expect(countResult.totalTokens).toBe(15);
SharedTokenManager.getInstance = originalGetInstance;
});
});

View File

@@ -180,9 +180,7 @@ export class QwenContentGenerator extends OpenAIContentGenerator {
override async countTokens(
request: CountTokensParameters,
): Promise<CountTokensResponse> {
return this.executeWithCredentialManagement(() =>
super.countTokens(request),
);
return super.countTokens(request);
}
/**

View File

@@ -0,0 +1,157 @@
/**
* @license
* Copyright 2025 Qwen
* SPDX-License-Identifier: Apache-2.0
*/
import { describe, it, expect } from 'vitest';
import { ImageTokenizer } from './imageTokenizer.js';
describe('ImageTokenizer', () => {
const tokenizer = new ImageTokenizer();
describe('token calculation', () => {
it('should calculate tokens based on image dimensions with reference logic', () => {
const metadata = {
width: 28,
height: 28,
mimeType: 'image/png',
dataSize: 1000,
};
const tokens = tokenizer.calculateTokens(metadata);
// 28x28 = 784 pixels = 1 image token + 2 special tokens = 3 total
// But minimum scaling may apply for small images
expect(tokens).toBeGreaterThanOrEqual(6); // Minimum after scaling + special tokens
});
it('should calculate tokens for larger images', () => {
const metadata = {
width: 512,
height: 512,
mimeType: 'image/png',
dataSize: 10000,
};
const tokens = tokenizer.calculateTokens(metadata);
// 512x512 with reference logic: rounded dimensions + scaling + special tokens
expect(tokens).toBeGreaterThan(300);
expect(tokens).toBeLessThan(400); // Should be reasonable for 512x512
});
it('should enforce minimum tokens per image with scaling', () => {
const metadata = {
width: 1,
height: 1,
mimeType: 'image/png',
dataSize: 100,
};
const tokens = tokenizer.calculateTokens(metadata);
// Tiny images get scaled up to minimum pixels + special tokens
expect(tokens).toBeGreaterThanOrEqual(6); // 4 image tokens + 2 special tokens
});
it('should handle very large images with scaling', () => {
const metadata = {
width: 8192,
height: 8192,
mimeType: 'image/png',
dataSize: 100000,
};
const tokens = tokenizer.calculateTokens(metadata);
// Very large images should be scaled down to max limit + special tokens
expect(tokens).toBeLessThanOrEqual(16386); // 16384 max + 2 special tokens
expect(tokens).toBeGreaterThan(16000); // Should be close to the limit
});
});
describe('PNG dimension extraction', () => {
it('should extract dimensions from valid PNG', async () => {
// 1x1 PNG image in base64
const pngBase64 =
'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNkYPhfDwAChAI9jU77yQAAAABJRU5ErkJggg==';
const metadata = await tokenizer.extractImageMetadata(
pngBase64,
'image/png',
);
expect(metadata.width).toBe(1);
expect(metadata.height).toBe(1);
expect(metadata.mimeType).toBe('image/png');
});
it('should handle invalid PNG gracefully', async () => {
const invalidBase64 = 'invalid-png-data';
const metadata = await tokenizer.extractImageMetadata(
invalidBase64,
'image/png',
);
// Should return default dimensions
expect(metadata.width).toBe(512);
expect(metadata.height).toBe(512);
expect(metadata.mimeType).toBe('image/png');
});
});
describe('batch processing', () => {
it('should process multiple images serially', async () => {
const pngBase64 =
'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNkYPhfDwAChAI9jU77yQAAAABJRU5ErkJggg==';
const images = [
{ data: pngBase64, mimeType: 'image/png' },
{ data: pngBase64, mimeType: 'image/png' },
{ data: pngBase64, mimeType: 'image/png' },
];
const tokens = await tokenizer.calculateTokensBatch(images);
expect(tokens).toHaveLength(3);
expect(tokens.every((t) => t >= 4)).toBe(true); // All should have at least 4 tokens
});
it('should handle mixed valid and invalid images', async () => {
const validPng =
'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNkYPhfDwAChAI9jU77yQAAAABJRU5ErkJggg==';
const invalidPng = 'invalid-data';
const images = [
{ data: validPng, mimeType: 'image/png' },
{ data: invalidPng, mimeType: 'image/png' },
];
const tokens = await tokenizer.calculateTokensBatch(images);
expect(tokens).toHaveLength(2);
expect(tokens.every((t) => t >= 4)).toBe(true); // All should have at least minimum tokens
});
});
describe('different image formats', () => {
it('should handle different MIME types', async () => {
const pngBase64 =
'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNkYPhfDwAChAI9jU77yQAAAABJRU5ErkJggg==';
const formats = ['image/png', 'image/jpeg', 'image/webp', 'image/gif'];
for (const mimeType of formats) {
const metadata = await tokenizer.extractImageMetadata(
pngBase64,
mimeType,
);
expect(metadata.mimeType).toBe(mimeType);
expect(metadata.width).toBeGreaterThan(0);
expect(metadata.height).toBeGreaterThan(0);
}
});
});
});

View File

@@ -0,0 +1,505 @@
/**
* @license
* Copyright 2025 Qwen
* SPDX-License-Identifier: Apache-2.0
*/
import type { ImageMetadata } from './types.js';
import { isSupportedImageMimeType } from './supportedImageFormats.js';
/**
* Image tokenizer for calculating image tokens based on dimensions
*
* Key rules:
* - 28x28 pixels = 1 token
* - Minimum: 4 tokens per image
* - Maximum: 16384 tokens per image
* - Additional: 2 special tokens (vision_bos + vision_eos)
* - Supports: PNG, JPEG, WebP, GIF, BMP, TIFF, HEIC formats
*/
export class ImageTokenizer {
/** 28x28 pixels = 1 token */
private static readonly PIXELS_PER_TOKEN = 28 * 28;
/** Minimum tokens per image */
private static readonly MIN_TOKENS_PER_IMAGE = 4;
/** Maximum tokens per image */
private static readonly MAX_TOKENS_PER_IMAGE = 16384;
/** Special tokens for vision markers */
private static readonly VISION_SPECIAL_TOKENS = 2;
/**
* Extract image metadata from base64 data
*
* @param base64Data Base64-encoded image data (with or without data URL prefix)
* @param mimeType MIME type of the image
* @returns Promise resolving to ImageMetadata with dimensions and format info
*/
async extractImageMetadata(
base64Data: string,
mimeType: string,
): Promise<ImageMetadata> {
try {
// Check if the MIME type is supported
if (!isSupportedImageMimeType(mimeType)) {
console.warn(`Unsupported image format: ${mimeType}`);
// Return default metadata for unsupported formats
return {
width: 512,
height: 512,
mimeType,
dataSize: Math.floor(base64Data.length * 0.75),
};
}
const cleanBase64 = base64Data.replace(/^data:[^;]+;base64,/, '');
const buffer = Buffer.from(cleanBase64, 'base64');
const dimensions = await this.extractDimensions(buffer, mimeType);
return {
width: dimensions.width,
height: dimensions.height,
mimeType,
dataSize: buffer.length,
};
} catch (error) {
console.warn('Failed to extract image metadata:', error);
// Return default metadata for fallback
return {
width: 512,
height: 512,
mimeType,
dataSize: Math.floor(base64Data.length * 0.75),
};
}
}
/**
* Extract image dimensions from buffer based on format
*
* @param buffer Binary image data buffer
* @param mimeType MIME type to determine parsing strategy
* @returns Promise resolving to width and height dimensions
*/
private async extractDimensions(
buffer: Buffer,
mimeType: string,
): Promise<{ width: number; height: number }> {
if (mimeType.includes('png')) {
return this.extractPngDimensions(buffer);
}
if (mimeType.includes('jpeg') || mimeType.includes('jpg')) {
return this.extractJpegDimensions(buffer);
}
if (mimeType.includes('webp')) {
return this.extractWebpDimensions(buffer);
}
if (mimeType.includes('gif')) {
return this.extractGifDimensions(buffer);
}
if (mimeType.includes('bmp')) {
return this.extractBmpDimensions(buffer);
}
if (mimeType.includes('tiff')) {
return this.extractTiffDimensions(buffer);
}
if (mimeType.includes('heic')) {
return this.extractHeicDimensions(buffer);
}
return { width: 512, height: 512 };
}
/**
* Extract PNG dimensions from IHDR chunk
* PNG signature: 89 50 4E 47 0D 0A 1A 0A
* Width/height at bytes 16-19 and 20-23 (big-endian)
*/
private extractPngDimensions(buffer: Buffer): {
width: number;
height: number;
} {
if (buffer.length < 24) {
throw new Error('Invalid PNG: buffer too short');
}
// Verify PNG signature
const signature = buffer.subarray(0, 8);
const expectedSignature = Buffer.from([
0x89, 0x50, 0x4e, 0x47, 0x0d, 0x0a, 0x1a, 0x0a,
]);
if (!signature.equals(expectedSignature)) {
throw new Error('Invalid PNG signature');
}
const width = buffer.readUInt32BE(16);
const height = buffer.readUInt32BE(20);
return { width, height };
}
/**
* Extract JPEG dimensions from SOF (Start of Frame) markers
* JPEG starts with FF D8, SOF markers: 0xC0-0xC3, 0xC5-0xC7, 0xC9-0xCB, 0xCD-0xCF
* Dimensions at offset +5 (height) and +7 (width) from SOF marker
*/
private extractJpegDimensions(buffer: Buffer): {
width: number;
height: number;
} {
if (buffer.length < 4 || buffer[0] !== 0xff || buffer[1] !== 0xd8) {
throw new Error('Invalid JPEG signature');
}
let offset = 2;
while (offset < buffer.length - 8) {
if (buffer[offset] !== 0xff) {
offset++;
continue;
}
const marker = buffer[offset + 1];
// SOF markers
if (
(marker >= 0xc0 && marker <= 0xc3) ||
(marker >= 0xc5 && marker <= 0xc7) ||
(marker >= 0xc9 && marker <= 0xcb) ||
(marker >= 0xcd && marker <= 0xcf)
) {
const height = buffer.readUInt16BE(offset + 5);
const width = buffer.readUInt16BE(offset + 7);
return { width, height };
}
const segmentLength = buffer.readUInt16BE(offset + 2);
offset += 2 + segmentLength;
}
throw new Error('Could not find JPEG dimensions');
}
/**
* Extract WebP dimensions from RIFF container
* Supports VP8, VP8L, and VP8X formats
*/
private extractWebpDimensions(buffer: Buffer): {
width: number;
height: number;
} {
if (buffer.length < 30) {
throw new Error('Invalid WebP: too short');
}
const riffSignature = buffer.subarray(0, 4).toString('ascii');
const webpSignature = buffer.subarray(8, 12).toString('ascii');
if (riffSignature !== 'RIFF' || webpSignature !== 'WEBP') {
throw new Error('Invalid WebP signature');
}
const format = buffer.subarray(12, 16).toString('ascii');
if (format === 'VP8 ') {
const width = buffer.readUInt16LE(26) & 0x3fff;
const height = buffer.readUInt16LE(28) & 0x3fff;
return { width, height };
} else if (format === 'VP8L') {
const bits = buffer.readUInt32LE(21);
const width = (bits & 0x3fff) + 1;
const height = ((bits >> 14) & 0x3fff) + 1;
return { width, height };
} else if (format === 'VP8X') {
const width = (buffer.readUInt32LE(24) & 0xffffff) + 1;
const height = (buffer.readUInt32LE(26) & 0xffffff) + 1;
return { width, height };
}
throw new Error('Unsupported WebP format');
}
/**
* Extract GIF dimensions from header
* Supports GIF87a and GIF89a formats
*/
private extractGifDimensions(buffer: Buffer): {
width: number;
height: number;
} {
if (buffer.length < 10) {
throw new Error('Invalid GIF: too short');
}
const signature = buffer.subarray(0, 6).toString('ascii');
if (signature !== 'GIF87a' && signature !== 'GIF89a') {
throw new Error('Invalid GIF signature');
}
const width = buffer.readUInt16LE(6);
const height = buffer.readUInt16LE(8);
return { width, height };
}
/**
* Calculate tokens for an image based on its metadata
*
* @param metadata Image metadata containing width, height, and format info
* @returns Total token count including base image tokens and special tokens
*/
calculateTokens(metadata: ImageMetadata): number {
return this.calculateTokensWithScaling(metadata.width, metadata.height);
}
/**
* Calculate tokens with scaling logic
*
* Steps:
* 1. Normalize to 28-pixel multiples
* 2. Scale large images down, small images up
* 3. Calculate tokens: pixels / 784 + 2 special tokens
*
* @param width Original image width in pixels
* @param height Original image height in pixels
* @returns Total token count for the image
*/
private calculateTokensWithScaling(width: number, height: number): number {
// Normalize to 28-pixel multiples
let hBar = Math.round(height / 28) * 28;
let wBar = Math.round(width / 28) * 28;
// Define pixel boundaries
const minPixels =
ImageTokenizer.MIN_TOKENS_PER_IMAGE * ImageTokenizer.PIXELS_PER_TOKEN;
const maxPixels =
ImageTokenizer.MAX_TOKENS_PER_IMAGE * ImageTokenizer.PIXELS_PER_TOKEN;
// Apply scaling
if (hBar * wBar > maxPixels) {
// Scale down large images
const beta = Math.sqrt((height * width) / maxPixels);
hBar = Math.floor(height / beta / 28) * 28;
wBar = Math.floor(width / beta / 28) * 28;
} else if (hBar * wBar < minPixels) {
// Scale up small images
const beta = Math.sqrt(minPixels / (height * width));
hBar = Math.ceil((height * beta) / 28) * 28;
wBar = Math.ceil((width * beta) / 28) * 28;
}
// Calculate tokens
const imageTokens = Math.floor(
(hBar * wBar) / ImageTokenizer.PIXELS_PER_TOKEN,
);
return imageTokens + ImageTokenizer.VISION_SPECIAL_TOKENS;
}
/**
* Calculate tokens for multiple images serially
*
* @param base64DataArray Array of image data with MIME type information
* @returns Promise resolving to array of token counts in same order as input
*/
async calculateTokensBatch(
base64DataArray: Array<{ data: string; mimeType: string }>,
): Promise<number[]> {
const results: number[] = [];
for (const { data, mimeType } of base64DataArray) {
try {
const metadata = await this.extractImageMetadata(data, mimeType);
results.push(this.calculateTokens(metadata));
} catch (error) {
console.warn('Error calculating tokens for image:', error);
// Return minimum tokens as fallback
results.push(
ImageTokenizer.MIN_TOKENS_PER_IMAGE +
ImageTokenizer.VISION_SPECIAL_TOKENS,
);
}
}
return results;
}
/**
* Extract BMP dimensions from header
* BMP signature: 42 4D (BM)
* Width/height at bytes 18-21 and 22-25 (little-endian)
*/
private extractBmpDimensions(buffer: Buffer): {
width: number;
height: number;
} {
if (buffer.length < 26) {
throw new Error('Invalid BMP: buffer too short');
}
// Verify BMP signature
if (buffer[0] !== 0x42 || buffer[1] !== 0x4d) {
throw new Error('Invalid BMP signature');
}
const width = buffer.readUInt32LE(18);
const height = buffer.readUInt32LE(22);
return { width, height: Math.abs(height) }; // Height can be negative for top-down BMPs
}
/**
* Extract TIFF dimensions from IFD (Image File Directory)
* TIFF can be little-endian (II) or big-endian (MM)
* Width/height are stored in IFD entries with tags 0x0100 and 0x0101
*/
private extractTiffDimensions(buffer: Buffer): {
width: number;
height: number;
} {
if (buffer.length < 8) {
throw new Error('Invalid TIFF: buffer too short');
}
// Check byte order
const byteOrder = buffer.subarray(0, 2).toString('ascii');
const isLittleEndian = byteOrder === 'II';
const isBigEndian = byteOrder === 'MM';
if (!isLittleEndian && !isBigEndian) {
throw new Error('Invalid TIFF byte order');
}
// Read magic number (should be 42)
const magic = isLittleEndian
? buffer.readUInt16LE(2)
: buffer.readUInt16BE(2);
if (magic !== 42) {
throw new Error('Invalid TIFF magic number');
}
// Read IFD offset
const ifdOffset = isLittleEndian
? buffer.readUInt32LE(4)
: buffer.readUInt32BE(4);
if (ifdOffset >= buffer.length) {
throw new Error('Invalid TIFF IFD offset');
}
// Read number of directory entries
const numEntries = isLittleEndian
? buffer.readUInt16LE(ifdOffset)
: buffer.readUInt16BE(ifdOffset);
let width = 0;
let height = 0;
// Parse IFD entries
for (let i = 0; i < numEntries; i++) {
const entryOffset = ifdOffset + 2 + i * 12;
if (entryOffset + 12 > buffer.length) break;
const tag = isLittleEndian
? buffer.readUInt16LE(entryOffset)
: buffer.readUInt16BE(entryOffset);
const type = isLittleEndian
? buffer.readUInt16LE(entryOffset + 2)
: buffer.readUInt16BE(entryOffset + 2);
const value = isLittleEndian
? buffer.readUInt32LE(entryOffset + 8)
: buffer.readUInt32BE(entryOffset + 8);
if (tag === 0x0100) {
// ImageWidth
width = type === 3 ? value : value; // SHORT or LONG
} else if (tag === 0x0101) {
// ImageLength (height)
height = type === 3 ? value : value; // SHORT or LONG
}
if (width > 0 && height > 0) break;
}
if (width === 0 || height === 0) {
throw new Error('Could not find TIFF dimensions');
}
return { width, height };
}
/**
* Extract HEIC dimensions from meta box
* HEIC is based on ISO Base Media File Format
* This is a simplified implementation that looks for 'ispe' (Image Spatial Extents) box
*/
private extractHeicDimensions(buffer: Buffer): {
width: number;
height: number;
} {
if (buffer.length < 12) {
throw new Error('Invalid HEIC: buffer too short');
}
// Check for ftyp box with HEIC brand
const ftypBox = buffer.subarray(4, 8).toString('ascii');
if (ftypBox !== 'ftyp') {
throw new Error('Invalid HEIC: missing ftyp box');
}
const brand = buffer.subarray(8, 12).toString('ascii');
if (!['heic', 'heix', 'hevc', 'hevx'].includes(brand)) {
throw new Error('Invalid HEIC brand');
}
// Look for meta box and then ispe box
let offset = 0;
while (offset < buffer.length - 8) {
const boxSize = buffer.readUInt32BE(offset);
const boxType = buffer.subarray(offset + 4, offset + 8).toString('ascii');
if (boxType === 'meta') {
// Look for ispe box inside meta box
const metaOffset = offset + 8;
let innerOffset = metaOffset + 4; // Skip version and flags
while (innerOffset < offset + boxSize - 8) {
const innerBoxSize = buffer.readUInt32BE(innerOffset);
const innerBoxType = buffer
.subarray(innerOffset + 4, innerOffset + 8)
.toString('ascii');
if (innerBoxType === 'ispe') {
// Found Image Spatial Extents box
if (innerOffset + 20 <= buffer.length) {
const width = buffer.readUInt32BE(innerOffset + 12);
const height = buffer.readUInt32BE(innerOffset + 16);
return { width, height };
}
}
if (innerBoxSize === 0) break;
innerOffset += innerBoxSize;
}
}
if (boxSize === 0) break;
offset += boxSize;
}
// Fallback: return default dimensions if we can't parse the structure
console.warn('Could not extract HEIC dimensions, using default');
return { width: 512, height: 512 };
}
}

View File

@@ -0,0 +1,40 @@
/**
* @license
* Copyright 2025 Qwen
* SPDX-License-Identifier: Apache-2.0
*/
export { DefaultRequestTokenizer } from './requestTokenizer.js';
import { DefaultRequestTokenizer } from './requestTokenizer.js';
export { TextTokenizer } from './textTokenizer.js';
export { ImageTokenizer } from './imageTokenizer.js';
export type {
RequestTokenizer,
TokenizerConfig,
TokenCalculationResult,
ImageMetadata,
} from './types.js';
// Singleton instance for convenient usage
let defaultTokenizer: DefaultRequestTokenizer | null = null;
/**
* Get the default request tokenizer instance
*/
export function getDefaultTokenizer(): DefaultRequestTokenizer {
if (!defaultTokenizer) {
defaultTokenizer = new DefaultRequestTokenizer();
}
return defaultTokenizer;
}
/**
* Dispose of the default tokenizer instance
*/
export async function disposeDefaultTokenizer(): Promise<void> {
if (defaultTokenizer) {
await defaultTokenizer.dispose();
defaultTokenizer = null;
}
}

View File

@@ -0,0 +1,293 @@
/**
* @license
* Copyright 2025 Qwen
* SPDX-License-Identifier: Apache-2.0
*/
import { describe, it, expect, beforeEach, afterEach } from 'vitest';
import { DefaultRequestTokenizer } from './requestTokenizer.js';
import type { CountTokensParameters } from '@google/genai';
describe('DefaultRequestTokenizer', () => {
let tokenizer: DefaultRequestTokenizer;
beforeEach(() => {
tokenizer = new DefaultRequestTokenizer();
});
afterEach(async () => {
await tokenizer.dispose();
});
describe('text token calculation', () => {
it('should calculate tokens for simple text content', async () => {
const request: CountTokensParameters = {
model: 'test-model',
contents: [
{
role: 'user',
parts: [{ text: 'Hello, world!' }],
},
],
};
const result = await tokenizer.calculateTokens(request);
expect(result.totalTokens).toBeGreaterThan(0);
expect(result.breakdown.textTokens).toBeGreaterThan(0);
expect(result.breakdown.imageTokens).toBe(0);
expect(result.processingTime).toBeGreaterThan(0);
});
it('should handle multiple text parts', async () => {
const request: CountTokensParameters = {
model: 'test-model',
contents: [
{
role: 'user',
parts: [
{ text: 'First part' },
{ text: 'Second part' },
{ text: 'Third part' },
],
},
],
};
const result = await tokenizer.calculateTokens(request);
expect(result.totalTokens).toBeGreaterThan(0);
expect(result.breakdown.textTokens).toBeGreaterThan(0);
});
it('should handle string content', async () => {
const request: CountTokensParameters = {
model: 'test-model',
contents: ['Simple string content'],
};
const result = await tokenizer.calculateTokens(request);
expect(result.totalTokens).toBeGreaterThan(0);
expect(result.breakdown.textTokens).toBeGreaterThan(0);
});
});
describe('image token calculation', () => {
it('should calculate tokens for image content', async () => {
// Create a simple 1x1 PNG image in base64
const pngBase64 =
'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNkYPhfDwAChAI9jU77yQAAAABJRU5ErkJggg==';
const request: CountTokensParameters = {
model: 'test-model',
contents: [
{
role: 'user',
parts: [
{
inlineData: {
mimeType: 'image/png',
data: pngBase64,
},
},
],
},
],
};
const result = await tokenizer.calculateTokens(request);
expect(result.totalTokens).toBeGreaterThanOrEqual(4); // Minimum 4 tokens per image
expect(result.breakdown.imageTokens).toBeGreaterThanOrEqual(4);
expect(result.breakdown.textTokens).toBe(0);
});
it('should handle multiple images', async () => {
const pngBase64 =
'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNkYPhfDwAChAI9jU77yQAAAABJRU5ErkJggg==';
const request: CountTokensParameters = {
model: 'test-model',
contents: [
{
role: 'user',
parts: [
{
inlineData: {
mimeType: 'image/png',
data: pngBase64,
},
},
{
inlineData: {
mimeType: 'image/png',
data: pngBase64,
},
},
],
},
],
};
const result = await tokenizer.calculateTokens(request);
expect(result.totalTokens).toBeGreaterThanOrEqual(8); // At least 4 tokens per image
expect(result.breakdown.imageTokens).toBeGreaterThanOrEqual(8);
});
});
describe('mixed content', () => {
it('should handle text and image content together', async () => {
const pngBase64 =
'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNkYPhfDwAChAI9jU77yQAAAABJRU5ErkJggg==';
const request: CountTokensParameters = {
model: 'test-model',
contents: [
{
role: 'user',
parts: [
{ text: 'Here is an image:' },
{
inlineData: {
mimeType: 'image/png',
data: pngBase64,
},
},
{ text: 'What do you see?' },
],
},
],
};
const result = await tokenizer.calculateTokens(request);
expect(result.totalTokens).toBeGreaterThan(4);
expect(result.breakdown.textTokens).toBeGreaterThan(0);
expect(result.breakdown.imageTokens).toBeGreaterThanOrEqual(4);
});
});
describe('function content', () => {
it('should handle function calls', async () => {
const request: CountTokensParameters = {
model: 'test-model',
contents: [
{
role: 'user',
parts: [
{
functionCall: {
name: 'test_function',
args: { param1: 'value1', param2: 42 },
},
},
],
},
],
};
const result = await tokenizer.calculateTokens(request);
expect(result.totalTokens).toBeGreaterThan(0);
expect(result.breakdown.otherTokens).toBeGreaterThan(0);
});
});
describe('empty content', () => {
it('should handle empty request', async () => {
const request: CountTokensParameters = {
model: 'test-model',
contents: [],
};
const result = await tokenizer.calculateTokens(request);
expect(result.totalTokens).toBe(0);
expect(result.breakdown.textTokens).toBe(0);
expect(result.breakdown.imageTokens).toBe(0);
});
it('should handle undefined contents', async () => {
const request: CountTokensParameters = {
model: 'test-model',
contents: [],
};
const result = await tokenizer.calculateTokens(request);
expect(result.totalTokens).toBe(0);
});
});
describe('configuration', () => {
it('should use custom text encoding', async () => {
const request: CountTokensParameters = {
model: 'test-model',
contents: [
{
role: 'user',
parts: [{ text: 'Test text for encoding' }],
},
],
};
const result = await tokenizer.calculateTokens(request, {
textEncoding: 'cl100k_base',
});
expect(result.totalTokens).toBeGreaterThan(0);
});
it('should process multiple images serially', async () => {
const pngBase64 =
'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNkYPhfDwAChAI9jU77yQAAAABJRU5ErkJggg==';
const request: CountTokensParameters = {
model: 'test-model',
contents: [
{
role: 'user',
parts: Array(10).fill({
inlineData: {
mimeType: 'image/png',
data: pngBase64,
},
}),
},
],
};
const result = await tokenizer.calculateTokens(request);
expect(result.totalTokens).toBeGreaterThanOrEqual(60); // At least 6 tokens per image * 10 images
});
});
describe('error handling', () => {
it('should handle malformed image data gracefully', async () => {
const request: CountTokensParameters = {
model: 'test-model',
contents: [
{
role: 'user',
parts: [
{
inlineData: {
mimeType: 'image/png',
data: 'invalid-base64-data',
},
},
],
},
],
};
const result = await tokenizer.calculateTokens(request);
// Should still return some tokens (fallback to minimum)
expect(result.totalTokens).toBeGreaterThanOrEqual(4);
});
});
});

View File

@@ -0,0 +1,341 @@
/**
* @license
* Copyright 2025 Qwen
* SPDX-License-Identifier: Apache-2.0
*/
import type {
CountTokensParameters,
Content,
Part,
PartUnion,
} from '@google/genai';
import type {
RequestTokenizer,
TokenizerConfig,
TokenCalculationResult,
} from './types.js';
import { TextTokenizer } from './textTokenizer.js';
import { ImageTokenizer } from './imageTokenizer.js';
/**
* Simple request tokenizer that handles text and image content serially
*/
export class DefaultRequestTokenizer implements RequestTokenizer {
private textTokenizer: TextTokenizer;
private imageTokenizer: ImageTokenizer;
constructor() {
this.textTokenizer = new TextTokenizer();
this.imageTokenizer = new ImageTokenizer();
}
/**
* Calculate tokens for a request using serial processing
*/
async calculateTokens(
request: CountTokensParameters,
config: TokenizerConfig = {},
): Promise<TokenCalculationResult> {
const startTime = performance.now();
// Apply configuration
if (config.textEncoding) {
this.textTokenizer = new TextTokenizer(config.textEncoding);
}
try {
// Process request content and group by type
const { textContents, imageContents, audioContents, otherContents } =
this.processAndGroupContents(request);
if (
textContents.length === 0 &&
imageContents.length === 0 &&
audioContents.length === 0 &&
otherContents.length === 0
) {
return {
totalTokens: 0,
breakdown: {
textTokens: 0,
imageTokens: 0,
audioTokens: 0,
otherTokens: 0,
},
processingTime: performance.now() - startTime,
};
}
// Calculate tokens for each content type serially
const textTokens = await this.calculateTextTokens(textContents);
const imageTokens = await this.calculateImageTokens(imageContents);
const audioTokens = await this.calculateAudioTokens(audioContents);
const otherTokens = await this.calculateOtherTokens(otherContents);
const totalTokens = textTokens + imageTokens + audioTokens + otherTokens;
const processingTime = performance.now() - startTime;
return {
totalTokens,
breakdown: {
textTokens,
imageTokens,
audioTokens,
otherTokens,
},
processingTime,
};
} catch (error) {
console.error('Error calculating tokens:', error);
// Fallback calculation
const fallbackTokens = this.calculateFallbackTokens(request);
return {
totalTokens: fallbackTokens,
breakdown: {
textTokens: fallbackTokens,
imageTokens: 0,
audioTokens: 0,
otherTokens: 0,
},
processingTime: performance.now() - startTime,
};
}
}
/**
* Calculate tokens for text contents
*/
private async calculateTextTokens(textContents: string[]): Promise<number> {
if (textContents.length === 0) return 0;
try {
const tokenCounts =
await this.textTokenizer.calculateTokensBatch(textContents);
return tokenCounts.reduce((sum, count) => sum + count, 0);
} catch (error) {
console.warn('Error calculating text tokens:', error);
// Fallback: character-based estimation
const totalChars = textContents.join('').length;
return Math.ceil(totalChars / 4);
}
}
/**
* Calculate tokens for image contents using serial processing
*/
private async calculateImageTokens(
imageContents: Array<{ data: string; mimeType: string }>,
): Promise<number> {
if (imageContents.length === 0) return 0;
try {
const tokenCounts =
await this.imageTokenizer.calculateTokensBatch(imageContents);
return tokenCounts.reduce((sum, count) => sum + count, 0);
} catch (error) {
console.warn('Error calculating image tokens:', error);
// Fallback: minimum tokens per image
return imageContents.length * 6; // 4 image tokens + 2 special tokens as minimum
}
}
/**
* Calculate tokens for audio contents
* TODO: Implement proper audio token calculation
*/
private async calculateAudioTokens(
audioContents: Array<{ data: string; mimeType: string }>,
): Promise<number> {
if (audioContents.length === 0) return 0;
// Placeholder implementation - audio token calculation would depend on
// the specific model's audio processing capabilities
// For now, estimate based on data size
let totalTokens = 0;
for (const audioContent of audioContents) {
try {
const dataSize = Math.floor(audioContent.data.length * 0.75); // Approximate binary size
// Rough estimate: 1 token per 100 bytes of audio data
totalTokens += Math.max(Math.ceil(dataSize / 100), 10); // Minimum 10 tokens per audio
} catch (error) {
console.warn('Error calculating audio tokens:', error);
totalTokens += 10; // Fallback minimum
}
}
return totalTokens;
}
/**
* Calculate tokens for other content types (functions, files, etc.)
*/
private async calculateOtherTokens(otherContents: string[]): Promise<number> {
if (otherContents.length === 0) return 0;
try {
// Treat other content as text for token calculation
const tokenCounts =
await this.textTokenizer.calculateTokensBatch(otherContents);
return tokenCounts.reduce((sum, count) => sum + count, 0);
} catch (error) {
console.warn('Error calculating other content tokens:', error);
// Fallback: character-based estimation
const totalChars = otherContents.join('').length;
return Math.ceil(totalChars / 4);
}
}
/**
* Fallback token calculation using simple string serialization
*/
private calculateFallbackTokens(request: CountTokensParameters): number {
try {
const content = JSON.stringify(request.contents);
return Math.ceil(content.length / 4); // Rough estimate: 1 token ≈ 4 characters
} catch (error) {
console.warn('Error in fallback token calculation:', error);
return 100; // Conservative fallback
}
}
/**
* Process request contents and group by type
*/
private processAndGroupContents(request: CountTokensParameters): {
textContents: string[];
imageContents: Array<{ data: string; mimeType: string }>;
audioContents: Array<{ data: string; mimeType: string }>;
otherContents: string[];
} {
const textContents: string[] = [];
const imageContents: Array<{ data: string; mimeType: string }> = [];
const audioContents: Array<{ data: string; mimeType: string }> = [];
const otherContents: string[] = [];
if (!request.contents) {
return { textContents, imageContents, audioContents, otherContents };
}
const contents = Array.isArray(request.contents)
? request.contents
: [request.contents];
for (const content of contents) {
this.processContent(
content,
textContents,
imageContents,
audioContents,
otherContents,
);
}
return { textContents, imageContents, audioContents, otherContents };
}
/**
* Process a single content item and add to appropriate arrays
*/
private processContent(
content: Content | string | PartUnion,
textContents: string[],
imageContents: Array<{ data: string; mimeType: string }>,
audioContents: Array<{ data: string; mimeType: string }>,
otherContents: string[],
): void {
if (typeof content === 'string') {
if (content.trim()) {
textContents.push(content);
}
return;
}
if ('parts' in content && content.parts) {
for (const part of content.parts) {
this.processPart(
part,
textContents,
imageContents,
audioContents,
otherContents,
);
}
}
}
/**
* Process a single part and add to appropriate arrays
*/
private processPart(
part: Part | string,
textContents: string[],
imageContents: Array<{ data: string; mimeType: string }>,
audioContents: Array<{ data: string; mimeType: string }>,
otherContents: string[],
): void {
if (typeof part === 'string') {
if (part.trim()) {
textContents.push(part);
}
return;
}
if ('text' in part && part.text) {
textContents.push(part.text);
return;
}
if ('inlineData' in part && part.inlineData) {
const { data, mimeType } = part.inlineData;
if (mimeType && mimeType.startsWith('image/')) {
imageContents.push({ data: data || '', mimeType });
return;
}
if (mimeType && mimeType.startsWith('audio/')) {
audioContents.push({ data: data || '', mimeType });
return;
}
}
if ('fileData' in part && part.fileData) {
otherContents.push(JSON.stringify(part.fileData));
return;
}
if ('functionCall' in part && part.functionCall) {
otherContents.push(JSON.stringify(part.functionCall));
return;
}
if ('functionResponse' in part && part.functionResponse) {
otherContents.push(JSON.stringify(part.functionResponse));
return;
}
// Unknown part type - try to serialize
try {
const serialized = JSON.stringify(part);
if (serialized && serialized !== '{}') {
otherContents.push(serialized);
}
} catch (error) {
console.warn('Failed to serialize unknown part type:', error);
}
}
/**
* Dispose of resources
*/
async dispose(): Promise<void> {
try {
// Dispose of tokenizers
this.textTokenizer.dispose();
} catch (error) {
console.warn('Error disposing request tokenizer:', error);
}
}
}

View File

@@ -0,0 +1,56 @@
/**
* @license
* Copyright 2025 Qwen
* SPDX-License-Identifier: Apache-2.0
*/
/**
* Supported image MIME types for vision models
* These formats are supported by the vision model and can be processed by the image tokenizer
*/
export const SUPPORTED_IMAGE_MIME_TYPES = [
'image/bmp',
'image/jpeg',
'image/jpg', // Alternative MIME type for JPEG
'image/png',
'image/tiff',
'image/webp',
'image/heic',
] as const;
/**
* Type for supported image MIME types
*/
export type SupportedImageMimeType =
(typeof SUPPORTED_IMAGE_MIME_TYPES)[number];
/**
* Check if a MIME type is supported for vision processing
* @param mimeType The MIME type to check
* @returns True if the MIME type is supported
*/
export function isSupportedImageMimeType(
mimeType: string,
): mimeType is SupportedImageMimeType {
return SUPPORTED_IMAGE_MIME_TYPES.includes(
mimeType as SupportedImageMimeType,
);
}
/**
* Get a human-readable list of supported image formats
* @returns Comma-separated string of supported formats
*/
export function getSupportedImageFormatsString(): string {
return SUPPORTED_IMAGE_MIME_TYPES.map((type) =>
type.replace('image/', '').toUpperCase(),
).join(', ');
}
/**
* Get warning message for unsupported image formats
* @returns Warning message string
*/
export function getUnsupportedImageFormatWarning(): string {
return `Only the following image formats are supported: ${getSupportedImageFormatsString()}. Other formats may not work as expected.`;
}

View File

@@ -0,0 +1,347 @@
/**
* @license
* Copyright 2025 Qwen
* SPDX-License-Identifier: Apache-2.0
*/
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
import { TextTokenizer } from './textTokenizer.js';
// Mock tiktoken at the top level with hoisted functions
const mockEncode = vi.hoisted(() => vi.fn());
const mockFree = vi.hoisted(() => vi.fn());
const mockGetEncoding = vi.hoisted(() => vi.fn());
vi.mock('tiktoken', () => ({
get_encoding: mockGetEncoding,
}));
describe('TextTokenizer', () => {
let tokenizer: TextTokenizer;
let consoleWarnSpy: ReturnType<typeof vi.spyOn>;
beforeEach(() => {
vi.resetAllMocks();
consoleWarnSpy = vi.spyOn(console, 'warn').mockImplementation(() => {});
// Default mock implementation
mockGetEncoding.mockReturnValue({
encode: mockEncode,
free: mockFree,
});
});
afterEach(() => {
vi.restoreAllMocks();
tokenizer?.dispose();
});
describe('constructor', () => {
it('should create tokenizer with default encoding', () => {
tokenizer = new TextTokenizer();
expect(tokenizer).toBeInstanceOf(TextTokenizer);
});
it('should create tokenizer with custom encoding', () => {
tokenizer = new TextTokenizer('gpt2');
expect(tokenizer).toBeInstanceOf(TextTokenizer);
});
});
describe('calculateTokens', () => {
beforeEach(() => {
tokenizer = new TextTokenizer();
});
it('should return 0 for empty text', async () => {
const result = await tokenizer.calculateTokens('');
expect(result).toBe(0);
});
it('should return 0 for null/undefined text', async () => {
const result1 = await tokenizer.calculateTokens(
null as unknown as string,
);
const result2 = await tokenizer.calculateTokens(
undefined as unknown as string,
);
expect(result1).toBe(0);
expect(result2).toBe(0);
});
it('should calculate tokens using tiktoken when available', async () => {
const testText = 'Hello, world!';
const mockTokens = [1, 2, 3, 4, 5]; // 5 tokens
mockEncode.mockReturnValue(mockTokens);
const result = await tokenizer.calculateTokens(testText);
expect(mockGetEncoding).toHaveBeenCalledWith('cl100k_base');
expect(mockEncode).toHaveBeenCalledWith(testText);
expect(result).toBe(5);
});
it('should use fallback calculation when tiktoken fails to load', async () => {
mockGetEncoding.mockImplementation(() => {
throw new Error('Failed to load tiktoken');
});
const testText = 'Hello, world!'; // 13 characters
const result = await tokenizer.calculateTokens(testText);
expect(consoleWarnSpy).toHaveBeenCalledWith(
'Failed to load tiktoken with encoding cl100k_base:',
expect.any(Error),
);
// Fallback: Math.ceil(13 / 4) = 4
expect(result).toBe(4);
});
it('should use fallback calculation when encoding fails', async () => {
mockEncode.mockImplementation(() => {
throw new Error('Encoding failed');
});
const testText = 'Hello, world!'; // 13 characters
const result = await tokenizer.calculateTokens(testText);
expect(consoleWarnSpy).toHaveBeenCalledWith(
'Error encoding text with tiktoken:',
expect.any(Error),
);
// Fallback: Math.ceil(13 / 4) = 4
expect(result).toBe(4);
});
it('should handle very long text', async () => {
const longText = 'a'.repeat(10000);
const mockTokens = new Array(2500); // 2500 tokens
mockEncode.mockReturnValue(mockTokens);
const result = await tokenizer.calculateTokens(longText);
expect(result).toBe(2500);
});
it('should handle unicode characters', async () => {
const unicodeText = '你好世界 🌍';
const mockTokens = [1, 2, 3, 4, 5, 6];
mockEncode.mockReturnValue(mockTokens);
const result = await tokenizer.calculateTokens(unicodeText);
expect(result).toBe(6);
});
it('should use custom encoding when specified', async () => {
tokenizer = new TextTokenizer('gpt2');
const testText = 'Hello, world!';
const mockTokens = [1, 2, 3];
mockEncode.mockReturnValue(mockTokens);
const result = await tokenizer.calculateTokens(testText);
expect(mockGetEncoding).toHaveBeenCalledWith('gpt2');
expect(result).toBe(3);
});
});
describe('calculateTokensBatch', () => {
beforeEach(() => {
tokenizer = new TextTokenizer();
});
it('should process multiple texts and return token counts', async () => {
const texts = ['Hello', 'world', 'test'];
mockEncode
.mockReturnValueOnce([1, 2]) // 2 tokens for 'Hello'
.mockReturnValueOnce([3, 4, 5]) // 3 tokens for 'world'
.mockReturnValueOnce([6]); // 1 token for 'test'
const result = await tokenizer.calculateTokensBatch(texts);
expect(result).toEqual([2, 3, 1]);
expect(mockEncode).toHaveBeenCalledTimes(3);
});
it('should handle empty array', async () => {
const result = await tokenizer.calculateTokensBatch([]);
expect(result).toEqual([]);
});
it('should handle array with empty strings', async () => {
const texts = ['', 'hello', ''];
mockEncode.mockReturnValue([1, 2, 3]); // Only called for 'hello'
const result = await tokenizer.calculateTokensBatch(texts);
expect(result).toEqual([0, 3, 0]);
expect(mockEncode).toHaveBeenCalledTimes(1);
expect(mockEncode).toHaveBeenCalledWith('hello');
});
it('should use fallback calculation when tiktoken fails to load', async () => {
mockGetEncoding.mockImplementation(() => {
throw new Error('Failed to load tiktoken');
});
const texts = ['Hello', 'world']; // 5 and 5 characters
const result = await tokenizer.calculateTokensBatch(texts);
expect(consoleWarnSpy).toHaveBeenCalledWith(
'Failed to load tiktoken with encoding cl100k_base:',
expect.any(Error),
);
// Fallback: Math.ceil(5/4) = 2 for both
expect(result).toEqual([2, 2]);
});
it('should use fallback calculation when encoding fails during batch processing', async () => {
mockEncode.mockImplementation(() => {
throw new Error('Encoding failed');
});
const texts = ['Hello', 'world']; // 5 and 5 characters
const result = await tokenizer.calculateTokensBatch(texts);
expect(consoleWarnSpy).toHaveBeenCalledWith(
'Error encoding texts with tiktoken:',
expect.any(Error),
);
// Fallback: Math.ceil(5/4) = 2 for both
expect(result).toEqual([2, 2]);
});
it('should handle null and undefined values in batch', async () => {
const texts = [null, 'hello', undefined, 'world'] as unknown as string[];
mockEncode
.mockReturnValueOnce([1, 2, 3]) // 3 tokens for 'hello'
.mockReturnValueOnce([4, 5]); // 2 tokens for 'world'
const result = await tokenizer.calculateTokensBatch(texts);
expect(result).toEqual([0, 3, 0, 2]);
});
});
describe('dispose', () => {
beforeEach(() => {
tokenizer = new TextTokenizer();
});
it('should free tiktoken encoding when disposing', async () => {
// Initialize the encoding by calling calculateTokens
await tokenizer.calculateTokens('test');
tokenizer.dispose();
expect(mockFree).toHaveBeenCalled();
});
it('should handle disposal when encoding is not initialized', () => {
expect(() => tokenizer.dispose()).not.toThrow();
expect(mockFree).not.toHaveBeenCalled();
});
it('should handle disposal when encoding is null', async () => {
// Force encoding to be null by making tiktoken fail
mockGetEncoding.mockImplementation(() => {
throw new Error('Failed to load');
});
await tokenizer.calculateTokens('test');
expect(() => tokenizer.dispose()).not.toThrow();
expect(mockFree).not.toHaveBeenCalled();
});
it('should handle errors during disposal gracefully', async () => {
await tokenizer.calculateTokens('test');
mockFree.mockImplementation(() => {
throw new Error('Free failed');
});
tokenizer.dispose();
expect(consoleWarnSpy).toHaveBeenCalledWith(
'Error freeing tiktoken encoding:',
expect.any(Error),
);
});
it('should allow multiple calls to dispose', async () => {
await tokenizer.calculateTokens('test');
tokenizer.dispose();
tokenizer.dispose(); // Second call should not throw
expect(mockFree).toHaveBeenCalledTimes(1);
});
});
describe('lazy initialization', () => {
beforeEach(() => {
tokenizer = new TextTokenizer();
});
it('should not initialize tiktoken until first use', () => {
expect(mockGetEncoding).not.toHaveBeenCalled();
});
it('should initialize tiktoken on first calculateTokens call', async () => {
await tokenizer.calculateTokens('test');
expect(mockGetEncoding).toHaveBeenCalledTimes(1);
});
it('should not reinitialize tiktoken on subsequent calls', async () => {
await tokenizer.calculateTokens('test1');
await tokenizer.calculateTokens('test2');
expect(mockGetEncoding).toHaveBeenCalledTimes(1);
});
it('should initialize tiktoken on first calculateTokensBatch call', async () => {
await tokenizer.calculateTokensBatch(['test']);
expect(mockGetEncoding).toHaveBeenCalledTimes(1);
});
});
describe('edge cases', () => {
beforeEach(() => {
tokenizer = new TextTokenizer();
});
it('should handle very short text', async () => {
const result = await tokenizer.calculateTokens('a');
if (mockGetEncoding.mock.calls.length > 0) {
// If tiktoken was called, use its result
expect(mockEncode).toHaveBeenCalledWith('a');
} else {
// If tiktoken failed, should use fallback: Math.ceil(1/4) = 1
expect(result).toBe(1);
}
});
it('should handle text with only whitespace', async () => {
const whitespaceText = ' \n\t ';
const mockTokens = [1];
mockEncode.mockReturnValue(mockTokens);
const result = await tokenizer.calculateTokens(whitespaceText);
expect(result).toBe(1);
});
it('should handle special characters and symbols', async () => {
const specialText = '!@#$%^&*()_+-=[]{}|;:,.<>?';
const mockTokens = new Array(10);
mockEncode.mockReturnValue(mockTokens);
const result = await tokenizer.calculateTokens(specialText);
expect(result).toBe(10);
});
});
});

View File

@@ -0,0 +1,97 @@
/**
* @license
* Copyright 2025 Qwen
* SPDX-License-Identifier: Apache-2.0
*/
import type { TiktokenEncoding, Tiktoken } from 'tiktoken';
import { get_encoding } from 'tiktoken';
/**
* Text tokenizer for calculating text tokens using tiktoken
*/
export class TextTokenizer {
private encoding: Tiktoken | null = null;
private encodingName: string;
constructor(encodingName: string = 'cl100k_base') {
this.encodingName = encodingName;
}
/**
* Initialize the tokenizer (lazy loading)
*/
private async ensureEncoding(): Promise<void> {
if (this.encoding) return;
try {
// Use type assertion since we know the encoding name is valid
this.encoding = get_encoding(this.encodingName as TiktokenEncoding);
} catch (error) {
console.warn(
`Failed to load tiktoken with encoding ${this.encodingName}:`,
error,
);
this.encoding = null;
}
}
/**
* Calculate tokens for text content
*/
async calculateTokens(text: string): Promise<number> {
if (!text) return 0;
await this.ensureEncoding();
if (this.encoding) {
try {
return this.encoding.encode(text).length;
} catch (error) {
console.warn('Error encoding text with tiktoken:', error);
}
}
// Fallback: rough approximation using character count
// This is a conservative estimate: 1 token ≈ 4 characters for most languages
return Math.ceil(text.length / 4);
}
/**
* Calculate tokens for multiple text strings in parallel
*/
async calculateTokensBatch(texts: string[]): Promise<number[]> {
await this.ensureEncoding();
if (this.encoding) {
try {
return texts.map((text) => {
if (!text) return 0;
// this.encoding may be null, add a null check to satisfy lint
return this.encoding ? this.encoding.encode(text).length : 0;
});
} catch (error) {
console.warn('Error encoding texts with tiktoken:', error);
// In case of error, return fallback estimation for all texts
return texts.map((text) => Math.ceil((text || '').length / 4));
}
}
// Fallback for batch processing
return texts.map((text) => Math.ceil((text || '').length / 4));
}
/**
* Dispose of resources
*/
dispose(): void {
if (this.encoding) {
try {
this.encoding.free();
} catch (error) {
console.warn('Error freeing tiktoken encoding:', error);
}
this.encoding = null;
}
}
}

View File

@@ -0,0 +1,64 @@
/**
* @license
* Copyright 2025 Qwen
* SPDX-License-Identifier: Apache-2.0
*/
import type { CountTokensParameters } from '@google/genai';
/**
* Token calculation result for different content types
*/
export interface TokenCalculationResult {
/** Total tokens calculated */
totalTokens: number;
/** Breakdown by content type */
breakdown: {
textTokens: number;
imageTokens: number;
audioTokens: number;
otherTokens: number;
};
/** Processing time in milliseconds */
processingTime: number;
}
/**
* Configuration for token calculation
*/
export interface TokenizerConfig {
/** Custom text tokenizer encoding (defaults to cl100k_base) */
textEncoding?: string;
}
/**
* Image metadata extracted from base64 data
*/
export interface ImageMetadata {
/** Image width in pixels */
width: number;
/** Image height in pixels */
height: number;
/** MIME type of the image */
mimeType: string;
/** Size of the base64 data in bytes */
dataSize: number;
}
/**
* Request tokenizer interface
*/
export interface RequestTokenizer {
/**
* Calculate tokens for a request
*/
calculateTokens(
request: CountTokensParameters,
config?: TokenizerConfig,
): Promise<TokenCalculationResult>;
/**
* Dispose of resources (worker threads, etc.)
*/
dispose(): Promise<void>;
}