Add Gemini provider, remove legacy Google OAuth, and tune generation defaults

This commit is contained in:
tanzhenxin
2025-12-19 16:26:54 +08:00
parent d07ae35c90
commit 17129024f4
86 changed files with 2409 additions and 7506 deletions

View File

@@ -23,8 +23,8 @@
"scripts/postinstall.js"
],
"dependencies": {
"@google/genai": "1.16.0",
"@modelcontextprotocol/sdk": "^1.11.0",
"@google/genai": "1.30.0",
"@modelcontextprotocol/sdk": "^1.25.1",
"@opentelemetry/api": "^1.9.0",
"async-mutex": "^0.5.0",
"@opentelemetry/exporter-logs-otlp-grpc": "^0.203.0",
@@ -34,7 +34,6 @@
"@opentelemetry/exporter-trace-otlp-grpc": "^0.203.0",
"@opentelemetry/exporter-trace-otlp-http": "^0.203.0",
"@opentelemetry/instrumentation-http": "^0.203.0",
"@opentelemetry/resource-detector-gcp": "^0.40.0",
"@opentelemetry/sdk-node": "^0.203.0",
"@types/html-to-text": "^9.0.4",
"@xterm/headless": "5.5.0",
@@ -48,7 +47,7 @@
"fdir": "^6.4.6",
"fzf": "^0.5.2",
"glob": "^10.5.0",
"google-auth-library": "^9.11.0",
"google-auth-library": "^10.5.0",
"html-to-text": "^9.0.5",
"https-proxy-agent": "^7.0.6",
"ignore": "^7.0.0",

View File

@@ -1,54 +0,0 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import type { ContentGenerator } from '../core/contentGenerator.js';
import { AuthType } from '../core/contentGenerator.js';
import { getOauthClient } from './oauth2.js';
import { setupUser } from './setup.js';
import type { HttpOptions } from './server.js';
import { CodeAssistServer } from './server.js';
import type { Config } from '../config/config.js';
import { LoggingContentGenerator } from '../core/loggingContentGenerator.js';
export async function createCodeAssistContentGenerator(
httpOptions: HttpOptions,
authType: AuthType,
config: Config,
sessionId?: string,
): Promise<ContentGenerator> {
if (
authType === AuthType.LOGIN_WITH_GOOGLE ||
authType === AuthType.CLOUD_SHELL
) {
const authClient = await getOauthClient(authType, config);
const userData = await setupUser(authClient);
return new CodeAssistServer(
authClient,
userData.projectId,
httpOptions,
sessionId,
userData.userTier,
);
}
throw new Error(`Unsupported authType: ${authType}`);
}
export function getCodeAssistServer(
config: Config,
): CodeAssistServer | undefined {
let server = config.getContentGenerator();
// Unwrap LoggingContentGenerator if present
if (server instanceof LoggingContentGenerator) {
server = server.getWrapped();
}
if (!(server instanceof CodeAssistServer)) {
return undefined;
}
return server;
}

View File

@@ -1,456 +0,0 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { describe, it, expect } from 'vitest';
import type { CaGenerateContentResponse } from './converter.js';
import {
toGenerateContentRequest,
fromGenerateContentResponse,
toContents,
} from './converter.js';
import type {
ContentListUnion,
GenerateContentParameters,
} from '@google/genai';
import {
GenerateContentResponse,
FinishReason,
BlockedReason,
type Part,
} from '@google/genai';
describe('converter', () => {
describe('toCodeAssistRequest', () => {
it('should convert a simple request with project', () => {
const genaiReq: GenerateContentParameters = {
model: 'gemini-pro',
contents: [{ role: 'user', parts: [{ text: 'Hello' }] }],
};
const codeAssistReq = toGenerateContentRequest(
genaiReq,
'my-prompt',
'my-project',
'my-session',
);
expect(codeAssistReq).toEqual({
model: 'gemini-pro',
project: 'my-project',
request: {
contents: [{ role: 'user', parts: [{ text: 'Hello' }] }],
systemInstruction: undefined,
cachedContent: undefined,
tools: undefined,
toolConfig: undefined,
labels: undefined,
safetySettings: undefined,
generationConfig: undefined,
session_id: 'my-session',
},
user_prompt_id: 'my-prompt',
});
});
it('should convert a request without a project', () => {
const genaiReq: GenerateContentParameters = {
model: 'gemini-pro',
contents: [{ role: 'user', parts: [{ text: 'Hello' }] }],
};
const codeAssistReq = toGenerateContentRequest(
genaiReq,
'my-prompt',
undefined,
'my-session',
);
expect(codeAssistReq).toEqual({
model: 'gemini-pro',
project: undefined,
request: {
contents: [{ role: 'user', parts: [{ text: 'Hello' }] }],
systemInstruction: undefined,
cachedContent: undefined,
tools: undefined,
toolConfig: undefined,
labels: undefined,
safetySettings: undefined,
generationConfig: undefined,
session_id: 'my-session',
},
user_prompt_id: 'my-prompt',
});
});
it('should convert a request with sessionId', () => {
const genaiReq: GenerateContentParameters = {
model: 'gemini-pro',
contents: [{ role: 'user', parts: [{ text: 'Hello' }] }],
};
const codeAssistReq = toGenerateContentRequest(
genaiReq,
'my-prompt',
'my-project',
'session-123',
);
expect(codeAssistReq).toEqual({
model: 'gemini-pro',
project: 'my-project',
request: {
contents: [{ role: 'user', parts: [{ text: 'Hello' }] }],
systemInstruction: undefined,
cachedContent: undefined,
tools: undefined,
toolConfig: undefined,
labels: undefined,
safetySettings: undefined,
generationConfig: undefined,
session_id: 'session-123',
},
user_prompt_id: 'my-prompt',
});
});
it('should handle string content', () => {
const genaiReq: GenerateContentParameters = {
model: 'gemini-pro',
contents: 'Hello',
};
const codeAssistReq = toGenerateContentRequest(
genaiReq,
'my-prompt',
'my-project',
'my-session',
);
expect(codeAssistReq.request.contents).toEqual([
{ role: 'user', parts: [{ text: 'Hello' }] },
]);
});
it('should handle Part[] content', () => {
const genaiReq: GenerateContentParameters = {
model: 'gemini-pro',
contents: [{ text: 'Hello' }, { text: 'World' }],
};
const codeAssistReq = toGenerateContentRequest(
genaiReq,
'my-prompt',
'my-project',
'my-session',
);
expect(codeAssistReq.request.contents).toEqual([
{ role: 'user', parts: [{ text: 'Hello' }] },
{ role: 'user', parts: [{ text: 'World' }] },
]);
});
it('should handle system instructions', () => {
const genaiReq: GenerateContentParameters = {
model: 'gemini-pro',
contents: 'Hello',
config: {
systemInstruction: 'You are a helpful assistant.',
},
};
const codeAssistReq = toGenerateContentRequest(
genaiReq,
'my-prompt',
'my-project',
'my-session',
);
expect(codeAssistReq.request.systemInstruction).toEqual({
role: 'user',
parts: [{ text: 'You are a helpful assistant.' }],
});
});
it('should handle generation config', () => {
const genaiReq: GenerateContentParameters = {
model: 'gemini-pro',
contents: 'Hello',
config: {
temperature: 0.8,
topK: 40,
},
};
const codeAssistReq = toGenerateContentRequest(
genaiReq,
'my-prompt',
'my-project',
'my-session',
);
expect(codeAssistReq.request.generationConfig).toEqual({
temperature: 0.8,
topK: 40,
});
});
it('should handle all generation config fields', () => {
const genaiReq: GenerateContentParameters = {
model: 'gemini-pro',
contents: 'Hello',
config: {
temperature: 0.1,
topP: 0.2,
topK: 3,
candidateCount: 4,
maxOutputTokens: 5,
stopSequences: ['a'],
responseLogprobs: true,
logprobs: 6,
presencePenalty: 0.7,
frequencyPenalty: 0.8,
seed: 9,
responseMimeType: 'application/json',
},
};
const codeAssistReq = toGenerateContentRequest(
genaiReq,
'my-prompt',
'my-project',
'my-session',
);
expect(codeAssistReq.request.generationConfig).toEqual({
temperature: 0.1,
topP: 0.2,
topK: 3,
candidateCount: 4,
maxOutputTokens: 5,
stopSequences: ['a'],
responseLogprobs: true,
logprobs: 6,
presencePenalty: 0.7,
frequencyPenalty: 0.8,
seed: 9,
responseMimeType: 'application/json',
});
});
});
describe('fromCodeAssistResponse', () => {
it('should convert a simple response', () => {
const codeAssistRes: CaGenerateContentResponse = {
response: {
candidates: [
{
index: 0,
content: {
role: 'model',
parts: [{ text: 'Hi there!' }],
},
finishReason: FinishReason.STOP,
safetyRatings: [],
},
],
},
};
const genaiRes = fromGenerateContentResponse(codeAssistRes);
expect(genaiRes).toBeInstanceOf(GenerateContentResponse);
expect(genaiRes.candidates).toEqual(codeAssistRes.response.candidates);
});
it('should handle prompt feedback and usage metadata', () => {
const codeAssistRes: CaGenerateContentResponse = {
response: {
candidates: [],
promptFeedback: {
blockReason: BlockedReason.SAFETY,
safetyRatings: [],
},
usageMetadata: {
promptTokenCount: 10,
candidatesTokenCount: 20,
totalTokenCount: 30,
},
},
};
const genaiRes = fromGenerateContentResponse(codeAssistRes);
expect(genaiRes.promptFeedback).toEqual(
codeAssistRes.response.promptFeedback,
);
expect(genaiRes.usageMetadata).toEqual(
codeAssistRes.response.usageMetadata,
);
});
it('should handle automatic function calling history', () => {
const codeAssistRes: CaGenerateContentResponse = {
response: {
candidates: [],
automaticFunctionCallingHistory: [
{
role: 'model',
parts: [
{
functionCall: {
name: 'test_function',
args: {
foo: 'bar',
},
},
},
],
},
],
},
};
const genaiRes = fromGenerateContentResponse(codeAssistRes);
expect(genaiRes.automaticFunctionCallingHistory).toEqual(
codeAssistRes.response.automaticFunctionCallingHistory,
);
});
it('should handle modelVersion', () => {
const codeAssistRes: CaGenerateContentResponse = {
response: {
candidates: [],
modelVersion: 'qwen3-coder-plus',
},
};
const genaiRes = fromGenerateContentResponse(codeAssistRes);
expect(genaiRes.modelVersion).toEqual('qwen3-coder-plus');
});
});
describe('toContents', () => {
it('should handle Content', () => {
const content: ContentListUnion = {
role: 'user',
parts: [{ text: 'hello' }],
};
expect(toContents(content)).toEqual([
{ role: 'user', parts: [{ text: 'hello' }] },
]);
});
it('should handle array of Contents', () => {
const contents: ContentListUnion = [
{ role: 'user', parts: [{ text: 'hello' }] },
{ role: 'model', parts: [{ text: 'hi' }] },
];
expect(toContents(contents)).toEqual([
{ role: 'user', parts: [{ text: 'hello' }] },
{ role: 'model', parts: [{ text: 'hi' }] },
]);
});
it('should handle Part', () => {
const part: ContentListUnion = { text: 'a part' };
expect(toContents(part)).toEqual([
{ role: 'user', parts: [{ text: 'a part' }] },
]);
});
it('should handle array of Parts', () => {
const parts = [{ text: 'part 1' }, 'part 2'];
expect(toContents(parts)).toEqual([
{ role: 'user', parts: [{ text: 'part 1' }] },
{ role: 'user', parts: [{ text: 'part 2' }] },
]);
});
it('should handle string', () => {
const str: ContentListUnion = 'a string';
expect(toContents(str)).toEqual([
{ role: 'user', parts: [{ text: 'a string' }] },
]);
});
it('should handle array of strings', () => {
const strings: ContentListUnion = ['string 1', 'string 2'];
expect(toContents(strings)).toEqual([
{ role: 'user', parts: [{ text: 'string 1' }] },
{ role: 'user', parts: [{ text: 'string 2' }] },
]);
});
it('should convert thought parts to text parts for API compatibility', () => {
const contentWithThought: ContentListUnion = {
role: 'model',
parts: [
{ text: 'regular text' },
{ thought: 'thinking about the problem' } as Part & {
thought: string;
},
{ text: 'more text' },
],
};
expect(toContents(contentWithThought)).toEqual([
{
role: 'model',
parts: [
{ text: 'regular text' },
{ text: '[Thought: thinking about the problem]' },
{ text: 'more text' },
],
},
]);
});
it('should combine text and thought for text parts with thoughts', () => {
const contentWithTextAndThought: ContentListUnion = {
role: 'model',
parts: [
{
text: 'Here is my response',
thought: 'I need to be careful here',
} as Part & { thought: string },
],
};
expect(toContents(contentWithTextAndThought)).toEqual([
{
role: 'model',
parts: [
{
text: 'Here is my response\n[Thought: I need to be careful here]',
},
],
},
]);
});
it('should preserve non-thought properties while removing thought', () => {
const contentWithComplexPart: ContentListUnion = {
role: 'model',
parts: [
{
functionCall: { name: 'calculate', args: { x: 5, y: 10 } },
thought: 'Performing calculation',
} as Part & { thought: string },
],
};
expect(toContents(contentWithComplexPart)).toEqual([
{
role: 'model',
parts: [
{
functionCall: { name: 'calculate', args: { x: 5, y: 10 } },
},
],
},
]);
});
it('should convert invalid text content to valid text part with thought', () => {
const contentWithInvalidText: ContentListUnion = {
role: 'model',
parts: [
{
text: 123, // Invalid - should be string
thought: 'Processing number',
} as Part & { thought: string; text: number },
],
};
expect(toContents(contentWithInvalidText)).toEqual([
{
role: 'model',
parts: [
{
text: '123\n[Thought: Processing number]',
},
],
},
]);
});
});
});

View File

@@ -1,285 +0,0 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import type {
Content,
ContentListUnion,
ContentUnion,
GenerateContentConfig,
GenerateContentParameters,
CountTokensParameters,
CountTokensResponse,
GenerationConfigRoutingConfig,
MediaResolution,
Candidate,
ModelSelectionConfig,
GenerateContentResponsePromptFeedback,
GenerateContentResponseUsageMetadata,
Part,
SafetySetting,
PartUnion,
SpeechConfigUnion,
ThinkingConfig,
ToolListUnion,
ToolConfig,
} from '@google/genai';
import { GenerateContentResponse } from '@google/genai';
export interface CAGenerateContentRequest {
model: string;
project?: string;
user_prompt_id?: string;
request: VertexGenerateContentRequest;
}
interface VertexGenerateContentRequest {
contents: Content[];
systemInstruction?: Content;
cachedContent?: string;
tools?: ToolListUnion;
toolConfig?: ToolConfig;
labels?: Record<string, string>;
safetySettings?: SafetySetting[];
generationConfig?: VertexGenerationConfig;
session_id?: string;
}
interface VertexGenerationConfig {
temperature?: number;
topP?: number;
topK?: number;
candidateCount?: number;
maxOutputTokens?: number;
stopSequences?: string[];
responseLogprobs?: boolean;
logprobs?: number;
presencePenalty?: number;
frequencyPenalty?: number;
seed?: number;
responseMimeType?: string;
responseJsonSchema?: unknown;
responseSchema?: unknown;
routingConfig?: GenerationConfigRoutingConfig;
modelSelectionConfig?: ModelSelectionConfig;
responseModalities?: string[];
mediaResolution?: MediaResolution;
speechConfig?: SpeechConfigUnion;
audioTimestamp?: boolean;
thinkingConfig?: ThinkingConfig;
}
export interface CaGenerateContentResponse {
response: VertexGenerateContentResponse;
}
interface VertexGenerateContentResponse {
candidates: Candidate[];
automaticFunctionCallingHistory?: Content[];
promptFeedback?: GenerateContentResponsePromptFeedback;
usageMetadata?: GenerateContentResponseUsageMetadata;
modelVersion?: string;
}
export interface CaCountTokenRequest {
request: VertexCountTokenRequest;
}
interface VertexCountTokenRequest {
model: string;
contents: Content[];
}
export interface CaCountTokenResponse {
totalTokens: number;
}
export function toCountTokenRequest(
req: CountTokensParameters,
): CaCountTokenRequest {
return {
request: {
model: 'models/' + req.model,
contents: toContents(req.contents),
},
};
}
export function fromCountTokenResponse(
res: CaCountTokenResponse,
): CountTokensResponse {
return {
totalTokens: res.totalTokens,
};
}
export function toGenerateContentRequest(
req: GenerateContentParameters,
userPromptId: string,
project?: string,
sessionId?: string,
): CAGenerateContentRequest {
return {
model: req.model,
project,
user_prompt_id: userPromptId,
request: toVertexGenerateContentRequest(req, sessionId),
};
}
export function fromGenerateContentResponse(
res: CaGenerateContentResponse,
): GenerateContentResponse {
const inres = res.response;
const out = new GenerateContentResponse();
out.candidates = inres.candidates;
out.automaticFunctionCallingHistory = inres.automaticFunctionCallingHistory;
out.promptFeedback = inres.promptFeedback;
out.usageMetadata = inres.usageMetadata;
out.modelVersion = inres.modelVersion;
return out;
}
function toVertexGenerateContentRequest(
req: GenerateContentParameters,
sessionId?: string,
): VertexGenerateContentRequest {
return {
contents: toContents(req.contents),
systemInstruction: maybeToContent(req.config?.systemInstruction),
cachedContent: req.config?.cachedContent,
tools: req.config?.tools,
toolConfig: req.config?.toolConfig,
labels: req.config?.labels,
safetySettings: req.config?.safetySettings,
generationConfig: toVertexGenerationConfig(req.config),
session_id: sessionId,
};
}
export function toContents(contents: ContentListUnion): Content[] {
if (Array.isArray(contents)) {
// it's a Content[] or a PartsUnion[]
return contents.map(toContent);
}
// it's a Content or a PartsUnion
return [toContent(contents)];
}
function maybeToContent(content?: ContentUnion): Content | undefined {
if (!content) {
return undefined;
}
return toContent(content);
}
function toContent(content: ContentUnion): Content {
if (Array.isArray(content)) {
// it's a PartsUnion[]
return {
role: 'user',
parts: toParts(content),
};
}
if (typeof content === 'string') {
// it's a string
return {
role: 'user',
parts: [{ text: content }],
};
}
if ('parts' in content) {
// it's a Content - process parts to handle thought filtering
return {
...content,
parts: content.parts
? toParts(content.parts.filter((p) => p != null))
: [],
};
}
// it's a Part
return {
role: 'user',
parts: [toPart(content as Part)],
};
}
export function toParts(parts: PartUnion[]): Part[] {
return parts.map(toPart);
}
function toPart(part: PartUnion): Part {
if (typeof part === 'string') {
// it's a string
return { text: part };
}
// Handle thought parts for CountToken API compatibility
// The CountToken API expects parts to have certain required "oneof" fields initialized,
// but thought parts don't conform to this schema and cause API failures
if ('thought' in part && part.thought) {
const thoughtText = `[Thought: ${part.thought}]`;
const newPart = { ...part };
delete (newPart as Record<string, unknown>)['thought'];
const hasApiContent =
'functionCall' in newPart ||
'functionResponse' in newPart ||
'inlineData' in newPart ||
'fileData' in newPart;
if (hasApiContent) {
// It's a functionCall or other non-text part. Just strip the thought.
return newPart;
}
// If no other valid API content, this must be a text part.
// Combine existing text (if any) with the thought, preserving other properties.
const text = (newPart as { text?: unknown }).text;
const existingText = text ? String(text) : '';
const combinedText = existingText
? `${existingText}\n${thoughtText}`
: thoughtText;
return {
...newPart,
text: combinedText,
};
}
return part;
}
function toVertexGenerationConfig(
config?: GenerateContentConfig,
): VertexGenerationConfig | undefined {
if (!config) {
return undefined;
}
return {
temperature: config.temperature,
topP: config.topP,
topK: config.topK,
candidateCount: config.candidateCount,
maxOutputTokens: config.maxOutputTokens,
stopSequences: config.stopSequences,
responseLogprobs: config.responseLogprobs,
logprobs: config.logprobs,
presencePenalty: config.presencePenalty,
frequencyPenalty: config.frequencyPenalty,
seed: config.seed,
responseMimeType: config.responseMimeType,
responseSchema: config.responseSchema,
responseJsonSchema: config.responseJsonSchema,
routingConfig: config.routingConfig,
modelSelectionConfig: config.modelSelectionConfig,
responseModalities: config.responseModalities,
mediaResolution: config.mediaResolution,
speechConfig: config.speechConfig,
audioTimestamp: config.audioTimestamp,
thinkingConfig: config.thinkingConfig,
};
}

View File

@@ -1,217 +0,0 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { type Credentials } from 'google-auth-library';
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
import { OAuthCredentialStorage } from './oauth-credential-storage.js';
import type { OAuthCredentials } from '../mcp/token-storage/types.js';
import * as path from 'node:path';
import * as os from 'node:os';
import { promises as fs } from 'node:fs';
// Mock external dependencies
const mockHybridTokenStorage = vi.hoisted(() => ({
getCredentials: vi.fn(),
setCredentials: vi.fn(),
deleteCredentials: vi.fn(),
}));
vi.mock('../mcp/token-storage/hybrid-token-storage.js', () => ({
HybridTokenStorage: vi.fn(() => mockHybridTokenStorage),
}));
vi.mock('node:fs', () => ({
promises: {
readFile: vi.fn(),
rm: vi.fn(),
},
}));
vi.mock('node:os');
vi.mock('node:path');
describe('OAuthCredentialStorage', () => {
const mockCredentials: Credentials = {
access_token: 'mock_access_token',
refresh_token: 'mock_refresh_token',
expiry_date: Date.now() + 3600 * 1000,
token_type: 'Bearer',
scope: 'email profile',
};
const mockMcpCredentials: OAuthCredentials = {
serverName: 'main-account',
token: {
accessToken: 'mock_access_token',
refreshToken: 'mock_refresh_token',
tokenType: 'Bearer',
scope: 'email profile',
expiresAt: mockCredentials.expiry_date!,
},
updatedAt: expect.any(Number),
};
const oldFilePath = '/mock/home/.qwen/oauth.json';
beforeEach(() => {
vi.spyOn(mockHybridTokenStorage, 'getCredentials').mockResolvedValue(null);
vi.spyOn(mockHybridTokenStorage, 'setCredentials').mockResolvedValue(
undefined,
);
vi.spyOn(mockHybridTokenStorage, 'deleteCredentials').mockResolvedValue(
undefined,
);
vi.spyOn(fs, 'readFile').mockRejectedValue(new Error('File not found'));
vi.spyOn(fs, 'rm').mockResolvedValue(undefined);
vi.spyOn(os, 'homedir').mockReturnValue('/mock/home');
vi.spyOn(path, 'join').mockReturnValue(oldFilePath);
});
afterEach(() => {
vi.restoreAllMocks();
});
describe('loadCredentials', () => {
it('should load credentials from HybridTokenStorage if available', async () => {
vi.spyOn(mockHybridTokenStorage, 'getCredentials').mockResolvedValue(
mockMcpCredentials,
);
const result = await OAuthCredentialStorage.loadCredentials();
expect(mockHybridTokenStorage.getCredentials).toHaveBeenCalledWith(
'main-account',
);
expect(result).toEqual(mockCredentials);
});
it('should fallback to migrateFromFileStorage if no credentials in HybridTokenStorage', async () => {
vi.spyOn(mockHybridTokenStorage, 'getCredentials').mockResolvedValue(
null,
);
vi.spyOn(fs, 'readFile').mockResolvedValue(
JSON.stringify(mockCredentials),
);
const result = await OAuthCredentialStorage.loadCredentials();
expect(mockHybridTokenStorage.getCredentials).toHaveBeenCalledWith(
'main-account',
);
expect(fs.readFile).toHaveBeenCalledWith(oldFilePath, 'utf-8');
expect(mockHybridTokenStorage.setCredentials).toHaveBeenCalled(); // Verify credentials were saved
expect(fs.rm).toHaveBeenCalledWith(oldFilePath, { force: true }); // Verify old file was removed
expect(result).toEqual(mockCredentials);
});
it('should return null if no credentials found and no old file to migrate', async () => {
vi.spyOn(fs, 'readFile').mockRejectedValue({
message: 'File not found',
code: 'ENOENT',
});
const result = await OAuthCredentialStorage.loadCredentials();
expect(result).toBeNull();
});
it('should throw an error if loading fails', async () => {
vi.spyOn(mockHybridTokenStorage, 'getCredentials').mockRejectedValue(
new Error('Loading error'),
);
await expect(OAuthCredentialStorage.loadCredentials()).rejects.toThrow(
'Failed to load OAuth credentials',
);
});
it('should throw an error if read file fails', async () => {
vi.spyOn(mockHybridTokenStorage, 'getCredentials').mockResolvedValue(
null,
);
vi.spyOn(fs, 'readFile').mockRejectedValue(
new Error('Permission denied'),
);
await expect(OAuthCredentialStorage.loadCredentials()).rejects.toThrow(
'Failed to load OAuth credentials',
);
});
it('should not throw error if migration file removal failed', async () => {
vi.spyOn(mockHybridTokenStorage, 'getCredentials').mockResolvedValue(
null,
);
vi.spyOn(fs, 'readFile').mockResolvedValue(
JSON.stringify(mockCredentials),
);
vi.spyOn(OAuthCredentialStorage, 'saveCredentials').mockResolvedValue(
undefined,
);
vi.spyOn(fs, 'rm').mockRejectedValue(new Error('Deletion failed'));
const result = await OAuthCredentialStorage.loadCredentials();
expect(result).toEqual(mockCredentials);
});
});
describe('saveCredentials', () => {
it('should save credentials to HybridTokenStorage', async () => {
await OAuthCredentialStorage.saveCredentials(mockCredentials);
expect(mockHybridTokenStorage.setCredentials).toHaveBeenCalledWith(
mockMcpCredentials,
);
});
it('should throw an error if access_token is missing', async () => {
const invalidCredentials: Credentials = {
...mockCredentials,
access_token: undefined,
};
await expect(
OAuthCredentialStorage.saveCredentials(invalidCredentials),
).rejects.toThrow(
'Attempted to save credentials without an access token.',
);
});
});
describe('clearCredentials', () => {
it('should delete credentials from HybridTokenStorage', async () => {
await OAuthCredentialStorage.clearCredentials();
expect(mockHybridTokenStorage.deleteCredentials).toHaveBeenCalledWith(
'main-account',
);
});
it('should attempt to remove the old file-based storage', async () => {
await OAuthCredentialStorage.clearCredentials();
expect(fs.rm).toHaveBeenCalledWith(oldFilePath, { force: true });
});
it('should not throw an error if deleting old file fails', async () => {
vi.spyOn(fs, 'rm').mockRejectedValue(new Error('File deletion failed'));
await expect(
OAuthCredentialStorage.clearCredentials(),
).resolves.toBeUndefined();
});
it('should throw an error if clearing from HybridTokenStorage fails', async () => {
vi.spyOn(mockHybridTokenStorage, 'deleteCredentials').mockRejectedValue(
new Error('Deletion error'),
);
await expect(OAuthCredentialStorage.clearCredentials()).rejects.toThrow(
'Failed to clear OAuth credentials',
);
});
});
});

View File

@@ -1,130 +0,0 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { type Credentials } from 'google-auth-library';
import { HybridTokenStorage } from '../mcp/token-storage/hybrid-token-storage.js';
import { OAUTH_FILE } from '../config/storage.js';
import type { OAuthCredentials } from '../mcp/token-storage/types.js';
import * as path from 'node:path';
import * as os from 'node:os';
import { promises as fs } from 'node:fs';
const QWEN_DIR = '.qwen';
const KEYCHAIN_SERVICE_NAME = 'qwen-code-oauth';
const MAIN_ACCOUNT_KEY = 'main-account';
export class OAuthCredentialStorage {
private static storage: HybridTokenStorage = new HybridTokenStorage(
KEYCHAIN_SERVICE_NAME,
);
/**
* Load cached OAuth credentials
*/
static async loadCredentials(): Promise<Credentials | null> {
try {
const credentials = await this.storage.getCredentials(MAIN_ACCOUNT_KEY);
if (credentials?.token) {
const { accessToken, refreshToken, expiresAt, tokenType, scope } =
credentials.token;
// Convert from OAuthCredentials format to Google Credentials format
const googleCreds: Credentials = {
access_token: accessToken,
refresh_token: refreshToken || undefined,
token_type: tokenType || undefined,
scope: scope || undefined,
};
if (expiresAt) {
googleCreds.expiry_date = expiresAt;
}
return googleCreds;
}
// Fallback: Try to migrate from old file-based storage
return await this.migrateFromFileStorage();
} catch (error: unknown) {
console.error(error);
throw new Error('Failed to load OAuth credentials');
}
}
/**
* Save OAuth credentials
*/
static async saveCredentials(credentials: Credentials): Promise<void> {
if (!credentials.access_token) {
throw new Error('Attempted to save credentials without an access token.');
}
// Convert Google Credentials to OAuthCredentials format
const mcpCredentials: OAuthCredentials = {
serverName: MAIN_ACCOUNT_KEY,
token: {
accessToken: credentials.access_token,
refreshToken: credentials.refresh_token || undefined,
tokenType: credentials.token_type || 'Bearer',
scope: credentials.scope || undefined,
expiresAt: credentials.expiry_date || undefined,
},
updatedAt: Date.now(),
};
await this.storage.setCredentials(mcpCredentials);
}
/**
* Clear cached OAuth credentials
*/
static async clearCredentials(): Promise<void> {
try {
await this.storage.deleteCredentials(MAIN_ACCOUNT_KEY);
// Also try to remove the old file if it exists
const oldFilePath = path.join(os.homedir(), QWEN_DIR, OAUTH_FILE);
await fs.rm(oldFilePath, { force: true }).catch(() => {});
} catch (error: unknown) {
console.error(error);
throw new Error('Failed to clear OAuth credentials');
}
}
/**
* Migrate credentials from old file-based storage to keychain
*/
private static async migrateFromFileStorage(): Promise<Credentials | null> {
const oldFilePath = path.join(os.homedir(), QWEN_DIR, OAUTH_FILE);
let credsJson: string;
try {
credsJson = await fs.readFile(oldFilePath, 'utf-8');
} catch (error: unknown) {
if (
typeof error === 'object' &&
error !== null &&
'code' in error &&
error.code === 'ENOENT'
) {
// File doesn't exist, so no migration.
return null;
}
// Other read errors should propagate.
throw error;
}
const credentials = JSON.parse(credsJson) as Credentials;
// Save to new storage
await this.saveCredentials(credentials);
// Remove old file after successful migration
await fs.rm(oldFilePath, { force: true }).catch(() => {});
return credentials;
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,563 +0,0 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import type { Credentials } from 'google-auth-library';
import {
CodeChallengeMethod,
Compute,
OAuth2Client,
} from 'google-auth-library';
import crypto from 'node:crypto';
import { promises as fs } from 'node:fs';
import * as http from 'node:http';
import * as net from 'node:net';
import path from 'node:path';
import readline from 'node:readline';
import url from 'node:url';
import open from 'open';
import type { Config } from '../config/config.js';
import { Storage } from '../config/storage.js';
import { AuthType } from '../core/contentGenerator.js';
import { FatalAuthenticationError, getErrorMessage } from '../utils/errors.js';
import { UserAccountManager } from '../utils/userAccountManager.js';
import { OAuthCredentialStorage } from './oauth-credential-storage.js';
import { FORCE_ENCRYPTED_FILE_ENV_VAR } from '../mcp/token-storage/index.js';
const userAccountManager = new UserAccountManager();
// OAuth Client ID used to initiate OAuth2Client class.
const OAUTH_CLIENT_ID =
'681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com';
// OAuth Secret value used to initiate OAuth2Client class.
// Note: It's ok to save this in git because this is an installed application
// as described here: https://developers.google.com/identity/protocols/oauth2#installed
// "The process results in a client ID and, in some cases, a client secret,
// which you embed in the source code of your application. (In this context,
// the client secret is obviously not treated as a secret.)"
const OAUTH_CLIENT_SECRET = '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';
// OAuth Scopes for Cloud Code authorization.
const OAUTH_SCOPE = [
'https://www.googleapis.com/auth/cloud-platform',
'https://www.googleapis.com/auth/userinfo.email',
'https://www.googleapis.com/auth/userinfo.profile',
];
const HTTP_REDIRECT = 301;
const SIGN_IN_SUCCESS_URL =
'https://developers.google.com/gemini-code-assist/auth_success_gemini';
const SIGN_IN_FAILURE_URL =
'https://developers.google.com/gemini-code-assist/auth_failure_gemini';
/**
* An Authentication URL for updating the credentials of a Oauth2Client
* as well as a promise that will resolve when the credentials have
* been refreshed (or which throws error when refreshing credentials failed).
*/
export interface OauthWebLogin {
authUrl: string;
loginCompletePromise: Promise<void>;
}
const oauthClientPromises = new Map<AuthType, Promise<OAuth2Client>>();
function getUseEncryptedStorageFlag() {
return process.env[FORCE_ENCRYPTED_FILE_ENV_VAR] === 'true';
}
async function initOauthClient(
authType: AuthType,
config: Config,
): Promise<OAuth2Client> {
const client = new OAuth2Client({
clientId: OAUTH_CLIENT_ID,
clientSecret: OAUTH_CLIENT_SECRET,
transporterOptions: {
proxy: config.getProxy(),
},
});
const useEncryptedStorage = getUseEncryptedStorageFlag();
if (
process.env['GOOGLE_GENAI_USE_GCA'] &&
process.env['GOOGLE_CLOUD_ACCESS_TOKEN']
) {
client.setCredentials({
access_token: process.env['GOOGLE_CLOUD_ACCESS_TOKEN'],
});
await fetchAndCacheUserInfo(client);
return client;
}
client.on('tokens', async (tokens: Credentials) => {
if (useEncryptedStorage) {
await OAuthCredentialStorage.saveCredentials(tokens);
} else {
await cacheCredentials(tokens);
}
});
// If there are cached creds on disk, they always take precedence
if (await loadCachedCredentials(client)) {
// Found valid cached credentials.
// Check if we need to retrieve Google Account ID or Email
if (!userAccountManager.getCachedGoogleAccount()) {
try {
await fetchAndCacheUserInfo(client);
} catch (error) {
// Non-fatal, continue with existing auth.
console.warn('Failed to fetch user info:', getErrorMessage(error));
}
}
console.log('Loaded cached credentials.');
return client;
}
// In Google Cloud Shell, we can use Application Default Credentials (ADC)
// provided via its metadata server to authenticate non-interactively using
// the identity of the user logged into Cloud Shell.
if (authType === AuthType.CLOUD_SHELL) {
try {
console.log("Attempting to authenticate via Cloud Shell VM's ADC.");
const computeClient = new Compute({
// We can leave this empty, since the metadata server will provide
// the service account email.
});
await computeClient.getAccessToken();
console.log('Authentication successful.');
// Do not cache creds in this case; note that Compute client will handle its own refresh
return computeClient;
} catch (e) {
throw new Error(
`Could not authenticate using Cloud Shell credentials. Please select a different authentication method or ensure you are in a properly configured environment. Error: ${getErrorMessage(
e,
)}`,
);
}
}
if (config.isBrowserLaunchSuppressed()) {
let success = false;
const maxRetries = 2;
for (let i = 0; !success && i < maxRetries; i++) {
success = await authWithUserCode(client);
if (!success) {
console.error(
'\nFailed to authenticate with user code.',
i === maxRetries - 1 ? '' : 'Retrying...\n',
);
}
}
if (!success) {
throw new FatalAuthenticationError(
'Failed to authenticate with user code.',
);
}
} else {
const webLogin = await authWithWeb(client);
console.log(
`\n\nCode Assist login required.\n` +
`Attempting to open authentication page in your browser.\n` +
`Otherwise navigate to:\n\n${webLogin.authUrl}\n\n`,
);
try {
// Attempt to open the authentication URL in the default browser.
// We do not use the `wait` option here because the main script's execution
// is already paused by `loginCompletePromise`, which awaits the server callback.
const childProcess = await open(webLogin.authUrl);
// IMPORTANT: Attach an error handler to the returned child process.
// Without this, if `open` fails to spawn a process (e.g., `xdg-open` is not found
// in a minimal Docker container), it will emit an unhandled 'error' event,
// causing the entire Node.js process to crash.
childProcess.on('error', (error) => {
console.error(
'Failed to open browser automatically. Please try running again with NO_BROWSER=true set.',
);
console.error('Browser error details:', getErrorMessage(error));
});
} catch (err) {
console.error(
'An unexpected error occurred while trying to open the browser:',
getErrorMessage(err),
'\nThis might be due to browser compatibility issues or system configuration.',
'\nPlease try running again with NO_BROWSER=true set for manual authentication.',
);
throw new FatalAuthenticationError(
`Failed to open browser: ${getErrorMessage(err)}`,
);
}
console.log('Waiting for authentication...');
// Add timeout to prevent infinite waiting when browser tab gets stuck
const authTimeout = 5 * 60 * 1000; // 5 minutes timeout
const timeoutPromise = new Promise<never>((_, reject) => {
setTimeout(() => {
reject(
new FatalAuthenticationError(
'Authentication timed out after 5 minutes. The browser tab may have gotten stuck in a loading state. ' +
'Please try again or use NO_BROWSER=true for manual authentication.',
),
);
}, authTimeout);
});
await Promise.race([webLogin.loginCompletePromise, timeoutPromise]);
}
return client;
}
export async function getOauthClient(
authType: AuthType,
config: Config,
): Promise<OAuth2Client> {
if (!oauthClientPromises.has(authType)) {
oauthClientPromises.set(authType, initOauthClient(authType, config));
}
return oauthClientPromises.get(authType)!;
}
async function authWithUserCode(client: OAuth2Client): Promise<boolean> {
const redirectUri = 'https://codeassist.google.com/authcode';
const codeVerifier = await client.generateCodeVerifierAsync();
const state = crypto.randomBytes(32).toString('hex');
const authUrl: string = client.generateAuthUrl({
redirect_uri: redirectUri,
access_type: 'offline',
scope: OAUTH_SCOPE,
code_challenge_method: CodeChallengeMethod.S256,
code_challenge: codeVerifier.codeChallenge,
state,
});
console.log('Please visit the following URL to authorize the application:');
console.log('');
console.log(authUrl);
console.log('');
const code = await new Promise<string>((resolve) => {
const rl = readline.createInterface({
input: process.stdin,
output: process.stdout,
});
rl.question('Enter the authorization code: ', (code) => {
rl.close();
resolve(code.trim());
});
});
if (!code) {
console.error('Authorization code is required.');
return false;
}
try {
const { tokens } = await client.getToken({
code,
codeVerifier: codeVerifier.codeVerifier,
redirect_uri: redirectUri,
});
client.setCredentials(tokens);
} catch (error) {
console.error(
'Failed to authenticate with authorization code:',
getErrorMessage(error),
);
return false;
}
return true;
}
async function authWithWeb(client: OAuth2Client): Promise<OauthWebLogin> {
const port = await getAvailablePort();
// The hostname used for the HTTP server binding (e.g., '0.0.0.0' in Docker).
const host = process.env['OAUTH_CALLBACK_HOST'] || 'localhost';
// The `redirectUri` sent to Google's authorization server MUST use a loopback IP literal
// (i.e., 'localhost' or '127.0.0.1'). This is a strict security policy for credentials of
// type 'Desktop app' or 'Web application' (when using loopback flow) to mitigate
// authorization code interception attacks.
const redirectUri = `http://localhost:${port}/oauth2callback`;
const state = crypto.randomBytes(32).toString('hex');
const authUrl = client.generateAuthUrl({
redirect_uri: redirectUri,
access_type: 'offline',
scope: OAUTH_SCOPE,
state,
});
const loginCompletePromise = new Promise<void>((resolve, reject) => {
const server = http.createServer(async (req, res) => {
try {
if (req.url!.indexOf('/oauth2callback') === -1) {
res.writeHead(HTTP_REDIRECT, { Location: SIGN_IN_FAILURE_URL });
res.end();
reject(
new FatalAuthenticationError(
'OAuth callback not received. Unexpected request: ' + req.url,
),
);
}
// acquire the code from the querystring, and close the web server.
const qs = new url.URL(req.url!, 'http://localhost:3000').searchParams;
if (qs.get('error')) {
res.writeHead(HTTP_REDIRECT, { Location: SIGN_IN_FAILURE_URL });
res.end();
const errorCode = qs.get('error');
const errorDescription =
qs.get('error_description') || 'No additional details provided';
reject(
new FatalAuthenticationError(
`Google OAuth error: ${errorCode}. ${errorDescription}`,
),
);
} else if (qs.get('state') !== state) {
res.end('State mismatch. Possible CSRF attack');
reject(
new FatalAuthenticationError(
'OAuth state mismatch. Possible CSRF attack or browser session issue.',
),
);
} else if (qs.get('code')) {
try {
const { tokens } = await client.getToken({
code: qs.get('code')!,
redirect_uri: redirectUri,
});
client.setCredentials(tokens);
// Retrieve and cache Google Account ID during authentication
try {
await fetchAndCacheUserInfo(client);
} catch (error) {
console.warn(
'Failed to retrieve Google Account ID during authentication:',
getErrorMessage(error),
);
// Don't fail the auth flow if Google Account ID retrieval fails
}
res.writeHead(HTTP_REDIRECT, { Location: SIGN_IN_SUCCESS_URL });
res.end();
resolve();
} catch (error) {
res.writeHead(HTTP_REDIRECT, { Location: SIGN_IN_FAILURE_URL });
res.end();
reject(
new FatalAuthenticationError(
`Failed to exchange authorization code for tokens: ${getErrorMessage(error)}`,
),
);
}
} else {
reject(
new FatalAuthenticationError(
'No authorization code received from Google OAuth. Please try authenticating again.',
),
);
}
} catch (e) {
// Provide more specific error message for unexpected errors during OAuth flow
if (e instanceof FatalAuthenticationError) {
reject(e);
} else {
reject(
new FatalAuthenticationError(
`Unexpected error during OAuth authentication: ${getErrorMessage(e)}`,
),
);
}
} finally {
server.close();
}
});
server.listen(port, host, () => {
// Server started successfully
});
server.on('error', (err) => {
reject(
new FatalAuthenticationError(
`OAuth callback server error: ${getErrorMessage(err)}`,
),
);
});
});
return {
authUrl,
loginCompletePromise,
};
}
export function getAvailablePort(): Promise<number> {
return new Promise((resolve, reject) => {
let port = 0;
try {
const portStr = process.env['OAUTH_CALLBACK_PORT'];
if (portStr) {
port = parseInt(portStr, 10);
if (isNaN(port) || port <= 0 || port > 65535) {
return reject(
new Error(`Invalid value for OAUTH_CALLBACK_PORT: "${portStr}"`),
);
}
return resolve(port);
}
const server = net.createServer();
server.listen(0, () => {
const address = server.address()! as net.AddressInfo;
port = address.port;
});
server.on('listening', () => {
server.close();
server.unref();
});
server.on('error', (e) => reject(e));
server.on('close', () => resolve(port));
} catch (e) {
reject(e);
}
});
}
async function loadCachedCredentials(client: OAuth2Client): Promise<boolean> {
const useEncryptedStorage = getUseEncryptedStorageFlag();
if (useEncryptedStorage) {
const credentials = await OAuthCredentialStorage.loadCredentials();
if (credentials) {
client.setCredentials(credentials);
return true;
}
return false;
}
const pathsToTry = [
Storage.getOAuthCredsPath(),
process.env['GOOGLE_APPLICATION_CREDENTIALS'],
].filter((p): p is string => !!p);
for (const keyFile of pathsToTry) {
try {
const creds = await fs.readFile(keyFile, 'utf-8');
client.setCredentials(JSON.parse(creds));
// This will verify locally that the credentials look good.
const { token } = await client.getAccessToken();
if (!token) {
continue;
}
// This will check with the server to see if it hasn't been revoked.
await client.getTokenInfo(token);
return true;
} catch (error) {
// Log specific error for debugging, but continue trying other paths
console.debug(
`Failed to load credentials from ${keyFile}:`,
getErrorMessage(error),
);
}
}
return false;
}
async function cacheCredentials(credentials: Credentials) {
const filePath = Storage.getOAuthCredsPath();
await fs.mkdir(path.dirname(filePath), { recursive: true });
const credString = JSON.stringify(credentials, null, 2);
await fs.writeFile(filePath, credString, { mode: 0o600 });
try {
await fs.chmod(filePath, 0o600);
} catch {
/* empty */
}
}
export function clearOauthClientCache() {
oauthClientPromises.clear();
}
export async function clearCachedCredentialFile() {
try {
const useEncryptedStorage = getUseEncryptedStorageFlag();
if (useEncryptedStorage) {
await OAuthCredentialStorage.clearCredentials();
} else {
await fs.rm(Storage.getOAuthCredsPath(), { force: true });
}
// Clear the Google Account ID cache when credentials are cleared
await userAccountManager.clearCachedGoogleAccount();
// Clear the in-memory OAuth client cache to force re-authentication
clearOauthClientCache();
/**
* Also clear Qwen SharedTokenManager cache and credentials file to prevent stale credentials
* when switching between auth types
* TODO: We do not depend on code_assist, we'll have to build an independent auth-cleaning procedure.
*/
try {
const { SharedTokenManager } = await import(
'../qwen/sharedTokenManager.js'
);
const { clearQwenCredentials } = await import('../qwen/qwenOAuth2.js');
const sharedManager = SharedTokenManager.getInstance();
sharedManager.clearCache();
await clearQwenCredentials();
} catch (qwenError) {
console.debug('Could not clear Qwen credentials:', qwenError);
}
} catch (e) {
console.error('Failed to clear cached credentials:', e);
}
}
async function fetchAndCacheUserInfo(client: OAuth2Client): Promise<void> {
try {
const { token } = await client.getAccessToken();
if (!token) {
return;
}
const response = await fetch(
'https://www.googleapis.com/oauth2/v2/userinfo',
{
headers: {
Authorization: `Bearer ${token}`,
},
},
);
if (!response.ok) {
console.error(
'Failed to fetch user info:',
response.status,
response.statusText,
);
return;
}
const userInfo = await response.json();
await userAccountManager.cacheGoogleAccount(userInfo.email);
} catch (error) {
console.error('Error retrieving user info:', error);
}
}
// Helper to ensure test isolation
export function resetOauthClientForTesting() {
oauthClientPromises.clear();
}

View File

@@ -1,255 +0,0 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { beforeEach, describe, it, expect, vi } from 'vitest';
import { CodeAssistServer } from './server.js';
import { OAuth2Client } from 'google-auth-library';
import { UserTierId } from './types.js';
vi.mock('google-auth-library');
describe('CodeAssistServer', () => {
beforeEach(() => {
vi.resetAllMocks();
});
it('should be able to be constructed', () => {
const auth = new OAuth2Client();
const server = new CodeAssistServer(
auth,
'test-project',
{},
'test-session',
UserTierId.FREE,
);
expect(server).toBeInstanceOf(CodeAssistServer);
});
it('should call the generateContent endpoint', async () => {
const client = new OAuth2Client();
const server = new CodeAssistServer(
client,
'test-project',
{},
'test-session',
UserTierId.FREE,
);
const mockResponse = {
response: {
candidates: [
{
index: 0,
content: {
role: 'model',
parts: [{ text: 'response' }],
},
finishReason: 'STOP',
safetyRatings: [],
},
],
},
};
vi.spyOn(server, 'requestPost').mockResolvedValue(mockResponse);
const response = await server.generateContent(
{
model: 'test-model',
contents: [{ role: 'user', parts: [{ text: 'request' }] }],
},
'user-prompt-id',
);
expect(server.requestPost).toHaveBeenCalledWith(
'generateContent',
expect.any(Object),
undefined,
);
expect(response.candidates?.[0]?.content?.parts?.[0]?.text).toBe(
'response',
);
});
it('should call the generateContentStream endpoint', async () => {
const client = new OAuth2Client();
const server = new CodeAssistServer(
client,
'test-project',
{},
'test-session',
UserTierId.FREE,
);
const mockResponse = (async function* () {
yield {
response: {
candidates: [
{
index: 0,
content: {
role: 'model',
parts: [{ text: 'response' }],
},
finishReason: 'STOP',
safetyRatings: [],
},
],
},
};
})();
vi.spyOn(server, 'requestStreamingPost').mockResolvedValue(mockResponse);
const stream = await server.generateContentStream(
{
model: 'test-model',
contents: [{ role: 'user', parts: [{ text: 'request' }] }],
},
'user-prompt-id',
);
for await (const res of stream) {
expect(server.requestStreamingPost).toHaveBeenCalledWith(
'streamGenerateContent',
expect.any(Object),
undefined,
);
expect(res.candidates?.[0]?.content?.parts?.[0]?.text).toBe('response');
}
});
it('should call the onboardUser endpoint', async () => {
const client = new OAuth2Client();
const server = new CodeAssistServer(
client,
'test-project',
{},
'test-session',
UserTierId.FREE,
);
const mockResponse = {
name: 'operations/123',
done: true,
};
vi.spyOn(server, 'requestPost').mockResolvedValue(mockResponse);
const response = await server.onboardUser({
tierId: 'test-tier',
cloudaicompanionProject: 'test-project',
metadata: {},
});
expect(server.requestPost).toHaveBeenCalledWith(
'onboardUser',
expect.any(Object),
);
expect(response.name).toBe('operations/123');
});
it('should call the loadCodeAssist endpoint', async () => {
const client = new OAuth2Client();
const server = new CodeAssistServer(
client,
'test-project',
{},
'test-session',
UserTierId.FREE,
);
const mockResponse = {
currentTier: {
id: UserTierId.FREE,
name: 'Free',
description: 'free tier',
},
allowedTiers: [],
ineligibleTiers: [],
cloudaicompanionProject: 'projects/test',
};
vi.spyOn(server, 'requestPost').mockResolvedValue(mockResponse);
const response = await server.loadCodeAssist({
metadata: {},
});
expect(server.requestPost).toHaveBeenCalledWith(
'loadCodeAssist',
expect.any(Object),
);
expect(response).toEqual(mockResponse);
});
it('should return 0 for countTokens', async () => {
const client = new OAuth2Client();
const server = new CodeAssistServer(
client,
'test-project',
{},
'test-session',
UserTierId.FREE,
);
const mockResponse = {
totalTokens: 100,
};
vi.spyOn(server, 'requestPost').mockResolvedValue(mockResponse);
const response = await server.countTokens({
model: 'test-model',
contents: [{ role: 'user', parts: [{ text: 'request' }] }],
});
expect(response.totalTokens).toBe(100);
});
it('should throw an error for embedContent', async () => {
const client = new OAuth2Client();
const server = new CodeAssistServer(
client,
'test-project',
{},
'test-session',
UserTierId.FREE,
);
await expect(
server.embedContent({
model: 'test-model',
contents: [{ role: 'user', parts: [{ text: 'request' }] }],
}),
).rejects.toThrow();
});
it('should handle VPC-SC errors when calling loadCodeAssist', async () => {
const client = new OAuth2Client();
const server = new CodeAssistServer(
client,
'test-project',
{},
'test-session',
UserTierId.FREE,
);
const mockVpcScError = {
response: {
data: {
error: {
details: [
{
reason: 'SECURITY_POLICY_VIOLATED',
},
],
},
},
},
};
vi.spyOn(server, 'requestPost').mockRejectedValue(mockVpcScError);
const response = await server.loadCodeAssist({
metadata: {},
});
expect(server.requestPost).toHaveBeenCalledWith(
'loadCodeAssist',
expect.any(Object),
);
expect(response).toEqual({
currentTier: { id: UserTierId.STANDARD },
});
});
});

View File

@@ -1,253 +0,0 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import type { OAuth2Client } from 'google-auth-library';
import type {
CodeAssistGlobalUserSettingResponse,
GoogleRpcResponse,
LoadCodeAssistRequest,
LoadCodeAssistResponse,
LongRunningOperationResponse,
OnboardUserRequest,
SetCodeAssistGlobalUserSettingRequest,
} from './types.js';
import type {
CountTokensParameters,
CountTokensResponse,
EmbedContentParameters,
EmbedContentResponse,
GenerateContentParameters,
GenerateContentResponse,
} from '@google/genai';
import * as readline from 'node:readline';
import type { ContentGenerator } from '../core/contentGenerator.js';
import { UserTierId } from './types.js';
import type {
CaCountTokenResponse,
CaGenerateContentResponse,
} from './converter.js';
import {
fromCountTokenResponse,
fromGenerateContentResponse,
toCountTokenRequest,
toGenerateContentRequest,
} from './converter.js';
/** HTTP options to be used in each of the requests. */
export interface HttpOptions {
/** Additional HTTP headers to be sent with the request. */
headers?: Record<string, string>;
}
export const CODE_ASSIST_ENDPOINT = 'https://localhost:0'; // Disable Google Code Assist API Request
export const CODE_ASSIST_API_VERSION = 'v1internal';
export class CodeAssistServer implements ContentGenerator {
constructor(
readonly client: OAuth2Client,
readonly projectId?: string,
readonly httpOptions: HttpOptions = {},
readonly sessionId?: string,
readonly userTier?: UserTierId,
) {}
async generateContentStream(
req: GenerateContentParameters,
userPromptId: string,
): Promise<AsyncGenerator<GenerateContentResponse>> {
const resps = await this.requestStreamingPost<CaGenerateContentResponse>(
'streamGenerateContent',
toGenerateContentRequest(
req,
userPromptId,
this.projectId,
this.sessionId,
),
req.config?.abortSignal,
);
return (async function* (): AsyncGenerator<GenerateContentResponse> {
for await (const resp of resps) {
yield fromGenerateContentResponse(resp);
}
})();
}
async generateContent(
req: GenerateContentParameters,
userPromptId: string,
): Promise<GenerateContentResponse> {
const resp = await this.requestPost<CaGenerateContentResponse>(
'generateContent',
toGenerateContentRequest(
req,
userPromptId,
this.projectId,
this.sessionId,
),
req.config?.abortSignal,
);
return fromGenerateContentResponse(resp);
}
async onboardUser(
req: OnboardUserRequest,
): Promise<LongRunningOperationResponse> {
return await this.requestPost<LongRunningOperationResponse>(
'onboardUser',
req,
);
}
async loadCodeAssist(
req: LoadCodeAssistRequest,
): Promise<LoadCodeAssistResponse> {
try {
return await this.requestPost<LoadCodeAssistResponse>(
'loadCodeAssist',
req,
);
} catch (e) {
if (isVpcScAffectedUser(e)) {
return {
currentTier: { id: UserTierId.STANDARD },
};
} else {
throw e;
}
}
}
async getCodeAssistGlobalUserSetting(): Promise<CodeAssistGlobalUserSettingResponse> {
return await this.requestGet<CodeAssistGlobalUserSettingResponse>(
'getCodeAssistGlobalUserSetting',
);
}
async setCodeAssistGlobalUserSetting(
req: SetCodeAssistGlobalUserSettingRequest,
): Promise<CodeAssistGlobalUserSettingResponse> {
return await this.requestPost<CodeAssistGlobalUserSettingResponse>(
'setCodeAssistGlobalUserSetting',
req,
);
}
async countTokens(req: CountTokensParameters): Promise<CountTokensResponse> {
const resp = await this.requestPost<CaCountTokenResponse>(
'countTokens',
toCountTokenRequest(req),
);
return fromCountTokenResponse(resp);
}
async embedContent(
_req: EmbedContentParameters,
): Promise<EmbedContentResponse> {
throw Error();
}
async requestPost<T>(
method: string,
req: object,
signal?: AbortSignal,
): Promise<T> {
const res = await this.client.request({
url: this.getMethodUrl(method),
method: 'POST',
headers: {
'Content-Type': 'application/json',
...this.httpOptions.headers,
},
responseType: 'json',
body: JSON.stringify(req),
signal,
});
return res.data as T;
}
async requestGet<T>(method: string, signal?: AbortSignal): Promise<T> {
const res = await this.client.request({
url: this.getMethodUrl(method),
method: 'GET',
headers: {
'Content-Type': 'application/json',
...this.httpOptions.headers,
},
responseType: 'json',
signal,
});
return res.data as T;
}
async requestStreamingPost<T>(
method: string,
req: object,
signal?: AbortSignal,
): Promise<AsyncGenerator<T>> {
const res = await this.client.request({
url: this.getMethodUrl(method),
method: 'POST',
params: {
alt: 'sse',
},
headers: {
'Content-Type': 'application/json',
...this.httpOptions.headers,
},
responseType: 'stream',
body: JSON.stringify(req),
signal,
});
return (async function* (): AsyncGenerator<T> {
const rl = readline.createInterface({
input: res.data as NodeJS.ReadableStream,
crlfDelay: Infinity, // Recognizes '\r\n' and '\n' as line breaks
});
let bufferedLines: string[] = [];
for await (const line of rl) {
// blank lines are used to separate JSON objects in the stream
if (line === '') {
if (bufferedLines.length === 0) {
continue; // no data to yield
}
yield JSON.parse(bufferedLines.join('\n')) as T;
bufferedLines = []; // Reset the buffer after yielding
} else if (line.startsWith('data: ')) {
bufferedLines.push(line.slice(6).trim());
} else {
throw new Error(`Unexpected line format in response: ${line}`);
}
}
})();
}
getMethodUrl(method: string): string {
const endpoint =
process.env['CODE_ASSIST_ENDPOINT'] ?? CODE_ASSIST_ENDPOINT;
return `${endpoint}/${CODE_ASSIST_API_VERSION}:${method}`;
}
}
function isVpcScAffectedUser(error: unknown): boolean {
if (error && typeof error === 'object' && 'response' in error) {
const gaxiosError = error as {
response?: {
data?: unknown;
};
};
const response = gaxiosError.response?.data as
| GoogleRpcResponse
| undefined;
if (Array.isArray(response?.error?.details)) {
return response.error.details.some(
(detail) => detail.reason === 'SECURITY_POLICY_VIOLATED',
);
}
}
return false;
}

View File

@@ -1,224 +0,0 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
import { setupUser, ProjectIdRequiredError } from './setup.js';
import { CodeAssistServer } from '../code_assist/server.js';
import type { OAuth2Client } from 'google-auth-library';
import type { GeminiUserTier } from './types.js';
import { UserTierId } from './types.js';
vi.mock('../code_assist/server.js');
const mockPaidTier: GeminiUserTier = {
id: UserTierId.STANDARD,
name: 'paid',
description: 'Paid tier',
isDefault: true,
};
const mockFreeTier: GeminiUserTier = {
id: UserTierId.FREE,
name: 'free',
description: 'Free tier',
isDefault: true,
};
describe('setupUser for existing user', () => {
let mockLoad: ReturnType<typeof vi.fn>;
let mockOnboardUser: ReturnType<typeof vi.fn>;
beforeEach(() => {
vi.resetAllMocks();
mockLoad = vi.fn();
mockOnboardUser = vi.fn().mockResolvedValue({
done: true,
response: {
cloudaicompanionProject: {
id: 'server-project',
},
},
});
vi.mocked(CodeAssistServer).mockImplementation(
() =>
({
loadCodeAssist: mockLoad,
onboardUser: mockOnboardUser,
}) as unknown as CodeAssistServer,
);
});
afterEach(() => {
vi.unstubAllEnvs();
});
it('should use GOOGLE_CLOUD_PROJECT when set and project from server is undefined', async () => {
vi.stubEnv('GOOGLE_CLOUD_PROJECT', 'test-project');
mockLoad.mockResolvedValue({
currentTier: mockPaidTier,
});
await setupUser({} as OAuth2Client);
expect(CodeAssistServer).toHaveBeenCalledWith(
{},
'test-project',
{},
'',
undefined,
);
});
it('should ignore GOOGLE_CLOUD_PROJECT when project from server is set', async () => {
vi.stubEnv('GOOGLE_CLOUD_PROJECT', 'test-project');
mockLoad.mockResolvedValue({
cloudaicompanionProject: 'server-project',
currentTier: mockPaidTier,
});
const projectId = await setupUser({} as OAuth2Client);
expect(CodeAssistServer).toHaveBeenCalledWith(
{},
'test-project',
{},
'',
undefined,
);
expect(projectId).toEqual({
projectId: 'server-project',
userTier: 'standard-tier',
});
});
it('should throw ProjectIdRequiredError when no project ID is available', async () => {
vi.stubEnv('GOOGLE_CLOUD_PROJECT', '');
// And the server itself requires a project ID internally
vi.mocked(CodeAssistServer).mockImplementation(() => {
throw new ProjectIdRequiredError();
});
await expect(setupUser({} as OAuth2Client)).rejects.toThrow(
ProjectIdRequiredError,
);
});
});
describe('setupUser for new user', () => {
let mockLoad: ReturnType<typeof vi.fn>;
let mockOnboardUser: ReturnType<typeof vi.fn>;
beforeEach(() => {
vi.resetAllMocks();
mockLoad = vi.fn();
mockOnboardUser = vi.fn().mockResolvedValue({
done: true,
response: {
cloudaicompanionProject: {
id: 'server-project',
},
},
});
vi.mocked(CodeAssistServer).mockImplementation(
() =>
({
loadCodeAssist: mockLoad,
onboardUser: mockOnboardUser,
}) as unknown as CodeAssistServer,
);
});
afterEach(() => {
vi.unstubAllEnvs();
});
it('should use GOOGLE_CLOUD_PROJECT when set and onboard a new paid user', async () => {
vi.stubEnv('GOOGLE_CLOUD_PROJECT', 'test-project');
mockLoad.mockResolvedValue({
allowedTiers: [mockPaidTier],
});
const userData = await setupUser({} as OAuth2Client);
expect(CodeAssistServer).toHaveBeenCalledWith(
{},
'test-project',
{},
'',
undefined,
);
expect(mockLoad).toHaveBeenCalled();
expect(mockOnboardUser).toHaveBeenCalledWith({
tierId: 'standard-tier',
cloudaicompanionProject: 'test-project',
metadata: {
ideType: 'IDE_UNSPECIFIED',
platform: 'PLATFORM_UNSPECIFIED',
pluginType: 'GEMINI',
duetProject: 'test-project',
},
});
expect(userData).toEqual({
projectId: 'server-project',
userTier: 'standard-tier',
});
});
it('should onboard a new free user when GOOGLE_CLOUD_PROJECT is not set', async () => {
vi.stubEnv('GOOGLE_CLOUD_PROJECT', '');
mockLoad.mockResolvedValue({
allowedTiers: [mockFreeTier],
});
const userData = await setupUser({} as OAuth2Client);
expect(CodeAssistServer).toHaveBeenCalledWith(
{},
undefined,
{},
'',
undefined,
);
expect(mockLoad).toHaveBeenCalled();
expect(mockOnboardUser).toHaveBeenCalledWith({
tierId: 'free-tier',
cloudaicompanionProject: undefined,
metadata: {
ideType: 'IDE_UNSPECIFIED',
platform: 'PLATFORM_UNSPECIFIED',
pluginType: 'GEMINI',
},
});
expect(userData).toEqual({
projectId: 'server-project',
userTier: 'free-tier',
});
});
it('should use GOOGLE_CLOUD_PROJECT when onboard response has no project ID', async () => {
vi.stubEnv('GOOGLE_CLOUD_PROJECT', 'test-project');
mockLoad.mockResolvedValue({
allowedTiers: [mockPaidTier],
});
mockOnboardUser.mockResolvedValue({
done: true,
response: {
cloudaicompanionProject: undefined,
},
});
const userData = await setupUser({} as OAuth2Client);
expect(userData).toEqual({
projectId: 'test-project',
userTier: 'standard-tier',
});
});
it('should throw ProjectIdRequiredError when no project ID is available', async () => {
vi.stubEnv('GOOGLE_CLOUD_PROJECT', '');
mockLoad.mockResolvedValue({
allowedTiers: [mockPaidTier],
});
mockOnboardUser.mockResolvedValue({
done: true,
response: {},
});
await expect(setupUser({} as OAuth2Client)).rejects.toThrow(
ProjectIdRequiredError,
);
});
});

View File

@@ -1,124 +0,0 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import type {
ClientMetadata,
GeminiUserTier,
LoadCodeAssistResponse,
OnboardUserRequest,
} from './types.js';
import { UserTierId } from './types.js';
import { CodeAssistServer } from './server.js';
import type { OAuth2Client } from 'google-auth-library';
export class ProjectIdRequiredError extends Error {
constructor() {
super(
'This account requires setting the GOOGLE_CLOUD_PROJECT env var. See https://goo.gle/gemini-cli-auth-docs#workspace-gca',
);
}
}
export interface UserData {
projectId: string;
userTier: UserTierId;
}
/**
*
* @param projectId the user's project id, if any
* @returns the user's actual project id
*/
export async function setupUser(client: OAuth2Client): Promise<UserData> {
const projectId = process.env['GOOGLE_CLOUD_PROJECT'] || undefined;
const caServer = new CodeAssistServer(client, projectId, {}, '', undefined);
const coreClientMetadata: ClientMetadata = {
ideType: 'IDE_UNSPECIFIED',
platform: 'PLATFORM_UNSPECIFIED',
pluginType: 'GEMINI',
};
const loadRes = await caServer.loadCodeAssist({
cloudaicompanionProject: projectId,
metadata: {
...coreClientMetadata,
duetProject: projectId,
},
});
if (loadRes.currentTier) {
if (!loadRes.cloudaicompanionProject) {
if (projectId) {
return {
projectId,
userTier: loadRes.currentTier.id,
};
}
throw new ProjectIdRequiredError();
}
return {
projectId: loadRes.cloudaicompanionProject,
userTier: loadRes.currentTier.id,
};
}
const tier = getOnboardTier(loadRes);
let onboardReq: OnboardUserRequest;
if (tier.id === UserTierId.FREE) {
// The free tier uses a managed google cloud project. Setting a project in the `onboardUser` request causes a `Precondition Failed` error.
onboardReq = {
tierId: tier.id,
cloudaicompanionProject: undefined,
metadata: coreClientMetadata,
};
} else {
onboardReq = {
tierId: tier.id,
cloudaicompanionProject: projectId,
metadata: {
...coreClientMetadata,
duetProject: projectId,
},
};
}
// Poll onboardUser until long running operation is complete.
let lroRes = await caServer.onboardUser(onboardReq);
while (!lroRes.done) {
await new Promise((f) => setTimeout(f, 5000));
lroRes = await caServer.onboardUser(onboardReq);
}
if (!lroRes.response?.cloudaicompanionProject?.id) {
if (projectId) {
return {
projectId,
userTier: tier.id,
};
}
throw new ProjectIdRequiredError();
}
return {
projectId: lroRes.response.cloudaicompanionProject.id,
userTier: tier.id,
};
}
function getOnboardTier(res: LoadCodeAssistResponse): GeminiUserTier {
for (const tier of res.allowedTiers || []) {
if (tier.isDefault) {
return tier;
}
}
return {
name: '',
description: '',
id: UserTierId.LEGACY,
userDefinedCloudaicompanionProject: true,
};
}

View File

@@ -1,201 +0,0 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
export interface ClientMetadata {
ideType?: ClientMetadataIdeType;
ideVersion?: string;
pluginVersion?: string;
platform?: ClientMetadataPlatform;
updateChannel?: string;
duetProject?: string;
pluginType?: ClientMetadataPluginType;
ideName?: string;
}
export type ClientMetadataIdeType =
| 'IDE_UNSPECIFIED'
| 'VSCODE'
| 'INTELLIJ'
| 'VSCODE_CLOUD_WORKSTATION'
| 'INTELLIJ_CLOUD_WORKSTATION'
| 'CLOUD_SHELL';
export type ClientMetadataPlatform =
| 'PLATFORM_UNSPECIFIED'
| 'DARWIN_AMD64'
| 'DARWIN_ARM64'
| 'LINUX_AMD64'
| 'LINUX_ARM64'
| 'WINDOWS_AMD64';
export type ClientMetadataPluginType =
| 'PLUGIN_UNSPECIFIED'
| 'CLOUD_CODE'
| 'GEMINI'
| 'AIPLUGIN_INTELLIJ'
| 'AIPLUGIN_STUDIO';
export interface LoadCodeAssistRequest {
cloudaicompanionProject?: string;
metadata: ClientMetadata;
}
/**
* Represents LoadCodeAssistResponse proto json field
* http://google3/google/internal/cloud/code/v1internal/cloudcode.proto;l=224
*/
export interface LoadCodeAssistResponse {
currentTier?: GeminiUserTier | null;
allowedTiers?: GeminiUserTier[] | null;
ineligibleTiers?: IneligibleTier[] | null;
cloudaicompanionProject?: string | null;
}
/**
* GeminiUserTier reflects the structure received from the CodeAssist when calling LoadCodeAssist.
*/
export interface GeminiUserTier {
id: UserTierId;
name?: string;
description?: string;
// This value is used to declare whether a given tier requires the user to configure the project setting on the IDE settings or not.
userDefinedCloudaicompanionProject?: boolean | null;
isDefault?: boolean;
privacyNotice?: PrivacyNotice;
hasAcceptedTos?: boolean;
hasOnboardedPreviously?: boolean;
}
/**
* Includes information specifying the reasons for a user's ineligibility for a specific tier.
* @param reasonCode mnemonic code representing the reason for in-eligibility.
* @param reasonMessage message to display to the user.
* @param tierId id of the tier.
* @param tierName name of the tier.
*/
export interface IneligibleTier {
reasonCode: IneligibleTierReasonCode;
reasonMessage: string;
tierId: UserTierId;
tierName: string;
}
/**
* List of predefined reason codes when a tier is blocked from a specific tier.
* https://source.corp.google.com/piper///depot/google3/google/internal/cloud/code/v1internal/cloudcode.proto;l=378
*/
export enum IneligibleTierReasonCode {
// go/keep-sorted start
DASHER_USER = 'DASHER_USER',
INELIGIBLE_ACCOUNT = 'INELIGIBLE_ACCOUNT',
NON_USER_ACCOUNT = 'NON_USER_ACCOUNT',
RESTRICTED_AGE = 'RESTRICTED_AGE',
RESTRICTED_NETWORK = 'RESTRICTED_NETWORK',
UNKNOWN = 'UNKNOWN',
UNKNOWN_LOCATION = 'UNKNOWN_LOCATION',
UNSUPPORTED_LOCATION = 'UNSUPPORTED_LOCATION',
// go/keep-sorted end
}
/**
* UserTierId represents IDs returned from the Cloud Code Private API representing a user's tier
*
* //depot/google3/cloud/developer_experience/cloudcode/pa/service/usertier.go;l=16
*/
export enum UserTierId {
FREE = 'free-tier',
LEGACY = 'legacy-tier',
STANDARD = 'standard-tier',
}
/**
* PrivacyNotice reflects the structure received from the CodeAssist in regards to a tier
* privacy notice.
*/
export interface PrivacyNotice {
showNotice: boolean;
noticeText?: string;
}
/**
* Proto signature of OnboardUserRequest as payload to OnboardUser call
*/
export interface OnboardUserRequest {
tierId: string | undefined;
cloudaicompanionProject: string | undefined;
metadata: ClientMetadata | undefined;
}
/**
* Represents LongRunningOperation proto
* http://google3/google/longrunning/operations.proto;rcl=698857719;l=107
*/
export interface LongRunningOperationResponse {
name: string;
done?: boolean;
response?: OnboardUserResponse;
}
/**
* Represents OnboardUserResponse proto
* http://google3/google/internal/cloud/code/v1internal/cloudcode.proto;l=215
*/
export interface OnboardUserResponse {
// tslint:disable-next-line:enforce-name-casing This is the name of the field in the proto.
cloudaicompanionProject?: {
id: string;
name: string;
};
}
/**
* Status code of user license status
* it does not strictly correspond to the proto
* Error value is an additional value assigned to error responses from OnboardUser
*/
export enum OnboardUserStatusCode {
Default = 'DEFAULT',
Notice = 'NOTICE',
Warning = 'WARNING',
Error = 'ERROR',
}
/**
* Status of user onboarded to gemini
*/
export interface OnboardUserStatus {
statusCode: OnboardUserStatusCode;
displayMessage: string;
helpLink: HelpLinkUrl | undefined;
}
export interface HelpLinkUrl {
description: string;
url: string;
}
export interface SetCodeAssistGlobalUserSettingRequest {
cloudaicompanionProject?: string;
freeTierDataCollectionOptin: boolean;
}
export interface CodeAssistGlobalUserSettingResponse {
cloudaicompanionProject?: string;
freeTierDataCollectionOptin: boolean;
}
/**
* Relevant fields that can be returned from a Google RPC response
*/
export interface GoogleRpcResponse {
error?: {
details?: GoogleRpcErrorInfo[];
};
}
/**
* Relevant fields that can be returned in the details of an error returned from GoogleRPCs
*/
interface GoogleRpcErrorInfo {
reason?: string;
}

View File

@@ -283,23 +283,6 @@ describe('Server Config (config.ts)', () => {
expect(config.isInFallbackMode()).toBe(false);
});
it('should strip thoughts when switching from GenAI to Vertex', async () => {
const config = new Config(baseParams);
vi.mocked(createContentGeneratorConfig).mockImplementation(
(_: Config, authType: AuthType | undefined) =>
({ authType }) as unknown as ContentGeneratorConfig,
);
await config.refreshAuth(AuthType.USE_GEMINI);
await config.refreshAuth(AuthType.LOGIN_WITH_GOOGLE);
expect(
config.getGeminiClient().stripThoughtsFromHistory,
).toHaveBeenCalledWith();
});
it('should not strip thoughts when switching from Vertex to GenAI', async () => {
const config = new Config(baseParams);

View File

@@ -16,7 +16,8 @@ import { ProxyAgent, setGlobalDispatcher } from 'undici';
import type {
ContentGenerator,
ContentGeneratorConfig,
} from '../core/contentGenerator.js';
AuthType} from '../core/contentGenerator.js';
import type { FallbackModelHandler } from '../fallback/types.js';
import type { MCPOAuthConfig } from '../mcp/oauth-provider.js';
import type { ShellExecutionConfig } from '../services/shellExecutionService.js';
@@ -26,7 +27,6 @@ import type { AnyToolInvocation } from '../tools/tools.js';
import { BaseLlmClient } from '../core/baseLlmClient.js';
import { GeminiClient } from '../core/client.js';
import {
AuthType,
createContentGenerator,
createContentGeneratorConfig,
} from '../core/contentGenerator.js';
@@ -684,16 +684,6 @@ export class Config {
}
async refreshAuth(authMethod: AuthType, isInitialAuth?: boolean) {
// Vertex and Genai have incompatible encryption and sending history with
// throughtSignature from Genai to Vertex will fail, we need to strip them
if (
this.contentGeneratorConfig?.authType === AuthType.USE_GEMINI &&
authMethod === AuthType.LOGIN_WITH_GOOGLE
) {
// Restore the conversation history to the new client
this.geminiClient.stripThoughtsFromHistory();
}
const newContentGeneratorConfig = createContentGeneratorConfig(
this,
authMethod,

View File

@@ -31,7 +31,7 @@ describe('Flash Model Fallback Configuration', () => {
config as unknown as { contentGeneratorConfig: unknown }
).contentGeneratorConfig = {
model: DEFAULT_GEMINI_MODEL,
authType: 'oauth-personal',
authType: 'gemini-api-key',
};
});

View File

@@ -73,6 +73,7 @@ describe('OpenAIContentGenerator Timeout Handling', () => {
}),
buildClient: vi.fn().mockReturnValue(mockOpenAIClient),
buildRequest: vi.fn().mockImplementation((req) => req),
getDefaultGenerationConfig: vi.fn().mockReturnValue({}),
};
// Create generator instance
@@ -299,6 +300,7 @@ describe('OpenAIContentGenerator Timeout Handling', () => {
}),
buildClient: vi.fn().mockReturnValue(mockOpenAIClient),
buildRequest: vi.fn().mockImplementation((req) => req),
getDefaultGenerationConfig: vi.fn().mockReturnValue({}),
};
new OpenAIContentGenerator(
@@ -333,6 +335,7 @@ describe('OpenAIContentGenerator Timeout Handling', () => {
}),
buildClient: vi.fn().mockReturnValue(mockOpenAIClient),
buildRequest: vi.fn().mockImplementation((req) => req),
getDefaultGenerationConfig: vi.fn().mockReturnValue({}),
};
new OpenAIContentGenerator(

View File

@@ -146,13 +146,11 @@ describe('BaseLlmClient', () => {
// Validate the parameters passed to the underlying generator
expect(mockGenerateContent).toHaveBeenCalledTimes(1);
expect(mockGenerateContent).toHaveBeenCalledWith(
{
expect.objectContaining({
model: 'test-model',
contents: defaultOptions.contents,
config: {
config: expect.objectContaining({
abortSignal: defaultOptions.abortSignal,
temperature: 0,
topP: 1,
tools: [
{
functionDeclarations: [
@@ -164,9 +162,8 @@ describe('BaseLlmClient', () => {
],
},
],
// Crucial: systemInstruction should NOT be in the config object if not provided
},
},
}),
}),
'test-prompt-id',
);
});
@@ -189,7 +186,6 @@ describe('BaseLlmClient', () => {
expect.objectContaining({
config: expect.objectContaining({
temperature: 0.8,
topP: 1, // Default should remain if not overridden
topK: 10,
tools: expect.any(Array),
}),

View File

@@ -64,12 +64,6 @@ export interface GenerateJsonOptions {
* A client dedicated to stateless, utility-focused LLM calls.
*/
export class BaseLlmClient {
// Default configuration for utility tasks
private readonly defaultUtilityConfig: GenerateContentConfig = {
temperature: 0,
topP: 1,
};
constructor(
private readonly contentGenerator: ContentGenerator,
private readonly config: Config,
@@ -90,7 +84,6 @@ export class BaseLlmClient {
const requestConfig: GenerateContentConfig = {
abortSignal,
...this.defaultUtilityConfig,
...options.config,
...(systemInstruction && { systemInstruction }),
};

View File

@@ -15,11 +15,7 @@ import {
} from 'vitest';
import type { Content, GenerateContentResponse, Part } from '@google/genai';
import {
isThinkingDefault,
isThinkingSupported,
GeminiClient,
} from './client.js';
import { GeminiClient } from './client.js';
import { findCompressSplitPoint } from '../services/chatCompressionService.js';
import {
AuthType,
@@ -247,40 +243,6 @@ describe('findCompressSplitPoint', () => {
});
});
describe('isThinkingSupported', () => {
it('should return true for gemini-2.5', () => {
expect(isThinkingSupported('gemini-2.5')).toBe(true);
});
it('should return true for gemini-2.5-pro', () => {
expect(isThinkingSupported('gemini-2.5-pro')).toBe(true);
});
it('should return false for other models', () => {
expect(isThinkingSupported('gemini-1.5-flash')).toBe(false);
expect(isThinkingSupported('some-other-model')).toBe(false);
});
});
describe('isThinkingDefault', () => {
it('should return false for gemini-2.5-flash-lite', () => {
expect(isThinkingDefault('gemini-2.5-flash-lite')).toBe(false);
});
it('should return true for gemini-2.5', () => {
expect(isThinkingDefault('gemini-2.5')).toBe(true);
});
it('should return true for gemini-2.5-pro', () => {
expect(isThinkingDefault('gemini-2.5-pro')).toBe(true);
});
it('should return false for other models', () => {
expect(isThinkingDefault('gemini-1.5-flash')).toBe(false);
expect(isThinkingDefault('some-other-model')).toBe(false);
});
});
describe('Gemini Client (client.ts)', () => {
let mockContentGenerator: ContentGenerator;
let mockConfig: Config;
@@ -2304,16 +2266,15 @@ ${JSON.stringify(
);
expect(mockContentGenerator.generateContent).toHaveBeenCalledWith(
{
expect.objectContaining({
model: DEFAULT_GEMINI_FLASH_MODEL,
config: {
config: expect.objectContaining({
abortSignal,
systemInstruction: getCoreSystemPrompt(''),
temperature: 0.5,
topP: 1,
},
}),
contents,
},
}),
'test-session-id',
);
});

View File

@@ -15,11 +15,7 @@ import type {
// Config
import { ApprovalMode, type Config } from '../config/config.js';
import {
DEFAULT_GEMINI_FLASH_MODEL,
DEFAULT_GEMINI_MODEL_AUTO,
DEFAULT_THINKING_MODE,
} from '../config/models.js';
import { DEFAULT_GEMINI_FLASH_MODEL } from '../config/models.js';
// Core modules
import type { ContentGenerator } from './contentGenerator.js';
@@ -78,25 +74,10 @@ import { type File, type IdeContext } from '../ide/types.js';
// Fallback handling
import { handleFallback } from '../fallback/handler.js';
export function isThinkingSupported(model: string) {
return model.startsWith('gemini-2.5') || model === DEFAULT_GEMINI_MODEL_AUTO;
}
export function isThinkingDefault(model: string) {
if (model.startsWith('gemini-2.5-flash-lite')) {
return false;
}
return model.startsWith('gemini-2.5') || model === DEFAULT_GEMINI_MODEL_AUTO;
}
const MAX_TURNS = 100;
export class GeminiClient {
private chat?: GeminiChat;
private readonly generateContentConfig: GenerateContentConfig = {
temperature: 0,
topP: 1,
};
private sessionTurnCount = 0;
private readonly loopDetector: LoopDetectionService;
@@ -208,20 +189,10 @@ export class GeminiClient {
const model = this.config.getModel();
const systemInstruction = getCoreSystemPrompt(userMemory, model);
const config: GenerateContentConfig = { ...this.generateContentConfig };
if (isThinkingSupported(model)) {
config.thinkingConfig = {
includeThoughts: true,
thinkingBudget: DEFAULT_THINKING_MODE,
};
}
return new GeminiChat(
this.config,
{
systemInstruction,
...config,
tools,
},
history,
@@ -618,11 +589,6 @@ export class GeminiClient {
): Promise<GenerateContentResponse> {
let currentAttemptModel: string = model;
const configToUse: GenerateContentConfig = {
...this.generateContentConfig,
...generationConfig,
};
try {
const userMemory = this.config.getUserMemory();
const finalSystemInstruction = generationConfig.systemInstruction
@@ -631,7 +597,7 @@ export class GeminiClient {
const requestConfig: GenerateContentConfig = {
abortSignal,
...configToUse,
...generationConfig,
systemInstruction: finalSystemInstruction,
};
@@ -672,7 +638,7 @@ export class GeminiClient {
`Error generating content via API with model ${currentAttemptModel}.`,
{
requestContents: contents,
requestConfig: configToUse,
requestConfig: generationConfig,
},
'generateContent-api',
);

View File

@@ -5,42 +5,19 @@
*/
import { describe, it, expect, vi } from 'vitest';
import type { ContentGenerator } from './contentGenerator.js';
import { createContentGenerator, AuthType } from './contentGenerator.js';
import { createCodeAssistContentGenerator } from '../code_assist/codeAssist.js';
import { GoogleGenAI } from '@google/genai';
import type { Config } from '../config/config.js';
import { LoggingContentGenerator } from './loggingContentGenerator.js';
import { LoggingContentGenerator } from './geminiContentGenerator/loggingContentGenerator.js';
vi.mock('../code_assist/codeAssist.js');
vi.mock('@google/genai');
const mockConfig = {
getCliVersion: vi.fn().mockReturnValue('1.0.0'),
} as unknown as Config;
describe('createContentGenerator', () => {
it('should create a CodeAssistContentGenerator', async () => {
const mockGenerator = {} as unknown as ContentGenerator;
vi.mocked(createCodeAssistContentGenerator).mockResolvedValue(
mockGenerator as never,
);
const generator = await createContentGenerator(
{
model: 'test-model',
authType: AuthType.LOGIN_WITH_GOOGLE,
},
mockConfig,
);
expect(createCodeAssistContentGenerator).toHaveBeenCalled();
expect(generator).toEqual(
new LoggingContentGenerator(mockGenerator, mockConfig),
);
});
it('should create a GoogleGenAI content generator', async () => {
it('should create a Gemini content generator', async () => {
const mockConfig = {
getUsageStatisticsEnabled: () => true,
getContentGeneratorConfig: () => ({}),
getCliVersion: () => '1.0.0',
} as unknown as Config;
const mockGenerator = {
@@ -65,17 +42,17 @@ describe('createContentGenerator', () => {
},
},
});
expect(generator).toEqual(
new LoggingContentGenerator(
(mockGenerator as GoogleGenAI).models,
mockConfig,
),
);
// We expect it to be a LoggingContentGenerator wrapping a GeminiContentGenerator
expect(generator).toBeInstanceOf(LoggingContentGenerator);
const wrapped = (generator as LoggingContentGenerator).getWrapped();
expect(wrapped).toBeDefined();
});
it('should create a GoogleGenAI content generator with client install id logging disabled', async () => {
it('should create a Gemini content generator with client install id logging disabled', async () => {
const mockConfig = {
getUsageStatisticsEnabled: () => false,
getContentGeneratorConfig: () => ({}),
getCliVersion: () => '1.0.0',
} as unknown as Config;
const mockGenerator = {
models: {},
@@ -98,11 +75,6 @@ describe('createContentGenerator', () => {
},
},
});
expect(generator).toEqual(
new LoggingContentGenerator(
(mockGenerator as GoogleGenAI).models,
mockConfig,
),
);
expect(generator).toBeInstanceOf(LoggingContentGenerator);
});
});

View File

@@ -12,15 +12,9 @@ import type {
GenerateContentParameters,
GenerateContentResponse,
} from '@google/genai';
import { GoogleGenAI } from '@google/genai';
import { createCodeAssistContentGenerator } from '../code_assist/codeAssist.js';
import { DEFAULT_QWEN_MODEL } from '../config/models.js';
import type { Config } from '../config/config.js';
import type { UserTierId } from '../code_assist/types.js';
import { InstallationManager } from '../utils/installationManager.js';
import { LoggingContentGenerator } from './loggingContentGenerator.js';
/**
* Interface abstracting the core functionalities for generating content and counting tokens.
*/
@@ -38,15 +32,11 @@ export interface ContentGenerator {
countTokens(request: CountTokensParameters): Promise<CountTokensResponse>;
embedContent(request: EmbedContentParameters): Promise<EmbedContentResponse>;
userTier?: UserTierId;
}
export enum AuthType {
LOGIN_WITH_GOOGLE = 'oauth-personal',
USE_GEMINI = 'gemini-api-key',
USE_VERTEX_AI = 'vertex-ai',
CLOUD_SHELL = 'cloud-shell',
USE_OPENAI = 'openai',
QWEN_OAUTH = 'qwen-oauth',
}
@@ -59,12 +49,9 @@ export type ContentGeneratorConfig = {
authType?: AuthType | undefined;
enableOpenAILogging?: boolean;
openAILoggingDir?: string;
// Timeout configuration in milliseconds
timeout?: number;
// Maximum retries for failed requests
maxRetries?: number;
// Disable cache control for DashScope providers
disableCacheControl?: boolean;
timeout?: number; // Timeout configuration in milliseconds
maxRetries?: number; // Maximum retries for failed requests
disableCacheControl?: boolean; // Disable cache control for DashScope providers
samplingParams?: {
top_p?: number;
top_k?: number;
@@ -74,6 +61,9 @@ export type ContentGeneratorConfig = {
temperature?: number;
max_tokens?: number;
};
reasoning?: {
effort?: 'low' | 'medium' | 'high';
};
proxy?: string | undefined;
userAgent?: string;
// Schema compliance mode for tool definitions
@@ -123,48 +113,14 @@ export async function createContentGenerator(
gcConfig: Config,
isInitialAuth?: boolean,
): Promise<ContentGenerator> {
const version = process.env['CLI_VERSION'] || process.version;
const userAgent = `QwenCode/${version} (${process.platform}; ${process.arch})`;
const baseHeaders: Record<string, string> = {
'User-Agent': userAgent,
};
if (
config.authType === AuthType.LOGIN_WITH_GOOGLE ||
config.authType === AuthType.CLOUD_SHELL
) {
const httpOptions = { headers: baseHeaders };
return new LoggingContentGenerator(
await createCodeAssistContentGenerator(
httpOptions,
config.authType,
gcConfig,
),
gcConfig,
);
}
if (
config.authType === AuthType.USE_GEMINI ||
config.authType === AuthType.USE_VERTEX_AI
) {
let headers: Record<string, string> = { ...baseHeaders };
if (gcConfig?.getUsageStatisticsEnabled()) {
const installationManager = new InstallationManager();
const installationId = installationManager.getInstallationId();
headers = {
...headers,
'x-gemini-api-privileged-user-id': `${installationId}`,
};
}
const httpOptions = { headers };
const googleGenAI = new GoogleGenAI({
apiKey: config.apiKey === '' ? undefined : config.apiKey,
vertexai: config.vertexai,
httpOptions,
});
return new LoggingContentGenerator(googleGenAI.models, gcConfig);
const { createGeminiContentGenerator } = await import(
'./geminiContentGenerator/index.js'
);
return createGeminiContentGenerator(config, gcConfig);
}
if (config.authType === AuthType.USE_OPENAI) {

View File

@@ -240,7 +240,7 @@ describe('CoreToolScheduler', () => {
getAllowedTools: () => [],
getContentGeneratorConfig: () => ({
model: 'test-model',
authType: 'oauth-personal',
authType: 'gemini-api-key',
}),
getShellExecutionConfig: () => ({
terminalWidth: 90,
@@ -318,7 +318,7 @@ describe('CoreToolScheduler', () => {
getAllowedTools: () => [],
getContentGeneratorConfig: () => ({
model: 'test-model',
authType: 'oauth-personal',
authType: 'gemini-api-key',
}),
getShellExecutionConfig: () => ({
terminalWidth: 90,
@@ -497,7 +497,7 @@ describe('CoreToolScheduler', () => {
getExcludeTools: () => ['write_file', 'edit', 'run_shell_command'],
getContentGeneratorConfig: () => ({
model: 'test-model',
authType: 'oauth-personal',
authType: 'gemini-api-key',
}),
getShellExecutionConfig: () => ({
terminalWidth: 90,
@@ -584,7 +584,7 @@ describe('CoreToolScheduler', () => {
getExcludeTools: () => ['write_file', 'edit'], // Different excluded tools
getContentGeneratorConfig: () => ({
model: 'test-model',
authType: 'oauth-personal',
authType: 'gemini-api-key',
}),
getShellExecutionConfig: () => ({
terminalWidth: 90,
@@ -674,7 +674,7 @@ describe('CoreToolScheduler with payload', () => {
getAllowedTools: () => [],
getContentGeneratorConfig: () => ({
model: 'test-model',
authType: 'oauth-personal',
authType: 'gemini-api-key',
}),
getShellExecutionConfig: () => ({
terminalWidth: 90,
@@ -1001,7 +1001,7 @@ describe('CoreToolScheduler edit cancellation', () => {
getAllowedTools: () => [],
getContentGeneratorConfig: () => ({
model: 'test-model',
authType: 'oauth-personal',
authType: 'gemini-api-key',
}),
getShellExecutionConfig: () => ({
terminalWidth: 90,
@@ -1108,7 +1108,7 @@ describe('CoreToolScheduler YOLO mode', () => {
getAllowedTools: () => [],
getContentGeneratorConfig: () => ({
model: 'test-model',
authType: 'oauth-personal',
authType: 'gemini-api-key',
}),
getShellExecutionConfig: () => ({
terminalWidth: 90,
@@ -1258,7 +1258,7 @@ describe('CoreToolScheduler cancellation during executing with live output', ()
getApprovalMode: () => ApprovalMode.DEFAULT,
getContentGeneratorConfig: () => ({
model: 'test-model',
authType: 'oauth-personal',
authType: 'gemini-api-key',
}),
getToolRegistry: () => mockToolRegistry,
getShellExecutionConfig: () => ({
@@ -1350,7 +1350,7 @@ describe('CoreToolScheduler request queueing', () => {
getAllowedTools: () => [],
getContentGeneratorConfig: () => ({
model: 'test-model',
authType: 'oauth-personal',
authType: 'gemini-api-key',
}),
getShellExecutionConfig: () => ({
terminalWidth: 90,
@@ -1482,7 +1482,7 @@ describe('CoreToolScheduler request queueing', () => {
getToolRegistry: () => toolRegistry,
getContentGeneratorConfig: () => ({
model: 'test-model',
authType: 'oauth-personal',
authType: 'gemini-api-key',
}),
getShellExecutionConfig: () => ({
terminalWidth: 80,
@@ -1586,7 +1586,7 @@ describe('CoreToolScheduler request queueing', () => {
getAllowedTools: () => [],
getContentGeneratorConfig: () => ({
model: 'test-model',
authType: 'oauth-personal',
authType: 'gemini-api-key',
}),
getShellExecutionConfig: () => ({
terminalWidth: 90,
@@ -1854,7 +1854,7 @@ describe('CoreToolScheduler Sequential Execution', () => {
getAllowedTools: () => [],
getContentGeneratorConfig: () => ({
model: 'test-model',
authType: 'oauth-personal',
authType: 'gemini-api-key',
}),
getShellExecutionConfig: () => ({
terminalWidth: 90,
@@ -1975,7 +1975,7 @@ describe('CoreToolScheduler Sequential Execution', () => {
getAllowedTools: () => [],
getContentGeneratorConfig: () => ({
model: 'test-model',
authType: 'oauth-personal',
authType: 'gemini-api-key',
}),
getShellExecutionConfig: () => ({
terminalWidth: 90,

View File

@@ -111,7 +111,7 @@ describe('GeminiChat', () => {
getUsageStatisticsEnabled: () => true,
getDebugMode: () => false,
getContentGeneratorConfig: vi.fn().mockReturnValue({
authType: 'oauth-personal', // Ensure this is set for fallback tests
authType: 'gemini-api-key', // Ensure this is set for fallback tests
model: 'test-model',
}),
getModel: vi.fn().mockReturnValue('gemini-pro'),
@@ -1382,7 +1382,7 @@ describe('GeminiChat', () => {
});
it('should call handleFallback with the specific failed model and retry if handler returns true', async () => {
const authType = AuthType.LOGIN_WITH_GOOGLE;
const authType = AuthType.USE_GEMINI;
vi.mocked(mockConfig.getContentGeneratorConfig).mockReturnValue({
model: 'test-model',
authType,

View File

@@ -0,0 +1,173 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { describe, it, expect, vi, beforeEach } from 'vitest';
import { GeminiContentGenerator } from './geminiContentGenerator.js';
import { GoogleGenAI } from '@google/genai';
vi.mock('@google/genai', () => {
const mockGenerateContent = vi.fn();
const mockGenerateContentStream = vi.fn();
const mockCountTokens = vi.fn();
const mockEmbedContent = vi.fn();
return {
GoogleGenAI: vi.fn().mockImplementation(() => ({
models: {
generateContent: mockGenerateContent,
generateContentStream: mockGenerateContentStream,
countTokens: mockCountTokens,
embedContent: mockEmbedContent,
},
})),
};
});
describe('GeminiContentGenerator', () => {
let generator: GeminiContentGenerator;
// eslint-disable-next-line @typescript-eslint/no-explicit-any
let mockGoogleGenAI: any;
beforeEach(() => {
vi.clearAllMocks();
generator = new GeminiContentGenerator({
apiKey: 'test-api-key',
});
mockGoogleGenAI = vi.mocked(GoogleGenAI).mock.results[0].value;
});
it('should call generateContent on the underlying model', async () => {
const request = { model: 'gemini-1.5-flash', contents: [] };
const expectedResponse = { responseId: 'test-id' };
mockGoogleGenAI.models.generateContent.mockResolvedValue(expectedResponse);
const response = await generator.generateContent(request, 'prompt-id');
expect(mockGoogleGenAI.models.generateContent).toHaveBeenCalledWith(
expect.objectContaining({
...request,
config: expect.objectContaining({
temperature: 1,
topP: 0.95,
thinkingConfig: {
includeThoughts: true,
thinkingLevel: 'HIGH',
},
}),
}),
);
expect(response).toBe(expectedResponse);
});
it('should call generateContentStream on the underlying model', async () => {
const request = { model: 'gemini-1.5-flash', contents: [] };
const mockStream = (async function* () {
yield { responseId: '1' };
})();
mockGoogleGenAI.models.generateContentStream.mockResolvedValue(mockStream);
const stream = await generator.generateContentStream(request, 'prompt-id');
expect(mockGoogleGenAI.models.generateContentStream).toHaveBeenCalledWith(
expect.objectContaining({
...request,
config: expect.objectContaining({
temperature: 1,
topP: 0.95,
thinkingConfig: {
includeThoughts: true,
thinkingLevel: 'HIGH',
},
}),
}),
);
expect(stream).toBe(mockStream);
});
it('should call countTokens on the underlying model', async () => {
const request = { model: 'gemini-1.5-flash', contents: [] };
const expectedResponse = { totalTokens: 10 };
mockGoogleGenAI.models.countTokens.mockResolvedValue(expectedResponse);
const response = await generator.countTokens(request);
expect(mockGoogleGenAI.models.countTokens).toHaveBeenCalledWith(request);
expect(response).toBe(expectedResponse);
});
it('should call embedContent on the underlying model', async () => {
const request = { model: 'embedding-model', contents: [] };
const expectedResponse = { embeddings: [] };
mockGoogleGenAI.models.embedContent.mockResolvedValue(expectedResponse);
const response = await generator.embedContent(request);
expect(mockGoogleGenAI.models.embedContent).toHaveBeenCalledWith(request);
expect(response).toBe(expectedResponse);
});
it('should prioritize contentGeneratorConfig samplingParams over request config', async () => {
const generatorWithParams = new GeminiContentGenerator({ apiKey: 'test' }, {
model: 'gemini-1.5-flash',
samplingParams: {
temperature: 0.1,
top_p: 0.2,
},
// eslint-disable-next-line @typescript-eslint/no-explicit-any
} as any);
const request = {
model: 'gemini-1.5-flash',
contents: [],
config: {
temperature: 0.9,
topP: 0.9,
},
};
await generatorWithParams.generateContent(request, 'prompt-id');
expect(mockGoogleGenAI.models.generateContent).toHaveBeenCalledWith(
expect.objectContaining({
config: expect.objectContaining({
temperature: 0.1,
topP: 0.2,
}),
}),
);
});
it('should map reasoning effort to thinkingConfig', async () => {
const generatorWithReasoning = new GeminiContentGenerator(
{ apiKey: 'test' },
{
model: 'gemini-2.5-pro',
reasoning: {
effort: 'high',
},
// eslint-disable-next-line @typescript-eslint/no-explicit-any
} as any,
);
const request = {
model: 'gemini-2.5-pro',
contents: [],
};
await generatorWithReasoning.generateContent(request, 'prompt-id');
expect(mockGoogleGenAI.models.generateContent).toHaveBeenCalledWith(
expect.objectContaining({
config: expect.objectContaining({
thinkingConfig: {
includeThoughts: true,
thinkingLevel: 'HIGH',
},
}),
}),
);
});
});

View File

@@ -0,0 +1,140 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import type {
CountTokensParameters,
CountTokensResponse,
EmbedContentParameters,
EmbedContentResponse,
GenerateContentParameters,
GenerateContentResponse,
GenerateContentConfig,
ThinkingLevel,
} from '@google/genai';
import { GoogleGenAI } from '@google/genai';
import type {
ContentGenerator,
ContentGeneratorConfig,
} from '../contentGenerator.js';
/**
* A wrapper for GoogleGenAI that implements the ContentGenerator interface.
*/
export class GeminiContentGenerator implements ContentGenerator {
private readonly googleGenAI: GoogleGenAI;
private readonly contentGeneratorConfig?: ContentGeneratorConfig;
constructor(
options: {
apiKey?: string;
vertexai?: boolean;
httpOptions?: { headers: Record<string, string> };
},
contentGeneratorConfig?: ContentGeneratorConfig,
) {
this.googleGenAI = new GoogleGenAI(options);
this.contentGeneratorConfig = contentGeneratorConfig;
}
private buildSamplingParameters(
request: GenerateContentParameters,
): GenerateContentConfig {
const configSamplingParams = this.contentGeneratorConfig?.samplingParams;
const requestConfig = request.config || {};
// Helper function to get parameter value with priority: config > request > default
const getParameterValue = <T>(
configValue: T | undefined,
requestKey: keyof GenerateContentConfig,
defaultValue?: T,
): T | undefined => {
const requestValue = requestConfig[requestKey] as T | undefined;
if (configValue !== undefined) return configValue;
if (requestValue !== undefined) return requestValue;
return defaultValue;
};
return {
...requestConfig,
temperature: getParameterValue<number>(
configSamplingParams?.temperature,
'temperature',
1,
),
topP: getParameterValue<number>(
configSamplingParams?.top_p,
'topP',
0.95,
),
topK: getParameterValue<number>(configSamplingParams?.top_k, 'topK', 64),
maxOutputTokens: getParameterValue<number>(
configSamplingParams?.max_tokens,
'maxOutputTokens',
),
presencePenalty: getParameterValue<number>(
configSamplingParams?.presence_penalty,
'presencePenalty',
),
frequencyPenalty: getParameterValue<number>(
configSamplingParams?.frequency_penalty,
'frequencyPenalty',
),
thinkingConfig: getParameterValue(
this.contentGeneratorConfig?.reasoning
? {
includeThoughts: true,
thinkingLevel: (this.contentGeneratorConfig.reasoning.effort ===
'low'
? 'LOW'
: this.contentGeneratorConfig.reasoning.effort === 'high'
? 'HIGH'
: 'THINKING_LEVEL_UNSPECIFIED') as ThinkingLevel,
}
: undefined,
'thinkingConfig',
{
includeThoughts: true,
thinkingLevel: 'HIGH' as ThinkingLevel,
},
),
};
}
async generateContent(
request: GenerateContentParameters,
_userPromptId: string,
): Promise<GenerateContentResponse> {
const finalRequest = {
...request,
config: this.buildSamplingParameters(request),
};
return this.googleGenAI.models.generateContent(finalRequest);
}
async generateContentStream(
request: GenerateContentParameters,
_userPromptId: string,
): Promise<AsyncGenerator<GenerateContentResponse>> {
const finalRequest = {
...request,
config: this.buildSamplingParameters(request),
};
return this.googleGenAI.models.generateContentStream(finalRequest);
}
async countTokens(
request: CountTokensParameters,
): Promise<CountTokensResponse> {
return this.googleGenAI.models.countTokens(request);
}
async embedContent(
request: EmbedContentParameters,
): Promise<EmbedContentResponse> {
return this.googleGenAI.models.embedContent(request);
}
}

View File

@@ -0,0 +1,47 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { describe, it, expect, vi, beforeEach } from 'vitest';
import { createGeminiContentGenerator } from './index.js';
import { GeminiContentGenerator } from './geminiContentGenerator.js';
import { LoggingContentGenerator } from './loggingContentGenerator.js';
import type { Config } from '../../config/config.js';
import { AuthType } from '../contentGenerator.js';
vi.mock('./geminiContentGenerator.js', () => ({
GeminiContentGenerator: vi.fn().mockImplementation(() => ({})),
}));
vi.mock('./loggingContentGenerator.js', () => ({
LoggingContentGenerator: vi.fn().mockImplementation((wrapped) => wrapped),
}));
describe('createGeminiContentGenerator', () => {
let mockConfig: Config;
beforeEach(() => {
vi.clearAllMocks();
mockConfig = {
getUsageStatisticsEnabled: vi.fn().mockReturnValue(false),
getContentGeneratorConfig: vi.fn().mockReturnValue({}),
getCliVersion: vi.fn().mockReturnValue('1.0.0'),
} as unknown as Config;
});
it('should create a GeminiContentGenerator wrapped in LoggingContentGenerator', () => {
const config = {
model: 'gemini-1.5-flash',
apiKey: 'test-key',
authType: AuthType.USE_GEMINI,
};
const generator = createGeminiContentGenerator(config, mockConfig);
expect(GeminiContentGenerator).toHaveBeenCalled();
expect(LoggingContentGenerator).toHaveBeenCalled();
expect(generator).toBeDefined();
});
});

View File

@@ -0,0 +1,55 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { GeminiContentGenerator } from './geminiContentGenerator.js';
import type {
ContentGenerator,
ContentGeneratorConfig,
} from '../contentGenerator.js';
import type { Config } from '../../config/config.js';
import { InstallationManager } from '../../utils/installationManager.js';
import { LoggingContentGenerator } from './loggingContentGenerator.js';
export { GeminiContentGenerator } from './geminiContentGenerator.js';
export { LoggingContentGenerator } from './loggingContentGenerator.js';
/**
* Create a Gemini content generator.
*/
export function createGeminiContentGenerator(
config: ContentGeneratorConfig,
gcConfig: Config,
): ContentGenerator {
const version = process.env['CLI_VERSION'] || process.version;
const userAgent =
config.userAgent ||
`QwenCode/${version} (${process.platform}; ${process.arch})`;
const baseHeaders: Record<string, string> = {
'User-Agent': userAgent,
};
let headers: Record<string, string> = { ...baseHeaders };
if (gcConfig?.getUsageStatisticsEnabled()) {
const installationManager = new InstallationManager();
const installationId = installationManager.getInstallationId();
headers = {
...headers,
'x-gemini-api-privileged-user-id': `${installationId}`,
};
}
const httpOptions = { headers };
const geminiContentGenerator = new GeminiContentGenerator(
{
apiKey: config.apiKey === '' ? undefined : config.apiKey,
vertexai: config.vertexai,
httpOptions,
},
config,
);
return new LoggingContentGenerator(geminiContentGenerator, gcConfig);
}

View File

@@ -13,21 +13,24 @@ import type {
GenerateContentParameters,
GenerateContentResponseUsageMetadata,
GenerateContentResponse,
ContentListUnion,
ContentUnion,
Part,
PartUnion,
} from '@google/genai';
import {
ApiRequestEvent,
ApiResponseEvent,
ApiErrorEvent,
} from '../telemetry/types.js';
import type { Config } from '../config/config.js';
} from '../../telemetry/types.js';
import type { Config } from '../../config/config.js';
import {
logApiError,
logApiRequest,
logApiResponse,
} from '../telemetry/loggers.js';
import type { ContentGenerator } from './contentGenerator.js';
import { toContents } from '../code_assist/converter.js';
import { isStructuredError } from '../utils/quotaErrorDetection.js';
} from '../../telemetry/loggers.js';
import type { ContentGenerator } from '../contentGenerator.js';
import { isStructuredError } from '../../utils/quotaErrorDetection.js';
interface StructuredError {
status: number;
@@ -112,7 +115,7 @@ export class LoggingContentGenerator implements ContentGenerator {
userPromptId: string,
): Promise<GenerateContentResponse> {
const startTime = Date.now();
this.logApiRequest(toContents(req.contents), req.model, userPromptId);
this.logApiRequest(this.toContents(req.contents), req.model, userPromptId);
try {
const response = await this.wrapped.generateContent(req, userPromptId);
const durationMs = Date.now() - startTime;
@@ -137,7 +140,7 @@ export class LoggingContentGenerator implements ContentGenerator {
userPromptId: string,
): Promise<AsyncGenerator<GenerateContentResponse>> {
const startTime = Date.now();
this.logApiRequest(toContents(req.contents), req.model, userPromptId);
this.logApiRequest(this.toContents(req.contents), req.model, userPromptId);
let stream: AsyncGenerator<GenerateContentResponse>;
try {
@@ -205,4 +208,91 @@ export class LoggingContentGenerator implements ContentGenerator {
): Promise<EmbedContentResponse> {
return this.wrapped.embedContent(req);
}
private toContents(contents: ContentListUnion): Content[] {
if (Array.isArray(contents)) {
// it's a Content[] or a PartsUnion[]
return contents.map((c) => this.toContent(c));
}
// it's a Content or a PartsUnion
return [this.toContent(contents)];
}
private toContent(content: ContentUnion): Content {
if (Array.isArray(content)) {
// it's a PartsUnion[]
return {
role: 'user',
parts: this.toParts(content),
};
}
if (typeof content === 'string') {
// it's a string
return {
role: 'user',
parts: [{ text: content }],
};
}
if ('parts' in content) {
// it's a Content - process parts to handle thought filtering
return {
...content,
parts: content.parts
? this.toParts(content.parts.filter((p) => p != null))
: [],
};
}
// it's a Part
return {
role: 'user',
parts: [this.toPart(content as Part)],
};
}
private toParts(parts: PartUnion[]): Part[] {
return parts.map((p) => this.toPart(p));
}
private toPart(part: PartUnion): Part {
if (typeof part === 'string') {
// it's a string
return { text: part };
}
// Handle thought parts for CountToken API compatibility
// The CountToken API expects parts to have certain required "oneof" fields initialized,
// but thought parts don't conform to this schema and cause API failures
if ('thought' in part && part.thought) {
const thoughtText = `[Thought: ${part.thought}]`;
const newPart = { ...part };
delete (newPart as Record<string, unknown>)['thought'];
const hasApiContent =
'functionCall' in newPart ||
'functionResponse' in newPart ||
'inlineData' in newPart ||
'fileData' in newPart;
if (hasApiContent) {
// It's a functionCall or other non-text part. Just strip the thought.
return newPart;
}
// If no other valid API content, this must be a text part.
// Combine existing text (if any) with the thought, preserving other properties.
const text = (newPart as { text?: unknown }).text;
const existingText = text ? String(text) : '';
const combinedText = existingText
? `${existingText}\n${thoughtText}`
: thoughtText;
return {
...newPart,
text: combinedText,
};
}
return part;
}
}

View File

@@ -47,7 +47,7 @@ describe('executeToolCall', () => {
getDebugMode: () => false,
getContentGeneratorConfig: () => ({
model: 'test-model',
authType: 'oauth-personal',
authType: 'gemini-api-key',
}),
getShellExecutionConfig: () => ({
terminalWidth: 90,

View File

@@ -99,6 +99,7 @@ describe('OpenAIContentGenerator (Refactored)', () => {
},
} as unknown as OpenAI),
buildRequest: vi.fn().mockImplementation((req) => req),
getDefaultGenerationConfig: vi.fn().mockReturnValue({}),
};
generator = new OpenAIContentGenerator(
@@ -211,6 +212,7 @@ describe('OpenAIContentGenerator (Refactored)', () => {
},
} as unknown as OpenAI),
buildRequest: vi.fn().mockImplementation((req) => req),
getDefaultGenerationConfig: vi.fn().mockReturnValue({}),
};
const testGenerator = new TestGenerator(
@@ -277,6 +279,7 @@ describe('OpenAIContentGenerator (Refactored)', () => {
},
} as unknown as OpenAI),
buildRequest: vi.fn().mockImplementation((req) => req),
getDefaultGenerationConfig: vi.fn().mockReturnValue({}),
};
const testGenerator = new TestGenerator(

View File

@@ -60,6 +60,7 @@ describe('ContentGenerationPipeline', () => {
buildClient: vi.fn().mockReturnValue(mockClient),
buildRequest: vi.fn().mockImplementation((req) => req),
buildHeaders: vi.fn().mockReturnValue({}),
getDefaultGenerationConfig: vi.fn().mockReturnValue({}),
};
// Mock telemetry service

View File

@@ -283,16 +283,22 @@ export class ContentGenerationPipeline {
private buildSamplingParameters(
request: GenerateContentParameters,
): Record<string, unknown> {
const defaultSamplingParams =
this.config.provider.getDefaultGenerationConfig();
const configSamplingParams = this.contentGeneratorConfig.samplingParams;
// Helper function to get parameter value with priority: config > request > default
const getParameterValue = <T>(
configKey: keyof NonNullable<typeof configSamplingParams>,
requestKey: keyof NonNullable<typeof request.config>,
defaultValue?: T,
requestKey?: keyof NonNullable<typeof request.config>,
): T | undefined => {
const configValue = configSamplingParams?.[configKey] as T | undefined;
const requestValue = request.config?.[requestKey] as T | undefined;
const requestValue = requestKey
? (request.config?.[requestKey] as T | undefined)
: undefined;
const defaultValue = requestKey
? (defaultSamplingParams[requestKey] as T)
: undefined;
if (configValue !== undefined) return configValue;
if (requestValue !== undefined) return requestValue;
@@ -304,12 +310,8 @@ export class ContentGenerationPipeline {
key: string,
configKey: keyof NonNullable<typeof configSamplingParams>,
requestKey?: keyof NonNullable<typeof request.config>,
defaultValue?: T,
): Record<string, T> | Record<string, never> => {
const value = requestKey
? getParameterValue(configKey, requestKey, defaultValue)
: ((configSamplingParams?.[configKey] as T | undefined) ??
defaultValue);
): Record<string, T | undefined> => {
const value = getParameterValue<T>(configKey, requestKey);
return value !== undefined ? { [key]: value } : {};
};
@@ -323,10 +325,18 @@ export class ContentGenerationPipeline {
...addParameterIfDefined('max_tokens', 'max_tokens', 'maxOutputTokens'),
// Config-only parameters (no request fallback)
...addParameterIfDefined('top_k', 'top_k'),
...addParameterIfDefined('top_k', 'top_k', 'topK'),
...addParameterIfDefined('repetition_penalty', 'repetition_penalty'),
...addParameterIfDefined('presence_penalty', 'presence_penalty'),
...addParameterIfDefined('frequency_penalty', 'frequency_penalty'),
...addParameterIfDefined(
'presence_penalty',
'presence_penalty',
'presencePenalty',
),
...addParameterIfDefined(
'frequency_penalty',
'frequency_penalty',
'frequencyPenalty',
),
};
return params;

View File

@@ -1,4 +1,5 @@
import OpenAI from 'openai';
import type { GenerateContentConfig } from '@google/genai';
import type { Config } from '../../../config/config.js';
import type { ContentGeneratorConfig } from '../../contentGenerator.js';
import { AuthType } from '../../contentGenerator.js';
@@ -141,6 +142,14 @@ export class DashScopeOpenAICompatibleProvider
};
}
getDefaultGenerationConfig(): GenerateContentConfig {
return {
temperature: 0.7,
topP: 0.8,
topK: 20,
};
}
/**
* Add cache control flag to specified message(s) for DashScope providers
*/

View File

@@ -8,6 +8,7 @@ import type OpenAI from 'openai';
import type { Config } from '../../../config/config.js';
import type { ContentGeneratorConfig } from '../../contentGenerator.js';
import { DefaultOpenAICompatibleProvider } from './default.js';
import type { GenerateContentConfig } from '@google/genai';
export class DeepSeekOpenAICompatibleProvider extends DefaultOpenAICompatibleProvider {
constructor(
@@ -76,4 +77,10 @@ export class DeepSeekOpenAICompatibleProvider extends DefaultOpenAICompatiblePro
messages,
};
}
override getDefaultGenerationConfig(): GenerateContentConfig {
return {
temperature: 0,
};
}
}

View File

@@ -1,4 +1,5 @@
import OpenAI from 'openai';
import type { GenerateContentConfig } from '@google/genai';
import type { Config } from '../../../config/config.js';
import type { ContentGeneratorConfig } from '../../contentGenerator.js';
import { DEFAULT_TIMEOUT, DEFAULT_MAX_RETRIES } from '../constants.js';
@@ -55,4 +56,11 @@ export class DefaultOpenAICompatibleProvider
...request, // Preserve all original parameters including sampling params
};
}
getDefaultGenerationConfig(): GenerateContentConfig {
return {
temperature: 1,
topP: 0.95,
};
}
}

View File

@@ -1,3 +1,4 @@
import type { GenerateContentConfig } from '@google/genai';
import type OpenAI from 'openai';
// Extended types to support cache_control for DashScope
@@ -22,6 +23,7 @@ export interface OpenAICompatibleProvider {
request: OpenAI.Chat.ChatCompletionCreateParams,
userPromptId: string,
): OpenAI.Chat.ChatCompletionCreateParams;
getDefaultGenerationConfig(): GenerateContentConfig;
}
export type DashScopeRequestMetadata = {

View File

@@ -4,36 +4,10 @@
* SPDX-License-Identifier: Apache-2.0
*/
import {
describe,
it,
expect,
vi,
beforeEach,
type Mock,
type MockInstance,
afterEach,
} from 'vitest';
import { describe, it, expect, vi, beforeEach } from 'vitest';
import { handleFallback } from './handler.js';
import type { Config } from '../config/config.js';
import { AuthType } from '../core/contentGenerator.js';
import {
DEFAULT_GEMINI_FLASH_MODEL,
DEFAULT_GEMINI_MODEL,
} from '../config/models.js';
import { logFlashFallback } from '../telemetry/index.js';
import type { FallbackModelHandler } from './types.js';
// Mock the telemetry logger and event class
vi.mock('../telemetry/index.js', () => ({
logFlashFallback: vi.fn(),
FlashFallbackEvent: class {},
}));
const MOCK_PRO_MODEL = DEFAULT_GEMINI_MODEL;
const FALLBACK_MODEL = DEFAULT_GEMINI_FLASH_MODEL;
const AUTH_OAUTH = AuthType.LOGIN_WITH_GOOGLE;
const AUTH_API_KEY = AuthType.USE_GEMINI;
const createMockConfig = (overrides: Partial<Config> = {}): Config =>
({
@@ -45,174 +19,28 @@ const createMockConfig = (overrides: Partial<Config> = {}): Config =>
describe('handleFallback', () => {
let mockConfig: Config;
let mockHandler: Mock<FallbackModelHandler>;
let consoleErrorSpy: MockInstance;
beforeEach(() => {
vi.clearAllMocks();
mockHandler = vi.fn();
// Default setup: OAuth user, Pro model failed, handler injected
mockConfig = createMockConfig({
fallbackModelHandler: mockHandler,
});
consoleErrorSpy = vi.spyOn(console, 'error').mockImplementation(() => {});
mockConfig = createMockConfig();
});
afterEach(() => {
consoleErrorSpy.mockRestore();
});
it('should return null immediately if authType is not OAuth', async () => {
it('should return null for unknown auth types', async () => {
const result = await handleFallback(
mockConfig,
MOCK_PRO_MODEL,
AUTH_API_KEY,
'test-model',
'unknown-auth',
);
expect(result).toBeNull();
expect(mockHandler).not.toHaveBeenCalled();
expect(mockConfig.setFallbackMode).not.toHaveBeenCalled();
});
it('should return null if the failed model is already the fallback model', async () => {
it('should handle Qwen OAuth error', async () => {
const result = await handleFallback(
mockConfig,
FALLBACK_MODEL, // Failed model is Flash
AUTH_OAUTH,
'test-model',
AuthType.QWEN_OAUTH,
new Error('unauthorized'),
);
expect(result).toBeNull();
expect(mockHandler).not.toHaveBeenCalled();
});
it('should return null if no fallbackHandler is injected in config', async () => {
const configWithoutHandler = createMockConfig({
fallbackModelHandler: undefined,
});
const result = await handleFallback(
configWithoutHandler,
MOCK_PRO_MODEL,
AUTH_OAUTH,
);
expect(result).toBeNull();
});
describe('when handler returns "retry"', () => {
it('should activate fallback mode, log telemetry, and return true', async () => {
mockHandler.mockResolvedValue('retry');
const result = await handleFallback(
mockConfig,
MOCK_PRO_MODEL,
AUTH_OAUTH,
);
expect(result).toBe(true);
expect(mockConfig.setFallbackMode).toHaveBeenCalledWith(true);
expect(logFlashFallback).toHaveBeenCalled();
});
});
describe('when handler returns "stop"', () => {
it('should activate fallback mode, log telemetry, and return false', async () => {
mockHandler.mockResolvedValue('stop');
const result = await handleFallback(
mockConfig,
MOCK_PRO_MODEL,
AUTH_OAUTH,
);
expect(result).toBe(false);
expect(mockConfig.setFallbackMode).toHaveBeenCalledWith(true);
expect(logFlashFallback).toHaveBeenCalled();
});
});
describe('when handler returns "auth"', () => {
it('should NOT activate fallback mode and return false', async () => {
mockHandler.mockResolvedValue('auth');
const result = await handleFallback(
mockConfig,
MOCK_PRO_MODEL,
AUTH_OAUTH,
);
expect(result).toBe(false);
expect(mockConfig.setFallbackMode).not.toHaveBeenCalled();
expect(logFlashFallback).not.toHaveBeenCalled();
});
});
describe('when handler returns an unexpected value', () => {
it('should log an error and return null', async () => {
mockHandler.mockResolvedValue(null);
const result = await handleFallback(
mockConfig,
MOCK_PRO_MODEL,
AUTH_OAUTH,
);
expect(result).toBeNull();
expect(consoleErrorSpy).toHaveBeenCalledWith(
'Fallback UI handler failed:',
new Error(
'Unexpected fallback intent received from fallbackModelHandler: "null"',
),
);
expect(mockConfig.setFallbackMode).not.toHaveBeenCalled();
});
});
it('should pass the correct context (failedModel, fallbackModel, error) to the handler', async () => {
const mockError = new Error('Quota Exceeded');
mockHandler.mockResolvedValue('retry');
await handleFallback(mockConfig, MOCK_PRO_MODEL, AUTH_OAUTH, mockError);
expect(mockHandler).toHaveBeenCalledWith(
MOCK_PRO_MODEL,
FALLBACK_MODEL,
mockError,
);
});
it('should not call setFallbackMode or log telemetry if already in fallback mode', async () => {
// Setup config where fallback mode is already active
const activeFallbackConfig = createMockConfig({
fallbackModelHandler: mockHandler,
isInFallbackMode: vi.fn(() => true), // Already active
setFallbackMode: vi.fn(),
});
mockHandler.mockResolvedValue('retry');
const result = await handleFallback(
activeFallbackConfig,
MOCK_PRO_MODEL,
AUTH_OAUTH,
);
// Should still return true to allow the retry (which will use the active fallback mode)
expect(result).toBe(true);
// Should still consult the handler
expect(mockHandler).toHaveBeenCalled();
// But should not mutate state or log telemetry again
expect(activeFallbackConfig.setFallbackMode).not.toHaveBeenCalled();
expect(logFlashFallback).not.toHaveBeenCalled();
});
it('should catch errors from the handler, log an error, and return null', async () => {
const handlerError = new Error('UI interaction failed');
mockHandler.mockRejectedValue(handlerError);
const result = await handleFallback(mockConfig, MOCK_PRO_MODEL, AUTH_OAUTH);
expect(result).toBeNull();
expect(consoleErrorSpy).toHaveBeenCalledWith(
'Fallback UI handler failed:',
handlerError,
);
expect(mockConfig.setFallbackMode).not.toHaveBeenCalled();
});
});

View File

@@ -6,8 +6,6 @@
import type { Config } from '../config/config.js';
import { AuthType } from '../core/contentGenerator.js';
import { DEFAULT_GEMINI_FLASH_MODEL } from '../config/models.js';
import { logFlashFallback, FlashFallbackEvent } from '../telemetry/index.js';
export async function handleFallback(
config: Config,
@@ -20,48 +18,7 @@ export async function handleFallback(
return handleQwenOAuthError(error);
}
// Applicability Checks
if (authType !== AuthType.LOGIN_WITH_GOOGLE) return null;
const fallbackModel = DEFAULT_GEMINI_FLASH_MODEL;
if (failedModel === fallbackModel) return null;
// Consult UI Handler for Intent
const fallbackModelHandler = config.fallbackModelHandler;
if (typeof fallbackModelHandler !== 'function') return null;
try {
// Pass the specific failed model to the UI handler.
const intent = await fallbackModelHandler(
failedModel,
fallbackModel,
error,
);
// Process Intent and Update State
switch (intent) {
case 'retry':
// Activate fallback mode. The NEXT retry attempt will pick this up.
activateFallbackMode(config, authType);
return true; // Signal retryWithBackoff to continue.
case 'stop':
activateFallbackMode(config, authType);
return false;
case 'auth':
return false;
default:
throw new Error(
`Unexpected fallback intent received from fallbackModelHandler: "${intent}"`,
);
}
} catch (handlerError) {
console.error('Fallback UI handler failed:', handlerError);
return null;
}
return null;
}
/**
@@ -118,12 +75,3 @@ async function handleQwenOAuthError(error?: unknown): Promise<string | null> {
// For other errors, don't handle them specially
return null;
}
function activateFallbackMode(config: Config, authType: string | undefined) {
if (!config.isInFallbackMode()) {
config.setFallbackMode(true);
if (authType) {
logFlashFallback(config, new FlashFallbackEvent(authType));
}
}
}

View File

@@ -12,7 +12,6 @@ export * from './output/json-formatter.js';
// Export Core Logic
export * from './core/client.js';
export * from './core/contentGenerator.js';
export * from './core/loggingContentGenerator.js';
export * from './core/geminiChat.js';
export * from './core/logger.js';
export * from './core/prompts.js';
@@ -24,11 +23,7 @@ export * from './core/nonInteractiveToolExecutor.js';
export * from './fallback/types.js';
export * from './code_assist/codeAssist.js';
export * from './code_assist/oauth2.js';
export * from './qwen/qwenOAuth2.js';
export * from './code_assist/server.js';
export * from './code_assist/types.js';
// Export utilities
export * from './utils/paths.js';

View File

@@ -907,3 +907,5 @@ export async function clearQwenCredentials(): Promise<void> {
function getQwenCachedCredentialPath(): string {
return path.join(os.homedir(), QWEN_DIR, QWEN_CREDENTIAL_FILENAME);
}
export const clearCachedCredentialFile = clearQwenCredentials;

View File

@@ -30,7 +30,6 @@ import {
ToolCallEvent,
} from '../types.js';
import { GIT_COMMIT_INFO, CLI_VERSION } from '../../generated/git-commit.js';
import { UserAccountManager } from '../../utils/userAccountManager.js';
import { InstallationManager } from '../../utils/installationManager.js';
import { safeJsonStringify } from '../../utils/safeJsonStringify.js';
@@ -90,10 +89,8 @@ expect.extend({
},
});
vi.mock('../../utils/userAccountManager.js');
vi.mock('../../utils/installationManager.js');
const mockUserAccount = vi.mocked(UserAccountManager.prototype);
const mockInstallMgr = vi.mocked(InstallationManager.prototype);
// TODO(richieforeman): Consider moving this to test setup globally.
@@ -128,11 +125,7 @@ describe('ClearcutLogger', () => {
vi.unstubAllEnvs();
});
function setup({
config = {} as Partial<ConfigParameters>,
lifetimeGoogleAccounts = 1,
cachedGoogleAccount = 'test@google.com',
} = {}) {
function setup({ config = {} as Partial<ConfigParameters> } = {}) {
server.resetHandlers(
http.post(CLEARCUT_URL, () => HttpResponse.text(EXAMPLE_RESPONSE)),
);
@@ -146,10 +139,6 @@ describe('ClearcutLogger', () => {
});
ClearcutLogger.clearInstance();
mockUserAccount.getCachedGoogleAccount.mockReturnValue(cachedGoogleAccount);
mockUserAccount.getLifetimeGoogleAccounts.mockReturnValue(
lifetimeGoogleAccounts,
);
mockInstallMgr.getInstallationId = vi
.fn()
.mockReturnValue('test-installation-id');
@@ -195,19 +184,6 @@ describe('ClearcutLogger', () => {
});
describe('createLogEvent', () => {
it('logs the total number of google accounts', () => {
const { logger } = setup({
lifetimeGoogleAccounts: 9001,
});
const event = logger?.createLogEvent(EventNames.API_ERROR, []);
expect(event?.event_metadata[0]).toContainEqual({
gemini_cli_key: EventMetadataKey.GEMINI_CLI_GOOGLE_ACCOUNTS_COUNT,
value: '9001',
});
});
it('logs the current surface from a github action', () => {
const { logger } = setup({});
@@ -251,7 +227,6 @@ describe('ClearcutLogger', () => {
// Define expected values
const session_id = 'test-session-id';
const auth_type = AuthType.USE_GEMINI;
const google_accounts = 123;
const surface = 'ide-1234';
const cli_version = CLI_VERSION;
const git_commit_hash = GIT_COMMIT_INFO;
@@ -260,7 +235,6 @@ describe('ClearcutLogger', () => {
// Setup logger with expected values
const { logger, loggerConfig } = setup({
lifetimeGoogleAccounts: google_accounts,
config: {},
});
vi.spyOn(loggerConfig, 'getContentGeneratorConfig').mockReturnValue({
@@ -283,10 +257,6 @@ describe('ClearcutLogger', () => {
gemini_cli_key: EventMetadataKey.GEMINI_CLI_AUTH_TYPE,
value: JSON.stringify(auth_type),
},
{
gemini_cli_key: EventMetadataKey.GEMINI_CLI_GOOGLE_ACCOUNTS_COUNT,
value: `${google_accounts}`,
},
{
gemini_cli_key: EventMetadataKey.GEMINI_CLI_SURFACE,
value: surface,
@@ -404,10 +374,14 @@ describe('ClearcutLogger', () => {
vi.stubEnv(key, value);
}
const event = logger?.createLogEvent(EventNames.API_ERROR, []);
expect(event?.event_metadata[0][3]).toEqual({
gemini_cli_key: EventMetadataKey.GEMINI_CLI_SURFACE,
value: expectedValue,
});
expect(event?.event_metadata[0]).toEqual(
expect.arrayContaining([
{
gemini_cli_key: EventMetadataKey.GEMINI_CLI_SURFACE,
value: expectedValue,
},
]),
);
},
);
});

View File

@@ -34,7 +34,6 @@ import type {
import { EventMetadataKey } from './event-metadata-key.js';
import type { Config } from '../../config/config.js';
import { InstallationManager } from '../../utils/installationManager.js';
import { UserAccountManager } from '../../utils/userAccountManager.js';
import { safeJsonStringify } from '../../utils/safeJsonStringify.js';
import { FixedDeque } from 'mnemonist';
import { GIT_COMMIT_INFO, CLI_VERSION } from '../../generated/git-commit.js';
@@ -157,7 +156,6 @@ export class ClearcutLogger {
private sessionData: EventValue[] = [];
private promptId: string = '';
private readonly installationManager: InstallationManager;
private readonly userAccountManager: UserAccountManager;
/**
* Queue of pending events that need to be flushed to the server. New events
@@ -186,7 +184,6 @@ export class ClearcutLogger {
this.events = new FixedDeque<LogEventEntry[]>(Array, MAX_EVENTS);
this.promptId = config?.getSessionId() ?? '';
this.installationManager = new InstallationManager();
this.userAccountManager = new UserAccountManager();
}
static getInstance(config?: Config): ClearcutLogger | undefined {
@@ -233,14 +230,11 @@ export class ClearcutLogger {
}
createLogEvent(eventName: EventNames, data: EventValue[] = []): LogEvent {
const email = this.userAccountManager.getCachedGoogleAccount();
if (eventName !== EventNames.START_SESSION) {
data.push(...this.sessionData);
}
const totalAccounts = this.userAccountManager.getLifetimeGoogleAccounts();
data = this.addDefaultFields(data, totalAccounts);
data = this.addDefaultFields(data);
const logEvent: LogEvent = {
console_type: 'GEMINI_CLI',
@@ -249,12 +243,7 @@ export class ClearcutLogger {
event_metadata: [data],
};
// Should log either email or install ID, not both. See go/cloudmill-1p-oss-instrumentation#define-sessionable-id
if (email) {
logEvent.client_email = email;
} else {
logEvent.client_install_id = this.installationManager.getInstallationId();
}
logEvent.client_install_id = this.installationManager.getInstallationId();
return logEvent;
}
@@ -1018,7 +1007,7 @@ export class ClearcutLogger {
* Adds default fields to data, and returns a new data array. This fields
* should exist on all log events.
*/
addDefaultFields(data: EventValue[], totalAccounts: number): EventValue[] {
addDefaultFields(data: EventValue[]): EventValue[] {
const surface = determineSurface();
const defaultLogMetadata: EventValue[] = [
@@ -1032,10 +1021,6 @@ export class ClearcutLogger {
this.config?.getContentGeneratorConfig()?.authType,
),
},
{
gemini_cli_key: EventMetadataKey.GEMINI_CLI_GOOGLE_ACCOUNTS_COUNT,
value: `${totalAccounts}`,
},
{
gemini_cli_key: EventMetadataKey.GEMINI_CLI_SURFACE,
value: surface,

View File

@@ -83,7 +83,6 @@ import type {
} from '@google/genai';
import { DiscoveredMCPTool } from '../tools/mcp-tool.js';
import * as uiTelemetry from './uiTelemetry.js';
import { UserAccountManager } from '../utils/userAccountManager.js';
import { makeFakeConfig } from '../test-utils/config.js';
describe('loggers', () => {
@@ -101,10 +100,6 @@ describe('loggers', () => {
vi.spyOn(uiTelemetry.uiTelemetryService, 'addEvent').mockImplementation(
mockUiEvent.addEvent,
);
vi.spyOn(
UserAccountManager.prototype,
'getCachedGoogleAccount',
).mockReturnValue('test-user@example.com');
vi.useFakeTimers();
vi.setSystemTime(new Date('2025-01-01T00:00:00.000Z'));
});
@@ -188,7 +183,6 @@ describe('loggers', () => {
body: 'CLI configuration loaded.',
attributes: {
'session.id': 'test-session-id',
'user.email': 'test-user@example.com',
'event.name': EVENT_CLI_CONFIG,
'event.timestamp': '2025-01-01T00:00:00.000Z',
model: 'test-model',
@@ -233,7 +227,6 @@ describe('loggers', () => {
body: 'User prompt. Length: 11.',
attributes: {
'session.id': 'test-session-id',
'user.email': 'test-user@example.com',
'event.name': EVENT_USER_PROMPT,
'event.timestamp': '2025-01-01T00:00:00.000Z',
prompt_length: 11,
@@ -255,7 +248,7 @@ describe('loggers', () => {
const event = new UserPromptEvent(
11,
'prompt-id-9',
AuthType.CLOUD_SHELL,
AuthType.USE_GEMINI,
'test-prompt',
);
@@ -265,12 +258,11 @@ describe('loggers', () => {
body: 'User prompt. Length: 11.',
attributes: {
'session.id': 'test-session-id',
'user.email': 'test-user@example.com',
'event.name': EVENT_USER_PROMPT,
'event.timestamp': '2025-01-01T00:00:00.000Z',
prompt_length: 11,
prompt_id: 'prompt-id-9',
auth_type: 'cloud-shell',
auth_type: 'gemini-api-key',
},
});
});
@@ -313,7 +305,7 @@ describe('loggers', () => {
'test-model',
100,
'prompt-id-1',
AuthType.LOGIN_WITH_GOOGLE,
AuthType.USE_GEMINI,
usageData,
'test-response',
);
@@ -324,7 +316,6 @@ describe('loggers', () => {
body: 'API response from test-model. Status: 200. Duration: 100ms.',
attributes: {
'session.id': 'test-session-id',
'user.email': 'test-user@example.com',
'event.name': EVENT_API_RESPONSE,
'event.timestamp': '2025-01-01T00:00:00.000Z',
[SemanticAttributes.HTTP_STATUS_CODE]: 200,
@@ -340,7 +331,7 @@ describe('loggers', () => {
total_token_count: 0,
response_text: 'test-response',
prompt_id: 'prompt-id-1',
auth_type: 'oauth-personal',
auth_type: 'gemini-api-key',
},
});
@@ -386,7 +377,6 @@ describe('loggers', () => {
body: 'API request to test-model.',
attributes: {
'session.id': 'test-session-id',
'user.email': 'test-user@example.com',
'event.name': EVENT_API_REQUEST,
'event.timestamp': '2025-01-01T00:00:00.000Z',
model: 'test-model',
@@ -405,7 +395,6 @@ describe('loggers', () => {
body: 'API request to test-model.',
attributes: {
'session.id': 'test-session-id',
'user.email': 'test-user@example.com',
'event.name': EVENT_API_REQUEST,
'event.timestamp': '2025-01-01T00:00:00.000Z',
model: 'test-model',
@@ -430,7 +419,6 @@ describe('loggers', () => {
body: 'Switching to flash as Fallback.',
attributes: {
'session.id': 'test-session-id',
'user.email': 'test-user@example.com',
'event.name': EVENT_FLASH_FALLBACK,
'event.timestamp': '2025-01-01T00:00:00.000Z',
auth_type: 'vertex-ai',
@@ -465,7 +453,6 @@ describe('loggers', () => {
expect(emittedEvent.attributes).toEqual(
expect.objectContaining({
'session.id': 'test-session-id',
'user.email': 'test-user@example.com',
'event.name': EVENT_RIPGREP_FALLBACK,
error: 'ripgrep is not available',
}),
@@ -484,7 +471,6 @@ describe('loggers', () => {
expect(emittedEvent.attributes).toEqual(
expect.objectContaining({
'session.id': 'test-session-id',
'user.email': 'test-user@example.com',
'event.name': EVENT_RIPGREP_FALLBACK,
error: 'rg not found',
}),
@@ -598,7 +584,6 @@ describe('loggers', () => {
body: 'Tool call: test-function. Decision: accept. Success: true. Duration: 100ms.',
attributes: {
'session.id': 'test-session-id',
'user.email': 'test-user@example.com',
'event.name': EVENT_TOOL_CALL,
'event.timestamp': '2025-01-01T00:00:00.000Z',
function_name: 'test-function',
@@ -682,7 +667,6 @@ describe('loggers', () => {
body: 'Tool call: test-function. Decision: reject. Success: false. Duration: 100ms.',
attributes: {
'session.id': 'test-session-id',
'user.email': 'test-user@example.com',
'event.name': EVENT_TOOL_CALL,
'event.timestamp': '2025-01-01T00:00:00.000Z',
function_name: 'test-function',
@@ -759,7 +743,6 @@ describe('loggers', () => {
body: 'Tool call: test-function. Decision: modify. Success: true. Duration: 100ms.',
attributes: {
'session.id': 'test-session-id',
'user.email': 'test-user@example.com',
'event.name': EVENT_TOOL_CALL,
'event.timestamp': '2025-01-01T00:00:00.000Z',
function_name: 'test-function',
@@ -835,7 +818,6 @@ describe('loggers', () => {
body: 'Tool call: test-function. Success: true. Duration: 100ms.',
attributes: {
'session.id': 'test-session-id',
'user.email': 'test-user@example.com',
'event.name': EVENT_TOOL_CALL,
'event.timestamp': '2025-01-01T00:00:00.000Z',
function_name: 'test-function',
@@ -910,7 +892,6 @@ describe('loggers', () => {
body: 'Tool call: test-function. Success: false. Duration: 100ms.',
attributes: {
'session.id': 'test-session-id',
'user.email': 'test-user@example.com',
'event.name': EVENT_TOOL_CALL,
'event.timestamp': '2025-01-01T00:00:00.000Z',
function_name: 'test-function',
@@ -999,7 +980,6 @@ describe('loggers', () => {
body: 'Tool call: mock_mcp_tool. Success: true. Duration: 100ms.',
attributes: {
'session.id': 'test-session-id',
'user.email': 'test-user@example.com',
'event.name': EVENT_TOOL_CALL,
'event.timestamp': '2025-01-01T00:00:00.000Z',
function_name: 'mock_mcp_tool',
@@ -1047,7 +1027,6 @@ describe('loggers', () => {
body: 'Malformed JSON response from test-model.',
attributes: {
'session.id': 'test-session-id',
'user.email': 'test-user@example.com',
'event.name': EVENT_MALFORMED_JSON_RESPONSE,
'event.timestamp': '2025-01-01T00:00:00.000Z',
model: 'test-model',
@@ -1091,7 +1070,6 @@ describe('loggers', () => {
body: 'File operation: read. Lines: 10.',
attributes: {
'session.id': 'test-session-id',
'user.email': 'test-user@example.com',
'event.name': EVENT_FILE_OPERATION,
'event.timestamp': '2025-01-01T00:00:00.000Z',
tool_name: 'test-tool',
@@ -1137,7 +1115,6 @@ describe('loggers', () => {
body: 'Tool output truncated for test-tool.',
attributes: {
'session.id': 'test-session-id',
'user.email': 'test-user@example.com',
'event.name': 'tool_output_truncated',
'event.timestamp': '2025-01-01T00:00:00.000Z',
eventName: 'tool_output_truncated',
@@ -1184,7 +1161,6 @@ describe('loggers', () => {
body: 'Installed extension vscode',
attributes: {
'session.id': 'test-session-id',
'user.email': 'test-user@example.com',
'event.name': EVENT_EXTENSION_INSTALL,
'event.timestamp': '2025-01-01T00:00:00.000Z',
extension_name: 'vscode',
@@ -1223,7 +1199,6 @@ describe('loggers', () => {
body: 'Uninstalled extension vscode',
attributes: {
'session.id': 'test-session-id',
'user.email': 'test-user@example.com',
'event.name': EVENT_EXTENSION_UNINSTALL,
'event.timestamp': '2025-01-01T00:00:00.000Z',
extension_name: 'vscode',
@@ -1260,7 +1235,6 @@ describe('loggers', () => {
body: 'Enabled extension vscode',
attributes: {
'session.id': 'test-session-id',
'user.email': 'test-user@example.com',
'event.name': EVENT_EXTENSION_ENABLE,
'event.timestamp': '2025-01-01T00:00:00.000Z',
extension_name: 'vscode',
@@ -1297,7 +1271,6 @@ describe('loggers', () => {
body: 'Disabled extension vscode',
attributes: {
'session.id': 'test-session-id',
'user.email': 'test-user@example.com',
'event.name': EVENT_EXTENSION_DISABLE,
'event.timestamp': '2025-01-01T00:00:00.000Z',
extension_name: 'vscode',

View File

@@ -9,7 +9,6 @@ import { logs } from '@opentelemetry/api-logs';
import { SemanticAttributes } from '@opentelemetry/semantic-conventions';
import type { Config } from '../config/config.js';
import { safeJsonStringify } from '../utils/safeJsonStringify.js';
import { UserAccountManager } from '../utils/userAccountManager.js';
import {
EVENT_API_ERROR,
EVENT_API_CANCEL,
@@ -93,11 +92,8 @@ const shouldLogUserPrompts = (config: Config): boolean =>
config.getTelemetryLogPromptsEnabled();
function getCommonAttributes(config: Config): LogAttributes {
const userAccountManager = new UserAccountManager();
const email = userAccountManager.getCachedGoogleAccount();
return {
'session.id': config.getSessionId(),
...(email && { 'user.email': email }),
};
}

View File

@@ -217,9 +217,9 @@ describe('mcp-client', () => {
false,
);
expect(transport).toEqual(
new StreamableHTTPClientTransport(new URL('http://test-server'), {}),
);
expect(transport).toBeInstanceOf(StreamableHTTPClientTransport);
// eslint-disable-next-line @typescript-eslint/no-explicit-any
expect((transport as any)._url).toEqual(new URL('http://test-server'));
});
it('with headers', async () => {
@@ -232,13 +232,13 @@ describe('mcp-client', () => {
false,
);
expect(transport).toEqual(
new StreamableHTTPClientTransport(new URL('http://test-server'), {
requestInit: {
headers: { Authorization: 'derp' },
},
}),
);
expect(transport).toBeInstanceOf(StreamableHTTPClientTransport);
// eslint-disable-next-line @typescript-eslint/no-explicit-any
expect((transport as any)._url).toEqual(new URL('http://test-server'));
// eslint-disable-next-line @typescript-eslint/no-explicit-any
expect((transport as any)._requestInit?.headers).toEqual({
Authorization: 'derp',
});
});
});
@@ -251,9 +251,9 @@ describe('mcp-client', () => {
},
false,
);
expect(transport).toEqual(
new SSEClientTransport(new URL('http://test-server'), {}),
);
expect(transport).toBeInstanceOf(SSEClientTransport);
// eslint-disable-next-line @typescript-eslint/no-explicit-any
expect((transport as any)._url).toEqual(new URL('http://test-server'));
});
it('with headers', async () => {
@@ -266,13 +266,13 @@ describe('mcp-client', () => {
false,
);
expect(transport).toEqual(
new SSEClientTransport(new URL('http://test-server'), {
requestInit: {
headers: { Authorization: 'derp' },
},
}),
);
expect(transport).toBeInstanceOf(SSEClientTransport);
// eslint-disable-next-line @typescript-eslint/no-explicit-any
expect((transport as any)._url).toEqual(new URL('http://test-server'));
// eslint-disable-next-line @typescript-eslint/no-explicit-any
expect((transport as any)._requestInit?.headers).toEqual({
Authorization: 'derp',
});
});
});

View File

@@ -6,9 +6,6 @@
import { describe, it, expect } from 'vitest';
import { parseAndFormatApiError } from './errorParsing.js';
import { isProQuotaExceededError } from './quotaErrorDetection.js';
import { DEFAULT_GEMINI_FLASH_MODEL } from '../config/models.js';
import { UserTierId } from '../code_assist/types.js';
import { AuthType } from '../core/contentGenerator.js';
import type { StructuredError } from '../core/turn.js';
@@ -27,32 +24,10 @@ describe('parseAndFormatApiError', () => {
it('should format a 429 API error with the default message', () => {
const errorMessage =
'got status: 429 Too Many Requests. {"error":{"code":429,"message":"Rate limit exceeded","status":"RESOURCE_EXHAUSTED"}}';
const result = parseAndFormatApiError(
errorMessage,
undefined,
undefined,
'gemini-2.5-pro',
DEFAULT_GEMINI_FLASH_MODEL,
);
const result = parseAndFormatApiError(errorMessage, undefined);
expect(result).toContain('[API Error: Rate limit exceeded');
expect(result).toContain(
'Possible quota limitations in place or slow response times detected. Switching to the gemini-2.5-flash model',
);
});
it('should format a 429 API error with the personal message', () => {
const errorMessage =
'got status: 429 Too Many Requests. {"error":{"code":429,"message":"Rate limit exceeded","status":"RESOURCE_EXHAUSTED"}}';
const result = parseAndFormatApiError(
errorMessage,
AuthType.LOGIN_WITH_GOOGLE,
undefined,
'gemini-2.5-pro',
DEFAULT_GEMINI_FLASH_MODEL,
);
expect(result).toContain('[API Error: Rate limit exceeded');
expect(result).toContain(
'Possible quota limitations in place or slow response times detected. Switching to the gemini-2.5-flash model',
'Possible quota limitations in place or slow response times detected. Please wait and try again later.',
);
});
@@ -132,230 +107,4 @@ describe('parseAndFormatApiError', () => {
const expected = '[API Error: An unknown error occurred.]';
expect(parseAndFormatApiError(error)).toBe(expected);
});
it('should format a 429 API error with Pro quota exceeded message for Google auth (Free tier)', () => {
const errorMessage =
'got status: 429 Too Many Requests. {"error":{"code":429,"message":"Quota exceeded for quota metric \'Gemini 2.5 Pro Requests\' and limit \'RequestsPerDay\' of service \'generativelanguage.googleapis.com\' for consumer \'project_number:123456789\'.","status":"RESOURCE_EXHAUSTED"}}';
const result = parseAndFormatApiError(
errorMessage,
AuthType.LOGIN_WITH_GOOGLE,
undefined,
'gemini-2.5-pro',
DEFAULT_GEMINI_FLASH_MODEL,
);
expect(result).toContain(
"[API Error: Quota exceeded for quota metric 'Gemini 2.5 Pro Requests'",
);
expect(result).toContain(
'You have reached your daily gemini-2.5-pro quota limit',
);
expect(result).toContain('upgrade to get higher limits');
});
it('should format a regular 429 API error with standard message for Google auth', () => {
const errorMessage =
'got status: 429 Too Many Requests. {"error":{"code":429,"message":"Rate limit exceeded","status":"RESOURCE_EXHAUSTED"}}';
const result = parseAndFormatApiError(
errorMessage,
AuthType.LOGIN_WITH_GOOGLE,
undefined,
'gemini-2.5-pro',
DEFAULT_GEMINI_FLASH_MODEL,
);
expect(result).toContain('[API Error: Rate limit exceeded');
expect(result).toContain(
'Possible quota limitations in place or slow response times detected. Switching to the gemini-2.5-flash model',
);
expect(result).not.toContain(
'You have reached your daily gemini-2.5-pro quota limit',
);
});
it('should format a 429 API error with generic quota exceeded message for Google auth', () => {
const errorMessage =
'got status: 429 Too Many Requests. {"error":{"code":429,"message":"Quota exceeded for quota metric \'GenerationRequests\' and limit \'RequestsPerDay\' of service \'generativelanguage.googleapis.com\' for consumer \'project_number:123456789\'.","status":"RESOURCE_EXHAUSTED"}}';
const result = parseAndFormatApiError(
errorMessage,
AuthType.LOGIN_WITH_GOOGLE,
undefined,
'gemini-2.5-pro',
DEFAULT_GEMINI_FLASH_MODEL,
);
expect(result).toContain(
"[API Error: Quota exceeded for quota metric 'GenerationRequests'",
);
expect(result).toContain('You have reached your daily quota limit');
expect(result).not.toContain(
'You have reached your daily Gemini 2.5 Pro quota limit',
);
});
it('should prioritize Pro quota message over generic quota message for Google auth', () => {
const errorMessage =
'got status: 429 Too Many Requests. {"error":{"code":429,"message":"Quota exceeded for quota metric \'Gemini 2.5 Pro Requests\' and limit \'RequestsPerDay\' of service \'generativelanguage.googleapis.com\' for consumer \'project_number:123456789\'.","status":"RESOURCE_EXHAUSTED"}}';
const result = parseAndFormatApiError(
errorMessage,
AuthType.LOGIN_WITH_GOOGLE,
undefined,
'gemini-2.5-pro',
DEFAULT_GEMINI_FLASH_MODEL,
);
expect(result).toContain(
"[API Error: Quota exceeded for quota metric 'Gemini 2.5 Pro Requests'",
);
expect(result).toContain(
'You have reached your daily gemini-2.5-pro quota limit',
);
expect(result).not.toContain('You have reached your daily quota limit');
});
it('should format a 429 API error with Pro quota exceeded message for Google auth (Standard tier)', () => {
const errorMessage =
'got status: 429 Too Many Requests. {"error":{"code":429,"message":"Quota exceeded for quota metric \'Gemini 2.5 Pro Requests\' and limit \'RequestsPerDay\' of service \'generativelanguage.googleapis.com\' for consumer \'project_number:123456789\'.","status":"RESOURCE_EXHAUSTED"}}';
const result = parseAndFormatApiError(
errorMessage,
AuthType.LOGIN_WITH_GOOGLE,
UserTierId.STANDARD,
'gemini-2.5-pro',
DEFAULT_GEMINI_FLASH_MODEL,
);
expect(result).toContain(
"[API Error: Quota exceeded for quota metric 'Gemini 2.5 Pro Requests'",
);
expect(result).toContain(
'You have reached your daily gemini-2.5-pro quota limit',
);
expect(result).toContain(
'We appreciate you for choosing Gemini Code Assist and the Gemini CLI',
);
expect(result).not.toContain('upgrade to get higher limits');
});
it('should format a 429 API error with Pro quota exceeded message for Google auth (Legacy tier)', () => {
const errorMessage =
'got status: 429 Too Many Requests. {"error":{"code":429,"message":"Quota exceeded for quota metric \'Gemini 2.5 Pro Requests\' and limit \'RequestsPerDay\' of service \'generativelanguage.googleapis.com\' for consumer \'project_number:123456789\'.","status":"RESOURCE_EXHAUSTED"}}';
const result = parseAndFormatApiError(
errorMessage,
AuthType.LOGIN_WITH_GOOGLE,
UserTierId.LEGACY,
'gemini-2.5-pro',
DEFAULT_GEMINI_FLASH_MODEL,
);
expect(result).toContain(
"[API Error: Quota exceeded for quota metric 'Gemini 2.5 Pro Requests'",
);
expect(result).toContain(
'You have reached your daily gemini-2.5-pro quota limit',
);
expect(result).toContain(
'We appreciate you for choosing Gemini Code Assist and the Gemini CLI',
);
expect(result).not.toContain('upgrade to get higher limits');
});
it('should handle different Gemini 2.5 version strings in Pro quota exceeded errors', () => {
const errorMessage25 =
'got status: 429 Too Many Requests. {"error":{"code":429,"message":"Quota exceeded for quota metric \'Gemini 2.5 Pro Requests\' and limit \'RequestsPerDay\' of service \'generativelanguage.googleapis.com\' for consumer \'project_number:123456789\'.","status":"RESOURCE_EXHAUSTED"}}';
const errorMessagePreview =
'got status: 429 Too Many Requests. {"error":{"code":429,"message":"Quota exceeded for quota metric \'Gemini 2.5-preview Pro Requests\' and limit \'RequestsPerDay\' of service \'generativelanguage.googleapis.com\' for consumer \'project_number:123456789\'.","status":"RESOURCE_EXHAUSTED"}}';
const result25 = parseAndFormatApiError(
errorMessage25,
AuthType.LOGIN_WITH_GOOGLE,
undefined,
'gemini-2.5-pro',
DEFAULT_GEMINI_FLASH_MODEL,
);
const resultPreview = parseAndFormatApiError(
errorMessagePreview,
AuthType.LOGIN_WITH_GOOGLE,
undefined,
'gemini-2.5-preview-pro',
DEFAULT_GEMINI_FLASH_MODEL,
);
expect(result25).toContain(
'You have reached your daily gemini-2.5-pro quota limit',
);
expect(resultPreview).toContain(
'You have reached your daily gemini-2.5-preview-pro quota limit',
);
expect(result25).toContain('upgrade to get higher limits');
expect(resultPreview).toContain('upgrade to get higher limits');
});
it('should not match non-Pro models with similar version strings', () => {
// Test that Flash models with similar version strings don't match
expect(
isProQuotaExceededError(
"Quota exceeded for quota metric 'Gemini 2.5 Flash Requests' and limit",
),
).toBe(false);
expect(
isProQuotaExceededError(
"Quota exceeded for quota metric 'Gemini 2.5-preview Flash Requests' and limit",
),
).toBe(false);
// Test other model types
expect(
isProQuotaExceededError(
"Quota exceeded for quota metric 'Gemini 2.5 Ultra Requests' and limit",
),
).toBe(false);
expect(
isProQuotaExceededError(
"Quota exceeded for quota metric 'Gemini 2.5 Standard Requests' and limit",
),
).toBe(false);
// Test generic quota messages
expect(
isProQuotaExceededError(
"Quota exceeded for quota metric 'GenerationRequests' and limit",
),
).toBe(false);
expect(
isProQuotaExceededError(
"Quota exceeded for quota metric 'EmbeddingRequests' and limit",
),
).toBe(false);
});
it('should format a generic quota exceeded message for Google auth (Standard tier)', () => {
const errorMessage =
'got status: 429 Too Many Requests. {"error":{"code":429,"message":"Quota exceeded for quota metric \'GenerationRequests\' and limit \'RequestsPerDay\' of service \'generativelanguage.googleapis.com\' for consumer \'project_number:123456789\'.","status":"RESOURCE_EXHAUSTED"}}';
const result = parseAndFormatApiError(
errorMessage,
AuthType.LOGIN_WITH_GOOGLE,
UserTierId.STANDARD,
'gemini-2.5-pro',
DEFAULT_GEMINI_FLASH_MODEL,
);
expect(result).toContain(
"[API Error: Quota exceeded for quota metric 'GenerationRequests'",
);
expect(result).toContain('You have reached your daily quota limit');
expect(result).toContain(
'We appreciate you for choosing Gemini Code Assist and the Gemini CLI',
);
expect(result).not.toContain('upgrade to get higher limits');
});
it('should format a regular 429 API error with standard message for Google auth (Standard tier)', () => {
const errorMessage =
'got status: 429 Too Many Requests. {"error":{"code":429,"message":"Rate limit exceeded","status":"RESOURCE_EXHAUSTED"}}';
const result = parseAndFormatApiError(
errorMessage,
AuthType.LOGIN_WITH_GOOGLE,
UserTierId.STANDARD,
'gemini-2.5-pro',
DEFAULT_GEMINI_FLASH_MODEL,
);
expect(result).toContain('[API Error: Rate limit exceeded');
expect(result).toContain(
'We appreciate you for choosing Gemini Code Assist and the Gemini CLI',
);
expect(result).not.toContain('upgrade to get higher limits');
});
});

View File

@@ -4,120 +4,36 @@
* SPDX-License-Identifier: Apache-2.0
*/
import {
isProQuotaExceededError,
isGenericQuotaExceededError,
isApiError,
isStructuredError,
} from './quotaErrorDetection.js';
import {
DEFAULT_GEMINI_MODEL,
DEFAULT_GEMINI_FLASH_MODEL,
} from '../config/models.js';
import { UserTierId } from '../code_assist/types.js';
import { isApiError, isStructuredError } from './quotaErrorDetection.js';
import { AuthType } from '../core/contentGenerator.js';
// Free Tier message functions
const getRateLimitErrorMessageGoogleFree = (
fallbackModel: string = DEFAULT_GEMINI_FLASH_MODEL,
) =>
`\nPossible quota limitations in place or slow response times detected. Switching to the ${fallbackModel} model for the rest of this session.`;
const getRateLimitErrorMessageGoogleProQuotaFree = (
currentModel: string = DEFAULT_GEMINI_MODEL,
fallbackModel: string = DEFAULT_GEMINI_FLASH_MODEL,
) =>
`\nYou have reached your daily ${currentModel} quota limit. You will be switched to the ${fallbackModel} model for the rest of this session. To increase your limits, upgrade to get higher limits at https://goo.gle/set-up-gemini-code-assist, or use /auth to switch to using a paid API key from AI Studio at https://aistudio.google.com/apikey`;
const getRateLimitErrorMessageGoogleGenericQuotaFree = () =>
`\nYou have reached your daily quota limit. To increase your limits, upgrade to get higher limits at https://goo.gle/set-up-gemini-code-assist, or use /auth to switch to using a paid API key from AI Studio at https://aistudio.google.com/apikey`;
// Legacy/Standard Tier message functions
const getRateLimitErrorMessageGooglePaid = (
fallbackModel: string = DEFAULT_GEMINI_FLASH_MODEL,
) =>
`\nPossible quota limitations in place or slow response times detected. Switching to the ${fallbackModel} model for the rest of this session. We appreciate you for choosing Gemini Code Assist and the Gemini CLI.`;
const getRateLimitErrorMessageGoogleProQuotaPaid = (
currentModel: string = DEFAULT_GEMINI_MODEL,
fallbackModel: string = DEFAULT_GEMINI_FLASH_MODEL,
) =>
`\nYou have reached your daily ${currentModel} quota limit. You will be switched to the ${fallbackModel} model for the rest of this session. We appreciate you for choosing Gemini Code Assist and the Gemini CLI. To continue accessing the ${currentModel} model today, consider using /auth to switch to using a paid API key from AI Studio at https://aistudio.google.com/apikey`;
const getRateLimitErrorMessageGoogleGenericQuotaPaid = (
currentModel: string = DEFAULT_GEMINI_MODEL,
) =>
`\nYou have reached your daily quota limit. We appreciate you for choosing Gemini Code Assist and the Gemini CLI. To continue accessing the ${currentModel} model today, consider using /auth to switch to using a paid API key from AI Studio at https://aistudio.google.com/apikey`;
const RATE_LIMIT_ERROR_MESSAGE_USE_GEMINI =
'\nPlease wait and try again later. To increase your limits, request a quota increase through AI Studio, or switch to another /auth method';
const RATE_LIMIT_ERROR_MESSAGE_VERTEX =
'\nPlease wait and try again later. To increase your limits, request a quota increase through Vertex, or switch to another /auth method';
const getRateLimitErrorMessageDefault = (
fallbackModel: string = DEFAULT_GEMINI_FLASH_MODEL,
) =>
`\nPossible quota limitations in place or slow response times detected. Switching to the ${fallbackModel} model for the rest of this session.`;
const RATE_LIMIT_ERROR_MESSAGE_DEFAULT =
'\nPossible quota limitations in place or slow response times detected. Please wait and try again later.';
function getRateLimitMessage(
authType?: AuthType,
error?: unknown,
userTier?: UserTierId,
currentModel?: string,
fallbackModel?: string,
): string {
function getRateLimitMessage(authType?: AuthType): string {
switch (authType) {
case AuthType.LOGIN_WITH_GOOGLE: {
// Determine if user is on a paid tier (Legacy or Standard) - default to FREE if not specified
const isPaidTier =
userTier === UserTierId.LEGACY || userTier === UserTierId.STANDARD;
if (isProQuotaExceededError(error)) {
return isPaidTier
? getRateLimitErrorMessageGoogleProQuotaPaid(
currentModel || DEFAULT_GEMINI_MODEL,
fallbackModel,
)
: getRateLimitErrorMessageGoogleProQuotaFree(
currentModel || DEFAULT_GEMINI_MODEL,
fallbackModel,
);
} else if (isGenericQuotaExceededError(error)) {
return isPaidTier
? getRateLimitErrorMessageGoogleGenericQuotaPaid(
currentModel || DEFAULT_GEMINI_MODEL,
)
: getRateLimitErrorMessageGoogleGenericQuotaFree();
} else {
return isPaidTier
? getRateLimitErrorMessageGooglePaid(fallbackModel)
: getRateLimitErrorMessageGoogleFree(fallbackModel);
}
}
case AuthType.USE_GEMINI:
return RATE_LIMIT_ERROR_MESSAGE_USE_GEMINI;
case AuthType.USE_VERTEX_AI:
return RATE_LIMIT_ERROR_MESSAGE_VERTEX;
default:
return getRateLimitErrorMessageDefault(fallbackModel);
return RATE_LIMIT_ERROR_MESSAGE_DEFAULT;
}
}
export function parseAndFormatApiError(
error: unknown,
authType?: AuthType,
userTier?: UserTierId,
currentModel?: string,
fallbackModel?: string,
): string {
if (isStructuredError(error)) {
let text = `[API Error: ${error.message}]`;
if (error.status === 429) {
text += getRateLimitMessage(
authType,
error,
userTier,
currentModel,
fallbackModel,
);
text += getRateLimitMessage(authType);
}
return text;
}
@@ -146,13 +62,7 @@ export function parseAndFormatApiError(
}
let text = `[API Error: ${finalMessage} (Status: ${parsedError.error.status})]`;
if (parsedError.error.code === 429) {
text += getRateLimitMessage(
authType,
parsedError,
userTier,
currentModel,
fallbackModel,
);
text += getRateLimitMessage(authType);
}
return text;
}

View File

@@ -11,12 +11,9 @@ import {
setSimulate429,
disableSimulationAfterFallback,
shouldSimulate429,
createSimulated429Error,
resetRequestCounter,
} from './testUtils.js';
import { DEFAULT_GEMINI_FLASH_MODEL } from '../config/models.js';
import { retryWithBackoff } from './retry.js';
import { AuthType } from '../core/contentGenerator.js';
// Import the new types (Assuming this test file is in packages/core/src/utils/)
import type { FallbackModelHandler } from '../fallback/types.js';
@@ -61,84 +58,6 @@ describe('Retry Utility Fallback Integration', () => {
expect(result).toBe('retry');
});
// This test validates the retry utility's logic for triggering the callback.
it('should trigger onPersistent429 after 2 consecutive 429 errors for OAuth users', async () => {
let fallbackCalled = false;
// Removed fallbackModel variable as it's no longer relevant here.
// Mock function that simulates exactly 2 429 errors, then succeeds after fallback
const mockApiCall = vi
.fn()
.mockRejectedValueOnce(createSimulated429Error())
.mockRejectedValueOnce(createSimulated429Error())
.mockResolvedValueOnce('success after fallback');
// Mock the onPersistent429 callback (this is what client.ts/geminiChat.ts provides)
const mockPersistent429Callback = vi.fn(async (_authType?: string) => {
fallbackCalled = true;
// Return true to signal retryWithBackoff to reset attempts and continue.
return true;
});
// Test with OAuth personal auth type, with maxAttempts = 2 to ensure fallback triggers
const result = await retryWithBackoff(mockApiCall, {
maxAttempts: 2,
initialDelayMs: 1,
maxDelayMs: 10,
shouldRetryOnError: (error: Error) => {
const status = (error as Error & { status?: number }).status;
return status === 429;
},
onPersistent429: mockPersistent429Callback,
authType: AuthType.LOGIN_WITH_GOOGLE,
});
// Verify fallback mechanism was triggered
expect(fallbackCalled).toBe(true);
expect(mockPersistent429Callback).toHaveBeenCalledWith(
AuthType.LOGIN_WITH_GOOGLE,
expect.any(Error),
);
expect(result).toBe('success after fallback');
// Should have: 2 failures, then fallback triggered, then 1 success after retry reset
expect(mockApiCall).toHaveBeenCalledTimes(3);
});
it('should not trigger onPersistent429 for API key users', async () => {
let fallbackCalled = false;
// Mock function that simulates 429 errors
const mockApiCall = vi.fn().mockRejectedValue(createSimulated429Error());
// Mock the callback
const mockPersistent429Callback = vi.fn(async () => {
fallbackCalled = true;
return true;
});
// Test with API key auth type - should not trigger fallback
try {
await retryWithBackoff(mockApiCall, {
maxAttempts: 5,
initialDelayMs: 10,
maxDelayMs: 100,
shouldRetryOnError: (error: Error) => {
const status = (error as Error & { status?: number }).status;
return status === 429;
},
onPersistent429: mockPersistent429Callback,
authType: AuthType.USE_GEMINI, // API key auth type
});
} catch (error) {
// Expected to throw after max attempts
expect((error as Error).message).toContain('Rate limit exceeded');
}
// Verify fallback was NOT triggered for API key users
expect(fallbackCalled).toBe(false);
expect(mockPersistent429Callback).not.toHaveBeenCalled();
});
// This test validates the test utilities themselves.
it('should properly disable simulation state after fallback (Test Utility)', () => {
// Enable simulation

View File

@@ -285,173 +285,6 @@ describe('retryWithBackoff', () => {
});
});
describe('Flash model fallback for OAuth users', () => {
it('should trigger fallback for OAuth personal users after persistent 429 errors', async () => {
const fallbackCallback = vi.fn().mockResolvedValue('gemini-2.5-flash');
let fallbackOccurred = false;
const mockFn = vi.fn().mockImplementation(async () => {
if (!fallbackOccurred) {
const error: HttpError = new Error('Rate limit exceeded');
error.status = 429;
throw error;
}
return 'success';
});
const promise = retryWithBackoff(mockFn, {
maxAttempts: 3,
initialDelayMs: 100,
onPersistent429: async (authType?: string) => {
fallbackOccurred = true;
return await fallbackCallback(authType);
},
authType: 'oauth-personal',
});
// Advance all timers to complete retries
await vi.runAllTimersAsync();
// Should succeed after fallback
await expect(promise).resolves.toBe('success');
// Verify callback was called with correct auth type
expect(fallbackCallback).toHaveBeenCalledWith('oauth-personal');
// Should retry again after fallback
expect(mockFn).toHaveBeenCalledTimes(3); // 2 initial attempts + 1 after fallback
});
it('should NOT trigger fallback for API key users', async () => {
const fallbackCallback = vi.fn();
const mockFn = vi.fn(async () => {
const error: HttpError = new Error('Rate limit exceeded');
error.status = 429;
throw error;
});
const promise = retryWithBackoff(mockFn, {
maxAttempts: 3,
initialDelayMs: 100,
onPersistent429: fallbackCallback,
authType: 'gemini-api-key',
});
// Handle the promise properly to avoid unhandled rejections
const resultPromise = promise.catch((error) => error);
await vi.runAllTimersAsync();
const result = await resultPromise;
// Should fail after all retries without fallback
expect(result).toBeInstanceOf(Error);
expect(result.message).toBe('Rate limit exceeded');
// Callback should not be called for API key users
expect(fallbackCallback).not.toHaveBeenCalled();
});
it('should reset attempt counter and continue after successful fallback', async () => {
let fallbackCalled = false;
const fallbackCallback = vi.fn().mockImplementation(async () => {
fallbackCalled = true;
return 'gemini-2.5-flash';
});
const mockFn = vi.fn().mockImplementation(async () => {
if (!fallbackCalled) {
const error: HttpError = new Error('Rate limit exceeded');
error.status = 429;
throw error;
}
return 'success';
});
const promise = retryWithBackoff(mockFn, {
maxAttempts: 3,
initialDelayMs: 100,
onPersistent429: fallbackCallback,
authType: 'oauth-personal',
});
await vi.runAllTimersAsync();
await expect(promise).resolves.toBe('success');
expect(fallbackCallback).toHaveBeenCalledOnce();
});
it('should continue with original error if fallback is rejected', async () => {
const fallbackCallback = vi.fn().mockResolvedValue(null); // User rejected fallback
const mockFn = vi.fn(async () => {
const error: HttpError = new Error('Rate limit exceeded');
error.status = 429;
throw error;
});
const promise = retryWithBackoff(mockFn, {
maxAttempts: 3,
initialDelayMs: 100,
onPersistent429: fallbackCallback,
authType: 'oauth-personal',
});
// Handle the promise properly to avoid unhandled rejections
const resultPromise = promise.catch((error) => error);
await vi.runAllTimersAsync();
const result = await resultPromise;
// Should fail with original error when fallback is rejected
expect(result).toBeInstanceOf(Error);
expect(result.message).toBe('Rate limit exceeded');
expect(fallbackCallback).toHaveBeenCalledWith(
'oauth-personal',
expect.any(Error),
);
});
it('should handle mixed error types (only count consecutive 429s)', async () => {
const fallbackCallback = vi.fn().mockResolvedValue('gemini-2.5-flash');
let attempts = 0;
let fallbackOccurred = false;
const mockFn = vi.fn().mockImplementation(async () => {
attempts++;
if (fallbackOccurred) {
return 'success';
}
if (attempts === 1) {
// First attempt: 500 error (resets consecutive count)
const error: HttpError = new Error('Server error');
error.status = 500;
throw error;
} else {
// Remaining attempts: 429 errors
const error: HttpError = new Error('Rate limit exceeded');
error.status = 429;
throw error;
}
});
const promise = retryWithBackoff(mockFn, {
maxAttempts: 5,
initialDelayMs: 100,
onPersistent429: async (authType?: string) => {
fallbackOccurred = true;
return await fallbackCallback(authType);
},
authType: 'oauth-personal',
});
await vi.runAllTimersAsync();
await expect(promise).resolves.toBe('success');
// Should trigger fallback after 2 consecutive 429s (attempts 2-3)
expect(fallbackCallback).toHaveBeenCalledWith('oauth-personal');
});
});
describe('Qwen OAuth 429 error handling', () => {
it('should retry for Qwen OAuth 429 errors that are throttling-related', async () => {
const errorWith429: HttpError = new Error('Rate limit exceeded');

View File

@@ -7,8 +7,6 @@
import type { GenerateContentResponse } from '@google/genai';
import { AuthType } from '../core/contentGenerator.js';
import {
isProQuotaExceededError,
isGenericQuotaExceededError,
isQwenQuotaExceededError,
isQwenThrottlingError,
} from './quotaErrorDetection.js';
@@ -90,7 +88,6 @@ export async function retryWithBackoff<T>(
maxAttempts,
initialDelayMs,
maxDelayMs,
onPersistent429,
authType,
shouldRetryOnError,
shouldRetryOnContent,
@@ -123,59 +120,6 @@ export async function retryWithBackoff<T>(
} catch (error) {
const errorStatus = getErrorStatus(error);
// Check for Pro quota exceeded error first - immediate fallback for OAuth users
if (
errorStatus === 429 &&
authType === AuthType.LOGIN_WITH_GOOGLE &&
isProQuotaExceededError(error) &&
onPersistent429
) {
try {
const fallbackModel = await onPersistent429(authType, error);
if (fallbackModel !== false && fallbackModel !== null) {
// Reset attempt counter and try with new model
attempt = 0;
consecutive429Count = 0;
currentDelay = initialDelayMs;
// With the model updated, we continue to the next attempt
continue;
} else {
// Fallback handler returned null/false, meaning don't continue - stop retry process
throw error;
}
} catch (fallbackError) {
// If fallback fails, continue with original error
console.warn('Fallback to Flash model failed:', fallbackError);
}
}
// Check for generic quota exceeded error (but not Pro, which was handled above) - immediate fallback for OAuth users
if (
errorStatus === 429 &&
authType === AuthType.LOGIN_WITH_GOOGLE &&
!isProQuotaExceededError(error) &&
isGenericQuotaExceededError(error) &&
onPersistent429
) {
try {
const fallbackModel = await onPersistent429(authType, error);
if (fallbackModel !== false && fallbackModel !== null) {
// Reset attempt counter and try with new model
attempt = 0;
consecutive429Count = 0;
currentDelay = initialDelayMs;
// With the model updated, we continue to the next attempt
continue;
} else {
// Fallback handler returned null/false, meaning don't continue - stop retry process
throw error;
}
} catch (fallbackError) {
// If fallback fails, continue with original error
console.warn('Fallback to Flash model failed:', fallbackError);
}
}
// Check for Qwen OAuth quota exceeded error - throw immediately without retry
if (authType === AuthType.QWEN_OAUTH && isQwenQuotaExceededError(error)) {
throw new Error(
@@ -197,30 +141,7 @@ export async function retryWithBackoff<T>(
consecutive429Count = 0;
}
// If we have persistent 429s and a fallback callback for OAuth
if (
consecutive429Count >= 2 &&
onPersistent429 &&
authType === AuthType.LOGIN_WITH_GOOGLE
) {
try {
const fallbackModel = await onPersistent429(authType, error);
if (fallbackModel !== false && fallbackModel !== null) {
// Reset attempt counter and try with new model
attempt = 0;
consecutive429Count = 0;
currentDelay = initialDelayMs;
// With the model updated, we continue to the next attempt
continue;
} else {
// Fallback handler returned null/false, meaning don't continue - stop retry process
throw error;
}
} catch (fallbackError) {
// If fallback fails, continue with original error
console.warn('Fallback to Flash model failed:', fallbackError);
}
}
console.debug('consecutive429Count', consecutive429Count);
// Check if we've exhausted retries or shouldn't retry
if (attempt >= maxAttempts || !shouldRetryOnError(error as Error)) {
@@ -240,7 +161,7 @@ export async function retryWithBackoff<T>(
// Reset currentDelay for next potential non-429 error, or if Retry-After is not present next time
currentDelay = initialDelayMs;
} else {
// Fall back to exponential backoff with jitter
// Fallback to exponential backoff with jitter
logRetryAttempt(attempt, error, errorStatus);
// Add jitter: +/- 30% of currentDelay
const jitter = currentDelay * 0.3 * (Math.random() * 2 - 1);

View File

@@ -1,340 +0,0 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import type { Mock } from 'vitest';
import { vi, describe, it, expect, beforeEach, afterEach } from 'vitest';
import { UserAccountManager } from './userAccountManager.js';
import * as fs from 'node:fs';
import * as os from 'node:os';
import path from 'node:path';
vi.mock('os', async (importOriginal) => {
const os = await importOriginal<typeof import('os')>();
return {
...os,
homedir: vi.fn(),
};
});
describe('UserAccountManager', () => {
let tempHomeDir: string;
let userAccountManager: UserAccountManager;
let accountsFile: () => string;
beforeEach(() => {
tempHomeDir = fs.mkdtempSync(
path.join(os.tmpdir(), 'qwen-code-test-home-'),
);
(os.homedir as Mock).mockReturnValue(tempHomeDir);
accountsFile = () =>
path.join(tempHomeDir, '.qwen', 'google_accounts.json');
userAccountManager = new UserAccountManager();
});
afterEach(() => {
fs.rmSync(tempHomeDir, { recursive: true, force: true });
vi.clearAllMocks();
});
describe('cacheGoogleAccount', () => {
it('should create directory and write initial account file', async () => {
await userAccountManager.cacheGoogleAccount('test1@google.com');
// Verify Google Account ID was cached
expect(fs.existsSync(accountsFile())).toBe(true);
expect(fs.readFileSync(accountsFile(), 'utf-8')).toBe(
JSON.stringify({ active: 'test1@google.com', old: [] }, null, 2),
);
});
it('should update active account and move previous to old', async () => {
fs.mkdirSync(path.dirname(accountsFile()), { recursive: true });
fs.writeFileSync(
accountsFile(),
JSON.stringify(
{ active: 'test2@google.com', old: ['test1@google.com'] },
null,
2,
),
);
await userAccountManager.cacheGoogleAccount('test3@google.com');
expect(fs.readFileSync(accountsFile(), 'utf-8')).toBe(
JSON.stringify(
{
active: 'test3@google.com',
old: ['test1@google.com', 'test2@google.com'],
},
null,
2,
),
);
});
it('should not add a duplicate to the old list', async () => {
fs.mkdirSync(path.dirname(accountsFile()), { recursive: true });
fs.writeFileSync(
accountsFile(),
JSON.stringify(
{ active: 'test1@google.com', old: ['test2@google.com'] },
null,
2,
),
);
await userAccountManager.cacheGoogleAccount('test2@google.com');
await userAccountManager.cacheGoogleAccount('test1@google.com');
expect(fs.readFileSync(accountsFile(), 'utf-8')).toBe(
JSON.stringify(
{ active: 'test1@google.com', old: ['test2@google.com'] },
null,
2,
),
);
});
it('should handle corrupted JSON by starting fresh', async () => {
fs.mkdirSync(path.dirname(accountsFile()), { recursive: true });
fs.writeFileSync(accountsFile(), 'not valid json');
const consoleLogSpy = vi
.spyOn(console, 'log')
.mockImplementation(() => {});
await userAccountManager.cacheGoogleAccount('test1@google.com');
expect(consoleLogSpy).toHaveBeenCalled();
expect(JSON.parse(fs.readFileSync(accountsFile(), 'utf-8'))).toEqual({
active: 'test1@google.com',
old: [],
});
});
it('should handle valid JSON with incorrect schema by starting fresh', async () => {
fs.mkdirSync(path.dirname(accountsFile()), { recursive: true });
fs.writeFileSync(
accountsFile(),
JSON.stringify({ active: 'test1@google.com', old: 'not-an-array' }),
);
const consoleLogSpy = vi
.spyOn(console, 'log')
.mockImplementation(() => {});
await userAccountManager.cacheGoogleAccount('test2@google.com');
expect(consoleLogSpy).toHaveBeenCalled();
expect(JSON.parse(fs.readFileSync(accountsFile(), 'utf-8'))).toEqual({
active: 'test2@google.com',
old: [],
});
});
});
describe('getCachedGoogleAccount', () => {
it('should return the active account if file exists and is valid', () => {
fs.mkdirSync(path.dirname(accountsFile()), { recursive: true });
fs.writeFileSync(
accountsFile(),
JSON.stringify({ active: 'active@google.com', old: [] }, null, 2),
);
const account = userAccountManager.getCachedGoogleAccount();
expect(account).toBe('active@google.com');
});
it('should return null if file does not exist', () => {
const account = userAccountManager.getCachedGoogleAccount();
expect(account).toBeNull();
});
it('should return null if file is empty', () => {
fs.mkdirSync(path.dirname(accountsFile()), { recursive: true });
fs.writeFileSync(accountsFile(), '');
const account = userAccountManager.getCachedGoogleAccount();
expect(account).toBeNull();
});
it('should return null and log if file is corrupted', () => {
fs.mkdirSync(path.dirname(accountsFile()), { recursive: true });
fs.writeFileSync(accountsFile(), '{ "active": "test@google.com"'); // Invalid JSON
const consoleLogSpy = vi
.spyOn(console, 'log')
.mockImplementation(() => {});
const account = userAccountManager.getCachedGoogleAccount();
expect(account).toBeNull();
expect(consoleLogSpy).toHaveBeenCalled();
});
it('should return null if active key is missing', () => {
fs.mkdirSync(path.dirname(accountsFile()), { recursive: true });
fs.writeFileSync(accountsFile(), JSON.stringify({ old: [] }));
const account = userAccountManager.getCachedGoogleAccount();
expect(account).toBeNull();
});
});
describe('clearCachedGoogleAccount', () => {
it('should set active to null and move it to old', async () => {
fs.mkdirSync(path.dirname(accountsFile()), { recursive: true });
fs.writeFileSync(
accountsFile(),
JSON.stringify(
{ active: 'active@google.com', old: ['old1@google.com'] },
null,
2,
),
);
await userAccountManager.clearCachedGoogleAccount();
const stored = JSON.parse(fs.readFileSync(accountsFile(), 'utf-8'));
expect(stored.active).toBeNull();
expect(stored.old).toEqual(['old1@google.com', 'active@google.com']);
});
it('should handle empty file gracefully', async () => {
fs.mkdirSync(path.dirname(accountsFile()), { recursive: true });
fs.writeFileSync(accountsFile(), '');
await userAccountManager.clearCachedGoogleAccount();
const stored = JSON.parse(fs.readFileSync(accountsFile(), 'utf-8'));
expect(stored.active).toBeNull();
expect(stored.old).toEqual([]);
});
it('should handle corrupted JSON by creating a fresh file', async () => {
fs.mkdirSync(path.dirname(accountsFile()), { recursive: true });
fs.writeFileSync(accountsFile(), 'not valid json');
const consoleLogSpy = vi
.spyOn(console, 'log')
.mockImplementation(() => {});
await userAccountManager.clearCachedGoogleAccount();
expect(consoleLogSpy).toHaveBeenCalled();
const stored = JSON.parse(fs.readFileSync(accountsFile(), 'utf-8'));
expect(stored.active).toBeNull();
expect(stored.old).toEqual([]);
});
it('should be idempotent if active account is already null', async () => {
fs.mkdirSync(path.dirname(accountsFile()), { recursive: true });
fs.writeFileSync(
accountsFile(),
JSON.stringify({ active: null, old: ['old1@google.com'] }, null, 2),
);
await userAccountManager.clearCachedGoogleAccount();
const stored = JSON.parse(fs.readFileSync(accountsFile(), 'utf-8'));
expect(stored.active).toBeNull();
expect(stored.old).toEqual(['old1@google.com']);
});
it('should not add a duplicate to the old list', async () => {
fs.mkdirSync(path.dirname(accountsFile()), { recursive: true });
fs.writeFileSync(
accountsFile(),
JSON.stringify(
{
active: 'active@google.com',
old: ['active@google.com'],
},
null,
2,
),
);
await userAccountManager.clearCachedGoogleAccount();
const stored = JSON.parse(fs.readFileSync(accountsFile(), 'utf-8'));
expect(stored.active).toBeNull();
expect(stored.old).toEqual(['active@google.com']);
});
});
describe('getLifetimeGoogleAccounts', () => {
it('should return 0 if the file does not exist', () => {
expect(userAccountManager.getLifetimeGoogleAccounts()).toBe(0);
});
it('should return 0 if the file is empty', () => {
fs.mkdirSync(path.dirname(accountsFile()), { recursive: true });
fs.writeFileSync(accountsFile(), '');
expect(userAccountManager.getLifetimeGoogleAccounts()).toBe(0);
});
it('should return 0 if the file is corrupted', () => {
fs.mkdirSync(path.dirname(accountsFile()), { recursive: true });
fs.writeFileSync(accountsFile(), 'invalid json');
const consoleDebugSpy = vi
.spyOn(console, 'log')
.mockImplementation(() => {});
expect(userAccountManager.getLifetimeGoogleAccounts()).toBe(0);
expect(consoleDebugSpy).toHaveBeenCalled();
});
it('should return 1 if there is only an active account', () => {
fs.mkdirSync(path.dirname(accountsFile()), { recursive: true });
fs.writeFileSync(
accountsFile(),
JSON.stringify({ active: 'test1@google.com', old: [] }),
);
expect(userAccountManager.getLifetimeGoogleAccounts()).toBe(1);
});
it('should correctly count old accounts when active is null', () => {
fs.mkdirSync(path.dirname(accountsFile()), { recursive: true });
fs.writeFileSync(
accountsFile(),
JSON.stringify({
active: null,
old: ['test1@google.com', 'test2@google.com'],
}),
);
expect(userAccountManager.getLifetimeGoogleAccounts()).toBe(2);
});
it('should correctly count both active and old accounts', () => {
fs.mkdirSync(path.dirname(accountsFile()), { recursive: true });
fs.writeFileSync(
accountsFile(),
JSON.stringify({
active: 'test3@google.com',
old: ['test1@google.com', 'test2@google.com'],
}),
);
expect(userAccountManager.getLifetimeGoogleAccounts()).toBe(3);
});
it('should handle valid JSON with incorrect schema by returning 0', () => {
fs.mkdirSync(path.dirname(accountsFile()), { recursive: true });
fs.writeFileSync(
accountsFile(),
JSON.stringify({ active: null, old: 1 }),
);
const consoleLogSpy = vi
.spyOn(console, 'log')
.mockImplementation(() => {});
expect(userAccountManager.getLifetimeGoogleAccounts()).toBe(0);
expect(consoleLogSpy).toHaveBeenCalled();
});
it('should not double count if active account is also in old list', () => {
fs.mkdirSync(path.dirname(accountsFile()), { recursive: true });
fs.writeFileSync(
accountsFile(),
JSON.stringify({
active: 'test1@google.com',
old: ['test1@google.com', 'test2@google.com'],
}),
);
expect(userAccountManager.getLifetimeGoogleAccounts()).toBe(2);
});
});
});

View File

@@ -1,140 +0,0 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import path from 'node:path';
import { promises as fsp, readFileSync } from 'node:fs';
import { Storage } from '../config/storage.js';
interface UserAccounts {
active: string | null;
old: string[];
}
export class UserAccountManager {
private getGoogleAccountsCachePath(): string {
return Storage.getGoogleAccountsPath();
}
/**
* Parses and validates the string content of an accounts file.
* @param content The raw string content from the file.
* @returns A valid UserAccounts object.
*/
private parseAndValidateAccounts(content: string): UserAccounts {
const defaultState = { active: null, old: [] };
if (!content.trim()) {
return defaultState;
}
const parsed = JSON.parse(content);
// Inlined validation logic
if (typeof parsed !== 'object' || parsed === null) {
console.log('Invalid accounts file schema, starting fresh.');
return defaultState;
}
const { active, old } = parsed as Partial<UserAccounts>;
const isValid =
(active === undefined || active === null || typeof active === 'string') &&
(old === undefined ||
(Array.isArray(old) && old.every((i) => typeof i === 'string')));
if (!isValid) {
console.log('Invalid accounts file schema, starting fresh.');
return defaultState;
}
return {
active: parsed.active ?? null,
old: parsed.old ?? [],
};
}
private readAccountsSync(filePath: string): UserAccounts {
const defaultState = { active: null, old: [] };
try {
const content = readFileSync(filePath, 'utf-8');
return this.parseAndValidateAccounts(content);
} catch (error) {
if (
error instanceof Error &&
'code' in error &&
error.code === 'ENOENT'
) {
return defaultState;
}
console.log('Error during sync read of accounts, starting fresh.', error);
return defaultState;
}
}
private async readAccounts(filePath: string): Promise<UserAccounts> {
const defaultState = { active: null, old: [] };
try {
const content = await fsp.readFile(filePath, 'utf-8');
return this.parseAndValidateAccounts(content);
} catch (error) {
if (
error instanceof Error &&
'code' in error &&
error.code === 'ENOENT'
) {
return defaultState;
}
console.log('Could not parse accounts file, starting fresh.', error);
return defaultState;
}
}
async cacheGoogleAccount(email: string): Promise<void> {
const filePath = this.getGoogleAccountsCachePath();
await fsp.mkdir(path.dirname(filePath), { recursive: true });
const accounts = await this.readAccounts(filePath);
if (accounts.active && accounts.active !== email) {
if (!accounts.old.includes(accounts.active)) {
accounts.old.push(accounts.active);
}
}
// If the new email was in the old list, remove it
accounts.old = accounts.old.filter((oldEmail) => oldEmail !== email);
accounts.active = email;
await fsp.writeFile(filePath, JSON.stringify(accounts, null, 2), 'utf-8');
}
getCachedGoogleAccount(): string | null {
const filePath = this.getGoogleAccountsCachePath();
const accounts = this.readAccountsSync(filePath);
return accounts.active;
}
getLifetimeGoogleAccounts(): number {
const filePath = this.getGoogleAccountsCachePath();
const accounts = this.readAccountsSync(filePath);
const allAccounts = new Set(accounts.old);
if (accounts.active) {
allAccounts.add(accounts.active);
}
return allAccounts.size;
}
async clearCachedGoogleAccount(): Promise<void> {
const filePath = this.getGoogleAccountsCachePath();
const accounts = await this.readAccounts(filePath);
if (accounts.active) {
if (!accounts.old.includes(accounts.active)) {
accounts.old.push(accounts.active);
}
accounts.active = null;
}
await fsp.writeFile(filePath, JSON.stringify(accounts, null, 2), 'utf-8');
}
}