mirror of
https://github.com/QwenLM/qwen-code.git
synced 2025-12-20 08:47:44 +00:00
feat: add yolo mode support to auto vision model switch (#652)
* feat: add yolo mode support to auto vision model switch * feat: add cli args & env variables for switch behavoir * fix: use dedicated model names and settings * docs: add vision model instructions * fix: failed test case * fix: setModel failure
This commit is contained in:
@@ -737,4 +737,85 @@ describe('setApprovalMode with folder trust', () => {
|
||||
expect(() => config.setApprovalMode(ApprovalMode.AUTO_EDIT)).not.toThrow();
|
||||
expect(() => config.setApprovalMode(ApprovalMode.DEFAULT)).not.toThrow();
|
||||
});
|
||||
|
||||
describe('Model Switch Logging', () => {
|
||||
it('should log model switch when setModel is called with different model', async () => {
|
||||
const config = new Config({
|
||||
sessionId: 'test-model-switch',
|
||||
targetDir: '.',
|
||||
debugMode: false,
|
||||
model: 'qwen3-coder-plus',
|
||||
cwd: '.',
|
||||
});
|
||||
|
||||
// Initialize the config to set up content generator
|
||||
await config.initialize();
|
||||
|
||||
// Mock the logger's logModelSwitch method
|
||||
const logModelSwitchSpy = vi.spyOn(config['logger']!, 'logModelSwitch');
|
||||
|
||||
// Change the model
|
||||
await config.setModel('qwen-vl-max-latest', {
|
||||
reason: 'vision_auto_switch',
|
||||
context: 'Test model switch',
|
||||
});
|
||||
|
||||
// Verify that logModelSwitch was called with correct parameters
|
||||
expect(logModelSwitchSpy).toHaveBeenCalledWith({
|
||||
fromModel: 'qwen3-coder-plus',
|
||||
toModel: 'qwen-vl-max-latest',
|
||||
reason: 'vision_auto_switch',
|
||||
context: 'Test model switch',
|
||||
});
|
||||
});
|
||||
|
||||
it('should not log when setModel is called with same model', async () => {
|
||||
const config = new Config({
|
||||
sessionId: 'test-same-model',
|
||||
targetDir: '.',
|
||||
debugMode: false,
|
||||
model: 'qwen3-coder-plus',
|
||||
cwd: '.',
|
||||
});
|
||||
|
||||
// Initialize the config to set up content generator
|
||||
await config.initialize();
|
||||
|
||||
// Mock the logger's logModelSwitch method
|
||||
const logModelSwitchSpy = vi.spyOn(config['logger']!, 'logModelSwitch');
|
||||
|
||||
// Set the same model
|
||||
await config.setModel('qwen3-coder-plus');
|
||||
|
||||
// Verify that logModelSwitch was not called
|
||||
expect(logModelSwitchSpy).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should use default reason when no options provided', async () => {
|
||||
const config = new Config({
|
||||
sessionId: 'test-default-reason',
|
||||
targetDir: '.',
|
||||
debugMode: false,
|
||||
model: 'qwen3-coder-plus',
|
||||
cwd: '.',
|
||||
});
|
||||
|
||||
// Initialize the config to set up content generator
|
||||
await config.initialize();
|
||||
|
||||
// Mock the logger's logModelSwitch method
|
||||
const logModelSwitchSpy = vi.spyOn(config['logger']!, 'logModelSwitch');
|
||||
|
||||
// Change the model without options
|
||||
await config.setModel('qwen-vl-max-latest');
|
||||
|
||||
// Verify that logModelSwitch was called with default reason
|
||||
expect(logModelSwitchSpy).toHaveBeenCalledWith({
|
||||
fromModel: 'qwen3-coder-plus',
|
||||
toModel: 'qwen-vl-max-latest',
|
||||
reason: 'manual',
|
||||
context: undefined,
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -56,6 +56,7 @@ import {
|
||||
DEFAULT_GEMINI_FLASH_MODEL,
|
||||
} from './models.js';
|
||||
import { Storage } from './storage.js';
|
||||
import { Logger, type ModelSwitchEvent } from '../core/logger.js';
|
||||
|
||||
// Re-export OAuth config type
|
||||
export type { AnyToolInvocation, MCPOAuthConfig };
|
||||
@@ -239,6 +240,7 @@ export interface ConfigParameters {
|
||||
extensionManagement?: boolean;
|
||||
enablePromptCompletion?: boolean;
|
||||
skipLoopDetection?: boolean;
|
||||
vlmSwitchMode?: string;
|
||||
}
|
||||
|
||||
export class Config {
|
||||
@@ -330,9 +332,11 @@ export class Config {
|
||||
private readonly extensionManagement: boolean;
|
||||
private readonly enablePromptCompletion: boolean = false;
|
||||
private readonly skipLoopDetection: boolean;
|
||||
private readonly vlmSwitchMode: string | undefined;
|
||||
private initialized: boolean = false;
|
||||
readonly storage: Storage;
|
||||
private readonly fileExclusions: FileExclusions;
|
||||
private logger: Logger | null = null;
|
||||
|
||||
constructor(params: ConfigParameters) {
|
||||
this.sessionId = params.sessionId;
|
||||
@@ -424,8 +428,15 @@ export class Config {
|
||||
this.extensionManagement = params.extensionManagement ?? false;
|
||||
this.storage = new Storage(this.targetDir);
|
||||
this.enablePromptCompletion = params.enablePromptCompletion ?? false;
|
||||
this.vlmSwitchMode = params.vlmSwitchMode;
|
||||
this.fileExclusions = new FileExclusions(this);
|
||||
|
||||
// Initialize logger asynchronously
|
||||
this.logger = new Logger(this.sessionId, this.storage);
|
||||
this.logger.initialize().catch((error) => {
|
||||
console.debug('Failed to initialize logger:', error);
|
||||
});
|
||||
|
||||
if (params.contextFileName) {
|
||||
setGeminiMdFilename(params.contextFileName);
|
||||
}
|
||||
@@ -517,21 +528,47 @@ export class Config {
|
||||
return this.contentGeneratorConfig?.model || this.model;
|
||||
}
|
||||
|
||||
setModel(newModel: string): void {
|
||||
async setModel(
|
||||
newModel: string,
|
||||
options?: {
|
||||
reason?: ModelSwitchEvent['reason'];
|
||||
context?: string;
|
||||
},
|
||||
): Promise<void> {
|
||||
const oldModel = this.getModel();
|
||||
|
||||
if (this.contentGeneratorConfig) {
|
||||
this.contentGeneratorConfig.model = newModel;
|
||||
}
|
||||
|
||||
// Log the model switch if the model actually changed
|
||||
if (oldModel !== newModel && this.logger) {
|
||||
const switchEvent: ModelSwitchEvent = {
|
||||
fromModel: oldModel,
|
||||
toModel: newModel,
|
||||
reason: options?.reason || 'manual',
|
||||
context: options?.context,
|
||||
};
|
||||
|
||||
// Log asynchronously to avoid blocking
|
||||
this.logger.logModelSwitch(switchEvent).catch((error) => {
|
||||
console.debug('Failed to log model switch:', error);
|
||||
});
|
||||
}
|
||||
|
||||
// Reinitialize chat with updated configuration while preserving history
|
||||
const geminiClient = this.getGeminiClient();
|
||||
if (geminiClient && geminiClient.isInitialized()) {
|
||||
// Use async operation but don't await to avoid blocking
|
||||
geminiClient.reinitialize().catch((error) => {
|
||||
// Now await the reinitialize operation to ensure completion
|
||||
try {
|
||||
await geminiClient.reinitialize();
|
||||
} catch (error) {
|
||||
console.error(
|
||||
'Failed to reinitialize chat with updated config:',
|
||||
error,
|
||||
);
|
||||
});
|
||||
throw error; // Re-throw to let callers handle the error
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -938,6 +975,10 @@ export class Config {
|
||||
return this.skipLoopDetection;
|
||||
}
|
||||
|
||||
getVlmSwitchMode(): string | undefined {
|
||||
return this.vlmSwitchMode;
|
||||
}
|
||||
|
||||
async getGitService(): Promise<GitService> {
|
||||
if (!this.gitService) {
|
||||
this.gitService = new GitService(this.targetDir, this.storage);
|
||||
|
||||
@@ -41,7 +41,7 @@ describe('Flash Model Fallback Configuration', () => {
|
||||
// with the fallback mechanism. This will be necessary we introduce more
|
||||
// intelligent model routing.
|
||||
describe('setModel', () => {
|
||||
it('should only mark as switched if contentGeneratorConfig exists', () => {
|
||||
it('should only mark as switched if contentGeneratorConfig exists', async () => {
|
||||
// Create config without initializing contentGeneratorConfig
|
||||
const newConfig = new Config({
|
||||
sessionId: 'test-session-2',
|
||||
@@ -52,15 +52,15 @@ describe('Flash Model Fallback Configuration', () => {
|
||||
});
|
||||
|
||||
// Should not crash when contentGeneratorConfig is undefined
|
||||
newConfig.setModel(DEFAULT_GEMINI_FLASH_MODEL);
|
||||
await newConfig.setModel(DEFAULT_GEMINI_FLASH_MODEL);
|
||||
expect(newConfig.isInFallbackMode()).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe('getModel', () => {
|
||||
it('should return contentGeneratorConfig model if available', () => {
|
||||
it('should return contentGeneratorConfig model if available', async () => {
|
||||
// Simulate initialized content generator config
|
||||
config.setModel(DEFAULT_GEMINI_FLASH_MODEL);
|
||||
await config.setModel(DEFAULT_GEMINI_FLASH_MODEL);
|
||||
expect(config.getModel()).toBe(DEFAULT_GEMINI_FLASH_MODEL);
|
||||
});
|
||||
|
||||
@@ -88,8 +88,8 @@ describe('Flash Model Fallback Configuration', () => {
|
||||
expect(config.isInFallbackMode()).toBe(false);
|
||||
});
|
||||
|
||||
it('should persist switched state throughout session', () => {
|
||||
config.setModel(DEFAULT_GEMINI_FLASH_MODEL);
|
||||
it('should persist switched state throughout session', async () => {
|
||||
await config.setModel(DEFAULT_GEMINI_FLASH_MODEL);
|
||||
// Setting state for fallback mode as is expected of clients
|
||||
config.setFallbackMode(true);
|
||||
expect(config.isInFallbackMode()).toBe(true);
|
||||
|
||||
@@ -4,11 +4,10 @@
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
export const DEFAULT_QWEN_MODEL = 'qwen3-coder-plus';
|
||||
// We do not have a fallback model for now, but note it here anyway.
|
||||
export const DEFAULT_QWEN_FLASH_MODEL = 'qwen3-coder-flash';
|
||||
export const DEFAULT_QWEN_MODEL = 'coder-model';
|
||||
export const DEFAULT_QWEN_FLASH_MODEL = 'coder-model';
|
||||
|
||||
export const DEFAULT_GEMINI_MODEL = 'qwen3-coder-plus';
|
||||
export const DEFAULT_GEMINI_MODEL = 'coder-model';
|
||||
export const DEFAULT_GEMINI_FLASH_MODEL = 'gemini-2.5-flash';
|
||||
export const DEFAULT_GEMINI_FLASH_LITE_MODEL = 'gemini-2.5-flash-lite';
|
||||
|
||||
|
||||
@@ -1053,7 +1053,7 @@ export class GeminiClient {
|
||||
error,
|
||||
);
|
||||
if (accepted !== false && accepted !== null) {
|
||||
this.config.setModel(fallbackModel);
|
||||
await this.config.setModel(fallbackModel);
|
||||
this.config.setFallbackMode(true);
|
||||
return fallbackModel;
|
||||
}
|
||||
|
||||
@@ -224,7 +224,7 @@ export class GeminiChat {
|
||||
error,
|
||||
);
|
||||
if (accepted !== false && accepted !== null) {
|
||||
this.config.setModel(fallbackModel);
|
||||
await this.config.setModel(fallbackModel);
|
||||
this.config.setFallbackMode(true);
|
||||
return fallbackModel;
|
||||
}
|
||||
|
||||
@@ -755,4 +755,84 @@ describe('Logger', () => {
|
||||
expect(logger['messageId']).toBe(0);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Model Switch Logging', () => {
|
||||
it('should log model switch events correctly', async () => {
|
||||
const testSessionId = 'test-session-model-switch';
|
||||
const logger = new Logger(testSessionId, new Storage(process.cwd()));
|
||||
await logger.initialize();
|
||||
|
||||
const modelSwitchEvent = {
|
||||
fromModel: 'qwen3-coder-plus',
|
||||
toModel: 'qwen-vl-max-latest',
|
||||
reason: 'vision_auto_switch' as const,
|
||||
context: 'YOLO mode auto-switch for image content',
|
||||
};
|
||||
|
||||
await logger.logModelSwitch(modelSwitchEvent);
|
||||
|
||||
// Read the log file to verify the entry was written
|
||||
const logContent = await fs.readFile(TEST_LOG_FILE_PATH, 'utf-8');
|
||||
const logs: LogEntry[] = JSON.parse(logContent);
|
||||
|
||||
const modelSwitchLog = logs.find(
|
||||
(log) =>
|
||||
log.sessionId === testSessionId &&
|
||||
log.type === MessageSenderType.MODEL_SWITCH,
|
||||
);
|
||||
|
||||
expect(modelSwitchLog).toBeDefined();
|
||||
expect(modelSwitchLog!.type).toBe(MessageSenderType.MODEL_SWITCH);
|
||||
|
||||
const loggedEvent = JSON.parse(modelSwitchLog!.message);
|
||||
expect(loggedEvent.fromModel).toBe('qwen3-coder-plus');
|
||||
expect(loggedEvent.toModel).toBe('qwen-vl-max-latest');
|
||||
expect(loggedEvent.reason).toBe('vision_auto_switch');
|
||||
expect(loggedEvent.context).toBe(
|
||||
'YOLO mode auto-switch for image content',
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle multiple model switch events', async () => {
|
||||
const testSessionId = 'test-session-multiple-switches';
|
||||
const logger = new Logger(testSessionId, new Storage(process.cwd()));
|
||||
await logger.initialize();
|
||||
|
||||
// Log first switch
|
||||
await logger.logModelSwitch({
|
||||
fromModel: 'qwen3-coder-plus',
|
||||
toModel: 'qwen-vl-max-latest',
|
||||
reason: 'vision_auto_switch',
|
||||
context: 'Auto-switch for image',
|
||||
});
|
||||
|
||||
// Log second switch (restore)
|
||||
await logger.logModelSwitch({
|
||||
fromModel: 'qwen-vl-max-latest',
|
||||
toModel: 'qwen3-coder-plus',
|
||||
reason: 'vision_auto_switch',
|
||||
context: 'Restoring original model',
|
||||
});
|
||||
|
||||
// Read the log file to verify both entries were written
|
||||
const logContent = await fs.readFile(TEST_LOG_FILE_PATH, 'utf-8');
|
||||
const logs: LogEntry[] = JSON.parse(logContent);
|
||||
|
||||
const modelSwitchLogs = logs.filter(
|
||||
(log) =>
|
||||
log.sessionId === testSessionId &&
|
||||
log.type === MessageSenderType.MODEL_SWITCH,
|
||||
);
|
||||
|
||||
expect(modelSwitchLogs).toHaveLength(2);
|
||||
|
||||
const firstSwitch = JSON.parse(modelSwitchLogs[0].message);
|
||||
expect(firstSwitch.fromModel).toBe('qwen3-coder-plus');
|
||||
expect(firstSwitch.toModel).toBe('qwen-vl-max-latest');
|
||||
|
||||
const secondSwitch = JSON.parse(modelSwitchLogs[1].message);
|
||||
expect(secondSwitch.fromModel).toBe('qwen-vl-max-latest');
|
||||
expect(secondSwitch.toModel).toBe('qwen3-coder-plus');
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -13,6 +13,7 @@ const LOG_FILE_NAME = 'logs.json';
|
||||
|
||||
export enum MessageSenderType {
|
||||
USER = 'user',
|
||||
MODEL_SWITCH = 'model_switch',
|
||||
}
|
||||
|
||||
export interface LogEntry {
|
||||
@@ -23,6 +24,13 @@ export interface LogEntry {
|
||||
message: string;
|
||||
}
|
||||
|
||||
export interface ModelSwitchEvent {
|
||||
fromModel: string;
|
||||
toModel: string;
|
||||
reason: 'vision_auto_switch' | 'manual' | 'fallback' | 'other';
|
||||
context?: string;
|
||||
}
|
||||
|
||||
// This regex matches any character that is NOT a letter (a-z, A-Z),
|
||||
// a number (0-9), a hyphen (-), an underscore (_), or a dot (.).
|
||||
|
||||
@@ -270,6 +278,17 @@ export class Logger {
|
||||
}
|
||||
}
|
||||
|
||||
async logModelSwitch(event: ModelSwitchEvent): Promise<void> {
|
||||
const message = JSON.stringify({
|
||||
fromModel: event.fromModel,
|
||||
toModel: event.toModel,
|
||||
reason: event.reason,
|
||||
context: event.context,
|
||||
});
|
||||
|
||||
await this.logMessage(MessageSenderType.MODEL_SWITCH, message);
|
||||
}
|
||||
|
||||
private _checkpointPath(tag: string): string {
|
||||
if (!tag.length) {
|
||||
throw new Error('No checkpoint tag specified.');
|
||||
|
||||
@@ -820,6 +820,14 @@ function getToolCallExamples(model?: string): string {
|
||||
if (/qwen[^-]*-vl/i.test(model)) {
|
||||
return qwenVlToolCallExamples;
|
||||
}
|
||||
// Match coder-model pattern (same as qwen3-coder)
|
||||
if (/coder-model/i.test(model)) {
|
||||
return qwenCoderToolCallExamples;
|
||||
}
|
||||
// Match vision-model pattern (same as qwen3-vl)
|
||||
if (/vision-model/i.test(model)) {
|
||||
return qwenVlToolCallExamples;
|
||||
}
|
||||
}
|
||||
|
||||
return generalToolCallExamples;
|
||||
|
||||
@@ -111,6 +111,12 @@ const PATTERNS: Array<[RegExp, TokenCount]> = [
|
||||
// Commercial Qwen3-Coder-Flash: 1M token context
|
||||
[/^qwen3-coder-flash(-.*)?$/, LIMITS['1m']], // catches "qwen3-coder-flash" and date variants
|
||||
|
||||
// Generic coder-model: same as qwen3-coder-plus (1M token context)
|
||||
[/^coder-model$/, LIMITS['1m']],
|
||||
|
||||
// Commercial Qwen3-Max-Preview: 256K token context
|
||||
[/^qwen3-max-preview(-.*)?$/, LIMITS['256k']], // catches "qwen3-max-preview" and date variants
|
||||
|
||||
// Open-source Qwen3-Coder variants: 256K native
|
||||
[/^qwen3-coder-.*$/, LIMITS['256k']],
|
||||
// Open-source Qwen3 2507 variants: 256K native
|
||||
@@ -131,6 +137,9 @@ const PATTERNS: Array<[RegExp, TokenCount]> = [
|
||||
// Qwen Vision Models
|
||||
[/^qwen-vl-max.*$/, LIMITS['128k']],
|
||||
|
||||
// Generic vision-model: same as qwen-vl-max (128K token context)
|
||||
[/^vision-model$/, LIMITS['128k']],
|
||||
|
||||
// -------------------
|
||||
// ByteDance Seed-OSS (512K)
|
||||
// -------------------
|
||||
@@ -166,8 +175,20 @@ const OUTPUT_PATTERNS: Array<[RegExp, TokenCount]> = [
|
||||
// Qwen3-Coder-Plus: 65,536 max output tokens
|
||||
[/^qwen3-coder-plus(-.*)?$/, LIMITS['64k']],
|
||||
|
||||
// Generic coder-model: same as qwen3-coder-plus (64K max output tokens)
|
||||
[/^coder-model$/, LIMITS['64k']],
|
||||
|
||||
// Qwen3-Max-Preview: 65,536 max output tokens
|
||||
[/^qwen3-max-preview(-.*)?$/, LIMITS['64k']],
|
||||
|
||||
// Qwen-VL-Max-Latest: 8,192 max output tokens
|
||||
[/^qwen-vl-max-latest$/, LIMITS['8k']],
|
||||
|
||||
// Generic vision-model: same as qwen-vl-max-latest (8K max output tokens)
|
||||
[/^vision-model$/, LIMITS['8k']],
|
||||
|
||||
// Qwen3-VL-Plus: 8,192 max output tokens
|
||||
[/^qwen3-vl-plus$/, LIMITS['8k']],
|
||||
];
|
||||
|
||||
/**
|
||||
|
||||
@@ -72,6 +72,19 @@ async function createMockConfig(
|
||||
} as unknown as ToolRegistry;
|
||||
|
||||
vi.spyOn(config, 'getToolRegistry').mockReturnValue(mockToolRegistry);
|
||||
|
||||
// Mock getContentGeneratorConfig to return a valid config
|
||||
vi.spyOn(config, 'getContentGeneratorConfig').mockReturnValue({
|
||||
model: DEFAULT_GEMINI_MODEL,
|
||||
authType: AuthType.USE_GEMINI,
|
||||
});
|
||||
|
||||
// Mock setModel method
|
||||
vi.spyOn(config, 'setModel').mockResolvedValue();
|
||||
|
||||
// Mock getSessionId method
|
||||
vi.spyOn(config, 'getSessionId').mockReturnValue('test-session');
|
||||
|
||||
return { config, toolRegistry: mockToolRegistry };
|
||||
}
|
||||
|
||||
|
||||
@@ -826,7 +826,7 @@ export class SubAgentScope {
|
||||
);
|
||||
|
||||
if (this.modelConfig.model) {
|
||||
this.runtimeContext.setModel(this.modelConfig.model);
|
||||
await this.runtimeContext.setModel(this.modelConfig.model);
|
||||
}
|
||||
|
||||
return new GeminiChat(
|
||||
|
||||
Reference in New Issue
Block a user