Compare commits

..

10 Commits

Author SHA1 Message Date
github-actions[bot]
63579e71f7 chore(release): v0.5.1-preview.1 2025-12-24 00:16:16 +00:00
tanzhenxin
bc2a7efcb3 Merge pull request #1297 from QwenLM/feat/gemini-3-integration
Add Gemini provider, remove legacy Google OAuth, and tune generation …
2025-12-23 22:16:39 +08:00
tanzhenxin
1dfd880e17 reset default topP to 0.95 as claude modes does not allow topP smaller than 0.95 2025-12-23 21:58:28 +08:00
tanzhenxin
10a0c843c1 fix flaky tests 2025-12-23 14:52:03 +08:00
tanzhenxin
955547d523 minor updates to address review comments 2025-12-23 14:35:41 +08:00
tanzhenxin
3bc862df89 unset temperature, and set topP=0.8 for default provider 2025-12-23 13:56:06 +08:00
tanzhenxin
87d8d82be7 special handling for summarized thinking 2025-12-22 14:07:23 +08:00
tanzhenxin
fefc138485 Merge branch 'main' into feat/gemini-3-integration 2025-12-22 10:08:15 +08:00
tanzhenxin
b8a16d362a Merge branch 'main' into feat/gemini-3-integration 2025-12-19 16:39:42 +08:00
tanzhenxin
17129024f4 Add Gemini provider, remove legacy Google OAuth, and tune generation defaults 2025-12-19 16:26:54 +08:00
169 changed files with 2613 additions and 16745 deletions

2021
package-lock.json generated

File diff suppressed because it is too large Load Diff

View File

