fix: setModel failure

This commit is contained in:
mingholy.lmh
2025-09-23 23:56:02 +08:00
parent e38947a62d
commit 2aa3667d0a
12 changed files with 124 additions and 47 deletions

View File

@@ -566,7 +566,9 @@ const App = ({ config, settings, startupWarnings = [], version }: AppProps) => {
} }
// Switch model for future use but return false to stop current retry // Switch model for future use but return false to stop current retry
config.setModel(fallbackModel); config.setModel(fallbackModel).catch((error) => {
console.error('Failed to switch to fallback model:', error);
});
config.setFallbackMode(true); config.setFallbackMode(true);
logFlashFallback( logFlashFallback(
config, config,
@@ -650,8 +652,9 @@ const App = ({ config, settings, startupWarnings = [], version }: AppProps) => {
}, []); }, []);
const handleModelSelect = useCallback( const handleModelSelect = useCallback(
(modelId: string) => { async (modelId: string) => {
config.setModel(modelId); try {
await config.setModel(modelId);
setCurrentModel(modelId); setCurrentModel(modelId);
setIsModelSelectionDialogOpen(false); setIsModelSelectionDialogOpen(false);
addItem( addItem(
@@ -661,6 +664,16 @@ const App = ({ config, settings, startupWarnings = [], version }: AppProps) => {
}, },
Date.now(), Date.now(),
); );
} catch (error) {
console.error('Failed to switch model:', error);
addItem(
{
type: MessageType.ERROR,
text: `Failed to switch to model \`${modelId}\`. Please try again.`,
},
Date.now(),
);
}
}, },
[config, setCurrentModel, addItem], [config, setCurrentModel, addItem],
); );

View File

