Compare commits

..

5 Commits

Author SHA1 Message Date
mingholy.lmh
bb9820d27f fix: regExp issue 2025-09-18 20:32:18 +08:00
mingholy.lmh
7fd6a8b73d fix: circular dependency issue and configurable tool-call style 2025-09-18 20:17:41 +08:00
mingholy.lmh
23a523df66 fix: switch system prompt to avoid malformed tool_calls 2025-09-18 18:51:14 +08:00
Mingholy
761833c915 Vision model support for Qwen-OAuth (#525)
* refactor: openaiContentGenerator

* refactor: optimize stream handling

* refactor: re-organize refactored files

* fix: unit test cases

* feat: `/model` command for switching to vision model

* fix: lint error

* feat: add image tokenizer to fit vlm context window

* fix: lint and type errors

* feat: add `visionModelPreview` to control default visibility of vision models

* fix: remove deprecated files

* fix: align supported image formats with bailian doc
2025-09-18 13:32:00 +08:00
Mingholy
56808ac210 fix: reset is_background (#644) 2025-09-18 13:27:09 +08:00
58 changed files with 6274 additions and 6214 deletions

View File

@@ -741,6 +741,16 @@ export const SETTINGS_SCHEMA = {
description: 'Enable extension management features.',
showInDialog: false,
},
visionModelPreview: {
type: 'boolean',
label: 'Vision Model Preview',
category: 'Experimental',
requiresRestart: false,
default: false,
description:
'Enable vision model support and auto-switching functionality. When disabled, vision models like qwen-vl-max-latest will be hidden and auto-switching will not occur.',
showInDialog: true,
},
},
},

View File

@@ -56,6 +56,13 @@ vi.mock('../ui/commands/mcpCommand.js', () => ({
kind: 'BUILT_IN',
},
}));
vi.mock('../ui/commands/modelCommand.js', () => ({
modelCommand: {
name: 'model',
description: 'Model command',
kind: 'BUILT_IN',
},
}));
describe('BuiltinCommandLoader', () => {
let mockConfig: Config;
@@ -126,5 +133,8 @@ describe('BuiltinCommandLoader', () => {
const mcpCmd = commands.find((c) => c.name === 'mcp');
expect(mcpCmd).toBeDefined();
const modelCmd = commands.find((c) => c.name === 'model');
expect(modelCmd).toBeDefined();
});
});

View File

@@ -35,6 +35,7 @@ import { settingsCommand } from '../ui/commands/settingsCommand.js';
import { vimCommand } from '../ui/commands/vimCommand.js';
import { setupGithubCommand } from '../ui/commands/setupGithubCommand.js';
import { terminalSetupCommand } from '../ui/commands/terminalSetupCommand.js';
import { modelCommand } from '../ui/commands/modelCommand.js';
import { agentsCommand } from '../ui/commands/agentsCommand.js';
/**
@@ -71,6 +72,7 @@ export class BuiltinCommandLoader implements ICommandLoader {
initCommand,
mcpCommand,
memoryCommand,
modelCommand,
privacyCommand,
quitCommand,
quitConfirmCommand,

View File

@@ -53,6 +53,17 @@ import { FolderTrustDialog } from './components/FolderTrustDialog.js';
import { ShellConfirmationDialog } from './components/ShellConfirmationDialog.js';
import { QuitConfirmationDialog } from './components/QuitConfirmationDialog.js';
import { RadioButtonSelect } from './components/shared/RadioButtonSelect.js';
import { ModelSelectionDialog } from './components/ModelSelectionDialog.js';
import {
ModelSwitchDialog,
type VisionSwitchOutcome,
} from './components/ModelSwitchDialog.js';
import {
getOpenAIAvailableModelFromEnv,
getFilteredQwenModels,
type AvailableModel,
} from './models/availableModels.js';
import { processVisionSwitchOutcome } from './hooks/useVisionAutoSwitch.js';
import {
AgentCreationWizard,
AgentsManagerDialog,
@@ -248,6 +259,20 @@ const App = ({ config, settings, startupWarnings = [], version }: AppProps) => {
onWorkspaceMigrationDialogClose,
} = useWorkspaceMigration(settings);
// Model selection dialog states
const [isModelSelectionDialogOpen, setIsModelSelectionDialogOpen] =
useState(false);
const [isVisionSwitchDialogOpen, setIsVisionSwitchDialogOpen] =
useState(false);
const [visionSwitchResolver, setVisionSwitchResolver] = useState<{
resolve: (result: {
modelOverride?: string;
persistSessionModel?: string;
showGuidance?: boolean;
}) => void;
reject: () => void;
} | null>(null);
useEffect(() => {
const unsubscribe = ideContext.subscribeToIdeContext(setIdeContextState);
// Set the initial value
@@ -590,6 +615,75 @@ const App = ({ config, settings, startupWarnings = [], version }: AppProps) => {
openAuthDialog();
}, [openAuthDialog, setAuthError]);
// Vision switch handler for auto-switch functionality
const handleVisionSwitchRequired = useCallback(
async (_query: unknown) =>
new Promise<{
modelOverride?: string;
persistSessionModel?: string;
showGuidance?: boolean;
}>((resolve, reject) => {
setVisionSwitchResolver({ resolve, reject });
setIsVisionSwitchDialogOpen(true);
}),
[],
);
const handleVisionSwitchSelect = useCallback(
(outcome: VisionSwitchOutcome) => {
setIsVisionSwitchDialogOpen(false);
if (visionSwitchResolver) {
const result = processVisionSwitchOutcome(outcome);
visionSwitchResolver.resolve(result);
setVisionSwitchResolver(null);
}
},
[visionSwitchResolver],
);
const handleModelSelectionOpen = useCallback(() => {
setIsModelSelectionDialogOpen(true);
}, []);
const handleModelSelectionClose = useCallback(() => {
setIsModelSelectionDialogOpen(false);
}, []);
const handleModelSelect = useCallback(
(modelId: string) => {
config.setModel(modelId);
setCurrentModel(modelId);
setIsModelSelectionDialogOpen(false);
addItem(
{
type: MessageType.INFO,
text: `Switched model to \`${modelId}\` for this session.`,
},
Date.now(),
);
},
[config, setCurrentModel, addItem],
);
const getAvailableModelsForCurrentAuth = useCallback((): AvailableModel[] => {
const contentGeneratorConfig = config.getContentGeneratorConfig();
if (!contentGeneratorConfig) return [];
const visionModelPreviewEnabled =
settings.merged.experimental?.visionModelPreview ?? false;
switch (contentGeneratorConfig.authType) {
case AuthType.QWEN_OAUTH:
return getFilteredQwenModels(visionModelPreviewEnabled);
case AuthType.USE_OPENAI: {
const openAIModel = getOpenAIAvailableModelFromEnv();
return openAIModel ? [openAIModel] : [];
}
default:
return [];
}
}, [config, settings.merged.experimental?.visionModelPreview]);
// Core hooks and processors
const {
vimEnabled: vimModeEnabled,
@@ -620,6 +714,7 @@ const App = ({ config, settings, startupWarnings = [], version }: AppProps) => {
setQuittingMessages,
openPrivacyNotice,
openSettingsDialog,
handleModelSelectionOpen,
openSubagentCreateDialog,
openAgentsManagerDialog,
toggleVimEnabled,
@@ -664,6 +759,8 @@ const App = ({ config, settings, startupWarnings = [], version }: AppProps) => {
setModelSwitchedFromQuotaError,
refreshStatic,
() => cancelHandlerRef.current(),
settings.merged.experimental?.visionModelPreview ?? false,
handleVisionSwitchRequired,
);
const pendingHistoryItems = useMemo(
@@ -1034,6 +1131,8 @@ const App = ({ config, settings, startupWarnings = [], version }: AppProps) => {
!isAuthDialogOpen &&
!isThemeDialogOpen &&
!isEditorDialogOpen &&
!isModelSelectionDialogOpen &&
!isVisionSwitchDialogOpen &&
!isSubagentCreateDialogOpen &&
!showPrivacyNotice &&
!showWelcomeBackDialog &&
@@ -1055,6 +1154,8 @@ const App = ({ config, settings, startupWarnings = [], version }: AppProps) => {
showWelcomeBackDialog,
welcomeBackChoice,
geminiClient,
isModelSelectionDialogOpen,
isVisionSwitchDialogOpen,
]);
if (quittingMessages) {
@@ -1322,6 +1423,15 @@ const App = ({ config, settings, startupWarnings = [], version }: AppProps) => {
onExit={exitEditorDialog}
/>
</Box>
) : isModelSelectionDialogOpen ? (
<ModelSelectionDialog
availableModels={getAvailableModelsForCurrentAuth()}
currentModel={currentModel}
onSelect={handleModelSelect}
onCancel={handleModelSelectionClose}
/>
) : isVisionSwitchDialogOpen ? (
<ModelSwitchDialog onSelect={handleVisionSwitchSelect} />
) : showPrivacyNotice ? (
<PrivacyNotice
onExit={() => setShowPrivacyNotice(false)}

View File

@@ -0,0 +1,179 @@
/**
* @license
* Copyright 2025 Qwen
* SPDX-License-Identifier: Apache-2.0
*/
import { describe, it, expect, beforeEach, vi } from 'vitest';
import { modelCommand } from './modelCommand.js';
import { type CommandContext } from './types.js';
import { createMockCommandContext } from '../../test-utils/mockCommandContext.js';
import {
AuthType,
type ContentGeneratorConfig,
type Config,
} from '@qwen-code/qwen-code-core';
import * as availableModelsModule from '../models/availableModels.js';
// Mock the availableModels module
vi.mock('../models/availableModels.js', () => ({
AVAILABLE_MODELS_QWEN: [
{ id: 'qwen3-coder-plus', label: 'qwen3-coder-plus' },
{ id: 'qwen-vl-max-latest', label: 'qwen-vl-max', isVision: true },
],
getOpenAIAvailableModelFromEnv: vi.fn(),
}));
// Helper function to create a mock config
function createMockConfig(
contentGeneratorConfig: ContentGeneratorConfig | null,
): Partial<Config> {
return {
getContentGeneratorConfig: vi.fn().mockReturnValue(contentGeneratorConfig),
};
}
describe('modelCommand', () => {
let mockContext: CommandContext;
const mockGetOpenAIAvailableModelFromEnv = vi.mocked(
availableModelsModule.getOpenAIAvailableModelFromEnv,
);
beforeEach(() => {
mockContext = createMockCommandContext();
vi.clearAllMocks();
});
it('should have the correct name and description', () => {
expect(modelCommand.name).toBe('model');
expect(modelCommand.description).toBe('Switch the model for this session');
});
it('should return error when config is not available', async () => {
mockContext.services.config = null;
const result = await modelCommand.action!(mockContext, '');
expect(result).toEqual({
type: 'message',
messageType: 'error',
content: 'Configuration not available.',
});
});
it('should return error when content generator config is not available', async () => {
const mockConfig = createMockConfig(null);
mockContext.services.config = mockConfig as Config;
const result = await modelCommand.action!(mockContext, '');
expect(result).toEqual({
type: 'message',
messageType: 'error',
content: 'Content generator configuration not available.',
});
});
it('should return error when auth type is not available', async () => {
const mockConfig = createMockConfig({
model: 'test-model',
authType: undefined,
});
mockContext.services.config = mockConfig as Config;
const result = await modelCommand.action!(mockContext, '');
expect(result).toEqual({
type: 'message',
messageType: 'error',
content: 'Authentication type not available.',
});
});
it('should return dialog action for QWEN_OAUTH auth type', async () => {
const mockConfig = createMockConfig({
model: 'test-model',
authType: AuthType.QWEN_OAUTH,
});
mockContext.services.config = mockConfig as Config;
const result = await modelCommand.action!(mockContext, '');
expect(result).toEqual({
type: 'dialog',
dialog: 'model',
});
});
it('should return dialog action for USE_OPENAI auth type when model is available', async () => {
mockGetOpenAIAvailableModelFromEnv.mockReturnValue({
id: 'gpt-4',
label: 'gpt-4',
});
const mockConfig = createMockConfig({
model: 'test-model',
authType: AuthType.USE_OPENAI,
});
mockContext.services.config = mockConfig as Config;
const result = await modelCommand.action!(mockContext, '');
expect(result).toEqual({
type: 'dialog',
dialog: 'model',
});
});
it('should return error for USE_OPENAI auth type when no model is available', async () => {
mockGetOpenAIAvailableModelFromEnv.mockReturnValue(null);
const mockConfig = createMockConfig({
model: 'test-model',
authType: AuthType.USE_OPENAI,
});
mockContext.services.config = mockConfig as Config;
const result = await modelCommand.action!(mockContext, '');
expect(result).toEqual({
type: 'message',
messageType: 'error',
content:
'No models available for the current authentication type (openai).',
});
});
it('should return error for unsupported auth types', async () => {
const mockConfig = createMockConfig({
model: 'test-model',
authType: 'UNSUPPORTED_AUTH_TYPE' as AuthType,
});
mockContext.services.config = mockConfig as Config;
const result = await modelCommand.action!(mockContext, '');
expect(result).toEqual({
type: 'message',
messageType: 'error',
content:
'No models available for the current authentication type (UNSUPPORTED_AUTH_TYPE).',
});
});
it('should handle undefined auth type', async () => {
const mockConfig = createMockConfig({
model: 'test-model',
authType: undefined,
});
mockContext.services.config = mockConfig as Config;
const result = await modelCommand.action!(mockContext, '');
expect(result).toEqual({
type: 'message',
messageType: 'error',
content: 'Authentication type not available.',
});
});
});

View File

@@ -0,0 +1,88 @@
/**
* @license
* Copyright 2025 Qwen
* SPDX-License-Identifier: Apache-2.0
*/
import { AuthType } from '@qwen-code/qwen-code-core';
import type {
SlashCommand,
CommandContext,
OpenDialogActionReturn,
MessageActionReturn,
} from './types.js';
import { CommandKind } from './types.js';
import {
AVAILABLE_MODELS_QWEN,
getOpenAIAvailableModelFromEnv,
type AvailableModel,
} from '../models/availableModels.js';
function getAvailableModelsForAuthType(authType: AuthType): AvailableModel[] {
switch (authType) {
case AuthType.QWEN_OAUTH:
return AVAILABLE_MODELS_QWEN;
case AuthType.USE_OPENAI: {
const openAIModel = getOpenAIAvailableModelFromEnv();
return openAIModel ? [openAIModel] : [];
}
default:
// For other auth types, return empty array for now
// This can be expanded later according to the design doc
return [];
}
}
export const modelCommand: SlashCommand = {
name: 'model',
description: 'Switch the model for this session',
kind: CommandKind.BUILT_IN,
action: async (
context: CommandContext,
): Promise<OpenDialogActionReturn | MessageActionReturn> => {
const { services } = context;
const { config } = services;
if (!config) {
return {
type: 'message',
messageType: 'error',
content: 'Configuration not available.',
};
}
const contentGeneratorConfig = config.getContentGeneratorConfig();
if (!contentGeneratorConfig) {
return {
type: 'message',
messageType: 'error',
content: 'Content generator configuration not available.',
};
}
const authType = contentGeneratorConfig.authType;
if (!authType) {
return {
type: 'message',
messageType: 'error',
content: 'Authentication type not available.',
};
}
const availableModels = getAvailableModelsForAuthType(authType);
if (availableModels.length === 0) {
return {
type: 'message',
messageType: 'error',
content: `No models available for the current authentication type (${authType}).`,
};
}
// Trigger model selection dialog
return {
type: 'dialog',
dialog: 'model',
};
},
};

View File

@@ -116,6 +116,7 @@ export interface OpenDialogActionReturn {
| 'editor'
| 'privacy'
| 'settings'
| 'model'
| 'subagent_create'
| 'subagent_list';
}

View File

@@ -0,0 +1,246 @@
/**
* @license
* Copyright 2025 Qwen
* SPDX-License-Identifier: Apache-2.0
*/
import React from 'react';
import { render } from 'ink-testing-library';
import { describe, it, expect, vi, beforeEach } from 'vitest';
import { ModelSelectionDialog } from './ModelSelectionDialog.js';
import type { AvailableModel } from '../models/availableModels.js';
import type { RadioSelectItem } from './shared/RadioButtonSelect.js';
// Mock the useKeypress hook
const mockUseKeypress = vi.hoisted(() => vi.fn());
vi.mock('../hooks/useKeypress.js', () => ({
useKeypress: mockUseKeypress,
}));
// Mock the RadioButtonSelect component
const mockRadioButtonSelect = vi.hoisted(() => vi.fn());
vi.mock('./shared/RadioButtonSelect.js', () => ({
RadioButtonSelect: mockRadioButtonSelect,
}));
describe('ModelSelectionDialog', () => {
const mockAvailableModels: AvailableModel[] = [
{ id: 'qwen3-coder-plus', label: 'qwen3-coder-plus' },
{ id: 'qwen-vl-max-latest', label: 'qwen-vl-max', isVision: true },
{ id: 'gpt-4', label: 'GPT-4' },
];
const mockOnSelect = vi.fn();
const mockOnCancel = vi.fn();
beforeEach(() => {
vi.clearAllMocks();
// Mock RadioButtonSelect to return a simple div
mockRadioButtonSelect.mockReturnValue(
React.createElement('div', { 'data-testid': 'radio-select' }),
);
});
it('should setup escape key handler to call onCancel', () => {
render(
<ModelSelectionDialog
availableModels={mockAvailableModels}
currentModel="qwen3-coder-plus"
onSelect={mockOnSelect}
onCancel={mockOnCancel}
/>,
);
expect(mockUseKeypress).toHaveBeenCalledWith(expect.any(Function), {
isActive: true,
});
// Simulate escape key press
const keypressHandler = mockUseKeypress.mock.calls[0][0];
keypressHandler({ name: 'escape' });
expect(mockOnCancel).toHaveBeenCalled();
});
it('should not call onCancel for non-escape keys', () => {
render(
<ModelSelectionDialog
availableModels={mockAvailableModels}
currentModel="qwen3-coder-plus"
onSelect={mockOnSelect}
onCancel={mockOnCancel}
/>,
);
const keypressHandler = mockUseKeypress.mock.calls[0][0];
keypressHandler({ name: 'enter' });
expect(mockOnCancel).not.toHaveBeenCalled();
});
it('should set correct initial index for current model', () => {
render(
<ModelSelectionDialog
availableModels={mockAvailableModels}
currentModel="qwen-vl-max-latest"
onSelect={mockOnSelect}
onCancel={mockOnCancel}
/>,
);
const callArgs = mockRadioButtonSelect.mock.calls[0][0];
expect(callArgs.initialIndex).toBe(1); // qwen-vl-max-latest is at index 1
});
it('should set initial index to 0 when current model is not found', () => {
render(
<ModelSelectionDialog
availableModels={mockAvailableModels}
currentModel="non-existent-model"
onSelect={mockOnSelect}
onCancel={mockOnCancel}
/>,
);
const callArgs = mockRadioButtonSelect.mock.calls[0][0];
expect(callArgs.initialIndex).toBe(0);
});
it('should call onSelect when a model is selected', () => {
render(
<ModelSelectionDialog
availableModels={mockAvailableModels}
currentModel="qwen3-coder-plus"
onSelect={mockOnSelect}
onCancel={mockOnCancel}
/>,
);
const callArgs = mockRadioButtonSelect.mock.calls[0][0];
expect(typeof callArgs.onSelect).toBe('function');
// Simulate selection
const onSelectCallback = mockRadioButtonSelect.mock.calls[0][0].onSelect;
onSelectCallback('qwen-vl-max-latest');
expect(mockOnSelect).toHaveBeenCalledWith('qwen-vl-max-latest');
});
it('should handle empty models array', () => {
render(
<ModelSelectionDialog
availableModels={[]}
currentModel=""
onSelect={mockOnSelect}
onCancel={mockOnCancel}
/>,
);
const callArgs = mockRadioButtonSelect.mock.calls[0][0];
expect(callArgs.items).toEqual([]);
expect(callArgs.initialIndex).toBe(0);
});
it('should create correct option items with proper labels', () => {
render(
<ModelSelectionDialog
availableModels={mockAvailableModels}
currentModel="qwen3-coder-plus"
onSelect={mockOnSelect}
onCancel={mockOnCancel}
/>,
);
const expectedItems = [
{
label: 'qwen3-coder-plus (current)',
value: 'qwen3-coder-plus',
},
{
label: 'qwen-vl-max [Vision]',
value: 'qwen-vl-max-latest',
},
{
label: 'GPT-4',
value: 'gpt-4',
},
];
const callArgs = mockRadioButtonSelect.mock.calls[0][0];
expect(callArgs.items).toEqual(expectedItems);
});
it('should show vision indicator for vision models', () => {
render(
<ModelSelectionDialog
availableModels={mockAvailableModels}
currentModel="gpt-4"
onSelect={mockOnSelect}
onCancel={mockOnCancel}
/>,
);
const callArgs = mockRadioButtonSelect.mock.calls[0][0];
const visionModelItem = callArgs.items.find(
(item: RadioSelectItem<string>) => item.value === 'qwen-vl-max-latest',
);
expect(visionModelItem?.label).toContain('[Vision]');
});
it('should show current indicator for the current model', () => {
render(
<ModelSelectionDialog
availableModels={mockAvailableModels}
currentModel="qwen-vl-max-latest"
onSelect={mockOnSelect}
onCancel={mockOnCancel}
/>,
);
const callArgs = mockRadioButtonSelect.mock.calls[0][0];
const currentModelItem = callArgs.items.find(
(item: RadioSelectItem<string>) => item.value === 'qwen-vl-max-latest',
);
expect(currentModelItem?.label).toContain('(current)');
});
it('should pass isFocused prop to RadioButtonSelect', () => {
render(
<ModelSelectionDialog
availableModels={mockAvailableModels}
currentModel="qwen3-coder-plus"
onSelect={mockOnSelect}
onCancel={mockOnCancel}
/>,
);
const callArgs = mockRadioButtonSelect.mock.calls[0][0];
expect(callArgs.isFocused).toBe(true);
});
it('should handle multiple onSelect calls correctly', () => {
render(
<ModelSelectionDialog
availableModels={mockAvailableModels}
currentModel="qwen3-coder-plus"
onSelect={mockOnSelect}
onCancel={mockOnCancel}
/>,
);
const onSelectCallback = mockRadioButtonSelect.mock.calls[0][0].onSelect;
// Call multiple times
onSelectCallback('qwen3-coder-plus');
onSelectCallback('qwen-vl-max-latest');
onSelectCallback('gpt-4');
expect(mockOnSelect).toHaveBeenCalledTimes(3);
expect(mockOnSelect).toHaveBeenNthCalledWith(1, 'qwen3-coder-plus');
expect(mockOnSelect).toHaveBeenNthCalledWith(2, 'qwen-vl-max-latest');
expect(mockOnSelect).toHaveBeenNthCalledWith(3, 'gpt-4');
});
});

View File

