mirror of
https://github.com/QwenLM/qwen-code.git
synced 2025-12-19 09:33:53 +00:00
Vision model support for Qwen-OAuth (#525)
* refactor: openaiContentGenerator * refactor: optimize stream handling * refactor: re-organize refactored files * fix: unit test cases * feat: `/model` command for switching to vision model * fix: lint error * feat: add image tokenizer to fit vlm context window * fix: lint and type errors * feat: add `visionModelPreview` to control default visibility of vision models * fix: remove deprecated files * fix: align supported image formats with bailian doc
This commit is contained in:
@@ -741,6 +741,16 @@ export const SETTINGS_SCHEMA = {
|
||||
description: 'Enable extension management features.',
|
||||
showInDialog: false,
|
||||
},
|
||||
visionModelPreview: {
|
||||
type: 'boolean',
|
||||
label: 'Vision Model Preview',
|
||||
category: 'Experimental',
|
||||
requiresRestart: false,
|
||||
default: false,
|
||||
description:
|
||||
'Enable vision model support and auto-switching functionality. When disabled, vision models like qwen-vl-max-latest will be hidden and auto-switching will not occur.',
|
||||
showInDialog: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
|
||||
|
||||
@@ -56,6 +56,13 @@ vi.mock('../ui/commands/mcpCommand.js', () => ({
|
||||
kind: 'BUILT_IN',
|
||||
},
|
||||
}));
|
||||
vi.mock('../ui/commands/modelCommand.js', () => ({
|
||||
modelCommand: {
|
||||
name: 'model',
|
||||
description: 'Model command',
|
||||
kind: 'BUILT_IN',
|
||||
},
|
||||
}));
|
||||
|
||||
describe('BuiltinCommandLoader', () => {
|
||||
let mockConfig: Config;
|
||||
@@ -126,5 +133,8 @@ describe('BuiltinCommandLoader', () => {
|
||||
|
||||
const mcpCmd = commands.find((c) => c.name === 'mcp');
|
||||
expect(mcpCmd).toBeDefined();
|
||||
|
||||
const modelCmd = commands.find((c) => c.name === 'model');
|
||||
expect(modelCmd).toBeDefined();
|
||||
});
|
||||
});
|
||||
|
||||
@@ -35,6 +35,7 @@ import { settingsCommand } from '../ui/commands/settingsCommand.js';
|
||||
import { vimCommand } from '../ui/commands/vimCommand.js';
|
||||
import { setupGithubCommand } from '../ui/commands/setupGithubCommand.js';
|
||||
import { terminalSetupCommand } from '../ui/commands/terminalSetupCommand.js';
|
||||
import { modelCommand } from '../ui/commands/modelCommand.js';
|
||||
import { agentsCommand } from '../ui/commands/agentsCommand.js';
|
||||
|
||||
/**
|
||||
@@ -71,6 +72,7 @@ export class BuiltinCommandLoader implements ICommandLoader {
|
||||
initCommand,
|
||||
mcpCommand,
|
||||
memoryCommand,
|
||||
modelCommand,
|
||||
privacyCommand,
|
||||
quitCommand,
|
||||
quitConfirmCommand,
|
||||
|
||||
@@ -53,6 +53,17 @@ import { FolderTrustDialog } from './components/FolderTrustDialog.js';
|
||||
import { ShellConfirmationDialog } from './components/ShellConfirmationDialog.js';
|
||||
import { QuitConfirmationDialog } from './components/QuitConfirmationDialog.js';
|
||||
import { RadioButtonSelect } from './components/shared/RadioButtonSelect.js';
|
||||
import { ModelSelectionDialog } from './components/ModelSelectionDialog.js';
|
||||
import {
|
||||
ModelSwitchDialog,
|
||||
type VisionSwitchOutcome,
|
||||
} from './components/ModelSwitchDialog.js';
|
||||
import {
|
||||
getOpenAIAvailableModelFromEnv,
|
||||
getFilteredQwenModels,
|
||||
type AvailableModel,
|
||||
} from './models/availableModels.js';
|
||||
import { processVisionSwitchOutcome } from './hooks/useVisionAutoSwitch.js';
|
||||
import {
|
||||
AgentCreationWizard,
|
||||
AgentsManagerDialog,
|
||||
@@ -248,6 +259,20 @@ const App = ({ config, settings, startupWarnings = [], version }: AppProps) => {
|
||||
onWorkspaceMigrationDialogClose,
|
||||
} = useWorkspaceMigration(settings);
|
||||
|
||||
// Model selection dialog states
|
||||
const [isModelSelectionDialogOpen, setIsModelSelectionDialogOpen] =
|
||||
useState(false);
|
||||
const [isVisionSwitchDialogOpen, setIsVisionSwitchDialogOpen] =
|
||||
useState(false);
|
||||
const [visionSwitchResolver, setVisionSwitchResolver] = useState<{
|
||||
resolve: (result: {
|
||||
modelOverride?: string;
|
||||
persistSessionModel?: string;
|
||||
showGuidance?: boolean;
|
||||
}) => void;
|
||||
reject: () => void;
|
||||
} | null>(null);
|
||||
|
||||
useEffect(() => {
|
||||
const unsubscribe = ideContext.subscribeToIdeContext(setIdeContextState);
|
||||
// Set the initial value
|
||||
@@ -590,6 +615,75 @@ const App = ({ config, settings, startupWarnings = [], version }: AppProps) => {
|
||||
openAuthDialog();
|
||||
}, [openAuthDialog, setAuthError]);
|
||||
|
||||
// Vision switch handler for auto-switch functionality
|
||||
const handleVisionSwitchRequired = useCallback(
|
||||
async (_query: unknown) =>
|
||||
new Promise<{
|
||||
modelOverride?: string;
|
||||
persistSessionModel?: string;
|
||||
showGuidance?: boolean;
|
||||
}>((resolve, reject) => {
|
||||
setVisionSwitchResolver({ resolve, reject });
|
||||
setIsVisionSwitchDialogOpen(true);
|
||||
}),
|
||||
[],
|
||||
);
|
||||
|
||||
const handleVisionSwitchSelect = useCallback(
|
||||
(outcome: VisionSwitchOutcome) => {
|
||||
setIsVisionSwitchDialogOpen(false);
|
||||
if (visionSwitchResolver) {
|
||||
const result = processVisionSwitchOutcome(outcome);
|
||||
visionSwitchResolver.resolve(result);
|
||||
setVisionSwitchResolver(null);
|
||||
}
|
||||
},
|
||||
[visionSwitchResolver],
|
||||
);
|
||||
|
||||
const handleModelSelectionOpen = useCallback(() => {
|
||||
setIsModelSelectionDialogOpen(true);
|
||||
}, []);
|
||||
|
||||
const handleModelSelectionClose = useCallback(() => {
|
||||
setIsModelSelectionDialogOpen(false);
|
||||
}, []);
|
||||
|
||||
const handleModelSelect = useCallback(
|
||||
(modelId: string) => {
|
||||
config.setModel(modelId);
|
||||
setCurrentModel(modelId);
|
||||
setIsModelSelectionDialogOpen(false);
|
||||
addItem(
|
||||
{
|
||||
type: MessageType.INFO,
|
||||
text: `Switched model to \`${modelId}\` for this session.`,
|
||||
},
|
||||
Date.now(),
|
||||
);
|
||||
},
|
||||
[config, setCurrentModel, addItem],
|
||||
);
|
||||
|
||||
const getAvailableModelsForCurrentAuth = useCallback((): AvailableModel[] => {
|
||||
const contentGeneratorConfig = config.getContentGeneratorConfig();
|
||||
if (!contentGeneratorConfig) return [];
|
||||
|
||||
const visionModelPreviewEnabled =
|
||||
settings.merged.experimental?.visionModelPreview ?? false;
|
||||
|
||||
switch (contentGeneratorConfig.authType) {
|
||||
case AuthType.QWEN_OAUTH:
|
||||
return getFilteredQwenModels(visionModelPreviewEnabled);
|
||||
case AuthType.USE_OPENAI: {
|
||||
const openAIModel = getOpenAIAvailableModelFromEnv();
|
||||
return openAIModel ? [openAIModel] : [];
|
||||
}
|
||||
default:
|
||||
return [];
|
||||
}
|
||||
}, [config, settings.merged.experimental?.visionModelPreview]);
|
||||
|
||||
// Core hooks and processors
|
||||
const {
|
||||
vimEnabled: vimModeEnabled,
|
||||
@@ -620,6 +714,7 @@ const App = ({ config, settings, startupWarnings = [], version }: AppProps) => {
|
||||
setQuittingMessages,
|
||||
openPrivacyNotice,
|
||||
openSettingsDialog,
|
||||
handleModelSelectionOpen,
|
||||
openSubagentCreateDialog,
|
||||
openAgentsManagerDialog,
|
||||
toggleVimEnabled,
|
||||
@@ -664,6 +759,8 @@ const App = ({ config, settings, startupWarnings = [], version }: AppProps) => {
|
||||
setModelSwitchedFromQuotaError,
|
||||
refreshStatic,
|
||||
() => cancelHandlerRef.current(),
|
||||
settings.merged.experimental?.visionModelPreview ?? false,
|
||||
handleVisionSwitchRequired,
|
||||
);
|
||||
|
||||
const pendingHistoryItems = useMemo(
|
||||
@@ -1034,6 +1131,8 @@ const App = ({ config, settings, startupWarnings = [], version }: AppProps) => {
|
||||
!isAuthDialogOpen &&
|
||||
!isThemeDialogOpen &&
|
||||
!isEditorDialogOpen &&
|
||||
!isModelSelectionDialogOpen &&
|
||||
!isVisionSwitchDialogOpen &&
|
||||
!isSubagentCreateDialogOpen &&
|
||||
!showPrivacyNotice &&
|
||||
!showWelcomeBackDialog &&
|
||||
@@ -1055,6 +1154,8 @@ const App = ({ config, settings, startupWarnings = [], version }: AppProps) => {
|
||||
showWelcomeBackDialog,
|
||||
welcomeBackChoice,
|
||||
geminiClient,
|
||||
isModelSelectionDialogOpen,
|
||||
isVisionSwitchDialogOpen,
|
||||
]);
|
||||
|
||||
if (quittingMessages) {
|
||||
@@ -1322,6 +1423,15 @@ const App = ({ config, settings, startupWarnings = [], version }: AppProps) => {
|
||||
onExit={exitEditorDialog}
|
||||
/>
|
||||
</Box>
|
||||
) : isModelSelectionDialogOpen ? (
|
||||
<ModelSelectionDialog
|
||||
availableModels={getAvailableModelsForCurrentAuth()}
|
||||
currentModel={currentModel}
|
||||
onSelect={handleModelSelect}
|
||||
onCancel={handleModelSelectionClose}
|
||||
/>
|
||||
) : isVisionSwitchDialogOpen ? (
|
||||
<ModelSwitchDialog onSelect={handleVisionSwitchSelect} />
|
||||
) : showPrivacyNotice ? (
|
||||
<PrivacyNotice
|
||||
onExit={() => setShowPrivacyNotice(false)}
|
||||
|
||||
179
packages/cli/src/ui/commands/modelCommand.test.ts
Normal file
179
packages/cli/src/ui/commands/modelCommand.test.ts
Normal file
@@ -0,0 +1,179 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Qwen
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { describe, it, expect, beforeEach, vi } from 'vitest';
|
||||
import { modelCommand } from './modelCommand.js';
|
||||
import { type CommandContext } from './types.js';
|
||||
import { createMockCommandContext } from '../../test-utils/mockCommandContext.js';
|
||||
import {
|
||||
AuthType,
|
||||
type ContentGeneratorConfig,
|
||||
type Config,
|
||||
} from '@qwen-code/qwen-code-core';
|
||||
import * as availableModelsModule from '../models/availableModels.js';
|
||||
|
||||
// Mock the availableModels module
|
||||
vi.mock('../models/availableModels.js', () => ({
|
||||
AVAILABLE_MODELS_QWEN: [
|
||||
{ id: 'qwen3-coder-plus', label: 'qwen3-coder-plus' },
|
||||
{ id: 'qwen-vl-max-latest', label: 'qwen-vl-max', isVision: true },
|
||||
],
|
||||
getOpenAIAvailableModelFromEnv: vi.fn(),
|
||||
}));
|
||||
|
||||
// Helper function to create a mock config
|
||||
function createMockConfig(
|
||||
contentGeneratorConfig: ContentGeneratorConfig | null,
|
||||
): Partial<Config> {
|
||||
return {
|
||||
getContentGeneratorConfig: vi.fn().mockReturnValue(contentGeneratorConfig),
|
||||
};
|
||||
}
|
||||
|
||||
describe('modelCommand', () => {
|
||||
let mockContext: CommandContext;
|
||||
const mockGetOpenAIAvailableModelFromEnv = vi.mocked(
|
||||
availableModelsModule.getOpenAIAvailableModelFromEnv,
|
||||
);
|
||||
|
||||
beforeEach(() => {
|
||||
mockContext = createMockCommandContext();
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
it('should have the correct name and description', () => {
|
||||
expect(modelCommand.name).toBe('model');
|
||||
expect(modelCommand.description).toBe('Switch the model for this session');
|
||||
});
|
||||
|
||||
it('should return error when config is not available', async () => {
|
||||
mockContext.services.config = null;
|
||||
|
||||
const result = await modelCommand.action!(mockContext, '');
|
||||
|
||||
expect(result).toEqual({
|
||||
type: 'message',
|
||||
messageType: 'error',
|
||||
content: 'Configuration not available.',
|
||||
});
|
||||
});
|
||||
|
||||
it('should return error when content generator config is not available', async () => {
|
||||
const mockConfig = createMockConfig(null);
|
||||
mockContext.services.config = mockConfig as Config;
|
||||
|
||||
const result = await modelCommand.action!(mockContext, '');
|
||||
|
||||
expect(result).toEqual({
|
||||
type: 'message',
|
||||
messageType: 'error',
|
||||
content: 'Content generator configuration not available.',
|
||||
});
|
||||
});
|
||||
|
||||
it('should return error when auth type is not available', async () => {
|
||||
const mockConfig = createMockConfig({
|
||||
model: 'test-model',
|
||||
authType: undefined,
|
||||
});
|
||||
mockContext.services.config = mockConfig as Config;
|
||||
|
||||
const result = await modelCommand.action!(mockContext, '');
|
||||
|
||||
expect(result).toEqual({
|
||||
type: 'message',
|
||||
messageType: 'error',
|
||||
content: 'Authentication type not available.',
|
||||
});
|
||||
});
|
||||
|
||||
it('should return dialog action for QWEN_OAUTH auth type', async () => {
|
||||
const mockConfig = createMockConfig({
|
||||
model: 'test-model',
|
||||
authType: AuthType.QWEN_OAUTH,
|
||||
});
|
||||
mockContext.services.config = mockConfig as Config;
|
||||
|
||||
const result = await modelCommand.action!(mockContext, '');
|
||||
|
||||
expect(result).toEqual({
|
||||
type: 'dialog',
|
||||
dialog: 'model',
|
||||
});
|
||||
});
|
||||
|
||||
it('should return dialog action for USE_OPENAI auth type when model is available', async () => {
|
||||
mockGetOpenAIAvailableModelFromEnv.mockReturnValue({
|
||||
id: 'gpt-4',
|
||||
label: 'gpt-4',
|
||||
});
|
||||
|
||||
const mockConfig = createMockConfig({
|
||||
model: 'test-model',
|
||||
authType: AuthType.USE_OPENAI,
|
||||
});
|
||||
mockContext.services.config = mockConfig as Config;
|
||||
|
||||
const result = await modelCommand.action!(mockContext, '');
|
||||
|
||||
expect(result).toEqual({
|
||||
type: 'dialog',
|
||||
dialog: 'model',
|
||||
});
|
||||
});
|
||||
|
||||
it('should return error for USE_OPENAI auth type when no model is available', async () => {
|
||||
mockGetOpenAIAvailableModelFromEnv.mockReturnValue(null);
|
||||
|
||||
const mockConfig = createMockConfig({
|
||||
model: 'test-model',
|
||||
authType: AuthType.USE_OPENAI,
|
||||
});
|
||||
mockContext.services.config = mockConfig as Config;
|
||||
|
||||
const result = await modelCommand.action!(mockContext, '');
|
||||
|
||||
expect(result).toEqual({
|
||||
type: 'message',
|
||||
messageType: 'error',
|
||||
content:
|
||||
'No models available for the current authentication type (openai).',
|
||||
});
|
||||
});
|
||||
|
||||
it('should return error for unsupported auth types', async () => {
|
||||
const mockConfig = createMockConfig({
|
||||
model: 'test-model',
|
||||
authType: 'UNSUPPORTED_AUTH_TYPE' as AuthType,
|
||||
});
|
||||
mockContext.services.config = mockConfig as Config;
|
||||
|
||||
const result = await modelCommand.action!(mockContext, '');
|
||||
|
||||
expect(result).toEqual({
|
||||
type: 'message',
|
||||
messageType: 'error',
|
||||
content:
|
||||
'No models available for the current authentication type (UNSUPPORTED_AUTH_TYPE).',
|
||||
});
|
||||
});
|
||||
|
||||
it('should handle undefined auth type', async () => {
|
||||
const mockConfig = createMockConfig({
|
||||
model: 'test-model',
|
||||
authType: undefined,
|
||||
});
|
||||
mockContext.services.config = mockConfig as Config;
|
||||
|
||||
const result = await modelCommand.action!(mockContext, '');
|
||||
|
||||
expect(result).toEqual({
|
||||
type: 'message',
|
||||
messageType: 'error',
|
||||
content: 'Authentication type not available.',
|
||||
});
|
||||
});
|
||||
});
|
||||
88
packages/cli/src/ui/commands/modelCommand.ts
Normal file
88
packages/cli/src/ui/commands/modelCommand.ts
Normal file
@@ -0,0 +1,88 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Qwen
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { AuthType } from '@qwen-code/qwen-code-core';
|
||||
import type {
|
||||
SlashCommand,
|
||||
CommandContext,
|
||||
OpenDialogActionReturn,
|
||||
MessageActionReturn,
|
||||
} from './types.js';
|
||||
import { CommandKind } from './types.js';
|
||||
import {
|
||||
AVAILABLE_MODELS_QWEN,
|
||||
getOpenAIAvailableModelFromEnv,
|
||||
type AvailableModel,
|
||||
} from '../models/availableModels.js';
|
||||
|
||||
function getAvailableModelsForAuthType(authType: AuthType): AvailableModel[] {
|
||||
switch (authType) {
|
||||
case AuthType.QWEN_OAUTH:
|
||||
return AVAILABLE_MODELS_QWEN;
|
||||
case AuthType.USE_OPENAI: {
|
||||
const openAIModel = getOpenAIAvailableModelFromEnv();
|
||||
return openAIModel ? [openAIModel] : [];
|
||||
}
|
||||
default:
|
||||
// For other auth types, return empty array for now
|
||||
// This can be expanded later according to the design doc
|
||||
return [];
|
||||
}
|
||||
}
|
||||
|
||||
export const modelCommand: SlashCommand = {
|
||||
name: 'model',
|
||||
description: 'Switch the model for this session',
|
||||
kind: CommandKind.BUILT_IN,
|
||||
action: async (
|
||||
context: CommandContext,
|
||||
): Promise<OpenDialogActionReturn | MessageActionReturn> => {
|
||||
const { services } = context;
|
||||
const { config } = services;
|
||||
|
||||
if (!config) {
|
||||
return {
|
||||
type: 'message',
|
||||
messageType: 'error',
|
||||
content: 'Configuration not available.',
|
||||
};
|
||||
}
|
||||
|
||||
const contentGeneratorConfig = config.getContentGeneratorConfig();
|
||||
if (!contentGeneratorConfig) {
|
||||
return {
|
||||
type: 'message',
|
||||
messageType: 'error',
|
||||
content: 'Content generator configuration not available.',
|
||||
};
|
||||
}
|
||||
|
||||
const authType = contentGeneratorConfig.authType;
|
||||
if (!authType) {
|
||||
return {
|
||||
type: 'message',
|
||||
messageType: 'error',
|
||||
content: 'Authentication type not available.',
|
||||
};
|
||||
}
|
||||
|
||||
const availableModels = getAvailableModelsForAuthType(authType);
|
||||
|
||||
if (availableModels.length === 0) {
|
||||
return {
|
||||
type: 'message',
|
||||
messageType: 'error',
|
||||
content: `No models available for the current authentication type (${authType}).`,
|
||||
};
|
||||
}
|
||||
|
||||
// Trigger model selection dialog
|
||||
return {
|
||||
type: 'dialog',
|
||||
dialog: 'model',
|
||||
};
|
||||
},
|
||||
};
|
||||
@@ -116,6 +116,7 @@ export interface OpenDialogActionReturn {
|
||||
| 'editor'
|
||||
| 'privacy'
|
||||
| 'settings'
|
||||
| 'model'
|
||||
| 'subagent_create'
|
||||
| 'subagent_list';
|
||||
}
|
||||
|
||||
246
packages/cli/src/ui/components/ModelSelectionDialog.test.tsx
Normal file
246
packages/cli/src/ui/components/ModelSelectionDialog.test.tsx
Normal file
@@ -0,0 +1,246 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Qwen
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import React from 'react';
|
||||
import { render } from 'ink-testing-library';
|
||||
import { describe, it, expect, vi, beforeEach } from 'vitest';
|
||||
import { ModelSelectionDialog } from './ModelSelectionDialog.js';
|
||||
import type { AvailableModel } from '../models/availableModels.js';
|
||||
import type { RadioSelectItem } from './shared/RadioButtonSelect.js';
|
||||
|
||||
// Mock the useKeypress hook
|
||||
const mockUseKeypress = vi.hoisted(() => vi.fn());
|
||||
vi.mock('../hooks/useKeypress.js', () => ({
|
||||
useKeypress: mockUseKeypress,
|
||||
}));
|
||||
|
||||
// Mock the RadioButtonSelect component
|
||||
const mockRadioButtonSelect = vi.hoisted(() => vi.fn());
|
||||
vi.mock('./shared/RadioButtonSelect.js', () => ({
|
||||
RadioButtonSelect: mockRadioButtonSelect,
|
||||
}));
|
||||
|
||||
describe('ModelSelectionDialog', () => {
|
||||
const mockAvailableModels: AvailableModel[] = [
|
||||
{ id: 'qwen3-coder-plus', label: 'qwen3-coder-plus' },
|
||||
{ id: 'qwen-vl-max-latest', label: 'qwen-vl-max', isVision: true },
|
||||
{ id: 'gpt-4', label: 'GPT-4' },
|
||||
];
|
||||
|
||||
const mockOnSelect = vi.fn();
|
||||
const mockOnCancel = vi.fn();
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
|
||||
// Mock RadioButtonSelect to return a simple div
|
||||
mockRadioButtonSelect.mockReturnValue(
|
||||
React.createElement('div', { 'data-testid': 'radio-select' }),
|
||||
);
|
||||
});
|
||||
|
||||
it('should setup escape key handler to call onCancel', () => {
|
||||
render(
|
||||
<ModelSelectionDialog
|
||||
availableModels={mockAvailableModels}
|
||||
currentModel="qwen3-coder-plus"
|
||||
onSelect={mockOnSelect}
|
||||
onCancel={mockOnCancel}
|
||||
/>,
|
||||
);
|
||||
|
||||
expect(mockUseKeypress).toHaveBeenCalledWith(expect.any(Function), {
|
||||
isActive: true,
|
||||
});
|
||||
|
||||
// Simulate escape key press
|
||||
const keypressHandler = mockUseKeypress.mock.calls[0][0];
|
||||
keypressHandler({ name: 'escape' });
|
||||
|
||||
expect(mockOnCancel).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should not call onCancel for non-escape keys', () => {
|
||||
render(
|
||||
<ModelSelectionDialog
|
||||
availableModels={mockAvailableModels}
|
||||
currentModel="qwen3-coder-plus"
|
||||
onSelect={mockOnSelect}
|
||||
onCancel={mockOnCancel}
|
||||
/>,
|
||||
);
|
||||
|
||||
const keypressHandler = mockUseKeypress.mock.calls[0][0];
|
||||
keypressHandler({ name: 'enter' });
|
||||
|
||||
expect(mockOnCancel).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should set correct initial index for current model', () => {
|
||||
render(
|
||||
<ModelSelectionDialog
|
||||
availableModels={mockAvailableModels}
|
||||
currentModel="qwen-vl-max-latest"
|
||||
onSelect={mockOnSelect}
|
||||
onCancel={mockOnCancel}
|
||||
/>,
|
||||
);
|
||||
|
||||
const callArgs = mockRadioButtonSelect.mock.calls[0][0];
|
||||
expect(callArgs.initialIndex).toBe(1); // qwen-vl-max-latest is at index 1
|
||||
});
|
||||
|
||||
it('should set initial index to 0 when current model is not found', () => {
|
||||
render(
|
||||
<ModelSelectionDialog
|
||||
availableModels={mockAvailableModels}
|
||||
currentModel="non-existent-model"
|
||||
onSelect={mockOnSelect}
|
||||
onCancel={mockOnCancel}
|
||||
/>,
|
||||
);
|
||||
|
||||
const callArgs = mockRadioButtonSelect.mock.calls[0][0];
|
||||
expect(callArgs.initialIndex).toBe(0);
|
||||
});
|
||||
|
||||
it('should call onSelect when a model is selected', () => {
|
||||
render(
|
||||
<ModelSelectionDialog
|
||||
availableModels={mockAvailableModels}
|
||||
currentModel="qwen3-coder-plus"
|
||||
onSelect={mockOnSelect}
|
||||
onCancel={mockOnCancel}
|
||||
/>,
|
||||
);
|
||||
|
||||
const callArgs = mockRadioButtonSelect.mock.calls[0][0];
|
||||
expect(typeof callArgs.onSelect).toBe('function');
|
||||
|
||||
// Simulate selection
|
||||
const onSelectCallback = mockRadioButtonSelect.mock.calls[0][0].onSelect;
|
||||
onSelectCallback('qwen-vl-max-latest');
|
||||
|
||||
expect(mockOnSelect).toHaveBeenCalledWith('qwen-vl-max-latest');
|
||||
});
|
||||
|
||||
it('should handle empty models array', () => {
|
||||
render(
|
||||
<ModelSelectionDialog
|
||||
availableModels={[]}
|
||||
currentModel=""
|
||||
onSelect={mockOnSelect}
|
||||
onCancel={mockOnCancel}
|
||||
/>,
|
||||
);
|
||||
|
||||
const callArgs = mockRadioButtonSelect.mock.calls[0][0];
|
||||
expect(callArgs.items).toEqual([]);
|
||||
expect(callArgs.initialIndex).toBe(0);
|
||||
});
|
||||
|
||||
it('should create correct option items with proper labels', () => {
|
||||
render(
|
||||
<ModelSelectionDialog
|
||||
availableModels={mockAvailableModels}
|
||||
currentModel="qwen3-coder-plus"
|
||||
onSelect={mockOnSelect}
|
||||
onCancel={mockOnCancel}
|
||||
/>,
|
||||
);
|
||||
|
||||
const expectedItems = [
|
||||
{
|
||||
label: 'qwen3-coder-plus (current)',
|
||||
value: 'qwen3-coder-plus',
|
||||
},
|
||||
{
|
||||
label: 'qwen-vl-max [Vision]',
|
||||
value: 'qwen-vl-max-latest',
|
||||
},
|
||||
{
|
||||
label: 'GPT-4',
|
||||
value: 'gpt-4',
|
||||
},
|
||||
];
|
||||
|
||||
const callArgs = mockRadioButtonSelect.mock.calls[0][0];
|
||||
expect(callArgs.items).toEqual(expectedItems);
|
||||
});
|
||||
|
||||
it('should show vision indicator for vision models', () => {
|
||||
render(
|
||||
<ModelSelectionDialog
|
||||
availableModels={mockAvailableModels}
|
||||
currentModel="gpt-4"
|
||||
onSelect={mockOnSelect}
|
||||
onCancel={mockOnCancel}
|
||||
/>,
|
||||
);
|
||||
|
||||
const callArgs = mockRadioButtonSelect.mock.calls[0][0];
|
||||
const visionModelItem = callArgs.items.find(
|
||||
(item: RadioSelectItem<string>) => item.value === 'qwen-vl-max-latest',
|
||||
);
|
||||
|
||||
expect(visionModelItem?.label).toContain('[Vision]');
|
||||
});
|
||||
|
||||
it('should show current indicator for the current model', () => {
|
||||
render(
|
||||
<ModelSelectionDialog
|
||||
availableModels={mockAvailableModels}
|
||||
currentModel="qwen-vl-max-latest"
|
||||
onSelect={mockOnSelect}
|
||||
onCancel={mockOnCancel}
|
||||
/>,
|
||||
);
|
||||
|
||||
const callArgs = mockRadioButtonSelect.mock.calls[0][0];
|
||||
const currentModelItem = callArgs.items.find(
|
||||
(item: RadioSelectItem<string>) => item.value === 'qwen-vl-max-latest',
|
||||
);
|
||||
|
||||
expect(currentModelItem?.label).toContain('(current)');
|
||||
});
|
||||
|
||||
it('should pass isFocused prop to RadioButtonSelect', () => {
|
||||
render(
|
||||
<ModelSelectionDialog
|
||||
availableModels={mockAvailableModels}
|
||||
currentModel="qwen3-coder-plus"
|
||||
onSelect={mockOnSelect}
|
||||
onCancel={mockOnCancel}
|
||||
/>,
|
||||
);
|
||||
|
||||
const callArgs = mockRadioButtonSelect.mock.calls[0][0];
|
||||
expect(callArgs.isFocused).toBe(true);
|
||||
});
|
||||
|
||||
it('should handle multiple onSelect calls correctly', () => {
|
||||
render(
|
||||
<ModelSelectionDialog
|
||||
availableModels={mockAvailableModels}
|
||||
currentModel="qwen3-coder-plus"
|
||||
onSelect={mockOnSelect}
|
||||
onCancel={mockOnCancel}
|
||||
/>,
|
||||
);
|
||||
|
||||
const onSelectCallback = mockRadioButtonSelect.mock.calls[0][0].onSelect;
|
||||
|
||||
// Call multiple times
|
||||
onSelectCallback('qwen3-coder-plus');
|
||||
onSelectCallback('qwen-vl-max-latest');
|
||||
onSelectCallback('gpt-4');
|
||||
|
||||
expect(mockOnSelect).toHaveBeenCalledTimes(3);
|
||||
expect(mockOnSelect).toHaveBeenNthCalledWith(1, 'qwen3-coder-plus');
|
||||
expect(mockOnSelect).toHaveBeenNthCalledWith(2, 'qwen-vl-max-latest');
|
||||
expect(mockOnSelect).toHaveBeenNthCalledWith(3, 'gpt-4');
|
||||
});
|
||||
});
|
||||
87
packages/cli/src/ui/components/ModelSelectionDialog.tsx
Normal file
87
packages/cli/src/ui/components/ModelSelectionDialog.tsx
Normal file
@@ -0,0 +1,87 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Qwen
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import type React from 'react';
|
||||
import { Box, Text } from 'ink';
|
||||
import { Colors } from '../colors.js';
|
||||
import {
|
||||
RadioButtonSelect,
|
||||
type RadioSelectItem,
|
||||
} from './shared/RadioButtonSelect.js';
|
||||
import { useKeypress } from '../hooks/useKeypress.js';
|
||||
import type { AvailableModel } from '../models/availableModels.js';
|
||||
|
||||
export interface ModelSelectionDialogProps {
|
||||
availableModels: AvailableModel[];
|
||||
currentModel: string;
|
||||
onSelect: (modelId: string) => void;
|
||||
onCancel: () => void;
|
||||
}
|
||||
|
||||
export const ModelSelectionDialog: React.FC<ModelSelectionDialogProps> = ({
|
||||
availableModels,
|
||||
currentModel,
|
||||
onSelect,
|
||||
onCancel,
|
||||
}) => {
|
||||
useKeypress(
|
||||
(key) => {
|
||||
if (key.name === 'escape') {
|
||||
onCancel();
|
||||
}
|
||||
},
|
||||
{ isActive: true },
|
||||
);
|
||||
|
||||
const options: Array<RadioSelectItem<string>> = availableModels.map(
|
||||
(model) => {
|
||||
const visionIndicator = model.isVision ? ' [Vision]' : '';
|
||||
const currentIndicator = model.id === currentModel ? ' (current)' : '';
|
||||
return {
|
||||
label: `${model.label}${visionIndicator}${currentIndicator}`,
|
||||
value: model.id,
|
||||
};
|
||||
},
|
||||
);
|
||||
|
||||
const initialIndex = Math.max(
|
||||
0,
|
||||
availableModels.findIndex((model) => model.id === currentModel),
|
||||
);
|
||||
|
||||
const handleSelect = (modelId: string) => {
|
||||
onSelect(modelId);
|
||||
};
|
||||
|
||||
return (
|
||||
<Box
|
||||
flexDirection="column"
|
||||
borderStyle="round"
|
||||
borderColor={Colors.AccentBlue}
|
||||
padding={1}
|
||||
width="100%"
|
||||
marginLeft={1}
|
||||
>
|
||||
<Box flexDirection="column" marginBottom={1}>
|
||||
<Text bold>Select Model</Text>
|
||||
<Text>Choose a model for this session:</Text>
|
||||
</Box>
|
||||
|
||||
<Box marginBottom={1}>
|
||||
<RadioButtonSelect
|
||||
items={options}
|
||||
initialIndex={initialIndex}
|
||||
onSelect={handleSelect}
|
||||
isFocused
|
||||
/>
|
||||
</Box>
|
||||
|
||||
<Box>
|
||||
<Text color={Colors.Gray}>Press Enter to select, Esc to cancel</Text>
|
||||
</Box>
|
||||
</Box>
|
||||
);
|
||||
};
|
||||
185
packages/cli/src/ui/components/ModelSwitchDialog.test.tsx
Normal file
185
packages/cli/src/ui/components/ModelSwitchDialog.test.tsx
Normal file
@@ -0,0 +1,185 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Qwen
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import React from 'react';
|
||||
import { render } from 'ink-testing-library';
|
||||
import { describe, it, expect, vi, beforeEach } from 'vitest';
|
||||
import { ModelSwitchDialog, VisionSwitchOutcome } from './ModelSwitchDialog.js';
|
||||
|
||||
// Mock the useKeypress hook
|
||||
const mockUseKeypress = vi.hoisted(() => vi.fn());
|
||||
vi.mock('../hooks/useKeypress.js', () => ({
|
||||
useKeypress: mockUseKeypress,
|
||||
}));
|
||||
|
||||
// Mock the RadioButtonSelect component
|
||||
const mockRadioButtonSelect = vi.hoisted(() => vi.fn());
|
||||
vi.mock('./shared/RadioButtonSelect.js', () => ({
|
||||
RadioButtonSelect: mockRadioButtonSelect,
|
||||
}));
|
||||
|
||||
describe('ModelSwitchDialog', () => {
|
||||
const mockOnSelect = vi.fn();
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
|
||||
// Mock RadioButtonSelect to return a simple div
|
||||
mockRadioButtonSelect.mockReturnValue(
|
||||
React.createElement('div', { 'data-testid': 'radio-select' }),
|
||||
);
|
||||
});
|
||||
|
||||
it('should setup RadioButtonSelect with correct options', () => {
|
||||
render(<ModelSwitchDialog onSelect={mockOnSelect} />);
|
||||
|
||||
const expectedItems = [
|
||||
{
|
||||
label: 'Switch for this request only',
|
||||
value: VisionSwitchOutcome.SwitchOnce,
|
||||
},
|
||||
{
|
||||
label: 'Switch session to vision model',
|
||||
value: VisionSwitchOutcome.SwitchSessionToVL,
|
||||
},
|
||||
{
|
||||
label: 'Do not switch, show guidance',
|
||||
value: VisionSwitchOutcome.DisallowWithGuidance,
|
||||
},
|
||||
];
|
||||
|
||||
const callArgs = mockRadioButtonSelect.mock.calls[0][0];
|
||||
expect(callArgs.items).toEqual(expectedItems);
|
||||
expect(callArgs.initialIndex).toBe(0);
|
||||
expect(callArgs.isFocused).toBe(true);
|
||||
});
|
||||
|
||||
it('should call onSelect when an option is selected', () => {
|
||||
render(<ModelSwitchDialog onSelect={mockOnSelect} />);
|
||||
|
||||
const callArgs = mockRadioButtonSelect.mock.calls[0][0];
|
||||
expect(typeof callArgs.onSelect).toBe('function');
|
||||
|
||||
// Simulate selection of "Switch for this request only"
|
||||
const onSelectCallback = mockRadioButtonSelect.mock.calls[0][0].onSelect;
|
||||
onSelectCallback(VisionSwitchOutcome.SwitchOnce);
|
||||
|
||||
expect(mockOnSelect).toHaveBeenCalledWith(VisionSwitchOutcome.SwitchOnce);
|
||||
});
|
||||
|
||||
it('should call onSelect with SwitchSessionToVL when second option is selected', () => {
|
||||
render(<ModelSwitchDialog onSelect={mockOnSelect} />);
|
||||
|
||||
const onSelectCallback = mockRadioButtonSelect.mock.calls[0][0].onSelect;
|
||||
onSelectCallback(VisionSwitchOutcome.SwitchSessionToVL);
|
||||
|
||||
expect(mockOnSelect).toHaveBeenCalledWith(
|
||||
VisionSwitchOutcome.SwitchSessionToVL,
|
||||
);
|
||||
});
|
||||
|
||||
it('should call onSelect with DisallowWithGuidance when third option is selected', () => {
|
||||
render(<ModelSwitchDialog onSelect={mockOnSelect} />);
|
||||
|
||||
const onSelectCallback = mockRadioButtonSelect.mock.calls[0][0].onSelect;
|
||||
onSelectCallback(VisionSwitchOutcome.DisallowWithGuidance);
|
||||
|
||||
expect(mockOnSelect).toHaveBeenCalledWith(
|
||||
VisionSwitchOutcome.DisallowWithGuidance,
|
||||
);
|
||||
});
|
||||
|
||||
it('should setup escape key handler to call onSelect with DisallowWithGuidance', () => {
|
||||
render(<ModelSwitchDialog onSelect={mockOnSelect} />);
|
||||
|
||||
expect(mockUseKeypress).toHaveBeenCalledWith(expect.any(Function), {
|
||||
isActive: true,
|
||||
});
|
||||
|
||||
// Simulate escape key press
|
||||
const keypressHandler = mockUseKeypress.mock.calls[0][0];
|
||||
keypressHandler({ name: 'escape' });
|
||||
|
||||
expect(mockOnSelect).toHaveBeenCalledWith(
|
||||
VisionSwitchOutcome.DisallowWithGuidance,
|
||||
);
|
||||
});
|
||||
|
||||
it('should not call onSelect for non-escape keys', () => {
|
||||
render(<ModelSwitchDialog onSelect={mockOnSelect} />);
|
||||
|
||||
const keypressHandler = mockUseKeypress.mock.calls[0][0];
|
||||
keypressHandler({ name: 'enter' });
|
||||
|
||||
expect(mockOnSelect).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should set initial index to 0 (first option)', () => {
|
||||
render(<ModelSwitchDialog onSelect={mockOnSelect} />);
|
||||
|
||||
const callArgs = mockRadioButtonSelect.mock.calls[0][0];
|
||||
expect(callArgs.initialIndex).toBe(0);
|
||||
});
|
||||
|
||||
describe('VisionSwitchOutcome enum', () => {
|
||||
it('should have correct enum values', () => {
|
||||
expect(VisionSwitchOutcome.SwitchOnce).toBe('switch_once');
|
||||
expect(VisionSwitchOutcome.SwitchSessionToVL).toBe(
|
||||
'switch_session_to_vl',
|
||||
);
|
||||
expect(VisionSwitchOutcome.DisallowWithGuidance).toBe(
|
||||
'disallow_with_guidance',
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
it('should handle multiple onSelect calls correctly', () => {
|
||||
render(<ModelSwitchDialog onSelect={mockOnSelect} />);
|
||||
|
||||
const onSelectCallback = mockRadioButtonSelect.mock.calls[0][0].onSelect;
|
||||
|
||||
// Call multiple times
|
||||
onSelectCallback(VisionSwitchOutcome.SwitchOnce);
|
||||
onSelectCallback(VisionSwitchOutcome.SwitchSessionToVL);
|
||||
onSelectCallback(VisionSwitchOutcome.DisallowWithGuidance);
|
||||
|
||||
expect(mockOnSelect).toHaveBeenCalledTimes(3);
|
||||
expect(mockOnSelect).toHaveBeenNthCalledWith(
|
||||
1,
|
||||
VisionSwitchOutcome.SwitchOnce,
|
||||
);
|
||||
expect(mockOnSelect).toHaveBeenNthCalledWith(
|
||||
2,
|
||||
VisionSwitchOutcome.SwitchSessionToVL,
|
||||
);
|
||||
expect(mockOnSelect).toHaveBeenNthCalledWith(
|
||||
3,
|
||||
VisionSwitchOutcome.DisallowWithGuidance,
|
||||
);
|
||||
});
|
||||
|
||||
it('should pass isFocused prop to RadioButtonSelect', () => {
|
||||
render(<ModelSwitchDialog onSelect={mockOnSelect} />);
|
||||
|
||||
const callArgs = mockRadioButtonSelect.mock.calls[0][0];
|
||||
expect(callArgs.isFocused).toBe(true);
|
||||
});
|
||||
|
||||
it('should handle escape key multiple times', () => {
|
||||
render(<ModelSwitchDialog onSelect={mockOnSelect} />);
|
||||
|
||||
const keypressHandler = mockUseKeypress.mock.calls[0][0];
|
||||
|
||||
// Call escape multiple times
|
||||
keypressHandler({ name: 'escape' });
|
||||
keypressHandler({ name: 'escape' });
|
||||
|
||||
expect(mockOnSelect).toHaveBeenCalledTimes(2);
|
||||
expect(mockOnSelect).toHaveBeenCalledWith(
|
||||
VisionSwitchOutcome.DisallowWithGuidance,
|
||||
);
|
||||
});
|
||||
});
|
||||
89
packages/cli/src/ui/components/ModelSwitchDialog.tsx
Normal file
89
packages/cli/src/ui/components/ModelSwitchDialog.tsx
Normal file
@@ -0,0 +1,89 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Qwen
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import type React from 'react';
|
||||
import { Box, Text } from 'ink';
|
||||
import { Colors } from '../colors.js';
|
||||
import {
|
||||
RadioButtonSelect,
|
||||
type RadioSelectItem,
|
||||
} from './shared/RadioButtonSelect.js';
|
||||
import { useKeypress } from '../hooks/useKeypress.js';
|
||||
|
||||
export enum VisionSwitchOutcome {
|
||||
SwitchOnce = 'switch_once',
|
||||
SwitchSessionToVL = 'switch_session_to_vl',
|
||||
DisallowWithGuidance = 'disallow_with_guidance',
|
||||
}
|
||||
|
||||
export interface ModelSwitchDialogProps {
|
||||
onSelect: (outcome: VisionSwitchOutcome) => void;
|
||||
}
|
||||
|
||||
export const ModelSwitchDialog: React.FC<ModelSwitchDialogProps> = ({
|
||||
onSelect,
|
||||
}) => {
|
||||
useKeypress(
|
||||
(key) => {
|
||||
if (key.name === 'escape') {
|
||||
onSelect(VisionSwitchOutcome.DisallowWithGuidance);
|
||||
}
|
||||
},
|
||||
{ isActive: true },
|
||||
);
|
||||
|
||||
const options: Array<RadioSelectItem<VisionSwitchOutcome>> = [
|
||||
{
|
||||
label: 'Switch for this request only',
|
||||
value: VisionSwitchOutcome.SwitchOnce,
|
||||
},
|
||||
{
|
||||
label: 'Switch session to vision model',
|
||||
value: VisionSwitchOutcome.SwitchSessionToVL,
|
||||
},
|
||||
{
|
||||
label: 'Do not switch, show guidance',
|
||||
value: VisionSwitchOutcome.DisallowWithGuidance,
|
||||
},
|
||||
];
|
||||
|
||||
const handleSelect = (outcome: VisionSwitchOutcome) => {
|
||||
onSelect(outcome);
|
||||
};
|
||||
|
||||
return (
|
||||
<Box
|
||||
flexDirection="column"
|
||||
borderStyle="round"
|
||||
borderColor={Colors.AccentYellow}
|
||||
padding={1}
|
||||
width="100%"
|
||||
marginLeft={1}
|
||||
>
|
||||
<Box flexDirection="column" marginBottom={1}>
|
||||
<Text bold>Vision Model Switch Required</Text>
|
||||
<Text>
|
||||
Your message contains an image, but the current model doesn't
|
||||
support vision.
|
||||
</Text>
|
||||
<Text>How would you like to proceed?</Text>
|
||||
</Box>
|
||||
|
||||
<Box marginBottom={1}>
|
||||
<RadioButtonSelect
|
||||
items={options}
|
||||
initialIndex={0}
|
||||
onSelect={handleSelect}
|
||||
isFocused
|
||||
/>
|
||||
</Box>
|
||||
|
||||
<Box>
|
||||
<Text color={Colors.Gray}>Press Enter to select, Esc to cancel</Text>
|
||||
</Box>
|
||||
</Box>
|
||||
);
|
||||
};
|
||||
@@ -106,6 +106,7 @@ describe('useSlashCommandProcessor', () => {
|
||||
const mockLoadHistory = vi.fn();
|
||||
const mockOpenThemeDialog = vi.fn();
|
||||
const mockOpenAuthDialog = vi.fn();
|
||||
const mockOpenModelSelectionDialog = vi.fn();
|
||||
const mockSetQuittingMessages = vi.fn();
|
||||
|
||||
const mockConfig = makeFakeConfig({});
|
||||
@@ -122,6 +123,7 @@ describe('useSlashCommandProcessor', () => {
|
||||
mockBuiltinLoadCommands.mockResolvedValue([]);
|
||||
mockFileLoadCommands.mockResolvedValue([]);
|
||||
mockMcpLoadCommands.mockResolvedValue([]);
|
||||
mockOpenModelSelectionDialog.mockClear();
|
||||
});
|
||||
|
||||
const setupProcessorHook = (
|
||||
@@ -150,11 +152,13 @@ describe('useSlashCommandProcessor', () => {
|
||||
mockSetQuittingMessages,
|
||||
vi.fn(), // openPrivacyNotice
|
||||
vi.fn(), // openSettingsDialog
|
||||
mockOpenModelSelectionDialog,
|
||||
vi.fn(), // openSubagentCreateDialog
|
||||
vi.fn(), // openAgentsManagerDialog
|
||||
vi.fn(), // toggleVimEnabled
|
||||
setIsProcessing,
|
||||
vi.fn(), // setGeminiMdFileCount
|
||||
vi.fn(), // _showQuitConfirmation
|
||||
),
|
||||
);
|
||||
|
||||
@@ -395,6 +399,21 @@ describe('useSlashCommandProcessor', () => {
|
||||
expect(mockOpenThemeDialog).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should handle "dialog: model" action', async () => {
|
||||
const command = createTestCommand({
|
||||
name: 'modelcmd',
|
||||
action: vi.fn().mockResolvedValue({ type: 'dialog', dialog: 'model' }),
|
||||
});
|
||||
const result = setupProcessorHook([command]);
|
||||
await waitFor(() => expect(result.current.slashCommands).toHaveLength(1));
|
||||
|
||||
await act(async () => {
|
||||
await result.current.handleSlashCommand('/modelcmd');
|
||||
});
|
||||
|
||||
expect(mockOpenModelSelectionDialog).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should handle "load_history" action', async () => {
|
||||
const command = createTestCommand({
|
||||
name: 'load',
|
||||
@@ -904,11 +923,13 @@ describe('useSlashCommandProcessor', () => {
|
||||
mockSetQuittingMessages,
|
||||
vi.fn(), // openPrivacyNotice
|
||||
vi.fn(), // openSettingsDialog
|
||||
vi.fn(), // openModelSelectionDialog
|
||||
vi.fn(), // openSubagentCreateDialog
|
||||
vi.fn(), // openAgentsManagerDialog
|
||||
vi.fn(), // toggleVimEnabled
|
||||
vi.fn(), // setIsProcessing
|
||||
vi.fn(), // setGeminiMdFileCount
|
||||
vi.fn(), // _showQuitConfirmation
|
||||
),
|
||||
);
|
||||
|
||||
|
||||
@@ -53,6 +53,7 @@ export const useSlashCommandProcessor = (
|
||||
setQuittingMessages: (message: HistoryItem[]) => void,
|
||||
openPrivacyNotice: () => void,
|
||||
openSettingsDialog: () => void,
|
||||
openModelSelectionDialog: () => void,
|
||||
openSubagentCreateDialog: () => void,
|
||||
openAgentsManagerDialog: () => void,
|
||||
toggleVimEnabled: () => Promise<boolean>,
|
||||
@@ -404,6 +405,9 @@ export const useSlashCommandProcessor = (
|
||||
case 'settings':
|
||||
openSettingsDialog();
|
||||
return { type: 'handled' };
|
||||
case 'model':
|
||||
openModelSelectionDialog();
|
||||
return { type: 'handled' };
|
||||
case 'subagent_create':
|
||||
openSubagentCreateDialog();
|
||||
return { type: 'handled' };
|
||||
@@ -663,6 +667,7 @@ export const useSlashCommandProcessor = (
|
||||
setSessionShellAllowlist,
|
||||
setIsProcessing,
|
||||
setConfirmationRequest,
|
||||
openModelSelectionDialog,
|
||||
session.stats,
|
||||
],
|
||||
);
|
||||
|
||||
@@ -56,6 +56,12 @@ const MockedUserPromptEvent = vi.hoisted(() =>
|
||||
);
|
||||
const mockParseAndFormatApiError = vi.hoisted(() => vi.fn());
|
||||
|
||||
// Vision auto-switch mocks (hoisted)
|
||||
const mockHandleVisionSwitch = vi.hoisted(() =>
|
||||
vi.fn().mockResolvedValue({ shouldProceed: true }),
|
||||
);
|
||||
const mockRestoreOriginalModel = vi.hoisted(() => vi.fn());
|
||||
|
||||
vi.mock('@qwen-code/qwen-code-core', async (importOriginal) => {
|
||||
const actualCoreModule = (await importOriginal()) as any;
|
||||
return {
|
||||
@@ -76,6 +82,13 @@ vi.mock('./useReactToolScheduler.js', async (importOriginal) => {
|
||||
};
|
||||
});
|
||||
|
||||
vi.mock('./useVisionAutoSwitch.js', () => ({
|
||||
useVisionAutoSwitch: vi.fn(() => ({
|
||||
handleVisionSwitch: mockHandleVisionSwitch,
|
||||
restoreOriginalModel: mockRestoreOriginalModel,
|
||||
})),
|
||||
}));
|
||||
|
||||
vi.mock('./useKeypress.js', () => ({
|
||||
useKeypress: vi.fn(),
|
||||
}));
|
||||
@@ -199,6 +212,7 @@ describe('useGeminiStream', () => {
|
||||
getContentGeneratorConfig: vi
|
||||
.fn()
|
||||
.mockReturnValue(contentGeneratorConfig),
|
||||
getMaxSessionTurns: vi.fn(() => 50),
|
||||
} as unknown as Config;
|
||||
mockOnDebugMessage = vi.fn();
|
||||
mockHandleSlashCommand = vi.fn().mockResolvedValue(false);
|
||||
@@ -1551,6 +1565,7 @@ describe('useGeminiStream', () => {
|
||||
expect.any(String), // Argument 3: The prompt_id string
|
||||
);
|
||||
});
|
||||
|
||||
describe('Thought Reset', () => {
|
||||
it('should reset thought to null when starting a new prompt', async () => {
|
||||
// First, simulate a response with a thought
|
||||
@@ -1900,4 +1915,166 @@ describe('useGeminiStream', () => {
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
// --- New tests focused on recent modifications ---
|
||||
describe('Vision Auto Switch Integration', () => {
|
||||
it('should call handleVisionSwitch and proceed to send when allowed', async () => {
|
||||
mockHandleVisionSwitch.mockResolvedValueOnce({ shouldProceed: true });
|
||||
mockSendMessageStream.mockReturnValue(
|
||||
(async function* () {
|
||||
yield { type: ServerGeminiEventType.Content, value: 'ok' };
|
||||
yield { type: ServerGeminiEventType.Finished, value: 'STOP' };
|
||||
})(),
|
||||
);
|
||||
|
||||
const { result } = renderHook(() =>
|
||||
useGeminiStream(
|
||||
new MockedGeminiClientClass(mockConfig),
|
||||
[],
|
||||
mockAddItem,
|
||||
mockConfig,
|
||||
mockOnDebugMessage,
|
||||
mockHandleSlashCommand,
|
||||
false,
|
||||
() => 'vscode' as EditorType,
|
||||
() => {},
|
||||
() => Promise.resolve(),
|
||||
false,
|
||||
() => {},
|
||||
() => {},
|
||||
() => {},
|
||||
),
|
||||
);
|
||||
|
||||
await act(async () => {
|
||||
await result.current.submitQuery('image prompt');
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockHandleVisionSwitch).toHaveBeenCalled();
|
||||
expect(mockSendMessageStream).toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
it('should gate submission when handleVisionSwitch returns shouldProceed=false', async () => {
|
||||
mockHandleVisionSwitch.mockResolvedValueOnce({ shouldProceed: false });
|
||||
|
||||
const { result } = renderHook(() =>
|
||||
useGeminiStream(
|
||||
new MockedGeminiClientClass(mockConfig),
|
||||
[],
|
||||
mockAddItem,
|
||||
mockConfig,
|
||||
mockOnDebugMessage,
|
||||
mockHandleSlashCommand,
|
||||
false,
|
||||
() => 'vscode' as EditorType,
|
||||
() => {},
|
||||
() => Promise.resolve(),
|
||||
false,
|
||||
() => {},
|
||||
() => {},
|
||||
() => {},
|
||||
),
|
||||
);
|
||||
|
||||
await act(async () => {
|
||||
await result.current.submitQuery('vision-gated');
|
||||
});
|
||||
|
||||
// No call to API, no restoreOriginalModel needed since no override occurred
|
||||
expect(mockSendMessageStream).not.toHaveBeenCalled();
|
||||
expect(mockRestoreOriginalModel).not.toHaveBeenCalled();
|
||||
|
||||
// Next call allowed (flag reset path)
|
||||
mockHandleVisionSwitch.mockResolvedValueOnce({ shouldProceed: true });
|
||||
mockSendMessageStream.mockReturnValue(
|
||||
(async function* () {
|
||||
yield { type: ServerGeminiEventType.Content, value: 'ok' };
|
||||
yield { type: ServerGeminiEventType.Finished, value: 'STOP' };
|
||||
})(),
|
||||
);
|
||||
await act(async () => {
|
||||
await result.current.submitQuery('after-gate');
|
||||
});
|
||||
await waitFor(() => {
|
||||
expect(mockSendMessageStream).toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('Model restore on completion and errors', () => {
|
||||
it('should restore model after successful stream completion', async () => {
|
||||
mockSendMessageStream.mockReturnValue(
|
||||
(async function* () {
|
||||
yield { type: ServerGeminiEventType.Content, value: 'content' };
|
||||
yield { type: ServerGeminiEventType.Finished, value: 'STOP' };
|
||||
})(),
|
||||
);
|
||||
|
||||
const { result } = renderHook(() =>
|
||||
useGeminiStream(
|
||||
new MockedGeminiClientClass(mockConfig),
|
||||
[],
|
||||
mockAddItem,
|
||||
mockConfig,
|
||||
mockOnDebugMessage,
|
||||
mockHandleSlashCommand,
|
||||
false,
|
||||
() => 'vscode' as EditorType,
|
||||
() => {},
|
||||
() => Promise.resolve(),
|
||||
false,
|
||||
() => {},
|
||||
() => {},
|
||||
() => {},
|
||||
),
|
||||
);
|
||||
|
||||
await act(async () => {
|
||||
await result.current.submitQuery('restore-success');
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockRestoreOriginalModel).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
});
|
||||
|
||||
it('should restore model when an error occurs during streaming', async () => {
|
||||
const testError = new Error('stream failure');
|
||||
mockSendMessageStream.mockReturnValue(
|
||||
(async function* () {
|
||||
yield { type: ServerGeminiEventType.Content, value: 'content' };
|
||||
throw testError;
|
||||
})(),
|
||||
);
|
||||
|
||||
const { result } = renderHook(() =>
|
||||
useGeminiStream(
|
||||
new MockedGeminiClientClass(mockConfig),
|
||||
[],
|
||||
mockAddItem,
|
||||
mockConfig,
|
||||
mockOnDebugMessage,
|
||||
mockHandleSlashCommand,
|
||||
false,
|
||||
() => 'vscode' as EditorType,
|
||||
() => {},
|
||||
() => Promise.resolve(),
|
||||
false,
|
||||
() => {},
|
||||
() => {},
|
||||
() => {},
|
||||
),
|
||||
);
|
||||
|
||||
await act(async () => {
|
||||
await result.current.submitQuery('restore-error');
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockRestoreOriginalModel).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -42,6 +42,7 @@ import type {
|
||||
import { StreamingState, MessageType, ToolCallStatus } from '../types.js';
|
||||
import { isAtCommand, isSlashCommand } from '../utils/commandUtils.js';
|
||||
import { useShellCommandProcessor } from './shellCommandProcessor.js';
|
||||
import { useVisionAutoSwitch } from './useVisionAutoSwitch.js';
|
||||
import { handleAtCommand } from './atCommandProcessor.js';
|
||||
import { findLastSafeSplitPoint } from '../utils/markdownUtilities.js';
|
||||
import { useStateAndRef } from './useStateAndRef.js';
|
||||
@@ -88,6 +89,12 @@ export const useGeminiStream = (
|
||||
setModelSwitchedFromQuotaError: React.Dispatch<React.SetStateAction<boolean>>,
|
||||
onEditorClose: () => void,
|
||||
onCancelSubmit: () => void,
|
||||
visionModelPreviewEnabled: boolean = false,
|
||||
onVisionSwitchRequired?: (query: PartListUnion) => Promise<{
|
||||
modelOverride?: string;
|
||||
persistSessionModel?: string;
|
||||
showGuidance?: boolean;
|
||||
}>,
|
||||
) => {
|
||||
const [initError, setInitError] = useState<string | null>(null);
|
||||
const abortControllerRef = useRef<AbortController | null>(null);
|
||||
@@ -155,6 +162,13 @@ export const useGeminiStream = (
|
||||
geminiClient,
|
||||
);
|
||||
|
||||
const { handleVisionSwitch, restoreOriginalModel } = useVisionAutoSwitch(
|
||||
config,
|
||||
addItem,
|
||||
visionModelPreviewEnabled,
|
||||
onVisionSwitchRequired,
|
||||
);
|
||||
|
||||
const streamingState = useMemo(() => {
|
||||
if (toolCalls.some((tc) => tc.status === 'awaiting_approval')) {
|
||||
return StreamingState.WaitingForConfirmation;
|
||||
@@ -715,6 +729,20 @@ export const useGeminiStream = (
|
||||
return;
|
||||
}
|
||||
|
||||
// Handle vision switch requirement
|
||||
const visionSwitchResult = await handleVisionSwitch(
|
||||
queryToSend,
|
||||
userMessageTimestamp,
|
||||
options?.isContinuation || false,
|
||||
);
|
||||
|
||||
if (!visionSwitchResult.shouldProceed) {
|
||||
isSubmittingQueryRef.current = false;
|
||||
return;
|
||||
}
|
||||
|
||||
const finalQueryToSend = queryToSend;
|
||||
|
||||
if (!options?.isContinuation) {
|
||||
startNewPrompt();
|
||||
setThought(null); // Reset thought when starting a new prompt
|
||||
@@ -725,7 +753,7 @@ export const useGeminiStream = (
|
||||
|
||||
try {
|
||||
const stream = geminiClient.sendMessageStream(
|
||||
queryToSend,
|
||||
finalQueryToSend,
|
||||
abortSignal,
|
||||
prompt_id!,
|
||||
);
|
||||
@@ -736,6 +764,8 @@ export const useGeminiStream = (
|
||||
);
|
||||
|
||||
if (processingStatus === StreamProcessingStatus.UserCancelled) {
|
||||
// Restore original model if it was temporarily overridden
|
||||
restoreOriginalModel();
|
||||
isSubmittingQueryRef.current = false;
|
||||
return;
|
||||
}
|
||||
@@ -748,7 +778,13 @@ export const useGeminiStream = (
|
||||
loopDetectedRef.current = false;
|
||||
handleLoopDetectedEvent();
|
||||
}
|
||||
|
||||
// Restore original model if it was temporarily overridden
|
||||
restoreOriginalModel();
|
||||
} catch (error: unknown) {
|
||||
// Restore original model if it was temporarily overridden
|
||||
restoreOriginalModel();
|
||||
|
||||
if (error instanceof UnauthorizedError) {
|
||||
onAuthError();
|
||||
} else if (!isNodeError(error) || error.name !== 'AbortError') {
|
||||
@@ -786,6 +822,8 @@ export const useGeminiStream = (
|
||||
startNewPrompt,
|
||||
getPromptCount,
|
||||
handleLoopDetectedEvent,
|
||||
handleVisionSwitch,
|
||||
restoreOriginalModel,
|
||||
],
|
||||
);
|
||||
|
||||
|
||||
374
packages/cli/src/ui/hooks/useVisionAutoSwitch.test.ts
Normal file
374
packages/cli/src/ui/hooks/useVisionAutoSwitch.test.ts
Normal file
@@ -0,0 +1,374 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Qwen
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
/* eslint-disable @typescript-eslint/no-explicit-any */
|
||||
import { describe, it, expect, vi, beforeEach } from 'vitest';
|
||||
import { renderHook, act } from '@testing-library/react';
|
||||
import type { Part, PartListUnion } from '@google/genai';
|
||||
import { AuthType, type Config } from '@qwen-code/qwen-code-core';
|
||||
import {
|
||||
shouldOfferVisionSwitch,
|
||||
processVisionSwitchOutcome,
|
||||
getVisionSwitchGuidanceMessage,
|
||||
useVisionAutoSwitch,
|
||||
} from './useVisionAutoSwitch.js';
|
||||
import { VisionSwitchOutcome } from '../components/ModelSwitchDialog.js';
|
||||
import { MessageType } from '../types.js';
|
||||
import { getDefaultVisionModel } from '../models/availableModels.js';
|
||||
|
||||
describe('useVisionAutoSwitch helpers', () => {
|
||||
describe('shouldOfferVisionSwitch', () => {
|
||||
it('returns false when authType is not QWEN_OAUTH', () => {
|
||||
const parts: PartListUnion = [
|
||||
{ inlineData: { mimeType: 'image/png', data: '...' } },
|
||||
];
|
||||
const result = shouldOfferVisionSwitch(
|
||||
parts,
|
||||
AuthType.USE_GEMINI,
|
||||
'qwen3-coder-plus',
|
||||
true,
|
||||
);
|
||||
expect(result).toBe(false);
|
||||
});
|
||||
|
||||
it('returns false when current model is already a vision model', () => {
|
||||
const parts: PartListUnion = [
|
||||
{ inlineData: { mimeType: 'image/png', data: '...' } },
|
||||
];
|
||||
const result = shouldOfferVisionSwitch(
|
||||
parts,
|
||||
AuthType.QWEN_OAUTH,
|
||||
'qwen-vl-max-latest',
|
||||
true,
|
||||
);
|
||||
expect(result).toBe(false);
|
||||
});
|
||||
|
||||
it('returns true when image parts exist, QWEN_OAUTH, and model is not vision', () => {
|
||||
const parts: PartListUnion = [
|
||||
{ text: 'hello' },
|
||||
{ inlineData: { mimeType: 'image/jpeg', data: '...' } },
|
||||
];
|
||||
const result = shouldOfferVisionSwitch(
|
||||
parts,
|
||||
AuthType.QWEN_OAUTH,
|
||||
'qwen3-coder-plus',
|
||||
true,
|
||||
);
|
||||
expect(result).toBe(true);
|
||||
});
|
||||
|
||||
it('detects image when provided as a single Part object (non-array)', () => {
|
||||
const singleImagePart: PartListUnion = {
|
||||
fileData: { mimeType: 'image/gif', fileUri: 'file://image.gif' },
|
||||
} as Part;
|
||||
const result = shouldOfferVisionSwitch(
|
||||
singleImagePart,
|
||||
AuthType.QWEN_OAUTH,
|
||||
'qwen3-coder-plus',
|
||||
true,
|
||||
);
|
||||
expect(result).toBe(true);
|
||||
});
|
||||
|
||||
it('returns false when parts contain no images', () => {
|
||||
const parts: PartListUnion = [{ text: 'just text' }];
|
||||
const result = shouldOfferVisionSwitch(
|
||||
parts,
|
||||
AuthType.QWEN_OAUTH,
|
||||
'qwen3-coder-plus',
|
||||
true,
|
||||
);
|
||||
expect(result).toBe(false);
|
||||
});
|
||||
|
||||
it('returns false when parts is a plain string', () => {
|
||||
const parts: PartListUnion = 'plain text';
|
||||
const result = shouldOfferVisionSwitch(
|
||||
parts,
|
||||
AuthType.QWEN_OAUTH,
|
||||
'qwen3-coder-plus',
|
||||
true,
|
||||
);
|
||||
expect(result).toBe(false);
|
||||
});
|
||||
|
||||
it('returns false when visionModelPreviewEnabled is false', () => {
|
||||
const parts: PartListUnion = [
|
||||
{ inlineData: { mimeType: 'image/png', data: '...' } },
|
||||
];
|
||||
const result = shouldOfferVisionSwitch(
|
||||
parts,
|
||||
AuthType.QWEN_OAUTH,
|
||||
'qwen3-coder-plus',
|
||||
false,
|
||||
);
|
||||
expect(result).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe('processVisionSwitchOutcome', () => {
|
||||
it('maps SwitchOnce to a one-time model override', () => {
|
||||
const vl = getDefaultVisionModel();
|
||||
const result = processVisionSwitchOutcome(VisionSwitchOutcome.SwitchOnce);
|
||||
expect(result).toEqual({ modelOverride: vl });
|
||||
});
|
||||
|
||||
it('maps SwitchSessionToVL to a persistent session model', () => {
|
||||
const vl = getDefaultVisionModel();
|
||||
const result = processVisionSwitchOutcome(
|
||||
VisionSwitchOutcome.SwitchSessionToVL,
|
||||
);
|
||||
expect(result).toEqual({ persistSessionModel: vl });
|
||||
});
|
||||
|
||||
it('maps DisallowWithGuidance to showGuidance', () => {
|
||||
const result = processVisionSwitchOutcome(
|
||||
VisionSwitchOutcome.DisallowWithGuidance,
|
||||
);
|
||||
expect(result).toEqual({ showGuidance: true });
|
||||
});
|
||||
});
|
||||
|
||||
describe('getVisionSwitchGuidanceMessage', () => {
|
||||
it('returns the expected guidance message', () => {
|
||||
const vl = getDefaultVisionModel();
|
||||
const expected =
|
||||
'To use images with your query, you can:\n' +
|
||||
`• Use /model set ${vl} to switch to a vision-capable model\n` +
|
||||
'• Or remove the image and provide a text description instead';
|
||||
expect(getVisionSwitchGuidanceMessage()).toBe(expected);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('useVisionAutoSwitch hook', () => {
|
||||
type AddItemFn = (
|
||||
item: { type: MessageType; text: string },
|
||||
ts: number,
|
||||
) => any;
|
||||
|
||||
const createMockConfig = (authType: AuthType, initialModel: string) => {
|
||||
let currentModel = initialModel;
|
||||
const mockConfig: Partial<Config> = {
|
||||
getModel: vi.fn(() => currentModel),
|
||||
setModel: vi.fn((m: string) => {
|
||||
currentModel = m;
|
||||
}),
|
||||
getContentGeneratorConfig: vi.fn(() => ({
|
||||
authType,
|
||||
model: currentModel,
|
||||
apiKey: 'test-key',
|
||||
vertexai: false,
|
||||
})),
|
||||
};
|
||||
return mockConfig as Config;
|
||||
};
|
||||
|
||||
let addItem: AddItemFn;
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
addItem = vi.fn();
|
||||
});
|
||||
|
||||
it('returns shouldProceed=true immediately for continuations', async () => {
|
||||
const config = createMockConfig(AuthType.QWEN_OAUTH, 'qwen3-coder-plus');
|
||||
const { result } = renderHook(() =>
|
||||
useVisionAutoSwitch(config, addItem as any, true, vi.fn()),
|
||||
);
|
||||
|
||||
const parts: PartListUnion = [
|
||||
{ inlineData: { mimeType: 'image/png', data: '...' } },
|
||||
];
|
||||
let res: any;
|
||||
await act(async () => {
|
||||
res = await result.current.handleVisionSwitch(parts, Date.now(), true);
|
||||
});
|
||||
expect(res).toEqual({ shouldProceed: true });
|
||||
expect(addItem).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('does nothing when authType is not QWEN_OAUTH', async () => {
|
||||
const config = createMockConfig(AuthType.USE_GEMINI, 'qwen3-coder-plus');
|
||||
const onVisionSwitchRequired = vi.fn();
|
||||
const { result } = renderHook(() =>
|
||||
useVisionAutoSwitch(config, addItem as any, true, onVisionSwitchRequired),
|
||||
);
|
||||
|
||||
const parts: PartListUnion = [
|
||||
{ inlineData: { mimeType: 'image/png', data: '...' } },
|
||||
];
|
||||
let res: any;
|
||||
await act(async () => {
|
||||
res = await result.current.handleVisionSwitch(parts, 123, false);
|
||||
});
|
||||
expect(res).toEqual({ shouldProceed: true });
|
||||
expect(onVisionSwitchRequired).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('does nothing when there are no image parts', async () => {
|
||||
const config = createMockConfig(AuthType.QWEN_OAUTH, 'qwen3-coder-plus');
|
||||
const onVisionSwitchRequired = vi.fn();
|
||||
const { result } = renderHook(() =>
|
||||
useVisionAutoSwitch(config, addItem as any, true, onVisionSwitchRequired),
|
||||
);
|
||||
|
||||
const parts: PartListUnion = [{ text: 'no images here' }];
|
||||
let res: any;
|
||||
await act(async () => {
|
||||
res = await result.current.handleVisionSwitch(parts, 456, false);
|
||||
});
|
||||
expect(res).toEqual({ shouldProceed: true });
|
||||
expect(onVisionSwitchRequired).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('shows guidance and blocks when dialog returns showGuidance', async () => {
|
||||
const config = createMockConfig(AuthType.QWEN_OAUTH, 'qwen3-coder-plus');
|
||||
const onVisionSwitchRequired = vi
|
||||
.fn()
|
||||
.mockResolvedValue({ showGuidance: true });
|
||||
const { result } = renderHook(() =>
|
||||
useVisionAutoSwitch(config, addItem as any, true, onVisionSwitchRequired),
|
||||
);
|
||||
|
||||
const parts: PartListUnion = [
|
||||
{ inlineData: { mimeType: 'image/png', data: '...' } },
|
||||
];
|
||||
|
||||
const userTs = 1010;
|
||||
let res: any;
|
||||
await act(async () => {
|
||||
res = await result.current.handleVisionSwitch(parts, userTs, false);
|
||||
});
|
||||
|
||||
expect(addItem).toHaveBeenCalledWith(
|
||||
{ type: MessageType.INFO, text: getVisionSwitchGuidanceMessage() },
|
||||
userTs,
|
||||
);
|
||||
expect(res).toEqual({ shouldProceed: false });
|
||||
expect(config.setModel).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('applies a one-time override and returns originalModel, then restores', async () => {
|
||||
const initialModel = 'qwen3-coder-plus';
|
||||
const config = createMockConfig(AuthType.QWEN_OAUTH, initialModel);
|
||||
const onVisionSwitchRequired = vi
|
||||
.fn()
|
||||
.mockResolvedValue({ modelOverride: 'qwen-vl-max-latest' });
|
||||
const { result } = renderHook(() =>
|
||||
useVisionAutoSwitch(config, addItem as any, true, onVisionSwitchRequired),
|
||||
);
|
||||
|
||||
const parts: PartListUnion = [
|
||||
{ inlineData: { mimeType: 'image/png', data: '...' } },
|
||||
];
|
||||
|
||||
let res: any;
|
||||
await act(async () => {
|
||||
res = await result.current.handleVisionSwitch(parts, 2020, false);
|
||||
});
|
||||
|
||||
expect(res).toEqual({ shouldProceed: true, originalModel: initialModel });
|
||||
expect(config.setModel).toHaveBeenCalledWith('qwen-vl-max-latest');
|
||||
|
||||
// Now restore
|
||||
act(() => {
|
||||
result.current.restoreOriginalModel();
|
||||
});
|
||||
expect(config.setModel).toHaveBeenLastCalledWith(initialModel);
|
||||
});
|
||||
|
||||
it('persists session model when dialog requests persistence', async () => {
|
||||
const config = createMockConfig(AuthType.QWEN_OAUTH, 'qwen3-coder-plus');
|
||||
const onVisionSwitchRequired = vi
|
||||
.fn()
|
||||
.mockResolvedValue({ persistSessionModel: 'qwen-vl-max-latest' });
|
||||
const { result } = renderHook(() =>
|
||||
useVisionAutoSwitch(config, addItem as any, true, onVisionSwitchRequired),
|
||||
);
|
||||
|
||||
const parts: PartListUnion = [
|
||||
{ inlineData: { mimeType: 'image/png', data: '...' } },
|
||||
];
|
||||
|
||||
let res: any;
|
||||
await act(async () => {
|
||||
res = await result.current.handleVisionSwitch(parts, 3030, false);
|
||||
});
|
||||
|
||||
expect(res).toEqual({ shouldProceed: true });
|
||||
expect(config.setModel).toHaveBeenCalledWith('qwen-vl-max-latest');
|
||||
|
||||
// Restore should be a no-op since no one-time override was used
|
||||
act(() => {
|
||||
result.current.restoreOriginalModel();
|
||||
});
|
||||
// Last call should still be the persisted model set
|
||||
expect((config.setModel as any).mock.calls.pop()?.[0]).toBe(
|
||||
'qwen-vl-max-latest',
|
||||
);
|
||||
});
|
||||
|
||||
it('returns shouldProceed=true when dialog returns no special flags', async () => {
|
||||
const config = createMockConfig(AuthType.QWEN_OAUTH, 'qwen3-coder-plus');
|
||||
const onVisionSwitchRequired = vi.fn().mockResolvedValue({});
|
||||
const { result } = renderHook(() =>
|
||||
useVisionAutoSwitch(config, addItem as any, true, onVisionSwitchRequired),
|
||||
);
|
||||
|
||||
const parts: PartListUnion = [
|
||||
{ inlineData: { mimeType: 'image/png', data: '...' } },
|
||||
];
|
||||
let res: any;
|
||||
await act(async () => {
|
||||
res = await result.current.handleVisionSwitch(parts, 4040, false);
|
||||
});
|
||||
expect(res).toEqual({ shouldProceed: true });
|
||||
expect(config.setModel).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('blocks when dialog throws or is cancelled', async () => {
|
||||
const config = createMockConfig(AuthType.QWEN_OAUTH, 'qwen3-coder-plus');
|
||||
const onVisionSwitchRequired = vi.fn().mockRejectedValue(new Error('x'));
|
||||
const { result } = renderHook(() =>
|
||||
useVisionAutoSwitch(config, addItem as any, true, onVisionSwitchRequired),
|
||||
);
|
||||
|
||||
const parts: PartListUnion = [
|
||||
{ inlineData: { mimeType: 'image/png', data: '...' } },
|
||||
];
|
||||
let res: any;
|
||||
await act(async () => {
|
||||
res = await result.current.handleVisionSwitch(parts, 5050, false);
|
||||
});
|
||||
expect(res).toEqual({ shouldProceed: false });
|
||||
expect(config.setModel).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('does nothing when visionModelPreviewEnabled is false', async () => {
|
||||
const config = createMockConfig(AuthType.QWEN_OAUTH, 'qwen3-coder-plus');
|
||||
const onVisionSwitchRequired = vi.fn();
|
||||
const { result } = renderHook(() =>
|
||||
useVisionAutoSwitch(
|
||||
config,
|
||||
addItem as any,
|
||||
false,
|
||||
onVisionSwitchRequired,
|
||||
),
|
||||
);
|
||||
|
||||
const parts: PartListUnion = [
|
||||
{ inlineData: { mimeType: 'image/png', data: '...' } },
|
||||
];
|
||||
let res: any;
|
||||
await act(async () => {
|
||||
res = await result.current.handleVisionSwitch(parts, 6060, false);
|
||||
});
|
||||
expect(res).toEqual({ shouldProceed: true });
|
||||
expect(onVisionSwitchRequired).not.toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
304
packages/cli/src/ui/hooks/useVisionAutoSwitch.ts
Normal file
304
packages/cli/src/ui/hooks/useVisionAutoSwitch.ts
Normal file
@@ -0,0 +1,304 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Qwen
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { type PartListUnion, type Part } from '@google/genai';
|
||||
import { AuthType, type Config } from '@qwen-code/qwen-code-core';
|
||||
import { useCallback, useRef } from 'react';
|
||||
import { VisionSwitchOutcome } from '../components/ModelSwitchDialog.js';
|
||||
import {
|
||||
getDefaultVisionModel,
|
||||
isVisionModel,
|
||||
} from '../models/availableModels.js';
|
||||
import { MessageType } from '../types.js';
|
||||
import type { UseHistoryManagerReturn } from './useHistoryManager.js';
|
||||
import {
|
||||
isSupportedImageMimeType,
|
||||
getUnsupportedImageFormatWarning,
|
||||
} from '@qwen-code/qwen-code-core';
|
||||
|
||||
/**
|
||||
* Checks if a PartListUnion contains image parts
|
||||
*/
|
||||
function hasImageParts(parts: PartListUnion): boolean {
|
||||
if (typeof parts === 'string') {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (Array.isArray(parts)) {
|
||||
return parts.some((part) => {
|
||||
// Skip string parts
|
||||
if (typeof part === 'string') return false;
|
||||
return isImagePart(part);
|
||||
});
|
||||
}
|
||||
|
||||
// If it's a single Part (not a string), check if it's an image
|
||||
if (typeof parts === 'object') {
|
||||
return isImagePart(parts);
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
/**
|
||||
* Checks if a single Part is an image part
|
||||
*/
|
||||
function isImagePart(part: Part): boolean {
|
||||
// Check for inlineData with image mime type
|
||||
if ('inlineData' in part && part.inlineData?.mimeType?.startsWith('image/')) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// Check for fileData with image mime type
|
||||
if ('fileData' in part && part.fileData?.mimeType?.startsWith('image/')) {
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
/**
|
||||
* Checks if image parts have supported formats and returns unsupported ones
|
||||
*/
|
||||
function checkImageFormatsSupport(parts: PartListUnion): {
|
||||
hasImages: boolean;
|
||||
hasUnsupportedFormats: boolean;
|
||||
unsupportedMimeTypes: string[];
|
||||
} {
|
||||
const unsupportedMimeTypes: string[] = [];
|
||||
let hasImages = false;
|
||||
|
||||
if (typeof parts === 'string') {
|
||||
return {
|
||||
hasImages: false,
|
||||
hasUnsupportedFormats: false,
|
||||
unsupportedMimeTypes: [],
|
||||
};
|
||||
}
|
||||
|
||||
const partsArray = Array.isArray(parts) ? parts : [parts];
|
||||
|
||||
for (const part of partsArray) {
|
||||
if (typeof part === 'string') continue;
|
||||
|
||||
let mimeType: string | undefined;
|
||||
|
||||
// Check inlineData
|
||||
if (
|
||||
'inlineData' in part &&
|
||||
part.inlineData?.mimeType?.startsWith('image/')
|
||||
) {
|
||||
hasImages = true;
|
||||
mimeType = part.inlineData.mimeType;
|
||||
}
|
||||
|
||||
// Check fileData
|
||||
if ('fileData' in part && part.fileData?.mimeType?.startsWith('image/')) {
|
||||
hasImages = true;
|
||||
mimeType = part.fileData.mimeType;
|
||||
}
|
||||
|
||||
// Check if the mime type is supported
|
||||
if (mimeType && !isSupportedImageMimeType(mimeType)) {
|
||||
unsupportedMimeTypes.push(mimeType);
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
hasImages,
|
||||
hasUnsupportedFormats: unsupportedMimeTypes.length > 0,
|
||||
unsupportedMimeTypes,
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Determines if we should offer vision switch for the given parts, auth type, and current model
|
||||
*/
|
||||
export function shouldOfferVisionSwitch(
|
||||
parts: PartListUnion,
|
||||
authType: AuthType,
|
||||
currentModel: string,
|
||||
visionModelPreviewEnabled: boolean = false,
|
||||
): boolean {
|
||||
// Only trigger for qwen-oauth
|
||||
if (authType !== AuthType.QWEN_OAUTH) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// If vision model preview is disabled, never offer vision switch
|
||||
if (!visionModelPreviewEnabled) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// If current model is already a vision model, no need to switch
|
||||
if (isVisionModel(currentModel)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check if the current message contains image parts
|
||||
return hasImageParts(parts);
|
||||
}
|
||||
|
||||
/**
|
||||
* Interface for vision switch result
|
||||
*/
|
||||
export interface VisionSwitchResult {
|
||||
modelOverride?: string;
|
||||
persistSessionModel?: string;
|
||||
showGuidance?: boolean;
|
||||
}
|
||||
|
||||
/**
|
||||
* Processes the vision switch outcome and returns the appropriate result
|
||||
*/
|
||||
export function processVisionSwitchOutcome(
|
||||
outcome: VisionSwitchOutcome,
|
||||
): VisionSwitchResult {
|
||||
const vlModelId = getDefaultVisionModel();
|
||||
|
||||
switch (outcome) {
|
||||
case VisionSwitchOutcome.SwitchOnce:
|
||||
return { modelOverride: vlModelId };
|
||||
|
||||
case VisionSwitchOutcome.SwitchSessionToVL:
|
||||
return { persistSessionModel: vlModelId };
|
||||
|
||||
case VisionSwitchOutcome.DisallowWithGuidance:
|
||||
return { showGuidance: true };
|
||||
|
||||
default:
|
||||
return { showGuidance: true };
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the guidance message for when vision switch is disallowed
|
||||
*/
|
||||
export function getVisionSwitchGuidanceMessage(): string {
|
||||
const vlModelId = getDefaultVisionModel();
|
||||
return `To use images with your query, you can:
|
||||
• Use /model set ${vlModelId} to switch to a vision-capable model
|
||||
• Or remove the image and provide a text description instead`;
|
||||
}
|
||||
|
||||
/**
|
||||
* Interface for vision switch handling result
|
||||
*/
|
||||
export interface VisionSwitchHandlingResult {
|
||||
shouldProceed: boolean;
|
||||
originalModel?: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* Custom hook for handling vision model auto-switching
|
||||
*/
|
||||
export function useVisionAutoSwitch(
|
||||
config: Config,
|
||||
addItem: UseHistoryManagerReturn['addItem'],
|
||||
visionModelPreviewEnabled: boolean = false,
|
||||
onVisionSwitchRequired?: (query: PartListUnion) => Promise<{
|
||||
modelOverride?: string;
|
||||
persistSessionModel?: string;
|
||||
showGuidance?: boolean;
|
||||
}>,
|
||||
) {
|
||||
const originalModelRef = useRef<string | null>(null);
|
||||
|
||||
const handleVisionSwitch = useCallback(
|
||||
async (
|
||||
query: PartListUnion,
|
||||
userMessageTimestamp: number,
|
||||
isContinuation: boolean,
|
||||
): Promise<VisionSwitchHandlingResult> => {
|
||||
// Skip vision switch handling for continuations or if no handler provided
|
||||
if (isContinuation || !onVisionSwitchRequired) {
|
||||
return { shouldProceed: true };
|
||||
}
|
||||
|
||||
const contentGeneratorConfig = config.getContentGeneratorConfig();
|
||||
|
||||
// Only handle qwen-oauth auth type
|
||||
if (contentGeneratorConfig?.authType !== AuthType.QWEN_OAUTH) {
|
||||
return { shouldProceed: true };
|
||||
}
|
||||
|
||||
// Check image format support first
|
||||
const formatCheck = checkImageFormatsSupport(query);
|
||||
|
||||
// If there are unsupported image formats, show warning
|
||||
if (formatCheck.hasUnsupportedFormats) {
|
||||
addItem(
|
||||
{
|
||||
type: MessageType.INFO,
|
||||
text: getUnsupportedImageFormatWarning(),
|
||||
},
|
||||
userMessageTimestamp,
|
||||
);
|
||||
// Continue processing but with warning shown
|
||||
}
|
||||
|
||||
// Check if vision switch is needed
|
||||
if (
|
||||
!shouldOfferVisionSwitch(
|
||||
query,
|
||||
contentGeneratorConfig.authType,
|
||||
config.getModel(),
|
||||
visionModelPreviewEnabled,
|
||||
)
|
||||
) {
|
||||
return { shouldProceed: true };
|
||||
}
|
||||
|
||||
try {
|
||||
const visionSwitchResult = await onVisionSwitchRequired(query);
|
||||
|
||||
if (visionSwitchResult.showGuidance) {
|
||||
// Show guidance and don't proceed with the request
|
||||
addItem(
|
||||
{
|
||||
type: MessageType.INFO,
|
||||
text: getVisionSwitchGuidanceMessage(),
|
||||
},
|
||||
userMessageTimestamp,
|
||||
);
|
||||
return { shouldProceed: false };
|
||||
}
|
||||
|
||||
if (visionSwitchResult.modelOverride) {
|
||||
// One-time model override
|
||||
originalModelRef.current = config.getModel();
|
||||
config.setModel(visionSwitchResult.modelOverride);
|
||||
return {
|
||||
shouldProceed: true,
|
||||
originalModel: originalModelRef.current,
|
||||
};
|
||||
} else if (visionSwitchResult.persistSessionModel) {
|
||||
// Persistent session model change
|
||||
config.setModel(visionSwitchResult.persistSessionModel);
|
||||
return { shouldProceed: true };
|
||||
}
|
||||
|
||||
return { shouldProceed: true };
|
||||
} catch (_error) {
|
||||
// If vision switch dialog was cancelled or errored, don't proceed
|
||||
return { shouldProceed: false };
|
||||
}
|
||||
},
|
||||
[config, addItem, visionModelPreviewEnabled, onVisionSwitchRequired],
|
||||
);
|
||||
|
||||
const restoreOriginalModel = useCallback(() => {
|
||||
if (originalModelRef.current) {
|
||||
config.setModel(originalModelRef.current);
|
||||
originalModelRef.current = null;
|
||||
}
|
||||
}, [config]);
|
||||
|
||||
return {
|
||||
handleVisionSwitch,
|
||||
restoreOriginalModel,
|
||||
};
|
||||
}
|
||||
52
packages/cli/src/ui/models/availableModels.ts
Normal file
52
packages/cli/src/ui/models/availableModels.ts
Normal file
@@ -0,0 +1,52 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Qwen
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
export type AvailableModel = {
|
||||
id: string;
|
||||
label: string;
|
||||
isVision?: boolean;
|
||||
};
|
||||
|
||||
export const AVAILABLE_MODELS_QWEN: AvailableModel[] = [
|
||||
{ id: 'qwen3-coder-plus', label: 'qwen3-coder-plus' },
|
||||
{ id: 'qwen-vl-max-latest', label: 'qwen-vl-max', isVision: true },
|
||||
];
|
||||
|
||||
/**
|
||||
* Get available Qwen models filtered by vision model preview setting
|
||||
*/
|
||||
export function getFilteredQwenModels(
|
||||
visionModelPreviewEnabled: boolean,
|
||||
): AvailableModel[] {
|
||||
if (visionModelPreviewEnabled) {
|
||||
return AVAILABLE_MODELS_QWEN;
|
||||
}
|
||||
return AVAILABLE_MODELS_QWEN.filter((model) => !model.isVision);
|
||||
}
|
||||
|
||||
/**
|
||||
* Currently we use the single model of `OPENAI_MODEL` in the env.
|
||||
* In the future, after settings.json is updated, we will allow users to configure this themselves.
|
||||
*/
|
||||
export function getOpenAIAvailableModelFromEnv(): AvailableModel | null {
|
||||
const id = process.env['OPENAI_MODEL']?.trim();
|
||||
return id ? { id, label: id } : null;
|
||||
}
|
||||
|
||||
/**
|
||||
/**
|
||||
* Hard code the default vision model as a string literal,
|
||||
* until our coding model supports multimodal.
|
||||
*/
|
||||
export function getDefaultVisionModel(): string {
|
||||
return 'qwen-vl-max-latest';
|
||||
}
|
||||
|
||||
export function isVisionModel(modelId: string): boolean {
|
||||
return AVAILABLE_MODELS_QWEN.some(
|
||||
(model) => model.id === modelId && model.isVision,
|
||||
);
|
||||
}
|
||||
@@ -19,3 +19,4 @@ export {
|
||||
} from './src/telemetry/types.js';
|
||||
export { makeFakeConfig } from './src/test-utils/config.js';
|
||||
export * from './src/utils/pathReader.js';
|
||||
export * from './src/utils/request-tokenizer/supportedImageFormats.js';
|
||||
|
||||
@@ -5,9 +5,10 @@
|
||||
*/
|
||||
|
||||
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
|
||||
import { OpenAIContentGenerator } from '../openaiContentGenerator.js';
|
||||
import { OpenAIContentGenerator } from '../openaiContentGenerator/openaiContentGenerator.js';
|
||||
import type { Config } from '../../config/config.js';
|
||||
import { AuthType } from '../contentGenerator.js';
|
||||
import type { OpenAICompatibleProvider } from '../openaiContentGenerator/provider/index.js';
|
||||
import OpenAI from 'openai';
|
||||
|
||||
// Mock OpenAI
|
||||
@@ -30,6 +31,7 @@ describe('OpenAIContentGenerator Timeout Handling', () => {
|
||||
let mockConfig: Config;
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
let mockOpenAIClient: any;
|
||||
let mockProvider: OpenAICompatibleProvider;
|
||||
|
||||
beforeEach(() => {
|
||||
// Reset mocks
|
||||
@@ -42,6 +44,7 @@ describe('OpenAIContentGenerator Timeout Handling', () => {
|
||||
mockConfig = {
|
||||
getContentGeneratorConfig: vi.fn().mockReturnValue({
|
||||
authType: 'openai',
|
||||
enableOpenAILogging: false,
|
||||
}),
|
||||
getCliVersion: vi.fn().mockReturnValue('1.0.0'),
|
||||
} as unknown as Config;
|
||||
@@ -53,17 +56,34 @@ describe('OpenAIContentGenerator Timeout Handling', () => {
|
||||
create: vi.fn(),
|
||||
},
|
||||
},
|
||||
embeddings: {
|
||||
create: vi.fn(),
|
||||
},
|
||||
};
|
||||
|
||||
vi.mocked(OpenAI).mockImplementation(() => mockOpenAIClient);
|
||||
|
||||
// Create mock provider
|
||||
mockProvider = {
|
||||
buildHeaders: vi.fn().mockReturnValue({
|
||||
'User-Agent': 'QwenCode/1.0.0 (test; test)',
|
||||
}),
|
||||
buildClient: vi.fn().mockReturnValue(mockOpenAIClient),
|
||||
buildRequest: vi.fn().mockImplementation((req) => req),
|
||||
};
|
||||
|
||||
// Create generator instance
|
||||
const contentGeneratorConfig = {
|
||||
model: 'gpt-4',
|
||||
apiKey: 'test-key',
|
||||
authType: AuthType.USE_OPENAI,
|
||||
enableOpenAILogging: false,
|
||||
};
|
||||
generator = new OpenAIContentGenerator(contentGeneratorConfig, mockConfig);
|
||||
generator = new OpenAIContentGenerator(
|
||||
contentGeneratorConfig,
|
||||
mockConfig,
|
||||
mockProvider,
|
||||
);
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
@@ -209,7 +229,7 @@ describe('OpenAIContentGenerator Timeout Handling', () => {
|
||||
await expect(
|
||||
generator.generateContentStream(request, 'test-prompt-id'),
|
||||
).rejects.toThrow(
|
||||
/Streaming setup timeout after \d+s\. Try reducing input length or increasing timeout in config\./,
|
||||
/Streaming request timeout after \d+s\. Try reducing input length or increasing timeout in config\./,
|
||||
);
|
||||
});
|
||||
|
||||
@@ -227,12 +247,8 @@ describe('OpenAIContentGenerator Timeout Handling', () => {
|
||||
} catch (error: unknown) {
|
||||
const errorMessage =
|
||||
error instanceof Error ? error.message : String(error);
|
||||
expect(errorMessage).toContain(
|
||||
'Streaming setup timeout troubleshooting:',
|
||||
);
|
||||
expect(errorMessage).toContain(
|
||||
'Check network connectivity and firewall settings',
|
||||
);
|
||||
expect(errorMessage).toContain('Streaming timeout troubleshooting:');
|
||||
expect(errorMessage).toContain('Check network connectivity');
|
||||
expect(errorMessage).toContain('Consider using non-streaming mode');
|
||||
}
|
||||
});
|
||||
@@ -246,23 +262,21 @@ describe('OpenAIContentGenerator Timeout Handling', () => {
|
||||
authType: AuthType.USE_OPENAI,
|
||||
baseUrl: 'http://localhost:8080',
|
||||
};
|
||||
new OpenAIContentGenerator(contentGeneratorConfig, mockConfig);
|
||||
new OpenAIContentGenerator(
|
||||
contentGeneratorConfig,
|
||||
mockConfig,
|
||||
mockProvider,
|
||||
);
|
||||
|
||||
// Verify OpenAI client was created with timeout config
|
||||
expect(OpenAI).toHaveBeenCalledWith({
|
||||
apiKey: 'test-key',
|
||||
baseURL: 'http://localhost:8080',
|
||||
timeout: 120000,
|
||||
maxRetries: 3,
|
||||
defaultHeaders: {
|
||||
'User-Agent': expect.stringMatching(/^QwenCode/),
|
||||
},
|
||||
});
|
||||
// Verify provider buildClient was called
|
||||
expect(mockProvider.buildClient).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should use custom timeout from config', () => {
|
||||
const customConfig = {
|
||||
getContentGeneratorConfig: vi.fn().mockReturnValue({}),
|
||||
getContentGeneratorConfig: vi.fn().mockReturnValue({
|
||||
enableOpenAILogging: false,
|
||||
}),
|
||||
getCliVersion: vi.fn().mockReturnValue('1.0.0'),
|
||||
} as unknown as Config;
|
||||
|
||||
@@ -274,22 +288,31 @@ describe('OpenAIContentGenerator Timeout Handling', () => {
|
||||
timeout: 300000,
|
||||
maxRetries: 5,
|
||||
};
|
||||
new OpenAIContentGenerator(contentGeneratorConfig, customConfig);
|
||||
|
||||
expect(OpenAI).toHaveBeenCalledWith({
|
||||
apiKey: 'test-key',
|
||||
baseURL: 'http://localhost:8080',
|
||||
timeout: 300000,
|
||||
maxRetries: 5,
|
||||
defaultHeaders: {
|
||||
'User-Agent': expect.stringMatching(/^QwenCode/),
|
||||
},
|
||||
});
|
||||
// Create a custom mock provider for this test
|
||||
const customMockProvider: OpenAICompatibleProvider = {
|
||||
buildHeaders: vi.fn().mockReturnValue({
|
||||
'User-Agent': 'QwenCode/1.0.0 (test; test)',
|
||||
}),
|
||||
buildClient: vi.fn().mockReturnValue(mockOpenAIClient),
|
||||
buildRequest: vi.fn().mockImplementation((req) => req),
|
||||
};
|
||||
|
||||
new OpenAIContentGenerator(
|
||||
contentGeneratorConfig,
|
||||
customConfig,
|
||||
customMockProvider,
|
||||
);
|
||||
|
||||
// Verify provider buildClient was called
|
||||
expect(customMockProvider.buildClient).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should handle missing timeout config gracefully', () => {
|
||||
const noTimeoutConfig = {
|
||||
getContentGeneratorConfig: vi.fn().mockReturnValue({}),
|
||||
getContentGeneratorConfig: vi.fn().mockReturnValue({
|
||||
enableOpenAILogging: false,
|
||||
}),
|
||||
getCliVersion: vi.fn().mockReturnValue('1.0.0'),
|
||||
} as unknown as Config;
|
||||
|
||||
@@ -299,17 +322,24 @@ describe('OpenAIContentGenerator Timeout Handling', () => {
|
||||
authType: AuthType.USE_OPENAI,
|
||||
baseUrl: 'http://localhost:8080',
|
||||
};
|
||||
new OpenAIContentGenerator(contentGeneratorConfig, noTimeoutConfig);
|
||||
|
||||
expect(OpenAI).toHaveBeenCalledWith({
|
||||
apiKey: 'test-key',
|
||||
baseURL: 'http://localhost:8080',
|
||||
timeout: 120000, // default
|
||||
maxRetries: 3, // default
|
||||
defaultHeaders: {
|
||||
'User-Agent': expect.stringMatching(/^QwenCode/),
|
||||
},
|
||||
});
|
||||
// Create a custom mock provider for this test
|
||||
const noTimeoutMockProvider: OpenAICompatibleProvider = {
|
||||
buildHeaders: vi.fn().mockReturnValue({
|
||||
'User-Agent': 'QwenCode/1.0.0 (test; test)',
|
||||
}),
|
||||
buildClient: vi.fn().mockReturnValue(mockOpenAIClient),
|
||||
buildRequest: vi.fn().mockImplementation((req) => req),
|
||||
};
|
||||
|
||||
new OpenAIContentGenerator(
|
||||
contentGeneratorConfig,
|
||||
noTimeoutConfig,
|
||||
noTimeoutMockProvider,
|
||||
);
|
||||
|
||||
// Verify provider buildClient was called
|
||||
expect(noTimeoutMockProvider.buildClient).toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
|
||||
@@ -500,7 +500,7 @@ export class GeminiChat {
|
||||
if (error instanceof Error && error.message) {
|
||||
if (isSchemaDepthError(error.message)) return false;
|
||||
if (error.message.includes('429')) return true;
|
||||
if (error.message.match(/5\d{2}/)) return true;
|
||||
if (error.message.match(/^5\d{2}/)) return true;
|
||||
}
|
||||
return false;
|
||||
},
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -5,6 +5,37 @@
|
||||
*/
|
||||
|
||||
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
|
||||
|
||||
// Mock the request tokenizer module BEFORE importing the class that uses it
|
||||
const mockTokenizer = {
|
||||
calculateTokens: vi.fn().mockResolvedValue({
|
||||
totalTokens: 50,
|
||||
breakdown: {
|
||||
textTokens: 50,
|
||||
imageTokens: 0,
|
||||
audioTokens: 0,
|
||||
otherTokens: 0,
|
||||
},
|
||||
processingTime: 1,
|
||||
}),
|
||||
dispose: vi.fn(),
|
||||
};
|
||||
|
||||
vi.mock('../../../utils/request-tokenizer/index.js', () => ({
|
||||
getDefaultTokenizer: vi.fn(() => mockTokenizer),
|
||||
DefaultRequestTokenizer: vi.fn(() => mockTokenizer),
|
||||
disposeDefaultTokenizer: vi.fn(),
|
||||
}));
|
||||
|
||||
// Mock tiktoken as well for completeness
|
||||
vi.mock('tiktoken', () => ({
|
||||
get_encoding: vi.fn(() => ({
|
||||
encode: vi.fn(() => new Array(50)), // Mock 50 tokens
|
||||
free: vi.fn(),
|
||||
})),
|
||||
}));
|
||||
|
||||
// Now import the modules that depend on the mocked modules
|
||||
import { OpenAIContentGenerator } from './openaiContentGenerator.js';
|
||||
import type { Config } from '../../config/config.js';
|
||||
import { AuthType } from '../contentGenerator.js';
|
||||
@@ -15,14 +46,6 @@ import type {
|
||||
import type { OpenAICompatibleProvider } from './provider/index.js';
|
||||
import type OpenAI from 'openai';
|
||||
|
||||
// Mock tiktoken
|
||||
vi.mock('tiktoken', () => ({
|
||||
get_encoding: vi.fn().mockReturnValue({
|
||||
encode: vi.fn().mockReturnValue(new Array(50)), // Mock 50 tokens
|
||||
free: vi.fn(),
|
||||
}),
|
||||
}));
|
||||
|
||||
describe('OpenAIContentGenerator (Refactored)', () => {
|
||||
let generator: OpenAIContentGenerator;
|
||||
let mockConfig: Config;
|
||||
|
||||
@@ -13,6 +13,7 @@ import type { PipelineConfig } from './pipeline.js';
|
||||
import { ContentGenerationPipeline } from './pipeline.js';
|
||||
import { DefaultTelemetryService } from './telemetryService.js';
|
||||
import { EnhancedErrorHandler } from './errorHandler.js';
|
||||
import { getDefaultTokenizer } from '../../utils/request-tokenizer/index.js';
|
||||
import type { ContentGeneratorConfig } from '../contentGenerator.js';
|
||||
|
||||
export class OpenAIContentGenerator implements ContentGenerator {
|
||||
@@ -71,28 +72,31 @@ export class OpenAIContentGenerator implements ContentGenerator {
|
||||
async countTokens(
|
||||
request: CountTokensParameters,
|
||||
): Promise<CountTokensResponse> {
|
||||
// Use tiktoken for accurate token counting
|
||||
const content = JSON.stringify(request.contents);
|
||||
let totalTokens = 0;
|
||||
|
||||
try {
|
||||
const { get_encoding } = await import('tiktoken');
|
||||
const encoding = get_encoding('cl100k_base'); // GPT-4 encoding, but estimate for qwen
|
||||
totalTokens = encoding.encode(content).length;
|
||||
encoding.free();
|
||||
// Use the new high-performance request tokenizer
|
||||
const tokenizer = getDefaultTokenizer();
|
||||
const result = await tokenizer.calculateTokens(request, {
|
||||
textEncoding: 'cl100k_base', // Use GPT-4 encoding for consistency
|
||||
});
|
||||
|
||||
return {
|
||||
totalTokens: result.totalTokens,
|
||||
};
|
||||
} catch (error) {
|
||||
console.warn(
|
||||
'Failed to load tiktoken, falling back to character approximation:',
|
||||
'Failed to calculate tokens with new tokenizer, falling back to simple method:',
|
||||
error,
|
||||
);
|
||||
// Fallback: rough approximation using character count
|
||||
totalTokens = Math.ceil(content.length / 4); // Rough estimate: 1 token ≈ 4 characters
|
||||
}
|
||||
|
||||
// Fallback to original simple method
|
||||
const content = JSON.stringify(request.contents);
|
||||
const totalTokens = Math.ceil(content.length / 4); // Rough estimate: 1 token ≈ 4 characters
|
||||
|
||||
return {
|
||||
totalTokens,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
async embedContent(
|
||||
request: EmbedContentParameters,
|
||||
|
||||
@@ -10,14 +10,11 @@ import {
|
||||
GenerateContentResponse,
|
||||
} from '@google/genai';
|
||||
import type { Config } from '../../config/config.js';
|
||||
import { type ContentGeneratorConfig } from '../contentGenerator.js';
|
||||
import { type OpenAICompatibleProvider } from './provider/index.js';
|
||||
import type { ContentGeneratorConfig } from '../contentGenerator.js';
|
||||
import type { OpenAICompatibleProvider } from './provider/index.js';
|
||||
import { OpenAIContentConverter } from './converter.js';
|
||||
import {
|
||||
type TelemetryService,
|
||||
type RequestContext,
|
||||
} from './telemetryService.js';
|
||||
import { type ErrorHandler } from './errorHandler.js';
|
||||
import type { TelemetryService, RequestContext } from './telemetryService.js';
|
||||
import type { ErrorHandler } from './errorHandler.js';
|
||||
|
||||
export interface PipelineConfig {
|
||||
cliConfig: Config;
|
||||
@@ -101,7 +98,7 @@ export class ContentGenerationPipeline {
|
||||
* 2. Filter empty responses
|
||||
* 3. Handle chunk merging for providers that send finishReason and usageMetadata separately
|
||||
* 4. Collect both formats for logging
|
||||
* 5. Handle success/error logging with original OpenAI format
|
||||
* 5. Handle success/error logging
|
||||
*/
|
||||
private async *processStreamWithLogging(
|
||||
stream: AsyncIterable<OpenAI.Chat.ChatCompletionChunk>,
|
||||
@@ -169,19 +166,11 @@ export class ContentGenerationPipeline {
|
||||
collectedOpenAIChunks,
|
||||
);
|
||||
} catch (error) {
|
||||
// Stage 2e: Stream failed - handle error and logging
|
||||
context.duration = Date.now() - context.startTime;
|
||||
|
||||
// Clear streaming tool calls on error to prevent data pollution
|
||||
this.converter.resetStreamingToolCalls();
|
||||
|
||||
await this.config.telemetryService.logError(
|
||||
context,
|
||||
error,
|
||||
openaiRequest,
|
||||
);
|
||||
|
||||
this.config.errorHandler.handle(error, context, request);
|
||||
// Use shared error handling logic
|
||||
await this.handleError(error, context, request);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -365,25 +354,59 @@ export class ContentGenerationPipeline {
|
||||
context.duration = Date.now() - context.startTime;
|
||||
return result;
|
||||
} catch (error) {
|
||||
context.duration = Date.now() - context.startTime;
|
||||
|
||||
// Log error
|
||||
const openaiRequest = await this.buildRequest(
|
||||
// Use shared error handling logic
|
||||
return await this.handleError(
|
||||
error,
|
||||
context,
|
||||
request,
|
||||
userPromptId,
|
||||
isStreaming,
|
||||
);
|
||||
await this.config.telemetryService.logError(
|
||||
context,
|
||||
error,
|
||||
openaiRequest,
|
||||
);
|
||||
|
||||
// Handle and throw enhanced error
|
||||
this.config.errorHandler.handle(error, context, request);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Shared error handling logic for both executeWithErrorHandling and processStreamWithLogging
|
||||
* This centralizes the common error processing steps to avoid duplication
|
||||
*/
|
||||
private async handleError(
|
||||
error: unknown,
|
||||
context: RequestContext,
|
||||
request: GenerateContentParameters,
|
||||
userPromptId?: string,
|
||||
isStreaming?: boolean,
|
||||
): Promise<never> {
|
||||
context.duration = Date.now() - context.startTime;
|
||||
|
||||
// Build request for logging (may fail, but we still want to log the error)
|
||||
let openaiRequest: OpenAI.Chat.ChatCompletionCreateParams;
|
||||
try {
|
||||
if (userPromptId !== undefined && isStreaming !== undefined) {
|
||||
openaiRequest = await this.buildRequest(
|
||||
request,
|
||||
userPromptId,
|
||||
isStreaming,
|
||||
);
|
||||
} else {
|
||||
// For processStreamWithLogging, we don't have userPromptId/isStreaming,
|
||||
// so create a minimal request
|
||||
openaiRequest = {
|
||||
model: this.contentGeneratorConfig.model,
|
||||
messages: [],
|
||||
};
|
||||
}
|
||||
} catch (_buildError) {
|
||||
// If we can't build the request, create a minimal one for logging
|
||||
openaiRequest = {
|
||||
model: this.contentGeneratorConfig.model,
|
||||
messages: [],
|
||||
};
|
||||
}
|
||||
|
||||
await this.config.telemetryService.logError(context, error, openaiRequest);
|
||||
this.config.errorHandler.handle(error, context, request);
|
||||
}
|
||||
|
||||
/**
|
||||
* Create request context with common properties
|
||||
*/
|
||||
|
||||
@@ -79,6 +79,16 @@ export class DashScopeOpenAICompatibleProvider
|
||||
messages = this.addDashScopeCacheControl(messages, cacheTarget);
|
||||
}
|
||||
|
||||
if (request.model.startsWith('qwen-vl')) {
|
||||
return {
|
||||
...request,
|
||||
messages,
|
||||
...(this.buildMetadata(userPromptId) || {}),
|
||||
/* @ts-expect-error dashscope exclusive */
|
||||
vl_high_resolution_images: true,
|
||||
};
|
||||
}
|
||||
|
||||
return {
|
||||
...request, // Preserve all original parameters including sampling params
|
||||
messages,
|
||||
|
||||
@@ -116,6 +116,9 @@ const PATTERNS: Array<[RegExp, TokenCount]> = [
|
||||
[/^qwen-flash-latest$/, LIMITS['1m']],
|
||||
[/^qwen-turbo.*$/, LIMITS['128k']],
|
||||
|
||||
// Qwen Vision Models
|
||||
[/^qwen-vl-max.*$/, LIMITS['128k']],
|
||||
|
||||
// -------------------
|
||||
// ByteDance Seed-OSS (512K)
|
||||
// -------------------
|
||||
|
||||
@@ -242,7 +242,7 @@ describe('Turn', () => {
|
||||
expect(turn.getDebugResponses().length).toBe(0);
|
||||
expect(reportError).toHaveBeenCalledWith(
|
||||
error,
|
||||
'Error when talking to Gemini API',
|
||||
'Error when talking to API',
|
||||
[...historyContent, reqParts],
|
||||
'Turn.run-sendMessageStream',
|
||||
);
|
||||
|
||||
@@ -310,7 +310,7 @@ export class Turn {
|
||||
const contextForReport = [...this.chat.getHistory(/*curated*/ true), req];
|
||||
await reportError(
|
||||
error,
|
||||
'Error when talking to Gemini API',
|
||||
'Error when talking to API',
|
||||
contextForReport,
|
||||
'Turn.run-sendMessageStream',
|
||||
);
|
||||
|
||||
@@ -401,11 +401,9 @@ describe('QwenContentGenerator', () => {
|
||||
expect(mockQwenClient.getAccessToken).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should count tokens with valid token', async () => {
|
||||
vi.mocked(mockQwenClient.getAccessToken).mockResolvedValue({
|
||||
token: 'valid-token',
|
||||
});
|
||||
vi.mocked(mockQwenClient.getCredentials).mockReturnValue(mockCredentials);
|
||||
it('should count tokens without requiring authentication', async () => {
|
||||
// Clear any previous mock calls
|
||||
vi.clearAllMocks();
|
||||
|
||||
const request: CountTokensParameters = {
|
||||
model: 'qwen-turbo',
|
||||
@@ -415,7 +413,8 @@ describe('QwenContentGenerator', () => {
|
||||
const result = await qwenContentGenerator.countTokens(request);
|
||||
|
||||
expect(result.totalTokens).toBe(15);
|
||||
expect(mockQwenClient.getAccessToken).toHaveBeenCalled();
|
||||
// countTokens is a local operation and should not require OAuth credentials
|
||||
expect(mockQwenClient.getAccessToken).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should embed content with valid token', async () => {
|
||||
@@ -1652,7 +1651,7 @@ describe('QwenContentGenerator', () => {
|
||||
SharedTokenManager.getInstance = originalGetInstance;
|
||||
});
|
||||
|
||||
it('should handle all method types with token failure', async () => {
|
||||
it('should handle method types with token failure (except countTokens)', async () => {
|
||||
const mockTokenManager = {
|
||||
getValidCredentials: vi
|
||||
.fn()
|
||||
@@ -1685,7 +1684,7 @@ describe('QwenContentGenerator', () => {
|
||||
contents: [{ parts: [{ text: 'Embed' }] }],
|
||||
};
|
||||
|
||||
// All methods should fail with the same error
|
||||
// Methods requiring authentication should fail
|
||||
await expect(
|
||||
newGenerator.generateContent(generateRequest, 'test-id'),
|
||||
).rejects.toThrow('Failed to obtain valid Qwen access token');
|
||||
@@ -1694,14 +1693,14 @@ describe('QwenContentGenerator', () => {
|
||||
newGenerator.generateContentStream(generateRequest, 'test-id'),
|
||||
).rejects.toThrow('Failed to obtain valid Qwen access token');
|
||||
|
||||
await expect(newGenerator.countTokens(countRequest)).rejects.toThrow(
|
||||
'Failed to obtain valid Qwen access token',
|
||||
);
|
||||
|
||||
await expect(newGenerator.embedContent(embedRequest)).rejects.toThrow(
|
||||
'Failed to obtain valid Qwen access token',
|
||||
);
|
||||
|
||||
// countTokens should succeed as it's a local operation
|
||||
const countResult = await newGenerator.countTokens(countRequest);
|
||||
expect(countResult.totalTokens).toBe(15);
|
||||
|
||||
SharedTokenManager.getInstance = originalGetInstance;
|
||||
});
|
||||
});
|
||||
|
||||
@@ -180,9 +180,7 @@ export class QwenContentGenerator extends OpenAIContentGenerator {
|
||||
override async countTokens(
|
||||
request: CountTokensParameters,
|
||||
): Promise<CountTokensResponse> {
|
||||
return this.executeWithCredentialManagement(() =>
|
||||
super.countTokens(request),
|
||||
);
|
||||
return super.countTokens(request);
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
157
packages/core/src/utils/request-tokenizer/imageTokenizer.test.ts
Normal file
157
packages/core/src/utils/request-tokenizer/imageTokenizer.test.ts
Normal file
@@ -0,0 +1,157 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Qwen
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { describe, it, expect } from 'vitest';
|
||||
import { ImageTokenizer } from './imageTokenizer.js';
|
||||
|
||||
describe('ImageTokenizer', () => {
|
||||
const tokenizer = new ImageTokenizer();
|
||||
|
||||
describe('token calculation', () => {
|
||||
it('should calculate tokens based on image dimensions with reference logic', () => {
|
||||
const metadata = {
|
||||
width: 28,
|
||||
height: 28,
|
||||
mimeType: 'image/png',
|
||||
dataSize: 1000,
|
||||
};
|
||||
|
||||
const tokens = tokenizer.calculateTokens(metadata);
|
||||
|
||||
// 28x28 = 784 pixels = 1 image token + 2 special tokens = 3 total
|
||||
// But minimum scaling may apply for small images
|
||||
expect(tokens).toBeGreaterThanOrEqual(6); // Minimum after scaling + special tokens
|
||||
});
|
||||
|
||||
it('should calculate tokens for larger images', () => {
|
||||
const metadata = {
|
||||
width: 512,
|
||||
height: 512,
|
||||
mimeType: 'image/png',
|
||||
dataSize: 10000,
|
||||
};
|
||||
|
||||
const tokens = tokenizer.calculateTokens(metadata);
|
||||
|
||||
// 512x512 with reference logic: rounded dimensions + scaling + special tokens
|
||||
expect(tokens).toBeGreaterThan(300);
|
||||
expect(tokens).toBeLessThan(400); // Should be reasonable for 512x512
|
||||
});
|
||||
|
||||
it('should enforce minimum tokens per image with scaling', () => {
|
||||
const metadata = {
|
||||
width: 1,
|
||||
height: 1,
|
||||
mimeType: 'image/png',
|
||||
dataSize: 100,
|
||||
};
|
||||
|
||||
const tokens = tokenizer.calculateTokens(metadata);
|
||||
|
||||
// Tiny images get scaled up to minimum pixels + special tokens
|
||||
expect(tokens).toBeGreaterThanOrEqual(6); // 4 image tokens + 2 special tokens
|
||||
});
|
||||
|
||||
it('should handle very large images with scaling', () => {
|
||||
const metadata = {
|
||||
width: 8192,
|
||||
height: 8192,
|
||||
mimeType: 'image/png',
|
||||
dataSize: 100000,
|
||||
};
|
||||
|
||||
const tokens = tokenizer.calculateTokens(metadata);
|
||||
|
||||
// Very large images should be scaled down to max limit + special tokens
|
||||
expect(tokens).toBeLessThanOrEqual(16386); // 16384 max + 2 special tokens
|
||||
expect(tokens).toBeGreaterThan(16000); // Should be close to the limit
|
||||
});
|
||||
});
|
||||
|
||||
describe('PNG dimension extraction', () => {
|
||||
it('should extract dimensions from valid PNG', async () => {
|
||||
// 1x1 PNG image in base64
|
||||
const pngBase64 =
|
||||
'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNkYPhfDwAChAI9jU77yQAAAABJRU5ErkJggg==';
|
||||
|
||||
const metadata = await tokenizer.extractImageMetadata(
|
||||
pngBase64,
|
||||
'image/png',
|
||||
);
|
||||
|
||||
expect(metadata.width).toBe(1);
|
||||
expect(metadata.height).toBe(1);
|
||||
expect(metadata.mimeType).toBe('image/png');
|
||||
});
|
||||
|
||||
it('should handle invalid PNG gracefully', async () => {
|
||||
const invalidBase64 = 'invalid-png-data';
|
||||
|
||||
const metadata = await tokenizer.extractImageMetadata(
|
||||
invalidBase64,
|
||||
'image/png',
|
||||
);
|
||||
|
||||
// Should return default dimensions
|
||||
expect(metadata.width).toBe(512);
|
||||
expect(metadata.height).toBe(512);
|
||||
expect(metadata.mimeType).toBe('image/png');
|
||||
});
|
||||
});
|
||||
|
||||
describe('batch processing', () => {
|
||||
it('should process multiple images serially', async () => {
|
||||
const pngBase64 =
|
||||
'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNkYPhfDwAChAI9jU77yQAAAABJRU5ErkJggg==';
|
||||
|
||||
const images = [
|
||||
{ data: pngBase64, mimeType: 'image/png' },
|
||||
{ data: pngBase64, mimeType: 'image/png' },
|
||||
{ data: pngBase64, mimeType: 'image/png' },
|
||||
];
|
||||
|
||||
const tokens = await tokenizer.calculateTokensBatch(images);
|
||||
|
||||
expect(tokens).toHaveLength(3);
|
||||
expect(tokens.every((t) => t >= 4)).toBe(true); // All should have at least 4 tokens
|
||||
});
|
||||
|
||||
it('should handle mixed valid and invalid images', async () => {
|
||||
const validPng =
|
||||
'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNkYPhfDwAChAI9jU77yQAAAABJRU5ErkJggg==';
|
||||
const invalidPng = 'invalid-data';
|
||||
|
||||
const images = [
|
||||
{ data: validPng, mimeType: 'image/png' },
|
||||
{ data: invalidPng, mimeType: 'image/png' },
|
||||
];
|
||||
|
||||
const tokens = await tokenizer.calculateTokensBatch(images);
|
||||
|
||||
expect(tokens).toHaveLength(2);
|
||||
expect(tokens.every((t) => t >= 4)).toBe(true); // All should have at least minimum tokens
|
||||
});
|
||||
});
|
||||
|
||||
describe('different image formats', () => {
|
||||
it('should handle different MIME types', async () => {
|
||||
const pngBase64 =
|
||||
'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNkYPhfDwAChAI9jU77yQAAAABJRU5ErkJggg==';
|
||||
|
||||
const formats = ['image/png', 'image/jpeg', 'image/webp', 'image/gif'];
|
||||
|
||||
for (const mimeType of formats) {
|
||||
const metadata = await tokenizer.extractImageMetadata(
|
||||
pngBase64,
|
||||
mimeType,
|
||||
);
|
||||
expect(metadata.mimeType).toBe(mimeType);
|
||||
expect(metadata.width).toBeGreaterThan(0);
|
||||
expect(metadata.height).toBeGreaterThan(0);
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
505
packages/core/src/utils/request-tokenizer/imageTokenizer.ts
Normal file
505
packages/core/src/utils/request-tokenizer/imageTokenizer.ts
Normal file
@@ -0,0 +1,505 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Qwen
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import type { ImageMetadata } from './types.js';
|
||||
import { isSupportedImageMimeType } from './supportedImageFormats.js';
|
||||
|
||||
/**
|
||||
* Image tokenizer for calculating image tokens based on dimensions
|
||||
*
|
||||
* Key rules:
|
||||
* - 28x28 pixels = 1 token
|
||||
* - Minimum: 4 tokens per image
|
||||
* - Maximum: 16384 tokens per image
|
||||
* - Additional: 2 special tokens (vision_bos + vision_eos)
|
||||
* - Supports: PNG, JPEG, WebP, GIF, BMP, TIFF, HEIC formats
|
||||
*/
|
||||
export class ImageTokenizer {
|
||||
/** 28x28 pixels = 1 token */
|
||||
private static readonly PIXELS_PER_TOKEN = 28 * 28;
|
||||
|
||||
/** Minimum tokens per image */
|
||||
private static readonly MIN_TOKENS_PER_IMAGE = 4;
|
||||
|
||||
/** Maximum tokens per image */
|
||||
private static readonly MAX_TOKENS_PER_IMAGE = 16384;
|
||||
|
||||
/** Special tokens for vision markers */
|
||||
private static readonly VISION_SPECIAL_TOKENS = 2;
|
||||
|
||||
/**
|
||||
* Extract image metadata from base64 data
|
||||
*
|
||||
* @param base64Data Base64-encoded image data (with or without data URL prefix)
|
||||
* @param mimeType MIME type of the image
|
||||
* @returns Promise resolving to ImageMetadata with dimensions and format info
|
||||
*/
|
||||
async extractImageMetadata(
|
||||
base64Data: string,
|
||||
mimeType: string,
|
||||
): Promise<ImageMetadata> {
|
||||
try {
|
||||
// Check if the MIME type is supported
|
||||
if (!isSupportedImageMimeType(mimeType)) {
|
||||
console.warn(`Unsupported image format: ${mimeType}`);
|
||||
// Return default metadata for unsupported formats
|
||||
return {
|
||||
width: 512,
|
||||
height: 512,
|
||||
mimeType,
|
||||
dataSize: Math.floor(base64Data.length * 0.75),
|
||||
};
|
||||
}
|
||||
|
||||
const cleanBase64 = base64Data.replace(/^data:[^;]+;base64,/, '');
|
||||
const buffer = Buffer.from(cleanBase64, 'base64');
|
||||
const dimensions = await this.extractDimensions(buffer, mimeType);
|
||||
|
||||
return {
|
||||
width: dimensions.width,
|
||||
height: dimensions.height,
|
||||
mimeType,
|
||||
dataSize: buffer.length,
|
||||
};
|
||||
} catch (error) {
|
||||
console.warn('Failed to extract image metadata:', error);
|
||||
// Return default metadata for fallback
|
||||
return {
|
||||
width: 512,
|
||||
height: 512,
|
||||
mimeType,
|
||||
dataSize: Math.floor(base64Data.length * 0.75),
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Extract image dimensions from buffer based on format
|
||||
*
|
||||
* @param buffer Binary image data buffer
|
||||
* @param mimeType MIME type to determine parsing strategy
|
||||
* @returns Promise resolving to width and height dimensions
|
||||
*/
|
||||
private async extractDimensions(
|
||||
buffer: Buffer,
|
||||
mimeType: string,
|
||||
): Promise<{ width: number; height: number }> {
|
||||
if (mimeType.includes('png')) {
|
||||
return this.extractPngDimensions(buffer);
|
||||
}
|
||||
|
||||
if (mimeType.includes('jpeg') || mimeType.includes('jpg')) {
|
||||
return this.extractJpegDimensions(buffer);
|
||||
}
|
||||
|
||||
if (mimeType.includes('webp')) {
|
||||
return this.extractWebpDimensions(buffer);
|
||||
}
|
||||
|
||||
if (mimeType.includes('gif')) {
|
||||
return this.extractGifDimensions(buffer);
|
||||
}
|
||||
|
||||
if (mimeType.includes('bmp')) {
|
||||
return this.extractBmpDimensions(buffer);
|
||||
}
|
||||
|
||||
if (mimeType.includes('tiff')) {
|
||||
return this.extractTiffDimensions(buffer);
|
||||
}
|
||||
|
||||
if (mimeType.includes('heic')) {
|
||||
return this.extractHeicDimensions(buffer);
|
||||
}
|
||||
|
||||
return { width: 512, height: 512 };
|
||||
}
|
||||
|
||||
/**
|
||||
* Extract PNG dimensions from IHDR chunk
|
||||
* PNG signature: 89 50 4E 47 0D 0A 1A 0A
|
||||
* Width/height at bytes 16-19 and 20-23 (big-endian)
|
||||
*/
|
||||
private extractPngDimensions(buffer: Buffer): {
|
||||
width: number;
|
||||
height: number;
|
||||
} {
|
||||
if (buffer.length < 24) {
|
||||
throw new Error('Invalid PNG: buffer too short');
|
||||
}
|
||||
|
||||
// Verify PNG signature
|
||||
const signature = buffer.subarray(0, 8);
|
||||
const expectedSignature = Buffer.from([
|
||||
0x89, 0x50, 0x4e, 0x47, 0x0d, 0x0a, 0x1a, 0x0a,
|
||||
]);
|
||||
if (!signature.equals(expectedSignature)) {
|
||||
throw new Error('Invalid PNG signature');
|
||||
}
|
||||
|
||||
const width = buffer.readUInt32BE(16);
|
||||
const height = buffer.readUInt32BE(20);
|
||||
|
||||
return { width, height };
|
||||
}
|
||||
|
||||
/**
|
||||
* Extract JPEG dimensions from SOF (Start of Frame) markers
|
||||
* JPEG starts with FF D8, SOF markers: 0xC0-0xC3, 0xC5-0xC7, 0xC9-0xCB, 0xCD-0xCF
|
||||
* Dimensions at offset +5 (height) and +7 (width) from SOF marker
|
||||
*/
|
||||
private extractJpegDimensions(buffer: Buffer): {
|
||||
width: number;
|
||||
height: number;
|
||||
} {
|
||||
if (buffer.length < 4 || buffer[0] !== 0xff || buffer[1] !== 0xd8) {
|
||||
throw new Error('Invalid JPEG signature');
|
||||
}
|
||||
|
||||
let offset = 2;
|
||||
|
||||
while (offset < buffer.length - 8) {
|
||||
if (buffer[offset] !== 0xff) {
|
||||
offset++;
|
||||
continue;
|
||||
}
|
||||
|
||||
const marker = buffer[offset + 1];
|
||||
|
||||
// SOF markers
|
||||
if (
|
||||
(marker >= 0xc0 && marker <= 0xc3) ||
|
||||
(marker >= 0xc5 && marker <= 0xc7) ||
|
||||
(marker >= 0xc9 && marker <= 0xcb) ||
|
||||
(marker >= 0xcd && marker <= 0xcf)
|
||||
) {
|
||||
const height = buffer.readUInt16BE(offset + 5);
|
||||
const width = buffer.readUInt16BE(offset + 7);
|
||||
return { width, height };
|
||||
}
|
||||
|
||||
const segmentLength = buffer.readUInt16BE(offset + 2);
|
||||
offset += 2 + segmentLength;
|
||||
}
|
||||
|
||||
throw new Error('Could not find JPEG dimensions');
|
||||
}
|
||||
|
||||
/**
|
||||
* Extract WebP dimensions from RIFF container
|
||||
* Supports VP8, VP8L, and VP8X formats
|
||||
*/
|
||||
private extractWebpDimensions(buffer: Buffer): {
|
||||
width: number;
|
||||
height: number;
|
||||
} {
|
||||
if (buffer.length < 30) {
|
||||
throw new Error('Invalid WebP: too short');
|
||||
}
|
||||
|
||||
const riffSignature = buffer.subarray(0, 4).toString('ascii');
|
||||
const webpSignature = buffer.subarray(8, 12).toString('ascii');
|
||||
|
||||
if (riffSignature !== 'RIFF' || webpSignature !== 'WEBP') {
|
||||
throw new Error('Invalid WebP signature');
|
||||
}
|
||||
|
||||
const format = buffer.subarray(12, 16).toString('ascii');
|
||||
|
||||
if (format === 'VP8 ') {
|
||||
const width = buffer.readUInt16LE(26) & 0x3fff;
|
||||
const height = buffer.readUInt16LE(28) & 0x3fff;
|
||||
return { width, height };
|
||||
} else if (format === 'VP8L') {
|
||||
const bits = buffer.readUInt32LE(21);
|
||||
const width = (bits & 0x3fff) + 1;
|
||||
const height = ((bits >> 14) & 0x3fff) + 1;
|
||||
return { width, height };
|
||||
} else if (format === 'VP8X') {
|
||||
const width = (buffer.readUInt32LE(24) & 0xffffff) + 1;
|
||||
const height = (buffer.readUInt32LE(26) & 0xffffff) + 1;
|
||||
return { width, height };
|
||||
}
|
||||
|
||||
throw new Error('Unsupported WebP format');
|
||||
}
|
||||
|
||||
/**
|
||||
* Extract GIF dimensions from header
|
||||
* Supports GIF87a and GIF89a formats
|
||||
*/
|
||||
private extractGifDimensions(buffer: Buffer): {
|
||||
width: number;
|
||||
height: number;
|
||||
} {
|
||||
if (buffer.length < 10) {
|
||||
throw new Error('Invalid GIF: too short');
|
||||
}
|
||||
|
||||
const signature = buffer.subarray(0, 6).toString('ascii');
|
||||
if (signature !== 'GIF87a' && signature !== 'GIF89a') {
|
||||
throw new Error('Invalid GIF signature');
|
||||
}
|
||||
|
||||
const width = buffer.readUInt16LE(6);
|
||||
const height = buffer.readUInt16LE(8);
|
||||
|
||||
return { width, height };
|
||||
}
|
||||
|
||||
/**
|
||||
* Calculate tokens for an image based on its metadata
|
||||
*
|
||||
* @param metadata Image metadata containing width, height, and format info
|
||||
* @returns Total token count including base image tokens and special tokens
|
||||
*/
|
||||
calculateTokens(metadata: ImageMetadata): number {
|
||||
return this.calculateTokensWithScaling(metadata.width, metadata.height);
|
||||
}
|
||||
|
||||
/**
|
||||
* Calculate tokens with scaling logic
|
||||
*
|
||||
* Steps:
|
||||
* 1. Normalize to 28-pixel multiples
|
||||
* 2. Scale large images down, small images up
|
||||
* 3. Calculate tokens: pixels / 784 + 2 special tokens
|
||||
*
|
||||
* @param width Original image width in pixels
|
||||
* @param height Original image height in pixels
|
||||
* @returns Total token count for the image
|
||||
*/
|
||||
private calculateTokensWithScaling(width: number, height: number): number {
|
||||
// Normalize to 28-pixel multiples
|
||||
let hBar = Math.round(height / 28) * 28;
|
||||
let wBar = Math.round(width / 28) * 28;
|
||||
|
||||
// Define pixel boundaries
|
||||
const minPixels =
|
||||
ImageTokenizer.MIN_TOKENS_PER_IMAGE * ImageTokenizer.PIXELS_PER_TOKEN;
|
||||
const maxPixels =
|
||||
ImageTokenizer.MAX_TOKENS_PER_IMAGE * ImageTokenizer.PIXELS_PER_TOKEN;
|
||||
|
||||
// Apply scaling
|
||||
if (hBar * wBar > maxPixels) {
|
||||
// Scale down large images
|
||||
const beta = Math.sqrt((height * width) / maxPixels);
|
||||
hBar = Math.floor(height / beta / 28) * 28;
|
||||
wBar = Math.floor(width / beta / 28) * 28;
|
||||
} else if (hBar * wBar < minPixels) {
|
||||
// Scale up small images
|
||||
const beta = Math.sqrt(minPixels / (height * width));
|
||||
hBar = Math.ceil((height * beta) / 28) * 28;
|
||||
wBar = Math.ceil((width * beta) / 28) * 28;
|
||||
}
|
||||
|
||||
// Calculate tokens
|
||||
const imageTokens = Math.floor(
|
||||
(hBar * wBar) / ImageTokenizer.PIXELS_PER_TOKEN,
|
||||
);
|
||||
|
||||
return imageTokens + ImageTokenizer.VISION_SPECIAL_TOKENS;
|
||||
}
|
||||
|
||||
/**
|
||||
* Calculate tokens for multiple images serially
|
||||
*
|
||||
* @param base64DataArray Array of image data with MIME type information
|
||||
* @returns Promise resolving to array of token counts in same order as input
|
||||
*/
|
||||
async calculateTokensBatch(
|
||||
base64DataArray: Array<{ data: string; mimeType: string }>,
|
||||
): Promise<number[]> {
|
||||
const results: number[] = [];
|
||||
|
||||
for (const { data, mimeType } of base64DataArray) {
|
||||
try {
|
||||
const metadata = await this.extractImageMetadata(data, mimeType);
|
||||
results.push(this.calculateTokens(metadata));
|
||||
} catch (error) {
|
||||
console.warn('Error calculating tokens for image:', error);
|
||||
// Return minimum tokens as fallback
|
||||
results.push(
|
||||
ImageTokenizer.MIN_TOKENS_PER_IMAGE +
|
||||
ImageTokenizer.VISION_SPECIAL_TOKENS,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
return results;
|
||||
}
|
||||
|
||||
/**
|
||||
* Extract BMP dimensions from header
|
||||
* BMP signature: 42 4D (BM)
|
||||
* Width/height at bytes 18-21 and 22-25 (little-endian)
|
||||
*/
|
||||
private extractBmpDimensions(buffer: Buffer): {
|
||||
width: number;
|
||||
height: number;
|
||||
} {
|
||||
if (buffer.length < 26) {
|
||||
throw new Error('Invalid BMP: buffer too short');
|
||||
}
|
||||
|
||||
// Verify BMP signature
|
||||
if (buffer[0] !== 0x42 || buffer[1] !== 0x4d) {
|
||||
throw new Error('Invalid BMP signature');
|
||||
}
|
||||
|
||||
const width = buffer.readUInt32LE(18);
|
||||
const height = buffer.readUInt32LE(22);
|
||||
|
||||
return { width, height: Math.abs(height) }; // Height can be negative for top-down BMPs
|
||||
}
|
||||
|
||||
/**
|
||||
* Extract TIFF dimensions from IFD (Image File Directory)
|
||||
* TIFF can be little-endian (II) or big-endian (MM)
|
||||
* Width/height are stored in IFD entries with tags 0x0100 and 0x0101
|
||||
*/
|
||||
private extractTiffDimensions(buffer: Buffer): {
|
||||
width: number;
|
||||
height: number;
|
||||
} {
|
||||
if (buffer.length < 8) {
|
||||
throw new Error('Invalid TIFF: buffer too short');
|
||||
}
|
||||
|
||||
// Check byte order
|
||||
const byteOrder = buffer.subarray(0, 2).toString('ascii');
|
||||
const isLittleEndian = byteOrder === 'II';
|
||||
const isBigEndian = byteOrder === 'MM';
|
||||
|
||||
if (!isLittleEndian && !isBigEndian) {
|
||||
throw new Error('Invalid TIFF byte order');
|
||||
}
|
||||
|
||||
// Read magic number (should be 42)
|
||||
const magic = isLittleEndian
|
||||
? buffer.readUInt16LE(2)
|
||||
: buffer.readUInt16BE(2);
|
||||
if (magic !== 42) {
|
||||
throw new Error('Invalid TIFF magic number');
|
||||
}
|
||||
|
||||
// Read IFD offset
|
||||
const ifdOffset = isLittleEndian
|
||||
? buffer.readUInt32LE(4)
|
||||
: buffer.readUInt32BE(4);
|
||||
|
||||
if (ifdOffset >= buffer.length) {
|
||||
throw new Error('Invalid TIFF IFD offset');
|
||||
}
|
||||
|
||||
// Read number of directory entries
|
||||
const numEntries = isLittleEndian
|
||||
? buffer.readUInt16LE(ifdOffset)
|
||||
: buffer.readUInt16BE(ifdOffset);
|
||||
|
||||
let width = 0;
|
||||
let height = 0;
|
||||
|
||||
// Parse IFD entries
|
||||
for (let i = 0; i < numEntries; i++) {
|
||||
const entryOffset = ifdOffset + 2 + i * 12;
|
||||
|
||||
if (entryOffset + 12 > buffer.length) break;
|
||||
|
||||
const tag = isLittleEndian
|
||||
? buffer.readUInt16LE(entryOffset)
|
||||
: buffer.readUInt16BE(entryOffset);
|
||||
|
||||
const type = isLittleEndian
|
||||
? buffer.readUInt16LE(entryOffset + 2)
|
||||
: buffer.readUInt16BE(entryOffset + 2);
|
||||
|
||||
const value = isLittleEndian
|
||||
? buffer.readUInt32LE(entryOffset + 8)
|
||||
: buffer.readUInt32BE(entryOffset + 8);
|
||||
|
||||
if (tag === 0x0100) {
|
||||
// ImageWidth
|
||||
width = type === 3 ? value : value; // SHORT or LONG
|
||||
} else if (tag === 0x0101) {
|
||||
// ImageLength (height)
|
||||
height = type === 3 ? value : value; // SHORT or LONG
|
||||
}
|
||||
|
||||
if (width > 0 && height > 0) break;
|
||||
}
|
||||
|
||||
if (width === 0 || height === 0) {
|
||||
throw new Error('Could not find TIFF dimensions');
|
||||
}
|
||||
|
||||
return { width, height };
|
||||
}
|
||||
|
||||
/**
|
||||
* Extract HEIC dimensions from meta box
|
||||
* HEIC is based on ISO Base Media File Format
|
||||
* This is a simplified implementation that looks for 'ispe' (Image Spatial Extents) box
|
||||
*/
|
||||
private extractHeicDimensions(buffer: Buffer): {
|
||||
width: number;
|
||||
height: number;
|
||||
} {
|
||||
if (buffer.length < 12) {
|
||||
throw new Error('Invalid HEIC: buffer too short');
|
||||
}
|
||||
|
||||
// Check for ftyp box with HEIC brand
|
||||
const ftypBox = buffer.subarray(4, 8).toString('ascii');
|
||||
if (ftypBox !== 'ftyp') {
|
||||
throw new Error('Invalid HEIC: missing ftyp box');
|
||||
}
|
||||
|
||||
const brand = buffer.subarray(8, 12).toString('ascii');
|
||||
if (!['heic', 'heix', 'hevc', 'hevx'].includes(brand)) {
|
||||
throw new Error('Invalid HEIC brand');
|
||||
}
|
||||
|
||||
// Look for meta box and then ispe box
|
||||
let offset = 0;
|
||||
while (offset < buffer.length - 8) {
|
||||
const boxSize = buffer.readUInt32BE(offset);
|
||||
const boxType = buffer.subarray(offset + 4, offset + 8).toString('ascii');
|
||||
|
||||
if (boxType === 'meta') {
|
||||
// Look for ispe box inside meta box
|
||||
const metaOffset = offset + 8;
|
||||
let innerOffset = metaOffset + 4; // Skip version and flags
|
||||
|
||||
while (innerOffset < offset + boxSize - 8) {
|
||||
const innerBoxSize = buffer.readUInt32BE(innerOffset);
|
||||
const innerBoxType = buffer
|
||||
.subarray(innerOffset + 4, innerOffset + 8)
|
||||
.toString('ascii');
|
||||
|
||||
if (innerBoxType === 'ispe') {
|
||||
// Found Image Spatial Extents box
|
||||
if (innerOffset + 20 <= buffer.length) {
|
||||
const width = buffer.readUInt32BE(innerOffset + 12);
|
||||
const height = buffer.readUInt32BE(innerOffset + 16);
|
||||
return { width, height };
|
||||
}
|
||||
}
|
||||
|
||||
if (innerBoxSize === 0) break;
|
||||
innerOffset += innerBoxSize;
|
||||
}
|
||||
}
|
||||
|
||||
if (boxSize === 0) break;
|
||||
offset += boxSize;
|
||||
}
|
||||
|
||||
// Fallback: return default dimensions if we can't parse the structure
|
||||
console.warn('Could not extract HEIC dimensions, using default');
|
||||
return { width: 512, height: 512 };
|
||||
}
|
||||
}
|
||||
40
packages/core/src/utils/request-tokenizer/index.ts
Normal file
40
packages/core/src/utils/request-tokenizer/index.ts
Normal file
@@ -0,0 +1,40 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Qwen
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
export { DefaultRequestTokenizer } from './requestTokenizer.js';
|
||||
import { DefaultRequestTokenizer } from './requestTokenizer.js';
|
||||
export { TextTokenizer } from './textTokenizer.js';
|
||||
export { ImageTokenizer } from './imageTokenizer.js';
|
||||
|
||||
export type {
|
||||
RequestTokenizer,
|
||||
TokenizerConfig,
|
||||
TokenCalculationResult,
|
||||
ImageMetadata,
|
||||
} from './types.js';
|
||||
|
||||
// Singleton instance for convenient usage
|
||||
let defaultTokenizer: DefaultRequestTokenizer | null = null;
|
||||
|
||||
/**
|
||||
* Get the default request tokenizer instance
|
||||
*/
|
||||
export function getDefaultTokenizer(): DefaultRequestTokenizer {
|
||||
if (!defaultTokenizer) {
|
||||
defaultTokenizer = new DefaultRequestTokenizer();
|
||||
}
|
||||
return defaultTokenizer;
|
||||
}
|
||||
|
||||
/**
|
||||
* Dispose of the default tokenizer instance
|
||||
*/
|
||||
export async function disposeDefaultTokenizer(): Promise<void> {
|
||||
if (defaultTokenizer) {
|
||||
await defaultTokenizer.dispose();
|
||||
defaultTokenizer = null;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,293 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Qwen
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { describe, it, expect, beforeEach, afterEach } from 'vitest';
|
||||
import { DefaultRequestTokenizer } from './requestTokenizer.js';
|
||||
import type { CountTokensParameters } from '@google/genai';
|
||||
|
||||
describe('DefaultRequestTokenizer', () => {
|
||||
let tokenizer: DefaultRequestTokenizer;
|
||||
|
||||
beforeEach(() => {
|
||||
tokenizer = new DefaultRequestTokenizer();
|
||||
});
|
||||
|
||||
afterEach(async () => {
|
||||
await tokenizer.dispose();
|
||||
});
|
||||
|
||||
describe('text token calculation', () => {
|
||||
it('should calculate tokens for simple text content', async () => {
|
||||
const request: CountTokensParameters = {
|
||||
model: 'test-model',
|
||||
contents: [
|
||||
{
|
||||
role: 'user',
|
||||
parts: [{ text: 'Hello, world!' }],
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
const result = await tokenizer.calculateTokens(request);
|
||||
|
||||
expect(result.totalTokens).toBeGreaterThan(0);
|
||||
expect(result.breakdown.textTokens).toBeGreaterThan(0);
|
||||
expect(result.breakdown.imageTokens).toBe(0);
|
||||
expect(result.processingTime).toBeGreaterThan(0);
|
||||
});
|
||||
|
||||
it('should handle multiple text parts', async () => {
|
||||
const request: CountTokensParameters = {
|
||||
model: 'test-model',
|
||||
contents: [
|
||||
{
|
||||
role: 'user',
|
||||
parts: [
|
||||
{ text: 'First part' },
|
||||
{ text: 'Second part' },
|
||||
{ text: 'Third part' },
|
||||
],
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
const result = await tokenizer.calculateTokens(request);
|
||||
|
||||
expect(result.totalTokens).toBeGreaterThan(0);
|
||||
expect(result.breakdown.textTokens).toBeGreaterThan(0);
|
||||
});
|
||||
|
||||
it('should handle string content', async () => {
|
||||
const request: CountTokensParameters = {
|
||||
model: 'test-model',
|
||||
contents: ['Simple string content'],
|
||||
};
|
||||
|
||||
const result = await tokenizer.calculateTokens(request);
|
||||
|
||||
expect(result.totalTokens).toBeGreaterThan(0);
|
||||
expect(result.breakdown.textTokens).toBeGreaterThan(0);
|
||||
});
|
||||
});
|
||||
|
||||
describe('image token calculation', () => {
|
||||
it('should calculate tokens for image content', async () => {
|
||||
// Create a simple 1x1 PNG image in base64
|
||||
const pngBase64 =
|
||||
'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNkYPhfDwAChAI9jU77yQAAAABJRU5ErkJggg==';
|
||||
|
||||
const request: CountTokensParameters = {
|
||||
model: 'test-model',
|
||||
contents: [
|
||||
{
|
||||
role: 'user',
|
||||
parts: [
|
||||
{
|
||||
inlineData: {
|
||||
mimeType: 'image/png',
|
||||
data: pngBase64,
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
const result = await tokenizer.calculateTokens(request);
|
||||
|
||||
expect(result.totalTokens).toBeGreaterThanOrEqual(4); // Minimum 4 tokens per image
|
||||
expect(result.breakdown.imageTokens).toBeGreaterThanOrEqual(4);
|
||||
expect(result.breakdown.textTokens).toBe(0);
|
||||
});
|
||||
|
||||
it('should handle multiple images', async () => {
|
||||
const pngBase64 =
|
||||
'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNkYPhfDwAChAI9jU77yQAAAABJRU5ErkJggg==';
|
||||
|
||||
const request: CountTokensParameters = {
|
||||
model: 'test-model',
|
||||
contents: [
|
||||
{
|
||||
role: 'user',
|
||||
parts: [
|
||||
{
|
||||
inlineData: {
|
||||
mimeType: 'image/png',
|
||||
data: pngBase64,
|
||||
},
|
||||
},
|
||||
{
|
||||
inlineData: {
|
||||
mimeType: 'image/png',
|
||||
data: pngBase64,
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
const result = await tokenizer.calculateTokens(request);
|
||||
|
||||
expect(result.totalTokens).toBeGreaterThanOrEqual(8); // At least 4 tokens per image
|
||||
expect(result.breakdown.imageTokens).toBeGreaterThanOrEqual(8);
|
||||
});
|
||||
});
|
||||
|
||||
describe('mixed content', () => {
|
||||
it('should handle text and image content together', async () => {
|
||||
const pngBase64 =
|
||||
'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNkYPhfDwAChAI9jU77yQAAAABJRU5ErkJggg==';
|
||||
|
||||
const request: CountTokensParameters = {
|
||||
model: 'test-model',
|
||||
contents: [
|
||||
{
|
||||
role: 'user',
|
||||
parts: [
|
||||
{ text: 'Here is an image:' },
|
||||
{
|
||||
inlineData: {
|
||||
mimeType: 'image/png',
|
||||
data: pngBase64,
|
||||
},
|
||||
},
|
||||
{ text: 'What do you see?' },
|
||||
],
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
const result = await tokenizer.calculateTokens(request);
|
||||
|
||||
expect(result.totalTokens).toBeGreaterThan(4);
|
||||
expect(result.breakdown.textTokens).toBeGreaterThan(0);
|
||||
expect(result.breakdown.imageTokens).toBeGreaterThanOrEqual(4);
|
||||
});
|
||||
});
|
||||
|
||||
describe('function content', () => {
|
||||
it('should handle function calls', async () => {
|
||||
const request: CountTokensParameters = {
|
||||
model: 'test-model',
|
||||
contents: [
|
||||
{
|
||||
role: 'user',
|
||||
parts: [
|
||||
{
|
||||
functionCall: {
|
||||
name: 'test_function',
|
||||
args: { param1: 'value1', param2: 42 },
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
const result = await tokenizer.calculateTokens(request);
|
||||
|
||||
expect(result.totalTokens).toBeGreaterThan(0);
|
||||
expect(result.breakdown.otherTokens).toBeGreaterThan(0);
|
||||
});
|
||||
});
|
||||
|
||||
describe('empty content', () => {
|
||||
it('should handle empty request', async () => {
|
||||
const request: CountTokensParameters = {
|
||||
model: 'test-model',
|
||||
contents: [],
|
||||
};
|
||||
|
||||
const result = await tokenizer.calculateTokens(request);
|
||||
|
||||
expect(result.totalTokens).toBe(0);
|
||||
expect(result.breakdown.textTokens).toBe(0);
|
||||
expect(result.breakdown.imageTokens).toBe(0);
|
||||
});
|
||||
|
||||
it('should handle undefined contents', async () => {
|
||||
const request: CountTokensParameters = {
|
||||
model: 'test-model',
|
||||
contents: [],
|
||||
};
|
||||
|
||||
const result = await tokenizer.calculateTokens(request);
|
||||
|
||||
expect(result.totalTokens).toBe(0);
|
||||
});
|
||||
});
|
||||
|
||||
describe('configuration', () => {
|
||||
it('should use custom text encoding', async () => {
|
||||
const request: CountTokensParameters = {
|
||||
model: 'test-model',
|
||||
contents: [
|
||||
{
|
||||
role: 'user',
|
||||
parts: [{ text: 'Test text for encoding' }],
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
const result = await tokenizer.calculateTokens(request, {
|
||||
textEncoding: 'cl100k_base',
|
||||
});
|
||||
|
||||
expect(result.totalTokens).toBeGreaterThan(0);
|
||||
});
|
||||
|
||||
it('should process multiple images serially', async () => {
|
||||
const pngBase64 =
|
||||
'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNkYPhfDwAChAI9jU77yQAAAABJRU5ErkJggg==';
|
||||
|
||||
const request: CountTokensParameters = {
|
||||
model: 'test-model',
|
||||
contents: [
|
||||
{
|
||||
role: 'user',
|
||||
parts: Array(10).fill({
|
||||
inlineData: {
|
||||
mimeType: 'image/png',
|
||||
data: pngBase64,
|
||||
},
|
||||
}),
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
const result = await tokenizer.calculateTokens(request);
|
||||
|
||||
expect(result.totalTokens).toBeGreaterThanOrEqual(60); // At least 6 tokens per image * 10 images
|
||||
});
|
||||
});
|
||||
|
||||
describe('error handling', () => {
|
||||
it('should handle malformed image data gracefully', async () => {
|
||||
const request: CountTokensParameters = {
|
||||
model: 'test-model',
|
||||
contents: [
|
||||
{
|
||||
role: 'user',
|
||||
parts: [
|
||||
{
|
||||
inlineData: {
|
||||
mimeType: 'image/png',
|
||||
data: 'invalid-base64-data',
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
const result = await tokenizer.calculateTokens(request);
|
||||
|
||||
// Should still return some tokens (fallback to minimum)
|
||||
expect(result.totalTokens).toBeGreaterThanOrEqual(4);
|
||||
});
|
||||
});
|
||||
});
|
||||
341
packages/core/src/utils/request-tokenizer/requestTokenizer.ts
Normal file
341
packages/core/src/utils/request-tokenizer/requestTokenizer.ts
Normal file
@@ -0,0 +1,341 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Qwen
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import type {
|
||||
CountTokensParameters,
|
||||
Content,
|
||||
Part,
|
||||
PartUnion,
|
||||
} from '@google/genai';
|
||||
import type {
|
||||
RequestTokenizer,
|
||||
TokenizerConfig,
|
||||
TokenCalculationResult,
|
||||
} from './types.js';
|
||||
import { TextTokenizer } from './textTokenizer.js';
|
||||
import { ImageTokenizer } from './imageTokenizer.js';
|
||||
|
||||
/**
|
||||
* Simple request tokenizer that handles text and image content serially
|
||||
*/
|
||||
export class DefaultRequestTokenizer implements RequestTokenizer {
|
||||
private textTokenizer: TextTokenizer;
|
||||
private imageTokenizer: ImageTokenizer;
|
||||
|
||||
constructor() {
|
||||
this.textTokenizer = new TextTokenizer();
|
||||
this.imageTokenizer = new ImageTokenizer();
|
||||
}
|
||||
|
||||
/**
|
||||
* Calculate tokens for a request using serial processing
|
||||
*/
|
||||
async calculateTokens(
|
||||
request: CountTokensParameters,
|
||||
config: TokenizerConfig = {},
|
||||
): Promise<TokenCalculationResult> {
|
||||
const startTime = performance.now();
|
||||
|
||||
// Apply configuration
|
||||
if (config.textEncoding) {
|
||||
this.textTokenizer = new TextTokenizer(config.textEncoding);
|
||||
}
|
||||
|
||||
try {
|
||||
// Process request content and group by type
|
||||
const { textContents, imageContents, audioContents, otherContents } =
|
||||
this.processAndGroupContents(request);
|
||||
|
||||
if (
|
||||
textContents.length === 0 &&
|
||||
imageContents.length === 0 &&
|
||||
audioContents.length === 0 &&
|
||||
otherContents.length === 0
|
||||
) {
|
||||
return {
|
||||
totalTokens: 0,
|
||||
breakdown: {
|
||||
textTokens: 0,
|
||||
imageTokens: 0,
|
||||
audioTokens: 0,
|
||||
otherTokens: 0,
|
||||
},
|
||||
processingTime: performance.now() - startTime,
|
||||
};
|
||||
}
|
||||
|
||||
// Calculate tokens for each content type serially
|
||||
const textTokens = await this.calculateTextTokens(textContents);
|
||||
const imageTokens = await this.calculateImageTokens(imageContents);
|
||||
const audioTokens = await this.calculateAudioTokens(audioContents);
|
||||
const otherTokens = await this.calculateOtherTokens(otherContents);
|
||||
|
||||
const totalTokens = textTokens + imageTokens + audioTokens + otherTokens;
|
||||
const processingTime = performance.now() - startTime;
|
||||
|
||||
return {
|
||||
totalTokens,
|
||||
breakdown: {
|
||||
textTokens,
|
||||
imageTokens,
|
||||
audioTokens,
|
||||
otherTokens,
|
||||
},
|
||||
processingTime,
|
||||
};
|
||||
} catch (error) {
|
||||
console.error('Error calculating tokens:', error);
|
||||
|
||||
// Fallback calculation
|
||||
const fallbackTokens = this.calculateFallbackTokens(request);
|
||||
|
||||
return {
|
||||
totalTokens: fallbackTokens,
|
||||
breakdown: {
|
||||
textTokens: fallbackTokens,
|
||||
imageTokens: 0,
|
||||
audioTokens: 0,
|
||||
otherTokens: 0,
|
||||
},
|
||||
processingTime: performance.now() - startTime,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Calculate tokens for text contents
|
||||
*/
|
||||
private async calculateTextTokens(textContents: string[]): Promise<number> {
|
||||
if (textContents.length === 0) return 0;
|
||||
|
||||
try {
|
||||
const tokenCounts =
|
||||
await this.textTokenizer.calculateTokensBatch(textContents);
|
||||
return tokenCounts.reduce((sum, count) => sum + count, 0);
|
||||
} catch (error) {
|
||||
console.warn('Error calculating text tokens:', error);
|
||||
// Fallback: character-based estimation
|
||||
const totalChars = textContents.join('').length;
|
||||
return Math.ceil(totalChars / 4);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Calculate tokens for image contents using serial processing
|
||||
*/
|
||||
private async calculateImageTokens(
|
||||
imageContents: Array<{ data: string; mimeType: string }>,
|
||||
): Promise<number> {
|
||||
if (imageContents.length === 0) return 0;
|
||||
|
||||
try {
|
||||
const tokenCounts =
|
||||
await this.imageTokenizer.calculateTokensBatch(imageContents);
|
||||
return tokenCounts.reduce((sum, count) => sum + count, 0);
|
||||
} catch (error) {
|
||||
console.warn('Error calculating image tokens:', error);
|
||||
// Fallback: minimum tokens per image
|
||||
return imageContents.length * 6; // 4 image tokens + 2 special tokens as minimum
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Calculate tokens for audio contents
|
||||
* TODO: Implement proper audio token calculation
|
||||
*/
|
||||
private async calculateAudioTokens(
|
||||
audioContents: Array<{ data: string; mimeType: string }>,
|
||||
): Promise<number> {
|
||||
if (audioContents.length === 0) return 0;
|
||||
|
||||
// Placeholder implementation - audio token calculation would depend on
|
||||
// the specific model's audio processing capabilities
|
||||
// For now, estimate based on data size
|
||||
let totalTokens = 0;
|
||||
|
||||
for (const audioContent of audioContents) {
|
||||
try {
|
||||
const dataSize = Math.floor(audioContent.data.length * 0.75); // Approximate binary size
|
||||
// Rough estimate: 1 token per 100 bytes of audio data
|
||||
totalTokens += Math.max(Math.ceil(dataSize / 100), 10); // Minimum 10 tokens per audio
|
||||
} catch (error) {
|
||||
console.warn('Error calculating audio tokens:', error);
|
||||
totalTokens += 10; // Fallback minimum
|
||||
}
|
||||
}
|
||||
|
||||
return totalTokens;
|
||||
}
|
||||
|
||||
/**
|
||||
* Calculate tokens for other content types (functions, files, etc.)
|
||||
*/
|
||||
private async calculateOtherTokens(otherContents: string[]): Promise<number> {
|
||||
if (otherContents.length === 0) return 0;
|
||||
|
||||
try {
|
||||
// Treat other content as text for token calculation
|
||||
const tokenCounts =
|
||||
await this.textTokenizer.calculateTokensBatch(otherContents);
|
||||
return tokenCounts.reduce((sum, count) => sum + count, 0);
|
||||
} catch (error) {
|
||||
console.warn('Error calculating other content tokens:', error);
|
||||
// Fallback: character-based estimation
|
||||
const totalChars = otherContents.join('').length;
|
||||
return Math.ceil(totalChars / 4);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Fallback token calculation using simple string serialization
|
||||
*/
|
||||
private calculateFallbackTokens(request: CountTokensParameters): number {
|
||||
try {
|
||||
const content = JSON.stringify(request.contents);
|
||||
return Math.ceil(content.length / 4); // Rough estimate: 1 token ≈ 4 characters
|
||||
} catch (error) {
|
||||
console.warn('Error in fallback token calculation:', error);
|
||||
return 100; // Conservative fallback
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Process request contents and group by type
|
||||
*/
|
||||
private processAndGroupContents(request: CountTokensParameters): {
|
||||
textContents: string[];
|
||||
imageContents: Array<{ data: string; mimeType: string }>;
|
||||
audioContents: Array<{ data: string; mimeType: string }>;
|
||||
otherContents: string[];
|
||||
} {
|
||||
const textContents: string[] = [];
|
||||
const imageContents: Array<{ data: string; mimeType: string }> = [];
|
||||
const audioContents: Array<{ data: string; mimeType: string }> = [];
|
||||
const otherContents: string[] = [];
|
||||
|
||||
if (!request.contents) {
|
||||
return { textContents, imageContents, audioContents, otherContents };
|
||||
}
|
||||
|
||||
const contents = Array.isArray(request.contents)
|
||||
? request.contents
|
||||
: [request.contents];
|
||||
|
||||
for (const content of contents) {
|
||||
this.processContent(
|
||||
content,
|
||||
textContents,
|
||||
imageContents,
|
||||
audioContents,
|
||||
otherContents,
|
||||
);
|
||||
}
|
||||
|
||||
return { textContents, imageContents, audioContents, otherContents };
|
||||
}
|
||||
|
||||
/**
|
||||
* Process a single content item and add to appropriate arrays
|
||||
*/
|
||||
private processContent(
|
||||
content: Content | string | PartUnion,
|
||||
textContents: string[],
|
||||
imageContents: Array<{ data: string; mimeType: string }>,
|
||||
audioContents: Array<{ data: string; mimeType: string }>,
|
||||
otherContents: string[],
|
||||
): void {
|
||||
if (typeof content === 'string') {
|
||||
if (content.trim()) {
|
||||
textContents.push(content);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if ('parts' in content && content.parts) {
|
||||
for (const part of content.parts) {
|
||||
this.processPart(
|
||||
part,
|
||||
textContents,
|
||||
imageContents,
|
||||
audioContents,
|
||||
otherContents,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Process a single part and add to appropriate arrays
|
||||
*/
|
||||
private processPart(
|
||||
part: Part | string,
|
||||
textContents: string[],
|
||||
imageContents: Array<{ data: string; mimeType: string }>,
|
||||
audioContents: Array<{ data: string; mimeType: string }>,
|
||||
otherContents: string[],
|
||||
): void {
|
||||
if (typeof part === 'string') {
|
||||
if (part.trim()) {
|
||||
textContents.push(part);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if ('text' in part && part.text) {
|
||||
textContents.push(part.text);
|
||||
return;
|
||||
}
|
||||
|
||||
if ('inlineData' in part && part.inlineData) {
|
||||
const { data, mimeType } = part.inlineData;
|
||||
if (mimeType && mimeType.startsWith('image/')) {
|
||||
imageContents.push({ data: data || '', mimeType });
|
||||
return;
|
||||
}
|
||||
if (mimeType && mimeType.startsWith('audio/')) {
|
||||
audioContents.push({ data: data || '', mimeType });
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
if ('fileData' in part && part.fileData) {
|
||||
otherContents.push(JSON.stringify(part.fileData));
|
||||
return;
|
||||
}
|
||||
|
||||
if ('functionCall' in part && part.functionCall) {
|
||||
otherContents.push(JSON.stringify(part.functionCall));
|
||||
return;
|
||||
}
|
||||
|
||||
if ('functionResponse' in part && part.functionResponse) {
|
||||
otherContents.push(JSON.stringify(part.functionResponse));
|
||||
return;
|
||||
}
|
||||
|
||||
// Unknown part type - try to serialize
|
||||
try {
|
||||
const serialized = JSON.stringify(part);
|
||||
if (serialized && serialized !== '{}') {
|
||||
otherContents.push(serialized);
|
||||
}
|
||||
} catch (error) {
|
||||
console.warn('Failed to serialize unknown part type:', error);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Dispose of resources
|
||||
*/
|
||||
async dispose(): Promise<void> {
|
||||
try {
|
||||
// Dispose of tokenizers
|
||||
this.textTokenizer.dispose();
|
||||
} catch (error) {
|
||||
console.warn('Error disposing request tokenizer:', error);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,56 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Qwen
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
/**
|
||||
* Supported image MIME types for vision models
|
||||
* These formats are supported by the vision model and can be processed by the image tokenizer
|
||||
*/
|
||||
export const SUPPORTED_IMAGE_MIME_TYPES = [
|
||||
'image/bmp',
|
||||
'image/jpeg',
|
||||
'image/jpg', // Alternative MIME type for JPEG
|
||||
'image/png',
|
||||
'image/tiff',
|
||||
'image/webp',
|
||||
'image/heic',
|
||||
] as const;
|
||||
|
||||
/**
|
||||
* Type for supported image MIME types
|
||||
*/
|
||||
export type SupportedImageMimeType =
|
||||
(typeof SUPPORTED_IMAGE_MIME_TYPES)[number];
|
||||
|
||||
/**
|
||||
* Check if a MIME type is supported for vision processing
|
||||
* @param mimeType The MIME type to check
|
||||
* @returns True if the MIME type is supported
|
||||
*/
|
||||
export function isSupportedImageMimeType(
|
||||
mimeType: string,
|
||||
): mimeType is SupportedImageMimeType {
|
||||
return SUPPORTED_IMAGE_MIME_TYPES.includes(
|
||||
mimeType as SupportedImageMimeType,
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get a human-readable list of supported image formats
|
||||
* @returns Comma-separated string of supported formats
|
||||
*/
|
||||
export function getSupportedImageFormatsString(): string {
|
||||
return SUPPORTED_IMAGE_MIME_TYPES.map((type) =>
|
||||
type.replace('image/', '').toUpperCase(),
|
||||
).join(', ');
|
||||
}
|
||||
|
||||
/**
|
||||
* Get warning message for unsupported image formats
|
||||
* @returns Warning message string
|
||||
*/
|
||||
export function getUnsupportedImageFormatWarning(): string {
|
||||
return `Only the following image formats are supported: ${getSupportedImageFormatsString()}. Other formats may not work as expected.`;
|
||||
}
|
||||
347
packages/core/src/utils/request-tokenizer/textTokenizer.test.ts
Normal file
347
packages/core/src/utils/request-tokenizer/textTokenizer.test.ts
Normal file
@@ -0,0 +1,347 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Qwen
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
|
||||
import { TextTokenizer } from './textTokenizer.js';
|
||||
|
||||
// Mock tiktoken at the top level with hoisted functions
|
||||
const mockEncode = vi.hoisted(() => vi.fn());
|
||||
const mockFree = vi.hoisted(() => vi.fn());
|
||||
const mockGetEncoding = vi.hoisted(() => vi.fn());
|
||||
|
||||
vi.mock('tiktoken', () => ({
|
||||
get_encoding: mockGetEncoding,
|
||||
}));
|
||||
|
||||
describe('TextTokenizer', () => {
|
||||
let tokenizer: TextTokenizer;
|
||||
let consoleWarnSpy: ReturnType<typeof vi.spyOn>;
|
||||
|
||||
beforeEach(() => {
|
||||
vi.resetAllMocks();
|
||||
consoleWarnSpy = vi.spyOn(console, 'warn').mockImplementation(() => {});
|
||||
|
||||
// Default mock implementation
|
||||
mockGetEncoding.mockReturnValue({
|
||||
encode: mockEncode,
|
||||
free: mockFree,
|
||||
});
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.restoreAllMocks();
|
||||
tokenizer?.dispose();
|
||||
});
|
||||
|
||||
describe('constructor', () => {
|
||||
it('should create tokenizer with default encoding', () => {
|
||||
tokenizer = new TextTokenizer();
|
||||
expect(tokenizer).toBeInstanceOf(TextTokenizer);
|
||||
});
|
||||
|
||||
it('should create tokenizer with custom encoding', () => {
|
||||
tokenizer = new TextTokenizer('gpt2');
|
||||
expect(tokenizer).toBeInstanceOf(TextTokenizer);
|
||||
});
|
||||
});
|
||||
|
||||
describe('calculateTokens', () => {
|
||||
beforeEach(() => {
|
||||
tokenizer = new TextTokenizer();
|
||||
});
|
||||
|
||||
it('should return 0 for empty text', async () => {
|
||||
const result = await tokenizer.calculateTokens('');
|
||||
expect(result).toBe(0);
|
||||
});
|
||||
|
||||
it('should return 0 for null/undefined text', async () => {
|
||||
const result1 = await tokenizer.calculateTokens(
|
||||
null as unknown as string,
|
||||
);
|
||||
const result2 = await tokenizer.calculateTokens(
|
||||
undefined as unknown as string,
|
||||
);
|
||||
expect(result1).toBe(0);
|
||||
expect(result2).toBe(0);
|
||||
});
|
||||
|
||||
it('should calculate tokens using tiktoken when available', async () => {
|
||||
const testText = 'Hello, world!';
|
||||
const mockTokens = [1, 2, 3, 4, 5]; // 5 tokens
|
||||
mockEncode.mockReturnValue(mockTokens);
|
||||
|
||||
const result = await tokenizer.calculateTokens(testText);
|
||||
|
||||
expect(mockGetEncoding).toHaveBeenCalledWith('cl100k_base');
|
||||
expect(mockEncode).toHaveBeenCalledWith(testText);
|
||||
expect(result).toBe(5);
|
||||
});
|
||||
|
||||
it('should use fallback calculation when tiktoken fails to load', async () => {
|
||||
mockGetEncoding.mockImplementation(() => {
|
||||
throw new Error('Failed to load tiktoken');
|
||||
});
|
||||
|
||||
const testText = 'Hello, world!'; // 13 characters
|
||||
const result = await tokenizer.calculateTokens(testText);
|
||||
|
||||
expect(consoleWarnSpy).toHaveBeenCalledWith(
|
||||
'Failed to load tiktoken with encoding cl100k_base:',
|
||||
expect.any(Error),
|
||||
);
|
||||
// Fallback: Math.ceil(13 / 4) = 4
|
||||
expect(result).toBe(4);
|
||||
});
|
||||
|
||||
it('should use fallback calculation when encoding fails', async () => {
|
||||
mockEncode.mockImplementation(() => {
|
||||
throw new Error('Encoding failed');
|
||||
});
|
||||
|
||||
const testText = 'Hello, world!'; // 13 characters
|
||||
const result = await tokenizer.calculateTokens(testText);
|
||||
|
||||
expect(consoleWarnSpy).toHaveBeenCalledWith(
|
||||
'Error encoding text with tiktoken:',
|
||||
expect.any(Error),
|
||||
);
|
||||
// Fallback: Math.ceil(13 / 4) = 4
|
||||
expect(result).toBe(4);
|
||||
});
|
||||
|
||||
it('should handle very long text', async () => {
|
||||
const longText = 'a'.repeat(10000);
|
||||
const mockTokens = new Array(2500); // 2500 tokens
|
||||
mockEncode.mockReturnValue(mockTokens);
|
||||
|
||||
const result = await tokenizer.calculateTokens(longText);
|
||||
|
||||
expect(result).toBe(2500);
|
||||
});
|
||||
|
||||
it('should handle unicode characters', async () => {
|
||||
const unicodeText = '你好世界 🌍';
|
||||
const mockTokens = [1, 2, 3, 4, 5, 6];
|
||||
mockEncode.mockReturnValue(mockTokens);
|
||||
|
||||
const result = await tokenizer.calculateTokens(unicodeText);
|
||||
|
||||
expect(result).toBe(6);
|
||||
});
|
||||
|
||||
it('should use custom encoding when specified', async () => {
|
||||
tokenizer = new TextTokenizer('gpt2');
|
||||
const testText = 'Hello, world!';
|
||||
const mockTokens = [1, 2, 3];
|
||||
mockEncode.mockReturnValue(mockTokens);
|
||||
|
||||
const result = await tokenizer.calculateTokens(testText);
|
||||
|
||||
expect(mockGetEncoding).toHaveBeenCalledWith('gpt2');
|
||||
expect(result).toBe(3);
|
||||
});
|
||||
});
|
||||
|
||||
describe('calculateTokensBatch', () => {
|
||||
beforeEach(() => {
|
||||
tokenizer = new TextTokenizer();
|
||||
});
|
||||
|
||||
it('should process multiple texts and return token counts', async () => {
|
||||
const texts = ['Hello', 'world', 'test'];
|
||||
mockEncode
|
||||
.mockReturnValueOnce([1, 2]) // 2 tokens for 'Hello'
|
||||
.mockReturnValueOnce([3, 4, 5]) // 3 tokens for 'world'
|
||||
.mockReturnValueOnce([6]); // 1 token for 'test'
|
||||
|
||||
const result = await tokenizer.calculateTokensBatch(texts);
|
||||
|
||||
expect(result).toEqual([2, 3, 1]);
|
||||
expect(mockEncode).toHaveBeenCalledTimes(3);
|
||||
});
|
||||
|
||||
it('should handle empty array', async () => {
|
||||
const result = await tokenizer.calculateTokensBatch([]);
|
||||
expect(result).toEqual([]);
|
||||
});
|
||||
|
||||
it('should handle array with empty strings', async () => {
|
||||
const texts = ['', 'hello', ''];
|
||||
mockEncode.mockReturnValue([1, 2, 3]); // Only called for 'hello'
|
||||
|
||||
const result = await tokenizer.calculateTokensBatch(texts);
|
||||
|
||||
expect(result).toEqual([0, 3, 0]);
|
||||
expect(mockEncode).toHaveBeenCalledTimes(1);
|
||||
expect(mockEncode).toHaveBeenCalledWith('hello');
|
||||
});
|
||||
|
||||
it('should use fallback calculation when tiktoken fails to load', async () => {
|
||||
mockGetEncoding.mockImplementation(() => {
|
||||
throw new Error('Failed to load tiktoken');
|
||||
});
|
||||
|
||||
const texts = ['Hello', 'world']; // 5 and 5 characters
|
||||
const result = await tokenizer.calculateTokensBatch(texts);
|
||||
|
||||
expect(consoleWarnSpy).toHaveBeenCalledWith(
|
||||
'Failed to load tiktoken with encoding cl100k_base:',
|
||||
expect.any(Error),
|
||||
);
|
||||
// Fallback: Math.ceil(5/4) = 2 for both
|
||||
expect(result).toEqual([2, 2]);
|
||||
});
|
||||
|
||||
it('should use fallback calculation when encoding fails during batch processing', async () => {
|
||||
mockEncode.mockImplementation(() => {
|
||||
throw new Error('Encoding failed');
|
||||
});
|
||||
|
||||
const texts = ['Hello', 'world']; // 5 and 5 characters
|
||||
const result = await tokenizer.calculateTokensBatch(texts);
|
||||
|
||||
expect(consoleWarnSpy).toHaveBeenCalledWith(
|
||||
'Error encoding texts with tiktoken:',
|
||||
expect.any(Error),
|
||||
);
|
||||
// Fallback: Math.ceil(5/4) = 2 for both
|
||||
expect(result).toEqual([2, 2]);
|
||||
});
|
||||
|
||||
it('should handle null and undefined values in batch', async () => {
|
||||
const texts = [null, 'hello', undefined, 'world'] as unknown as string[];
|
||||
mockEncode
|
||||
.mockReturnValueOnce([1, 2, 3]) // 3 tokens for 'hello'
|
||||
.mockReturnValueOnce([4, 5]); // 2 tokens for 'world'
|
||||
|
||||
const result = await tokenizer.calculateTokensBatch(texts);
|
||||
|
||||
expect(result).toEqual([0, 3, 0, 2]);
|
||||
});
|
||||
});
|
||||
|
||||
describe('dispose', () => {
|
||||
beforeEach(() => {
|
||||
tokenizer = new TextTokenizer();
|
||||
});
|
||||
|
||||
it('should free tiktoken encoding when disposing', async () => {
|
||||
// Initialize the encoding by calling calculateTokens
|
||||
await tokenizer.calculateTokens('test');
|
||||
|
||||
tokenizer.dispose();
|
||||
|
||||
expect(mockFree).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should handle disposal when encoding is not initialized', () => {
|
||||
expect(() => tokenizer.dispose()).not.toThrow();
|
||||
expect(mockFree).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should handle disposal when encoding is null', async () => {
|
||||
// Force encoding to be null by making tiktoken fail
|
||||
mockGetEncoding.mockImplementation(() => {
|
||||
throw new Error('Failed to load');
|
||||
});
|
||||
|
||||
await tokenizer.calculateTokens('test');
|
||||
|
||||
expect(() => tokenizer.dispose()).not.toThrow();
|
||||
expect(mockFree).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should handle errors during disposal gracefully', async () => {
|
||||
await tokenizer.calculateTokens('test');
|
||||
|
||||
mockFree.mockImplementation(() => {
|
||||
throw new Error('Free failed');
|
||||
});
|
||||
|
||||
tokenizer.dispose();
|
||||
|
||||
expect(consoleWarnSpy).toHaveBeenCalledWith(
|
||||
'Error freeing tiktoken encoding:',
|
||||
expect.any(Error),
|
||||
);
|
||||
});
|
||||
|
||||
it('should allow multiple calls to dispose', async () => {
|
||||
await tokenizer.calculateTokens('test');
|
||||
|
||||
tokenizer.dispose();
|
||||
tokenizer.dispose(); // Second call should not throw
|
||||
|
||||
expect(mockFree).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
});
|
||||
|
||||
describe('lazy initialization', () => {
|
||||
beforeEach(() => {
|
||||
tokenizer = new TextTokenizer();
|
||||
});
|
||||
|
||||
it('should not initialize tiktoken until first use', () => {
|
||||
expect(mockGetEncoding).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should initialize tiktoken on first calculateTokens call', async () => {
|
||||
await tokenizer.calculateTokens('test');
|
||||
expect(mockGetEncoding).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
|
||||
it('should not reinitialize tiktoken on subsequent calls', async () => {
|
||||
await tokenizer.calculateTokens('test1');
|
||||
await tokenizer.calculateTokens('test2');
|
||||
|
||||
expect(mockGetEncoding).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
|
||||
it('should initialize tiktoken on first calculateTokensBatch call', async () => {
|
||||
await tokenizer.calculateTokensBatch(['test']);
|
||||
expect(mockGetEncoding).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
});
|
||||
|
||||
describe('edge cases', () => {
|
||||
beforeEach(() => {
|
||||
tokenizer = new TextTokenizer();
|
||||
});
|
||||
|
||||
it('should handle very short text', async () => {
|
||||
const result = await tokenizer.calculateTokens('a');
|
||||
|
||||
if (mockGetEncoding.mock.calls.length > 0) {
|
||||
// If tiktoken was called, use its result
|
||||
expect(mockEncode).toHaveBeenCalledWith('a');
|
||||
} else {
|
||||
// If tiktoken failed, should use fallback: Math.ceil(1/4) = 1
|
||||
expect(result).toBe(1);
|
||||
}
|
||||
});
|
||||
|
||||
it('should handle text with only whitespace', async () => {
|
||||
const whitespaceText = ' \n\t ';
|
||||
const mockTokens = [1];
|
||||
mockEncode.mockReturnValue(mockTokens);
|
||||
|
||||
const result = await tokenizer.calculateTokens(whitespaceText);
|
||||
|
||||
expect(result).toBe(1);
|
||||
});
|
||||
|
||||
it('should handle special characters and symbols', async () => {
|
||||
const specialText = '!@#$%^&*()_+-=[]{}|;:,.<>?';
|
||||
const mockTokens = new Array(10);
|
||||
mockEncode.mockReturnValue(mockTokens);
|
||||
|
||||
const result = await tokenizer.calculateTokens(specialText);
|
||||
|
||||
expect(result).toBe(10);
|
||||
});
|
||||
});
|
||||
});
|
||||
97
packages/core/src/utils/request-tokenizer/textTokenizer.ts
Normal file
97
packages/core/src/utils/request-tokenizer/textTokenizer.ts
Normal file
@@ -0,0 +1,97 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Qwen
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import type { TiktokenEncoding, Tiktoken } from 'tiktoken';
|
||||
import { get_encoding } from 'tiktoken';
|
||||
|
||||
/**
|
||||
* Text tokenizer for calculating text tokens using tiktoken
|
||||
*/
|
||||
export class TextTokenizer {
|
||||
private encoding: Tiktoken | null = null;
|
||||
private encodingName: string;
|
||||
|
||||
constructor(encodingName: string = 'cl100k_base') {
|
||||
this.encodingName = encodingName;
|
||||
}
|
||||
|
||||
/**
|
||||
* Initialize the tokenizer (lazy loading)
|
||||
*/
|
||||
private async ensureEncoding(): Promise<void> {
|
||||
if (this.encoding) return;
|
||||
|
||||
try {
|
||||
// Use type assertion since we know the encoding name is valid
|
||||
this.encoding = get_encoding(this.encodingName as TiktokenEncoding);
|
||||
} catch (error) {
|
||||
console.warn(
|
||||
`Failed to load tiktoken with encoding ${this.encodingName}:`,
|
||||
error,
|
||||
);
|
||||
this.encoding = null;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Calculate tokens for text content
|
||||
*/
|
||||
async calculateTokens(text: string): Promise<number> {
|
||||
if (!text) return 0;
|
||||
|
||||
await this.ensureEncoding();
|
||||
|
||||
if (this.encoding) {
|
||||
try {
|
||||
return this.encoding.encode(text).length;
|
||||
} catch (error) {
|
||||
console.warn('Error encoding text with tiktoken:', error);
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback: rough approximation using character count
|
||||
// This is a conservative estimate: 1 token ≈ 4 characters for most languages
|
||||
return Math.ceil(text.length / 4);
|
||||
}
|
||||
|
||||
/**
|
||||
* Calculate tokens for multiple text strings in parallel
|
||||
*/
|
||||
async calculateTokensBatch(texts: string[]): Promise<number[]> {
|
||||
await this.ensureEncoding();
|
||||
|
||||
if (this.encoding) {
|
||||
try {
|
||||
return texts.map((text) => {
|
||||
if (!text) return 0;
|
||||
// this.encoding may be null, add a null check to satisfy lint
|
||||
return this.encoding ? this.encoding.encode(text).length : 0;
|
||||
});
|
||||
} catch (error) {
|
||||
console.warn('Error encoding texts with tiktoken:', error);
|
||||
// In case of error, return fallback estimation for all texts
|
||||
return texts.map((text) => Math.ceil((text || '').length / 4));
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback for batch processing
|
||||
return texts.map((text) => Math.ceil((text || '').length / 4));
|
||||
}
|
||||
|
||||
/**
|
||||
* Dispose of resources
|
||||
*/
|
||||
dispose(): void {
|
||||
if (this.encoding) {
|
||||
try {
|
||||
this.encoding.free();
|
||||
} catch (error) {
|
||||
console.warn('Error freeing tiktoken encoding:', error);
|
||||
}
|
||||
this.encoding = null;
|
||||
}
|
||||
}
|
||||
}
|
||||
64
packages/core/src/utils/request-tokenizer/types.ts
Normal file
64
packages/core/src/utils/request-tokenizer/types.ts
Normal file
@@ -0,0 +1,64 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Qwen
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import type { CountTokensParameters } from '@google/genai';
|
||||
|
||||
/**
|
||||
* Token calculation result for different content types
|
||||
*/
|
||||
export interface TokenCalculationResult {
|
||||
/** Total tokens calculated */
|
||||
totalTokens: number;
|
||||
/** Breakdown by content type */
|
||||
breakdown: {
|
||||
textTokens: number;
|
||||
imageTokens: number;
|
||||
audioTokens: number;
|
||||
otherTokens: number;
|
||||
};
|
||||
/** Processing time in milliseconds */
|
||||
processingTime: number;
|
||||
}
|
||||
|
||||
/**
|
||||
* Configuration for token calculation
|
||||
*/
|
||||
export interface TokenizerConfig {
|
||||
/** Custom text tokenizer encoding (defaults to cl100k_base) */
|
||||
textEncoding?: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* Image metadata extracted from base64 data
|
||||
*/
|
||||
export interface ImageMetadata {
|
||||
/** Image width in pixels */
|
||||
width: number;
|
||||
/** Image height in pixels */
|
||||
height: number;
|
||||
/** MIME type of the image */
|
||||
mimeType: string;
|
||||
/** Size of the base64 data in bytes */
|
||||
dataSize: number;
|
||||
}
|
||||
|
||||
/**
|
||||
* Request tokenizer interface
|
||||
*/
|
||||
export interface RequestTokenizer {
|
||||
/**
|
||||
* Calculate tokens for a request
|
||||
*/
|
||||
calculateTokens(
|
||||
request: CountTokensParameters,
|
||||
config?: TokenizerConfig,
|
||||
): Promise<TokenCalculationResult>;
|
||||
|
||||
/**
|
||||
* Dispose of resources (worker threads, etc.)
|
||||
*/
|
||||
dispose(): Promise<void>;
|
||||
}
|
||||
Reference in New Issue
Block a user