mirror of
https://github.com/QwenLM/qwen-code.git
synced 2025-12-21 01:07:46 +00:00
Add Gemini provider, remove legacy Google OAuth, and tune generation defaults
This commit is contained in:
@@ -23,8 +23,8 @@
|
||||
"scripts/postinstall.js"
|
||||
],
|
||||
"dependencies": {
|
||||
"@google/genai": "1.16.0",
|
||||
"@modelcontextprotocol/sdk": "^1.11.0",
|
||||
"@google/genai": "1.30.0",
|
||||
"@modelcontextprotocol/sdk": "^1.25.1",
|
||||
"@opentelemetry/api": "^1.9.0",
|
||||
"async-mutex": "^0.5.0",
|
||||
"@opentelemetry/exporter-logs-otlp-grpc": "^0.203.0",
|
||||
@@ -34,7 +34,6 @@
|
||||
"@opentelemetry/exporter-trace-otlp-grpc": "^0.203.0",
|
||||
"@opentelemetry/exporter-trace-otlp-http": "^0.203.0",
|
||||
"@opentelemetry/instrumentation-http": "^0.203.0",
|
||||
"@opentelemetry/resource-detector-gcp": "^0.40.0",
|
||||
"@opentelemetry/sdk-node": "^0.203.0",
|
||||
"@types/html-to-text": "^9.0.4",
|
||||
"@xterm/headless": "5.5.0",
|
||||
@@ -48,7 +47,7 @@
|
||||
"fdir": "^6.4.6",
|
||||
"fzf": "^0.5.2",
|
||||
"glob": "^10.5.0",
|
||||
"google-auth-library": "^9.11.0",
|
||||
"google-auth-library": "^10.5.0",
|
||||
"html-to-text": "^9.0.5",
|
||||
"https-proxy-agent": "^7.0.6",
|
||||
"ignore": "^7.0.0",
|
||||
|
||||
@@ -1,54 +0,0 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import type { ContentGenerator } from '../core/contentGenerator.js';
|
||||
import { AuthType } from '../core/contentGenerator.js';
|
||||
import { getOauthClient } from './oauth2.js';
|
||||
import { setupUser } from './setup.js';
|
||||
import type { HttpOptions } from './server.js';
|
||||
import { CodeAssistServer } from './server.js';
|
||||
import type { Config } from '../config/config.js';
|
||||
import { LoggingContentGenerator } from '../core/loggingContentGenerator.js';
|
||||
|
||||
export async function createCodeAssistContentGenerator(
|
||||
httpOptions: HttpOptions,
|
||||
authType: AuthType,
|
||||
config: Config,
|
||||
sessionId?: string,
|
||||
): Promise<ContentGenerator> {
|
||||
if (
|
||||
authType === AuthType.LOGIN_WITH_GOOGLE ||
|
||||
authType === AuthType.CLOUD_SHELL
|
||||
) {
|
||||
const authClient = await getOauthClient(authType, config);
|
||||
const userData = await setupUser(authClient);
|
||||
return new CodeAssistServer(
|
||||
authClient,
|
||||
userData.projectId,
|
||||
httpOptions,
|
||||
sessionId,
|
||||
userData.userTier,
|
||||
);
|
||||
}
|
||||
|
||||
throw new Error(`Unsupported authType: ${authType}`);
|
||||
}
|
||||
|
||||
export function getCodeAssistServer(
|
||||
config: Config,
|
||||
): CodeAssistServer | undefined {
|
||||
let server = config.getContentGenerator();
|
||||
|
||||
// Unwrap LoggingContentGenerator if present
|
||||
if (server instanceof LoggingContentGenerator) {
|
||||
server = server.getWrapped();
|
||||
}
|
||||
|
||||
if (!(server instanceof CodeAssistServer)) {
|
||||
return undefined;
|
||||
}
|
||||
return server;
|
||||
}
|
||||
@@ -1,456 +0,0 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { describe, it, expect } from 'vitest';
|
||||
import type { CaGenerateContentResponse } from './converter.js';
|
||||
import {
|
||||
toGenerateContentRequest,
|
||||
fromGenerateContentResponse,
|
||||
toContents,
|
||||
} from './converter.js';
|
||||
import type {
|
||||
ContentListUnion,
|
||||
GenerateContentParameters,
|
||||
} from '@google/genai';
|
||||
import {
|
||||
GenerateContentResponse,
|
||||
FinishReason,
|
||||
BlockedReason,
|
||||
type Part,
|
||||
} from '@google/genai';
|
||||
|
||||
describe('converter', () => {
|
||||
describe('toCodeAssistRequest', () => {
|
||||
it('should convert a simple request with project', () => {
|
||||
const genaiReq: GenerateContentParameters = {
|
||||
model: 'gemini-pro',
|
||||
contents: [{ role: 'user', parts: [{ text: 'Hello' }] }],
|
||||
};
|
||||
const codeAssistReq = toGenerateContentRequest(
|
||||
genaiReq,
|
||||
'my-prompt',
|
||||
'my-project',
|
||||
'my-session',
|
||||
);
|
||||
expect(codeAssistReq).toEqual({
|
||||
model: 'gemini-pro',
|
||||
project: 'my-project',
|
||||
request: {
|
||||
contents: [{ role: 'user', parts: [{ text: 'Hello' }] }],
|
||||
systemInstruction: undefined,
|
||||
cachedContent: undefined,
|
||||
tools: undefined,
|
||||
toolConfig: undefined,
|
||||
labels: undefined,
|
||||
safetySettings: undefined,
|
||||
generationConfig: undefined,
|
||||
session_id: 'my-session',
|
||||
},
|
||||
user_prompt_id: 'my-prompt',
|
||||
});
|
||||
});
|
||||
|
||||
it('should convert a request without a project', () => {
|
||||
const genaiReq: GenerateContentParameters = {
|
||||
model: 'gemini-pro',
|
||||
contents: [{ role: 'user', parts: [{ text: 'Hello' }] }],
|
||||
};
|
||||
const codeAssistReq = toGenerateContentRequest(
|
||||
genaiReq,
|
||||
'my-prompt',
|
||||
undefined,
|
||||
'my-session',
|
||||
);
|
||||
expect(codeAssistReq).toEqual({
|
||||
model: 'gemini-pro',
|
||||
project: undefined,
|
||||
request: {
|
||||
contents: [{ role: 'user', parts: [{ text: 'Hello' }] }],
|
||||
systemInstruction: undefined,
|
||||
cachedContent: undefined,
|
||||
tools: undefined,
|
||||
toolConfig: undefined,
|
||||
labels: undefined,
|
||||
safetySettings: undefined,
|
||||
generationConfig: undefined,
|
||||
session_id: 'my-session',
|
||||
},
|
||||
user_prompt_id: 'my-prompt',
|
||||
});
|
||||
});
|
||||
|
||||
it('should convert a request with sessionId', () => {
|
||||
const genaiReq: GenerateContentParameters = {
|
||||
model: 'gemini-pro',
|
||||
contents: [{ role: 'user', parts: [{ text: 'Hello' }] }],
|
||||
};
|
||||
const codeAssistReq = toGenerateContentRequest(
|
||||
genaiReq,
|
||||
'my-prompt',
|
||||
'my-project',
|
||||
'session-123',
|
||||
);
|
||||
expect(codeAssistReq).toEqual({
|
||||
model: 'gemini-pro',
|
||||
project: 'my-project',
|
||||
request: {
|
||||
contents: [{ role: 'user', parts: [{ text: 'Hello' }] }],
|
||||
systemInstruction: undefined,
|
||||
cachedContent: undefined,
|
||||
tools: undefined,
|
||||
toolConfig: undefined,
|
||||
labels: undefined,
|
||||
safetySettings: undefined,
|
||||
generationConfig: undefined,
|
||||
session_id: 'session-123',
|
||||
},
|
||||
user_prompt_id: 'my-prompt',
|
||||
});
|
||||
});
|
||||
|
||||
it('should handle string content', () => {
|
||||
const genaiReq: GenerateContentParameters = {
|
||||
model: 'gemini-pro',
|
||||
contents: 'Hello',
|
||||
};
|
||||
const codeAssistReq = toGenerateContentRequest(
|
||||
genaiReq,
|
||||
'my-prompt',
|
||||
'my-project',
|
||||
'my-session',
|
||||
);
|
||||
expect(codeAssistReq.request.contents).toEqual([
|
||||
{ role: 'user', parts: [{ text: 'Hello' }] },
|
||||
]);
|
||||
});
|
||||
|
||||
it('should handle Part[] content', () => {
|
||||
const genaiReq: GenerateContentParameters = {
|
||||
model: 'gemini-pro',
|
||||
contents: [{ text: 'Hello' }, { text: 'World' }],
|
||||
};
|
||||
const codeAssistReq = toGenerateContentRequest(
|
||||
genaiReq,
|
||||
'my-prompt',
|
||||
'my-project',
|
||||
'my-session',
|
||||
);
|
||||
expect(codeAssistReq.request.contents).toEqual([
|
||||
{ role: 'user', parts: [{ text: 'Hello' }] },
|
||||
{ role: 'user', parts: [{ text: 'World' }] },
|
||||
]);
|
||||
});
|
||||
|
||||
it('should handle system instructions', () => {
|
||||
const genaiReq: GenerateContentParameters = {
|
||||
model: 'gemini-pro',
|
||||
contents: 'Hello',
|
||||
config: {
|
||||
systemInstruction: 'You are a helpful assistant.',
|
||||
},
|
||||
};
|
||||
const codeAssistReq = toGenerateContentRequest(
|
||||
genaiReq,
|
||||
'my-prompt',
|
||||
'my-project',
|
||||
'my-session',
|
||||
);
|
||||
expect(codeAssistReq.request.systemInstruction).toEqual({
|
||||
role: 'user',
|
||||
parts: [{ text: 'You are a helpful assistant.' }],
|
||||
});
|
||||
});
|
||||
|
||||
it('should handle generation config', () => {
|
||||
const genaiReq: GenerateContentParameters = {
|
||||
model: 'gemini-pro',
|
||||
contents: 'Hello',
|
||||
config: {
|
||||
temperature: 0.8,
|
||||
topK: 40,
|
||||
},
|
||||
};
|
||||
const codeAssistReq = toGenerateContentRequest(
|
||||
genaiReq,
|
||||
'my-prompt',
|
||||
'my-project',
|
||||
'my-session',
|
||||
);
|
||||
expect(codeAssistReq.request.generationConfig).toEqual({
|
||||
temperature: 0.8,
|
||||
topK: 40,
|
||||
});
|
||||
});
|
||||
|
||||
it('should handle all generation config fields', () => {
|
||||
const genaiReq: GenerateContentParameters = {
|
||||
model: 'gemini-pro',
|
||||
contents: 'Hello',
|
||||
config: {
|
||||
temperature: 0.1,
|
||||
topP: 0.2,
|
||||
topK: 3,
|
||||
candidateCount: 4,
|
||||
maxOutputTokens: 5,
|
||||
stopSequences: ['a'],
|
||||
responseLogprobs: true,
|
||||
logprobs: 6,
|
||||
presencePenalty: 0.7,
|
||||
frequencyPenalty: 0.8,
|
||||
seed: 9,
|
||||
responseMimeType: 'application/json',
|
||||
},
|
||||
};
|
||||
const codeAssistReq = toGenerateContentRequest(
|
||||
genaiReq,
|
||||
'my-prompt',
|
||||
'my-project',
|
||||
'my-session',
|
||||
);
|
||||
expect(codeAssistReq.request.generationConfig).toEqual({
|
||||
temperature: 0.1,
|
||||
topP: 0.2,
|
||||
topK: 3,
|
||||
candidateCount: 4,
|
||||
maxOutputTokens: 5,
|
||||
stopSequences: ['a'],
|
||||
responseLogprobs: true,
|
||||
logprobs: 6,
|
||||
presencePenalty: 0.7,
|
||||
frequencyPenalty: 0.8,
|
||||
seed: 9,
|
||||
responseMimeType: 'application/json',
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('fromCodeAssistResponse', () => {
|
||||
it('should convert a simple response', () => {
|
||||
const codeAssistRes: CaGenerateContentResponse = {
|
||||
response: {
|
||||
candidates: [
|
||||
{
|
||||
index: 0,
|
||||
content: {
|
||||
role: 'model',
|
||||
parts: [{ text: 'Hi there!' }],
|
||||
},
|
||||
finishReason: FinishReason.STOP,
|
||||
safetyRatings: [],
|
||||
},
|
||||
],
|
||||
},
|
||||
};
|
||||
const genaiRes = fromGenerateContentResponse(codeAssistRes);
|
||||
expect(genaiRes).toBeInstanceOf(GenerateContentResponse);
|
||||
expect(genaiRes.candidates).toEqual(codeAssistRes.response.candidates);
|
||||
});
|
||||
|
||||
it('should handle prompt feedback and usage metadata', () => {
|
||||
const codeAssistRes: CaGenerateContentResponse = {
|
||||
response: {
|
||||
candidates: [],
|
||||
promptFeedback: {
|
||||
blockReason: BlockedReason.SAFETY,
|
||||
safetyRatings: [],
|
||||
},
|
||||
usageMetadata: {
|
||||
promptTokenCount: 10,
|
||||
candidatesTokenCount: 20,
|
||||
totalTokenCount: 30,
|
||||
},
|
||||
},
|
||||
};
|
||||
const genaiRes = fromGenerateContentResponse(codeAssistRes);
|
||||
expect(genaiRes.promptFeedback).toEqual(
|
||||
codeAssistRes.response.promptFeedback,
|
||||
);
|
||||
expect(genaiRes.usageMetadata).toEqual(
|
||||
codeAssistRes.response.usageMetadata,
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle automatic function calling history', () => {
|
||||
const codeAssistRes: CaGenerateContentResponse = {
|
||||
response: {
|
||||
candidates: [],
|
||||
automaticFunctionCallingHistory: [
|
||||
{
|
||||
role: 'model',
|
||||
parts: [
|
||||
{
|
||||
functionCall: {
|
||||
name: 'test_function',
|
||||
args: {
|
||||
foo: 'bar',
|
||||
},
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
],
|
||||
},
|
||||
};
|
||||
const genaiRes = fromGenerateContentResponse(codeAssistRes);
|
||||
expect(genaiRes.automaticFunctionCallingHistory).toEqual(
|
||||
codeAssistRes.response.automaticFunctionCallingHistory,
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle modelVersion', () => {
|
||||
const codeAssistRes: CaGenerateContentResponse = {
|
||||
response: {
|
||||
candidates: [],
|
||||
modelVersion: 'qwen3-coder-plus',
|
||||
},
|
||||
};
|
||||
const genaiRes = fromGenerateContentResponse(codeAssistRes);
|
||||
expect(genaiRes.modelVersion).toEqual('qwen3-coder-plus');
|
||||
});
|
||||
});
|
||||
|
||||
describe('toContents', () => {
|
||||
it('should handle Content', () => {
|
||||
const content: ContentListUnion = {
|
||||
role: 'user',
|
||||
parts: [{ text: 'hello' }],
|
||||
};
|
||||
expect(toContents(content)).toEqual([
|
||||
{ role: 'user', parts: [{ text: 'hello' }] },
|
||||
]);
|
||||
});
|
||||
|
||||
it('should handle array of Contents', () => {
|
||||
const contents: ContentListUnion = [
|
||||
{ role: 'user', parts: [{ text: 'hello' }] },
|
||||
{ role: 'model', parts: [{ text: 'hi' }] },
|
||||
];
|
||||
expect(toContents(contents)).toEqual([
|
||||
{ role: 'user', parts: [{ text: 'hello' }] },
|
||||
{ role: 'model', parts: [{ text: 'hi' }] },
|
||||
]);
|
||||
});
|
||||
|
||||
it('should handle Part', () => {
|
||||
const part: ContentListUnion = { text: 'a part' };
|
||||
expect(toContents(part)).toEqual([
|
||||
{ role: 'user', parts: [{ text: 'a part' }] },
|
||||
]);
|
||||
});
|
||||
|
||||
it('should handle array of Parts', () => {
|
||||
const parts = [{ text: 'part 1' }, 'part 2'];
|
||||
expect(toContents(parts)).toEqual([
|
||||
{ role: 'user', parts: [{ text: 'part 1' }] },
|
||||
{ role: 'user', parts: [{ text: 'part 2' }] },
|
||||
]);
|
||||
});
|
||||
|
||||
it('should handle string', () => {
|
||||
const str: ContentListUnion = 'a string';
|
||||
expect(toContents(str)).toEqual([
|
||||
{ role: 'user', parts: [{ text: 'a string' }] },
|
||||
]);
|
||||
});
|
||||
|
||||
it('should handle array of strings', () => {
|
||||
const strings: ContentListUnion = ['string 1', 'string 2'];
|
||||
expect(toContents(strings)).toEqual([
|
||||
{ role: 'user', parts: [{ text: 'string 1' }] },
|
||||
{ role: 'user', parts: [{ text: 'string 2' }] },
|
||||
]);
|
||||
});
|
||||
|
||||
it('should convert thought parts to text parts for API compatibility', () => {
|
||||
const contentWithThought: ContentListUnion = {
|
||||
role: 'model',
|
||||
parts: [
|
||||
{ text: 'regular text' },
|
||||
{ thought: 'thinking about the problem' } as Part & {
|
||||
thought: string;
|
||||
},
|
||||
{ text: 'more text' },
|
||||
],
|
||||
};
|
||||
expect(toContents(contentWithThought)).toEqual([
|
||||
{
|
||||
role: 'model',
|
||||
parts: [
|
||||
{ text: 'regular text' },
|
||||
{ text: '[Thought: thinking about the problem]' },
|
||||
{ text: 'more text' },
|
||||
],
|
||||
},
|
||||
]);
|
||||
});
|
||||
|
||||
it('should combine text and thought for text parts with thoughts', () => {
|
||||
const contentWithTextAndThought: ContentListUnion = {
|
||||
role: 'model',
|
||||
parts: [
|
||||
{
|
||||
text: 'Here is my response',
|
||||
thought: 'I need to be careful here',
|
||||
} as Part & { thought: string },
|
||||
],
|
||||
};
|
||||
expect(toContents(contentWithTextAndThought)).toEqual([
|
||||
{
|
||||
role: 'model',
|
||||
parts: [
|
||||
{
|
||||
text: 'Here is my response\n[Thought: I need to be careful here]',
|
||||
},
|
||||
],
|
||||
},
|
||||
]);
|
||||
});
|
||||
|
||||
it('should preserve non-thought properties while removing thought', () => {
|
||||
const contentWithComplexPart: ContentListUnion = {
|
||||
role: 'model',
|
||||
parts: [
|
||||
{
|
||||
functionCall: { name: 'calculate', args: { x: 5, y: 10 } },
|
||||
thought: 'Performing calculation',
|
||||
} as Part & { thought: string },
|
||||
],
|
||||
};
|
||||
expect(toContents(contentWithComplexPart)).toEqual([
|
||||
{
|
||||
role: 'model',
|
||||
parts: [
|
||||
{
|
||||
functionCall: { name: 'calculate', args: { x: 5, y: 10 } },
|
||||
},
|
||||
],
|
||||
},
|
||||
]);
|
||||
});
|
||||
|
||||
it('should convert invalid text content to valid text part with thought', () => {
|
||||
const contentWithInvalidText: ContentListUnion = {
|
||||
role: 'model',
|
||||
parts: [
|
||||
{
|
||||
text: 123, // Invalid - should be string
|
||||
thought: 'Processing number',
|
||||
} as Part & { thought: string; text: number },
|
||||
],
|
||||
};
|
||||
expect(toContents(contentWithInvalidText)).toEqual([
|
||||
{
|
||||
role: 'model',
|
||||
parts: [
|
||||
{
|
||||
text: '123\n[Thought: Processing number]',
|
||||
},
|
||||
],
|
||||
},
|
||||
]);
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -1,285 +0,0 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import type {
|
||||
Content,
|
||||
ContentListUnion,
|
||||
ContentUnion,
|
||||
GenerateContentConfig,
|
||||
GenerateContentParameters,
|
||||
CountTokensParameters,
|
||||
CountTokensResponse,
|
||||
GenerationConfigRoutingConfig,
|
||||
MediaResolution,
|
||||
Candidate,
|
||||
ModelSelectionConfig,
|
||||
GenerateContentResponsePromptFeedback,
|
||||
GenerateContentResponseUsageMetadata,
|
||||
Part,
|
||||
SafetySetting,
|
||||
PartUnion,
|
||||
SpeechConfigUnion,
|
||||
ThinkingConfig,
|
||||
ToolListUnion,
|
||||
ToolConfig,
|
||||
} from '@google/genai';
|
||||
import { GenerateContentResponse } from '@google/genai';
|
||||
|
||||
export interface CAGenerateContentRequest {
|
||||
model: string;
|
||||
project?: string;
|
||||
user_prompt_id?: string;
|
||||
request: VertexGenerateContentRequest;
|
||||
}
|
||||
|
||||
interface VertexGenerateContentRequest {
|
||||
contents: Content[];
|
||||
systemInstruction?: Content;
|
||||
cachedContent?: string;
|
||||
tools?: ToolListUnion;
|
||||
toolConfig?: ToolConfig;
|
||||
labels?: Record<string, string>;
|
||||
safetySettings?: SafetySetting[];
|
||||
generationConfig?: VertexGenerationConfig;
|
||||
session_id?: string;
|
||||
}
|
||||
|
||||
interface VertexGenerationConfig {
|
||||
temperature?: number;
|
||||
topP?: number;
|
||||
topK?: number;
|
||||
candidateCount?: number;
|
||||
maxOutputTokens?: number;
|
||||
stopSequences?: string[];
|
||||
responseLogprobs?: boolean;
|
||||
logprobs?: number;
|
||||
presencePenalty?: number;
|
||||
frequencyPenalty?: number;
|
||||
seed?: number;
|
||||
responseMimeType?: string;
|
||||
responseJsonSchema?: unknown;
|
||||
responseSchema?: unknown;
|
||||
routingConfig?: GenerationConfigRoutingConfig;
|
||||
modelSelectionConfig?: ModelSelectionConfig;
|
||||
responseModalities?: string[];
|
||||
mediaResolution?: MediaResolution;
|
||||
speechConfig?: SpeechConfigUnion;
|
||||
audioTimestamp?: boolean;
|
||||
thinkingConfig?: ThinkingConfig;
|
||||
}
|
||||
|
||||
export interface CaGenerateContentResponse {
|
||||
response: VertexGenerateContentResponse;
|
||||
}
|
||||
|
||||
interface VertexGenerateContentResponse {
|
||||
candidates: Candidate[];
|
||||
automaticFunctionCallingHistory?: Content[];
|
||||
promptFeedback?: GenerateContentResponsePromptFeedback;
|
||||
usageMetadata?: GenerateContentResponseUsageMetadata;
|
||||
modelVersion?: string;
|
||||
}
|
||||
|
||||
export interface CaCountTokenRequest {
|
||||
request: VertexCountTokenRequest;
|
||||
}
|
||||
|
||||
interface VertexCountTokenRequest {
|
||||
model: string;
|
||||
contents: Content[];
|
||||
}
|
||||
|
||||
export interface CaCountTokenResponse {
|
||||
totalTokens: number;
|
||||
}
|
||||
|
||||
export function toCountTokenRequest(
|
||||
req: CountTokensParameters,
|
||||
): CaCountTokenRequest {
|
||||
return {
|
||||
request: {
|
||||
model: 'models/' + req.model,
|
||||
contents: toContents(req.contents),
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
export function fromCountTokenResponse(
|
||||
res: CaCountTokenResponse,
|
||||
): CountTokensResponse {
|
||||
return {
|
||||
totalTokens: res.totalTokens,
|
||||
};
|
||||
}
|
||||
|
||||
export function toGenerateContentRequest(
|
||||
req: GenerateContentParameters,
|
||||
userPromptId: string,
|
||||
project?: string,
|
||||
sessionId?: string,
|
||||
): CAGenerateContentRequest {
|
||||
return {
|
||||
model: req.model,
|
||||
project,
|
||||
user_prompt_id: userPromptId,
|
||||
request: toVertexGenerateContentRequest(req, sessionId),
|
||||
};
|
||||
}
|
||||
|
||||
export function fromGenerateContentResponse(
|
||||
res: CaGenerateContentResponse,
|
||||
): GenerateContentResponse {
|
||||
const inres = res.response;
|
||||
const out = new GenerateContentResponse();
|
||||
out.candidates = inres.candidates;
|
||||
out.automaticFunctionCallingHistory = inres.automaticFunctionCallingHistory;
|
||||
out.promptFeedback = inres.promptFeedback;
|
||||
out.usageMetadata = inres.usageMetadata;
|
||||
out.modelVersion = inres.modelVersion;
|
||||
return out;
|
||||
}
|
||||
|
||||
function toVertexGenerateContentRequest(
|
||||
req: GenerateContentParameters,
|
||||
sessionId?: string,
|
||||
): VertexGenerateContentRequest {
|
||||
return {
|
||||
contents: toContents(req.contents),
|
||||
systemInstruction: maybeToContent(req.config?.systemInstruction),
|
||||
cachedContent: req.config?.cachedContent,
|
||||
tools: req.config?.tools,
|
||||
toolConfig: req.config?.toolConfig,
|
||||
labels: req.config?.labels,
|
||||
safetySettings: req.config?.safetySettings,
|
||||
generationConfig: toVertexGenerationConfig(req.config),
|
||||
session_id: sessionId,
|
||||
};
|
||||
}
|
||||
|
||||
export function toContents(contents: ContentListUnion): Content[] {
|
||||
if (Array.isArray(contents)) {
|
||||
// it's a Content[] or a PartsUnion[]
|
||||
return contents.map(toContent);
|
||||
}
|
||||
// it's a Content or a PartsUnion
|
||||
return [toContent(contents)];
|
||||
}
|
||||
|
||||
function maybeToContent(content?: ContentUnion): Content | undefined {
|
||||
if (!content) {
|
||||
return undefined;
|
||||
}
|
||||
return toContent(content);
|
||||
}
|
||||
|
||||
function toContent(content: ContentUnion): Content {
|
||||
if (Array.isArray(content)) {
|
||||
// it's a PartsUnion[]
|
||||
return {
|
||||
role: 'user',
|
||||
parts: toParts(content),
|
||||
};
|
||||
}
|
||||
if (typeof content === 'string') {
|
||||
// it's a string
|
||||
return {
|
||||
role: 'user',
|
||||
parts: [{ text: content }],
|
||||
};
|
||||
}
|
||||
if ('parts' in content) {
|
||||
// it's a Content - process parts to handle thought filtering
|
||||
return {
|
||||
...content,
|
||||
parts: content.parts
|
||||
? toParts(content.parts.filter((p) => p != null))
|
||||
: [],
|
||||
};
|
||||
}
|
||||
// it's a Part
|
||||
return {
|
||||
role: 'user',
|
||||
parts: [toPart(content as Part)],
|
||||
};
|
||||
}
|
||||
|
||||
export function toParts(parts: PartUnion[]): Part[] {
|
||||
return parts.map(toPart);
|
||||
}
|
||||
|
||||
function toPart(part: PartUnion): Part {
|
||||
if (typeof part === 'string') {
|
||||
// it's a string
|
||||
return { text: part };
|
||||
}
|
||||
|
||||
// Handle thought parts for CountToken API compatibility
|
||||
// The CountToken API expects parts to have certain required "oneof" fields initialized,
|
||||
// but thought parts don't conform to this schema and cause API failures
|
||||
if ('thought' in part && part.thought) {
|
||||
const thoughtText = `[Thought: ${part.thought}]`;
|
||||
|
||||
const newPart = { ...part };
|
||||
delete (newPart as Record<string, unknown>)['thought'];
|
||||
|
||||
const hasApiContent =
|
||||
'functionCall' in newPart ||
|
||||
'functionResponse' in newPart ||
|
||||
'inlineData' in newPart ||
|
||||
'fileData' in newPart;
|
||||
|
||||
if (hasApiContent) {
|
||||
// It's a functionCall or other non-text part. Just strip the thought.
|
||||
return newPart;
|
||||
}
|
||||
|
||||
// If no other valid API content, this must be a text part.
|
||||
// Combine existing text (if any) with the thought, preserving other properties.
|
||||
const text = (newPart as { text?: unknown }).text;
|
||||
const existingText = text ? String(text) : '';
|
||||
const combinedText = existingText
|
||||
? `${existingText}\n${thoughtText}`
|
||||
: thoughtText;
|
||||
|
||||
return {
|
||||
...newPart,
|
||||
text: combinedText,
|
||||
};
|
||||
}
|
||||
|
||||
return part;
|
||||
}
|
||||
|
||||
function toVertexGenerationConfig(
|
||||
config?: GenerateContentConfig,
|
||||
): VertexGenerationConfig | undefined {
|
||||
if (!config) {
|
||||
return undefined;
|
||||
}
|
||||
return {
|
||||
temperature: config.temperature,
|
||||
topP: config.topP,
|
||||
topK: config.topK,
|
||||
candidateCount: config.candidateCount,
|
||||
maxOutputTokens: config.maxOutputTokens,
|
||||
stopSequences: config.stopSequences,
|
||||
responseLogprobs: config.responseLogprobs,
|
||||
logprobs: config.logprobs,
|
||||
presencePenalty: config.presencePenalty,
|
||||
frequencyPenalty: config.frequencyPenalty,
|
||||
seed: config.seed,
|
||||
responseMimeType: config.responseMimeType,
|
||||
responseSchema: config.responseSchema,
|
||||
responseJsonSchema: config.responseJsonSchema,
|
||||
routingConfig: config.routingConfig,
|
||||
modelSelectionConfig: config.modelSelectionConfig,
|
||||
responseModalities: config.responseModalities,
|
||||
mediaResolution: config.mediaResolution,
|
||||
speechConfig: config.speechConfig,
|
||||
audioTimestamp: config.audioTimestamp,
|
||||
thinkingConfig: config.thinkingConfig,
|
||||
};
|
||||
}
|
||||
@@ -1,217 +0,0 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { type Credentials } from 'google-auth-library';
|
||||
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
|
||||
import { OAuthCredentialStorage } from './oauth-credential-storage.js';
|
||||
import type { OAuthCredentials } from '../mcp/token-storage/types.js';
|
||||
|
||||
import * as path from 'node:path';
|
||||
import * as os from 'node:os';
|
||||
import { promises as fs } from 'node:fs';
|
||||
|
||||
// Mock external dependencies
|
||||
const mockHybridTokenStorage = vi.hoisted(() => ({
|
||||
getCredentials: vi.fn(),
|
||||
setCredentials: vi.fn(),
|
||||
deleteCredentials: vi.fn(),
|
||||
}));
|
||||
vi.mock('../mcp/token-storage/hybrid-token-storage.js', () => ({
|
||||
HybridTokenStorage: vi.fn(() => mockHybridTokenStorage),
|
||||
}));
|
||||
vi.mock('node:fs', () => ({
|
||||
promises: {
|
||||
readFile: vi.fn(),
|
||||
rm: vi.fn(),
|
||||
},
|
||||
}));
|
||||
vi.mock('node:os');
|
||||
vi.mock('node:path');
|
||||
|
||||
describe('OAuthCredentialStorage', () => {
|
||||
const mockCredentials: Credentials = {
|
||||
access_token: 'mock_access_token',
|
||||
refresh_token: 'mock_refresh_token',
|
||||
expiry_date: Date.now() + 3600 * 1000,
|
||||
token_type: 'Bearer',
|
||||
scope: 'email profile',
|
||||
};
|
||||
|
||||
const mockMcpCredentials: OAuthCredentials = {
|
||||
serverName: 'main-account',
|
||||
token: {
|
||||
accessToken: 'mock_access_token',
|
||||
refreshToken: 'mock_refresh_token',
|
||||
tokenType: 'Bearer',
|
||||
scope: 'email profile',
|
||||
expiresAt: mockCredentials.expiry_date!,
|
||||
},
|
||||
updatedAt: expect.any(Number),
|
||||
};
|
||||
|
||||
const oldFilePath = '/mock/home/.qwen/oauth.json';
|
||||
|
||||
beforeEach(() => {
|
||||
vi.spyOn(mockHybridTokenStorage, 'getCredentials').mockResolvedValue(null);
|
||||
vi.spyOn(mockHybridTokenStorage, 'setCredentials').mockResolvedValue(
|
||||
undefined,
|
||||
);
|
||||
vi.spyOn(mockHybridTokenStorage, 'deleteCredentials').mockResolvedValue(
|
||||
undefined,
|
||||
);
|
||||
|
||||
vi.spyOn(fs, 'readFile').mockRejectedValue(new Error('File not found'));
|
||||
vi.spyOn(fs, 'rm').mockResolvedValue(undefined);
|
||||
|
||||
vi.spyOn(os, 'homedir').mockReturnValue('/mock/home');
|
||||
vi.spyOn(path, 'join').mockReturnValue(oldFilePath);
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.restoreAllMocks();
|
||||
});
|
||||
|
||||
describe('loadCredentials', () => {
|
||||
it('should load credentials from HybridTokenStorage if available', async () => {
|
||||
vi.spyOn(mockHybridTokenStorage, 'getCredentials').mockResolvedValue(
|
||||
mockMcpCredentials,
|
||||
);
|
||||
|
||||
const result = await OAuthCredentialStorage.loadCredentials();
|
||||
|
||||
expect(mockHybridTokenStorage.getCredentials).toHaveBeenCalledWith(
|
||||
'main-account',
|
||||
);
|
||||
expect(result).toEqual(mockCredentials);
|
||||
});
|
||||
|
||||
it('should fallback to migrateFromFileStorage if no credentials in HybridTokenStorage', async () => {
|
||||
vi.spyOn(mockHybridTokenStorage, 'getCredentials').mockResolvedValue(
|
||||
null,
|
||||
);
|
||||
vi.spyOn(fs, 'readFile').mockResolvedValue(
|
||||
JSON.stringify(mockCredentials),
|
||||
);
|
||||
|
||||
const result = await OAuthCredentialStorage.loadCredentials();
|
||||
|
||||
expect(mockHybridTokenStorage.getCredentials).toHaveBeenCalledWith(
|
||||
'main-account',
|
||||
);
|
||||
expect(fs.readFile).toHaveBeenCalledWith(oldFilePath, 'utf-8');
|
||||
expect(mockHybridTokenStorage.setCredentials).toHaveBeenCalled(); // Verify credentials were saved
|
||||
expect(fs.rm).toHaveBeenCalledWith(oldFilePath, { force: true }); // Verify old file was removed
|
||||
expect(result).toEqual(mockCredentials);
|
||||
});
|
||||
|
||||
it('should return null if no credentials found and no old file to migrate', async () => {
|
||||
vi.spyOn(fs, 'readFile').mockRejectedValue({
|
||||
message: 'File not found',
|
||||
code: 'ENOENT',
|
||||
});
|
||||
|
||||
const result = await OAuthCredentialStorage.loadCredentials();
|
||||
|
||||
expect(result).toBeNull();
|
||||
});
|
||||
|
||||
it('should throw an error if loading fails', async () => {
|
||||
vi.spyOn(mockHybridTokenStorage, 'getCredentials').mockRejectedValue(
|
||||
new Error('Loading error'),
|
||||
);
|
||||
|
||||
await expect(OAuthCredentialStorage.loadCredentials()).rejects.toThrow(
|
||||
'Failed to load OAuth credentials',
|
||||
);
|
||||
});
|
||||
|
||||
it('should throw an error if read file fails', async () => {
|
||||
vi.spyOn(mockHybridTokenStorage, 'getCredentials').mockResolvedValue(
|
||||
null,
|
||||
);
|
||||
vi.spyOn(fs, 'readFile').mockRejectedValue(
|
||||
new Error('Permission denied'),
|
||||
);
|
||||
|
||||
await expect(OAuthCredentialStorage.loadCredentials()).rejects.toThrow(
|
||||
'Failed to load OAuth credentials',
|
||||
);
|
||||
});
|
||||
|
||||
it('should not throw error if migration file removal failed', async () => {
|
||||
vi.spyOn(mockHybridTokenStorage, 'getCredentials').mockResolvedValue(
|
||||
null,
|
||||
);
|
||||
vi.spyOn(fs, 'readFile').mockResolvedValue(
|
||||
JSON.stringify(mockCredentials),
|
||||
);
|
||||
vi.spyOn(OAuthCredentialStorage, 'saveCredentials').mockResolvedValue(
|
||||
undefined,
|
||||
);
|
||||
vi.spyOn(fs, 'rm').mockRejectedValue(new Error('Deletion failed'));
|
||||
|
||||
const result = await OAuthCredentialStorage.loadCredentials();
|
||||
|
||||
expect(result).toEqual(mockCredentials);
|
||||
});
|
||||
});
|
||||
|
||||
describe('saveCredentials', () => {
|
||||
it('should save credentials to HybridTokenStorage', async () => {
|
||||
await OAuthCredentialStorage.saveCredentials(mockCredentials);
|
||||
|
||||
expect(mockHybridTokenStorage.setCredentials).toHaveBeenCalledWith(
|
||||
mockMcpCredentials,
|
||||
);
|
||||
});
|
||||
|
||||
it('should throw an error if access_token is missing', async () => {
|
||||
const invalidCredentials: Credentials = {
|
||||
...mockCredentials,
|
||||
access_token: undefined,
|
||||
};
|
||||
await expect(
|
||||
OAuthCredentialStorage.saveCredentials(invalidCredentials),
|
||||
).rejects.toThrow(
|
||||
'Attempted to save credentials without an access token.',
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('clearCredentials', () => {
|
||||
it('should delete credentials from HybridTokenStorage', async () => {
|
||||
await OAuthCredentialStorage.clearCredentials();
|
||||
|
||||
expect(mockHybridTokenStorage.deleteCredentials).toHaveBeenCalledWith(
|
||||
'main-account',
|
||||
);
|
||||
});
|
||||
|
||||
it('should attempt to remove the old file-based storage', async () => {
|
||||
await OAuthCredentialStorage.clearCredentials();
|
||||
|
||||
expect(fs.rm).toHaveBeenCalledWith(oldFilePath, { force: true });
|
||||
});
|
||||
|
||||
it('should not throw an error if deleting old file fails', async () => {
|
||||
vi.spyOn(fs, 'rm').mockRejectedValue(new Error('File deletion failed'));
|
||||
|
||||
await expect(
|
||||
OAuthCredentialStorage.clearCredentials(),
|
||||
).resolves.toBeUndefined();
|
||||
});
|
||||
|
||||
it('should throw an error if clearing from HybridTokenStorage fails', async () => {
|
||||
vi.spyOn(mockHybridTokenStorage, 'deleteCredentials').mockRejectedValue(
|
||||
new Error('Deletion error'),
|
||||
);
|
||||
|
||||
await expect(OAuthCredentialStorage.clearCredentials()).rejects.toThrow(
|
||||
'Failed to clear OAuth credentials',
|
||||
);
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -1,130 +0,0 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { type Credentials } from 'google-auth-library';
|
||||
import { HybridTokenStorage } from '../mcp/token-storage/hybrid-token-storage.js';
|
||||
import { OAUTH_FILE } from '../config/storage.js';
|
||||
import type { OAuthCredentials } from '../mcp/token-storage/types.js';
|
||||
import * as path from 'node:path';
|
||||
import * as os from 'node:os';
|
||||
import { promises as fs } from 'node:fs';
|
||||
|
||||
const QWEN_DIR = '.qwen';
|
||||
const KEYCHAIN_SERVICE_NAME = 'qwen-code-oauth';
|
||||
const MAIN_ACCOUNT_KEY = 'main-account';
|
||||
|
||||
export class OAuthCredentialStorage {
|
||||
private static storage: HybridTokenStorage = new HybridTokenStorage(
|
||||
KEYCHAIN_SERVICE_NAME,
|
||||
);
|
||||
|
||||
/**
|
||||
* Load cached OAuth credentials
|
||||
*/
|
||||
static async loadCredentials(): Promise<Credentials | null> {
|
||||
try {
|
||||
const credentials = await this.storage.getCredentials(MAIN_ACCOUNT_KEY);
|
||||
|
||||
if (credentials?.token) {
|
||||
const { accessToken, refreshToken, expiresAt, tokenType, scope } =
|
||||
credentials.token;
|
||||
// Convert from OAuthCredentials format to Google Credentials format
|
||||
const googleCreds: Credentials = {
|
||||
access_token: accessToken,
|
||||
refresh_token: refreshToken || undefined,
|
||||
token_type: tokenType || undefined,
|
||||
scope: scope || undefined,
|
||||
};
|
||||
|
||||
if (expiresAt) {
|
||||
googleCreds.expiry_date = expiresAt;
|
||||
}
|
||||
|
||||
return googleCreds;
|
||||
}
|
||||
|
||||
// Fallback: Try to migrate from old file-based storage
|
||||
return await this.migrateFromFileStorage();
|
||||
} catch (error: unknown) {
|
||||
console.error(error);
|
||||
throw new Error('Failed to load OAuth credentials');
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Save OAuth credentials
|
||||
*/
|
||||
static async saveCredentials(credentials: Credentials): Promise<void> {
|
||||
if (!credentials.access_token) {
|
||||
throw new Error('Attempted to save credentials without an access token.');
|
||||
}
|
||||
|
||||
// Convert Google Credentials to OAuthCredentials format
|
||||
const mcpCredentials: OAuthCredentials = {
|
||||
serverName: MAIN_ACCOUNT_KEY,
|
||||
token: {
|
||||
accessToken: credentials.access_token,
|
||||
refreshToken: credentials.refresh_token || undefined,
|
||||
tokenType: credentials.token_type || 'Bearer',
|
||||
scope: credentials.scope || undefined,
|
||||
expiresAt: credentials.expiry_date || undefined,
|
||||
},
|
||||
updatedAt: Date.now(),
|
||||
};
|
||||
|
||||
await this.storage.setCredentials(mcpCredentials);
|
||||
}
|
||||
|
||||
/**
|
||||
* Clear cached OAuth credentials
|
||||
*/
|
||||
static async clearCredentials(): Promise<void> {
|
||||
try {
|
||||
await this.storage.deleteCredentials(MAIN_ACCOUNT_KEY);
|
||||
|
||||
// Also try to remove the old file if it exists
|
||||
const oldFilePath = path.join(os.homedir(), QWEN_DIR, OAUTH_FILE);
|
||||
await fs.rm(oldFilePath, { force: true }).catch(() => {});
|
||||
} catch (error: unknown) {
|
||||
console.error(error);
|
||||
throw new Error('Failed to clear OAuth credentials');
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Migrate credentials from old file-based storage to keychain
|
||||
*/
|
||||
private static async migrateFromFileStorage(): Promise<Credentials | null> {
|
||||
const oldFilePath = path.join(os.homedir(), QWEN_DIR, OAUTH_FILE);
|
||||
|
||||
let credsJson: string;
|
||||
try {
|
||||
credsJson = await fs.readFile(oldFilePath, 'utf-8');
|
||||
} catch (error: unknown) {
|
||||
if (
|
||||
typeof error === 'object' &&
|
||||
error !== null &&
|
||||
'code' in error &&
|
||||
error.code === 'ENOENT'
|
||||
) {
|
||||
// File doesn't exist, so no migration.
|
||||
return null;
|
||||
}
|
||||
// Other read errors should propagate.
|
||||
throw error;
|
||||
}
|
||||
|
||||
const credentials = JSON.parse(credsJson) as Credentials;
|
||||
|
||||
// Save to new storage
|
||||
await this.saveCredentials(credentials);
|
||||
|
||||
// Remove old file after successful migration
|
||||
await fs.rm(oldFilePath, { force: true }).catch(() => {});
|
||||
|
||||
return credentials;
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,563 +0,0 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import type { Credentials } from 'google-auth-library';
|
||||
import {
|
||||
CodeChallengeMethod,
|
||||
Compute,
|
||||
OAuth2Client,
|
||||
} from 'google-auth-library';
|
||||
import crypto from 'node:crypto';
|
||||
import { promises as fs } from 'node:fs';
|
||||
import * as http from 'node:http';
|
||||
import * as net from 'node:net';
|
||||
import path from 'node:path';
|
||||
import readline from 'node:readline';
|
||||
import url from 'node:url';
|
||||
import open from 'open';
|
||||
import type { Config } from '../config/config.js';
|
||||
import { Storage } from '../config/storage.js';
|
||||
import { AuthType } from '../core/contentGenerator.js';
|
||||
import { FatalAuthenticationError, getErrorMessage } from '../utils/errors.js';
|
||||
import { UserAccountManager } from '../utils/userAccountManager.js';
|
||||
import { OAuthCredentialStorage } from './oauth-credential-storage.js';
|
||||
import { FORCE_ENCRYPTED_FILE_ENV_VAR } from '../mcp/token-storage/index.js';
|
||||
|
||||
const userAccountManager = new UserAccountManager();
|
||||
|
||||
// OAuth Client ID used to initiate OAuth2Client class.
|
||||
const OAUTH_CLIENT_ID =
|
||||
'681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com';
|
||||
|
||||
// OAuth Secret value used to initiate OAuth2Client class.
|
||||
// Note: It's ok to save this in git because this is an installed application
|
||||
// as described here: https://developers.google.com/identity/protocols/oauth2#installed
|
||||
// "The process results in a client ID and, in some cases, a client secret,
|
||||
// which you embed in the source code of your application. (In this context,
|
||||
// the client secret is obviously not treated as a secret.)"
|
||||
const OAUTH_CLIENT_SECRET = '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';
|
||||
|
||||
// OAuth Scopes for Cloud Code authorization.
|
||||
const OAUTH_SCOPE = [
|
||||
'https://www.googleapis.com/auth/cloud-platform',
|
||||
'https://www.googleapis.com/auth/userinfo.email',
|
||||
'https://www.googleapis.com/auth/userinfo.profile',
|
||||
];
|
||||
|
||||
const HTTP_REDIRECT = 301;
|
||||
const SIGN_IN_SUCCESS_URL =
|
||||
'https://developers.google.com/gemini-code-assist/auth_success_gemini';
|
||||
const SIGN_IN_FAILURE_URL =
|
||||
'https://developers.google.com/gemini-code-assist/auth_failure_gemini';
|
||||
|
||||
/**
|
||||
* An Authentication URL for updating the credentials of a Oauth2Client
|
||||
* as well as a promise that will resolve when the credentials have
|
||||
* been refreshed (or which throws error when refreshing credentials failed).
|
||||
*/
|
||||
export interface OauthWebLogin {
|
||||
authUrl: string;
|
||||
loginCompletePromise: Promise<void>;
|
||||
}
|
||||
|
||||
const oauthClientPromises = new Map<AuthType, Promise<OAuth2Client>>();
|
||||
|
||||
function getUseEncryptedStorageFlag() {
|
||||
return process.env[FORCE_ENCRYPTED_FILE_ENV_VAR] === 'true';
|
||||
}
|
||||
|
||||
async function initOauthClient(
|
||||
authType: AuthType,
|
||||
config: Config,
|
||||
): Promise<OAuth2Client> {
|
||||
const client = new OAuth2Client({
|
||||
clientId: OAUTH_CLIENT_ID,
|
||||
clientSecret: OAUTH_CLIENT_SECRET,
|
||||
transporterOptions: {
|
||||
proxy: config.getProxy(),
|
||||
},
|
||||
});
|
||||
const useEncryptedStorage = getUseEncryptedStorageFlag();
|
||||
|
||||
if (
|
||||
process.env['GOOGLE_GENAI_USE_GCA'] &&
|
||||
process.env['GOOGLE_CLOUD_ACCESS_TOKEN']
|
||||
) {
|
||||
client.setCredentials({
|
||||
access_token: process.env['GOOGLE_CLOUD_ACCESS_TOKEN'],
|
||||
});
|
||||
await fetchAndCacheUserInfo(client);
|
||||
return client;
|
||||
}
|
||||
|
||||
client.on('tokens', async (tokens: Credentials) => {
|
||||
if (useEncryptedStorage) {
|
||||
await OAuthCredentialStorage.saveCredentials(tokens);
|
||||
} else {
|
||||
await cacheCredentials(tokens);
|
||||
}
|
||||
});
|
||||
|
||||
// If there are cached creds on disk, they always take precedence
|
||||
if (await loadCachedCredentials(client)) {
|
||||
// Found valid cached credentials.
|
||||
// Check if we need to retrieve Google Account ID or Email
|
||||
if (!userAccountManager.getCachedGoogleAccount()) {
|
||||
try {
|
||||
await fetchAndCacheUserInfo(client);
|
||||
} catch (error) {
|
||||
// Non-fatal, continue with existing auth.
|
||||
console.warn('Failed to fetch user info:', getErrorMessage(error));
|
||||
}
|
||||
}
|
||||
console.log('Loaded cached credentials.');
|
||||
return client;
|
||||
}
|
||||
|
||||
// In Google Cloud Shell, we can use Application Default Credentials (ADC)
|
||||
// provided via its metadata server to authenticate non-interactively using
|
||||
// the identity of the user logged into Cloud Shell.
|
||||
if (authType === AuthType.CLOUD_SHELL) {
|
||||
try {
|
||||
console.log("Attempting to authenticate via Cloud Shell VM's ADC.");
|
||||
const computeClient = new Compute({
|
||||
// We can leave this empty, since the metadata server will provide
|
||||
// the service account email.
|
||||
});
|
||||
await computeClient.getAccessToken();
|
||||
console.log('Authentication successful.');
|
||||
|
||||
// Do not cache creds in this case; note that Compute client will handle its own refresh
|
||||
return computeClient;
|
||||
} catch (e) {
|
||||
throw new Error(
|
||||
`Could not authenticate using Cloud Shell credentials. Please select a different authentication method or ensure you are in a properly configured environment. Error: ${getErrorMessage(
|
||||
e,
|
||||
)}`,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
if (config.isBrowserLaunchSuppressed()) {
|
||||
let success = false;
|
||||
const maxRetries = 2;
|
||||
for (let i = 0; !success && i < maxRetries; i++) {
|
||||
success = await authWithUserCode(client);
|
||||
if (!success) {
|
||||
console.error(
|
||||
'\nFailed to authenticate with user code.',
|
||||
i === maxRetries - 1 ? '' : 'Retrying...\n',
|
||||
);
|
||||
}
|
||||
}
|
||||
if (!success) {
|
||||
throw new FatalAuthenticationError(
|
||||
'Failed to authenticate with user code.',
|
||||
);
|
||||
}
|
||||
} else {
|
||||
const webLogin = await authWithWeb(client);
|
||||
|
||||
console.log(
|
||||
`\n\nCode Assist login required.\n` +
|
||||
`Attempting to open authentication page in your browser.\n` +
|
||||
`Otherwise navigate to:\n\n${webLogin.authUrl}\n\n`,
|
||||
);
|
||||
try {
|
||||
// Attempt to open the authentication URL in the default browser.
|
||||
// We do not use the `wait` option here because the main script's execution
|
||||
// is already paused by `loginCompletePromise`, which awaits the server callback.
|
||||
const childProcess = await open(webLogin.authUrl);
|
||||
|
||||
// IMPORTANT: Attach an error handler to the returned child process.
|
||||
// Without this, if `open` fails to spawn a process (e.g., `xdg-open` is not found
|
||||
// in a minimal Docker container), it will emit an unhandled 'error' event,
|
||||
// causing the entire Node.js process to crash.
|
||||
childProcess.on('error', (error) => {
|
||||
console.error(
|
||||
'Failed to open browser automatically. Please try running again with NO_BROWSER=true set.',
|
||||
);
|
||||
console.error('Browser error details:', getErrorMessage(error));
|
||||
});
|
||||
} catch (err) {
|
||||
console.error(
|
||||
'An unexpected error occurred while trying to open the browser:',
|
||||
getErrorMessage(err),
|
||||
'\nThis might be due to browser compatibility issues or system configuration.',
|
||||
'\nPlease try running again with NO_BROWSER=true set for manual authentication.',
|
||||
);
|
||||
throw new FatalAuthenticationError(
|
||||
`Failed to open browser: ${getErrorMessage(err)}`,
|
||||
);
|
||||
}
|
||||
console.log('Waiting for authentication...');
|
||||
|
||||
// Add timeout to prevent infinite waiting when browser tab gets stuck
|
||||
const authTimeout = 5 * 60 * 1000; // 5 minutes timeout
|
||||
const timeoutPromise = new Promise<never>((_, reject) => {
|
||||
setTimeout(() => {
|
||||
reject(
|
||||
new FatalAuthenticationError(
|
||||
'Authentication timed out after 5 minutes. The browser tab may have gotten stuck in a loading state. ' +
|
||||
'Please try again or use NO_BROWSER=true for manual authentication.',
|
||||
),
|
||||
);
|
||||
}, authTimeout);
|
||||
});
|
||||
|
||||
await Promise.race([webLogin.loginCompletePromise, timeoutPromise]);
|
||||
}
|
||||
|
||||
return client;
|
||||
}
|
||||
|
||||
export async function getOauthClient(
|
||||
authType: AuthType,
|
||||
config: Config,
|
||||
): Promise<OAuth2Client> {
|
||||
if (!oauthClientPromises.has(authType)) {
|
||||
oauthClientPromises.set(authType, initOauthClient(authType, config));
|
||||
}
|
||||
return oauthClientPromises.get(authType)!;
|
||||
}
|
||||
|
||||
async function authWithUserCode(client: OAuth2Client): Promise<boolean> {
|
||||
const redirectUri = 'https://codeassist.google.com/authcode';
|
||||
const codeVerifier = await client.generateCodeVerifierAsync();
|
||||
const state = crypto.randomBytes(32).toString('hex');
|
||||
const authUrl: string = client.generateAuthUrl({
|
||||
redirect_uri: redirectUri,
|
||||
access_type: 'offline',
|
||||
scope: OAUTH_SCOPE,
|
||||
code_challenge_method: CodeChallengeMethod.S256,
|
||||
code_challenge: codeVerifier.codeChallenge,
|
||||
state,
|
||||
});
|
||||
console.log('Please visit the following URL to authorize the application:');
|
||||
console.log('');
|
||||
console.log(authUrl);
|
||||
console.log('');
|
||||
|
||||
const code = await new Promise<string>((resolve) => {
|
||||
const rl = readline.createInterface({
|
||||
input: process.stdin,
|
||||
output: process.stdout,
|
||||
});
|
||||
rl.question('Enter the authorization code: ', (code) => {
|
||||
rl.close();
|
||||
resolve(code.trim());
|
||||
});
|
||||
});
|
||||
|
||||
if (!code) {
|
||||
console.error('Authorization code is required.');
|
||||
return false;
|
||||
}
|
||||
|
||||
try {
|
||||
const { tokens } = await client.getToken({
|
||||
code,
|
||||
codeVerifier: codeVerifier.codeVerifier,
|
||||
redirect_uri: redirectUri,
|
||||
});
|
||||
client.setCredentials(tokens);
|
||||
} catch (error) {
|
||||
console.error(
|
||||
'Failed to authenticate with authorization code:',
|
||||
getErrorMessage(error),
|
||||
);
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
async function authWithWeb(client: OAuth2Client): Promise<OauthWebLogin> {
|
||||
const port = await getAvailablePort();
|
||||
// The hostname used for the HTTP server binding (e.g., '0.0.0.0' in Docker).
|
||||
const host = process.env['OAUTH_CALLBACK_HOST'] || 'localhost';
|
||||
// The `redirectUri` sent to Google's authorization server MUST use a loopback IP literal
|
||||
// (i.e., 'localhost' or '127.0.0.1'). This is a strict security policy for credentials of
|
||||
// type 'Desktop app' or 'Web application' (when using loopback flow) to mitigate
|
||||
// authorization code interception attacks.
|
||||
const redirectUri = `http://localhost:${port}/oauth2callback`;
|
||||
const state = crypto.randomBytes(32).toString('hex');
|
||||
const authUrl = client.generateAuthUrl({
|
||||
redirect_uri: redirectUri,
|
||||
access_type: 'offline',
|
||||
scope: OAUTH_SCOPE,
|
||||
state,
|
||||
});
|
||||
|
||||
const loginCompletePromise = new Promise<void>((resolve, reject) => {
|
||||
const server = http.createServer(async (req, res) => {
|
||||
try {
|
||||
if (req.url!.indexOf('/oauth2callback') === -1) {
|
||||
res.writeHead(HTTP_REDIRECT, { Location: SIGN_IN_FAILURE_URL });
|
||||
res.end();
|
||||
reject(
|
||||
new FatalAuthenticationError(
|
||||
'OAuth callback not received. Unexpected request: ' + req.url,
|
||||
),
|
||||
);
|
||||
}
|
||||
// acquire the code from the querystring, and close the web server.
|
||||
const qs = new url.URL(req.url!, 'http://localhost:3000').searchParams;
|
||||
if (qs.get('error')) {
|
||||
res.writeHead(HTTP_REDIRECT, { Location: SIGN_IN_FAILURE_URL });
|
||||
res.end();
|
||||
|
||||
const errorCode = qs.get('error');
|
||||
const errorDescription =
|
||||
qs.get('error_description') || 'No additional details provided';
|
||||
reject(
|
||||
new FatalAuthenticationError(
|
||||
`Google OAuth error: ${errorCode}. ${errorDescription}`,
|
||||
),
|
||||
);
|
||||
} else if (qs.get('state') !== state) {
|
||||
res.end('State mismatch. Possible CSRF attack');
|
||||
|
||||
reject(
|
||||
new FatalAuthenticationError(
|
||||
'OAuth state mismatch. Possible CSRF attack or browser session issue.',
|
||||
),
|
||||
);
|
||||
} else if (qs.get('code')) {
|
||||
try {
|
||||
const { tokens } = await client.getToken({
|
||||
code: qs.get('code')!,
|
||||
redirect_uri: redirectUri,
|
||||
});
|
||||
client.setCredentials(tokens);
|
||||
|
||||
// Retrieve and cache Google Account ID during authentication
|
||||
try {
|
||||
await fetchAndCacheUserInfo(client);
|
||||
} catch (error) {
|
||||
console.warn(
|
||||
'Failed to retrieve Google Account ID during authentication:',
|
||||
getErrorMessage(error),
|
||||
);
|
||||
// Don't fail the auth flow if Google Account ID retrieval fails
|
||||
}
|
||||
|
||||
res.writeHead(HTTP_REDIRECT, { Location: SIGN_IN_SUCCESS_URL });
|
||||
res.end();
|
||||
resolve();
|
||||
} catch (error) {
|
||||
res.writeHead(HTTP_REDIRECT, { Location: SIGN_IN_FAILURE_URL });
|
||||
res.end();
|
||||
reject(
|
||||
new FatalAuthenticationError(
|
||||
`Failed to exchange authorization code for tokens: ${getErrorMessage(error)}`,
|
||||
),
|
||||
);
|
||||
}
|
||||
} else {
|
||||
reject(
|
||||
new FatalAuthenticationError(
|
||||
'No authorization code received from Google OAuth. Please try authenticating again.',
|
||||
),
|
||||
);
|
||||
}
|
||||
} catch (e) {
|
||||
// Provide more specific error message for unexpected errors during OAuth flow
|
||||
if (e instanceof FatalAuthenticationError) {
|
||||
reject(e);
|
||||
} else {
|
||||
reject(
|
||||
new FatalAuthenticationError(
|
||||
`Unexpected error during OAuth authentication: ${getErrorMessage(e)}`,
|
||||
),
|
||||
);
|
||||
}
|
||||
} finally {
|
||||
server.close();
|
||||
}
|
||||
});
|
||||
|
||||
server.listen(port, host, () => {
|
||||
// Server started successfully
|
||||
});
|
||||
|
||||
server.on('error', (err) => {
|
||||
reject(
|
||||
new FatalAuthenticationError(
|
||||
`OAuth callback server error: ${getErrorMessage(err)}`,
|
||||
),
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
return {
|
||||
authUrl,
|
||||
loginCompletePromise,
|
||||
};
|
||||
}
|
||||
|
||||
export function getAvailablePort(): Promise<number> {
|
||||
return new Promise((resolve, reject) => {
|
||||
let port = 0;
|
||||
try {
|
||||
const portStr = process.env['OAUTH_CALLBACK_PORT'];
|
||||
if (portStr) {
|
||||
port = parseInt(portStr, 10);
|
||||
if (isNaN(port) || port <= 0 || port > 65535) {
|
||||
return reject(
|
||||
new Error(`Invalid value for OAUTH_CALLBACK_PORT: "${portStr}"`),
|
||||
);
|
||||
}
|
||||
return resolve(port);
|
||||
}
|
||||
const server = net.createServer();
|
||||
server.listen(0, () => {
|
||||
const address = server.address()! as net.AddressInfo;
|
||||
port = address.port;
|
||||
});
|
||||
server.on('listening', () => {
|
||||
server.close();
|
||||
server.unref();
|
||||
});
|
||||
server.on('error', (e) => reject(e));
|
||||
server.on('close', () => resolve(port));
|
||||
} catch (e) {
|
||||
reject(e);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
async function loadCachedCredentials(client: OAuth2Client): Promise<boolean> {
|
||||
const useEncryptedStorage = getUseEncryptedStorageFlag();
|
||||
if (useEncryptedStorage) {
|
||||
const credentials = await OAuthCredentialStorage.loadCredentials();
|
||||
if (credentials) {
|
||||
client.setCredentials(credentials);
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
const pathsToTry = [
|
||||
Storage.getOAuthCredsPath(),
|
||||
process.env['GOOGLE_APPLICATION_CREDENTIALS'],
|
||||
].filter((p): p is string => !!p);
|
||||
|
||||
for (const keyFile of pathsToTry) {
|
||||
try {
|
||||
const creds = await fs.readFile(keyFile, 'utf-8');
|
||||
client.setCredentials(JSON.parse(creds));
|
||||
|
||||
// This will verify locally that the credentials look good.
|
||||
const { token } = await client.getAccessToken();
|
||||
if (!token) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// This will check with the server to see if it hasn't been revoked.
|
||||
await client.getTokenInfo(token);
|
||||
|
||||
return true;
|
||||
} catch (error) {
|
||||
// Log specific error for debugging, but continue trying other paths
|
||||
console.debug(
|
||||
`Failed to load credentials from ${keyFile}:`,
|
||||
getErrorMessage(error),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
async function cacheCredentials(credentials: Credentials) {
|
||||
const filePath = Storage.getOAuthCredsPath();
|
||||
await fs.mkdir(path.dirname(filePath), { recursive: true });
|
||||
|
||||
const credString = JSON.stringify(credentials, null, 2);
|
||||
await fs.writeFile(filePath, credString, { mode: 0o600 });
|
||||
try {
|
||||
await fs.chmod(filePath, 0o600);
|
||||
} catch {
|
||||
/* empty */
|
||||
}
|
||||
}
|
||||
|
||||
export function clearOauthClientCache() {
|
||||
oauthClientPromises.clear();
|
||||
}
|
||||
|
||||
export async function clearCachedCredentialFile() {
|
||||
try {
|
||||
const useEncryptedStorage = getUseEncryptedStorageFlag();
|
||||
if (useEncryptedStorage) {
|
||||
await OAuthCredentialStorage.clearCredentials();
|
||||
} else {
|
||||
await fs.rm(Storage.getOAuthCredsPath(), { force: true });
|
||||
}
|
||||
// Clear the Google Account ID cache when credentials are cleared
|
||||
await userAccountManager.clearCachedGoogleAccount();
|
||||
// Clear the in-memory OAuth client cache to force re-authentication
|
||||
clearOauthClientCache();
|
||||
|
||||
/**
|
||||
* Also clear Qwen SharedTokenManager cache and credentials file to prevent stale credentials
|
||||
* when switching between auth types
|
||||
* TODO: We do not depend on code_assist, we'll have to build an independent auth-cleaning procedure.
|
||||
*/
|
||||
try {
|
||||
const { SharedTokenManager } = await import(
|
||||
'../qwen/sharedTokenManager.js'
|
||||
);
|
||||
const { clearQwenCredentials } = await import('../qwen/qwenOAuth2.js');
|
||||
|
||||
const sharedManager = SharedTokenManager.getInstance();
|
||||
sharedManager.clearCache();
|
||||
|
||||
await clearQwenCredentials();
|
||||
} catch (qwenError) {
|
||||
console.debug('Could not clear Qwen credentials:', qwenError);
|
||||
}
|
||||
} catch (e) {
|
||||
console.error('Failed to clear cached credentials:', e);
|
||||
}
|
||||
}
|
||||
|
||||
async function fetchAndCacheUserInfo(client: OAuth2Client): Promise<void> {
|
||||
try {
|
||||
const { token } = await client.getAccessToken();
|
||||
if (!token) {
|
||||
return;
|
||||
}
|
||||
|
||||
const response = await fetch(
|
||||
'https://www.googleapis.com/oauth2/v2/userinfo',
|
||||
{
|
||||
headers: {
|
||||
Authorization: `Bearer ${token}`,
|
||||
},
|
||||
},
|
||||
);
|
||||
|
||||
if (!response.ok) {
|
||||
console.error(
|
||||
'Failed to fetch user info:',
|
||||
response.status,
|
||||
response.statusText,
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
const userInfo = await response.json();
|
||||
await userAccountManager.cacheGoogleAccount(userInfo.email);
|
||||
} catch (error) {
|
||||
console.error('Error retrieving user info:', error);
|
||||
}
|
||||
}
|
||||
|
||||
// Helper to ensure test isolation
|
||||
export function resetOauthClientForTesting() {
|
||||
oauthClientPromises.clear();
|
||||
}
|
||||
@@ -1,255 +0,0 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { beforeEach, describe, it, expect, vi } from 'vitest';
|
||||
import { CodeAssistServer } from './server.js';
|
||||
import { OAuth2Client } from 'google-auth-library';
|
||||
import { UserTierId } from './types.js';
|
||||
|
||||
vi.mock('google-auth-library');
|
||||
|
||||
describe('CodeAssistServer', () => {
|
||||
beforeEach(() => {
|
||||
vi.resetAllMocks();
|
||||
});
|
||||
|
||||
it('should be able to be constructed', () => {
|
||||
const auth = new OAuth2Client();
|
||||
const server = new CodeAssistServer(
|
||||
auth,
|
||||
'test-project',
|
||||
{},
|
||||
'test-session',
|
||||
UserTierId.FREE,
|
||||
);
|
||||
expect(server).toBeInstanceOf(CodeAssistServer);
|
||||
});
|
||||
|
||||
it('should call the generateContent endpoint', async () => {
|
||||
const client = new OAuth2Client();
|
||||
const server = new CodeAssistServer(
|
||||
client,
|
||||
'test-project',
|
||||
{},
|
||||
'test-session',
|
||||
UserTierId.FREE,
|
||||
);
|
||||
const mockResponse = {
|
||||
response: {
|
||||
candidates: [
|
||||
{
|
||||
index: 0,
|
||||
content: {
|
||||
role: 'model',
|
||||
parts: [{ text: 'response' }],
|
||||
},
|
||||
finishReason: 'STOP',
|
||||
safetyRatings: [],
|
||||
},
|
||||
],
|
||||
},
|
||||
};
|
||||
vi.spyOn(server, 'requestPost').mockResolvedValue(mockResponse);
|
||||
|
||||
const response = await server.generateContent(
|
||||
{
|
||||
model: 'test-model',
|
||||
contents: [{ role: 'user', parts: [{ text: 'request' }] }],
|
||||
},
|
||||
'user-prompt-id',
|
||||
);
|
||||
|
||||
expect(server.requestPost).toHaveBeenCalledWith(
|
||||
'generateContent',
|
||||
expect.any(Object),
|
||||
undefined,
|
||||
);
|
||||
expect(response.candidates?.[0]?.content?.parts?.[0]?.text).toBe(
|
||||
'response',
|
||||
);
|
||||
});
|
||||
|
||||
it('should call the generateContentStream endpoint', async () => {
|
||||
const client = new OAuth2Client();
|
||||
const server = new CodeAssistServer(
|
||||
client,
|
||||
'test-project',
|
||||
{},
|
||||
'test-session',
|
||||
UserTierId.FREE,
|
||||
);
|
||||
const mockResponse = (async function* () {
|
||||
yield {
|
||||
response: {
|
||||
candidates: [
|
||||
{
|
||||
index: 0,
|
||||
content: {
|
||||
role: 'model',
|
||||
parts: [{ text: 'response' }],
|
||||
},
|
||||
finishReason: 'STOP',
|
||||
safetyRatings: [],
|
||||
},
|
||||
],
|
||||
},
|
||||
};
|
||||
})();
|
||||
vi.spyOn(server, 'requestStreamingPost').mockResolvedValue(mockResponse);
|
||||
|
||||
const stream = await server.generateContentStream(
|
||||
{
|
||||
model: 'test-model',
|
||||
contents: [{ role: 'user', parts: [{ text: 'request' }] }],
|
||||
},
|
||||
'user-prompt-id',
|
||||
);
|
||||
|
||||
for await (const res of stream) {
|
||||
expect(server.requestStreamingPost).toHaveBeenCalledWith(
|
||||
'streamGenerateContent',
|
||||
expect.any(Object),
|
||||
undefined,
|
||||
);
|
||||
expect(res.candidates?.[0]?.content?.parts?.[0]?.text).toBe('response');
|
||||
}
|
||||
});
|
||||
|
||||
it('should call the onboardUser endpoint', async () => {
|
||||
const client = new OAuth2Client();
|
||||
const server = new CodeAssistServer(
|
||||
client,
|
||||
'test-project',
|
||||
{},
|
||||
'test-session',
|
||||
UserTierId.FREE,
|
||||
);
|
||||
const mockResponse = {
|
||||
name: 'operations/123',
|
||||
done: true,
|
||||
};
|
||||
vi.spyOn(server, 'requestPost').mockResolvedValue(mockResponse);
|
||||
|
||||
const response = await server.onboardUser({
|
||||
tierId: 'test-tier',
|
||||
cloudaicompanionProject: 'test-project',
|
||||
metadata: {},
|
||||
});
|
||||
|
||||
expect(server.requestPost).toHaveBeenCalledWith(
|
||||
'onboardUser',
|
||||
expect.any(Object),
|
||||
);
|
||||
expect(response.name).toBe('operations/123');
|
||||
});
|
||||
|
||||
it('should call the loadCodeAssist endpoint', async () => {
|
||||
const client = new OAuth2Client();
|
||||
const server = new CodeAssistServer(
|
||||
client,
|
||||
'test-project',
|
||||
{},
|
||||
'test-session',
|
||||
UserTierId.FREE,
|
||||
);
|
||||
const mockResponse = {
|
||||
currentTier: {
|
||||
id: UserTierId.FREE,
|
||||
name: 'Free',
|
||||
description: 'free tier',
|
||||
},
|
||||
allowedTiers: [],
|
||||
ineligibleTiers: [],
|
||||
cloudaicompanionProject: 'projects/test',
|
||||
};
|
||||
vi.spyOn(server, 'requestPost').mockResolvedValue(mockResponse);
|
||||
|
||||
const response = await server.loadCodeAssist({
|
||||
metadata: {},
|
||||
});
|
||||
|
||||
expect(server.requestPost).toHaveBeenCalledWith(
|
||||
'loadCodeAssist',
|
||||
expect.any(Object),
|
||||
);
|
||||
expect(response).toEqual(mockResponse);
|
||||
});
|
||||
|
||||
it('should return 0 for countTokens', async () => {
|
||||
const client = new OAuth2Client();
|
||||
const server = new CodeAssistServer(
|
||||
client,
|
||||
'test-project',
|
||||
{},
|
||||
'test-session',
|
||||
UserTierId.FREE,
|
||||
);
|
||||
const mockResponse = {
|
||||
totalTokens: 100,
|
||||
};
|
||||
vi.spyOn(server, 'requestPost').mockResolvedValue(mockResponse);
|
||||
|
||||
const response = await server.countTokens({
|
||||
model: 'test-model',
|
||||
contents: [{ role: 'user', parts: [{ text: 'request' }] }],
|
||||
});
|
||||
expect(response.totalTokens).toBe(100);
|
||||
});
|
||||
|
||||
it('should throw an error for embedContent', async () => {
|
||||
const client = new OAuth2Client();
|
||||
const server = new CodeAssistServer(
|
||||
client,
|
||||
'test-project',
|
||||
{},
|
||||
'test-session',
|
||||
UserTierId.FREE,
|
||||
);
|
||||
await expect(
|
||||
server.embedContent({
|
||||
model: 'test-model',
|
||||
contents: [{ role: 'user', parts: [{ text: 'request' }] }],
|
||||
}),
|
||||
).rejects.toThrow();
|
||||
});
|
||||
|
||||
it('should handle VPC-SC errors when calling loadCodeAssist', async () => {
|
||||
const client = new OAuth2Client();
|
||||
const server = new CodeAssistServer(
|
||||
client,
|
||||
'test-project',
|
||||
{},
|
||||
'test-session',
|
||||
UserTierId.FREE,
|
||||
);
|
||||
const mockVpcScError = {
|
||||
response: {
|
||||
data: {
|
||||
error: {
|
||||
details: [
|
||||
{
|
||||
reason: 'SECURITY_POLICY_VIOLATED',
|
||||
},
|
||||
],
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
vi.spyOn(server, 'requestPost').mockRejectedValue(mockVpcScError);
|
||||
|
||||
const response = await server.loadCodeAssist({
|
||||
metadata: {},
|
||||
});
|
||||
|
||||
expect(server.requestPost).toHaveBeenCalledWith(
|
||||
'loadCodeAssist',
|
||||
expect.any(Object),
|
||||
);
|
||||
expect(response).toEqual({
|
||||
currentTier: { id: UserTierId.STANDARD },
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -1,253 +0,0 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import type { OAuth2Client } from 'google-auth-library';
|
||||
import type {
|
||||
CodeAssistGlobalUserSettingResponse,
|
||||
GoogleRpcResponse,
|
||||
LoadCodeAssistRequest,
|
||||
LoadCodeAssistResponse,
|
||||
LongRunningOperationResponse,
|
||||
OnboardUserRequest,
|
||||
SetCodeAssistGlobalUserSettingRequest,
|
||||
} from './types.js';
|
||||
import type {
|
||||
CountTokensParameters,
|
||||
CountTokensResponse,
|
||||
EmbedContentParameters,
|
||||
EmbedContentResponse,
|
||||
GenerateContentParameters,
|
||||
GenerateContentResponse,
|
||||
} from '@google/genai';
|
||||
import * as readline from 'node:readline';
|
||||
import type { ContentGenerator } from '../core/contentGenerator.js';
|
||||
import { UserTierId } from './types.js';
|
||||
import type {
|
||||
CaCountTokenResponse,
|
||||
CaGenerateContentResponse,
|
||||
} from './converter.js';
|
||||
import {
|
||||
fromCountTokenResponse,
|
||||
fromGenerateContentResponse,
|
||||
toCountTokenRequest,
|
||||
toGenerateContentRequest,
|
||||
} from './converter.js';
|
||||
|
||||
/** HTTP options to be used in each of the requests. */
|
||||
export interface HttpOptions {
|
||||
/** Additional HTTP headers to be sent with the request. */
|
||||
headers?: Record<string, string>;
|
||||
}
|
||||
|
||||
export const CODE_ASSIST_ENDPOINT = 'https://localhost:0'; // Disable Google Code Assist API Request
|
||||
export const CODE_ASSIST_API_VERSION = 'v1internal';
|
||||
|
||||
export class CodeAssistServer implements ContentGenerator {
|
||||
constructor(
|
||||
readonly client: OAuth2Client,
|
||||
readonly projectId?: string,
|
||||
readonly httpOptions: HttpOptions = {},
|
||||
readonly sessionId?: string,
|
||||
readonly userTier?: UserTierId,
|
||||
) {}
|
||||
|
||||
async generateContentStream(
|
||||
req: GenerateContentParameters,
|
||||
userPromptId: string,
|
||||
): Promise<AsyncGenerator<GenerateContentResponse>> {
|
||||
const resps = await this.requestStreamingPost<CaGenerateContentResponse>(
|
||||
'streamGenerateContent',
|
||||
toGenerateContentRequest(
|
||||
req,
|
||||
userPromptId,
|
||||
this.projectId,
|
||||
this.sessionId,
|
||||
),
|
||||
req.config?.abortSignal,
|
||||
);
|
||||
return (async function* (): AsyncGenerator<GenerateContentResponse> {
|
||||
for await (const resp of resps) {
|
||||
yield fromGenerateContentResponse(resp);
|
||||
}
|
||||
})();
|
||||
}
|
||||
|
||||
async generateContent(
|
||||
req: GenerateContentParameters,
|
||||
userPromptId: string,
|
||||
): Promise<GenerateContentResponse> {
|
||||
const resp = await this.requestPost<CaGenerateContentResponse>(
|
||||
'generateContent',
|
||||
toGenerateContentRequest(
|
||||
req,
|
||||
userPromptId,
|
||||
this.projectId,
|
||||
this.sessionId,
|
||||
),
|
||||
req.config?.abortSignal,
|
||||
);
|
||||
return fromGenerateContentResponse(resp);
|
||||
}
|
||||
|
||||
async onboardUser(
|
||||
req: OnboardUserRequest,
|
||||
): Promise<LongRunningOperationResponse> {
|
||||
return await this.requestPost<LongRunningOperationResponse>(
|
||||
'onboardUser',
|
||||
req,
|
||||
);
|
||||
}
|
||||
|
||||
async loadCodeAssist(
|
||||
req: LoadCodeAssistRequest,
|
||||
): Promise<LoadCodeAssistResponse> {
|
||||
try {
|
||||
return await this.requestPost<LoadCodeAssistResponse>(
|
||||
'loadCodeAssist',
|
||||
req,
|
||||
);
|
||||
} catch (e) {
|
||||
if (isVpcScAffectedUser(e)) {
|
||||
return {
|
||||
currentTier: { id: UserTierId.STANDARD },
|
||||
};
|
||||
} else {
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async getCodeAssistGlobalUserSetting(): Promise<CodeAssistGlobalUserSettingResponse> {
|
||||
return await this.requestGet<CodeAssistGlobalUserSettingResponse>(
|
||||
'getCodeAssistGlobalUserSetting',
|
||||
);
|
||||
}
|
||||
|
||||
async setCodeAssistGlobalUserSetting(
|
||||
req: SetCodeAssistGlobalUserSettingRequest,
|
||||
): Promise<CodeAssistGlobalUserSettingResponse> {
|
||||
return await this.requestPost<CodeAssistGlobalUserSettingResponse>(
|
||||
'setCodeAssistGlobalUserSetting',
|
||||
req,
|
||||
);
|
||||
}
|
||||
|
||||
async countTokens(req: CountTokensParameters): Promise<CountTokensResponse> {
|
||||
const resp = await this.requestPost<CaCountTokenResponse>(
|
||||
'countTokens',
|
||||
toCountTokenRequest(req),
|
||||
);
|
||||
return fromCountTokenResponse(resp);
|
||||
}
|
||||
|
||||
async embedContent(
|
||||
_req: EmbedContentParameters,
|
||||
): Promise<EmbedContentResponse> {
|
||||
throw Error();
|
||||
}
|
||||
|
||||
async requestPost<T>(
|
||||
method: string,
|
||||
req: object,
|
||||
signal?: AbortSignal,
|
||||
): Promise<T> {
|
||||
const res = await this.client.request({
|
||||
url: this.getMethodUrl(method),
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
...this.httpOptions.headers,
|
||||
},
|
||||
responseType: 'json',
|
||||
body: JSON.stringify(req),
|
||||
signal,
|
||||
});
|
||||
return res.data as T;
|
||||
}
|
||||
|
||||
async requestGet<T>(method: string, signal?: AbortSignal): Promise<T> {
|
||||
const res = await this.client.request({
|
||||
url: this.getMethodUrl(method),
|
||||
method: 'GET',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
...this.httpOptions.headers,
|
||||
},
|
||||
responseType: 'json',
|
||||
signal,
|
||||
});
|
||||
return res.data as T;
|
||||
}
|
||||
|
||||
async requestStreamingPost<T>(
|
||||
method: string,
|
||||
req: object,
|
||||
signal?: AbortSignal,
|
||||
): Promise<AsyncGenerator<T>> {
|
||||
const res = await this.client.request({
|
||||
url: this.getMethodUrl(method),
|
||||
method: 'POST',
|
||||
params: {
|
||||
alt: 'sse',
|
||||
},
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
...this.httpOptions.headers,
|
||||
},
|
||||
responseType: 'stream',
|
||||
body: JSON.stringify(req),
|
||||
signal,
|
||||
});
|
||||
|
||||
return (async function* (): AsyncGenerator<T> {
|
||||
const rl = readline.createInterface({
|
||||
input: res.data as NodeJS.ReadableStream,
|
||||
crlfDelay: Infinity, // Recognizes '\r\n' and '\n' as line breaks
|
||||
});
|
||||
|
||||
let bufferedLines: string[] = [];
|
||||
for await (const line of rl) {
|
||||
// blank lines are used to separate JSON objects in the stream
|
||||
if (line === '') {
|
||||
if (bufferedLines.length === 0) {
|
||||
continue; // no data to yield
|
||||
}
|
||||
yield JSON.parse(bufferedLines.join('\n')) as T;
|
||||
bufferedLines = []; // Reset the buffer after yielding
|
||||
} else if (line.startsWith('data: ')) {
|
||||
bufferedLines.push(line.slice(6).trim());
|
||||
} else {
|
||||
throw new Error(`Unexpected line format in response: ${line}`);
|
||||
}
|
||||
}
|
||||
})();
|
||||
}
|
||||
|
||||
getMethodUrl(method: string): string {
|
||||
const endpoint =
|
||||
process.env['CODE_ASSIST_ENDPOINT'] ?? CODE_ASSIST_ENDPOINT;
|
||||
return `${endpoint}/${CODE_ASSIST_API_VERSION}:${method}`;
|
||||
}
|
||||
}
|
||||
|
||||
function isVpcScAffectedUser(error: unknown): boolean {
|
||||
if (error && typeof error === 'object' && 'response' in error) {
|
||||
const gaxiosError = error as {
|
||||
response?: {
|
||||
data?: unknown;
|
||||
};
|
||||
};
|
||||
const response = gaxiosError.response?.data as
|
||||
| GoogleRpcResponse
|
||||
| undefined;
|
||||
if (Array.isArray(response?.error?.details)) {
|
||||
return response.error.details.some(
|
||||
(detail) => detail.reason === 'SECURITY_POLICY_VIOLATED',
|
||||
);
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
@@ -1,224 +0,0 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
|
||||
import { setupUser, ProjectIdRequiredError } from './setup.js';
|
||||
import { CodeAssistServer } from '../code_assist/server.js';
|
||||
import type { OAuth2Client } from 'google-auth-library';
|
||||
import type { GeminiUserTier } from './types.js';
|
||||
import { UserTierId } from './types.js';
|
||||
|
||||
vi.mock('../code_assist/server.js');
|
||||
|
||||
const mockPaidTier: GeminiUserTier = {
|
||||
id: UserTierId.STANDARD,
|
||||
name: 'paid',
|
||||
description: 'Paid tier',
|
||||
isDefault: true,
|
||||
};
|
||||
|
||||
const mockFreeTier: GeminiUserTier = {
|
||||
id: UserTierId.FREE,
|
||||
name: 'free',
|
||||
description: 'Free tier',
|
||||
isDefault: true,
|
||||
};
|
||||
|
||||
describe('setupUser for existing user', () => {
|
||||
let mockLoad: ReturnType<typeof vi.fn>;
|
||||
let mockOnboardUser: ReturnType<typeof vi.fn>;
|
||||
|
||||
beforeEach(() => {
|
||||
vi.resetAllMocks();
|
||||
mockLoad = vi.fn();
|
||||
mockOnboardUser = vi.fn().mockResolvedValue({
|
||||
done: true,
|
||||
response: {
|
||||
cloudaicompanionProject: {
|
||||
id: 'server-project',
|
||||
},
|
||||
},
|
||||
});
|
||||
vi.mocked(CodeAssistServer).mockImplementation(
|
||||
() =>
|
||||
({
|
||||
loadCodeAssist: mockLoad,
|
||||
onboardUser: mockOnboardUser,
|
||||
}) as unknown as CodeAssistServer,
|
||||
);
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.unstubAllEnvs();
|
||||
});
|
||||
|
||||
it('should use GOOGLE_CLOUD_PROJECT when set and project from server is undefined', async () => {
|
||||
vi.stubEnv('GOOGLE_CLOUD_PROJECT', 'test-project');
|
||||
mockLoad.mockResolvedValue({
|
||||
currentTier: mockPaidTier,
|
||||
});
|
||||
await setupUser({} as OAuth2Client);
|
||||
expect(CodeAssistServer).toHaveBeenCalledWith(
|
||||
{},
|
||||
'test-project',
|
||||
{},
|
||||
'',
|
||||
undefined,
|
||||
);
|
||||
});
|
||||
|
||||
it('should ignore GOOGLE_CLOUD_PROJECT when project from server is set', async () => {
|
||||
vi.stubEnv('GOOGLE_CLOUD_PROJECT', 'test-project');
|
||||
mockLoad.mockResolvedValue({
|
||||
cloudaicompanionProject: 'server-project',
|
||||
currentTier: mockPaidTier,
|
||||
});
|
||||
const projectId = await setupUser({} as OAuth2Client);
|
||||
expect(CodeAssistServer).toHaveBeenCalledWith(
|
||||
{},
|
||||
'test-project',
|
||||
{},
|
||||
'',
|
||||
undefined,
|
||||
);
|
||||
expect(projectId).toEqual({
|
||||
projectId: 'server-project',
|
||||
userTier: 'standard-tier',
|
||||
});
|
||||
});
|
||||
|
||||
it('should throw ProjectIdRequiredError when no project ID is available', async () => {
|
||||
vi.stubEnv('GOOGLE_CLOUD_PROJECT', '');
|
||||
// And the server itself requires a project ID internally
|
||||
vi.mocked(CodeAssistServer).mockImplementation(() => {
|
||||
throw new ProjectIdRequiredError();
|
||||
});
|
||||
|
||||
await expect(setupUser({} as OAuth2Client)).rejects.toThrow(
|
||||
ProjectIdRequiredError,
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('setupUser for new user', () => {
|
||||
let mockLoad: ReturnType<typeof vi.fn>;
|
||||
let mockOnboardUser: ReturnType<typeof vi.fn>;
|
||||
|
||||
beforeEach(() => {
|
||||
vi.resetAllMocks();
|
||||
mockLoad = vi.fn();
|
||||
mockOnboardUser = vi.fn().mockResolvedValue({
|
||||
done: true,
|
||||
response: {
|
||||
cloudaicompanionProject: {
|
||||
id: 'server-project',
|
||||
},
|
||||
},
|
||||
});
|
||||
vi.mocked(CodeAssistServer).mockImplementation(
|
||||
() =>
|
||||
({
|
||||
loadCodeAssist: mockLoad,
|
||||
onboardUser: mockOnboardUser,
|
||||
}) as unknown as CodeAssistServer,
|
||||
);
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.unstubAllEnvs();
|
||||
});
|
||||
|
||||
it('should use GOOGLE_CLOUD_PROJECT when set and onboard a new paid user', async () => {
|
||||
vi.stubEnv('GOOGLE_CLOUD_PROJECT', 'test-project');
|
||||
mockLoad.mockResolvedValue({
|
||||
allowedTiers: [mockPaidTier],
|
||||
});
|
||||
const userData = await setupUser({} as OAuth2Client);
|
||||
expect(CodeAssistServer).toHaveBeenCalledWith(
|
||||
{},
|
||||
'test-project',
|
||||
{},
|
||||
'',
|
||||
undefined,
|
||||
);
|
||||
expect(mockLoad).toHaveBeenCalled();
|
||||
expect(mockOnboardUser).toHaveBeenCalledWith({
|
||||
tierId: 'standard-tier',
|
||||
cloudaicompanionProject: 'test-project',
|
||||
metadata: {
|
||||
ideType: 'IDE_UNSPECIFIED',
|
||||
platform: 'PLATFORM_UNSPECIFIED',
|
||||
pluginType: 'GEMINI',
|
||||
duetProject: 'test-project',
|
||||
},
|
||||
});
|
||||
expect(userData).toEqual({
|
||||
projectId: 'server-project',
|
||||
userTier: 'standard-tier',
|
||||
});
|
||||
});
|
||||
|
||||
it('should onboard a new free user when GOOGLE_CLOUD_PROJECT is not set', async () => {
|
||||
vi.stubEnv('GOOGLE_CLOUD_PROJECT', '');
|
||||
mockLoad.mockResolvedValue({
|
||||
allowedTiers: [mockFreeTier],
|
||||
});
|
||||
const userData = await setupUser({} as OAuth2Client);
|
||||
expect(CodeAssistServer).toHaveBeenCalledWith(
|
||||
{},
|
||||
undefined,
|
||||
{},
|
||||
'',
|
||||
undefined,
|
||||
);
|
||||
expect(mockLoad).toHaveBeenCalled();
|
||||
expect(mockOnboardUser).toHaveBeenCalledWith({
|
||||
tierId: 'free-tier',
|
||||
cloudaicompanionProject: undefined,
|
||||
metadata: {
|
||||
ideType: 'IDE_UNSPECIFIED',
|
||||
platform: 'PLATFORM_UNSPECIFIED',
|
||||
pluginType: 'GEMINI',
|
||||
},
|
||||
});
|
||||
expect(userData).toEqual({
|
||||
projectId: 'server-project',
|
||||
userTier: 'free-tier',
|
||||
});
|
||||
});
|
||||
|
||||
it('should use GOOGLE_CLOUD_PROJECT when onboard response has no project ID', async () => {
|
||||
vi.stubEnv('GOOGLE_CLOUD_PROJECT', 'test-project');
|
||||
mockLoad.mockResolvedValue({
|
||||
allowedTiers: [mockPaidTier],
|
||||
});
|
||||
mockOnboardUser.mockResolvedValue({
|
||||
done: true,
|
||||
response: {
|
||||
cloudaicompanionProject: undefined,
|
||||
},
|
||||
});
|
||||
const userData = await setupUser({} as OAuth2Client);
|
||||
expect(userData).toEqual({
|
||||
projectId: 'test-project',
|
||||
userTier: 'standard-tier',
|
||||
});
|
||||
});
|
||||
|
||||
it('should throw ProjectIdRequiredError when no project ID is available', async () => {
|
||||
vi.stubEnv('GOOGLE_CLOUD_PROJECT', '');
|
||||
mockLoad.mockResolvedValue({
|
||||
allowedTiers: [mockPaidTier],
|
||||
});
|
||||
mockOnboardUser.mockResolvedValue({
|
||||
done: true,
|
||||
response: {},
|
||||
});
|
||||
await expect(setupUser({} as OAuth2Client)).rejects.toThrow(
|
||||
ProjectIdRequiredError,
|
||||
);
|
||||
});
|
||||
});
|
||||
@@ -1,124 +0,0 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import type {
|
||||
ClientMetadata,
|
||||
GeminiUserTier,
|
||||
LoadCodeAssistResponse,
|
||||
OnboardUserRequest,
|
||||
} from './types.js';
|
||||
import { UserTierId } from './types.js';
|
||||
import { CodeAssistServer } from './server.js';
|
||||
import type { OAuth2Client } from 'google-auth-library';
|
||||
|
||||
export class ProjectIdRequiredError extends Error {
|
||||
constructor() {
|
||||
super(
|
||||
'This account requires setting the GOOGLE_CLOUD_PROJECT env var. See https://goo.gle/gemini-cli-auth-docs#workspace-gca',
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
export interface UserData {
|
||||
projectId: string;
|
||||
userTier: UserTierId;
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param projectId the user's project id, if any
|
||||
* @returns the user's actual project id
|
||||
*/
|
||||
export async function setupUser(client: OAuth2Client): Promise<UserData> {
|
||||
const projectId = process.env['GOOGLE_CLOUD_PROJECT'] || undefined;
|
||||
const caServer = new CodeAssistServer(client, projectId, {}, '', undefined);
|
||||
const coreClientMetadata: ClientMetadata = {
|
||||
ideType: 'IDE_UNSPECIFIED',
|
||||
platform: 'PLATFORM_UNSPECIFIED',
|
||||
pluginType: 'GEMINI',
|
||||
};
|
||||
|
||||
const loadRes = await caServer.loadCodeAssist({
|
||||
cloudaicompanionProject: projectId,
|
||||
metadata: {
|
||||
...coreClientMetadata,
|
||||
duetProject: projectId,
|
||||
},
|
||||
});
|
||||
|
||||
if (loadRes.currentTier) {
|
||||
if (!loadRes.cloudaicompanionProject) {
|
||||
if (projectId) {
|
||||
return {
|
||||
projectId,
|
||||
userTier: loadRes.currentTier.id,
|
||||
};
|
||||
}
|
||||
throw new ProjectIdRequiredError();
|
||||
}
|
||||
return {
|
||||
projectId: loadRes.cloudaicompanionProject,
|
||||
userTier: loadRes.currentTier.id,
|
||||
};
|
||||
}
|
||||
|
||||
const tier = getOnboardTier(loadRes);
|
||||
|
||||
let onboardReq: OnboardUserRequest;
|
||||
if (tier.id === UserTierId.FREE) {
|
||||
// The free tier uses a managed google cloud project. Setting a project in the `onboardUser` request causes a `Precondition Failed` error.
|
||||
onboardReq = {
|
||||
tierId: tier.id,
|
||||
cloudaicompanionProject: undefined,
|
||||
metadata: coreClientMetadata,
|
||||
};
|
||||
} else {
|
||||
onboardReq = {
|
||||
tierId: tier.id,
|
||||
cloudaicompanionProject: projectId,
|
||||
metadata: {
|
||||
...coreClientMetadata,
|
||||
duetProject: projectId,
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
// Poll onboardUser until long running operation is complete.
|
||||
let lroRes = await caServer.onboardUser(onboardReq);
|
||||
while (!lroRes.done) {
|
||||
await new Promise((f) => setTimeout(f, 5000));
|
||||
lroRes = await caServer.onboardUser(onboardReq);
|
||||
}
|
||||
|
||||
if (!lroRes.response?.cloudaicompanionProject?.id) {
|
||||
if (projectId) {
|
||||
return {
|
||||
projectId,
|
||||
userTier: tier.id,
|
||||
};
|
||||
}
|
||||
throw new ProjectIdRequiredError();
|
||||
}
|
||||
|
||||
return {
|
||||
projectId: lroRes.response.cloudaicompanionProject.id,
|
||||
userTier: tier.id,
|
||||
};
|
||||
}
|
||||
|
||||
function getOnboardTier(res: LoadCodeAssistResponse): GeminiUserTier {
|
||||
for (const tier of res.allowedTiers || []) {
|
||||
if (tier.isDefault) {
|
||||
return tier;
|
||||
}
|
||||
}
|
||||
return {
|
||||
name: '',
|
||||
description: '',
|
||||
id: UserTierId.LEGACY,
|
||||
userDefinedCloudaicompanionProject: true,
|
||||
};
|
||||
}
|
||||
@@ -1,201 +0,0 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
export interface ClientMetadata {
|
||||
ideType?: ClientMetadataIdeType;
|
||||
ideVersion?: string;
|
||||
pluginVersion?: string;
|
||||
platform?: ClientMetadataPlatform;
|
||||
updateChannel?: string;
|
||||
duetProject?: string;
|
||||
pluginType?: ClientMetadataPluginType;
|
||||
ideName?: string;
|
||||
}
|
||||
|
||||
export type ClientMetadataIdeType =
|
||||
| 'IDE_UNSPECIFIED'
|
||||
| 'VSCODE'
|
||||
| 'INTELLIJ'
|
||||
| 'VSCODE_CLOUD_WORKSTATION'
|
||||
| 'INTELLIJ_CLOUD_WORKSTATION'
|
||||
| 'CLOUD_SHELL';
|
||||
export type ClientMetadataPlatform =
|
||||
| 'PLATFORM_UNSPECIFIED'
|
||||
| 'DARWIN_AMD64'
|
||||
| 'DARWIN_ARM64'
|
||||
| 'LINUX_AMD64'
|
||||
| 'LINUX_ARM64'
|
||||
| 'WINDOWS_AMD64';
|
||||
export type ClientMetadataPluginType =
|
||||
| 'PLUGIN_UNSPECIFIED'
|
||||
| 'CLOUD_CODE'
|
||||
| 'GEMINI'
|
||||
| 'AIPLUGIN_INTELLIJ'
|
||||
| 'AIPLUGIN_STUDIO';
|
||||
|
||||
export interface LoadCodeAssistRequest {
|
||||
cloudaicompanionProject?: string;
|
||||
metadata: ClientMetadata;
|
||||
}
|
||||
|
||||
/**
|
||||
* Represents LoadCodeAssistResponse proto json field
|
||||
* http://google3/google/internal/cloud/code/v1internal/cloudcode.proto;l=224
|
||||
*/
|
||||
export interface LoadCodeAssistResponse {
|
||||
currentTier?: GeminiUserTier | null;
|
||||
allowedTiers?: GeminiUserTier[] | null;
|
||||
ineligibleTiers?: IneligibleTier[] | null;
|
||||
cloudaicompanionProject?: string | null;
|
||||
}
|
||||
|
||||
/**
|
||||
* GeminiUserTier reflects the structure received from the CodeAssist when calling LoadCodeAssist.
|
||||
*/
|
||||
export interface GeminiUserTier {
|
||||
id: UserTierId;
|
||||
name?: string;
|
||||
description?: string;
|
||||
// This value is used to declare whether a given tier requires the user to configure the project setting on the IDE settings or not.
|
||||
userDefinedCloudaicompanionProject?: boolean | null;
|
||||
isDefault?: boolean;
|
||||
privacyNotice?: PrivacyNotice;
|
||||
hasAcceptedTos?: boolean;
|
||||
hasOnboardedPreviously?: boolean;
|
||||
}
|
||||
|
||||
/**
|
||||
* Includes information specifying the reasons for a user's ineligibility for a specific tier.
|
||||
* @param reasonCode mnemonic code representing the reason for in-eligibility.
|
||||
* @param reasonMessage message to display to the user.
|
||||
* @param tierId id of the tier.
|
||||
* @param tierName name of the tier.
|
||||
*/
|
||||
export interface IneligibleTier {
|
||||
reasonCode: IneligibleTierReasonCode;
|
||||
reasonMessage: string;
|
||||
tierId: UserTierId;
|
||||
tierName: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* List of predefined reason codes when a tier is blocked from a specific tier.
|
||||
* https://source.corp.google.com/piper///depot/google3/google/internal/cloud/code/v1internal/cloudcode.proto;l=378
|
||||
*/
|
||||
export enum IneligibleTierReasonCode {
|
||||
// go/keep-sorted start
|
||||
DASHER_USER = 'DASHER_USER',
|
||||
INELIGIBLE_ACCOUNT = 'INELIGIBLE_ACCOUNT',
|
||||
NON_USER_ACCOUNT = 'NON_USER_ACCOUNT',
|
||||
RESTRICTED_AGE = 'RESTRICTED_AGE',
|
||||
RESTRICTED_NETWORK = 'RESTRICTED_NETWORK',
|
||||
UNKNOWN = 'UNKNOWN',
|
||||
UNKNOWN_LOCATION = 'UNKNOWN_LOCATION',
|
||||
UNSUPPORTED_LOCATION = 'UNSUPPORTED_LOCATION',
|
||||
// go/keep-sorted end
|
||||
}
|
||||
/**
|
||||
* UserTierId represents IDs returned from the Cloud Code Private API representing a user's tier
|
||||
*
|
||||
* //depot/google3/cloud/developer_experience/cloudcode/pa/service/usertier.go;l=16
|
||||
*/
|
||||
export enum UserTierId {
|
||||
FREE = 'free-tier',
|
||||
LEGACY = 'legacy-tier',
|
||||
STANDARD = 'standard-tier',
|
||||
}
|
||||
|
||||
/**
|
||||
* PrivacyNotice reflects the structure received from the CodeAssist in regards to a tier
|
||||
* privacy notice.
|
||||
*/
|
||||
export interface PrivacyNotice {
|
||||
showNotice: boolean;
|
||||
noticeText?: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* Proto signature of OnboardUserRequest as payload to OnboardUser call
|
||||
*/
|
||||
export interface OnboardUserRequest {
|
||||
tierId: string | undefined;
|
||||
cloudaicompanionProject: string | undefined;
|
||||
metadata: ClientMetadata | undefined;
|
||||
}
|
||||
|
||||
/**
|
||||
* Represents LongRunningOperation proto
|
||||
* http://google3/google/longrunning/operations.proto;rcl=698857719;l=107
|
||||
*/
|
||||
export interface LongRunningOperationResponse {
|
||||
name: string;
|
||||
done?: boolean;
|
||||
response?: OnboardUserResponse;
|
||||
}
|
||||
|
||||
/**
|
||||
* Represents OnboardUserResponse proto
|
||||
* http://google3/google/internal/cloud/code/v1internal/cloudcode.proto;l=215
|
||||
*/
|
||||
export interface OnboardUserResponse {
|
||||
// tslint:disable-next-line:enforce-name-casing This is the name of the field in the proto.
|
||||
cloudaicompanionProject?: {
|
||||
id: string;
|
||||
name: string;
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Status code of user license status
|
||||
* it does not strictly correspond to the proto
|
||||
* Error value is an additional value assigned to error responses from OnboardUser
|
||||
*/
|
||||
export enum OnboardUserStatusCode {
|
||||
Default = 'DEFAULT',
|
||||
Notice = 'NOTICE',
|
||||
Warning = 'WARNING',
|
||||
Error = 'ERROR',
|
||||
}
|
||||
|
||||
/**
|
||||
* Status of user onboarded to gemini
|
||||
*/
|
||||
export interface OnboardUserStatus {
|
||||
statusCode: OnboardUserStatusCode;
|
||||
displayMessage: string;
|
||||
helpLink: HelpLinkUrl | undefined;
|
||||
}
|
||||
|
||||
export interface HelpLinkUrl {
|
||||
description: string;
|
||||
url: string;
|
||||
}
|
||||
|
||||
export interface SetCodeAssistGlobalUserSettingRequest {
|
||||
cloudaicompanionProject?: string;
|
||||
freeTierDataCollectionOptin: boolean;
|
||||
}
|
||||
|
||||
export interface CodeAssistGlobalUserSettingResponse {
|
||||
cloudaicompanionProject?: string;
|
||||
freeTierDataCollectionOptin: boolean;
|
||||
}
|
||||
|
||||
/**
|
||||
* Relevant fields that can be returned from a Google RPC response
|
||||
*/
|
||||
export interface GoogleRpcResponse {
|
||||
error?: {
|
||||
details?: GoogleRpcErrorInfo[];
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Relevant fields that can be returned in the details of an error returned from GoogleRPCs
|
||||
*/
|
||||
interface GoogleRpcErrorInfo {
|
||||
reason?: string;
|
||||
}
|
||||
@@ -283,23 +283,6 @@ describe('Server Config (config.ts)', () => {
|
||||
expect(config.isInFallbackMode()).toBe(false);
|
||||
});
|
||||
|
||||
it('should strip thoughts when switching from GenAI to Vertex', async () => {
|
||||
const config = new Config(baseParams);
|
||||
|
||||
vi.mocked(createContentGeneratorConfig).mockImplementation(
|
||||
(_: Config, authType: AuthType | undefined) =>
|
||||
({ authType }) as unknown as ContentGeneratorConfig,
|
||||
);
|
||||
|
||||
await config.refreshAuth(AuthType.USE_GEMINI);
|
||||
|
||||
await config.refreshAuth(AuthType.LOGIN_WITH_GOOGLE);
|
||||
|
||||
expect(
|
||||
config.getGeminiClient().stripThoughtsFromHistory,
|
||||
).toHaveBeenCalledWith();
|
||||
});
|
||||
|
||||
it('should not strip thoughts when switching from Vertex to GenAI', async () => {
|
||||
const config = new Config(baseParams);
|
||||
|
||||
|
||||
@@ -16,7 +16,8 @@ import { ProxyAgent, setGlobalDispatcher } from 'undici';
|
||||
import type {
|
||||
ContentGenerator,
|
||||
ContentGeneratorConfig,
|
||||
} from '../core/contentGenerator.js';
|
||||
|
||||
AuthType} from '../core/contentGenerator.js';
|
||||
import type { FallbackModelHandler } from '../fallback/types.js';
|
||||
import type { MCPOAuthConfig } from '../mcp/oauth-provider.js';
|
||||
import type { ShellExecutionConfig } from '../services/shellExecutionService.js';
|
||||
@@ -26,7 +27,6 @@ import type { AnyToolInvocation } from '../tools/tools.js';
|
||||
import { BaseLlmClient } from '../core/baseLlmClient.js';
|
||||
import { GeminiClient } from '../core/client.js';
|
||||
import {
|
||||
AuthType,
|
||||
createContentGenerator,
|
||||
createContentGeneratorConfig,
|
||||
} from '../core/contentGenerator.js';
|
||||
@@ -684,16 +684,6 @@ export class Config {
|
||||
}
|
||||
|
||||
async refreshAuth(authMethod: AuthType, isInitialAuth?: boolean) {
|
||||
// Vertex and Genai have incompatible encryption and sending history with
|
||||
// throughtSignature from Genai to Vertex will fail, we need to strip them
|
||||
if (
|
||||
this.contentGeneratorConfig?.authType === AuthType.USE_GEMINI &&
|
||||
authMethod === AuthType.LOGIN_WITH_GOOGLE
|
||||
) {
|
||||
// Restore the conversation history to the new client
|
||||
this.geminiClient.stripThoughtsFromHistory();
|
||||
}
|
||||
|
||||
const newContentGeneratorConfig = createContentGeneratorConfig(
|
||||
this,
|
||||
authMethod,
|
||||
|
||||
@@ -31,7 +31,7 @@ describe('Flash Model Fallback Configuration', () => {
|
||||
config as unknown as { contentGeneratorConfig: unknown }
|
||||
).contentGeneratorConfig = {
|
||||
model: DEFAULT_GEMINI_MODEL,
|
||||
authType: 'oauth-personal',
|
||||
authType: 'gemini-api-key',
|
||||
};
|
||||
});
|
||||
|
||||
|
||||
@@ -73,6 +73,7 @@ describe('OpenAIContentGenerator Timeout Handling', () => {
|
||||
}),
|
||||
buildClient: vi.fn().mockReturnValue(mockOpenAIClient),
|
||||
buildRequest: vi.fn().mockImplementation((req) => req),
|
||||
getDefaultGenerationConfig: vi.fn().mockReturnValue({}),
|
||||
};
|
||||
|
||||
// Create generator instance
|
||||
@@ -299,6 +300,7 @@ describe('OpenAIContentGenerator Timeout Handling', () => {
|
||||
}),
|
||||
buildClient: vi.fn().mockReturnValue(mockOpenAIClient),
|
||||
buildRequest: vi.fn().mockImplementation((req) => req),
|
||||
getDefaultGenerationConfig: vi.fn().mockReturnValue({}),
|
||||
};
|
||||
|
||||
new OpenAIContentGenerator(
|
||||
@@ -333,6 +335,7 @@ describe('OpenAIContentGenerator Timeout Handling', () => {
|
||||
}),
|
||||
buildClient: vi.fn().mockReturnValue(mockOpenAIClient),
|
||||
buildRequest: vi.fn().mockImplementation((req) => req),
|
||||
getDefaultGenerationConfig: vi.fn().mockReturnValue({}),
|
||||
};
|
||||
|
||||
new OpenAIContentGenerator(
|
||||
|
||||
@@ -146,13 +146,11 @@ describe('BaseLlmClient', () => {
|
||||
// Validate the parameters passed to the underlying generator
|
||||
expect(mockGenerateContent).toHaveBeenCalledTimes(1);
|
||||
expect(mockGenerateContent).toHaveBeenCalledWith(
|
||||
{
|
||||
expect.objectContaining({
|
||||
model: 'test-model',
|
||||
contents: defaultOptions.contents,
|
||||
config: {
|
||||
config: expect.objectContaining({
|
||||
abortSignal: defaultOptions.abortSignal,
|
||||
temperature: 0,
|
||||
topP: 1,
|
||||
tools: [
|
||||
{
|
||||
functionDeclarations: [
|
||||
@@ -164,9 +162,8 @@ describe('BaseLlmClient', () => {
|
||||
],
|
||||
},
|
||||
],
|
||||
// Crucial: systemInstruction should NOT be in the config object if not provided
|
||||
},
|
||||
},
|
||||
}),
|
||||
}),
|
||||
'test-prompt-id',
|
||||
);
|
||||
});
|
||||
@@ -189,7 +186,6 @@ describe('BaseLlmClient', () => {
|
||||
expect.objectContaining({
|
||||
config: expect.objectContaining({
|
||||
temperature: 0.8,
|
||||
topP: 1, // Default should remain if not overridden
|
||||
topK: 10,
|
||||
tools: expect.any(Array),
|
||||
}),
|
||||
|
||||
@@ -64,12 +64,6 @@ export interface GenerateJsonOptions {
|
||||
* A client dedicated to stateless, utility-focused LLM calls.
|
||||
*/
|
||||
export class BaseLlmClient {
|
||||
// Default configuration for utility tasks
|
||||
private readonly defaultUtilityConfig: GenerateContentConfig = {
|
||||
temperature: 0,
|
||||
topP: 1,
|
||||
};
|
||||
|
||||
constructor(
|
||||
private readonly contentGenerator: ContentGenerator,
|
||||
private readonly config: Config,
|
||||
@@ -90,7 +84,6 @@ export class BaseLlmClient {
|
||||
|
||||
const requestConfig: GenerateContentConfig = {
|
||||
abortSignal,
|
||||
...this.defaultUtilityConfig,
|
||||
...options.config,
|
||||
...(systemInstruction && { systemInstruction }),
|
||||
};
|
||||
|
||||
@@ -15,11 +15,7 @@ import {
|
||||
} from 'vitest';
|
||||
|
||||
import type { Content, GenerateContentResponse, Part } from '@google/genai';
|
||||
import {
|
||||
isThinkingDefault,
|
||||
isThinkingSupported,
|
||||
GeminiClient,
|
||||
} from './client.js';
|
||||
import { GeminiClient } from './client.js';
|
||||
import { findCompressSplitPoint } from '../services/chatCompressionService.js';
|
||||
import {
|
||||
AuthType,
|
||||
@@ -247,40 +243,6 @@ describe('findCompressSplitPoint', () => {
|
||||
});
|
||||
});
|
||||
|
||||
describe('isThinkingSupported', () => {
|
||||
it('should return true for gemini-2.5', () => {
|
||||
expect(isThinkingSupported('gemini-2.5')).toBe(true);
|
||||
});
|
||||
|
||||
it('should return true for gemini-2.5-pro', () => {
|
||||
expect(isThinkingSupported('gemini-2.5-pro')).toBe(true);
|
||||
});
|
||||
|
||||
it('should return false for other models', () => {
|
||||
expect(isThinkingSupported('gemini-1.5-flash')).toBe(false);
|
||||
expect(isThinkingSupported('some-other-model')).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe('isThinkingDefault', () => {
|
||||
it('should return false for gemini-2.5-flash-lite', () => {
|
||||
expect(isThinkingDefault('gemini-2.5-flash-lite')).toBe(false);
|
||||
});
|
||||
|
||||
it('should return true for gemini-2.5', () => {
|
||||
expect(isThinkingDefault('gemini-2.5')).toBe(true);
|
||||
});
|
||||
|
||||
it('should return true for gemini-2.5-pro', () => {
|
||||
expect(isThinkingDefault('gemini-2.5-pro')).toBe(true);
|
||||
});
|
||||
|
||||
it('should return false for other models', () => {
|
||||
expect(isThinkingDefault('gemini-1.5-flash')).toBe(false);
|
||||
expect(isThinkingDefault('some-other-model')).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Gemini Client (client.ts)', () => {
|
||||
let mockContentGenerator: ContentGenerator;
|
||||
let mockConfig: Config;
|
||||
@@ -2304,16 +2266,15 @@ ${JSON.stringify(
|
||||
);
|
||||
|
||||
expect(mockContentGenerator.generateContent).toHaveBeenCalledWith(
|
||||
{
|
||||
expect.objectContaining({
|
||||
model: DEFAULT_GEMINI_FLASH_MODEL,
|
||||
config: {
|
||||
config: expect.objectContaining({
|
||||
abortSignal,
|
||||
systemInstruction: getCoreSystemPrompt(''),
|
||||
temperature: 0.5,
|
||||
topP: 1,
|
||||
},
|
||||
}),
|
||||
contents,
|
||||
},
|
||||
}),
|
||||
'test-session-id',
|
||||
);
|
||||
});
|
||||
|
||||
@@ -15,11 +15,7 @@ import type {
|
||||
|
||||
// Config
|
||||
import { ApprovalMode, type Config } from '../config/config.js';
|
||||
import {
|
||||
DEFAULT_GEMINI_FLASH_MODEL,
|
||||
DEFAULT_GEMINI_MODEL_AUTO,
|
||||
DEFAULT_THINKING_MODE,
|
||||
} from '../config/models.js';
|
||||
import { DEFAULT_GEMINI_FLASH_MODEL } from '../config/models.js';
|
||||
|
||||
// Core modules
|
||||
import type { ContentGenerator } from './contentGenerator.js';
|
||||
@@ -78,25 +74,10 @@ import { type File, type IdeContext } from '../ide/types.js';
|
||||
// Fallback handling
|
||||
import { handleFallback } from '../fallback/handler.js';
|
||||
|
||||
export function isThinkingSupported(model: string) {
|
||||
return model.startsWith('gemini-2.5') || model === DEFAULT_GEMINI_MODEL_AUTO;
|
||||
}
|
||||
|
||||
export function isThinkingDefault(model: string) {
|
||||
if (model.startsWith('gemini-2.5-flash-lite')) {
|
||||
return false;
|
||||
}
|
||||
return model.startsWith('gemini-2.5') || model === DEFAULT_GEMINI_MODEL_AUTO;
|
||||
}
|
||||
|
||||
const MAX_TURNS = 100;
|
||||
|
||||
export class GeminiClient {
|
||||
private chat?: GeminiChat;
|
||||
private readonly generateContentConfig: GenerateContentConfig = {
|
||||
temperature: 0,
|
||||
topP: 1,
|
||||
};
|
||||
private sessionTurnCount = 0;
|
||||
|
||||
private readonly loopDetector: LoopDetectionService;
|
||||
@@ -208,20 +189,10 @@ export class GeminiClient {
|
||||
const model = this.config.getModel();
|
||||
const systemInstruction = getCoreSystemPrompt(userMemory, model);
|
||||
|
||||
const config: GenerateContentConfig = { ...this.generateContentConfig };
|
||||
|
||||
if (isThinkingSupported(model)) {
|
||||
config.thinkingConfig = {
|
||||
includeThoughts: true,
|
||||
thinkingBudget: DEFAULT_THINKING_MODE,
|
||||
};
|
||||
}
|
||||
|
||||
return new GeminiChat(
|
||||
this.config,
|
||||
{
|
||||
systemInstruction,
|
||||
...config,
|
||||
tools,
|
||||
},
|
||||
history,
|
||||
@@ -618,11 +589,6 @@ export class GeminiClient {
|
||||
): Promise<GenerateContentResponse> {
|
||||
let currentAttemptModel: string = model;
|
||||
|
||||
const configToUse: GenerateContentConfig = {
|
||||
...this.generateContentConfig,
|
||||
...generationConfig,
|
||||
};
|
||||
|
||||
try {
|
||||
const userMemory = this.config.getUserMemory();
|
||||
const finalSystemInstruction = generationConfig.systemInstruction
|
||||
@@ -631,7 +597,7 @@ export class GeminiClient {
|
||||
|
||||
const requestConfig: GenerateContentConfig = {
|
||||
abortSignal,
|
||||
...configToUse,
|
||||
...generationConfig,
|
||||
systemInstruction: finalSystemInstruction,
|
||||
};
|
||||
|
||||
@@ -672,7 +638,7 @@ export class GeminiClient {
|
||||
`Error generating content via API with model ${currentAttemptModel}.`,
|
||||
{
|
||||
requestContents: contents,
|
||||
requestConfig: configToUse,
|
||||
requestConfig: generationConfig,
|
||||
},
|
||||
'generateContent-api',
|
||||
);
|
||||
|
||||
@@ -5,42 +5,19 @@
|
||||
*/
|
||||
|
||||
import { describe, it, expect, vi } from 'vitest';
|
||||
import type { ContentGenerator } from './contentGenerator.js';
|
||||
import { createContentGenerator, AuthType } from './contentGenerator.js';
|
||||
import { createCodeAssistContentGenerator } from '../code_assist/codeAssist.js';
|
||||
import { GoogleGenAI } from '@google/genai';
|
||||
import type { Config } from '../config/config.js';
|
||||
import { LoggingContentGenerator } from './loggingContentGenerator.js';
|
||||
import { LoggingContentGenerator } from './geminiContentGenerator/loggingContentGenerator.js';
|
||||
|
||||
vi.mock('../code_assist/codeAssist.js');
|
||||
vi.mock('@google/genai');
|
||||
|
||||
const mockConfig = {
|
||||
getCliVersion: vi.fn().mockReturnValue('1.0.0'),
|
||||
} as unknown as Config;
|
||||
|
||||
describe('createContentGenerator', () => {
|
||||
it('should create a CodeAssistContentGenerator', async () => {
|
||||
const mockGenerator = {} as unknown as ContentGenerator;
|
||||
vi.mocked(createCodeAssistContentGenerator).mockResolvedValue(
|
||||
mockGenerator as never,
|
||||
);
|
||||
const generator = await createContentGenerator(
|
||||
{
|
||||
model: 'test-model',
|
||||
authType: AuthType.LOGIN_WITH_GOOGLE,
|
||||
},
|
||||
mockConfig,
|
||||
);
|
||||
expect(createCodeAssistContentGenerator).toHaveBeenCalled();
|
||||
expect(generator).toEqual(
|
||||
new LoggingContentGenerator(mockGenerator, mockConfig),
|
||||
);
|
||||
});
|
||||
|
||||
it('should create a GoogleGenAI content generator', async () => {
|
||||
it('should create a Gemini content generator', async () => {
|
||||
const mockConfig = {
|
||||
getUsageStatisticsEnabled: () => true,
|
||||
getContentGeneratorConfig: () => ({}),
|
||||
getCliVersion: () => '1.0.0',
|
||||
} as unknown as Config;
|
||||
|
||||
const mockGenerator = {
|
||||
@@ -65,17 +42,17 @@ describe('createContentGenerator', () => {
|
||||
},
|
||||
},
|
||||
});
|
||||
expect(generator).toEqual(
|
||||
new LoggingContentGenerator(
|
||||
(mockGenerator as GoogleGenAI).models,
|
||||
mockConfig,
|
||||
),
|
||||
);
|
||||
// We expect it to be a LoggingContentGenerator wrapping a GeminiContentGenerator
|
||||
expect(generator).toBeInstanceOf(LoggingContentGenerator);
|
||||
const wrapped = (generator as LoggingContentGenerator).getWrapped();
|
||||
expect(wrapped).toBeDefined();
|
||||
});
|
||||
|
||||
it('should create a GoogleGenAI content generator with client install id logging disabled', async () => {
|
||||
it('should create a Gemini content generator with client install id logging disabled', async () => {
|
||||
const mockConfig = {
|
||||
getUsageStatisticsEnabled: () => false,
|
||||
getContentGeneratorConfig: () => ({}),
|
||||
getCliVersion: () => '1.0.0',
|
||||
} as unknown as Config;
|
||||
const mockGenerator = {
|
||||
models: {},
|
||||
@@ -98,11 +75,6 @@ describe('createContentGenerator', () => {
|
||||
},
|
||||
},
|
||||
});
|
||||
expect(generator).toEqual(
|
||||
new LoggingContentGenerator(
|
||||
(mockGenerator as GoogleGenAI).models,
|
||||
mockConfig,
|
||||
),
|
||||
);
|
||||
expect(generator).toBeInstanceOf(LoggingContentGenerator);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -12,15 +12,9 @@ import type {
|
||||
GenerateContentParameters,
|
||||
GenerateContentResponse,
|
||||
} from '@google/genai';
|
||||
import { GoogleGenAI } from '@google/genai';
|
||||
import { createCodeAssistContentGenerator } from '../code_assist/codeAssist.js';
|
||||
import { DEFAULT_QWEN_MODEL } from '../config/models.js';
|
||||
import type { Config } from '../config/config.js';
|
||||
|
||||
import type { UserTierId } from '../code_assist/types.js';
|
||||
import { InstallationManager } from '../utils/installationManager.js';
|
||||
import { LoggingContentGenerator } from './loggingContentGenerator.js';
|
||||
|
||||
/**
|
||||
* Interface abstracting the core functionalities for generating content and counting tokens.
|
||||
*/
|
||||
@@ -38,15 +32,11 @@ export interface ContentGenerator {
|
||||
countTokens(request: CountTokensParameters): Promise<CountTokensResponse>;
|
||||
|
||||
embedContent(request: EmbedContentParameters): Promise<EmbedContentResponse>;
|
||||
|
||||
userTier?: UserTierId;
|
||||
}
|
||||
|
||||
export enum AuthType {
|
||||
LOGIN_WITH_GOOGLE = 'oauth-personal',
|
||||
USE_GEMINI = 'gemini-api-key',
|
||||
USE_VERTEX_AI = 'vertex-ai',
|
||||
CLOUD_SHELL = 'cloud-shell',
|
||||
USE_OPENAI = 'openai',
|
||||
QWEN_OAUTH = 'qwen-oauth',
|
||||
}
|
||||
@@ -59,12 +49,9 @@ export type ContentGeneratorConfig = {
|
||||
authType?: AuthType | undefined;
|
||||
enableOpenAILogging?: boolean;
|
||||
openAILoggingDir?: string;
|
||||
// Timeout configuration in milliseconds
|
||||
timeout?: number;
|
||||
// Maximum retries for failed requests
|
||||
maxRetries?: number;
|
||||
// Disable cache control for DashScope providers
|
||||
disableCacheControl?: boolean;
|
||||
timeout?: number; // Timeout configuration in milliseconds
|
||||
maxRetries?: number; // Maximum retries for failed requests
|
||||
disableCacheControl?: boolean; // Disable cache control for DashScope providers
|
||||
samplingParams?: {
|
||||
top_p?: number;
|
||||
top_k?: number;
|
||||
@@ -74,6 +61,9 @@ export type ContentGeneratorConfig = {
|
||||
temperature?: number;
|
||||
max_tokens?: number;
|
||||
};
|
||||
reasoning?: {
|
||||
effort?: 'low' | 'medium' | 'high';
|
||||
};
|
||||
proxy?: string | undefined;
|
||||
userAgent?: string;
|
||||
// Schema compliance mode for tool definitions
|
||||
@@ -123,48 +113,14 @@ export async function createContentGenerator(
|
||||
gcConfig: Config,
|
||||
isInitialAuth?: boolean,
|
||||
): Promise<ContentGenerator> {
|
||||
const version = process.env['CLI_VERSION'] || process.version;
|
||||
const userAgent = `QwenCode/${version} (${process.platform}; ${process.arch})`;
|
||||
const baseHeaders: Record<string, string> = {
|
||||
'User-Agent': userAgent,
|
||||
};
|
||||
|
||||
if (
|
||||
config.authType === AuthType.LOGIN_WITH_GOOGLE ||
|
||||
config.authType === AuthType.CLOUD_SHELL
|
||||
) {
|
||||
const httpOptions = { headers: baseHeaders };
|
||||
return new LoggingContentGenerator(
|
||||
await createCodeAssistContentGenerator(
|
||||
httpOptions,
|
||||
config.authType,
|
||||
gcConfig,
|
||||
),
|
||||
gcConfig,
|
||||
);
|
||||
}
|
||||
|
||||
if (
|
||||
config.authType === AuthType.USE_GEMINI ||
|
||||
config.authType === AuthType.USE_VERTEX_AI
|
||||
) {
|
||||
let headers: Record<string, string> = { ...baseHeaders };
|
||||
if (gcConfig?.getUsageStatisticsEnabled()) {
|
||||
const installationManager = new InstallationManager();
|
||||
const installationId = installationManager.getInstallationId();
|
||||
headers = {
|
||||
...headers,
|
||||
'x-gemini-api-privileged-user-id': `${installationId}`,
|
||||
};
|
||||
}
|
||||
const httpOptions = { headers };
|
||||
|
||||
const googleGenAI = new GoogleGenAI({
|
||||
apiKey: config.apiKey === '' ? undefined : config.apiKey,
|
||||
vertexai: config.vertexai,
|
||||
httpOptions,
|
||||
});
|
||||
return new LoggingContentGenerator(googleGenAI.models, gcConfig);
|
||||
const { createGeminiContentGenerator } = await import(
|
||||
'./geminiContentGenerator/index.js'
|
||||
);
|
||||
return createGeminiContentGenerator(config, gcConfig);
|
||||
}
|
||||
|
||||
if (config.authType === AuthType.USE_OPENAI) {
|
||||
|
||||
@@ -240,7 +240,7 @@ describe('CoreToolScheduler', () => {
|
||||
getAllowedTools: () => [],
|
||||
getContentGeneratorConfig: () => ({
|
||||
model: 'test-model',
|
||||
authType: 'oauth-personal',
|
||||
authType: 'gemini-api-key',
|
||||
}),
|
||||
getShellExecutionConfig: () => ({
|
||||
terminalWidth: 90,
|
||||
@@ -318,7 +318,7 @@ describe('CoreToolScheduler', () => {
|
||||
getAllowedTools: () => [],
|
||||
getContentGeneratorConfig: () => ({
|
||||
model: 'test-model',
|
||||
authType: 'oauth-personal',
|
||||
authType: 'gemini-api-key',
|
||||
}),
|
||||
getShellExecutionConfig: () => ({
|
||||
terminalWidth: 90,
|
||||
@@ -497,7 +497,7 @@ describe('CoreToolScheduler', () => {
|
||||
getExcludeTools: () => ['write_file', 'edit', 'run_shell_command'],
|
||||
getContentGeneratorConfig: () => ({
|
||||
model: 'test-model',
|
||||
authType: 'oauth-personal',
|
||||
authType: 'gemini-api-key',
|
||||
}),
|
||||
getShellExecutionConfig: () => ({
|
||||
terminalWidth: 90,
|
||||
@@ -584,7 +584,7 @@ describe('CoreToolScheduler', () => {
|
||||
getExcludeTools: () => ['write_file', 'edit'], // Different excluded tools
|
||||
getContentGeneratorConfig: () => ({
|
||||
model: 'test-model',
|
||||
authType: 'oauth-personal',
|
||||
authType: 'gemini-api-key',
|
||||
}),
|
||||
getShellExecutionConfig: () => ({
|
||||
terminalWidth: 90,
|
||||
@@ -674,7 +674,7 @@ describe('CoreToolScheduler with payload', () => {
|
||||
getAllowedTools: () => [],
|
||||
getContentGeneratorConfig: () => ({
|
||||
model: 'test-model',
|
||||
authType: 'oauth-personal',
|
||||
authType: 'gemini-api-key',
|
||||
}),
|
||||
getShellExecutionConfig: () => ({
|
||||
terminalWidth: 90,
|
||||
@@ -1001,7 +1001,7 @@ describe('CoreToolScheduler edit cancellation', () => {
|
||||
getAllowedTools: () => [],
|
||||
getContentGeneratorConfig: () => ({
|
||||
model: 'test-model',
|
||||
authType: 'oauth-personal',
|
||||
authType: 'gemini-api-key',
|
||||
}),
|
||||
getShellExecutionConfig: () => ({
|
||||
terminalWidth: 90,
|
||||
@@ -1108,7 +1108,7 @@ describe('CoreToolScheduler YOLO mode', () => {
|
||||
getAllowedTools: () => [],
|
||||
getContentGeneratorConfig: () => ({
|
||||
model: 'test-model',
|
||||
authType: 'oauth-personal',
|
||||
authType: 'gemini-api-key',
|
||||
}),
|
||||
getShellExecutionConfig: () => ({
|
||||
terminalWidth: 90,
|
||||
@@ -1258,7 +1258,7 @@ describe('CoreToolScheduler cancellation during executing with live output', ()
|
||||
getApprovalMode: () => ApprovalMode.DEFAULT,
|
||||
getContentGeneratorConfig: () => ({
|
||||
model: 'test-model',
|
||||
authType: 'oauth-personal',
|
||||
authType: 'gemini-api-key',
|
||||
}),
|
||||
getToolRegistry: () => mockToolRegistry,
|
||||
getShellExecutionConfig: () => ({
|
||||
@@ -1350,7 +1350,7 @@ describe('CoreToolScheduler request queueing', () => {
|
||||
getAllowedTools: () => [],
|
||||
getContentGeneratorConfig: () => ({
|
||||
model: 'test-model',
|
||||
authType: 'oauth-personal',
|
||||
authType: 'gemini-api-key',
|
||||
}),
|
||||
getShellExecutionConfig: () => ({
|
||||
terminalWidth: 90,
|
||||
@@ -1482,7 +1482,7 @@ describe('CoreToolScheduler request queueing', () => {
|
||||
getToolRegistry: () => toolRegistry,
|
||||
getContentGeneratorConfig: () => ({
|
||||
model: 'test-model',
|
||||
authType: 'oauth-personal',
|
||||
authType: 'gemini-api-key',
|
||||
}),
|
||||
getShellExecutionConfig: () => ({
|
||||
terminalWidth: 80,
|
||||
@@ -1586,7 +1586,7 @@ describe('CoreToolScheduler request queueing', () => {
|
||||
getAllowedTools: () => [],
|
||||
getContentGeneratorConfig: () => ({
|
||||
model: 'test-model',
|
||||
authType: 'oauth-personal',
|
||||
authType: 'gemini-api-key',
|
||||
}),
|
||||
getShellExecutionConfig: () => ({
|
||||
terminalWidth: 90,
|
||||
@@ -1854,7 +1854,7 @@ describe('CoreToolScheduler Sequential Execution', () => {
|
||||
getAllowedTools: () => [],
|
||||
getContentGeneratorConfig: () => ({
|
||||
model: 'test-model',
|
||||
authType: 'oauth-personal',
|
||||
authType: 'gemini-api-key',
|
||||
}),
|
||||
getShellExecutionConfig: () => ({
|
||||
terminalWidth: 90,
|
||||
@@ -1975,7 +1975,7 @@ describe('CoreToolScheduler Sequential Execution', () => {
|
||||
getAllowedTools: () => [],
|
||||
getContentGeneratorConfig: () => ({
|
||||
model: 'test-model',
|
||||
authType: 'oauth-personal',
|
||||
authType: 'gemini-api-key',
|
||||
}),
|
||||
getShellExecutionConfig: () => ({
|
||||
terminalWidth: 90,
|
||||
|
||||
@@ -111,7 +111,7 @@ describe('GeminiChat', () => {
|
||||
getUsageStatisticsEnabled: () => true,
|
||||
getDebugMode: () => false,
|
||||
getContentGeneratorConfig: vi.fn().mockReturnValue({
|
||||
authType: 'oauth-personal', // Ensure this is set for fallback tests
|
||||
authType: 'gemini-api-key', // Ensure this is set for fallback tests
|
||||
model: 'test-model',
|
||||
}),
|
||||
getModel: vi.fn().mockReturnValue('gemini-pro'),
|
||||
@@ -1382,7 +1382,7 @@ describe('GeminiChat', () => {
|
||||
});
|
||||
|
||||
it('should call handleFallback with the specific failed model and retry if handler returns true', async () => {
|
||||
const authType = AuthType.LOGIN_WITH_GOOGLE;
|
||||
const authType = AuthType.USE_GEMINI;
|
||||
vi.mocked(mockConfig.getContentGeneratorConfig).mockReturnValue({
|
||||
model: 'test-model',
|
||||
authType,
|
||||
|
||||
@@ -0,0 +1,173 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { describe, it, expect, vi, beforeEach } from 'vitest';
|
||||
import { GeminiContentGenerator } from './geminiContentGenerator.js';
|
||||
import { GoogleGenAI } from '@google/genai';
|
||||
|
||||
vi.mock('@google/genai', () => {
|
||||
const mockGenerateContent = vi.fn();
|
||||
const mockGenerateContentStream = vi.fn();
|
||||
const mockCountTokens = vi.fn();
|
||||
const mockEmbedContent = vi.fn();
|
||||
|
||||
return {
|
||||
GoogleGenAI: vi.fn().mockImplementation(() => ({
|
||||
models: {
|
||||
generateContent: mockGenerateContent,
|
||||
generateContentStream: mockGenerateContentStream,
|
||||
countTokens: mockCountTokens,
|
||||
embedContent: mockEmbedContent,
|
||||
},
|
||||
})),
|
||||
};
|
||||
});
|
||||
|
||||
describe('GeminiContentGenerator', () => {
|
||||
let generator: GeminiContentGenerator;
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
let mockGoogleGenAI: any;
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
generator = new GeminiContentGenerator({
|
||||
apiKey: 'test-api-key',
|
||||
});
|
||||
mockGoogleGenAI = vi.mocked(GoogleGenAI).mock.results[0].value;
|
||||
});
|
||||
|
||||
it('should call generateContent on the underlying model', async () => {
|
||||
const request = { model: 'gemini-1.5-flash', contents: [] };
|
||||
const expectedResponse = { responseId: 'test-id' };
|
||||
mockGoogleGenAI.models.generateContent.mockResolvedValue(expectedResponse);
|
||||
|
||||
const response = await generator.generateContent(request, 'prompt-id');
|
||||
|
||||
expect(mockGoogleGenAI.models.generateContent).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
...request,
|
||||
config: expect.objectContaining({
|
||||
temperature: 1,
|
||||
topP: 0.95,
|
||||
thinkingConfig: {
|
||||
includeThoughts: true,
|
||||
thinkingLevel: 'HIGH',
|
||||
},
|
||||
}),
|
||||
}),
|
||||
);
|
||||
expect(response).toBe(expectedResponse);
|
||||
});
|
||||
|
||||
it('should call generateContentStream on the underlying model', async () => {
|
||||
const request = { model: 'gemini-1.5-flash', contents: [] };
|
||||
const mockStream = (async function* () {
|
||||
yield { responseId: '1' };
|
||||
})();
|
||||
mockGoogleGenAI.models.generateContentStream.mockResolvedValue(mockStream);
|
||||
|
||||
const stream = await generator.generateContentStream(request, 'prompt-id');
|
||||
|
||||
expect(mockGoogleGenAI.models.generateContentStream).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
...request,
|
||||
config: expect.objectContaining({
|
||||
temperature: 1,
|
||||
topP: 0.95,
|
||||
thinkingConfig: {
|
||||
includeThoughts: true,
|
||||
thinkingLevel: 'HIGH',
|
||||
},
|
||||
}),
|
||||
}),
|
||||
);
|
||||
expect(stream).toBe(mockStream);
|
||||
});
|
||||
|
||||
it('should call countTokens on the underlying model', async () => {
|
||||
const request = { model: 'gemini-1.5-flash', contents: [] };
|
||||
const expectedResponse = { totalTokens: 10 };
|
||||
mockGoogleGenAI.models.countTokens.mockResolvedValue(expectedResponse);
|
||||
|
||||
const response = await generator.countTokens(request);
|
||||
|
||||
expect(mockGoogleGenAI.models.countTokens).toHaveBeenCalledWith(request);
|
||||
expect(response).toBe(expectedResponse);
|
||||
});
|
||||
|
||||
it('should call embedContent on the underlying model', async () => {
|
||||
const request = { model: 'embedding-model', contents: [] };
|
||||
const expectedResponse = { embeddings: [] };
|
||||
mockGoogleGenAI.models.embedContent.mockResolvedValue(expectedResponse);
|
||||
|
||||
const response = await generator.embedContent(request);
|
||||
|
||||
expect(mockGoogleGenAI.models.embedContent).toHaveBeenCalledWith(request);
|
||||
expect(response).toBe(expectedResponse);
|
||||
});
|
||||
|
||||
it('should prioritize contentGeneratorConfig samplingParams over request config', async () => {
|
||||
const generatorWithParams = new GeminiContentGenerator({ apiKey: 'test' }, {
|
||||
model: 'gemini-1.5-flash',
|
||||
samplingParams: {
|
||||
temperature: 0.1,
|
||||
top_p: 0.2,
|
||||
},
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
} as any);
|
||||
|
||||
const request = {
|
||||
model: 'gemini-1.5-flash',
|
||||
contents: [],
|
||||
config: {
|
||||
temperature: 0.9,
|
||||
topP: 0.9,
|
||||
},
|
||||
};
|
||||
|
||||
await generatorWithParams.generateContent(request, 'prompt-id');
|
||||
|
||||
expect(mockGoogleGenAI.models.generateContent).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
config: expect.objectContaining({
|
||||
temperature: 0.1,
|
||||
topP: 0.2,
|
||||
}),
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
it('should map reasoning effort to thinkingConfig', async () => {
|
||||
const generatorWithReasoning = new GeminiContentGenerator(
|
||||
{ apiKey: 'test' },
|
||||
{
|
||||
model: 'gemini-2.5-pro',
|
||||
reasoning: {
|
||||
effort: 'high',
|
||||
},
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
} as any,
|
||||
);
|
||||
|
||||
const request = {
|
||||
model: 'gemini-2.5-pro',
|
||||
contents: [],
|
||||
};
|
||||
|
||||
await generatorWithReasoning.generateContent(request, 'prompt-id');
|
||||
|
||||
expect(mockGoogleGenAI.models.generateContent).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
config: expect.objectContaining({
|
||||
thinkingConfig: {
|
||||
includeThoughts: true,
|
||||
thinkingLevel: 'HIGH',
|
||||
},
|
||||
}),
|
||||
}),
|
||||
);
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,140 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import type {
|
||||
CountTokensParameters,
|
||||
CountTokensResponse,
|
||||
EmbedContentParameters,
|
||||
EmbedContentResponse,
|
||||
GenerateContentParameters,
|
||||
GenerateContentResponse,
|
||||
GenerateContentConfig,
|
||||
ThinkingLevel,
|
||||
} from '@google/genai';
|
||||
import { GoogleGenAI } from '@google/genai';
|
||||
import type {
|
||||
ContentGenerator,
|
||||
ContentGeneratorConfig,
|
||||
} from '../contentGenerator.js';
|
||||
|
||||
/**
|
||||
* A wrapper for GoogleGenAI that implements the ContentGenerator interface.
|
||||
*/
|
||||
export class GeminiContentGenerator implements ContentGenerator {
|
||||
private readonly googleGenAI: GoogleGenAI;
|
||||
private readonly contentGeneratorConfig?: ContentGeneratorConfig;
|
||||
|
||||
constructor(
|
||||
options: {
|
||||
apiKey?: string;
|
||||
vertexai?: boolean;
|
||||
httpOptions?: { headers: Record<string, string> };
|
||||
},
|
||||
contentGeneratorConfig?: ContentGeneratorConfig,
|
||||
) {
|
||||
this.googleGenAI = new GoogleGenAI(options);
|
||||
this.contentGeneratorConfig = contentGeneratorConfig;
|
||||
}
|
||||
|
||||
private buildSamplingParameters(
|
||||
request: GenerateContentParameters,
|
||||
): GenerateContentConfig {
|
||||
const configSamplingParams = this.contentGeneratorConfig?.samplingParams;
|
||||
const requestConfig = request.config || {};
|
||||
|
||||
// Helper function to get parameter value with priority: config > request > default
|
||||
const getParameterValue = <T>(
|
||||
configValue: T | undefined,
|
||||
requestKey: keyof GenerateContentConfig,
|
||||
defaultValue?: T,
|
||||
): T | undefined => {
|
||||
const requestValue = requestConfig[requestKey] as T | undefined;
|
||||
|
||||
if (configValue !== undefined) return configValue;
|
||||
if (requestValue !== undefined) return requestValue;
|
||||
return defaultValue;
|
||||
};
|
||||
|
||||
return {
|
||||
...requestConfig,
|
||||
temperature: getParameterValue<number>(
|
||||
configSamplingParams?.temperature,
|
||||
'temperature',
|
||||
1,
|
||||
),
|
||||
topP: getParameterValue<number>(
|
||||
configSamplingParams?.top_p,
|
||||
'topP',
|
||||
0.95,
|
||||
),
|
||||
topK: getParameterValue<number>(configSamplingParams?.top_k, 'topK', 64),
|
||||
maxOutputTokens: getParameterValue<number>(
|
||||
configSamplingParams?.max_tokens,
|
||||
'maxOutputTokens',
|
||||
),
|
||||
presencePenalty: getParameterValue<number>(
|
||||
configSamplingParams?.presence_penalty,
|
||||
'presencePenalty',
|
||||
),
|
||||
frequencyPenalty: getParameterValue<number>(
|
||||
configSamplingParams?.frequency_penalty,
|
||||
'frequencyPenalty',
|
||||
),
|
||||
thinkingConfig: getParameterValue(
|
||||
this.contentGeneratorConfig?.reasoning
|
||||
? {
|
||||
includeThoughts: true,
|
||||
thinkingLevel: (this.contentGeneratorConfig.reasoning.effort ===
|
||||
'low'
|
||||
? 'LOW'
|
||||
: this.contentGeneratorConfig.reasoning.effort === 'high'
|
||||
? 'HIGH'
|
||||
: 'THINKING_LEVEL_UNSPECIFIED') as ThinkingLevel,
|
||||
}
|
||||
: undefined,
|
||||
'thinkingConfig',
|
||||
{
|
||||
includeThoughts: true,
|
||||
thinkingLevel: 'HIGH' as ThinkingLevel,
|
||||
},
|
||||
),
|
||||
};
|
||||
}
|
||||
|
||||
async generateContent(
|
||||
request: GenerateContentParameters,
|
||||
_userPromptId: string,
|
||||
): Promise<GenerateContentResponse> {
|
||||
const finalRequest = {
|
||||
...request,
|
||||
config: this.buildSamplingParameters(request),
|
||||
};
|
||||
return this.googleGenAI.models.generateContent(finalRequest);
|
||||
}
|
||||
|
||||
async generateContentStream(
|
||||
request: GenerateContentParameters,
|
||||
_userPromptId: string,
|
||||
): Promise<AsyncGenerator<GenerateContentResponse>> {
|
||||
const finalRequest = {
|
||||
...request,
|
||||
config: this.buildSamplingParameters(request),
|
||||
};
|
||||
return this.googleGenAI.models.generateContentStream(finalRequest);
|
||||
}
|
||||
|
||||
async countTokens(
|
||||
request: CountTokensParameters,
|
||||
): Promise<CountTokensResponse> {
|
||||
return this.googleGenAI.models.countTokens(request);
|
||||
}
|
||||
|
||||
async embedContent(
|
||||
request: EmbedContentParameters,
|
||||
): Promise<EmbedContentResponse> {
|
||||
return this.googleGenAI.models.embedContent(request);
|
||||
}
|
||||
}
|
||||
47
packages/core/src/core/geminiContentGenerator/index.test.ts
Normal file
47
packages/core/src/core/geminiContentGenerator/index.test.ts
Normal file
@@ -0,0 +1,47 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { describe, it, expect, vi, beforeEach } from 'vitest';
|
||||
import { createGeminiContentGenerator } from './index.js';
|
||||
import { GeminiContentGenerator } from './geminiContentGenerator.js';
|
||||
import { LoggingContentGenerator } from './loggingContentGenerator.js';
|
||||
import type { Config } from '../../config/config.js';
|
||||
import { AuthType } from '../contentGenerator.js';
|
||||
|
||||
vi.mock('./geminiContentGenerator.js', () => ({
|
||||
GeminiContentGenerator: vi.fn().mockImplementation(() => ({})),
|
||||
}));
|
||||
|
||||
vi.mock('./loggingContentGenerator.js', () => ({
|
||||
LoggingContentGenerator: vi.fn().mockImplementation((wrapped) => wrapped),
|
||||
}));
|
||||
|
||||
describe('createGeminiContentGenerator', () => {
|
||||
let mockConfig: Config;
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
mockConfig = {
|
||||
getUsageStatisticsEnabled: vi.fn().mockReturnValue(false),
|
||||
getContentGeneratorConfig: vi.fn().mockReturnValue({}),
|
||||
getCliVersion: vi.fn().mockReturnValue('1.0.0'),
|
||||
} as unknown as Config;
|
||||
});
|
||||
|
||||
it('should create a GeminiContentGenerator wrapped in LoggingContentGenerator', () => {
|
||||
const config = {
|
||||
model: 'gemini-1.5-flash',
|
||||
apiKey: 'test-key',
|
||||
authType: AuthType.USE_GEMINI,
|
||||
};
|
||||
|
||||
const generator = createGeminiContentGenerator(config, mockConfig);
|
||||
|
||||
expect(GeminiContentGenerator).toHaveBeenCalled();
|
||||
expect(LoggingContentGenerator).toHaveBeenCalled();
|
||||
expect(generator).toBeDefined();
|
||||
});
|
||||
});
|
||||
55
packages/core/src/core/geminiContentGenerator/index.ts
Normal file
55
packages/core/src/core/geminiContentGenerator/index.ts
Normal file
@@ -0,0 +1,55 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { GeminiContentGenerator } from './geminiContentGenerator.js';
|
||||
import type {
|
||||
ContentGenerator,
|
||||
ContentGeneratorConfig,
|
||||
} from '../contentGenerator.js';
|
||||
import type { Config } from '../../config/config.js';
|
||||
import { InstallationManager } from '../../utils/installationManager.js';
|
||||
import { LoggingContentGenerator } from './loggingContentGenerator.js';
|
||||
|
||||
export { GeminiContentGenerator } from './geminiContentGenerator.js';
|
||||
export { LoggingContentGenerator } from './loggingContentGenerator.js';
|
||||
|
||||
/**
|
||||
* Create a Gemini content generator.
|
||||
*/
|
||||
export function createGeminiContentGenerator(
|
||||
config: ContentGeneratorConfig,
|
||||
gcConfig: Config,
|
||||
): ContentGenerator {
|
||||
const version = process.env['CLI_VERSION'] || process.version;
|
||||
const userAgent =
|
||||
config.userAgent ||
|
||||
`QwenCode/${version} (${process.platform}; ${process.arch})`;
|
||||
const baseHeaders: Record<string, string> = {
|
||||
'User-Agent': userAgent,
|
||||
};
|
||||
|
||||
let headers: Record<string, string> = { ...baseHeaders };
|
||||
if (gcConfig?.getUsageStatisticsEnabled()) {
|
||||
const installationManager = new InstallationManager();
|
||||
const installationId = installationManager.getInstallationId();
|
||||
headers = {
|
||||
...headers,
|
||||
'x-gemini-api-privileged-user-id': `${installationId}`,
|
||||
};
|
||||
}
|
||||
const httpOptions = { headers };
|
||||
|
||||
const geminiContentGenerator = new GeminiContentGenerator(
|
||||
{
|
||||
apiKey: config.apiKey === '' ? undefined : config.apiKey,
|
||||
vertexai: config.vertexai,
|
||||
httpOptions,
|
||||
},
|
||||
config,
|
||||
);
|
||||
|
||||
return new LoggingContentGenerator(geminiContentGenerator, gcConfig);
|
||||
}
|
||||
@@ -13,21 +13,24 @@ import type {
|
||||
GenerateContentParameters,
|
||||
GenerateContentResponseUsageMetadata,
|
||||
GenerateContentResponse,
|
||||
ContentListUnion,
|
||||
ContentUnion,
|
||||
Part,
|
||||
PartUnion,
|
||||
} from '@google/genai';
|
||||
import {
|
||||
ApiRequestEvent,
|
||||
ApiResponseEvent,
|
||||
ApiErrorEvent,
|
||||
} from '../telemetry/types.js';
|
||||
import type { Config } from '../config/config.js';
|
||||
} from '../../telemetry/types.js';
|
||||
import type { Config } from '../../config/config.js';
|
||||
import {
|
||||
logApiError,
|
||||
logApiRequest,
|
||||
logApiResponse,
|
||||
} from '../telemetry/loggers.js';
|
||||
import type { ContentGenerator } from './contentGenerator.js';
|
||||
import { toContents } from '../code_assist/converter.js';
|
||||
import { isStructuredError } from '../utils/quotaErrorDetection.js';
|
||||
} from '../../telemetry/loggers.js';
|
||||
import type { ContentGenerator } from '../contentGenerator.js';
|
||||
import { isStructuredError } from '../../utils/quotaErrorDetection.js';
|
||||
|
||||
interface StructuredError {
|
||||
status: number;
|
||||
@@ -112,7 +115,7 @@ export class LoggingContentGenerator implements ContentGenerator {
|
||||
userPromptId: string,
|
||||
): Promise<GenerateContentResponse> {
|
||||
const startTime = Date.now();
|
||||
this.logApiRequest(toContents(req.contents), req.model, userPromptId);
|
||||
this.logApiRequest(this.toContents(req.contents), req.model, userPromptId);
|
||||
try {
|
||||
const response = await this.wrapped.generateContent(req, userPromptId);
|
||||
const durationMs = Date.now() - startTime;
|
||||
@@ -137,7 +140,7 @@ export class LoggingContentGenerator implements ContentGenerator {
|
||||
userPromptId: string,
|
||||
): Promise<AsyncGenerator<GenerateContentResponse>> {
|
||||
const startTime = Date.now();
|
||||
this.logApiRequest(toContents(req.contents), req.model, userPromptId);
|
||||
this.logApiRequest(this.toContents(req.contents), req.model, userPromptId);
|
||||
|
||||
let stream: AsyncGenerator<GenerateContentResponse>;
|
||||
try {
|
||||
@@ -205,4 +208,91 @@ export class LoggingContentGenerator implements ContentGenerator {
|
||||
): Promise<EmbedContentResponse> {
|
||||
return this.wrapped.embedContent(req);
|
||||
}
|
||||
|
||||
private toContents(contents: ContentListUnion): Content[] {
|
||||
if (Array.isArray(contents)) {
|
||||
// it's a Content[] or a PartsUnion[]
|
||||
return contents.map((c) => this.toContent(c));
|
||||
}
|
||||
// it's a Content or a PartsUnion
|
||||
return [this.toContent(contents)];
|
||||
}
|
||||
|
||||
private toContent(content: ContentUnion): Content {
|
||||
if (Array.isArray(content)) {
|
||||
// it's a PartsUnion[]
|
||||
return {
|
||||
role: 'user',
|
||||
parts: this.toParts(content),
|
||||
};
|
||||
}
|
||||
if (typeof content === 'string') {
|
||||
// it's a string
|
||||
return {
|
||||
role: 'user',
|
||||
parts: [{ text: content }],
|
||||
};
|
||||
}
|
||||
if ('parts' in content) {
|
||||
// it's a Content - process parts to handle thought filtering
|
||||
return {
|
||||
...content,
|
||||
parts: content.parts
|
||||
? this.toParts(content.parts.filter((p) => p != null))
|
||||
: [],
|
||||
};
|
||||
}
|
||||
// it's a Part
|
||||
return {
|
||||
role: 'user',
|
||||
parts: [this.toPart(content as Part)],
|
||||
};
|
||||
}
|
||||
|
||||
private toParts(parts: PartUnion[]): Part[] {
|
||||
return parts.map((p) => this.toPart(p));
|
||||
}
|
||||
|
||||
private toPart(part: PartUnion): Part {
|
||||
if (typeof part === 'string') {
|
||||
// it's a string
|
||||
return { text: part };
|
||||
}
|
||||
|
||||
// Handle thought parts for CountToken API compatibility
|
||||
// The CountToken API expects parts to have certain required "oneof" fields initialized,
|
||||
// but thought parts don't conform to this schema and cause API failures
|
||||
if ('thought' in part && part.thought) {
|
||||
const thoughtText = `[Thought: ${part.thought}]`;
|
||||
|
||||
const newPart = { ...part };
|
||||
delete (newPart as Record<string, unknown>)['thought'];
|
||||
|
||||
const hasApiContent =
|
||||
'functionCall' in newPart ||
|
||||
'functionResponse' in newPart ||
|
||||
'inlineData' in newPart ||
|
||||
'fileData' in newPart;
|
||||
|
||||
if (hasApiContent) {
|
||||
// It's a functionCall or other non-text part. Just strip the thought.
|
||||
return newPart;
|
||||
}
|
||||
|
||||
// If no other valid API content, this must be a text part.
|
||||
// Combine existing text (if any) with the thought, preserving other properties.
|
||||
const text = (newPart as { text?: unknown }).text;
|
||||
const existingText = text ? String(text) : '';
|
||||
const combinedText = existingText
|
||||
? `${existingText}\n${thoughtText}`
|
||||
: thoughtText;
|
||||
|
||||
return {
|
||||
...newPart,
|
||||
text: combinedText,
|
||||
};
|
||||
}
|
||||
|
||||
return part;
|
||||
}
|
||||
}
|
||||
@@ -47,7 +47,7 @@ describe('executeToolCall', () => {
|
||||
getDebugMode: () => false,
|
||||
getContentGeneratorConfig: () => ({
|
||||
model: 'test-model',
|
||||
authType: 'oauth-personal',
|
||||
authType: 'gemini-api-key',
|
||||
}),
|
||||
getShellExecutionConfig: () => ({
|
||||
terminalWidth: 90,
|
||||
|
||||
@@ -99,6 +99,7 @@ describe('OpenAIContentGenerator (Refactored)', () => {
|
||||
},
|
||||
} as unknown as OpenAI),
|
||||
buildRequest: vi.fn().mockImplementation((req) => req),
|
||||
getDefaultGenerationConfig: vi.fn().mockReturnValue({}),
|
||||
};
|
||||
|
||||
generator = new OpenAIContentGenerator(
|
||||
@@ -211,6 +212,7 @@ describe('OpenAIContentGenerator (Refactored)', () => {
|
||||
},
|
||||
} as unknown as OpenAI),
|
||||
buildRequest: vi.fn().mockImplementation((req) => req),
|
||||
getDefaultGenerationConfig: vi.fn().mockReturnValue({}),
|
||||
};
|
||||
|
||||
const testGenerator = new TestGenerator(
|
||||
@@ -277,6 +279,7 @@ describe('OpenAIContentGenerator (Refactored)', () => {
|
||||
},
|
||||
} as unknown as OpenAI),
|
||||
buildRequest: vi.fn().mockImplementation((req) => req),
|
||||
getDefaultGenerationConfig: vi.fn().mockReturnValue({}),
|
||||
};
|
||||
|
||||
const testGenerator = new TestGenerator(
|
||||
|
||||
@@ -60,6 +60,7 @@ describe('ContentGenerationPipeline', () => {
|
||||
buildClient: vi.fn().mockReturnValue(mockClient),
|
||||
buildRequest: vi.fn().mockImplementation((req) => req),
|
||||
buildHeaders: vi.fn().mockReturnValue({}),
|
||||
getDefaultGenerationConfig: vi.fn().mockReturnValue({}),
|
||||
};
|
||||
|
||||
// Mock telemetry service
|
||||
|
||||
@@ -283,16 +283,22 @@ export class ContentGenerationPipeline {
|
||||
private buildSamplingParameters(
|
||||
request: GenerateContentParameters,
|
||||
): Record<string, unknown> {
|
||||
const defaultSamplingParams =
|
||||
this.config.provider.getDefaultGenerationConfig();
|
||||
const configSamplingParams = this.contentGeneratorConfig.samplingParams;
|
||||
|
||||
// Helper function to get parameter value with priority: config > request > default
|
||||
const getParameterValue = <T>(
|
||||
configKey: keyof NonNullable<typeof configSamplingParams>,
|
||||
requestKey: keyof NonNullable<typeof request.config>,
|
||||
defaultValue?: T,
|
||||
requestKey?: keyof NonNullable<typeof request.config>,
|
||||
): T | undefined => {
|
||||
const configValue = configSamplingParams?.[configKey] as T | undefined;
|
||||
const requestValue = request.config?.[requestKey] as T | undefined;
|
||||
const requestValue = requestKey
|
||||
? (request.config?.[requestKey] as T | undefined)
|
||||
: undefined;
|
||||
const defaultValue = requestKey
|
||||
? (defaultSamplingParams[requestKey] as T)
|
||||
: undefined;
|
||||
|
||||
if (configValue !== undefined) return configValue;
|
||||
if (requestValue !== undefined) return requestValue;
|
||||
@@ -304,12 +310,8 @@ export class ContentGenerationPipeline {
|
||||
key: string,
|
||||
configKey: keyof NonNullable<typeof configSamplingParams>,
|
||||
requestKey?: keyof NonNullable<typeof request.config>,
|
||||
defaultValue?: T,
|
||||
): Record<string, T> | Record<string, never> => {
|
||||
const value = requestKey
|
||||
? getParameterValue(configKey, requestKey, defaultValue)
|
||||
: ((configSamplingParams?.[configKey] as T | undefined) ??
|
||||
defaultValue);
|
||||
): Record<string, T | undefined> => {
|
||||
const value = getParameterValue<T>(configKey, requestKey);
|
||||
|
||||
return value !== undefined ? { [key]: value } : {};
|
||||
};
|
||||
@@ -323,10 +325,18 @@ export class ContentGenerationPipeline {
|
||||
...addParameterIfDefined('max_tokens', 'max_tokens', 'maxOutputTokens'),
|
||||
|
||||
// Config-only parameters (no request fallback)
|
||||
...addParameterIfDefined('top_k', 'top_k'),
|
||||
...addParameterIfDefined('top_k', 'top_k', 'topK'),
|
||||
...addParameterIfDefined('repetition_penalty', 'repetition_penalty'),
|
||||
...addParameterIfDefined('presence_penalty', 'presence_penalty'),
|
||||
...addParameterIfDefined('frequency_penalty', 'frequency_penalty'),
|
||||
...addParameterIfDefined(
|
||||
'presence_penalty',
|
||||
'presence_penalty',
|
||||
'presencePenalty',
|
||||
),
|
||||
...addParameterIfDefined(
|
||||
'frequency_penalty',
|
||||
'frequency_penalty',
|
||||
'frequencyPenalty',
|
||||
),
|
||||
};
|
||||
|
||||
return params;
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import OpenAI from 'openai';
|
||||
import type { GenerateContentConfig } from '@google/genai';
|
||||
import type { Config } from '../../../config/config.js';
|
||||
import type { ContentGeneratorConfig } from '../../contentGenerator.js';
|
||||
import { AuthType } from '../../contentGenerator.js';
|
||||
@@ -141,6 +142,14 @@ export class DashScopeOpenAICompatibleProvider
|
||||
};
|
||||
}
|
||||
|
||||
getDefaultGenerationConfig(): GenerateContentConfig {
|
||||
return {
|
||||
temperature: 0.7,
|
||||
topP: 0.8,
|
||||
topK: 20,
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Add cache control flag to specified message(s) for DashScope providers
|
||||
*/
|
||||
|
||||
@@ -8,6 +8,7 @@ import type OpenAI from 'openai';
|
||||
import type { Config } from '../../../config/config.js';
|
||||
import type { ContentGeneratorConfig } from '../../contentGenerator.js';
|
||||
import { DefaultOpenAICompatibleProvider } from './default.js';
|
||||
import type { GenerateContentConfig } from '@google/genai';
|
||||
|
||||
export class DeepSeekOpenAICompatibleProvider extends DefaultOpenAICompatibleProvider {
|
||||
constructor(
|
||||
@@ -76,4 +77,10 @@ export class DeepSeekOpenAICompatibleProvider extends DefaultOpenAICompatiblePro
|
||||
messages,
|
||||
};
|
||||
}
|
||||
|
||||
override getDefaultGenerationConfig(): GenerateContentConfig {
|
||||
return {
|
||||
temperature: 0,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import OpenAI from 'openai';
|
||||
import type { GenerateContentConfig } from '@google/genai';
|
||||
import type { Config } from '../../../config/config.js';
|
||||
import type { ContentGeneratorConfig } from '../../contentGenerator.js';
|
||||
import { DEFAULT_TIMEOUT, DEFAULT_MAX_RETRIES } from '../constants.js';
|
||||
@@ -55,4 +56,11 @@ export class DefaultOpenAICompatibleProvider
|
||||
...request, // Preserve all original parameters including sampling params
|
||||
};
|
||||
}
|
||||
|
||||
getDefaultGenerationConfig(): GenerateContentConfig {
|
||||
return {
|
||||
temperature: 1,
|
||||
topP: 0.95,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import type { GenerateContentConfig } from '@google/genai';
|
||||
import type OpenAI from 'openai';
|
||||
|
||||
// Extended types to support cache_control for DashScope
|
||||
@@ -22,6 +23,7 @@ export interface OpenAICompatibleProvider {
|
||||
request: OpenAI.Chat.ChatCompletionCreateParams,
|
||||
userPromptId: string,
|
||||
): OpenAI.Chat.ChatCompletionCreateParams;
|
||||
getDefaultGenerationConfig(): GenerateContentConfig;
|
||||
}
|
||||
|
||||
export type DashScopeRequestMetadata = {
|
||||
|
||||
@@ -4,36 +4,10 @@
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import {
|
||||
describe,
|
||||
it,
|
||||
expect,
|
||||
vi,
|
||||
beforeEach,
|
||||
type Mock,
|
||||
type MockInstance,
|
||||
afterEach,
|
||||
} from 'vitest';
|
||||
import { describe, it, expect, vi, beforeEach } from 'vitest';
|
||||
import { handleFallback } from './handler.js';
|
||||
import type { Config } from '../config/config.js';
|
||||
import { AuthType } from '../core/contentGenerator.js';
|
||||
import {
|
||||
DEFAULT_GEMINI_FLASH_MODEL,
|
||||
DEFAULT_GEMINI_MODEL,
|
||||
} from '../config/models.js';
|
||||
import { logFlashFallback } from '../telemetry/index.js';
|
||||
import type { FallbackModelHandler } from './types.js';
|
||||
|
||||
// Mock the telemetry logger and event class
|
||||
vi.mock('../telemetry/index.js', () => ({
|
||||
logFlashFallback: vi.fn(),
|
||||
FlashFallbackEvent: class {},
|
||||
}));
|
||||
|
||||
const MOCK_PRO_MODEL = DEFAULT_GEMINI_MODEL;
|
||||
const FALLBACK_MODEL = DEFAULT_GEMINI_FLASH_MODEL;
|
||||
const AUTH_OAUTH = AuthType.LOGIN_WITH_GOOGLE;
|
||||
const AUTH_API_KEY = AuthType.USE_GEMINI;
|
||||
|
||||
const createMockConfig = (overrides: Partial<Config> = {}): Config =>
|
||||
({
|
||||
@@ -45,174 +19,28 @@ const createMockConfig = (overrides: Partial<Config> = {}): Config =>
|
||||
|
||||
describe('handleFallback', () => {
|
||||
let mockConfig: Config;
|
||||
let mockHandler: Mock<FallbackModelHandler>;
|
||||
let consoleErrorSpy: MockInstance;
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
mockHandler = vi.fn();
|
||||
// Default setup: OAuth user, Pro model failed, handler injected
|
||||
mockConfig = createMockConfig({
|
||||
fallbackModelHandler: mockHandler,
|
||||
});
|
||||
consoleErrorSpy = vi.spyOn(console, 'error').mockImplementation(() => {});
|
||||
mockConfig = createMockConfig();
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
consoleErrorSpy.mockRestore();
|
||||
});
|
||||
|
||||
it('should return null immediately if authType is not OAuth', async () => {
|
||||
it('should return null for unknown auth types', async () => {
|
||||
const result = await handleFallback(
|
||||
mockConfig,
|
||||
MOCK_PRO_MODEL,
|
||||
AUTH_API_KEY,
|
||||
'test-model',
|
||||
'unknown-auth',
|
||||
);
|
||||
expect(result).toBeNull();
|
||||
expect(mockHandler).not.toHaveBeenCalled();
|
||||
expect(mockConfig.setFallbackMode).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should return null if the failed model is already the fallback model', async () => {
|
||||
it('should handle Qwen OAuth error', async () => {
|
||||
const result = await handleFallback(
|
||||
mockConfig,
|
||||
FALLBACK_MODEL, // Failed model is Flash
|
||||
AUTH_OAUTH,
|
||||
'test-model',
|
||||
AuthType.QWEN_OAUTH,
|
||||
new Error('unauthorized'),
|
||||
);
|
||||
expect(result).toBeNull();
|
||||
expect(mockHandler).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should return null if no fallbackHandler is injected in config', async () => {
|
||||
const configWithoutHandler = createMockConfig({
|
||||
fallbackModelHandler: undefined,
|
||||
});
|
||||
const result = await handleFallback(
|
||||
configWithoutHandler,
|
||||
MOCK_PRO_MODEL,
|
||||
AUTH_OAUTH,
|
||||
);
|
||||
expect(result).toBeNull();
|
||||
});
|
||||
|
||||
describe('when handler returns "retry"', () => {
|
||||
it('should activate fallback mode, log telemetry, and return true', async () => {
|
||||
mockHandler.mockResolvedValue('retry');
|
||||
|
||||
const result = await handleFallback(
|
||||
mockConfig,
|
||||
MOCK_PRO_MODEL,
|
||||
AUTH_OAUTH,
|
||||
);
|
||||
|
||||
expect(result).toBe(true);
|
||||
expect(mockConfig.setFallbackMode).toHaveBeenCalledWith(true);
|
||||
expect(logFlashFallback).toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
describe('when handler returns "stop"', () => {
|
||||
it('should activate fallback mode, log telemetry, and return false', async () => {
|
||||
mockHandler.mockResolvedValue('stop');
|
||||
|
||||
const result = await handleFallback(
|
||||
mockConfig,
|
||||
MOCK_PRO_MODEL,
|
||||
AUTH_OAUTH,
|
||||
);
|
||||
|
||||
expect(result).toBe(false);
|
||||
expect(mockConfig.setFallbackMode).toHaveBeenCalledWith(true);
|
||||
expect(logFlashFallback).toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
describe('when handler returns "auth"', () => {
|
||||
it('should NOT activate fallback mode and return false', async () => {
|
||||
mockHandler.mockResolvedValue('auth');
|
||||
|
||||
const result = await handleFallback(
|
||||
mockConfig,
|
||||
MOCK_PRO_MODEL,
|
||||
AUTH_OAUTH,
|
||||
);
|
||||
|
||||
expect(result).toBe(false);
|
||||
expect(mockConfig.setFallbackMode).not.toHaveBeenCalled();
|
||||
expect(logFlashFallback).not.toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
describe('when handler returns an unexpected value', () => {
|
||||
it('should log an error and return null', async () => {
|
||||
mockHandler.mockResolvedValue(null);
|
||||
|
||||
const result = await handleFallback(
|
||||
mockConfig,
|
||||
MOCK_PRO_MODEL,
|
||||
AUTH_OAUTH,
|
||||
);
|
||||
|
||||
expect(result).toBeNull();
|
||||
expect(consoleErrorSpy).toHaveBeenCalledWith(
|
||||
'Fallback UI handler failed:',
|
||||
new Error(
|
||||
'Unexpected fallback intent received from fallbackModelHandler: "null"',
|
||||
),
|
||||
);
|
||||
expect(mockConfig.setFallbackMode).not.toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
it('should pass the correct context (failedModel, fallbackModel, error) to the handler', async () => {
|
||||
const mockError = new Error('Quota Exceeded');
|
||||
mockHandler.mockResolvedValue('retry');
|
||||
|
||||
await handleFallback(mockConfig, MOCK_PRO_MODEL, AUTH_OAUTH, mockError);
|
||||
|
||||
expect(mockHandler).toHaveBeenCalledWith(
|
||||
MOCK_PRO_MODEL,
|
||||
FALLBACK_MODEL,
|
||||
mockError,
|
||||
);
|
||||
});
|
||||
|
||||
it('should not call setFallbackMode or log telemetry if already in fallback mode', async () => {
|
||||
// Setup config where fallback mode is already active
|
||||
const activeFallbackConfig = createMockConfig({
|
||||
fallbackModelHandler: mockHandler,
|
||||
isInFallbackMode: vi.fn(() => true), // Already active
|
||||
setFallbackMode: vi.fn(),
|
||||
});
|
||||
|
||||
mockHandler.mockResolvedValue('retry');
|
||||
|
||||
const result = await handleFallback(
|
||||
activeFallbackConfig,
|
||||
MOCK_PRO_MODEL,
|
||||
AUTH_OAUTH,
|
||||
);
|
||||
|
||||
// Should still return true to allow the retry (which will use the active fallback mode)
|
||||
expect(result).toBe(true);
|
||||
// Should still consult the handler
|
||||
expect(mockHandler).toHaveBeenCalled();
|
||||
// But should not mutate state or log telemetry again
|
||||
expect(activeFallbackConfig.setFallbackMode).not.toHaveBeenCalled();
|
||||
expect(logFlashFallback).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should catch errors from the handler, log an error, and return null', async () => {
|
||||
const handlerError = new Error('UI interaction failed');
|
||||
mockHandler.mockRejectedValue(handlerError);
|
||||
|
||||
const result = await handleFallback(mockConfig, MOCK_PRO_MODEL, AUTH_OAUTH);
|
||||
|
||||
expect(result).toBeNull();
|
||||
expect(consoleErrorSpy).toHaveBeenCalledWith(
|
||||
'Fallback UI handler failed:',
|
||||
handlerError,
|
||||
);
|
||||
expect(mockConfig.setFallbackMode).not.toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
@@ -6,8 +6,6 @@
|
||||
|
||||
import type { Config } from '../config/config.js';
|
||||
import { AuthType } from '../core/contentGenerator.js';
|
||||
import { DEFAULT_GEMINI_FLASH_MODEL } from '../config/models.js';
|
||||
import { logFlashFallback, FlashFallbackEvent } from '../telemetry/index.js';
|
||||
|
||||
export async function handleFallback(
|
||||
config: Config,
|
||||
@@ -20,48 +18,7 @@ export async function handleFallback(
|
||||
return handleQwenOAuthError(error);
|
||||
}
|
||||
|
||||
// Applicability Checks
|
||||
if (authType !== AuthType.LOGIN_WITH_GOOGLE) return null;
|
||||
|
||||
const fallbackModel = DEFAULT_GEMINI_FLASH_MODEL;
|
||||
|
||||
if (failedModel === fallbackModel) return null;
|
||||
|
||||
// Consult UI Handler for Intent
|
||||
const fallbackModelHandler = config.fallbackModelHandler;
|
||||
if (typeof fallbackModelHandler !== 'function') return null;
|
||||
|
||||
try {
|
||||
// Pass the specific failed model to the UI handler.
|
||||
const intent = await fallbackModelHandler(
|
||||
failedModel,
|
||||
fallbackModel,
|
||||
error,
|
||||
);
|
||||
|
||||
// Process Intent and Update State
|
||||
switch (intent) {
|
||||
case 'retry':
|
||||
// Activate fallback mode. The NEXT retry attempt will pick this up.
|
||||
activateFallbackMode(config, authType);
|
||||
return true; // Signal retryWithBackoff to continue.
|
||||
|
||||
case 'stop':
|
||||
activateFallbackMode(config, authType);
|
||||
return false;
|
||||
|
||||
case 'auth':
|
||||
return false;
|
||||
|
||||
default:
|
||||
throw new Error(
|
||||
`Unexpected fallback intent received from fallbackModelHandler: "${intent}"`,
|
||||
);
|
||||
}
|
||||
} catch (handlerError) {
|
||||
console.error('Fallback UI handler failed:', handlerError);
|
||||
return null;
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -118,12 +75,3 @@ async function handleQwenOAuthError(error?: unknown): Promise<string | null> {
|
||||
// For other errors, don't handle them specially
|
||||
return null;
|
||||
}
|
||||
|
||||
function activateFallbackMode(config: Config, authType: string | undefined) {
|
||||
if (!config.isInFallbackMode()) {
|
||||
config.setFallbackMode(true);
|
||||
if (authType) {
|
||||
logFlashFallback(config, new FlashFallbackEvent(authType));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,7 +12,6 @@ export * from './output/json-formatter.js';
|
||||
// Export Core Logic
|
||||
export * from './core/client.js';
|
||||
export * from './core/contentGenerator.js';
|
||||
export * from './core/loggingContentGenerator.js';
|
||||
export * from './core/geminiChat.js';
|
||||
export * from './core/logger.js';
|
||||
export * from './core/prompts.js';
|
||||
@@ -24,11 +23,7 @@ export * from './core/nonInteractiveToolExecutor.js';
|
||||
|
||||
export * from './fallback/types.js';
|
||||
|
||||
export * from './code_assist/codeAssist.js';
|
||||
export * from './code_assist/oauth2.js';
|
||||
export * from './qwen/qwenOAuth2.js';
|
||||
export * from './code_assist/server.js';
|
||||
export * from './code_assist/types.js';
|
||||
|
||||
// Export utilities
|
||||
export * from './utils/paths.js';
|
||||
|
||||
@@ -907,3 +907,5 @@ export async function clearQwenCredentials(): Promise<void> {
|
||||
function getQwenCachedCredentialPath(): string {
|
||||
return path.join(os.homedir(), QWEN_DIR, QWEN_CREDENTIAL_FILENAME);
|
||||
}
|
||||
|
||||
export const clearCachedCredentialFile = clearQwenCredentials;
|
||||
|
||||
@@ -30,7 +30,6 @@ import {
|
||||
ToolCallEvent,
|
||||
} from '../types.js';
|
||||
import { GIT_COMMIT_INFO, CLI_VERSION } from '../../generated/git-commit.js';
|
||||
import { UserAccountManager } from '../../utils/userAccountManager.js';
|
||||
import { InstallationManager } from '../../utils/installationManager.js';
|
||||
import { safeJsonStringify } from '../../utils/safeJsonStringify.js';
|
||||
|
||||
@@ -90,10 +89,8 @@ expect.extend({
|
||||
},
|
||||
});
|
||||
|
||||
vi.mock('../../utils/userAccountManager.js');
|
||||
vi.mock('../../utils/installationManager.js');
|
||||
|
||||
const mockUserAccount = vi.mocked(UserAccountManager.prototype);
|
||||
const mockInstallMgr = vi.mocked(InstallationManager.prototype);
|
||||
|
||||
// TODO(richieforeman): Consider moving this to test setup globally.
|
||||
@@ -128,11 +125,7 @@ describe('ClearcutLogger', () => {
|
||||
vi.unstubAllEnvs();
|
||||
});
|
||||
|
||||
function setup({
|
||||
config = {} as Partial<ConfigParameters>,
|
||||
lifetimeGoogleAccounts = 1,
|
||||
cachedGoogleAccount = 'test@google.com',
|
||||
} = {}) {
|
||||
function setup({ config = {} as Partial<ConfigParameters> } = {}) {
|
||||
server.resetHandlers(
|
||||
http.post(CLEARCUT_URL, () => HttpResponse.text(EXAMPLE_RESPONSE)),
|
||||
);
|
||||
@@ -146,10 +139,6 @@ describe('ClearcutLogger', () => {
|
||||
});
|
||||
ClearcutLogger.clearInstance();
|
||||
|
||||
mockUserAccount.getCachedGoogleAccount.mockReturnValue(cachedGoogleAccount);
|
||||
mockUserAccount.getLifetimeGoogleAccounts.mockReturnValue(
|
||||
lifetimeGoogleAccounts,
|
||||
);
|
||||
mockInstallMgr.getInstallationId = vi
|
||||
.fn()
|
||||
.mockReturnValue('test-installation-id');
|
||||
@@ -195,19 +184,6 @@ describe('ClearcutLogger', () => {
|
||||
});
|
||||
|
||||
describe('createLogEvent', () => {
|
||||
it('logs the total number of google accounts', () => {
|
||||
const { logger } = setup({
|
||||
lifetimeGoogleAccounts: 9001,
|
||||
});
|
||||
|
||||
const event = logger?.createLogEvent(EventNames.API_ERROR, []);
|
||||
|
||||
expect(event?.event_metadata[0]).toContainEqual({
|
||||
gemini_cli_key: EventMetadataKey.GEMINI_CLI_GOOGLE_ACCOUNTS_COUNT,
|
||||
value: '9001',
|
||||
});
|
||||
});
|
||||
|
||||
it('logs the current surface from a github action', () => {
|
||||
const { logger } = setup({});
|
||||
|
||||
@@ -251,7 +227,6 @@ describe('ClearcutLogger', () => {
|
||||
// Define expected values
|
||||
const session_id = 'test-session-id';
|
||||
const auth_type = AuthType.USE_GEMINI;
|
||||
const google_accounts = 123;
|
||||
const surface = 'ide-1234';
|
||||
const cli_version = CLI_VERSION;
|
||||
const git_commit_hash = GIT_COMMIT_INFO;
|
||||
@@ -260,7 +235,6 @@ describe('ClearcutLogger', () => {
|
||||
|
||||
// Setup logger with expected values
|
||||
const { logger, loggerConfig } = setup({
|
||||
lifetimeGoogleAccounts: google_accounts,
|
||||
config: {},
|
||||
});
|
||||
vi.spyOn(loggerConfig, 'getContentGeneratorConfig').mockReturnValue({
|
||||
@@ -283,10 +257,6 @@ describe('ClearcutLogger', () => {
|
||||
gemini_cli_key: EventMetadataKey.GEMINI_CLI_AUTH_TYPE,
|
||||
value: JSON.stringify(auth_type),
|
||||
},
|
||||
{
|
||||
gemini_cli_key: EventMetadataKey.GEMINI_CLI_GOOGLE_ACCOUNTS_COUNT,
|
||||
value: `${google_accounts}`,
|
||||
},
|
||||
{
|
||||
gemini_cli_key: EventMetadataKey.GEMINI_CLI_SURFACE,
|
||||
value: surface,
|
||||
@@ -404,10 +374,14 @@ describe('ClearcutLogger', () => {
|
||||
vi.stubEnv(key, value);
|
||||
}
|
||||
const event = logger?.createLogEvent(EventNames.API_ERROR, []);
|
||||
expect(event?.event_metadata[0][3]).toEqual({
|
||||
gemini_cli_key: EventMetadataKey.GEMINI_CLI_SURFACE,
|
||||
value: expectedValue,
|
||||
});
|
||||
expect(event?.event_metadata[0]).toEqual(
|
||||
expect.arrayContaining([
|
||||
{
|
||||
gemini_cli_key: EventMetadataKey.GEMINI_CLI_SURFACE,
|
||||
value: expectedValue,
|
||||
},
|
||||
]),
|
||||
);
|
||||
},
|
||||
);
|
||||
});
|
||||
|
||||
@@ -34,7 +34,6 @@ import type {
|
||||
import { EventMetadataKey } from './event-metadata-key.js';
|
||||
import type { Config } from '../../config/config.js';
|
||||
import { InstallationManager } from '../../utils/installationManager.js';
|
||||
import { UserAccountManager } from '../../utils/userAccountManager.js';
|
||||
import { safeJsonStringify } from '../../utils/safeJsonStringify.js';
|
||||
import { FixedDeque } from 'mnemonist';
|
||||
import { GIT_COMMIT_INFO, CLI_VERSION } from '../../generated/git-commit.js';
|
||||
@@ -157,7 +156,6 @@ export class ClearcutLogger {
|
||||
private sessionData: EventValue[] = [];
|
||||
private promptId: string = '';
|
||||
private readonly installationManager: InstallationManager;
|
||||
private readonly userAccountManager: UserAccountManager;
|
||||
|
||||
/**
|
||||
* Queue of pending events that need to be flushed to the server. New events
|
||||
@@ -186,7 +184,6 @@ export class ClearcutLogger {
|
||||
this.events = new FixedDeque<LogEventEntry[]>(Array, MAX_EVENTS);
|
||||
this.promptId = config?.getSessionId() ?? '';
|
||||
this.installationManager = new InstallationManager();
|
||||
this.userAccountManager = new UserAccountManager();
|
||||
}
|
||||
|
||||
static getInstance(config?: Config): ClearcutLogger | undefined {
|
||||
@@ -233,14 +230,11 @@ export class ClearcutLogger {
|
||||
}
|
||||
|
||||
createLogEvent(eventName: EventNames, data: EventValue[] = []): LogEvent {
|
||||
const email = this.userAccountManager.getCachedGoogleAccount();
|
||||
|
||||
if (eventName !== EventNames.START_SESSION) {
|
||||
data.push(...this.sessionData);
|
||||
}
|
||||
const totalAccounts = this.userAccountManager.getLifetimeGoogleAccounts();
|
||||
|
||||
data = this.addDefaultFields(data, totalAccounts);
|
||||
data = this.addDefaultFields(data);
|
||||
|
||||
const logEvent: LogEvent = {
|
||||
console_type: 'GEMINI_CLI',
|
||||
@@ -249,12 +243,7 @@ export class ClearcutLogger {
|
||||
event_metadata: [data],
|
||||
};
|
||||
|
||||
// Should log either email or install ID, not both. See go/cloudmill-1p-oss-instrumentation#define-sessionable-id
|
||||
if (email) {
|
||||
logEvent.client_email = email;
|
||||
} else {
|
||||
logEvent.client_install_id = this.installationManager.getInstallationId();
|
||||
}
|
||||
logEvent.client_install_id = this.installationManager.getInstallationId();
|
||||
|
||||
return logEvent;
|
||||
}
|
||||
@@ -1018,7 +1007,7 @@ export class ClearcutLogger {
|
||||
* Adds default fields to data, and returns a new data array. This fields
|
||||
* should exist on all log events.
|
||||
*/
|
||||
addDefaultFields(data: EventValue[], totalAccounts: number): EventValue[] {
|
||||
addDefaultFields(data: EventValue[]): EventValue[] {
|
||||
const surface = determineSurface();
|
||||
|
||||
const defaultLogMetadata: EventValue[] = [
|
||||
@@ -1032,10 +1021,6 @@ export class ClearcutLogger {
|
||||
this.config?.getContentGeneratorConfig()?.authType,
|
||||
),
|
||||
},
|
||||
{
|
||||
gemini_cli_key: EventMetadataKey.GEMINI_CLI_GOOGLE_ACCOUNTS_COUNT,
|
||||
value: `${totalAccounts}`,
|
||||
},
|
||||
{
|
||||
gemini_cli_key: EventMetadataKey.GEMINI_CLI_SURFACE,
|
||||
value: surface,
|
||||
|
||||
@@ -83,7 +83,6 @@ import type {
|
||||
} from '@google/genai';
|
||||
import { DiscoveredMCPTool } from '../tools/mcp-tool.js';
|
||||
import * as uiTelemetry from './uiTelemetry.js';
|
||||
import { UserAccountManager } from '../utils/userAccountManager.js';
|
||||
import { makeFakeConfig } from '../test-utils/config.js';
|
||||
|
||||
describe('loggers', () => {
|
||||
@@ -101,10 +100,6 @@ describe('loggers', () => {
|
||||
vi.spyOn(uiTelemetry.uiTelemetryService, 'addEvent').mockImplementation(
|
||||
mockUiEvent.addEvent,
|
||||
);
|
||||
vi.spyOn(
|
||||
UserAccountManager.prototype,
|
||||
'getCachedGoogleAccount',
|
||||
).mockReturnValue('test-user@example.com');
|
||||
vi.useFakeTimers();
|
||||
vi.setSystemTime(new Date('2025-01-01T00:00:00.000Z'));
|
||||
});
|
||||
@@ -188,7 +183,6 @@ describe('loggers', () => {
|
||||
body: 'CLI configuration loaded.',
|
||||
attributes: {
|
||||
'session.id': 'test-session-id',
|
||||
'user.email': 'test-user@example.com',
|
||||
'event.name': EVENT_CLI_CONFIG,
|
||||
'event.timestamp': '2025-01-01T00:00:00.000Z',
|
||||
model: 'test-model',
|
||||
@@ -233,7 +227,6 @@ describe('loggers', () => {
|
||||
body: 'User prompt. Length: 11.',
|
||||
attributes: {
|
||||
'session.id': 'test-session-id',
|
||||
'user.email': 'test-user@example.com',
|
||||
'event.name': EVENT_USER_PROMPT,
|
||||
'event.timestamp': '2025-01-01T00:00:00.000Z',
|
||||
prompt_length: 11,
|
||||
@@ -255,7 +248,7 @@ describe('loggers', () => {
|
||||
const event = new UserPromptEvent(
|
||||
11,
|
||||
'prompt-id-9',
|
||||
AuthType.CLOUD_SHELL,
|
||||
AuthType.USE_GEMINI,
|
||||
'test-prompt',
|
||||
);
|
||||
|
||||
@@ -265,12 +258,11 @@ describe('loggers', () => {
|
||||
body: 'User prompt. Length: 11.',
|
||||
attributes: {
|
||||
'session.id': 'test-session-id',
|
||||
'user.email': 'test-user@example.com',
|
||||
'event.name': EVENT_USER_PROMPT,
|
||||
'event.timestamp': '2025-01-01T00:00:00.000Z',
|
||||
prompt_length: 11,
|
||||
prompt_id: 'prompt-id-9',
|
||||
auth_type: 'cloud-shell',
|
||||
auth_type: 'gemini-api-key',
|
||||
},
|
||||
});
|
||||
});
|
||||
@@ -313,7 +305,7 @@ describe('loggers', () => {
|
||||
'test-model',
|
||||
100,
|
||||
'prompt-id-1',
|
||||
AuthType.LOGIN_WITH_GOOGLE,
|
||||
AuthType.USE_GEMINI,
|
||||
usageData,
|
||||
'test-response',
|
||||
);
|
||||
@@ -324,7 +316,6 @@ describe('loggers', () => {
|
||||
body: 'API response from test-model. Status: 200. Duration: 100ms.',
|
||||
attributes: {
|
||||
'session.id': 'test-session-id',
|
||||
'user.email': 'test-user@example.com',
|
||||
'event.name': EVENT_API_RESPONSE,
|
||||
'event.timestamp': '2025-01-01T00:00:00.000Z',
|
||||
[SemanticAttributes.HTTP_STATUS_CODE]: 200,
|
||||
@@ -340,7 +331,7 @@ describe('loggers', () => {
|
||||
total_token_count: 0,
|
||||
response_text: 'test-response',
|
||||
prompt_id: 'prompt-id-1',
|
||||
auth_type: 'oauth-personal',
|
||||
auth_type: 'gemini-api-key',
|
||||
},
|
||||
});
|
||||
|
||||
@@ -386,7 +377,6 @@ describe('loggers', () => {
|
||||
body: 'API request to test-model.',
|
||||
attributes: {
|
||||
'session.id': 'test-session-id',
|
||||
'user.email': 'test-user@example.com',
|
||||
'event.name': EVENT_API_REQUEST,
|
||||
'event.timestamp': '2025-01-01T00:00:00.000Z',
|
||||
model: 'test-model',
|
||||
@@ -405,7 +395,6 @@ describe('loggers', () => {
|
||||
body: 'API request to test-model.',
|
||||
attributes: {
|
||||
'session.id': 'test-session-id',
|
||||
'user.email': 'test-user@example.com',
|
||||
'event.name': EVENT_API_REQUEST,
|
||||
'event.timestamp': '2025-01-01T00:00:00.000Z',
|
||||
model: 'test-model',
|
||||
@@ -430,7 +419,6 @@ describe('loggers', () => {
|
||||
body: 'Switching to flash as Fallback.',
|
||||
attributes: {
|
||||
'session.id': 'test-session-id',
|
||||
'user.email': 'test-user@example.com',
|
||||
'event.name': EVENT_FLASH_FALLBACK,
|
||||
'event.timestamp': '2025-01-01T00:00:00.000Z',
|
||||
auth_type: 'vertex-ai',
|
||||
@@ -465,7 +453,6 @@ describe('loggers', () => {
|
||||
expect(emittedEvent.attributes).toEqual(
|
||||
expect.objectContaining({
|
||||
'session.id': 'test-session-id',
|
||||
'user.email': 'test-user@example.com',
|
||||
'event.name': EVENT_RIPGREP_FALLBACK,
|
||||
error: 'ripgrep is not available',
|
||||
}),
|
||||
@@ -484,7 +471,6 @@ describe('loggers', () => {
|
||||
expect(emittedEvent.attributes).toEqual(
|
||||
expect.objectContaining({
|
||||
'session.id': 'test-session-id',
|
||||
'user.email': 'test-user@example.com',
|
||||
'event.name': EVENT_RIPGREP_FALLBACK,
|
||||
error: 'rg not found',
|
||||
}),
|
||||
@@ -598,7 +584,6 @@ describe('loggers', () => {
|
||||
body: 'Tool call: test-function. Decision: accept. Success: true. Duration: 100ms.',
|
||||
attributes: {
|
||||
'session.id': 'test-session-id',
|
||||
'user.email': 'test-user@example.com',
|
||||
'event.name': EVENT_TOOL_CALL,
|
||||
'event.timestamp': '2025-01-01T00:00:00.000Z',
|
||||
function_name: 'test-function',
|
||||
@@ -682,7 +667,6 @@ describe('loggers', () => {
|
||||
body: 'Tool call: test-function. Decision: reject. Success: false. Duration: 100ms.',
|
||||
attributes: {
|
||||
'session.id': 'test-session-id',
|
||||
'user.email': 'test-user@example.com',
|
||||
'event.name': EVENT_TOOL_CALL,
|
||||
'event.timestamp': '2025-01-01T00:00:00.000Z',
|
||||
function_name: 'test-function',
|
||||
@@ -759,7 +743,6 @@ describe('loggers', () => {
|
||||
body: 'Tool call: test-function. Decision: modify. Success: true. Duration: 100ms.',
|
||||
attributes: {
|
||||
'session.id': 'test-session-id',
|
||||
'user.email': 'test-user@example.com',
|
||||
'event.name': EVENT_TOOL_CALL,
|
||||
'event.timestamp': '2025-01-01T00:00:00.000Z',
|
||||
function_name: 'test-function',
|
||||
@@ -835,7 +818,6 @@ describe('loggers', () => {
|
||||
body: 'Tool call: test-function. Success: true. Duration: 100ms.',
|
||||
attributes: {
|
||||
'session.id': 'test-session-id',
|
||||
'user.email': 'test-user@example.com',
|
||||
'event.name': EVENT_TOOL_CALL,
|
||||
'event.timestamp': '2025-01-01T00:00:00.000Z',
|
||||
function_name: 'test-function',
|
||||
@@ -910,7 +892,6 @@ describe('loggers', () => {
|
||||
body: 'Tool call: test-function. Success: false. Duration: 100ms.',
|
||||
attributes: {
|
||||
'session.id': 'test-session-id',
|
||||
'user.email': 'test-user@example.com',
|
||||
'event.name': EVENT_TOOL_CALL,
|
||||
'event.timestamp': '2025-01-01T00:00:00.000Z',
|
||||
function_name: 'test-function',
|
||||
@@ -999,7 +980,6 @@ describe('loggers', () => {
|
||||
body: 'Tool call: mock_mcp_tool. Success: true. Duration: 100ms.',
|
||||
attributes: {
|
||||
'session.id': 'test-session-id',
|
||||
'user.email': 'test-user@example.com',
|
||||
'event.name': EVENT_TOOL_CALL,
|
||||
'event.timestamp': '2025-01-01T00:00:00.000Z',
|
||||
function_name: 'mock_mcp_tool',
|
||||
@@ -1047,7 +1027,6 @@ describe('loggers', () => {
|
||||
body: 'Malformed JSON response from test-model.',
|
||||
attributes: {
|
||||
'session.id': 'test-session-id',
|
||||
'user.email': 'test-user@example.com',
|
||||
'event.name': EVENT_MALFORMED_JSON_RESPONSE,
|
||||
'event.timestamp': '2025-01-01T00:00:00.000Z',
|
||||
model: 'test-model',
|
||||
@@ -1091,7 +1070,6 @@ describe('loggers', () => {
|
||||
body: 'File operation: read. Lines: 10.',
|
||||
attributes: {
|
||||
'session.id': 'test-session-id',
|
||||
'user.email': 'test-user@example.com',
|
||||
'event.name': EVENT_FILE_OPERATION,
|
||||
'event.timestamp': '2025-01-01T00:00:00.000Z',
|
||||
tool_name: 'test-tool',
|
||||
@@ -1137,7 +1115,6 @@ describe('loggers', () => {
|
||||
body: 'Tool output truncated for test-tool.',
|
||||
attributes: {
|
||||
'session.id': 'test-session-id',
|
||||
'user.email': 'test-user@example.com',
|
||||
'event.name': 'tool_output_truncated',
|
||||
'event.timestamp': '2025-01-01T00:00:00.000Z',
|
||||
eventName: 'tool_output_truncated',
|
||||
@@ -1184,7 +1161,6 @@ describe('loggers', () => {
|
||||
body: 'Installed extension vscode',
|
||||
attributes: {
|
||||
'session.id': 'test-session-id',
|
||||
'user.email': 'test-user@example.com',
|
||||
'event.name': EVENT_EXTENSION_INSTALL,
|
||||
'event.timestamp': '2025-01-01T00:00:00.000Z',
|
||||
extension_name: 'vscode',
|
||||
@@ -1223,7 +1199,6 @@ describe('loggers', () => {
|
||||
body: 'Uninstalled extension vscode',
|
||||
attributes: {
|
||||
'session.id': 'test-session-id',
|
||||
'user.email': 'test-user@example.com',
|
||||
'event.name': EVENT_EXTENSION_UNINSTALL,
|
||||
'event.timestamp': '2025-01-01T00:00:00.000Z',
|
||||
extension_name: 'vscode',
|
||||
@@ -1260,7 +1235,6 @@ describe('loggers', () => {
|
||||
body: 'Enabled extension vscode',
|
||||
attributes: {
|
||||
'session.id': 'test-session-id',
|
||||
'user.email': 'test-user@example.com',
|
||||
'event.name': EVENT_EXTENSION_ENABLE,
|
||||
'event.timestamp': '2025-01-01T00:00:00.000Z',
|
||||
extension_name: 'vscode',
|
||||
@@ -1297,7 +1271,6 @@ describe('loggers', () => {
|
||||
body: 'Disabled extension vscode',
|
||||
attributes: {
|
||||
'session.id': 'test-session-id',
|
||||
'user.email': 'test-user@example.com',
|
||||
'event.name': EVENT_EXTENSION_DISABLE,
|
||||
'event.timestamp': '2025-01-01T00:00:00.000Z',
|
||||
extension_name: 'vscode',
|
||||
|
||||
@@ -9,7 +9,6 @@ import { logs } from '@opentelemetry/api-logs';
|
||||
import { SemanticAttributes } from '@opentelemetry/semantic-conventions';
|
||||
import type { Config } from '../config/config.js';
|
||||
import { safeJsonStringify } from '../utils/safeJsonStringify.js';
|
||||
import { UserAccountManager } from '../utils/userAccountManager.js';
|
||||
import {
|
||||
EVENT_API_ERROR,
|
||||
EVENT_API_CANCEL,
|
||||
@@ -93,11 +92,8 @@ const shouldLogUserPrompts = (config: Config): boolean =>
|
||||
config.getTelemetryLogPromptsEnabled();
|
||||
|
||||
function getCommonAttributes(config: Config): LogAttributes {
|
||||
const userAccountManager = new UserAccountManager();
|
||||
const email = userAccountManager.getCachedGoogleAccount();
|
||||
return {
|
||||
'session.id': config.getSessionId(),
|
||||
...(email && { 'user.email': email }),
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
@@ -217,9 +217,9 @@ describe('mcp-client', () => {
|
||||
false,
|
||||
);
|
||||
|
||||
expect(transport).toEqual(
|
||||
new StreamableHTTPClientTransport(new URL('http://test-server'), {}),
|
||||
);
|
||||
expect(transport).toBeInstanceOf(StreamableHTTPClientTransport);
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
expect((transport as any)._url).toEqual(new URL('http://test-server'));
|
||||
});
|
||||
|
||||
it('with headers', async () => {
|
||||
@@ -232,13 +232,13 @@ describe('mcp-client', () => {
|
||||
false,
|
||||
);
|
||||
|
||||
expect(transport).toEqual(
|
||||
new StreamableHTTPClientTransport(new URL('http://test-server'), {
|
||||
requestInit: {
|
||||
headers: { Authorization: 'derp' },
|
||||
},
|
||||
}),
|
||||
);
|
||||
expect(transport).toBeInstanceOf(StreamableHTTPClientTransport);
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
expect((transport as any)._url).toEqual(new URL('http://test-server'));
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
expect((transport as any)._requestInit?.headers).toEqual({
|
||||
Authorization: 'derp',
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -251,9 +251,9 @@ describe('mcp-client', () => {
|
||||
},
|
||||
false,
|
||||
);
|
||||
expect(transport).toEqual(
|
||||
new SSEClientTransport(new URL('http://test-server'), {}),
|
||||
);
|
||||
expect(transport).toBeInstanceOf(SSEClientTransport);
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
expect((transport as any)._url).toEqual(new URL('http://test-server'));
|
||||
});
|
||||
|
||||
it('with headers', async () => {
|
||||
@@ -266,13 +266,13 @@ describe('mcp-client', () => {
|
||||
false,
|
||||
);
|
||||
|
||||
expect(transport).toEqual(
|
||||
new SSEClientTransport(new URL('http://test-server'), {
|
||||
requestInit: {
|
||||
headers: { Authorization: 'derp' },
|
||||
},
|
||||
}),
|
||||
);
|
||||
expect(transport).toBeInstanceOf(SSEClientTransport);
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
expect((transport as any)._url).toEqual(new URL('http://test-server'));
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
expect((transport as any)._requestInit?.headers).toEqual({
|
||||
Authorization: 'derp',
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
|
||||
@@ -6,9 +6,6 @@
|
||||
|
||||
import { describe, it, expect } from 'vitest';
|
||||
import { parseAndFormatApiError } from './errorParsing.js';
|
||||
import { isProQuotaExceededError } from './quotaErrorDetection.js';
|
||||
import { DEFAULT_GEMINI_FLASH_MODEL } from '../config/models.js';
|
||||
import { UserTierId } from '../code_assist/types.js';
|
||||
import { AuthType } from '../core/contentGenerator.js';
|
||||
import type { StructuredError } from '../core/turn.js';
|
||||
|
||||
@@ -27,32 +24,10 @@ describe('parseAndFormatApiError', () => {
|
||||
it('should format a 429 API error with the default message', () => {
|
||||
const errorMessage =
|
||||
'got status: 429 Too Many Requests. {"error":{"code":429,"message":"Rate limit exceeded","status":"RESOURCE_EXHAUSTED"}}';
|
||||
const result = parseAndFormatApiError(
|
||||
errorMessage,
|
||||
undefined,
|
||||
undefined,
|
||||
'gemini-2.5-pro',
|
||||
DEFAULT_GEMINI_FLASH_MODEL,
|
||||
);
|
||||
const result = parseAndFormatApiError(errorMessage, undefined);
|
||||
expect(result).toContain('[API Error: Rate limit exceeded');
|
||||
expect(result).toContain(
|
||||
'Possible quota limitations in place or slow response times detected. Switching to the gemini-2.5-flash model',
|
||||
);
|
||||
});
|
||||
|
||||
it('should format a 429 API error with the personal message', () => {
|
||||
const errorMessage =
|
||||
'got status: 429 Too Many Requests. {"error":{"code":429,"message":"Rate limit exceeded","status":"RESOURCE_EXHAUSTED"}}';
|
||||
const result = parseAndFormatApiError(
|
||||
errorMessage,
|
||||
AuthType.LOGIN_WITH_GOOGLE,
|
||||
undefined,
|
||||
'gemini-2.5-pro',
|
||||
DEFAULT_GEMINI_FLASH_MODEL,
|
||||
);
|
||||
expect(result).toContain('[API Error: Rate limit exceeded');
|
||||
expect(result).toContain(
|
||||
'Possible quota limitations in place or slow response times detected. Switching to the gemini-2.5-flash model',
|
||||
'Possible quota limitations in place or slow response times detected. Please wait and try again later.',
|
||||
);
|
||||
});
|
||||
|
||||
@@ -132,230 +107,4 @@ describe('parseAndFormatApiError', () => {
|
||||
const expected = '[API Error: An unknown error occurred.]';
|
||||
expect(parseAndFormatApiError(error)).toBe(expected);
|
||||
});
|
||||
|
||||
it('should format a 429 API error with Pro quota exceeded message for Google auth (Free tier)', () => {
|
||||
const errorMessage =
|
||||
'got status: 429 Too Many Requests. {"error":{"code":429,"message":"Quota exceeded for quota metric \'Gemini 2.5 Pro Requests\' and limit \'RequestsPerDay\' of service \'generativelanguage.googleapis.com\' for consumer \'project_number:123456789\'.","status":"RESOURCE_EXHAUSTED"}}';
|
||||
const result = parseAndFormatApiError(
|
||||
errorMessage,
|
||||
AuthType.LOGIN_WITH_GOOGLE,
|
||||
undefined,
|
||||
'gemini-2.5-pro',
|
||||
DEFAULT_GEMINI_FLASH_MODEL,
|
||||
);
|
||||
expect(result).toContain(
|
||||
"[API Error: Quota exceeded for quota metric 'Gemini 2.5 Pro Requests'",
|
||||
);
|
||||
expect(result).toContain(
|
||||
'You have reached your daily gemini-2.5-pro quota limit',
|
||||
);
|
||||
expect(result).toContain('upgrade to get higher limits');
|
||||
});
|
||||
|
||||
it('should format a regular 429 API error with standard message for Google auth', () => {
|
||||
const errorMessage =
|
||||
'got status: 429 Too Many Requests. {"error":{"code":429,"message":"Rate limit exceeded","status":"RESOURCE_EXHAUSTED"}}';
|
||||
const result = parseAndFormatApiError(
|
||||
errorMessage,
|
||||
AuthType.LOGIN_WITH_GOOGLE,
|
||||
undefined,
|
||||
'gemini-2.5-pro',
|
||||
DEFAULT_GEMINI_FLASH_MODEL,
|
||||
);
|
||||
expect(result).toContain('[API Error: Rate limit exceeded');
|
||||
expect(result).toContain(
|
||||
'Possible quota limitations in place or slow response times detected. Switching to the gemini-2.5-flash model',
|
||||
);
|
||||
expect(result).not.toContain(
|
||||
'You have reached your daily gemini-2.5-pro quota limit',
|
||||
);
|
||||
});
|
||||
|
||||
it('should format a 429 API error with generic quota exceeded message for Google auth', () => {
|
||||
const errorMessage =
|
||||
'got status: 429 Too Many Requests. {"error":{"code":429,"message":"Quota exceeded for quota metric \'GenerationRequests\' and limit \'RequestsPerDay\' of service \'generativelanguage.googleapis.com\' for consumer \'project_number:123456789\'.","status":"RESOURCE_EXHAUSTED"}}';
|
||||
const result = parseAndFormatApiError(
|
||||
errorMessage,
|
||||
AuthType.LOGIN_WITH_GOOGLE,
|
||||
undefined,
|
||||
'gemini-2.5-pro',
|
||||
DEFAULT_GEMINI_FLASH_MODEL,
|
||||
);
|
||||
expect(result).toContain(
|
||||
"[API Error: Quota exceeded for quota metric 'GenerationRequests'",
|
||||
);
|
||||
expect(result).toContain('You have reached your daily quota limit');
|
||||
expect(result).not.toContain(
|
||||
'You have reached your daily Gemini 2.5 Pro quota limit',
|
||||
);
|
||||
});
|
||||
|
||||
it('should prioritize Pro quota message over generic quota message for Google auth', () => {
|
||||
const errorMessage =
|
||||
'got status: 429 Too Many Requests. {"error":{"code":429,"message":"Quota exceeded for quota metric \'Gemini 2.5 Pro Requests\' and limit \'RequestsPerDay\' of service \'generativelanguage.googleapis.com\' for consumer \'project_number:123456789\'.","status":"RESOURCE_EXHAUSTED"}}';
|
||||
const result = parseAndFormatApiError(
|
||||
errorMessage,
|
||||
AuthType.LOGIN_WITH_GOOGLE,
|
||||
undefined,
|
||||
'gemini-2.5-pro',
|
||||
DEFAULT_GEMINI_FLASH_MODEL,
|
||||
);
|
||||
expect(result).toContain(
|
||||
"[API Error: Quota exceeded for quota metric 'Gemini 2.5 Pro Requests'",
|
||||
);
|
||||
expect(result).toContain(
|
||||
'You have reached your daily gemini-2.5-pro quota limit',
|
||||
);
|
||||
expect(result).not.toContain('You have reached your daily quota limit');
|
||||
});
|
||||
|
||||
it('should format a 429 API error with Pro quota exceeded message for Google auth (Standard tier)', () => {
|
||||
const errorMessage =
|
||||
'got status: 429 Too Many Requests. {"error":{"code":429,"message":"Quota exceeded for quota metric \'Gemini 2.5 Pro Requests\' and limit \'RequestsPerDay\' of service \'generativelanguage.googleapis.com\' for consumer \'project_number:123456789\'.","status":"RESOURCE_EXHAUSTED"}}';
|
||||
const result = parseAndFormatApiError(
|
||||
errorMessage,
|
||||
AuthType.LOGIN_WITH_GOOGLE,
|
||||
UserTierId.STANDARD,
|
||||
'gemini-2.5-pro',
|
||||
DEFAULT_GEMINI_FLASH_MODEL,
|
||||
);
|
||||
expect(result).toContain(
|
||||
"[API Error: Quota exceeded for quota metric 'Gemini 2.5 Pro Requests'",
|
||||
);
|
||||
expect(result).toContain(
|
||||
'You have reached your daily gemini-2.5-pro quota limit',
|
||||
);
|
||||
expect(result).toContain(
|
||||
'We appreciate you for choosing Gemini Code Assist and the Gemini CLI',
|
||||
);
|
||||
expect(result).not.toContain('upgrade to get higher limits');
|
||||
});
|
||||
|
||||
it('should format a 429 API error with Pro quota exceeded message for Google auth (Legacy tier)', () => {
|
||||
const errorMessage =
|
||||
'got status: 429 Too Many Requests. {"error":{"code":429,"message":"Quota exceeded for quota metric \'Gemini 2.5 Pro Requests\' and limit \'RequestsPerDay\' of service \'generativelanguage.googleapis.com\' for consumer \'project_number:123456789\'.","status":"RESOURCE_EXHAUSTED"}}';
|
||||
const result = parseAndFormatApiError(
|
||||
errorMessage,
|
||||
AuthType.LOGIN_WITH_GOOGLE,
|
||||
UserTierId.LEGACY,
|
||||
'gemini-2.5-pro',
|
||||
DEFAULT_GEMINI_FLASH_MODEL,
|
||||
);
|
||||
expect(result).toContain(
|
||||
"[API Error: Quota exceeded for quota metric 'Gemini 2.5 Pro Requests'",
|
||||
);
|
||||
expect(result).toContain(
|
||||
'You have reached your daily gemini-2.5-pro quota limit',
|
||||
);
|
||||
expect(result).toContain(
|
||||
'We appreciate you for choosing Gemini Code Assist and the Gemini CLI',
|
||||
);
|
||||
expect(result).not.toContain('upgrade to get higher limits');
|
||||
});
|
||||
|
||||
it('should handle different Gemini 2.5 version strings in Pro quota exceeded errors', () => {
|
||||
const errorMessage25 =
|
||||
'got status: 429 Too Many Requests. {"error":{"code":429,"message":"Quota exceeded for quota metric \'Gemini 2.5 Pro Requests\' and limit \'RequestsPerDay\' of service \'generativelanguage.googleapis.com\' for consumer \'project_number:123456789\'.","status":"RESOURCE_EXHAUSTED"}}';
|
||||
const errorMessagePreview =
|
||||
'got status: 429 Too Many Requests. {"error":{"code":429,"message":"Quota exceeded for quota metric \'Gemini 2.5-preview Pro Requests\' and limit \'RequestsPerDay\' of service \'generativelanguage.googleapis.com\' for consumer \'project_number:123456789\'.","status":"RESOURCE_EXHAUSTED"}}';
|
||||
|
||||
const result25 = parseAndFormatApiError(
|
||||
errorMessage25,
|
||||
AuthType.LOGIN_WITH_GOOGLE,
|
||||
undefined,
|
||||
'gemini-2.5-pro',
|
||||
DEFAULT_GEMINI_FLASH_MODEL,
|
||||
);
|
||||
const resultPreview = parseAndFormatApiError(
|
||||
errorMessagePreview,
|
||||
AuthType.LOGIN_WITH_GOOGLE,
|
||||
undefined,
|
||||
'gemini-2.5-preview-pro',
|
||||
DEFAULT_GEMINI_FLASH_MODEL,
|
||||
);
|
||||
|
||||
expect(result25).toContain(
|
||||
'You have reached your daily gemini-2.5-pro quota limit',
|
||||
);
|
||||
expect(resultPreview).toContain(
|
||||
'You have reached your daily gemini-2.5-preview-pro quota limit',
|
||||
);
|
||||
expect(result25).toContain('upgrade to get higher limits');
|
||||
expect(resultPreview).toContain('upgrade to get higher limits');
|
||||
});
|
||||
|
||||
it('should not match non-Pro models with similar version strings', () => {
|
||||
// Test that Flash models with similar version strings don't match
|
||||
expect(
|
||||
isProQuotaExceededError(
|
||||
"Quota exceeded for quota metric 'Gemini 2.5 Flash Requests' and limit",
|
||||
),
|
||||
).toBe(false);
|
||||
expect(
|
||||
isProQuotaExceededError(
|
||||
"Quota exceeded for quota metric 'Gemini 2.5-preview Flash Requests' and limit",
|
||||
),
|
||||
).toBe(false);
|
||||
|
||||
// Test other model types
|
||||
expect(
|
||||
isProQuotaExceededError(
|
||||
"Quota exceeded for quota metric 'Gemini 2.5 Ultra Requests' and limit",
|
||||
),
|
||||
).toBe(false);
|
||||
expect(
|
||||
isProQuotaExceededError(
|
||||
"Quota exceeded for quota metric 'Gemini 2.5 Standard Requests' and limit",
|
||||
),
|
||||
).toBe(false);
|
||||
|
||||
// Test generic quota messages
|
||||
expect(
|
||||
isProQuotaExceededError(
|
||||
"Quota exceeded for quota metric 'GenerationRequests' and limit",
|
||||
),
|
||||
).toBe(false);
|
||||
expect(
|
||||
isProQuotaExceededError(
|
||||
"Quota exceeded for quota metric 'EmbeddingRequests' and limit",
|
||||
),
|
||||
).toBe(false);
|
||||
});
|
||||
|
||||
it('should format a generic quota exceeded message for Google auth (Standard tier)', () => {
|
||||
const errorMessage =
|
||||
'got status: 429 Too Many Requests. {"error":{"code":429,"message":"Quota exceeded for quota metric \'GenerationRequests\' and limit \'RequestsPerDay\' of service \'generativelanguage.googleapis.com\' for consumer \'project_number:123456789\'.","status":"RESOURCE_EXHAUSTED"}}';
|
||||
const result = parseAndFormatApiError(
|
||||
errorMessage,
|
||||
AuthType.LOGIN_WITH_GOOGLE,
|
||||
UserTierId.STANDARD,
|
||||
'gemini-2.5-pro',
|
||||
DEFAULT_GEMINI_FLASH_MODEL,
|
||||
);
|
||||
expect(result).toContain(
|
||||
"[API Error: Quota exceeded for quota metric 'GenerationRequests'",
|
||||
);
|
||||
expect(result).toContain('You have reached your daily quota limit');
|
||||
expect(result).toContain(
|
||||
'We appreciate you for choosing Gemini Code Assist and the Gemini CLI',
|
||||
);
|
||||
expect(result).not.toContain('upgrade to get higher limits');
|
||||
});
|
||||
|
||||
it('should format a regular 429 API error with standard message for Google auth (Standard tier)', () => {
|
||||
const errorMessage =
|
||||
'got status: 429 Too Many Requests. {"error":{"code":429,"message":"Rate limit exceeded","status":"RESOURCE_EXHAUSTED"}}';
|
||||
const result = parseAndFormatApiError(
|
||||
errorMessage,
|
||||
AuthType.LOGIN_WITH_GOOGLE,
|
||||
UserTierId.STANDARD,
|
||||
'gemini-2.5-pro',
|
||||
DEFAULT_GEMINI_FLASH_MODEL,
|
||||
);
|
||||
expect(result).toContain('[API Error: Rate limit exceeded');
|
||||
expect(result).toContain(
|
||||
'We appreciate you for choosing Gemini Code Assist and the Gemini CLI',
|
||||
);
|
||||
expect(result).not.toContain('upgrade to get higher limits');
|
||||
});
|
||||
});
|
||||
|
||||
@@ -4,120 +4,36 @@
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import {
|
||||
isProQuotaExceededError,
|
||||
isGenericQuotaExceededError,
|
||||
isApiError,
|
||||
isStructuredError,
|
||||
} from './quotaErrorDetection.js';
|
||||
import {
|
||||
DEFAULT_GEMINI_MODEL,
|
||||
DEFAULT_GEMINI_FLASH_MODEL,
|
||||
} from '../config/models.js';
|
||||
import { UserTierId } from '../code_assist/types.js';
|
||||
import { isApiError, isStructuredError } from './quotaErrorDetection.js';
|
||||
import { AuthType } from '../core/contentGenerator.js';
|
||||
|
||||
// Free Tier message functions
|
||||
const getRateLimitErrorMessageGoogleFree = (
|
||||
fallbackModel: string = DEFAULT_GEMINI_FLASH_MODEL,
|
||||
) =>
|
||||
`\nPossible quota limitations in place or slow response times detected. Switching to the ${fallbackModel} model for the rest of this session.`;
|
||||
|
||||
const getRateLimitErrorMessageGoogleProQuotaFree = (
|
||||
currentModel: string = DEFAULT_GEMINI_MODEL,
|
||||
fallbackModel: string = DEFAULT_GEMINI_FLASH_MODEL,
|
||||
) =>
|
||||
`\nYou have reached your daily ${currentModel} quota limit. You will be switched to the ${fallbackModel} model for the rest of this session. To increase your limits, upgrade to get higher limits at https://goo.gle/set-up-gemini-code-assist, or use /auth to switch to using a paid API key from AI Studio at https://aistudio.google.com/apikey`;
|
||||
|
||||
const getRateLimitErrorMessageGoogleGenericQuotaFree = () =>
|
||||
`\nYou have reached your daily quota limit. To increase your limits, upgrade to get higher limits at https://goo.gle/set-up-gemini-code-assist, or use /auth to switch to using a paid API key from AI Studio at https://aistudio.google.com/apikey`;
|
||||
|
||||
// Legacy/Standard Tier message functions
|
||||
const getRateLimitErrorMessageGooglePaid = (
|
||||
fallbackModel: string = DEFAULT_GEMINI_FLASH_MODEL,
|
||||
) =>
|
||||
`\nPossible quota limitations in place or slow response times detected. Switching to the ${fallbackModel} model for the rest of this session. We appreciate you for choosing Gemini Code Assist and the Gemini CLI.`;
|
||||
|
||||
const getRateLimitErrorMessageGoogleProQuotaPaid = (
|
||||
currentModel: string = DEFAULT_GEMINI_MODEL,
|
||||
fallbackModel: string = DEFAULT_GEMINI_FLASH_MODEL,
|
||||
) =>
|
||||
`\nYou have reached your daily ${currentModel} quota limit. You will be switched to the ${fallbackModel} model for the rest of this session. We appreciate you for choosing Gemini Code Assist and the Gemini CLI. To continue accessing the ${currentModel} model today, consider using /auth to switch to using a paid API key from AI Studio at https://aistudio.google.com/apikey`;
|
||||
|
||||
const getRateLimitErrorMessageGoogleGenericQuotaPaid = (
|
||||
currentModel: string = DEFAULT_GEMINI_MODEL,
|
||||
) =>
|
||||
`\nYou have reached your daily quota limit. We appreciate you for choosing Gemini Code Assist and the Gemini CLI. To continue accessing the ${currentModel} model today, consider using /auth to switch to using a paid API key from AI Studio at https://aistudio.google.com/apikey`;
|
||||
const RATE_LIMIT_ERROR_MESSAGE_USE_GEMINI =
|
||||
'\nPlease wait and try again later. To increase your limits, request a quota increase through AI Studio, or switch to another /auth method';
|
||||
const RATE_LIMIT_ERROR_MESSAGE_VERTEX =
|
||||
'\nPlease wait and try again later. To increase your limits, request a quota increase through Vertex, or switch to another /auth method';
|
||||
const getRateLimitErrorMessageDefault = (
|
||||
fallbackModel: string = DEFAULT_GEMINI_FLASH_MODEL,
|
||||
) =>
|
||||
`\nPossible quota limitations in place or slow response times detected. Switching to the ${fallbackModel} model for the rest of this session.`;
|
||||
const RATE_LIMIT_ERROR_MESSAGE_DEFAULT =
|
||||
'\nPossible quota limitations in place or slow response times detected. Please wait and try again later.';
|
||||
|
||||
function getRateLimitMessage(
|
||||
authType?: AuthType,
|
||||
error?: unknown,
|
||||
userTier?: UserTierId,
|
||||
currentModel?: string,
|
||||
fallbackModel?: string,
|
||||
): string {
|
||||
function getRateLimitMessage(authType?: AuthType): string {
|
||||
switch (authType) {
|
||||
case AuthType.LOGIN_WITH_GOOGLE: {
|
||||
// Determine if user is on a paid tier (Legacy or Standard) - default to FREE if not specified
|
||||
const isPaidTier =
|
||||
userTier === UserTierId.LEGACY || userTier === UserTierId.STANDARD;
|
||||
|
||||
if (isProQuotaExceededError(error)) {
|
||||
return isPaidTier
|
||||
? getRateLimitErrorMessageGoogleProQuotaPaid(
|
||||
currentModel || DEFAULT_GEMINI_MODEL,
|
||||
fallbackModel,
|
||||
)
|
||||
: getRateLimitErrorMessageGoogleProQuotaFree(
|
||||
currentModel || DEFAULT_GEMINI_MODEL,
|
||||
fallbackModel,
|
||||
);
|
||||
} else if (isGenericQuotaExceededError(error)) {
|
||||
return isPaidTier
|
||||
? getRateLimitErrorMessageGoogleGenericQuotaPaid(
|
||||
currentModel || DEFAULT_GEMINI_MODEL,
|
||||
)
|
||||
: getRateLimitErrorMessageGoogleGenericQuotaFree();
|
||||
} else {
|
||||
return isPaidTier
|
||||
? getRateLimitErrorMessageGooglePaid(fallbackModel)
|
||||
: getRateLimitErrorMessageGoogleFree(fallbackModel);
|
||||
}
|
||||
}
|
||||
case AuthType.USE_GEMINI:
|
||||
return RATE_LIMIT_ERROR_MESSAGE_USE_GEMINI;
|
||||
case AuthType.USE_VERTEX_AI:
|
||||
return RATE_LIMIT_ERROR_MESSAGE_VERTEX;
|
||||
default:
|
||||
return getRateLimitErrorMessageDefault(fallbackModel);
|
||||
return RATE_LIMIT_ERROR_MESSAGE_DEFAULT;
|
||||
}
|
||||
}
|
||||
|
||||
export function parseAndFormatApiError(
|
||||
error: unknown,
|
||||
authType?: AuthType,
|
||||
userTier?: UserTierId,
|
||||
currentModel?: string,
|
||||
fallbackModel?: string,
|
||||
): string {
|
||||
if (isStructuredError(error)) {
|
||||
let text = `[API Error: ${error.message}]`;
|
||||
if (error.status === 429) {
|
||||
text += getRateLimitMessage(
|
||||
authType,
|
||||
error,
|
||||
userTier,
|
||||
currentModel,
|
||||
fallbackModel,
|
||||
);
|
||||
text += getRateLimitMessage(authType);
|
||||
}
|
||||
return text;
|
||||
}
|
||||
@@ -146,13 +62,7 @@ export function parseAndFormatApiError(
|
||||
}
|
||||
let text = `[API Error: ${finalMessage} (Status: ${parsedError.error.status})]`;
|
||||
if (parsedError.error.code === 429) {
|
||||
text += getRateLimitMessage(
|
||||
authType,
|
||||
parsedError,
|
||||
userTier,
|
||||
currentModel,
|
||||
fallbackModel,
|
||||
);
|
||||
text += getRateLimitMessage(authType);
|
||||
}
|
||||
return text;
|
||||
}
|
||||
|
||||
@@ -11,12 +11,9 @@ import {
|
||||
setSimulate429,
|
||||
disableSimulationAfterFallback,
|
||||
shouldSimulate429,
|
||||
createSimulated429Error,
|
||||
resetRequestCounter,
|
||||
} from './testUtils.js';
|
||||
import { DEFAULT_GEMINI_FLASH_MODEL } from '../config/models.js';
|
||||
import { retryWithBackoff } from './retry.js';
|
||||
import { AuthType } from '../core/contentGenerator.js';
|
||||
// Import the new types (Assuming this test file is in packages/core/src/utils/)
|
||||
import type { FallbackModelHandler } from '../fallback/types.js';
|
||||
|
||||
@@ -61,84 +58,6 @@ describe('Retry Utility Fallback Integration', () => {
|
||||
expect(result).toBe('retry');
|
||||
});
|
||||
|
||||
// This test validates the retry utility's logic for triggering the callback.
|
||||
it('should trigger onPersistent429 after 2 consecutive 429 errors for OAuth users', async () => {
|
||||
let fallbackCalled = false;
|
||||
// Removed fallbackModel variable as it's no longer relevant here.
|
||||
|
||||
// Mock function that simulates exactly 2 429 errors, then succeeds after fallback
|
||||
const mockApiCall = vi
|
||||
.fn()
|
||||
.mockRejectedValueOnce(createSimulated429Error())
|
||||
.mockRejectedValueOnce(createSimulated429Error())
|
||||
.mockResolvedValueOnce('success after fallback');
|
||||
|
||||
// Mock the onPersistent429 callback (this is what client.ts/geminiChat.ts provides)
|
||||
const mockPersistent429Callback = vi.fn(async (_authType?: string) => {
|
||||
fallbackCalled = true;
|
||||
// Return true to signal retryWithBackoff to reset attempts and continue.
|
||||
return true;
|
||||
});
|
||||
|
||||
// Test with OAuth personal auth type, with maxAttempts = 2 to ensure fallback triggers
|
||||
const result = await retryWithBackoff(mockApiCall, {
|
||||
maxAttempts: 2,
|
||||
initialDelayMs: 1,
|
||||
maxDelayMs: 10,
|
||||
shouldRetryOnError: (error: Error) => {
|
||||
const status = (error as Error & { status?: number }).status;
|
||||
return status === 429;
|
||||
},
|
||||
onPersistent429: mockPersistent429Callback,
|
||||
authType: AuthType.LOGIN_WITH_GOOGLE,
|
||||
});
|
||||
|
||||
// Verify fallback mechanism was triggered
|
||||
expect(fallbackCalled).toBe(true);
|
||||
expect(mockPersistent429Callback).toHaveBeenCalledWith(
|
||||
AuthType.LOGIN_WITH_GOOGLE,
|
||||
expect.any(Error),
|
||||
);
|
||||
expect(result).toBe('success after fallback');
|
||||
// Should have: 2 failures, then fallback triggered, then 1 success after retry reset
|
||||
expect(mockApiCall).toHaveBeenCalledTimes(3);
|
||||
});
|
||||
|
||||
it('should not trigger onPersistent429 for API key users', async () => {
|
||||
let fallbackCalled = false;
|
||||
|
||||
// Mock function that simulates 429 errors
|
||||
const mockApiCall = vi.fn().mockRejectedValue(createSimulated429Error());
|
||||
|
||||
// Mock the callback
|
||||
const mockPersistent429Callback = vi.fn(async () => {
|
||||
fallbackCalled = true;
|
||||
return true;
|
||||
});
|
||||
|
||||
// Test with API key auth type - should not trigger fallback
|
||||
try {
|
||||
await retryWithBackoff(mockApiCall, {
|
||||
maxAttempts: 5,
|
||||
initialDelayMs: 10,
|
||||
maxDelayMs: 100,
|
||||
shouldRetryOnError: (error: Error) => {
|
||||
const status = (error as Error & { status?: number }).status;
|
||||
return status === 429;
|
||||
},
|
||||
onPersistent429: mockPersistent429Callback,
|
||||
authType: AuthType.USE_GEMINI, // API key auth type
|
||||
});
|
||||
} catch (error) {
|
||||
// Expected to throw after max attempts
|
||||
expect((error as Error).message).toContain('Rate limit exceeded');
|
||||
}
|
||||
|
||||
// Verify fallback was NOT triggered for API key users
|
||||
expect(fallbackCalled).toBe(false);
|
||||
expect(mockPersistent429Callback).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
// This test validates the test utilities themselves.
|
||||
it('should properly disable simulation state after fallback (Test Utility)', () => {
|
||||
// Enable simulation
|
||||
|
||||
@@ -285,173 +285,6 @@ describe('retryWithBackoff', () => {
|
||||
});
|
||||
});
|
||||
|
||||
describe('Flash model fallback for OAuth users', () => {
|
||||
it('should trigger fallback for OAuth personal users after persistent 429 errors', async () => {
|
||||
const fallbackCallback = vi.fn().mockResolvedValue('gemini-2.5-flash');
|
||||
|
||||
let fallbackOccurred = false;
|
||||
const mockFn = vi.fn().mockImplementation(async () => {
|
||||
if (!fallbackOccurred) {
|
||||
const error: HttpError = new Error('Rate limit exceeded');
|
||||
error.status = 429;
|
||||
throw error;
|
||||
}
|
||||
return 'success';
|
||||
});
|
||||
|
||||
const promise = retryWithBackoff(mockFn, {
|
||||
maxAttempts: 3,
|
||||
initialDelayMs: 100,
|
||||
onPersistent429: async (authType?: string) => {
|
||||
fallbackOccurred = true;
|
||||
return await fallbackCallback(authType);
|
||||
},
|
||||
authType: 'oauth-personal',
|
||||
});
|
||||
|
||||
// Advance all timers to complete retries
|
||||
await vi.runAllTimersAsync();
|
||||
|
||||
// Should succeed after fallback
|
||||
await expect(promise).resolves.toBe('success');
|
||||
|
||||
// Verify callback was called with correct auth type
|
||||
expect(fallbackCallback).toHaveBeenCalledWith('oauth-personal');
|
||||
|
||||
// Should retry again after fallback
|
||||
expect(mockFn).toHaveBeenCalledTimes(3); // 2 initial attempts + 1 after fallback
|
||||
});
|
||||
|
||||
it('should NOT trigger fallback for API key users', async () => {
|
||||
const fallbackCallback = vi.fn();
|
||||
|
||||
const mockFn = vi.fn(async () => {
|
||||
const error: HttpError = new Error('Rate limit exceeded');
|
||||
error.status = 429;
|
||||
throw error;
|
||||
});
|
||||
|
||||
const promise = retryWithBackoff(mockFn, {
|
||||
maxAttempts: 3,
|
||||
initialDelayMs: 100,
|
||||
onPersistent429: fallbackCallback,
|
||||
authType: 'gemini-api-key',
|
||||
});
|
||||
|
||||
// Handle the promise properly to avoid unhandled rejections
|
||||
const resultPromise = promise.catch((error) => error);
|
||||
await vi.runAllTimersAsync();
|
||||
const result = await resultPromise;
|
||||
|
||||
// Should fail after all retries without fallback
|
||||
expect(result).toBeInstanceOf(Error);
|
||||
expect(result.message).toBe('Rate limit exceeded');
|
||||
|
||||
// Callback should not be called for API key users
|
||||
expect(fallbackCallback).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should reset attempt counter and continue after successful fallback', async () => {
|
||||
let fallbackCalled = false;
|
||||
const fallbackCallback = vi.fn().mockImplementation(async () => {
|
||||
fallbackCalled = true;
|
||||
return 'gemini-2.5-flash';
|
||||
});
|
||||
|
||||
const mockFn = vi.fn().mockImplementation(async () => {
|
||||
if (!fallbackCalled) {
|
||||
const error: HttpError = new Error('Rate limit exceeded');
|
||||
error.status = 429;
|
||||
throw error;
|
||||
}
|
||||
return 'success';
|
||||
});
|
||||
|
||||
const promise = retryWithBackoff(mockFn, {
|
||||
maxAttempts: 3,
|
||||
initialDelayMs: 100,
|
||||
onPersistent429: fallbackCallback,
|
||||
authType: 'oauth-personal',
|
||||
});
|
||||
|
||||
await vi.runAllTimersAsync();
|
||||
|
||||
await expect(promise).resolves.toBe('success');
|
||||
expect(fallbackCallback).toHaveBeenCalledOnce();
|
||||
});
|
||||
|
||||
it('should continue with original error if fallback is rejected', async () => {
|
||||
const fallbackCallback = vi.fn().mockResolvedValue(null); // User rejected fallback
|
||||
|
||||
const mockFn = vi.fn(async () => {
|
||||
const error: HttpError = new Error('Rate limit exceeded');
|
||||
error.status = 429;
|
||||
throw error;
|
||||
});
|
||||
|
||||
const promise = retryWithBackoff(mockFn, {
|
||||
maxAttempts: 3,
|
||||
initialDelayMs: 100,
|
||||
onPersistent429: fallbackCallback,
|
||||
authType: 'oauth-personal',
|
||||
});
|
||||
|
||||
// Handle the promise properly to avoid unhandled rejections
|
||||
const resultPromise = promise.catch((error) => error);
|
||||
await vi.runAllTimersAsync();
|
||||
const result = await resultPromise;
|
||||
|
||||
// Should fail with original error when fallback is rejected
|
||||
expect(result).toBeInstanceOf(Error);
|
||||
expect(result.message).toBe('Rate limit exceeded');
|
||||
expect(fallbackCallback).toHaveBeenCalledWith(
|
||||
'oauth-personal',
|
||||
expect.any(Error),
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle mixed error types (only count consecutive 429s)', async () => {
|
||||
const fallbackCallback = vi.fn().mockResolvedValue('gemini-2.5-flash');
|
||||
let attempts = 0;
|
||||
let fallbackOccurred = false;
|
||||
|
||||
const mockFn = vi.fn().mockImplementation(async () => {
|
||||
attempts++;
|
||||
if (fallbackOccurred) {
|
||||
return 'success';
|
||||
}
|
||||
if (attempts === 1) {
|
||||
// First attempt: 500 error (resets consecutive count)
|
||||
const error: HttpError = new Error('Server error');
|
||||
error.status = 500;
|
||||
throw error;
|
||||
} else {
|
||||
// Remaining attempts: 429 errors
|
||||
const error: HttpError = new Error('Rate limit exceeded');
|
||||
error.status = 429;
|
||||
throw error;
|
||||
}
|
||||
});
|
||||
|
||||
const promise = retryWithBackoff(mockFn, {
|
||||
maxAttempts: 5,
|
||||
initialDelayMs: 100,
|
||||
onPersistent429: async (authType?: string) => {
|
||||
fallbackOccurred = true;
|
||||
return await fallbackCallback(authType);
|
||||
},
|
||||
authType: 'oauth-personal',
|
||||
});
|
||||
|
||||
await vi.runAllTimersAsync();
|
||||
|
||||
await expect(promise).resolves.toBe('success');
|
||||
|
||||
// Should trigger fallback after 2 consecutive 429s (attempts 2-3)
|
||||
expect(fallbackCallback).toHaveBeenCalledWith('oauth-personal');
|
||||
});
|
||||
});
|
||||
|
||||
describe('Qwen OAuth 429 error handling', () => {
|
||||
it('should retry for Qwen OAuth 429 errors that are throttling-related', async () => {
|
||||
const errorWith429: HttpError = new Error('Rate limit exceeded');
|
||||
|
||||
@@ -7,8 +7,6 @@
|
||||
import type { GenerateContentResponse } from '@google/genai';
|
||||
import { AuthType } from '../core/contentGenerator.js';
|
||||
import {
|
||||
isProQuotaExceededError,
|
||||
isGenericQuotaExceededError,
|
||||
isQwenQuotaExceededError,
|
||||
isQwenThrottlingError,
|
||||
} from './quotaErrorDetection.js';
|
||||
@@ -90,7 +88,6 @@ export async function retryWithBackoff<T>(
|
||||
maxAttempts,
|
||||
initialDelayMs,
|
||||
maxDelayMs,
|
||||
onPersistent429,
|
||||
authType,
|
||||
shouldRetryOnError,
|
||||
shouldRetryOnContent,
|
||||
@@ -123,59 +120,6 @@ export async function retryWithBackoff<T>(
|
||||
} catch (error) {
|
||||
const errorStatus = getErrorStatus(error);
|
||||
|
||||
// Check for Pro quota exceeded error first - immediate fallback for OAuth users
|
||||
if (
|
||||
errorStatus === 429 &&
|
||||
authType === AuthType.LOGIN_WITH_GOOGLE &&
|
||||
isProQuotaExceededError(error) &&
|
||||
onPersistent429
|
||||
) {
|
||||
try {
|
||||
const fallbackModel = await onPersistent429(authType, error);
|
||||
if (fallbackModel !== false && fallbackModel !== null) {
|
||||
// Reset attempt counter and try with new model
|
||||
attempt = 0;
|
||||
consecutive429Count = 0;
|
||||
currentDelay = initialDelayMs;
|
||||
// With the model updated, we continue to the next attempt
|
||||
continue;
|
||||
} else {
|
||||
// Fallback handler returned null/false, meaning don't continue - stop retry process
|
||||
throw error;
|
||||
}
|
||||
} catch (fallbackError) {
|
||||
// If fallback fails, continue with original error
|
||||
console.warn('Fallback to Flash model failed:', fallbackError);
|
||||
}
|
||||
}
|
||||
|
||||
// Check for generic quota exceeded error (but not Pro, which was handled above) - immediate fallback for OAuth users
|
||||
if (
|
||||
errorStatus === 429 &&
|
||||
authType === AuthType.LOGIN_WITH_GOOGLE &&
|
||||
!isProQuotaExceededError(error) &&
|
||||
isGenericQuotaExceededError(error) &&
|
||||
onPersistent429
|
||||
) {
|
||||
try {
|
||||
const fallbackModel = await onPersistent429(authType, error);
|
||||
if (fallbackModel !== false && fallbackModel !== null) {
|
||||
// Reset attempt counter and try with new model
|
||||
attempt = 0;
|
||||
consecutive429Count = 0;
|
||||
currentDelay = initialDelayMs;
|
||||
// With the model updated, we continue to the next attempt
|
||||
continue;
|
||||
} else {
|
||||
// Fallback handler returned null/false, meaning don't continue - stop retry process
|
||||
throw error;
|
||||
}
|
||||
} catch (fallbackError) {
|
||||
// If fallback fails, continue with original error
|
||||
console.warn('Fallback to Flash model failed:', fallbackError);
|
||||
}
|
||||
}
|
||||
|
||||
// Check for Qwen OAuth quota exceeded error - throw immediately without retry
|
||||
if (authType === AuthType.QWEN_OAUTH && isQwenQuotaExceededError(error)) {
|
||||
throw new Error(
|
||||
@@ -197,30 +141,7 @@ export async function retryWithBackoff<T>(
|
||||
consecutive429Count = 0;
|
||||
}
|
||||
|
||||
// If we have persistent 429s and a fallback callback for OAuth
|
||||
if (
|
||||
consecutive429Count >= 2 &&
|
||||
onPersistent429 &&
|
||||
authType === AuthType.LOGIN_WITH_GOOGLE
|
||||
) {
|
||||
try {
|
||||
const fallbackModel = await onPersistent429(authType, error);
|
||||
if (fallbackModel !== false && fallbackModel !== null) {
|
||||
// Reset attempt counter and try with new model
|
||||
attempt = 0;
|
||||
consecutive429Count = 0;
|
||||
currentDelay = initialDelayMs;
|
||||
// With the model updated, we continue to the next attempt
|
||||
continue;
|
||||
} else {
|
||||
// Fallback handler returned null/false, meaning don't continue - stop retry process
|
||||
throw error;
|
||||
}
|
||||
} catch (fallbackError) {
|
||||
// If fallback fails, continue with original error
|
||||
console.warn('Fallback to Flash model failed:', fallbackError);
|
||||
}
|
||||
}
|
||||
console.debug('consecutive429Count', consecutive429Count);
|
||||
|
||||
// Check if we've exhausted retries or shouldn't retry
|
||||
if (attempt >= maxAttempts || !shouldRetryOnError(error as Error)) {
|
||||
@@ -240,7 +161,7 @@ export async function retryWithBackoff<T>(
|
||||
// Reset currentDelay for next potential non-429 error, or if Retry-After is not present next time
|
||||
currentDelay = initialDelayMs;
|
||||
} else {
|
||||
// Fall back to exponential backoff with jitter
|
||||
// Fallback to exponential backoff with jitter
|
||||
logRetryAttempt(attempt, error, errorStatus);
|
||||
// Add jitter: +/- 30% of currentDelay
|
||||
const jitter = currentDelay * 0.3 * (Math.random() * 2 - 1);
|
||||
|
||||
@@ -1,340 +0,0 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import type { Mock } from 'vitest';
|
||||
import { vi, describe, it, expect, beforeEach, afterEach } from 'vitest';
|
||||
import { UserAccountManager } from './userAccountManager.js';
|
||||
import * as fs from 'node:fs';
|
||||
import * as os from 'node:os';
|
||||
import path from 'node:path';
|
||||
|
||||
vi.mock('os', async (importOriginal) => {
|
||||
const os = await importOriginal<typeof import('os')>();
|
||||
return {
|
||||
...os,
|
||||
homedir: vi.fn(),
|
||||
};
|
||||
});
|
||||
|
||||
describe('UserAccountManager', () => {
|
||||
let tempHomeDir: string;
|
||||
let userAccountManager: UserAccountManager;
|
||||
let accountsFile: () => string;
|
||||
|
||||
beforeEach(() => {
|
||||
tempHomeDir = fs.mkdtempSync(
|
||||
path.join(os.tmpdir(), 'qwen-code-test-home-'),
|
||||
);
|
||||
(os.homedir as Mock).mockReturnValue(tempHomeDir);
|
||||
accountsFile = () =>
|
||||
path.join(tempHomeDir, '.qwen', 'google_accounts.json');
|
||||
userAccountManager = new UserAccountManager();
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
fs.rmSync(tempHomeDir, { recursive: true, force: true });
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
describe('cacheGoogleAccount', () => {
|
||||
it('should create directory and write initial account file', async () => {
|
||||
await userAccountManager.cacheGoogleAccount('test1@google.com');
|
||||
|
||||
// Verify Google Account ID was cached
|
||||
expect(fs.existsSync(accountsFile())).toBe(true);
|
||||
expect(fs.readFileSync(accountsFile(), 'utf-8')).toBe(
|
||||
JSON.stringify({ active: 'test1@google.com', old: [] }, null, 2),
|
||||
);
|
||||
});
|
||||
|
||||
it('should update active account and move previous to old', async () => {
|
||||
fs.mkdirSync(path.dirname(accountsFile()), { recursive: true });
|
||||
fs.writeFileSync(
|
||||
accountsFile(),
|
||||
JSON.stringify(
|
||||
{ active: 'test2@google.com', old: ['test1@google.com'] },
|
||||
null,
|
||||
2,
|
||||
),
|
||||
);
|
||||
|
||||
await userAccountManager.cacheGoogleAccount('test3@google.com');
|
||||
|
||||
expect(fs.readFileSync(accountsFile(), 'utf-8')).toBe(
|
||||
JSON.stringify(
|
||||
{
|
||||
active: 'test3@google.com',
|
||||
old: ['test1@google.com', 'test2@google.com'],
|
||||
},
|
||||
null,
|
||||
2,
|
||||
),
|
||||
);
|
||||
});
|
||||
|
||||
it('should not add a duplicate to the old list', async () => {
|
||||
fs.mkdirSync(path.dirname(accountsFile()), { recursive: true });
|
||||
fs.writeFileSync(
|
||||
accountsFile(),
|
||||
JSON.stringify(
|
||||
{ active: 'test1@google.com', old: ['test2@google.com'] },
|
||||
null,
|
||||
2,
|
||||
),
|
||||
);
|
||||
await userAccountManager.cacheGoogleAccount('test2@google.com');
|
||||
await userAccountManager.cacheGoogleAccount('test1@google.com');
|
||||
|
||||
expect(fs.readFileSync(accountsFile(), 'utf-8')).toBe(
|
||||
JSON.stringify(
|
||||
{ active: 'test1@google.com', old: ['test2@google.com'] },
|
||||
null,
|
||||
2,
|
||||
),
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle corrupted JSON by starting fresh', async () => {
|
||||
fs.mkdirSync(path.dirname(accountsFile()), { recursive: true });
|
||||
fs.writeFileSync(accountsFile(), 'not valid json');
|
||||
const consoleLogSpy = vi
|
||||
.spyOn(console, 'log')
|
||||
.mockImplementation(() => {});
|
||||
|
||||
await userAccountManager.cacheGoogleAccount('test1@google.com');
|
||||
|
||||
expect(consoleLogSpy).toHaveBeenCalled();
|
||||
expect(JSON.parse(fs.readFileSync(accountsFile(), 'utf-8'))).toEqual({
|
||||
active: 'test1@google.com',
|
||||
old: [],
|
||||
});
|
||||
});
|
||||
|
||||
it('should handle valid JSON with incorrect schema by starting fresh', async () => {
|
||||
fs.mkdirSync(path.dirname(accountsFile()), { recursive: true });
|
||||
fs.writeFileSync(
|
||||
accountsFile(),
|
||||
JSON.stringify({ active: 'test1@google.com', old: 'not-an-array' }),
|
||||
);
|
||||
const consoleLogSpy = vi
|
||||
.spyOn(console, 'log')
|
||||
.mockImplementation(() => {});
|
||||
|
||||
await userAccountManager.cacheGoogleAccount('test2@google.com');
|
||||
|
||||
expect(consoleLogSpy).toHaveBeenCalled();
|
||||
expect(JSON.parse(fs.readFileSync(accountsFile(), 'utf-8'))).toEqual({
|
||||
active: 'test2@google.com',
|
||||
old: [],
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('getCachedGoogleAccount', () => {
|
||||
it('should return the active account if file exists and is valid', () => {
|
||||
fs.mkdirSync(path.dirname(accountsFile()), { recursive: true });
|
||||
fs.writeFileSync(
|
||||
accountsFile(),
|
||||
JSON.stringify({ active: 'active@google.com', old: [] }, null, 2),
|
||||
);
|
||||
const account = userAccountManager.getCachedGoogleAccount();
|
||||
expect(account).toBe('active@google.com');
|
||||
});
|
||||
|
||||
it('should return null if file does not exist', () => {
|
||||
const account = userAccountManager.getCachedGoogleAccount();
|
||||
expect(account).toBeNull();
|
||||
});
|
||||
|
||||
it('should return null if file is empty', () => {
|
||||
fs.mkdirSync(path.dirname(accountsFile()), { recursive: true });
|
||||
fs.writeFileSync(accountsFile(), '');
|
||||
const account = userAccountManager.getCachedGoogleAccount();
|
||||
expect(account).toBeNull();
|
||||
});
|
||||
|
||||
it('should return null and log if file is corrupted', () => {
|
||||
fs.mkdirSync(path.dirname(accountsFile()), { recursive: true });
|
||||
fs.writeFileSync(accountsFile(), '{ "active": "test@google.com"'); // Invalid JSON
|
||||
const consoleLogSpy = vi
|
||||
.spyOn(console, 'log')
|
||||
.mockImplementation(() => {});
|
||||
|
||||
const account = userAccountManager.getCachedGoogleAccount();
|
||||
|
||||
expect(account).toBeNull();
|
||||
expect(consoleLogSpy).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should return null if active key is missing', () => {
|
||||
fs.mkdirSync(path.dirname(accountsFile()), { recursive: true });
|
||||
fs.writeFileSync(accountsFile(), JSON.stringify({ old: [] }));
|
||||
const account = userAccountManager.getCachedGoogleAccount();
|
||||
expect(account).toBeNull();
|
||||
});
|
||||
});
|
||||
|
||||
describe('clearCachedGoogleAccount', () => {
|
||||
it('should set active to null and move it to old', async () => {
|
||||
fs.mkdirSync(path.dirname(accountsFile()), { recursive: true });
|
||||
fs.writeFileSync(
|
||||
accountsFile(),
|
||||
JSON.stringify(
|
||||
{ active: 'active@google.com', old: ['old1@google.com'] },
|
||||
null,
|
||||
2,
|
||||
),
|
||||
);
|
||||
|
||||
await userAccountManager.clearCachedGoogleAccount();
|
||||
|
||||
const stored = JSON.parse(fs.readFileSync(accountsFile(), 'utf-8'));
|
||||
expect(stored.active).toBeNull();
|
||||
expect(stored.old).toEqual(['old1@google.com', 'active@google.com']);
|
||||
});
|
||||
|
||||
it('should handle empty file gracefully', async () => {
|
||||
fs.mkdirSync(path.dirname(accountsFile()), { recursive: true });
|
||||
fs.writeFileSync(accountsFile(), '');
|
||||
await userAccountManager.clearCachedGoogleAccount();
|
||||
const stored = JSON.parse(fs.readFileSync(accountsFile(), 'utf-8'));
|
||||
expect(stored.active).toBeNull();
|
||||
expect(stored.old).toEqual([]);
|
||||
});
|
||||
|
||||
it('should handle corrupted JSON by creating a fresh file', async () => {
|
||||
fs.mkdirSync(path.dirname(accountsFile()), { recursive: true });
|
||||
fs.writeFileSync(accountsFile(), 'not valid json');
|
||||
const consoleLogSpy = vi
|
||||
.spyOn(console, 'log')
|
||||
.mockImplementation(() => {});
|
||||
|
||||
await userAccountManager.clearCachedGoogleAccount();
|
||||
|
||||
expect(consoleLogSpy).toHaveBeenCalled();
|
||||
const stored = JSON.parse(fs.readFileSync(accountsFile(), 'utf-8'));
|
||||
expect(stored.active).toBeNull();
|
||||
expect(stored.old).toEqual([]);
|
||||
});
|
||||
|
||||
it('should be idempotent if active account is already null', async () => {
|
||||
fs.mkdirSync(path.dirname(accountsFile()), { recursive: true });
|
||||
fs.writeFileSync(
|
||||
accountsFile(),
|
||||
JSON.stringify({ active: null, old: ['old1@google.com'] }, null, 2),
|
||||
);
|
||||
|
||||
await userAccountManager.clearCachedGoogleAccount();
|
||||
|
||||
const stored = JSON.parse(fs.readFileSync(accountsFile(), 'utf-8'));
|
||||
expect(stored.active).toBeNull();
|
||||
expect(stored.old).toEqual(['old1@google.com']);
|
||||
});
|
||||
|
||||
it('should not add a duplicate to the old list', async () => {
|
||||
fs.mkdirSync(path.dirname(accountsFile()), { recursive: true });
|
||||
fs.writeFileSync(
|
||||
accountsFile(),
|
||||
JSON.stringify(
|
||||
{
|
||||
active: 'active@google.com',
|
||||
old: ['active@google.com'],
|
||||
},
|
||||
null,
|
||||
2,
|
||||
),
|
||||
);
|
||||
|
||||
await userAccountManager.clearCachedGoogleAccount();
|
||||
|
||||
const stored = JSON.parse(fs.readFileSync(accountsFile(), 'utf-8'));
|
||||
expect(stored.active).toBeNull();
|
||||
expect(stored.old).toEqual(['active@google.com']);
|
||||
});
|
||||
});
|
||||
|
||||
describe('getLifetimeGoogleAccounts', () => {
|
||||
it('should return 0 if the file does not exist', () => {
|
||||
expect(userAccountManager.getLifetimeGoogleAccounts()).toBe(0);
|
||||
});
|
||||
|
||||
it('should return 0 if the file is empty', () => {
|
||||
fs.mkdirSync(path.dirname(accountsFile()), { recursive: true });
|
||||
fs.writeFileSync(accountsFile(), '');
|
||||
expect(userAccountManager.getLifetimeGoogleAccounts()).toBe(0);
|
||||
});
|
||||
|
||||
it('should return 0 if the file is corrupted', () => {
|
||||
fs.mkdirSync(path.dirname(accountsFile()), { recursive: true });
|
||||
fs.writeFileSync(accountsFile(), 'invalid json');
|
||||
const consoleDebugSpy = vi
|
||||
.spyOn(console, 'log')
|
||||
.mockImplementation(() => {});
|
||||
|
||||
expect(userAccountManager.getLifetimeGoogleAccounts()).toBe(0);
|
||||
expect(consoleDebugSpy).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should return 1 if there is only an active account', () => {
|
||||
fs.mkdirSync(path.dirname(accountsFile()), { recursive: true });
|
||||
fs.writeFileSync(
|
||||
accountsFile(),
|
||||
JSON.stringify({ active: 'test1@google.com', old: [] }),
|
||||
);
|
||||
expect(userAccountManager.getLifetimeGoogleAccounts()).toBe(1);
|
||||
});
|
||||
|
||||
it('should correctly count old accounts when active is null', () => {
|
||||
fs.mkdirSync(path.dirname(accountsFile()), { recursive: true });
|
||||
fs.writeFileSync(
|
||||
accountsFile(),
|
||||
JSON.stringify({
|
||||
active: null,
|
||||
old: ['test1@google.com', 'test2@google.com'],
|
||||
}),
|
||||
);
|
||||
expect(userAccountManager.getLifetimeGoogleAccounts()).toBe(2);
|
||||
});
|
||||
|
||||
it('should correctly count both active and old accounts', () => {
|
||||
fs.mkdirSync(path.dirname(accountsFile()), { recursive: true });
|
||||
fs.writeFileSync(
|
||||
accountsFile(),
|
||||
JSON.stringify({
|
||||
active: 'test3@google.com',
|
||||
old: ['test1@google.com', 'test2@google.com'],
|
||||
}),
|
||||
);
|
||||
expect(userAccountManager.getLifetimeGoogleAccounts()).toBe(3);
|
||||
});
|
||||
|
||||
it('should handle valid JSON with incorrect schema by returning 0', () => {
|
||||
fs.mkdirSync(path.dirname(accountsFile()), { recursive: true });
|
||||
fs.writeFileSync(
|
||||
accountsFile(),
|
||||
JSON.stringify({ active: null, old: 1 }),
|
||||
);
|
||||
const consoleLogSpy = vi
|
||||
.spyOn(console, 'log')
|
||||
.mockImplementation(() => {});
|
||||
|
||||
expect(userAccountManager.getLifetimeGoogleAccounts()).toBe(0);
|
||||
expect(consoleLogSpy).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should not double count if active account is also in old list', () => {
|
||||
fs.mkdirSync(path.dirname(accountsFile()), { recursive: true });
|
||||
fs.writeFileSync(
|
||||
accountsFile(),
|
||||
JSON.stringify({
|
||||
active: 'test1@google.com',
|
||||
old: ['test1@google.com', 'test2@google.com'],
|
||||
}),
|
||||
);
|
||||
expect(userAccountManager.getLifetimeGoogleAccounts()).toBe(2);
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -1,140 +0,0 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import path from 'node:path';
|
||||
import { promises as fsp, readFileSync } from 'node:fs';
|
||||
import { Storage } from '../config/storage.js';
|
||||
|
||||
interface UserAccounts {
|
||||
active: string | null;
|
||||
old: string[];
|
||||
}
|
||||
|
||||
export class UserAccountManager {
|
||||
private getGoogleAccountsCachePath(): string {
|
||||
return Storage.getGoogleAccountsPath();
|
||||
}
|
||||
|
||||
/**
|
||||
* Parses and validates the string content of an accounts file.
|
||||
* @param content The raw string content from the file.
|
||||
* @returns A valid UserAccounts object.
|
||||
*/
|
||||
private parseAndValidateAccounts(content: string): UserAccounts {
|
||||
const defaultState = { active: null, old: [] };
|
||||
if (!content.trim()) {
|
||||
return defaultState;
|
||||
}
|
||||
|
||||
const parsed = JSON.parse(content);
|
||||
|
||||
// Inlined validation logic
|
||||
if (typeof parsed !== 'object' || parsed === null) {
|
||||
console.log('Invalid accounts file schema, starting fresh.');
|
||||
return defaultState;
|
||||
}
|
||||
const { active, old } = parsed as Partial<UserAccounts>;
|
||||
const isValid =
|
||||
(active === undefined || active === null || typeof active === 'string') &&
|
||||
(old === undefined ||
|
||||
(Array.isArray(old) && old.every((i) => typeof i === 'string')));
|
||||
|
||||
if (!isValid) {
|
||||
console.log('Invalid accounts file schema, starting fresh.');
|
||||
return defaultState;
|
||||
}
|
||||
|
||||
return {
|
||||
active: parsed.active ?? null,
|
||||
old: parsed.old ?? [],
|
||||
};
|
||||
}
|
||||
|
||||
private readAccountsSync(filePath: string): UserAccounts {
|
||||
const defaultState = { active: null, old: [] };
|
||||
try {
|
||||
const content = readFileSync(filePath, 'utf-8');
|
||||
return this.parseAndValidateAccounts(content);
|
||||
} catch (error) {
|
||||
if (
|
||||
error instanceof Error &&
|
||||
'code' in error &&
|
||||
error.code === 'ENOENT'
|
||||
) {
|
||||
return defaultState;
|
||||
}
|
||||
console.log('Error during sync read of accounts, starting fresh.', error);
|
||||
return defaultState;
|
||||
}
|
||||
}
|
||||
|
||||
private async readAccounts(filePath: string): Promise<UserAccounts> {
|
||||
const defaultState = { active: null, old: [] };
|
||||
try {
|
||||
const content = await fsp.readFile(filePath, 'utf-8');
|
||||
return this.parseAndValidateAccounts(content);
|
||||
} catch (error) {
|
||||
if (
|
||||
error instanceof Error &&
|
||||
'code' in error &&
|
||||
error.code === 'ENOENT'
|
||||
) {
|
||||
return defaultState;
|
||||
}
|
||||
console.log('Could not parse accounts file, starting fresh.', error);
|
||||
return defaultState;
|
||||
}
|
||||
}
|
||||
|
||||
async cacheGoogleAccount(email: string): Promise<void> {
|
||||
const filePath = this.getGoogleAccountsCachePath();
|
||||
await fsp.mkdir(path.dirname(filePath), { recursive: true });
|
||||
|
||||
const accounts = await this.readAccounts(filePath);
|
||||
|
||||
if (accounts.active && accounts.active !== email) {
|
||||
if (!accounts.old.includes(accounts.active)) {
|
||||
accounts.old.push(accounts.active);
|
||||
}
|
||||
}
|
||||
|
||||
// If the new email was in the old list, remove it
|
||||
accounts.old = accounts.old.filter((oldEmail) => oldEmail !== email);
|
||||
|
||||
accounts.active = email;
|
||||
await fsp.writeFile(filePath, JSON.stringify(accounts, null, 2), 'utf-8');
|
||||
}
|
||||
|
||||
getCachedGoogleAccount(): string | null {
|
||||
const filePath = this.getGoogleAccountsCachePath();
|
||||
const accounts = this.readAccountsSync(filePath);
|
||||
return accounts.active;
|
||||
}
|
||||
|
||||
getLifetimeGoogleAccounts(): number {
|
||||
const filePath = this.getGoogleAccountsCachePath();
|
||||
const accounts = this.readAccountsSync(filePath);
|
||||
const allAccounts = new Set(accounts.old);
|
||||
if (accounts.active) {
|
||||
allAccounts.add(accounts.active);
|
||||
}
|
||||
return allAccounts.size;
|
||||
}
|
||||
|
||||
async clearCachedGoogleAccount(): Promise<void> {
|
||||
const filePath = this.getGoogleAccountsCachePath();
|
||||
const accounts = await this.readAccounts(filePath);
|
||||
|
||||
if (accounts.active) {
|
||||
if (!accounts.old.includes(accounts.active)) {
|
||||
accounts.old.push(accounts.active);
|
||||
}
|
||||
accounts.active = null;
|
||||
}
|
||||
|
||||
await fsp.writeFile(filePath, JSON.stringify(accounts, null, 2), 'utf-8');
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user