@@ -0,0 +1,87 @@
/**
* @license
* Copyright 2025 Qwen
* SPDX-License-Identifier: Apache-2.0
*/
import type React from 'react';
import { Box, Text } from 'ink';
import { Colors } from '../colors.js';
import {
RadioButtonSelect,
type RadioSelectItem,
} from './shared/RadioButtonSelect.js';
import { useKeypress } from '../hooks/useKeypress.js';
import type { AvailableModel } from '../models/availableModels.js';
export interface ModelSelectionDialogProps {
availableModels: AvailableModel[];
currentModel: string;
onSelect: (modelId: string) => void;
onCancel: () => void;
}
export const ModelSelectionDialog: React.FC<ModelSelectionDialogProps> = ({
availableModels,
currentModel,
onSelect,
onCancel,
}) => {
useKeypress(
(key) => {
if (key.name === 'escape') {
onCancel();
}
},
{ isActive: true },
);
const options: Array<RadioSelectItem<string>> = availableModels.map(
(model) => {
const visionIndicator = model.isVision ? ' [Vision]' : '';
const currentIndicator = model.id === currentModel ? ' (current)' : '';
return {
label: `${model.label}${visionIndicator}${currentIndicator}`,
value: model.id,
};
},
);
const initialIndex = Math.max(
0,
availableModels.findIndex((model) => model.id === currentModel),
);
const handleSelect = (modelId: string) => {
onSelect(modelId);
};
return (
<Box
flexDirection="column"
borderStyle="round"
borderColor={Colors.AccentBlue}
padding={1}
width="100%"
marginLeft={1}
>
<Box flexDirection="column" marginBottom={1}>
<Text bold>Select Model</Text>
<Text>Choose a model for this session:</Text>
</Box>
<Box marginBottom={1}>
<RadioButtonSelect
items={options}
initialIndex={initialIndex}
onSelect={handleSelect}
isFocused
/>
</Box>
<Box>
<Text color={Colors.Gray}>Press Enter to select, Esc to cancel</Text>
</Box>
</Box>
);
};

View File

@@ -0,0 +1,185 @@
/**
* @license
* Copyright 2025 Qwen
* SPDX-License-Identifier: Apache-2.0
*/
import React from 'react';
import { render } from 'ink-testing-library';
import { describe, it, expect, vi, beforeEach } from 'vitest';
import { ModelSwitchDialog, VisionSwitchOutcome } from './ModelSwitchDialog.js';
// Mock the useKeypress hook
const mockUseKeypress = vi.hoisted(() => vi.fn());
vi.mock('../hooks/useKeypress.js', () => ({
useKeypress: mockUseKeypress,
}));
// Mock the RadioButtonSelect component
const mockRadioButtonSelect = vi.hoisted(() => vi.fn());
vi.mock('./shared/RadioButtonSelect.js', () => ({
RadioButtonSelect: mockRadioButtonSelect,
}));
describe('ModelSwitchDialog', () => {
const mockOnSelect = vi.fn();
beforeEach(() => {
vi.clearAllMocks();
// Mock RadioButtonSelect to return a simple div
mockRadioButtonSelect.mockReturnValue(
React.createElement('div', { 'data-testid': 'radio-select' }),
);
});
it('should setup RadioButtonSelect with correct options', () => {
render(<ModelSwitchDialog onSelect={mockOnSelect} />);
const expectedItems = [
{
label: 'Switch for this request only',
value: VisionSwitchOutcome.SwitchOnce,
},
{
label: 'Switch session to vision model',
value: VisionSwitchOutcome.SwitchSessionToVL,
},
{
label: 'Do not switch, show guidance',
value: VisionSwitchOutcome.DisallowWithGuidance,
},
];
const callArgs = mockRadioButtonSelect.mock.calls[0][0];
expect(callArgs.items).toEqual(expectedItems);
expect(callArgs.initialIndex).toBe(0);
expect(callArgs.isFocused).toBe(true);
});
it('should call onSelect when an option is selected', () => {
render(<ModelSwitchDialog onSelect={mockOnSelect} />);
const callArgs = mockRadioButtonSelect.mock.calls[0][0];
expect(typeof callArgs.onSelect).toBe('function');
// Simulate selection of "Switch for this request only"
const onSelectCallback = mockRadioButtonSelect.mock.calls[0][0].onSelect;
onSelectCallback(VisionSwitchOutcome.SwitchOnce);
expect(mockOnSelect).toHaveBeenCalledWith(VisionSwitchOutcome.SwitchOnce);
});
it('should call onSelect with SwitchSessionToVL when second option is selected', () => {
render(<ModelSwitchDialog onSelect={mockOnSelect} />);
const onSelectCallback = mockRadioButtonSelect.mock.calls[0][0].onSelect;
onSelectCallback(VisionSwitchOutcome.SwitchSessionToVL);
expect(mockOnSelect).toHaveBeenCalledWith(
VisionSwitchOutcome.SwitchSessionToVL,
);
});
it('should call onSelect with DisallowWithGuidance when third option is selected', () => {
render(<ModelSwitchDialog onSelect={mockOnSelect} />);
const onSelectCallback = mockRadioButtonSelect.mock.calls[0][0].onSelect;
onSelectCallback(VisionSwitchOutcome.DisallowWithGuidance);
expect(mockOnSelect).toHaveBeenCalledWith(
VisionSwitchOutcome.DisallowWithGuidance,
);
});
it('should setup escape key handler to call onSelect with DisallowWithGuidance', () => {
render(<ModelSwitchDialog onSelect={mockOnSelect} />);
expect(mockUseKeypress).toHaveBeenCalledWith(expect.any(Function), {
isActive: true,
});
// Simulate escape key press
const keypressHandler = mockUseKeypress.mock.calls[0][0];
keypressHandler({ name: 'escape' });
expect(mockOnSelect).toHaveBeenCalledWith(
VisionSwitchOutcome.DisallowWithGuidance,
);
});
it('should not call onSelect for non-escape keys', () => {
render(<ModelSwitchDialog onSelect={mockOnSelect} />);
const keypressHandler = mockUseKeypress.mock.calls[0][0];
keypressHandler({ name: 'enter' });
expect(mockOnSelect).not.toHaveBeenCalled();
});
it('should set initial index to 0 (first option)', () => {
render(<ModelSwitchDialog onSelect={mockOnSelect} />);
const callArgs = mockRadioButtonSelect.mock.calls[0][0];
expect(callArgs.initialIndex).toBe(0);
});
describe('VisionSwitchOutcome enum', () => {
it('should have correct enum values', () => {
expect(VisionSwitchOutcome.SwitchOnce).toBe('switch_once');
expect(VisionSwitchOutcome.SwitchSessionToVL).toBe(
'switch_session_to_vl',
);
expect(VisionSwitchOutcome.DisallowWithGuidance).toBe(
'disallow_with_guidance',
);
});
});
it('should handle multiple onSelect calls correctly', () => {
render(<ModelSwitchDialog onSelect={mockOnSelect} />);
const onSelectCallback = mockRadioButtonSelect.mock.calls[0][0].onSelect;
// Call multiple times
onSelectCallback(VisionSwitchOutcome.SwitchOnce);
onSelectCallback(VisionSwitchOutcome.SwitchSessionToVL);
onSelectCallback(VisionSwitchOutcome.DisallowWithGuidance);
expect(mockOnSelect).toHaveBeenCalledTimes(3);
expect(mockOnSelect).toHaveBeenNthCalledWith(
1,
VisionSwitchOutcome.SwitchOnce,
);
expect(mockOnSelect).toHaveBeenNthCalledWith(
2,
VisionSwitchOutcome.SwitchSessionToVL,
);
expect(mockOnSelect).toHaveBeenNthCalledWith(
3,
VisionSwitchOutcome.DisallowWithGuidance,
);
});
it('should pass isFocused prop to RadioButtonSelect', () => {
render(<ModelSwitchDialog onSelect={mockOnSelect} />);
const callArgs = mockRadioButtonSelect.mock.calls[0][0];
expect(callArgs.isFocused).toBe(true);
});
it('should handle escape key multiple times', () => {
render(<ModelSwitchDialog onSelect={mockOnSelect} />);
const keypressHandler = mockUseKeypress.mock.calls[0][0];
// Call escape multiple times
keypressHandler({ name: 'escape' });
keypressHandler({ name: 'escape' });
expect(mockOnSelect).toHaveBeenCalledTimes(2);
expect(mockOnSelect).toHaveBeenCalledWith(
VisionSwitchOutcome.DisallowWithGuidance,
);
});
});

View File

@@ -0,0 +1,89 @@
/**
* @license
* Copyright 2025 Qwen
* SPDX-License-Identifier: Apache-2.0
*/
import type React from 'react';
import { Box, Text } from 'ink';
import { Colors } from '../colors.js';
import {
RadioButtonSelect,
type RadioSelectItem,
} from './shared/RadioButtonSelect.js';
import { useKeypress } from '../hooks/useKeypress.js';
export enum VisionSwitchOutcome {
SwitchOnce = 'switch_once',
SwitchSessionToVL = 'switch_session_to_vl',
DisallowWithGuidance = 'disallow_with_guidance',
}
export interface ModelSwitchDialogProps {
onSelect: (outcome: VisionSwitchOutcome) => void;
}
export const ModelSwitchDialog: React.FC<ModelSwitchDialogProps> = ({
onSelect,
}) => {
useKeypress(
(key) => {
if (key.name === 'escape') {
onSelect(VisionSwitchOutcome.DisallowWithGuidance);
}
},
{ isActive: true },
);
const options: Array<RadioSelectItem<VisionSwitchOutcome>> = [
{
label: 'Switch for this request only',
value: VisionSwitchOutcome.SwitchOnce,
},
{
label: 'Switch session to vision model',
value: VisionSwitchOutcome.SwitchSessionToVL,
},
{
label: 'Do not switch, show guidance',
value: VisionSwitchOutcome.DisallowWithGuidance,
},
];
const handleSelect = (outcome: VisionSwitchOutcome) => {
onSelect(outcome);
};
return (
<Box
flexDirection="column"
borderStyle="round"
borderColor={Colors.AccentYellow}
padding={1}
width="100%"
marginLeft={1}
>
<Box flexDirection="column" marginBottom={1}>
<Text bold>Vision Model Switch Required</Text>
<Text>
Your message contains an image, but the current model doesn&apos;t
support vision.
</Text>
<Text>How would you like to proceed?</Text>
</Box>
<Box marginBottom={1}>
<RadioButtonSelect
items={options}
initialIndex={0}
onSelect={handleSelect}
isFocused
/>
</Box>
<Box>
<Text color={Colors.Gray}>Press Enter to select, Esc to cancel</Text>
</Box>
</Box>
);
};

View File

