mirror of
https://github.com/QwenLM/qwen-code.git
synced 2025-12-19 09:33:53 +00:00
Vision model support for Qwen-OAuth (#525)
* refactor: openaiContentGenerator * refactor: optimize stream handling * refactor: re-organize refactored files * fix: unit test cases * feat: `/model` command for switching to vision model * fix: lint error * feat: add image tokenizer to fit vlm context window * fix: lint and type errors * feat: add `visionModelPreview` to control default visibility of vision models * fix: remove deprecated files * fix: align supported image formats with bailian doc
This commit is contained in:
@@ -741,6 +741,16 @@ export const SETTINGS_SCHEMA = {
|
||||
description: 'Enable extension management features.',
|
||||
showInDialog: false,
|
||||
},
|
||||
visionModelPreview: {
|
||||
type: 'boolean',
|
||||
label: 'Vision Model Preview',
|
||||
category: 'Experimental',
|
||||
requiresRestart: false,
|
||||
default: false,
|
||||
description:
|
||||
'Enable vision model support and auto-switching functionality. When disabled, vision models like qwen-vl-max-latest will be hidden and auto-switching will not occur.',
|
||||
showInDialog: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
|
||||
|
||||
@@ -56,6 +56,13 @@ vi.mock('../ui/commands/mcpCommand.js', () => ({
|
||||
kind: 'BUILT_IN',
|
||||
},
|
||||
}));
|
||||
vi.mock('../ui/commands/modelCommand.js', () => ({
|
||||
modelCommand: {
|
||||
name: 'model',
|
||||
description: 'Model command',
|
||||
kind: 'BUILT_IN',
|
||||
},
|
||||
}));
|
||||
|
||||
describe('BuiltinCommandLoader', () => {
|
||||
let mockConfig: Config;
|
||||
@@ -126,5 +133,8 @@ describe('BuiltinCommandLoader', () => {
|
||||
|
||||
const mcpCmd = commands.find((c) => c.name === 'mcp');
|
||||
expect(mcpCmd).toBeDefined();
|
||||
|
||||
const modelCmd = commands.find((c) => c.name === 'model');
|
||||
expect(modelCmd).toBeDefined();
|
||||
});
|
||||
});
|
||||
|
||||
@@ -35,6 +35,7 @@ import { settingsCommand } from '../ui/commands/settingsCommand.js';
|
||||
import { vimCommand } from '../ui/commands/vimCommand.js';
|
||||
import { setupGithubCommand } from '../ui/commands/setupGithubCommand.js';
|
||||
import { terminalSetupCommand } from '../ui/commands/terminalSetupCommand.js';
|
||||
import { modelCommand } from '../ui/commands/modelCommand.js';
|
||||
import { agentsCommand } from '../ui/commands/agentsCommand.js';
|
||||
|
||||
/**
|
||||
@@ -71,6 +72,7 @@ export class BuiltinCommandLoader implements ICommandLoader {
|
||||
initCommand,
|
||||
mcpCommand,
|
||||
memoryCommand,
|
||||
modelCommand,
|
||||
privacyCommand,
|
||||
quitCommand,
|
||||
quitConfirmCommand,
|
||||
|
||||
@@ -53,6 +53,17 @@ import { FolderTrustDialog } from './components/FolderTrustDialog.js';
|
||||
import { ShellConfirmationDialog } from './components/ShellConfirmationDialog.js';
|
||||
import { QuitConfirmationDialog } from './components/QuitConfirmationDialog.js';
|
||||
import { RadioButtonSelect } from './components/shared/RadioButtonSelect.js';
|
||||
import { ModelSelectionDialog } from './components/ModelSelectionDialog.js';
|
||||
import {
|
||||
ModelSwitchDialog,
|
||||
type VisionSwitchOutcome,
|
||||
} from './components/ModelSwitchDialog.js';
|
||||
import {
|
||||
getOpenAIAvailableModelFromEnv,
|
||||
getFilteredQwenModels,
|
||||
type AvailableModel,
|
||||
} from './models/availableModels.js';
|
||||
import { processVisionSwitchOutcome } from './hooks/useVisionAutoSwitch.js';
|
||||
import {
|
||||
AgentCreationWizard,
|
||||
AgentsManagerDialog,
|
||||
@@ -248,6 +259,20 @@ const App = ({ config, settings, startupWarnings = [], version }: AppProps) => {
|
||||
onWorkspaceMigrationDialogClose,
|
||||
} = useWorkspaceMigration(settings);
|
||||
|
||||
// Model selection dialog states
|
||||
const [isModelSelectionDialogOpen, setIsModelSelectionDialogOpen] =
|
||||
useState(false);
|
||||
const [isVisionSwitchDialogOpen, setIsVisionSwitchDialogOpen] =
|
||||
useState(false);
|
||||
const [visionSwitchResolver, setVisionSwitchResolver] = useState<{
|
||||
resolve: (result: {
|
||||
modelOverride?: string;
|
||||
persistSessionModel?: string;
|
||||
showGuidance?: boolean;
|
||||
}) => void;
|
||||
reject: () => void;
|
||||
} | null>(null);
|
||||
|
||||
useEffect(() => {
|
||||
const unsubscribe = ideContext.subscribeToIdeContext(setIdeContextState);
|
||||
// Set the initial value
|
||||
@@ -590,6 +615,75 @@ const App = ({ config, settings, startupWarnings = [], version }: AppProps) => {
|
||||
openAuthDialog();
|
||||
}, [openAuthDialog, setAuthError]);
|
||||
|
||||
// Vision switch handler for auto-switch functionality
|
||||
const handleVisionSwitchRequired = useCallback(
|
||||
async (_query: unknown) =>
|
||||
new Promise<{
|
||||
modelOverride?: string;
|
||||
persistSessionModel?: string;
|
||||
showGuidance?: boolean;
|
||||
}>((resolve, reject) => {
|
||||
setVisionSwitchResolver({ resolve, reject });
|
||||
setIsVisionSwitchDialogOpen(true);
|
||||
}),
|
||||
[],
|
||||
);
|
||||
|
||||
const handleVisionSwitchSelect = useCallback(
|
||||
(outcome: VisionSwitchOutcome) => {
|
||||
setIsVisionSwitchDialogOpen(false);
|
||||
if (visionSwitchResolver) {
|
||||
const result = processVisionSwitchOutcome(outcome);
|
||||
visionSwitchResolver.resolve(result);
|
||||
setVisionSwitchResolver(null);
|
||||
}
|
||||
},
|
||||
[visionSwitchResolver],
|
||||
);
|
||||
|
||||
const handleModelSelectionOpen = useCallback(() => {
|
||||
setIsModelSelectionDialogOpen(true);
|
||||
}, []);
|
||||
|
||||
const handleModelSelectionClose = useCallback(() => {
|
||||
setIsModelSelectionDialogOpen(false);
|
||||
}, []);
|
||||
|
||||
const handleModelSelect = useCallback(
|
||||
(modelId: string) => {
|
||||
config.setModel(modelId);
|
||||
setCurrentModel(modelId);
|
||||
setIsModelSelectionDialogOpen(false);
|
||||
addItem(
|
||||
{
|
||||
type: MessageType.INFO,
|
||||
text: `Switched model to \`${modelId}\` for this session.`,
|
||||
},
|
||||
Date.now(),
|
||||
);
|
||||
},
|
||||
[config, setCurrentModel, addItem],
|
||||
);
|
||||
|
||||
const getAvailableModelsForCurrentAuth = useCallback((): AvailableModel[] => {
|
||||
const contentGeneratorConfig = config.getContentGeneratorConfig();
|
||||
if (!contentGeneratorConfig) return [];
|
||||
|
||||
const visionModelPreviewEnabled =
|
||||
settings.merged.experimental?.visionModelPreview ?? false;
|
||||
|
||||
switch (contentGeneratorConfig.authType) {
|
||||
case AuthType.QWEN_OAUTH:
|
||||
return getFilteredQwenModels(visionModelPreviewEnabled);
|
||||
case AuthType.USE_OPENAI: {
|
||||
const openAIModel = getOpenAIAvailableModelFromEnv();
|
||||
return openAIModel ? [openAIModel] : [];
|
||||
}
|
||||
default:
|
||||
return [];
|
||||
}
|
||||
}, [config, settings.merged.experimental?.visionModelPreview]);
|
||||
|
||||
// Core hooks and processors
|
||||
const {
|
||||
vimEnabled: vimModeEnabled,
|
||||
@@ -620,6 +714,7 @@ const App = ({ config, settings, startupWarnings = [], version }: AppProps) => {
|
||||
setQuittingMessages,
|
||||
openPrivacyNotice,
|
||||
openSettingsDialog,
|
||||
handleModelSelectionOpen,
|
||||
openSubagentCreateDialog,
|
||||
openAgentsManagerDialog,
|
||||
toggleVimEnabled,
|
||||
@@ -664,6 +759,8 @@ const App = ({ config, settings, startupWarnings = [], version }: AppProps) => {
|
||||
setModelSwitchedFromQuotaError,
|
||||
refreshStatic,
|
||||
() => cancelHandlerRef.current(),
|
||||
settings.merged.experimental?.visionModelPreview ?? false,
|
||||
handleVisionSwitchRequired,
|
||||
);
|
||||
|
||||
const pendingHistoryItems = useMemo(
|
||||
@@ -1034,6 +1131,8 @@ const App = ({ config, settings, startupWarnings = [], version }: AppProps) => {
|
||||
!isAuthDialogOpen &&
|
||||
!isThemeDialogOpen &&
|
||||
!isEditorDialogOpen &&
|
||||
!isModelSelectionDialogOpen &&
|
||||
!isVisionSwitchDialogOpen &&
|
||||
!isSubagentCreateDialogOpen &&
|
||||
!showPrivacyNotice &&
|
||||
!showWelcomeBackDialog &&
|
||||
@@ -1055,6 +1154,8 @@ const App = ({ config, settings, startupWarnings = [], version }: AppProps) => {
|
||||
showWelcomeBackDialog,
|
||||
welcomeBackChoice,
|
||||
geminiClient,
|
||||
isModelSelectionDialogOpen,
|
||||
isVisionSwitchDialogOpen,
|
||||
]);
|
||||
|
||||
if (quittingMessages) {
|
||||
@@ -1322,6 +1423,15 @@ const App = ({ config, settings, startupWarnings = [], version }: AppProps) => {
|
||||
onExit={exitEditorDialog}
|
||||
/>
|
||||
</Box>
|
||||
) : isModelSelectionDialogOpen ? (
|
||||
<ModelSelectionDialog
|
||||
availableModels={getAvailableModelsForCurrentAuth()}
|
||||
currentModel={currentModel}
|
||||
onSelect={handleModelSelect}
|
||||
onCancel={handleModelSelectionClose}
|
||||
/>
|
||||
) : isVisionSwitchDialogOpen ? (
|
||||
<ModelSwitchDialog onSelect={handleVisionSwitchSelect} />
|
||||
) : showPrivacyNotice ? (
|
||||
<PrivacyNotice
|
||||
onExit={() => setShowPrivacyNotice(false)}
|
||||
|
||||
179
packages/cli/src/ui/commands/modelCommand.test.ts
Normal file
179
packages/cli/src/ui/commands/modelCommand.test.ts
Normal file
@@ -0,0 +1,179 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Qwen
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { describe, it, expect, beforeEach, vi } from 'vitest';
|
||||
import { modelCommand } from './modelCommand.js';
|
||||
import { type CommandContext } from './types.js';
|
||||
import { createMockCommandContext } from '../../test-utils/mockCommandContext.js';
|
||||
import {
|
||||
AuthType,
|
||||
type ContentGeneratorConfig,
|
||||
type Config,
|
||||
} from '@qwen-code/qwen-code-core';
|
||||
import * as availableModelsModule from '../models/availableModels.js';
|
||||
|
||||
// Mock the availableModels module
|
||||
vi.mock('../models/availableModels.js', () => ({
|
||||
AVAILABLE_MODELS_QWEN: [
|
||||
{ id: 'qwen3-coder-plus', label: 'qwen3-coder-plus' },
|
||||
{ id: 'qwen-vl-max-latest', label: 'qwen-vl-max', isVision: true },
|
||||
],
|
||||
getOpenAIAvailableModelFromEnv: vi.fn(),
|
||||
}));
|
||||
|
||||
// Helper function to create a mock config
|
||||
function createMockConfig(
|
||||
contentGeneratorConfig: ContentGeneratorConfig | null,
|
||||
): Partial<Config> {
|
||||
return {
|
||||
getContentGeneratorConfig: vi.fn().mockReturnValue(contentGeneratorConfig),
|
||||
};
|
||||
}
|
||||
|
||||
describe('modelCommand', () => {
|
||||
let mockContext: CommandContext;
|
||||
const mockGetOpenAIAvailableModelFromEnv = vi.mocked(
|
||||
availableModelsModule.getOpenAIAvailableModelFromEnv,
|
||||
);
|
||||
|
||||
beforeEach(() => {
|
||||
mockContext = createMockCommandContext();
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
it('should have the correct name and description', () => {
|
||||
expect(modelCommand.name).toBe('model');
|
||||
expect(modelCommand.description).toBe('Switch the model for this session');
|
||||
});
|
||||
|
||||
it('should return error when config is not available', async () => {
|
||||
mockContext.services.config = null;
|
||||
|
||||
const result = await modelCommand.action!(mockContext, '');
|
||||
|
||||
expect(result).toEqual({
|
||||
type: 'message',
|
||||
messageType: 'error',
|
||||
content: 'Configuration not available.',
|
||||
});
|
||||
});
|
||||
|
||||
it('should return error when content generator config is not available', async () => {
|
||||
const mockConfig = createMockConfig(null);
|
||||
mockContext.services.config = mockConfig as Config;
|
||||
|
||||
const result = await modelCommand.action!(mockContext, '');
|
||||
|
||||
expect(result).toEqual({
|
||||
type: 'message',
|
||||
messageType: 'error',
|
||||
content: 'Content generator configuration not available.',
|
||||
});
|
||||
});
|
||||
|
||||
it('should return error when auth type is not available', async () => {
|
||||
const mockConfig = createMockConfig({
|
||||
model: 'test-model',
|
||||
authType: undefined,
|
||||
});
|
||||
mockContext.services.config = mockConfig as Config;
|
||||
|
||||
const result = await modelCommand.action!(mockContext, '');
|
||||
|
||||
expect(result).toEqual({
|
||||
type: 'message',
|
||||
messageType: 'error',
|
||||
content: 'Authentication type not available.',
|
||||
});
|
||||
});
|
||||
|
||||
it('should return dialog action for QWEN_OAUTH auth type', async () => {
|
||||
const mockConfig = createMockConfig({
|
||||
model: 'test-model',
|
||||
authType: AuthType.QWEN_OAUTH,
|
||||
});
|
||||
mockContext.services.config = mockConfig as Config;
|
||||
|
||||
const result = await modelCommand.action!(mockContext, '');
|
||||
|
||||
expect(result).toEqual({
|
||||
type: 'dialog',
|
||||
dialog: 'model',
|
||||
});
|
||||
});
|
||||
|
||||
it('should return dialog action for USE_OPENAI auth type when model is available', async () => {
|
||||
mockGetOpenAIAvailableModelFromEnv.mockReturnValue({
|
||||
id: 'gpt-4',
|
||||
label: 'gpt-4',
|
||||
});
|
||||
|
||||
const mockConfig = createMockConfig({
|
||||
model: 'test-model',
|
||||
authType: AuthType.USE_OPENAI,
|
||||
});
|
||||
mockContext.services.config = mockConfig as Config;
|
||||
|
||||
const result = await modelCommand.action!(mockContext, '');
|
||||
|
||||
expect(result).toEqual({
|
||||
type: 'dialog',
|
||||
dialog: 'model',
|
||||
});
|
||||
});
|
||||
|
||||
it('should return error for USE_OPENAI auth type when no model is available', async () => {
|
||||
mockGetOpenAIAvailableModelFromEnv.mockReturnValue(null);
|
||||
|
||||
const mockConfig = createMockConfig({
|
||||
model: 'test-model',
|
||||
authType: AuthType.USE_OPENAI,
|
||||
});
|
||||
mockContext.services.config = mockConfig as Config;
|
||||
|
||||
const result = await modelCommand.action!(mockContext, '');
|
||||
|
||||
expect(result).toEqual({
|
||||
type: 'message',
|
||||
messageType: 'error',
|
||||
content:
|
||||
'No models available for the current authentication type (openai).',
|
||||
});
|
||||
});
|
||||
|
||||
it('should return error for unsupported auth types', async () => {
|
||||
const mockConfig = createMockConfig({
|
||||
model: 'test-model',
|
||||
authType: 'UNSUPPORTED_AUTH_TYPE' as AuthType,
|
||||
});
|
||||
mockContext.services.config = mockConfig as Config;
|
||||
|
||||
const result = await modelCommand.action!(mockContext, '');
|
||||
|
||||
expect(result).toEqual({
|
||||
type: 'message',
|
||||
messageType: 'error',
|
||||
content:
|
||||
'No models available for the current authentication type (UNSUPPORTED_AUTH_TYPE).',
|
||||
});
|
||||
});
|
||||
|
||||
it('should handle undefined auth type', async () => {
|
||||
const mockConfig = createMockConfig({
|
||||
model: 'test-model',
|
||||
authType: undefined,
|
||||
});
|
||||
mockContext.services.config = mockConfig as Config;
|
||||
|
||||
const result = await modelCommand.action!(mockContext, '');
|
||||
|
||||
expect(result).toEqual({
|
||||
type: 'message',
|
||||
messageType: 'error',
|
||||
content: 'Authentication type not available.',
|
||||
});
|
||||
});
|
||||
});
|
||||
88
packages/cli/src/ui/commands/modelCommand.ts
Normal file
88
packages/cli/src/ui/commands/modelCommand.ts
Normal file
@@ -0,0 +1,88 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Qwen
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { AuthType } from '@qwen-code/qwen-code-core';
|
||||
import type {
|
||||
SlashCommand,
|
||||
CommandContext,
|
||||
OpenDialogActionReturn,
|
||||
MessageActionReturn,
|
||||
} from './types.js';
|
||||
import { CommandKind } from './types.js';
|
||||
import {
|
||||
AVAILABLE_MODELS_QWEN,
|
||||
getOpenAIAvailableModelFromEnv,
|
||||
type AvailableModel,
|
||||
} from '../models/availableModels.js';
|
||||
|
||||
function getAvailableModelsForAuthType(authType: AuthType): AvailableModel[] {
|
||||
switch (authType) {
|
||||
case AuthType.QWEN_OAUTH:
|
||||
return AVAILABLE_MODELS_QWEN;
|
||||
case AuthType.USE_OPENAI: {
|
||||
const openAIModel = getOpenAIAvailableModelFromEnv();
|
||||
return openAIModel ? [openAIModel] : [];
|
||||
}
|
||||
default:
|
||||
// For other auth types, return empty array for now
|
||||
// This can be expanded later according to the design doc
|
||||
return [];
|
||||
}
|
||||
}
|
||||
|
||||
export const modelCommand: SlashCommand = {
|
||||
name: 'model',
|
||||
description: 'Switch the model for this session',
|
||||
kind: CommandKind.BUILT_IN,
|
||||
action: async (
|
||||
context: CommandContext,
|
||||
): Promise<OpenDialogActionReturn | MessageActionReturn> => {
|
||||
const { services } = context;
|
||||
const { config } = services;
|
||||
|
||||
if (!config) {
|
||||
return {
|
||||
type: 'message',
|
||||
messageType: 'error',
|
||||
content: 'Configuration not available.',
|
||||
};
|
||||
}
|
||||
|
||||
const contentGeneratorConfig = config.getContentGeneratorConfig();
|
||||
if (!contentGeneratorConfig) {
|
||||
return {
|
||||
type: 'message',
|
||||
messageType: 'error',
|
||||
content: 'Content generator configuration not available.',
|
||||
};
|
||||
}
|
||||
|
||||
const authType = contentGeneratorConfig.authType;
|
||||
if (!authType) {
|
||||
return {
|
||||
type: 'message',
|
||||
messageType: 'error',
|
||||
content: 'Authentication type not available.',
|
||||
};
|
||||
}
|
||||
|
||||
const availableModels = getAvailableModelsForAuthType(authType);
|
||||
|
||||
if (availableModels.length === 0) {
|
||||
return {
|
||||
type: 'message',
|
||||
messageType: 'error',
|
||||
content: `No models available for the current authentication type (${authType}).`,
|
||||
};
|
||||
}
|
||||
|
||||
// Trigger model selection dialog
|
||||
return {
|
||||
type: 'dialog',
|
||||
dialog: 'model',
|
||||
};
|
||||
},
|
||||
};
|
||||
@@ -116,6 +116,7 @@ export interface OpenDialogActionReturn {
|
||||
| 'editor'
|
||||
| 'privacy'
|
||||
| 'settings'
|
||||
| 'model'
|
||||
| 'subagent_create'
|
||||
| 'subagent_list';
|
||||
}
|
||||
|
||||
246
packages/cli/src/ui/components/ModelSelectionDialog.test.tsx
Normal file
246
packages/cli/src/ui/components/ModelSelectionDialog.test.tsx
Normal file
@@ -0,0 +1,246 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Qwen
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import React from 'react';
|
||||
import { render } from 'ink-testing-library';
|
||||
import { describe, it, expect, vi, beforeEach } from 'vitest';
|
||||
import { ModelSelectionDialog } from './ModelSelectionDialog.js';
|
||||
import type { AvailableModel } from '../models/availableModels.js';
|
||||
import type { RadioSelectItem } from './shared/RadioButtonSelect.js';
|
||||
|
||||
// Mock the useKeypress hook
|
||||
const mockUseKeypress = vi.hoisted(() => vi.fn());
|
||||
vi.mock('../hooks/useKeypress.js', () => ({
|
||||
useKeypress: mockUseKeypress,
|
||||
}));
|
||||
|
||||
// Mock the RadioButtonSelect component
|
||||
const mockRadioButtonSelect = vi.hoisted(() => vi.fn());
|
||||
vi.mock('./shared/RadioButtonSelect.js', () => ({
|
||||
RadioButtonSelect: mockRadioButtonSelect,
|
||||
}));
|
||||
|
||||
describe('ModelSelectionDialog', () => {
|
||||
const mockAvailableModels: AvailableModel[] = [
|
||||
{ id: 'qwen3-coder-plus', label: 'qwen3-coder-plus' },
|
||||
{ id: 'qwen-vl-max-latest', label: 'qwen-vl-max', isVision: true },
|
||||
{ id: 'gpt-4', label: 'GPT-4' },
|
||||
];
|
||||
|
||||
const mockOnSelect = vi.fn();
|
||||
const mockOnCancel = vi.fn();
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
|
||||
// Mock RadioButtonSelect to return a simple div
|
||||
mockRadioButtonSelect.mockReturnValue(
|
||||
React.createElement('div', { 'data-testid': 'radio-select' }),
|
||||
);
|
||||
});
|
||||
|
||||
it('should setup escape key handler to call onCancel', () => {
|
||||
render(
|
||||
<ModelSelectionDialog
|
||||
availableModels={mockAvailableModels}
|
||||
currentModel="qwen3-coder-plus"
|
||||
onSelect={mockOnSelect}
|
||||
onCancel={mockOnCancel}
|
||||
/>,
|
||||
);
|
||||
|
||||
expect(mockUseKeypress).toHaveBeenCalledWith(expect.any(Function), {
|
||||
isActive: true,
|
||||
});
|
||||
|
||||
// Simulate escape key press
|
||||
const keypressHandler = mockUseKeypress.mock.calls[0][0];
|
||||
keypressHandler({ name: 'escape' });
|
||||
|
||||
expect(mockOnCancel).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should not call onCancel for non-escape keys', () => {
|
||||
render(
|
||||
<ModelSelectionDialog
|
||||
availableModels={mockAvailableModels}
|
||||
currentModel="qwen3-coder-plus"
|
||||
onSelect={mockOnSelect}
|
||||
onCancel={mockOnCancel}
|
||||
/>,
|
||||
);
|
||||
|
||||
const keypressHandler = mockUseKeypress.mock.calls[0][0];
|
||||
keypressHandler({ name: 'enter' });
|
||||
|
||||
expect(mockOnCancel).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should set correct initial index for current model', () => {
|
||||
render(
|
||||
<ModelSelectionDialog
|
||||
availableModels={mockAvailableModels}
|
||||
currentModel="qwen-vl-max-latest"
|
||||
onSelect={mockOnSelect}
|
||||
onCancel={mockOnCancel}
|
||||
/>,
|
||||
);
|
||||
|
||||
const callArgs = mockRadioButtonSelect.mock.calls[0][0];
|
||||
expect(callArgs.initialIndex).toBe(1); // qwen-vl-max-latest is at index 1
|
||||
});
|
||||
|
||||
it('should set initial index to 0 when current model is not found', () => {
|
||||
render(
|
||||
<ModelSelectionDialog
|
||||
availableModels={mockAvailableModels}
|
||||
currentModel="non-existent-model"
|
||||
onSelect={mockOnSelect}
|
||||
onCancel={mockOnCancel}
|
||||
/>,
|
||||
);
|
||||
|
||||
const callArgs = mockRadioButtonSelect.mock.calls[0][0];
|
||||
expect(callArgs.initialIndex).toBe(0);
|
||||
});
|
||||
|
||||
it('should call onSelect when a model is selected', () => {
|
||||
render(
|
||||
<ModelSelectionDialog
|
||||
availableModels={mockAvailableModels}
|
||||
currentModel="qwen3-coder-plus"
|
||||
onSelect={mockOnSelect}
|
||||
onCancel={mockOnCancel}
|
||||
/>,
|
||||
);
|
||||
|
||||
const callArgs = mockRadioButtonSelect.mock.calls[0][0];
|
||||
expect(typeof callArgs.onSelect).toBe('function');
|
||||
|
||||
// Simulate selection
|
||||
const onSelectCallback = mockRadioButtonSelect.mock.calls[0][0].onSelect;
|
||||
onSelectCallback('qwen-vl-max-latest');
|
||||
|
||||
expect(mockOnSelect).toHaveBeenCalledWith('qwen-vl-max-latest');
|
||||
});
|
||||
|
||||
it('should handle empty models array', () => {
|
||||
render(
|
||||
<ModelSelectionDialog
|
||||
availableModels={[]}
|
||||
currentModel=""
|
||||
onSelect={mockOnSelect}
|
||||
onCancel={mockOnCancel}
|
||||
/>,
|
||||
);
|
||||
|
||||
const callArgs = mockRadioButtonSelect.mock.calls[0][0];
|
||||
expect(callArgs.items).toEqual([]);
|
||||
expect(callArgs.initialIndex).toBe(0);
|
||||
});
|
||||
|
||||
it('should create correct option items with proper labels', () => {
|
||||
render(
|
||||
<ModelSelectionDialog
|
||||
availableModels={mockAvailableModels}
|
||||
currentModel="qwen3-coder-plus"
|
||||
onSelect={mockOnSelect}
|
||||
onCancel={mockOnCancel}
|
||||
/>,
|
||||
);
|
||||
|
||||
const expectedItems = [
|
||||
{
|
||||
label: 'qwen3-coder-plus (current)',
|
||||
value: 'qwen3-coder-plus',
|
||||
},
|
||||
{
|
||||
label: 'qwen-vl-max [Vision]',
|
||||
value: 'qwen-vl-max-latest',
|
||||
},
|
||||
{
|
||||
label: 'GPT-4',
|
||||
value: 'gpt-4',
|
||||
},
|
||||
];
|
||||
|
||||
const callArgs = mockRadioButtonSelect.mock.calls[0][0];
|
||||
expect(callArgs.items).toEqual(expectedItems);
|
||||
});
|
||||
|
||||
it('should show vision indicator for vision models', () => {
|
||||
render(
|
||||
<ModelSelectionDialog
|
||||
availableModels={mockAvailableModels}
|
||||
currentModel="gpt-4"
|
||||
onSelect={mockOnSelect}
|
||||
onCancel={mockOnCancel}
|
||||
/>,
|
||||
);
|
||||
|
||||
const callArgs = mockRadioButtonSelect.mock.calls[0][0];
|
||||
const visionModelItem = callArgs.items.find(
|
||||
(item: RadioSelectItem<string>) => item.value === 'qwen-vl-max-latest',
|
||||
);
|
||||
|
||||
expect(visionModelItem?.label).toContain('[Vision]');
|
||||
});
|
||||
|
||||
it('should show current indicator for the current model', () => {
|
||||
render(
|
||||
<ModelSelectionDialog
|
||||
availableModels={mockAvailableModels}
|
||||
currentModel="qwen-vl-max-latest"
|
||||
onSelect={mockOnSelect}
|
||||
onCancel={mockOnCancel}
|
||||
/>,
|
||||
);
|
||||
|
||||
const callArgs = mockRadioButtonSelect.mock.calls[0][0];
|
||||
const currentModelItem = callArgs.items.find(
|
||||
(item: RadioSelectItem<string>) => item.value === 'qwen-vl-max-latest',
|
||||
);
|
||||
|
||||
expect(currentModelItem?.label).toContain('(current)');
|
||||
});
|
||||
|
||||
it('should pass isFocused prop to RadioButtonSelect', () => {
|
||||
render(
|
||||
<ModelSelectionDialog
|
||||
availableModels={mockAvailableModels}
|
||||
currentModel="qwen3-coder-plus"
|
||||
onSelect={mockOnSelect}
|
||||
onCancel={mockOnCancel}
|
||||
/>,
|
||||
);
|
||||
|
||||
const callArgs = mockRadioButtonSelect.mock.calls[0][0];
|
||||
expect(callArgs.isFocused).toBe(true);
|
||||
});
|
||||
|
||||
it('should handle multiple onSelect calls correctly', () => {
|
||||
render(
|
||||
<ModelSelectionDialog
|
||||
availableModels={mockAvailableModels}
|
||||
currentModel="qwen3-coder-plus"
|
||||
onSelect={mockOnSelect}
|
||||
onCancel={mockOnCancel}
|
||||
/>,
|
||||
);
|
||||
|
||||
const onSelectCallback = mockRadioButtonSelect.mock.calls[0][0].onSelect;
|
||||
|
||||
// Call multiple times
|
||||
onSelectCallback('qwen3-coder-plus');
|
||||
onSelectCallback('qwen-vl-max-latest');
|
||||
onSelectCallback('gpt-4');
|
||||
|
||||
expect(mockOnSelect).toHaveBeenCalledTimes(3);
|
||||
expect(mockOnSelect).toHaveBeenNthCalledWith(1, 'qwen3-coder-plus');
|
||||
expect(mockOnSelect).toHaveBeenNthCalledWith(2, 'qwen-vl-max-latest');
|
||||
expect(mockOnSelect).toHaveBeenNthCalledWith(3, 'gpt-4');
|
||||
});
|
||||
});
|
||||
87
packages/cli/src/ui/components/ModelSelectionDialog.tsx
Normal file
87
packages/cli/src/ui/components/ModelSelectionDialog.tsx
Normal file
@@ -0,0 +1,87 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Qwen
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import type React from 'react';
|
||||
import { Box, Text } from 'ink';
|
||||
import { Colors } from '../colors.js';
|
||||
import {
|
||||
RadioButtonSelect,
|
||||
type RadioSelectItem,
|
||||
} from './shared/RadioButtonSelect.js';
|
||||
import { useKeypress } from '../hooks/useKeypress.js';
|
||||
import type { AvailableModel } from '../models/availableModels.js';
|
||||
|
||||
export interface ModelSelectionDialogProps {
|
||||
availableModels: AvailableModel[];
|
||||
currentModel: string;
|
||||
onSelect: (modelId: string) => void;
|
||||
onCancel: () => void;
|
||||
}
|
||||
|
||||
export const ModelSelectionDialog: React.FC<ModelSelectionDialogProps> = ({
|
||||
availableModels,
|
||||
currentModel,
|
||||
onSelect,
|
||||
onCancel,
|
||||
}) => {
|
||||
useKeypress(
|
||||
(key) => {
|
||||
if (key.name === 'escape') {
|
||||
onCancel();
|
||||
}
|
||||
},
|
||||
{ isActive: true },
|
||||
);
|
||||
|
||||
const options: Array<RadioSelectItem<string>> = availableModels.map(
|
||||
(model) => {
|
||||
const visionIndicator = model.isVision ? ' [Vision]' : '';
|
||||
const currentIndicator = model.id === currentModel ? ' (current)' : '';
|
||||
return {
|
||||
label: `${model.label}${visionIndicator}${currentIndicator}`,
|
||||
value: model.id,
|
||||
};
|
||||
},
|
||||
);
|
||||
|
||||
const initialIndex = Math.max(
|
||||
0,
|
||||
availableModels.findIndex((model) => model.id === currentModel),
|
||||
);
|
||||
|
||||
const handleSelect = (modelId: string) => {
|
||||
onSelect(modelId);
|
||||
};
|
||||
|
||||
return (
|
||||
<Box
|
||||
flexDirection="column"
|
||||
borderStyle="round"
|
||||
borderColor={Colors.AccentBlue}
|
||||
padding={1}
|
||||
width="100%"
|
||||
marginLeft={1}
|
||||
>
|
||||
<Box flexDirection="column" marginBottom={1}>
|
||||
<Text bold>Select Model</Text>
|
||||
<Text>Choose a model for this session:</Text>
|
||||
</Box>
|
||||
|
||||
<Box marginBottom={1}>
|
||||
<RadioButtonSelect
|
||||
items={options}
|
||||
initialIndex={initialIndex}
|
||||
onSelect={handleSelect}
|
||||
isFocused
|
||||
/>
|
||||
</Box>
|
||||
|
||||
<Box>
|
||||
<Text color={Colors.Gray}>Press Enter to select, Esc to cancel</Text>
|
||||
</Box>
|
||||
</Box>
|
||||
);
|
||||
};
|
||||
185
packages/cli/src/ui/components/ModelSwitchDialog.test.tsx
Normal file
185
packages/cli/src/ui/components/ModelSwitchDialog.test.tsx
Normal file
@@ -0,0 +1,185 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Qwen
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import React from 'react';
|
||||
import { render } from 'ink-testing-library';
|
||||
import { describe, it, expect, vi, beforeEach } from 'vitest';
|
||||
import { ModelSwitchDialog, VisionSwitchOutcome } from './ModelSwitchDialog.js';
|
||||
|
||||
// Mock the useKeypress hook
|
||||
const mockUseKeypress = vi.hoisted(() => vi.fn());
|
||||
vi.mock('../hooks/useKeypress.js', () => ({
|
||||
useKeypress: mockUseKeypress,
|
||||
}));
|
||||
|
||||
// Mock the RadioButtonSelect component
|
||||
const mockRadioButtonSelect = vi.hoisted(() => vi.fn());
|
||||
vi.mock('./shared/RadioButtonSelect.js', () => ({
|
||||
RadioButtonSelect: mockRadioButtonSelect,
|
||||
}));
|
||||
|
||||
describe('ModelSwitchDialog', () => {
|
||||
const mockOnSelect = vi.fn();
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
|
||||
// Mock RadioButtonSelect to return a simple div
|
||||
mockRadioButtonSelect.mockReturnValue(
|
||||
React.createElement('div', { 'data-testid': 'radio-select' }),
|
||||
);
|
||||
});
|
||||
|
||||
it('should setup RadioButtonSelect with correct options', () => {
|
||||
render(<ModelSwitchDialog onSelect={mockOnSelect} />);
|
||||
|
||||
const expectedItems = [
|
||||
{
|
||||
label: 'Switch for this request only',
|
||||
value: VisionSwitchOutcome.SwitchOnce,
|
||||
},
|
||||
{
|
||||
label: 'Switch session to vision model',
|
||||
value: VisionSwitchOutcome.SwitchSessionToVL,
|
||||
},
|
||||
{
|
||||
label: 'Do not switch, show guidance',
|
||||
value: VisionSwitchOutcome.DisallowWithGuidance,
|
||||
},
|
||||
];
|
||||
|
||||
const callArgs = mockRadioButtonSelect.mock.calls[0][0];
|
||||
expect(callArgs.items).toEqual(expectedItems);
|
||||
expect(callArgs.initialIndex).toBe(0);
|
||||
expect(callArgs.isFocused).toBe(true);
|
||||
});
|
||||
|
||||
it('should call onSelect when an option is selected', () => {
|
||||
render(<ModelSwitchDialog onSelect={mockOnSelect} />);
|
||||
|
||||
const callArgs = mockRadioButtonSelect.mock.calls[0][0];
|
||||
expect(typeof callArgs.onSelect).toBe('function');
|
||||
|
||||
// Simulate selection of "Switch for this request only"
|
||||
const onSelectCallback = mockRadioButtonSelect.mock.calls[0][0].onSelect;
|
||||
onSelectCallback(VisionSwitchOutcome.SwitchOnce);
|
||||
|
||||
expect(mockOnSelect).toHaveBeenCalledWith(VisionSwitchOutcome.SwitchOnce);
|
||||
});
|
||||
|
||||
it('should call onSelect with SwitchSessionToVL when second option is selected', () => {
|
||||
render(<ModelSwitchDialog onSelect={mockOnSelect} />);
|
||||
|
||||
const onSelectCallback = mockRadioButtonSelect.mock.calls[0][0].onSelect;
|
||||
onSelectCallback(VisionSwitchOutcome.SwitchSessionToVL);
|
||||
|
||||
expect(mockOnSelect).toHaveBeenCalledWith(
|
||||
VisionSwitchOutcome.SwitchSessionToVL,
|
||||
);
|
||||
});
|
||||
|
||||
it('should call onSelect with DisallowWithGuidance when third option is selected', () => {
|
||||
render(<ModelSwitchDialog onSelect={mockOnSelect} />);
|
||||
|
||||
const onSelectCallback = mockRadioButtonSelect.mock.calls[0][0].onSelect;
|
||||
onSelectCallback(VisionSwitchOutcome.DisallowWithGuidance);
|
||||
|
||||
expect(mockOnSelect).toHaveBeenCalledWith(
|
||||
VisionSwitchOutcome.DisallowWithGuidance,
|
||||
);
|
||||
});
|
||||
|
||||
it('should setup escape key handler to call onSelect with DisallowWithGuidance', () => {
|
||||
render(<ModelSwitchDialog onSelect={mockOnSelect} />);
|
||||
|
||||
expect(mockUseKeypress).toHaveBeenCalledWith(expect.any(Function), {
|
||||
isActive: true,
|
||||
});
|
||||
|
||||
// Simulate escape key press
|
||||
const keypressHandler = mockUseKeypress.mock.calls[0][0];
|
||||
keypressHandler({ name: 'escape' });
|
||||
|
||||
expect(mockOnSelect).toHaveBeenCalledWith(
|
||||
VisionSwitchOutcome.DisallowWithGuidance,
|
||||
);
|
||||
});
|
||||
|
||||
it('should not call onSelect for non-escape keys', () => {
|
||||
render(<ModelSwitchDialog onSelect={mockOnSelect} />);
|
||||
|
||||
const keypressHandler = mockUseKeypress.mock.calls[0][0];
|
||||
keypressHandler({ name: 'enter' });
|
||||
|
||||
expect(mockOnSelect).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should set initial index to 0 (first option)', () => {
|
||||
render(<ModelSwitchDialog onSelect={mockOnSelect} />);
|
||||
|
||||
const callArgs = mockRadioButtonSelect.mock.calls[0][0];
|
||||
expect(callArgs.initialIndex).toBe(0);
|
||||
});
|
||||
|
||||
describe('VisionSwitchOutcome enum', () => {
|
||||
it('should have correct enum values', () => {
|
||||
expect(VisionSwitchOutcome.SwitchOnce).toBe('switch_once');
|
||||
expect(VisionSwitchOutcome.SwitchSessionToVL).toBe(
|
||||
'switch_session_to_vl',
|
||||
);
|
||||
expect(VisionSwitchOutcome.DisallowWithGuidance).toBe(
|
||||
'disallow_with_guidance',
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
it('should handle multiple onSelect calls correctly', () => {
|
||||
render(<ModelSwitchDialog onSelect={mockOnSelect} />);
|
||||
|
||||
const onSelectCallback = mockRadioButtonSelect.mock.calls[0][0].onSelect;
|
||||
|
||||
// Call multiple times
|
||||
onSelectCallback(VisionSwitchOutcome.SwitchOnce);
|
||||
onSelectCallback(VisionSwitchOutcome.SwitchSessionToVL);
|
||||
onSelectCallback(VisionSwitchOutcome.DisallowWithGuidance);
|
||||
|
||||
expect(mockOnSelect).toHaveBeenCalledTimes(3);
|
||||
expect(mockOnSelect).toHaveBeenNthCalledWith(
|
||||
1,
|
||||
VisionSwitchOutcome.SwitchOnce,
|
||||
);
|
||||
expect(mockOnSelect).toHaveBeenNthCalledWith(
|
||||
2,
|
||||
VisionSwitchOutcome.SwitchSessionToVL,
|
||||
);
|
||||
expect(mockOnSelect).toHaveBeenNthCalledWith(
|
||||
3,
|
||||
VisionSwitchOutcome.DisallowWithGuidance,
|
||||
);
|
||||
});
|
||||
|
||||
it('should pass isFocused prop to RadioButtonSelect', () => {
|
||||
render(<ModelSwitchDialog onSelect={mockOnSelect} />);
|
||||
|
||||
const callArgs = mockRadioButtonSelect.mock.calls[0][0];
|
||||
expect(callArgs.isFocused).toBe(true);
|
||||
});
|
||||
|
||||
it('should handle escape key multiple times', () => {
|
||||
render(<ModelSwitchDialog onSelect={mockOnSelect} />);
|
||||
|
||||
const keypressHandler = mockUseKeypress.mock.calls[0][0];
|
||||
|
||||
// Call escape multiple times
|
||||
keypressHandler({ name: 'escape' });
|
||||
keypressHandler({ name: 'escape' });
|
||||
|
||||
expect(mockOnSelect).toHaveBeenCalledTimes(2);
|
||||
expect(mockOnSelect).toHaveBeenCalledWith(
|
||||
VisionSwitchOutcome.DisallowWithGuidance,
|
||||
);
|
||||
});
|
||||
});
|
||||
89
packages/cli/src/ui/components/ModelSwitchDialog.tsx
Normal file
89
packages/cli/src/ui/components/ModelSwitchDialog.tsx
Normal file
@@ -0,0 +1,89 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Qwen
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import type React from 'react';
|
||||
import { Box, Text } from 'ink';
|
||||
import { Colors } from '../colors.js';
|
||||
import {
|
||||
RadioButtonSelect,
|
||||
type RadioSelectItem,
|
||||
} from './shared/RadioButtonSelect.js';
|
||||
import { useKeypress } from '../hooks/useKeypress.js';
|
||||
|
||||
export enum VisionSwitchOutcome {
|
||||
SwitchOnce = 'switch_once',
|
||||
SwitchSessionToVL = 'switch_session_to_vl',
|
||||
DisallowWithGuidance = 'disallow_with_guidance',
|
||||
}
|
||||
|
||||
export interface ModelSwitchDialogProps {
|
||||
onSelect: (outcome: VisionSwitchOutcome) => void;
|
||||
}
|
||||
|
||||
export const ModelSwitchDialog: React.FC<ModelSwitchDialogProps> = ({
|
||||
onSelect,
|
||||
}) => {
|
||||
useKeypress(
|
||||
(key) => {
|
||||
if (key.name === 'escape') {
|
||||
onSelect(VisionSwitchOutcome.DisallowWithGuidance);
|
||||
}
|
||||
},
|
||||
{ isActive: true },
|
||||
);
|
||||
|
||||
const options: Array<RadioSelectItem<VisionSwitchOutcome>> = [
|
||||
{
|
||||
label: 'Switch for this request only',
|
||||
value: VisionSwitchOutcome.SwitchOnce,
|
||||
},
|
||||
{
|
||||
label: 'Switch session to vision model',
|
||||
value: VisionSwitchOutcome.SwitchSessionToVL,
|
||||
},
|
||||
{
|
||||
label: 'Do not switch, show guidance',
|
||||
value: VisionSwitchOutcome.DisallowWithGuidance,
|
||||
},
|
||||
];
|
||||
|
||||
const handleSelect = (outcome: VisionSwitchOutcome) => {
|
||||
onSelect(outcome);
|
||||
};
|
||||
|
||||
return (
|
||||
<Box
|
||||
flexDirection="column"
|
||||
borderStyle="round"
|
||||
borderColor={Colors.AccentYellow}
|
||||
padding={1}
|
||||
width="100%"
|
||||
marginLeft={1}
|
||||
>
|
||||
<Box flexDirection="column" marginBottom={1}>
|
||||
<Text bold>Vision Model Switch Required</Text>
|
||||
<Text>
|
||||
Your message contains an image, but the current model doesn't
|
||||
support vision.
|
||||
</Text>
|
||||
<Text>How would you like to proceed?</Text>
|
||||
</Box>
|
||||
|
||||
<Box marginBottom={1}>
|
||||
<RadioButtonSelect
|
||||
items={options}
|
||||
initialIndex={0}
|
||||
onSelect={handleSelect}
|
||||
isFocused
|
||||
/>
|
||||
</Box>
|
||||
|
||||
<Box>
|
||||
<Text color={Colors.Gray}>Press Enter to select, Esc to cancel</Text>
|
||||
</Box>
|
||||
</Box>
|
||||
);
|
||||
};
|
||||
@@ -106,6 +106,7 @@ describe('useSlashCommandProcessor', () => {
|
||||
const mockLoadHistory = vi.fn();
|
||||
const mockOpenThemeDialog = vi.fn();
|
||||
const mockOpenAuthDialog = vi.fn();
|
||||
const mockOpenModelSelectionDialog = vi.fn();
|
||||
const mockSetQuittingMessages = vi.fn();
|
||||
|
||||
const mockConfig = makeFakeConfig({});
|
||||
@@ -122,6 +123,7 @@ describe('useSlashCommandProcessor', () => {
|
||||
mockBuiltinLoadCommands.mockResolvedValue([]);
|
||||
mockFileLoadCommands.mockResolvedValue([]);
|
||||
mockMcpLoadCommands.mockResolvedValue([]);
|
||||
mockOpenModelSelectionDialog.mockClear();
|
||||
});
|
||||
|
||||
const setupProcessorHook = (
|
||||
@@ -150,11 +152,13 @@ describe('useSlashCommandProcessor', () => {
|
||||
mockSetQuittingMessages,
|
||||
vi.fn(), // openPrivacyNotice
|
||||
vi.fn(), // openSettingsDialog
|
||||
mockOpenModelSelectionDialog,
|
||||
vi.fn(), // openSubagentCreateDialog
|
||||
vi.fn(), // openAgentsManagerDialog
|
||||
vi.fn(), // toggleVimEnabled
|
||||
setIsProcessing,
|
||||
vi.fn(), // setGeminiMdFileCount
|
||||
vi.fn(), // _showQuitConfirmation
|
||||
),
|
||||
);
|
||||
|
||||
@@ -395,6 +399,21 @@ describe('useSlashCommandProcessor', () => {
|
||||
expect(mockOpenThemeDialog).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should handle "dialog: model" action', async () => {
|
||||
const command = createTestCommand({
|
||||
name: 'modelcmd',
|
||||
action: vi.fn().mockResolvedValue({ type: 'dialog', dialog: 'model' }),
|
||||
});
|
||||
const result = setupProcessorHook([command]);
|
||||
await waitFor(() => expect(result.current.slashCommands).toHaveLength(1));
|
||||
|
||||
await act(async () => {
|
||||
await result.current.handleSlashCommand('/modelcmd');
|
||||
});
|
||||
|
||||
expect(mockOpenModelSelectionDialog).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should handle "load_history" action', async () => {
|
||||
const command = createTestCommand({
|
||||
name: 'load',
|
||||
@@ -904,11 +923,13 @@ describe('useSlashCommandProcessor', () => {
|
||||
mockSetQuittingMessages,
|
||||
vi.fn(), // openPrivacyNotice
|
||||
vi.fn(), // openSettingsDialog
|
||||
vi.fn(), // openModelSelectionDialog
|
||||
vi.fn(), // openSubagentCreateDialog
|
||||
vi.fn(), // openAgentsManagerDialog
|
||||
vi.fn(), // toggleVimEnabled
|
||||
vi.fn(), // setIsProcessing
|
||||
vi.fn(), // setGeminiMdFileCount
|
||||
vi.fn(), // _showQuitConfirmation
|
||||
),
|
||||
);
|
||||
|
||||
|
||||
@@ -53,6 +53,7 @@ export const useSlashCommandProcessor = (
|
||||
setQuittingMessages: (message: HistoryItem[]) => void,
|
||||
openPrivacyNotice: () => void,
|
||||
openSettingsDialog: () => void,
|
||||
openModelSelectionDialog: () => void,
|
||||
openSubagentCreateDialog: () => void,
|
||||
openAgentsManagerDialog: () => void,
|
||||
toggleVimEnabled: () => Promise<boolean>,
|
||||
@@ -404,6 +405,9 @@ export const useSlashCommandProcessor = (
|
||||
case 'settings':
|
||||
openSettingsDialog();
|
||||
return { type: 'handled' };
|
||||
case 'model':
|
||||
openModelSelectionDialog();
|
||||
return { type: 'handled' };
|
||||
case 'subagent_create':
|
||||
openSubagentCreateDialog();
|
||||
return { type: 'handled' };
|
||||
@@ -663,6 +667,7 @@ export const useSlashCommandProcessor = (
|
||||
setSessionShellAllowlist,
|
||||
setIsProcessing,
|
||||
setConfirmationRequest,
|
||||
openModelSelectionDialog,
|
||||
session.stats,
|
||||
],
|
||||
);
|
||||
|
||||
@@ -56,6 +56,12 @@ const MockedUserPromptEvent = vi.hoisted(() =>
|
||||
);
|
||||
const mockParseAndFormatApiError = vi.hoisted(() => vi.fn());
|
||||
|
||||
// Vision auto-switch mocks (hoisted)
|
||||
const mockHandleVisionSwitch = vi.hoisted(() =>
|
||||
vi.fn().mockResolvedValue({ shouldProceed: true }),
|
||||
);
|
||||
const mockRestoreOriginalModel = vi.hoisted(() => vi.fn());
|
||||
|
||||
vi.mock('@qwen-code/qwen-code-core', async (importOriginal) => {
|
||||
const actualCoreModule = (await importOriginal()) as any;
|
||||
return {
|
||||
@@ -76,6 +82,13 @@ vi.mock('./useReactToolScheduler.js', async (importOriginal) => {
|
||||
};
|
||||
});
|
||||
|
||||
vi.mock('./useVisionAutoSwitch.js', () => ({
|
||||
useVisionAutoSwitch: vi.fn(() => ({
|
||||
handleVisionSwitch: mockHandleVisionSwitch,
|
||||
restoreOriginalModel: mockRestoreOriginalModel,
|
||||
})),
|
||||
}));
|
||||
|
||||
vi.mock('./useKeypress.js', () => ({
|
||||
useKeypress: vi.fn(),
|
||||
}));
|
||||
@@ -199,6 +212,7 @@ describe('useGeminiStream', () => {
|
||||
getContentGeneratorConfig: vi
|
||||
.fn()
|
||||
.mockReturnValue(contentGeneratorConfig),
|
||||
getMaxSessionTurns: vi.fn(() => 50),
|
||||
} as unknown as Config;
|
||||
mockOnDebugMessage = vi.fn();
|
||||
mockHandleSlashCommand = vi.fn().mockResolvedValue(false);
|
||||
@@ -1551,6 +1565,7 @@ describe('useGeminiStream', () => {
|
||||
expect.any(String), // Argument 3: The prompt_id string
|
||||
);
|
||||
});
|
||||
|
||||
describe('Thought Reset', () => {
|
||||
it('should reset thought to null when starting a new prompt', async () => {
|
||||
// First, simulate a response with a thought
|
||||
@@ -1900,4 +1915,166 @@ describe('useGeminiStream', () => {
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
// --- New tests focused on recent modifications ---
|
||||
describe('Vision Auto Switch Integration', () => {
|
||||
it('should call handleVisionSwitch and proceed to send when allowed', async () => {
|
||||
mockHandleVisionSwitch.mockResolvedValueOnce({ shouldProceed: true });
|
||||
mockSendMessageStream.mockReturnValue(
|
||||
(async function* () {
|
||||
yield { type: ServerGeminiEventType.Content, value: 'ok' };
|
||||
yield { type: ServerGeminiEventType.Finished, value: 'STOP' };
|
||||
})(),
|
||||
);
|
||||
|
||||
const { result } = renderHook(() =>
|
||||
useGeminiStream(
|
||||
new MockedGeminiClientClass(mockConfig),
|
||||
[],
|
||||
mockAddItem,
|
||||
mockConfig,
|
||||
mockOnDebugMessage,
|
||||
mockHandleSlashCommand,
|
||||
false,
|
||||
() => 'vscode' as EditorType,
|
||||
() => {},
|
||||
() => Promise.resolve(),
|
||||
false,
|
||||
() => {},
|
||||
() => {},
|
||||
() => {},
|
||||
),
|
||||
);
|
||||
|
||||
await act(async () => {
|
||||
await result.current.submitQuery('image prompt');
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockHandleVisionSwitch).toHaveBeenCalled();
|
||||
expect(mockSendMessageStream).toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
it('should gate submission when handleVisionSwitch returns shouldProceed=false', async () => {
|
||||
mockHandleVisionSwitch.mockResolvedValueOnce({ shouldProceed: false });
|
||||
|
||||
const { result } = renderHook(() =>
|
||||
useGeminiStream(
|
||||
new MockedGeminiClientClass(mockConfig),
|
||||
[],
|
||||
mockAddItem,
|
||||
mockConfig,
|
||||
mockOnDebugMessage,
|
||||
mockHandleSlashCommand,
|
||||
false,
|
||||
() => 'vscode' as EditorType,
|
||||
() => {},
|
||||
() => Promise.resolve(),
|
||||
false,
|
||||
() => {},
|
||||
() => {},
|
||||
() => {},
|
||||
),
|
||||
);
|
||||
|
||||
await act(async () => {
|
||||
await result.current.submitQuery('vision-gated');
|
||||
});
|
||||
|
||||
// No call to API, no restoreOriginalModel needed since no override occurred
|
||||
expect(mockSendMessageStream).not.toHaveBeenCalled();
|
||||
expect(mockRestoreOriginalModel).not.toHaveBeenCalled();
|
||||
|
||||
// Next call allowed (flag reset path)
|
||||
mockHandleVisionSwitch.mockResolvedValueOnce({ shouldProceed: true });
|
||||
mockSendMessageStream.mockReturnValue(
|
||||
(async function* () {
|
||||
yield { type: ServerGeminiEventType.Content, value: 'ok' };
|
||||
yield { type: ServerGeminiEventType.Finished, value: 'STOP' };
|
||||
})(),
|
||||
);
|
||||
await act(async () => {
|
||||
await result.current.submitQuery('after-gate');
|
||||
});
|
||||
await waitFor(() => {
|
||||
expect(mockSendMessageStream).toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('Model restore on completion and errors', () => {
|
||||
it('should restore model after successful stream completion', async () => {
|
||||
mockSendMessageStream.mockReturnValue(
|
||||
(async function* () {
|
||||
yield { type: ServerGeminiEventType.Content, value: 'content' };
|
||||
yield { type: ServerGeminiEventType.Finished, value: 'STOP' };
|
||||
})(),
|
||||
);
|
||||
|
||||
const { result } = renderHook(() =>
|
||||
useGeminiStream(
|
||||
new MockedGeminiClientClass(mockConfig),
|
||||
[],
|
||||
mockAddItem,
|
||||
mockConfig,
|
||||
mockOnDebugMessage,
|
||||
mockHandleSlashCommand,
|
||||
false,
|
||||
() => 'vscode' as EditorType,
|
||||
() => {},
|
||||
() => Promise.resolve(),
|
||||
false,
|
||||
() => {},
|
||||
() => {},
|
||||
() => {},
|
||||
),
|
||||
);
|
||||
|
||||
await act(async () => {
|
||||
await result.current.submitQuery('restore-success');
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockRestoreOriginalModel).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
});
|
||||
|
||||
it('should restore model when an error occurs during streaming', async () => {
|
||||
const testError = new Error('stream failure');
|
||||
mockSendMessageStream.mockReturnValue(
|
||||
(async function* () {
|
||||
yield { type: ServerGeminiEventType.Content, value: 'content' };
|
||||
throw testError;
|
||||
})(),
|
||||
);
|
||||
|
||||
const { result } = renderHook(() =>
|
||||
useGeminiStream(
|
||||
new MockedGeminiClientClass(mockConfig),
|
||||
[],
|
||||
mockAddItem,
|
||||
mockConfig,
|
||||
mockOnDebugMessage,
|
||||
mockHandleSlashCommand,
|
||||
false,
|
||||
() => 'vscode' as EditorType,
|
||||
() => {},
|
||||
() => Promise.resolve(),
|
||||
false,
|
||||
() => {},
|
||||
() => {},
|
||||
() => {},
|
||||
),
|
||||
);
|
||||
|
||||
await act(async () => {
|
||||
await result.current.submitQuery('restore-error');
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockRestoreOriginalModel).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -42,6 +42,7 @@ import type {
|
||||
import { StreamingState, MessageType, ToolCallStatus } from '../types.js';
|
||||
import { isAtCommand, isSlashCommand } from '../utils/commandUtils.js';
|
||||
import { useShellCommandProcessor } from './shellCommandProcessor.js';
|
||||
import { useVisionAutoSwitch } from './useVisionAutoSwitch.js';
|
||||
import { handleAtCommand } from './atCommandProcessor.js';
|
||||
import { findLastSafeSplitPoint } from '../utils/markdownUtilities.js';
|
||||
import { useStateAndRef } from './useStateAndRef.js';
|
||||
@@ -88,6 +89,12 @@ export const useGeminiStream = (
|
||||
setModelSwitchedFromQuotaError: React.Dispatch<React.SetStateAction<boolean>>,
|
||||
onEditorClose: () => void,
|
||||
onCancelSubmit: () => void,
|
||||
visionModelPreviewEnabled: boolean = false,
|
||||
onVisionSwitchRequired?: (query: PartListUnion) => Promise<{
|
||||
modelOverride?: string;
|
||||
persistSessionModel?: string;
|
||||
showGuidance?: boolean;
|
||||
}>,
|
||||
) => {
|
||||
const [initError, setInitError] = useState<string | null>(null);
|
||||
const abortControllerRef = useRef<AbortController | null>(null);
|
||||
@@ -155,6 +162,13 @@ export const useGeminiStream = (
|
||||
geminiClient,
|
||||
);
|
||||
|
||||
const { handleVisionSwitch, restoreOriginalModel } = useVisionAutoSwitch(
|
||||
config,
|
||||
addItem,
|
||||
visionModelPreviewEnabled,
|
||||
onVisionSwitchRequired,
|
||||
);
|
||||
|
||||
const streamingState = useMemo(() => {
|
||||
if (toolCalls.some((tc) => tc.status === 'awaiting_approval')) {
|
||||
return StreamingState.WaitingForConfirmation;
|
||||
@@ -715,6 +729,20 @@ export const useGeminiStream = (
|
||||
return;
|
||||
}
|
||||
|
||||
// Handle vision switch requirement
|
||||
const visionSwitchResult = await handleVisionSwitch(
|
||||
queryToSend,
|
||||
userMessageTimestamp,
|
||||
options?.isContinuation || false,
|
||||
);
|
||||
|
||||
if (!visionSwitchResult.shouldProceed) {
|
||||
isSubmittingQueryRef.current = false;
|
||||
return;
|
||||
}
|
||||
|
||||
const finalQueryToSend = queryToSend;
|
||||
|
||||
if (!options?.isContinuation) {
|
||||
startNewPrompt();
|
||||
setThought(null); // Reset thought when starting a new prompt
|
||||
@@ -725,7 +753,7 @@ export const useGeminiStream = (
|
||||
|
||||
try {
|
||||
const stream = geminiClient.sendMessageStream(
|
||||
queryToSend,
|
||||
finalQueryToSend,
|
||||
abortSignal,
|
||||
prompt_id!,
|
||||
);
|
||||
@@ -736,6 +764,8 @@ export const useGeminiStream = (
|
||||
);
|
||||
|
||||
if (processingStatus === StreamProcessingStatus.UserCancelled) {
|
||||
// Restore original model if it was temporarily overridden
|
||||
restoreOriginalModel();
|
||||
isSubmittingQueryRef.current = false;
|
||||
return;
|
||||
}
|
||||
@@ -748,7 +778,13 @@ export const useGeminiStream = (
|
||||
loopDetectedRef.current = false;
|
||||
handleLoopDetectedEvent();
|
||||
}
|
||||
|
||||
// Restore original model if it was temporarily overridden
|
||||
restoreOriginalModel();
|
||||
} catch (error: unknown) {
|
||||
// Restore original model if it was temporarily overridden
|
||||
restoreOriginalModel();
|
||||
|
||||
if (error instanceof UnauthorizedError) {
|
||||
onAuthError();
|
||||
} else if (!isNodeError(error) || error.name !== 'AbortError') {
|
||||
@@ -786,6 +822,8 @@ export const useGeminiStream = (
|
||||
startNewPrompt,
|
||||
getPromptCount,
|
||||
handleLoopDetectedEvent,
|
||||
handleVisionSwitch,
|
||||
restoreOriginalModel,
|
||||
],
|
||||
);
|
||||
|
||||
|
||||
374
packages/cli/src/ui/hooks/useVisionAutoSwitch.test.ts
Normal file
374
packages/cli/src/ui/hooks/useVisionAutoSwitch.test.ts
Normal file
@@ -0,0 +1,374 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Qwen
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
/* eslint-disable @typescript-eslint/no-explicit-any */
|
||||
import { describe, it, expect, vi, beforeEach } from 'vitest';
|
||||
import { renderHook, act } from '@testing-library/react';
|
||||
import type { Part, PartListUnion } from '@google/genai';
|
||||
import { AuthType, type Config } from '@qwen-code/qwen-code-core';
|
||||
import {
|
||||
shouldOfferVisionSwitch,
|
||||
processVisionSwitchOutcome,
|
||||
getVisionSwitchGuidanceMessage,
|
||||
useVisionAutoSwitch,
|
||||
} from './useVisionAutoSwitch.js';
|
||||
import { VisionSwitchOutcome } from '../components/ModelSwitchDialog.js';
|
||||
import { MessageType } from '../types.js';
|
||||
import { getDefaultVisionModel } from '../models/availableModels.js';
|
||||
|
||||
describe('useVisionAutoSwitch helpers', () => {
|
||||
describe('shouldOfferVisionSwitch', () => {
|
||||
it('returns false when authType is not QWEN_OAUTH', () => {
|
||||
const parts: PartListUnion = [
|
||||
{ inlineData: { mimeType: 'image/png', data: '...' } },
|
||||
];
|
||||
const result = shouldOfferVisionSwitch(
|
||||
parts,
|
||||
AuthType.USE_GEMINI,
|
||||
'qwen3-coder-plus',
|
||||
true,
|
||||
);
|
||||
expect(result).toBe(false);
|
||||
});
|
||||
|
||||
it('returns false when current model is already a vision model', () => {
|
||||
const parts: PartListUnion = [
|
||||
{ inlineData: { mimeType: 'image/png', data: '...' } },
|
||||
];
|
||||
const result = shouldOfferVisionSwitch(
|
||||
parts,
|
||||
AuthType.QWEN_OAUTH,
|
||||
'qwen-vl-max-latest',
|
||||
true,
|
||||
);
|
||||
expect(result).toBe(false);
|
||||
});
|
||||
|
||||
it('returns true when image parts exist, QWEN_OAUTH, and model is not vision', () => {
|
||||
const parts: PartListUnion = [
|
||||
{ text: 'hello' },
|
||||
{ inlineData: { mimeType: 'image/jpeg', data: '...' } },
|
||||
];
|
||||
const result = shouldOfferVisionSwitch(
|
||||
parts,
|
||||
AuthType.QWEN_OAUTH,
|
||||
'qwen3-coder-plus',
|
||||
true,
|
||||
);
|
||||
expect(result).toBe(true);
|
||||
});
|
||||
|
||||
it('detects image when provided as a single Part object (non-array)', () => {
|
||||
const singleImagePart: PartListUnion = {
|
||||
fileData: { mimeType: 'image/gif', fileUri: 'file://image.gif' },
|
||||
} as Part;
|
||||
const result = shouldOfferVisionSwitch(
|
||||
singleImagePart,
|
||||
AuthType.QWEN_OAUTH,
|
||||
'qwen3-coder-plus',
|
||||
true,
|
||||
);
|
||||
expect(result).toBe(true);
|
||||
});
|
||||
|
||||
it('returns false when parts contain no images', () => {
|
||||
const parts: PartListUnion = [{ text: 'just text' }];
|
||||
const result = shouldOfferVisionSwitch(
|
||||
parts,
|
||||
AuthType.QWEN_OAUTH,
|
||||
'qwen3-coder-plus',
|
||||
true,
|
||||
);
|
||||
expect(result).toBe(false);
|
||||
});
|
||||
|
||||
it('returns false when parts is a plain string', () => {
|
||||
const parts: PartListUnion = 'plain text';
|
||||
const result = shouldOfferVisionSwitch(
|
||||
parts,
|
||||
AuthType.QWEN_OAUTH,
|
||||
'qwen3-coder-plus',
|
||||
true,
|
||||
);
|
||||
expect(result).toBe(false);
|
||||
});
|
||||
|
||||
it('returns false when visionModelPreviewEnabled is false', () => {
|
||||
const parts: PartListUnion = [
|
||||
{ inlineData: { mimeType: 'image/png', data: '...' } },
|
||||
];
|
||||
const result = shouldOfferVisionSwitch(
|
||||
parts,
|
||||
AuthType.QWEN_OAUTH,
|
||||
'qwen3-coder-plus',
|
||||
false,
|
||||
);
|
||||
expect(result).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe('processVisionSwitchOutcome', () => {
|
||||
it('maps SwitchOnce to a one-time model override', () => {
|
||||
const vl = getDefaultVisionModel();
|
||||
const result = processVisionSwitchOutcome(VisionSwitchOutcome.SwitchOnce);
|
||||
expect(result).toEqual({ modelOverride: vl });
|
||||
});
|
||||
|
||||
it('maps SwitchSessionToVL to a persistent session model', () => {
|
||||
const vl = getDefaultVisionModel();
|
||||
const result = processVisionSwitchOutcome(
|
||||
VisionSwitchOutcome.SwitchSessionToVL,
|
||||
);
|
||||
expect(result).toEqual({ persistSessionModel: vl });
|
||||
});
|
||||
|
||||
it('maps DisallowWithGuidance to showGuidance', () => {
|
||||
const result = processVisionSwitchOutcome(
|
||||
VisionSwitchOutcome.DisallowWithGuidance,
|
||||
);
|
||||
expect(result).toEqual({ showGuidance: true });
|
||||
});
|
||||
});
|
||||
|
||||
describe('getVisionSwitchGuidanceMessage', () => {
|
||||
it('returns the expected guidance message', () => {
|
||||
const vl = getDefaultVisionModel();
|
||||
const expected =
|
||||
'To use images with your query, you can:\n' +
|
||||
`• Use /model set ${vl} to switch to a vision-capable model\n` +
|
||||
'• Or remove the image and provide a text description instead';
|
||||
expect(getVisionSwitchGuidanceMessage()).toBe(expected);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('useVisionAutoSwitch hook', () => {
|
||||
type AddItemFn = (
|
||||
item: { type: MessageType; text: string },
|
||||
ts: number,
|
||||
) => any;
|
||||
|
||||
const createMockConfig = (authType: AuthType, initialModel: string) => {
|
||||
let currentModel = initialModel;
|
||||
const mockConfig: Partial<Config> = {
|
||||
getModel: vi.fn(() => currentModel),
|
||||
setModel: vi.fn((m: string) => {
|
||||
currentModel = m;
|
||||
}),
|
||||
getContentGeneratorConfig: vi.fn(() => ({
|
||||
authType,
|
||||
model: currentModel,
|
||||
apiKey: 'test-key',
|
||||
vertexai: false,
|
||||
})),
|
||||
};
|
||||
return mockConfig as Config;
|
||||
};
|
||||
|
||||
let addItem: AddItemFn;
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
addItem = vi.fn();
|
||||
});
|
||||
|
||||
it('returns shouldProceed=true immediately for continuations', async () => {
|
||||
const config = createMockConfig(AuthType.QWEN_OAUTH, 'qwen3-coder-plus');
|
||||
const { result } = renderHook(() =>
|
||||
useVisionAutoSwitch(config, addItem as any, true, vi.fn()),
|
||||
);
|
||||
|
||||
const parts: PartListUnion = [
|
||||
{ inlineData: { mimeType: 'image/png', data: '...' } },
|
||||
];
|
||||
let res: any;
|
||||
await act(async () => {
|
||||
res = await result.current.handleVisionSwitch(parts, Date.now(), true);
|
||||
});
|
||||
expect(res).toEqual({ shouldProceed: true });
|
||||
expect(addItem).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('does nothing when authType is not QWEN_OAUTH', async () => {
|
||||
const config = createMockConfig(AuthType.USE_GEMINI, 'qwen3-coder-plus');
|
||||
const onVisionSwitchRequired = vi.fn();
|
||||
const { result } = renderHook(() =>
|
||||
useVisionAutoSwitch(config, addItem as any, true, onVisionSwitchRequired),
|
||||
);
|
||||
|
||||
const parts: PartListUnion = [
|
||||
{ inlineData: { mimeType: 'image/png', data: '...' } },
|
||||
];
|
||||
let res: any;
|
||||
await act(async () => {
|
||||
res = await result.current.handleVisionSwitch(parts, 123, false);
|
||||
});
|
||||
expect(res).toEqual({ shouldProceed: true });
|
||||
expect(onVisionSwitchRequired).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('does nothing when there are no image parts', async () => {
|
||||
const config = createMockConfig(AuthType.QWEN_OAUTH, 'qwen3-coder-plus');
|
||||
const onVisionSwitchRequired = vi.fn();
|
||||
const { result } = renderHook(() =>
|
||||
useVisionAutoSwitch(config, addItem as any, true, onVisionSwitchRequired),
|
||||
);
|
||||
|
||||
const parts: PartListUnion = [{ text: 'no images here' }];
|
||||
let res: any;
|
||||
await act(async () => {
|
||||
res = await result.current.handleVisionSwitch(parts, 456, false);
|
||||
});
|
||||
expect(res).toEqual({ shouldProceed: true });
|
||||
expect(onVisionSwitchRequired).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('shows guidance and blocks when dialog returns showGuidance', async () => {
|
||||
const config = createMockConfig(AuthType.QWEN_OAUTH, 'qwen3-coder-plus');
|
||||
const onVisionSwitchRequired = vi
|
||||
.fn()
|
||||
.mockResolvedValue({ showGuidance: true });
|
||||
const { result } = renderHook(() =>
|
||||
useVisionAutoSwitch(config, addItem as any, true, onVisionSwitchRequired),
|
||||
);
|
||||
|
||||
const parts: PartListUnion = [
|
||||
{ inlineData: { mimeType: 'image/png', data: '...' } },
|
||||
];
|
||||
|
||||
const userTs = 1010;
|
||||
let res: any;
|
||||
await act(async () => {
|
||||
res = await result.current.handleVisionSwitch(parts, userTs, false);
|
||||
});
|
||||
|
||||
expect(addItem).toHaveBeenCalledWith(
|
||||
{ type: MessageType.INFO, text: getVisionSwitchGuidanceMessage() },
|
||||
userTs,
|
||||
);
|
||||
expect(res).toEqual({ shouldProceed: false });
|
||||
expect(config.setModel).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('applies a one-time override and returns originalModel, then restores', async () => {
|
||||
const initialModel = 'qwen3-coder-plus';
|
||||
const config = createMockConfig(AuthType.QWEN_OAUTH, initialModel);
|
||||
const onVisionSwitchRequired = vi
|
||||
.fn()
|
||||
.mockResolvedValue({ modelOverride: 'qwen-vl-max-latest' });
|
||||
const { result } = renderHook(() =>
|
||||
useVisionAutoSwitch(config, addItem as any, true, onVisionSwitchRequired),
|
||||
);
|
||||
|
||||
const parts: PartListUnion = [
|
||||
{ inlineData: { mimeType: 'image/png', data: '...' } },
|
||||
];
|
||||
|
||||
let res: any;
|
||||
await act(async () => {
|
||||
res = await result.current.handleVisionSwitch(parts, 2020, false);
|
||||
});
|
||||
|
||||
expect(res).toEqual({ shouldProceed: true, originalModel: initialModel });
|
||||
expect(config.setModel).toHaveBeenCalledWith('qwen-vl-max-latest');
|
||||
|
||||
// Now restore
|
||||
act(() => {
|
||||
result.current.restoreOriginalModel();
|
||||
});
|
||||
expect(config.setModel).toHaveBeenLastCalledWith(initialModel);
|
||||
});
|
||||
|
||||
it('persists session model when dialog requests persistence', async () => {
|
||||
const config = createMockConfig(AuthType.QWEN_OAUTH, 'qwen3-coder-plus');
|
||||
const onVisionSwitchRequired = vi
|
||||
.fn()
|
||||
.mockResolvedValue({ persistSessionModel: 'qwen-vl-max-latest' });
|
||||
const { result } = renderHook(() =>
|
||||
useVisionAutoSwitch(config, addItem as any, true, onVisionSwitchRequired),
|
||||
);
|
||||
|
||||
const parts: PartListUnion = [
|
||||
{ inlineData: { mimeType: 'image/png', data: '...' } },
|
||||
];
|
||||
|
||||
let res: any;
|
||||
await act(async () => {
|
||||
res = await result.current.handleVisionSwitch(parts, 3030, false);
|
||||
});
|
||||
|
||||
expect(res).toEqual({ shouldProceed: true });
|
||||
expect(config.setModel).toHaveBeenCalledWith('qwen-vl-max-latest');
|
||||
|
||||
// Restore should be a no-op since no one-time override was used
|
||||
act(() => {
|
||||
result.current.restoreOriginalModel();
|
||||
});
|
||||
// Last call should still be the persisted model set
|
||||
expect((config.setModel as any).mock.calls.pop()?.[0]).toBe(
|
||||
'qwen-vl-max-latest',
|
||||
);
|
||||
});
|
||||
|
||||
it('returns shouldProceed=true when dialog returns no special flags', async () => {
|
||||
const config = createMockConfig(AuthType.QWEN_OAUTH, 'qwen3-coder-plus');
|
||||
const onVisionSwitchRequired = vi.fn().mockResolvedValue({});
|
||||
const { result } = renderHook(() =>
|
||||
useVisionAutoSwitch(config, addItem as any, true, onVisionSwitchRequired),
|
||||
);
|
||||
|
||||
const parts: PartListUnion = [
|
||||
{ inlineData: { mimeType: 'image/png', data: '...' } },
|
||||
];
|
||||
let res: any;
|
||||
await act(async () => {
|
||||
res = await result.current.handleVisionSwitch(parts, 4040, false);
|
||||
});
|
||||
expect(res).toEqual({ shouldProceed: true });
|
||||
expect(config.setModel).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('blocks when dialog throws or is cancelled', async () => {
|
||||
const config = createMockConfig(AuthType.QWEN_OAUTH, 'qwen3-coder-plus');
|
||||
const onVisionSwitchRequired = vi.fn().mockRejectedValue(new Error('x'));
|
||||
const { result } = renderHook(() =>
|
||||
useVisionAutoSwitch(config, addItem as any, true, onVisionSwitchRequired),
|
||||
);
|
||||
|
||||
const parts: PartListUnion = [
|
||||
{ inlineData: { mimeType: 'image/png', data: '...' } },
|
||||
];
|
||||
let res: any;
|
||||
await act(async () => {
|
||||
res = await result.current.handleVisionSwitch(parts, 5050, false);
|
||||
});
|
||||
expect(res).toEqual({ shouldProceed: false });
|
||||
expect(config.setModel).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('does nothing when visionModelPreviewEnabled is false', async () => {
|
||||
const config = createMockConfig(AuthType.QWEN_OAUTH, 'qwen3-coder-plus');
|
||||
const onVisionSwitchRequired = vi.fn();
|
||||
const { result } = renderHook(() =>
|
||||
useVisionAutoSwitch(
|
||||
config,
|
||||
addItem as any,
|
||||
false,
|
||||
onVisionSwitchRequired,
|
||||
),
|
||||
);
|
||||
|
||||
const parts: PartListUnion = [
|
||||
{ inlineData: { mimeType: 'image/png', data: '...' } },
|
||||
];
|
||||
let res: any;
|
||||
await act(async () => {
|
||||
res = await result.current.handleVisionSwitch(parts, 6060, false);
|
||||
});
|
||||
expect(res).toEqual({ shouldProceed: true });
|
||||
expect(onVisionSwitchRequired).not.toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
304
packages/cli/src/ui/hooks/useVisionAutoSwitch.ts
Normal file
304
packages/cli/src/ui/hooks/useVisionAutoSwitch.ts
Normal file
@@ -0,0 +1,304 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Qwen
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { type PartListUnion, type Part } from '@google/genai';
|
||||
import { AuthType, type Config } from '@qwen-code/qwen-code-core';
|
||||
import { useCallback, useRef } from 'react';
|
||||
import { VisionSwitchOutcome } from '../components/ModelSwitchDialog.js';
|
||||
import {
|
||||
getDefaultVisionModel,
|
||||
isVisionModel,
|
||||
} from '../models/availableModels.js';
|
||||
import { MessageType } from '../types.js';
|
||||
import type { UseHistoryManagerReturn } from './useHistoryManager.js';
|
||||
import {
|
||||
isSupportedImageMimeType,
|
||||
getUnsupportedImageFormatWarning,
|
||||
} from '@qwen-code/qwen-code-core';
|
||||
|
||||
/**
|
||||
* Checks if a PartListUnion contains image parts
|
||||
*/
|
||||
function hasImageParts(parts: PartListUnion): boolean {
|
||||
if (typeof parts === 'string') {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (Array.isArray(parts)) {
|
||||
return parts.some((part) => {
|
||||
// Skip string parts
|
||||
if (typeof part === 'string') return false;
|
||||
return isImagePart(part);
|
||||
});
|
||||
}
|
||||
|
||||
// If it's a single Part (not a string), check if it's an image
|
||||
if (typeof parts === 'object') {
|
||||
return isImagePart(parts);
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
/**
|
||||
* Checks if a single Part is an image part
|
||||
*/
|
||||
function isImagePart(part: Part): boolean {
|
||||
// Check for inlineData with image mime type
|
||||
if ('inlineData' in part && part.inlineData?.mimeType?.startsWith('image/')) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// Check for fileData with image mime type
|
||||
if ('fileData' in part && part.fileData?.mimeType?.startsWith('image/')) {
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
/**
|
||||
* Checks if image parts have supported formats and returns unsupported ones
|
||||
*/
|
||||
function checkImageFormatsSupport(parts: PartListUnion): {
|
||||
hasImages: boolean;
|
||||
hasUnsupportedFormats: boolean;
|
||||
unsupportedMimeTypes: string[];
|
||||
} {
|
||||
const unsupportedMimeTypes: string[] = [];
|
||||
let hasImages = false;
|
||||
|
||||
if (typeof parts === 'string') {
|
||||
return {
|
||||
hasImages: false,
|
||||
hasUnsupportedFormats: false,
|
||||
unsupportedMimeTypes: [],
|
||||
};
|
||||
}
|
||||
|
||||
const partsArray = Array.isArray(parts) ? parts : [parts];
|
||||
|
||||
for (const part of partsArray) {
|
||||
if (typeof part === 'string') continue;
|
||||
|
||||
let mimeType: string | undefined;
|
||||
|
||||
// Check inlineData
|
||||
if (
|
||||
'inlineData' in part &&
|
||||
part.inlineData?.mimeType?.startsWith('image/')
|
||||
) {
|
||||
hasImages = true;
|
||||
mimeType = part.inlineData.mimeType;
|
||||
}
|
||||
|
||||
// Check fileData
|
||||
if ('fileData' in part && part.fileData?.mimeType?.startsWith('image/')) {
|
||||
hasImages = true;
|
||||
mimeType = part.fileData.mimeType;
|
||||
}
|
||||
|
||||
// Check if the mime type is supported
|
||||
if (mimeType && !isSupportedImageMimeType(mimeType)) {
|
||||
unsupportedMimeTypes.push(mimeType);
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
hasImages,
|
||||
hasUnsupportedFormats: unsupportedMimeTypes.length > 0,
|
||||
unsupportedMimeTypes,
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Determines if we should offer vision switch for the given parts, auth type, and current model
|
||||
*/
|
||||
export function shouldOfferVisionSwitch(
|
||||
parts: PartListUnion,
|
||||
authType: AuthType,
|
||||
currentModel: string,
|
||||
visionModelPreviewEnabled: boolean = false,
|
||||
): boolean {
|
||||
// Only trigger for qwen-oauth
|
||||
if (authType !== AuthType.QWEN_OAUTH) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// If vision model preview is disabled, never offer vision switch
|
||||
if (!visionModelPreviewEnabled) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// If current model is already a vision model, no need to switch
|
||||
if (isVisionModel(currentModel)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check if the current message contains image parts
|
||||
return hasImageParts(parts);
|
||||
}
|
||||
|
||||
/**
|
||||
* Interface for vision switch result
|
||||
*/
|
||||
export interface VisionSwitchResult {
|
||||
modelOverride?: string;
|
||||
persistSessionModel?: string;
|
||||
showGuidance?: boolean;
|
||||
}
|
||||
|
||||
/**
|
||||
* Processes the vision switch outcome and returns the appropriate result
|
||||
*/
|
||||
export function processVisionSwitchOutcome(
|
||||
outcome: VisionSwitchOutcome,
|
||||
): VisionSwitchResult {
|
||||
const vlModelId = getDefaultVisionModel();
|
||||
|
||||
switch (outcome) {
|
||||
case VisionSwitchOutcome.SwitchOnce:
|
||||
return { modelOverride: vlModelId };
|
||||
|
||||
case VisionSwitchOutcome.SwitchSessionToVL:
|
||||
return { persistSessionModel: vlModelId };
|
||||
|
||||
case VisionSwitchOutcome.DisallowWithGuidance:
|
||||
return { showGuidance: true };
|
||||
|
||||
default:
|
||||
return { showGuidance: true };
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the guidance message for when vision switch is disallowed
|
||||
*/
|
||||
export function getVisionSwitchGuidanceMessage(): string {
|
||||
const vlModelId = getDefaultVisionModel();
|
||||
return `To use images with your query, you can:
|
||||
• Use /model set ${vlModelId} to switch to a vision-capable model
|
||||
• Or remove the image and provide a text description instead`;
|
||||
}
|
||||
|
||||
/**
|
||||
* Interface for vision switch handling result
|
||||
*/
|
||||
export interface VisionSwitchHandlingResult {
|
||||
shouldProceed: boolean;
|
||||
originalModel?: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* Custom hook for handling vision model auto-switching
|
||||
*/
|
||||
export function useVisionAutoSwitch(
|
||||
config: Config,
|
||||
addItem: UseHistoryManagerReturn['addItem'],
|
||||
visionModelPreviewEnabled: boolean = false,
|
||||
onVisionSwitchRequired?: (query: PartListUnion) => Promise<{
|
||||
modelOverride?: string;
|
||||
persistSessionModel?: string;
|
||||
showGuidance?: boolean;
|
||||
}>,
|
||||
) {
|
||||
const originalModelRef = useRef<string | null>(null);
|
||||
|
||||
const handleVisionSwitch = useCallback(
|
||||
async (
|
||||
query: PartListUnion,
|
||||
userMessageTimestamp: number,
|
||||
isContinuation: boolean,
|
||||
): Promise<VisionSwitchHandlingResult> => {
|
||||
// Skip vision switch handling for continuations or if no handler provided
|
||||
if (isContinuation || !onVisionSwitchRequired) {
|
||||
return { shouldProceed: true };
|
||||
}
|
||||
|
||||
const contentGeneratorConfig = config.getContentGeneratorConfig();
|
||||
|
||||
// Only handle qwen-oauth auth type
|
||||
if (contentGeneratorConfig?.authType !== AuthType.QWEN_OAUTH) {
|
||||
return { shouldProceed: true };
|
||||
}
|
||||
|
||||
// Check image format support first
|
||||
const formatCheck = checkImageFormatsSupport(query);
|
||||
|
||||
// If there are unsupported image formats, show warning
|
||||
if (formatCheck.hasUnsupportedFormats) {
|
||||
addItem(
|
||||
{
|
||||
type: MessageType.INFO,
|
||||
text: getUnsupportedImageFormatWarning(),
|
||||
},
|
||||
userMessageTimestamp,
|
||||
);
|
||||
// Continue processing but with warning shown
|
||||
}
|
||||
|
||||
// Check if vision switch is needed
|
||||
if (
|
||||
!shouldOfferVisionSwitch(
|
||||
query,
|
||||
contentGeneratorConfig.authType,
|
||||
config.getModel(),
|
||||
visionModelPreviewEnabled,
|
||||
)
|
||||
) {
|
||||
return { shouldProceed: true };
|
||||
}
|
||||
|
||||
try {
|
||||
const visionSwitchResult = await onVisionSwitchRequired(query);
|
||||
|
||||
if (visionSwitchResult.showGuidance) {
|
||||
// Show guidance and don't proceed with the request
|
||||
addItem(
|
||||
{
|
||||
type: MessageType.INFO,
|
||||
text: getVisionSwitchGuidanceMessage(),
|
||||
},
|
||||
userMessageTimestamp,
|
||||
);
|
||||
return { shouldProceed: false };
|
||||
}
|
||||
|
||||
if (visionSwitchResult.modelOverride) {
|
||||
// One-time model override
|
||||
originalModelRef.current = config.getModel();
|
||||
config.setModel(visionSwitchResult.modelOverride);
|
||||
return {
|
||||
shouldProceed: true,
|
||||
originalModel: originalModelRef.current,
|
||||
};
|
||||
} else if (visionSwitchResult.persistSessionModel) {
|
||||
// Persistent session model change
|
||||
config.setModel(visionSwitchResult.persistSessionModel);
|
||||
return { shouldProceed: true };
|
||||
}
|
||||
|
||||
return { shouldProceed: true };
|
||||
} catch (_error) {
|
||||
// If vision switch dialog was cancelled or errored, don't proceed
|
||||
return { shouldProceed: false };
|
||||
}
|
||||
},
|
||||
[config, addItem, visionModelPreviewEnabled, onVisionSwitchRequired],
|
||||
);
|
||||
|
||||
const restoreOriginalModel = useCallback(() => {
|
||||
if (originalModelRef.current) {
|
||||
config.setModel(originalModelRef.current);
|
||||
originalModelRef.current = null;
|
||||
}
|
||||
}, [config]);
|
||||
|
||||
return {
|
||||
handleVisionSwitch,
|
||||
restoreOriginalModel,
|
||||
};
|
||||
}
|
||||
52
packages/cli/src/ui/models/availableModels.ts
Normal file
52
packages/cli/src/ui/models/availableModels.ts
Normal file
@@ -0,0 +1,52 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Qwen
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
export type AvailableModel = {
|
||||
id: string;
|
||||
label: string;
|
||||
isVision?: boolean;
|
||||
};
|
||||
|
||||
export const AVAILABLE_MODELS_QWEN: AvailableModel[] = [
|
||||
{ id: 'qwen3-coder-plus', label: 'qwen3-coder-plus' },
|
||||
{ id: 'qwen-vl-max-latest', label: 'qwen-vl-max', isVision: true },
|
||||
];
|
||||
|
||||
/**
|
||||
* Get available Qwen models filtered by vision model preview setting
|
||||
*/
|
||||
export function getFilteredQwenModels(
|
||||
visionModelPreviewEnabled: boolean,
|
||||
): AvailableModel[] {
|
||||
if (visionModelPreviewEnabled) {
|
||||
return AVAILABLE_MODELS_QWEN;
|
||||
}
|
||||
return AVAILABLE_MODELS_QWEN.filter((model) => !model.isVision);
|
||||
}
|
||||
|
||||
/**
|
||||
* Currently we use the single model of `OPENAI_MODEL` in the env.
|
||||
* In the future, after settings.json is updated, we will allow users to configure this themselves.
|
||||
*/
|
||||
export function getOpenAIAvailableModelFromEnv(): AvailableModel | null {
|
||||
const id = process.env['OPENAI_MODEL']?.trim();
|
||||
return id ? { id, label: id } : null;
|
||||
}
|
||||
|
||||
/**
|
||||
/**
|
||||
* Hard code the default vision model as a string literal,
|
||||
* until our coding model supports multimodal.
|
||||
*/
|
||||
export function getDefaultVisionModel(): string {
|
||||
return 'qwen-vl-max-latest';
|
||||
}
|
||||
|
||||
export function isVisionModel(modelId: string): boolean {
|
||||
return AVAILABLE_MODELS_QWEN.some(
|
||||
(model) => model.id === modelId && model.isVision,
|
||||
);
|
||||
}
|
||||
Reference in New Issue
Block a user