Compare commits

..

13 Commits

Author SHA1 Message Date
mingholy.lmh
bf206237fa fix: align supported image formats with bailian doc 2025-09-18 11:12:51 +08:00
mingholy.lmh
9229134086 fix: remove deprecated files 2025-09-18 10:08:01 +08:00
mingholy.lmh
a8ca4ebf89 feat: add visionModelPreview to control default visibility of vision models 2025-09-17 20:58:54 +08:00
mingholy.lmh
caedd8338f fix: lint and type errors 2025-09-17 20:29:16 +08:00
mingholy.lmh
b4ba23fd80 Merge branch 'main' into feat/vision-model-autoswitch 2025-09-17 20:20:28 +08:00
mingholy.lmh
93fbc54f88 feat: add image tokenizer to fit vlm context window 2025-09-17 20:10:19 +08:00
mingholy.lmh
30b463b7ee fix: lint error 2025-09-16 11:18:36 +08:00
mingholy.lmh
e969dbd5d2 Merge branch 'main' into feat/vision-model-autoswitch 2025-09-16 11:04:32 +08:00
mingholy.lmh
71cf4fbae0 feat: /model command for switching to vision model 2025-09-05 16:06:20 +08:00
mingholy.lmh
413be4467f fix: unit test cases 2025-09-04 12:18:41 +08:00
mingholy.lmh
6005051713 refactor: re-organize refactored files 2025-09-04 12:00:00 +08:00
mingholy.lmh
65549193c1 refactor: optimize stream handling 2025-09-03 21:43:25 +08:00
mingholy.lmh
002f1e2f36 refactor: openaiContentGenerator 2025-09-02 15:13:46 +08:00
58 changed files with 4222 additions and 5703 deletions

View File