@@ -1,6 +1,6 @@
{
"name": "@qwen-code/qwen-code",
"version": "0.6.0",
"version": "0.5.1-preview.1",
"engines": {
"node": ">=20.0.0"
},
@@ -13,14 +13,11 @@
"url": "git+https://github.com/QwenLM/qwen-code.git"
},
"config": {
"sandboxImageUri": "ghcr.io/qwenlm/qwen-code:0.6.0"
"sandboxImageUri": "ghcr.io/qwenlm/qwen-code:0.5.1-preview.1"
},
"scripts": {
"start": "cross-env node scripts/start.js",
"debug": "cross-env DEBUG=1 node --inspect-brk scripts/start.js",
"auth:npm": "npx google-artifactregistry-auth",
"auth:docker": "gcloud auth configure-docker us-west1-docker.pkg.dev",
"auth": "npm run auth:npm && npm run auth:docker",
"generate": "node scripts/generate-git-commit-info.js",
"build": "node scripts/build.js",
"build-and-start": "npm run build && npm run start",
@@ -95,7 +92,6 @@
"eslint-plugin-react-hooks": "^5.2.0",
"glob": "^10.5.0",
"globals": "^16.0.0",
"google-artifactregistry-auth": "^3.4.0",
"husky": "^9.1.7",
"json": "^11.0.0",
"lint-staged": "^16.1.6",

View File

@@ -1,6 +1,6 @@
{
"name": "@qwen-code/qwen-code",
"version": "0.6.0",
"version": "0.5.1-preview.1",
"description": "Qwen Code",
"repository": {
"type": "git",
@@ -33,13 +33,13 @@
"dist"
],
"config": {
"sandboxImageUri": "ghcr.io/qwenlm/qwen-code:0.6.0"
"sandboxImageUri": "ghcr.io/qwenlm/qwen-code:0.5.1-preview.1"
},
"dependencies": {
"@google/genai": "1.16.0",
"@google/genai": "1.30.0",
"@iarna/toml": "^2.2.5",
"@qwen-code/qwen-code-core": "file:../core",
"@modelcontextprotocol/sdk": "^1.15.1",
"@modelcontextprotocol/sdk": "^1.25.1",
"@types/update-notifier": "^6.0.8",
"ansi-regex": "^6.2.2",
"command-exists": "^1.2.9",

View File

@@ -26,5 +26,23 @@ export function validateAuthMethod(authMethod: string): string | null {
return null;
}
if (authMethod === AuthType.USE_GEMINI) {
const hasApiKey = process.env['GEMINI_API_KEY'];
if (!hasApiKey) {
return 'GEMINI_API_KEY environment variable not found. Please set it in your .env file or environment variables.';
}
return null;
}
if (authMethod === AuthType.USE_VERTEX_AI) {
const hasApiKey = process.env['GOOGLE_API_KEY'];
if (!hasApiKey) {
return 'GOOGLE_API_KEY environment variable not found. Please set it in your .env file or environment variables.';
}
process.env['GOOGLE_GENAI_USE_VERTEXAI'] = 'true';
return null;
}
return 'Invalid auth method selected.';
}

View File

@@ -460,7 +460,12 @@ export async function parseArguments(settings: Settings): Promise<CliArgs> {
})
.option('auth-type', {
type: 'string',
choices: [AuthType.USE_OPENAI, AuthType.QWEN_OAUTH],
choices: [
AuthType.USE_OPENAI,
AuthType.QWEN_OAUTH,
AuthType.USE_GEMINI,
AuthType.USE_VERTEX_AI,
],
description: 'Authentication type',
})
.deprecateOption(

View File

@@ -56,6 +56,17 @@ vi.mock('simple-git', () => ({
}),
}));
vi.mock('./extensions/github.js', async (importOriginal) => {
const actual =
await importOriginal<typeof import('./extensions/github.js')>();
return {
...actual,
downloadFromGitHubRelease: vi
.fn()
.mockRejectedValue(new Error('Mocked GitHub release download failure')),
};
});
vi.mock('os', async (importOriginal) => {
const mockedOs = await importOriginal<typeof os>();
return {

View File

@@ -41,6 +41,17 @@ vi.mock('simple-git', () => ({
}),
}));
vi.mock('../extensions/github.js', async (importOriginal) => {
const actual =
await importOriginal<typeof import('../extensions/github.js')>();
return {
...actual,
downloadFromGitHubRelease: vi
.fn()
.mockRejectedValue(new Error('Mocked GitHub release download failure')),
};
});
vi.mock('os', async (importOriginal) => {
const mockedOs = await importOriginal<typeof os>();
return {

View File

@@ -4,13 +4,8 @@
* SPDX-License-Identifier: Apache-2.0
*/
import type { Config } from '@qwen-code/qwen-code-core';
import {
AuthType,
getOauthClient,
InputFormat,
logUserPrompt,
} from '@qwen-code/qwen-code-core';
import type { Config, AuthType } from '@qwen-code/qwen-code-core';
import { InputFormat, logUserPrompt } from '@qwen-code/qwen-code-core';
import { render } from 'ink';
import dns from 'node:dns';
import os from 'node:os';
@@ -399,15 +394,6 @@ export async function main() {
initializationResult = await initializeApp(config, settings);
}
if (
settings.merged.security?.auth?.selectedType ===
AuthType.LOGIN_WITH_GOOGLE &&
config.isBrowserLaunchSuppressed()
) {
// Do oauth before app renders to make copying the link possible.
await getOauthClient(settings.merged.security.auth.selectedType, config);
}
if (config.getExperimentalZedIntegration()) {
return runAcpAgent(config, settings, extensions, argv);
}

View File

@@ -610,8 +610,6 @@ export abstract class BaseJsonOutputAdapter {
const errorText = parseAndFormatApiError(
event.value.error,
this.config.getContentGeneratorConfig()?.authType,
undefined,
this.config.getModel(),
);
this.appendText(state, errorText, null);
break;

View File

@@ -221,8 +221,6 @@ export async function runNonInteractive(
const errorText = parseAndFormatApiError(
event.value.error,
config.getContentGeneratorConfig()?.authType,
undefined,
config.getModel(),
);
process.stderr.write(`${errorText}\n`);
}

View File

@@ -28,7 +28,7 @@ const mockPrompt = {
{ name: 'trail', required: false, description: "The animal's trail." },
],
invoke: vi.fn().mockResolvedValue({
messages: [{ content: { text: 'Hello, world!' } }],
messages: [{ content: { type: 'text', text: 'Hello, world!' } }],
}),
};

View File

@@ -123,7 +123,10 @@ export class McpPromptLoader implements ICommandLoader {
};
}
if (!result.messages?.[0]?.content?.['text']) {
const firstMessage = result.messages?.[0];
const content = firstMessage?.content;
if (content?.type !== 'text') {
return {
type: 'message',
messageType: 'error',
@@ -134,7 +137,7 @@ export class McpPromptLoader implements ICommandLoader {
return {
type: 'submit_prompt',
content: JSON.stringify(result.messages[0].content.text),
content: JSON.stringify(content.text),
};
} catch (error) {
return {

View File

@@ -23,7 +23,6 @@ import {
} from '@qwen-code/qwen-code-core';
import type { LoadedSettings } from '../config/settings.js';
import type { InitializationResult } from '../core/initializer.js';
import { useQuotaAndFallback } from './hooks/useQuotaAndFallback.js';
import { UIStateContext, type UIState } from './contexts/UIStateContext.js';
import {
UIActionsContext,
@@ -56,7 +55,6 @@ vi.mock('./App.js', () => ({
App: TestContextConsumer,
}));
vi.mock('./hooks/useQuotaAndFallback.js');
vi.mock('./hooks/useHistoryManager.js');
vi.mock('./hooks/useThemeCommand.js');
vi.mock('./auth/useAuth.js');
@@ -122,7 +120,6 @@ describe('AppContainer State Management', () => {
let mockInitResult: InitializationResult;
// Create typed mocks for all hooks
const mockedUseQuotaAndFallback = useQuotaAndFallback as Mock;
const mockedUseHistory = useHistory as Mock;
const mockedUseThemeCommand = useThemeCommand as Mock;
const mockedUseAuthCommand = useAuthCommand as Mock;
@@ -164,10 +161,6 @@ describe('AppContainer State Management', () => {
capturedUIActions = null!;
// **Provide a default return value for EVERY mocked hook.**
mockedUseQuotaAndFallback.mockReturnValue({
proQuotaRequest: null,
handleProQuotaChoice: vi.fn(),
});
mockedUseHistory.mockReturnValue({
history: [],
addItem: vi.fn(),
@@ -567,75 +560,6 @@ describe('AppContainer State Management', () => {
});
});
describe('Quota and Fallback Integration', () => {
it('passes a null proQuotaRequest to UIStateContext by default', () => {
// The default mock from beforeEach already sets proQuotaRequest to null
render(
<AppContainer
config={mockConfig}
settings={mockSettings}
version="1.0.0"
initializationResult={mockInitResult}
/>,
);
// Assert that the context value is as expected
expect(capturedUIState.proQuotaRequest).toBeNull();
});
it('passes a valid proQuotaRequest to UIStateContext when provided by the hook', () => {
// Arrange: Create a mock request object that a UI dialog would receive
const mockRequest = {
failedModel: 'gemini-pro',
fallbackModel: 'gemini-flash',
resolve: vi.fn(),
};
mockedUseQuotaAndFallback.mockReturnValue({
proQuotaRequest: mockRequest,
handleProQuotaChoice: vi.fn(),
});
// Act: Render the container
render(
<AppContainer
config={mockConfig}
settings={mockSettings}
version="1.0.0"
initializationResult={mockInitResult}
/>,
);
// Assert: The mock request is correctly passed through the context
expect(capturedUIState.proQuotaRequest).toEqual(mockRequest);
});
it('passes the handleProQuotaChoice function to UIActionsContext', () => {
// Arrange: Create a mock handler function
const mockHandler = vi.fn();
mockedUseQuotaAndFallback.mockReturnValue({
proQuotaRequest: null,
handleProQuotaChoice: mockHandler,
});
// Act: Render the container
render(
<AppContainer
config={mockConfig}
settings={mockSettings}
version="1.0.0"
initializationResult={mockInitResult}
/>,
);
// Assert: The action in the context is the mock handler we provided
expect(capturedUIActions.handleProQuotaChoice).toBe(mockHandler);
// You can even verify that the plumbed function is callable
capturedUIActions.handleProQuotaChoice('auth');
expect(mockHandler).toHaveBeenCalledWith('auth');
});
});
describe('Terminal Title Update Feature', () => {
beforeEach(() => {
// Reset mock stdout for each test

View File

@@ -32,7 +32,6 @@ import {
type Config,
type IdeInfo,
type IdeContext,
type UserTierId,
DEFAULT_GEMINI_FLASH_MODEL,
IdeClient,
ideContextStore,
@@ -48,7 +47,6 @@ import { useHistory } from './hooks/useHistoryManager.js';
import { useMemoryMonitor } from './hooks/useMemoryMonitor.js';
import { useThemeCommand } from './hooks/useThemeCommand.js';
import { useAuthCommand } from './auth/useAuth.js';
import { useQuotaAndFallback } from './hooks/useQuotaAndFallback.js';
import { useEditorSettings } from './hooks/useEditorSettings.js';
import { useSettingsCommand } from './hooks/useSettingsCommand.js';
import { useModelCommand } from './hooks/useModelCommand.js';
@@ -192,8 +190,6 @@ export const AppContainer = (props: AppContainerProps) => {
const [currentModel, setCurrentModel] = useState(getEffectiveModel());
const [userTier] = useState<UserTierId | undefined>(undefined);
const [isConfigInitialized, setConfigInitialized] = useState(false);
const [userMessages, setUserMessages] = useState<string[]>([]);
@@ -367,14 +363,6 @@ export const AppContainer = (props: AppContainerProps) => {
cancelAuthentication,
} = useAuthCommand(settings, config, historyManager.addItem);
const { proQuotaRequest, handleProQuotaChoice } = useQuotaAndFallback({
config,
historyManager,
userTier,
setAuthState,
setModelSwitchedFromQuotaError,
});
useInitializationAuthError(initializationResult.authError, onAuthError);
// Sync user tier from config when authentication changes
@@ -752,8 +740,7 @@ export const AppContainer = (props: AppContainerProps) => {
!initError &&
!isProcessing &&
(streamingState === StreamingState.Idle ||
streamingState === StreamingState.Responding) &&
!proQuotaRequest;
streamingState === StreamingState.Responding);
const [controlsHeight, setControlsHeight] = useState(0);
@@ -1206,7 +1193,6 @@ export const AppContainer = (props: AppContainerProps) => {
isAuthenticating ||
isEditorDialogOpen ||
showIdeRestartPrompt ||
!!proQuotaRequest ||
isSubagentCreateDialogOpen ||
isAgentsManagerDialogOpen ||
isApprovalModeDialogOpen ||
@@ -1277,8 +1263,6 @@ export const AppContainer = (props: AppContainerProps) => {
showWorkspaceMigrationDialog,
workspaceExtensions,
currentModel,
userTier,
proQuotaRequest,
contextFileNames,
errorCount,
availableTerminalHeight,
@@ -1367,8 +1351,6 @@ export const AppContainer = (props: AppContainerProps) => {
showAutoAcceptIndicator,
showWorkspaceMigrationDialog,
workspaceExtensions,
userTier,
proQuotaRequest,
contextFileNames,
errorCount,
availableTerminalHeight,
@@ -1430,7 +1412,6 @@ export const AppContainer = (props: AppContainerProps) => {
handleClearScreen,
onWorkspaceMigrationDialogOpen,
onWorkspaceMigrationDialogClose,
handleProQuotaChoice,
// Vision switch dialog
handleVisionSwitchSelect,
// Welcome back dialog
@@ -1468,7 +1449,6 @@ export const AppContainer = (props: AppContainerProps) => {
handleClearScreen,
onWorkspaceMigrationDialogOpen,
onWorkspaceMigrationDialogClose,
handleProQuotaChoice,
handleVisionSwitchSelect,
handleWelcomeBackSelection,
handleWelcomeBackClose,

View File

@@ -168,7 +168,7 @@ describe('AuthDialog', () => {
it('should not show the GEMINI_API_KEY message if QWEN_DEFAULT_AUTH_TYPE is set to something else', () => {
process.env['GEMINI_API_KEY'] = 'foobar';
process.env['QWEN_DEFAULT_AUTH_TYPE'] = AuthType.LOGIN_WITH_GOOGLE;
process.env['QWEN_DEFAULT_AUTH_TYPE'] = AuthType.USE_OPENAI;
const settings: LoadedSettings = new LoadedSettings(
{
@@ -212,7 +212,7 @@ describe('AuthDialog', () => {
it('should show the GEMINI_API_KEY message if QWEN_DEFAULT_AUTH_TYPE is set to use api key', () => {
process.env['GEMINI_API_KEY'] = 'foobar';
process.env['QWEN_DEFAULT_AUTH_TYPE'] = AuthType.USE_GEMINI;
process.env['QWEN_DEFAULT_AUTH_TYPE'] = AuthType.USE_OPENAI;
const settings: LoadedSettings = new LoadedSettings(
{
@@ -504,12 +504,12 @@ describe('AuthDialog', () => {
},
{
settings: {
security: { auth: { selectedType: AuthType.LOGIN_WITH_GOOGLE } },
security: { auth: { selectedType: AuthType.USE_OPENAI } },
ui: { customThemes: {} },
mcpServers: {},
},
originalSettings: {
security: { auth: { selectedType: AuthType.LOGIN_WITH_GOOGLE } },
security: { auth: { selectedType: AuthType.USE_OPENAI } },
ui: { customThemes: {} },
mcpServers: {},
},

View File

@@ -225,16 +225,24 @@ export const useAuthCommand = (
const defaultAuthType = process.env['QWEN_DEFAULT_AUTH_TYPE'];
if (
defaultAuthType &&
![AuthType.QWEN_OAUTH, AuthType.USE_OPENAI].includes(
defaultAuthType as AuthType,
)
![
AuthType.QWEN_OAUTH,
AuthType.USE_OPENAI,
AuthType.USE_GEMINI,
AuthType.USE_VERTEX_AI,
].includes(defaultAuthType as AuthType)
) {
onAuthError(
t(
'Invalid QWEN_DEFAULT_AUTH_TYPE value: "{{value}}". Valid values are: {{validValues}}',
{
value: defaultAuthType,
validValues: [AuthType.QWEN_OAUTH, AuthType.USE_OPENAI].join(', '),
validValues: [
AuthType.QWEN_OAUTH,
AuthType.USE_OPENAI,
AuthType.USE_GEMINI,
AuthType.USE_VERTEX_AI,
].join(', '),
},
),
);

View File

@@ -15,7 +15,6 @@ vi.mock('@qwen-code/qwen-code-core', async (importOriginal) => {
const original = await importOriginal<typeof core>();
return {
...original,
getOauthClient: vi.fn(original.getOauthClient),
getIdeInstaller: vi.fn(original.getIdeInstaller),
IdeClient: {
getInstance: vi.fn(),

View File

@@ -17,7 +17,6 @@ import { AuthDialog } from '../auth/AuthDialog.js';
import { OpenAIKeyPrompt } from './OpenAIKeyPrompt.js';
import { EditorSettingsDialog } from './EditorSettingsDialog.js';
import { WorkspaceMigrationDialog } from './WorkspaceMigrationDialog.js';
import { ProQuotaDialog } from './ProQuotaDialog.js';
import { PermissionsModifyTrustDialog } from './PermissionsModifyTrustDialog.js';
import { ModelDialog } from './ModelDialog.js';
import { ApprovalModeDialog } from './ApprovalModeDialog.js';
@@ -87,15 +86,6 @@ export const DialogManager = ({
/>
);
}
if (uiState.proQuotaRequest) {
return (
<ProQuotaDialog
failedModel={uiState.proQuotaRequest.failedModel}
fallbackModel={uiState.proQuotaRequest.fallbackModel}
onChoice={uiActions.handleProQuotaChoice}
/>
);
}
if (uiState.shouldShowIdePrompt) {
return (
<IdeIntegrationNudge

View File

@@ -1,91 +0,0 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { render } from 'ink-testing-library';
import { describe, it, expect, vi, beforeEach, type Mock } from 'vitest';
import { ProQuotaDialog } from './ProQuotaDialog.js';
import { RadioButtonSelect } from './shared/RadioButtonSelect.js';
// Mock the child component to make it easier to test the parent
vi.mock('./shared/RadioButtonSelect.js', () => ({
RadioButtonSelect: vi.fn(),
}));
describe('ProQuotaDialog', () => {
beforeEach(() => {
vi.clearAllMocks();
});
it('should render with correct title and options', () => {
const { lastFrame } = render(
<ProQuotaDialog
failedModel="gemini-2.5-pro"
fallbackModel="gemini-2.5-flash"
onChoice={() => {}}
/>,
);
const output = lastFrame();
expect(output).toContain('Pro quota limit reached for gemini-2.5-pro.');
// Check that RadioButtonSelect was called with the correct items
expect(RadioButtonSelect).toHaveBeenCalledWith(
expect.objectContaining({
items: [
{
label: 'Change auth (executes the /auth command)',
value: 'auth',
key: 'auth',
},
{
label: `Continue with gemini-2.5-flash`,
value: 'continue',
key: 'continue',
},
],
}),
undefined,
);
});
it('should call onChoice with "auth" when "Change auth" is selected', () => {
const mockOnChoice = vi.fn();
render(
<ProQuotaDialog
failedModel="gemini-2.5-pro"
fallbackModel="gemini-2.5-flash"
onChoice={mockOnChoice}
/>,
);
// Get the onSelect function passed to RadioButtonSelect
const onSelect = (RadioButtonSelect as Mock).mock.calls[0][0].onSelect;
// Simulate the selection
onSelect('auth');
expect(mockOnChoice).toHaveBeenCalledWith('auth');
});
it('should call onChoice with "continue" when "Continue with flash" is selected', () => {
const mockOnChoice = vi.fn();
render(
<ProQuotaDialog
failedModel="gemini-2.5-pro"
fallbackModel="gemini-2.5-flash"
onChoice={mockOnChoice}
/>,
);
// Get the onSelect function passed to RadioButtonSelect
const onSelect = (RadioButtonSelect as Mock).mock.calls[0][0].onSelect;
// Simulate the selection
onSelect('continue');
expect(mockOnChoice).toHaveBeenCalledWith('continue');
});
});

View File

@@ -1,55 +0,0 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import type React from 'react';
import { Box, Text } from 'ink';
import { RadioButtonSelect } from './shared/RadioButtonSelect.js';
import { theme } from '../semantic-colors.js';
import { t } from '../../i18n/index.js';
interface ProQuotaDialogProps {
failedModel: string;
fallbackModel: string;
onChoice: (choice: 'auth' | 'continue') => void;
}
export function ProQuotaDialog({
failedModel,
fallbackModel,
onChoice,
}: ProQuotaDialogProps): React.JSX.Element {
const items = [
{
label: t('Change auth (executes the /auth command)'),
value: 'auth' as const,
key: 'auth',
},
{
label: t('Continue with {{model}}', { model: fallbackModel }),
value: 'continue' as const,
key: 'continue',
},
];
const handleSelect = (choice: 'auth' | 'continue') => {
onChoice(choice);
};
return (
<Box borderStyle="round" flexDirection="column" paddingX={1}>
<Text bold color={theme.status.warning}>
{t('Pro quota limit reached for {{model}}.', { model: failedModel })}
</Text>
<Box marginTop={1}>
<RadioButtonSelect
items={items}
initialIndex={1}
onSelect={handleSelect}
/>
</Box>
</Box>
);
}

View File

@@ -55,7 +55,6 @@ export interface UIActions {
handleClearScreen: () => void;
onWorkspaceMigrationDialogOpen: () => void;
onWorkspaceMigrationDialogClose: () => void;
handleProQuotaChoice: (choice: 'auth' | 'continue') => void;
// Vision switch dialog
handleVisionSwitchSelect: (outcome: VisionSwitchOutcome) => void;
// Welcome back dialog

View File

@@ -22,21 +22,13 @@ import type {
AuthType,
IdeContext,
ApprovalMode,
UserTierId,
IdeInfo,
FallbackIntent,
} from '@qwen-code/qwen-code-core';
import type { DOMElement } from 'ink';
import type { SessionStatsState } from '../contexts/SessionContext.js';
import type { ExtensionUpdateState } from '../state/extensions.js';
import type { UpdateObject } from '../utils/updateCheck.js';
export interface ProQuotaDialogRequest {
failedModel: string;
fallbackModel: string;
resolve: (intent: FallbackIntent) => void;
}
import { type UseHistoryManagerReturn } from '../hooks/useHistoryManager.js';
import { type RestartReason } from '../hooks/useIdeTrustListener.js';
@@ -99,8 +91,6 @@ export interface UIState {
// eslint-disable-next-line @typescript-eslint/no-explicit-any
workspaceExtensions: any[]; // Extension[]
// Quota-related state
userTier: UserTierId | undefined;
proQuotaRequest: ProQuotaDialogRequest | null;
currentModel: string;
contextFileNames: string[];
errorCount: number;

View File

@@ -1323,7 +1323,7 @@ describe('useGeminiStream', () => {
it('should call parseAndFormatApiError with the correct authType on stream initialization failure', async () => {
// 1. Setup
const mockError = new Error('Rate limit exceeded');
const mockAuthType = AuthType.LOGIN_WITH_GOOGLE;
const mockAuthType = AuthType.USE_VERTEX_AI;
mockParseAndFormatApiError.mockClear();
mockSendMessageStream.mockReturnValue(
(async function* () {
@@ -1374,9 +1374,6 @@ describe('useGeminiStream', () => {
expect(mockParseAndFormatApiError).toHaveBeenCalledWith(
'Rate limit exceeded',
mockAuthType,
undefined,
'gemini-2.5-pro',
'gemini-2.5-flash',
);
});
});
@@ -2493,9 +2490,6 @@ describe('useGeminiStream', () => {
expect(mockParseAndFormatApiError).toHaveBeenCalledWith(
{ message: 'Test error' },
expect.any(String),
undefined,
'gemini-2.5-pro',
'gemini-2.5-flash',
);
});
});

View File

@@ -26,7 +26,6 @@ import {
GitService,
UnauthorizedError,
UserPromptEvent,
DEFAULT_GEMINI_FLASH_MODEL,
logConversationFinishedEvent,
ConversationFinishedEvent,
ApprovalMode,
@@ -600,9 +599,6 @@ export const useGeminiStream = (
text: parseAndFormatApiError(
eventValue.error,
config.getContentGeneratorConfig()?.authType,
undefined,
config.getModel(),
DEFAULT_GEMINI_FLASH_MODEL,
),
},
userMessageTimestamp,
@@ -654,6 +650,9 @@ export const useGeminiStream = (
'Response stopped due to image safety violations.',
[FinishReason.UNEXPECTED_TOOL_CALL]:
'Response stopped due to unexpected tool call.',
[FinishReason.IMAGE_PROHIBITED_CONTENT]:
'Response stopped due to image prohibited content.',
[FinishReason.NO_IMAGE]: 'Response stopped due to no image.',
};
const message = finishReasonMessages[finishReason];
@@ -770,11 +769,17 @@ export const useGeminiStream = (
for await (const event of stream) {
switch (event.type) {
case ServerGeminiEventType.Thought:
thoughtBuffer = handleThoughtEvent(
event.value,
thoughtBuffer,
userMessageTimestamp,
);
// If the thought has a subject, it's a discrete status update rather than
// a streamed textual thought, so we update the thought state directly.
if (event.value.subject) {
setThought(event.value);
} else {
thoughtBuffer = handleThoughtEvent(
event.value,
thoughtBuffer,
userMessageTimestamp,
);
}
break;
case ServerGeminiEventType.Content:
geminiMessageBuffer = handleContentEvent(
@@ -845,6 +850,7 @@ export const useGeminiStream = (
handleMaxSessionTurnsEvent,
handleSessionTokenLimitExceededEvent,
handleCitationEvent,
setThought,
],
);
@@ -987,9 +993,6 @@ export const useGeminiStream = (
text: parseAndFormatApiError(
getErrorMessage(error) || 'Unknown error',
config.getContentGeneratorConfig()?.authType,
undefined,
config.getModel(),
DEFAULT_GEMINI_FLASH_MODEL,
),
},
userMessageTimestamp,

View File

@@ -1,391 +0,0 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import {
vi,
describe,
it,
expect,
beforeEach,
afterEach,
type Mock,
} from 'vitest';
import { act, renderHook } from '@testing-library/react';
import {
type Config,
type FallbackModelHandler,
UserTierId,
AuthType,
isGenericQuotaExceededError,
isProQuotaExceededError,
makeFakeConfig,
} from '@qwen-code/qwen-code-core';
import { useQuotaAndFallback } from './useQuotaAndFallback.js';
import type { UseHistoryManagerReturn } from './useHistoryManager.js';
import { AuthState, MessageType } from '../types.js';
// Mock the error checking functions from the core package to control test scenarios
vi.mock('@qwen-code/qwen-code-core', async (importOriginal) => {
const original =
await importOriginal<typeof import('@qwen-code/qwen-code-core')>();
return {
...original,
isGenericQuotaExceededError: vi.fn(),
isProQuotaExceededError: vi.fn(),
};
});
// Use a type alias for SpyInstance as it's not directly exported
type SpyInstance = ReturnType<typeof vi.spyOn>;
describe('useQuotaAndFallback', () => {
let mockConfig: Config;
let mockHistoryManager: UseHistoryManagerReturn;
let mockSetAuthState: Mock;
let mockSetModelSwitchedFromQuotaError: Mock;
let setFallbackHandlerSpy: SpyInstance;
const mockedIsGenericQuotaExceededError = isGenericQuotaExceededError as Mock;
const mockedIsProQuotaExceededError = isProQuotaExceededError as Mock;
beforeEach(() => {
mockConfig = makeFakeConfig();
// Spy on the method that requires the private field and mock its return.
// This is cleaner than modifying the config class for tests.
vi.spyOn(mockConfig, 'getContentGeneratorConfig').mockReturnValue({
model: 'test-model',
authType: AuthType.LOGIN_WITH_GOOGLE,
});
mockHistoryManager = {
addItem: vi.fn(),
history: [],
updateItem: vi.fn(),
clearItems: vi.fn(),
loadHistory: vi.fn(),
};
mockSetAuthState = vi.fn();
mockSetModelSwitchedFromQuotaError = vi.fn();
setFallbackHandlerSpy = vi.spyOn(mockConfig, 'setFallbackModelHandler');
vi.spyOn(mockConfig, 'setQuotaErrorOccurred');
mockedIsGenericQuotaExceededError.mockReturnValue(false);
mockedIsProQuotaExceededError.mockReturnValue(false);
});
afterEach(() => {
vi.clearAllMocks();
});
it('should register a fallback handler on initialization', () => {
renderHook(() =>
useQuotaAndFallback({
config: mockConfig,
historyManager: mockHistoryManager,
userTier: UserTierId.FREE,
setAuthState: mockSetAuthState,
setModelSwitchedFromQuotaError: mockSetModelSwitchedFromQuotaError,
}),
);
expect(setFallbackHandlerSpy).toHaveBeenCalledTimes(1);
expect(setFallbackHandlerSpy.mock.calls[0][0]).toBeInstanceOf(Function);
});
describe('Fallback Handler Logic', () => {
// Helper function to render the hook and extract the registered handler
const getRegisteredHandler = (
userTier: UserTierId = UserTierId.FREE,
): FallbackModelHandler => {
renderHook(
(props) =>
useQuotaAndFallback({
config: mockConfig,
historyManager: mockHistoryManager,
userTier: props.userTier,
setAuthState: mockSetAuthState,
setModelSwitchedFromQuotaError: mockSetModelSwitchedFromQuotaError,
}),
{ initialProps: { userTier } },
);
return setFallbackHandlerSpy.mock.calls[0][0] as FallbackModelHandler;
};
it('should return null and take no action if already in fallback mode', async () => {
vi.spyOn(mockConfig, 'isInFallbackMode').mockReturnValue(true);
const handler = getRegisteredHandler();
const result = await handler('gemini-pro', 'gemini-flash', new Error());
expect(result).toBeNull();
expect(mockHistoryManager.addItem).not.toHaveBeenCalled();
});
it('should return null and take no action if authType is not LOGIN_WITH_GOOGLE', async () => {
// Override the default mock from beforeEach for this specific test
vi.spyOn(mockConfig, 'getContentGeneratorConfig').mockReturnValue({
model: 'test-model',
authType: AuthType.USE_GEMINI,
});
const handler = getRegisteredHandler();
const result = await handler('gemini-pro', 'gemini-flash', new Error());
expect(result).toBeNull();
expect(mockHistoryManager.addItem).not.toHaveBeenCalled();
});
describe('Automatic Fallback Scenarios', () => {
const testCases = [
{
errorType: 'generic',
tier: UserTierId.FREE,
expectedMessageSnippets: [
'Automatically switching from model-A to model-B',
'upgrade to a Gemini Code Assist Standard or Enterprise plan',
],
},
{
errorType: 'generic',
tier: UserTierId.STANDARD, // Paid tier
expectedMessageSnippets: [
'Automatically switching from model-A to model-B',
'switch to using a paid API key from AI Studio',
],
},
{
errorType: 'other',
tier: UserTierId.FREE,
expectedMessageSnippets: [
'Automatically switching from model-A to model-B for faster responses',
'upgrade to a Gemini Code Assist Standard or Enterprise plan',
],
},
{
errorType: 'other',
tier: UserTierId.LEGACY, // Paid tier
expectedMessageSnippets: [
'Automatically switching from model-A to model-B for faster responses',
'switch to using a paid API key from AI Studio',
],
},
];
for (const { errorType, tier, expectedMessageSnippets } of testCases) {
it(`should handle ${errorType} error for ${tier} tier correctly`, async () => {
mockedIsGenericQuotaExceededError.mockReturnValue(
errorType === 'generic',
);
const handler = getRegisteredHandler(tier);
const result = await handler(
'model-A',
'model-B',
new Error('quota exceeded'),
);
// Automatic fallbacks should return 'stop'
expect(result).toBe('stop');
expect(mockHistoryManager.addItem).toHaveBeenCalledWith(
expect.objectContaining({ type: MessageType.INFO }),
expect.any(Number),
);
const message = (mockHistoryManager.addItem as Mock).mock.calls[0][0]
.text;
for (const snippet of expectedMessageSnippets) {
expect(message).toContain(snippet);
}
expect(mockSetModelSwitchedFromQuotaError).toHaveBeenCalledWith(true);
expect(mockConfig.setQuotaErrorOccurred).toHaveBeenCalledWith(true);
});
}
});
describe('Interactive Fallback (Pro Quota Error)', () => {
beforeEach(() => {
mockedIsProQuotaExceededError.mockReturnValue(true);
});
it('should set an interactive request and wait for user choice', async () => {
const { result } = renderHook(() =>
useQuotaAndFallback({
config: mockConfig,
historyManager: mockHistoryManager,
userTier: UserTierId.FREE,
setAuthState: mockSetAuthState,
setModelSwitchedFromQuotaError: mockSetModelSwitchedFromQuotaError,
}),
);
const handler = setFallbackHandlerSpy.mock
.calls[0][0] as FallbackModelHandler;
// Call the handler but do not await it, to check the intermediate state
const promise = handler(
'gemini-pro',
'gemini-flash',
new Error('pro quota'),
);
await act(async () => {});
// The hook should now have a pending request for the UI to handle
expect(result.current.proQuotaRequest).not.toBeNull();
expect(result.current.proQuotaRequest?.failedModel).toBe('gemini-pro');
// Simulate the user choosing to continue with the fallback model
act(() => {
result.current.handleProQuotaChoice('continue');
});
// The original promise from the handler should now resolve
const intent = await promise;
expect(intent).toBe('retry');
// The pending request should be cleared from the state
expect(result.current.proQuotaRequest).toBeNull();
});
it('should handle race conditions by stopping subsequent requests', async () => {
const { result } = renderHook(() =>
useQuotaAndFallback({
config: mockConfig,
historyManager: mockHistoryManager,
userTier: UserTierId.FREE,
setAuthState: mockSetAuthState,
setModelSwitchedFromQuotaError: mockSetModelSwitchedFromQuotaError,
}),
);
const handler = setFallbackHandlerSpy.mock
.calls[0][0] as FallbackModelHandler;
const promise1 = handler(
'gemini-pro',
'gemini-flash',
new Error('pro quota 1'),
);
await act(async () => {});
const firstRequest = result.current.proQuotaRequest;
expect(firstRequest).not.toBeNull();
const result2 = await handler(
'gemini-pro',
'gemini-flash',
new Error('pro quota 2'),
);
// The lock should have stopped the second request
expect(result2).toBe('stop');
expect(result.current.proQuotaRequest).toBe(firstRequest);
act(() => {
result.current.handleProQuotaChoice('continue');
});
const intent1 = await promise1;
expect(intent1).toBe('retry');
expect(result.current.proQuotaRequest).toBeNull();
});
});
});
describe('handleProQuotaChoice', () => {
beforeEach(() => {
mockedIsProQuotaExceededError.mockReturnValue(true);
});
it('should do nothing if there is no pending pro quota request', () => {
const { result } = renderHook(() =>
useQuotaAndFallback({
config: mockConfig,
historyManager: mockHistoryManager,
userTier: UserTierId.FREE,
setAuthState: mockSetAuthState,
setModelSwitchedFromQuotaError: mockSetModelSwitchedFromQuotaError,
}),
);
act(() => {
result.current.handleProQuotaChoice('auth');
});
expect(mockSetAuthState).not.toHaveBeenCalled();
expect(mockHistoryManager.addItem).not.toHaveBeenCalled();
});
it('should resolve intent to "auth" and trigger auth state update', async () => {
const { result } = renderHook(() =>
useQuotaAndFallback({
config: mockConfig,
historyManager: mockHistoryManager,
userTier: UserTierId.FREE,
setAuthState: mockSetAuthState,
setModelSwitchedFromQuotaError: mockSetModelSwitchedFromQuotaError,
}),
);
const handler = setFallbackHandlerSpy.mock
.calls[0][0] as FallbackModelHandler;
const promise = handler(
'gemini-pro',
'gemini-flash',
new Error('pro quota'),
);
await act(async () => {}); // Allow state to update
act(() => {
result.current.handleProQuotaChoice('auth');
});
const intent = await promise;
expect(intent).toBe('auth');
expect(mockSetAuthState).toHaveBeenCalledWith(AuthState.Updating);
expect(result.current.proQuotaRequest).toBeNull();
});
it('should resolve intent to "retry" and add info message on continue', async () => {
const { result } = renderHook(() =>
useQuotaAndFallback({
config: mockConfig,
historyManager: mockHistoryManager,
userTier: UserTierId.FREE,
setAuthState: mockSetAuthState,
setModelSwitchedFromQuotaError: mockSetModelSwitchedFromQuotaError,
}),
);
const handler = setFallbackHandlerSpy.mock
.calls[0][0] as FallbackModelHandler;
// The first `addItem` call is for the initial quota error message
const promise = handler(
'gemini-pro',
'gemini-flash',
new Error('pro quota'),
);
await act(async () => {}); // Allow state to update
act(() => {
result.current.handleProQuotaChoice('continue');
});
const intent = await promise;
expect(intent).toBe('retry');
expect(result.current.proQuotaRequest).toBeNull();
// Check for the second "Switched to fallback model" message
expect(mockHistoryManager.addItem).toHaveBeenCalledTimes(2);
const lastCall = (mockHistoryManager.addItem as Mock).mock.calls[1][0];
expect(lastCall.type).toBe(MessageType.INFO);
expect(lastCall.text).toContain('Switched to fallback model.');
});
});
});

View File

@@ -1,175 +0,0 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import {
AuthType,
type Config,
type FallbackModelHandler,
type FallbackIntent,
isGenericQuotaExceededError,
isProQuotaExceededError,
UserTierId,
} from '@qwen-code/qwen-code-core';
import { useCallback, useEffect, useRef, useState } from 'react';
import { type UseHistoryManagerReturn } from './useHistoryManager.js';
import { AuthState, MessageType } from '../types.js';
import { type ProQuotaDialogRequest } from '../contexts/UIStateContext.js';
interface UseQuotaAndFallbackArgs {
config: Config;
historyManager: UseHistoryManagerReturn;
userTier: UserTierId | undefined;
setAuthState: (state: AuthState) => void;
setModelSwitchedFromQuotaError: (value: boolean) => void;
}
export function useQuotaAndFallback({
config,
historyManager,
userTier,
setAuthState,
setModelSwitchedFromQuotaError,
}: UseQuotaAndFallbackArgs) {
const [proQuotaRequest, setProQuotaRequest] =
useState<ProQuotaDialogRequest | null>(null);
const isDialogPending = useRef(false);
// Set up Flash fallback handler
useEffect(() => {
const fallbackHandler: FallbackModelHandler = async (
failedModel,
fallbackModel,
error,
): Promise<FallbackIntent | null> => {
if (config.isInFallbackMode()) {
return null;
}
// Fallbacks are currently only handled for OAuth users.
const contentGeneratorConfig = config.getContentGeneratorConfig();
if (
!contentGeneratorConfig ||
contentGeneratorConfig.authType !== AuthType.LOGIN_WITH_GOOGLE
) {
return null;
}
// Use actual user tier if available; otherwise, default to FREE tier behavior (safe default)
const isPaidTier =
userTier === UserTierId.LEGACY || userTier === UserTierId.STANDARD;
let message: string;
if (error && isProQuotaExceededError(error)) {
// Pro Quota specific messages (Interactive)
if (isPaidTier) {
message = `⚡ You have reached your daily ${failedModel} quota limit.
⚡ You can choose to authenticate with a paid API key or continue with the fallback model.
⚡ To continue accessing the ${failedModel} model today, consider using /auth to switch to using a paid API key from AI Studio at https://aistudio.google.com/apikey`;
} else {
message = `⚡ You have reached your daily ${failedModel} quota limit.
⚡ You can choose to authenticate with a paid API key or continue with the fallback model.
⚡ To increase your limits, upgrade to a Gemini Code Assist Standard or Enterprise plan with higher limits at https://goo.gle/set-up-gemini-code-assist
⚡ Or you can utilize a Gemini API Key. See: https://goo.gle/gemini-cli-docs-auth#gemini-api-key
⚡ You can switch authentication methods by typing /auth`;
}
} else if (error && isGenericQuotaExceededError(error)) {
// Generic Quota (Automatic fallback)
const actionMessage = `⚡ You have reached your daily quota limit.\n⚡ Automatically switching from ${failedModel} to ${fallbackModel} for the remainder of this session.`;
if (isPaidTier) {
message = `${actionMessage}
⚡ To continue accessing the ${failedModel} model today, consider using /auth to switch to using a paid API key from AI Studio at https://aistudio.google.com/apikey`;
} else {
message = `${actionMessage}
⚡ To increase your limits, upgrade to a Gemini Code Assist Standard or Enterprise plan with higher limits at https://goo.gle/set-up-gemini-code-assist
⚡ Or you can utilize a Gemini API Key. See: https://goo.gle/gemini-cli-docs-auth#gemini-api-key
⚡ You can switch authentication methods by typing /auth`;
}
} else {
// Consecutive 429s or other errors (Automatic fallback)
const actionMessage = `⚡ Automatically switching from ${failedModel} to ${fallbackModel} for faster responses for the remainder of this session.`;
if (isPaidTier) {
message = `${actionMessage}
⚡ Possible reasons for this are that you have received multiple consecutive capacity errors or you have reached your daily ${failedModel} quota limit
⚡ To continue accessing the ${failedModel} model today, consider using /auth to switch to using a paid API key from AI Studio at https://aistudio.google.com/apikey`;
} else {
message = `${actionMessage}
⚡ Possible reasons for this are that you have received multiple consecutive capacity errors or you have reached your daily ${failedModel} quota limit
⚡ To increase your limits, upgrade to a Gemini Code Assist Standard or Enterprise plan with higher limits at https://goo.gle/set-up-gemini-code-assist
⚡ Or you can utilize a Gemini API Key. See: https://goo.gle/gemini-cli-docs-auth#gemini-api-key
⚡ You can switch authentication methods by typing /auth`;
}
}
// Add message to UI history
historyManager.addItem(
{
type: MessageType.INFO,
text: message,
},
Date.now(),
);
setModelSwitchedFromQuotaError(true);
config.setQuotaErrorOccurred(true);
// Interactive Fallback for Pro quota
if (error && isProQuotaExceededError(error)) {
if (isDialogPending.current) {
return 'stop'; // A dialog is already active, so just stop this request.
}
isDialogPending.current = true;
const intent: FallbackIntent = await new Promise<FallbackIntent>(
(resolve) => {
setProQuotaRequest({
failedModel,
fallbackModel,
resolve,
});
},
);
return intent;
}
return 'stop';
};
config.setFallbackModelHandler(fallbackHandler);
}, [config, historyManager, userTier, setModelSwitchedFromQuotaError]);
const handleProQuotaChoice = useCallback(
(choice: 'auth' | 'continue') => {
if (!proQuotaRequest) return;
const intent: FallbackIntent = choice === 'auth' ? 'auth' : 'retry';
proQuotaRequest.resolve(intent);
setProQuotaRequest(null);
isDialogPending.current = false; // Reset the flag here
if (choice === 'auth') {
setAuthState(AuthState.Updating);
} else {
historyManager.addItem(
{
type: MessageType.INFO,
text: 'Switched to fallback model. Tip: Press Ctrl+P (or Up Arrow) to recall your previous prompt and submit it again if you wish.',
},
Date.now(),
);
}
},
[proQuotaRequest, setAuthState, historyManager],
);
return {
proQuotaRequest,
handleProQuotaChoice,
};
}

View File

@@ -411,7 +411,7 @@ describe('useQwenAuth', () => {
expect(geminiResult.current.qwenAuthState.authStatus).toBe('idle');
const { result: oauthResult } = renderHook(() =>
useQwenAuth(AuthType.LOGIN_WITH_GOOGLE, true),
useQwenAuth(AuthType.USE_OPENAI, true),
);
expect(oauthResult.current.qwenAuthState.authStatus).toBe('idle');
});

View File

@@ -62,7 +62,7 @@ const mockConfig = {
getAllowedTools: vi.fn(() => []),
getContentGeneratorConfig: () => ({
model: 'test-model',
authType: 'oauth-personal',
authType: 'gemini-api-key',
}),
getUseSmartEdit: () => false,
getUseModelRouter: () => false,

View File

@@ -21,6 +21,13 @@ function getAuthTypeFromEnv(): AuthType | undefined {
return AuthType.QWEN_OAUTH;
}
if (process.env['GEMINI_API_KEY']) {
return AuthType.USE_GEMINI;
}
if (process.env['GOOGLE_API_KEY']) {
return AuthType.USE_VERTEX_AI;
}
return undefined;
}

View File

@@ -1,6 +1,6 @@
{
"name": "@qwen-code/qwen-code-core",
"version": "0.6.0",
"version": "0.5.1-preview.1",
"description": "Qwen Code Core",
"repository": {
"type": "git",
@@ -23,8 +23,8 @@
"scripts/postinstall.js"
],
"dependencies": {
"@google/genai": "1.16.0",
"@modelcontextprotocol/sdk": "^1.11.0",
"@google/genai": "1.30.0",
"@modelcontextprotocol/sdk": "^1.25.1",
"@opentelemetry/api": "^1.9.0",
"async-mutex": "^0.5.0",
"@opentelemetry/exporter-logs-otlp-grpc": "^0.203.0",
@@ -34,7 +34,6 @@
"@opentelemetry/exporter-trace-otlp-grpc": "^0.203.0",
"@opentelemetry/exporter-trace-otlp-http": "^0.203.0",
"@opentelemetry/instrumentation-http": "^0.203.0",
"@opentelemetry/resource-detector-gcp": "^0.40.0",
"@opentelemetry/sdk-node": "^0.203.0",
"@types/html-to-text": "^9.0.4",
"@xterm/headless": "5.5.0",
@@ -48,7 +47,7 @@
"fdir": "^6.4.6",
"fzf": "^0.5.2",
"glob": "^10.5.0",
"google-auth-library": "^9.11.0",
"google-auth-library": "^10.5.0",
"html-to-text": "^9.0.5",
"https-proxy-agent": "^7.0.6",
"ignore": "^7.0.0",

View File

@@ -1,54 +0,0 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import type { ContentGenerator } from '../core/contentGenerator.js';
import { AuthType } from '../core/contentGenerator.js';
import { getOauthClient } from './oauth2.js';
import { setupUser } from './setup.js';
import type { HttpOptions } from './server.js';
import { CodeAssistServer } from './server.js';
import type { Config } from '../config/config.js';
import { LoggingContentGenerator } from '../core/loggingContentGenerator.js';
export async function createCodeAssistContentGenerator(
httpOptions: HttpOptions,
authType: AuthType,
config: Config,
sessionId?: string,
): Promise<ContentGenerator> {
if (
authType === AuthType.LOGIN_WITH_GOOGLE ||
authType === AuthType.CLOUD_SHELL
) {
const authClient = await getOauthClient(authType, config);
const userData = await setupUser(authClient);
return new CodeAssistServer(
authClient,
userData.projectId,
httpOptions,
sessionId,
userData.userTier,
);
}
throw new Error(`Unsupported authType: ${authType}`);
}
export function getCodeAssistServer(
config: Config,
): CodeAssistServer | undefined {
let server = config.getContentGenerator();
// Unwrap LoggingContentGenerator if present
if (server instanceof LoggingContentGenerator) {
server = server.getWrapped();
}
if (!(server instanceof CodeAssistServer)) {
return undefined;
}
return server;
}

View File

@@ -1,456 +0,0 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { describe, it, expect } from 'vitest';
import type { CaGenerateContentResponse } from './converter.js';
import {
toGenerateContentRequest,
fromGenerateContentResponse,
toContents,
} from './converter.js';
import type {
ContentListUnion,
GenerateContentParameters,
} from '@google/genai';
import {
GenerateContentResponse,
FinishReason,
BlockedReason,
type Part,
} from '@google/genai';
describe('converter', () => {
describe('toCodeAssistRequest', () => {
it('should convert a simple request with project', () => {
const genaiReq: GenerateContentParameters = {
model: 'gemini-pro',
contents: [{ role: 'user', parts: [{ text: 'Hello' }] }],
};
const codeAssistReq = toGenerateContentRequest(
genaiReq,
'my-prompt',
'my-project',
'my-session',
);
expect(codeAssistReq).toEqual({
model: 'gemini-pro',
project: 'my-project',
request: {
contents: [{ role: 'user', parts: [{ text: 'Hello' }] }],
systemInstruction: undefined,
cachedContent: undefined,
tools: undefined,
toolConfig: undefined,
labels: undefined,
safetySettings: undefined,
generationConfig: undefined,
session_id: 'my-session',
},
user_prompt_id: 'my-prompt',
});
});
it('should convert a request without a project', () => {
const genaiReq: GenerateContentParameters = {
model: 'gemini-pro',
contents: [{ role: 'user', parts: [{ text: 'Hello' }] }],
};
const codeAssistReq = toGenerateContentRequest(
genaiReq,
'my-prompt',
undefined,
'my-session',
);
expect(codeAssistReq).toEqual({
model: 'gemini-pro',
project: undefined,
request: {
contents: [{ role: 'user', parts: [{ text: 'Hello' }] }],
systemInstruction: undefined,
cachedContent: undefined,
tools: undefined,
toolConfig: undefined,
labels: undefined,
safetySettings: undefined,
generationConfig: undefined,
session_id: 'my-session',
},
user_prompt_id: 'my-prompt',
});
});
it('should convert a request with sessionId', () => {
const genaiReq: GenerateContentParameters = {
model: 'gemini-pro',
contents: [{ role: 'user', parts: [{ text: 'Hello' }] }],
};
const codeAssistReq = toGenerateContentRequest(
genaiReq,
'my-prompt',
'my-project',
'session-123',
);
expect(codeAssistReq).toEqual({
model: 'gemini-pro',
project: 'my-project',
request: {
contents: [{ role: 'user', parts: [{ text: 'Hello' }] }],
systemInstruction: undefined,
cachedContent: undefined,
tools: undefined,
toolConfig: undefined,
labels: undefined,
safetySettings: undefined,
generationConfig: undefined,
session_id: 'session-123',
},
user_prompt_id: 'my-prompt',
});
});
it('should handle string content', () => {
const genaiReq: GenerateContentParameters = {
model: 'gemini-pro',
contents: 'Hello',
};
const codeAssistReq = toGenerateContentRequest(
genaiReq,
'my-prompt',
'my-project',
'my-session',
);
expect(codeAssistReq.request.contents).toEqual([
{ role: 'user', parts: [{ text: 'Hello' }] },
]);
});
it('should handle Part[] content', () => {
const genaiReq: GenerateContentParameters = {
model: 'gemini-pro',
contents: [{ text: 'Hello' }, { text: 'World' }],
};
const codeAssistReq = toGenerateContentRequest(
genaiReq,
'my-prompt',
'my-project',
'my-session',
);
expect(codeAssistReq.request.contents).toEqual([
{ role: 'user', parts: [{ text: 'Hello' }] },
{ role: 'user', parts: [{ text: 'World' }] },
]);
});
it('should handle system instructions', () => {
const genaiReq: GenerateContentParameters = {
model: 'gemini-pro',
contents: 'Hello',
config: {
systemInstruction: 'You are a helpful assistant.',
},
};
const codeAssistReq = toGenerateContentRequest(
genaiReq,
'my-prompt',
'my-project',
'my-session',
);
expect(codeAssistReq.request.systemInstruction).toEqual({
role: 'user',
parts: [{ text: 'You are a helpful assistant.' }],
});
});
it('should handle generation config', () => {
const genaiReq: GenerateContentParameters = {
model: 'gemini-pro',
contents: 'Hello',
config: {
temperature: 0.8,
topK: 40,
},
};
const codeAssistReq = toGenerateContentRequest(
genaiReq,
'my-prompt',
'my-project',
'my-session',
);
expect(codeAssistReq.request.generationConfig).toEqual({
temperature: 0.8,
topK: 40,
});
});
it('should handle all generation config fields', () => {
const genaiReq: GenerateContentParameters = {
model: 'gemini-pro',
contents: 'Hello',
config: {
temperature: 0.1,
topP: 0.2,
topK: 3,
candidateCount: 4,
maxOutputTokens: 5,
stopSequences: ['a'],
responseLogprobs: true,
logprobs: 6,
presencePenalty: 0.7,
frequencyPenalty: 0.8,
seed: 9,
responseMimeType: 'application/json',
},
};
const codeAssistReq = toGenerateContentRequest(
genaiReq,
'my-prompt',
'my-project',
'my-session',
);
expect(codeAssistReq.request.generationConfig).toEqual({
temperature: 0.1,
topP: 0.2,
topK: 3,
candidateCount: 4,
maxOutputTokens: 5,
stopSequences: ['a'],
responseLogprobs: true,
logprobs: 6,
presencePenalty: 0.7,
frequencyPenalty: 0.8,
seed: 9,
responseMimeType: 'application/json',
});
});
});
describe('fromCodeAssistResponse', () => {
it('should convert a simple response', () => {
const codeAssistRes: CaGenerateContentResponse = {
response: {
candidates: [
{
index: 0,
content: {
role: 'model',
parts: [{ text: 'Hi there!' }],
},
finishReason: FinishReason.STOP,
safetyRatings: [],
},
],
},
};
const genaiRes = fromGenerateContentResponse(codeAssistRes);
expect(genaiRes).toBeInstanceOf(GenerateContentResponse);
expect(genaiRes.candidates).toEqual(codeAssistRes.response.candidates);
});
it('should handle prompt feedback and usage metadata', () => {
const codeAssistRes: CaGenerateContentResponse = {
response: {
candidates: [],
promptFeedback: {
blockReason: BlockedReason.SAFETY,
safetyRatings: [],
},
usageMetadata: {
promptTokenCount: 10,
candidatesTokenCount: 20,
totalTokenCount: 30,
},
},
};
const genaiRes = fromGenerateContentResponse(codeAssistRes);
expect(genaiRes.promptFeedback).toEqual(
codeAssistRes.response.promptFeedback,
);
expect(genaiRes.usageMetadata).toEqual(
codeAssistRes.response.usageMetadata,
);
});
it('should handle automatic function calling history', () => {
const codeAssistRes: CaGenerateContentResponse = {
response: {
candidates: [],
automaticFunctionCallingHistory: [
{
role: 'model',
parts: [
{
functionCall: {
name: 'test_function',
args: {
foo: 'bar',
},
},
},
],
},
],
},
};
const genaiRes = fromGenerateContentResponse(codeAssistRes);
expect(genaiRes.automaticFunctionCallingHistory).toEqual(
codeAssistRes.response.automaticFunctionCallingHistory,
);
});
it('should handle modelVersion', () => {
const codeAssistRes: CaGenerateContentResponse = {
response: {
candidates: [],
modelVersion: 'qwen3-coder-plus',
},
};
const genaiRes = fromGenerateContentResponse(codeAssistRes);
expect(genaiRes.modelVersion).toEqual('qwen3-coder-plus');
});
});
describe('toContents', () => {
it('should handle Content', () => {
const content: ContentListUnion = {
role: 'user',
parts: [{ text: 'hello' }],
};
expect(toContents(content)).toEqual([
{ role: 'user', parts: [{ text: 'hello' }] },
]);
});
it('should handle array of Contents', () => {
const contents: ContentListUnion = [
{ role: 'user', parts: [{ text: 'hello' }] },
{ role: 'model', parts: [{ text: 'hi' }] },
];
expect(toContents(contents)).toEqual([
{ role: 'user', parts: [{ text: 'hello' }] },
{ role: 'model', parts: [{ text: 'hi' }] },
]);
});
it('should handle Part', () => {
const part: ContentListUnion = { text: 'a part' };
expect(toContents(part)).toEqual([
{ role: 'user', parts: [{ text: 'a part' }] },
]);
});
it('should handle array of Parts', () => {
const parts = [{ text: 'part 1' }, 'part 2'];
expect(toContents(parts)).toEqual([
{ role: 'user', parts: [{ text: 'part 1' }] },
{ role: 'user', parts: [{ text: 'part 2' }] },
]);
});
it('should handle string', () => {
const str: ContentListUnion = 'a string';
expect(toContents(str)).toEqual([
{ role: 'user', parts: [{ text: 'a string' }] },
]);
});
it('should handle array of strings', () => {
const strings: ContentListUnion = ['string 1', 'string 2'];
expect(toContents(strings)).toEqual([
{ role: 'user', parts: [{ text: 'string 1' }] },
{ role: 'user', parts: [{ text: 'string 2' }] },
]);
});
it('should convert thought parts to text parts for API compatibility', () => {
const contentWithThought: ContentListUnion = {
role: 'model',
parts: [
{ text: 'regular text' },
{ thought: 'thinking about the problem' } as Part & {
thought: string;
},
{ text: 'more text' },
],
};
expect(toContents(contentWithThought)).toEqual([
{
role: 'model',
parts: [
{ text: 'regular text' },
{ text: '[Thought: thinking about the problem]' },
{ text: 'more text' },
],
},
]);
});
it('should combine text and thought for text parts with thoughts', () => {
const contentWithTextAndThought: ContentListUnion = {
role: 'model',
parts: [
{
text: 'Here is my response',
thought: 'I need to be careful here',
} as Part & { thought: string },
],
};
expect(toContents(contentWithTextAndThought)).toEqual([
{
role: 'model',
parts: [
{
text: 'Here is my response\n[Thought: I need to be careful here]',
},
],
},
]);
});
it('should preserve non-thought properties while removing thought', () => {
const contentWithComplexPart: ContentListUnion = {
role: 'model',
parts: [
{
functionCall: { name: 'calculate', args: { x: 5, y: 10 } },
thought: 'Performing calculation',
} as Part & { thought: string },
],
};
expect(toContents(contentWithComplexPart)).toEqual([
{
role: 'model',
parts: [
{
functionCall: { name: 'calculate', args: { x: 5, y: 10 } },
},
],
},
]);
});
it('should convert invalid text content to valid text part with thought', () => {
const contentWithInvalidText: ContentListUnion = {
role: 'model',
parts: [
{
text: 123, // Invalid - should be string
thought: 'Processing number',
} as Part & { thought: string; text: number },
],
};
expect(toContents(contentWithInvalidText)).toEqual([
{
role: 'model',
parts: [
{
text: '123\n[Thought: Processing number]',
},
],
},
]);
});
});
});

View File

@@ -1,285 +0,0 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import type {
Content,
ContentListUnion,
ContentUnion,
GenerateContentConfig,
GenerateContentParameters,
CountTokensParameters,
CountTokensResponse,
GenerationConfigRoutingConfig,
MediaResolution,
Candidate,
ModelSelectionConfig,
GenerateContentResponsePromptFeedback,
GenerateContentResponseUsageMetadata,
Part,
SafetySetting,
PartUnion,
SpeechConfigUnion,
ThinkingConfig,
ToolListUnion,
ToolConfig,
} from '@google/genai';
import { GenerateContentResponse } from '@google/genai';
export interface CAGenerateContentRequest {
model: string;
project?: string;
user_prompt_id?: string;
request: VertexGenerateContentRequest;
}
interface VertexGenerateContentRequest {
contents: Content[];
systemInstruction?: Content;
cachedContent?: string;
tools?: ToolListUnion;
toolConfig?: ToolConfig;
labels?: Record<string, string>;
safetySettings?: SafetySetting[];
generationConfig?: VertexGenerationConfig;
session_id?: string;
}
interface VertexGenerationConfig {
temperature?: number;
topP?: number;
topK?: number;
candidateCount?: number;
maxOutputTokens?: number;
stopSequences?: string[];
responseLogprobs?: boolean;
logprobs?: number;
presencePenalty?: number;
frequencyPenalty?: number;
seed?: number;
responseMimeType?: string;
responseJsonSchema?: unknown;
responseSchema?: unknown;
routingConfig?: GenerationConfigRoutingConfig;
modelSelectionConfig?: ModelSelectionConfig;
responseModalities?: string[];
mediaResolution?: MediaResolution;
speechConfig?: SpeechConfigUnion;
audioTimestamp?: boolean;
thinkingConfig?: ThinkingConfig;
}
export interface CaGenerateContentResponse {
response: VertexGenerateContentResponse;
}
interface VertexGenerateContentResponse {
candidates: Candidate[];
automaticFunctionCallingHistory?: Content[];
promptFeedback?: GenerateContentResponsePromptFeedback;
usageMetadata?: GenerateContentResponseUsageMetadata;
modelVersion?: string;
}
export interface CaCountTokenRequest {
request: VertexCountTokenRequest;
}
interface VertexCountTokenRequest {
model: string;
contents: Content[];
}
export interface CaCountTokenResponse {
totalTokens: number;
}
export function toCountTokenRequest(
req: CountTokensParameters,
): CaCountTokenRequest {
return {
request: {
model: 'models/' + req.model,
contents: toContents(req.contents),
},
};
}
export function fromCountTokenResponse(
res: CaCountTokenResponse,
): CountTokensResponse {
return {
totalTokens: res.totalTokens,
};
}
export function toGenerateContentRequest(
req: GenerateContentParameters,
userPromptId: string,
project?: string,
sessionId?: string,
): CAGenerateContentRequest {
return {
model: req.model,
project,
user_prompt_id: userPromptId,
request: toVertexGenerateContentRequest(req, sessionId),
};
}
export function fromGenerateContentResponse(
res: CaGenerateContentResponse,
): GenerateContentResponse {
const inres = res.response;
const out = new GenerateContentResponse();
out.candidates = inres.candidates;
out.automaticFunctionCallingHistory = inres.automaticFunctionCallingHistory;
out.promptFeedback = inres.promptFeedback;
out.usageMetadata = inres.usageMetadata;
out.modelVersion = inres.modelVersion;
return out;
}
function toVertexGenerateContentRequest(
req: GenerateContentParameters,
sessionId?: string,
): VertexGenerateContentRequest {
return {
contents: toContents(req.contents),
systemInstruction: maybeToContent(req.config?.systemInstruction),
cachedContent: req.config?.cachedContent,
tools: req.config?.tools,
toolConfig: req.config?.toolConfig,
labels: req.config?.labels,
safetySettings: req.config?.safetySettings,
generationConfig: toVertexGenerationConfig(req.config),
session_id: sessionId,
};
}
export function toContents(contents: ContentListUnion): Content[] {
if (Array.isArray(contents)) {
// it's a Content[] or a PartsUnion[]
return contents.map(toContent);
}
// it's a Content or a PartsUnion
return [toContent(contents)];
}
function maybeToContent(content?: ContentUnion): Content | undefined {
if (!content) {
return undefined;
}
return toContent(content);
}
function toContent(content: ContentUnion): Content {
if (Array.isArray(content)) {
// it's a PartsUnion[]
return {
role: 'user',
parts: toParts(content),
};
}
if (typeof content === 'string') {
// it's a string
return {
role: 'user',
parts: [{ text: content }],
};
}
if ('parts' in content) {
// it's a Content - process parts to handle thought filtering
return {
...content,
parts: content.parts
? toParts(content.parts.filter((p) => p != null))
: [],
};
}
// it's a Part
return {
role: 'user',
parts: [toPart(content as Part)],
};
}
export function toParts(parts: PartUnion[]): Part[] {
return parts.map(toPart);
}
function toPart(part: PartUnion): Part {
if (typeof part === 'string') {
// it's a string
return { text: part };
}
// Handle thought parts for CountToken API compatibility
// The CountToken API expects parts to have certain required "oneof" fields initialized,
// but thought parts don't conform to this schema and cause API failures
if ('thought' in part && part.thought) {
const thoughtText = `[Thought: ${part.thought}]`;
const newPart = { ...part };
delete (newPart as Record<string, unknown>)['thought'];
const hasApiContent =
'functionCall' in newPart ||
'functionResponse' in newPart ||
'inlineData' in newPart ||
'fileData' in newPart;
if (hasApiContent) {
// It's a functionCall or other non-text part. Just strip the thought.
return newPart;
}
// If no other valid API content, this must be a text part.
// Combine existing text (if any) with the thought, preserving other properties.
const text = (newPart as { text?: unknown }).text;
const existingText = text ? String(text) : '';
const combinedText = existingText
? `${existingText}\n${thoughtText}`
: thoughtText;
return {
...newPart,
text: combinedText,
};
}
return part;
}
function toVertexGenerationConfig(
config?: GenerateContentConfig,
): VertexGenerationConfig | undefined {
if (!config) {
return undefined;
}
return {
temperature: config.temperature,
topP: config.topP,
topK: config.topK,
candidateCount: config.candidateCount,
maxOutputTokens: config.maxOutputTokens,
stopSequences: config.stopSequences,
responseLogprobs: config.responseLogprobs,
logprobs: config.logprobs,
presencePenalty: config.presencePenalty,
frequencyPenalty: config.frequencyPenalty,
seed: config.seed,
responseMimeType: config.responseMimeType,
responseSchema: config.responseSchema,
responseJsonSchema: config.responseJsonSchema,
routingConfig: config.routingConfig,
modelSelectionConfig: config.modelSelectionConfig,
responseModalities: config.responseModalities,
mediaResolution: config.mediaResolution,
speechConfig: config.speechConfig,
audioTimestamp: config.audioTimestamp,
thinkingConfig: config.thinkingConfig,
};
}

View File

@@ -1,217 +0,0 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { type Credentials } from 'google-auth-library';
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
import { OAuthCredentialStorage } from './oauth-credential-storage.js';
import type { OAuthCredentials } from '../mcp/token-storage/types.js';
import * as path from 'node:path';
import * as os from 'node:os';
import { promises as fs } from 'node:fs';
// Mock external dependencies
const mockHybridTokenStorage = vi.hoisted(() => ({
getCredentials: vi.fn(),
setCredentials: vi.fn(),
deleteCredentials: vi.fn(),
}));
vi.mock('../mcp/token-storage/hybrid-token-storage.js', () => ({
HybridTokenStorage: vi.fn(() => mockHybridTokenStorage),
}));
vi.mock('node:fs', () => ({
promises: {
readFile: vi.fn(),
rm: vi.fn(),
},
}));
vi.mock('node:os');
vi.mock('node:path');
describe('OAuthCredentialStorage', () => {
const mockCredentials: Credentials = {
access_token: 'mock_access_token',
refresh_token: 'mock_refresh_token',
expiry_date: Date.now() + 3600 * 1000,
token_type: 'Bearer',
scope: 'email profile',
};
const mockMcpCredentials: OAuthCredentials = {
serverName: 'main-account',
token: {
accessToken: 'mock_access_token',
refreshToken: 'mock_refresh_token',
tokenType: 'Bearer',
scope: 'email profile',
expiresAt: mockCredentials.expiry_date!,
},
updatedAt: expect.any(Number),
};
const oldFilePath = '/mock/home/.qwen/oauth.json';
beforeEach(() => {
vi.spyOn(mockHybridTokenStorage, 'getCredentials').mockResolvedValue(null);
vi.spyOn(mockHybridTokenStorage, 'setCredentials').mockResolvedValue(
undefined,
);
vi.spyOn(mockHybridTokenStorage, 'deleteCredentials').mockResolvedValue(
undefined,
);
vi.spyOn(fs, 'readFile').mockRejectedValue(new Error('File not found'));
vi.spyOn(fs, 'rm').mockResolvedValue(undefined);
vi.spyOn(os, 'homedir').mockReturnValue('/mock/home');
vi.spyOn(path, 'join').mockReturnValue(oldFilePath);
});
afterEach(() => {
vi.restoreAllMocks();
});
describe('loadCredentials', () => {
it('should load credentials from HybridTokenStorage if available', async () => {
vi.spyOn(mockHybridTokenStorage, 'getCredentials').mockResolvedValue(
mockMcpCredentials,
);
const result = await OAuthCredentialStorage.loadCredentials();
expect(mockHybridTokenStorage.getCredentials).toHaveBeenCalledWith(
'main-account',
);
expect(result).toEqual(mockCredentials);
});
it('should fallback to migrateFromFileStorage if no credentials in HybridTokenStorage', async () => {
vi.spyOn(mockHybridTokenStorage, 'getCredentials').mockResolvedValue(
null,
);
vi.spyOn(fs, 'readFile').mockResolvedValue(
JSON.stringify(mockCredentials),
);
const result = await OAuthCredentialStorage.loadCredentials();
expect(mockHybridTokenStorage.getCredentials).toHaveBeenCalledWith(
'main-account',
);
expect(fs.readFile).toHaveBeenCalledWith(oldFilePath, 'utf-8');
expect(mockHybridTokenStorage.setCredentials).toHaveBeenCalled(); // Verify credentials were saved
expect(fs.rm).toHaveBeenCalledWith(oldFilePath, { force: true }); // Verify old file was removed
expect(result).toEqual(mockCredentials);
});
it('should return null if no credentials found and no old file to migrate', async () => {
vi.spyOn(fs, 'readFile').mockRejectedValue({
message: 'File not found',
code: 'ENOENT',
});
const result = await OAuthCredentialStorage.loadCredentials();
expect(result).toBeNull();
});
it('should throw an error if loading fails', async () => {
vi.spyOn(mockHybridTokenStorage, 'getCredentials').mockRejectedValue(
new Error('Loading error'),
);
await expect(OAuthCredentialStorage.loadCredentials()).rejects.toThrow(
'Failed to load OAuth credentials',
);
});
it('should throw an error if read file fails', async () => {
vi.spyOn(mockHybridTokenStorage, 'getCredentials').mockResolvedValue(
null,
);
vi.spyOn(fs, 'readFile').mockRejectedValue(
new Error('Permission denied'),
);
await expect(OAuthCredentialStorage.loadCredentials()).rejects.toThrow(
'Failed to load OAuth credentials',
);
});
it('should not throw error if migration file removal failed', async () => {
vi.spyOn(mockHybridTokenStorage, 'getCredentials').mockResolvedValue(
null,
);
vi.spyOn(fs, 'readFile').mockResolvedValue(
JSON.stringify(mockCredentials),
);
vi.spyOn(OAuthCredentialStorage, 'saveCredentials').mockResolvedValue(
undefined,
);
vi.spyOn(fs, 'rm').mockRejectedValue(new Error('Deletion failed'));
const result = await OAuthCredentialStorage.loadCredentials();
expect(result).toEqual(mockCredentials);
});
});
describe('saveCredentials', () => {
it('should save credentials to HybridTokenStorage', async () => {
await OAuthCredentialStorage.saveCredentials(mockCredentials);
expect(mockHybridTokenStorage.setCredentials).toHaveBeenCalledWith(
mockMcpCredentials,
);
});
it('should throw an error if access_token is missing', async () => {
const invalidCredentials: Credentials = {
...mockCredentials,
access_token: undefined,
};
await expect(
OAuthCredentialStorage.saveCredentials(invalidCredentials),
).rejects.toThrow(
'Attempted to save credentials without an access token.',
);
});
});
describe('clearCredentials', () => {
it('should delete credentials from HybridTokenStorage', async () => {
await OAuthCredentialStorage.clearCredentials();
expect(mockHybridTokenStorage.deleteCredentials).toHaveBeenCalledWith(
'main-account',
);
});
it('should attempt to remove the old file-based storage', async () => {
await OAuthCredentialStorage.clearCredentials();
expect(fs.rm).toHaveBeenCalledWith(oldFilePath, { force: true });
});
it('should not throw an error if deleting old file fails', async () => {
vi.spyOn(fs, 'rm').mockRejectedValue(new Error('File deletion failed'));
await expect(
OAuthCredentialStorage.clearCredentials(),
).resolves.toBeUndefined();
});
it('should throw an error if clearing from HybridTokenStorage fails', async () => {
vi.spyOn(mockHybridTokenStorage, 'deleteCredentials').mockRejectedValue(
new Error('Deletion error'),
);
await expect(OAuthCredentialStorage.clearCredentials()).rejects.toThrow(
'Failed to clear OAuth credentials',
);
});
});
});

View File

@@ -1,130 +0,0 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { type Credentials } from 'google-auth-library';
import { HybridTokenStorage } from '../mcp/token-storage/hybrid-token-storage.js';
import { OAUTH_FILE } from '../config/storage.js';
import type { OAuthCredentials } from '../mcp/token-storage/types.js';
import * as path from 'node:path';
import * as os from 'node:os';
import { promises as fs } from 'node:fs';
const QWEN_DIR = '.qwen';
const KEYCHAIN_SERVICE_NAME = 'qwen-code-oauth';
const MAIN_ACCOUNT_KEY = 'main-account';
export class OAuthCredentialStorage {
private static storage: HybridTokenStorage = new HybridTokenStorage(
KEYCHAIN_SERVICE_NAME,
);
/**
* Load cached OAuth credentials
*/
static async loadCredentials(): Promise<Credentials | null> {
try {
const credentials = await this.storage.getCredentials(MAIN_ACCOUNT_KEY);
if (credentials?.token) {
const { accessToken, refreshToken, expiresAt, tokenType, scope } =
credentials.token;
// Convert from OAuthCredentials format to Google Credentials format
const googleCreds: Credentials = {
access_token: accessToken,
refresh_token: refreshToken || undefined,
token_type: tokenType || undefined,
scope: scope || undefined,
};
if (expiresAt) {
googleCreds.expiry_date = expiresAt;
}
return googleCreds;
}
// Fallback: Try to migrate from old file-based storage
return await this.migrateFromFileStorage();
} catch (error: unknown) {
console.error(error);
throw new Error('Failed to load OAuth credentials');
}
}
/**
* Save OAuth credentials
*/
static async saveCredentials(credentials: Credentials): Promise<void> {
if (!credentials.access_token) {
throw new Error('Attempted to save credentials without an access token.');
}
// Convert Google Credentials to OAuthCredentials format
const mcpCredentials: OAuthCredentials = {
serverName: MAIN_ACCOUNT_KEY,
token: {
accessToken: credentials.access_token,
refreshToken: credentials.refresh_token || undefined,
tokenType: credentials.token_type || 'Bearer',
scope: credentials.scope || undefined,
expiresAt: credentials.expiry_date || undefined,
},
updatedAt: Date.now(),
};
await this.storage.setCredentials(mcpCredentials);
}
/**
* Clear cached OAuth credentials
*/
static async clearCredentials(): Promise<void> {
try {
await this.storage.deleteCredentials(MAIN_ACCOUNT_KEY);
// Also try to remove the old file if it exists
const oldFilePath = path.join(os.homedir(), QWEN_DIR, OAUTH_FILE);
await fs.rm(oldFilePath, { force: true }).catch(() => {});
} catch (error: unknown) {
console.error(error);
throw new Error('Failed to clear OAuth credentials');
}
}
/**
* Migrate credentials from old file-based storage to keychain
*/
private static async migrateFromFileStorage(): Promise<Credentials | null> {
const oldFilePath = path.join(os.homedir(), QWEN_DIR, OAUTH_FILE);
let credsJson: string;
try {
credsJson = await fs.readFile(oldFilePath, 'utf-8');
} catch (error: unknown) {
if (
typeof error === 'object' &&
error !== null &&
'code' in error &&
error.code === 'ENOENT'
) {
// File doesn't exist, so no migration.
return null;
}
// Other read errors should propagate.
throw error;
}
const credentials = JSON.parse(credsJson) as Credentials;
// Save to new storage
await this.saveCredentials(credentials);
// Remove old file after successful migration
await fs.rm(oldFilePath, { force: true }).catch(() => {});
return credentials;
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,563 +0,0 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import type { Credentials } from 'google-auth-library';
import {
CodeChallengeMethod,
Compute,
OAuth2Client,
} from 'google-auth-library';
import crypto from 'node:crypto';
import { promises as fs } from 'node:fs';
import * as http from 'node:http';
import * as net from 'node:net';
import path from 'node:path';
import readline from 'node:readline';
import url from 'node:url';
import open from 'open';
import type { Config } from '../config/config.js';
import { Storage } from '../config/storage.js';
import { AuthType } from '../core/contentGenerator.js';
import { FatalAuthenticationError, getErrorMessage } from '../utils/errors.js';
import { UserAccountManager } from '../utils/userAccountManager.js';
import { OAuthCredentialStorage } from './oauth-credential-storage.js';
import { FORCE_ENCRYPTED_FILE_ENV_VAR } from '../mcp/token-storage/index.js';
const userAccountManager = new UserAccountManager();
// OAuth Client ID used to initiate OAuth2Client class.
const OAUTH_CLIENT_ID =
'681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com';
// OAuth Secret value used to initiate OAuth2Client class.
// Note: It's ok to save this in git because this is an installed application
// as described here: https://developers.google.com/identity/protocols/oauth2#installed
// "The process results in a client ID and, in some cases, a client secret,
// which you embed in the source code of your application. (In this context,
// the client secret is obviously not treated as a secret.)"
const OAUTH_CLIENT_SECRET = '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';
// OAuth Scopes for Cloud Code authorization.
const OAUTH_SCOPE = [
'https://www.googleapis.com/auth/cloud-platform',
'https://www.googleapis.com/auth/userinfo.email',
'https://www.googleapis.com/auth/userinfo.profile',
];
const HTTP_REDIRECT = 301;
const SIGN_IN_SUCCESS_URL =
'https://developers.google.com/gemini-code-assist/auth_success_gemini';
const SIGN_IN_FAILURE_URL =
'https://developers.google.com/gemini-code-assist/auth_failure_gemini';
/**
* An Authentication URL for updating the credentials of a Oauth2Client
* as well as a promise that will resolve when the credentials have
* been refreshed (or which throws error when refreshing credentials failed).
*/
export interface OauthWebLogin {
authUrl: string;
loginCompletePromise: Promise<void>;
}
const oauthClientPromises = new Map<AuthType, Promise<OAuth2Client>>();
function getUseEncryptedStorageFlag() {
return process.env[FORCE_ENCRYPTED_FILE_ENV_VAR] === 'true';
}
async function initOauthClient(
authType: AuthType,
config: Config,
): Promise<OAuth2Client> {
const client = new OAuth2Client({
clientId: OAUTH_CLIENT_ID,
clientSecret: OAUTH_CLIENT_SECRET,
transporterOptions: {
proxy: config.getProxy(),
},
});
const useEncryptedStorage = getUseEncryptedStorageFlag();
if (
process.env['GOOGLE_GENAI_USE_GCA'] &&
process.env['GOOGLE_CLOUD_ACCESS_TOKEN']
) {
client.setCredentials({
access_token: process.env['GOOGLE_CLOUD_ACCESS_TOKEN'],
});
await fetchAndCacheUserInfo(client);
return client;
}
client.on('tokens', async (tokens: Credentials) => {
if (useEncryptedStorage) {
await OAuthCredentialStorage.saveCredentials(tokens);
} else {
await cacheCredentials(tokens);
}
});
// If there are cached creds on disk, they always take precedence
if (await loadCachedCredentials(client)) {
// Found valid cached credentials.
// Check if we need to retrieve Google Account ID or Email
if (!userAccountManager.getCachedGoogleAccount()) {
try {
await fetchAndCacheUserInfo(client);
} catch (error) {
// Non-fatal, continue with existing auth.
console.warn('Failed to fetch user info:', getErrorMessage(error));
}
}
console.log('Loaded cached credentials.');
return client;
}
// In Google Cloud Shell, we can use Application Default Credentials (ADC)
// provided via its metadata server to authenticate non-interactively using
// the identity of the user logged into Cloud Shell.
if (authType === AuthType.CLOUD_SHELL) {
try {
console.log("Attempting to authenticate via Cloud Shell VM's ADC.");
const computeClient = new Compute({
// We can leave this empty, since the metadata server will provide
// the service account email.
});
await computeClient.getAccessToken();
console.log('Authentication successful.');
// Do not cache creds in this case; note that Compute client will handle its own refresh
return computeClient;
} catch (e) {
throw new Error(
`Could not authenticate using Cloud Shell credentials. Please select a different authentication method or ensure you are in a properly configured environment. Error: ${getErrorMessage(
e,
)}`,
);
}
}
if (config.isBrowserLaunchSuppressed()) {
let success = false;
const maxRetries = 2;
for (let i = 0; !success && i < maxRetries; i++) {
success = await authWithUserCode(client);
if (!success) {
console.error(
'\nFailed to authenticate with user code.',
i === maxRetries - 1 ? '' : 'Retrying...\n',
);
}
}
if (!success) {
throw new FatalAuthenticationError(
'Failed to authenticate with user code.',
);
}
} else {
const webLogin = await authWithWeb(client);
console.log(
`\n\nCode Assist login required.\n` +
`Attempting to open authentication page in your browser.\n` +
`Otherwise navigate to:\n\n${webLogin.authUrl}\n\n`,
);
try {
// Attempt to open the authentication URL in the default browser.
// We do not use the `wait` option here because the main script's execution
// is already paused by `loginCompletePromise`, which awaits the server callback.
const childProcess = await open(webLogin.authUrl);
// IMPORTANT: Attach an error handler to the returned child process.
// Without this, if `open` fails to spawn a process (e.g., `xdg-open` is not found
// in a minimal Docker container), it will emit an unhandled 'error' event,
// causing the entire Node.js process to crash.
childProcess.on('error', (error) => {
console.error(
'Failed to open browser automatically. Please try running again with NO_BROWSER=true set.',
);
console.error('Browser error details:', getErrorMessage(error));
});
} catch (err) {
console.error(
'An unexpected error occurred while trying to open the browser:',
getErrorMessage(err),
'\nThis might be due to browser compatibility issues or system configuration.',
'\nPlease try running again with NO_BROWSER=true set for manual authentication.',
);
throw new FatalAuthenticationError(
`Failed to open browser: ${getErrorMessage(err)}`,
);
}
console.log('Waiting for authentication...');
// Add timeout to prevent infinite waiting when browser tab gets stuck
const authTimeout = 5 * 60 * 1000; // 5 minutes timeout
const timeoutPromise = new Promise<never>((_, reject) => {
setTimeout(() => {
reject(
new FatalAuthenticationError(
'Authentication timed out after 5 minutes. The browser tab may have gotten stuck in a loading state. ' +
'Please try again or use NO_BROWSER=true for manual authentication.',
),
);
}, authTimeout);
});
await Promise.race([webLogin.loginCompletePromise, timeoutPromise]);
}
return client;
}
export async function getOauthClient(
authType: AuthType,
config: Config,
): Promise<OAuth2Client> {
if (!oauthClientPromises.has(authType)) {
oauthClientPromises.set(authType, initOauthClient(authType, config));
}
return oauthClientPromises.get(authType)!;
}
async function authWithUserCode(client: OAuth2Client): Promise<boolean> {
const redirectUri = 'https://codeassist.google.com/authcode';
const codeVerifier = await client.generateCodeVerifierAsync();
const state = crypto.randomBytes(32).toString('hex');
const authUrl: string = client.generateAuthUrl({
redirect_uri: redirectUri,
access_type: 'offline',
scope: OAUTH_SCOPE,
code_challenge_method: CodeChallengeMethod.S256,
code_challenge: codeVerifier.codeChallenge,
state,
});
console.log('Please visit the following URL to authorize the application:');
console.log('');
console.log(authUrl);
console.log('');
const code = await new Promise<string>((resolve) => {
const rl = readline.createInterface({
input: process.stdin,
output: process.stdout,
});
rl.question('Enter the authorization code: ', (code) => {
rl.close();
resolve(code.trim());
});
});
if (!code) {
console.error('Authorization code is required.');
return false;
}
try {
const { tokens } = await client.getToken({
code,
codeVerifier: codeVerifier.codeVerifier,
redirect_uri: redirectUri,
});
client.setCredentials(tokens);
} catch (error) {
console.error(
'Failed to authenticate with authorization code:',
getErrorMessage(error),
);
return false;
}
return true;
}
async function authWithWeb(client: OAuth2Client): Promise<OauthWebLogin> {
const port = await getAvailablePort();
// The hostname used for the HTTP server binding (e.g., '0.0.0.0' in Docker).
const host = process.env['OAUTH_CALLBACK_HOST'] || 'localhost';
// The `redirectUri` sent to Google's authorization server MUST use a loopback IP literal
// (i.e., 'localhost' or '127.0.0.1'). This is a strict security policy for credentials of
// type 'Desktop app' or 'Web application' (when using loopback flow) to mitigate
// authorization code interception attacks.
const redirectUri = `http://localhost:${port}/oauth2callback`;
const state = crypto.randomBytes(32).toString('hex');
const authUrl = client.generateAuthUrl({
redirect_uri: redirectUri,
access_type: 'offline',
scope: OAUTH_SCOPE,
state,
});
const loginCompletePromise = new Promise<void>((resolve, reject) => {
const server = http.createServer(async (req, res) => {
try {
if (req.url!.indexOf('/oauth2callback') === -1) {
res.writeHead(HTTP_REDIRECT, { Location: SIGN_IN_FAILURE_URL });
res.end();
reject(
new FatalAuthenticationError(
'OAuth callback not received. Unexpected request: ' + req.url,
),
);
}
// acquire the code from the querystring, and close the web server.
const qs = new url.URL(req.url!, 'http://localhost:3000').searchParams;
if (qs.get('error')) {
res.writeHead(HTTP_REDIRECT, { Location: SIGN_IN_FAILURE_URL });
res.end();
const errorCode = qs.get('error');
const errorDescription =
qs.get('error_description') || 'No additional details provided';
reject(
new FatalAuthenticationError(
`Google OAuth error: ${errorCode}. ${errorDescription}`,
),
);
} else if (qs.get('state') !== state) {
res.end('State mismatch. Possible CSRF attack');
reject(
new FatalAuthenticationError(
'OAuth state mismatch. Possible CSRF attack or browser session issue.',
),
);
} else if (qs.get('code')) {
try {
const { tokens } = await client.getToken({
code: qs.get('code')!,
redirect_uri: redirectUri,
});
client.setCredentials(tokens);
// Retrieve and cache Google Account ID during authentication
try {
await fetchAndCacheUserInfo(client);
} catch (error) {
console.warn(
'Failed to retrieve Google Account ID during authentication:',
getErrorMessage(error),
);
// Don't fail the auth flow if Google Account ID retrieval fails
}
res.writeHead(HTTP_REDIRECT, { Location: SIGN_IN_SUCCESS_URL });
res.end();
resolve();
} catch (error) {
res.writeHead(HTTP_REDIRECT, { Location: SIGN_IN_FAILURE_URL });
res.end();
reject(
new FatalAuthenticationError(
`Failed to exchange authorization code for tokens: ${getErrorMessage(error)}`,
),
);
}
} else {
reject(
new FatalAuthenticationError(
'No authorization code received from Google OAuth. Please try authenticating again.',
),
);
}
} catch (e) {
// Provide more specific error message for unexpected errors during OAuth flow
if (e instanceof FatalAuthenticationError) {
reject(e);
} else {
reject(
new FatalAuthenticationError(
`Unexpected error during OAuth authentication: ${getErrorMessage(e)}`,
),
);
}
} finally {
server.close();
}
});
server.listen(port, host, () => {
// Server started successfully
});
server.on('error', (err) => {
reject(
new FatalAuthenticationError(
`OAuth callback server error: ${getErrorMessage(err)}`,
),
);
});
});
return {
authUrl,
loginCompletePromise,
};
}
export function getAvailablePort(): Promise<number> {
return new Promise((resolve, reject) => {
let port = 0;
try {
const portStr = process.env['OAUTH_CALLBACK_PORT'];
if (portStr) {
port = parseInt(portStr, 10);
if (isNaN(port) || port <= 0 || port > 65535) {
return reject(
new Error(`Invalid value for OAUTH_CALLBACK_PORT: "${portStr}"`),
);
}
return resolve(port);
}
const server = net.createServer();
server.listen(0, () => {
const address = server.address()! as net.AddressInfo;
port = address.port;
});
server.on('listening', () => {
server.close();
server.unref();
});
server.on('error', (e) => reject(e));
server.on('close', () => resolve(port));
} catch (e) {
reject(e);
}
});
}
async function loadCachedCredentials(client: OAuth2Client): Promise<boolean> {
const useEncryptedStorage = getUseEncryptedStorageFlag();
if (useEncryptedStorage) {
const credentials = await OAuthCredentialStorage.loadCredentials();
if (credentials) {
client.setCredentials(credentials);
return true;
}
return false;
}
const pathsToTry = [
Storage.getOAuthCredsPath(),
process.env['GOOGLE_APPLICATION_CREDENTIALS'],
].filter((p): p is string => !!p);
for (const keyFile of pathsToTry) {
try {
const creds = await fs.readFile(keyFile, 'utf-8');
client.setCredentials(JSON.parse(creds));
// This will verify locally that the credentials look good.
const { token } = await client.getAccessToken();
if (!token) {
continue;
}
// This will check with the server to see if it hasn't been revoked.
await client.getTokenInfo(token);
return true;
} catch (error) {
// Log specific error for debugging, but continue trying other paths
console.debug(
`Failed to load credentials from ${keyFile}:`,
getErrorMessage(error),
);
}
}
return false;
}
async function cacheCredentials(credentials: Credentials) {
const filePath = Storage.getOAuthCredsPath();
await fs.mkdir(path.dirname(filePath), { recursive: true });
const credString = JSON.stringify(credentials, null, 2);
await fs.writeFile(filePath, credString, { mode: 0o600 });
try {
await fs.chmod(filePath, 0o600);
} catch {
/* empty */
}
}
export function clearOauthClientCache() {
oauthClientPromises.clear();
}
export async function clearCachedCredentialFile() {
try {
const useEncryptedStorage = getUseEncryptedStorageFlag();
if (useEncryptedStorage) {
await OAuthCredentialStorage.clearCredentials();
} else {
await fs.rm(Storage.getOAuthCredsPath(), { force: true });
}
// Clear the Google Account ID cache when credentials are cleared
await userAccountManager.clearCachedGoogleAccount();
// Clear the in-memory OAuth client cache to force re-authentication
clearOauthClientCache();
/**
* Also clear Qwen SharedTokenManager cache and credentials file to prevent stale credentials
* when switching between auth types
* TODO: We do not depend on code_assist, we'll have to build an independent auth-cleaning procedure.
*/
try {
const { SharedTokenManager } = await import(
'../qwen/sharedTokenManager.js'
);
const { clearQwenCredentials } = await import('../qwen/qwenOAuth2.js');
const sharedManager = SharedTokenManager.getInstance();
sharedManager.clearCache();
await clearQwenCredentials();
} catch (qwenError) {
console.debug('Could not clear Qwen credentials:', qwenError);
}
} catch (e) {
console.error('Failed to clear cached credentials:', e);
}
}
async function fetchAndCacheUserInfo(client: OAuth2Client): Promise<void> {
try {
const { token } = await client.getAccessToken();
if (!token) {
return;
}
const response = await fetch(
'https://www.googleapis.com/oauth2/v2/userinfo',
{
headers: {
Authorization: `Bearer ${token}`,
},
},
);
if (!response.ok) {
console.error(
'Failed to fetch user info:',
response.status,
response.statusText,
);
return;
}
const userInfo = await response.json();
await userAccountManager.cacheGoogleAccount(userInfo.email);
} catch (error) {
console.error('Error retrieving user info:', error);
}
}
// Helper to ensure test isolation
export function resetOauthClientForTesting() {
oauthClientPromises.clear();
}

View File

@@ -1,255 +0,0 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { beforeEach, describe, it, expect, vi } from 'vitest';
import { CodeAssistServer } from './server.js';
import { OAuth2Client } from 'google-auth-library';
import { UserTierId } from './types.js';
vi.mock('google-auth-library');
describe('CodeAssistServer', () => {
beforeEach(() => {
vi.resetAllMocks();
});
it('should be able to be constructed', () => {
const auth = new OAuth2Client();
const server = new CodeAssistServer(
auth,
'test-project',
{},
'test-session',
UserTierId.FREE,
);
expect(server).toBeInstanceOf(CodeAssistServer);
});
it('should call the generateContent endpoint', async () => {
const client = new OAuth2Client();
const server = new CodeAssistServer(
client,
'test-project',
{},
'test-session',
UserTierId.FREE,
);
const mockResponse = {
response: {
candidates: [
{
index: 0,
content: {
role: 'model',
parts: [{ text: 'response' }],
},
finishReason: 'STOP',
safetyRatings: [],
},
],
},
};
vi.spyOn(server, 'requestPost').mockResolvedValue(mockResponse);
const response = await server.generateContent(
{
model: 'test-model',
contents: [{ role: 'user', parts: [{ text: 'request' }] }],
},
'user-prompt-id',
);
expect(server.requestPost).toHaveBeenCalledWith(
'generateContent',
expect.any(Object),
undefined,
);
expect(response.candidates?.[0]?.content?.parts?.[0]?.text).toBe(
'response',
);
});
it('should call the generateContentStream endpoint', async () => {
const client = new OAuth2Client();
const server = new CodeAssistServer(
client,
'test-project',
{},
'test-session',
UserTierId.FREE,
);
const mockResponse = (async function* () {
yield {
response: {
candidates: [
{
index: 0,
content: {
role: 'model',
parts: [{ text: 'response' }],
},
finishReason: 'STOP',
safetyRatings: [],
},
],
},
};
})();
vi.spyOn(server, 'requestStreamingPost').mockResolvedValue(mockResponse);
const stream = await server.generateContentStream(
{
model: 'test-model',
contents: [{ role: 'user', parts: [{ text: 'request' }] }],
},
'user-prompt-id',
);
for await (const res of stream) {
expect(server.requestStreamingPost).toHaveBeenCalledWith(
'streamGenerateContent',
expect.any(Object),
undefined,
);
expect(res.candidates?.[0]?.content?.parts?.[0]?.text).toBe('response');
}
});
it('should call the onboardUser endpoint', async () => {
const client = new OAuth2Client();
const server = new CodeAssistServer(
client,
'test-project',
{},
'test-session',
UserTierId.FREE,
);
const mockResponse = {
name: 'operations/123',
done: true,
};
vi.spyOn(server, 'requestPost').mockResolvedValue(mockResponse);
const response = await server.onboardUser({
tierId: 'test-tier',
cloudaicompanionProject: 'test-project',
metadata: {},
});
expect(server.requestPost).toHaveBeenCalledWith(
'onboardUser',
expect.any(Object),
);
expect(response.name).toBe('operations/123');
});
it('should call the loadCodeAssist endpoint', async () => {
const client = new OAuth2Client();
const server = new CodeAssistServer(
client,
'test-project',
{},
'test-session',
UserTierId.FREE,
);
const mockResponse = {
currentTier: {
id: UserTierId.FREE,
name: 'Free',
description: 'free tier',
},
allowedTiers: [],
ineligibleTiers: [],
cloudaicompanionProject: 'projects/test',
};
vi.spyOn(server, 'requestPost').mockResolvedValue(mockResponse);
const response = await server.loadCodeAssist({
metadata: {},
});
expect(server.requestPost).toHaveBeenCalledWith(
'loadCodeAssist',
expect.any(Object),
);
expect(response).toEqual(mockResponse);
});
it('should return 0 for countTokens', async () => {
const client = new OAuth2Client();
const server = new CodeAssistServer(
client,
'test-project',
{},
'test-session',
UserTierId.FREE,
);
const mockResponse = {
totalTokens: 100,
};
vi.spyOn(server, 'requestPost').mockResolvedValue(mockResponse);
const response = await server.countTokens({
model: 'test-model',
contents: [{ role: 'user', parts: [{ text: 'request' }] }],
});
expect(response.totalTokens).toBe(100);
});
it('should throw an error for embedContent', async () => {
const client = new OAuth2Client();
const server = new CodeAssistServer(
client,
'test-project',
{},
'test-session',
UserTierId.FREE,
);
await expect(
server.embedContent({
model: 'test-model',
contents: [{ role: 'user', parts: [{ text: 'request' }] }],
}),
).rejects.toThrow();
});
it('should handle VPC-SC errors when calling loadCodeAssist', async () => {
const client = new OAuth2Client();
const server = new CodeAssistServer(
client,
'test-project',
{},
'test-session',
UserTierId.FREE,
);
const mockVpcScError = {
response: {
data: {
error: {
details: [
{
reason: 'SECURITY_POLICY_VIOLATED',
},
],
},
},
},
};
vi.spyOn(server, 'requestPost').mockRejectedValue(mockVpcScError);
const response = await server.loadCodeAssist({
metadata: {},
});
expect(server.requestPost).toHaveBeenCalledWith(
'loadCodeAssist',
expect.any(Object),
);
expect(response).toEqual({
currentTier: { id: UserTierId.STANDARD },
});
});
});

View File

@@ -1,253 +0,0 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import type { OAuth2Client } from 'google-auth-library';
import type {
CodeAssistGlobalUserSettingResponse,
GoogleRpcResponse,
LoadCodeAssistRequest,
LoadCodeAssistResponse,
LongRunningOperationResponse,
OnboardUserRequest,
SetCodeAssistGlobalUserSettingRequest,
} from './types.js';
import type {
CountTokensParameters,
CountTokensResponse,
EmbedContentParameters,
EmbedContentResponse,
GenerateContentParameters,
GenerateContentResponse,
} from '@google/genai';
import * as readline from 'node:readline';
import type { ContentGenerator } from '../core/contentGenerator.js';
import { UserTierId } from './types.js';
import type {
CaCountTokenResponse,
CaGenerateContentResponse,
} from './converter.js';
import {
fromCountTokenResponse,
fromGenerateContentResponse,
toCountTokenRequest,
toGenerateContentRequest,
} from './converter.js';
/** HTTP options to be used in each of the requests. */
export interface HttpOptions {
/** Additional HTTP headers to be sent with the request. */
headers?: Record<string, string>;
}
export const CODE_ASSIST_ENDPOINT = 'https://localhost:0'; // Disable Google Code Assist API Request
export const CODE_ASSIST_API_VERSION = 'v1internal';
export class CodeAssistServer implements ContentGenerator {
constructor(
readonly client: OAuth2Client,
readonly projectId?: string,
readonly httpOptions: HttpOptions = {},
readonly sessionId?: string,
readonly userTier?: UserTierId,
) {}
async generateContentStream(
req: GenerateContentParameters,
userPromptId: string,
): Promise<AsyncGenerator<GenerateContentResponse>> {
const resps = await this.requestStreamingPost<CaGenerateContentResponse>(
'streamGenerateContent',
toGenerateContentRequest(
req,
userPromptId,
this.projectId,
this.sessionId,
),
req.config?.abortSignal,
);
return (async function* (): AsyncGenerator<GenerateContentResponse> {
for await (const resp of resps) {
yield fromGenerateContentResponse(resp);
}
})();
}
async generateContent(
req: GenerateContentParameters,
userPromptId: string,
): Promise<GenerateContentResponse> {
const resp = await this.requestPost<CaGenerateContentResponse>(
'generateContent',
toGenerateContentRequest(
req,
userPromptId,
this.projectId,
this.sessionId,
),
req.config?.abortSignal,
);
return fromGenerateContentResponse(resp);
}
async onboardUser(
req: OnboardUserRequest,
): Promise<LongRunningOperationResponse> {
return await this.requestPost<LongRunningOperationResponse>(
'onboardUser',
req,
);
}
async loadCodeAssist(
req: LoadCodeAssistRequest,
): Promise<LoadCodeAssistResponse> {
try {
return await this.requestPost<LoadCodeAssistResponse>(
'loadCodeAssist',
req,
);
} catch (e) {
if (isVpcScAffectedUser(e)) {
return {
currentTier: { id: UserTierId.STANDARD },
};
} else {
throw e;
}
}
}
async getCodeAssistGlobalUserSetting(): Promise<CodeAssistGlobalUserSettingResponse> {
return await this.requestGet<CodeAssistGlobalUserSettingResponse>(
'getCodeAssistGlobalUserSetting',
);
}
async setCodeAssistGlobalUserSetting(
req: SetCodeAssistGlobalUserSettingRequest,
): Promise<CodeAssistGlobalUserSettingResponse> {
return await this.requestPost<CodeAssistGlobalUserSettingResponse>(
'setCodeAssistGlobalUserSetting',
req,
);
}
async countTokens(req: CountTokensParameters): Promise<CountTokensResponse> {
const resp = await this.requestPost<CaCountTokenResponse>(
'countTokens',
toCountTokenRequest(req),
);
return fromCountTokenResponse(resp);
}
async embedContent(
_req: EmbedContentParameters,
): Promise<EmbedContentResponse> {
throw Error();
}
async requestPost<T>(
method: string,
req: object,
signal?: AbortSignal,
): Promise<T> {
const res = await this.client.request({
url: this.getMethodUrl(method),
method: 'POST',
headers: {
'Content-Type': 'application/json',
...this.httpOptions.headers,
},
responseType: 'json',
body: JSON.stringify(req),
signal,
});
return res.data as T;
}
async requestGet<T>(method: string, signal?: AbortSignal): Promise<T> {
const res = await this.client.request({
url: this.getMethodUrl(method),
method: 'GET',
headers: {
'Content-Type': 'application/json',
...this.httpOptions.headers,
},
responseType: 'json',
signal,
});
return res.data as T;
}
async requestStreamingPost<T>(
method: string,
req: object,
signal?: AbortSignal,
): Promise<AsyncGenerator<T>> {
const res = await this.client.request({
url: this.getMethodUrl(method),
method: 'POST',
params: {
alt: 'sse',
},
headers: {
'Content-Type': 'application/json',
...this.httpOptions.headers,
},
responseType: 'stream',
body: JSON.stringify(req),
signal,
});
return (async function* (): AsyncGenerator<T> {
const rl = readline.createInterface({
input: res.data as NodeJS.ReadableStream,
crlfDelay: Infinity, // Recognizes '\r\n' and '\n' as line breaks
});
let bufferedLines: string[] = [];
for await (const line of rl) {
// blank lines are used to separate JSON objects in the stream
if (line === '') {
if (bufferedLines.length === 0) {
continue; // no data to yield
}
yield JSON.parse(bufferedLines.join('\n')) as T;
bufferedLines = []; // Reset the buffer after yielding
} else if (line.startsWith('data: ')) {
bufferedLines.push(line.slice(6).trim());
} else {
throw new Error(`Unexpected line format in response: ${line}`);
}
}
})();
}
getMethodUrl(method: string): string {
const endpoint =
process.env['CODE_ASSIST_ENDPOINT'] ?? CODE_ASSIST_ENDPOINT;
return `${endpoint}/${CODE_ASSIST_API_VERSION}:${method}`;
}
}
function isVpcScAffectedUser(error: unknown): boolean {
if (error && typeof error === 'object' && 'response' in error) {
const gaxiosError = error as {
response?: {
data?: unknown;
};
};
const response = gaxiosError.response?.data as
| GoogleRpcResponse
| undefined;
if (Array.isArray(response?.error?.details)) {
return response.error.details.some(
(detail) => detail.reason === 'SECURITY_POLICY_VIOLATED',
);
}
}
return false;
}

View File

@@ -1,224 +0,0 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
import { setupUser, ProjectIdRequiredError } from './setup.js';
import { CodeAssistServer } from '../code_assist/server.js';
import type { OAuth2Client } from 'google-auth-library';
import type { GeminiUserTier } from './types.js';
import { UserTierId } from './types.js';
vi.mock('../code_assist/server.js');
const mockPaidTier: GeminiUserTier = {
id: UserTierId.STANDARD,
name: 'paid',
description: 'Paid tier',
isDefault: true,
};
const mockFreeTier: GeminiUserTier = {
id: UserTierId.FREE,
name: 'free',
description: 'Free tier',
isDefault: true,
};
describe('setupUser for existing user', () => {
let mockLoad: ReturnType<typeof vi.fn>;
let mockOnboardUser: ReturnType<typeof vi.fn>;
beforeEach(() => {
vi.resetAllMocks();
mockLoad = vi.fn();
mockOnboardUser = vi.fn().mockResolvedValue({
done: true,
response: {
cloudaicompanionProject: {
id: 'server-project',
},
},
});
vi.mocked(CodeAssistServer).mockImplementation(
() =>
({
loadCodeAssist: mockLoad,
onboardUser: mockOnboardUser,
}) as unknown as CodeAssistServer,
);
});
afterEach(() => {
vi.unstubAllEnvs();
});
it('should use GOOGLE_CLOUD_PROJECT when set and project from server is undefined', async () => {
vi.stubEnv('GOOGLE_CLOUD_PROJECT', 'test-project');
mockLoad.mockResolvedValue({
currentTier: mockPaidTier,
});
await setupUser({} as OAuth2Client);
expect(CodeAssistServer).toHaveBeenCalledWith(
{},
'test-project',
{},
'',
undefined,
);
});
it('should ignore GOOGLE_CLOUD_PROJECT when project from server is set', async () => {
vi.stubEnv('GOOGLE_CLOUD_PROJECT', 'test-project');
mockLoad.mockResolvedValue({
cloudaicompanionProject: 'server-project',
currentTier: mockPaidTier,
});
const projectId = await setupUser({} as OAuth2Client);
expect(CodeAssistServer).toHaveBeenCalledWith(
{},
'test-project',
{},
'',
undefined,
);
expect(projectId).toEqual({
projectId: 'server-project',
userTier: 'standard-tier',
});
});
it('should throw ProjectIdRequiredError when no project ID is available', async () => {
vi.stubEnv('GOOGLE_CLOUD_PROJECT', '');
// And the server itself requires a project ID internally
vi.mocked(CodeAssistServer).mockImplementation(() => {
throw new ProjectIdRequiredError();
});
await expect(setupUser({} as OAuth2Client)).rejects.toThrow(
ProjectIdRequiredError,
);
});
});
describe('setupUser for new user', () => {
let mockLoad: ReturnType<typeof vi.fn>;
let mockOnboardUser: ReturnType<typeof vi.fn>;
beforeEach(() => {
vi.resetAllMocks();
mockLoad = vi.fn();
mockOnboardUser = vi.fn().mockResolvedValue({
done: true,
response: {
cloudaicompanionProject: {
id: 'server-project',
},
},
});
vi.mocked(CodeAssistServer).mockImplementation(
() =>
({
loadCodeAssist: mockLoad,
onboardUser: mockOnboardUser,
}) as unknown as CodeAssistServer,
);
});
afterEach(() => {
vi.unstubAllEnvs();
});
it('should use GOOGLE_CLOUD_PROJECT when set and onboard a new paid user', async () => {
vi.stubEnv('GOOGLE_CLOUD_PROJECT', 'test-project');
mockLoad.mockResolvedValue({
allowedTiers: [mockPaidTier],
});
const userData = await setupUser({} as OAuth2Client);
expect(CodeAssistServer).toHaveBeenCalledWith(
{},
'test-project',
{},
'',
undefined,
);
expect(mockLoad).toHaveBeenCalled();
expect(mockOnboardUser).toHaveBeenCalledWith({
tierId: 'standard-tier',
cloudaicompanionProject: 'test-project',
metadata: {
ideType: 'IDE_UNSPECIFIED',
platform: 'PLATFORM_UNSPECIFIED',
pluginType: 'GEMINI',
duetProject: 'test-project',
},
});
expect(userData).toEqual({
projectId: 'server-project',
userTier: 'standard-tier',
});
});
it('should onboard a new free user when GOOGLE_CLOUD_PROJECT is not set', async () => {
vi.stubEnv('GOOGLE_CLOUD_PROJECT', '');
mockLoad.mockResolvedValue({
allowedTiers: [mockFreeTier],
});
const userData = await setupUser({} as OAuth2Client);
expect(CodeAssistServer).toHaveBeenCalledWith(
{},
undefined,
{},
'',
undefined,
);
expect(mockLoad).toHaveBeenCalled();
expect(mockOnboardUser).toHaveBeenCalledWith({
tierId: 'free-tier',
cloudaicompanionProject: undefined,
metadata: {
ideType: 'IDE_UNSPECIFIED',
platform: 'PLATFORM_UNSPECIFIED',
pluginType: 'GEMINI',
},
});
expect(userData).toEqual({
projectId: 'server-project',
userTier: 'free-tier',
});
});
it('should use GOOGLE_CLOUD_PROJECT when onboard response has no project ID', async () => {
vi.stubEnv('GOOGLE_CLOUD_PROJECT', 'test-project');
mockLoad.mockResolvedValue({
allowedTiers: [mockPaidTier],
});
mockOnboardUser.mockResolvedValue({
done: true,
response: {
cloudaicompanionProject: undefined,
},
});
const userData = await setupUser({} as OAuth2Client);
expect(userData).toEqual({
projectId: 'test-project',
userTier: 'standard-tier',
});
});
it('should throw ProjectIdRequiredError when no project ID is available', async () => {
vi.stubEnv('GOOGLE_CLOUD_PROJECT', '');
mockLoad.mockResolvedValue({
allowedTiers: [mockPaidTier],
});
mockOnboardUser.mockResolvedValue({
done: true,
response: {},
});
await expect(setupUser({} as OAuth2Client)).rejects.toThrow(
ProjectIdRequiredError,
);
});
});

View File

@@ -1,124 +0,0 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import type {
ClientMetadata,
GeminiUserTier,
LoadCodeAssistResponse,
OnboardUserRequest,
} from './types.js';
import { UserTierId } from './types.js';
import { CodeAssistServer } from './server.js';
import type { OAuth2Client } from 'google-auth-library';
export class ProjectIdRequiredError extends Error {
constructor() {
super(
'This account requires setting the GOOGLE_CLOUD_PROJECT env var. See https://goo.gle/gemini-cli-auth-docs#workspace-gca',
);
}
}
export interface UserData {
projectId: string;
userTier: UserTierId;
}
/**
*
* @param projectId the user's project id, if any
* @returns the user's actual project id
*/
export async function setupUser(client: OAuth2Client): Promise<UserData> {
const projectId = process.env['GOOGLE_CLOUD_PROJECT'] || undefined;
const caServer = new CodeAssistServer(client, projectId, {}, '', undefined);
const coreClientMetadata: ClientMetadata = {
ideType: 'IDE_UNSPECIFIED',
platform: 'PLATFORM_UNSPECIFIED',
pluginType: 'GEMINI',
};
const loadRes = await caServer.loadCodeAssist({
cloudaicompanionProject: projectId,
metadata: {
...coreClientMetadata,
duetProject: projectId,
},
});
if (loadRes.currentTier) {
if (!loadRes.cloudaicompanionProject) {
if (projectId) {
return {
projectId,
userTier: loadRes.currentTier.id,
};
}
throw new ProjectIdRequiredError();
}
return {
projectId: loadRes.cloudaicompanionProject,
userTier: loadRes.currentTier.id,
};
}
const tier = getOnboardTier(loadRes);
let onboardReq: OnboardUserRequest;
if (tier.id === UserTierId.FREE) {
// The free tier uses a managed google cloud project. Setting a project in the `onboardUser` request causes a `Precondition Failed` error.
onboardReq = {
tierId: tier.id,
cloudaicompanionProject: undefined,
metadata: coreClientMetadata,
};
} else {
onboardReq = {
tierId: tier.id,
cloudaicompanionProject: projectId,
metadata: {
...coreClientMetadata,
duetProject: projectId,
},
};
}
// Poll onboardUser until long running operation is complete.
let lroRes = await caServer.onboardUser(onboardReq);
while (!lroRes.done) {
await new Promise((f) => setTimeout(f, 5000));
lroRes = await caServer.onboardUser(onboardReq);
}
if (!lroRes.response?.cloudaicompanionProject?.id) {
if (projectId) {
return {
projectId,
userTier: tier.id,
};
}
throw new ProjectIdRequiredError();
}
return {
projectId: lroRes.response.cloudaicompanionProject.id,
userTier: tier.id,
};
}
function getOnboardTier(res: LoadCodeAssistResponse): GeminiUserTier {
for (const tier of res.allowedTiers || []) {
if (tier.isDefault) {
return tier;
}
}
return {
name: '',
description: '',
id: UserTierId.LEGACY,
userDefinedCloudaicompanionProject: true,
};
}

View File

@@ -1,201 +0,0 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
export interface ClientMetadata {
ideType?: ClientMetadataIdeType;
ideVersion?: string;
pluginVersion?: string;
platform?: ClientMetadataPlatform;
updateChannel?: string;
duetProject?: string;
pluginType?: ClientMetadataPluginType;
ideName?: string;
}
export type ClientMetadataIdeType =
| 'IDE_UNSPECIFIED'
| 'VSCODE'
| 'INTELLIJ'
| 'VSCODE_CLOUD_WORKSTATION'
| 'INTELLIJ_CLOUD_WORKSTATION'
| 'CLOUD_SHELL';
export type ClientMetadataPlatform =
| 'PLATFORM_UNSPECIFIED'
| 'DARWIN_AMD64'
| 'DARWIN_ARM64'
| 'LINUX_AMD64'
| 'LINUX_ARM64'
| 'WINDOWS_AMD64';
export type ClientMetadataPluginType =
| 'PLUGIN_UNSPECIFIED'
| 'CLOUD_CODE'
| 'GEMINI'
| 'AIPLUGIN_INTELLIJ'
| 'AIPLUGIN_STUDIO';
export interface LoadCodeAssistRequest {
cloudaicompanionProject?: string;
metadata: ClientMetadata;
}
/**
* Represents LoadCodeAssistResponse proto json field
* http://google3/google/internal/cloud/code/v1internal/cloudcode.proto;l=224
*/
export interface LoadCodeAssistResponse {
currentTier?: GeminiUserTier | null;
allowedTiers?: GeminiUserTier[] | null;
ineligibleTiers?: IneligibleTier[] | null;
cloudaicompanionProject?: string | null;
}
/**
* GeminiUserTier reflects the structure received from the CodeAssist when calling LoadCodeAssist.
*/
export interface GeminiUserTier {
id: UserTierId;
name?: string;
description?: string;
// This value is used to declare whether a given tier requires the user to configure the project setting on the IDE settings or not.
userDefinedCloudaicompanionProject?: boolean | null;
isDefault?: boolean;
privacyNotice?: PrivacyNotice;
hasAcceptedTos?: boolean;
hasOnboardedPreviously?: boolean;
}
/**
* Includes information specifying the reasons for a user's ineligibility for a specific tier.
* @param reasonCode mnemonic code representing the reason for in-eligibility.
* @param reasonMessage message to display to the user.
* @param tierId id of the tier.
* @param tierName name of the tier.
*/
export interface IneligibleTier {
reasonCode: IneligibleTierReasonCode;
reasonMessage: string;
tierId: UserTierId;
tierName: string;
}
/**
* List of predefined reason codes when a tier is blocked from a specific tier.
* https://source.corp.google.com/piper///depot/google3/google/internal/cloud/code/v1internal/cloudcode.proto;l=378
*/
export enum IneligibleTierReasonCode {
// go/keep-sorted start
DASHER_USER = 'DASHER_USER',
INELIGIBLE_ACCOUNT = 'INELIGIBLE_ACCOUNT',
NON_USER_ACCOUNT = 'NON_USER_ACCOUNT',
RESTRICTED_AGE = 'RESTRICTED_AGE',
RESTRICTED_NETWORK = 'RESTRICTED_NETWORK',
UNKNOWN = 'UNKNOWN',
UNKNOWN_LOCATION = 'UNKNOWN_LOCATION',
UNSUPPORTED_LOCATION = 'UNSUPPORTED_LOCATION',
// go/keep-sorted end
}
/**
* UserTierId represents IDs returned from the Cloud Code Private API representing a user's tier
*
* //depot/google3/cloud/developer_experience/cloudcode/pa/service/usertier.go;l=16
*/
export enum UserTierId {
FREE = 'free-tier',
LEGACY = 'legacy-tier',
STANDARD = 'standard-tier',
}
/**
* PrivacyNotice reflects the structure received from the CodeAssist in regards to a tier
* privacy notice.
*/
export interface PrivacyNotice {
showNotice: boolean;
noticeText?: string;
}
/**
* Proto signature of OnboardUserRequest as payload to OnboardUser call
*/
export interface OnboardUserRequest {
tierId: string | undefined;
cloudaicompanionProject: string | undefined;
metadata: ClientMetadata | undefined;
}
/**
* Represents LongRunningOperation proto
* http://google3/google/longrunning/operations.proto;rcl=698857719;l=107
*/
export interface LongRunningOperationResponse {
name: string;
done?: boolean;
response?: OnboardUserResponse;
}
/**
* Represents OnboardUserResponse proto
* http://google3/google/internal/cloud/code/v1internal/cloudcode.proto;l=215
*/
export interface OnboardUserResponse {
// tslint:disable-next-line:enforce-name-casing This is the name of the field in the proto.
cloudaicompanionProject?: {
id: string;
name: string;
};
}
/**
* Status code of user license status
* it does not strictly correspond to the proto
* Error value is an additional value assigned to error responses from OnboardUser
*/
export enum OnboardUserStatusCode {
Default = 'DEFAULT',
Notice = 'NOTICE',
Warning = 'WARNING',
Error = 'ERROR',
}
/**
* Status of user onboarded to gemini
*/
export interface OnboardUserStatus {
statusCode: OnboardUserStatusCode;
displayMessage: string;
helpLink: HelpLinkUrl | undefined;
}
export interface HelpLinkUrl {
description: string;
url: string;
}
export interface SetCodeAssistGlobalUserSettingRequest {
cloudaicompanionProject?: string;
freeTierDataCollectionOptin: boolean;
}
export interface CodeAssistGlobalUserSettingResponse {
cloudaicompanionProject?: string;
freeTierDataCollectionOptin: boolean;
}
/**
* Relevant fields that can be returned from a Google RPC response
*/
export interface GoogleRpcResponse {
error?: {
details?: GoogleRpcErrorInfo[];
};
}
/**
* Relevant fields that can be returned in the details of an error returned from GoogleRPCs
*/
interface GoogleRpcErrorInfo {
reason?: string;
}

View File

@@ -283,23 +283,6 @@ describe('Server Config (config.ts)', () => {
expect(config.isInFallbackMode()).toBe(false);
});
it('should strip thoughts when switching from GenAI to Vertex', async () => {
const config = new Config(baseParams);
vi.mocked(createContentGeneratorConfig).mockImplementation(
(_: Config, authType: AuthType | undefined) =>
({ authType }) as unknown as ContentGeneratorConfig,
);
await config.refreshAuth(AuthType.USE_GEMINI);
await config.refreshAuth(AuthType.LOGIN_WITH_GOOGLE);
expect(
config.getGeminiClient().stripThoughtsFromHistory,
).toHaveBeenCalledWith();
});
it('should not strip thoughts when switching from Vertex to GenAI', async () => {
const config = new Config(baseParams);

View File

@@ -16,6 +16,7 @@ import { ProxyAgent, setGlobalDispatcher } from 'undici';
import type {
ContentGenerator,
ContentGeneratorConfig,
AuthType,
} from '../core/contentGenerator.js';
import type { FallbackModelHandler } from '../fallback/types.js';
import type { MCPOAuthConfig } from '../mcp/oauth-provider.js';
@@ -26,7 +27,6 @@ import type { AnyToolInvocation } from '../tools/tools.js';
import { BaseLlmClient } from '../core/baseLlmClient.js';
import { GeminiClient } from '../core/client.js';
import {
AuthType,
createContentGenerator,
createContentGeneratorConfig,
} from '../core/contentGenerator.js';
@@ -684,16 +684,6 @@ export class Config {
}
async refreshAuth(authMethod: AuthType, isInitialAuth?: boolean) {
// Vertex and Genai have incompatible encryption and sending history with
// throughtSignature from Genai to Vertex will fail, we need to strip them
if (
this.contentGeneratorConfig?.authType === AuthType.USE_GEMINI &&
authMethod === AuthType.LOGIN_WITH_GOOGLE
) {
// Restore the conversation history to the new client
this.geminiClient.stripThoughtsFromHistory();
}
const newContentGeneratorConfig = createContentGeneratorConfig(
this,
authMethod,

View File

@@ -31,7 +31,7 @@ describe('Flash Model Fallback Configuration', () => {
config as unknown as { contentGeneratorConfig: unknown }
).contentGeneratorConfig = {
model: DEFAULT_GEMINI_MODEL,
authType: 'oauth-personal',
authType: 'gemini-api-key',
};
});

View File

@@ -73,6 +73,7 @@ describe('OpenAIContentGenerator Timeout Handling', () => {
}),
buildClient: vi.fn().mockReturnValue(mockOpenAIClient),
buildRequest: vi.fn().mockImplementation((req) => req),
getDefaultGenerationConfig: vi.fn().mockReturnValue({}),
};
// Create generator instance
@@ -299,6 +300,7 @@ describe('OpenAIContentGenerator Timeout Handling', () => {
}),
buildClient: vi.fn().mockReturnValue(mockOpenAIClient),
buildRequest: vi.fn().mockImplementation((req) => req),
getDefaultGenerationConfig: vi.fn().mockReturnValue({}),
};
new OpenAIContentGenerator(
@@ -333,6 +335,7 @@ describe('OpenAIContentGenerator Timeout Handling', () => {
}),
buildClient: vi.fn().mockReturnValue(mockOpenAIClient),
buildRequest: vi.fn().mockImplementation((req) => req),
getDefaultGenerationConfig: vi.fn().mockReturnValue({}),
};
new OpenAIContentGenerator(

View File

@@ -146,12 +146,11 @@ describe('BaseLlmClient', () => {
// Validate the parameters passed to the underlying generator
expect(mockGenerateContent).toHaveBeenCalledTimes(1);
expect(mockGenerateContent).toHaveBeenCalledWith(
{
expect.objectContaining({
model: 'test-model',
contents: defaultOptions.contents,
config: {
config: expect.objectContaining({
abortSignal: defaultOptions.abortSignal,
topP: 0.8,
tools: [
{
functionDeclarations: [
@@ -163,9 +162,8 @@ describe('BaseLlmClient', () => {
],
},
],
// Crucial: systemInstruction should NOT be in the config object if not provided
},
},
}),
}),
'test-prompt-id',
);
});
@@ -188,7 +186,6 @@ describe('BaseLlmClient', () => {
expect.objectContaining({
config: expect.objectContaining({
temperature: 0.8,
topP: 0.8, // Default should remain if not overridden
topK: 10,
tools: expect.any(Array),
}),

View File

@@ -64,11 +64,6 @@ export interface GenerateJsonOptions {
* A client dedicated to stateless, utility-focused LLM calls.
*/
export class BaseLlmClient {
// Default configuration for utility tasks
private readonly defaultUtilityConfig: GenerateContentConfig = {
topP: 0.8,
};
constructor(
private readonly contentGenerator: ContentGenerator,
private readonly config: Config,
@@ -89,7 +84,6 @@ export class BaseLlmClient {
const requestConfig: GenerateContentConfig = {
abortSignal,
...this.defaultUtilityConfig,
...options.config,
...(systemInstruction && { systemInstruction }),
};

View File

@@ -15,11 +15,7 @@ import {
} from 'vitest';
import type { Content, GenerateContentResponse, Part } from '@google/genai';
import {
isThinkingDefault,
isThinkingSupported,
GeminiClient,
} from './client.js';
import { GeminiClient } from './client.js';
import { findCompressSplitPoint } from '../services/chatCompressionService.js';
import {
AuthType,
@@ -247,40 +243,6 @@ describe('findCompressSplitPoint', () => {
});
});
describe('isThinkingSupported', () => {
it('should return true for gemini-2.5', () => {
expect(isThinkingSupported('gemini-2.5')).toBe(true);
});
it('should return true for gemini-2.5-pro', () => {
expect(isThinkingSupported('gemini-2.5-pro')).toBe(true);
});
it('should return false for other models', () => {
expect(isThinkingSupported('gemini-1.5-flash')).toBe(false);
expect(isThinkingSupported('some-other-model')).toBe(false);
});
});
describe('isThinkingDefault', () => {
it('should return false for gemini-2.5-flash-lite', () => {
expect(isThinkingDefault('gemini-2.5-flash-lite')).toBe(false);
});
it('should return true for gemini-2.5', () => {
expect(isThinkingDefault('gemini-2.5')).toBe(true);
});
it('should return true for gemini-2.5-pro', () => {
expect(isThinkingDefault('gemini-2.5-pro')).toBe(true);
});
it('should return false for other models', () => {
expect(isThinkingDefault('gemini-1.5-flash')).toBe(false);
expect(isThinkingDefault('some-other-model')).toBe(false);
});
});
describe('Gemini Client (client.ts)', () => {
let mockContentGenerator: ContentGenerator;
let mockConfig: Config;
@@ -2304,16 +2266,15 @@ ${JSON.stringify(
);
expect(mockContentGenerator.generateContent).toHaveBeenCalledWith(
{
expect.objectContaining({
model: DEFAULT_GEMINI_FLASH_MODEL,
config: {
config: expect.objectContaining({
abortSignal,
systemInstruction: getCoreSystemPrompt(''),
temperature: 0.5,
topP: 0.8,
},
}),
contents,
},
}),
'test-session-id',
);
});

View File

@@ -15,11 +15,7 @@ import type {
// Config
import { ApprovalMode, type Config } from '../config/config.js';
import {
DEFAULT_GEMINI_FLASH_MODEL,
DEFAULT_GEMINI_MODEL_AUTO,
DEFAULT_THINKING_MODE,
} from '../config/models.js';
import { DEFAULT_GEMINI_FLASH_MODEL } from '../config/models.js';
// Core modules
import type { ContentGenerator } from './contentGenerator.js';
@@ -78,24 +74,10 @@ import { type File, type IdeContext } from '../ide/types.js';
// Fallback handling
import { handleFallback } from '../fallback/handler.js';
export function isThinkingSupported(model: string) {
return model.startsWith('gemini-2.5') || model === DEFAULT_GEMINI_MODEL_AUTO;
}
export function isThinkingDefault(model: string) {
if (model.startsWith('gemini-2.5-flash-lite')) {
return false;
}
return model.startsWith('gemini-2.5') || model === DEFAULT_GEMINI_MODEL_AUTO;
}
const MAX_TURNS = 100;
export class GeminiClient {
private chat?: GeminiChat;
private readonly generateContentConfig: GenerateContentConfig = {
topP: 0.8,
};
private sessionTurnCount = 0;
private readonly loopDetector: LoopDetectionService;
@@ -207,20 +189,10 @@ export class GeminiClient {
const model = this.config.getModel();
const systemInstruction = getCoreSystemPrompt(userMemory, model);
const config: GenerateContentConfig = { ...this.generateContentConfig };
if (isThinkingSupported(model)) {
config.thinkingConfig = {
includeThoughts: true,
thinkingBudget: DEFAULT_THINKING_MODE,
};
}
return new GeminiChat(
this.config,
{
systemInstruction,
...config,
tools,
},
history,
@@ -617,11 +589,6 @@ export class GeminiClient {
): Promise<GenerateContentResponse> {
let currentAttemptModel: string = model;
const configToUse: GenerateContentConfig = {
...this.generateContentConfig,
...generationConfig,
};
try {
const userMemory = this.config.getUserMemory();
const finalSystemInstruction = generationConfig.systemInstruction
@@ -630,7 +597,7 @@ export class GeminiClient {
const requestConfig: GenerateContentConfig = {
abortSignal,
...configToUse,
...generationConfig,
systemInstruction: finalSystemInstruction,
};
@@ -671,7 +638,7 @@ export class GeminiClient {
`Error generating content via API with model ${currentAttemptModel}.`,
{
requestContents: contents,
requestConfig: configToUse,
requestConfig: generationConfig,
},
'generateContent-api',
);

View File

@@ -5,42 +5,19 @@
*/
import { describe, it, expect, vi } from 'vitest';
import type { ContentGenerator } from './contentGenerator.js';
import { createContentGenerator, AuthType } from './contentGenerator.js';
import { createCodeAssistContentGenerator } from '../code_assist/codeAssist.js';
import { GoogleGenAI } from '@google/genai';
import type { Config } from '../config/config.js';
import { LoggingContentGenerator } from './loggingContentGenerator.js';
import { LoggingContentGenerator } from './geminiContentGenerator/loggingContentGenerator.js';
vi.mock('../code_assist/codeAssist.js');
vi.mock('@google/genai');
const mockConfig = {
getCliVersion: vi.fn().mockReturnValue('1.0.0'),
} as unknown as Config;
describe('createContentGenerator', () => {
it('should create a CodeAssistContentGenerator', async () => {
const mockGenerator = {} as unknown as ContentGenerator;
vi.mocked(createCodeAssistContentGenerator).mockResolvedValue(
mockGenerator as never,
);
const generator = await createContentGenerator(
{
model: 'test-model',
authType: AuthType.LOGIN_WITH_GOOGLE,
},
mockConfig,
);
expect(createCodeAssistContentGenerator).toHaveBeenCalled();
expect(generator).toEqual(
new LoggingContentGenerator(mockGenerator, mockConfig),
);
});
it('should create a GoogleGenAI content generator', async () => {
it('should create a Gemini content generator', async () => {
const mockConfig = {
getUsageStatisticsEnabled: () => true,
getContentGeneratorConfig: () => ({}),
getCliVersion: () => '1.0.0',
} as unknown as Config;
const mockGenerator = {
@@ -65,17 +42,17 @@ describe('createContentGenerator', () => {
},
},
});
expect(generator).toEqual(
new LoggingContentGenerator(
(mockGenerator as GoogleGenAI).models,
mockConfig,
),
);
// We expect it to be a LoggingContentGenerator wrapping a GeminiContentGenerator
expect(generator).toBeInstanceOf(LoggingContentGenerator);
const wrapped = (generator as LoggingContentGenerator).getWrapped();
expect(wrapped).toBeDefined();
});
it('should create a GoogleGenAI content generator with client install id logging disabled', async () => {
it('should create a Gemini content generator with client install id logging disabled', async () => {
const mockConfig = {
getUsageStatisticsEnabled: () => false,
getContentGeneratorConfig: () => ({}),
getCliVersion: () => '1.0.0',
} as unknown as Config;
const mockGenerator = {
models: {},
@@ -98,11 +75,6 @@ describe('createContentGenerator', () => {
},
},
});
expect(generator).toEqual(
new LoggingContentGenerator(
(mockGenerator as GoogleGenAI).models,
mockConfig,
),
);
expect(generator).toBeInstanceOf(LoggingContentGenerator);
});
});

View File

@@ -12,15 +12,9 @@ import type {
GenerateContentParameters,
GenerateContentResponse,
} from '@google/genai';
import { GoogleGenAI } from '@google/genai';
import { createCodeAssistContentGenerator } from '../code_assist/codeAssist.js';
import { DEFAULT_QWEN_MODEL } from '../config/models.js';
import type { Config } from '../config/config.js';
import type { UserTierId } from '../code_assist/types.js';
import { InstallationManager } from '../utils/installationManager.js';
import { LoggingContentGenerator } from './loggingContentGenerator.js';
/**
* Interface abstracting the core functionalities for generating content and counting tokens.
*/
@@ -39,14 +33,12 @@ export interface ContentGenerator {
embedContent(request: EmbedContentParameters): Promise<EmbedContentResponse>;
userTier?: UserTierId;
useSummarizedThinking(): boolean;
}
export enum AuthType {
LOGIN_WITH_GOOGLE = 'oauth-personal',
USE_GEMINI = 'gemini-api-key',
USE_VERTEX_AI = 'vertex-ai',
CLOUD_SHELL = 'cloud-shell',
USE_OPENAI = 'openai',
QWEN_OAUTH = 'qwen-oauth',
}
@@ -59,12 +51,9 @@ export type ContentGeneratorConfig = {
authType?: AuthType | undefined;
enableOpenAILogging?: boolean;
openAILoggingDir?: string;
// Timeout configuration in milliseconds
timeout?: number;
// Maximum retries for failed requests
maxRetries?: number;
// Disable cache control for DashScope providers
disableCacheControl?: boolean;
timeout?: number; // Timeout configuration in milliseconds
maxRetries?: number; // Maximum retries for failed requests
disableCacheControl?: boolean; // Disable cache control for DashScope providers
samplingParams?: {
top_p?: number;
top_k?: number;
@@ -74,6 +63,9 @@ export type ContentGeneratorConfig = {
temperature?: number;
max_tokens?: number;
};
reasoning?: {
effort?: 'low' | 'medium' | 'high';
};
proxy?: string | undefined;
userAgent?: string;
// Schema compliance mode for tool definitions
@@ -123,48 +115,14 @@ export async function createContentGenerator(
gcConfig: Config,
isInitialAuth?: boolean,
): Promise<ContentGenerator> {
const version = process.env['CLI_VERSION'] || process.version;
const userAgent = `QwenCode/${version} (${process.platform}; ${process.arch})`;
const baseHeaders: Record<string, string> = {
'User-Agent': userAgent,
};
if (
config.authType === AuthType.LOGIN_WITH_GOOGLE ||
config.authType === AuthType.CLOUD_SHELL
) {
const httpOptions = { headers: baseHeaders };
return new LoggingContentGenerator(
await createCodeAssistContentGenerator(
httpOptions,
config.authType,
gcConfig,
),
gcConfig,
);
}
if (
config.authType === AuthType.USE_GEMINI ||
config.authType === AuthType.USE_VERTEX_AI
) {
let headers: Record<string, string> = { ...baseHeaders };
if (gcConfig?.getUsageStatisticsEnabled()) {
const installationManager = new InstallationManager();
const installationId = installationManager.getInstallationId();
headers = {
...headers,
'x-gemini-api-privileged-user-id': `${installationId}`,
};
}
const httpOptions = { headers };
const googleGenAI = new GoogleGenAI({
apiKey: config.apiKey === '' ? undefined : config.apiKey,
vertexai: config.vertexai,
httpOptions,
});
return new LoggingContentGenerator(googleGenAI.models, gcConfig);
const { createGeminiContentGenerator } = await import(
'./geminiContentGenerator/index.js'
);
return createGeminiContentGenerator(config, gcConfig);
}
if (config.authType === AuthType.USE_OPENAI) {

View File

@@ -240,7 +240,7 @@ describe('CoreToolScheduler', () => {
getAllowedTools: () => [],
getContentGeneratorConfig: () => ({
model: 'test-model',
authType: 'oauth-personal',
authType: 'gemini-api-key',
}),
getShellExecutionConfig: () => ({
terminalWidth: 90,
@@ -318,7 +318,7 @@ describe('CoreToolScheduler', () => {
getAllowedTools: () => [],
getContentGeneratorConfig: () => ({
model: 'test-model',
authType: 'oauth-personal',
authType: 'gemini-api-key',
}),
getShellExecutionConfig: () => ({
terminalWidth: 90,
@@ -497,7 +497,7 @@ describe('CoreToolScheduler', () => {
getExcludeTools: () => ['write_file', 'edit', 'run_shell_command'],
getContentGeneratorConfig: () => ({
model: 'test-model',
authType: 'oauth-personal',
authType: 'gemini-api-key',
}),
getShellExecutionConfig: () => ({
terminalWidth: 90,
@@ -584,7 +584,7 @@ describe('CoreToolScheduler', () => {
getExcludeTools: () => ['write_file', 'edit'], // Different excluded tools
getContentGeneratorConfig: () => ({
model: 'test-model',
authType: 'oauth-personal',
authType: 'gemini-api-key',
}),
getShellExecutionConfig: () => ({
terminalWidth: 90,
@@ -674,7 +674,7 @@ describe('CoreToolScheduler with payload', () => {
getAllowedTools: () => [],
getContentGeneratorConfig: () => ({
model: 'test-model',
authType: 'oauth-personal',
authType: 'gemini-api-key',
}),
getShellExecutionConfig: () => ({
terminalWidth: 90,
@@ -1001,7 +1001,7 @@ describe('CoreToolScheduler edit cancellation', () => {
getAllowedTools: () => [],
getContentGeneratorConfig: () => ({
model: 'test-model',
authType: 'oauth-personal',
authType: 'gemini-api-key',
}),
getShellExecutionConfig: () => ({
terminalWidth: 90,
@@ -1108,7 +1108,7 @@ describe('CoreToolScheduler YOLO mode', () => {
getAllowedTools: () => [],
getContentGeneratorConfig: () => ({
model: 'test-model',
authType: 'oauth-personal',
authType: 'gemini-api-key',
}),
getShellExecutionConfig: () => ({
terminalWidth: 90,
@@ -1258,7 +1258,7 @@ describe('CoreToolScheduler cancellation during executing with live output', ()
getApprovalMode: () => ApprovalMode.DEFAULT,
getContentGeneratorConfig: () => ({
model: 'test-model',
authType: 'oauth-personal',
authType: 'gemini-api-key',
}),
getToolRegistry: () => mockToolRegistry,
getShellExecutionConfig: () => ({
@@ -1350,7 +1350,7 @@ describe('CoreToolScheduler request queueing', () => {
getAllowedTools: () => [],
getContentGeneratorConfig: () => ({
model: 'test-model',
authType: 'oauth-personal',
authType: 'gemini-api-key',
}),
getShellExecutionConfig: () => ({
terminalWidth: 90,
@@ -1482,7 +1482,7 @@ describe('CoreToolScheduler request queueing', () => {
getToolRegistry: () => toolRegistry,
getContentGeneratorConfig: () => ({
model: 'test-model',
authType: 'oauth-personal',
authType: 'gemini-api-key',
}),
getShellExecutionConfig: () => ({
terminalWidth: 80,
@@ -1586,7 +1586,7 @@ describe('CoreToolScheduler request queueing', () => {
getAllowedTools: () => [],
getContentGeneratorConfig: () => ({
model: 'test-model',
authType: 'oauth-personal',
authType: 'gemini-api-key',
}),
getShellExecutionConfig: () => ({
terminalWidth: 90,
@@ -1854,7 +1854,7 @@ describe('CoreToolScheduler Sequential Execution', () => {
getAllowedTools: () => [],
getContentGeneratorConfig: () => ({
model: 'test-model',
authType: 'oauth-personal',
authType: 'gemini-api-key',
}),
getShellExecutionConfig: () => ({
terminalWidth: 90,
@@ -1975,7 +1975,7 @@ describe('CoreToolScheduler Sequential Execution', () => {
getAllowedTools: () => [],
getContentGeneratorConfig: () => ({
model: 'test-model',
authType: 'oauth-personal',
authType: 'gemini-api-key',
}),
getShellExecutionConfig: () => ({
terminalWidth: 90,

View File

@@ -100,6 +100,7 @@ describe('GeminiChat', () => {
countTokens: vi.fn(),
embedContent: vi.fn(),
batchEmbedContents: vi.fn(),
useSummarizedThinking: vi.fn().mockReturnValue(false),
} as unknown as ContentGenerator;
mockHandleFallback.mockClear();
@@ -111,7 +112,7 @@ describe('GeminiChat', () => {
getUsageStatisticsEnabled: () => true,
getDebugMode: () => false,
getContentGeneratorConfig: vi.fn().mockReturnValue({
authType: 'oauth-personal', // Ensure this is set for fallback tests
authType: 'gemini-api-key', // Ensure this is set for fallback tests
model: 'test-model',
}),
getModel: vi.fn().mockReturnValue('gemini-pro'),
@@ -718,6 +719,99 @@ describe('GeminiChat', () => {
1,
);
});
it('should handle summarized thinking by conditionally including thoughts in history', async () => {
// Case 1: useSummarizedThinking is true -> thoughts NOT in history
vi.mocked(mockContentGenerator.useSummarizedThinking).mockReturnValue(
true,
);
const stream1 = (async function* () {
yield {
candidates: [
{
content: {
role: 'model',
parts: [{ thought: true, text: 'T1' }, { text: 'A1' }],
},
finishReason: 'STOP',
},
],
} as unknown as GenerateContentResponse;
})();
vi.mocked(mockContentGenerator.generateContentStream).mockResolvedValue(
stream1,
);
const res1 = await chat.sendMessageStream('m1', { message: 'h1' }, 'p1');
for await (const _ of res1);
const history1 = chat.getHistory();
expect(history1[1].parts).toEqual([{ text: 'A1' }]);
// Case 2: useSummarizedThinking is false -> thoughts ARE in history
chat.clearHistory();
vi.mocked(mockContentGenerator.useSummarizedThinking).mockReturnValue(
false,
);
const stream2 = (async function* () {
yield {
candidates: [
{
content: {
role: 'model',
parts: [{ thought: true, text: 'T2' }, { text: 'A2' }],
},
finishReason: 'STOP',
},
],
} as unknown as GenerateContentResponse;
})();
vi.mocked(mockContentGenerator.generateContentStream).mockResolvedValue(
stream2,
);
const res2 = await chat.sendMessageStream('m1', { message: 'h1' }, 'p2');
for await (const _ of res2);
const history2 = chat.getHistory();
expect(history2[1].parts).toEqual([
{ text: 'T2', thought: true },
{ text: 'A2' },
]);
});
it('should keep parts with thoughtSignature when consolidating history', async () => {
const stream = (async function* () {
yield {
candidates: [
{
content: {
role: 'model',
parts: [
{
text: 'p1',
thoughtSignature: 's1',
} as unknown as { text: string; thoughtSignature: string },
],
},
finishReason: 'STOP',
},
],
} as unknown as GenerateContentResponse;
})();
vi.mocked(mockContentGenerator.generateContentStream).mockResolvedValue(
stream,
);
const res = await chat.sendMessageStream('m1', { message: 'h1' }, 'p1');
for await (const _ of res);
const history = chat.getHistory();
expect(history[1].parts![0]).toEqual({
text: 'p1',
thoughtSignature: 's1',
});
});
});
describe('addHistory', () => {
@@ -1382,7 +1476,7 @@ describe('GeminiChat', () => {
});
it('should call handleFallback with the specific failed model and retry if handler returns true', async () => {
const authType = AuthType.LOGIN_WITH_GOOGLE;
const authType = AuthType.USE_GEMINI;
vi.mocked(mockConfig.getContentGeneratorConfig).mockReturnValue({
model: 'test-model',
authType,
@@ -1532,7 +1626,7 @@ describe('GeminiChat', () => {
});
describe('stripThoughtsFromHistory', () => {
it('should strip thought signatures', () => {
it('should strip thoughts and thought signatures, and remove empty content objects', () => {
chat.setHistory([
{
role: 'user',
@@ -1544,10 +1638,15 @@ describe('GeminiChat', () => {
{ text: 'thinking...', thought: true },
{ text: 'hi' },
{
functionCall: { name: 'test', args: {} },
},
text: 'hidden metadata',
thoughtSignature: 'abc',
} as unknown as { text: string; thoughtSignature: string },
],
},
{
role: 'model',
parts: [{ text: 'only thinking', thought: true }],
},
]);
chat.stripThoughtsFromHistory();
@@ -1559,7 +1658,7 @@ describe('GeminiChat', () => {
},
{
role: 'model',
parts: [{ text: 'hi' }, { functionCall: { name: 'test', args: {} } }],
parts: [{ text: 'hi' }, { text: 'hidden metadata' }],
},
]);
});

View File

@@ -92,6 +92,7 @@ export function isValidNonThoughtTextPart(part: Part): boolean {
return (
typeof part.text === 'string' &&
!part.thought &&
!part.thoughtSignature &&
// Technically, the model should never generate parts that have text and
// any of these but we don't trust them so check anyways.
!part.functionCall &&
@@ -109,18 +110,24 @@ function isValidContent(content: Content): boolean {
if (part === undefined || Object.keys(part).length === 0) {
return false;
}
if (
!part.thought &&
part.text !== undefined &&
part.text === '' &&
part.functionCall === undefined
) {
if (!isValidContentPart(part)) {
return false;
}
}
return true;
}
function isValidContentPart(part: Part): boolean {
const isInvalid =
!part.thought &&
!part.thoughtSignature &&
part.text !== undefined &&
part.text === '' &&
part.functionCall === undefined;
return !isInvalid;
}
/**
* Validates the history contains the correct roles.
*
@@ -448,15 +455,29 @@ export class GeminiChat {
if (!content.parts) return content;
// Filter out thought parts entirely
const filteredParts = content.parts.filter(
(part) =>
!(
const filteredParts = content.parts
.filter(
(part) =>
!(
part &&
typeof part === 'object' &&
'thought' in part &&
part.thought
),
)
.map((part) => {
if (
part &&
typeof part === 'object' &&
'thought' in part &&
part.thought
),
);
'thoughtSignature' in part
) {
const newPart = { ...part };
delete (newPart as { thoughtSignature?: string })
.thoughtSignature;
return newPart;
}
return part;
});
return {
...content,
@@ -538,11 +559,15 @@ export class GeminiChat {
yield chunk; // Yield every chunk to the UI immediately.
}
const thoughtParts = allModelParts.filter((part) => part.thought);
const thoughtText = thoughtParts
.map((part) => part.text)
.join('')
.trim();
let thoughtText = '';
// Only include thoughts if not using summarized thinking.
if (!this.config.getContentGenerator().useSummarizedThinking()) {
thoughtText = allModelParts
.filter((part) => part.thought)
.map((part) => part.text)
.join('')
.trim();
}
const contentParts = allModelParts.filter((part) => !part.thought);
const consolidatedHistoryParts: Part[] = [];
@@ -555,7 +580,7 @@ export class GeminiChat {
isValidNonThoughtTextPart(part)
) {
lastPart.text += part.text;
} else {
} else if (isValidContentPart(part)) {
consolidatedHistoryParts.push(part);
}
}

View File

@@ -0,0 +1,173 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { describe, it, expect, vi, beforeEach } from 'vitest';
import { GeminiContentGenerator } from './geminiContentGenerator.js';
import { GoogleGenAI } from '@google/genai';
vi.mock('@google/genai', () => {
const mockGenerateContent = vi.fn();
const mockGenerateContentStream = vi.fn();
const mockCountTokens = vi.fn();
const mockEmbedContent = vi.fn();
return {
GoogleGenAI: vi.fn().mockImplementation(() => ({
models: {
generateContent: mockGenerateContent,
generateContentStream: mockGenerateContentStream,
countTokens: mockCountTokens,
embedContent: mockEmbedContent,
},
})),
};
});
describe('GeminiContentGenerator', () => {
let generator: GeminiContentGenerator;
// eslint-disable-next-line @typescript-eslint/no-explicit-any
let mockGoogleGenAI: any;
beforeEach(() => {
vi.clearAllMocks();
generator = new GeminiContentGenerator({
apiKey: 'test-api-key',
});
mockGoogleGenAI = vi.mocked(GoogleGenAI).mock.results[0].value;
});
it('should call generateContent on the underlying model', async () => {
const request = { model: 'gemini-1.5-flash', contents: [] };
const expectedResponse = { responseId: 'test-id' };
mockGoogleGenAI.models.generateContent.mockResolvedValue(expectedResponse);
const response = await generator.generateContent(request, 'prompt-id');
expect(mockGoogleGenAI.models.generateContent).toHaveBeenCalledWith(
expect.objectContaining({
...request,
config: expect.objectContaining({
temperature: 1,
topP: 0.95,
thinkingConfig: {
includeThoughts: true,
thinkingLevel: 'THINKING_LEVEL_UNSPECIFIED',
},
}),
}),
);
expect(response).toBe(expectedResponse);
});
it('should call generateContentStream on the underlying model', async () => {
const request = { model: 'gemini-1.5-flash', contents: [] };
const mockStream = (async function* () {
yield { responseId: '1' };
})();
mockGoogleGenAI.models.generateContentStream.mockResolvedValue(mockStream);
const stream = await generator.generateContentStream(request, 'prompt-id');
expect(mockGoogleGenAI.models.generateContentStream).toHaveBeenCalledWith(
expect.objectContaining({
...request,
config: expect.objectContaining({
temperature: 1,
topP: 0.95,
thinkingConfig: {
includeThoughts: true,
thinkingLevel: 'THINKING_LEVEL_UNSPECIFIED',
},
}),
}),
);
expect(stream).toBe(mockStream);
});
it('should call countTokens on the underlying model', async () => {
const request = { model: 'gemini-1.5-flash', contents: [] };
const expectedResponse = { totalTokens: 10 };
mockGoogleGenAI.models.countTokens.mockResolvedValue(expectedResponse);
const response = await generator.countTokens(request);
expect(mockGoogleGenAI.models.countTokens).toHaveBeenCalledWith(request);
expect(response).toBe(expectedResponse);
});
it('should call embedContent on the underlying model', async () => {
const request = { model: 'embedding-model', contents: [] };
const expectedResponse = { embeddings: [] };
mockGoogleGenAI.models.embedContent.mockResolvedValue(expectedResponse);
const response = await generator.embedContent(request);
expect(mockGoogleGenAI.models.embedContent).toHaveBeenCalledWith(request);
expect(response).toBe(expectedResponse);
});
it('should prioritize contentGeneratorConfig samplingParams over request config', async () => {
const generatorWithParams = new GeminiContentGenerator({ apiKey: 'test' }, {
model: 'gemini-1.5-flash',
samplingParams: {
temperature: 0.1,
top_p: 0.2,
},
// eslint-disable-next-line @typescript-eslint/no-explicit-any
} as any);
const request = {
model: 'gemini-1.5-flash',
contents: [],
config: {
temperature: 0.9,
topP: 0.9,
},
};
await generatorWithParams.generateContent(request, 'prompt-id');
expect(mockGoogleGenAI.models.generateContent).toHaveBeenCalledWith(
expect.objectContaining({
config: expect.objectContaining({
temperature: 0.1,
topP: 0.2,
}),
}),
);
});
it('should map reasoning effort to thinkingConfig', async () => {
const generatorWithReasoning = new GeminiContentGenerator(
{ apiKey: 'test' },
{
model: 'gemini-2.5-pro',
reasoning: {
effort: 'high',
},
// eslint-disable-next-line @typescript-eslint/no-explicit-any
} as any,
);
const request = {
model: 'gemini-2.5-pro',
contents: [],
};
await generatorWithReasoning.generateContent(request, 'prompt-id');
expect(mockGoogleGenAI.models.generateContent).toHaveBeenCalledWith(
expect.objectContaining({
config: expect.objectContaining({
thinkingConfig: {
includeThoughts: true,
thinkingLevel: 'HIGH',
},
}),
}),
);
});
});

View File

@@ -0,0 +1,144 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import type {
CountTokensParameters,
CountTokensResponse,
EmbedContentParameters,
EmbedContentResponse,
GenerateContentParameters,
GenerateContentResponse,
GenerateContentConfig,
ThinkingLevel,
} from '@google/genai';
import { GoogleGenAI } from '@google/genai';
import type {
ContentGenerator,
ContentGeneratorConfig,
} from '../contentGenerator.js';
/**
* A wrapper for GoogleGenAI that implements the ContentGenerator interface.
*/
export class GeminiContentGenerator implements ContentGenerator {
private readonly googleGenAI: GoogleGenAI;
private readonly contentGeneratorConfig?: ContentGeneratorConfig;
constructor(
options: {
apiKey?: string;
vertexai?: boolean;
httpOptions?: { headers: Record<string, string> };
},
contentGeneratorConfig?: ContentGeneratorConfig,
) {
this.googleGenAI = new GoogleGenAI(options);
this.contentGeneratorConfig = contentGeneratorConfig;
}
private buildSamplingParameters(
request: GenerateContentParameters,
): GenerateContentConfig {
const configSamplingParams = this.contentGeneratorConfig?.samplingParams;
const requestConfig = request.config || {};
// Helper function to get parameter value with priority: config > request > default
const getParameterValue = <T>(
configValue: T | undefined,
requestKey: keyof GenerateContentConfig,
defaultValue?: T,
): T | undefined => {
const requestValue = requestConfig[requestKey] as T | undefined;
if (configValue !== undefined) return configValue;
if (requestValue !== undefined) return requestValue;
return defaultValue;
};
return {
...requestConfig,
temperature: getParameterValue<number>(
configSamplingParams?.temperature,
'temperature',
1,
),
topP: getParameterValue<number>(
configSamplingParams?.top_p,
'topP',
0.95,
),
topK: getParameterValue<number>(configSamplingParams?.top_k, 'topK', 64),
maxOutputTokens: getParameterValue<number>(
configSamplingParams?.max_tokens,
'maxOutputTokens',
),
presencePenalty: getParameterValue<number>(
configSamplingParams?.presence_penalty,
'presencePenalty',
),
frequencyPenalty: getParameterValue<number>(
configSamplingParams?.frequency_penalty,
'frequencyPenalty',
),
thinkingConfig: getParameterValue(
this.contentGeneratorConfig?.reasoning
? {
includeThoughts: true,
thinkingLevel: (this.contentGeneratorConfig.reasoning.effort ===
'low'
? 'LOW'
: this.contentGeneratorConfig.reasoning.effort === 'high'
? 'HIGH'
: 'THINKING_LEVEL_UNSPECIFIED') as ThinkingLevel,
}
: undefined,
'thinkingConfig',
{
includeThoughts: true,
thinkingLevel: 'THINKING_LEVEL_UNSPECIFIED' as ThinkingLevel,
},
),
};
}
async generateContent(
request: GenerateContentParameters,
_userPromptId: string,
): Promise<GenerateContentResponse> {
const finalRequest = {
...request,
config: this.buildSamplingParameters(request),
};
return this.googleGenAI.models.generateContent(finalRequest);
}
async generateContentStream(
request: GenerateContentParameters,
_userPromptId: string,
): Promise<AsyncGenerator<GenerateContentResponse>> {
const finalRequest = {
...request,
config: this.buildSamplingParameters(request),
};
return this.googleGenAI.models.generateContentStream(finalRequest);
}
async countTokens(
request: CountTokensParameters,
): Promise<CountTokensResponse> {
return this.googleGenAI.models.countTokens(request);
}
async embedContent(
request: EmbedContentParameters,
): Promise<EmbedContentResponse> {
return this.googleGenAI.models.embedContent(request);
}
useSummarizedThinking(): boolean {
return true;
}
}

View File

@@ -0,0 +1,47 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { describe, it, expect, vi, beforeEach } from 'vitest';
import { createGeminiContentGenerator } from './index.js';
import { GeminiContentGenerator } from './geminiContentGenerator.js';
import { LoggingContentGenerator } from './loggingContentGenerator.js';
import type { Config } from '../../config/config.js';
import { AuthType } from '../contentGenerator.js';
vi.mock('./geminiContentGenerator.js', () => ({
GeminiContentGenerator: vi.fn().mockImplementation(() => ({})),
}));
vi.mock('./loggingContentGenerator.js', () => ({
LoggingContentGenerator: vi.fn().mockImplementation((wrapped) => wrapped),
}));
describe('createGeminiContentGenerator', () => {
let mockConfig: Config;
beforeEach(() => {
vi.clearAllMocks();
mockConfig = {
getUsageStatisticsEnabled: vi.fn().mockReturnValue(false),
getContentGeneratorConfig: vi.fn().mockReturnValue({}),
getCliVersion: vi.fn().mockReturnValue('1.0.0'),
} as unknown as Config;
});
it('should create a GeminiContentGenerator wrapped in LoggingContentGenerator', () => {
const config = {
model: 'gemini-1.5-flash',
apiKey: 'test-key',
authType: AuthType.USE_GEMINI,
};
const generator = createGeminiContentGenerator(config, mockConfig);
expect(GeminiContentGenerator).toHaveBeenCalled();
expect(LoggingContentGenerator).toHaveBeenCalled();
expect(generator).toBeDefined();
});
});

View File

@@ -0,0 +1,55 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { GeminiContentGenerator } from './geminiContentGenerator.js';
import type {
ContentGenerator,
ContentGeneratorConfig,
} from '../contentGenerator.js';
import type { Config } from '../../config/config.js';
import { InstallationManager } from '../../utils/installationManager.js';
import { LoggingContentGenerator } from './loggingContentGenerator.js';
export { GeminiContentGenerator } from './geminiContentGenerator.js';
export { LoggingContentGenerator } from './loggingContentGenerator.js';
/**
* Create a Gemini content generator.
*/
export function createGeminiContentGenerator(
config: ContentGeneratorConfig,
gcConfig: Config,
): ContentGenerator {
const version = process.env['CLI_VERSION'] || process.version;
const userAgent =
config.userAgent ||
`QwenCode/${version} (${process.platform}; ${process.arch})`;
const baseHeaders: Record<string, string> = {
'User-Agent': userAgent,
};
let headers: Record<string, string> = { ...baseHeaders };
if (gcConfig?.getUsageStatisticsEnabled()) {
const installationManager = new InstallationManager();
const installationId = installationManager.getInstallationId();
headers = {
...headers,
'x-gemini-api-privileged-user-id': `${installationId}`,
};
}
const httpOptions = { headers };
const geminiContentGenerator = new GeminiContentGenerator(
{
apiKey: config.apiKey === '' ? undefined : config.apiKey,
vertexai: config.vertexai,
httpOptions,
},
config,
);
return new LoggingContentGenerator(geminiContentGenerator, gcConfig);
}

View File

@@ -13,21 +13,24 @@ import type {
GenerateContentParameters,
GenerateContentResponseUsageMetadata,
GenerateContentResponse,
ContentListUnion,
ContentUnion,
Part,
PartUnion,
} from '@google/genai';
import {
ApiRequestEvent,
ApiResponseEvent,
ApiErrorEvent,
} from '../telemetry/types.js';
import type { Config } from '../config/config.js';
} from '../../telemetry/types.js';
import type { Config } from '../../config/config.js';
import {
logApiError,
logApiRequest,
logApiResponse,
} from '../telemetry/loggers.js';
import type { ContentGenerator } from './contentGenerator.js';
import { toContents } from '../code_assist/converter.js';
import { isStructuredError } from '../utils/quotaErrorDetection.js';
} from '../../telemetry/loggers.js';
import type { ContentGenerator } from '../contentGenerator.js';
import { isStructuredError } from '../../utils/quotaErrorDetection.js';
interface StructuredError {
status: number;
@@ -112,7 +115,7 @@ export class LoggingContentGenerator implements ContentGenerator {
userPromptId: string,
): Promise<GenerateContentResponse> {
const startTime = Date.now();
this.logApiRequest(toContents(req.contents), req.model, userPromptId);
this.logApiRequest(this.toContents(req.contents), req.model, userPromptId);
try {
const response = await this.wrapped.generateContent(req, userPromptId);
const durationMs = Date.now() - startTime;
@@ -137,7 +140,7 @@ export class LoggingContentGenerator implements ContentGenerator {
userPromptId: string,
): Promise<AsyncGenerator<GenerateContentResponse>> {
const startTime = Date.now();
this.logApiRequest(toContents(req.contents), req.model, userPromptId);
this.logApiRequest(this.toContents(req.contents), req.model, userPromptId);
let stream: AsyncGenerator<GenerateContentResponse>;
try {
@@ -205,4 +208,95 @@ export class LoggingContentGenerator implements ContentGenerator {
): Promise<EmbedContentResponse> {
return this.wrapped.embedContent(req);
}
useSummarizedThinking(): boolean {
return this.wrapped.useSummarizedThinking();
}
private toContents(contents: ContentListUnion): Content[] {
if (Array.isArray(contents)) {
// it's a Content[] or a PartsUnion[]
return contents.map((c) => this.toContent(c));
}
// it's a Content or a PartsUnion
return [this.toContent(contents)];
}
private toContent(content: ContentUnion): Content {
if (Array.isArray(content)) {
// it's a PartsUnion[]
return {
role: 'user',
parts: this.toParts(content),
};
}
if (typeof content === 'string') {
// it's a string
return {
role: 'user',
parts: [{ text: content }],
};
}
if ('parts' in content) {
// it's a Content - process parts to handle thought filtering
return {
...content,
parts: content.parts
? this.toParts(content.parts.filter((p) => p != null))
: [],
};
}
// it's a Part
return {
role: 'user',
parts: [this.toPart(content as Part)],
};
}
private toParts(parts: PartUnion[]): Part[] {
return parts.map((p) => this.toPart(p));
}
private toPart(part: PartUnion): Part {
if (typeof part === 'string') {
// it's a string
return { text: part };
}
// Handle thought parts for CountToken API compatibility
// The CountToken API expects parts to have certain required "oneof" fields initialized,
// but thought parts don't conform to this schema and cause API failures
if ('thought' in part && part.thought) {
const thoughtText = `[Thought: ${part.thought}]`;
const newPart = { ...part };
delete (newPart as Record<string, unknown>)['thought'];
const hasApiContent =
'functionCall' in newPart ||
'functionResponse' in newPart ||
'inlineData' in newPart ||
'fileData' in newPart;
if (hasApiContent) {
// It's a functionCall or other non-text part. Just strip the thought.
return newPart;
}
// If no other valid API content, this must be a text part.
// Combine existing text (if any) with the thought, preserving other properties.
const text = (newPart as { text?: unknown }).text;
const existingText = text ? String(text) : '';
const combinedText = existingText
? `${existingText}\n${thoughtText}`
: thoughtText;
return {
...newPart,
text: combinedText,
};
}
return part;
}
}

View File

@@ -47,7 +47,7 @@ describe('executeToolCall', () => {
getDebugMode: () => false,
getContentGeneratorConfig: () => ({
model: 'test-model',
authType: 'oauth-personal',
authType: 'gemini-api-key',
}),
getShellExecutionConfig: () => ({
terminalWidth: 90,

View File

@@ -99,6 +99,7 @@ describe('OpenAIContentGenerator (Refactored)', () => {
},
} as unknown as OpenAI),
buildRequest: vi.fn().mockImplementation((req) => req),
getDefaultGenerationConfig: vi.fn().mockReturnValue({}),
};
generator = new OpenAIContentGenerator(
@@ -211,6 +212,7 @@ describe('OpenAIContentGenerator (Refactored)', () => {
},
} as unknown as OpenAI),
buildRequest: vi.fn().mockImplementation((req) => req),
getDefaultGenerationConfig: vi.fn().mockReturnValue({}),
};
const testGenerator = new TestGenerator(
@@ -277,6 +279,7 @@ describe('OpenAIContentGenerator (Refactored)', () => {
},
} as unknown as OpenAI),
buildRequest: vi.fn().mockImplementation((req) => req),
getDefaultGenerationConfig: vi.fn().mockReturnValue({}),
};
const testGenerator = new TestGenerator(

View File

@@ -154,4 +154,8 @@ export class OpenAIContentGenerator implements ContentGenerator {
);
}
}
useSummarizedThinking(): boolean {
return false;
}
}

View File

@@ -60,6 +60,7 @@ describe('ContentGenerationPipeline', () => {
buildClient: vi.fn().mockReturnValue(mockClient),
buildRequest: vi.fn().mockImplementation((req) => req),
buildHeaders: vi.fn().mockReturnValue({}),
getDefaultGenerationConfig: vi.fn().mockReturnValue({}),
};
// Mock telemetry service

View File

@@ -283,16 +283,22 @@ export class ContentGenerationPipeline {
private buildSamplingParameters(
request: GenerateContentParameters,
): Record<string, unknown> {
const defaultSamplingParams =
this.config.provider.getDefaultGenerationConfig();
const configSamplingParams = this.contentGeneratorConfig.samplingParams;
// Helper function to get parameter value with priority: config > request > default
const getParameterValue = <T>(
configKey: keyof NonNullable<typeof configSamplingParams>,
requestKey: keyof NonNullable<typeof request.config>,
defaultValue?: T,
requestKey?: keyof NonNullable<typeof request.config>,
): T | undefined => {
const configValue = configSamplingParams?.[configKey] as T | undefined;
const requestValue = request.config?.[requestKey] as T | undefined;
const requestValue = requestKey
? (request.config?.[requestKey] as T | undefined)
: undefined;
const defaultValue = requestKey
? (defaultSamplingParams[requestKey] as T)
: undefined;
if (configValue !== undefined) return configValue;
if (requestValue !== undefined) return requestValue;
@@ -304,12 +310,8 @@ export class ContentGenerationPipeline {
key: string,
configKey: keyof NonNullable<typeof configSamplingParams>,
requestKey?: keyof NonNullable<typeof request.config>,
defaultValue?: T,
): Record<string, T> | Record<string, never> => {
const value = requestKey
? getParameterValue(configKey, requestKey, defaultValue)
: ((configSamplingParams?.[configKey] as T | undefined) ??
defaultValue);
): Record<string, T | undefined> => {
const value = getParameterValue<T>(configKey, requestKey);
return value !== undefined ? { [key]: value } : {};
};
@@ -323,10 +325,18 @@ export class ContentGenerationPipeline {
...addParameterIfDefined('max_tokens', 'max_tokens', 'maxOutputTokens'),
// Config-only parameters (no request fallback)
...addParameterIfDefined('top_k', 'top_k'),
...addParameterIfDefined('top_k', 'top_k', 'topK'),
...addParameterIfDefined('repetition_penalty', 'repetition_penalty'),
...addParameterIfDefined('presence_penalty', 'presence_penalty'),
...addParameterIfDefined('frequency_penalty', 'frequency_penalty'),
...addParameterIfDefined(
'presence_penalty',
'presence_penalty',
'presencePenalty',
),
...addParameterIfDefined(
'frequency_penalty',
'frequency_penalty',
'frequencyPenalty',
),
};
return params;

View File

@@ -1,4 +1,5 @@
import OpenAI from 'openai';
import type { GenerateContentConfig } from '@google/genai';
import type { Config } from '../../../config/config.js';
import type { ContentGeneratorConfig } from '../../contentGenerator.js';
import { AuthType } from '../../contentGenerator.js';
@@ -141,6 +142,14 @@ export class DashScopeOpenAICompatibleProvider
};
}
getDefaultGenerationConfig(): GenerateContentConfig {
return {
temperature: 0.7,
topP: 0.8,
topK: 20,
};
}
/**
* Add cache control flag to specified message(s) for DashScope providers
*/

View File

@@ -8,6 +8,7 @@ import type OpenAI from 'openai';
import type { Config } from '../../../config/config.js';
import type { ContentGeneratorConfig } from '../../contentGenerator.js';
import { DefaultOpenAICompatibleProvider } from './default.js';
import type { GenerateContentConfig } from '@google/genai';
export class DeepSeekOpenAICompatibleProvider extends DefaultOpenAICompatibleProvider {
constructor(
@@ -76,4 +77,10 @@ export class DeepSeekOpenAICompatibleProvider extends DefaultOpenAICompatiblePro
messages,
};
}
override getDefaultGenerationConfig(): GenerateContentConfig {
return {
temperature: 0,
};
}
}

View File

@@ -1,4 +1,5 @@
import OpenAI from 'openai';
import type { GenerateContentConfig } from '@google/genai';
import type { Config } from '../../../config/config.js';
import type { ContentGeneratorConfig } from '../../contentGenerator.js';
import { DEFAULT_TIMEOUT, DEFAULT_MAX_RETRIES } from '../constants.js';
@@ -55,4 +56,10 @@ export class DefaultOpenAICompatibleProvider
...request, // Preserve all original parameters including sampling params
};
}
getDefaultGenerationConfig(): GenerateContentConfig {
return {
topP: 0.95,
};
}
}

View File

@@ -1,3 +1,4 @@
import type { GenerateContentConfig } from '@google/genai';
import type OpenAI from 'openai';
// Extended types to support cache_control for DashScope
@@ -22,6 +23,7 @@ export interface OpenAICompatibleProvider {
request: OpenAI.Chat.ChatCompletionCreateParams,
userPromptId: string,
): OpenAI.Chat.ChatCompletionCreateParams;
getDefaultGenerationConfig(): GenerateContentConfig;
}
export type DashScopeRequestMetadata = {

View File

@@ -36,13 +36,6 @@ vi.mock('../utils/errorReporting', () => ({
reportError: vi.fn(),
}));
// Use the actual implementation from partUtils now that it's provided.
vi.mock('../utils/generateContentResponseUtilities', () => ({
getResponseText: (resp: GenerateContentResponse) =>
resp.candidates?.[0]?.content?.parts?.map((part) => part.text).join('') ||
undefined,
}));
describe('Turn', () => {
let turn: Turn;
// Define a type for the mocked Chat instance for clarity
@@ -156,6 +149,7 @@ describe('Turn', () => {
type: GeminiEventType.Thought,
value: { subject: '', description: 'reasoning...' },
},
{ type: GeminiEventType.Content, value: 'final answer' },
]);
});

View File

@@ -27,7 +27,11 @@ import {
toFriendlyError,
} from '../utils/errors.js';
import type { GeminiChat } from './geminiChat.js';
import { getThoughtText, type ThoughtSummary } from '../utils/thoughtUtils.js';
import {
getThoughtText,
parseThought,
type ThoughtSummary,
} from '../utils/thoughtUtils.js';
// Define a structure for tools passed to the server
export interface ServerTool {
@@ -266,13 +270,12 @@ export class Turn {
this.currentResponseId = resp.responseId;
}
const thoughtPart = getThoughtText(resp);
if (thoughtPart) {
const thoughtText = getThoughtText(resp);
if (thoughtText) {
yield {
type: GeminiEventType.Thought,
value: { subject: '', description: thoughtPart },
value: parseThought(thoughtText),
};
continue;
}
const text = getResponseText(resp);

View File

@@ -4,36 +4,10 @@
* SPDX-License-Identifier: Apache-2.0
*/
import {
describe,
it,
expect,
vi,
beforeEach,
type Mock,
type MockInstance,
afterEach,
} from 'vitest';
import { describe, it, expect, vi, beforeEach } from 'vitest';
import { handleFallback } from './handler.js';
import type { Config } from '../config/config.js';
import { AuthType } from '../core/contentGenerator.js';
import {
DEFAULT_GEMINI_FLASH_MODEL,
DEFAULT_GEMINI_MODEL,
} from '../config/models.js';
import { logFlashFallback } from '../telemetry/index.js';
import type { FallbackModelHandler } from './types.js';
// Mock the telemetry logger and event class
vi.mock('../telemetry/index.js', () => ({
logFlashFallback: vi.fn(),
FlashFallbackEvent: class {},
}));
const MOCK_PRO_MODEL = DEFAULT_GEMINI_MODEL;
const FALLBACK_MODEL = DEFAULT_GEMINI_FLASH_MODEL;
const AUTH_OAUTH = AuthType.LOGIN_WITH_GOOGLE;
const AUTH_API_KEY = AuthType.USE_GEMINI;
const createMockConfig = (overrides: Partial<Config> = {}): Config =>
({
@@ -45,174 +19,28 @@ const createMockConfig = (overrides: Partial<Config> = {}): Config =>
describe('handleFallback', () => {
let mockConfig: Config;
let mockHandler: Mock<FallbackModelHandler>;
let consoleErrorSpy: MockInstance;
beforeEach(() => {
vi.clearAllMocks();
mockHandler = vi.fn();
// Default setup: OAuth user, Pro model failed, handler injected
mockConfig = createMockConfig({
fallbackModelHandler: mockHandler,
});
consoleErrorSpy = vi.spyOn(console, 'error').mockImplementation(() => {});
mockConfig = createMockConfig();
});
afterEach(() => {
consoleErrorSpy.mockRestore();
});
it('should return null immediately if authType is not OAuth', async () => {
it('should return null for unknown auth types', async () => {
const result = await handleFallback(
mockConfig,
MOCK_PRO_MODEL,
AUTH_API_KEY,
'test-model',
'unknown-auth',
);
expect(result).toBeNull();
expect(mockHandler).not.toHaveBeenCalled();
expect(mockConfig.setFallbackMode).not.toHaveBeenCalled();
});
it('should return null if the failed model is already the fallback model', async () => {
it('should handle Qwen OAuth error', async () => {
const result = await handleFallback(
mockConfig,
FALLBACK_MODEL, // Failed model is Flash
AUTH_OAUTH,
'test-model',
AuthType.QWEN_OAUTH,
new Error('unauthorized'),
);
expect(result).toBeNull();
expect(mockHandler).not.toHaveBeenCalled();
});
it('should return null if no fallbackHandler is injected in config', async () => {
const configWithoutHandler = createMockConfig({
fallbackModelHandler: undefined,
});
const result = await handleFallback(
configWithoutHandler,
MOCK_PRO_MODEL,
AUTH_OAUTH,
);
expect(result).toBeNull();
});
describe('when handler returns "retry"', () => {
it('should activate fallback mode, log telemetry, and return true', async () => {
mockHandler.mockResolvedValue('retry');
const result = await handleFallback(
mockConfig,
MOCK_PRO_MODEL,
AUTH_OAUTH,
);
expect(result).toBe(true);
expect(mockConfig.setFallbackMode).toHaveBeenCalledWith(true);
expect(logFlashFallback).toHaveBeenCalled();
});
});
describe('when handler returns "stop"', () => {
it('should activate fallback mode, log telemetry, and return false', async () => {
mockHandler.mockResolvedValue('stop');
const result = await handleFallback(
mockConfig,
MOCK_PRO_MODEL,
AUTH_OAUTH,
);
expect(result).toBe(false);
expect(mockConfig.setFallbackMode).toHaveBeenCalledWith(true);
expect(logFlashFallback).toHaveBeenCalled();
});
});
describe('when handler returns "auth"', () => {
it('should NOT activate fallback mode and return false', async () => {
mockHandler.mockResolvedValue('auth');
const result = await handleFallback(
mockConfig,
MOCK_PRO_MODEL,
AUTH_OAUTH,
);
expect(result).toBe(false);
expect(mockConfig.setFallbackMode).not.toHaveBeenCalled();
expect(logFlashFallback).not.toHaveBeenCalled();
});
});
describe('when handler returns an unexpected value', () => {
it('should log an error and return null', async () => {
mockHandler.mockResolvedValue(null);
const result = await handleFallback(
mockConfig,
MOCK_PRO_MODEL,
AUTH_OAUTH,
);
expect(result).toBeNull();
expect(consoleErrorSpy).toHaveBeenCalledWith(
'Fallback UI handler failed:',
new Error(
'Unexpected fallback intent received from fallbackModelHandler: "null"',
),
);
expect(mockConfig.setFallbackMode).not.toHaveBeenCalled();
});
});
it('should pass the correct context (failedModel, fallbackModel, error) to the handler', async () => {
const mockError = new Error('Quota Exceeded');
mockHandler.mockResolvedValue('retry');
await handleFallback(mockConfig, MOCK_PRO_MODEL, AUTH_OAUTH, mockError);
expect(mockHandler).toHaveBeenCalledWith(
MOCK_PRO_MODEL,
FALLBACK_MODEL,
mockError,
);
});
it('should not call setFallbackMode or log telemetry if already in fallback mode', async () => {
// Setup config where fallback mode is already active
const activeFallbackConfig = createMockConfig({
fallbackModelHandler: mockHandler,
isInFallbackMode: vi.fn(() => true), // Already active
setFallbackMode: vi.fn(),
});
mockHandler.mockResolvedValue('retry');
const result = await handleFallback(
activeFallbackConfig,
MOCK_PRO_MODEL,
AUTH_OAUTH,
);
// Should still return true to allow the retry (which will use the active fallback mode)
expect(result).toBe(true);
// Should still consult the handler
expect(mockHandler).toHaveBeenCalled();
// But should not mutate state or log telemetry again
expect(activeFallbackConfig.setFallbackMode).not.toHaveBeenCalled();
expect(logFlashFallback).not.toHaveBeenCalled();
});
it('should catch errors from the handler, log an error, and return null', async () => {
const handlerError = new Error('UI interaction failed');
mockHandler.mockRejectedValue(handlerError);
const result = await handleFallback(mockConfig, MOCK_PRO_MODEL, AUTH_OAUTH);
expect(result).toBeNull();
expect(consoleErrorSpy).toHaveBeenCalledWith(
'Fallback UI handler failed:',
handlerError,
);
expect(mockConfig.setFallbackMode).not.toHaveBeenCalled();
});
});

View File

@@ -6,8 +6,6 @@
import type { Config } from '../config/config.js';
import { AuthType } from '../core/contentGenerator.js';
import { DEFAULT_GEMINI_FLASH_MODEL } from '../config/models.js';
import { logFlashFallback, FlashFallbackEvent } from '../telemetry/index.js';
export async function handleFallback(
config: Config,
@@ -20,48 +18,7 @@ export async function handleFallback(
return handleQwenOAuthError(error);
}
// Applicability Checks
if (authType !== AuthType.LOGIN_WITH_GOOGLE) return null;
const fallbackModel = DEFAULT_GEMINI_FLASH_MODEL;
if (failedModel === fallbackModel) return null;
// Consult UI Handler for Intent
const fallbackModelHandler = config.fallbackModelHandler;
if (typeof fallbackModelHandler !== 'function') return null;
try {
// Pass the specific failed model to the UI handler.
const intent = await fallbackModelHandler(
failedModel,
fallbackModel,
error,
);
// Process Intent and Update State
switch (intent) {
case 'retry':
// Activate fallback mode. The NEXT retry attempt will pick this up.
activateFallbackMode(config, authType);
return true; // Signal retryWithBackoff to continue.
case 'stop':
activateFallbackMode(config, authType);
return false;
case 'auth':
return false;
default:
throw new Error(
`Unexpected fallback intent received from fallbackModelHandler: "${intent}"`,
);
}
} catch (handlerError) {
console.error('Fallback UI handler failed:', handlerError);
return null;
}
return null;
}
/**
@@ -118,12 +75,3 @@ async function handleQwenOAuthError(error?: unknown): Promise<string | null> {
// For other errors, don't handle them specially
return null;
}
function activateFallbackMode(config: Config, authType: string | undefined) {
if (!config.isInFallbackMode()) {
config.setFallbackMode(true);
if (authType) {
logFlashFallback(config, new FlashFallbackEvent(authType));
}
}
}

View File

@@ -12,7 +12,6 @@ export * from './output/json-formatter.js';
// Export Core Logic
export * from './core/client.js';
export * from './core/contentGenerator.js';
export * from './core/loggingContentGenerator.js';
export * from './core/geminiChat.js';
export * from './core/logger.js';
export * from './core/prompts.js';
@@ -24,11 +23,7 @@ export * from './core/nonInteractiveToolExecutor.js';
export * from './fallback/types.js';
export * from './code_assist/codeAssist.js';
export * from './code_assist/oauth2.js';
export * from './qwen/qwenOAuth2.js';
export * from './code_assist/server.js';
export * from './code_assist/types.js';
// Export utilities
export * from './utils/paths.js';

View File

@@ -907,3 +907,5 @@ export async function clearQwenCredentials(): Promise<void> {
function getQwenCachedCredentialPath(): string {
return path.join(os.homedir(), QWEN_DIR, QWEN_CREDENTIAL_FILENAME);
}
export const clearCachedCredentialFile = clearQwenCredentials;

View File

@@ -30,7 +30,6 @@ import {
ToolCallEvent,
} from '../types.js';
import { GIT_COMMIT_INFO, CLI_VERSION } from '../../generated/git-commit.js';
import { UserAccountManager } from '../../utils/userAccountManager.js';
import { InstallationManager } from '../../utils/installationManager.js';
import { safeJsonStringify } from '../../utils/safeJsonStringify.js';
@@ -90,10 +89,8 @@ expect.extend({
},
});
vi.mock('../../utils/userAccountManager.js');
vi.mock('../../utils/installationManager.js');
const mockUserAccount = vi.mocked(UserAccountManager.prototype);
const mockInstallMgr = vi.mocked(InstallationManager.prototype);
// TODO(richieforeman): Consider moving this to test setup globally.
@@ -128,11 +125,7 @@ describe('ClearcutLogger', () => {
vi.unstubAllEnvs();
});
function setup({
config = {} as Partial<ConfigParameters>,
lifetimeGoogleAccounts = 1,
cachedGoogleAccount = 'test@google.com',
} = {}) {
function setup({ config = {} as Partial<ConfigParameters> } = {}) {
server.resetHandlers(
http.post(CLEARCUT_URL, () => HttpResponse.text(EXAMPLE_RESPONSE)),
);
@@ -146,10 +139,6 @@ describe('ClearcutLogger', () => {
});
ClearcutLogger.clearInstance();
mockUserAccount.getCachedGoogleAccount.mockReturnValue(cachedGoogleAccount);
mockUserAccount.getLifetimeGoogleAccounts.mockReturnValue(
lifetimeGoogleAccounts,
);
mockInstallMgr.getInstallationId = vi
.fn()
.mockReturnValue('test-installation-id');
@@ -195,19 +184,6 @@ describe('ClearcutLogger', () => {
});
describe('createLogEvent', () => {
it('logs the total number of google accounts', () => {
const { logger } = setup({
lifetimeGoogleAccounts: 9001,
});
const event = logger?.createLogEvent(EventNames.API_ERROR, []);
expect(event?.event_metadata[0]).toContainEqual({
gemini_cli_key: EventMetadataKey.GEMINI_CLI_GOOGLE_ACCOUNTS_COUNT,
value: '9001',
});
});
it('logs the current surface from a github action', () => {
const { logger } = setup({});
@@ -251,7 +227,6 @@ describe('ClearcutLogger', () => {
// Define expected values
const session_id = 'test-session-id';
const auth_type = AuthType.USE_GEMINI;
const google_accounts = 123;
const surface = 'ide-1234';
const cli_version = CLI_VERSION;
const git_commit_hash = GIT_COMMIT_INFO;
@@ -260,7 +235,6 @@ describe('ClearcutLogger', () => {
// Setup logger with expected values
const { logger, loggerConfig } = setup({
lifetimeGoogleAccounts: google_accounts,
config: {},
});
vi.spyOn(loggerConfig, 'getContentGeneratorConfig').mockReturnValue({
@@ -283,10 +257,6 @@ describe('ClearcutLogger', () => {
gemini_cli_key: EventMetadataKey.GEMINI_CLI_AUTH_TYPE,
value: JSON.stringify(auth_type),
},
{
gemini_cli_key: EventMetadataKey.GEMINI_CLI_GOOGLE_ACCOUNTS_COUNT,
value: `${google_accounts}`,
},
{
gemini_cli_key: EventMetadataKey.GEMINI_CLI_SURFACE,
value: surface,
@@ -404,10 +374,14 @@ describe('ClearcutLogger', () => {
vi.stubEnv(key, value);
}
const event = logger?.createLogEvent(EventNames.API_ERROR, []);
expect(event?.event_metadata[0][3]).toEqual({
gemini_cli_key: EventMetadataKey.GEMINI_CLI_SURFACE,
value: expectedValue,
});
expect(event?.event_metadata[0]).toEqual(
expect.arrayContaining([
{
gemini_cli_key: EventMetadataKey.GEMINI_CLI_SURFACE,
value: expectedValue,
},
]),
);
},
);
});

View File

@@ -34,7 +34,6 @@ import type {
import { EventMetadataKey } from './event-metadata-key.js';
import type { Config } from '../../config/config.js';
import { InstallationManager } from '../../utils/installationManager.js';
import { UserAccountManager } from '../../utils/userAccountManager.js';
import { safeJsonStringify } from '../../utils/safeJsonStringify.js';
import { FixedDeque } from 'mnemonist';
import { GIT_COMMIT_INFO, CLI_VERSION } from '../../generated/git-commit.js';
@@ -157,7 +156,6 @@ export class ClearcutLogger {
private sessionData: EventValue[] = [];
private promptId: string = '';
private readonly installationManager: InstallationManager;
private readonly userAccountManager: UserAccountManager;
/**
* Queue of pending events that need to be flushed to the server. New events
@@ -186,7 +184,6 @@ export class ClearcutLogger {
this.events = new FixedDeque<LogEventEntry[]>(Array, MAX_EVENTS);
this.promptId = config?.getSessionId() ?? '';
this.installationManager = new InstallationManager();
this.userAccountManager = new UserAccountManager();
}
static getInstance(config?: Config): ClearcutLogger | undefined {
@@ -233,14 +230,11 @@ export class ClearcutLogger {
}
createLogEvent(eventName: EventNames, data: EventValue[] = []): LogEvent {
const email = this.userAccountManager.getCachedGoogleAccount();
if (eventName !== EventNames.START_SESSION) {
data.push(...this.sessionData);
}
const totalAccounts = this.userAccountManager.getLifetimeGoogleAccounts();
data = this.addDefaultFields(data, totalAccounts);
data = this.addDefaultFields(data);
const logEvent: LogEvent = {
console_type: 'GEMINI_CLI',
@@ -249,12 +243,7 @@ export class ClearcutLogger {
event_metadata: [data],
};
// Should log either email or install ID, not both. See go/cloudmill-1p-oss-instrumentation#define-sessionable-id
if (email) {
logEvent.client_email = email;
} else {
logEvent.client_install_id = this.installationManager.getInstallationId();
}
logEvent.client_install_id = this.installationManager.getInstallationId();
return logEvent;
}
@@ -1018,7 +1007,7 @@ export class ClearcutLogger {
* Adds default fields to data, and returns a new data array. This fields
* should exist on all log events.
*/
addDefaultFields(data: EventValue[], totalAccounts: number): EventValue[] {
addDefaultFields(data: EventValue[]): EventValue[] {
const surface = determineSurface();
const defaultLogMetadata: EventValue[] = [
@@ -1032,10 +1021,6 @@ export class ClearcutLogger {
this.config?.getContentGeneratorConfig()?.authType,
),
},
{
gemini_cli_key: EventMetadataKey.GEMINI_CLI_GOOGLE_ACCOUNTS_COUNT,
value: `${totalAccounts}`,
},
{
gemini_cli_key: EventMetadataKey.GEMINI_CLI_SURFACE,
value: surface,

View File

@@ -83,7 +83,6 @@ import type {
} from '@google/genai';
import { DiscoveredMCPTool } from '../tools/mcp-tool.js';
import * as uiTelemetry from './uiTelemetry.js';
import { UserAccountManager } from '../utils/userAccountManager.js';
import { makeFakeConfig } from '../test-utils/config.js';
describe('loggers', () => {
@@ -101,10 +100,6 @@ describe('loggers', () => {
vi.spyOn(uiTelemetry.uiTelemetryService, 'addEvent').mockImplementation(
mockUiEvent.addEvent,
);
vi.spyOn(
UserAccountManager.prototype,
'getCachedGoogleAccount',
).mockReturnValue('test-user@example.com');
vi.useFakeTimers();
vi.setSystemTime(new Date('2025-01-01T00:00:00.000Z'));
});
@@ -188,7 +183,6 @@ describe('loggers', () => {
body: 'CLI configuration loaded.',
attributes: {
'session.id': 'test-session-id',
'user.email': 'test-user@example.com',
'event.name': EVENT_CLI_CONFIG,
'event.timestamp': '2025-01-01T00:00:00.000Z',
model: 'test-model',
@@ -233,7 +227,6 @@ describe('loggers', () => {
body: 'User prompt. Length: 11.',
attributes: {
'session.id': 'test-session-id',
'user.email': 'test-user@example.com',
'event.name': EVENT_USER_PROMPT,
'event.timestamp': '2025-01-01T00:00:00.000Z',
prompt_length: 11,
@@ -255,7 +248,7 @@ describe('loggers', () => {
const event = new UserPromptEvent(
11,
'prompt-id-9',
AuthType.CLOUD_SHELL,
AuthType.USE_GEMINI,
'test-prompt',
);
@@ -265,12 +258,11 @@ describe('loggers', () => {
body: 'User prompt. Length: 11.',
attributes: {
'session.id': 'test-session-id',
'user.email': 'test-user@example.com',
'event.name': EVENT_USER_PROMPT,
'event.timestamp': '2025-01-01T00:00:00.000Z',
prompt_length: 11,
prompt_id: 'prompt-id-9',
auth_type: 'cloud-shell',
auth_type: 'gemini-api-key',
},
});
});
@@ -313,7 +305,7 @@ describe('loggers', () => {
'test-model',
100,
'prompt-id-1',
AuthType.LOGIN_WITH_GOOGLE,
AuthType.USE_GEMINI,
usageData,
'test-response',
);
@@ -324,7 +316,6 @@ describe('loggers', () => {
body: 'API response from test-model. Status: 200. Duration: 100ms.',
attributes: {
'session.id': 'test-session-id',
'user.email': 'test-user@example.com',
'event.name': EVENT_API_RESPONSE,
'event.timestamp': '2025-01-01T00:00:00.000Z',
[SemanticAttributes.HTTP_STATUS_CODE]: 200,
@@ -340,7 +331,7 @@ describe('loggers', () => {
total_token_count: 0,
response_text: 'test-response',
prompt_id: 'prompt-id-1',
auth_type: 'oauth-personal',
auth_type: 'gemini-api-key',
},
});
@@ -386,7 +377,6 @@ describe('loggers', () => {
body: 'API request to test-model.',
attributes: {
'session.id': 'test-session-id',
'user.email': 'test-user@example.com',
'event.name': EVENT_API_REQUEST,
'event.timestamp': '2025-01-01T00:00:00.000Z',
model: 'test-model',
@@ -405,7 +395,6 @@ describe('loggers', () => {
body: 'API request to test-model.',
attributes: {
'session.id': 'test-session-id',
'user.email': 'test-user@example.com',
'event.name': EVENT_API_REQUEST,
'event.timestamp': '2025-01-01T00:00:00.000Z',
model: 'test-model',
@@ -430,7 +419,6 @@ describe('loggers', () => {
body: 'Switching to flash as Fallback.',
attributes: {
'session.id': 'test-session-id',
'user.email': 'test-user@example.com',
'event.name': EVENT_FLASH_FALLBACK,
'event.timestamp': '2025-01-01T00:00:00.000Z',
auth_type: 'vertex-ai',
@@ -465,7 +453,6 @@ describe('loggers', () => {
expect(emittedEvent.attributes).toEqual(
expect.objectContaining({
'session.id': 'test-session-id',
'user.email': 'test-user@example.com',
'event.name': EVENT_RIPGREP_FALLBACK,
error: 'ripgrep is not available',
}),
@@ -484,7 +471,6 @@ describe('loggers', () => {
expect(emittedEvent.attributes).toEqual(
expect.objectContaining({
'session.id': 'test-session-id',
'user.email': 'test-user@example.com',
'event.name': EVENT_RIPGREP_FALLBACK,
error: 'rg not found',
}),
@@ -598,7 +584,6 @@ describe('loggers', () => {
body: 'Tool call: test-function. Decision: accept. Success: true. Duration: 100ms.',
attributes: {
'session.id': 'test-session-id',
'user.email': 'test-user@example.com',
'event.name': EVENT_TOOL_CALL,
'event.timestamp': '2025-01-01T00:00:00.000Z',
function_name: 'test-function',
@@ -682,7 +667,6 @@ describe('loggers', () => {
body: 'Tool call: test-function. Decision: reject. Success: false. Duration: 100ms.',
attributes: {
'session.id': 'test-session-id',
'user.email': 'test-user@example.com',
'event.name': EVENT_TOOL_CALL,
'event.timestamp': '2025-01-01T00:00:00.000Z',
function_name: 'test-function',
@@ -759,7 +743,6 @@ describe('loggers', () => {
body: 'Tool call: test-function. Decision: modify. Success: true. Duration: 100ms.',
attributes: {
'session.id': 'test-session-id',
'user.email': 'test-user@example.com',
'event.name': EVENT_TOOL_CALL,
'event.timestamp': '2025-01-01T00:00:00.000Z',
function_name: 'test-function',
@@ -835,7 +818,6 @@ describe('loggers', () => {
body: 'Tool call: test-function. Success: true. Duration: 100ms.',
attributes: {
'session.id': 'test-session-id',
'user.email': 'test-user@example.com',
'event.name': EVENT_TOOL_CALL,
'event.timestamp': '2025-01-01T00:00:00.000Z',
function_name: 'test-function',
@@ -910,7 +892,6 @@ describe('loggers', () => {
body: 'Tool call: test-function. Success: false. Duration: 100ms.',
attributes: {
'session.id': 'test-session-id',
'user.email': 'test-user@example.com',
'event.name': EVENT_TOOL_CALL,
'event.timestamp': '2025-01-01T00:00:00.000Z',
function_name: 'test-function',
@@ -999,7 +980,6 @@ describe('loggers', () => {
body: 'Tool call: mock_mcp_tool. Success: true. Duration: 100ms.',
attributes: {
'session.id': 'test-session-id',
'user.email': 'test-user@example.com',
'event.name': EVENT_TOOL_CALL,
'event.timestamp': '2025-01-01T00:00:00.000Z',
function_name: 'mock_mcp_tool',
@@ -1047,7 +1027,6 @@ describe('loggers', () => {
body: 'Malformed JSON response from test-model.',
attributes: {
'session.id': 'test-session-id',
'user.email': 'test-user@example.com',
'event.name': EVENT_MALFORMED_JSON_RESPONSE,
'event.timestamp': '2025-01-01T00:00:00.000Z',
model: 'test-model',
@@ -1091,7 +1070,6 @@ describe('loggers', () => {
body: 'File operation: read. Lines: 10.',
attributes: {
'session.id': 'test-session-id',
'user.email': 'test-user@example.com',
'event.name': EVENT_FILE_OPERATION,
'event.timestamp': '2025-01-01T00:00:00.000Z',
tool_name: 'test-tool',
@@ -1137,7 +1115,6 @@ describe('loggers', () => {
body: 'Tool output truncated for test-tool.',
attributes: {
'session.id': 'test-session-id',
'user.email': 'test-user@example.com',
'event.name': 'tool_output_truncated',
'event.timestamp': '2025-01-01T00:00:00.000Z',
eventName: 'tool_output_truncated',
@@ -1184,7 +1161,6 @@ describe('loggers', () => {
body: 'Installed extension vscode',
attributes: {
'session.id': 'test-session-id',
'user.email': 'test-user@example.com',
'event.name': EVENT_EXTENSION_INSTALL,
'event.timestamp': '2025-01-01T00:00:00.000Z',
extension_name: 'vscode',
@@ -1223,7 +1199,6 @@ describe('loggers', () => {
body: 'Uninstalled extension vscode',
attributes: {
'session.id': 'test-session-id',
'user.email': 'test-user@example.com',
'event.name': EVENT_EXTENSION_UNINSTALL,
'event.timestamp': '2025-01-01T00:00:00.000Z',
extension_name: 'vscode',
@@ -1260,7 +1235,6 @@ describe('loggers', () => {
body: 'Enabled extension vscode',
attributes: {
'session.id': 'test-session-id',
'user.email': 'test-user@example.com',
'event.name': EVENT_EXTENSION_ENABLE,
'event.timestamp': '2025-01-01T00:00:00.000Z',
extension_name: 'vscode',
@@ -1297,7 +1271,6 @@ describe('loggers', () => {
body: 'Disabled extension vscode',
attributes: {
'session.id': 'test-session-id',
'user.email': 'test-user@example.com',
'event.name': EVENT_EXTENSION_DISABLE,
'event.timestamp': '2025-01-01T00:00:00.000Z',
extension_name: 'vscode',

View File

@@ -9,7 +9,6 @@ import { logs } from '@opentelemetry/api-logs';
import { SemanticAttributes } from '@opentelemetry/semantic-conventions';
import type { Config } from '../config/config.js';
import { safeJsonStringify } from '../utils/safeJsonStringify.js';
import { UserAccountManager } from '../utils/userAccountManager.js';
import {
EVENT_API_ERROR,
EVENT_API_CANCEL,
@@ -93,11 +92,8 @@ const shouldLogUserPrompts = (config: Config): boolean =>
config.getTelemetryLogPromptsEnabled();
function getCommonAttributes(config: Config): LogAttributes {
const userAccountManager = new UserAccountManager();
const email = userAccountManager.getCachedGoogleAccount();
return {
'session.id': config.getSessionId(),
...(email && { 'user.email': email }),
};
}

View File

@@ -217,9 +217,9 @@ describe('mcp-client', () => {
false,
);
expect(transport).toEqual(
new StreamableHTTPClientTransport(new URL('http://test-server'), {}),
);
expect(transport).toBeInstanceOf(StreamableHTTPClientTransport);
// eslint-disable-next-line @typescript-eslint/no-explicit-any
expect((transport as any)._url).toEqual(new URL('http://test-server'));
});
it('with headers', async () => {
@@ -232,13 +232,13 @@ describe('mcp-client', () => {
false,
);
expect(transport).toEqual(
new StreamableHTTPClientTransport(new URL('http://test-server'), {
requestInit: {
headers: { Authorization: 'derp' },
},
}),
);
expect(transport).toBeInstanceOf(StreamableHTTPClientTransport);
// eslint-disable-next-line @typescript-eslint/no-explicit-any
expect((transport as any)._url).toEqual(new URL('http://test-server'));
// eslint-disable-next-line @typescript-eslint/no-explicit-any
expect((transport as any)._requestInit?.headers).toEqual({
Authorization: 'derp',
});
});
});
@@ -251,9 +251,9 @@ describe('mcp-client', () => {
},
false,
);
expect(transport).toEqual(
new SSEClientTransport(new URL('http://test-server'), {}),
);
expect(transport).toBeInstanceOf(SSEClientTransport);
// eslint-disable-next-line @typescript-eslint/no-explicit-any
expect((transport as any)._url).toEqual(new URL('http://test-server'));
});
it('with headers', async () => {
@@ -266,13 +266,13 @@ describe('mcp-client', () => {
false,
);
expect(transport).toEqual(
new SSEClientTransport(new URL('http://test-server'), {
requestInit: {
headers: { Authorization: 'derp' },
},
}),
);
expect(transport).toBeInstanceOf(SSEClientTransport);
// eslint-disable-next-line @typescript-eslint/no-explicit-any
expect((transport as any)._url).toEqual(new URL('http://test-server'));
// eslint-disable-next-line @typescript-eslint/no-explicit-any
expect((transport as any)._requestInit?.headers).toEqual({
Authorization: 'derp',
});
});
});

View File

@@ -6,9 +6,6 @@
import { describe, it, expect } from 'vitest';
import { parseAndFormatApiError } from './errorParsing.js';
import { isProQuotaExceededError } from './quotaErrorDetection.js';
import { DEFAULT_GEMINI_FLASH_MODEL } from '../config/models.js';
import { UserTierId } from '../code_assist/types.js';
import { AuthType } from '../core/contentGenerator.js';
import type { StructuredError } from '../core/turn.js';
@@ -27,32 +24,10 @@ describe('parseAndFormatApiError', () => {
it('should format a 429 API error with the default message', () => {
const errorMessage =
'got status: 429 Too Many Requests. {"error":{"code":429,"message":"Rate limit exceeded","status":"RESOURCE_EXHAUSTED"}}';
const result = parseAndFormatApiError(
errorMessage,
undefined,
undefined,
'gemini-2.5-pro',
DEFAULT_GEMINI_FLASH_MODEL,
);
const result = parseAndFormatApiError(errorMessage, undefined);
expect(result).toContain('[API Error: Rate limit exceeded');
expect(result).toContain(
'Possible quota limitations in place or slow response times detected. Switching to the gemini-2.5-flash model',
);
});
it('should format a 429 API error with the personal message', () => {
const errorMessage =
'got status: 429 Too Many Requests. {"error":{"code":429,"message":"Rate limit exceeded","status":"RESOURCE_EXHAUSTED"}}';
const result = parseAndFormatApiError(
errorMessage,
AuthType.LOGIN_WITH_GOOGLE,
undefined,
'gemini-2.5-pro',
DEFAULT_GEMINI_FLASH_MODEL,
);
expect(result).toContain('[API Error: Rate limit exceeded');
expect(result).toContain(
'Possible quota limitations in place or slow response times detected. Switching to the gemini-2.5-flash model',
'Possible quota limitations in place or slow response times detected. Please wait and try again later.',
);
});
@@ -132,230 +107,4 @@ describe('parseAndFormatApiError', () => {
const expected = '[API Error: An unknown error occurred.]';
expect(parseAndFormatApiError(error)).toBe(expected);
});
it('should format a 429 API error with Pro quota exceeded message for Google auth (Free tier)', () => {
const errorMessage =
'got status: 429 Too Many Requests. {"error":{"code":429,"message":"Quota exceeded for quota metric \'Gemini 2.5 Pro Requests\' and limit \'RequestsPerDay\' of service \'generativelanguage.googleapis.com\' for consumer \'project_number:123456789\'.","status":"RESOURCE_EXHAUSTED"}}';
const result = parseAndFormatApiError(
errorMessage,
AuthType.LOGIN_WITH_GOOGLE,
undefined,
'gemini-2.5-pro',
DEFAULT_GEMINI_FLASH_MODEL,
);
expect(result).toContain(
"[API Error: Quota exceeded for quota metric 'Gemini 2.5 Pro Requests'",
);
expect(result).toContain(
'You have reached your daily gemini-2.5-pro quota limit',
);
expect(result).toContain('upgrade to get higher limits');
});
it('should format a regular 429 API error with standard message for Google auth', () => {
const errorMessage =
'got status: 429 Too Many Requests. {"error":{"code":429,"message":"Rate limit exceeded","status":"RESOURCE_EXHAUSTED"}}';
const result = parseAndFormatApiError(
errorMessage,
AuthType.LOGIN_WITH_GOOGLE,
undefined,
'gemini-2.5-pro',
DEFAULT_GEMINI_FLASH_MODEL,
);
expect(result).toContain('[API Error: Rate limit exceeded');
expect(result).toContain(
'Possible quota limitations in place or slow response times detected. Switching to the gemini-2.5-flash model',
);
expect(result).not.toContain(
'You have reached your daily gemini-2.5-pro quota limit',
);
});
it('should format a 429 API error with generic quota exceeded message for Google auth', () => {
const errorMessage =
'got status: 429 Too Many Requests. {"error":{"code":429,"message":"Quota exceeded for quota metric \'GenerationRequests\' and limit \'RequestsPerDay\' of service \'generativelanguage.googleapis.com\' for consumer \'project_number:123456789\'.","status":"RESOURCE_EXHAUSTED"}}';
const result = parseAndFormatApiError(
errorMessage,
AuthType.LOGIN_WITH_GOOGLE,
undefined,
'gemini-2.5-pro',
DEFAULT_GEMINI_FLASH_MODEL,
);
expect(result).toContain(
"[API Error: Quota exceeded for quota metric 'GenerationRequests'",
);
expect(result).toContain('You have reached your daily quota limit');
expect(result).not.toContain(
'You have reached your daily Gemini 2.5 Pro quota limit',
);
});
it('should prioritize Pro quota message over generic quota message for Google auth', () => {
const errorMessage =
'got status: 429 Too Many Requests. {"error":{"code":429,"message":"Quota exceeded for quota metric \'Gemini 2.5 Pro Requests\' and limit \'RequestsPerDay\' of service \'generativelanguage.googleapis.com\' for consumer \'project_number:123456789\'.","status":"RESOURCE_EXHAUSTED"}}';
const result = parseAndFormatApiError(
errorMessage,
AuthType.LOGIN_WITH_GOOGLE,
undefined,
'gemini-2.5-pro',
DEFAULT_GEMINI_FLASH_MODEL,
);
expect(result).toContain(
"[API Error: Quota exceeded for quota metric 'Gemini 2.5 Pro Requests'",
);
expect(result).toContain(
'You have reached your daily gemini-2.5-pro quota limit',
);
expect(result).not.toContain('You have reached your daily quota limit');
});
it('should format a 429 API error with Pro quota exceeded message for Google auth (Standard tier)', () => {
const errorMessage =
'got status: 429 Too Many Requests. {"error":{"code":429,"message":"Quota exceeded for quota metric \'Gemini 2.5 Pro Requests\' and limit \'RequestsPerDay\' of service \'generativelanguage.googleapis.com\' for consumer \'project_number:123456789\'.","status":"RESOURCE_EXHAUSTED"}}';
const result = parseAndFormatApiError(
errorMessage,
AuthType.LOGIN_WITH_GOOGLE,
UserTierId.STANDARD,
'gemini-2.5-pro',
DEFAULT_GEMINI_FLASH_MODEL,
);
expect(result).toContain(
"[API Error: Quota exceeded for quota metric 'Gemini 2.5 Pro Requests'",
);
expect(result).toContain(
'You have reached your daily gemini-2.5-pro quota limit',
);
expect(result).toContain(
'We appreciate you for choosing Gemini Code Assist and the Gemini CLI',
);
expect(result).not.toContain('upgrade to get higher limits');
});
it('should format a 429 API error with Pro quota exceeded message for Google auth (Legacy tier)', () => {
const errorMessage =
'got status: 429 Too Many Requests. {"error":{"code":429,"message":"Quota exceeded for quota metric \'Gemini 2.5 Pro Requests\' and limit \'RequestsPerDay\' of service \'generativelanguage.googleapis.com\' for consumer \'project_number:123456789\'.","status":"RESOURCE_EXHAUSTED"}}';
const result = parseAndFormatApiError(
errorMessage,
AuthType.LOGIN_WITH_GOOGLE,
UserTierId.LEGACY,
'gemini-2.5-pro',
DEFAULT_GEMINI_FLASH_MODEL,
);
expect(result).toContain(
"[API Error: Quota exceeded for quota metric 'Gemini 2.5 Pro Requests'",
);
expect(result).toContain(
'You have reached your daily gemini-2.5-pro quota limit',
);
expect(result).toContain(
'We appreciate you for choosing Gemini Code Assist and the Gemini CLI',
);
expect(result).not.toContain('upgrade to get higher limits');
});
it('should handle different Gemini 2.5 version strings in Pro quota exceeded errors', () => {
const errorMessage25 =
'got status: 429 Too Many Requests. {"error":{"code":429,"message":"Quota exceeded for quota metric \'Gemini 2.5 Pro Requests\' and limit \'RequestsPerDay\' of service \'generativelanguage.googleapis.com\' for consumer \'project_number:123456789\'.","status":"RESOURCE_EXHAUSTED"}}';
const errorMessagePreview =
'got status: 429 Too Many Requests. {"error":{"code":429,"message":"Quota exceeded for quota metric \'Gemini 2.5-preview Pro Requests\' and limit \'RequestsPerDay\' of service \'generativelanguage.googleapis.com\' for consumer \'project_number:123456789\'.","status":"RESOURCE_EXHAUSTED"}}';
const result25 = parseAndFormatApiError(
errorMessage25,
AuthType.LOGIN_WITH_GOOGLE,
undefined,
'gemini-2.5-pro',
DEFAULT_GEMINI_FLASH_MODEL,
);
const resultPreview = parseAndFormatApiError(
errorMessagePreview,
AuthType.LOGIN_WITH_GOOGLE,
undefined,
'gemini-2.5-preview-pro',
DEFAULT_GEMINI_FLASH_MODEL,
);
expect(result25).toContain(
'You have reached your daily gemini-2.5-pro quota limit',
);
expect(resultPreview).toContain(
'You have reached your daily gemini-2.5-preview-pro quota limit',
);
expect(result25).toContain('upgrade to get higher limits');
expect(resultPreview).toContain('upgrade to get higher limits');
});
it('should not match non-Pro models with similar version strings', () => {
// Test that Flash models with similar version strings don't match
expect(
isProQuotaExceededError(
"Quota exceeded for quota metric 'Gemini 2.5 Flash Requests' and limit",
),
).toBe(false);
expect(
isProQuotaExceededError(
"Quota exceeded for quota metric 'Gemini 2.5-preview Flash Requests' and limit",
),
).toBe(false);
// Test other model types
expect(
isProQuotaExceededError(
"Quota exceeded for quota metric 'Gemini 2.5 Ultra Requests' and limit",
),
).toBe(false);
expect(
isProQuotaExceededError(
"Quota exceeded for quota metric 'Gemini 2.5 Standard Requests' and limit",
),
).toBe(false);
// Test generic quota messages
expect(
isProQuotaExceededError(
"Quota exceeded for quota metric 'GenerationRequests' and limit",
),
).toBe(false);
expect(
isProQuotaExceededError(
"Quota exceeded for quota metric 'EmbeddingRequests' and limit",
),
).toBe(false);
});
it('should format a generic quota exceeded message for Google auth (Standard tier)', () => {
const errorMessage =
'got status: 429 Too Many Requests. {"error":{"code":429,"message":"Quota exceeded for quota metric \'GenerationRequests\' and limit \'RequestsPerDay\' of service \'generativelanguage.googleapis.com\' for consumer \'project_number:123456789\'.","status":"RESOURCE_EXHAUSTED"}}';
const result = parseAndFormatApiError(
errorMessage,
AuthType.LOGIN_WITH_GOOGLE,
UserTierId.STANDARD,
'gemini-2.5-pro',
DEFAULT_GEMINI_FLASH_MODEL,
);
expect(result).toContain(
"[API Error: Quota exceeded for quota metric 'GenerationRequests'",
);
expect(result).toContain('You have reached your daily quota limit');
expect(result).toContain(
'We appreciate you for choosing Gemini Code Assist and the Gemini CLI',
);
expect(result).not.toContain('upgrade to get higher limits');
});
it('should format a regular 429 API error with standard message for Google auth (Standard tier)', () => {
const errorMessage =
'got status: 429 Too Many Requests. {"error":{"code":429,"message":"Rate limit exceeded","status":"RESOURCE_EXHAUSTED"}}';
const result = parseAndFormatApiError(
errorMessage,
AuthType.LOGIN_WITH_GOOGLE,
UserTierId.STANDARD,
'gemini-2.5-pro',
DEFAULT_GEMINI_FLASH_MODEL,
);
expect(result).toContain('[API Error: Rate limit exceeded');
expect(result).toContain(
'We appreciate you for choosing Gemini Code Assist and the Gemini CLI',
);
expect(result).not.toContain('upgrade to get higher limits');
});
});

View File

@@ -4,120 +4,36 @@
* SPDX-License-Identifier: Apache-2.0
*/
import {
isProQuotaExceededError,
isGenericQuotaExceededError,
isApiError,
isStructuredError,
} from './quotaErrorDetection.js';
import {
DEFAULT_GEMINI_MODEL,
DEFAULT_GEMINI_FLASH_MODEL,
} from '../config/models.js';
import { UserTierId } from '../code_assist/types.js';
import { isApiError, isStructuredError } from './quotaErrorDetection.js';
import { AuthType } from '../core/contentGenerator.js';
// Free Tier message functions
const getRateLimitErrorMessageGoogleFree = (
fallbackModel: string = DEFAULT_GEMINI_FLASH_MODEL,
) =>
`\nPossible quota limitations in place or slow response times detected. Switching to the ${fallbackModel} model for the rest of this session.`;
const getRateLimitErrorMessageGoogleProQuotaFree = (
currentModel: string = DEFAULT_GEMINI_MODEL,
fallbackModel: string = DEFAULT_GEMINI_FLASH_MODEL,
) =>
`\nYou have reached your daily ${currentModel} quota limit. You will be switched to the ${fallbackModel} model for the rest of this session. To increase your limits, upgrade to get higher limits at https://goo.gle/set-up-gemini-code-assist, or use /auth to switch to using a paid API key from AI Studio at https://aistudio.google.com/apikey`;
const getRateLimitErrorMessageGoogleGenericQuotaFree = () =>
`\nYou have reached your daily quota limit. To increase your limits, upgrade to get higher limits at https://goo.gle/set-up-gemini-code-assist, or use /auth to switch to using a paid API key from AI Studio at https://aistudio.google.com/apikey`;
// Legacy/Standard Tier message functions
const getRateLimitErrorMessageGooglePaid = (
fallbackModel: string = DEFAULT_GEMINI_FLASH_MODEL,
) =>
`\nPossible quota limitations in place or slow response times detected. Switching to the ${fallbackModel} model for the rest of this session. We appreciate you for choosing Gemini Code Assist and the Gemini CLI.`;
const getRateLimitErrorMessageGoogleProQuotaPaid = (
currentModel: string = DEFAULT_GEMINI_MODEL,
fallbackModel: string = DEFAULT_GEMINI_FLASH_MODEL,
) =>
`\nYou have reached your daily ${currentModel} quota limit. You will be switched to the ${fallbackModel} model for the rest of this session. We appreciate you for choosing Gemini Code Assist and the Gemini CLI. To continue accessing the ${currentModel} model today, consider using /auth to switch to using a paid API key from AI Studio at https://aistudio.google.com/apikey`;
const getRateLimitErrorMessageGoogleGenericQuotaPaid = (
currentModel: string = DEFAULT_GEMINI_MODEL,
) =>
`\nYou have reached your daily quota limit. We appreciate you for choosing Gemini Code Assist and the Gemini CLI. To continue accessing the ${currentModel} model today, consider using /auth to switch to using a paid API key from AI Studio at https://aistudio.google.com/apikey`;
const RATE_LIMIT_ERROR_MESSAGE_USE_GEMINI =
'\nPlease wait and try again later. To increase your limits, request a quota increase through AI Studio, or switch to another /auth method';
const RATE_LIMIT_ERROR_MESSAGE_VERTEX =
'\nPlease wait and try again later. To increase your limits, request a quota increase through Vertex, or switch to another /auth method';
const getRateLimitErrorMessageDefault = (
fallbackModel: string = DEFAULT_GEMINI_FLASH_MODEL,
) =>
`\nPossible quota limitations in place or slow response times detected. Switching to the ${fallbackModel} model for the rest of this session.`;
const RATE_LIMIT_ERROR_MESSAGE_DEFAULT =
'\nPossible quota limitations in place or slow response times detected. Please wait and try again later.';
function getRateLimitMessage(
authType?: AuthType,
error?: unknown,
userTier?: UserTierId,
currentModel?: string,
fallbackModel?: string,
): string {
function getRateLimitMessage(authType?: AuthType): string {
switch (authType) {
case AuthType.LOGIN_WITH_GOOGLE: {
// Determine if user is on a paid tier (Legacy or Standard) - default to FREE if not specified
const isPaidTier =
userTier === UserTierId.LEGACY || userTier === UserTierId.STANDARD;
if (isProQuotaExceededError(error)) {
return isPaidTier
? getRateLimitErrorMessageGoogleProQuotaPaid(
currentModel || DEFAULT_GEMINI_MODEL,
fallbackModel,
)
: getRateLimitErrorMessageGoogleProQuotaFree(
currentModel || DEFAULT_GEMINI_MODEL,
fallbackModel,
);
} else if (isGenericQuotaExceededError(error)) {
return isPaidTier
? getRateLimitErrorMessageGoogleGenericQuotaPaid(
currentModel || DEFAULT_GEMINI_MODEL,
)
: getRateLimitErrorMessageGoogleGenericQuotaFree();
} else {
return isPaidTier
? getRateLimitErrorMessageGooglePaid(fallbackModel)
: getRateLimitErrorMessageGoogleFree(fallbackModel);
}
}
case AuthType.USE_GEMINI:
return RATE_LIMIT_ERROR_MESSAGE_USE_GEMINI;
case AuthType.USE_VERTEX_AI:
return RATE_LIMIT_ERROR_MESSAGE_VERTEX;
default:
return getRateLimitErrorMessageDefault(fallbackModel);
return RATE_LIMIT_ERROR_MESSAGE_DEFAULT;
}
}
export function parseAndFormatApiError(
error: unknown,
authType?: AuthType,
userTier?: UserTierId,
currentModel?: string,
fallbackModel?: string,
): string {
if (isStructuredError(error)) {
let text = `[API Error: ${error.message}]`;
if (error.status === 429) {
text += getRateLimitMessage(
authType,
error,
userTier,
currentModel,
fallbackModel,
);
text += getRateLimitMessage(authType);
}
return text;
}
@@ -146,13 +62,7 @@ export function parseAndFormatApiError(
}
let text = `[API Error: ${finalMessage} (Status: ${parsedError.error.status})]`;
if (parsedError.error.code === 429) {
text += getRateLimitMessage(
authType,
parsedError,
userTier,
currentModel,
fallbackModel,
);
text += getRateLimitMessage(authType);
}
return text;
}

View File

@@ -11,12 +11,9 @@ import {
setSimulate429,
disableSimulationAfterFallback,
shouldSimulate429,
createSimulated429Error,
resetRequestCounter,
} from './testUtils.js';
import { DEFAULT_GEMINI_FLASH_MODEL } from '../config/models.js';
import { retryWithBackoff } from './retry.js';
import { AuthType } from '../core/contentGenerator.js';
// Import the new types (Assuming this test file is in packages/core/src/utils/)
import type { FallbackModelHandler } from '../fallback/types.js';
@@ -61,84 +58,6 @@ describe('Retry Utility Fallback Integration', () => {
expect(result).toBe('retry');
});
// This test validates the retry utility's logic for triggering the callback.
it('should trigger onPersistent429 after 2 consecutive 429 errors for OAuth users', async () => {
let fallbackCalled = false;
// Removed fallbackModel variable as it's no longer relevant here.
// Mock function that simulates exactly 2 429 errors, then succeeds after fallback
const mockApiCall = vi
.fn()
.mockRejectedValueOnce(createSimulated429Error())
.mockRejectedValueOnce(createSimulated429Error())
.mockResolvedValueOnce('success after fallback');
// Mock the onPersistent429 callback (this is what client.ts/geminiChat.ts provides)
const mockPersistent429Callback = vi.fn(async (_authType?: string) => {
fallbackCalled = true;
// Return true to signal retryWithBackoff to reset attempts and continue.
return true;
});
// Test with OAuth personal auth type, with maxAttempts = 2 to ensure fallback triggers
const result = await retryWithBackoff(mockApiCall, {
maxAttempts: 2,
initialDelayMs: 1,
maxDelayMs: 10,
shouldRetryOnError: (error: Error) => {
const status = (error as Error & { status?: number }).status;
return status === 429;
},
onPersistent429: mockPersistent429Callback,
authType: AuthType.LOGIN_WITH_GOOGLE,
});
// Verify fallback mechanism was triggered
expect(fallbackCalled).toBe(true);
expect(mockPersistent429Callback).toHaveBeenCalledWith(
AuthType.LOGIN_WITH_GOOGLE,
expect.any(Error),
);
expect(result).toBe('success after fallback');
// Should have: 2 failures, then fallback triggered, then 1 success after retry reset
expect(mockApiCall).toHaveBeenCalledTimes(3);
});
it('should not trigger onPersistent429 for API key users', async () => {
let fallbackCalled = false;
// Mock function that simulates 429 errors
const mockApiCall = vi.fn().mockRejectedValue(createSimulated429Error());
// Mock the callback
const mockPersistent429Callback = vi.fn(async () => {
fallbackCalled = true;
return true;
});
// Test with API key auth type - should not trigger fallback
try {
await retryWithBackoff(mockApiCall, {
maxAttempts: 5,
initialDelayMs: 10,
maxDelayMs: 100,
shouldRetryOnError: (error: Error) => {
const status = (error as Error & { status?: number }).status;
return status === 429;
},
onPersistent429: mockPersistent429Callback,
authType: AuthType.USE_GEMINI, // API key auth type
});
} catch (error) {
// Expected to throw after max attempts
expect((error as Error).message).toContain('Rate limit exceeded');
}
// Verify fallback was NOT triggered for API key users
expect(fallbackCalled).toBe(false);
expect(mockPersistent429Callback).not.toHaveBeenCalled();
});
// This test validates the test utilities themselves.
it('should properly disable simulation state after fallback (Test Utility)', () => {
// Enable simulation

View File

@@ -61,6 +61,7 @@ describe('checkNextSpeaker', () => {
generateContentStream: vi.fn(),
countTokens: vi.fn(),
embedContent: vi.fn(),
useSummarizedThinking: vi.fn().mockReturnValue(false),
} as ContentGenerator,
{} as Config,
);

View File

@@ -81,7 +81,7 @@ export function getResponseText(
candidate.content.parts.length > 0
) {
return candidate.content.parts
.filter((part) => part.text)
.filter((part) => part.text && !part.thought)
.map((part) => part.text)
.join('');
}

View File

@@ -285,173 +285,6 @@ describe('retryWithBackoff', () => {
});
});
describe('Flash model fallback for OAuth users', () => {
it('should trigger fallback for OAuth personal users after persistent 429 errors', async () => {
const fallbackCallback = vi.fn().mockResolvedValue('gemini-2.5-flash');
let fallbackOccurred = false;
const mockFn = vi.fn().mockImplementation(async () => {
if (!fallbackOccurred) {
const error: HttpError = new Error('Rate limit exceeded');
error.status = 429;
throw error;
}
return 'success';
});
const promise = retryWithBackoff(mockFn, {
maxAttempts: 3,
initialDelayMs: 100,
onPersistent429: async (authType?: string) => {
fallbackOccurred = true;
return await fallbackCallback(authType);
},
authType: 'oauth-personal',
});
// Advance all timers to complete retries
await vi.runAllTimersAsync();
// Should succeed after fallback
await expect(promise).resolves.toBe('success');
// Verify callback was called with correct auth type
expect(fallbackCallback).toHaveBeenCalledWith('oauth-personal');
// Should retry again after fallback
expect(mockFn).toHaveBeenCalledTimes(3); // 2 initial attempts + 1 after fallback
});
it('should NOT trigger fallback for API key users', async () => {
const fallbackCallback = vi.fn();
const mockFn = vi.fn(async () => {
const error: HttpError = new Error('Rate limit exceeded');
error.status = 429;
throw error;
});
const promise = retryWithBackoff(mockFn, {
maxAttempts: 3,
initialDelayMs: 100,
onPersistent429: fallbackCallback,
authType: 'gemini-api-key',
});
// Handle the promise properly to avoid unhandled rejections
const resultPromise = promise.catch((error) => error);
await vi.runAllTimersAsync();
const result = await resultPromise;
// Should fail after all retries without fallback
expect(result).toBeInstanceOf(Error);
expect(result.message).toBe('Rate limit exceeded');
// Callback should not be called for API key users
expect(fallbackCallback).not.toHaveBeenCalled();
});
it('should reset attempt counter and continue after successful fallback', async () => {
let fallbackCalled = false;
const fallbackCallback = vi.fn().mockImplementation(async () => {
fallbackCalled = true;
return 'gemini-2.5-flash';
});
const mockFn = vi.fn().mockImplementation(async () => {
if (!fallbackCalled) {
const error: HttpError = new Error('Rate limit exceeded');
error.status = 429;
throw error;
}
return 'success';
});
const promise = retryWithBackoff(mockFn, {
maxAttempts: 3,
initialDelayMs: 100,
onPersistent429: fallbackCallback,
authType: 'oauth-personal',
});
await vi.runAllTimersAsync();
await expect(promise).resolves.toBe('success');
expect(fallbackCallback).toHaveBeenCalledOnce();
});
it('should continue with original error if fallback is rejected', async () => {
const fallbackCallback = vi.fn().mockResolvedValue(null); // User rejected fallback
const mockFn = vi.fn(async () => {
const error: HttpError = new Error('Rate limit exceeded');
error.status = 429;
throw error;
});
const promise = retryWithBackoff(mockFn, {
maxAttempts: 3,
initialDelayMs: 100,
onPersistent429: fallbackCallback,
authType: 'oauth-personal',
});
// Handle the promise properly to avoid unhandled rejections
const resultPromise = promise.catch((error) => error);
await vi.runAllTimersAsync();
const result = await resultPromise;
// Should fail with original error when fallback is rejected
expect(result).toBeInstanceOf(Error);
expect(result.message).toBe('Rate limit exceeded');
expect(fallbackCallback).toHaveBeenCalledWith(
'oauth-personal',
expect.any(Error),
);
});
it('should handle mixed error types (only count consecutive 429s)', async () => {
const fallbackCallback = vi.fn().mockResolvedValue('gemini-2.5-flash');
let attempts = 0;
let fallbackOccurred = false;
const mockFn = vi.fn().mockImplementation(async () => {
attempts++;
if (fallbackOccurred) {
return 'success';
}
if (attempts === 1) {
// First attempt: 500 error (resets consecutive count)
const error: HttpError = new Error('Server error');
error.status = 500;
throw error;
} else {
// Remaining attempts: 429 errors
const error: HttpError = new Error('Rate limit exceeded');
error.status = 429;
throw error;
}
});
const promise = retryWithBackoff(mockFn, {
maxAttempts: 5,
initialDelayMs: 100,
onPersistent429: async (authType?: string) => {
fallbackOccurred = true;
return await fallbackCallback(authType);
},
authType: 'oauth-personal',
});
await vi.runAllTimersAsync();
await expect(promise).resolves.toBe('success');
// Should trigger fallback after 2 consecutive 429s (attempts 2-3)
expect(fallbackCallback).toHaveBeenCalledWith('oauth-personal');
});
});
describe('Qwen OAuth 429 error handling', () => {
it('should retry for Qwen OAuth 429 errors that are throttling-related', async () => {
const errorWith429: HttpError = new Error('Rate limit exceeded');

View File

@@ -7,8 +7,6 @@
import type { GenerateContentResponse } from '@google/genai';
import { AuthType } from '../core/contentGenerator.js';
import {
isProQuotaExceededError,
isGenericQuotaExceededError,
isQwenQuotaExceededError,
isQwenThrottlingError,
} from './quotaErrorDetection.js';
@@ -90,7 +88,6 @@ export async function retryWithBackoff<T>(
maxAttempts,
initialDelayMs,
maxDelayMs,
onPersistent429,
authType,
shouldRetryOnError,
shouldRetryOnContent,
@@ -123,59 +120,6 @@ export async function retryWithBackoff<T>(
} catch (error) {
const errorStatus = getErrorStatus(error);
// Check for Pro quota exceeded error first - immediate fallback for OAuth users
if (
errorStatus === 429 &&
authType === AuthType.LOGIN_WITH_GOOGLE &&
isProQuotaExceededError(error) &&
onPersistent429
) {
try {
const fallbackModel = await onPersistent429(authType, error);
if (fallbackModel !== false && fallbackModel !== null) {
// Reset attempt counter and try with new model
attempt = 0;
consecutive429Count = 0;
currentDelay = initialDelayMs;
// With the model updated, we continue to the next attempt
continue;
} else {
// Fallback handler returned null/false, meaning don't continue - stop retry process
throw error;
}
} catch (fallbackError) {
// If fallback fails, continue with original error
console.warn('Fallback to Flash model failed:', fallbackError);
}
}
// Check for generic quota exceeded error (but not Pro, which was handled above) - immediate fallback for OAuth users
if (
errorStatus === 429 &&
authType === AuthType.LOGIN_WITH_GOOGLE &&
!isProQuotaExceededError(error) &&
isGenericQuotaExceededError(error) &&
onPersistent429
) {
try {
const fallbackModel = await onPersistent429(authType, error);
if (fallbackModel !== false && fallbackModel !== null) {
// Reset attempt counter and try with new model
attempt = 0;
consecutive429Count = 0;
currentDelay = initialDelayMs;
// With the model updated, we continue to the next attempt
continue;
} else {
// Fallback handler returned null/false, meaning don't continue - stop retry process
throw error;
}
} catch (fallbackError) {
// If fallback fails, continue with original error
console.warn('Fallback to Flash model failed:', fallbackError);
}
}
// Check for Qwen OAuth quota exceeded error - throw immediately without retry
if (authType === AuthType.QWEN_OAUTH && isQwenQuotaExceededError(error)) {
throw new Error(
@@ -197,30 +141,7 @@ export async function retryWithBackoff<T>(
consecutive429Count = 0;
}
// If we have persistent 429s and a fallback callback for OAuth
if (
consecutive429Count >= 2 &&
onPersistent429 &&
authType === AuthType.LOGIN_WITH_GOOGLE
) {
try {
const fallbackModel = await onPersistent429(authType, error);
if (fallbackModel !== false && fallbackModel !== null) {
// Reset attempt counter and try with new model
attempt = 0;
consecutive429Count = 0;
currentDelay = initialDelayMs;
// With the model updated, we continue to the next attempt
continue;
} else {
// Fallback handler returned null/false, meaning don't continue - stop retry process
throw error;
}
} catch (fallbackError) {
// If fallback fails, continue with original error
console.warn('Fallback to Flash model failed:', fallbackError);
}
}
console.debug('consecutive429Count', consecutive429Count);
// Check if we've exhausted retries or shouldn't retry
if (attempt >= maxAttempts || !shouldRetryOnError(error as Error)) {
@@ -240,7 +161,7 @@ export async function retryWithBackoff<T>(
// Reset currentDelay for next potential non-429 error, or if Retry-After is not present next time
currentDelay = initialDelayMs;
} else {
// Fall back to exponential backoff with jitter
// Fallback to exponential backoff with jitter
logRetryAttempt(attempt, error, errorStatus);
// Add jitter: +/- 30% of currentDelay
const jitter = currentDelay * 0.3 * (Math.random() * 2 - 1);

View File

@@ -29,7 +29,7 @@ export function parseThought(rawText: string): ThoughtSummary {
const startIndex = rawText.indexOf(START_DELIMITER);
if (startIndex === -1) {
// No start delimiter found, the whole text is the description.
return { subject: '', description: rawText.trim() };
return { subject: '', description: rawText };
}
const endIndex = rawText.indexOf(
@@ -39,7 +39,7 @@ export function parseThought(rawText: string): ThoughtSummary {
if (endIndex === -1) {
// Start delimiter found but no end delimiter, so it's not a valid subject.
// Treat the entire string as the description.
return { subject: '', description: rawText.trim() };
return { subject: '', description: rawText };
}
const subject = rawText

View File

@@ -1,340 +0,0 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import type { Mock } from 'vitest';
import { vi, describe, it, expect, beforeEach, afterEach } from 'vitest';
import { UserAccountManager } from './userAccountManager.js';
import * as fs from 'node:fs';
import * as os from 'node:os';
import path from 'node:path';
vi.mock('os', async (importOriginal) => {
const os = await importOriginal<typeof import('os')>();
return {
...os,
homedir: vi.fn(),
};
});
describe('UserAccountManager', () => {
let tempHomeDir: string;
let userAccountManager: UserAccountManager;
let accountsFile: () => string;
beforeEach(() => {
tempHomeDir = fs.mkdtempSync(
path.join(os.tmpdir(), 'qwen-code-test-home-'),
);
(os.homedir as Mock).mockReturnValue(tempHomeDir);
accountsFile = () =>
path.join(tempHomeDir, '.qwen', 'google_accounts.json');
userAccountManager = new UserAccountManager();
});
afterEach(() => {
fs.rmSync(tempHomeDir, { recursive: true, force: true });
vi.clearAllMocks();
});
describe('cacheGoogleAccount', () => {
it('should create directory and write initial account file', async () => {
await userAccountManager.cacheGoogleAccount('test1@google.com');
// Verify Google Account ID was cached
expect(fs.existsSync(accountsFile())).toBe(true);
expect(fs.readFileSync(accountsFile(), 'utf-8')).toBe(
JSON.stringify({ active: 'test1@google.com', old: [] }, null, 2),
);
});
it('should update active account and move previous to old', async () => {
fs.mkdirSync(path.dirname(accountsFile()), { recursive: true });
fs.writeFileSync(
accountsFile(),
JSON.stringify(
{ active: 'test2@google.com', old: ['test1@google.com'] },
null,
2,
),
);
await userAccountManager.cacheGoogleAccount('test3@google.com');
expect(fs.readFileSync(accountsFile(), 'utf-8')).toBe(
JSON.stringify(
{
active: 'test3@google.com',
old: ['test1@google.com', 'test2@google.com'],
},
null,
2,
),
);
});
it('should not add a duplicate to the old list', async () => {
fs.mkdirSync(path.dirname(accountsFile()), { recursive: true });
fs.writeFileSync(
accountsFile(),
JSON.stringify(
{ active: 'test1@google.com', old: ['test2@google.com'] },
null,
2,
),
);
await userAccountManager.cacheGoogleAccount('test2@google.com');
await userAccountManager.cacheGoogleAccount('test1@google.com');
expect(fs.readFileSync(accountsFile(), 'utf-8')).toBe(
JSON.stringify(
{ active: 'test1@google.com', old: ['test2@google.com'] },
null,
2,
),
);
});
it('should handle corrupted JSON by starting fresh', async () => {
fs.mkdirSync(path.dirname(accountsFile()), { recursive: true });
fs.writeFileSync(accountsFile(), 'not valid json');
const consoleLogSpy = vi
.spyOn(console, 'log')
.mockImplementation(() => {});
await userAccountManager.cacheGoogleAccount('test1@google.com');
expect(consoleLogSpy).toHaveBeenCalled();
expect(JSON.parse(fs.readFileSync(accountsFile(), 'utf-8'))).toEqual({
active: 'test1@google.com',
old: [],
});
});
it('should handle valid JSON with incorrect schema by starting fresh', async () => {
fs.mkdirSync(path.dirname(accountsFile()), { recursive: true });
fs.writeFileSync(
accountsFile(),
JSON.stringify({ active: 'test1@google.com', old: 'not-an-array' }),
);
const consoleLogSpy = vi
.spyOn(console, 'log')
.mockImplementation(() => {});
await userAccountManager.cacheGoogleAccount('test2@google.com');
expect(consoleLogSpy).toHaveBeenCalled();
expect(JSON.parse(fs.readFileSync(accountsFile(), 'utf-8'))).toEqual({
active: 'test2@google.com',
old: [],
});
});
});
describe('getCachedGoogleAccount', () => {
it('should return the active account if file exists and is valid', () => {
fs.mkdirSync(path.dirname(accountsFile()), { recursive: true });
fs.writeFileSync(
accountsFile(),
JSON.stringify({ active: 'active@google.com', old: [] }, null, 2),
);
const account = userAccountManager.getCachedGoogleAccount();
expect(account).toBe('active@google.com');
});
it('should return null if file does not exist', () => {
const account = userAccountManager.getCachedGoogleAccount();
expect(account).toBeNull();
});
it('should return null if file is empty', () => {
fs.mkdirSync(path.dirname(accountsFile()), { recursive: true });
fs.writeFileSync(accountsFile(), '');
const account = userAccountManager.getCachedGoogleAccount();
expect(account).toBeNull();
});
it('should return null and log if file is corrupted', () => {
fs.mkdirSync(path.dirname(accountsFile()), { recursive: true });
fs.writeFileSync(accountsFile(), '{ "active": "test@google.com"'); // Invalid JSON
const consoleLogSpy = vi
.spyOn(console, 'log')
.mockImplementation(() => {});
const account = userAccountManager.getCachedGoogleAccount();
expect(account).toBeNull();
expect(consoleLogSpy).toHaveBeenCalled();
});
it('should return null if active key is missing', () => {
fs.mkdirSync(path.dirname(accountsFile()), { recursive: true });
fs.writeFileSync(accountsFile(), JSON.stringify({ old: [] }));
const account = userAccountManager.getCachedGoogleAccount();
expect(account).toBeNull();
});
});
describe('clearCachedGoogleAccount', () => {
it('should set active to null and move it to old', async () => {
fs.mkdirSync(path.dirname(accountsFile()), { recursive: true });
fs.writeFileSync(
accountsFile(),
JSON.stringify(
{ active: 'active@google.com', old: ['old1@google.com'] },
null,
2,
),
);
await userAccountManager.clearCachedGoogleAccount();
const stored = JSON.parse(fs.readFileSync(accountsFile(), 'utf-8'));
expect(stored.active).toBeNull();
expect(stored.old).toEqual(['old1@google.com', 'active@google.com']);
});
it('should handle empty file gracefully', async () => {
fs.mkdirSync(path.dirname(accountsFile()), { recursive: true });
fs.writeFileSync(accountsFile(), '');
await userAccountManager.clearCachedGoogleAccount();
const stored = JSON.parse(fs.readFileSync(accountsFile(), 'utf-8'));
expect(stored.active).toBeNull();
expect(stored.old).toEqual([]);
});
it('should handle corrupted JSON by creating a fresh file', async () => {
fs.mkdirSync(path.dirname(accountsFile()), { recursive: true });
fs.writeFileSync(accountsFile(), 'not valid json');
const consoleLogSpy = vi
.spyOn(console, 'log')
.mockImplementation(() => {});
await userAccountManager.clearCachedGoogleAccount();
expect(consoleLogSpy).toHaveBeenCalled();
const stored = JSON.parse(fs.readFileSync(accountsFile(), 'utf-8'));
expect(stored.active).toBeNull();
expect(stored.old).toEqual([]);
});
it('should be idempotent if active account is already null', async () => {
fs.mkdirSync(path.dirname(accountsFile()), { recursive: true });
fs.writeFileSync(
accountsFile(),
JSON.stringify({ active: null, old: ['old1@google.com'] }, null, 2),
);
await userAccountManager.clearCachedGoogleAccount();
const stored = JSON.parse(fs.readFileSync(accountsFile(), 'utf-8'));
expect(stored.active).toBeNull();
expect(stored.old).toEqual(['old1@google.com']);
});
it('should not add a duplicate to the old list', async () => {
fs.mkdirSync(path.dirname(accountsFile()), { recursive: true });
fs.writeFileSync(
accountsFile(),
JSON.stringify(
{
active: 'active@google.com',
old: ['active@google.com'],
},
null,
2,
),
);
await userAccountManager.clearCachedGoogleAccount();
const stored = JSON.parse(fs.readFileSync(accountsFile(), 'utf-8'));
expect(stored.active).toBeNull();
expect(stored.old).toEqual(['active@google.com']);
});
});
describe('getLifetimeGoogleAccounts', () => {
it('should return 0 if the file does not exist', () => {
expect(userAccountManager.getLifetimeGoogleAccounts()).toBe(0);
});
it('should return 0 if the file is empty', () => {
fs.mkdirSync(path.dirname(accountsFile()), { recursive: true });
fs.writeFileSync(accountsFile(), '');
expect(userAccountManager.getLifetimeGoogleAccounts()).toBe(0);
});
it('should return 0 if the file is corrupted', () => {
fs.mkdirSync(path.dirname(accountsFile()), { recursive: true });
fs.writeFileSync(accountsFile(), 'invalid json');
const consoleDebugSpy = vi
.spyOn(console, 'log')
.mockImplementation(() => {});
expect(userAccountManager.getLifetimeGoogleAccounts()).toBe(0);
expect(consoleDebugSpy).toHaveBeenCalled();
});
it('should return 1 if there is only an active account', () => {
fs.mkdirSync(path.dirname(accountsFile()), { recursive: true });
fs.writeFileSync(
accountsFile(),
JSON.stringify({ active: 'test1@google.com', old: [] }),
);
expect(userAccountManager.getLifetimeGoogleAccounts()).toBe(1);
});
it('should correctly count old accounts when active is null', () => {
fs.mkdirSync(path.dirname(accountsFile()), { recursive: true });
fs.writeFileSync(
accountsFile(),
JSON.stringify({
active: null,
old: ['test1@google.com', 'test2@google.com'],
}),
);
expect(userAccountManager.getLifetimeGoogleAccounts()).toBe(2);
});
it('should correctly count both active and old accounts', () => {
fs.mkdirSync(path.dirname(accountsFile()), { recursive: true });
fs.writeFileSync(
accountsFile(),
JSON.stringify({
active: 'test3@google.com',
old: ['test1@google.com', 'test2@google.com'],
}),
);
expect(userAccountManager.getLifetimeGoogleAccounts()).toBe(3);
});
it('should handle valid JSON with incorrect schema by returning 0', () => {
fs.mkdirSync(path.dirname(accountsFile()), { recursive: true });
fs.writeFileSync(
accountsFile(),
JSON.stringify({ active: null, old: 1 }),
);
const consoleLogSpy = vi
.spyOn(console, 'log')
.mockImplementation(() => {});
expect(userAccountManager.getLifetimeGoogleAccounts()).toBe(0);
expect(consoleLogSpy).toHaveBeenCalled();
});
it('should not double count if active account is also in old list', () => {
fs.mkdirSync(path.dirname(accountsFile()), { recursive: true });
fs.writeFileSync(
accountsFile(),
JSON.stringify({
active: 'test1@google.com',
old: ['test1@google.com', 'test2@google.com'],
}),
);
expect(userAccountManager.getLifetimeGoogleAccounts()).toBe(2);
});
});
});

View File

@@ -1,140 +0,0 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import path from 'node:path';
import { promises as fsp, readFileSync } from 'node:fs';
import { Storage } from '../config/storage.js';
interface UserAccounts {
active: string | null;
old: string[];
}
export class UserAccountManager {
private getGoogleAccountsCachePath(): string {
return Storage.getGoogleAccountsPath();
}
/**
* Parses and validates the string content of an accounts file.
* @param content The raw string content from the file.
* @returns A valid UserAccounts object.
*/
private parseAndValidateAccounts(content: string): UserAccounts {
const defaultState = { active: null, old: [] };
if (!content.trim()) {
return defaultState;
}
const parsed = JSON.parse(content);
// Inlined validation logic
if (typeof parsed !== 'object' || parsed === null) {
console.log('Invalid accounts file schema, starting fresh.');
return defaultState;
}
const { active, old } = parsed as Partial<UserAccounts>;
const isValid =
(active === undefined || active === null || typeof active === 'string') &&
(old === undefined ||
(Array.isArray(old) && old.every((i) => typeof i === 'string')));
if (!isValid) {
console.log('Invalid accounts file schema, starting fresh.');
return defaultState;
}
return {
active: parsed.active ?? null,
old: parsed.old ?? [],
};
}
private readAccountsSync(filePath: string): UserAccounts {
const defaultState = { active: null, old: [] };
try {
const content = readFileSync(filePath, 'utf-8');
return this.parseAndValidateAccounts(content);
} catch (error) {
if (
error instanceof Error &&
'code' in error &&
error.code === 'ENOENT'
) {
return defaultState;
}
console.log('Error during sync read of accounts, starting fresh.', error);
return defaultState;
}
}
private async readAccounts(filePath: string): Promise<UserAccounts> {
const defaultState = { active: null, old: [] };
try {
const content = await fsp.readFile(filePath, 'utf-8');
return this.parseAndValidateAccounts(content);
} catch (error) {
if (
error instanceof Error &&
'code' in error &&
error.code === 'ENOENT'
) {
return defaultState;
}
console.log('Could not parse accounts file, starting fresh.', error);
return defaultState;
}
}
async cacheGoogleAccount(email: string): Promise<void> {
const filePath = this.getGoogleAccountsCachePath();
await fsp.mkdir(path.dirname(filePath), { recursive: true });
const accounts = await this.readAccounts(filePath);
if (accounts.active && accounts.active !== email) {
if (!accounts.old.includes(accounts.active)) {
accounts.old.push(accounts.active);
}
}
// If the new email was in the old list, remove it
accounts.old = accounts.old.filter((oldEmail) => oldEmail !== email);
accounts.active = email;
await fsp.writeFile(filePath, JSON.stringify(accounts, null, 2), 'utf-8');
}
getCachedGoogleAccount(): string | null {
const filePath = this.getGoogleAccountsCachePath();
const accounts = this.readAccountsSync(filePath);
return accounts.active;
}
getLifetimeGoogleAccounts(): number {
const filePath = this.getGoogleAccountsCachePath();
const accounts = this.readAccountsSync(filePath);
const allAccounts = new Set(accounts.old);
if (accounts.active) {
allAccounts.add(accounts.active);
}
return allAccounts.size;
}
async clearCachedGoogleAccount(): Promise<void> {
const filePath = this.getGoogleAccountsCachePath();
const accounts = await this.readAccounts(filePath);
if (accounts.active) {
if (!accounts.old.includes(accounts.active)) {
accounts.old.push(accounts.active);
}
accounts.active = null;
}
await fsp.writeFile(filePath, JSON.stringify(accounts, null, 2), 'utf-8');
}
}

View File

@@ -1,24 +0,0 @@
root = true
[*]
charset = utf-8
end_of_line = lf
insert_final_newline = true
trim_trailing_whitespace = true
indent_style = space
indent_size = 4
tab_width = 4
ij_continuation_indent_size = 8
[*.java]
ij_java_doc_align_exception_comments = false
ij_java_doc_align_param_comments = false
[*.{yaml, yml, sh, ps1}]
indent_size = 2
[*.{md, mkd, markdown}]
trim_trailing_whitespace = false
[{**/res/**.xml, **/AndroidManifest.xml}]
ij_continuation_indent_size = 4

View File

@@ -1,14 +0,0 @@
### IntelliJ IDEA ###
.idea
*.iws
*.iml
*.ipr
# Mac
.DS_Store
# Maven
log/
target/
/docs/

View File

@@ -1,378 +0,0 @@
# Qwen Code Java SDK
## Project Overview
The Qwen Code Java SDK is a minimum experimental SDK for programmatic access to Qwen Code functionality. It provides a Java interface to interact with the Qwen Code CLI, allowing developers to integrate Qwen Code capabilities into their Java applications.
**Context Information:**
- Current Date: Monday 5 January 2026
- Operating System: darwin
- Working Directory: /Users/weigeng/repos/qwen-code/packages/sdk-java
## Project Details
- **Group ID**: com.alibaba
- **Artifact ID**: qwencode-sdk (as per pom.xml)
- **Version**: 0.0.1-SNAPSHOT
- **Packaging**: JAR
- **Java Version**: 1.8+ (source and target)
- **License**: Apache-2.0
## Architecture
The SDK follows a layered architecture:
- **API Layer**: Provides the main entry points through `QwenCodeCli` class with simple static methods for basic usage
- **Session Layer**: Manages communication sessions with the Qwen Code CLI through the `Session` class
- **Transport Layer**: Handles the communication mechanism between the SDK and CLI process (currently using process transport via `ProcessTransport`)
- **Protocol Layer**: Defines data structures for communication based on the CLI protocol
- **Utils**: Common utilities for concurrent execution, timeout handling, and error management
## Key Components
### Main Classes
- `QwenCodeCli`: Main entry point with static methods for simple queries
- `Session`: Manages communication sessions with the CLI
- `Transport`: Abstracts the communication mechanism (currently using process transport)
- `ProcessTransport`: Implementation that communicates via process execution
- `TransportOptions`: Configuration class for transport layer settings
- `SessionEventSimpleConsumers`: High-level event handler for processing responses
- `AssistantContentSimpleConsumers`: Handles different types of content within assistant messages
### Dependencies
- **Logging**: ch.qos.logback:logback-classic
- **Utilities**: org.apache.commons:commons-lang3
- **JSON Processing**: com.alibaba.fastjson2:fastjson2
- **Testing**: JUnit 5 (org.junit.jupiter:junit-jupiter)
## Building and Running
### Prerequisites
- Java 8 or higher
- Apache Maven 3.6.0 or higher
### Build Commands
```bash
# Compile the project
mvn compile
# Run tests
mvn test
# Package the JAR
mvn package
# Install to local repository
mvn install
# Run checkstyle verification
mvn checkstyle:check
# Generate Javadoc
mvn javadoc:javadoc
```
### Testing
The project includes basic unit tests using JUnit 5. The main test class `QwenCodeCliTest` demonstrates how to use the SDK to make simple queries to the Qwen Code CLI.
### Code Quality
The project uses Checkstyle for code formatting and style enforcement. The configuration is defined in `checkstyle.xml` and includes rules for:
- Whitespace and indentation
- Naming conventions
- Import ordering
- Code structure
- Line endings (LF only)
- No trailing whitespace
- 8-space indentation for line wrapping
## Development Conventions
### Coding Standards
- Java 8 language features are supported
- Follow standard Java naming conventions
- Use UTF-8 encoding for source files
- Line endings should be LF (Unix-style)
- No trailing whitespace allowed
- Use 8-space indentation for line wrapping
### Testing Practices
- Write unit tests using JUnit 5
- Test classes should be in the `src/test/java` directory
- Follow the naming convention `*Test.java` for test classes
- Use appropriate assertions to validate functionality
### Documentation
- API documentation should follow JavaDoc conventions
- Update README files when adding new features
- Include examples in documentation
## API Reference
### QwenCodeCli Class
The main class provides several primary methods:
- `simpleQuery(String prompt)`: Synchronous method that returns a list of responses
- `simpleQuery(String prompt, TransportOptions transportOptions)`: Synchronous method with custom transport options
- `simpleQuery(String prompt, TransportOptions transportOptions, AssistantContentConsumers assistantContentConsumers)`: Advanced method with custom content consumers
- `newSession()`: Creates a new session with default options
- `newSession(TransportOptions transportOptions)`: Creates a new session with custom options
### Permission Modes
The SDK supports different permission modes for controlling tool execution:
- **`default`**: Write tools are denied unless approved via `canUseTool` callback or in `allowedTools`. Read-only tools execute without confirmation.
- **`plan`**: Blocks all write tools, instructing AI to present a plan first.
- **`auto-edit`**: Auto-approve edit tools (edit, write_file) while other tools require confirmation.
- **`yolo`**: All tools execute automatically without confirmation.
### Transport Options
The `TransportOptions` class allows configuration of how the SDK communicates with the Qwen Code CLI:
- `pathToQwenExecutable`: Path to the Qwen Code CLI executable
- `cwd`: Working directory for the CLI process
- `model`: AI model to use for the session
- `permissionMode`: Permission mode that controls tool execution
- `env`: Environment variables to pass to the CLI process
- `maxSessionTurns`: Limits the number of conversation turns in a session
- `coreTools`: List of core tools that should be available to the AI
- `excludeTools`: List of tools to exclude from being available to the AI
- `allowedTools`: List of tools that are pre-approved for use without additional confirmation
- `authType`: Authentication type to use for the session
- `includePartialMessages`: Enables receiving partial messages during streaming responses
- `skillsEnable`: Enables or disables skills functionality for the session
- `turnTimeout`: Timeout for a complete turn of conversation
- `messageTimeout`: Timeout for individual messages within a turn
- `resumeSessionId`: ID of a previous session to resume
- `otherOptions`: Additional command-line options to pass to the CLI
### Session Control Features
- **Session creation**: Use `QwenCodeCli.newSession()` to create a new session with custom options
- **Session management**: The `Session` class provides methods to send prompts, handle responses, and manage session state
- **Session cleanup**: Always close sessions using `session.close()` to properly terminate the CLI process
- **Session resumption**: Use `setResumeSessionId()` in `TransportOptions` to resume a previous session
- **Session interruption**: Use `session.interrupt()` to interrupt a currently running prompt
- **Dynamic model switching**: Use `session.setModel()` to change the model during a session
- **Dynamic permission mode switching**: Use `session.setPermissionMode()` to change the permission mode during a session
### Thread Pool Configuration
The SDK uses a thread pool for managing concurrent operations with the following default configuration:
- **Core Pool Size**: 30 threads
- **Maximum Pool Size**: 100 threads
- **Keep-Alive Time**: 60 seconds
- **Queue Capacity**: 300 tasks (using LinkedBlockingQueue)
- **Thread Naming**: "qwen_code_cli-pool-{number}"
- **Daemon Threads**: false
- **Rejected Execution Handler**: CallerRunsPolicy
### Session Event Consumers and Assistant Content Consumers
The SDK provides two key interfaces for handling events and content from the CLI:
#### SessionEventConsumers Interface
The `SessionEventConsumers` interface provides callbacks for different types of messages during a session:
- `onSystemMessage`: Handles system messages from the CLI (receives Session and SDKSystemMessage)
- `onResultMessage`: Handles result messages from the CLI (receives Session and SDKResultMessage)
- `onAssistantMessage`: Handles assistant messages (AI responses) (receives Session and SDKAssistantMessage)
- `onPartialAssistantMessage`: Handles partial assistant messages during streaming (receives Session and SDKPartialAssistantMessage)
- `onUserMessage`: Handles user messages (receives Session and SDKUserMessage)
- `onOtherMessage`: Handles other types of messages (receives Session and String message)
- `onControlResponse`: Handles control responses (receives Session and CLIControlResponse)
- `onControlRequest`: Handles control requests (receives Session and CLIControlRequest, returns CLIControlResponse)
- `onPermissionRequest`: Handles permission requests (receives Session and CLIControlRequest<CLIControlPermissionRequest>, returns Behavior)
#### AssistantContentConsumers Interface
The `AssistantContentConsumers` interface handles different types of content within assistant messages:
- `onText`: Handles text content (receives Session and TextAssistantContent)
- `onThinking`: Handles thinking content (receives Session and ThingkingAssistantContent)
- `onToolUse`: Handles tool use content (receives Session and ToolUseAssistantContent)
- `onToolResult`: Handles tool result content (receives Session and ToolResultAssistantContent)
- `onOtherContent`: Handles other content types (receives Session and AssistantContent)
- `onUsage`: Handles usage information (receives Session and AssistantUsage)
- `onPermissionRequest`: Handles permission requests (receives Session and CLIControlPermissionRequest, returns Behavior)
- `onOtherControlRequest`: Handles other control requests (receives Session and ControlRequestPayload, returns ControlResponsePayload)
#### Relationship Between the Interfaces
**Important Note on Event Hierarchy:**
- `SessionEventConsumers` is the **high-level** event processor that handles different message types (system, assistant, user, etc.)
- `AssistantContentConsumers` is the **low-level** content processor that handles different types of content within assistant messages (text, tools, thinking, etc.)
**Processor Relationship:**
- `SessionEventConsumers``AssistantContentConsumers` (SessionEventConsumers uses AssistantContentConsumers to process content within assistant messages)
**Event Derivation Relationships:**
- `onAssistantMessage``onText`, `onThinking`, `onToolUse`, `onToolResult`, `onOtherContent`, `onUsage`
- `onPartialAssistantMessage``onText`, `onThinking`, `onToolUse`, `onToolResult`, `onOtherContent`
- `onControlRequest``onPermissionRequest`, `onOtherControlRequest`
**Event Timeout Relationships:**
Each event handler method has a corresponding timeout method that allows customizing the timeout behavior for that specific event:
- `onSystemMessage``onSystemMessageTimeout`
- `onResultMessage``onResultMessageTimeout`
- `onAssistantMessage``onAssistantMessageTimeout`
- `onPartialAssistantMessage``onPartialAssistantMessageTimeout`
- `onUserMessage``onUserMessageTimeout`
- `onOtherMessage``onOtherMessageTimeout`
- `onControlResponse``onControlResponseTimeout`
- `onControlRequest``onControlRequestTimeout`
For AssistantContentConsumers timeout methods:
- `onText``onTextTimeout`
- `onThinking``onThinkingTimeout`
- `onToolUse``onToolUseTimeout`
- `onToolResult``onToolResultTimeout`
- `onOtherContent``onOtherContentTimeout`
- `onPermissionRequest``onPermissionRequestTimeout`
- `onOtherControlRequest``onOtherControlRequestTimeout`
**Default Timeout Values:**
- `SessionEventSimpleConsumers` default timeout: 180 seconds (Timeout.TIMEOUT_180_SECONDS)
- `AssistantContentSimpleConsumers` default timeout: 60 seconds (Timeout.TIMEOUT_60_SECONDS)
**Timeout Hierarchy Requirements:**
For proper operation, the following timeout relationships should be maintained:
- `onAssistantMessageTimeout` return value should be greater than `onTextTimeout`, `onThinkingTimeout`, `onToolUseTimeout`, `onToolResultTimeout`, and `onOtherContentTimeout` return values
- `onControlRequestTimeout` return value should be greater than `onPermissionRequestTimeout` and `onOtherControlRequestTimeout` return values
#### Relationship Between the Interfaces
- `AssistantContentSimpleConsumers` is the default implementation of `AssistantContentConsumers`
- `SessionEventSimpleConsumers` is the concrete implementation that combines both interfaces and depends on an `AssistantContentConsumers` instance to handle content within assistant messages
- The timeout methods in `SessionEventConsumers` now include the message object as a parameter (e.g., `onSystemMessageTimeout(Session session, SDKSystemMessage systemMessage)`)
Event processing is subject to the timeout settings configured in `TransportOptions` and `SessionEventConsumers`. For detailed timeout configuration options, see the "Timeout" section above.
## Usage Examples
The SDK includes several example files in `src/test/java/com/alibaba/qwen/code/cli/example/` that demonstrate different aspects of the API:
### Basic Usage
- `QuickStartExample.java`: Demonstrates simple query usage, transport options configuration, and streaming content handling
### Session Control
- `SessionExample.java`: Shows session control features including permission mode changes, model switching, interruption, and event handling
### Configuration
- `ThreadPoolConfigurationExample.java`: Shows how to configure the thread pool used by the SDK
## Error Handling
The SDK provides specific exception types for different error scenarios:
- `SessionControlException`: Thrown when there's an issue with session control (creation, initialization, etc.)
- `SessionSendPromptException`: Thrown when there's an issue sending a prompt or receiving a response
- `SessionClosedException`: Thrown when attempting to use a closed session
## Project Structure
```
src/
├── example/
│ └── java/
│ └── com/
│ └── alibaba/
│ └── qwen/
│ └── code/
│ └── example/
├── main/
│ └── java/
│ └── com/
│ └── alibaba/
│ └── qwen/
│ └── code/
│ └── cli/
│ ├── QwenCodeCli.java
│ ├── protocol/
│ ├── session/
│ ├── transport/
│ └── utils/
└── test/
├── java/
│ └── com/
│ └── alibaba/
│ └── qwen/
│ └── code/
│ └── cli/
│ ├── QwenCodeCliTest.java
│ ├── session/
│ │ └── SessionTest.java
│ └── transport/
│ ├── PermissionModeTest.java
│ └── process/
│ └── ProcessTransportTest.java
└── temp/
```
## Configuration Files
- `pom.xml`: Maven build configuration and dependencies
- `checkstyle.xml`: Code style and formatting rules
- `.editorconfig`: Editor configuration settings
## FAQ / Troubleshooting
### Q: Do I need to install the Qwen CLI separately?
A: No, from v0.1.1, the CLI is bundled with the SDK, so no standalone CLI installation is needed.
### Q: What Java versions are supported?
A: The SDK requires Java 1.8 or higher.
### Q: How do I handle long-running requests?
A: The SDK includes timeout utilities. You can configure timeouts using the `Timeout` class in `TransportOptions`.
### Q: Why are some tools not executing?
A: This is likely due to permission modes. Check your permission mode settings and consider using `allowedTools` to pre-approve certain tools.
### Q: How do I resume a previous session?
A: Use the `setResumeSessionId()` method in `TransportOptions` to resume a previous session.
### Q: Can I customize the environment for the CLI process?
A: Yes, use the `setEnv()` method in `TransportOptions` to pass environment variables to the CLI process.
### Q: What happens if the CLI process crashes?
A: The SDK will throw appropriate exceptions. Make sure to handle `SessionControlException` and implement retry logic if needed.
## Maintainers
- **Developer**: skyfire (gengwei.gw(at)alibaba-inc.com)
- **Organization**: Alibaba Group

View File

@@ -1,312 +0,0 @@
# Qwen Code Java SDK
The Qwen Code Java SDK is a minimum experimental SDK for programmatic access to Qwen Code functionality. It provides a Java interface to interact with the Qwen Code CLI, allowing developers to integrate Qwen Code capabilities into their Java applications.
## Requirements
- Java >= 1.8
- Maven >= 3.6.0 (for building from source)
- qwen-code >= 0.5.0
### Dependencies
- **Logging**: ch.qos.logback:logback-classic
- **Utilities**: org.apache.commons:commons-lang3
- **JSON Processing**: com.alibaba.fastjson2:fastjson2
- **Testing**: JUnit 5 (org.junit.jupiter:junit-jupiter)
## Installation
Add the following dependency to your Maven `pom.xml`:
```xml
<dependency>
<groupId>com.alibaba</groupId>
<artifactId>qwencode-sdk</artifactId>
<version>{$version}</version>
</dependency>
```
Or if using Gradle, add to your `build.gradle`:
```gradle
implementation 'com.alibaba:qwencode-sdk:{$version}'
```
## Building and Running
### Build Commands
```bash
# Compile the project
mvn compile
# Run tests
mvn test
# Package the JAR
mvn package
# Install to local repository
mvn install
```
## Quick Start
The simplest way to use the SDK is through the `QwenCodeCli.simpleQuery()` method:
```java
public static void runSimpleExample() {
List<String> result = QwenCodeCli.simpleQuery("hello world");
result.forEach(logger::info);
}
```
For more advanced usage with custom transport options:
```java
public static void runTransportOptionsExample() {
TransportOptions options = new TransportOptions()
.setModel("qwen3-coder-flash")
.setPermissionMode(PermissionMode.AUTO_EDIT)
.setCwd("./")
.setEnv(new HashMap<String, String>() {{put("CUSTOM_VAR", "value");}})
.setIncludePartialMessages(true)
.setTurnTimeout(new Timeout(120L, TimeUnit.SECONDS))
.setMessageTimeout(new Timeout(90L, TimeUnit.SECONDS))
.setAllowedTools(Arrays.asList("read_file", "write_file", "list_directory"));
List<String> result = QwenCodeCli.simpleQuery("who are you, what are your capabilities?", options);
result.forEach(logger::info);
}
```
For streaming content handling with custom content consumers:
```java
public static void runStreamingExample() {
QwenCodeCli.simpleQuery("who are you, what are your capabilities?",
new TransportOptions().setMessageTimeout(new Timeout(10L, TimeUnit.SECONDS)), new AssistantContentSimpleConsumers() {
@Override
public void onText(Session session, TextAssistantContent textAssistantContent) {
logger.info("Text content received: {}", textAssistantContent.getText());
}
@Override
public void onThinking(Session session, ThingkingAssistantContent thingkingAssistantContent) {
logger.info("Thinking content received: {}", thingkingAssistantContent.getThinking());
}
@Override
public void onToolUse(Session session, ToolUseAssistantContent toolUseContent) {
logger.info("Tool use content received: {} with arguments: {}",
toolUseContent, toolUseContent.getInput());
}
@Override
public void onToolResult(Session session, ToolResultAssistantContent toolResultContent) {
logger.info("Tool result content received: {}", toolResultContent.getContent());
}
@Override
public void onOtherContent(Session session, AssistantContent<?> other) {
logger.info("Other content received: {}", other);
}
@Override
public void onUsage(Session session, AssistantUsage assistantUsage) {
logger.info("Usage information received: Input tokens: {}, Output tokens: {}",
assistantUsage.getUsage().getInputTokens(), assistantUsage.getUsage().getOutputTokens());
}
}.setDefaultPermissionOperation(Operation.allow));
logger.info("Streaming example completed.");
}
```
other examples see src/test/java/com/alibaba/qwen/code/cli/example
## Architecture
The SDK follows a layered architecture:
- **API Layer**: Provides the main entry points through `QwenCodeCli` class with simple static methods for basic usage
- **Session Layer**: Manages communication sessions with the Qwen Code CLI through the `Session` class
- **Transport Layer**: Handles the communication mechanism between the SDK and CLI process (currently using process transport via `ProcessTransport`)
- **Protocol Layer**: Defines data structures for communication based on the CLI protocol
- **Utils**: Common utilities for concurrent execution, timeout handling, and error management
## Key Features
### Permission Modes
The SDK supports different permission modes for controlling tool execution:
- **`default`**: Write tools are denied unless approved via `canUseTool` callback or in `allowedTools`. Read-only tools execute without confirmation.
- **`plan`**: Blocks all write tools, instructing AI to present a plan first.
- **`auto-edit`**: Auto-approve edit tools (edit, write_file) while other tools require confirmation.
- **`yolo`**: All tools execute automatically without confirmation.
### Session Event Consumers and Assistant Content Consumers
The SDK provides two key interfaces for handling events and content from the CLI:
#### SessionEventConsumers Interface
The `SessionEventConsumers` interface provides callbacks for different types of messages during a session:
- `onSystemMessage`: Handles system messages from the CLI (receives Session and SDKSystemMessage)
- `onResultMessage`: Handles result messages from the CLI (receives Session and SDKResultMessage)
- `onAssistantMessage`: Handles assistant messages (AI responses) (receives Session and SDKAssistantMessage)
- `onPartialAssistantMessage`: Handles partial assistant messages during streaming (receives Session and SDKPartialAssistantMessage)
- `onUserMessage`: Handles user messages (receives Session and SDKUserMessage)
- `onOtherMessage`: Handles other types of messages (receives Session and String message)
- `onControlResponse`: Handles control responses (receives Session and CLIControlResponse)
- `onControlRequest`: Handles control requests (receives Session and CLIControlRequest, returns CLIControlResponse)
- `onPermissionRequest`: Handles permission requests (receives Session and CLIControlRequest<CLIControlPermissionRequest>, returns Behavior)
#### AssistantContentConsumers Interface
The `AssistantContentConsumers` interface handles different types of content within assistant messages:
- `onText`: Handles text content (receives Session and TextAssistantContent)
- `onThinking`: Handles thinking content (receives Session and ThingkingAssistantContent)
- `onToolUse`: Handles tool use content (receives Session and ToolUseAssistantContent)
- `onToolResult`: Handles tool result content (receives Session and ToolResultAssistantContent)
- `onOtherContent`: Handles other content types (receives Session and AssistantContent)
- `onUsage`: Handles usage information (receives Session and AssistantUsage)
- `onPermissionRequest`: Handles permission requests (receives Session and CLIControlPermissionRequest, returns Behavior)
- `onOtherControlRequest`: Handles other control requests (receives Session and ControlRequestPayload, returns ControlResponsePayload)
#### Relationship Between the Interfaces
**Important Note on Event Hierarchy:**
- `SessionEventConsumers` is the **high-level** event processor that handles different message types (system, assistant, user, etc.)
- `AssistantContentConsumers` is the **low-level** content processor that handles different types of content within assistant messages (text, tools, thinking, etc.)
**Processor Relationship:**
- `SessionEventConsumers``AssistantContentConsumers` (SessionEventConsumers uses AssistantContentConsumers to process content within assistant messages)
**Event Derivation Relationships:**
- `onAssistantMessage``onText`, `onThinking`, `onToolUse`, `onToolResult`, `onOtherContent`, `onUsage`
- `onPartialAssistantMessage``onText`, `onThinking`, `onToolUse`, `onToolResult`, `onOtherContent`
- `onControlRequest``onPermissionRequest`, `onOtherControlRequest`
**Event Timeout Relationships:**
Each event handler method has a corresponding timeout method that allows customizing the timeout behavior for that specific event:
- `onSystemMessage``onSystemMessageTimeout`
- `onResultMessage``onResultMessageTimeout`
- `onAssistantMessage``onAssistantMessageTimeout`
- `onPartialAssistantMessage``onPartialAssistantMessageTimeout`
- `onUserMessage``onUserMessageTimeout`
- `onOtherMessage``onOtherMessageTimeout`
- `onControlResponse``onControlResponseTimeout`
- `onControlRequest``onControlRequestTimeout`
For AssistantContentConsumers timeout methods:
- `onText``onTextTimeout`
- `onThinking``onThinkingTimeout`
- `onToolUse``onToolUseTimeout`
- `onToolResult``onToolResultTimeout`
- `onOtherContent``onOtherContentTimeout`
- `onPermissionRequest``onPermissionRequestTimeout`
- `onOtherControlRequest``onOtherControlRequestTimeout`
**Default Timeout Values:**
- `SessionEventSimpleConsumers` default timeout: 180 seconds (Timeout.TIMEOUT_180_SECONDS)
- `AssistantContentSimpleConsumers` default timeout: 60 seconds (Timeout.TIMEOUT_60_SECONDS)
**Timeout Hierarchy Requirements:**
For proper operation, the following timeout relationships should be maintained:
- `onAssistantMessageTimeout` return value should be greater than `onTextTimeout`, `onThinkingTimeout`, `onToolUseTimeout`, `onToolResultTimeout`, and `onOtherContentTimeout` return values
- `onControlRequestTimeout` return value should be greater than `onPermissionRequestTimeout` and `onOtherControlRequestTimeout` return values
### Transport Options
The `TransportOptions` class allows configuration of how the SDK communicates with the Qwen Code CLI:
- `pathToQwenExecutable`: Path to the Qwen Code CLI executable
- `cwd`: Working directory for the CLI process
- `model`: AI model to use for the session
- `permissionMode`: Permission mode that controls tool execution
- `env`: Environment variables to pass to the CLI process
- `maxSessionTurns`: Limits the number of conversation turns in a session
- `coreTools`: List of core tools that should be available to the AI
- `excludeTools`: List of tools to exclude from being available to the AI
- `allowedTools`: List of tools that are pre-approved for use without additional confirmation
- `authType`: Authentication type to use for the session
- `includePartialMessages`: Enables receiving partial messages during streaming responses
- `skillsEnable`: Enables or disables skills functionality for the session
- `turnTimeout`: Timeout for a complete turn of conversation
- `messageTimeout`: Timeout for individual messages within a turn
- `resumeSessionId`: ID of a previous session to resume
- `otherOptions`: Additional command-line options to pass to the CLI
### Session Control Features
- **Session creation**: Use `QwenCodeCli.newSession()` to create a new session with custom options
- **Session management**: The `Session` class provides methods to send prompts, handle responses, and manage session state
- **Session cleanup**: Always close sessions using `session.close()` to properly terminate the CLI process
- **Session resumption**: Use `setResumeSessionId()` in `TransportOptions` to resume a previous session
- **Session interruption**: Use `session.interrupt()` to interrupt a currently running prompt
- **Dynamic model switching**: Use `session.setModel()` to change the model during a session
- **Dynamic permission mode switching**: Use `session.setPermissionMode()` to change the permission mode during a session
### Thread Pool Configuration
The SDK uses a thread pool for managing concurrent operations with the following default configuration:
- **Core Pool Size**: 30 threads
- **Maximum Pool Size**: 100 threads
- **Keep-Alive Time**: 60 seconds
- **Queue Capacity**: 300 tasks (using LinkedBlockingQueue)
- **Thread Naming**: "qwen_code_cli-pool-{number}"
- **Daemon Threads**: false
- **Rejected Execution Handler**: CallerRunsPolicy
## Error Handling
The SDK provides specific exception types for different error scenarios:
- `SessionControlException`: Thrown when there's an issue with session control (creation, initialization, etc.)
- `SessionSendPromptException`: Thrown when there's an issue sending a prompt or receiving a response
- `SessionClosedException`: Thrown when attempting to use a closed session
## FAQ / Troubleshooting
### Q: Do I need to install the Qwen CLI separately?
A: No, from v0.1.1, the CLI is bundled with the SDK, so no standalone CLI installation is needed.
### Q: What Java versions are supported?
A: The SDK requires Java 1.8 or higher.
### Q: How do I handle long-running requests?
A: The SDK includes timeout utilities. You can configure timeouts using the `Timeout` class in `TransportOptions`.
### Q: Why are some tools not executing?
A: This is likely due to permission modes. Check your permission mode settings and consider using `allowedTools` to pre-approve certain tools.
### Q: How do I resume a previous session?
A: Use the `setResumeSessionId()` method in `TransportOptions` to resume a previous session.
### Q: Can I customize the environment for the CLI process?
A: Yes, use the `setEnv()` method in `TransportOptions` to pass environment variables to the CLI process.
## License
Apache-2.0 - see [LICENSE](./LICENSE) for details.

View File

@@ -1,131 +0,0 @@
<?xml version="1.0" encoding="UTF-8"?>
<!DOCTYPE module PUBLIC
"-//Puppy Crawl//DTD Check Configuration 1.3//EN"
"http://checkstyle.sourceforge.net/dtds/configuration_1_3.dtd">
<module name="Checker">
<module name="FileTabCharacter" />
<module name="NewlineAtEndOfFile">
<property name="lineSeparator" value="lf" />
</module>
<module name="RegexpMultiline">
<property name="format" value="\r" />
<property name="message" value="Line contains carriage return" />
</module>
<module name="RegexpMultiline">
<property name="format" value=" \n" />
<property name="message" value="Line has trailing whitespace" />
</module>
<module name="RegexpMultiline">
<property name="format" value="\n\n\n" />
<property name="message" value="Multiple consecutive blank lines" />
</module>
<module name="RegexpMultiline">
<property name="format" value="\n\n\Z" />
<property name="message" value="Blank line before end of file" />
</module>
<module name="RegexpMultiline">
<property name="format" value="\{\n\n" />
<property name="message" value="Blank line after opening brace" />
</module>
<module name="RegexpMultiline">
<property name="format" value="\n\n\s*\}" />
<property name="message" value="Blank line before closing brace" />
</module>
<module name="RegexpMultiline">
<property name="format" value="->\s*\{\s+\}" />
<property name="message" value="Whitespace inside empty lambda body" />
</module>
<module name="TreeWalker">
<module name="SuppressWarningsHolder" />
<module name="EmptyBlock">
<property name="option" value="text" />
<property name="tokens" value="
LITERAL_DO, LITERAL_ELSE, LITERAL_FINALLY, LITERAL_IF,
LITERAL_FOR, LITERAL_TRY, LITERAL_WHILE, INSTANCE_INIT, STATIC_INIT" />
</module>
<module name="EmptyStatement" />
<module name="EmptyForInitializerPad" />
<module name="MethodParamPad">
<property name="allowLineBreaks" value="true" />
<property name="option" value="nospace" />
</module>
<module name="ParenPad" />
<module name="TypecastParenPad" />
<module name="NeedBraces" />
<module name="LeftCurly">
<property name="option" value="eol" />
<property name="tokens" value="
LITERAL_CATCH, LITERAL_DO, LITERAL_ELSE, LITERAL_FINALLY, LITERAL_FOR,
LITERAL_IF, LITERAL_SWITCH, LITERAL_SYNCHRONIZED, LITERAL_TRY, LITERAL_WHILE" />
</module>
<module name="GenericWhitespace" />
<module name="WhitespaceAfter" />
<module name="NoWhitespaceAfter" />
<module name="NoWhitespaceBefore" />
<module name="SingleSpaceSeparator" />
<module name="Indentation">
<property name="throwsIndent" value="8" />
<property name="lineWrappingIndentation" value="8" />
</module>
<module name="UpperEll" />
<module name="DefaultComesLast" />
<module name="ArrayTypeStyle" />
<module name="ModifierOrder" />
<module name="OneStatementPerLine" />
<module name="StringLiteralEquality" />
<module name="MutableException" />
<module name="EqualsHashCode" />
<module name="ExplicitInitialization" />
<module name="OneTopLevelClass" />
<module name="MemberName" />
<module name="PackageName" />
<module name="ClassTypeParameterName">
<property name="format" value="^[A-Z][0-9]?$" />
</module>
<module name="MethodTypeParameterName">
<property name="format" value="^[A-Z][0-9]?$" />
</module>
<module name="AnnotationUseStyle">
<property name="trailingArrayComma" value="ignore" />
</module>
<module name="RedundantImport" />
<module name="UnusedImports" />
<!-- <module name="ImportOrder">-->
<!-- <property name="groups" value="*,javax,java" />-->
<!-- <property name="separated" value="true" />-->
<!-- <property name="option" value="bottom" />-->
<!-- <property name="sortStaticImportsAlphabetically" value="true" />-->
<!-- </module>-->
<module name="WhitespaceAround">
<property name="allowEmptyConstructors" value="true" />
<property name="allowEmptyMethods" value="true" />
<property name="allowEmptyLambdas" value="true" />
<property name="ignoreEnhancedForColon" value="false" />
<property name="tokens" value="
ASSIGN, BAND, BAND_ASSIGN, BOR, BOR_ASSIGN, BSR, BSR_ASSIGN,
BXOR, BXOR_ASSIGN, COLON, DIV, DIV_ASSIGN, DO_WHILE, EQUAL, GE, GT, LAND,
LAMBDA, LE, LITERAL_ASSERT, LITERAL_CATCH, LITERAL_DO, LITERAL_ELSE,
LITERAL_FINALLY, LITERAL_FOR, LITERAL_IF, LITERAL_RETURN, LITERAL_SWITCH,
LITERAL_SYNCHRONIZED, LITERAL_TRY, LITERAL_WHILE,
LOR, LT, MINUS, MINUS_ASSIGN, MOD, MOD_ASSIGN, NOT_EQUAL,
PLUS, PLUS_ASSIGN, QUESTION, SL, SLIST, SL_ASSIGN, SR, SR_ASSIGN,
STAR, STAR_ASSIGN, TYPE_EXTENSION_AND" />
</module>
<module name="WhitespaceAfter" />
<module name="NoWhitespaceAfter">
<property name="tokens" value="DOT" />
<property name="allowLineBreaks" value="false" />
</module>
<module name="MissingOverride"/>
</module>
</module>

View File

@@ -1,190 +0,0 @@
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/maven-v4_0_0.xsd">
<modelVersion>4.0.0</modelVersion>
<groupId>com.alibaba</groupId>
<artifactId>qwencode-sdk</artifactId>
<packaging>jar</packaging>
<version>0.0.1-alpha1</version>
<name>qwencode-sdk</name>
<url>https://maven.apache.org</url>
<licenses>
<license>
<name>Apache 2</name>
<url>https://www.apache.org/licenses/LICENSE-2.0.txt</url>
<distribution>repo</distribution>
<comments>A business-friendly OSS license</comments>
</license>
</licenses>
<scm>
<url>https://github.com/QwenLM/qwen-code</url>
<connection>scm:git:https://github.com/QwenLM/qwen-code.git</connection>
</scm>
<properties>
<maven.compiler.target>1.8</maven.compiler.target>
<maven.compiler.source>1.8</maven.compiler.source>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
<checkstyle-maven-plugin.version>3.6.0</checkstyle-maven-plugin.version>
<jacoco-maven-plugin.version>0.8.12</jacoco-maven-plugin.version>
<junit5.version>5.14.1</junit5.version>
<logback-classic.version>1.3.16</logback-classic.version>
<fastjson2.version>2.0.60</fastjson2.version>
<maven-compiler-plugin.version>3.13.0</maven-compiler-plugin.version>
<central-publishing-maven-plugin.version>9</central-publishing-maven-plugin.version>
<maven-source-plugin.version>2</maven-source-plugin.version>
<maven-javadoc-plugin.version>2.9.1</maven-javadoc-plugin.version>
<maven-gpg-plugin.version>1.5</maven-gpg-plugin.version>
</properties>
<dependencyManagement>
<dependencies>
<dependency>
<groupId>org.junit</groupId>
<artifactId>junit-bom</artifactId>
<type>pom</type>
<version>${junit5.version}</version>
<scope>import</scope>
</dependency>
</dependencies>
</dependencyManagement>
<dependencies>
<dependency>
<groupId>ch.qos.logback</groupId>
<artifactId>logback-classic</artifactId>
<version>${logback-classic.version}</version>
</dependency>
<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-lang3</artifactId>
<version>3.20.0</version>
</dependency>
<dependency>
<groupId>com.alibaba.fastjson2</groupId>
<artifactId>fastjson2</artifactId>
<version>${fastjson2.version}</version>
</dependency>
<dependency>
<groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter</artifactId>
<scope>test</scope>
</dependency>
</dependencies>
<build>
<plugins>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-checkstyle-plugin</artifactId>
<version>${checkstyle-maven-plugin.version}</version>
<configuration>
<configLocation>checkstyle.xml</configLocation>
</configuration>
<executions>
<execution>
<goals>
<goal>check</goal>
</goals>
</execution>
</executions>
</plugin>
<plugin>
<groupId>org.jacoco</groupId>
<artifactId>jacoco-maven-plugin</artifactId>
<version>${jacoco-maven-plugin.version}</version>
<executions>
<execution>
<goals>
<goal>prepare-agent</goal>
</goals>
</execution>
<execution>
<id>report</id>
<phase>test</phase>
<goals>
<goal>report</goal>
</goals>
</execution>
</executions>
</plugin>
<plugin>
<groupId>org.sonatype.central</groupId>
<artifactId>central-publishing-maven-plugin</artifactId>
<version>0.${central-publishing-maven-plugin.version}.0</version>
<extensions>true</extensions>
<configuration>
<publishingServerId>central</publishingServerId>
</configuration>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-source-plugin</artifactId>
<version>${maven-source-plugin.version}.2.1</version>
<executions>
<execution>
<id>attach-sources</id>
<goals>
<goal>jar-no-fork</goal>
</goals>
</execution>
</executions>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-javadoc-plugin</artifactId>
<version>${maven-javadoc-plugin.version}</version>
<executions>
<execution>
<id>attach-javadocs</id>
<goals>
<goal>jar</goal>
</goals>
</execution>
</executions>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-gpg-plugin</artifactId>
<version>${maven-gpg-plugin.version}</version>
<executions>
<execution>
<id>sign-artifacts</id>
<phase>verify</phase>
<goals>
<goal>sign</goal>
</goals>
</execution>
</executions>
</plugin>
</plugins>
</build>
<organization>
<name>Alibaba Group</name>
<url>https://github.com/alibaba</url>
</organization>
<developers>
<developer>
<id>skyfire</id>
<name>skyfire</name>
<email>gengwei.gw(at)alibaba-inc.com</email>
<roles>
<role>Developer</role>
<role>Designer</role>
</roles>
<timezone>+8</timezone>
<url>https://github.com/gwinthis</url>
</developer>
</developers>
<distributionManagement>
<snapshotRepository>
<id>central</id>
<url>https://central.sonatype.com/repository/maven-snapshots/</url>
</snapshotRepository>
<repository>
<id>central</id>
<url>https://central.sonatype.com/service/local/staging/deploy/maven2/</url>
</repository>
</distributionManagement>
</project>

View File

@@ -1,142 +0,0 @@
package com.alibaba.qwen.code.cli;
import java.util.ArrayList;
import java.util.List;
import com.alibaba.fastjson2.JSON;
import com.alibaba.qwen.code.cli.protocol.data.AssistantUsage;
import com.alibaba.qwen.code.cli.protocol.data.AssistantContent;
import com.alibaba.qwen.code.cli.protocol.data.AssistantContent.TextAssistantContent;
import com.alibaba.qwen.code.cli.protocol.data.AssistantContent.ThingkingAssistantContent;
import com.alibaba.qwen.code.cli.protocol.data.AssistantContent.ToolResultAssistantContent;
import com.alibaba.qwen.code.cli.protocol.data.AssistantContent.ToolUseAssistantContent;
import com.alibaba.qwen.code.cli.protocol.data.behavior.Behavior.Operation;
import com.alibaba.qwen.code.cli.session.Session;
import com.alibaba.qwen.code.cli.session.event.consumers.AssistantContentConsumers;
import com.alibaba.qwen.code.cli.session.event.consumers.AssistantContentSimpleConsumers;
import com.alibaba.qwen.code.cli.session.event.consumers.SessionEventSimpleConsumers;
import com.alibaba.qwen.code.cli.transport.Transport;
import com.alibaba.qwen.code.cli.transport.TransportOptions;
import com.alibaba.qwen.code.cli.transport.process.ProcessTransport;
import com.alibaba.qwen.code.cli.utils.MyConcurrentUtils;
import com.alibaba.qwen.code.cli.utils.Timeout;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* Main entry point for interacting with the Qwen Code CLI. Provides static methods for simple queries and session management.
*
* @author skyfire
* @version $Id: 0.0.1
*/
public class QwenCodeCli {
private static final Logger log = LoggerFactory.getLogger(QwenCodeCli.class);
/**
* Sends a simple query to the Qwen Code CLI and returns a list of responses.
*
* @param prompt The input prompt to send to the CLI
* @return A list of strings representing the CLI's responses
*/
public static List<String> simpleQuery(String prompt) {
return simpleQuery(prompt, new TransportOptions());
}
/**
* Sends a simple query with custom transport options.
*
* @param prompt The input prompt to send to the CLI
* @param transportOptions Configuration options for the transport layer
* @return A list of strings representing the CLI's responses
*/
public static List<String> simpleQuery(String prompt, TransportOptions transportOptions) {
final List<String> response = new ArrayList<>();
MyConcurrentUtils.runAndWait(() -> simpleQuery(prompt, transportOptions, new AssistantContentSimpleConsumers() {
@Override
public void onText(Session session, TextAssistantContent textAssistantContent) {
response.add(textAssistantContent.getText());
}
@Override
public void onThinking(Session session, ThingkingAssistantContent thingkingAssistantContent) {
response.add(thingkingAssistantContent.getThinking());
}
@Override
public void onToolUse(Session session, ToolUseAssistantContent toolUseAssistantContent) {
response.add(JSON.toJSONString(toolUseAssistantContent.getContentOfAssistant()));
}
@Override
public void onToolResult(Session session, ToolResultAssistantContent toolResultAssistantContent) {
response.add(JSON.toJSONString(toolResultAssistantContent));
}
public void onOtherContent(Session session, AssistantContent<?> other) {
response.add(JSON.toJSONString(other.getContentOfAssistant()));
}
@Override
public void onUsage(Session session, AssistantUsage assistantUsage) {
log.info("received usage {} of message {}", assistantUsage.getUsage(), assistantUsage.getMessageId());
}
}.setDefaultPermissionOperation(Operation.allow)), Timeout.TIMEOUT_30_MINUTES);
return response;
}
/**
* Sends a query with custom content consumers.
*
* @param prompt The input prompt to send to the CLI
* @param transportOptions Configuration options for the transport layer
* @param assistantContentConsumers Consumers for handling different types of assistant content
*/
public static void simpleQuery(String prompt, TransportOptions transportOptions, AssistantContentConsumers assistantContentConsumers) {
Session session = newSession(transportOptions);
try {
session.sendPrompt(prompt, new SessionEventSimpleConsumers()
.setAssistantContentConsumer(assistantContentConsumers));
} catch (Exception e) {
throw new RuntimeException("sendPrompt error!", e);
} finally {
try {
session.close();
} catch (Exception e) {
log.error("close session error!", e);
}
}
}
/**
* Creates a new session with default transport options.
*
* @return A new Session instance
*/
public static Session newSession() {
return newSession(new TransportOptions());
}
/**
* Creates a new session with custom transport options.
*
* @param transportOptions Configuration options for the transport layer
* @return A new Session instance
*/
public static Session newSession(TransportOptions transportOptions) {
Transport transport;
try {
transport = new ProcessTransport(transportOptions);
} catch (Exception e) {
throw new RuntimeException("initialized ProcessTransport error!", e);
}
Session session;
try {
session = new Session(transport);
} catch (Exception e) {
throw new RuntimeException("initialized Session error!", e);
}
return session;
}
}

View File

@@ -1,95 +0,0 @@
package com.alibaba.qwen.code.cli.protocol.data;
import java.util.Map;
/**
* Represents content from the assistant in a Qwen Code session.
*
* @param <C> The type of content
* @author skyfire
* @version $Id: 0.0.1
*/
public interface AssistantContent<C> {
/**
* Gets the type of the assistant content.
*
* @return The type of the assistant content
*/
String getType();
/**
* Gets the actual content from the assistant.
*
* @return The content from the assistant
*/
C getContentOfAssistant();
/**
* Gets the message ID associated with this content.
*
* @return The message ID
*/
String getMessageId();
/**
* Represents text content from the assistant.
*/
interface TextAssistantContent extends AssistantContent<String> {
/**
* Gets the text content.
*
* @return The text content
*/
String getText();
}
/**
* Represents thinking content from the assistant.
*/
interface ThingkingAssistantContent extends AssistantContent<String> {
/**
* Gets the thinking content.
*
* @return The thinking content
*/
String getThinking();
}
/**
* Represents tool use content from the assistant.
*/
interface ToolUseAssistantContent extends AssistantContent<Map<String, Object>> {
/**
* Gets the tool input.
*
* @return The tool input
*/
Map<String, Object> getInput();
}
/**
* Represents tool result content from the assistant.
*/
interface ToolResultAssistantContent extends AssistantContent<String> {
/**
* Gets whether the tool result indicates an error.
*
* @return Whether the tool result indicates an error
*/
Boolean getIsError();
/**
* Gets the tool result content.
*
* @return The tool result content
*/
String getContent();
/**
* Gets the tool use ID.
*
* @return The tool use ID
*/
String getToolUseId();
}
}

View File

@@ -1,76 +0,0 @@
package com.alibaba.qwen.code.cli.protocol.data;
import com.alibaba.fastjson2.JSON;
/**
* Represents usage information for an assistant message.
*
* @author skyfire
* @version $Id: 0.0.1
*/
public class AssistantUsage {
/**
* The ID of the message.
*/
String messageId;
/**
* The usage information.
*/
Usage usage;
/**
* Gets the message ID.
*
* @return The message ID
*/
public String getMessageId() {
return messageId;
}
/**
* Sets the message ID.
*
* @param messageId The message ID
*/
public void setMessageId(String messageId) {
this.messageId = messageId;
}
/**
* Gets the usage information.
*
* @return The usage information
*/
public Usage getUsage() {
return usage;
}
/**
* Sets the usage information.
*
* @param usage The usage information
*/
public void setUsage(Usage usage) {
this.usage = usage;
}
/**
* Constructs a new AssistantUsage instance.
*
* @param messageId The message ID
* @param usage The usage information
*/
public AssistantUsage(String messageId, Usage usage) {
this.messageId = messageId;
this.usage = usage;
}
/**
* <p>toString.</p>
*
* @return a {@link java.lang.String} object.
*/
public String toString() {
return JSON.toJSONString(this);
}
}

View File

@@ -1,83 +0,0 @@
package com.alibaba.qwen.code.cli.protocol.data;
import com.alibaba.fastjson2.annotation.JSONField;
/**
* Represents a permission denial from the CLI.
*
* @author skyfire
* @version $Id: 0.0.1
*/
public class CLIPermissionDenial {
/**
* The name of the denied tool.
*/
@JSONField(name = "tool_name")
private String toolName;
/**
* The ID of the denied tool use.
*/
@JSONField(name = "tool_use_id")
private String toolUseId;
/**
* The input for the denied tool.
*/
@JSONField(name = "tool_input")
private Object toolInput;
/**
* Gets the name of the denied tool.
*
* @return The name of the denied tool
*/
public String getToolName() {
return toolName;
}
/**
* Sets the name of the denied tool.
*
* @param toolName The name of the denied tool
*/
public void setToolName(String toolName) {
this.toolName = toolName;
}
/**
* Gets the ID of the denied tool use.
*
* @return The ID of the denied tool use
*/
public String getToolUseId() {
return toolUseId;
}
/**
* Sets the ID of the denied tool use.
*
* @param toolUseId The ID of the denied tool use
*/
public void setToolUseId(String toolUseId) {
this.toolUseId = toolUseId;
}
/**
* Gets the input for the denied tool.
*
* @return The input for the denied tool
*/
public Object getToolInput() {
return toolInput;
}
/**
* Sets the input for the denied tool.
*
* @param toolInput The input for the denied tool
*/
public void setToolInput(Object toolInput) {
this.toolInput = toolInput;
}
}

Some files were not shown because too many files have changed in this diff Show More