@@ -60,7 +60,9 @@ const mockParseAndFormatApiError = vi.hoisted(() => vi.fn());
const mockHandleVisionSwitch = vi.hoisted(() => const mockHandleVisionSwitch = vi.hoisted(() =>
vi.fn().mockResolvedValue({ shouldProceed: true }), vi.fn().mockResolvedValue({ shouldProceed: true }),
); );
const mockRestoreOriginalModel = vi.hoisted(() => vi.fn()); const mockRestoreOriginalModel = vi.hoisted(() =>
vi.fn().mockResolvedValue(undefined),
);
vi.mock('@qwen-code/qwen-code-core', async (importOriginal) => { vi.mock('@qwen-code/qwen-code-core', async (importOriginal) => {
const actualCoreModule = (await importOriginal()) as any; const actualCoreModule = (await importOriginal()) as any;
@@ -301,6 +303,8 @@ describe('useGeminiStream', () => {
() => {}, () => {},
() => {}, () => {},
() => {}, () => {},
false, // visionModelPreviewEnabled
undefined, // onVisionSwitchRequired (optional)
); );
}, },
{ {
@@ -462,6 +466,8 @@ describe('useGeminiStream', () => {
() => {}, () => {},
() => {}, () => {},
() => {}, () => {},
false, // visionModelPreviewEnabled
undefined, // onVisionSwitchRequired (optional)
), ),
); );
@@ -541,6 +547,8 @@ describe('useGeminiStream', () => {
() => {}, () => {},
() => {}, () => {},
() => {}, () => {},
false, // visionModelPreviewEnabled
undefined, // onVisionSwitchRequired (optional)
), ),
); );
@@ -649,6 +657,8 @@ describe('useGeminiStream', () => {
() => {}, () => {},
() => {}, () => {},
() => {}, () => {},
false, // visionModelPreviewEnabled
undefined, // onVisionSwitchRequired (optional)
), ),
); );
@@ -758,6 +768,8 @@ describe('useGeminiStream', () => {
() => {}, () => {},
() => {}, () => {},
() => {}, () => {},
false, // visionModelPreviewEnabled
undefined, // onVisionSwitchRequired (optional)
), ),
); );
@@ -887,6 +899,8 @@ describe('useGeminiStream', () => {
() => {}, () => {},
() => {}, () => {},
cancelSubmitSpy, cancelSubmitSpy,
false, // visionModelPreviewEnabled
undefined, // onVisionSwitchRequired (optional)
), ),
); );
@@ -1198,6 +1212,8 @@ describe('useGeminiStream', () => {
() => {}, () => {},
() => {}, () => {},
() => {}, () => {},
false, // visionModelPreviewEnabled
undefined, // onVisionSwitchRequired (optional)
), ),
); );
@@ -1251,6 +1267,8 @@ describe('useGeminiStream', () => {
() => {}, () => {},
() => {}, () => {},
() => {}, () => {},
false, // visionModelPreviewEnabled
undefined, // onVisionSwitchRequired (optional)
), ),
); );
@@ -1301,6 +1319,8 @@ describe('useGeminiStream', () => {
() => {}, () => {},
() => {}, () => {},
() => {}, () => {},
false, // visionModelPreviewEnabled
undefined, // onVisionSwitchRequired (optional)
), ),
); );
@@ -1349,6 +1369,8 @@ describe('useGeminiStream', () => {
() => {}, () => {},
() => {}, () => {},
() => {}, () => {},
false, // visionModelPreviewEnabled
undefined, // onVisionSwitchRequired (optional)
), ),
); );
@@ -1398,6 +1420,8 @@ describe('useGeminiStream', () => {
() => {}, () => {},
() => {}, () => {},
() => {}, () => {},
false, // visionModelPreviewEnabled
undefined, // onVisionSwitchRequired (optional)
), ),
); );
@@ -1487,6 +1511,8 @@ describe('useGeminiStream', () => {
() => {}, () => {},
() => {}, () => {},
() => {}, () => {},
false, // visionModelPreviewEnabled
undefined, // onVisionSwitchRequired (optional)
), ),
); );
@@ -1537,6 +1563,8 @@ describe('useGeminiStream', () => {
vi.fn(), // setModelSwitched vi.fn(), // setModelSwitched
vi.fn(), // onEditorClose vi.fn(), // onEditorClose
vi.fn(), // onCancelSubmit vi.fn(), // onCancelSubmit
false, // visionModelPreviewEnabled
undefined, // onVisionSwitchRequired (optional)
), ),
); );
@@ -1602,6 +1630,8 @@ describe('useGeminiStream', () => {
() => {}, () => {},
() => {}, () => {},
() => {}, () => {},
false, // visionModelPreviewEnabled
undefined, // onVisionSwitchRequired (optional)
), ),
); );
@@ -1680,6 +1710,8 @@ describe('useGeminiStream', () => {
() => {}, () => {},
() => {}, () => {},
() => {}, () => {},
false, // visionModelPreviewEnabled
undefined, // onVisionSwitchRequired (optional)
), ),
); );
@@ -1734,6 +1766,8 @@ describe('useGeminiStream', () => {
() => {}, () => {},
() => {}, () => {},
() => {}, () => {},
false, // visionModelPreviewEnabled
undefined, // onVisionSwitchRequired (optional)
), ),
); );
@@ -1943,6 +1977,8 @@ describe('useGeminiStream', () => {
() => {}, () => {},
() => {}, () => {},
() => {}, () => {},
false, // visionModelPreviewEnabled
undefined, // onVisionSwitchRequired (optional)
), ),
); );
@@ -1975,6 +2011,8 @@ describe('useGeminiStream', () => {
() => {}, () => {},
() => {}, () => {},
() => {}, () => {},
false, // visionModelPreviewEnabled
undefined, // onVisionSwitchRequired (optional)
), ),
); );
@@ -2028,6 +2066,8 @@ describe('useGeminiStream', () => {
() => {}, () => {},
() => {}, () => {},
() => {}, () => {},
false, // visionModelPreviewEnabled
undefined, // onVisionSwitchRequired (optional)
), ),
); );
@@ -2065,6 +2105,8 @@ describe('useGeminiStream', () => {
() => {}, () => {},
() => {}, () => {},
() => {}, () => {},
false, // visionModelPreviewEnabled
undefined, // onVisionSwitchRequired (optional)
), ),
); );

View File