@@ -133,28 +133,6 @@ Focus on creating clear, comprehensive documentation that helps both
new contributors and end users understand the project.
```
## Using Subagents Effectively
### Automatic Delegation
Qwen Code proactively delegates tasks based on:
- The task description in your request
- The description field in subagent configurations
- Current context and available tools
To encourage more proactive subagent use, include phrases like "use PROACTIVELY" or "MUST BE USED" in your description field.
### Explicit Invocation
Request a specific subagent by mentioning it in your command:
```
> Let the testing-expert subagent create unit tests for the payment module
> Have the documentation-writer subagent update the API reference
> Get the react-specialist subagent to optimize this component's performance
```
## Examples
### Development Workflow Agents

View File

@@ -741,6 +741,16 @@ export const SETTINGS_SCHEMA = {
description: 'Enable extension management features.',
showInDialog: false,
},
visionModelPreview: {
type: 'boolean',
label: 'Vision Model Preview',
category: 'Experimental',
requiresRestart: false,
default: false,
description:
'Enable vision model support and auto-switching functionality. When disabled, vision models like qwen-vl-max-latest will be hidden and auto-switching will not occur.',
showInDialog: true,
},
},
},

View File

@@ -56,6 +56,13 @@ vi.mock('../ui/commands/mcpCommand.js', () => ({
kind: 'BUILT_IN',
},
}));
vi.mock('../ui/commands/modelCommand.js', () => ({
modelCommand: {
name: 'model',
description: 'Model command',
kind: 'BUILT_IN',
},
}));
describe('BuiltinCommandLoader', () => {
let mockConfig: Config;
@@ -126,5 +133,8 @@ describe('BuiltinCommandLoader', () => {
const mcpCmd = commands.find((c) => c.name === 'mcp');
expect(mcpCmd).toBeDefined();
const modelCmd = commands.find((c) => c.name === 'model');
expect(modelCmd).toBeDefined();
});
});

View File

@@ -35,6 +35,7 @@ import { settingsCommand } from '../ui/commands/settingsCommand.js';
import { vimCommand } from '../ui/commands/vimCommand.js';
import { setupGithubCommand } from '../ui/commands/setupGithubCommand.js';
import { terminalSetupCommand } from '../ui/commands/terminalSetupCommand.js';
import { modelCommand } from '../ui/commands/modelCommand.js';
import { agentsCommand } from '../ui/commands/agentsCommand.js';
/**
@@ -71,6 +72,7 @@ export class BuiltinCommandLoader implements ICommandLoader {
initCommand,
mcpCommand,
memoryCommand,
modelCommand,
privacyCommand,
quitCommand,
quitConfirmCommand,

View File

@@ -53,6 +53,17 @@ import { FolderTrustDialog } from './components/FolderTrustDialog.js';
import { ShellConfirmationDialog } from './components/ShellConfirmationDialog.js';
import { QuitConfirmationDialog } from './components/QuitConfirmationDialog.js';
import { RadioButtonSelect } from './components/shared/RadioButtonSelect.js';
import { ModelSelectionDialog } from './components/ModelSelectionDialog.js';
import {
ModelSwitchDialog,
type VisionSwitchOutcome,
} from './components/ModelSwitchDialog.js';
import {
getOpenAIAvailableModelFromEnv,
getFilteredQwenModels,
type AvailableModel,
} from './models/availableModels.js';
import { processVisionSwitchOutcome } from './hooks/useVisionAutoSwitch.js';
import {
AgentCreationWizard,
AgentsManagerDialog,
@@ -248,6 +259,20 @@ const App = ({ config, settings, startupWarnings = [], version }: AppProps) => {
onWorkspaceMigrationDialogClose,
} = useWorkspaceMigration(settings);
// Model selection dialog states
const [isModelSelectionDialogOpen, setIsModelSelectionDialogOpen] =
useState(false);
const [isVisionSwitchDialogOpen, setIsVisionSwitchDialogOpen] =
useState(false);
const [visionSwitchResolver, setVisionSwitchResolver] = useState<{
resolve: (result: {
modelOverride?: string;
persistSessionModel?: string;
showGuidance?: boolean;
}) => void;
reject: () => void;
} | null>(null);
useEffect(() => {
const unsubscribe = ideContext.subscribeToIdeContext(setIdeContextState);
// Set the initial value
@@ -590,6 +615,75 @@ const App = ({ config, settings, startupWarnings = [], version }: AppProps) => {
openAuthDialog();
}, [openAuthDialog, setAuthError]);
// Vision switch handler for auto-switch functionality
const handleVisionSwitchRequired = useCallback(
async (_query: unknown) =>
new Promise<{
modelOverride?: string;
persistSessionModel?: string;
showGuidance?: boolean;
}>((resolve, reject) => {
setVisionSwitchResolver({ resolve, reject });
setIsVisionSwitchDialogOpen(true);
}),
[],
);
const handleVisionSwitchSelect = useCallback(
(outcome: VisionSwitchOutcome) => {
setIsVisionSwitchDialogOpen(false);
if (visionSwitchResolver) {
const result = processVisionSwitchOutcome(outcome);
visionSwitchResolver.resolve(result);
setVisionSwitchResolver(null);
}
},
[visionSwitchResolver],
);
const handleModelSelectionOpen = useCallback(() => {
setIsModelSelectionDialogOpen(true);
}, []);
const handleModelSelectionClose = useCallback(() => {
setIsModelSelectionDialogOpen(false);
}, []);
const handleModelSelect = useCallback(
(modelId: string) => {
config.setModel(modelId);
setCurrentModel(modelId);
setIsModelSelectionDialogOpen(false);
addItem(
{
type: MessageType.INFO,
text: `Switched model to \`${modelId}\` for this session.`,
},
Date.now(),
);
},
[config, setCurrentModel, addItem],
);
const getAvailableModelsForCurrentAuth = useCallback((): AvailableModel[] => {
const contentGeneratorConfig = config.getContentGeneratorConfig();
if (!contentGeneratorConfig) return [];
const visionModelPreviewEnabled =
settings.merged.experimental?.visionModelPreview ?? false;
switch (contentGeneratorConfig.authType) {
case AuthType.QWEN_OAUTH:
return getFilteredQwenModels(visionModelPreviewEnabled);
case AuthType.USE_OPENAI: {
const openAIModel = getOpenAIAvailableModelFromEnv();
return openAIModel ? [openAIModel] : [];
}
default:
return [];
}
}, [config, settings.merged.experimental?.visionModelPreview]);
// Core hooks and processors
const {
vimEnabled: vimModeEnabled,
@@ -620,6 +714,7 @@ const App = ({ config, settings, startupWarnings = [], version }: AppProps) => {
setQuittingMessages,
openPrivacyNotice,
openSettingsDialog,
handleModelSelectionOpen,
openSubagentCreateDialog,
openAgentsManagerDialog,
toggleVimEnabled,
@@ -664,16 +759,12 @@ const App = ({ config, settings, startupWarnings = [], version }: AppProps) => {
setModelSwitchedFromQuotaError,
refreshStatic,
() => cancelHandlerRef.current(),
settings.merged.experimental?.visionModelPreview ?? false,
handleVisionSwitchRequired,
);
const pendingHistoryItems = useMemo(
() =>
[...pendingSlashCommandHistoryItems, ...pendingGeminiHistoryItems].map(
(item, index) => ({
...item,
id: index,
}),
),
() => [...pendingSlashCommandHistoryItems, ...pendingGeminiHistoryItems],
[pendingSlashCommandHistoryItems, pendingGeminiHistoryItems],
);
@@ -1034,6 +1125,8 @@ const App = ({ config, settings, startupWarnings = [], version }: AppProps) => {
!isAuthDialogOpen &&
!isThemeDialogOpen &&
!isEditorDialogOpen &&
!isModelSelectionDialogOpen &&
!isVisionSwitchDialogOpen &&
!isSubagentCreateDialogOpen &&
!showPrivacyNotice &&
!showWelcomeBackDialog &&
@@ -1055,6 +1148,8 @@ const App = ({ config, settings, startupWarnings = [], version }: AppProps) => {
showWelcomeBackDialog,
welcomeBackChoice,
geminiClient,
isModelSelectionDialogOpen,
isVisionSwitchDialogOpen,
]);
if (quittingMessages) {
@@ -1127,14 +1222,16 @@ const App = ({ config, settings, startupWarnings = [], version }: AppProps) => {
</Static>
<OverflowProvider>
<Box ref={pendingHistoryItemRef} flexDirection="column">
{pendingHistoryItems.map((item) => (
{pendingHistoryItems.map((item, i) => (
<HistoryItemDisplay
key={item.id}
key={i}
availableTerminalHeight={
constrainHeight ? availableTerminalHeight : undefined
}
terminalWidth={mainAreaWidth}
item={item}
// TODO(taehykim): It seems like references to ids aren't necessary in
// HistoryItemDisplay. Refactor later. Use a fake id for now.
item={{ ...item, id: 0 }}
isPending={true}
config={config}
isFocused={!isEditorDialogOpen}
@@ -1322,6 +1419,15 @@ const App = ({ config, settings, startupWarnings = [], version }: AppProps) => {
onExit={exitEditorDialog}
/>
</Box>
) : isModelSelectionDialogOpen ? (
<ModelSelectionDialog
availableModels={getAvailableModelsForCurrentAuth()}
currentModel={currentModel}
onSelect={handleModelSelect}
onCancel={handleModelSelectionClose}
/>
) : isVisionSwitchDialogOpen ? (
<ModelSwitchDialog onSelect={handleVisionSwitchSelect} />
) : showPrivacyNotice ? (
<PrivacyNotice
onExit={() => setShowPrivacyNotice(false)}

View File

@@ -0,0 +1,179 @@
/**
* @license
* Copyright 2025 Qwen
* SPDX-License-Identifier: Apache-2.0
*/
import { describe, it, expect, beforeEach, vi } from 'vitest';
import { modelCommand } from './modelCommand.js';
import { type CommandContext } from './types.js';
import { createMockCommandContext } from '../../test-utils/mockCommandContext.js';
import {
AuthType,
type ContentGeneratorConfig,
type Config,
} from '@qwen-code/qwen-code-core';
import * as availableModelsModule from '../models/availableModels.js';
// Mock the availableModels module
vi.mock('../models/availableModels.js', () => ({
AVAILABLE_MODELS_QWEN: [
{ id: 'qwen3-coder-plus', label: 'qwen3-coder-plus' },
{ id: 'qwen-vl-max-latest', label: 'qwen-vl-max', isVision: true },
],
getOpenAIAvailableModelFromEnv: vi.fn(),
}));
// Helper function to create a mock config
function createMockConfig(
contentGeneratorConfig: ContentGeneratorConfig | null,
): Partial<Config> {
return {
getContentGeneratorConfig: vi.fn().mockReturnValue(contentGeneratorConfig),
};
}
describe('modelCommand', () => {
let mockContext: CommandContext;
const mockGetOpenAIAvailableModelFromEnv = vi.mocked(
availableModelsModule.getOpenAIAvailableModelFromEnv,
);
beforeEach(() => {
mockContext = createMockCommandContext();
vi.clearAllMocks();
});
it('should have the correct name and description', () => {
expect(modelCommand.name).toBe('model');
expect(modelCommand.description).toBe('Switch the model for this session');
});
it('should return error when config is not available', async () => {
mockContext.services.config = null;
const result = await modelCommand.action!(mockContext, '');
expect(result).toEqual({
type: 'message',
messageType: 'error',
content: 'Configuration not available.',
});
});
it('should return error when content generator config is not available', async () => {
const mockConfig = createMockConfig(null);
mockContext.services.config = mockConfig as Config;
const result = await modelCommand.action!(mockContext, '');
expect(result).toEqual({
type: 'message',
messageType: 'error',
content: 'Content generator configuration not available.',
});
});
it('should return error when auth type is not available', async () => {
const mockConfig = createMockConfig({
model: 'test-model',
authType: undefined,
});
mockContext.services.config = mockConfig as Config;
const result = await modelCommand.action!(mockContext, '');
expect(result).toEqual({
type: 'message',
messageType: 'error',
content: 'Authentication type not available.',
});
});
it('should return dialog action for QWEN_OAUTH auth type', async () => {
const mockConfig = createMockConfig({
model: 'test-model',
authType: AuthType.QWEN_OAUTH,
});
mockContext.services.config = mockConfig as Config;
const result = await modelCommand.action!(mockContext, '');
expect(result).toEqual({
type: 'dialog',
dialog: 'model',
});
});
it('should return dialog action for USE_OPENAI auth type when model is available', async () => {
mockGetOpenAIAvailableModelFromEnv.mockReturnValue({
id: 'gpt-4',
label: 'gpt-4',
});
const mockConfig = createMockConfig({
model: 'test-model',
authType: AuthType.USE_OPENAI,
});
mockContext.services.config = mockConfig as Config;
const result = await modelCommand.action!(mockContext, '');
expect(result).toEqual({
type: 'dialog',
dialog: 'model',
});
});
it('should return error for USE_OPENAI auth type when no model is available', async () => {
mockGetOpenAIAvailableModelFromEnv.mockReturnValue(null);
const mockConfig = createMockConfig({
model: 'test-model',
authType: AuthType.USE_OPENAI,
});
mockContext.services.config = mockConfig as Config;
const result = await modelCommand.action!(mockContext, '');
expect(result).toEqual({
type: 'message',
messageType: 'error',
content:
'No models available for the current authentication type (openai).',
});
});
it('should return error for unsupported auth types', async () => {
const mockConfig = createMockConfig({
model: 'test-model',
authType: 'UNSUPPORTED_AUTH_TYPE' as AuthType,
});
mockContext.services.config = mockConfig as Config;
const result = await modelCommand.action!(mockContext, '');
expect(result).toEqual({
type: 'message',
messageType: 'error',
content:
'No models available for the current authentication type (UNSUPPORTED_AUTH_TYPE).',
});
});
it('should handle undefined auth type', async () => {
const mockConfig = createMockConfig({
model: 'test-model',
authType: undefined,
});
mockContext.services.config = mockConfig as Config;
const result = await modelCommand.action!(mockContext, '');
expect(result).toEqual({
type: 'message',
messageType: 'error',
content: 'Authentication type not available.',
});
});
});

View File

@@ -0,0 +1,88 @@
/**
* @license
* Copyright 2025 Qwen
* SPDX-License-Identifier: Apache-2.0
*/
import { AuthType } from '@qwen-code/qwen-code-core';
import type {
SlashCommand,
CommandContext,
OpenDialogActionReturn,
MessageActionReturn,
} from './types.js';
import { CommandKind } from './types.js';
import {
AVAILABLE_MODELS_QWEN,
getOpenAIAvailableModelFromEnv,
type AvailableModel,
} from '../models/availableModels.js';
function getAvailableModelsForAuthType(authType: AuthType): AvailableModel[] {
switch (authType) {
case AuthType.QWEN_OAUTH:
return AVAILABLE_MODELS_QWEN;
case AuthType.USE_OPENAI: {
const openAIModel = getOpenAIAvailableModelFromEnv();
return openAIModel ? [openAIModel] : [];
}
default:
// For other auth types, return empty array for now
// This can be expanded later according to the design doc
return [];
}
}
export const modelCommand: SlashCommand = {
name: 'model',
description: 'Switch the model for this session',
kind: CommandKind.BUILT_IN,
action: async (
context: CommandContext,
): Promise<OpenDialogActionReturn | MessageActionReturn> => {
const { services } = context;
const { config } = services;
if (!config) {
return {
type: 'message',
messageType: 'error',
content: 'Configuration not available.',
};
}
const contentGeneratorConfig = config.getContentGeneratorConfig();
if (!contentGeneratorConfig) {
return {
type: 'message',
messageType: 'error',
content: 'Content generator configuration not available.',
};
}
const authType = contentGeneratorConfig.authType;
if (!authType) {
return {
type: 'message',
messageType: 'error',
content: 'Authentication type not available.',
};
}
const availableModels = getAvailableModelsForAuthType(authType);
if (availableModels.length === 0) {
return {
type: 'message',
messageType: 'error',
content: `No models available for the current authentication type (${authType}).`,
};
}
// Trigger model selection dialog
return {
type: 'dialog',
dialog: 'model',
};
},
};

View File

@@ -116,6 +116,7 @@ export interface OpenDialogActionReturn {
| 'editor'
| 'privacy'
| 'settings'
| 'model'
| 'subagent_create'
| 'subagent_list';
}

View File

@@ -5,7 +5,6 @@
*/
import type React from 'react';
import { memo } from 'react';
import type { HistoryItem } from '../types.js';
import { UserMessage } from './messages/UserMessage.js';
import { UserShellMessage } from './messages/UserShellMessage.js';
@@ -36,7 +35,7 @@ interface HistoryItemDisplayProps {
commands?: readonly SlashCommand[];
}
const HistoryItemDisplayComponent: React.FC<HistoryItemDisplayProps> = ({
export const HistoryItemDisplay: React.FC<HistoryItemDisplayProps> = ({
item,
availableTerminalHeight,
terminalWidth,
@@ -102,7 +101,3 @@ const HistoryItemDisplayComponent: React.FC<HistoryItemDisplayProps> = ({
{item.type === 'summary' && <SummaryMessage summary={item.summary} />}
</Box>
);
HistoryItemDisplayComponent.displayName = 'HistoryItemDisplay';
export const HistoryItemDisplay = memo(HistoryItemDisplayComponent);

View File

@@ -0,0 +1,246 @@
/**
* @license
* Copyright 2025 Qwen
* SPDX-License-Identifier: Apache-2.0
*/
import React from 'react';
import { render } from 'ink-testing-library';
import { describe, it, expect, vi, beforeEach } from 'vitest';
import { ModelSelectionDialog } from './ModelSelectionDialog.js';
import type { AvailableModel } from '../models/availableModels.js';
import type { RadioSelectItem } from './shared/RadioButtonSelect.js';
// Mock the useKeypress hook
const mockUseKeypress = vi.hoisted(() => vi.fn());
vi.mock('../hooks/useKeypress.js', () => ({
useKeypress: mockUseKeypress,
}));
// Mock the RadioButtonSelect component
const mockRadioButtonSelect = vi.hoisted(() => vi.fn());
vi.mock('./shared/RadioButtonSelect.js', () => ({
RadioButtonSelect: mockRadioButtonSelect,
}));
describe('ModelSelectionDialog', () => {
const mockAvailableModels: AvailableModel[] = [
{ id: 'qwen3-coder-plus', label: 'qwen3-coder-plus' },
{ id: 'qwen-vl-max-latest', label: 'qwen-vl-max', isVision: true },
{ id: 'gpt-4', label: 'GPT-4' },
];
const mockOnSelect = vi.fn();
const mockOnCancel = vi.fn();
beforeEach(() => {
vi.clearAllMocks();
// Mock RadioButtonSelect to return a simple div
mockRadioButtonSelect.mockReturnValue(
React.createElement('div', { 'data-testid': 'radio-select' }),
);
});
it('should setup escape key handler to call onCancel', () => {
render(
<ModelSelectionDialog
availableModels={mockAvailableModels}
currentModel="qwen3-coder-plus"
onSelect={mockOnSelect}
onCancel={mockOnCancel}
/>,
);
expect(mockUseKeypress).toHaveBeenCalledWith(expect.any(Function), {
isActive: true,
});
// Simulate escape key press
const keypressHandler = mockUseKeypress.mock.calls[0][0];
keypressHandler({ name: 'escape' });
expect(mockOnCancel).toHaveBeenCalled();
});
it('should not call onCancel for non-escape keys', () => {
render(
<ModelSelectionDialog
availableModels={mockAvailableModels}
currentModel="qwen3-coder-plus"
onSelect={mockOnSelect}
onCancel={mockOnCancel}
/>,
);
const keypressHandler = mockUseKeypress.mock.calls[0][0];
keypressHandler({ name: 'enter' });
expect(mockOnCancel).not.toHaveBeenCalled();
});
it('should set correct initial index for current model', () => {
render(
<ModelSelectionDialog
availableModels={mockAvailableModels}
currentModel="qwen-vl-max-latest"
onSelect={mockOnSelect}
onCancel={mockOnCancel}
/>,
);
const callArgs = mockRadioButtonSelect.mock.calls[0][0];
expect(callArgs.initialIndex).toBe(1); // qwen-vl-max-latest is at index 1
});
it('should set initial index to 0 when current model is not found', () => {
render(
<ModelSelectionDialog
availableModels={mockAvailableModels}
currentModel="non-existent-model"
onSelect={mockOnSelect}
onCancel={mockOnCancel}
/>,
);
const callArgs = mockRadioButtonSelect.mock.calls[0][0];
expect(callArgs.initialIndex).toBe(0);
});
it('should call onSelect when a model is selected', () => {
render(
<ModelSelectionDialog
availableModels={mockAvailableModels}
currentModel="qwen3-coder-plus"
onSelect={mockOnSelect}
onCancel={mockOnCancel}
/>,
);
const callArgs = mockRadioButtonSelect.mock.calls[0][0];
expect(typeof callArgs.onSelect).toBe('function');
// Simulate selection
const onSelectCallback = mockRadioButtonSelect.mock.calls[0][0].onSelect;
onSelectCallback('qwen-vl-max-latest');
expect(mockOnSelect).toHaveBeenCalledWith('qwen-vl-max-latest');
});
it('should handle empty models array', () => {
render(
<ModelSelectionDialog
availableModels={[]}
currentModel=""
onSelect={mockOnSelect}
onCancel={mockOnCancel}
/>,
);
const callArgs = mockRadioButtonSelect.mock.calls[0][0];
expect(callArgs.items).toEqual([]);
expect(callArgs.initialIndex).toBe(0);
});
it('should create correct option items with proper labels', () => {
render(
<ModelSelectionDialog
availableModels={mockAvailableModels}
currentModel="qwen3-coder-plus"
onSelect={mockOnSelect}
onCancel={mockOnCancel}
/>,
);
const expectedItems = [
{
label: 'qwen3-coder-plus (current)',
value: 'qwen3-coder-plus',
},
{
label: 'qwen-vl-max [Vision]',
value: 'qwen-vl-max-latest',
},
{
label: 'GPT-4',
value: 'gpt-4',
},
];
const callArgs = mockRadioButtonSelect.mock.calls[0][0];
expect(callArgs.items).toEqual(expectedItems);
});
it('should show vision indicator for vision models', () => {
render(
<ModelSelectionDialog
availableModels={mockAvailableModels}
currentModel="gpt-4"
onSelect={mockOnSelect}
onCancel={mockOnCancel}
/>,
);
const callArgs = mockRadioButtonSelect.mock.calls[0][0];
const visionModelItem = callArgs.items.find(
(item: RadioSelectItem<string>) => item.value === 'qwen-vl-max-latest',
);
expect(visionModelItem?.label).toContain('[Vision]');
});
it('should show current indicator for the current model', () => {
render(
<ModelSelectionDialog
availableModels={mockAvailableModels}
currentModel="qwen-vl-max-latest"
onSelect={mockOnSelect}
onCancel={mockOnCancel}
/>,
);
const callArgs = mockRadioButtonSelect.mock.calls[0][0];
const currentModelItem = callArgs.items.find(
(item: RadioSelectItem<string>) => item.value === 'qwen-vl-max-latest',
);
expect(currentModelItem?.label).toContain('(current)');
});
it('should pass isFocused prop to RadioButtonSelect', () => {
render(
<ModelSelectionDialog
availableModels={mockAvailableModels}
currentModel="qwen3-coder-plus"
onSelect={mockOnSelect}
onCancel={mockOnCancel}
/>,
);
const callArgs = mockRadioButtonSelect.mock.calls[0][0];
expect(callArgs.isFocused).toBe(true);
});
it('should handle multiple onSelect calls correctly', () => {
render(
<ModelSelectionDialog
availableModels={mockAvailableModels}
currentModel="qwen3-coder-plus"
onSelect={mockOnSelect}
onCancel={mockOnCancel}
/>,
);
const onSelectCallback = mockRadioButtonSelect.mock.calls[0][0].onSelect;
// Call multiple times
onSelectCallback('qwen3-coder-plus');
onSelectCallback('qwen-vl-max-latest');
onSelectCallback('gpt-4');
expect(mockOnSelect).toHaveBeenCalledTimes(3);
expect(mockOnSelect).toHaveBeenNthCalledWith(1, 'qwen3-coder-plus');
expect(mockOnSelect).toHaveBeenNthCalledWith(2, 'qwen-vl-max-latest');
expect(mockOnSelect).toHaveBeenNthCalledWith(3, 'gpt-4');
});
});

View File

@@ -0,0 +1,87 @@
/**
* @license
* Copyright 2025 Qwen
* SPDX-License-Identifier: Apache-2.0
*/
import type React from 'react';
import { Box, Text } from 'ink';
import { Colors } from '../colors.js';
import {
RadioButtonSelect,
type RadioSelectItem,
} from './shared/RadioButtonSelect.js';
import { useKeypress } from '../hooks/useKeypress.js';
import type { AvailableModel } from '../models/availableModels.js';
export interface ModelSelectionDialogProps {
availableModels: AvailableModel[];
currentModel: string;
onSelect: (modelId: string) => void;
onCancel: () => void;
}
export const ModelSelectionDialog: React.FC<ModelSelectionDialogProps> = ({
availableModels,
currentModel,
onSelect,
onCancel,
}) => {
useKeypress(
(key) => {
if (key.name === 'escape') {
onCancel();
}
},
{ isActive: true },
);
const options: Array<RadioSelectItem<string>> = availableModels.map(
(model) => {
const visionIndicator = model.isVision ? ' [Vision]' : '';
const currentIndicator = model.id === currentModel ? ' (current)' : '';
return {
label: `${model.label}${visionIndicator}${currentIndicator}`,
value: model.id,
};
},
);
const initialIndex = Math.max(
0,
availableModels.findIndex((model) => model.id === currentModel),
);
const handleSelect = (modelId: string) => {
onSelect(modelId);
};
return (
<Box
flexDirection="column"
borderStyle="round"
borderColor={Colors.AccentBlue}
padding={1}
width="100%"
marginLeft={1}
>
<Box flexDirection="column" marginBottom={1}>
<Text bold>Select Model</Text>
<Text>Choose a model for this session:</Text>
</Box>
<Box marginBottom={1}>
<RadioButtonSelect
items={options}
initialIndex={initialIndex}
onSelect={handleSelect}
isFocused
/>
</Box>
<Box>
<Text color={Colors.Gray}>Press Enter to select, Esc to cancel</Text>
</Box>
</Box>
);
};

View File

@@ -0,0 +1,185 @@
/**
* @license
* Copyright 2025 Qwen
* SPDX-License-Identifier: Apache-2.0
*/
import React from 'react';
import { render } from 'ink-testing-library';
import { describe, it, expect, vi, beforeEach } from 'vitest';
import { ModelSwitchDialog, VisionSwitchOutcome } from './ModelSwitchDialog.js';
// Mock the useKeypress hook
const mockUseKeypress = vi.hoisted(() => vi.fn());
vi.mock('../hooks/useKeypress.js', () => ({
useKeypress: mockUseKeypress,
}));
// Mock the RadioButtonSelect component
const mockRadioButtonSelect = vi.hoisted(() => vi.fn());
vi.mock('./shared/RadioButtonSelect.js', () => ({
RadioButtonSelect: mockRadioButtonSelect,
}));
describe('ModelSwitchDialog', () => {
const mockOnSelect = vi.fn();
beforeEach(() => {
vi.clearAllMocks();
// Mock RadioButtonSelect to return a simple div
mockRadioButtonSelect.mockReturnValue(
React.createElement('div', { 'data-testid': 'radio-select' }),
);
});
it('should setup RadioButtonSelect with correct options', () => {
render(<ModelSwitchDialog onSelect={mockOnSelect} />);
const expectedItems = [
{
label: 'Switch for this request only',
value: VisionSwitchOutcome.SwitchOnce,
},
{
label: 'Switch session to vision model',
value: VisionSwitchOutcome.SwitchSessionToVL,
},
{
label: 'Do not switch, show guidance',
value: VisionSwitchOutcome.DisallowWithGuidance,
},
];
const callArgs = mockRadioButtonSelect.mock.calls[0][0];
expect(callArgs.items).toEqual(expectedItems);
expect(callArgs.initialIndex).toBe(0);
expect(callArgs.isFocused).toBe(true);
});
it('should call onSelect when an option is selected', () => {
render(<ModelSwitchDialog onSelect={mockOnSelect} />);
const callArgs = mockRadioButtonSelect.mock.calls[0][0];
expect(typeof callArgs.onSelect).toBe('function');
// Simulate selection of "Switch for this request only"
const onSelectCallback = mockRadioButtonSelect.mock.calls[0][0].onSelect;
onSelectCallback(VisionSwitchOutcome.SwitchOnce);
expect(mockOnSelect).toHaveBeenCalledWith(VisionSwitchOutcome.SwitchOnce);
});
it('should call onSelect with SwitchSessionToVL when second option is selected', () => {
render(<ModelSwitchDialog onSelect={mockOnSelect} />);
const onSelectCallback = mockRadioButtonSelect.mock.calls[0][0].onSelect;
onSelectCallback(VisionSwitchOutcome.SwitchSessionToVL);
expect(mockOnSelect).toHaveBeenCalledWith(
VisionSwitchOutcome.SwitchSessionToVL,
);
});
it('should call onSelect with DisallowWithGuidance when third option is selected', () => {
render(<ModelSwitchDialog onSelect={mockOnSelect} />);
const onSelectCallback = mockRadioButtonSelect.mock.calls[0][0].onSelect;
onSelectCallback(VisionSwitchOutcome.DisallowWithGuidance);
expect(mockOnSelect).toHaveBeenCalledWith(
VisionSwitchOutcome.DisallowWithGuidance,
);
});
it('should setup escape key handler to call onSelect with DisallowWithGuidance', () => {
render(<ModelSwitchDialog onSelect={mockOnSelect} />);
expect(mockUseKeypress).toHaveBeenCalledWith(expect.any(Function), {
isActive: true,
});
// Simulate escape key press
const keypressHandler = mockUseKeypress.mock.calls[0][0];
keypressHandler({ name: 'escape' });
expect(mockOnSelect).toHaveBeenCalledWith(
VisionSwitchOutcome.DisallowWithGuidance,
);
});
it('should not call onSelect for non-escape keys', () => {
render(<ModelSwitchDialog onSelect={mockOnSelect} />);
const keypressHandler = mockUseKeypress.mock.calls[0][0];
keypressHandler({ name: 'enter' });
expect(mockOnSelect).not.toHaveBeenCalled();
});
it('should set initial index to 0 (first option)', () => {
render(<ModelSwitchDialog onSelect={mockOnSelect} />);
const callArgs = mockRadioButtonSelect.mock.calls[0][0];
expect(callArgs.initialIndex).toBe(0);
});
describe('VisionSwitchOutcome enum', () => {
it('should have correct enum values', () => {
expect(VisionSwitchOutcome.SwitchOnce).toBe('switch_once');
expect(VisionSwitchOutcome.SwitchSessionToVL).toBe(
'switch_session_to_vl',
);
expect(VisionSwitchOutcome.DisallowWithGuidance).toBe(
'disallow_with_guidance',
);
});
});
it('should handle multiple onSelect calls correctly', () => {
render(<ModelSwitchDialog onSelect={mockOnSelect} />);
const onSelectCallback = mockRadioButtonSelect.mock.calls[0][0].onSelect;
// Call multiple times
onSelectCallback(VisionSwitchOutcome.SwitchOnce);
onSelectCallback(VisionSwitchOutcome.SwitchSessionToVL);
onSelectCallback(VisionSwitchOutcome.DisallowWithGuidance);
expect(mockOnSelect).toHaveBeenCalledTimes(3);
expect(mockOnSelect).toHaveBeenNthCalledWith(
1,
VisionSwitchOutcome.SwitchOnce,
);
expect(mockOnSelect).toHaveBeenNthCalledWith(
2,
VisionSwitchOutcome.SwitchSessionToVL,
);
expect(mockOnSelect).toHaveBeenNthCalledWith(
3,
VisionSwitchOutcome.DisallowWithGuidance,
);
});
it('should pass isFocused prop to RadioButtonSelect', () => {
render(<ModelSwitchDialog onSelect={mockOnSelect} />);
const callArgs = mockRadioButtonSelect.mock.calls[0][0];
expect(callArgs.isFocused).toBe(true);
});
it('should handle escape key multiple times', () => {
render(<ModelSwitchDialog onSelect={mockOnSelect} />);
const keypressHandler = mockUseKeypress.mock.calls[0][0];
// Call escape multiple times
keypressHandler({ name: 'escape' });
keypressHandler({ name: 'escape' });
expect(mockOnSelect).toHaveBeenCalledTimes(2);
expect(mockOnSelect).toHaveBeenCalledWith(
VisionSwitchOutcome.DisallowWithGuidance,
);
});
});

View File

@@ -0,0 +1,89 @@
/**
* @license
* Copyright 2025 Qwen
* SPDX-License-Identifier: Apache-2.0
*/
import type React from 'react';
import { Box, Text } from 'ink';
import { Colors } from '../colors.js';
import {
RadioButtonSelect,
type RadioSelectItem,
} from './shared/RadioButtonSelect.js';
import { useKeypress } from '../hooks/useKeypress.js';
export enum VisionSwitchOutcome {
SwitchOnce = 'switch_once',
SwitchSessionToVL = 'switch_session_to_vl',
DisallowWithGuidance = 'disallow_with_guidance',
}
export interface ModelSwitchDialogProps {
onSelect: (outcome: VisionSwitchOutcome) => void;
}
export const ModelSwitchDialog: React.FC<ModelSwitchDialogProps> = ({
onSelect,
}) => {
useKeypress(
(key) => {
if (key.name === 'escape') {
onSelect(VisionSwitchOutcome.DisallowWithGuidance);
}
},
{ isActive: true },
);
const options: Array<RadioSelectItem<VisionSwitchOutcome>> = [
{
label: 'Switch for this request only',
value: VisionSwitchOutcome.SwitchOnce,
},
{
label: 'Switch session to vision model',
value: VisionSwitchOutcome.SwitchSessionToVL,
},
{
label: 'Do not switch, show guidance',
value: VisionSwitchOutcome.DisallowWithGuidance,
},
];
const handleSelect = (outcome: VisionSwitchOutcome) => {
onSelect(outcome);
};
return (
<Box
flexDirection="column"
borderStyle="round"
borderColor={Colors.AccentYellow}
padding={1}
width="100%"
marginLeft={1}
>
<Box flexDirection="column" marginBottom={1}>
<Text bold>Vision Model Switch Required</Text>
<Text>
Your message contains an image, but the current model doesn&apos;t
support vision.
</Text>
<Text>How would you like to proceed?</Text>
</Box>
<Box marginBottom={1}>
<RadioButtonSelect
items={options}
initialIndex={0}
onSelect={handleSelect}
isFocused
/>
</Box>
<Box>
<Text color={Colors.Gray}>Press Enter to select, Esc to cancel</Text>
</Box>
</Box>
);
};

View File

@@ -27,7 +27,6 @@ export interface ToolConfirmationMessageProps {
isFocused?: boolean;
availableTerminalHeight?: number;
terminalWidth: number;
compactMode?: boolean;
}
export const ToolConfirmationMessage: React.FC<
@@ -38,7 +37,6 @@ export const ToolConfirmationMessage: React.FC<
isFocused = true,
availableTerminalHeight,
terminalWidth,
compactMode = false,
}) => {
const { onConfirm } = confirmationDetails;
const childWidth = terminalWidth - 2; // 2 for padding
@@ -72,40 +70,6 @@ export const ToolConfirmationMessage: React.FC<
const handleSelect = (item: ToolConfirmationOutcome) => handleConfirm(item);
// Compact mode: return simple 3-option display
if (compactMode) {
const compactOptions: Array<RadioSelectItem<ToolConfirmationOutcome>> = [
{
label: 'Yes, allow once',
value: ToolConfirmationOutcome.ProceedOnce,
},
{
label: 'Allow always',
value: ToolConfirmationOutcome.ProceedAlways,
},
{
label: 'No',
value: ToolConfirmationOutcome.Cancel,
},
];
return (
<Box flexDirection="column">
<Box>
<Text wrap="truncate">Do you want to proceed?</Text>
</Box>
<Box>
<RadioButtonSelect
items={compactOptions}
onSelect={handleSelect}
isFocused={isFocused}
/>
</Box>
</Box>
);
}
// Original logic continues unchanged below
let bodyContent: React.ReactNode | null = null; // Removed contextDisplay here
let question: string;

View File

@@ -5,7 +5,7 @@
*/
import { useReducer, useCallback, useMemo } from 'react';
import { Box, Text } from 'ink';
import { Box, Text, useInput } from 'ink';
import { wizardReducer, initialWizardState } from '../reducers.js';
import { LocationSelector } from './LocationSelector.js';
import { GenerationMethodSelector } from './GenerationMethodSelector.js';
@@ -20,7 +20,6 @@ import type { Config } from '@qwen-code/qwen-code-core';
import { Colors } from '../../../colors.js';
import { theme } from '../../../semantic-colors.js';
import { TextEntryStep } from './TextEntryStep.js';
import { useKeypress } from '../../../hooks/useKeypress.js';
interface AgentCreationWizardProps {
onClose: () => void;
@@ -50,12 +49,8 @@ export function AgentCreationWizard({
}, [onClose]);
// Centralized ESC key handling for the entire wizard
useKeypress(
(key) => {
if (key.name !== 'escape') {
return;
}
useInput((input, key) => {
if (key.escape) {
// LLM DescriptionInput handles its own ESC logic when generating
const kind = getStepKind(state.generationMethod, state.currentStep);
if (kind === 'LLM_DESC' && state.isGenerating) {
@@ -69,9 +64,8 @@ export function AgentCreationWizard({
// On other steps, ESC goes back to previous step
handlePrevious();
}
},
{ isActive: true },
);
}
});
const stepProps: WizardStepProps = useMemo(
() => ({

View File

@@ -227,7 +227,7 @@ export const AgentSelectionStep = ({
const textColor = isSelected ? theme.text.accent : theme.text.primary;
return (
<Box key={`${agent.name}-${agent.level}`} alignItems="center">
<Box key={agent.name} alignItems="center">
<Box minWidth={2} flexShrink={0}>
<Text color={isSelected ? theme.text.accent : theme.text.primary}>
{isSelected ? '●' : ' '}

View File

@@ -5,7 +5,7 @@
*/
import { useState, useCallback, useMemo, useEffect } from 'react';
import { Box, Text } from 'ink';
import { Box, Text, useInput } from 'ink';
import { AgentSelectionStep } from './AgentSelectionStep.js';
import { ActionSelectionStep } from './ActionSelectionStep.js';
import { AgentViewerStep } from './AgentViewerStep.js';
@@ -17,8 +17,7 @@ import { MANAGEMENT_STEPS } from '../types.js';
import { Colors } from '../../../colors.js';
import { theme } from '../../../semantic-colors.js';
import { getColorForDisplay, shouldShowColor } from '../utils.js';
import type { SubagentConfig, Config } from '@qwen-code/qwen-code-core';
import { useKeypress } from '../../../hooks/useKeypress.js';
import type { Config, SubagentConfig } from '@qwen-code/qwen-code-core';
interface AgentsManagerDialogProps {
onClose: () => void;
@@ -53,7 +52,18 @@ export function AgentsManagerDialog({
const manager = config.getSubagentManager();
// Load agents from all levels separately to show all agents including conflicts
const allAgents = await manager.listSubagents();
const [projectAgents, userAgents, builtinAgents] = await Promise.all([
manager.listSubagents({ level: 'project' }),
manager.listSubagents({ level: 'user' }),
manager.listSubagents({ level: 'builtin' }),
]);
// Combine all agents (project, user, and builtin level)
const allAgents = [
...(projectAgents || []),
...(userAgents || []),
...(builtinAgents || []),
];
setAvailableAgents(allAgents);
}, [config]);
@@ -112,12 +122,8 @@ export function AgentsManagerDialog({
);
// Centralized ESC key handling for the entire dialog
useKeypress(
(key) => {
if (key.name !== 'escape') {
return;
}
useInput((input, key) => {
if (key.escape) {
const currentStep = getCurrentStep();
if (currentStep === MANAGEMENT_STEPS.AGENT_SELECTION) {
// On first step, ESC cancels the entire dialog
@@ -126,9 +132,8 @@ export function AgentsManagerDialog({
// On other steps, ESC goes back to previous step in navigation stack
handleNavigateBack();
}
},
{ isActive: true },
);
}
});
// Props for child components - now using direct state and callbacks
const commonProps = useMemo(

View File

@@ -18,12 +18,12 @@ import { COLOR_OPTIONS } from '../constants.js';
import { fmtDuration } from '../utils.js';
import { ToolConfirmationMessage } from '../../messages/ToolConfirmationMessage.js';
export type DisplayMode = 'compact' | 'default' | 'verbose';
export type DisplayMode = 'default' | 'verbose';
export interface AgentExecutionDisplayProps {
data: TaskResultDisplay;
availableHeight?: number;
childWidth: number;
childWidth?: number;
config: Config;
}
@@ -80,7 +80,7 @@ export const AgentExecutionDisplay: React.FC<AgentExecutionDisplayProps> = ({
childWidth,
config,
}) => {
const [displayMode, setDisplayMode] = React.useState<DisplayMode>('compact');
const [displayMode, setDisplayMode] = React.useState<DisplayMode>('default');
const agentColor = useMemo(() => {
const colorOption = COLOR_OPTIONS.find(
@@ -93,6 +93,8 @@ export const AgentExecutionDisplay: React.FC<AgentExecutionDisplayProps> = ({
// This component only listens to keyboard shortcut events when the subagent is running
if (data.status !== 'running') return '';
if (displayMode === 'verbose') return 'Press ctrl+r to show less.';
if (displayMode === 'default') {
const hasMoreLines =
data.taskPrompt.split('\n').length > MAX_TASK_PROMPT_LINES;
@@ -100,28 +102,17 @@ export const AgentExecutionDisplay: React.FC<AgentExecutionDisplayProps> = ({
data.toolCalls && data.toolCalls.length > MAX_TOOL_CALLS;
if (hasMoreToolCalls || hasMoreLines) {
return 'Press ctrl+r to show less, ctrl+e to show more.';
return 'Press ctrl+r to show more.';
}
return 'Press ctrl+r to show less.';
return '';
}
if (displayMode === 'verbose') {
return 'Press ctrl+e to show less.';
}
return '';
}, [displayMode, data]);
}, [displayMode, data.toolCalls, data.taskPrompt, data.status]);
// Handle keyboard shortcuts to control display mode
// Handle ctrl+r keypresses to control display mode
useKeypress(
(key) => {
if (key.ctrl && key.name === 'r') {
// ctrl+r toggles between compact and default
setDisplayMode((current) =>
current === 'compact' ? 'default' : 'compact',
);
} else if (key.ctrl && key.name === 'e') {
// ctrl+e toggles between default and verbose
setDisplayMode((current) =>
current === 'default' ? 'verbose' : 'default',
);
@@ -130,82 +121,6 @@ export const AgentExecutionDisplay: React.FC<AgentExecutionDisplayProps> = ({
{ isActive: true },
);
if (displayMode === 'compact') {
return (
<Box flexDirection="column">
{/* Header: Agent name and status */}
{!data.pendingConfirmation && (
<Box flexDirection="row">
<Text bold color={agentColor}>
{data.subagentName}
</Text>
<StatusDot status={data.status} />
<StatusIndicator status={data.status} />
</Box>
)}
{/* Running state: Show current tool call and progress */}
{data.status === 'running' && (
<>
{/* Current tool call */}
{data.toolCalls && data.toolCalls.length > 0 && (
<Box flexDirection="column">
<ToolCallItem
toolCall={data.toolCalls[data.toolCalls.length - 1]}
compact={true}
/>
{/* Show count of additional tool calls if there are more than 1 */}
{data.toolCalls.length > 1 && !data.pendingConfirmation && (
<Box flexDirection="row" paddingLeft={4}>
<Text color={Colors.Gray}>
+{data.toolCalls.length - 1} more tool calls (ctrl+r to
expand)
</Text>
</Box>
)}
</Box>
)}
{/* Inline approval prompt when awaiting confirmation */}
{data.pendingConfirmation && (
<Box flexDirection="column" marginTop={1} paddingLeft={1}>
<ToolConfirmationMessage
confirmationDetails={data.pendingConfirmation}
isFocused={true}
availableTerminalHeight={availableHeight}
terminalWidth={childWidth}
compactMode={true}
config={config}
/>
</Box>
)}
</>
)}
{/* Completed state: Show summary line */}
{data.status === 'completed' && data.executionSummary && (
<Box flexDirection="row" marginTop={1}>
<Text color={theme.text.secondary}>
Execution Summary: {data.executionSummary.totalToolCalls} tool
uses · {data.executionSummary.totalTokens.toLocaleString()} tokens
· {fmtDuration(data.executionSummary.totalDurationMs)}
</Text>
</Box>
)}
{/* Failed/Cancelled state: Show error reason */}
{data.status === 'failed' && (
<Box flexDirection="row" marginTop={1}>
<Text color={theme.status.error}>
Failed: {data.terminateReason}
</Text>
</Box>
)}
</Box>
);
}
// Default and verbose modes use normal layout
return (
<Box flexDirection="column" paddingX={1} gap={1}>
{/* Header with subagent name and status */}
@@ -243,8 +158,7 @@ export const AgentExecutionDisplay: React.FC<AgentExecutionDisplayProps> = ({
config={config}
isFocused={true}
availableTerminalHeight={availableHeight}
terminalWidth={childWidth}
compactMode={true}
terminalWidth={childWidth ?? 80}
/>
</Box>
)}
@@ -366,8 +280,7 @@ const ToolCallItem: React.FC<{
resultDisplay?: string;
description?: string;
};
compact?: boolean;
}> = ({ toolCall, compact = false }) => {
}> = ({ toolCall }) => {
const STATUS_INDICATOR_WIDTH = 3;
// Map subagent status to ToolCallStatus-like display
@@ -422,8 +335,8 @@ const ToolCallItem: React.FC<{
</Text>
</Box>
{/* Second line: truncated returnDisplay output - hidden in compact mode */}
{!compact && truncatedOutput && (
{/* Second line: truncated returnDisplay output */}
{truncatedOutput && (
<Box flexDirection="row" paddingLeft={STATUS_INDICATOR_WIDTH}>
<Text color={Colors.Gray}>{truncatedOutput}</Text>
</Box>

View File

@@ -106,6 +106,7 @@ describe('useSlashCommandProcessor', () => {
const mockLoadHistory = vi.fn();
const mockOpenThemeDialog = vi.fn();
const mockOpenAuthDialog = vi.fn();
const mockOpenModelSelectionDialog = vi.fn();
const mockSetQuittingMessages = vi.fn();
const mockConfig = makeFakeConfig({});
@@ -122,6 +123,7 @@ describe('useSlashCommandProcessor', () => {
mockBuiltinLoadCommands.mockResolvedValue([]);
mockFileLoadCommands.mockResolvedValue([]);
mockMcpLoadCommands.mockResolvedValue([]);
mockOpenModelSelectionDialog.mockClear();
});
const setupProcessorHook = (
@@ -150,11 +152,13 @@ describe('useSlashCommandProcessor', () => {
mockSetQuittingMessages,
vi.fn(), // openPrivacyNotice
vi.fn(), // openSettingsDialog
mockOpenModelSelectionDialog,
vi.fn(), // openSubagentCreateDialog
vi.fn(), // openAgentsManagerDialog
vi.fn(), // toggleVimEnabled
setIsProcessing,
vi.fn(), // setGeminiMdFileCount
vi.fn(), // _showQuitConfirmation
),
);
@@ -395,6 +399,21 @@ describe('useSlashCommandProcessor', () => {
expect(mockOpenThemeDialog).toHaveBeenCalled();
});
it('should handle "dialog: model" action', async () => {
const command = createTestCommand({
name: 'modelcmd',
action: vi.fn().mockResolvedValue({ type: 'dialog', dialog: 'model' }),
});
const result = setupProcessorHook([command]);
await waitFor(() => expect(result.current.slashCommands).toHaveLength(1));
await act(async () => {
await result.current.handleSlashCommand('/modelcmd');
});
expect(mockOpenModelSelectionDialog).toHaveBeenCalled();
});
it('should handle "load_history" action', async () => {
const command = createTestCommand({
name: 'load',
@@ -904,11 +923,13 @@ describe('useSlashCommandProcessor', () => {
mockSetQuittingMessages,
vi.fn(), // openPrivacyNotice
vi.fn(), // openSettingsDialog
vi.fn(), // openModelSelectionDialog
vi.fn(), // openSubagentCreateDialog
vi.fn(), // openAgentsManagerDialog
vi.fn(), // toggleVimEnabled
vi.fn(), // setIsProcessing
vi.fn(), // setGeminiMdFileCount
vi.fn(), // _showQuitConfirmation
),
);

View File

@@ -53,6 +53,7 @@ export const useSlashCommandProcessor = (
setQuittingMessages: (message: HistoryItem[]) => void,
openPrivacyNotice: () => void,
openSettingsDialog: () => void,
openModelSelectionDialog: () => void,
openSubagentCreateDialog: () => void,
openAgentsManagerDialog: () => void,
toggleVimEnabled: () => Promise<boolean>,
@@ -404,6 +405,9 @@ export const useSlashCommandProcessor = (
case 'settings':
openSettingsDialog();
return { type: 'handled' };
case 'model':
openModelSelectionDialog();
return { type: 'handled' };
case 'subagent_create':
openSubagentCreateDialog();
return { type: 'handled' };
@@ -663,6 +667,7 @@ export const useSlashCommandProcessor = (
setSessionShellAllowlist,
setIsProcessing,
setConfirmationRequest,
openModelSelectionDialog,
session.stats,
],
);

View File

@@ -56,6 +56,12 @@ const MockedUserPromptEvent = vi.hoisted(() =>
);
const mockParseAndFormatApiError = vi.hoisted(() => vi.fn());
// Vision auto-switch mocks (hoisted)
const mockHandleVisionSwitch = vi.hoisted(() =>
vi.fn().mockResolvedValue({ shouldProceed: true }),
);
const mockRestoreOriginalModel = vi.hoisted(() => vi.fn());
vi.mock('@qwen-code/qwen-code-core', async (importOriginal) => {
const actualCoreModule = (await importOriginal()) as any;
return {
@@ -76,6 +82,13 @@ vi.mock('./useReactToolScheduler.js', async (importOriginal) => {
};
});
vi.mock('./useVisionAutoSwitch.js', () => ({
useVisionAutoSwitch: vi.fn(() => ({
handleVisionSwitch: mockHandleVisionSwitch,
restoreOriginalModel: mockRestoreOriginalModel,
})),
}));
vi.mock('./useKeypress.js', () => ({
useKeypress: vi.fn(),
}));
@@ -199,6 +212,7 @@ describe('useGeminiStream', () => {
getContentGeneratorConfig: vi
.fn()
.mockReturnValue(contentGeneratorConfig),
getMaxSessionTurns: vi.fn(() => 50),
} as unknown as Config;
mockOnDebugMessage = vi.fn();
mockHandleSlashCommand = vi.fn().mockResolvedValue(false);
@@ -1551,6 +1565,7 @@ describe('useGeminiStream', () => {
expect.any(String), // Argument 3: The prompt_id string
);
});
describe('Thought Reset', () => {
it('should reset thought to null when starting a new prompt', async () => {
// First, simulate a response with a thought
@@ -1900,4 +1915,166 @@ describe('useGeminiStream', () => {
);
});
});
// --- New tests focused on recent modifications ---
describe('Vision Auto Switch Integration', () => {
it('should call handleVisionSwitch and proceed to send when allowed', async () => {
mockHandleVisionSwitch.mockResolvedValueOnce({ shouldProceed: true });
mockSendMessageStream.mockReturnValue(
(async function* () {
yield { type: ServerGeminiEventType.Content, value: 'ok' };
yield { type: ServerGeminiEventType.Finished, value: 'STOP' };
})(),
);
const { result } = renderHook(() =>
useGeminiStream(
new MockedGeminiClientClass(mockConfig),
[],
mockAddItem,
mockConfig,
mockOnDebugMessage,
mockHandleSlashCommand,
false,
() => 'vscode' as EditorType,
() => {},
() => Promise.resolve(),
false,
() => {},
() => {},
() => {},
),
);
await act(async () => {
await result.current.submitQuery('image prompt');
});
await waitFor(() => {
expect(mockHandleVisionSwitch).toHaveBeenCalled();
expect(mockSendMessageStream).toHaveBeenCalled();
});
});
it('should gate submission when handleVisionSwitch returns shouldProceed=false', async () => {
mockHandleVisionSwitch.mockResolvedValueOnce({ shouldProceed: false });
const { result } = renderHook(() =>
useGeminiStream(
new MockedGeminiClientClass(mockConfig),
[],
mockAddItem,
mockConfig,
mockOnDebugMessage,
mockHandleSlashCommand,
false,
() => 'vscode' as EditorType,
() => {},
() => Promise.resolve(),
false,
() => {},
() => {},
() => {},
),
);
await act(async () => {
await result.current.submitQuery('vision-gated');
});
// No call to API, no restoreOriginalModel needed since no override occurred
expect(mockSendMessageStream).not.toHaveBeenCalled();
expect(mockRestoreOriginalModel).not.toHaveBeenCalled();
// Next call allowed (flag reset path)
mockHandleVisionSwitch.mockResolvedValueOnce({ shouldProceed: true });
mockSendMessageStream.mockReturnValue(
(async function* () {
yield { type: ServerGeminiEventType.Content, value: 'ok' };
yield { type: ServerGeminiEventType.Finished, value: 'STOP' };
})(),
);
await act(async () => {
await result.current.submitQuery('after-gate');
});
await waitFor(() => {
expect(mockSendMessageStream).toHaveBeenCalled();
});
});
});
describe('Model restore on completion and errors', () => {
it('should restore model after successful stream completion', async () => {
mockSendMessageStream.mockReturnValue(
(async function* () {
yield { type: ServerGeminiEventType.Content, value: 'content' };
yield { type: ServerGeminiEventType.Finished, value: 'STOP' };
})(),
);
const { result } = renderHook(() =>
useGeminiStream(
new MockedGeminiClientClass(mockConfig),
[],
mockAddItem,
mockConfig,
mockOnDebugMessage,
mockHandleSlashCommand,
false,
() => 'vscode' as EditorType,
() => {},
() => Promise.resolve(),
false,
() => {},
() => {},
() => {},
),
);
await act(async () => {
await result.current.submitQuery('restore-success');
});
await waitFor(() => {
expect(mockRestoreOriginalModel).toHaveBeenCalledTimes(1);
});
});
it('should restore model when an error occurs during streaming', async () => {
const testError = new Error('stream failure');
mockSendMessageStream.mockReturnValue(
(async function* () {
yield { type: ServerGeminiEventType.Content, value: 'content' };
throw testError;
})(),
);
const { result } = renderHook(() =>
useGeminiStream(
new MockedGeminiClientClass(mockConfig),
[],
mockAddItem,
mockConfig,
mockOnDebugMessage,
mockHandleSlashCommand,
false,
() => 'vscode' as EditorType,
() => {},
() => Promise.resolve(),
false,
() => {},
() => {},
() => {},
),
);
await act(async () => {
await result.current.submitQuery('restore-error');
});
await waitFor(() => {
expect(mockRestoreOriginalModel).toHaveBeenCalledTimes(1);
});
});
});
});

View File

@@ -42,6 +42,7 @@ import type {
import { StreamingState, MessageType, ToolCallStatus } from '../types.js';
import { isAtCommand, isSlashCommand } from '../utils/commandUtils.js';
import { useShellCommandProcessor } from './shellCommandProcessor.js';
import { useVisionAutoSwitch } from './useVisionAutoSwitch.js';
import { handleAtCommand } from './atCommandProcessor.js';
import { findLastSafeSplitPoint } from '../utils/markdownUtilities.js';
import { useStateAndRef } from './useStateAndRef.js';
@@ -88,6 +89,12 @@ export const useGeminiStream = (
setModelSwitchedFromQuotaError: React.Dispatch<React.SetStateAction<boolean>>,
onEditorClose: () => void,
onCancelSubmit: () => void,
visionModelPreviewEnabled: boolean = false,
onVisionSwitchRequired?: (query: PartListUnion) => Promise<{
modelOverride?: string;
persistSessionModel?: string;
showGuidance?: boolean;
}>,
) => {
const [initError, setInitError] = useState<string | null>(null);
const abortControllerRef = useRef<AbortController | null>(null);
@@ -155,6 +162,13 @@ export const useGeminiStream = (
geminiClient,
);
const { handleVisionSwitch, restoreOriginalModel } = useVisionAutoSwitch(
config,
addItem,
visionModelPreviewEnabled,
onVisionSwitchRequired,
);
const streamingState = useMemo(() => {
if (toolCalls.some((tc) => tc.status === 'awaiting_approval')) {
return StreamingState.WaitingForConfirmation;
@@ -715,6 +729,20 @@ export const useGeminiStream = (
return;
}
// Handle vision switch requirement
const visionSwitchResult = await handleVisionSwitch(
queryToSend,
userMessageTimestamp,
options?.isContinuation || false,
);
if (!visionSwitchResult.shouldProceed) {
isSubmittingQueryRef.current = false;
return;
}
const finalQueryToSend = queryToSend;
if (!options?.isContinuation) {
startNewPrompt();
setThought(null); // Reset thought when starting a new prompt
@@ -725,7 +753,7 @@ export const useGeminiStream = (
try {
const stream = geminiClient.sendMessageStream(
queryToSend,
finalQueryToSend,
abortSignal,
prompt_id!,
);
@@ -736,6 +764,8 @@ export const useGeminiStream = (
);
if (processingStatus === StreamProcessingStatus.UserCancelled) {
// Restore original model if it was temporarily overridden
restoreOriginalModel();
isSubmittingQueryRef.current = false;
return;
}
@@ -748,7 +778,13 @@ export const useGeminiStream = (
loopDetectedRef.current = false;
handleLoopDetectedEvent();
}
// Restore original model if it was temporarily overridden
restoreOriginalModel();
} catch (error: unknown) {
// Restore original model if it was temporarily overridden
restoreOriginalModel();
if (error instanceof UnauthorizedError) {
onAuthError();
} else if (!isNodeError(error) || error.name !== 'AbortError') {
@@ -786,6 +822,8 @@ export const useGeminiStream = (
startNewPrompt,
getPromptCount,
handleLoopDetectedEvent,
handleVisionSwitch,
restoreOriginalModel,
],
);
@@ -911,13 +949,10 @@ export const useGeminiStream = (
],
);
const pendingHistoryItems = useMemo(
() =>
[pendingHistoryItemRef.current, pendingToolCallGroupDisplay].filter(
(i) => i !== undefined && i !== null,
),
[pendingHistoryItemRef, pendingToolCallGroupDisplay],
);
const pendingHistoryItems = [
pendingHistoryItemRef.current,
pendingToolCallGroupDisplay,
].filter((i) => i !== undefined && i !== null);
useEffect(() => {
const saveRestorableToolCalls = async () => {

View File

@@ -0,0 +1,374 @@
/**
* @license
* Copyright 2025 Qwen
* SPDX-License-Identifier: Apache-2.0
*/
/* eslint-disable @typescript-eslint/no-explicit-any */
import { describe, it, expect, vi, beforeEach } from 'vitest';
import { renderHook, act } from '@testing-library/react';
import type { Part, PartListUnion } from '@google/genai';
import { AuthType, type Config } from '@qwen-code/qwen-code-core';
import {
shouldOfferVisionSwitch,
processVisionSwitchOutcome,
getVisionSwitchGuidanceMessage,
useVisionAutoSwitch,
} from './useVisionAutoSwitch.js';
import { VisionSwitchOutcome } from '../components/ModelSwitchDialog.js';
import { MessageType } from '../types.js';
import { getDefaultVisionModel } from '../models/availableModels.js';
describe('useVisionAutoSwitch helpers', () => {
describe('shouldOfferVisionSwitch', () => {
it('returns false when authType is not QWEN_OAUTH', () => {
const parts: PartListUnion = [
{ inlineData: { mimeType: 'image/png', data: '...' } },
];
const result = shouldOfferVisionSwitch(
parts,
AuthType.USE_GEMINI,
'qwen3-coder-plus',
true,
);
expect(result).toBe(false);
});
it('returns false when current model is already a vision model', () => {
const parts: PartListUnion = [
{ inlineData: { mimeType: 'image/png', data: '...' } },
];
const result = shouldOfferVisionSwitch(
parts,
AuthType.QWEN_OAUTH,
'qwen-vl-max-latest',
true,
);
expect(result).toBe(false);
});
it('returns true when image parts exist, QWEN_OAUTH, and model is not vision', () => {
const parts: PartListUnion = [
{ text: 'hello' },
{ inlineData: { mimeType: 'image/jpeg', data: '...' } },
];
const result = shouldOfferVisionSwitch(
parts,
AuthType.QWEN_OAUTH,
'qwen3-coder-plus',
true,
);
expect(result).toBe(true);
});
it('detects image when provided as a single Part object (non-array)', () => {
const singleImagePart: PartListUnion = {
fileData: { mimeType: 'image/gif', fileUri: 'file://image.gif' },
} as Part;
const result = shouldOfferVisionSwitch(
singleImagePart,
AuthType.QWEN_OAUTH,
'qwen3-coder-plus',
true,
);
expect(result).toBe(true);
});
it('returns false when parts contain no images', () => {
const parts: PartListUnion = [{ text: 'just text' }];
const result = shouldOfferVisionSwitch(
parts,
AuthType.QWEN_OAUTH,
'qwen3-coder-plus',
true,
);
expect(result).toBe(false);
});
it('returns false when parts is a plain string', () => {
const parts: PartListUnion = 'plain text';
const result = shouldOfferVisionSwitch(
parts,
AuthType.QWEN_OAUTH,
'qwen3-coder-plus',
true,
);
expect(result).toBe(false);
});
it('returns false when visionModelPreviewEnabled is false', () => {
const parts: PartListUnion = [
{ inlineData: { mimeType: 'image/png', data: '...' } },
];
const result = shouldOfferVisionSwitch(
parts,
AuthType.QWEN_OAUTH,
'qwen3-coder-plus',
false,
);
expect(result).toBe(false);
});
});
describe('processVisionSwitchOutcome', () => {
it('maps SwitchOnce to a one-time model override', () => {
const vl = getDefaultVisionModel();
const result = processVisionSwitchOutcome(VisionSwitchOutcome.SwitchOnce);
expect(result).toEqual({ modelOverride: vl });
});
it('maps SwitchSessionToVL to a persistent session model', () => {
const vl = getDefaultVisionModel();
const result = processVisionSwitchOutcome(
VisionSwitchOutcome.SwitchSessionToVL,
);
expect(result).toEqual({ persistSessionModel: vl });
});
it('maps DisallowWithGuidance to showGuidance', () => {
const result = processVisionSwitchOutcome(
VisionSwitchOutcome.DisallowWithGuidance,
);
expect(result).toEqual({ showGuidance: true });
});
});
describe('getVisionSwitchGuidanceMessage', () => {
it('returns the expected guidance message', () => {
const vl = getDefaultVisionModel();
const expected =
'To use images with your query, you can:\n' +
`• Use /model set ${vl} to switch to a vision-capable model\n` +
'• Or remove the image and provide a text description instead';
expect(getVisionSwitchGuidanceMessage()).toBe(expected);
});
});
});
describe('useVisionAutoSwitch hook', () => {
type AddItemFn = (
item: { type: MessageType; text: string },
ts: number,
) => any;
const createMockConfig = (authType: AuthType, initialModel: string) => {
let currentModel = initialModel;
const mockConfig: Partial<Config> = {
getModel: vi.fn(() => currentModel),
setModel: vi.fn((m: string) => {
currentModel = m;
}),
getContentGeneratorConfig: vi.fn(() => ({
authType,
model: currentModel,
apiKey: 'test-key',
vertexai: false,
})),
};
return mockConfig as Config;
};
let addItem: AddItemFn;
beforeEach(() => {
vi.clearAllMocks();
addItem = vi.fn();
});
it('returns shouldProceed=true immediately for continuations', async () => {
const config = createMockConfig(AuthType.QWEN_OAUTH, 'qwen3-coder-plus');
const { result } = renderHook(() =>
useVisionAutoSwitch(config, addItem as any, true, vi.fn()),
);
const parts: PartListUnion = [
{ inlineData: { mimeType: 'image/png', data: '...' } },
];
let res: any;
await act(async () => {
res = await result.current.handleVisionSwitch(parts, Date.now(), true);
});
expect(res).toEqual({ shouldProceed: true });
expect(addItem).not.toHaveBeenCalled();
});
it('does nothing when authType is not QWEN_OAUTH', async () => {
const config = createMockConfig(AuthType.USE_GEMINI, 'qwen3-coder-plus');
const onVisionSwitchRequired = vi.fn();
const { result } = renderHook(() =>
useVisionAutoSwitch(config, addItem as any, true, onVisionSwitchRequired),
);
const parts: PartListUnion = [
{ inlineData: { mimeType: 'image/png', data: '...' } },
];
let res: any;
await act(async () => {
res = await result.current.handleVisionSwitch(parts, 123, false);
});
expect(res).toEqual({ shouldProceed: true });
expect(onVisionSwitchRequired).not.toHaveBeenCalled();
});
it('does nothing when there are no image parts', async () => {
const config = createMockConfig(AuthType.QWEN_OAUTH, 'qwen3-coder-plus');
const onVisionSwitchRequired = vi.fn();
const { result } = renderHook(() =>
useVisionAutoSwitch(config, addItem as any, true, onVisionSwitchRequired),
);
const parts: PartListUnion = [{ text: 'no images here' }];
let res: any;
await act(async () => {
res = await result.current.handleVisionSwitch(parts, 456, false);
});
expect(res).toEqual({ shouldProceed: true });
expect(onVisionSwitchRequired).not.toHaveBeenCalled();
});
it('shows guidance and blocks when dialog returns showGuidance', async () => {
const config = createMockConfig(AuthType.QWEN_OAUTH, 'qwen3-coder-plus');
const onVisionSwitchRequired = vi
.fn()
.mockResolvedValue({ showGuidance: true });
const { result } = renderHook(() =>
useVisionAutoSwitch(config, addItem as any, true, onVisionSwitchRequired),
);
const parts: PartListUnion = [
{ inlineData: { mimeType: 'image/png', data: '...' } },
];
const userTs = 1010;
let res: any;
await act(async () => {
res = await result.current.handleVisionSwitch(parts, userTs, false);
});
expect(addItem).toHaveBeenCalledWith(
{ type: MessageType.INFO, text: getVisionSwitchGuidanceMessage() },
userTs,
);
expect(res).toEqual({ shouldProceed: false });
expect(config.setModel).not.toHaveBeenCalled();
});
it('applies a one-time override and returns originalModel, then restores', async () => {
const initialModel = 'qwen3-coder-plus';
const config = createMockConfig(AuthType.QWEN_OAUTH, initialModel);
const onVisionSwitchRequired = vi
.fn()
.mockResolvedValue({ modelOverride: 'qwen-vl-max-latest' });
const { result } = renderHook(() =>
useVisionAutoSwitch(config, addItem as any, true, onVisionSwitchRequired),
);
const parts: PartListUnion = [
{ inlineData: { mimeType: 'image/png', data: '...' } },
];
let res: any;
await act(async () => {
res = await result.current.handleVisionSwitch(parts, 2020, false);
});
expect(res).toEqual({ shouldProceed: true, originalModel: initialModel });
expect(config.setModel).toHaveBeenCalledWith('qwen-vl-max-latest');
// Now restore
act(() => {
result.current.restoreOriginalModel();
});
expect(config.setModel).toHaveBeenLastCalledWith(initialModel);
});
it('persists session model when dialog requests persistence', async () => {
const config = createMockConfig(AuthType.QWEN_OAUTH, 'qwen3-coder-plus');
const onVisionSwitchRequired = vi
.fn()
.mockResolvedValue({ persistSessionModel: 'qwen-vl-max-latest' });
const { result } = renderHook(() =>
useVisionAutoSwitch(config, addItem as any, true, onVisionSwitchRequired),
);
const parts: PartListUnion = [
{ inlineData: { mimeType: 'image/png', data: '...' } },
];
let res: any;
await act(async () => {
res = await result.current.handleVisionSwitch(parts, 3030, false);
});
expect(res).toEqual({ shouldProceed: true });
expect(config.setModel).toHaveBeenCalledWith('qwen-vl-max-latest');
// Restore should be a no-op since no one-time override was used
act(() => {
result.current.restoreOriginalModel();
});
// Last call should still be the persisted model set
expect((config.setModel as any).mock.calls.pop()?.[0]).toBe(
'qwen-vl-max-latest',
);
});
it('returns shouldProceed=true when dialog returns no special flags', async () => {
const config = createMockConfig(AuthType.QWEN_OAUTH, 'qwen3-coder-plus');
const onVisionSwitchRequired = vi.fn().mockResolvedValue({});
const { result } = renderHook(() =>
useVisionAutoSwitch(config, addItem as any, true, onVisionSwitchRequired),
);
const parts: PartListUnion = [
{ inlineData: { mimeType: 'image/png', data: '...' } },
];
let res: any;
await act(async () => {
res = await result.current.handleVisionSwitch(parts, 4040, false);
});
expect(res).toEqual({ shouldProceed: true });
expect(config.setModel).not.toHaveBeenCalled();
});
it('blocks when dialog throws or is cancelled', async () => {
const config = createMockConfig(AuthType.QWEN_OAUTH, 'qwen3-coder-plus');
const onVisionSwitchRequired = vi.fn().mockRejectedValue(new Error('x'));
const { result } = renderHook(() =>
useVisionAutoSwitch(config, addItem as any, true, onVisionSwitchRequired),
);
const parts: PartListUnion = [
{ inlineData: { mimeType: 'image/png', data: '...' } },
];
let res: any;
await act(async () => {
res = await result.current.handleVisionSwitch(parts, 5050, false);
});
expect(res).toEqual({ shouldProceed: false });
expect(config.setModel).not.toHaveBeenCalled();
});
it('does nothing when visionModelPreviewEnabled is false', async () => {
const config = createMockConfig(AuthType.QWEN_OAUTH, 'qwen3-coder-plus');
const onVisionSwitchRequired = vi.fn();
const { result } = renderHook(() =>
useVisionAutoSwitch(
config,
addItem as any,
false,
onVisionSwitchRequired,
),
);
const parts: PartListUnion = [
{ inlineData: { mimeType: 'image/png', data: '...' } },
];
let res: any;
await act(async () => {
res = await result.current.handleVisionSwitch(parts, 6060, false);
});
expect(res).toEqual({ shouldProceed: true });
expect(onVisionSwitchRequired).not.toHaveBeenCalled();
});
});

View File

@@ -0,0 +1,304 @@
/**
* @license
* Copyright 2025 Qwen
* SPDX-License-Identifier: Apache-2.0
*/
import { type PartListUnion, type Part } from '@google/genai';
import { AuthType, type Config } from '@qwen-code/qwen-code-core';
import { useCallback, useRef } from 'react';
import { VisionSwitchOutcome } from '../components/ModelSwitchDialog.js';
import {
getDefaultVisionModel,
isVisionModel,
} from '../models/availableModels.js';
import { MessageType } from '../types.js';
import type { UseHistoryManagerReturn } from './useHistoryManager.js';
import {
isSupportedImageMimeType,
getUnsupportedImageFormatWarning,
} from '@qwen-code/qwen-code-core';
/**
* Checks if a PartListUnion contains image parts
*/
function hasImageParts(parts: PartListUnion): boolean {
if (typeof parts === 'string') {
return false;
}
if (Array.isArray(parts)) {
return parts.some((part) => {
// Skip string parts
if (typeof part === 'string') return false;
return isImagePart(part);
});
}
// If it's a single Part (not a string), check if it's an image
if (typeof parts === 'object') {
return isImagePart(parts);
}
return false;
}
/**
* Checks if a single Part is an image part
*/
function isImagePart(part: Part): boolean {
// Check for inlineData with image mime type
if ('inlineData' in part && part.inlineData?.mimeType?.startsWith('image/')) {
return true;
}
// Check for fileData with image mime type
if ('fileData' in part && part.fileData?.mimeType?.startsWith('image/')) {
return true;
}
return false;
}
/**
* Checks if image parts have supported formats and returns unsupported ones
*/
function checkImageFormatsSupport(parts: PartListUnion): {
hasImages: boolean;
hasUnsupportedFormats: boolean;
unsupportedMimeTypes: string[];
} {
const unsupportedMimeTypes: string[] = [];
let hasImages = false;
if (typeof parts === 'string') {
return {
hasImages: false,
hasUnsupportedFormats: false,
unsupportedMimeTypes: [],
};
}
const partsArray = Array.isArray(parts) ? parts : [parts];
for (const part of partsArray) {
if (typeof part === 'string') continue;
let mimeType: string | undefined;
// Check inlineData
if (
'inlineData' in part &&
part.inlineData?.mimeType?.startsWith('image/')
) {
hasImages = true;
mimeType = part.inlineData.mimeType;
}
// Check fileData
if ('fileData' in part && part.fileData?.mimeType?.startsWith('image/')) {
hasImages = true;
mimeType = part.fileData.mimeType;
}
// Check if the mime type is supported
if (mimeType && !isSupportedImageMimeType(mimeType)) {
unsupportedMimeTypes.push(mimeType);
}
}
return {
hasImages,
hasUnsupportedFormats: unsupportedMimeTypes.length > 0,
unsupportedMimeTypes,
};
}
/**
* Determines if we should offer vision switch for the given parts, auth type, and current model
*/
export function shouldOfferVisionSwitch(
parts: PartListUnion,
authType: AuthType,
currentModel: string,
visionModelPreviewEnabled: boolean = false,
): boolean {
// Only trigger for qwen-oauth
if (authType !== AuthType.QWEN_OAUTH) {
return false;
}
// If vision model preview is disabled, never offer vision switch
if (!visionModelPreviewEnabled) {
return false;
}
// If current model is already a vision model, no need to switch
if (isVisionModel(currentModel)) {
return false;
}
// Check if the current message contains image parts
return hasImageParts(parts);
}
/**
* Interface for vision switch result
*/
export interface VisionSwitchResult {
modelOverride?: string;
persistSessionModel?: string;
showGuidance?: boolean;
}
/**
* Processes the vision switch outcome and returns the appropriate result
*/
export function processVisionSwitchOutcome(
outcome: VisionSwitchOutcome,
): VisionSwitchResult {
const vlModelId = getDefaultVisionModel();
switch (outcome) {
case VisionSwitchOutcome.SwitchOnce:
return { modelOverride: vlModelId };
case VisionSwitchOutcome.SwitchSessionToVL:
return { persistSessionModel: vlModelId };
case VisionSwitchOutcome.DisallowWithGuidance:
return { showGuidance: true };
default:
return { showGuidance: true };
}
}
/**
* Gets the guidance message for when vision switch is disallowed
*/
export function getVisionSwitchGuidanceMessage(): string {
const vlModelId = getDefaultVisionModel();
return `To use images with your query, you can:
• Use /model set ${vlModelId} to switch to a vision-capable model
• Or remove the image and provide a text description instead`;
}
/**
* Interface for vision switch handling result
*/
export interface VisionSwitchHandlingResult {
shouldProceed: boolean;
originalModel?: string;
}
/**
* Custom hook for handling vision model auto-switching
*/
export function useVisionAutoSwitch(
config: Config,
addItem: UseHistoryManagerReturn['addItem'],
visionModelPreviewEnabled: boolean = false,
onVisionSwitchRequired?: (query: PartListUnion) => Promise<{
modelOverride?: string;
persistSessionModel?: string;
showGuidance?: boolean;
}>,
) {
const originalModelRef = useRef<string | null>(null);
const handleVisionSwitch = useCallback(
async (
query: PartListUnion,
userMessageTimestamp: number,
isContinuation: boolean,
): Promise<VisionSwitchHandlingResult> => {
// Skip vision switch handling for continuations or if no handler provided
if (isContinuation || !onVisionSwitchRequired) {
return { shouldProceed: true };
}
const contentGeneratorConfig = config.getContentGeneratorConfig();
// Only handle qwen-oauth auth type
if (contentGeneratorConfig?.authType !== AuthType.QWEN_OAUTH) {
return { shouldProceed: true };
}
// Check image format support first
const formatCheck = checkImageFormatsSupport(query);
// If there are unsupported image formats, show warning
if (formatCheck.hasUnsupportedFormats) {
addItem(
{
type: MessageType.INFO,
text: getUnsupportedImageFormatWarning(),
},
userMessageTimestamp,
);
// Continue processing but with warning shown
}
// Check if vision switch is needed
if (
!shouldOfferVisionSwitch(
query,
contentGeneratorConfig.authType,
config.getModel(),
visionModelPreviewEnabled,
)
) {
return { shouldProceed: true };
}
try {
const visionSwitchResult = await onVisionSwitchRequired(query);
if (visionSwitchResult.showGuidance) {
// Show guidance and don't proceed with the request
addItem(
{
type: MessageType.INFO,
text: getVisionSwitchGuidanceMessage(),
},
userMessageTimestamp,
);
return { shouldProceed: false };
}
if (visionSwitchResult.modelOverride) {
// One-time model override
originalModelRef.current = config.getModel();
config.setModel(visionSwitchResult.modelOverride);
return {
shouldProceed: true,
originalModel: originalModelRef.current,
};
} else if (visionSwitchResult.persistSessionModel) {
// Persistent session model change
config.setModel(visionSwitchResult.persistSessionModel);
return { shouldProceed: true };
}
return { shouldProceed: true };
} catch (_error) {
// If vision switch dialog was cancelled or errored, don't proceed
return { shouldProceed: false };
}
},
[config, addItem, visionModelPreviewEnabled, onVisionSwitchRequired],
);
const restoreOriginalModel = useCallback(() => {
if (originalModelRef.current) {
config.setModel(originalModelRef.current);
originalModelRef.current = null;
}
}, [config]);
return {
handleVisionSwitch,
restoreOriginalModel,
};
}

View File

@@ -0,0 +1,52 @@
/**
* @license
* Copyright 2025 Qwen
* SPDX-License-Identifier: Apache-2.0
*/
export type AvailableModel = {
id: string;
label: string;
isVision?: boolean;
};
export const AVAILABLE_MODELS_QWEN: AvailableModel[] = [
{ id: 'qwen3-coder-plus', label: 'qwen3-coder-plus' },
{ id: 'qwen-vl-max-latest', label: 'qwen-vl-max', isVision: true },
];
/**
* Get available Qwen models filtered by vision model preview setting
*/
export function getFilteredQwenModels(
visionModelPreviewEnabled: boolean,
): AvailableModel[] {
if (visionModelPreviewEnabled) {
return AVAILABLE_MODELS_QWEN;
}
return AVAILABLE_MODELS_QWEN.filter((model) => !model.isVision);
}
/**
* Currently we use the single model of `OPENAI_MODEL` in the env.
* In the future, after settings.json is updated, we will allow users to configure this themselves.
*/
export function getOpenAIAvailableModelFromEnv(): AvailableModel | null {
const id = process.env['OPENAI_MODEL']?.trim();
return id ? { id, label: id } : null;
}
/**
/**
* Hard code the default vision model as a string literal,
* until our coding model supports multimodal.
*/
export function getDefaultVisionModel(): string {
return 'qwen-vl-max-latest';
}
export function isVisionModel(modelId: string): boolean {
return AVAILABLE_MODELS_QWEN.some(
(model) => model.id === modelId && model.isVision,
);
}

View File

@@ -19,3 +19,4 @@ export {
} from './src/telemetry/types.js';
export { makeFakeConfig } from './src/test-utils/config.js';
export * from './src/utils/pathReader.js';
export * from './src/utils/request-tokenizer/supportedImageFormats.js';

View File

@@ -5,9 +5,10 @@
*/
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
import { OpenAIContentGenerator } from '../openaiContentGenerator.js';
import { OpenAIContentGenerator } from '../openaiContentGenerator/openaiContentGenerator.js';
import type { Config } from '../../config/config.js';
import { AuthType } from '../contentGenerator.js';
import type { OpenAICompatibleProvider } from '../openaiContentGenerator/provider/index.js';
import OpenAI from 'openai';
// Mock OpenAI
@@ -30,6 +31,7 @@ describe('OpenAIContentGenerator Timeout Handling', () => {
let mockConfig: Config;
// eslint-disable-next-line @typescript-eslint/no-explicit-any
let mockOpenAIClient: any;
let mockProvider: OpenAICompatibleProvider;
beforeEach(() => {
// Reset mocks
@@ -42,6 +44,7 @@ describe('OpenAIContentGenerator Timeout Handling', () => {
mockConfig = {
getContentGeneratorConfig: vi.fn().mockReturnValue({
authType: 'openai',
enableOpenAILogging: false,
}),
getCliVersion: vi.fn().mockReturnValue('1.0.0'),
} as unknown as Config;
@@ -53,17 +56,34 @@ describe('OpenAIContentGenerator Timeout Handling', () => {
create: vi.fn(),
},
},
embeddings: {
create: vi.fn(),
},
};
vi.mocked(OpenAI).mockImplementation(() => mockOpenAIClient);
// Create mock provider
mockProvider = {
buildHeaders: vi.fn().mockReturnValue({
'User-Agent': 'QwenCode/1.0.0 (test; test)',
}),
buildClient: vi.fn().mockReturnValue(mockOpenAIClient),
buildRequest: vi.fn().mockImplementation((req) => req),
};
// Create generator instance
const contentGeneratorConfig = {
model: 'gpt-4',
apiKey: 'test-key',
authType: AuthType.USE_OPENAI,
enableOpenAILogging: false,
};
generator = new OpenAIContentGenerator(contentGeneratorConfig, mockConfig);
generator = new OpenAIContentGenerator(
contentGeneratorConfig,
mockConfig,
mockProvider,
);
});
afterEach(() => {
@@ -209,7 +229,7 @@ describe('OpenAIContentGenerator Timeout Handling', () => {
await expect(
generator.generateContentStream(request, 'test-prompt-id'),
).rejects.toThrow(
/Streaming setup timeout after \d+s\. Try reducing input length or increasing timeout in config\./,
/Streaming request timeout after \d+s\. Try reducing input length or increasing timeout in config\./,
);
});
@@ -227,12 +247,8 @@ describe('OpenAIContentGenerator Timeout Handling', () => {
} catch (error: unknown) {
const errorMessage =
error instanceof Error ? error.message : String(error);
expect(errorMessage).toContain(
'Streaming setup timeout troubleshooting:',
);
expect(errorMessage).toContain(
'Check network connectivity and firewall settings',
);
expect(errorMessage).toContain('Streaming timeout troubleshooting:');
expect(errorMessage).toContain('Check network connectivity');
expect(errorMessage).toContain('Consider using non-streaming mode');
}
});
@@ -246,23 +262,21 @@ describe('OpenAIContentGenerator Timeout Handling', () => {
authType: AuthType.USE_OPENAI,
baseUrl: 'http://localhost:8080',
};
new OpenAIContentGenerator(contentGeneratorConfig, mockConfig);
new OpenAIContentGenerator(
contentGeneratorConfig,
mockConfig,
mockProvider,
);
// Verify OpenAI client was created with timeout config
expect(OpenAI).toHaveBeenCalledWith({
apiKey: 'test-key',
baseURL: 'http://localhost:8080',
timeout: 120000,
maxRetries: 3,
defaultHeaders: {
'User-Agent': expect.stringMatching(/^QwenCode/),
},
});
// Verify provider buildClient was called
expect(mockProvider.buildClient).toHaveBeenCalled();
});
it('should use custom timeout from config', () => {
const customConfig = {
getContentGeneratorConfig: vi.fn().mockReturnValue({}),
getContentGeneratorConfig: vi.fn().mockReturnValue({
enableOpenAILogging: false,
}),
getCliVersion: vi.fn().mockReturnValue('1.0.0'),
} as unknown as Config;
@@ -274,22 +288,31 @@ describe('OpenAIContentGenerator Timeout Handling', () => {
timeout: 300000,
maxRetries: 5,
};
new OpenAIContentGenerator(contentGeneratorConfig, customConfig);
expect(OpenAI).toHaveBeenCalledWith({
apiKey: 'test-key',
baseURL: 'http://localhost:8080',
timeout: 300000,
maxRetries: 5,
defaultHeaders: {
'User-Agent': expect.stringMatching(/^QwenCode/),
},
});
// Create a custom mock provider for this test
const customMockProvider: OpenAICompatibleProvider = {
buildHeaders: vi.fn().mockReturnValue({
'User-Agent': 'QwenCode/1.0.0 (test; test)',
}),
buildClient: vi.fn().mockReturnValue(mockOpenAIClient),
buildRequest: vi.fn().mockImplementation((req) => req),
};
new OpenAIContentGenerator(
contentGeneratorConfig,
customConfig,
customMockProvider,
);
// Verify provider buildClient was called
expect(customMockProvider.buildClient).toHaveBeenCalled();
});
it('should handle missing timeout config gracefully', () => {
const noTimeoutConfig = {
getContentGeneratorConfig: vi.fn().mockReturnValue({}),
getContentGeneratorConfig: vi.fn().mockReturnValue({
enableOpenAILogging: false,
}),
getCliVersion: vi.fn().mockReturnValue('1.0.0'),
} as unknown as Config;
@@ -299,17 +322,24 @@ describe('OpenAIContentGenerator Timeout Handling', () => {
authType: AuthType.USE_OPENAI,
baseUrl: 'http://localhost:8080',
};
new OpenAIContentGenerator(contentGeneratorConfig, noTimeoutConfig);
expect(OpenAI).toHaveBeenCalledWith({
apiKey: 'test-key',
baseURL: 'http://localhost:8080',
timeout: 120000, // default
maxRetries: 3, // default
defaultHeaders: {
'User-Agent': expect.stringMatching(/^QwenCode/),
},
});
// Create a custom mock provider for this test
const noTimeoutMockProvider: OpenAICompatibleProvider = {
buildHeaders: vi.fn().mockReturnValue({
'User-Agent': 'QwenCode/1.0.0 (test; test)',
}),
buildClient: vi.fn().mockReturnValue(mockOpenAIClient),
buildRequest: vi.fn().mockImplementation((req) => req),
};
new OpenAIContentGenerator(
contentGeneratorConfig,
noTimeoutConfig,
noTimeoutMockProvider,
);
// Verify provider buildClient was called
expect(noTimeoutMockProvider.buildClient).toHaveBeenCalled();
});
});

View File

@@ -226,9 +226,6 @@ describe('Gemini Client (client.ts)', () => {
vertexai: false,
authType: AuthType.USE_GEMINI,
};
const mockSubagentManager = {
listSubagents: vi.fn().mockResolvedValue([]),
};
const mockConfigObject = {
getContentGeneratorConfig: vi
.fn()
@@ -263,7 +260,6 @@ describe('Gemini Client (client.ts)', () => {
getCliVersion: vi.fn().mockReturnValue('1.0.0'),
getChatCompression: vi.fn().mockReturnValue(undefined),
getSkipNextSpeakerCheck: vi.fn().mockReturnValue(false),
getSubagentManager: vi.fn().mockReturnValue(mockSubagentManager),
getSkipLoopDetection: vi.fn().mockReturnValue(false),
};
const MockedConfig = vi.mocked(Config, true);

View File

@@ -29,7 +29,6 @@ import {
makeChatCompressionEvent,
NextSpeakerCheckEvent,
} from '../telemetry/types.js';
import { TaskTool } from '../tools/task.js';
import {
getDirectoryContextString,
getEnvironmentContext,
@@ -456,8 +455,7 @@ export class GeminiClient {
turns: number = MAX_TURNS,
originalModel?: string,
): AsyncGenerator<ServerGeminiStreamEvent, Turn> {
const isNewPrompt = this.lastPromptId !== prompt_id;
if (isNewPrompt) {
if (this.lastPromptId !== prompt_id) {
this.loopDetector.reset(prompt_id);
this.lastPromptId = prompt_id;
}
@@ -554,24 +552,6 @@ export class GeminiClient {
this.forceFullIdeContext = false;
}
if (isNewPrompt) {
const taskTool = this.config.getToolRegistry().getTool(TaskTool.Name);
const subagents = (
await this.config.getSubagentManager().listSubagents()
).filter((subagent) => subagent.level !== 'builtin');
if (taskTool && subagents.length > 0) {
this.getChat().addHistory({
role: 'user',
parts: [
{
text: `<system-reminder>You have powerful specialized agents at your disposal, available agent types are: ${subagents.map((subagent) => subagent.name).join(', ')}. PROACTIVELY use the ${TaskTool.Name} tool to delegate user's task to appropriate agent when user's task matches agent capabilities. Ignore this message if user's task is not relevant to any agent. This message is for internal use only. Do not mention this to user in your response.</system-reminder>`,
},
],
});
}
}
const turn = new Turn(this.getChat(), prompt_id);
if (!this.config.getSkipLoopDetection()) {

View File

@@ -500,7 +500,7 @@ export class GeminiChat {
if (error instanceof Error && error.message) {
if (isSchemaDepthError(error.message)) return false;
if (error.message.includes('429')) return true;
if (error.message.match(/5\d{2}/)) return true;
if (error.message.match(/^5\d{2}/)) return true;
}
return false;
},

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -5,6 +5,37 @@
*/
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
// Mock the request tokenizer module BEFORE importing the class that uses it
const mockTokenizer = {
calculateTokens: vi.fn().mockResolvedValue({
totalTokens: 50,
breakdown: {
textTokens: 50,
imageTokens: 0,
audioTokens: 0,
otherTokens: 0,
},
processingTime: 1,
}),
dispose: vi.fn(),
};
vi.mock('../../../utils/request-tokenizer/index.js', () => ({
getDefaultTokenizer: vi.fn(() => mockTokenizer),
DefaultRequestTokenizer: vi.fn(() => mockTokenizer),
disposeDefaultTokenizer: vi.fn(),
}));
// Mock tiktoken as well for completeness
vi.mock('tiktoken', () => ({
get_encoding: vi.fn(() => ({
encode: vi.fn(() => new Array(50)), // Mock 50 tokens
free: vi.fn(),
})),
}));
// Now import the modules that depend on the mocked modules
import { OpenAIContentGenerator } from './openaiContentGenerator.js';
import type { Config } from '../../config/config.js';
import { AuthType } from '../contentGenerator.js';
@@ -15,14 +46,6 @@ import type {
import type { OpenAICompatibleProvider } from './provider/index.js';
import type OpenAI from 'openai';
// Mock tiktoken
vi.mock('tiktoken', () => ({
get_encoding: vi.fn().mockReturnValue({
encode: vi.fn().mockReturnValue(new Array(50)), // Mock 50 tokens
free: vi.fn(),
}),
}));
describe('OpenAIContentGenerator (Refactored)', () => {
let generator: OpenAIContentGenerator;
let mockConfig: Config;

View File

@@ -13,6 +13,7 @@ import type { PipelineConfig } from './pipeline.js';
import { ContentGenerationPipeline } from './pipeline.js';
import { DefaultTelemetryService } from './telemetryService.js';
import { EnhancedErrorHandler } from './errorHandler.js';
import { getDefaultTokenizer } from '../../utils/request-tokenizer/index.js';
import type { ContentGeneratorConfig } from '../contentGenerator.js';
export class OpenAIContentGenerator implements ContentGenerator {
@@ -71,27 +72,30 @@ export class OpenAIContentGenerator implements ContentGenerator {
async countTokens(
request: CountTokensParameters,
): Promise<CountTokensResponse> {
// Use tiktoken for accurate token counting
const content = JSON.stringify(request.contents);
let totalTokens = 0;
try {
const { get_encoding } = await import('tiktoken');
const encoding = get_encoding('cl100k_base'); // GPT-4 encoding, but estimate for qwen
totalTokens = encoding.encode(content).length;
encoding.free();
// Use the new high-performance request tokenizer
const tokenizer = getDefaultTokenizer();
const result = await tokenizer.calculateTokens(request, {
textEncoding: 'cl100k_base', // Use GPT-4 encoding for consistency
});
return {
totalTokens: result.totalTokens,
};
} catch (error) {
console.warn(
'Failed to load tiktoken, falling back to character approximation:',
'Failed to calculate tokens with new tokenizer, falling back to simple method:',
error,
);
// Fallback: rough approximation using character count
totalTokens = Math.ceil(content.length / 4); // Rough estimate: 1 token ≈ 4 characters
}
return {
totalTokens,
};
// Fallback to original simple method
const content = JSON.stringify(request.contents);
const totalTokens = Math.ceil(content.length / 4); // Rough estimate: 1 token ≈ 4 characters
return {
totalTokens,
};
}
}
async embedContent(

View File

@@ -10,14 +10,11 @@ import {
GenerateContentResponse,
} from '@google/genai';
import type { Config } from '../../config/config.js';
import { type ContentGeneratorConfig } from '../contentGenerator.js';
import { type OpenAICompatibleProvider } from './provider/index.js';
import type { ContentGeneratorConfig } from '../contentGenerator.js';
import type { OpenAICompatibleProvider } from './provider/index.js';
import { OpenAIContentConverter } from './converter.js';
import {
type TelemetryService,
type RequestContext,
} from './telemetryService.js';
import { type ErrorHandler } from './errorHandler.js';
import type { TelemetryService, RequestContext } from './telemetryService.js';
import type { ErrorHandler } from './errorHandler.js';
export interface PipelineConfig {
cliConfig: Config;
@@ -101,7 +98,7 @@ export class ContentGenerationPipeline {
* 2. Filter empty responses
* 3. Handle chunk merging for providers that send finishReason and usageMetadata separately
* 4. Collect both formats for logging
* 5. Handle success/error logging with original OpenAI format
* 5. Handle success/error logging
*/
private async *processStreamWithLogging(
stream: AsyncIterable<OpenAI.Chat.ChatCompletionChunk>,
@@ -169,19 +166,11 @@ export class ContentGenerationPipeline {
collectedOpenAIChunks,
);
} catch (error) {
// Stage 2e: Stream failed - handle error and logging
context.duration = Date.now() - context.startTime;
// Clear streaming tool calls on error to prevent data pollution
this.converter.resetStreamingToolCalls();
await this.config.telemetryService.logError(
context,
error,
openaiRequest,
);
this.config.errorHandler.handle(error, context, request);
// Use shared error handling logic
await this.handleError(error, context, request);
}
}
@@ -365,25 +354,59 @@ export class ContentGenerationPipeline {
context.duration = Date.now() - context.startTime;
return result;
} catch (error) {
context.duration = Date.now() - context.startTime;
// Log error
const openaiRequest = await this.buildRequest(
// Use shared error handling logic
return await this.handleError(
error,
context,
request,
userPromptId,
isStreaming,
);
await this.config.telemetryService.logError(
context,
error,
openaiRequest,
);
// Handle and throw enhanced error
this.config.errorHandler.handle(error, context, request);
}
}
/**
* Shared error handling logic for both executeWithErrorHandling and processStreamWithLogging
* This centralizes the common error processing steps to avoid duplication
*/
private async handleError(
error: unknown,
context: RequestContext,
request: GenerateContentParameters,
userPromptId?: string,
isStreaming?: boolean,
): Promise<never> {
context.duration = Date.now() - context.startTime;
// Build request for logging (may fail, but we still want to log the error)
let openaiRequest: OpenAI.Chat.ChatCompletionCreateParams;
try {
if (userPromptId !== undefined && isStreaming !== undefined) {
openaiRequest = await this.buildRequest(
request,
userPromptId,
isStreaming,
);
} else {
// For processStreamWithLogging, we don't have userPromptId/isStreaming,
// so create a minimal request
openaiRequest = {
model: this.contentGeneratorConfig.model,
messages: [],
};
}
} catch (_buildError) {
// If we can't build the request, create a minimal one for logging
openaiRequest = {
model: this.contentGeneratorConfig.model,
messages: [],
};
}
await this.config.telemetryService.logError(context, error, openaiRequest);
this.config.errorHandler.handle(error, context, request);
}
/**
* Create request context with common properties
*/

View File

@@ -79,6 +79,16 @@ export class DashScopeOpenAICompatibleProvider
messages = this.addDashScopeCacheControl(messages, cacheTarget);
}
if (request.model.startsWith('qwen-vl')) {
return {
...request,
messages,
...(this.buildMetadata(userPromptId) || {}),
/* @ts-expect-error dashscope exclusive */
vl_high_resolution_images: true,
};
}
return {
...request, // Preserve all original parameters including sampling params
messages,

View File

@@ -116,6 +116,9 @@ const PATTERNS: Array<[RegExp, TokenCount]> = [
[/^qwen-flash-latest$/, LIMITS['1m']],
[/^qwen-turbo.*$/, LIMITS['128k']],
// Qwen Vision Models
[/^qwen-vl-max.*$/, LIMITS['128k']],
// -------------------
// ByteDance Seed-OSS (512K)
// -------------------

View File

@@ -242,7 +242,7 @@ describe('Turn', () => {
expect(turn.getDebugResponses().length).toBe(0);
expect(reportError).toHaveBeenCalledWith(
error,
'Error when talking to Gemini API',
'Error when talking to API',
[...historyContent, reqParts],
'Turn.run-sendMessageStream',
);

View File

@@ -310,7 +310,7 @@ export class Turn {
const contextForReport = [...this.chat.getHistory(/*curated*/ true), req];
await reportError(
error,
'Error when talking to Gemini API',
'Error when talking to API',
contextForReport,
'Turn.run-sendMessageStream',
);

View File

@@ -401,11 +401,9 @@ describe('QwenContentGenerator', () => {
expect(mockQwenClient.getAccessToken).toHaveBeenCalled();
});
it('should count tokens with valid token', async () => {
vi.mocked(mockQwenClient.getAccessToken).mockResolvedValue({
token: 'valid-token',
});
vi.mocked(mockQwenClient.getCredentials).mockReturnValue(mockCredentials);
it('should count tokens without requiring authentication', async () => {
// Clear any previous mock calls
vi.clearAllMocks();
const request: CountTokensParameters = {
model: 'qwen-turbo',
@@ -415,7 +413,8 @@ describe('QwenContentGenerator', () => {
const result = await qwenContentGenerator.countTokens(request);
expect(result.totalTokens).toBe(15);
expect(mockQwenClient.getAccessToken).toHaveBeenCalled();
// countTokens is a local operation and should not require OAuth credentials
expect(mockQwenClient.getAccessToken).not.toHaveBeenCalled();
});
it('should embed content with valid token', async () => {
@@ -1652,7 +1651,7 @@ describe('QwenContentGenerator', () => {
SharedTokenManager.getInstance = originalGetInstance;
});
it('should handle all method types with token failure', async () => {
it('should handle method types with token failure (except countTokens)', async () => {
const mockTokenManager = {
getValidCredentials: vi
.fn()
@@ -1685,7 +1684,7 @@ describe('QwenContentGenerator', () => {
contents: [{ parts: [{ text: 'Embed' }] }],
};
// All methods should fail with the same error
// Methods requiring authentication should fail
await expect(
newGenerator.generateContent(generateRequest, 'test-id'),
).rejects.toThrow('Failed to obtain valid Qwen access token');
@@ -1694,14 +1693,14 @@ describe('QwenContentGenerator', () => {
newGenerator.generateContentStream(generateRequest, 'test-id'),
).rejects.toThrow('Failed to obtain valid Qwen access token');
await expect(newGenerator.countTokens(countRequest)).rejects.toThrow(
'Failed to obtain valid Qwen access token',
);
await expect(newGenerator.embedContent(embedRequest)).rejects.toThrow(
'Failed to obtain valid Qwen access token',
);
// countTokens should succeed as it's a local operation
const countResult = await newGenerator.countTokens(countRequest);
expect(countResult.totalTokens).toBe(15);
SharedTokenManager.getInstance = originalGetInstance;
});
});

View File

@@ -180,9 +180,7 @@ export class QwenContentGenerator extends OpenAIContentGenerator {
override async countTokens(
request: CountTokensParameters,
): Promise<CountTokensResponse> {
return this.executeWithCredentialManagement(() =>
super.countTokens(request),
);
return super.countTokens(request);
}
/**

View File

@@ -185,7 +185,6 @@ You are a helpful assistant.
const config = manager.parseSubagentContent(
validMarkdown,
validConfig.filePath,
'project',
);
expect(config.name).toBe('test-agent');
@@ -210,7 +209,6 @@ You are a helpful assistant.
const config = manager.parseSubagentContent(
markdownWithTools,
validConfig.filePath,
'project',
);
expect(config.tools).toEqual(['read_file', 'write_file']);
@@ -231,7 +229,6 @@ You are a helpful assistant.
const config = manager.parseSubagentContent(
markdownWithModel,
validConfig.filePath,
'project',
);
expect(config.modelConfig).toEqual({ model: 'custom-model', temp: 0.5 });
@@ -252,7 +249,6 @@ You are a helpful assistant.
const config = manager.parseSubagentContent(
markdownWithRun,
validConfig.filePath,
'project',
);
expect(config.runConfig).toEqual({ max_time_minutes: 5, max_turns: 10 });
@@ -270,7 +266,6 @@ You are a helpful assistant.
const config = manager.parseSubagentContent(
markdownWithNumeric,
validConfig.filePath,
'project',
);
expect(config.name).toBe('11');
@@ -291,7 +286,6 @@ You are a helpful assistant.
const config = manager.parseSubagentContent(
markdownWithBoolean,
validConfig.filePath,
'project',
);
expect(config.name).toBe('true');
@@ -307,13 +301,8 @@ You are a helpful assistant.
const projectConfig = manager.parseSubagentContent(
validMarkdown,
projectPath,
'project',
);
const userConfig = manager.parseSubagentContent(
validMarkdown,
userPath,
'user',
);
const userConfig = manager.parseSubagentContent(validMarkdown, userPath);
expect(projectConfig.level).toBe('project');
expect(userConfig.level).toBe('user');
@@ -324,11 +313,7 @@ You are a helpful assistant.
Just content`;
expect(() =>
manager.parseSubagentContent(
invalidMarkdown,
validConfig.filePath,
'project',
),
manager.parseSubagentContent(invalidMarkdown, validConfig.filePath),
).toThrow(SubagentError);
});
@@ -341,11 +326,7 @@ You are a helpful assistant.
`;
expect(() =>
manager.parseSubagentContent(
markdownWithoutName,
validConfig.filePath,
'project',
),
manager.parseSubagentContent(markdownWithoutName, validConfig.filePath),
).toThrow(SubagentError);
});
@@ -361,20 +342,39 @@ You are a helpful assistant.
manager.parseSubagentContent(
markdownWithoutDescription,
validConfig.filePath,
'project',
),
).toThrow(SubagentError);
});
it('should warn when filename does not match subagent name', () => {
const consoleSpy = vi.spyOn(console, 'warn').mockImplementation(() => {});
const mismatchedPath = '/test/project/.qwen/agents/wrong-filename.md';
const config = manager.parseSubagentContent(
validMarkdown,
mismatchedPath,
);
expect(config.name).toBe('test-agent');
expect(consoleSpy).toHaveBeenCalledWith(
expect.stringContaining(
'Warning: Subagent file "wrong-filename.md" contains name "test-agent"',
),
);
expect(consoleSpy).toHaveBeenCalledWith(
expect.stringContaining(
'Consider renaming the file to "test-agent.md"',
),
);
consoleSpy.mockRestore();
});
it('should not warn when filename matches subagent name', () => {
const consoleSpy = vi.spyOn(console, 'warn').mockImplementation(() => {});
const matchingPath = '/test/project/.qwen/agents/test-agent.md';
const config = manager.parseSubagentContent(
validMarkdown,
matchingPath,
'project',
);
const config = manager.parseSubagentContent(validMarkdown, matchingPath);
expect(config.name).toBe('test-agent');
expect(consoleSpy).not.toHaveBeenCalled();

View File

@@ -39,7 +39,6 @@ const AGENT_CONFIG_DIR = 'agents';
*/
export class SubagentManager {
private readonly validator: SubagentValidator;
private subagentsCache: Map<SubagentLevel, SubagentConfig[]> | null = null;
constructor(private readonly config: Config) {
this.validator = new SubagentValidator();
@@ -93,8 +92,6 @@ export class SubagentManager {
try {
await fs.writeFile(filePath, content, 'utf8');
// Clear cache after successful creation
this.clearCache();
} catch (error) {
throw new SubagentError(
`Failed to write subagent file: ${error instanceof Error ? error.message : 'Unknown error'}`,
@@ -183,8 +180,6 @@ export class SubagentManager {
try {
await fs.writeFile(existing.filePath, content, 'utf8');
// Clear cache after successful update
this.clearCache();
} catch (error) {
throw new SubagentError(
`Failed to update subagent file: ${error instanceof Error ? error.message : 'Unknown error'}`,
@@ -241,9 +236,6 @@ export class SubagentManager {
name,
);
}
// Clear cache after successful deletion
this.clearCache();
}
/**
@@ -262,17 +254,9 @@ export class SubagentManager {
? [options.level]
: ['project', 'user', 'builtin'];
// Check if we should use cache or force refresh
const shouldUseCache = !options.force && this.subagentsCache !== null;
// Initialize cache if it doesn't exist or we're forcing a refresh
if (!shouldUseCache) {
await this.refreshCache();
}
// Collect subagents from each level (project takes precedence over user, user takes precedence over builtin)
for (const level of levelsToCheck) {
const levelSubagents = this.subagentsCache?.get(level) || [];
const levelSubagents = await this.listSubagentsAtLevel(level);
for (const subagent of levelSubagents) {
// Skip if we've already seen this name (precedence: project > user > builtin)
@@ -320,30 +304,6 @@ export class SubagentManager {
return subagents;
}
/**
* Refreshes the subagents cache by loading all subagents from disk.
* This method is called automatically when cache is null or when force=true.
*
* @private
*/
private async refreshCache(): Promise<void> {
this.subagentsCache = new Map();
const levels: SubagentLevel[] = ['project', 'user', 'builtin'];
for (const level of levels) {
const levelSubagents = await this.listSubagentsAtLevel(level);
this.subagentsCache.set(level, levelSubagents);
}
}
/**
* Clears the subagents cache, forcing the next listSubagents call to reload from disk.
*/
clearCache(): void {
this.subagentsCache = null;
}
/**
* Finds a subagent by name and returns its metadata.
*
@@ -369,10 +329,7 @@ export class SubagentManager {
* @returns SubagentConfig
* @throws SubagentError if parsing fails
*/
async parseSubagentFile(
filePath: string,
level: SubagentLevel,
): Promise<SubagentConfig> {
async parseSubagentFile(filePath: string): Promise<SubagentConfig> {
let content: string;
try {
@@ -384,7 +341,7 @@ export class SubagentManager {
);
}
return this.parseSubagentContent(content, filePath, level);
return this.parseSubagentContent(content, filePath);
}
/**
@@ -395,11 +352,7 @@ export class SubagentManager {
* @returns SubagentConfig
* @throws SubagentError if parsing fails
*/
parseSubagentContent(
content: string,
filePath: string,
level: SubagentLevel,
): SubagentConfig {
parseSubagentContent(content: string, filePath: string): SubagentConfig {
try {
// Split frontmatter and content
const frontmatterRegex = /^---\n([\s\S]*?)\n---\n([\s\S]*)$/;
@@ -440,16 +393,31 @@ export class SubagentManager {
| undefined;
const color = frontmatter['color'] as string | undefined;
// Determine level from file path using robust, cross-platform check
// A project-level agent lives under <projectRoot>/.qwen/agents
const projectAgentsDir = path.join(
this.config.getProjectRoot(),
QWEN_CONFIG_DIR,
AGENT_CONFIG_DIR,
);
const rel = path.relative(
path.normalize(projectAgentsDir),
path.normalize(filePath),
);
const isProjectLevel =
rel !== '' && !rel.startsWith('..') && !path.isAbsolute(rel);
const level: SubagentLevel = isProjectLevel ? 'project' : 'user';
const config: SubagentConfig = {
name,
description,
tools,
systemPrompt: systemPrompt.trim(),
level,
filePath,
modelConfig: modelConfig as Partial<ModelConfig>,
runConfig: runConfig as Partial<RunConfig>,
color,
level,
};
// Validate the parsed configuration
@@ -458,6 +426,16 @@ export class SubagentManager {
throw new Error(`Validation failed: ${validation.errors.join(', ')}`);
}
// Warn if filename doesn't match subagent name (potential issue)
const expectedFilename = `${config.name}.md`;
const actualFilename = path.basename(filePath);
if (actualFilename !== expectedFilename) {
console.warn(
`Warning: Subagent file "${actualFilename}" contains name "${config.name}" but filename suggests "${path.basename(actualFilename, '.md')}". ` +
`Consider renaming the file to "${expectedFilename}" for consistency.`,
);
}
return config;
} catch (error) {
throw new SubagentError(
@@ -700,18 +678,14 @@ export class SubagentManager {
return BuiltinAgentRegistry.getBuiltinAgents();
}
const projectRoot = this.config.getProjectRoot();
const homeDir = os.homedir();
const isHomeDirectory = path.resolve(projectRoot) === path.resolve(homeDir);
// If project level is requested but project root is same as home directory,
// return empty array to avoid conflicts between project and global agents
if (level === 'project' && isHomeDirectory) {
return [];
}
let baseDir = level === 'project' ? projectRoot : homeDir;
baseDir = path.join(baseDir, QWEN_CONFIG_DIR, AGENT_CONFIG_DIR);
const baseDir =
level === 'project'
? path.join(
this.config.getProjectRoot(),
QWEN_CONFIG_DIR,
AGENT_CONFIG_DIR,
)
: path.join(os.homedir(), QWEN_CONFIG_DIR, AGENT_CONFIG_DIR);
try {
const files = await fs.readdir(baseDir);
@@ -723,7 +697,7 @@ export class SubagentManager {
const filePath = path.join(baseDir, file);
try {
const config = await this.parseSubagentFile(filePath, level);
const config = await this.parseSubagentFile(filePath);
subagents.push(config);
} catch (_error) {
// Ignore invalid files

View File

@@ -116,9 +116,6 @@ export interface ListSubagentsOptions {
/** Sort direction */
sortOrder?: 'asc' | 'desc';
/** Force refresh from disk, bypassing cache. Defaults to false. */
force?: boolean;
}
/**

View File

@@ -62,9 +62,6 @@ describe('GlobTool', () => {
// Ensure a noticeable difference in modification time
await new Promise((resolve) => setTimeout(resolve, 50));
await fs.writeFile(path.join(tempRootDir, 'newer.sortme'), 'newer_content');
// For type coercion testing
await fs.mkdir(path.join(tempRootDir, '123'));
});
afterEach(async () => {
@@ -282,20 +279,26 @@ describe('GlobTool', () => {
);
});
it('should pass if path is provided but is not a string (type coercion)', () => {
it('should return error if path is provided but is not a string (schema validation)', () => {
const params = {
pattern: '*.ts',
path: 123,
} as unknown as GlobToolParams; // Force incorrect type
expect(globTool.validateToolParams(params)).toBeNull();
};
// @ts-expect-error - We're intentionally creating invalid params for testing
expect(globTool.validateToolParams(params)).toBe(
'params/path must be string',
);
});
it('should pass if case_sensitive is provided but is not a boolean (type coercion)', () => {
it('should return error if case_sensitive is provided but is not a boolean (schema validation)', () => {
const params = {
pattern: '*.ts',
case_sensitive: 'true',
} as unknown as GlobToolParams; // Force incorrect type
expect(globTool.validateToolParams(params)).toBeNull();
};
// @ts-expect-error - We're intentionally creating invalid params for testing
expect(globTool.validateToolParams(params)).toBe(
'params/case_sensitive must be boolean',
);
});
it("should return error if search path resolves outside the tool's root directory", () => {

View File

@@ -191,12 +191,14 @@ describe('ReadManyFilesTool', () => {
);
});
it('should coerce non-string elements in include array', () => {
it('should throw error if include array contains non-string elements', () => {
const params = {
paths: ['file1.txt'],
include: ['*.ts', 123] as string[],
};
expect(() => tool.build(params)).toBeDefined();
expect(() => tool.build(params)).toThrow(
'params/include/1 must be string',
);
});
it('should throw error if exclude array contains non-string elements', () => {

View File

@@ -419,11 +419,6 @@ export class ShellTool extends BaseDeclarativeTool<
type: 'string',
description: getCommandDescription(),
},
is_background: {
type: 'boolean',
description:
'Whether to run the command in background. Default is false. Set to true for long-running processes like development servers, watchers, or daemons that should continue running without blocking further commands.',
},
description: {
type: 'string',
description:

View File

@@ -220,12 +220,14 @@ describe('WriteFileTool', () => {
);
});
it('should coerce null content into an empty string', () => {
it('should throw an error if the content is null', () => {
const dirAsFilePath = path.join(rootDir, 'a_directory');
fs.mkdirSync(dirAsFilePath);
const params = {
file_path: path.join(rootDir, 'test.txt'),
file_path: dirAsFilePath,
content: null,
} as unknown as WriteFileToolParams; // Intentionally non-conforming
expect(() => tool.build(params)).toBeDefined();
expect(() => tool.build(params)).toThrow('params/content must be string');
});
it('should throw error if the file_path is empty', () => {

View File

@@ -0,0 +1,157 @@
/**
* @license
* Copyright 2025 Qwen
* SPDX-License-Identifier: Apache-2.0
*/
import { describe, it, expect } from 'vitest';
import { ImageTokenizer } from './imageTokenizer.js';
describe('ImageTokenizer', () => {
const tokenizer = new ImageTokenizer();
describe('token calculation', () => {
it('should calculate tokens based on image dimensions with reference logic', () => {
const metadata = {
width: 28,
height: 28,
mimeType: 'image/png',
dataSize: 1000,
};
const tokens = tokenizer.calculateTokens(metadata);
// 28x28 = 784 pixels = 1 image token + 2 special tokens = 3 total
// But minimum scaling may apply for small images
expect(tokens).toBeGreaterThanOrEqual(6); // Minimum after scaling + special tokens
});
it('should calculate tokens for larger images', () => {
const metadata = {
width: 512,
height: 512,
mimeType: 'image/png',
dataSize: 10000,
};
const tokens = tokenizer.calculateTokens(metadata);
// 512x512 with reference logic: rounded dimensions + scaling + special tokens
expect(tokens).toBeGreaterThan(300);
expect(tokens).toBeLessThan(400); // Should be reasonable for 512x512
});
it('should enforce minimum tokens per image with scaling', () => {
const metadata = {
width: 1,
height: 1,
mimeType: 'image/png',
dataSize: 100,
};
const tokens = tokenizer.calculateTokens(metadata);
// Tiny images get scaled up to minimum pixels + special tokens
expect(tokens).toBeGreaterThanOrEqual(6); // 4 image tokens + 2 special tokens
});
it('should handle very large images with scaling', () => {
const metadata = {
width: 8192,
height: 8192,
mimeType: 'image/png',
dataSize: 100000,
};
const tokens = tokenizer.calculateTokens(metadata);
// Very large images should be scaled down to max limit + special tokens
expect(tokens).toBeLessThanOrEqual(16386); // 16384 max + 2 special tokens
expect(tokens).toBeGreaterThan(16000); // Should be close to the limit
});
});
describe('PNG dimension extraction', () => {
it('should extract dimensions from valid PNG', async () => {
// 1x1 PNG image in base64
const pngBase64 =
'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNkYPhfDwAChAI9jU77yQAAAABJRU5ErkJggg==';
const metadata = await tokenizer.extractImageMetadata(
pngBase64,
'image/png',
);
expect(metadata.width).toBe(1);
expect(metadata.height).toBe(1);
expect(metadata.mimeType).toBe('image/png');
});
it('should handle invalid PNG gracefully', async () => {
const invalidBase64 = 'invalid-png-data';
const metadata = await tokenizer.extractImageMetadata(
invalidBase64,
'image/png',
);
// Should return default dimensions
expect(metadata.width).toBe(512);
expect(metadata.height).toBe(512);
expect(metadata.mimeType).toBe('image/png');
});
});
describe('batch processing', () => {
it('should process multiple images serially', async () => {
const pngBase64 =
'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNkYPhfDwAChAI9jU77yQAAAABJRU5ErkJggg==';
const images = [
{ data: pngBase64, mimeType: 'image/png' },
{ data: pngBase64, mimeType: 'image/png' },
{ data: pngBase64, mimeType: 'image/png' },
];
const tokens = await tokenizer.calculateTokensBatch(images);
expect(tokens).toHaveLength(3);
expect(tokens.every((t) => t >= 4)).toBe(true); // All should have at least 4 tokens
});
it('should handle mixed valid and invalid images', async () => {
const validPng =
'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNkYPhfDwAChAI9jU77yQAAAABJRU5ErkJggg==';
const invalidPng = 'invalid-data';
const images = [
{ data: validPng, mimeType: 'image/png' },
{ data: invalidPng, mimeType: 'image/png' },
];
const tokens = await tokenizer.calculateTokensBatch(images);
expect(tokens).toHaveLength(2);
expect(tokens.every((t) => t >= 4)).toBe(true); // All should have at least minimum tokens
});
});
describe('different image formats', () => {
it('should handle different MIME types', async () => {
const pngBase64 =
'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNkYPhfDwAChAI9jU77yQAAAABJRU5ErkJggg==';
const formats = ['image/png', 'image/jpeg', 'image/webp', 'image/gif'];
for (const mimeType of formats) {
const metadata = await tokenizer.extractImageMetadata(
pngBase64,
mimeType,
);
expect(metadata.mimeType).toBe(mimeType);
expect(metadata.width).toBeGreaterThan(0);
expect(metadata.height).toBeGreaterThan(0);
}
});
});
});

View File

@@ -0,0 +1,505 @@
/**
* @license
* Copyright 2025 Qwen
* SPDX-License-Identifier: Apache-2.0
*/
import type { ImageMetadata } from './types.js';
import { isSupportedImageMimeType } from './supportedImageFormats.js';
/**
* Image tokenizer for calculating image tokens based on dimensions
*
* Key rules:
* - 28x28 pixels = 1 token
* - Minimum: 4 tokens per image
* - Maximum: 16384 tokens per image
* - Additional: 2 special tokens (vision_bos + vision_eos)
* - Supports: PNG, JPEG, WebP, GIF, BMP, TIFF, HEIC formats
*/
export class ImageTokenizer {
/** 28x28 pixels = 1 token */
private static readonly PIXELS_PER_TOKEN = 28 * 28;
/** Minimum tokens per image */
private static readonly MIN_TOKENS_PER_IMAGE = 4;
/** Maximum tokens per image */
private static readonly MAX_TOKENS_PER_IMAGE = 16384;
/** Special tokens for vision markers */
private static readonly VISION_SPECIAL_TOKENS = 2;
/**
* Extract image metadata from base64 data
*
* @param base64Data Base64-encoded image data (with or without data URL prefix)
* @param mimeType MIME type of the image
* @returns Promise resolving to ImageMetadata with dimensions and format info
*/
async extractImageMetadata(
base64Data: string,
mimeType: string,
): Promise<ImageMetadata> {
try {
// Check if the MIME type is supported
if (!isSupportedImageMimeType(mimeType)) {
console.warn(`Unsupported image format: ${mimeType}`);
// Return default metadata for unsupported formats
return {
width: 512,
height: 512,
mimeType,
dataSize: Math.floor(base64Data.length * 0.75),
};
}
const cleanBase64 = base64Data.replace(/^data:[^;]+;base64,/, '');
const buffer = Buffer.from(cleanBase64, 'base64');
const dimensions = await this.extractDimensions(buffer, mimeType);
return {
width: dimensions.width,
height: dimensions.height,
mimeType,
dataSize: buffer.length,
};
} catch (error) {
console.warn('Failed to extract image metadata:', error);
// Return default metadata for fallback
return {
width: 512,
height: 512,
mimeType,
dataSize: Math.floor(base64Data.length * 0.75),
};
}
}
/**
* Extract image dimensions from buffer based on format
*
* @param buffer Binary image data buffer
* @param mimeType MIME type to determine parsing strategy
* @returns Promise resolving to width and height dimensions
*/
private async extractDimensions(
buffer: Buffer,
mimeType: string,
): Promise<{ width: number; height: number }> {
if (mimeType.includes('png')) {
return this.extractPngDimensions(buffer);
}
if (mimeType.includes('jpeg') || mimeType.includes('jpg')) {
return this.extractJpegDimensions(buffer);
}
if (mimeType.includes('webp')) {
return this.extractWebpDimensions(buffer);
}
if (mimeType.includes('gif')) {
return this.extractGifDimensions(buffer);
}
if (mimeType.includes('bmp')) {
return this.extractBmpDimensions(buffer);
}
if (mimeType.includes('tiff')) {
return this.extractTiffDimensions(buffer);
}
if (mimeType.includes('heic')) {
return this.extractHeicDimensions(buffer);
}
return { width: 512, height: 512 };
}
/**
* Extract PNG dimensions from IHDR chunk
* PNG signature: 89 50 4E 47 0D 0A 1A 0A
* Width/height at bytes 16-19 and 20-23 (big-endian)
*/
private extractPngDimensions(buffer: Buffer): {
width: number;
height: number;
} {
if (buffer.length < 24) {
throw new Error('Invalid PNG: buffer too short');
}
// Verify PNG signature
const signature = buffer.subarray(0, 8);
const expectedSignature = Buffer.from([
0x89, 0x50, 0x4e, 0x47, 0x0d, 0x0a, 0x1a, 0x0a,
]);
if (!signature.equals(expectedSignature)) {
throw new Error('Invalid PNG signature');
}
const width = buffer.readUInt32BE(16);
const height = buffer.readUInt32BE(20);
return { width, height };
}
/**
* Extract JPEG dimensions from SOF (Start of Frame) markers
* JPEG starts with FF D8, SOF markers: 0xC0-0xC3, 0xC5-0xC7, 0xC9-0xCB, 0xCD-0xCF
* Dimensions at offset +5 (height) and +7 (width) from SOF marker
*/
private extractJpegDimensions(buffer: Buffer): {
width: number;
height: number;
} {
if (buffer.length < 4 || buffer[0] !== 0xff || buffer[1] !== 0xd8) {
throw new Error('Invalid JPEG signature');
}
let offset = 2;
while (offset < buffer.length - 8) {
if (buffer[offset] !== 0xff) {
offset++;
continue;
}
const marker = buffer[offset + 1];
// SOF markers
if (
(marker >= 0xc0 && marker <= 0xc3) ||
(marker >= 0xc5 && marker <= 0xc7) ||
(marker >= 0xc9 && marker <= 0xcb) ||
(marker >= 0xcd && marker <= 0xcf)
) {
const height = buffer.readUInt16BE(offset + 5);
const width = buffer.readUInt16BE(offset + 7);
return { width, height };
}
const segmentLength = buffer.readUInt16BE(offset + 2);
offset += 2 + segmentLength;
}
throw new Error('Could not find JPEG dimensions');
}
/**
* Extract WebP dimensions from RIFF container
* Supports VP8, VP8L, and VP8X formats
*/
private extractWebpDimensions(buffer: Buffer): {
width: number;
height: number;
} {
if (buffer.length < 30) {
throw new Error('Invalid WebP: too short');
}
const riffSignature = buffer.subarray(0, 4).toString('ascii');
const webpSignature = buffer.subarray(8, 12).toString('ascii');
if (riffSignature !== 'RIFF' || webpSignature !== 'WEBP') {
throw new Error('Invalid WebP signature');
}
const format = buffer.subarray(12, 16).toString('ascii');
if (format === 'VP8 ') {
const width = buffer.readUInt16LE(26) & 0x3fff;
const height = buffer.readUInt16LE(28) & 0x3fff;
return { width, height };
} else if (format === 'VP8L') {
const bits = buffer.readUInt32LE(21);
const width = (bits & 0x3fff) + 1;
const height = ((bits >> 14) & 0x3fff) + 1;
return { width, height };
} else if (format === 'VP8X') {
const width = (buffer.readUInt32LE(24) & 0xffffff) + 1;
const height = (buffer.readUInt32LE(26) & 0xffffff) + 1;
return { width, height };
}
throw new Error('Unsupported WebP format');
}
/**
* Extract GIF dimensions from header
* Supports GIF87a and GIF89a formats
*/
private extractGifDimensions(buffer: Buffer): {
width: number;
height: number;
} {
if (buffer.length < 10) {
throw new Error('Invalid GIF: too short');
}
const signature = buffer.subarray(0, 6).toString('ascii');
if (signature !== 'GIF87a' && signature !== 'GIF89a') {
throw new Error('Invalid GIF signature');
}
const width = buffer.readUInt16LE(6);
const height = buffer.readUInt16LE(8);
return { width, height };
}
/**
* Calculate tokens for an image based on its metadata
*
* @param metadata Image metadata containing width, height, and format info
* @returns Total token count including base image tokens and special tokens
*/
calculateTokens(metadata: ImageMetadata): number {
return this.calculateTokensWithScaling(metadata.width, metadata.height);
}
/**
* Calculate tokens with scaling logic
*
* Steps:
* 1. Normalize to 28-pixel multiples
* 2. Scale large images down, small images up
* 3. Calculate tokens: pixels / 784 + 2 special tokens
*
* @param width Original image width in pixels
* @param height Original image height in pixels
* @returns Total token count for the image
*/
private calculateTokensWithScaling(width: number, height: number): number {
// Normalize to 28-pixel multiples
let hBar = Math.round(height / 28) * 28;
let wBar = Math.round(width / 28) * 28;
// Define pixel boundaries
const minPixels =
ImageTokenizer.MIN_TOKENS_PER_IMAGE * ImageTokenizer.PIXELS_PER_TOKEN;
const maxPixels =
ImageTokenizer.MAX_TOKENS_PER_IMAGE * ImageTokenizer.PIXELS_PER_TOKEN;
// Apply scaling
if (hBar * wBar > maxPixels) {
// Scale down large images
const beta = Math.sqrt((height * width) / maxPixels);
hBar = Math.floor(height / beta / 28) * 28;
wBar = Math.floor(width / beta / 28) * 28;
} else if (hBar * wBar < minPixels) {
// Scale up small images
const beta = Math.sqrt(minPixels / (height * width));
hBar = Math.ceil((height * beta) / 28) * 28;
wBar = Math.ceil((width * beta) / 28) * 28;
}
// Calculate tokens
const imageTokens = Math.floor(
(hBar * wBar) / ImageTokenizer.PIXELS_PER_TOKEN,
);
return imageTokens + ImageTokenizer.VISION_SPECIAL_TOKENS;
}
/**
* Calculate tokens for multiple images serially
*
* @param base64DataArray Array of image data with MIME type information
* @returns Promise resolving to array of token counts in same order as input
*/
async calculateTokensBatch(
base64DataArray: Array<{ data: string; mimeType: string }>,
): Promise<number[]> {
const results: number[] = [];
for (const { data, mimeType } of base64DataArray) {
try {
const metadata = await this.extractImageMetadata(data, mimeType);
results.push(this.calculateTokens(metadata));
} catch (error) {
console.warn('Error calculating tokens for image:', error);
// Return minimum tokens as fallback
results.push(
ImageTokenizer.MIN_TOKENS_PER_IMAGE +
ImageTokenizer.VISION_SPECIAL_TOKENS,
);
}
}
return results;
}
/**
* Extract BMP dimensions from header
* BMP signature: 42 4D (BM)
* Width/height at bytes 18-21 and 22-25 (little-endian)
*/
private extractBmpDimensions(buffer: Buffer): {
width: number;
height: number;
} {
if (buffer.length < 26) {
throw new Error('Invalid BMP: buffer too short');
}
// Verify BMP signature
if (buffer[0] !== 0x42 || buffer[1] !== 0x4d) {
throw new Error('Invalid BMP signature');
}
const width = buffer.readUInt32LE(18);
const height = buffer.readUInt32LE(22);
return { width, height: Math.abs(height) }; // Height can be negative for top-down BMPs
}
/**
* Extract TIFF dimensions from IFD (Image File Directory)
* TIFF can be little-endian (II) or big-endian (MM)
* Width/height are stored in IFD entries with tags 0x0100 and 0x0101
*/
private extractTiffDimensions(buffer: Buffer): {
width: number;
height: number;
} {
if (buffer.length < 8) {
throw new Error('Invalid TIFF: buffer too short');
}
// Check byte order
const byteOrder = buffer.subarray(0, 2).toString('ascii');
const isLittleEndian = byteOrder === 'II';
const isBigEndian = byteOrder === 'MM';
if (!isLittleEndian && !isBigEndian) {
throw new Error('Invalid TIFF byte order');
}
// Read magic number (should be 42)
const magic = isLittleEndian
? buffer.readUInt16LE(2)
: buffer.readUInt16BE(2);
if (magic !== 42) {
throw new Error('Invalid TIFF magic number');
}
// Read IFD offset
const ifdOffset = isLittleEndian
? buffer.readUInt32LE(4)
: buffer.readUInt32BE(4);
if (ifdOffset >= buffer.length) {
throw new Error('Invalid TIFF IFD offset');
}
// Read number of directory entries
const numEntries = isLittleEndian
? buffer.readUInt16LE(ifdOffset)
: buffer.readUInt16BE(ifdOffset);
let width = 0;
let height = 0;
// Parse IFD entries
for (let i = 0; i < numEntries; i++) {
const entryOffset = ifdOffset + 2 + i * 12;
if (entryOffset + 12 > buffer.length) break;
const tag = isLittleEndian
? buffer.readUInt16LE(entryOffset)
: buffer.readUInt16BE(entryOffset);
const type = isLittleEndian
? buffer.readUInt16LE(entryOffset + 2)
: buffer.readUInt16BE(entryOffset + 2);
const value = isLittleEndian
? buffer.readUInt32LE(entryOffset + 8)
: buffer.readUInt32BE(entryOffset + 8);
if (tag === 0x0100) {
// ImageWidth
width = type === 3 ? value : value; // SHORT or LONG
} else if (tag === 0x0101) {
// ImageLength (height)
height = type === 3 ? value : value; // SHORT or LONG
}
if (width > 0 && height > 0) break;
}
if (width === 0 || height === 0) {
throw new Error('Could not find TIFF dimensions');
}
return { width, height };
}
/**
* Extract HEIC dimensions from meta box
* HEIC is based on ISO Base Media File Format
* This is a simplified implementation that looks for 'ispe' (Image Spatial Extents) box
*/
private extractHeicDimensions(buffer: Buffer): {
width: number;
height: number;
} {
if (buffer.length < 12) {
throw new Error('Invalid HEIC: buffer too short');
}
// Check for ftyp box with HEIC brand
const ftypBox = buffer.subarray(4, 8).toString('ascii');
if (ftypBox !== 'ftyp') {
throw new Error('Invalid HEIC: missing ftyp box');
}
const brand = buffer.subarray(8, 12).toString('ascii');
if (!['heic', 'heix', 'hevc', 'hevx'].includes(brand)) {
throw new Error('Invalid HEIC brand');
}
// Look for meta box and then ispe box
let offset = 0;
while (offset < buffer.length - 8) {
const boxSize = buffer.readUInt32BE(offset);
const boxType = buffer.subarray(offset + 4, offset + 8).toString('ascii');
if (boxType === 'meta') {
// Look for ispe box inside meta box
const metaOffset = offset + 8;
let innerOffset = metaOffset + 4; // Skip version and flags
while (innerOffset < offset + boxSize - 8) {
const innerBoxSize = buffer.readUInt32BE(innerOffset);
const innerBoxType = buffer
.subarray(innerOffset + 4, innerOffset + 8)
.toString('ascii');
if (innerBoxType === 'ispe') {
// Found Image Spatial Extents box
if (innerOffset + 20 <= buffer.length) {
const width = buffer.readUInt32BE(innerOffset + 12);
const height = buffer.readUInt32BE(innerOffset + 16);
return { width, height };
}
}
if (innerBoxSize === 0) break;
innerOffset += innerBoxSize;
}
}
if (boxSize === 0) break;
offset += boxSize;
}
// Fallback: return default dimensions if we can't parse the structure
console.warn('Could not extract HEIC dimensions, using default');
return { width: 512, height: 512 };
}
}

View File

@@ -0,0 +1,40 @@
/**
* @license
* Copyright 2025 Qwen
* SPDX-License-Identifier: Apache-2.0
*/
export { DefaultRequestTokenizer } from './requestTokenizer.js';
import { DefaultRequestTokenizer } from './requestTokenizer.js';
export { TextTokenizer } from './textTokenizer.js';
export { ImageTokenizer } from './imageTokenizer.js';
export type {
RequestTokenizer,
TokenizerConfig,
TokenCalculationResult,
ImageMetadata,
} from './types.js';
// Singleton instance for convenient usage
let defaultTokenizer: DefaultRequestTokenizer | null = null;
/**
* Get the default request tokenizer instance
*/
export function getDefaultTokenizer(): DefaultRequestTokenizer {
if (!defaultTokenizer) {
defaultTokenizer = new DefaultRequestTokenizer();
}
return defaultTokenizer;
}
/**
* Dispose of the default tokenizer instance
*/
export async function disposeDefaultTokenizer(): Promise<void> {
if (defaultTokenizer) {
await defaultTokenizer.dispose();
defaultTokenizer = null;
}
}

View File

@@ -0,0 +1,293 @@
/**
* @license
* Copyright 2025 Qwen
* SPDX-License-Identifier: Apache-2.0
*/
import { describe, it, expect, beforeEach, afterEach } from 'vitest';
import { DefaultRequestTokenizer } from './requestTokenizer.js';
import type { CountTokensParameters } from '@google/genai';
describe('DefaultRequestTokenizer', () => {
let tokenizer: DefaultRequestTokenizer;
beforeEach(() => {
tokenizer = new DefaultRequestTokenizer();
});
afterEach(async () => {
await tokenizer.dispose();
});
describe('text token calculation', () => {
it('should calculate tokens for simple text content', async () => {
const request: CountTokensParameters = {
model: 'test-model',
contents: [
{
role: 'user',
parts: [{ text: 'Hello, world!' }],
},
],
};
const result = await tokenizer.calculateTokens(request);
expect(result.totalTokens).toBeGreaterThan(0);
expect(result.breakdown.textTokens).toBeGreaterThan(0);
expect(result.breakdown.imageTokens).toBe(0);
expect(result.processingTime).toBeGreaterThan(0);
});
it('should handle multiple text parts', async () => {
const request: CountTokensParameters = {
model: 'test-model',
contents: [
{
role: 'user',
parts: [
{ text: 'First part' },
{ text: 'Second part' },
{ text: 'Third part' },
],
},
],
};
const result = await tokenizer.calculateTokens(request);
expect(result.totalTokens).toBeGreaterThan(0);
expect(result.breakdown.textTokens).toBeGreaterThan(0);
});
it('should handle string content', async () => {
const request: CountTokensParameters = {
model: 'test-model',
contents: ['Simple string content'],
};
const result = await tokenizer.calculateTokens(request);
expect(result.totalTokens).toBeGreaterThan(0);
expect(result.breakdown.textTokens).toBeGreaterThan(0);
});
});
describe('image token calculation', () => {
it('should calculate tokens for image content', async () => {
// Create a simple 1x1 PNG image in base64
const pngBase64 =
'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNkYPhfDwAChAI9jU77yQAAAABJRU5ErkJggg==';
const request: CountTokensParameters = {
model: 'test-model',
contents: [
{
role: 'user',
parts: [
{
inlineData: {
mimeType: 'image/png',
data: pngBase64,
},
},
],
},
],
};
const result = await tokenizer.calculateTokens(request);
expect(result.totalTokens).toBeGreaterThanOrEqual(4); // Minimum 4 tokens per image
expect(result.breakdown.imageTokens).toBeGreaterThanOrEqual(4);
expect(result.breakdown.textTokens).toBe(0);
});
it('should handle multiple images', async () => {
const pngBase64 =
'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNkYPhfDwAChAI9jU77yQAAAABJRU5ErkJggg==';
const request: CountTokensParameters = {
model: 'test-model',
contents: [
{
role: 'user',
parts: [
{
inlineData: {
mimeType: 'image/png',
data: pngBase64,
},
},
{
inlineData: {
mimeType: 'image/png',
data: pngBase64,
},
},
],
},
],
};
const result = await tokenizer.calculateTokens(request);
expect(result.totalTokens).toBeGreaterThanOrEqual(8); // At least 4 tokens per image
expect(result.breakdown.imageTokens).toBeGreaterThanOrEqual(8);
});
});
describe('mixed content', () => {
it('should handle text and image content together', async () => {
const pngBase64 =
'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNkYPhfDwAChAI9jU77yQAAAABJRU5ErkJggg==';
const request: CountTokensParameters = {
model: 'test-model',
contents: [
{
role: 'user',
parts: [
{ text: 'Here is an image:' },
{
inlineData: {
mimeType: 'image/png',
data: pngBase64,
},
},
{ text: 'What do you see?' },
],
},
],
};
const result = await tokenizer.calculateTokens(request);
expect(result.totalTokens).toBeGreaterThan(4);
expect(result.breakdown.textTokens).toBeGreaterThan(0);
expect(result.breakdown.imageTokens).toBeGreaterThanOrEqual(4);
});
});
describe('function content', () => {
it('should handle function calls', async () => {
const request: CountTokensParameters = {
model: 'test-model',
contents: [
{
role: 'user',
parts: [
{
functionCall: {
name: 'test_function',
args: { param1: 'value1', param2: 42 },
},
},
],
},
],
};
const result = await tokenizer.calculateTokens(request);
expect(result.totalTokens).toBeGreaterThan(0);
expect(result.breakdown.otherTokens).toBeGreaterThan(0);
});
});
describe('empty content', () => {
it('should handle empty request', async () => {
const request: CountTokensParameters = {
model: 'test-model',
contents: [],
};
const result = await tokenizer.calculateTokens(request);
expect(result.totalTokens).toBe(0);
expect(result.breakdown.textTokens).toBe(0);
expect(result.breakdown.imageTokens).toBe(0);
});
it('should handle undefined contents', async () => {
const request: CountTokensParameters = {
model: 'test-model',
contents: [],
};
const result = await tokenizer.calculateTokens(request);
expect(result.totalTokens).toBe(0);
});
});
describe('configuration', () => {
it('should use custom text encoding', async () => {
const request: CountTokensParameters = {
model: 'test-model',
contents: [
{
role: 'user',
parts: [{ text: 'Test text for encoding' }],
},
],
};
const result = await tokenizer.calculateTokens(request, {
textEncoding: 'cl100k_base',
});
expect(result.totalTokens).toBeGreaterThan(0);
});
it('should process multiple images serially', async () => {
const pngBase64 =
'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNkYPhfDwAChAI9jU77yQAAAABJRU5ErkJggg==';
const request: CountTokensParameters = {
model: 'test-model',
contents: [
{
role: 'user',
parts: Array(10).fill({
inlineData: {
mimeType: 'image/png',
data: pngBase64,
},
}),
},
],
};
const result = await tokenizer.calculateTokens(request);
expect(result.totalTokens).toBeGreaterThanOrEqual(60); // At least 6 tokens per image * 10 images
});
});
describe('error handling', () => {
it('should handle malformed image data gracefully', async () => {
const request: CountTokensParameters = {
model: 'test-model',
contents: [
{
role: 'user',
parts: [
{
inlineData: {
mimeType: 'image/png',
data: 'invalid-base64-data',
},
},
],
},
],
};
const result = await tokenizer.calculateTokens(request);
// Should still return some tokens (fallback to minimum)
expect(result.totalTokens).toBeGreaterThanOrEqual(4);
});
});
});

View File

@@ -0,0 +1,341 @@
/**
* @license
* Copyright 2025 Qwen
* SPDX-License-Identifier: Apache-2.0
*/
import type {
CountTokensParameters,
Content,
Part,
PartUnion,
} from '@google/genai';
import type {
RequestTokenizer,
TokenizerConfig,
TokenCalculationResult,
} from './types.js';
import { TextTokenizer } from './textTokenizer.js';
import { ImageTokenizer } from './imageTokenizer.js';
/**
* Simple request tokenizer that handles text and image content serially
*/
export class DefaultRequestTokenizer implements RequestTokenizer {
private textTokenizer: TextTokenizer;
private imageTokenizer: ImageTokenizer;
constructor() {
this.textTokenizer = new TextTokenizer();
this.imageTokenizer = new ImageTokenizer();
}
/**
* Calculate tokens for a request using serial processing
*/
async calculateTokens(
request: CountTokensParameters,
config: TokenizerConfig = {},
): Promise<TokenCalculationResult> {
const startTime = performance.now();
// Apply configuration
if (config.textEncoding) {
this.textTokenizer = new TextTokenizer(config.textEncoding);
}
try {
// Process request content and group by type
const { textContents, imageContents, audioContents, otherContents } =
this.processAndGroupContents(request);
if (
textContents.length === 0 &&
imageContents.length === 0 &&
audioContents.length === 0 &&
otherContents.length === 0
) {
return {
totalTokens: 0,
breakdown: {
textTokens: 0,
imageTokens: 0,
audioTokens: 0,
otherTokens: 0,
},
processingTime: performance.now() - startTime,
};
}
// Calculate tokens for each content type serially
const textTokens = await this.calculateTextTokens(textContents);
const imageTokens = await this.calculateImageTokens(imageContents);
const audioTokens = await this.calculateAudioTokens(audioContents);
const otherTokens = await this.calculateOtherTokens(otherContents);
const totalTokens = textTokens + imageTokens + audioTokens + otherTokens;
const processingTime = performance.now() - startTime;
return {
totalTokens,
breakdown: {
textTokens,
imageTokens,
audioTokens,
otherTokens,
},
processingTime,
};
} catch (error) {
console.error('Error calculating tokens:', error);
// Fallback calculation
const fallbackTokens = this.calculateFallbackTokens(request);
return {
totalTokens: fallbackTokens,
breakdown: {
textTokens: fallbackTokens,
imageTokens: 0,
audioTokens: 0,
otherTokens: 0,
},
processingTime: performance.now() - startTime,
};
}
}
/**
* Calculate tokens for text contents
*/
private async calculateTextTokens(textContents: string[]): Promise<number> {
if (textContents.length === 0) return 0;
try {
const tokenCounts =
await this.textTokenizer.calculateTokensBatch(textContents);
return tokenCounts.reduce((sum, count) => sum + count, 0);
} catch (error) {
console.warn('Error calculating text tokens:', error);
// Fallback: character-based estimation
const totalChars = textContents.join('').length;
return Math.ceil(totalChars / 4);
}
}
/**
* Calculate tokens for image contents using serial processing
*/
private async calculateImageTokens(
imageContents: Array<{ data: string; mimeType: string }>,
): Promise<number> {
if (imageContents.length === 0) return 0;
try {
const tokenCounts =
await this.imageTokenizer.calculateTokensBatch(imageContents);
return tokenCounts.reduce((sum, count) => sum + count, 0);
} catch (error) {
console.warn('Error calculating image tokens:', error);
// Fallback: minimum tokens per image
return imageContents.length * 6; // 4 image tokens + 2 special tokens as minimum
}
}
/**
* Calculate tokens for audio contents
* TODO: Implement proper audio token calculation
*/
private async calculateAudioTokens(
audioContents: Array<{ data: string; mimeType: string }>,
): Promise<number> {
if (audioContents.length === 0) return 0;
// Placeholder implementation - audio token calculation would depend on
// the specific model's audio processing capabilities
// For now, estimate based on data size
let totalTokens = 0;
for (const audioContent of audioContents) {
try {
const dataSize = Math.floor(audioContent.data.length * 0.75); // Approximate binary size
// Rough estimate: 1 token per 100 bytes of audio data
totalTokens += Math.max(Math.ceil(dataSize / 100), 10); // Minimum 10 tokens per audio
} catch (error) {
console.warn('Error calculating audio tokens:', error);
totalTokens += 10; // Fallback minimum
}
}
return totalTokens;
}
/**
* Calculate tokens for other content types (functions, files, etc.)
*/
private async calculateOtherTokens(otherContents: string[]): Promise<number> {
if (otherContents.length === 0) return 0;
try {
// Treat other content as text for token calculation
const tokenCounts =
await this.textTokenizer.calculateTokensBatch(otherContents);
return tokenCounts.reduce((sum, count) => sum + count, 0);
} catch (error) {
console.warn('Error calculating other content tokens:', error);
// Fallback: character-based estimation
const totalChars = otherContents.join('').length;
return Math.ceil(totalChars / 4);
}
}
/**
* Fallback token calculation using simple string serialization
*/
private calculateFallbackTokens(request: CountTokensParameters): number {
try {
const content = JSON.stringify(request.contents);
return Math.ceil(content.length / 4); // Rough estimate: 1 token ≈ 4 characters
} catch (error) {
console.warn('Error in fallback token calculation:', error);
return 100; // Conservative fallback
}
}
/**
* Process request contents and group by type
*/
private processAndGroupContents(request: CountTokensParameters): {
textContents: string[];
imageContents: Array<{ data: string; mimeType: string }>;
audioContents: Array<{ data: string; mimeType: string }>;
otherContents: string[];
} {
const textContents: string[] = [];
const imageContents: Array<{ data: string; mimeType: string }> = [];
const audioContents: Array<{ data: string; mimeType: string }> = [];
const otherContents: string[] = [];
if (!request.contents) {
return { textContents, imageContents, audioContents, otherContents };
}
const contents = Array.isArray(request.contents)
? request.contents
: [request.contents];
for (const content of contents) {
this.processContent(
content,
textContents,
imageContents,
audioContents,
otherContents,
);
}
return { textContents, imageContents, audioContents, otherContents };
}
/**
* Process a single content item and add to appropriate arrays
*/
private processContent(
content: Content | string | PartUnion,
textContents: string[],
imageContents: Array<{ data: string; mimeType: string }>,
audioContents: Array<{ data: string; mimeType: string }>,
otherContents: string[],
): void {
if (typeof content === 'string') {
if (content.trim()) {
textContents.push(content);
}
return;
}
if ('parts' in content && content.parts) {
for (const part of content.parts) {
this.processPart(
part,
textContents,
imageContents,
audioContents,
otherContents,
);
}
}
}
/**
* Process a single part and add to appropriate arrays
*/
private processPart(
part: Part | string,
textContents: string[],
imageContents: Array<{ data: string; mimeType: string }>,
audioContents: Array<{ data: string; mimeType: string }>,
otherContents: string[],
): void {
if (typeof part === 'string') {
if (part.trim()) {
textContents.push(part);
}
return;
}
if ('text' in part && part.text) {
textContents.push(part.text);
return;
}
if ('inlineData' in part && part.inlineData) {
const { data, mimeType } = part.inlineData;
if (mimeType && mimeType.startsWith('image/')) {
imageContents.push({ data: data || '', mimeType });
return;
}
if (mimeType && mimeType.startsWith('audio/')) {
audioContents.push({ data: data || '', mimeType });
return;
}
}
if ('fileData' in part && part.fileData) {
otherContents.push(JSON.stringify(part.fileData));
return;
}
if ('functionCall' in part && part.functionCall) {
otherContents.push(JSON.stringify(part.functionCall));
return;
}
if ('functionResponse' in part && part.functionResponse) {
otherContents.push(JSON.stringify(part.functionResponse));
return;
}
// Unknown part type - try to serialize
try {
const serialized = JSON.stringify(part);
if (serialized && serialized !== '{}') {
otherContents.push(serialized);
}
} catch (error) {
console.warn('Failed to serialize unknown part type:', error);
}
}
/**
* Dispose of resources
*/
async dispose(): Promise<void> {
try {
// Dispose of tokenizers
this.textTokenizer.dispose();
} catch (error) {
console.warn('Error disposing request tokenizer:', error);
}
}
}

View File

@@ -0,0 +1,56 @@
/**
* @license
* Copyright 2025 Qwen
* SPDX-License-Identifier: Apache-2.0
*/
/**
* Supported image MIME types for vision models
* These formats are supported by the vision model and can be processed by the image tokenizer
*/
export const SUPPORTED_IMAGE_MIME_TYPES = [
'image/bmp',
'image/jpeg',
'image/jpg', // Alternative MIME type for JPEG
'image/png',
'image/tiff',
'image/webp',
'image/heic',
] as const;
/**
* Type for supported image MIME types
*/
export type SupportedImageMimeType =
(typeof SUPPORTED_IMAGE_MIME_TYPES)[number];
/**
* Check if a MIME type is supported for vision processing
* @param mimeType The MIME type to check
* @returns True if the MIME type is supported
*/
export function isSupportedImageMimeType(
mimeType: string,
): mimeType is SupportedImageMimeType {
return SUPPORTED_IMAGE_MIME_TYPES.includes(
mimeType as SupportedImageMimeType,
);
}
/**
* Get a human-readable list of supported image formats
* @returns Comma-separated string of supported formats
*/
export function getSupportedImageFormatsString(): string {
return SUPPORTED_IMAGE_MIME_TYPES.map((type) =>
type.replace('image/', '').toUpperCase(),
).join(', ');
}
/**
* Get warning message for unsupported image formats
* @returns Warning message string
*/
export function getUnsupportedImageFormatWarning(): string {
return `Only the following image formats are supported: ${getSupportedImageFormatsString()}. Other formats may not work as expected.`;
}

View File

@@ -0,0 +1,347 @@
/**
* @license
* Copyright 2025 Qwen
* SPDX-License-Identifier: Apache-2.0
*/
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
import { TextTokenizer } from './textTokenizer.js';
// Mock tiktoken at the top level with hoisted functions
const mockEncode = vi.hoisted(() => vi.fn());
const mockFree = vi.hoisted(() => vi.fn());
const mockGetEncoding = vi.hoisted(() => vi.fn());
vi.mock('tiktoken', () => ({
get_encoding: mockGetEncoding,
}));
describe('TextTokenizer', () => {
let tokenizer: TextTokenizer;
let consoleWarnSpy: ReturnType<typeof vi.spyOn>;
beforeEach(() => {
vi.resetAllMocks();
consoleWarnSpy = vi.spyOn(console, 'warn').mockImplementation(() => {});
// Default mock implementation
mockGetEncoding.mockReturnValue({
encode: mockEncode,
free: mockFree,
});
});
afterEach(() => {
vi.restoreAllMocks();
tokenizer?.dispose();
});
describe('constructor', () => {
it('should create tokenizer with default encoding', () => {
tokenizer = new TextTokenizer();
expect(tokenizer).toBeInstanceOf(TextTokenizer);
});
it('should create tokenizer with custom encoding', () => {
tokenizer = new TextTokenizer('gpt2');
expect(tokenizer).toBeInstanceOf(TextTokenizer);
});
});
describe('calculateTokens', () => {
beforeEach(() => {
tokenizer = new TextTokenizer();
});
it('should return 0 for empty text', async () => {
const result = await tokenizer.calculateTokens('');
expect(result).toBe(0);
});
it('should return 0 for null/undefined text', async () => {
const result1 = await tokenizer.calculateTokens(
null as unknown as string,
);
const result2 = await tokenizer.calculateTokens(
undefined as unknown as string,
);
expect(result1).toBe(0);
expect(result2).toBe(0);
});
it('should calculate tokens using tiktoken when available', async () => {
const testText = 'Hello, world!';
const mockTokens = [1, 2, 3, 4, 5]; // 5 tokens
mockEncode.mockReturnValue(mockTokens);
const result = await tokenizer.calculateTokens(testText);
expect(mockGetEncoding).toHaveBeenCalledWith('cl100k_base');
expect(mockEncode).toHaveBeenCalledWith(testText);
expect(result).toBe(5);
});
it('should use fallback calculation when tiktoken fails to load', async () => {
mockGetEncoding.mockImplementation(() => {
throw new Error('Failed to load tiktoken');
});
const testText = 'Hello, world!'; // 13 characters
const result = await tokenizer.calculateTokens(testText);
expect(consoleWarnSpy).toHaveBeenCalledWith(
'Failed to load tiktoken with encoding cl100k_base:',
expect.any(Error),
);
// Fallback: Math.ceil(13 / 4) = 4
expect(result).toBe(4);
});
it('should use fallback calculation when encoding fails', async () => {
mockEncode.mockImplementation(() => {
throw new Error('Encoding failed');
});
const testText = 'Hello, world!'; // 13 characters
const result = await tokenizer.calculateTokens(testText);
expect(consoleWarnSpy).toHaveBeenCalledWith(
'Error encoding text with tiktoken:',
expect.any(Error),
);
// Fallback: Math.ceil(13 / 4) = 4
expect(result).toBe(4);
});
it('should handle very long text', async () => {
const longText = 'a'.repeat(10000);
const mockTokens = new Array(2500); // 2500 tokens
mockEncode.mockReturnValue(mockTokens);
const result = await tokenizer.calculateTokens(longText);
expect(result).toBe(2500);
});
it('should handle unicode characters', async () => {
const unicodeText = '你好世界 🌍';
const mockTokens = [1, 2, 3, 4, 5, 6];
mockEncode.mockReturnValue(mockTokens);
const result = await tokenizer.calculateTokens(unicodeText);
expect(result).toBe(6);
});
it('should use custom encoding when specified', async () => {
tokenizer = new TextTokenizer('gpt2');
const testText = 'Hello, world!';
const mockTokens = [1, 2, 3];
mockEncode.mockReturnValue(mockTokens);
const result = await tokenizer.calculateTokens(testText);
expect(mockGetEncoding).toHaveBeenCalledWith('gpt2');
expect(result).toBe(3);
});
});
describe('calculateTokensBatch', () => {
beforeEach(() => {
tokenizer = new TextTokenizer();
});
it('should process multiple texts and return token counts', async () => {
const texts = ['Hello', 'world', 'test'];
mockEncode
.mockReturnValueOnce([1, 2]) // 2 tokens for 'Hello'
.mockReturnValueOnce([3, 4, 5]) // 3 tokens for 'world'
.mockReturnValueOnce([6]); // 1 token for 'test'
const result = await tokenizer.calculateTokensBatch(texts);
expect(result).toEqual([2, 3, 1]);
expect(mockEncode).toHaveBeenCalledTimes(3);
});
it('should handle empty array', async () => {
const result = await tokenizer.calculateTokensBatch([]);
expect(result).toEqual([]);
});
it('should handle array with empty strings', async () => {
const texts = ['', 'hello', ''];
mockEncode.mockReturnValue([1, 2, 3]); // Only called for 'hello'
const result = await tokenizer.calculateTokensBatch(texts);
expect(result).toEqual([0, 3, 0]);
expect(mockEncode).toHaveBeenCalledTimes(1);
expect(mockEncode).toHaveBeenCalledWith('hello');
});
it('should use fallback calculation when tiktoken fails to load', async () => {
mockGetEncoding.mockImplementation(() => {
throw new Error('Failed to load tiktoken');
});
const texts = ['Hello', 'world']; // 5 and 5 characters
const result = await tokenizer.calculateTokensBatch(texts);
expect(consoleWarnSpy).toHaveBeenCalledWith(
'Failed to load tiktoken with encoding cl100k_base:',
expect.any(Error),
);
// Fallback: Math.ceil(5/4) = 2 for both
expect(result).toEqual([2, 2]);
});
it('should use fallback calculation when encoding fails during batch processing', async () => {
mockEncode.mockImplementation(() => {
throw new Error('Encoding failed');
});
const texts = ['Hello', 'world']; // 5 and 5 characters
const result = await tokenizer.calculateTokensBatch(texts);
expect(consoleWarnSpy).toHaveBeenCalledWith(
'Error encoding texts with tiktoken:',
expect.any(Error),
);
// Fallback: Math.ceil(5/4) = 2 for both
expect(result).toEqual([2, 2]);
});
it('should handle null and undefined values in batch', async () => {
const texts = [null, 'hello', undefined, 'world'] as unknown as string[];
mockEncode
.mockReturnValueOnce([1, 2, 3]) // 3 tokens for 'hello'
.mockReturnValueOnce([4, 5]); // 2 tokens for 'world'
const result = await tokenizer.calculateTokensBatch(texts);
expect(result).toEqual([0, 3, 0, 2]);
});
});
describe('dispose', () => {
beforeEach(() => {
tokenizer = new TextTokenizer();
});
it('should free tiktoken encoding when disposing', async () => {
// Initialize the encoding by calling calculateTokens
await tokenizer.calculateTokens('test');
tokenizer.dispose();
expect(mockFree).toHaveBeenCalled();
});
it('should handle disposal when encoding is not initialized', () => {
expect(() => tokenizer.dispose()).not.toThrow();
expect(mockFree).not.toHaveBeenCalled();
});
it('should handle disposal when encoding is null', async () => {
// Force encoding to be null by making tiktoken fail
mockGetEncoding.mockImplementation(() => {
throw new Error('Failed to load');
});
await tokenizer.calculateTokens('test');
expect(() => tokenizer.dispose()).not.toThrow();
expect(mockFree).not.toHaveBeenCalled();
});
it('should handle errors during disposal gracefully', async () => {
await tokenizer.calculateTokens('test');
mockFree.mockImplementation(() => {
throw new Error('Free failed');
});
tokenizer.dispose();
expect(consoleWarnSpy).toHaveBeenCalledWith(
'Error freeing tiktoken encoding:',
expect.any(Error),
);
});
it('should allow multiple calls to dispose', async () => {
await tokenizer.calculateTokens('test');
tokenizer.dispose();
tokenizer.dispose(); // Second call should not throw
expect(mockFree).toHaveBeenCalledTimes(1);
});
});
describe('lazy initialization', () => {
beforeEach(() => {
tokenizer = new TextTokenizer();
});
it('should not initialize tiktoken until first use', () => {
expect(mockGetEncoding).not.toHaveBeenCalled();
});
it('should initialize tiktoken on first calculateTokens call', async () => {
await tokenizer.calculateTokens('test');
expect(mockGetEncoding).toHaveBeenCalledTimes(1);
});
it('should not reinitialize tiktoken on subsequent calls', async () => {
await tokenizer.calculateTokens('test1');
await tokenizer.calculateTokens('test2');
expect(mockGetEncoding).toHaveBeenCalledTimes(1);
});
it('should initialize tiktoken on first calculateTokensBatch call', async () => {
await tokenizer.calculateTokensBatch(['test']);
expect(mockGetEncoding).toHaveBeenCalledTimes(1);
});
});
describe('edge cases', () => {
beforeEach(() => {
tokenizer = new TextTokenizer();
});
it('should handle very short text', async () => {
const result = await tokenizer.calculateTokens('a');
if (mockGetEncoding.mock.calls.length > 0) {
// If tiktoken was called, use its result
expect(mockEncode).toHaveBeenCalledWith('a');
} else {
// If tiktoken failed, should use fallback: Math.ceil(1/4) = 1
expect(result).toBe(1);
}
});
it('should handle text with only whitespace', async () => {
const whitespaceText = ' \n\t ';
const mockTokens = [1];
mockEncode.mockReturnValue(mockTokens);
const result = await tokenizer.calculateTokens(whitespaceText);
expect(result).toBe(1);
});
it('should handle special characters and symbols', async () => {
const specialText = '!@#$%^&*()_+-=[]{}|;:,.<>?';
const mockTokens = new Array(10);
mockEncode.mockReturnValue(mockTokens);
const result = await tokenizer.calculateTokens(specialText);
expect(result).toBe(10);
});
});
});

View File

@@ -0,0 +1,97 @@
/**
* @license
* Copyright 2025 Qwen
* SPDX-License-Identifier: Apache-2.0
*/
import type { TiktokenEncoding, Tiktoken } from 'tiktoken';
import { get_encoding } from 'tiktoken';
/**
* Text tokenizer for calculating text tokens using tiktoken
*/
export class TextTokenizer {
private encoding: Tiktoken | null = null;
private encodingName: string;
constructor(encodingName: string = 'cl100k_base') {
this.encodingName = encodingName;
}
/**
* Initialize the tokenizer (lazy loading)
*/
private async ensureEncoding(): Promise<void> {
if (this.encoding) return;
try {
// Use type assertion since we know the encoding name is valid
this.encoding = get_encoding(this.encodingName as TiktokenEncoding);
} catch (error) {
console.warn(
`Failed to load tiktoken with encoding ${this.encodingName}:`,
error,
);
this.encoding = null;
}
}
/**
* Calculate tokens for text content
*/
async calculateTokens(text: string): Promise<number> {
if (!text) return 0;
await this.ensureEncoding();
if (this.encoding) {
try {
return this.encoding.encode(text).length;
} catch (error) {
console.warn('Error encoding text with tiktoken:', error);
}
}
// Fallback: rough approximation using character count
// This is a conservative estimate: 1 token ≈ 4 characters for most languages
return Math.ceil(text.length / 4);
}
/**
* Calculate tokens for multiple text strings in parallel
*/
async calculateTokensBatch(texts: string[]): Promise<number[]> {
await this.ensureEncoding();
if (this.encoding) {
try {
return texts.map((text) => {
if (!text) return 0;
// this.encoding may be null, add a null check to satisfy lint
return this.encoding ? this.encoding.encode(text).length : 0;
});
} catch (error) {
console.warn('Error encoding texts with tiktoken:', error);
// In case of error, return fallback estimation for all texts
return texts.map((text) => Math.ceil((text || '').length / 4));
}
}
// Fallback for batch processing
return texts.map((text) => Math.ceil((text || '').length / 4));
}
/**
* Dispose of resources
*/
dispose(): void {
if (this.encoding) {
try {
this.encoding.free();
} catch (error) {
console.warn('Error freeing tiktoken encoding:', error);
}
this.encoding = null;
}
}
}

View File

@@ -0,0 +1,64 @@
/**
* @license
* Copyright 2025 Qwen
* SPDX-License-Identifier: Apache-2.0
*/
import type { CountTokensParameters } from '@google/genai';
/**
* Token calculation result for different content types
*/
export interface TokenCalculationResult {
/** Total tokens calculated */
totalTokens: number;
/** Breakdown by content type */
breakdown: {
textTokens: number;
imageTokens: number;
audioTokens: number;
otherTokens: number;
};
/** Processing time in milliseconds */
processingTime: number;
}
/**
* Configuration for token calculation
*/
export interface TokenizerConfig {
/** Custom text tokenizer encoding (defaults to cl100k_base) */
textEncoding?: string;
}
/**
* Image metadata extracted from base64 data
*/
export interface ImageMetadata {
/** Image width in pixels */
width: number;
/** Image height in pixels */
height: number;
/** MIME type of the image */
mimeType: string;
/** Size of the base64 data in bytes */
dataSize: number;
}
/**
* Request tokenizer interface
*/
export interface RequestTokenizer {
/**
* Calculate tokens for a request
*/
calculateTokens(
request: CountTokensParameters,
config?: TokenizerConfig,
): Promise<TokenCalculationResult>;
/**
* Dispose of resources (worker threads, etc.)
*/
dispose(): Promise<void>;
}

View File

@@ -9,7 +9,7 @@ import * as addFormats from 'ajv-formats';
// Ajv's ESM/CJS interop: use 'any' for compatibility as recommended by Ajv docs
// eslint-disable-next-line @typescript-eslint/no-explicit-any
const AjvClass = (AjvPkg as any).default || AjvPkg;
const ajValidator = new AjvClass({ coerceTypes: true });
const ajValidator = new AjvClass();
// eslint-disable-next-line @typescript-eslint/no-explicit-any
const addFormatsFunc = (addFormats as any).default || addFormats;
addFormatsFunc(ajValidator);
@@ -32,27 +32,8 @@ export class SchemaValidator {
const validate = ajValidator.compile(schema);
const valid = validate(data);
if (!valid && validate.errors) {
// Find any True or False values and lowercase them
fixBooleanCasing(data as Record<string, unknown>);
const validate = ajValidator.compile(schema);
const valid = validate(data);
if (!valid && validate.errors) {
return ajValidator.errorsText(validate.errors, { dataVar: 'params' });
}
return ajValidator.errorsText(validate.errors, { dataVar: 'params' });
}
return null;
}
}
function fixBooleanCasing(data: Record<string, unknown>) {
for (const key of Object.keys(data)) {
if (!(key in data)) continue;
if (typeof data[key] === 'object') {
fixBooleanCasing(data[key] as Record<string, unknown>);
} else if (data[key] === 'True') data[key] = 'true';
else if (data[key] === 'False') data[key] = 'false';
}
}