Remove auto-execution on Flash in the event of a 429/Quota failover (#3662)

Co-authored-by: Jenna Inouye <jinouye@google.com>
This commit is contained in:
Bryan Morgan
2025-07-09 13:55:56 -04:00
committed by GitHub
parent 01e756481f
commit 8a6509ffeb
14 changed files with 292 additions and 86 deletions

View File

@@ -31,7 +31,23 @@ import {
toCountTokenRequest,
toGenerateContentRequest,
} from './converter.js';
import { PassThrough } from 'node:stream';
import { Readable } from 'node:stream';
interface ErrorData {
error?: {
message?: string;
};
}
interface GaxiosResponse {
status: number;
data: unknown;
}
interface StreamError extends Error {
status?: number;
response?: GaxiosResponse;
}
/** HTTP options to be used in each of the requests. */
export interface HttpOptions {
@@ -177,8 +193,45 @@ export class CodeAssistServer implements ContentGenerator {
});
return (async function* (): AsyncGenerator<T> {
// Convert ReadableStream to Node.js stream if needed
let nodeStream: NodeJS.ReadableStream;
if (res.data instanceof ReadableStream) {
// Convert Web ReadableStream to Node.js Readable stream
// eslint-disable-next-line @typescript-eslint/no-explicit-any
nodeStream = Readable.fromWeb(res.data as any);
} else if (
res.data &&
typeof (res.data as NodeJS.ReadableStream).on === 'function'
) {
// Already a Node.js stream
nodeStream = res.data as NodeJS.ReadableStream;
} else {
// If res.data is not a stream, it might be an error response
// Try to extract error information from the response
let errorMessage =
'Response data is not a readable stream. This may indicate a server error or quota issue.';
if (res.data && typeof res.data === 'object') {
// Check if this is an error response with error details
const errorData = res.data as ErrorData;
if (errorData.error?.message) {
errorMessage = errorData.error.message;
} else if (typeof errorData === 'string') {
errorMessage = errorData;
}
}
// Create an error that looks like a quota error if it contains quota information
const error: StreamError = new Error(errorMessage);
// Add status and response properties so it can be properly handled by retry logic
error.status = res.status;
error.response = res;
throw error;
}
const rl = readline.createInterface({
input: res.data as PassThrough,
input: nodeStream,
crlfDelay: Infinity, // Recognizes '\r\n' and '\n' as line breaks
});

View File

@@ -104,7 +104,7 @@ export type FlashFallbackHandler = (
currentModel: string,
fallbackModel: string,
error?: unknown,
) => Promise<boolean>;
) => Promise<boolean | string | null>;
export interface ConfigParameters {
sessionId: string;
@@ -183,6 +183,7 @@ export class Config {
private readonly listExtensions: boolean;
private readonly _activeExtensions: ActiveExtension[];
flashFallbackHandler?: FlashFallbackHandler;
private quotaErrorOccurred: boolean = false;
constructor(params: ConfigParameters) {
this.sessionId = params.sessionId;
@@ -304,6 +305,14 @@ export class Config {
this.flashFallbackHandler = handler;
}
setQuotaErrorOccurred(value: boolean): void {
this.quotaErrorOccurred = value;
}
getQuotaErrorOccurred(): boolean {
return this.quotaErrorOccurred;
}
getEmbeddingModel(): string {
return this.embeddingModel;
}

View File

@@ -178,6 +178,8 @@ describe('Gemini Client (client.ts)', () => {
getProxy: vi.fn().mockReturnValue(undefined),
getWorkingDir: vi.fn().mockReturnValue('/test/dir'),
getFileService: vi.fn().mockReturnValue(fileService),
getQuotaErrorOccurred: vi.fn().mockReturnValue(false),
setQuotaErrorOccurred: vi.fn(),
};
return mock as unknown as Config;
});
@@ -351,7 +353,7 @@ describe('Gemini Client (client.ts)', () => {
await client.generateJson(contents, schema, abortSignal);
expect(mockGenerateContentFn).toHaveBeenCalledWith({
model: DEFAULT_GEMINI_FLASH_MODEL,
model: 'test-model', // Should use current model from config
config: {
abortSignal,
systemInstruction: getCoreSystemPrompt(''),

View File

@@ -262,6 +262,7 @@ export class GeminiClient {
request: PartListUnion,
signal: AbortSignal,
turns: number = this.MAX_TURNS,
originalModel?: string,
): AsyncGenerator<ServerGeminiStreamEvent, Turn> {
// Ensure turns never exceeds MAX_TURNS to prevent infinite loops
const boundedTurns = Math.min(turns, this.MAX_TURNS);
@@ -269,6 +270,9 @@ export class GeminiClient {
return new Turn(this.getChat());
}
// Track the original model from the first call to detect model switching
const initialModel = originalModel || this.config.getModel();
const compressed = await this.tryCompressChat();
if (compressed) {
yield { type: GeminiEventType.ChatCompressed, value: compressed };
@@ -279,6 +283,14 @@ export class GeminiClient {
yield event;
}
if (!turn.pendingToolCalls.length && signal && !signal.aborted) {
// Check if model was switched during the call (likely due to quota error)
const currentModel = this.config.getModel();
if (currentModel !== initialModel) {
// Model was switched (likely due to quota error fallback)
// Don't continue with recursive call to prevent unwanted Flash execution
return turn;
}
const nextSpeakerCheck = await checkNextSpeaker(
this.getChat(),
this,
@@ -288,7 +300,12 @@ export class GeminiClient {
const nextRequest = [{ text: 'Please continue.' }];
// This recursive call's events will be yielded out, but the final
// turn object will be from the top-level call.
yield* this.sendMessageStream(nextRequest, signal, boundedTurns - 1);
yield* this.sendMessageStream(
nextRequest,
signal,
boundedTurns - 1,
initialModel,
);
}
}
return turn;
@@ -298,9 +315,12 @@ export class GeminiClient {
contents: Content[],
schema: SchemaUnion,
abortSignal: AbortSignal,
model: string = DEFAULT_GEMINI_FLASH_MODEL,
model?: string,
config: GenerateContentConfig = {},
): Promise<Record<string, unknown>> {
// Use current model from config instead of hardcoded Flash model
const modelToUse =
model || this.config.getModel() || DEFAULT_GEMINI_FLASH_MODEL;
try {
const userMemory = this.config.getUserMemory();
const systemInstruction = getCoreSystemPrompt(userMemory);
@@ -312,7 +332,7 @@ export class GeminiClient {
const apiCall = () =>
this.getContentGenerator().generateContent({
model,
model: modelToUse,
config: {
...requestConfig,
systemInstruction,
@@ -585,10 +605,14 @@ export class GeminiClient {
fallbackModel,
error,
);
if (accepted) {
if (accepted !== false && accepted !== null) {
this.config.setModel(fallbackModel);
return fallbackModel;
}
// Check if the model was switched manually in the handler
if (this.config.getModel() === fallbackModel) {
return null; // Model was switched but don't continue with current prompt
}
} catch (error) {
console.warn('Flash fallback handler failed:', error);
}

View File

@@ -43,6 +43,8 @@ describe('GeminiChat', () => {
}),
getModel: vi.fn().mockReturnValue('gemini-pro'),
setModel: vi.fn(),
getQuotaErrorOccurred: vi.fn().mockReturnValue(false),
setQuotaErrorOccurred: vi.fn(),
flashFallbackHandler: undefined,
} as unknown as Config;

View File

@@ -217,10 +217,14 @@ export class GeminiChat {
fallbackModel,
error,
);
if (accepted) {
if (accepted !== false && accepted !== null) {
this.config.setModel(fallbackModel);
return fallbackModel;
}
// Check if the model was switched manually in the handler
if (this.config.getModel() === fallbackModel) {
return null; // Model was switched but don't continue with current prompt
}
} catch (error) {
console.warn('Flash fallback handler failed:', error);
}
@@ -262,12 +266,25 @@ export class GeminiChat {
let response: GenerateContentResponse;
try {
const apiCall = () =>
this.contentGenerator.generateContent({
model: this.config.getModel() || DEFAULT_GEMINI_FLASH_MODEL,
const apiCall = () => {
const modelToUse = this.config.getModel() || DEFAULT_GEMINI_FLASH_MODEL;
// Prevent Flash model calls immediately after quota error
if (
this.config.getQuotaErrorOccurred() &&
modelToUse === DEFAULT_GEMINI_FLASH_MODEL
) {
throw new Error(
'Please submit a new query to continue with the Flash model.',
);
}
return this.contentGenerator.generateContent({
model: modelToUse,
contents: requestContents,
config: { ...this.generationConfig, ...params.config },
});
};
response = await retryWithBackoff(apiCall, {
shouldRetry: (error: Error) => {
@@ -354,12 +371,25 @@ export class GeminiChat {
const startTime = Date.now();
try {
const apiCall = () =>
this.contentGenerator.generateContentStream({
model: this.config.getModel(),
const apiCall = () => {
const modelToUse = this.config.getModel();
// Prevent Flash model calls immediately after quota error
if (
this.config.getQuotaErrorOccurred() &&
modelToUse === DEFAULT_GEMINI_FLASH_MODEL
) {
throw new Error(
'Please submit a new query to continue with the Flash model.',
);
}
return this.contentGenerator.generateContentStream({
model: modelToUse,
contents: requestContents,
config: { ...this.generationConfig, ...params.config },
});
};
// Note: Retrying streams can be complex. If generateContentStream itself doesn't handle retries
// for transient issues internally before yielding the async generator, this retry will re-initiate

View File

@@ -214,6 +214,8 @@ describe('editCorrector', () => {
setAlwaysSkipModificationConfirmation: vi.fn((skip: boolean) => {
configParams.alwaysSkipModificationConfirmation = skip;
}),
getQuotaErrorOccurred: vi.fn().mockReturnValue(false),
setQuotaErrorOccurred: vi.fn(),
} as unknown as Config;
callCount = 0;
@@ -654,6 +656,8 @@ describe('editCorrector', () => {
setAlwaysSkipModificationConfirmation: vi.fn((skip: boolean) => {
configParams.alwaysSkipModificationConfirmation = skip;
}),
getQuotaErrorOccurred: vi.fn().mockReturnValue(false),
setQuotaErrorOccurred: vi.fn(),
} as unknown as Config;
callCount = 0;

View File

@@ -41,14 +41,23 @@ export function isProQuotaExceededError(error: unknown): boolean {
// Check for Pro quota exceeded errors by looking for the specific pattern
// This will match patterns like:
// - "Quota exceeded for quota metric 'Gemini 2.5 Pro Requests'"
// - "Quota exceeded for quota metric 'Gemini 1.5-preview Pro Requests'"
// - "Quota exceeded for quota metric 'Gemini beta-3.0 Pro Requests'"
// - "Quota exceeded for quota metric 'Gemini experimental-v2 Pro Requests'"
// - "Quota exceeded for quota metric 'Gemini 2.5-preview Pro Requests'"
// We use string methods instead of regex to avoid ReDoS vulnerabilities
const checkMessage = (message: string): boolean =>
message.includes("Quota exceeded for quota metric 'Gemini") &&
message.includes("Pro Requests'");
const checkMessage = (message: string): boolean => {
console.log('[DEBUG] isProQuotaExceededError checking message:', message);
const result =
message.includes("Quota exceeded for quota metric 'Gemini") &&
message.includes("Pro Requests'");
console.log('[DEBUG] isProQuotaExceededError result:', result);
return result;
};
// Log the full error object to understand its structure
console.log(
'[DEBUG] isProQuotaExceededError - full error object:',
JSON.stringify(error, null, 2),
);
if (typeof error === 'string') {
return checkMessage(error);
@@ -62,6 +71,38 @@ export function isProQuotaExceededError(error: unknown): boolean {
return checkMessage(error.error.message);
}
// Check if it's a Gaxios error with response data
if (error && typeof error === 'object' && 'response' in error) {
const gaxiosError = error as {
response?: {
data?: unknown;
};
};
if (gaxiosError.response && gaxiosError.response.data) {
console.log(
'[DEBUG] isProQuotaExceededError - checking response data:',
gaxiosError.response.data,
);
if (typeof gaxiosError.response.data === 'string') {
return checkMessage(gaxiosError.response.data);
}
if (
typeof gaxiosError.response.data === 'object' &&
gaxiosError.response.data !== null &&
'error' in gaxiosError.response.data
) {
const errorData = gaxiosError.response.data as {
error?: { message?: string };
};
return checkMessage(errorData.error?.message || '');
}
}
}
console.log(
'[DEBUG] isProQuotaExceededError - no matching error format for:',
error,
);
return false;
}

View File

@@ -18,7 +18,7 @@ export interface RetryOptions {
onPersistent429?: (
authType?: string,
error?: unknown,
) => Promise<string | null>;
) => Promise<string | boolean | null>;
authType?: string;
}
@@ -102,13 +102,16 @@ export async function retryWithBackoff<T>(
) {
try {
const fallbackModel = await onPersistent429(authType, error);
if (fallbackModel) {
if (fallbackModel !== false && fallbackModel !== null) {
// Reset attempt counter and try with new model
attempt = 0;
consecutive429Count = 0;
currentDelay = initialDelayMs;
// With the model updated, we continue to the next attempt
continue;
} else {
// Fallback handler returned null/false, meaning don't continue - stop retry process
throw error;
}
} catch (fallbackError) {
// If fallback fails, continue with original error
@@ -126,13 +129,16 @@ export async function retryWithBackoff<T>(
) {
try {
const fallbackModel = await onPersistent429(authType, error);
if (fallbackModel) {
if (fallbackModel !== false && fallbackModel !== null) {
// Reset attempt counter and try with new model
attempt = 0;
consecutive429Count = 0;
currentDelay = initialDelayMs;
// With the model updated, we continue to the next attempt
continue;
} else {
// Fallback handler returned null/false, meaning don't continue - stop retry process
throw error;
}
} catch (fallbackError) {
// If fallback fails, continue with original error
@@ -155,13 +161,16 @@ export async function retryWithBackoff<T>(
) {
try {
const fallbackModel = await onPersistent429(authType, error);
if (fallbackModel) {
if (fallbackModel !== false && fallbackModel !== null) {
// Reset attempt counter and try with new model
attempt = 0;
consecutive429Count = 0;
currentDelay = initialDelayMs;
// With the model updated, we continue to the next attempt
continue;
} else {
// Fallback handler returned null/false, meaning don't continue - stop retry process
throw error;
}
} catch (fallbackError) {
// If fallback fails, continue with original error