@@ -765,7 +765,9 @@ export const useGeminiStream = (
if (processingStatus === StreamProcessingStatus.UserCancelled) { if (processingStatus === StreamProcessingStatus.UserCancelled) {
// Restore original model if it was temporarily overridden // Restore original model if it was temporarily overridden
restoreOriginalModel(); restoreOriginalModel().catch((error) => {
console.error('Failed to restore original model:', error);
});
isSubmittingQueryRef.current = false; isSubmittingQueryRef.current = false;
return; return;
} }
@@ -780,10 +782,14 @@ export const useGeminiStream = (
} }
// Restore original model if it was temporarily overridden // Restore original model if it was temporarily overridden
restoreOriginalModel(); restoreOriginalModel().catch((error) => {
console.error('Failed to restore original model:', error);
});
} catch (error: unknown) { } catch (error: unknown) {
// Restore original model if it was temporarily overridden // Restore original model if it was temporarily overridden
restoreOriginalModel(); restoreOriginalModel().catch((error) => {
console.error('Failed to restore original model:', error);
});
if (error instanceof UnauthorizedError) { if (error instanceof UnauthorizedError) {
onAuthError(); onAuthError();

View File

@@ -210,7 +210,7 @@ describe('useVisionAutoSwitch hook', () => {
let currentModel = initialModel; let currentModel = initialModel;
const mockConfig: Partial<Config> = { const mockConfig: Partial<Config> = {
getModel: vi.fn(() => currentModel), getModel: vi.fn(() => currentModel),
setModel: vi.fn((m: string) => { setModel: vi.fn(async (m: string) => {
currentModel = m; currentModel = m;
}), }),
getApprovalMode: vi.fn(() => approvalMode), getApprovalMode: vi.fn(() => approvalMode),
@@ -335,8 +335,8 @@ describe('useVisionAutoSwitch hook', () => {
}); });
// Now restore // Now restore
act(() => { await act(async () => {
result.current.restoreOriginalModel(); await result.current.restoreOriginalModel();
}); });
expect(config.setModel).toHaveBeenLastCalledWith(initialModel, { expect(config.setModel).toHaveBeenLastCalledWith(initialModel, {
reason: 'vision_auto_switch', reason: 'vision_auto_switch',
@@ -369,8 +369,8 @@ describe('useVisionAutoSwitch hook', () => {
}); });
// Restore should be a no-op since no one-time override was used // Restore should be a no-op since no one-time override was used
act(() => { await act(async () => {
result.current.restoreOriginalModel(); await result.current.restoreOriginalModel();
}); });
// Last call should still be the persisted model set // Last call should still be the persisted model set
expect((config.setModel as any).mock.calls.pop()?.[0]).toBe('coder-model'); expect((config.setModel as any).mock.calls.pop()?.[0]).toBe('coder-model');
@@ -565,8 +565,8 @@ describe('useVisionAutoSwitch hook', () => {
}); });
// Now restore the original model // Now restore the original model
act(() => { await act(async () => {
result.current.restoreOriginalModel(); await result.current.restoreOriginalModel();
}); });
// Verify model was restored // Verify model was restored

View File

@@ -256,7 +256,7 @@ export function useVisionAutoSwitch(
if (config.getApprovalMode() === ApprovalMode.YOLO) { if (config.getApprovalMode() === ApprovalMode.YOLO) {
const vlModelId = getDefaultVisionModel(); const vlModelId = getDefaultVisionModel();
originalModelRef.current = config.getModel(); originalModelRef.current = config.getModel();
config.setModel(vlModelId, { await config.setModel(vlModelId, {
reason: 'vision_auto_switch', reason: 'vision_auto_switch',
context: 'YOLO mode auto-switch for image content', context: 'YOLO mode auto-switch for image content',
}); });
@@ -292,7 +292,7 @@ export function useVisionAutoSwitch(
if (visionSwitchResult.modelOverride) { if (visionSwitchResult.modelOverride) {
// One-time model override // One-time model override
originalModelRef.current = config.getModel(); originalModelRef.current = config.getModel();
config.setModel(visionSwitchResult.modelOverride, { await config.setModel(visionSwitchResult.modelOverride, {
reason: 'vision_auto_switch', reason: 'vision_auto_switch',
context: `Default VLM switch mode: ${defaultVlmSwitchMode} (one-time override)`, context: `Default VLM switch mode: ${defaultVlmSwitchMode} (one-time override)`,
}); });
@@ -302,7 +302,7 @@ export function useVisionAutoSwitch(
}; };
} else if (visionSwitchResult.persistSessionModel) { } else if (visionSwitchResult.persistSessionModel) {
// Persistent session model change // Persistent session model change
config.setModel(visionSwitchResult.persistSessionModel, { await config.setModel(visionSwitchResult.persistSessionModel, {
reason: 'vision_auto_switch', reason: 'vision_auto_switch',
context: `Default VLM switch mode: ${defaultVlmSwitchMode} (session persistent)`, context: `Default VLM switch mode: ${defaultVlmSwitchMode} (session persistent)`,
}); });
@@ -319,7 +319,7 @@ export function useVisionAutoSwitch(
if (visionSwitchResult.modelOverride) { if (visionSwitchResult.modelOverride) {
// One-time model override // One-time model override
originalModelRef.current = config.getModel(); originalModelRef.current = config.getModel();
config.setModel(visionSwitchResult.modelOverride, { await config.setModel(visionSwitchResult.modelOverride, {
reason: 'vision_auto_switch', reason: 'vision_auto_switch',
context: 'User-prompted vision switch (one-time override)', context: 'User-prompted vision switch (one-time override)',
}); });
@@ -329,7 +329,7 @@ export function useVisionAutoSwitch(
}; };
} else if (visionSwitchResult.persistSessionModel) { } else if (visionSwitchResult.persistSessionModel) {
// Persistent session model change // Persistent session model change
config.setModel(visionSwitchResult.persistSessionModel, { await config.setModel(visionSwitchResult.persistSessionModel, {
reason: 'vision_auto_switch', reason: 'vision_auto_switch',
context: 'User-prompted vision switch (session persistent)', context: 'User-prompted vision switch (session persistent)',
}); });
@@ -346,9 +346,9 @@ export function useVisionAutoSwitch(
[config, addItem, visionModelPreviewEnabled, onVisionSwitchRequired], [config, addItem, visionModelPreviewEnabled, onVisionSwitchRequired],
); );
const restoreOriginalModel = useCallback(() => { const restoreOriginalModel = useCallback(async () => {
if (originalModelRef.current) { if (originalModelRef.current) {
config.setModel(originalModelRef.current, { await config.setModel(originalModelRef.current, {
reason: 'vision_auto_switch', reason: 'vision_auto_switch',
context: 'Restoring original model after vision switch', context: 'Restoring original model after vision switch',
}); });

View File

@@ -755,7 +755,7 @@ describe('setApprovalMode with folder trust', () => {
const logModelSwitchSpy = vi.spyOn(config['logger']!, 'logModelSwitch'); const logModelSwitchSpy = vi.spyOn(config['logger']!, 'logModelSwitch');
// Change the model // Change the model
config.setModel('qwen-vl-max-latest', { await config.setModel('qwen-vl-max-latest', {
reason: 'vision_auto_switch', reason: 'vision_auto_switch',
context: 'Test model switch', context: 'Test model switch',
}); });
@@ -785,7 +785,7 @@ describe('setApprovalMode with folder trust', () => {
const logModelSwitchSpy = vi.spyOn(config['logger']!, 'logModelSwitch'); const logModelSwitchSpy = vi.spyOn(config['logger']!, 'logModelSwitch');
// Set the same model // Set the same model
config.setModel('qwen3-coder-plus'); await config.setModel('qwen3-coder-plus');
// Verify that logModelSwitch was not called // Verify that logModelSwitch was not called
expect(logModelSwitchSpy).not.toHaveBeenCalled(); expect(logModelSwitchSpy).not.toHaveBeenCalled();
@@ -807,7 +807,7 @@ describe('setApprovalMode with folder trust', () => {
const logModelSwitchSpy = vi.spyOn(config['logger']!, 'logModelSwitch'); const logModelSwitchSpy = vi.spyOn(config['logger']!, 'logModelSwitch');
// Change the model without options // Change the model without options
config.setModel('qwen-vl-max-latest'); await config.setModel('qwen-vl-max-latest');
// Verify that logModelSwitch was called with default reason // Verify that logModelSwitch was called with default reason
expect(logModelSwitchSpy).toHaveBeenCalledWith({ expect(logModelSwitchSpy).toHaveBeenCalledWith({

View File

@@ -528,13 +528,13 @@ export class Config {
return this.contentGeneratorConfig?.model || this.model; return this.contentGeneratorConfig?.model || this.model;
} }
setModel( async setModel(
newModel: string, newModel: string,
options?: { options?: {
reason?: ModelSwitchEvent['reason']; reason?: ModelSwitchEvent['reason'];
context?: string; context?: string;
}, },
): void { ): Promise<void> {
const oldModel = this.getModel(); const oldModel = this.getModel();
if (this.contentGeneratorConfig) { if (this.contentGeneratorConfig) {
@@ -559,13 +559,16 @@ export class Config {
// Reinitialize chat with updated configuration while preserving history // Reinitialize chat with updated configuration while preserving history
const geminiClient = this.getGeminiClient(); const geminiClient = this.getGeminiClient();
if (geminiClient && geminiClient.isInitialized()) { if (geminiClient && geminiClient.isInitialized()) {
// Use async operation but don't await to avoid blocking // Now await the reinitialize operation to ensure completion
geminiClient.reinitialize().catch((error) => { try {
await geminiClient.reinitialize();
} catch (error) {
console.error( console.error(
'Failed to reinitialize chat with updated config:', 'Failed to reinitialize chat with updated config:',
error, error,
); );
}); throw error; // Re-throw to let callers handle the error
}
} }
} }

View File

@@ -41,7 +41,7 @@ describe('Flash Model Fallback Configuration', () => {
// with the fallback mechanism. This will be necessary we introduce more // with the fallback mechanism. This will be necessary we introduce more
// intelligent model routing. // intelligent model routing.
describe('setModel', () => { describe('setModel', () => {
it('should only mark as switched if contentGeneratorConfig exists', () => { it('should only mark as switched if contentGeneratorConfig exists', async () => {
// Create config without initializing contentGeneratorConfig // Create config without initializing contentGeneratorConfig
const newConfig = new Config({ const newConfig = new Config({
sessionId: 'test-session-2', sessionId: 'test-session-2',
@@ -52,15 +52,15 @@ describe('Flash Model Fallback Configuration', () => {
}); });
// Should not crash when contentGeneratorConfig is undefined // Should not crash when contentGeneratorConfig is undefined
newConfig.setModel(DEFAULT_GEMINI_FLASH_MODEL); await newConfig.setModel(DEFAULT_GEMINI_FLASH_MODEL);
expect(newConfig.isInFallbackMode()).toBe(false); expect(newConfig.isInFallbackMode()).toBe(false);
}); });
}); });
describe('getModel', () => { describe('getModel', () => {
it('should return contentGeneratorConfig model if available', () => { it('should return contentGeneratorConfig model if available', async () => {
// Simulate initialized content generator config // Simulate initialized content generator config
config.setModel(DEFAULT_GEMINI_FLASH_MODEL); await config.setModel(DEFAULT_GEMINI_FLASH_MODEL);
expect(config.getModel()).toBe(DEFAULT_GEMINI_FLASH_MODEL); expect(config.getModel()).toBe(DEFAULT_GEMINI_FLASH_MODEL);
}); });
@@ -88,8 +88,8 @@ describe('Flash Model Fallback Configuration', () => {
expect(config.isInFallbackMode()).toBe(false); expect(config.isInFallbackMode()).toBe(false);
}); });
it('should persist switched state throughout session', () => { it('should persist switched state throughout session', async () => {
config.setModel(DEFAULT_GEMINI_FLASH_MODEL); await config.setModel(DEFAULT_GEMINI_FLASH_MODEL);
// Setting state for fallback mode as is expected of clients // Setting state for fallback mode as is expected of clients
config.setFallbackMode(true); config.setFallbackMode(true);
expect(config.isInFallbackMode()).toBe(true); expect(config.isInFallbackMode()).toBe(true);

View File

@@ -1053,7 +1053,7 @@ export class GeminiClient {
error, error,
); );
if (accepted !== false && accepted !== null) { if (accepted !== false && accepted !== null) {
this.config.setModel(fallbackModel); await this.config.setModel(fallbackModel);
this.config.setFallbackMode(true); this.config.setFallbackMode(true);
return fallbackModel; return fallbackModel;
} }

View File

@@ -224,7 +224,7 @@ export class GeminiChat {
error, error,
); );
if (accepted !== false && accepted !== null) { if (accepted !== false && accepted !== null) {
this.config.setModel(fallbackModel); await this.config.setModel(fallbackModel);
this.config.setFallbackMode(true); this.config.setFallbackMode(true);
return fallbackModel; return fallbackModel;
} }

View File

@@ -72,6 +72,19 @@ async function createMockConfig(
} as unknown as ToolRegistry; } as unknown as ToolRegistry;
vi.spyOn(config, 'getToolRegistry').mockReturnValue(mockToolRegistry); vi.spyOn(config, 'getToolRegistry').mockReturnValue(mockToolRegistry);
// Mock getContentGeneratorConfig to return a valid config
vi.spyOn(config, 'getContentGeneratorConfig').mockReturnValue({
model: DEFAULT_GEMINI_MODEL,
authType: AuthType.USE_GEMINI,
});
// Mock setModel method
vi.spyOn(config, 'setModel').mockResolvedValue();
// Mock getSessionId method
vi.spyOn(config, 'getSessionId').mockReturnValue('test-session');
return { config, toolRegistry: mockToolRegistry }; return { config, toolRegistry: mockToolRegistry };
} }

View File

@@ -826,7 +826,7 @@ export class SubAgentScope {
); );
if (this.modelConfig.model) { if (this.modelConfig.model) {
this.runtimeContext.setModel(this.modelConfig.model); await this.runtimeContext.setModel(this.modelConfig.model);
} }
return new GeminiChat( return new GeminiChat(