@@ -106,6 +106,7 @@ describe('useSlashCommandProcessor', () => {
const mockLoadHistory = vi.fn();
const mockOpenThemeDialog = vi.fn();
const mockOpenAuthDialog = vi.fn();
const mockOpenModelSelectionDialog = vi.fn();
const mockSetQuittingMessages = vi.fn();
const mockConfig = makeFakeConfig({});
@@ -122,6 +123,7 @@ describe('useSlashCommandProcessor', () => {
mockBuiltinLoadCommands.mockResolvedValue([]);
mockFileLoadCommands.mockResolvedValue([]);
mockMcpLoadCommands.mockResolvedValue([]);
mockOpenModelSelectionDialog.mockClear();
});
const setupProcessorHook = (
@@ -150,11 +152,13 @@ describe('useSlashCommandProcessor', () => {
mockSetQuittingMessages,
vi.fn(), // openPrivacyNotice
vi.fn(), // openSettingsDialog
mockOpenModelSelectionDialog,
vi.fn(), // openSubagentCreateDialog
vi.fn(), // openAgentsManagerDialog
vi.fn(), // toggleVimEnabled
setIsProcessing,
vi.fn(), // setGeminiMdFileCount
vi.fn(), // _showQuitConfirmation
),
);
@@ -395,6 +399,21 @@ describe('useSlashCommandProcessor', () => {
expect(mockOpenThemeDialog).toHaveBeenCalled();
});
it('should handle "dialog: model" action', async () => {
const command = createTestCommand({
name: 'modelcmd',
action: vi.fn().mockResolvedValue({ type: 'dialog', dialog: 'model' }),
});
const result = setupProcessorHook([command]);
await waitFor(() => expect(result.current.slashCommands).toHaveLength(1));
await act(async () => {
await result.current.handleSlashCommand('/modelcmd');
});
expect(mockOpenModelSelectionDialog).toHaveBeenCalled();
});
it('should handle "load_history" action', async () => {
const command = createTestCommand({
name: 'load',
@@ -904,11 +923,13 @@ describe('useSlashCommandProcessor', () => {
mockSetQuittingMessages,
vi.fn(), // openPrivacyNotice
vi.fn(), // openSettingsDialog
vi.fn(), // openModelSelectionDialog
vi.fn(), // openSubagentCreateDialog
vi.fn(), // openAgentsManagerDialog
vi.fn(), // toggleVimEnabled
vi.fn(), // setIsProcessing
vi.fn(), // setGeminiMdFileCount
vi.fn(), // _showQuitConfirmation
),
);

View File

@@ -53,6 +53,7 @@ export const useSlashCommandProcessor = (
setQuittingMessages: (message: HistoryItem[]) => void,
openPrivacyNotice: () => void,
openSettingsDialog: () => void,
openModelSelectionDialog: () => void,
openSubagentCreateDialog: () => void,
openAgentsManagerDialog: () => void,
toggleVimEnabled: () => Promise<boolean>,
@@ -404,6 +405,9 @@ export const useSlashCommandProcessor = (
case 'settings':
openSettingsDialog();
return { type: 'handled' };
case 'model':
openModelSelectionDialog();
return { type: 'handled' };
case 'subagent_create':
openSubagentCreateDialog();
return { type: 'handled' };
@@ -663,6 +667,7 @@ export const useSlashCommandProcessor = (
setSessionShellAllowlist,
setIsProcessing,
setConfirmationRequest,
openModelSelectionDialog,
session.stats,
],
);

View File

@@ -56,6 +56,12 @@ const MockedUserPromptEvent = vi.hoisted(() =>
);
const mockParseAndFormatApiError = vi.hoisted(() => vi.fn());
// Vision auto-switch mocks (hoisted)
const mockHandleVisionSwitch = vi.hoisted(() =>
vi.fn().mockResolvedValue({ shouldProceed: true }),
);
const mockRestoreOriginalModel = vi.hoisted(() => vi.fn());
vi.mock('@qwen-code/qwen-code-core', async (importOriginal) => {
const actualCoreModule = (await importOriginal()) as any;
return {
@@ -76,6 +82,13 @@ vi.mock('./useReactToolScheduler.js', async (importOriginal) => {
};
});
vi.mock('./useVisionAutoSwitch.js', () => ({
useVisionAutoSwitch: vi.fn(() => ({
handleVisionSwitch: mockHandleVisionSwitch,
restoreOriginalModel: mockRestoreOriginalModel,
})),
}));
vi.mock('./useKeypress.js', () => ({
useKeypress: vi.fn(),
}));
@@ -199,6 +212,7 @@ describe('useGeminiStream', () => {
getContentGeneratorConfig: vi
.fn()
.mockReturnValue(contentGeneratorConfig),
getMaxSessionTurns: vi.fn(() => 50),
} as unknown as Config;
mockOnDebugMessage = vi.fn();
mockHandleSlashCommand = vi.fn().mockResolvedValue(false);
@@ -1551,6 +1565,7 @@ describe('useGeminiStream', () => {
expect.any(String), // Argument 3: The prompt_id string
);
});
describe('Thought Reset', () => {
it('should reset thought to null when starting a new prompt', async () => {
// First, simulate a response with a thought
@@ -1900,4 +1915,166 @@ describe('useGeminiStream', () => {
);
});
});
// --- New tests focused on recent modifications ---
describe('Vision Auto Switch Integration', () => {
it('should call handleVisionSwitch and proceed to send when allowed', async () => {
mockHandleVisionSwitch.mockResolvedValueOnce({ shouldProceed: true });
mockSendMessageStream.mockReturnValue(
(async function* () {
yield { type: ServerGeminiEventType.Content, value: 'ok' };
yield { type: ServerGeminiEventType.Finished, value: 'STOP' };
})(),
);
const { result } = renderHook(() =>
useGeminiStream(
new MockedGeminiClientClass(mockConfig),
[],
mockAddItem,
mockConfig,
mockOnDebugMessage,
mockHandleSlashCommand,
false,
() => 'vscode' as EditorType,
() => {},
() => Promise.resolve(),
false,
() => {},
() => {},
() => {},
),
);
await act(async () => {
await result.current.submitQuery('image prompt');
});
await waitFor(() => {
expect(mockHandleVisionSwitch).toHaveBeenCalled();
expect(mockSendMessageStream).toHaveBeenCalled();
});
});
it('should gate submission when handleVisionSwitch returns shouldProceed=false', async () => {
mockHandleVisionSwitch.mockResolvedValueOnce({ shouldProceed: false });
const { result } = renderHook(() =>
useGeminiStream(
new MockedGeminiClientClass(mockConfig),
[],
mockAddItem,
mockConfig,
mockOnDebugMessage,
mockHandleSlashCommand,
false,
() => 'vscode' as EditorType,
() => {},
() => Promise.resolve(),
false,
() => {},
() => {},
() => {},
),
);
await act(async () => {
await result.current.submitQuery('vision-gated');
});
// No call to API, no restoreOriginalModel needed since no override occurred
expect(mockSendMessageStream).not.toHaveBeenCalled();
expect(mockRestoreOriginalModel).not.toHaveBeenCalled();
// Next call allowed (flag reset path)
mockHandleVisionSwitch.mockResolvedValueOnce({ shouldProceed: true });
mockSendMessageStream.mockReturnValue(
(async function* () {
yield { type: ServerGeminiEventType.Content, value: 'ok' };
yield { type: ServerGeminiEventType.Finished, value: 'STOP' };
})(),
);
await act(async () => {
await result.current.submitQuery('after-gate');
});
await waitFor(() => {
expect(mockSendMessageStream).toHaveBeenCalled();
});
});
});
describe('Model restore on completion and errors', () => {
it('should restore model after successful stream completion', async () => {
mockSendMessageStream.mockReturnValue(
(async function* () {
yield { type: ServerGeminiEventType.Content, value: 'content' };
yield { type: ServerGeminiEventType.Finished, value: 'STOP' };
})(),
);
const { result } = renderHook(() =>
useGeminiStream(
new MockedGeminiClientClass(mockConfig),
[],
mockAddItem,
mockConfig,
mockOnDebugMessage,
mockHandleSlashCommand,
false,
() => 'vscode' as EditorType,
() => {},
() => Promise.resolve(),
false,
() => {},
() => {},
() => {},
),
);
await act(async () => {
await result.current.submitQuery('restore-success');
});
await waitFor(() => {
expect(mockRestoreOriginalModel).toHaveBeenCalledTimes(1);
});
});
it('should restore model when an error occurs during streaming', async () => {
const testError = new Error('stream failure');
mockSendMessageStream.mockReturnValue(
(async function* () {
yield { type: ServerGeminiEventType.Content, value: 'content' };
throw testError;
})(),
);
const { result } = renderHook(() =>
useGeminiStream(
new MockedGeminiClientClass(mockConfig),
[],
mockAddItem,
mockConfig,
mockOnDebugMessage,
mockHandleSlashCommand,
false,
() => 'vscode' as EditorType,
() => {},
() => Promise.resolve(),
false,
() => {},
() => {},
() => {},
),
);
await act(async () => {
await result.current.submitQuery('restore-error');
});
await waitFor(() => {
expect(mockRestoreOriginalModel).toHaveBeenCalledTimes(1);
});
});
});
});

View File

@@ -42,6 +42,7 @@ import type {
import { StreamingState, MessageType, ToolCallStatus } from '../types.js';
import { isAtCommand, isSlashCommand } from '../utils/commandUtils.js';
import { useShellCommandProcessor } from './shellCommandProcessor.js';
import { useVisionAutoSwitch } from './useVisionAutoSwitch.js';
import { handleAtCommand } from './atCommandProcessor.js';
import { findLastSafeSplitPoint } from '../utils/markdownUtilities.js';
import { useStateAndRef } from './useStateAndRef.js';
@@ -88,6 +89,12 @@ export const useGeminiStream = (
setModelSwitchedFromQuotaError: React.Dispatch<React.SetStateAction<boolean>>,
onEditorClose: () => void,
onCancelSubmit: () => void,
visionModelPreviewEnabled: boolean = false,
onVisionSwitchRequired?: (query: PartListUnion) => Promise<{
modelOverride?: string;
persistSessionModel?: string;
showGuidance?: boolean;
}>,
) => {
const [initError, setInitError] = useState<string | null>(null);
const abortControllerRef = useRef<AbortController | null>(null);
@@ -155,6 +162,13 @@ export const useGeminiStream = (
geminiClient,
);
const { handleVisionSwitch, restoreOriginalModel } = useVisionAutoSwitch(
config,
addItem,
visionModelPreviewEnabled,
onVisionSwitchRequired,
);
const streamingState = useMemo(() => {
if (toolCalls.some((tc) => tc.status === 'awaiting_approval')) {
return StreamingState.WaitingForConfirmation;
@@ -715,6 +729,20 @@ export const useGeminiStream = (
return;
}
// Handle vision switch requirement
const visionSwitchResult = await handleVisionSwitch(
queryToSend,
userMessageTimestamp,
options?.isContinuation || false,
);
if (!visionSwitchResult.shouldProceed) {
isSubmittingQueryRef.current = false;
return;
}
const finalQueryToSend = queryToSend;
if (!options?.isContinuation) {
startNewPrompt();
setThought(null); // Reset thought when starting a new prompt
@@ -725,7 +753,7 @@ export const useGeminiStream = (
try {
const stream = geminiClient.sendMessageStream(
queryToSend,
finalQueryToSend,
abortSignal,
prompt_id!,
);
@@ -736,6 +764,8 @@ export const useGeminiStream = (
);
if (processingStatus === StreamProcessingStatus.UserCancelled) {
// Restore original model if it was temporarily overridden
restoreOriginalModel();
isSubmittingQueryRef.current = false;
return;
}
@@ -748,7 +778,13 @@ export const useGeminiStream = (
loopDetectedRef.current = false;
handleLoopDetectedEvent();
}
// Restore original model if it was temporarily overridden
restoreOriginalModel();
} catch (error: unknown) {
// Restore original model if it was temporarily overridden
restoreOriginalModel();
if (error instanceof UnauthorizedError) {
onAuthError();
} else if (!isNodeError(error) || error.name !== 'AbortError') {
@@ -786,6 +822,8 @@ export const useGeminiStream = (
startNewPrompt,
getPromptCount,
handleLoopDetectedEvent,
handleVisionSwitch,
restoreOriginalModel,
],
);

View File

@@ -0,0 +1,374 @@
/**
* @license
* Copyright 2025 Qwen
* SPDX-License-Identifier: Apache-2.0
*/
/* eslint-disable @typescript-eslint/no-explicit-any */
import { describe, it, expect, vi, beforeEach } from 'vitest';
import { renderHook, act } from '@testing-library/react';
import type { Part, PartListUnion } from '@google/genai';
import { AuthType, type Config } from '@qwen-code/qwen-code-core';
import {
shouldOfferVisionSwitch,
processVisionSwitchOutcome,
getVisionSwitchGuidanceMessage,
useVisionAutoSwitch,
} from './useVisionAutoSwitch.js';
import { VisionSwitchOutcome } from '../components/ModelSwitchDialog.js';
import { MessageType } from '../types.js';
import { getDefaultVisionModel } from '../models/availableModels.js';
describe('useVisionAutoSwitch helpers', () => {
describe('shouldOfferVisionSwitch', () => {
it('returns false when authType is not QWEN_OAUTH', () => {
const parts: PartListUnion = [
{ inlineData: { mimeType: 'image/png', data: '...' } },
];
const result = shouldOfferVisionSwitch(
parts,
AuthType.USE_GEMINI,
'qwen3-coder-plus',
true,
);
expect(result).toBe(false);
});
it('returns false when current model is already a vision model', () => {
const parts: PartListUnion = [
{ inlineData: { mimeType: 'image/png', data: '...' } },
];
const result = shouldOfferVisionSwitch(
parts,
AuthType.QWEN_OAUTH,
'qwen-vl-max-latest',
true,
);
expect(result).toBe(false);
});
it('returns true when image parts exist, QWEN_OAUTH, and model is not vision', () => {
const parts: PartListUnion = [
{ text: 'hello' },
{ inlineData: { mimeType: 'image/jpeg', data: '...' } },
];
const result = shouldOfferVisionSwitch(
parts,
AuthType.QWEN_OAUTH,
'qwen3-coder-plus',
true,
);
expect(result).toBe(true);
});
it('detects image when provided as a single Part object (non-array)', () => {
const singleImagePart: PartListUnion = {
fileData: { mimeType: 'image/gif', fileUri: 'file://image.gif' },
} as Part;
const result = shouldOfferVisionSwitch(
singleImagePart,
AuthType.QWEN_OAUTH,
'qwen3-coder-plus',
true,
);
expect(result).toBe(true);
});
it('returns false when parts contain no images', () => {
const parts: PartListUnion = [{ text: 'just text' }];
const result = shouldOfferVisionSwitch(
parts,
AuthType.QWEN_OAUTH,
'qwen3-coder-plus',
true,
);
expect(result).toBe(false);
});
it('returns false when parts is a plain string', () => {
const parts: PartListUnion = 'plain text';
const result = shouldOfferVisionSwitch(
parts,
AuthType.QWEN_OAUTH,
'qwen3-coder-plus',
true,
);
expect(result).toBe(false);
});
it('returns false when visionModelPreviewEnabled is false', () => {
const parts: PartListUnion = [
{ inlineData: { mimeType: 'image/png', data: '...' } },
];
const result = shouldOfferVisionSwitch(
parts,
AuthType.QWEN_OAUTH,
'qwen3-coder-plus',
false,
);
expect(result).toBe(false);
});
});
describe('processVisionSwitchOutcome', () => {
it('maps SwitchOnce to a one-time model override', () => {
const vl = getDefaultVisionModel();
const result = processVisionSwitchOutcome(VisionSwitchOutcome.SwitchOnce);
expect(result).toEqual({ modelOverride: vl });
});
it('maps SwitchSessionToVL to a persistent session model', () => {
const vl = getDefaultVisionModel();
const result = processVisionSwitchOutcome(
VisionSwitchOutcome.SwitchSessionToVL,
);
expect(result).toEqual({ persistSessionModel: vl });
});
it('maps DisallowWithGuidance to showGuidance', () => {
const result = processVisionSwitchOutcome(
VisionSwitchOutcome.DisallowWithGuidance,
);
expect(result).toEqual({ showGuidance: true });
});
});
describe('getVisionSwitchGuidanceMessage', () => {
it('returns the expected guidance message', () => {
const vl = getDefaultVisionModel();
const expected =
'To use images with your query, you can:\n' +
`• Use /model set ${vl} to switch to a vision-capable model\n` +
'• Or remove the image and provide a text description instead';
expect(getVisionSwitchGuidanceMessage()).toBe(expected);
});
});
});
describe('useVisionAutoSwitch hook', () => {
type AddItemFn = (
item: { type: MessageType; text: string },
ts: number,
) => any;
const createMockConfig = (authType: AuthType, initialModel: string) => {
let currentModel = initialModel;
const mockConfig: Partial<Config> = {
getModel: vi.fn(() => currentModel),
setModel: vi.fn((m: string) => {
currentModel = m;
}),
getContentGeneratorConfig: vi.fn(() => ({
authType,
model: currentModel,
apiKey: 'test-key',
vertexai: false,
})),
};
return mockConfig as Config;
};
let addItem: AddItemFn;
beforeEach(() => {
vi.clearAllMocks();
addItem = vi.fn();
});
it('returns shouldProceed=true immediately for continuations', async () => {
const config = createMockConfig(AuthType.QWEN_OAUTH, 'qwen3-coder-plus');
const { result } = renderHook(() =>
useVisionAutoSwitch(config, addItem as any, true, vi.fn()),
);
const parts: PartListUnion = [
{ inlineData: { mimeType: 'image/png', data: '...' } },
];
let res: any;
await act(async () => {
res = await result.current.handleVisionSwitch(parts, Date.now(), true);
});
expect(res).toEqual({ shouldProceed: true });
expect(addItem).not.toHaveBeenCalled();
});
it('does nothing when authType is not QWEN_OAUTH', async () => {
const config = createMockConfig(AuthType.USE_GEMINI, 'qwen3-coder-plus');
const onVisionSwitchRequired = vi.fn();
const { result } = renderHook(() =>
useVisionAutoSwitch(config, addItem as any, true, onVisionSwitchRequired),
);
const parts: PartListUnion = [
{ inlineData: { mimeType: 'image/png', data: '...' } },
];
let res: any;
await act(async () => {
res = await result.current.handleVisionSwitch(parts, 123, false);
});
expect(res).toEqual({ shouldProceed: true });
expect(onVisionSwitchRequired).not.toHaveBeenCalled();
});
it('does nothing when there are no image parts', async () => {
const config = createMockConfig(AuthType.QWEN_OAUTH, 'qwen3-coder-plus');
const onVisionSwitchRequired = vi.fn();
const { result } = renderHook(() =>
useVisionAutoSwitch(config, addItem as any, true, onVisionSwitchRequired),
);
const parts: PartListUnion = [{ text: 'no images here' }];
let res: any;
await act(async () => {
res = await result.current.handleVisionSwitch(parts, 456, false);
});
expect(res).toEqual({ shouldProceed: true });
expect(onVisionSwitchRequired).not.toHaveBeenCalled();
});
it('shows guidance and blocks when dialog returns showGuidance', async () => {
const config = createMockConfig(AuthType.QWEN_OAUTH, 'qwen3-coder-plus');
const onVisionSwitchRequired = vi
.fn()
.mockResolvedValue({ showGuidance: true });
const { result } = renderHook(() =>
useVisionAutoSwitch(config, addItem as any, true, onVisionSwitchRequired),
);
const parts: PartListUnion = [
{ inlineData: { mimeType: 'image/png', data: '...' } },
];
const userTs = 1010;
let res: any;
await act(async () => {
res = await result.current.handleVisionSwitch(parts, userTs, false);
});
expect(addItem).toHaveBeenCalledWith(
{ type: MessageType.INFO, text: getVisionSwitchGuidanceMessage() },
userTs,
);
expect(res).toEqual({ shouldProceed: false });
expect(config.setModel).not.toHaveBeenCalled();
});
it('applies a one-time override and returns originalModel, then restores', async () => {
const initialModel = 'qwen3-coder-plus';
const config = createMockConfig(AuthType.QWEN_OAUTH, initialModel);
const onVisionSwitchRequired = vi
.fn()
.mockResolvedValue({ modelOverride: 'qwen-vl-max-latest' });
const { result } = renderHook(() =>
useVisionAutoSwitch(config, addItem as any, true, onVisionSwitchRequired),
);
const parts: PartListUnion = [
{ inlineData: { mimeType: 'image/png', data: '...' } },
];
let res: any;
await act(async () => {
res = await result.current.handleVisionSwitch(parts, 2020, false);
});
expect(res).toEqual({ shouldProceed: true, originalModel: initialModel });
expect(config.setModel).toHaveBeenCalledWith('qwen-vl-max-latest');
// Now restore
act(() => {
result.current.restoreOriginalModel();
});
expect(config.setModel).toHaveBeenLastCalledWith(initialModel);
});
it('persists session model when dialog requests persistence', async () => {
const config = createMockConfig(AuthType.QWEN_OAUTH, 'qwen3-coder-plus');
const onVisionSwitchRequired = vi
.fn()
.mockResolvedValue({ persistSessionModel: 'qwen-vl-max-latest' });
const { result } = renderHook(() =>
useVisionAutoSwitch(config, addItem as any, true, onVisionSwitchRequired),
);
const parts: PartListUnion = [
{ inlineData: { mimeType: 'image/png', data: '...' } },
];
let res: any;
await act(async () => {
res = await result.current.handleVisionSwitch(parts, 3030, false);
});
expect(res).toEqual({ shouldProceed: true });
expect(config.setModel).toHaveBeenCalledWith('qwen-vl-max-latest');
// Restore should be a no-op since no one-time override was used
act(() => {
result.current.restoreOriginalModel();
});
// Last call should still be the persisted model set
expect((config.setModel as any).mock.calls.pop()?.[0]).toBe(
'qwen-vl-max-latest',
);
});
it('returns shouldProceed=true when dialog returns no special flags', async () => {
const config = createMockConfig(AuthType.QWEN_OAUTH, 'qwen3-coder-plus');
const onVisionSwitchRequired = vi.fn().mockResolvedValue({});
const { result } = renderHook(() =>
useVisionAutoSwitch(config, addItem as any, true, onVisionSwitchRequired),
);
const parts: PartListUnion = [
{ inlineData: { mimeType: 'image/png', data: '...' } },
];
let res: any;
await act(async () => {
res = await result.current.handleVisionSwitch(parts, 4040, false);
});
expect(res).toEqual({ shouldProceed: true });
expect(config.setModel).not.toHaveBeenCalled();
});
it('blocks when dialog throws or is cancelled', async () => {
const config = createMockConfig(AuthType.QWEN_OAUTH, 'qwen3-coder-plus');
const onVisionSwitchRequired = vi.fn().mockRejectedValue(new Error('x'));
const { result } = renderHook(() =>
useVisionAutoSwitch(config, addItem as any, true, onVisionSwitchRequired),
);
const parts: PartListUnion = [
{ inlineData: { mimeType: 'image/png', data: '...' } },
];
let res: any;
await act(async () => {
res = await result.current.handleVisionSwitch(parts, 5050, false);
});
expect(res).toEqual({ shouldProceed: false });
expect(config.setModel).not.toHaveBeenCalled();
});
it('does nothing when visionModelPreviewEnabled is false', async () => {
const config = createMockConfig(AuthType.QWEN_OAUTH, 'qwen3-coder-plus');
const onVisionSwitchRequired = vi.fn();
const { result } = renderHook(() =>
useVisionAutoSwitch(
config,
addItem as any,
false,
onVisionSwitchRequired,
),
);
const parts: PartListUnion = [
{ inlineData: { mimeType: 'image/png', data: '...' } },
];
let res: any;
await act(async () => {
res = await result.current.handleVisionSwitch(parts, 6060, false);
});
expect(res).toEqual({ shouldProceed: true });
expect(onVisionSwitchRequired).not.toHaveBeenCalled();
});
});

View File

@@ -0,0 +1,304 @@
/**
* @license
* Copyright 2025 Qwen
* SPDX-License-Identifier: Apache-2.0
*/
import { type PartListUnion, type Part } from '@google/genai';
import { AuthType, type Config } from '@qwen-code/qwen-code-core';
import { useCallback, useRef } from 'react';
import { VisionSwitchOutcome } from '../components/ModelSwitchDialog.js';
import {
getDefaultVisionModel,
isVisionModel,
} from '../models/availableModels.js';
import { MessageType } from '../types.js';
import type { UseHistoryManagerReturn } from './useHistoryManager.js';
import {
isSupportedImageMimeType,
getUnsupportedImageFormatWarning,
} from '@qwen-code/qwen-code-core';
/**
* Checks if a PartListUnion contains image parts
*/
function hasImageParts(parts: PartListUnion): boolean {
if (typeof parts === 'string') {
return false;
}
if (Array.isArray(parts)) {
return parts.some((part) => {
// Skip string parts
if (typeof part === 'string') return false;
return isImagePart(part);
});
}
// If it's a single Part (not a string), check if it's an image
if (typeof parts === 'object') {
return isImagePart(parts);
}
return false;
}
/**
* Checks if a single Part is an image part
*/
function isImagePart(part: Part): boolean {
// Check for inlineData with image mime type
if ('inlineData' in part && part.inlineData?.mimeType?.startsWith('image/')) {
return true;
}
// Check for fileData with image mime type
if ('fileData' in part && part.fileData?.mimeType?.startsWith('image/')) {
return true;
}
return false;
}
/**
* Checks if image parts have supported formats and returns unsupported ones
*/
function checkImageFormatsSupport(parts: PartListUnion): {
hasImages: boolean;
hasUnsupportedFormats: boolean;
unsupportedMimeTypes: string[];
} {
const unsupportedMimeTypes: string[] = [];
let hasImages = false;
if (typeof parts === 'string') {
return {
hasImages: false,
hasUnsupportedFormats: false,
unsupportedMimeTypes: [],
};
}
const partsArray = Array.isArray(parts) ? parts : [parts];
for (const part of partsArray) {
if (typeof part === 'string') continue;
let mimeType: string | undefined;
// Check inlineData
if (
'inlineData' in part &&
part.inlineData?.mimeType?.startsWith('image/')
) {
hasImages = true;
mimeType = part.inlineData.mimeType;
}
// Check fileData
if ('fileData' in part && part.fileData?.mimeType?.startsWith('image/')) {
hasImages = true;
mimeType = part.fileData.mimeType;
}
// Check if the mime type is supported
if (mimeType && !isSupportedImageMimeType(mimeType)) {
unsupportedMimeTypes.push(mimeType);
}
}
return {
hasImages,
hasUnsupportedFormats: unsupportedMimeTypes.length > 0,
unsupportedMimeTypes,
};
}
/**
* Determines if we should offer vision switch for the given parts, auth type, and current model
*/
export function shouldOfferVisionSwitch(
parts: PartListUnion,
authType: AuthType,
currentModel: string,
visionModelPreviewEnabled: boolean = false,
): boolean {
// Only trigger for qwen-oauth
if (authType !== AuthType.QWEN_OAUTH) {
return false;
}
// If vision model preview is disabled, never offer vision switch
if (!visionModelPreviewEnabled) {
return false;
}
// If current model is already a vision model, no need to switch
if (isVisionModel(currentModel)) {
return false;
}
// Check if the current message contains image parts
return hasImageParts(parts);
}
/**
* Interface for vision switch result
*/
export interface VisionSwitchResult {
modelOverride?: string;
persistSessionModel?: string;
showGuidance?: boolean;
}
/**
* Processes the vision switch outcome and returns the appropriate result
*/
export function processVisionSwitchOutcome(
outcome: VisionSwitchOutcome,
): VisionSwitchResult {
const vlModelId = getDefaultVisionModel();
switch (outcome) {
case VisionSwitchOutcome.SwitchOnce:
return { modelOverride: vlModelId };
case VisionSwitchOutcome.SwitchSessionToVL:
return { persistSessionModel: vlModelId };
case VisionSwitchOutcome.DisallowWithGuidance:
return { showGuidance: true };
default:
return { showGuidance: true };
}
}
/**
* Gets the guidance message for when vision switch is disallowed
*/
export function getVisionSwitchGuidanceMessage(): string {
const vlModelId = getDefaultVisionModel();
return `To use images with your query, you can:
• Use /model set ${vlModelId} to switch to a vision-capable model
• Or remove the image and provide a text description instead`;
}
/**
* Interface for vision switch handling result
*/
export interface VisionSwitchHandlingResult {
shouldProceed: boolean;
originalModel?: string;
}
/**
* Custom hook for handling vision model auto-switching
*/
export function useVisionAutoSwitch(
config: Config,
addItem: UseHistoryManagerReturn['addItem'],
visionModelPreviewEnabled: boolean = false,
onVisionSwitchRequired?: (query: PartListUnion) => Promise<{
modelOverride?: string;
persistSessionModel?: string;
showGuidance?: boolean;
}>,
) {
const originalModelRef = useRef<string | null>(null);
const handleVisionSwitch = useCallback(
async (
query: PartListUnion,
userMessageTimestamp: number,
isContinuation: boolean,
): Promise<VisionSwitchHandlingResult> => {
// Skip vision switch handling for continuations or if no handler provided
if (isContinuation || !onVisionSwitchRequired) {
return { shouldProceed: true };
}
const contentGeneratorConfig = config.getContentGeneratorConfig();
// Only handle qwen-oauth auth type
if (contentGeneratorConfig?.authType !== AuthType.QWEN_OAUTH) {
return { shouldProceed: true };
}
// Check image format support first
const formatCheck = checkImageFormatsSupport(query);
// If there are unsupported image formats, show warning
if (formatCheck.hasUnsupportedFormats) {
addItem(
{
type: MessageType.INFO,
text: getUnsupportedImageFormatWarning(),
},
userMessageTimestamp,
);
// Continue processing but with warning shown
}
// Check if vision switch is needed
if (
!shouldOfferVisionSwitch(
query,
contentGeneratorConfig.authType,
config.getModel(),
visionModelPreviewEnabled,
)
) {
return { shouldProceed: true };
}
try {
const visionSwitchResult = await onVisionSwitchRequired(query);
if (visionSwitchResult.showGuidance) {
// Show guidance and don't proceed with the request
addItem(
{
type: MessageType.INFO,
text: getVisionSwitchGuidanceMessage(),
},
userMessageTimestamp,
);
return { shouldProceed: false };
}
if (visionSwitchResult.modelOverride) {
// One-time model override
originalModelRef.current = config.getModel();
config.setModel(visionSwitchResult.modelOverride);
return {
shouldProceed: true,
originalModel: originalModelRef.current,
};
} else if (visionSwitchResult.persistSessionModel) {
// Persistent session model change
config.setModel(visionSwitchResult.persistSessionModel);
return { shouldProceed: true };
}
return { shouldProceed: true };
} catch (_error) {
// If vision switch dialog was cancelled or errored, don't proceed
return { shouldProceed: false };
}
},
[config, addItem, visionModelPreviewEnabled, onVisionSwitchRequired],
);
const restoreOriginalModel = useCallback(() => {
if (originalModelRef.current) {
config.setModel(originalModelRef.current);
originalModelRef.current = null;
}
}, [config]);
return {
handleVisionSwitch,
restoreOriginalModel,
};
}

View File

@@ -0,0 +1,52 @@
/**
* @license
* Copyright 2025 Qwen
* SPDX-License-Identifier: Apache-2.0
*/
export type AvailableModel = {
id: string;
label: string;
isVision?: boolean;
};
export const AVAILABLE_MODELS_QWEN: AvailableModel[] = [
{ id: 'qwen3-coder-plus', label: 'qwen3-coder-plus' },
{ id: 'qwen-vl-max-latest', label: 'qwen-vl-max', isVision: true },
];
/**
* Get available Qwen models filtered by vision model preview setting
*/
export function getFilteredQwenModels(
visionModelPreviewEnabled: boolean,
): AvailableModel[] {
if (visionModelPreviewEnabled) {
return AVAILABLE_MODELS_QWEN;
}
return AVAILABLE_MODELS_QWEN.filter((model) => !model.isVision);
}
/**
* Currently we use the single model of `OPENAI_MODEL` in the env.
* In the future, after settings.json is updated, we will allow users to configure this themselves.
*/
export function getOpenAIAvailableModelFromEnv(): AvailableModel | null {
const id = process.env['OPENAI_MODEL']?.trim();
return id ? { id, label: id } : null;
}
/**
/**
* Hard code the default vision model as a string literal,
* until our coding model supports multimodal.
*/
export function getDefaultVisionModel(): string {
return 'qwen-vl-max-latest';
}
export function isVisionModel(modelId: string): boolean {
return AVAILABLE_MODELS_QWEN.some(
(model) => model.id === modelId && model.isVision,
);
}

View File

@@ -19,3 +19,4 @@ export {
} from './src/telemetry/types.js';
export { makeFakeConfig } from './src/test-utils/config.js';
export * from './src/utils/pathReader.js';
export * from './src/utils/request-tokenizer/supportedImageFormats.js';

View File

@@ -521,6 +521,18 @@ export class Config {
if (this.contentGeneratorConfig) {
this.contentGeneratorConfig.model = newModel;
}
// Reinitialize chat with updated configuration while preserving history
const geminiClient = this.getGeminiClient();
if (geminiClient && geminiClient.isInitialized()) {
// Use async operation but don't await to avoid blocking
geminiClient.reinitialize().catch((error) => {
console.error(
'Failed to reinitialize chat with updated config:',
error,
);
});
}
}
isInFallbackMode(): boolean {

File diff suppressed because it is too large Load Diff

View File

@@ -5,9 +5,10 @@
*/
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
import { OpenAIContentGenerator } from '../openaiContentGenerator.js';
import { OpenAIContentGenerator } from '../openaiContentGenerator/openaiContentGenerator.js';
import type { Config } from '../../config/config.js';
import { AuthType } from '../contentGenerator.js';
import type { OpenAICompatibleProvider } from '../openaiContentGenerator/provider/index.js';
import OpenAI from 'openai';
// Mock OpenAI
@@ -30,6 +31,7 @@ describe('OpenAIContentGenerator Timeout Handling', () => {
let mockConfig: Config;
// eslint-disable-next-line @typescript-eslint/no-explicit-any
let mockOpenAIClient: any;
let mockProvider: OpenAICompatibleProvider;
beforeEach(() => {
// Reset mocks
@@ -42,6 +44,7 @@ describe('OpenAIContentGenerator Timeout Handling', () => {
mockConfig = {
getContentGeneratorConfig: vi.fn().mockReturnValue({
authType: 'openai',
enableOpenAILogging: false,
}),
getCliVersion: vi.fn().mockReturnValue('1.0.0'),
} as unknown as Config;
@@ -53,17 +56,34 @@ describe('OpenAIContentGenerator Timeout Handling', () => {
create: vi.fn(),
},
},
embeddings: {
create: vi.fn(),
},
};
vi.mocked(OpenAI).mockImplementation(() => mockOpenAIClient);
// Create mock provider
mockProvider = {
buildHeaders: vi.fn().mockReturnValue({
'User-Agent': 'QwenCode/1.0.0 (test; test)',
}),
buildClient: vi.fn().mockReturnValue(mockOpenAIClient),
buildRequest: vi.fn().mockImplementation((req) => req),
};
// Create generator instance
const contentGeneratorConfig = {
model: 'gpt-4',
apiKey: 'test-key',
authType: AuthType.USE_OPENAI,
enableOpenAILogging: false,
};
generator = new OpenAIContentGenerator(contentGeneratorConfig, mockConfig);
generator = new OpenAIContentGenerator(
contentGeneratorConfig,
mockConfig,
mockProvider,
);
});
afterEach(() => {
@@ -209,7 +229,7 @@ describe('OpenAIContentGenerator Timeout Handling', () => {
await expect(
generator.generateContentStream(request, 'test-prompt-id'),
).rejects.toThrow(
/Streaming setup timeout after \d+s\. Try reducing input length or increasing timeout in config\./,
/Streaming request timeout after \d+s\. Try reducing input length or increasing timeout in config\./,
);
});
@@ -227,12 +247,8 @@ describe('OpenAIContentGenerator Timeout Handling', () => {
} catch (error: unknown) {
const errorMessage =
error instanceof Error ? error.message : String(error);
expect(errorMessage).toContain(
'Streaming setup timeout troubleshooting:',
);
expect(errorMessage).toContain(
'Check network connectivity and firewall settings',
);
expect(errorMessage).toContain('Streaming timeout troubleshooting:');
expect(errorMessage).toContain('Check network connectivity');
expect(errorMessage).toContain('Consider using non-streaming mode');
}
});
@@ -246,23 +262,21 @@ describe('OpenAIContentGenerator Timeout Handling', () => {
authType: AuthType.USE_OPENAI,
baseUrl: 'http://localhost:8080',
};
new OpenAIContentGenerator(contentGeneratorConfig, mockConfig);
new OpenAIContentGenerator(
contentGeneratorConfig,
mockConfig,
mockProvider,
);
// Verify OpenAI client was created with timeout config
expect(OpenAI).toHaveBeenCalledWith({
apiKey: 'test-key',
baseURL: 'http://localhost:8080',
timeout: 120000,
maxRetries: 3,
defaultHeaders: {
'User-Agent': expect.stringMatching(/^QwenCode/),
},
});
// Verify provider buildClient was called
expect(mockProvider.buildClient).toHaveBeenCalled();
});
it('should use custom timeout from config', () => {
const customConfig = {
getContentGeneratorConfig: vi.fn().mockReturnValue({}),
getContentGeneratorConfig: vi.fn().mockReturnValue({
enableOpenAILogging: false,
}),
getCliVersion: vi.fn().mockReturnValue('1.0.0'),
} as unknown as Config;
@@ -274,22 +288,31 @@ describe('OpenAIContentGenerator Timeout Handling', () => {
timeout: 300000,
maxRetries: 5,
};
new OpenAIContentGenerator(contentGeneratorConfig, customConfig);
expect(OpenAI).toHaveBeenCalledWith({
apiKey: 'test-key',
baseURL: 'http://localhost:8080',
timeout: 300000,
maxRetries: 5,
defaultHeaders: {
'User-Agent': expect.stringMatching(/^QwenCode/),
},
});
// Create a custom mock provider for this test
const customMockProvider: OpenAICompatibleProvider = {
buildHeaders: vi.fn().mockReturnValue({
'User-Agent': 'QwenCode/1.0.0 (test; test)',
}),
buildClient: vi.fn().mockReturnValue(mockOpenAIClient),
buildRequest: vi.fn().mockImplementation((req) => req),
};
new OpenAIContentGenerator(
contentGeneratorConfig,
customConfig,
customMockProvider,
);
// Verify provider buildClient was called
expect(customMockProvider.buildClient).toHaveBeenCalled();
});
it('should handle missing timeout config gracefully', () => {
const noTimeoutConfig = {
getContentGeneratorConfig: vi.fn().mockReturnValue({}),
getContentGeneratorConfig: vi.fn().mockReturnValue({
enableOpenAILogging: false,
}),
getCliVersion: vi.fn().mockReturnValue('1.0.0'),
} as unknown as Config;
@@ -299,17 +322,24 @@ describe('OpenAIContentGenerator Timeout Handling', () => {
authType: AuthType.USE_OPENAI,
baseUrl: 'http://localhost:8080',
};
new OpenAIContentGenerator(contentGeneratorConfig, noTimeoutConfig);
expect(OpenAI).toHaveBeenCalledWith({
apiKey: 'test-key',
baseURL: 'http://localhost:8080',
timeout: 120000, // default
maxRetries: 3, // default
defaultHeaders: {
'User-Agent': expect.stringMatching(/^QwenCode/),
},
});
// Create a custom mock provider for this test
const noTimeoutMockProvider: OpenAICompatibleProvider = {
buildHeaders: vi.fn().mockReturnValue({
'User-Agent': 'QwenCode/1.0.0 (test; test)',
}),
buildClient: vi.fn().mockReturnValue(mockOpenAIClient),
buildRequest: vi.fn().mockImplementation((req) => req),
};
new OpenAIContentGenerator(
contentGeneratorConfig,
noTimeoutConfig,
noTimeoutMockProvider,
);
// Verify provider buildClient was called
expect(noTimeoutMockProvider.buildClient).toHaveBeenCalled();
});
});

View File

@@ -441,7 +441,8 @@ describe('Gemini Client (client.ts)', () => {
);
});
it('should allow overriding model and config', async () => {
/* We now use model in contentGeneratorConfig in most cases. */
it.skip('should allow overriding model and config', async () => {
const contents: Content[] = [
{ role: 'user', parts: [{ text: 'hello' }] },
];
@@ -2549,4 +2550,82 @@ ${JSON.stringify(
expect(mockChat.setHistory).toHaveBeenCalledWith(historyWithThoughts);
});
});
describe('initialize', () => {
it('should accept extraHistory parameter and pass it to startChat', async () => {
const mockStartChat = vi.fn().mockResolvedValue({});
client['startChat'] = mockStartChat;
const extraHistory = [
{ role: 'user', parts: [{ text: 'Previous message' }] },
{ role: 'model', parts: [{ text: 'Previous response' }] },
];
const contentGeneratorConfig = {
model: 'test-model',
apiKey: 'test-key',
vertexai: false,
authType: AuthType.USE_GEMINI,
};
await client.initialize(contentGeneratorConfig, extraHistory);
expect(mockStartChat).toHaveBeenCalledWith(extraHistory, 'test-model');
});
it('should use empty array when no extraHistory is provided', async () => {
const mockStartChat = vi.fn().mockResolvedValue({});
client['startChat'] = mockStartChat;
const contentGeneratorConfig = {
model: 'test-model',
apiKey: 'test-key',
vertexai: false,
authType: AuthType.USE_GEMINI,
};
await client.initialize(contentGeneratorConfig);
expect(mockStartChat).toHaveBeenCalledWith([], 'test-model');
});
});
describe('reinitialize', () => {
it('should reinitialize with preserved user history', async () => {
// Mock the initialize method
const mockInitialize = vi.fn().mockResolvedValue(undefined);
client['initialize'] = mockInitialize;
// Set up initial history with environment context + user messages
const mockHistory = [
{ role: 'user', parts: [{ text: 'Environment context' }] },
{ role: 'model', parts: [{ text: 'Got it. Thanks for the context!' }] },
{ role: 'user', parts: [{ text: 'User message 1' }] },
{ role: 'model', parts: [{ text: 'Model response 1' }] },
];
const mockChat = {
getHistory: vi.fn().mockReturnValue(mockHistory),
};
client['chat'] = mockChat as unknown as GeminiChat;
client['getHistory'] = vi.fn().mockReturnValue(mockHistory);
await client.reinitialize();
// Should call initialize with preserved user history (excluding first 2 env messages)
expect(mockInitialize).toHaveBeenCalledWith(
expect.any(Object), // contentGeneratorConfig
[
{ role: 'user', parts: [{ text: 'User message 1' }] },
{ role: 'model', parts: [{ text: 'Model response 1' }] },
],
);
});
it('should not throw error when chat is not initialized', async () => {
client['chat'] = undefined;
await expect(client.reinitialize()).resolves.not.toThrow();
});
});
});

View File

@@ -138,13 +138,24 @@ export class GeminiClient {
this.lastPromptId = this.config.getSessionId();
}
async initialize(contentGeneratorConfig: ContentGeneratorConfig) {
async initialize(
contentGeneratorConfig: ContentGeneratorConfig,
extraHistory?: Content[],
) {
this.contentGenerator = await createContentGenerator(
contentGeneratorConfig,
this.config,
this.config.getSessionId(),
);
this.chat = await this.startChat();
/**
* Always take the model from contentGeneratorConfig to initialize,
* despite the `this.config.contentGeneratorConfig` is not updated yet because in
* `Config` it will not be updated until the initialization is successful.
*/
this.chat = await this.startChat(
extraHistory || [],
contentGeneratorConfig.model,
);
}
getContentGenerator(): ContentGenerator {
@@ -217,6 +228,28 @@ export class GeminiClient {
this.chat = await this.startChat();
}
/**
* Reinitializes the chat with the current contentGeneratorConfig while preserving chat history.
* This creates a new chat object using the existing history and updated configuration.
* Should be called when configuration changes (model, auth, etc.) to ensure consistency.
*/
async reinitialize(): Promise<void> {
if (!this.chat) {
return;
}
// Preserve the current chat history (excluding environment context)
const currentHistory = this.getHistory();
// Remove the initial environment context (first 2 messages: user env + model acknowledgment)
const userHistory = currentHistory.slice(2);
// Get current content generator config and reinitialize with preserved history
const contentGeneratorConfig = this.config.getContentGeneratorConfig();
if (contentGeneratorConfig) {
await this.initialize(contentGeneratorConfig, userHistory);
}
}
async addDirectoryContext(): Promise<void> {
if (!this.chat) {
return;
@@ -228,7 +261,10 @@ export class GeminiClient {
});
}
async startChat(extraHistory?: Content[]): Promise<GeminiChat> {
async startChat(
extraHistory?: Content[],
model?: string,
): Promise<GeminiChat> {
this.forceFullIdeContext = true;
this.hasFailedCompressionAttempt = false;
const envParts = await getEnvironmentContext(this.config);
@@ -248,9 +284,13 @@ export class GeminiClient {
];
try {
const userMemory = this.config.getUserMemory();
const systemInstruction = getCoreSystemPrompt(userMemory);
const systemInstruction = getCoreSystemPrompt(
userMemory,
{},
model || this.config.getModel(),
);
const generateContentConfigWithThinking = isThinkingSupported(
this.config.getModel(),
model || this.config.getModel(),
)
? {
...this.generateContentConfig,
@@ -490,7 +530,11 @@ export class GeminiClient {
// Get all the content that would be sent in an API call
const currentHistory = this.getChat().getHistory(true);
const userMemory = this.config.getUserMemory();
const systemPrompt = getCoreSystemPrompt(userMemory);
const systemPrompt = getCoreSystemPrompt(
userMemory,
{},
this.config.getModel(),
);
const environment = await getEnvironmentContext(this.config);
// Create a mock request content to count total tokens
@@ -644,14 +688,18 @@ export class GeminiClient {
model?: string,
config: GenerateContentConfig = {},
): Promise<Record<string, unknown>> {
// Use current model from config instead of hardcoded Flash model
const modelToUse =
model || this.config.getModel() || DEFAULT_GEMINI_FLASH_MODEL;
/**
* TODO: ensure `model` consistency among GeminiClient, GeminiChat, and ContentGenerator
* `model` passed to generateContent is not respected as we always use contentGenerator
* We should ignore model for now because some calls use `DEFAULT_GEMINI_FLASH_MODEL`
* which is not available as `qwen3-coder-flash`
*/
const modelToUse = this.config.getModel() || DEFAULT_GEMINI_FLASH_MODEL;
try {
const userMemory = this.config.getUserMemory();
const finalSystemInstruction = config.systemInstruction
? getCustomSystemPrompt(config.systemInstruction, userMemory)
: getCoreSystemPrompt(userMemory);
: getCoreSystemPrompt(userMemory, {}, modelToUse);
const requestConfig = {
abortSignal,
@@ -742,7 +790,7 @@ export class GeminiClient {
const userMemory = this.config.getUserMemory();
const finalSystemInstruction = generationConfig.systemInstruction
? getCustomSystemPrompt(generationConfig.systemInstruction, userMemory)
: getCoreSystemPrompt(userMemory);
: getCoreSystemPrompt(userMemory, {}, this.config.getModel());
const requestConfig: GenerateContentConfig = {
abortSignal,

View File

@@ -500,7 +500,7 @@ export class GeminiChat {
if (error instanceof Error && error.message) {
if (isSchemaDepthError(error.message)) return false;
if (error.message.includes('429')) return true;
if (error.message.match(/5\d{2}/)) return true;
if (error.message.match(/^5\d{2}/)) return true;
}
return false;
},

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -5,6 +5,37 @@
*/
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
// Mock the request tokenizer module BEFORE importing the class that uses it
const mockTokenizer = {
calculateTokens: vi.fn().mockResolvedValue({
totalTokens: 50,
breakdown: {
textTokens: 50,
imageTokens: 0,
audioTokens: 0,
otherTokens: 0,
},
processingTime: 1,
}),
dispose: vi.fn(),
};
vi.mock('../../../utils/request-tokenizer/index.js', () => ({
getDefaultTokenizer: vi.fn(() => mockTokenizer),
DefaultRequestTokenizer: vi.fn(() => mockTokenizer),
disposeDefaultTokenizer: vi.fn(),
}));
// Mock tiktoken as well for completeness
vi.mock('tiktoken', () => ({
get_encoding: vi.fn(() => ({
encode: vi.fn(() => new Array(50)), // Mock 50 tokens
free: vi.fn(),
})),
}));
// Now import the modules that depend on the mocked modules
import { OpenAIContentGenerator } from './openaiContentGenerator.js';
import type { Config } from '../../config/config.js';
import { AuthType } from '../contentGenerator.js';
@@ -15,14 +46,6 @@ import type {
import type { OpenAICompatibleProvider } from './provider/index.js';
import type OpenAI from 'openai';
// Mock tiktoken
vi.mock('tiktoken', () => ({
get_encoding: vi.fn().mockReturnValue({
encode: vi.fn().mockReturnValue(new Array(50)), // Mock 50 tokens
free: vi.fn(),
}),
}));
describe('OpenAIContentGenerator (Refactored)', () => {
let generator: OpenAIContentGenerator;
let mockConfig: Config;

View File

@@ -13,6 +13,7 @@ import type { PipelineConfig } from './pipeline.js';
import { ContentGenerationPipeline } from './pipeline.js';
import { DefaultTelemetryService } from './telemetryService.js';
import { EnhancedErrorHandler } from './errorHandler.js';
import { getDefaultTokenizer } from '../../utils/request-tokenizer/index.js';
import type { ContentGeneratorConfig } from '../contentGenerator.js';
export class OpenAIContentGenerator implements ContentGenerator {
@@ -71,27 +72,30 @@ export class OpenAIContentGenerator implements ContentGenerator {
async countTokens(
request: CountTokensParameters,
): Promise<CountTokensResponse> {
// Use tiktoken for accurate token counting
const content = JSON.stringify(request.contents);
let totalTokens = 0;
try {
const { get_encoding } = await import('tiktoken');
const encoding = get_encoding('cl100k_base'); // GPT-4 encoding, but estimate for qwen
totalTokens = encoding.encode(content).length;
encoding.free();
// Use the new high-performance request tokenizer
const tokenizer = getDefaultTokenizer();
const result = await tokenizer.calculateTokens(request, {
textEncoding: 'cl100k_base', // Use GPT-4 encoding for consistency
});
return {
totalTokens: result.totalTokens,
};
} catch (error) {
console.warn(
'Failed to load tiktoken, falling back to character approximation:',
'Failed to calculate tokens with new tokenizer, falling back to simple method:',
error,
);
// Fallback: rough approximation using character count
totalTokens = Math.ceil(content.length / 4); // Rough estimate: 1 token ≈ 4 characters
}
return {
totalTokens,
};
// Fallback to original simple method
const content = JSON.stringify(request.contents);
const totalTokens = Math.ceil(content.length / 4); // Rough estimate: 1 token ≈ 4 characters
return {
totalTokens,
};
}
}
async embedContent(

View File

@@ -10,14 +10,11 @@ import {
GenerateContentResponse,
} from '@google/genai';
import type { Config } from '../../config/config.js';
import { type ContentGeneratorConfig } from '../contentGenerator.js';
import { type OpenAICompatibleProvider } from './provider/index.js';
import type { ContentGeneratorConfig } from '../contentGenerator.js';
import type { OpenAICompatibleProvider } from './provider/index.js';
import { OpenAIContentConverter } from './converter.js';
import {
type TelemetryService,
type RequestContext,
} from './telemetryService.js';
import { type ErrorHandler } from './errorHandler.js';
import type { TelemetryService, RequestContext } from './telemetryService.js';
import type { ErrorHandler } from './errorHandler.js';
export interface PipelineConfig {
cliConfig: Config;
@@ -101,7 +98,7 @@ export class ContentGenerationPipeline {
* 2. Filter empty responses
* 3. Handle chunk merging for providers that send finishReason and usageMetadata separately
* 4. Collect both formats for logging
* 5. Handle success/error logging with original OpenAI format
* 5. Handle success/error logging
*/
private async *processStreamWithLogging(
stream: AsyncIterable<OpenAI.Chat.ChatCompletionChunk>,
@@ -169,19 +166,11 @@ export class ContentGenerationPipeline {
collectedOpenAIChunks,
);
} catch (error) {
// Stage 2e: Stream failed - handle error and logging
context.duration = Date.now() - context.startTime;
// Clear streaming tool calls on error to prevent data pollution
this.converter.resetStreamingToolCalls();
await this.config.telemetryService.logError(
context,
error,
openaiRequest,
);
this.config.errorHandler.handle(error, context, request);
// Use shared error handling logic
await this.handleError(error, context, request);
}
}
@@ -365,25 +354,59 @@ export class ContentGenerationPipeline {
context.duration = Date.now() - context.startTime;
return result;
} catch (error) {
context.duration = Date.now() - context.startTime;
// Log error
const openaiRequest = await this.buildRequest(
// Use shared error handling logic
return await this.handleError(
error,
context,
request,
userPromptId,
isStreaming,
);
await this.config.telemetryService.logError(
context,
error,
openaiRequest,
);
// Handle and throw enhanced error
this.config.errorHandler.handle(error, context, request);
}
}
/**
* Shared error handling logic for both executeWithErrorHandling and processStreamWithLogging
* This centralizes the common error processing steps to avoid duplication
*/
private async handleError(
error: unknown,
context: RequestContext,
request: GenerateContentParameters,
userPromptId?: string,
isStreaming?: boolean,
): Promise<never> {
context.duration = Date.now() - context.startTime;
// Build request for logging (may fail, but we still want to log the error)
let openaiRequest: OpenAI.Chat.ChatCompletionCreateParams;
try {
if (userPromptId !== undefined && isStreaming !== undefined) {
openaiRequest = await this.buildRequest(
request,
userPromptId,
isStreaming,
);
} else {
// For processStreamWithLogging, we don't have userPromptId/isStreaming,
// so create a minimal request
openaiRequest = {
model: this.contentGeneratorConfig.model,
messages: [],
};
}
} catch (_buildError) {
// If we can't build the request, create a minimal one for logging
openaiRequest = {
model: this.contentGeneratorConfig.model,
messages: [],
};
}
await this.config.telemetryService.logError(context, error, openaiRequest);
this.config.errorHandler.handle(error, context, request);
}
/**
* Create request context with common properties
*/

View File

@@ -79,6 +79,16 @@ export class DashScopeOpenAICompatibleProvider
messages = this.addDashScopeCacheControl(messages, cacheTarget);
}
if (request.model.startsWith('qwen-vl')) {
return {
...request,
messages,
...(this.buildMetadata(userPromptId) || {}),
/* @ts-expect-error dashscope exclusive */
vl_high_resolution_images: true,
};
}
return {
...request, // Preserve all original parameters including sampling params
messages,

View File

@@ -364,6 +364,120 @@ describe('URL matching with trailing slash compatibility', () => {
});
});
describe('Model-specific tool call formats', () => {
beforeEach(() => {
vi.resetAllMocks();
vi.stubEnv('SANDBOX', undefined);
});
it('should use XML format for qwen3-coder model', () => {
vi.mocked(isGitRepository).mockReturnValue(false);
const prompt = getCoreSystemPrompt(undefined, undefined, 'qwen3-coder-7b');
// Should contain XML-style tool calls
expect(prompt).toContain('<tool_call>');
expect(prompt).toContain('<function=run_shell_command>');
expect(prompt).toContain('<parameter=command>');
expect(prompt).toContain('</function>');
expect(prompt).toContain('</tool_call>');
// Should NOT contain bracket-style tool calls
expect(prompt).not.toContain('[tool_call: run_shell_command for');
// Should NOT contain JSON-style tool calls
expect(prompt).not.toContain('{"name": "run_shell_command"');
expect(prompt).toMatchSnapshot();
});
it('should use JSON format for qwen-vl model', () => {
vi.mocked(isGitRepository).mockReturnValue(false);
const prompt = getCoreSystemPrompt(undefined, undefined, 'qwen-vl-max');
// Should contain JSON-style tool calls
expect(prompt).toContain('<tool_call>');
expect(prompt).toContain('{"name": "run_shell_command"');
expect(prompt).toContain('"arguments": {"command": "node server.js &"}');
expect(prompt).toContain('</tool_call>');
// Should NOT contain bracket-style tool calls
expect(prompt).not.toContain('[tool_call: run_shell_command for');
// Should NOT contain XML-style tool calls with parameters
expect(prompt).not.toContain('<function=run_shell_command>');
expect(prompt).not.toContain('<parameter=command>');
expect(prompt).toMatchSnapshot();
});
it('should use bracket format for generic models', () => {
vi.mocked(isGitRepository).mockReturnValue(false);
const prompt = getCoreSystemPrompt(undefined, undefined, 'gpt-4');
// Should contain bracket-style tool calls
expect(prompt).toContain('[tool_call: run_shell_command for');
expect(prompt).toContain('because it must run in the background]');
// Should NOT contain XML-style tool calls
expect(prompt).not.toContain('<function=run_shell_command>');
expect(prompt).not.toContain('<parameter=command>');
// Should NOT contain JSON-style tool calls
expect(prompt).not.toContain('{"name": "run_shell_command"');
expect(prompt).toMatchSnapshot();
});
it('should use bracket format when no model is specified', () => {
vi.mocked(isGitRepository).mockReturnValue(false);
const prompt = getCoreSystemPrompt();
// Should contain bracket-style tool calls (default behavior)
expect(prompt).toContain('[tool_call: run_shell_command for');
expect(prompt).toContain('because it must run in the background]');
// Should NOT contain XML or JSON formats
expect(prompt).not.toContain('<function=run_shell_command>');
expect(prompt).not.toContain('{"name": "run_shell_command"');
expect(prompt).toMatchSnapshot();
});
it('should preserve model-specific formats with user memory', () => {
vi.mocked(isGitRepository).mockReturnValue(false);
const userMemory = 'User prefers concise responses.';
const prompt = getCoreSystemPrompt(
userMemory,
undefined,
'qwen3-coder-14b',
);
// Should contain XML-style tool calls
expect(prompt).toContain('<tool_call>');
expect(prompt).toContain('<function=run_shell_command>');
// Should contain user memory with separator
expect(prompt).toContain('---');
expect(prompt).toContain('User prefers concise responses.');
expect(prompt).toMatchSnapshot();
});
it('should preserve model-specific formats with sandbox environment', () => {
vi.stubEnv('SANDBOX', 'true');
vi.mocked(isGitRepository).mockReturnValue(false);
const prompt = getCoreSystemPrompt(undefined, undefined, 'qwen-vl-plus');
// Should contain JSON-style tool calls
expect(prompt).toContain('{"name": "run_shell_command"');
// Should contain sandbox instructions
expect(prompt).toContain('# Sandbox');
expect(prompt).toMatchSnapshot();
});
});
describe('getCustomSystemPrompt', () => {
it('should handle string custom instruction without user memory', () => {
const customInstruction =

View File

@@ -7,18 +7,10 @@
import path from 'node:path';
import fs from 'node:fs';
import os from 'node:os';
import { EditTool } from '../tools/edit.js';
import { GlobTool } from '../tools/glob.js';
import { GrepTool } from '../tools/grep.js';
import { ReadFileTool } from '../tools/read-file.js';
import { ReadManyFilesTool } from '../tools/read-many-files.js';
import { ShellTool } from '../tools/shell.js';
import { WriteFileTool } from '../tools/write-file.js';
import { ToolNames } from '../tools/tool-names.js';
import process from 'node:process';
import { isGitRepository } from '../utils/gitUtils.js';
import { MemoryTool, GEMINI_CONFIG_DIR } from '../tools/memoryTool.js';
import { TodoWriteTool } from '../tools/todoWrite.js';
import { TaskTool } from '../tools/task.js';
import { GEMINI_CONFIG_DIR } from '../tools/memoryTool.js';
import type { GenerateContentConfig } from '@google/genai';
export interface ModelTemplateMapping {
@@ -91,6 +83,7 @@ export function getCustomSystemPrompt(
export function getCoreSystemPrompt(
userMemory?: string,
config?: SystemPromptConfig,
model?: string,
): string {
// if GEMINI_SYSTEM_MD is set (and not 0|false), override system prompt from file
// default path is .gemini/system.md but can be modified via custom path in GEMINI_SYSTEM_MD
@@ -177,11 +170,11 @@ You are Qwen Code, an interactive CLI agent developed by Alibaba Group, speciali
- **Proactiveness:** Fulfill the user's request thoroughly, including reasonable, directly implied follow-up actions.
- **Confirm Ambiguity/Expansion:** Do not take significant actions beyond the clear scope of the request without confirming with the user. If asked *how* to do something, explain first, don't just do it.
- **Explaining Changes:** After completing a code modification or file operation *do not* provide summaries unless asked.
- **Path Construction:** Before using any file system tool (e.g., ${ReadFileTool.Name}' or '${WriteFileTool.Name}'), you must construct the full absolute path for the file_path argument. Always combine the absolute path of the project's root directory with the file's path relative to the root. For example, if the project root is /path/to/project/ and the file is foo/bar/baz.txt, the final path you must use is /path/to/project/foo/bar/baz.txt. If the user provides a relative path, you must resolve it against the root directory to create an absolute path.
- **Path Construction:** Before using any file system tool (e.g., ${ToolNames.READ_FILE}' or '${ToolNames.WRITE_FILE}'), you must construct the full absolute path for the file_path argument. Always combine the absolute path of the project's root directory with the file's path relative to the root. For example, if the project root is /path/to/project/ and the file is foo/bar/baz.txt, the final path you must use is /path/to/project/foo/bar/baz.txt. If the user provides a relative path, you must resolve it against the root directory to create an absolute path.
- **Do Not revert changes:** Do not revert changes to the codebase unless asked to do so by the user. Only revert changes made by you if they have resulted in an error or if the user has explicitly asked you to revert the changes.
# Task Management
You have access to the ${TodoWriteTool.Name} tool to help you manage and plan tasks. Use these tools VERY frequently to ensure that you are tracking your tasks and giving the user visibility into your progress.
You have access to the ${ToolNames.TODO_WRITE} tool to help you manage and plan tasks. Use these tools VERY frequently to ensure that you are tracking your tasks and giving the user visibility into your progress.
These tools are also EXTREMELY helpful for planning tasks, and for breaking down larger complex tasks into smaller steps. If you do not use this tool when planning, you may forget to do important tasks - and that is unacceptable.
It is critical that you mark todos as completed as soon as you are done with a task. Do not batch up multiple tasks before marking them as completed.
@@ -190,13 +183,13 @@ Examples:
<example>
user: Run the build and fix any type errors
assistant: I'm going to use the ${TodoWriteTool.Name} tool to write the following items to the todo list:
assistant: I'm going to use the ${ToolNames.TODO_WRITE} tool to write the following items to the todo list:
- Run the build
- Fix any type errors
I'm now going to run the build using Bash.
Looks like I found 10 type errors. I'm going to use the ${TodoWriteTool.Name} tool to write 10 items to the todo list.
Looks like I found 10 type errors. I'm going to use the ${ToolNames.TODO_WRITE} tool to write 10 items to the todo list.
marking the first todo as in_progress
@@ -211,7 +204,7 @@ In the above example, the assistant completes all the tasks, including the 10 er
<example>
user: Help me write a new feature that allows users to track their usage metrics and export them to various formats
A: I'll help you implement a usage metrics tracking and export feature. Let me first use the ${TodoWriteTool.Name} tool to plan this task.
A: I'll help you implement a usage metrics tracking and export feature. Let me first use the ${ToolNames.TODO_WRITE} tool to plan this task.
Adding the following todos to the todo list:
1. Research existing metrics tracking in the codebase
2. Design the metrics collection system
@@ -232,8 +225,8 @@ I've found some existing telemetry code. Let me mark the first todo as in_progre
## Software Engineering Tasks
When requested to perform tasks like fixing bugs, adding features, refactoring, or explaining code, follow this iterative approach:
- **Plan:** After understanding the user's request, create an initial plan based on your existing knowledge and any immediately obvious context. Use the '${TodoWriteTool.Name}' tool to capture this rough plan for complex or multi-step work. Don't wait for complete understanding - start with what you know.
- **Implement:** Begin implementing the plan while gathering additional context as needed. Use '${GrepTool.Name}', '${GlobTool.Name}', '${ReadFileTool.Name}', and '${ReadManyFilesTool.Name}' tools strategically when you encounter specific unknowns during implementation. Use the available tools (e.g., '${EditTool.Name}', '${WriteFileTool.Name}' '${ShellTool.Name}' ...) to act on the plan, strictly adhering to the project's established conventions (detailed under 'Core Mandates').
- **Plan:** After understanding the user's request, create an initial plan based on your existing knowledge and any immediately obvious context. Use the '${ToolNames.TODO_WRITE}' tool to capture this rough plan for complex or multi-step work. Don't wait for complete understanding - start with what you know.
- **Implement:** Begin implementing the plan while gathering additional context as needed. Use '${ToolNames.GREP}', '${ToolNames.GLOB}', '${ToolNames.READ_FILE}', and '${ToolNames.READ_MANY_FILES}' tools strategically when you encounter specific unknowns during implementation. Use the available tools (e.g., '${ToolNames.EDIT}', '${ToolNames.WRITE_FILE}' '${ToolNames.SHELL}' ...) to act on the plan, strictly adhering to the project's established conventions (detailed under 'Core Mandates').
- **Adapt:** As you discover new information or encounter obstacles, update your plan and todos accordingly. Mark todos as in_progress when starting and completed when finishing each task. Add new todos if the scope expands. Refine your approach based on what you learn.
- **Verify (Tests):** If applicable and feasible, verify the changes using the project's testing procedures. Identify the correct test commands and frameworks by examining 'README' files, build/package configuration (e.g., 'package.json'), or existing test execution patterns. NEVER assume standard test commands.
- **Verify (Standards):** VERY IMPORTANT: After making code changes, execute the project-specific build, linting and type-checking commands (e.g., 'tsc', 'npm run lint', 'ruff check .') that you have identified for this project (or obtained from the user). This ensures code quality and adherence to standards. If unsure about these commands, you can ask the user if they'd like you to run them and if so how to.
@@ -242,11 +235,11 @@ When requested to perform tasks like fixing bugs, adding features, refactoring,
- Tool results and user messages may include <system-reminder> tags. <system-reminder> tags contain useful information and reminders. They are NOT part of the user's provided input or the tool result.
IMPORTANT: Always use the ${TodoWriteTool.Name} tool to plan and track tasks throughout the conversation.
IMPORTANT: Always use the ${ToolNames.TODO_WRITE} tool to plan and track tasks throughout the conversation.
## New Applications
**Goal:** Autonomously implement and deliver a visually appealing, substantially complete, and functional prototype. Utilize all tools at your disposal to implement the application. Some tools you may especially find useful are '${WriteFileTool.Name}', '${EditTool.Name}' and '${ShellTool.Name}'.
**Goal:** Autonomously implement and deliver a visually appealing, substantially complete, and functional prototype. Utilize all tools at your disposal to implement the application. Some tools you may especially find useful are '${ToolNames.WRITE_FILE}', '${ToolNames.EDIT}' and '${ToolNames.SHELL}'.
1. **Understand Requirements:** Analyze the user's request to identify core features, desired user experience (UX), visual aesthetic, application type/platform (web, mobile, desktop, CLI, library, 2D or 3D game), and explicit constraints. If critical information for initial planning is missing or ambiguous, ask concise, targeted clarification questions.
2. **Propose Plan:** Formulate an internal development plan. Present a clear, concise, high-level summary to the user. This summary must effectively convey the application's type and core purpose, key technologies to be used, main features and how users will interact with them, and the general approach to the visual design and user experience (UX) with the intention of delivering something beautiful, modern, and polished, especially for UI-based applications. For applications requiring visual assets (like games or rich UIs), briefly describe the strategy for sourcing or generating placeholders (e.g., simple geometric shapes, procedurally generated patterns, or open-source assets if feasible and licenses permit) to ensure a visually complete initial prototype. Ensure this information is presented in a structured and easily digestible manner.
@@ -259,7 +252,7 @@ IMPORTANT: Always use the ${TodoWriteTool.Name} tool to plan and track tasks thr
- **3d Games:** HTML/CSS/JavaScript with Three.js.
- **2d Games:** HTML/CSS/JavaScript.
3. **User Approval:** Obtain user approval for the proposed plan.
4. **Implementation:** Use the '${TodoWriteTool.Name}' tool to convert the approved plan into a structured todo list with specific, actionable tasks, then autonomously implement each task utilizing all available tools. When starting ensure you scaffold the application using '${ShellTool.Name}' for commands like 'npm init', 'npx create-react-app'. Aim for full scope completion. Proactively create or source necessary placeholder assets (e.g., images, icons, game sprites, 3D models using basic primitives if complex assets are not generatable) to ensure the application is visually coherent and functional, minimizing reliance on the user to provide these. If the model can generate simple assets (e.g., a uniformly colored square sprite, a simple 3D cube), it should do so. Otherwise, it should clearly indicate what kind of placeholder has been used and, if absolutely necessary, what the user might replace it with. Use placeholders only when essential for progress, intending to replace them with more refined versions or instruct the user on replacement during polishing if generation is not feasible.
4. **Implementation:** Use the '${ToolNames.TODO_WRITE}' tool to convert the approved plan into a structured todo list with specific, actionable tasks, then autonomously implement each task utilizing all available tools. When starting ensure you scaffold the application using '${ToolNames.SHELL}' for commands like 'npm init', 'npx create-react-app'. Aim for full scope completion. Proactively create or source necessary placeholder assets (e.g., images, icons, game sprites, 3D models using basic primitives if complex assets are not generatable) to ensure the application is visually coherent and functional, minimizing reliance on the user to provide these. If the model can generate simple assets (e.g., a uniformly colored square sprite, a simple 3D cube), it should do so. Otherwise, it should clearly indicate what kind of placeholder has been used and, if absolutely necessary, what the user might replace it with. Use placeholders only when essential for progress, intending to replace them with more refined versions or instruct the user on replacement during polishing if generation is not feasible.
5. **Verify:** Review work against the original request, the approved plan. Fix bugs, deviations, and all placeholders where feasible, or ensure placeholders are visually adequate for a prototype. Ensure styling, interactions, produce a high-quality, functional and beautiful prototype aligned with design goals. Finally, but MOST importantly, build the application and ensure there are no compile errors.
6. **Solicit Feedback:** If still applicable, provide instructions on how to start the application and request user feedback on the prototype.
@@ -275,18 +268,18 @@ IMPORTANT: Always use the ${TodoWriteTool.Name} tool to plan and track tasks thr
- **Handling Inability:** If unable/unwilling to fulfill a request, state so briefly (1-2 sentences) without excessive justification. Offer alternatives if appropriate.
## Security and Safety Rules
- **Explain Critical Commands:** Before executing commands with '${ShellTool.Name}' that modify the file system, codebase, or system state, you *must* provide a brief explanation of the command's purpose and potential impact. Prioritize user understanding and safety. You should not ask permission to use the tool; the user will be presented with a confirmation dialogue upon use (you do not need to tell them this).
- **Explain Critical Commands:** Before executing commands with '${ToolNames.SHELL}' that modify the file system, codebase, or system state, you *must* provide a brief explanation of the command's purpose and potential impact. Prioritize user understanding and safety. You should not ask permission to use the tool; the user will be presented with a confirmation dialogue upon use (you do not need to tell them this).
- **Security First:** Always apply security best practices. Never introduce code that exposes, logs, or commits secrets, API keys, or other sensitive information.
## Tool Usage
- **File Paths:** Always use absolute paths when referring to files with tools like '${ReadFileTool.Name}' or '${WriteFileTool.Name}'. Relative paths are not supported. You must provide an absolute path.
- **File Paths:** Always use absolute paths when referring to files with tools like '${ToolNames.READ_FILE}' or '${ToolNames.WRITE_FILE}'. Relative paths are not supported. You must provide an absolute path.
- **Parallelism:** Execute multiple independent tool calls in parallel when feasible (i.e. searching the codebase).
- **Command Execution:** Use the '${ShellTool.Name}' tool for running shell commands, remembering the safety rule to explain modifying commands first.
- **Command Execution:** Use the '${ToolNames.SHELL}' tool for running shell commands, remembering the safety rule to explain modifying commands first.
- **Background Processes:** Use background processes (via \`&\`) for commands that are unlikely to stop on their own, e.g. \`node server.js &\`. If unsure, ask the user.
- **Interactive Commands:** Try to avoid shell commands that are likely to require user interaction (e.g. \`git rebase -i\`). Use non-interactive versions of commands (e.g. \`npm init -y\` instead of \`npm init\`) when available, and otherwise remind the user that interactive shell commands are not supported and may cause hangs until canceled by the user.
- **Task Management:** Use the '${TodoWriteTool.Name}' tool proactively for complex, multi-step tasks to track progress and provide visibility to users. This tool helps organize work systematically and ensures no requirements are missed.
- **Subagent Delegation:** When doing file search, prefer to use the '${TaskTool.Name}' tool in order to reduce context usage. You should proactively use the '${TaskTool.Name}' tool with specialized agents when the task at hand matches the agent's description.
- **Remembering Facts:** Use the '${MemoryTool.Name}' tool to remember specific, *user-related* facts or preferences when the user explicitly asks, or when they state a clear, concise piece of information that would help personalize or streamline *your future interactions with them* (e.g., preferred coding style, common project paths they use, personal tool aliases). This tool is for user-specific information that should persist across sessions. Do *not* use it for general project context or information. If unsure whether to save something, you can ask the user, "Should I remember that for you?"
- **Task Management:** Use the '${ToolNames.TODO_WRITE}' tool proactively for complex, multi-step tasks to track progress and provide visibility to users. This tool helps organize work systematically and ensures no requirements are missed.
- **Subagent Delegation:** When doing file search, prefer to use the '${ToolNames.TASK}' tool in order to reduce context usage. You should proactively use the '${ToolNames.TASK}' tool with specialized agents when the task at hand matches the agent's description.
- **Remembering Facts:** Use the '${ToolNames.MEMORY}' tool to remember specific, *user-related* facts or preferences when the user explicitly asks, or when they state a clear, concise piece of information that would help personalize or streamline *your future interactions with them* (e.g., preferred coding style, common project paths they use, personal tool aliases). This tool is for user-specific information that should persist across sessions. Do *not* use it for general project context or information. If unsure whether to save something, you can ask the user, "Should I remember that for you?"
- **Respect User Confirmations:** Most tool calls (also denoted as 'function calls') will first require confirmation from the user, where they will either approve or cancel the function call. If a user cancels a function call, respect their choice and do _not_ try to make the function call again. It is okay to request the tool call again _only_ if the user requests that same tool call on a subsequent prompt. When a user cancels a function call, assume best intentions from the user and consider inquiring if they prefer any alternative paths forward.
## Interaction Details
@@ -338,157 +331,10 @@ ${(function () {
return '';
})()}
# Examples (Illustrating Tone and Workflow)
<example>
user: 1 + 2
model: 3
</example>
<example>
user: is 13 a prime number?
model: true
</example>
<example>
user: start the server implemented in server.js
model:
<tool_call>
<function=run_shell_command>
<parameter=command>
node server.js &
</parameter>
</function>
</tool_call>
</example>
<example>
user: Refactor the auth logic in src/auth.py to use the requests library instead of urllib.
model: Okay, I can refactor 'src/auth.py'.
First, I'll analyze the code and check for a test safety net before planning any changes.
<tool_call>
<function=glob>
<parameter=path>
tests/test_auth.py
</parameter>
</function>
</tool_call>
<tool_call>
<function=read_file>
<parameter=path>
/path/to/tests/test_auth.py
</parameter>
<parameter=offset>
0
</parameter>
<parameter=limit>
10
</parameter>
</function>
</tool_call>
(After analysis)
Great, 'tests/test_auth.py' exists and covers the core authentication logic. With this safety net in place, I can safely plan the refactoring.
I'll also confirm 'requests' is a dependency.
<tool_call>
<function=read_file>
<parameter=path>
/path/to/requirements.txt
</parameter>
</function>
</tool_call>
(After analysis)
Looks good, 'requests' is available.
Here's the plan:
1. Replace the 'urllib' calls with 'requests'.
2. Add proper 'try...except' error handling for the new network calls.
3. Remove the old 'urllib' import.
4. Run the project's linter and tests to verify the changes.
<tool_call>
<function=replace>
<parameter=path>
src/auth.py
</parameter>
<parameter=old_content>
(old code content)
</parameter>
<parameter=new_content>
(new code content)
</parameter>
</function>
</tool_call>
Refactoring complete. Running verification...
<tool_call>
<function=run_shell_command
<parameter=command>
ruff check src/auth.py && pytest
</parameter>
</function>
</tool_call>
(After verification passes)
All checks passed. This is a stable checkpoint.
Would you like me to write a commit message and commit these changes?
</example>
<example>
user: Delete the temp directory.
model: I can run \`rm -rf /path/to/project/temp\`. This will permanently delete the directory and all its contents.
</example>
<example>
user: Write tests for someFile.ts
model:
Okay, I can write those tests. First, I'll read someFile.ts to understand its functionality.
<tool_call>
<function=read_file>
<parameter=path>
/path/to/someFile.ts
</parameter>
</function>
</tool_call>
Now I'll look for existing or related test files to understand current testing conventions and dependencies.
<tool_call>
<function>read_many_files for paths ['**/*.test.ts', 'src/**/*.spec.ts'] assuming someFile.ts is in the src directory]
</tool_call>
(After reviewing existing tests and the file content)
<tool_call>
<function=write_file>
<parameter=path>
/path/to/someFile.test.ts
</parameter>
</function>
</tool_call>
I've written the tests. Now I'll run the project's test command to verify them.
<tool_call>
<function=run_shell_command>
<parameter=command>
npm run test
</parameter>
</function>
</tool_call>
(After verification passes)
All checks passed. This is a stable checkpoint.
</example>
<example>
user: Where are all the 'app.config' files in this project? I need to check their settings.
model:
<tool_call>
<function=glob>
<parameter=pattern>
./**/app.config
</parameter>
</function>
</tool_call>
(Assuming GlobTool returns a list of paths like ['/path/to/moduleA/app.config', '/path/to/moduleB/app.config'])
I found the following 'app.config' files:
- /path/to/moduleA/app.config
- /path/to/moduleB/app.config
To help you check their settings, I can read their contents. Which one would you like to start with, or should I read all of them?
</example>
${getToolCallExamples(model || '')}
# Final Reminder
Your core function is efficient and safe assistance. Balance extreme conciseness with the crucial need for clarity, especially regarding safety and potential system modifications. Always prioritize user control and project conventions. Never make assumptions about the contents of files; instead use '${ReadFileTool.Name}' or '${ReadManyFilesTool.Name}' to ensure you aren't making broad assumptions. Finally, you are an agent - please keep going until the user's query is completely resolved.
Your core function is efficient and safe assistance. Balance extreme conciseness with the crucial need for clarity, especially regarding safety and potential system modifications. Always prioritize user control and project conventions. Never make assumptions about the contents of files; instead use '${ToolNames.READ_FILE}' or '${ToolNames.READ_MANY_FILES}' to ensure you aren't making broad assumptions. Finally, you are an agent - please keep going until the user's query is completely resolved.
`.trim();
// if GEMINI_WRITE_SYSTEM_MD is set (and not 0|false), write base system prompt to file
@@ -615,3 +461,366 @@ You are a specialized context summarizer that creates a comprehensive markdown s
`.trim();
}
const generalToolCallExamples = `
# Examples (Illustrating Tone and Workflow)
<example>
user: 1 + 2
model: 3
</example>
<example>
user: is 13 a prime number?
model: true
</example>
<example>
user: start the server implemented in server.js
model: [tool_call: ${ToolNames.SHELL} for 'node server.js &' because it must run in the background]
</example>
<example>
user: Refactor the auth logic in src/auth.py to use the requests library instead of urllib.
model: Okay, I can refactor 'src/auth.py'.
First, I'll analyze the code and check for a test safety net before planning any changes.
[tool_call: ${ToolNames.GLOB} for path 'tests/test_auth.py']
[tool_call: ${ToolNames.READ_FILE} for path '/path/to/tests/test_auth.py' with offset 0 and limit 10]
(After analysis)
Great, 'tests/test_auth.py' exists and covers the core authentication logic. With this safety net in place, I can safely plan the refactoring.
I'll also confirm 'requests' is a dependency.
[tool_call: ${ToolNames.READ_FILE} for path '/path/to/requirements.txt']
(After analysis)
Looks good, 'requests' is available.
Here's the plan:
1. Replace the 'urllib' calls with 'requests'.
2. Add proper 'try...except' error handling for the new network calls.
3. Remove the old 'urllib' import.
4. Run the project's linter and tests to verify the changes.
[tool_call: ${ToolNames.EDIT} for path 'src/auth.py' replacing old content with new content]
Refactoring complete. Running verification...
[tool_call: ${ToolNames.SHELL} for 'ruff check src/auth.py && pytest']
(After verification passes)
All checks passed. This is a stable checkpoint.
Would you like me to write a commit message and commit these changes?
</example>
<example>
user: Delete the temp directory.
model: I can run \`rm -rf /path/to/project/temp\`. This will permanently delete the directory and all its contents.
</example>
<example>
user: Write tests for someFile.ts
model:
Okay, I can write those tests. First, I'll read someFile.ts to understand its functionality.
[tool_call: ${ToolNames.READ_FILE} for path '/path/to/someFile.ts']
Now I'll look for existing or related test files to understand current testing conventions and dependencies.
[tool_call: ${ToolNames.READ_MANY_FILES} for paths ['**/*.test.ts', 'src/**/*.spec.ts']]
(After reviewing existing tests and the file content)
[tool_call: ${ToolNames.WRITE_FILE} for path '/path/to/someFile.test.ts']
I've written the tests. Now I'll run the project's test command to verify them.
[tool_call: ${ToolNames.SHELL} for 'npm run test']
(After verification passes)
All checks passed. This is a stable checkpoint.
</example>
<example>
user: Where are all the 'app.config' files in this project? I need to check their settings.
model:
[tool_call: ${ToolNames.GLOB} for pattern './**/app.config']
(Assuming GlobTool returns a list of paths like ['/path/to/moduleA/app.config', '/path/to/moduleB/app.config'])
I found the following 'app.config' files:
- /path/to/moduleA/app.config
- /path/to/moduleB/app.config
To help you check their settings, I can read their contents. Which one would you like to start with, or should I read all of them?
</example>
`.trim();
const qwenCoderToolCallExamples = `
# Examples (Illustrating Tone and Workflow)
<example>
user: 1 + 2
model: 3
</example>
<example>
user: is 13 a prime number?
model: true
</example>
<example>
user: start the server implemented in server.js
model:
<tool_call>
<function=${ToolNames.SHELL}>
<parameter=command>
node server.js &
</parameter>
</function>
</tool_call>
</example>
<example>
user: Refactor the auth logic in src/auth.py to use the requests library instead of urllib.
model: Okay, I can refactor 'src/auth.py'.
First, I'll analyze the code and check for a test safety net before planning any changes.
<tool_call>
<function=${ToolNames.GLOB}>
<parameter=path>
tests/test_auth.py
</parameter>
</function>
</tool_call>
<tool_call>
<function=${ToolNames.READ_FILE}>
<parameter=path>
/path/to/tests/test_auth.py
</parameter>
<parameter=offset>
0
</parameter>
<parameter=limit>
10
</parameter>
</function>
</tool_call>
(After analysis)
Great, 'tests/test_auth.py' exists and covers the core authentication logic. With this safety net in place, I can safely plan the refactoring.
I'll also confirm 'requests' is a dependency.
<tool_call>
<function=${ToolNames.READ_FILE}>
<parameter=path>
/path/to/requirements.txt
</parameter>
</function>
</tool_call>
(After analysis)
Looks good, 'requests' is available.
Here's the plan:
1. Replace the 'urllib' calls with 'requests'.
2. Add proper 'try...except' error handling for the new network calls.
3. Remove the old 'urllib' import.
4. Run the project's linter and tests to verify the changes.
<tool_call>
<function=${ToolNames.EDIT}>
<parameter=path>
src/auth.py
</parameter>
<parameter=old_content>
(old code content)
</parameter>
<parameter=new_content>
(new code content)
</parameter>
</function>
</tool_call>
Refactoring complete. Running verification...
<tool_call>
<function=${ToolNames.SHELL}>
<parameter=command>
ruff check src/auth.py && pytest
</parameter>
</function>
</tool_call>
(After verification passes)
All checks passed. This is a stable checkpoint.
Would you like me to write a commit message and commit these changes?
</example>
<example>
user: Delete the temp directory.
model: I can run \`rm -rf /path/to/project/temp\`. This will permanently delete the directory and all its contents.
</example>
<example>
user: Write tests for someFile.ts
model:
Okay, I can write those tests. First, I'll read someFile.ts to understand its functionality.
<tool_call>
<function=${ToolNames.READ_FILE}>
<parameter=path>
/path/to/someFile.ts
</parameter>
</function>
</tool_call>
Now I'll look for existing or related test files to understand current testing conventions and dependencies.
<tool_call>
<function=${ToolNames.READ_MANY_FILES}>
<parameter=paths>
['**/*.test.ts', 'src/**/*.spec.ts']
</parameter>
</function>
</tool_call>
(After reviewing existing tests and the file content)
<tool_call>
<function=${ToolNames.WRITE_FILE}>
<parameter=path>
/path/to/someFile.test.ts
</parameter>
</function>
</tool_call>
I've written the tests. Now I'll run the project's test command to verify them.
<tool_call>
<function=${ToolNames.SHELL}>
<parameter=command>
npm run test
</parameter>
</function>
</tool_call>
(After verification passes)
All checks passed. This is a stable checkpoint.
</example>
<example>
user: Where are all the 'app.config' files in this project? I need to check their settings.
model:
<tool_call>
<function=${ToolNames.GLOB}>
<parameter=pattern>
./**/app.config
</parameter>
</function>
</tool_call>
(Assuming GlobTool returns a list of paths like ['/path/to/moduleA/app.config', '/path/to/moduleB/app.config'])
I found the following 'app.config' files:
- /path/to/moduleA/app.config
- /path/to/moduleB/app.config
To help you check their settings, I can read their contents. Which one would you like to start with, or should I read all of them?
</example>
`.trim();
const qwenVlToolCallExamples = `
# Examples (Illustrating Tone and Workflow)
<example>
user: 1 + 2
model: 3
</example>
<example>
user: is 13 a prime number?
model: true
</example>
<example>
user: start the server implemented in server.js
model:
<tool_call>
{"name": "${ToolNames.SHELL}", "arguments": {"command": "node server.js &"}}
</tool_call>
</example>
<example>
user: Refactor the auth logic in src/auth.py to use the requests library instead of urllib.
model: Okay, I can refactor 'src/auth.py'.
First, I'll analyze the code and check for a test safety net before planning any changes.
<tool_call>
{"name": "${ToolNames.GLOB}", "arguments": {"path": "tests/test_auth.py"}}
</tool_call>
<tool_call>
{"name": "${ToolNames.READ_FILE}", "arguments": {"path": "/path/to/tests/test_auth.py", "offset": 0, "limit": 10}}
</tool_call>
(After analysis)
Great, 'tests/test_auth.py' exists and covers the core authentication logic. With this safety net in place, I can safely plan the refactoring.
I'll also confirm 'requests' is a dependency.
<tool_call>
{"name": "${ToolNames.READ_FILE}", "arguments": {"path": "/path/to/requirements.txt"}}
</tool_call>
(After analysis)
Looks good, 'requests' is available.
Here's the plan:
1. Replace the 'urllib' calls with 'requests'.
2. Add proper 'try...except' error handling for the new network calls.
3. Remove the old 'urllib' import.
4. Run the project's linter and tests to verify the changes.
<tool_call>
{"name": "${ToolNames.EDIT}", "arguments": {"path": "src/auth.py", "old_content": "(old code content)", "new_content": "(new code content)"}}
</tool_call>
Refactoring complete. Running verification...
<tool_call>
{"name": "${ToolNames.SHELL}", "arguments": {"command": "ruff check src/auth.py && pytest"}}
</tool_call>
(After verification passes)
All checks passed. This is a stable checkpoint.
Would you like me to write a commit message and commit these changes?
</example>
<example>
user: Delete the temp directory.
model: I can run \`rm -rf /path/to/project/temp\`. This will permanently delete the directory and all its contents.
</example>
<example>
user: Write tests for someFile.ts
model:
Okay, I can write those tests. First, I'll read someFile.ts to understand its functionality.
<tool_call>
{"name": "${ToolNames.READ_FILE}", "arguments": {"path": "/path/to/someFile.ts"}}
</tool_call>
Now I'll look for existing or related test files to understand current testing conventions and dependencies.
<tool_call>
{"name": "${ToolNames.READ_MANY_FILES}", "arguments": {"paths": ["**/*.test.ts", "src/**/*.spec.ts"]}}
</tool_call>
(After reviewing existing tests and the file content)
<tool_call>
{"name": "${ToolNames.WRITE_FILE}", "arguments": {"path": "/path/to/someFile.test.ts"}}
</tool_call>
I've written the tests. Now I'll run the project's test command to verify them.
<tool_call>
{"name": "${ToolNames.SHELL}", "arguments": {"command": "npm run test"}}
</tool_call>
(After verification passes)
All checks passed. This is a stable checkpoint.
</example>
<example>
user: Where are all the 'app.config' files in this project? I need to check their settings.
model:
<tool_call>
{"name": "${ToolNames.GLOB}", "arguments": {"pattern": "./**/app.config"}}
</tool_call>
(Assuming GlobTool returns a list of paths like ['/path/to/moduleA/app.config', '/path/to/moduleB/app.config'])
I found the following 'app.config' files:
- /path/to/moduleA/app.config
- /path/to/moduleB/app.config
To help you check their settings, I can read their contents. Which one would you like to start with, or should I read all of them?
</example>
`.trim();
function getToolCallExamples(model?: string): string {
// Check for environment variable override first
const toolCallStyle = process.env['QWEN_CODE_TOOL_CALL_STYLE'];
if (toolCallStyle) {
switch (toolCallStyle.toLowerCase()) {
case 'qwen-coder':
return qwenCoderToolCallExamples;
case 'qwen-vl':
return qwenVlToolCallExamples;
case 'general':
return generalToolCallExamples;
default:
console.warn(
`Unknown QWEN_CODE_TOOL_CALL_STYLE value: ${toolCallStyle}. Using model-based detection.`,
);
break;
}
}
// Enhanced regex-based model detection
if (model && model.length < 100) {
// Match qwen*-coder patterns (e.g., qwen3-coder, qwen2.5-coder, qwen-coder)
if (/qwen[^-]*-coder/i.test(model)) {
return qwenCoderToolCallExamples;
}
// Match qwen*-vl patterns (e.g., qwen-vl, qwen2-vl, qwen3-vl)
if (/qwen[^-]*-vl/i.test(model)) {
return qwenVlToolCallExamples;
}
}
return generalToolCallExamples;
}

View File

@@ -116,6 +116,9 @@ const PATTERNS: Array<[RegExp, TokenCount]> = [
[/^qwen-flash-latest$/, LIMITS['1m']],
[/^qwen-turbo.*$/, LIMITS['128k']],
// Qwen Vision Models
[/^qwen-vl-max.*$/, LIMITS['128k']],
// -------------------
// ByteDance Seed-OSS (512K)
// -------------------

View File

@@ -242,7 +242,7 @@ describe('Turn', () => {
expect(turn.getDebugResponses().length).toBe(0);
expect(reportError).toHaveBeenCalledWith(
error,
'Error when talking to Gemini API',
'Error when talking to API',
[...historyContent, reqParts],
'Turn.run-sendMessageStream',
);

View File

@@ -310,7 +310,7 @@ export class Turn {
const contextForReport = [...this.chat.getHistory(/*curated*/ true), req];
await reportError(
error,
'Error when talking to Gemini API',
'Error when talking to API',
contextForReport,
'Turn.run-sendMessageStream',
);

View File

@@ -401,11 +401,9 @@ describe('QwenContentGenerator', () => {
expect(mockQwenClient.getAccessToken).toHaveBeenCalled();
});
it('should count tokens with valid token', async () => {
vi.mocked(mockQwenClient.getAccessToken).mockResolvedValue({
token: 'valid-token',
});
vi.mocked(mockQwenClient.getCredentials).mockReturnValue(mockCredentials);
it('should count tokens without requiring authentication', async () => {
// Clear any previous mock calls
vi.clearAllMocks();
const request: CountTokensParameters = {
model: 'qwen-turbo',
@@ -415,7 +413,8 @@ describe('QwenContentGenerator', () => {
const result = await qwenContentGenerator.countTokens(request);
expect(result.totalTokens).toBe(15);
expect(mockQwenClient.getAccessToken).toHaveBeenCalled();
// countTokens is a local operation and should not require OAuth credentials
expect(mockQwenClient.getAccessToken).not.toHaveBeenCalled();
});
it('should embed content with valid token', async () => {
@@ -1652,7 +1651,7 @@ describe('QwenContentGenerator', () => {
SharedTokenManager.getInstance = originalGetInstance;
});
it('should handle all method types with token failure', async () => {
it('should handle method types with token failure (except countTokens)', async () => {
const mockTokenManager = {
getValidCredentials: vi
.fn()
@@ -1685,7 +1684,7 @@ describe('QwenContentGenerator', () => {
contents: [{ parts: [{ text: 'Embed' }] }],
};
// All methods should fail with the same error
// Methods requiring authentication should fail
await expect(
newGenerator.generateContent(generateRequest, 'test-id'),
).rejects.toThrow('Failed to obtain valid Qwen access token');
@@ -1694,14 +1693,14 @@ describe('QwenContentGenerator', () => {
newGenerator.generateContentStream(generateRequest, 'test-id'),
).rejects.toThrow('Failed to obtain valid Qwen access token');
await expect(newGenerator.countTokens(countRequest)).rejects.toThrow(
'Failed to obtain valid Qwen access token',
);
await expect(newGenerator.embedContent(embedRequest)).rejects.toThrow(
'Failed to obtain valid Qwen access token',
);
// countTokens should succeed as it's a local operation
const countResult = await newGenerator.countTokens(countRequest);
expect(countResult.totalTokens).toBe(15);
SharedTokenManager.getInstance = originalGetInstance;
});
});

View File

@@ -180,9 +180,7 @@ export class QwenContentGenerator extends OpenAIContentGenerator {
override async countTokens(
request: CountTokensParameters,
): Promise<CountTokensResponse> {
return this.executeWithCredentialManagement(() =>
super.countTokens(request),
);
return super.countTokens(request);
}
/**

View File

@@ -23,7 +23,11 @@ import {
} from 'vitest';
import { Config, type ConfigParameters } from '../config/config.js';
import { DEFAULT_GEMINI_MODEL } from '../config/models.js';
import { createContentGenerator } from '../core/contentGenerator.js';
import {
createContentGenerator,
createContentGeneratorConfig,
AuthType,
} from '../core/contentGenerator.js';
import { GeminiChat } from '../core/geminiChat.js';
import { executeToolCall } from '../core/nonInteractiveToolExecutor.js';
import type { ToolRegistry } from '../tools/tool-registry.js';
@@ -56,8 +60,7 @@ async function createMockConfig(
};
const config = new Config(configParams);
await config.initialize();
// eslint-disable-next-line @typescript-eslint/no-explicit-any
await config.refreshAuth('test-auth' as any);
await config.refreshAuth(AuthType.USE_GEMINI);
// Mock ToolRegistry
const mockToolRegistry = {
@@ -164,6 +167,10 @@ describe('subagent.ts', () => {
getGenerativeModel: vi.fn(),
// eslint-disable-next-line @typescript-eslint/no-explicit-any
} as any);
vi.mocked(createContentGeneratorConfig).mockReturnValue({
model: DEFAULT_GEMINI_MODEL,
authType: undefined,
});
mockSendMessageStream = vi.fn();
// We mock the implementation of the constructor.

View File

@@ -24,6 +24,7 @@ import { ApprovalMode } from '../config/config.js';
import { ensureCorrectEdit } from '../utils/editCorrector.js';
import { DEFAULT_DIFF_OPTIONS, getDiffStat } from './diffOptions.js';
import { ReadFileTool } from './read-file.js';
import { ToolNames } from './tool-names.js';
import type {
ModifiableDeclarativeTool,
ModifyContext,
@@ -461,7 +462,7 @@ export class EditTool
extends BaseDeclarativeTool<EditToolParams, ToolResult>
implements ModifiableDeclarativeTool<EditToolParams>
{
static readonly Name = 'edit';
static readonly Name = ToolNames.EDIT;
constructor(private readonly config: Config) {
super(
EditTool.Name,

View File

@@ -9,6 +9,7 @@ import path from 'node:path';
import { glob, escape } from 'glob';
import type { ToolInvocation, ToolResult } from './tools.js';
import { BaseDeclarativeTool, BaseToolInvocation, Kind } from './tools.js';
import { ToolNames } from './tool-names.js';
import { shortenPath, makeRelative } from '../utils/paths.js';
import type { Config } from '../config/config.js';
import { ToolErrorType } from './tool-error.js';
@@ -252,7 +253,7 @@ class GlobToolInvocation extends BaseToolInvocation<
* Implementation of the Glob tool logic
*/
export class GlobTool extends BaseDeclarativeTool<GlobToolParams, ToolResult> {
static readonly Name = 'glob';
static readonly Name = ToolNames.GLOB;
constructor(private config: Config) {
super(

View File

@@ -12,6 +12,7 @@ import { spawn } from 'node:child_process';
import { globStream } from 'glob';
import type { ToolInvocation, ToolResult } from './tools.js';
import { BaseDeclarativeTool, BaseToolInvocation, Kind } from './tools.js';
import { ToolNames } from './tool-names.js';
import { makeRelative, shortenPath } from '../utils/paths.js';
import { getErrorMessage, isNodeError } from '../utils/errors.js';
import { isGitRepository } from '../utils/gitUtils.js';
@@ -597,7 +598,7 @@ class GrepToolInvocation extends BaseToolInvocation<
* Implementation of the Grep tool logic (moved from CLI)
*/
export class GrepTool extends BaseDeclarativeTool<GrepToolParams, ToolResult> {
static readonly Name = 'search_file_content'; // Keep static name
static readonly Name = ToolNames.GREP;
constructor(private readonly config: Config) {
super(

View File

@@ -8,6 +8,7 @@ import path from 'node:path';
import { makeRelative, shortenPath } from '../utils/paths.js';
import type { ToolInvocation, ToolLocation, ToolResult } from './tools.js';
import { BaseDeclarativeTool, BaseToolInvocation, Kind } from './tools.js';
import { ToolNames } from './tool-names.js';
import type { PartUnion } from '@google/genai';
import {
@@ -136,7 +137,7 @@ export class ReadFileTool extends BaseDeclarativeTool<
ReadFileToolParams,
ToolResult
> {
static readonly Name: string = 'read_file';
static readonly Name: string = ToolNames.READ_FILE;
constructor(private config: Config) {
super(

View File

@@ -6,6 +6,7 @@
import type { ToolInvocation, ToolResult } from './tools.js';
import { BaseDeclarativeTool, BaseToolInvocation, Kind } from './tools.js';
import { ToolNames } from './tool-names.js';
import { getErrorMessage } from '../utils/errors.js';
import * as fs from 'node:fs';
import * as path from 'node:path';
@@ -526,7 +527,7 @@ export class ReadManyFilesTool extends BaseDeclarativeTool<
ReadManyFilesParams,
ToolResult
> {
static readonly Name: string = 'read_many_files';
static readonly Name: string = ToolNames.READ_MANY_FILES;
constructor(private config: Config) {
const parameterSchema = {

View File

@@ -9,6 +9,7 @@ import path from 'node:path';
import os, { EOL } from 'node:os';
import crypto from 'node:crypto';
import type { Config } from '../config/config.js';
import { ToolNames } from './tool-names.js';
import { ToolErrorType } from './tool-error.js';
import type {
ToolInvocation,
@@ -403,7 +404,7 @@ export class ShellTool extends BaseDeclarativeTool<
ShellToolParams,
ToolResult
> {
static Name: string = 'run_shell_command';
static Name: string = ToolNames.SHELL;
private allowlist: Set<string> = new Set();
constructor(private readonly config: Config) {

View File

@@ -5,6 +5,7 @@
*/
import { BaseDeclarativeTool, BaseToolInvocation, Kind } from './tools.js';
import { ToolNames } from './tool-names.js';
import type {
ToolResult,
ToolResultDisplay,
@@ -46,7 +47,7 @@ export interface TaskParams {
* for the model to choose from.
*/
export class TaskTool extends BaseDeclarativeTool<TaskParams, ToolResult> {
static readonly Name: string = 'task';
static readonly Name: string = ToolNames.TASK;
private subagentManager: SubagentManager;
private availableSubagents: SubagentConfig[] = [];

View File

@@ -0,0 +1,23 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
/**
* Tool name constants to avoid circular dependencies.
* These constants are used across multiple files and should be kept in sync
* with the actual tool class names.
*/
export const ToolNames = {
EDIT: 'edit',
WRITE_FILE: 'write_file',
READ_FILE: 'read_file',
READ_MANY_FILES: 'read_many_files',
GREP: 'search_file_content',
GLOB: 'glob',
SHELL: 'run_shell_command',
TODO_WRITE: 'todo_write',
MEMORY: 'save_memory',
TASK: 'task',
} as const;

View File

@@ -31,6 +31,7 @@ import {
ensureCorrectFileContent,
} from '../utils/editCorrector.js';
import { DEFAULT_DIFF_OPTIONS, getDiffStat } from './diffOptions.js';
import { ToolNames } from './tool-names.js';
import type {
ModifiableDeclarativeTool,
ModifyContext,
@@ -403,7 +404,7 @@ export class WriteFileTool
extends BaseDeclarativeTool<WriteFileToolParams, ToolResult>
implements ModifiableDeclarativeTool<WriteFileToolParams>
{
static readonly Name: string = 'write_file';
static readonly Name: string = ToolNames.WRITE_FILE;
constructor(private readonly config: Config) {
super(

View File

@@ -7,11 +7,7 @@
import type { Content, GenerateContentConfig } from '@google/genai';
import type { GeminiClient } from '../core/client.js';
import type { EditToolParams } from '../tools/edit.js';
import { EditTool } from '../tools/edit.js';
import { WriteFileTool } from '../tools/write-file.js';
import { ReadFileTool } from '../tools/read-file.js';
import { ReadManyFilesTool } from '../tools/read-many-files.js';
import { GrepTool } from '../tools/grep.js';
import { ToolNames } from '../tools/tool-names.js';
import { LruCache } from './LruCache.js';
import { DEFAULT_QWEN_FLASH_MODEL } from '../config/models.js';
import {
@@ -85,14 +81,14 @@ async function findLastEditTimestamp(
const history = (await client.getHistory()) ?? [];
// Tools that may reference the file path in their FunctionResponse `output`.
const toolsInResp = new Set([
WriteFileTool.Name,
EditTool.Name,
ReadManyFilesTool.Name,
GrepTool.Name,
const toolsInResp = new Set<string>([
ToolNames.WRITE_FILE,
ToolNames.EDIT,
ToolNames.READ_MANY_FILES,
ToolNames.GREP,
]);
// Tools that may reference the file path in their FunctionCall `args`.
const toolsInCall = new Set([...toolsInResp, ReadFileTool.Name]);
const toolsInCall = new Set<string>([...toolsInResp, ToolNames.READ_FILE]);
// Iterate backwards to find the most recent relevant action.
for (const entry of history.slice().reverse()) {

View File

@@ -0,0 +1,157 @@
/**
* @license
* Copyright 2025 Qwen
* SPDX-License-Identifier: Apache-2.0
*/
import { describe, it, expect } from 'vitest';
import { ImageTokenizer } from './imageTokenizer.js';
describe('ImageTokenizer', () => {
const tokenizer = new ImageTokenizer();
describe('token calculation', () => {
it('should calculate tokens based on image dimensions with reference logic', () => {
const metadata = {
width: 28,
height: 28,
mimeType: 'image/png',
dataSize: 1000,
};
const tokens = tokenizer.calculateTokens(metadata);
// 28x28 = 784 pixels = 1 image token + 2 special tokens = 3 total
// But minimum scaling may apply for small images
expect(tokens).toBeGreaterThanOrEqual(6); // Minimum after scaling + special tokens
});
it('should calculate tokens for larger images', () => {
const metadata = {
width: 512,
height: 512,
mimeType: 'image/png',
dataSize: 10000,
};
const tokens = tokenizer.calculateTokens(metadata);
// 512x512 with reference logic: rounded dimensions + scaling + special tokens
expect(tokens).toBeGreaterThan(300);
expect(tokens).toBeLessThan(400); // Should be reasonable for 512x512
});
it('should enforce minimum tokens per image with scaling', () => {
const metadata = {
width: 1,
height: 1,
mimeType: 'image/png',
dataSize: 100,
};
const tokens = tokenizer.calculateTokens(metadata);
// Tiny images get scaled up to minimum pixels + special tokens
expect(tokens).toBeGreaterThanOrEqual(6); // 4 image tokens + 2 special tokens
});
it('should handle very large images with scaling', () => {
const metadata = {
width: 8192,
height: 8192,
mimeType: 'image/png',
dataSize: 100000,
};
const tokens = tokenizer.calculateTokens(metadata);
// Very large images should be scaled down to max limit + special tokens
expect(tokens).toBeLessThanOrEqual(16386); // 16384 max + 2 special tokens
expect(tokens).toBeGreaterThan(16000); // Should be close to the limit
});
});
describe('PNG dimension extraction', () => {
it('should extract dimensions from valid PNG', async () => {
// 1x1 PNG image in base64
const pngBase64 =
'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNkYPhfDwAChAI9jU77yQAAAABJRU5ErkJggg==';
const metadata = await tokenizer.extractImageMetadata(
pngBase64,
'image/png',
);
expect(metadata.width).toBe(1);
expect(metadata.height).toBe(1);
expect(metadata.mimeType).toBe('image/png');
});
it('should handle invalid PNG gracefully', async () => {
const invalidBase64 = 'invalid-png-data';
const metadata = await tokenizer.extractImageMetadata(
invalidBase64,
'image/png',
);
// Should return default dimensions
expect(metadata.width).toBe(512);
expect(metadata.height).toBe(512);
expect(metadata.mimeType).toBe('image/png');
});
});
describe('batch processing', () => {
it('should process multiple images serially', async () => {
const pngBase64 =
'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNkYPhfDwAChAI9jU77yQAAAABJRU5ErkJggg==';
const images = [
{ data: pngBase64, mimeType: 'image/png' },
{ data: pngBase64, mimeType: 'image/png' },
{ data: pngBase64, mimeType: 'image/png' },
];
const tokens = await tokenizer.calculateTokensBatch(images);
expect(tokens).toHaveLength(3);
expect(tokens.every((t) => t >= 4)).toBe(true); // All should have at least 4 tokens
});
it('should handle mixed valid and invalid images', async () => {
const validPng =
'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNkYPhfDwAChAI9jU77yQAAAABJRU5ErkJggg==';
const invalidPng = 'invalid-data';
const images = [
{ data: validPng, mimeType: 'image/png' },
{ data: invalidPng, mimeType: 'image/png' },
];
const tokens = await tokenizer.calculateTokensBatch(images);
expect(tokens).toHaveLength(2);
expect(tokens.every((t) => t >= 4)).toBe(true); // All should have at least minimum tokens
});
});
describe('different image formats', () => {
it('should handle different MIME types', async () => {
const pngBase64 =
'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNkYPhfDwAChAI9jU77yQAAAABJRU5ErkJggg==';
const formats = ['image/png', 'image/jpeg', 'image/webp', 'image/gif'];
for (const mimeType of formats) {
const metadata = await tokenizer.extractImageMetadata(
pngBase64,
mimeType,
);
expect(metadata.mimeType).toBe(mimeType);
expect(metadata.width).toBeGreaterThan(0);
expect(metadata.height).toBeGreaterThan(0);
}
});
});
});

View File

@@ -0,0 +1,505 @@
/**
* @license
* Copyright 2025 Qwen
* SPDX-License-Identifier: Apache-2.0
*/
import type { ImageMetadata } from './types.js';
import { isSupportedImageMimeType } from './supportedImageFormats.js';
/**
* Image tokenizer for calculating image tokens based on dimensions
*
* Key rules:
* - 28x28 pixels = 1 token
* - Minimum: 4 tokens per image
* - Maximum: 16384 tokens per image
* - Additional: 2 special tokens (vision_bos + vision_eos)
* - Supports: PNG, JPEG, WebP, GIF, BMP, TIFF, HEIC formats
*/
export class ImageTokenizer {
/** 28x28 pixels = 1 token */
private static readonly PIXELS_PER_TOKEN = 28 * 28;
/** Minimum tokens per image */
private static readonly MIN_TOKENS_PER_IMAGE = 4;
/** Maximum tokens per image */
private static readonly MAX_TOKENS_PER_IMAGE = 16384;
/** Special tokens for vision markers */
private static readonly VISION_SPECIAL_TOKENS = 2;
/**
* Extract image metadata from base64 data
*
* @param base64Data Base64-encoded image data (with or without data URL prefix)
* @param mimeType MIME type of the image
* @returns Promise resolving to ImageMetadata with dimensions and format info
*/
async extractImageMetadata(
base64Data: string,
mimeType: string,
): Promise<ImageMetadata> {
try {
// Check if the MIME type is supported
if (!isSupportedImageMimeType(mimeType)) {
console.warn(`Unsupported image format: ${mimeType}`);
// Return default metadata for unsupported formats
return {
width: 512,
height: 512,
mimeType,
dataSize: Math.floor(base64Data.length * 0.75),
};
}
const cleanBase64 = base64Data.replace(/^data:[^;]+;base64,/, '');
const buffer = Buffer.from(cleanBase64, 'base64');
const dimensions = await this.extractDimensions(buffer, mimeType);
return {
width: dimensions.width,
height: dimensions.height,
mimeType,
dataSize: buffer.length,
};
} catch (error) {
console.warn('Failed to extract image metadata:', error);
// Return default metadata for fallback
return {
width: 512,
height: 512,
mimeType,
dataSize: Math.floor(base64Data.length * 0.75),
};
}
}
/**
* Extract image dimensions from buffer based on format
*
* @param buffer Binary image data buffer
* @param mimeType MIME type to determine parsing strategy
* @returns Promise resolving to width and height dimensions
*/
private async extractDimensions(
buffer: Buffer,
mimeType: string,
): Promise<{ width: number; height: number }> {
if (mimeType.includes('png')) {
return this.extractPngDimensions(buffer);
}
if (mimeType.includes('jpeg') || mimeType.includes('jpg')) {
return this.extractJpegDimensions(buffer);
}
if (mimeType.includes('webp')) {
return this.extractWebpDimensions(buffer);
}
if (mimeType.includes('gif')) {
return this.extractGifDimensions(buffer);
}
if (mimeType.includes('bmp')) {
return this.extractBmpDimensions(buffer);
}
if (mimeType.includes('tiff')) {
return this.extractTiffDimensions(buffer);
}
if (mimeType.includes('heic')) {
return this.extractHeicDimensions(buffer);
}
return { width: 512, height: 512 };
}
/**
* Extract PNG dimensions from IHDR chunk
* PNG signature: 89 50 4E 47 0D 0A 1A 0A
* Width/height at bytes 16-19 and 20-23 (big-endian)
*/
private extractPngDimensions(buffer: Buffer): {
width: number;
height: number;
} {
if (buffer.length < 24) {
throw new Error('Invalid PNG: buffer too short');
}
// Verify PNG signature
const signature = buffer.subarray(0, 8);
const expectedSignature = Buffer.from([
0x89, 0x50, 0x4e, 0x47, 0x0d, 0x0a, 0x1a, 0x0a,
]);
if (!signature.equals(expectedSignature)) {
throw new Error('Invalid PNG signature');
}
const width = buffer.readUInt32BE(16);
const height = buffer.readUInt32BE(20);
return { width, height };
}
/**
* Extract JPEG dimensions from SOF (Start of Frame) markers
* JPEG starts with FF D8, SOF markers: 0xC0-0xC3, 0xC5-0xC7, 0xC9-0xCB, 0xCD-0xCF
* Dimensions at offset +5 (height) and +7 (width) from SOF marker
*/
private extractJpegDimensions(buffer: Buffer): {
width: number;
height: number;
} {
if (buffer.length < 4 || buffer[0] !== 0xff || buffer[1] !== 0xd8) {
throw new Error('Invalid JPEG signature');
}
let offset = 2;
while (offset < buffer.length - 8) {
if (buffer[offset] !== 0xff) {
offset++;
continue;
}
const marker = buffer[offset + 1];
// SOF markers
if (
(marker >= 0xc0 && marker <= 0xc3) ||
(marker >= 0xc5 && marker <= 0xc7) ||
(marker >= 0xc9 && marker <= 0xcb) ||
(marker >= 0xcd && marker <= 0xcf)
) {
const height = buffer.readUInt16BE(offset + 5);
const width = buffer.readUInt16BE(offset + 7);
return { width, height };
}
const segmentLength = buffer.readUInt16BE(offset + 2);
offset += 2 + segmentLength;
}
throw new Error('Could not find JPEG dimensions');
}
/**
* Extract WebP dimensions from RIFF container
* Supports VP8, VP8L, and VP8X formats
*/
private extractWebpDimensions(buffer: Buffer): {
width: number;
height: number;
} {
if (buffer.length < 30) {
throw new Error('Invalid WebP: too short');
}
const riffSignature = buffer.subarray(0, 4).toString('ascii');
const webpSignature = buffer.subarray(8, 12).toString('ascii');
if (riffSignature !== 'RIFF' || webpSignature !== 'WEBP') {
throw new Error('Invalid WebP signature');
}
const format = buffer.subarray(12, 16).toString('ascii');
if (format === 'VP8 ') {
const width = buffer.readUInt16LE(26) & 0x3fff;
const height = buffer.readUInt16LE(28) & 0x3fff;
return { width, height };
} else if (format === 'VP8L') {
const bits = buffer.readUInt32LE(21);
const width = (bits & 0x3fff) + 1;
const height = ((bits >> 14) & 0x3fff) + 1;
return { width, height };
} else if (format === 'VP8X') {
const width = (buffer.readUInt32LE(24) & 0xffffff) + 1;
const height = (buffer.readUInt32LE(26) & 0xffffff) + 1;
return { width, height };
}
throw new Error('Unsupported WebP format');
}
/**
* Extract GIF dimensions from header
* Supports GIF87a and GIF89a formats
*/
private extractGifDimensions(buffer: Buffer): {
width: number;
height: number;
} {
if (buffer.length < 10) {
throw new Error('Invalid GIF: too short');
}
const signature = buffer.subarray(0, 6).toString('ascii');
if (signature !== 'GIF87a' && signature !== 'GIF89a') {
throw new Error('Invalid GIF signature');
}
const width = buffer.readUInt16LE(6);
const height = buffer.readUInt16LE(8);
return { width, height };
}
/**
* Calculate tokens for an image based on its metadata
*
* @param metadata Image metadata containing width, height, and format info
* @returns Total token count including base image tokens and special tokens
*/
calculateTokens(metadata: ImageMetadata): number {
return this.calculateTokensWithScaling(metadata.width, metadata.height);
}
/**
* Calculate tokens with scaling logic
*
* Steps:
* 1. Normalize to 28-pixel multiples
* 2. Scale large images down, small images up
* 3. Calculate tokens: pixels / 784 + 2 special tokens
*
* @param width Original image width in pixels
* @param height Original image height in pixels
* @returns Total token count for the image
*/
private calculateTokensWithScaling(width: number, height: number): number {
// Normalize to 28-pixel multiples
let hBar = Math.round(height / 28) * 28;
let wBar = Math.round(width / 28) * 28;
// Define pixel boundaries
const minPixels =
ImageTokenizer.MIN_TOKENS_PER_IMAGE * ImageTokenizer.PIXELS_PER_TOKEN;
const maxPixels =
ImageTokenizer.MAX_TOKENS_PER_IMAGE * ImageTokenizer.PIXELS_PER_TOKEN;
// Apply scaling
if (hBar * wBar > maxPixels) {
// Scale down large images
const beta = Math.sqrt((height * width) / maxPixels);
hBar = Math.floor(height / beta / 28) * 28;
wBar = Math.floor(width / beta / 28) * 28;
} else if (hBar * wBar < minPixels) {
// Scale up small images
const beta = Math.sqrt(minPixels / (height * width));
hBar = Math.ceil((height * beta) / 28) * 28;
wBar = Math.ceil((width * beta) / 28) * 28;
}
// Calculate tokens
const imageTokens = Math.floor(
(hBar * wBar) / ImageTokenizer.PIXELS_PER_TOKEN,
);
return imageTokens + ImageTokenizer.VISION_SPECIAL_TOKENS;
}
/**
* Calculate tokens for multiple images serially
*
* @param base64DataArray Array of image data with MIME type information
* @returns Promise resolving to array of token counts in same order as input
*/
async calculateTokensBatch(
base64DataArray: Array<{ data: string; mimeType: string }>,
): Promise<number[]> {
const results: number[] = [];
for (const { data, mimeType } of base64DataArray) {
try {
const metadata = await this.extractImageMetadata(data, mimeType);
results.push(this.calculateTokens(metadata));
} catch (error) {
console.warn('Error calculating tokens for image:', error);
// Return minimum tokens as fallback
results.push(
ImageTokenizer.MIN_TOKENS_PER_IMAGE +
ImageTokenizer.VISION_SPECIAL_TOKENS,
);
}
}
return results;
}
/**
* Extract BMP dimensions from header
* BMP signature: 42 4D (BM)
* Width/height at bytes 18-21 and 22-25 (little-endian)
*/
private extractBmpDimensions(buffer: Buffer): {
width: number;
height: number;
} {
if (buffer.length < 26) {
throw new Error('Invalid BMP: buffer too short');
}
// Verify BMP signature
if (buffer[0] !== 0x42 || buffer[1] !== 0x4d) {
throw new Error('Invalid BMP signature');
}
const width = buffer.readUInt32LE(18);
const height = buffer.readUInt32LE(22);
return { width, height: Math.abs(height) }; // Height can be negative for top-down BMPs
}
/**
* Extract TIFF dimensions from IFD (Image File Directory)
* TIFF can be little-endian (II) or big-endian (MM)
* Width/height are stored in IFD entries with tags 0x0100 and 0x0101
*/
private extractTiffDimensions(buffer: Buffer): {
width: number;
height: number;
} {
if (buffer.length < 8) {
throw new Error('Invalid TIFF: buffer too short');
}
// Check byte order
const byteOrder = buffer.subarray(0, 2).toString('ascii');
const isLittleEndian = byteOrder === 'II';
const isBigEndian = byteOrder === 'MM';
if (!isLittleEndian && !isBigEndian) {
throw new Error('Invalid TIFF byte order');
}
// Read magic number (should be 42)
const magic = isLittleEndian
? buffer.readUInt16LE(2)
: buffer.readUInt16BE(2);
if (magic !== 42) {
throw new Error('Invalid TIFF magic number');
}
// Read IFD offset
const ifdOffset = isLittleEndian
? buffer.readUInt32LE(4)
: buffer.readUInt32BE(4);
if (ifdOffset >= buffer.length) {
throw new Error('Invalid TIFF IFD offset');
}
// Read number of directory entries
const numEntries = isLittleEndian
? buffer.readUInt16LE(ifdOffset)
: buffer.readUInt16BE(ifdOffset);
let width = 0;
let height = 0;
// Parse IFD entries
for (let i = 0; i < numEntries; i++) {
const entryOffset = ifdOffset + 2 + i * 12;
if (entryOffset + 12 > buffer.length) break;
const tag = isLittleEndian
? buffer.readUInt16LE(entryOffset)
: buffer.readUInt16BE(entryOffset);
const type = isLittleEndian
? buffer.readUInt16LE(entryOffset + 2)
: buffer.readUInt16BE(entryOffset + 2);
const value = isLittleEndian
? buffer.readUInt32LE(entryOffset + 8)
: buffer.readUInt32BE(entryOffset + 8);
if (tag === 0x0100) {
// ImageWidth
width = type === 3 ? value : value; // SHORT or LONG
} else if (tag === 0x0101) {
// ImageLength (height)
height = type === 3 ? value : value; // SHORT or LONG
}
if (width > 0 && height > 0) break;
}
if (width === 0 || height === 0) {
throw new Error('Could not find TIFF dimensions');
}
return { width, height };
}
/**
* Extract HEIC dimensions from meta box
* HEIC is based on ISO Base Media File Format
* This is a simplified implementation that looks for 'ispe' (Image Spatial Extents) box
*/
private extractHeicDimensions(buffer: Buffer): {
width: number;
height: number;
} {
if (buffer.length < 12) {
throw new Error('Invalid HEIC: buffer too short');
}
// Check for ftyp box with HEIC brand
const ftypBox = buffer.subarray(4, 8).toString('ascii');
if (ftypBox !== 'ftyp') {
throw new Error('Invalid HEIC: missing ftyp box');
}
const brand = buffer.subarray(8, 12).toString('ascii');
if (!['heic', 'heix', 'hevc', 'hevx'].includes(brand)) {
throw new Error('Invalid HEIC brand');
}
// Look for meta box and then ispe box
let offset = 0;
while (offset < buffer.length - 8) {
const boxSize = buffer.readUInt32BE(offset);
const boxType = buffer.subarray(offset + 4, offset + 8).toString('ascii');
if (boxType === 'meta') {
// Look for ispe box inside meta box
const metaOffset = offset + 8;
let innerOffset = metaOffset + 4; // Skip version and flags
while (innerOffset < offset + boxSize - 8) {
const innerBoxSize = buffer.readUInt32BE(innerOffset);
const innerBoxType = buffer
.subarray(innerOffset + 4, innerOffset + 8)
.toString('ascii');
if (innerBoxType === 'ispe') {
// Found Image Spatial Extents box
if (innerOffset + 20 <= buffer.length) {
const width = buffer.readUInt32BE(innerOffset + 12);
const height = buffer.readUInt32BE(innerOffset + 16);
return { width, height };
}
}
if (innerBoxSize === 0) break;
innerOffset += innerBoxSize;
}
}
if (boxSize === 0) break;
offset += boxSize;
}
// Fallback: return default dimensions if we can't parse the structure
console.warn('Could not extract HEIC dimensions, using default');
return { width: 512, height: 512 };
}
}

View File

@@ -0,0 +1,40 @@
/**
* @license
* Copyright 2025 Qwen
* SPDX-License-Identifier: Apache-2.0
*/
export { DefaultRequestTokenizer } from './requestTokenizer.js';
import { DefaultRequestTokenizer } from './requestTokenizer.js';
export { TextTokenizer } from './textTokenizer.js';
export { ImageTokenizer } from './imageTokenizer.js';
export type {
RequestTokenizer,
TokenizerConfig,
TokenCalculationResult,
ImageMetadata,
} from './types.js';
// Singleton instance for convenient usage
let defaultTokenizer: DefaultRequestTokenizer | null = null;
/**
* Get the default request tokenizer instance
*/
export function getDefaultTokenizer(): DefaultRequestTokenizer {
if (!defaultTokenizer) {
defaultTokenizer = new DefaultRequestTokenizer();
}
return defaultTokenizer;
}
/**
* Dispose of the default tokenizer instance
*/
export async function disposeDefaultTokenizer(): Promise<void> {
if (defaultTokenizer) {
await defaultTokenizer.dispose();
defaultTokenizer = null;
}
}

View File

@@ -0,0 +1,293 @@
/**
* @license
* Copyright 2025 Qwen
* SPDX-License-Identifier: Apache-2.0
*/
import { describe, it, expect, beforeEach, afterEach } from 'vitest';
import { DefaultRequestTokenizer } from './requestTokenizer.js';
import type { CountTokensParameters } from '@google/genai';
describe('DefaultRequestTokenizer', () => {
let tokenizer: DefaultRequestTokenizer;
beforeEach(() => {
tokenizer = new DefaultRequestTokenizer();
});
afterEach(async () => {
await tokenizer.dispose();
});
describe('text token calculation', () => {
it('should calculate tokens for simple text content', async () => {
const request: CountTokensParameters = {
model: 'test-model',
contents: [
{
role: 'user',
parts: [{ text: 'Hello, world!' }],
},
],
};
const result = await tokenizer.calculateTokens(request);
expect(result.totalTokens).toBeGreaterThan(0);
expect(result.breakdown.textTokens).toBeGreaterThan(0);
expect(result.breakdown.imageTokens).toBe(0);
expect(result.processingTime).toBeGreaterThan(0);
});
it('should handle multiple text parts', async () => {
const request: CountTokensParameters = {
model: 'test-model',
contents: [
{
role: 'user',
parts: [
{ text: 'First part' },
{ text: 'Second part' },
{ text: 'Third part' },
],
},
],
};
const result = await tokenizer.calculateTokens(request);
expect(result.totalTokens).toBeGreaterThan(0);
expect(result.breakdown.textTokens).toBeGreaterThan(0);
});
it('should handle string content', async () => {
const request: CountTokensParameters = {
model: 'test-model',
contents: ['Simple string content'],
};
const result = await tokenizer.calculateTokens(request);
expect(result.totalTokens).toBeGreaterThan(0);
expect(result.breakdown.textTokens).toBeGreaterThan(0);
});
});
describe('image token calculation', () => {
it('should calculate tokens for image content', async () => {
// Create a simple 1x1 PNG image in base64
const pngBase64 =
'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNkYPhfDwAChAI9jU77yQAAAABJRU5ErkJggg==';
const request: CountTokensParameters = {
model: 'test-model',
contents: [
{
role: 'user',
parts: [
{
inlineData: {
mimeType: 'image/png',
data: pngBase64,
},
},
],
},
],
};
const result = await tokenizer.calculateTokens(request);
expect(result.totalTokens).toBeGreaterThanOrEqual(4); // Minimum 4 tokens per image
expect(result.breakdown.imageTokens).toBeGreaterThanOrEqual(4);
expect(result.breakdown.textTokens).toBe(0);
});
it('should handle multiple images', async () => {
const pngBase64 =
'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNkYPhfDwAChAI9jU77yQAAAABJRU5ErkJggg==';
const request: CountTokensParameters = {
model: 'test-model',
contents: [
{
role: 'user',
parts: [
{
inlineData: {
mimeType: 'image/png',
data: pngBase64,
},
},
{
inlineData: {
mimeType: 'image/png',
data: pngBase64,
},
},
],
},
],
};
const result = await tokenizer.calculateTokens(request);
expect(result.totalTokens).toBeGreaterThanOrEqual(8); // At least 4 tokens per image
expect(result.breakdown.imageTokens).toBeGreaterThanOrEqual(8);
});
});
describe('mixed content', () => {
it('should handle text and image content together', async () => {
const pngBase64 =
'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNkYPhfDwAChAI9jU77yQAAAABJRU5ErkJggg==';
const request: CountTokensParameters = {
model: 'test-model',
contents: [
{
role: 'user',
parts: [
{ text: 'Here is an image:' },
{
inlineData: {
mimeType: 'image/png',
data: pngBase64,
},
},
{ text: 'What do you see?' },
],
},
],
};
const result = await tokenizer.calculateTokens(request);
expect(result.totalTokens).toBeGreaterThan(4);
expect(result.breakdown.textTokens).toBeGreaterThan(0);
expect(result.breakdown.imageTokens).toBeGreaterThanOrEqual(4);
});
});
describe('function content', () => {
it('should handle function calls', async () => {
const request: CountTokensParameters = {
model: 'test-model',
contents: [
{
role: 'user',
parts: [
{
functionCall: {
name: 'test_function',
args: { param1: 'value1', param2: 42 },
},
},
],
},
],
};
const result = await tokenizer.calculateTokens(request);
expect(result.totalTokens).toBeGreaterThan(0);
expect(result.breakdown.otherTokens).toBeGreaterThan(0);
});
});
describe('empty content', () => {
it('should handle empty request', async () => {
const request: CountTokensParameters = {
model: 'test-model',
contents: [],
};
const result = await tokenizer.calculateTokens(request);
expect(result.totalTokens).toBe(0);
expect(result.breakdown.textTokens).toBe(0);
expect(result.breakdown.imageTokens).toBe(0);
});
it('should handle undefined contents', async () => {
const request: CountTokensParameters = {
model: 'test-model',
contents: [],
};
const result = await tokenizer.calculateTokens(request);
expect(result.totalTokens).toBe(0);
});
});
describe('configuration', () => {
it('should use custom text encoding', async () => {
const request: CountTokensParameters = {
model: 'test-model',
contents: [
{
role: 'user',
parts: [{ text: 'Test text for encoding' }],
},
],
};
const result = await tokenizer.calculateTokens(request, {
textEncoding: 'cl100k_base',
});
expect(result.totalTokens).toBeGreaterThan(0);
});
it('should process multiple images serially', async () => {
const pngBase64 =
'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNkYPhfDwAChAI9jU77yQAAAABJRU5ErkJggg==';
const request: CountTokensParameters = {
model: 'test-model',
contents: [
{
role: 'user',
parts: Array(10).fill({
inlineData: {
mimeType: 'image/png',
data: pngBase64,
},
}),
},
],
};
const result = await tokenizer.calculateTokens(request);
expect(result.totalTokens).toBeGreaterThanOrEqual(60); // At least 6 tokens per image * 10 images
});
});
describe('error handling', () => {
it('should handle malformed image data gracefully', async () => {
const request: CountTokensParameters = {
model: 'test-model',
contents: [
{
role: 'user',
parts: [
{
inlineData: {
mimeType: 'image/png',
data: 'invalid-base64-data',
},
},
],
},
],
};
const result = await tokenizer.calculateTokens(request);
// Should still return some tokens (fallback to minimum)
expect(result.totalTokens).toBeGreaterThanOrEqual(4);
});
});
});

View File

@@ -0,0 +1,341 @@
/**
* @license
* Copyright 2025 Qwen
* SPDX-License-Identifier: Apache-2.0
*/
import type {
CountTokensParameters,
Content,
Part,
PartUnion,
} from '@google/genai';
import type {
RequestTokenizer,
TokenizerConfig,
TokenCalculationResult,
} from './types.js';
import { TextTokenizer } from './textTokenizer.js';
import { ImageTokenizer } from './imageTokenizer.js';
/**
* Simple request tokenizer that handles text and image content serially
*/
export class DefaultRequestTokenizer implements RequestTokenizer {
private textTokenizer: TextTokenizer;
private imageTokenizer: ImageTokenizer;
constructor() {
this.textTokenizer = new TextTokenizer();
this.imageTokenizer = new ImageTokenizer();
}
/**
* Calculate tokens for a request using serial processing
*/
async calculateTokens(
request: CountTokensParameters,
config: TokenizerConfig = {},
): Promise<TokenCalculationResult> {
const startTime = performance.now();
// Apply configuration
if (config.textEncoding) {
this.textTokenizer = new TextTokenizer(config.textEncoding);
}
try {
// Process request content and group by type
const { textContents, imageContents, audioContents, otherContents } =
this.processAndGroupContents(request);
if (
textContents.length === 0 &&
imageContents.length === 0 &&
audioContents.length === 0 &&
otherContents.length === 0
) {
return {
totalTokens: 0,
breakdown: {
textTokens: 0,
imageTokens: 0,
audioTokens: 0,
otherTokens: 0,
},
processingTime: performance.now() - startTime,
};
}
// Calculate tokens for each content type serially
const textTokens = await this.calculateTextTokens(textContents);
const imageTokens = await this.calculateImageTokens(imageContents);
const audioTokens = await this.calculateAudioTokens(audioContents);
const otherTokens = await this.calculateOtherTokens(otherContents);
const totalTokens = textTokens + imageTokens + audioTokens + otherTokens;
const processingTime = performance.now() - startTime;
return {
totalTokens,
breakdown: {
textTokens,
imageTokens,
audioTokens,
otherTokens,
},
processingTime,
};
} catch (error) {
console.error('Error calculating tokens:', error);
// Fallback calculation
const fallbackTokens = this.calculateFallbackTokens(request);
return {
totalTokens: fallbackTokens,
breakdown: {
textTokens: fallbackTokens,
imageTokens: 0,
audioTokens: 0,
otherTokens: 0,
},
processingTime: performance.now() - startTime,
};
}
}
/**
* Calculate tokens for text contents
*/
private async calculateTextTokens(textContents: string[]): Promise<number> {
if (textContents.length === 0) return 0;
try {
const tokenCounts =
await this.textTokenizer.calculateTokensBatch(textContents);
return tokenCounts.reduce((sum, count) => sum + count, 0);
} catch (error) {
console.warn('Error calculating text tokens:', error);
// Fallback: character-based estimation
const totalChars = textContents.join('').length;
return Math.ceil(totalChars / 4);
}
}
/**
* Calculate tokens for image contents using serial processing
*/
private async calculateImageTokens(
imageContents: Array<{ data: string; mimeType: string }>,
): Promise<number> {
if (imageContents.length === 0) return 0;
try {
const tokenCounts =
await this.imageTokenizer.calculateTokensBatch(imageContents);
return tokenCounts.reduce((sum, count) => sum + count, 0);
} catch (error) {
console.warn('Error calculating image tokens:', error);
// Fallback: minimum tokens per image
return imageContents.length * 6; // 4 image tokens + 2 special tokens as minimum
}
}
/**
* Calculate tokens for audio contents
* TODO: Implement proper audio token calculation
*/
private async calculateAudioTokens(
audioContents: Array<{ data: string; mimeType: string }>,
): Promise<number> {
if (audioContents.length === 0) return 0;
// Placeholder implementation - audio token calculation would depend on
// the specific model's audio processing capabilities
// For now, estimate based on data size
let totalTokens = 0;
for (const audioContent of audioContents) {
try {
const dataSize = Math.floor(audioContent.data.length * 0.75); // Approximate binary size
// Rough estimate: 1 token per 100 bytes of audio data
totalTokens += Math.max(Math.ceil(dataSize / 100), 10); // Minimum 10 tokens per audio
} catch (error) {
console.warn('Error calculating audio tokens:', error);
totalTokens += 10; // Fallback minimum
}
}
return totalTokens;
}
/**
* Calculate tokens for other content types (functions, files, etc.)
*/
private async calculateOtherTokens(otherContents: string[]): Promise<number> {
if (otherContents.length === 0) return 0;
try {
// Treat other content as text for token calculation
const tokenCounts =
await this.textTokenizer.calculateTokensBatch(otherContents);
return tokenCounts.reduce((sum, count) => sum + count, 0);
} catch (error) {
console.warn('Error calculating other content tokens:', error);
// Fallback: character-based estimation
const totalChars = otherContents.join('').length;
return Math.ceil(totalChars / 4);
}
}
/**
* Fallback token calculation using simple string serialization
*/
private calculateFallbackTokens(request: CountTokensParameters): number {
try {
const content = JSON.stringify(request.contents);
return Math.ceil(content.length / 4); // Rough estimate: 1 token ≈ 4 characters
} catch (error) {
console.warn('Error in fallback token calculation:', error);
return 100; // Conservative fallback
}
}
/**
* Process request contents and group by type
*/
private processAndGroupContents(request: CountTokensParameters): {
textContents: string[];
imageContents: Array<{ data: string; mimeType: string }>;
audioContents: Array<{ data: string; mimeType: string }>;
otherContents: string[];
} {
const textContents: string[] = [];
const imageContents: Array<{ data: string; mimeType: string }> = [];
const audioContents: Array<{ data: string; mimeType: string }> = [];
const otherContents: string[] = [];
if (!request.contents) {
return { textContents, imageContents, audioContents, otherContents };
}
const contents = Array.isArray(request.contents)
? request.contents
: [request.contents];
for (const content of contents) {
this.processContent(
content,
textContents,
imageContents,
audioContents,
otherContents,
);
}
return { textContents, imageContents, audioContents, otherContents };
}
/**
* Process a single content item and add to appropriate arrays
*/
private processContent(
content: Content | string | PartUnion,
textContents: string[],
imageContents: Array<{ data: string; mimeType: string }>,
audioContents: Array<{ data: string; mimeType: string }>,
otherContents: string[],
): void {
if (typeof content === 'string') {
if (content.trim()) {
textContents.push(content);
}
return;
}
if ('parts' in content && content.parts) {
for (const part of content.parts) {
this.processPart(
part,
textContents,
imageContents,
audioContents,
otherContents,
);
}
}
}
/**
* Process a single part and add to appropriate arrays
*/
private processPart(
part: Part | string,
textContents: string[],
imageContents: Array<{ data: string; mimeType: string }>,
audioContents: Array<{ data: string; mimeType: string }>,
otherContents: string[],
): void {
if (typeof part === 'string') {
if (part.trim()) {
textContents.push(part);
}
return;
}
if ('text' in part && part.text) {
textContents.push(part.text);
return;
}
if ('inlineData' in part && part.inlineData) {
const { data, mimeType } = part.inlineData;
if (mimeType && mimeType.startsWith('image/')) {
imageContents.push({ data: data || '', mimeType });
return;
}
if (mimeType && mimeType.startsWith('audio/')) {
audioContents.push({ data: data || '', mimeType });
return;
}
}
if ('fileData' in part && part.fileData) {
otherContents.push(JSON.stringify(part.fileData));
return;
}
if ('functionCall' in part && part.functionCall) {
otherContents.push(JSON.stringify(part.functionCall));
return;
}
if ('functionResponse' in part && part.functionResponse) {
otherContents.push(JSON.stringify(part.functionResponse));
return;
}
// Unknown part type - try to serialize
try {
const serialized = JSON.stringify(part);
if (serialized && serialized !== '{}') {
otherContents.push(serialized);
}
} catch (error) {
console.warn('Failed to serialize unknown part type:', error);
}
}
/**
* Dispose of resources
*/
async dispose(): Promise<void> {
try {
// Dispose of tokenizers
this.textTokenizer.dispose();
} catch (error) {
console.warn('Error disposing request tokenizer:', error);
}
}
}

View File

@@ -0,0 +1,56 @@
/**
* @license
* Copyright 2025 Qwen
* SPDX-License-Identifier: Apache-2.0
*/
/**
* Supported image MIME types for vision models
* These formats are supported by the vision model and can be processed by the image tokenizer
*/
export const SUPPORTED_IMAGE_MIME_TYPES = [
'image/bmp',
'image/jpeg',
'image/jpg', // Alternative MIME type for JPEG
'image/png',
'image/tiff',
'image/webp',
'image/heic',
] as const;
/**
* Type for supported image MIME types
*/
export type SupportedImageMimeType =
(typeof SUPPORTED_IMAGE_MIME_TYPES)[number];
/**
* Check if a MIME type is supported for vision processing
* @param mimeType The MIME type to check
* @returns True if the MIME type is supported
*/
export function isSupportedImageMimeType(
mimeType: string,
): mimeType is SupportedImageMimeType {
return SUPPORTED_IMAGE_MIME_TYPES.includes(
mimeType as SupportedImageMimeType,
);
}
/**
* Get a human-readable list of supported image formats
* @returns Comma-separated string of supported formats
*/
export function getSupportedImageFormatsString(): string {
return SUPPORTED_IMAGE_MIME_TYPES.map((type) =>
type.replace('image/', '').toUpperCase(),
).join(', ');
}
/**
* Get warning message for unsupported image formats
* @returns Warning message string
*/
export function getUnsupportedImageFormatWarning(): string {
return `Only the following image formats are supported: ${getSupportedImageFormatsString()}. Other formats may not work as expected.`;
}

View File

@@ -0,0 +1,347 @@
/**
* @license
* Copyright 2025 Qwen
* SPDX-License-Identifier: Apache-2.0
*/
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
import { TextTokenizer } from './textTokenizer.js';
// Mock tiktoken at the top level with hoisted functions
const mockEncode = vi.hoisted(() => vi.fn());
const mockFree = vi.hoisted(() => vi.fn());
const mockGetEncoding = vi.hoisted(() => vi.fn());
vi.mock('tiktoken', () => ({
get_encoding: mockGetEncoding,
}));
describe('TextTokenizer', () => {
let tokenizer: TextTokenizer;
let consoleWarnSpy: ReturnType<typeof vi.spyOn>;
beforeEach(() => {
vi.resetAllMocks();
consoleWarnSpy = vi.spyOn(console, 'warn').mockImplementation(() => {});
// Default mock implementation
mockGetEncoding.mockReturnValue({
encode: mockEncode,
free: mockFree,
});
});
afterEach(() => {
vi.restoreAllMocks();
tokenizer?.dispose();
});
describe('constructor', () => {
it('should create tokenizer with default encoding', () => {
tokenizer = new TextTokenizer();
expect(tokenizer).toBeInstanceOf(TextTokenizer);
});
it('should create tokenizer with custom encoding', () => {
tokenizer = new TextTokenizer('gpt2');
expect(tokenizer).toBeInstanceOf(TextTokenizer);
});
});
describe('calculateTokens', () => {
beforeEach(() => {
tokenizer = new TextTokenizer();
});
it('should return 0 for empty text', async () => {
const result = await tokenizer.calculateTokens('');
expect(result).toBe(0);
});
it('should return 0 for null/undefined text', async () => {
const result1 = await tokenizer.calculateTokens(
null as unknown as string,
);
const result2 = await tokenizer.calculateTokens(
undefined as unknown as string,
);
expect(result1).toBe(0);
expect(result2).toBe(0);
});
it('should calculate tokens using tiktoken when available', async () => {
const testText = 'Hello, world!';
const mockTokens = [1, 2, 3, 4, 5]; // 5 tokens
mockEncode.mockReturnValue(mockTokens);
const result = await tokenizer.calculateTokens(testText);
expect(mockGetEncoding).toHaveBeenCalledWith('cl100k_base');
expect(mockEncode).toHaveBeenCalledWith(testText);
expect(result).toBe(5);
});
it('should use fallback calculation when tiktoken fails to load', async () => {
mockGetEncoding.mockImplementation(() => {
throw new Error('Failed to load tiktoken');
});
const testText = 'Hello, world!'; // 13 characters
const result = await tokenizer.calculateTokens(testText);
expect(consoleWarnSpy).toHaveBeenCalledWith(
'Failed to load tiktoken with encoding cl100k_base:',
expect.any(Error),
);
// Fallback: Math.ceil(13 / 4) = 4
expect(result).toBe(4);
});
it('should use fallback calculation when encoding fails', async () => {
mockEncode.mockImplementation(() => {
throw new Error('Encoding failed');
});
const testText = 'Hello, world!'; // 13 characters
const result = await tokenizer.calculateTokens(testText);
expect(consoleWarnSpy).toHaveBeenCalledWith(
'Error encoding text with tiktoken:',
expect.any(Error),
);
// Fallback: Math.ceil(13 / 4) = 4
expect(result).toBe(4);
});
it('should handle very long text', async () => {
const longText = 'a'.repeat(10000);
const mockTokens = new Array(2500); // 2500 tokens
mockEncode.mockReturnValue(mockTokens);
const result = await tokenizer.calculateTokens(longText);
expect(result).toBe(2500);
});
it('should handle unicode characters', async () => {
const unicodeText = '你好世界 🌍';
const mockTokens = [1, 2, 3, 4, 5, 6];
mockEncode.mockReturnValue(mockTokens);
const result = await tokenizer.calculateTokens(unicodeText);
expect(result).toBe(6);
});
it('should use custom encoding when specified', async () => {
tokenizer = new TextTokenizer('gpt2');
const testText = 'Hello, world!';
const mockTokens = [1, 2, 3];
mockEncode.mockReturnValue(mockTokens);
const result = await tokenizer.calculateTokens(testText);
expect(mockGetEncoding).toHaveBeenCalledWith('gpt2');
expect(result).toBe(3);
});
});
describe('calculateTokensBatch', () => {
beforeEach(() => {
tokenizer = new TextTokenizer();
});
it('should process multiple texts and return token counts', async () => {
const texts = ['Hello', 'world', 'test'];
mockEncode
.mockReturnValueOnce([1, 2]) // 2 tokens for 'Hello'
.mockReturnValueOnce([3, 4, 5]) // 3 tokens for 'world'
.mockReturnValueOnce([6]); // 1 token for 'test'
const result = await tokenizer.calculateTokensBatch(texts);
expect(result).toEqual([2, 3, 1]);
expect(mockEncode).toHaveBeenCalledTimes(3);
});
it('should handle empty array', async () => {
const result = await tokenizer.calculateTokensBatch([]);
expect(result).toEqual([]);
});
it('should handle array with empty strings', async () => {
const texts = ['', 'hello', ''];
mockEncode.mockReturnValue([1, 2, 3]); // Only called for 'hello'
const result = await tokenizer.calculateTokensBatch(texts);
expect(result).toEqual([0, 3, 0]);
expect(mockEncode).toHaveBeenCalledTimes(1);
expect(mockEncode).toHaveBeenCalledWith('hello');
});
it('should use fallback calculation when tiktoken fails to load', async () => {
mockGetEncoding.mockImplementation(() => {
throw new Error('Failed to load tiktoken');
});
const texts = ['Hello', 'world']; // 5 and 5 characters
const result = await tokenizer.calculateTokensBatch(texts);
expect(consoleWarnSpy).toHaveBeenCalledWith(
'Failed to load tiktoken with encoding cl100k_base:',
expect.any(Error),
);
// Fallback: Math.ceil(5/4) = 2 for both
expect(result).toEqual([2, 2]);
});
it('should use fallback calculation when encoding fails during batch processing', async () => {
mockEncode.mockImplementation(() => {
throw new Error('Encoding failed');
});
const texts = ['Hello', 'world']; // 5 and 5 characters
const result = await tokenizer.calculateTokensBatch(texts);
expect(consoleWarnSpy).toHaveBeenCalledWith(
'Error encoding texts with tiktoken:',
expect.any(Error),
);
// Fallback: Math.ceil(5/4) = 2 for both
expect(result).toEqual([2, 2]);
});
it('should handle null and undefined values in batch', async () => {
const texts = [null, 'hello', undefined, 'world'] as unknown as string[];
mockEncode
.mockReturnValueOnce([1, 2, 3]) // 3 tokens for 'hello'
.mockReturnValueOnce([4, 5]); // 2 tokens for 'world'
const result = await tokenizer.calculateTokensBatch(texts);
expect(result).toEqual([0, 3, 0, 2]);
});
});
describe('dispose', () => {
beforeEach(() => {
tokenizer = new TextTokenizer();
});
it('should free tiktoken encoding when disposing', async () => {
// Initialize the encoding by calling calculateTokens
await tokenizer.calculateTokens('test');
tokenizer.dispose();
expect(mockFree).toHaveBeenCalled();
});
it('should handle disposal when encoding is not initialized', () => {
expect(() => tokenizer.dispose()).not.toThrow();
expect(mockFree).not.toHaveBeenCalled();
});
it('should handle disposal when encoding is null', async () => {
// Force encoding to be null by making tiktoken fail
mockGetEncoding.mockImplementation(() => {
throw new Error('Failed to load');
});
await tokenizer.calculateTokens('test');
expect(() => tokenizer.dispose()).not.toThrow();
expect(mockFree).not.toHaveBeenCalled();
});
it('should handle errors during disposal gracefully', async () => {
await tokenizer.calculateTokens('test');
mockFree.mockImplementation(() => {
throw new Error('Free failed');
});
tokenizer.dispose();
expect(consoleWarnSpy).toHaveBeenCalledWith(
'Error freeing tiktoken encoding:',
expect.any(Error),
);
});
it('should allow multiple calls to dispose', async () => {
await tokenizer.calculateTokens('test');
tokenizer.dispose();
tokenizer.dispose(); // Second call should not throw
expect(mockFree).toHaveBeenCalledTimes(1);
});
});
describe('lazy initialization', () => {
beforeEach(() => {
tokenizer = new TextTokenizer();
});
it('should not initialize tiktoken until first use', () => {
expect(mockGetEncoding).not.toHaveBeenCalled();
});
it('should initialize tiktoken on first calculateTokens call', async () => {
await tokenizer.calculateTokens('test');
expect(mockGetEncoding).toHaveBeenCalledTimes(1);
});
it('should not reinitialize tiktoken on subsequent calls', async () => {
await tokenizer.calculateTokens('test1');
await tokenizer.calculateTokens('test2');
expect(mockGetEncoding).toHaveBeenCalledTimes(1);
});
it('should initialize tiktoken on first calculateTokensBatch call', async () => {
await tokenizer.calculateTokensBatch(['test']);
expect(mockGetEncoding).toHaveBeenCalledTimes(1);
});
});
describe('edge cases', () => {
beforeEach(() => {
tokenizer = new TextTokenizer();
});
it('should handle very short text', async () => {
const result = await tokenizer.calculateTokens('a');
if (mockGetEncoding.mock.calls.length > 0) {
// If tiktoken was called, use its result
expect(mockEncode).toHaveBeenCalledWith('a');
} else {
// If tiktoken failed, should use fallback: Math.ceil(1/4) = 1
expect(result).toBe(1);
}
});
it('should handle text with only whitespace', async () => {
const whitespaceText = ' \n\t ';
const mockTokens = [1];
mockEncode.mockReturnValue(mockTokens);
const result = await tokenizer.calculateTokens(whitespaceText);
expect(result).toBe(1);
});
it('should handle special characters and symbols', async () => {
const specialText = '!@#$%^&*()_+-=[]{}|;:,.<>?';
const mockTokens = new Array(10);
mockEncode.mockReturnValue(mockTokens);
const result = await tokenizer.calculateTokens(specialText);
expect(result).toBe(10);
});
});
});

View File

@@ -0,0 +1,97 @@
/**
* @license
* Copyright 2025 Qwen
* SPDX-License-Identifier: Apache-2.0
*/
import type { TiktokenEncoding, Tiktoken } from 'tiktoken';
import { get_encoding } from 'tiktoken';
/**
* Text tokenizer for calculating text tokens using tiktoken
*/
export class TextTokenizer {
private encoding: Tiktoken | null = null;
private encodingName: string;
constructor(encodingName: string = 'cl100k_base') {
this.encodingName = encodingName;
}
/**
* Initialize the tokenizer (lazy loading)
*/
private async ensureEncoding(): Promise<void> {
if (this.encoding) return;
try {
// Use type assertion since we know the encoding name is valid
this.encoding = get_encoding(this.encodingName as TiktokenEncoding);
} catch (error) {
console.warn(
`Failed to load tiktoken with encoding ${this.encodingName}:`,
error,
);
this.encoding = null;
}
}
/**
* Calculate tokens for text content
*/
async calculateTokens(text: string): Promise<number> {
if (!text) return 0;
await this.ensureEncoding();
if (this.encoding) {
try {
return this.encoding.encode(text).length;
} catch (error) {
console.warn('Error encoding text with tiktoken:', error);
}
}
// Fallback: rough approximation using character count
// This is a conservative estimate: 1 token ≈ 4 characters for most languages
return Math.ceil(text.length / 4);
}
/**
* Calculate tokens for multiple text strings in parallel
*/
async calculateTokensBatch(texts: string[]): Promise<number[]> {
await this.ensureEncoding();
if (this.encoding) {
try {
return texts.map((text) => {
if (!text) return 0;
// this.encoding may be null, add a null check to satisfy lint
return this.encoding ? this.encoding.encode(text).length : 0;
});
} catch (error) {
console.warn('Error encoding texts with tiktoken:', error);
// In case of error, return fallback estimation for all texts
return texts.map((text) => Math.ceil((text || '').length / 4));
}
}
// Fallback for batch processing
return texts.map((text) => Math.ceil((text || '').length / 4));
}
/**
* Dispose of resources
*/
dispose(): void {
if (this.encoding) {
try {
this.encoding.free();
} catch (error) {
console.warn('Error freeing tiktoken encoding:', error);
}
this.encoding = null;
}
}
}

View File

@@ -0,0 +1,64 @@
/**
* @license
* Copyright 2025 Qwen
* SPDX-License-Identifier: Apache-2.0
*/
import type { CountTokensParameters } from '@google/genai';
/**
* Token calculation result for different content types
*/
export interface TokenCalculationResult {
/** Total tokens calculated */
totalTokens: number;
/** Breakdown by content type */
breakdown: {
textTokens: number;
imageTokens: number;
audioTokens: number;
otherTokens: number;
};
/** Processing time in milliseconds */
processingTime: number;
}
/**
* Configuration for token calculation
*/
export interface TokenizerConfig {
/** Custom text tokenizer encoding (defaults to cl100k_base) */
textEncoding?: string;
}
/**
* Image metadata extracted from base64 data
*/
export interface ImageMetadata {
/** Image width in pixels */
width: number;
/** Image height in pixels */
height: number;
/** MIME type of the image */
mimeType: string;
/** Size of the base64 data in bytes */
dataSize: number;
}
/**
* Request tokenizer interface
*/
export interface RequestTokenizer {
/**
* Calculate tokens for a request
*/
calculateTokens(
request: CountTokensParameters,
config?: TokenizerConfig,
): Promise<TokenCalculationResult>;
/**
* Dispose of resources (worker threads, etc.)
*/
dispose(): Promise<void>;
}