feat(oauth): add Qwen OAuth integration

This commit is contained in:
mingholy.lmh
2025-08-08 09:48:31 +08:00
parent ffc2d27ca3
commit ea7dcf8347
37 changed files with 7795 additions and 169 deletions

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,862 @@
/**
* @license
* Copyright 2025 Qwen
* SPDX-License-Identifier: Apache-2.0
*/
import crypto from 'crypto';
import path from 'node:path';
import { promises as fs } from 'node:fs';
import * as os from 'os';
import open from 'open';
import { EventEmitter } from 'events';
import { Config } from '../config/config.js';
import { randomUUID } from 'node:crypto';
// OAuth Endpoints
const QWEN_OAUTH_BASE_URL = 'https://pre4-chat.qwen.ai';
// const QWEN_OAUTH_BASE_URL = 'https://chat.qwen.ai';
const QWEN_OAUTH_DEVICE_CODE_ENDPOINT = `${QWEN_OAUTH_BASE_URL}/api/v1/oauth2/device/code`;
const QWEN_OAUTH_TOKEN_ENDPOINT = `${QWEN_OAUTH_BASE_URL}/api/v1/oauth2/token`;
// OAuth Client Configuration
// const QWEN_OAUTH_CLIENT_ID = '93a239d6ed36412c8c442e91b60fa305';
const QWEN_OAUTH_CLIENT_ID = 'f0304373b74a44d2b584a3fb70ca9e56';
const QWEN_OAUTH_SCOPE = 'openid profile email model.completion';
const QWEN_OAUTH_GRANT_TYPE = 'urn:ietf:params:oauth:grant-type:device_code';
// File System Configuration
const QWEN_DIR = '.qwen';
const QWEN_CREDENTIAL_FILENAME = 'oauth_creds.json';
// Token Configuration
const TOKEN_REFRESH_BUFFER_MS = 30 * 1000; // 30 seconds
/**
* PKCE (Proof Key for Code Exchange) utilities
* Implements RFC 7636 - Proof Key for Code Exchange by OAuth Public Clients
*/
/**
* Generate a random code verifier for PKCE
* @returns A random string of 43-128 characters
*/
export function generateCodeVerifier(): string {
return crypto.randomBytes(32).toString('base64url');
}
/**
* Generate a code challenge from a code verifier using SHA-256
* @param codeVerifier The code verifier string
* @returns The code challenge string
*/
export function generateCodeChallenge(codeVerifier: string): string {
const hash = crypto.createHash('sha256');
hash.update(codeVerifier);
return hash.digest('base64url');
}
/**
* Generate PKCE code verifier and challenge pair
* @returns Object containing code_verifier and code_challenge
*/
export function generatePKCEPair(): {
code_verifier: string;
code_challenge: string;
} {
const codeVerifier = generateCodeVerifier();
const codeChallenge = generateCodeChallenge(codeVerifier);
return { code_verifier: codeVerifier, code_challenge: codeChallenge };
}
/**
* Convert object to URL-encoded form data
* @param data The object to convert
* @returns URL-encoded string
*/
function objectToUrlEncoded(data: Record<string, string>): string {
return Object.keys(data)
.map((key) => `${encodeURIComponent(key)}=${encodeURIComponent(data[key])}`)
.join('&');
}
/**
* Standard error response data
*/
export interface ErrorData {
error: string;
error_description: string;
}
/**
* Qwen OAuth2 credentials interface
*/
export interface QwenCredentials {
access_token?: string;
refresh_token?: string;
id_token?: string;
expiry_date?: number;
token_type?: string;
resource_url?: string;
}
/**
* Device authorization success data
*/
export interface DeviceAuthorizationData {
device_code: string;
user_code: string;
verification_uri: string;
verification_uri_complete: string;
expires_in: number;
}
/**
* Device authorization response interface
*/
export type DeviceAuthorizationResponse = DeviceAuthorizationData | ErrorData;
/**
* Type guard to check if device authorization was successful
*/
export function isDeviceAuthorizationSuccess(
response: DeviceAuthorizationResponse,
): response is DeviceAuthorizationData {
return 'device_code' in response;
}
/**
* Device token success data
*/
export interface DeviceTokenData {
access_token: string | null;
refresh_token?: string | null;
token_type: string;
expires_in: number | null;
scope?: string | null;
endpoint?: string;
resource_url?: string;
}
/**
* Device token pending response
*/
export interface DeviceTokenPendingData {
status: 'pending';
slowDown?: boolean; // Indicates if client should increase polling interval
}
/**
* Device token response interface
*/
export type DeviceTokenResponse =
| DeviceTokenData
| DeviceTokenPendingData
| ErrorData;
/**
* Type guard to check if device token response was successful
*/
export function isDeviceTokenSuccess(
response: DeviceTokenResponse,
): response is DeviceTokenData {
return (
'access_token' in response &&
response.access_token !== null &&
response.access_token !== undefined &&
typeof response.access_token === 'string' &&
response.access_token.length > 0
);
}
/**
* Type guard to check if device token response is pending
*/
export function isDeviceTokenPending(
response: DeviceTokenResponse,
): response is DeviceTokenPendingData {
return (
'status' in response &&
(response as DeviceTokenPendingData).status === 'pending'
);
}
/**
* Type guard to check if response is an error
*/
export function isErrorResponse(
response:
| DeviceAuthorizationResponse
| DeviceTokenResponse
| TokenRefreshResponse,
): response is ErrorData {
return 'error' in response;
}
/**
* Token refresh success data
*/
export interface TokenRefreshData {
access_token: string;
token_type: string;
expires_in: number;
refresh_token?: string; // Some OAuth servers may return a new refresh token
resource_url?: string;
}
/**
* Token refresh response interface
*/
export type TokenRefreshResponse = TokenRefreshData | ErrorData;
/**
* Qwen OAuth2 client interface
*/
export interface IQwenOAuth2Client {
setCredentials(credentials: QwenCredentials): void;
getCredentials(): QwenCredentials;
getAccessToken(): Promise<{ token?: string }>;
requestDeviceAuthorization(options: {
scope: string;
code_challenge: string;
code_challenge_method: string;
}): Promise<DeviceAuthorizationResponse>;
pollDeviceToken(options: {
device_code: string;
code_verifier: string;
}): Promise<DeviceTokenResponse>;
refreshAccessToken(): Promise<TokenRefreshResponse>;
}
/**
* Qwen OAuth2 client implementation
*/
export class QwenOAuth2Client implements IQwenOAuth2Client {
private credentials: QwenCredentials = {};
private proxy?: string;
constructor(options: { proxy?: string }) {
this.proxy = options.proxy;
}
setCredentials(credentials: QwenCredentials): void {
this.credentials = credentials;
}
getCredentials(): QwenCredentials {
return this.credentials;
}
async getAccessToken(): Promise<{ token?: string }> {
if (this.credentials.access_token && this.isTokenValid()) {
return { token: this.credentials.access_token };
}
if (this.credentials.refresh_token) {
const refreshResponse = await this.refreshAccessToken();
const tokenData = refreshResponse as TokenRefreshData;
return { token: tokenData.access_token };
}
return { token: undefined };
}
async requestDeviceAuthorization(options: {
scope: string;
code_challenge: string;
code_challenge_method: string;
}): Promise<DeviceAuthorizationResponse> {
const bodyData = {
client_id: QWEN_OAUTH_CLIENT_ID,
scope: options.scope,
code_challenge: options.code_challenge,
code_challenge_method: options.code_challenge_method,
};
const response = await fetch(QWEN_OAUTH_DEVICE_CODE_ENDPOINT, {
method: 'POST',
headers: {
'Content-Type': 'application/x-www-form-urlencoded',
Accept: 'application/json',
'x-request-id': randomUUID(),
},
body: objectToUrlEncoded(bodyData),
});
if (!response.ok) {
const errorData = await response.text();
throw new Error(
`Device authorization failed: ${response.status} ${response.statusText}. Response: ${errorData}`,
);
}
const result = (await response.json()) as DeviceAuthorizationResponse;
console.log('Device authorization result:', result);
// Check if the response indicates success
if (!isDeviceAuthorizationSuccess(result)) {
const errorData = result as ErrorData;
throw new Error(
`Device authorization failed: ${errorData?.error || 'Unknown error'} - ${errorData?.error_description || 'No details provided'}`,
);
}
return result;
}
async pollDeviceToken(options: {
device_code: string;
code_verifier: string;
}): Promise<DeviceTokenResponse> {
const bodyData = {
grant_type: QWEN_OAUTH_GRANT_TYPE,
client_id: QWEN_OAUTH_CLIENT_ID,
device_code: options.device_code,
code_verifier: options.code_verifier,
};
const response = await fetch(QWEN_OAUTH_TOKEN_ENDPOINT, {
method: 'POST',
headers: {
'Content-Type': 'application/x-www-form-urlencoded',
Accept: 'application/json',
},
body: objectToUrlEncoded(bodyData),
});
if (!response.ok) {
// Parse the response as JSON to check for OAuth RFC 8628 standard errors
try {
const errorData = (await response.json()) as ErrorData;
console.log(errorData.error);
// According to OAuth RFC 8628, handle standard polling responses
if (
response.status === 400 &&
errorData.error === 'authorization_pending'
) {
// User has not yet approved the authorization request. Continue polling.
return { status: 'pending' } as DeviceTokenPendingData;
}
if (response.status === 429 && errorData.error === 'slow_down') {
// Client is polling too frequently. Return pending with slowDown flag.
return {
status: 'pending',
slowDown: true,
} as DeviceTokenPendingData;
}
// Handle other 400 errors (access_denied, expired_token, etc.) as real errors
// For other errors, throw with proper error information
const error = new Error(
`Device token poll failed: ${errorData.error || 'Unknown error'} - ${errorData.error_description || 'No details provided'}`,
);
(error as Error & { status?: number }).status = response.status;
throw error;
} catch (_parseError) {
// If JSON parsing fails, fall back to text response
const errorData = await response.text();
const error = new Error(
`Device token poll failed: ${response.status} ${response.statusText}. Response: ${errorData}`,
);
(error as Error & { status?: number }).status = response.status;
throw error;
}
}
return (await response.json()) as DeviceTokenResponse;
}
async refreshAccessToken(): Promise<TokenRefreshResponse> {
if (!this.credentials.refresh_token) {
throw new Error('No refresh token available');
}
const bodyData = {
grant_type: 'refresh_token',
refresh_token: this.credentials.refresh_token,
client_id: QWEN_OAUTH_CLIENT_ID,
};
const response = await fetch(QWEN_OAUTH_TOKEN_ENDPOINT, {
method: 'POST',
headers: {
'Content-Type': 'application/x-www-form-urlencoded',
Accept: 'application/json',
},
body: objectToUrlEncoded(bodyData),
});
if (!response.ok) {
const errorData = await response.text();
// Handle 401 errors which might indicate refresh token expiry
if (response.status === 400) {
await clearQwenCredentials();
throw new Error(
"Refresh token expired or invalid. Please use '/auth' to re-authenticate.",
);
}
throw new Error(
`Token refresh failed: ${response.status} ${response.statusText}. Response: ${errorData}`,
);
}
const responseData = (await response.json()) as TokenRefreshResponse;
// Check if the response indicates success
if (isErrorResponse(responseData)) {
const errorData = responseData as ErrorData;
throw new Error(
`Token refresh failed: ${errorData?.error || 'Unknown error'} - ${errorData?.error_description || 'No details provided'}`,
);
}
// Handle successful response
const tokenData = responseData as TokenRefreshData;
const tokens: QwenCredentials = {
access_token: tokenData.access_token,
token_type: tokenData.token_type,
// Use new refresh token if provided, otherwise preserve existing one
refresh_token: tokenData.refresh_token || this.credentials.refresh_token,
resource_url: tokenData.resource_url, // Include resource_url if provided
expiry_date: Date.now() + tokenData.expires_in * 1000,
};
this.setCredentials(tokens);
// Cache the updated credentials to file
await cacheQwenCredentials(tokens);
return responseData;
}
private isTokenValid(): boolean {
if (!this.credentials.expiry_date) {
return false;
}
// Check if token expires within the refresh buffer time
return Date.now() < this.credentials.expiry_date - TOKEN_REFRESH_BUFFER_MS;
}
}
export enum QwenOAuth2Event {
AuthUri = 'auth-uri',
AuthProgress = 'auth-progress',
AuthCancel = 'auth-cancel',
}
/**
* Authentication result types to distinguish different failure reasons
*/
export type AuthResult =
| { success: true }
| {
success: false;
reason: 'timeout' | 'cancelled' | 'error' | 'rate_limit';
};
/**
* Global event emitter instance for QwenOAuth2 authentication events
*/
export const qwenOAuth2Events = new EventEmitter();
export async function getQwenOAuthClient(
config: Config,
): Promise<QwenOAuth2Client> {
const client = new QwenOAuth2Client({
proxy: config.getProxy(),
});
// If there are cached creds on disk, they always take precedence
if (await loadCachedQwenCredentials(client)) {
console.log('Loaded cached Qwen credentials.');
try {
await client.refreshAccessToken();
return client;
} catch (error: unknown) {
// Handle refresh token errors
const errorMessage =
error instanceof Error ? error.message : String(error);
const isInvalidToken = errorMessage.includes(
'Refresh token expired or invalid',
);
const userMessage = isInvalidToken
? 'Cached credentials are invalid. Please re-authenticate.'
: `Token refresh failed: ${errorMessage}`;
const throwMessage = isInvalidToken
? 'Cached Qwen credentials are invalid. Please re-authenticate.'
: `Qwen token refresh failed: ${errorMessage}`;
// Emit token refresh error event
qwenOAuth2Events.emit(QwenOAuth2Event.AuthProgress, 'error', userMessage);
throw new Error(throwMessage);
}
}
// Use device authorization flow for authentication (single attempt)
const result = await authWithQwenDeviceFlow(client, config);
if (!result.success) {
// Only emit timeout event if the failure reason is actually timeout
// Other error types (401, 429, etc.) have already emitted their specific events
if (result.reason === 'timeout') {
qwenOAuth2Events.emit(
QwenOAuth2Event.AuthProgress,
'timeout',
'Authentication timed out. Please try again or select a different authentication method.',
);
}
// Throw error with appropriate message based on failure reason
switch (result.reason) {
case 'timeout':
throw new Error('Qwen OAuth authentication timed out');
case 'cancelled':
throw new Error('Qwen OAuth authentication was cancelled by user');
case 'rate_limit':
throw new Error(
'Too many request for Qwen OAuth authentication, please try again later.',
);
case 'error':
default:
throw new Error('Qwen OAuth authentication failed');
}
}
return client;
}
async function authWithQwenDeviceFlow(
client: QwenOAuth2Client,
config: Config,
): Promise<AuthResult> {
let isCancelled = false;
// Set up cancellation listener
const cancelHandler = () => {
isCancelled = true;
};
qwenOAuth2Events.once(QwenOAuth2Event.AuthCancel, cancelHandler);
try {
// Generate PKCE code verifier and challenge
const { code_verifier, code_challenge } = generatePKCEPair();
// Request device authorization
const deviceAuth = await client.requestDeviceAuthorization({
scope: QWEN_OAUTH_SCOPE,
code_challenge,
code_challenge_method: 'S256',
});
// Ensure we have a successful authorization response
if (!isDeviceAuthorizationSuccess(deviceAuth)) {
const errorData = deviceAuth as ErrorData;
throw new Error(
`Device authorization failed: ${errorData?.error || 'Unknown error'} - ${errorData?.error_description || 'No details provided'}`,
);
}
// Emit device authorization event for UI integration immediately
qwenOAuth2Events.emit(QwenOAuth2Event.AuthUri, deviceAuth);
console.log('\n=== Qwen OAuth Device Authorization ===');
console.log(
`Please visit the following URL on your phone or browser for authorization:`,
);
console.log(`\n${deviceAuth.verification_uri_complete}\n`);
const showFallbackMessage = () => {
// Fallback message for console output
};
// If browser launch is not suppressed, try to open the URL
if (!config.isBrowserLaunchSuppressed()) {
try {
const childProcess = await open(deviceAuth.verification_uri_complete);
// IMPORTANT: Attach an error handler to the returned child process.
// Without this, if `open` fails to spawn a process (e.g., `xdg-open` is not found
// in a minimal Docker container), it will emit an unhandled 'error' event,
// causing the entire Node.js process to crash.
if (childProcess) {
childProcess.on('error', () => {
console.log('Failed to open browser. Visit this URL to authorize:');
showFallbackMessage();
});
}
} catch (_err) {
showFallbackMessage();
}
} else {
// Browser launch is suppressed, show fallback message
showFallbackMessage();
}
// Emit auth progress event
qwenOAuth2Events.emit(
QwenOAuth2Event.AuthProgress,
'polling',
'Waiting for authorization...',
);
console.log('Waiting for authorization...\n');
// Poll for the token
let pollInterval = 2000; // 2 seconds, can be increased if slow_down is received
const maxAttempts = Math.ceil(
deviceAuth.expires_in / (pollInterval / 1000),
);
for (let attempt = 0; attempt < maxAttempts; attempt++) {
// Check if authentication was cancelled
if (isCancelled) {
console.log('\nAuthentication cancelled by user.');
qwenOAuth2Events.emit(
QwenOAuth2Event.AuthProgress,
'error',
'Authentication cancelled by user.',
);
return { success: false, reason: 'cancelled' };
}
try {
console.log('polling for token...');
const tokenResponse = await client.pollDeviceToken({
device_code: deviceAuth.device_code,
code_verifier,
});
// Check if the response is successful and contains token data
if (isDeviceTokenSuccess(tokenResponse)) {
const tokenData = tokenResponse as DeviceTokenData;
// Convert to QwenCredentials format
const credentials: QwenCredentials = {
access_token: tokenData.access_token!, // Safe to assert as non-null due to isDeviceTokenSuccess check
refresh_token: tokenData.refresh_token || undefined,
token_type: tokenData.token_type,
resource_url: tokenData.resource_url,
expiry_date: tokenData.expires_in
? /* ts-ignore */
Date.now() + (tokenData.expires_in ?? 1) * 1000
: undefined,
};
client.setCredentials(credentials);
// Cache the new tokens
await cacheQwenCredentials(credentials);
// Emit auth progress success event
qwenOAuth2Events.emit(
QwenOAuth2Event.AuthProgress,
'success',
'Authentication successful! Access token obtained.',
);
console.log('Authentication successful! Access token obtained.');
return { success: true };
}
// Check if the response is pending
if (isDeviceTokenPending(tokenResponse)) {
const pendingData = tokenResponse as DeviceTokenPendingData;
console.log(pendingData);
// Handle slow_down error by increasing poll interval
if (pendingData.slowDown) {
pollInterval = Math.min(pollInterval * 1.5, 10000); // Increase by 50%, max 10 seconds
console.log(
`\nServer requested to slow down, increasing poll interval to ${pollInterval}ms`,
);
} else {
pollInterval = 2000; // Reset to default interval
}
// Emit polling progress event
qwenOAuth2Events.emit(
QwenOAuth2Event.AuthProgress,
'polling',
`Polling... (attempt ${attempt + 1}/${maxAttempts})`,
);
process.stdout.write('.');
// Wait with cancellation check every 100ms
await new Promise<void>((resolve) => {
const checkInterval = 100; // Check every 100ms
let elapsedTime = 0;
const intervalId = setInterval(() => {
elapsedTime += checkInterval;
// Check for cancellation during wait
if (isCancelled) {
clearInterval(intervalId);
resolve();
return;
}
// Complete wait when interval is reached
if (elapsedTime >= pollInterval) {
clearInterval(intervalId);
resolve();
return;
}
}, checkInterval);
});
// Check for cancellation after waiting
if (isCancelled) {
console.log('\nAuthentication cancelled by user.');
qwenOAuth2Events.emit(
QwenOAuth2Event.AuthProgress,
'error',
'Authentication cancelled by user.',
);
return { success: false, reason: 'cancelled' };
}
continue;
}
// Handle error response
if (isErrorResponse(tokenResponse)) {
const errorData = tokenResponse as ErrorData;
throw new Error(
`Token polling failed: ${errorData?.error || 'Unknown error'} - ${errorData?.error_description || 'No details provided'}`,
);
}
} catch (error: unknown) {
// Handle specific error cases
const errorMessage =
error instanceof Error ? error.message : String(error);
const statusCode =
error instanceof Error
? (error as Error & { status?: number }).status
: null;
if (errorMessage.includes('401') || statusCode === 401) {
const message =
'Device code expired or invalid, please restart the authorization process.';
// Emit error event
qwenOAuth2Events.emit(QwenOAuth2Event.AuthProgress, 'error', message);
return { success: false, reason: 'error' };
}
// Handle 429 Too Many Requests error
if (errorMessage.includes('429') || statusCode === 429) {
const message =
'Too many requests. The server is rate limiting our requests. Please select a different authentication method or try again later.';
// Emit rate limit event to notify user
qwenOAuth2Events.emit(
QwenOAuth2Event.AuthProgress,
'rate_limit',
message,
);
console.log('\n' + message);
// Return false to stop polling and go back to auth selection
return { success: false, reason: 'rate_limit' };
}
const message = `Error polling for token: ${errorMessage}`;
// Emit error event
qwenOAuth2Events.emit(QwenOAuth2Event.AuthProgress, 'error', message);
// Check for cancellation before waiting
if (isCancelled) {
return { success: false, reason: 'cancelled' };
}
await new Promise((resolve) => setTimeout(resolve, pollInterval));
}
}
const timeoutMessage = 'Authorization timeout, please restart the process.';
// Emit timeout error event
qwenOAuth2Events.emit(
QwenOAuth2Event.AuthProgress,
'timeout',
timeoutMessage,
);
console.error('\n' + timeoutMessage);
return { success: false, reason: 'timeout' };
} catch (error: unknown) {
const errorMessage = error instanceof Error ? error.message : String(error);
console.error('Device authorization flow failed:', errorMessage);
return { success: false, reason: 'error' };
} finally {
// Clean up event listener
qwenOAuth2Events.off(QwenOAuth2Event.AuthCancel, cancelHandler);
}
}
async function loadCachedQwenCredentials(
client: QwenOAuth2Client,
): Promise<boolean> {
try {
const keyFile = getQwenCachedCredentialPath();
const creds = await fs.readFile(keyFile, 'utf-8');
const credentials = JSON.parse(creds) as QwenCredentials;
client.setCredentials(credentials);
// Verify that the credentials are still valid
const { token } = await client.getAccessToken();
if (!token) {
return false;
}
return true;
} catch (_) {
return false;
}
}
async function cacheQwenCredentials(credentials: QwenCredentials) {
const filePath = getQwenCachedCredentialPath();
await fs.mkdir(path.dirname(filePath), { recursive: true });
const credString = JSON.stringify(credentials, null, 2);
await fs.writeFile(filePath, credString);
}
/**
* Clear cached Qwen credentials from disk
* This is useful when credentials have expired or need to be reset
*/
export async function clearQwenCredentials(): Promise<void> {
try {
const filePath = getQwenCachedCredentialPath();
await fs.unlink(filePath);
console.log('Cached Qwen credentials cleared successfully.');
} catch (error: unknown) {
// If file doesn't exist or can't be deleted, we consider it cleared
if (error instanceof Error && 'code' in error && error.code === 'ENOENT') {
// File doesn't exist, already cleared
return;
}
// Log other errors but don't throw - clearing credentials should be non-critical
console.warn('Warning: Failed to clear cached Qwen credentials:', error);
}
}
function getQwenCachedCredentialPath(): string {
return path.join(os.homedir(), QWEN_DIR, QWEN_CREDENTIAL_FILENAME);
}

View File

@@ -333,7 +333,7 @@ export class Config {
this.model = params.model;
this.extensionContextFilePaths = params.extensionContextFilePaths ?? [];
this.maxSessionTurns = params.maxSessionTurns ?? -1;
this.sessionTokenLimit = params.sessionTokenLimit ?? 32000;
this.sessionTokenLimit = params.sessionTokenLimit ?? -1;
this.maxFolderItems = params.maxFolderItems ?? 20;
this.experimentalAcp = params.experimentalAcp ?? false;
this.listExtensions = params.listExtensions ?? false;

View File

@@ -128,7 +128,8 @@ describe('OpenAIContentGenerator Timeout Handling', () => {
: String(thrownError);
expect(errorMessage).not.toMatch(/timeout after \d+s/);
expect(errorMessage).not.toMatch(/Troubleshooting tips:/);
expect(errorMessage).toMatch(/OpenAI API error:/);
// Should preserve the original error message
expect(errorMessage).toMatch(new RegExp(error.message));
}
}
});
@@ -161,7 +162,7 @@ describe('OpenAIContentGenerator Timeout Handling', () => {
};
await expect(generator.generateContent(request)).rejects.toThrow(
'OpenAI API error: Invalid API key',
'Invalid API key',
);
});
@@ -238,6 +239,9 @@ describe('OpenAIContentGenerator Timeout Handling', () => {
baseURL: '',
timeout: 120000,
maxRetries: 3,
defaultHeaders: {
'User-Agent': expect.stringMatching(/^QwenCode/),
},
});
});
@@ -256,6 +260,9 @@ describe('OpenAIContentGenerator Timeout Handling', () => {
baseURL: '',
timeout: 300000,
maxRetries: 5,
defaultHeaders: {
'User-Agent': expect.stringMatching(/^QwenCode/),
},
});
});
@@ -271,6 +278,9 @@ describe('OpenAIContentGenerator Timeout Handling', () => {
baseURL: '',
timeout: 120000, // default
maxRetries: 3, // default
defaultHeaders: {
'User-Agent': expect.stringMatching(/^QwenCode/),
},
});
});
});

View File

@@ -797,6 +797,11 @@ export class GeminiClient {
authType?: string,
error?: unknown,
): Promise<string | null> {
// Handle different auth types
if (authType === AuthType.QWEN_OAUTH) {
return this.handleQwenOAuthError(error);
}
// Only handle fallback for OAuth users
if (authType !== AuthType.LOGIN_WITH_GOOGLE) {
return null;
@@ -835,4 +840,59 @@ export class GeminiClient {
return null;
}
/**
* Handles Qwen OAuth authentication errors and rate limiting
*/
private async handleQwenOAuthError(error?: unknown): Promise<string | null> {
if (!error) {
return null;
}
const errorMessage =
error instanceof Error
? error.message.toLowerCase()
: String(error).toLowerCase();
const errorCode =
(error as { status?: number; code?: number })?.status ||
(error as { status?: number; code?: number })?.code;
// Check if this is an authentication/authorization error
const isAuthError =
errorCode === 401 ||
errorCode === 403 ||
errorMessage.includes('unauthorized') ||
errorMessage.includes('forbidden') ||
errorMessage.includes('invalid api key') ||
errorMessage.includes('authentication') ||
errorMessage.includes('access denied') ||
(errorMessage.includes('token') && errorMessage.includes('expired'));
// Check if this is a rate limiting error
const isRateLimitError =
errorCode === 429 ||
errorMessage.includes('429') ||
errorMessage.includes('rate limit') ||
errorMessage.includes('too many requests');
if (isAuthError) {
console.warn('Qwen OAuth authentication error detected:', errorMessage);
// The QwenContentGenerator should automatically handle token refresh
// If it still fails, it likely means the refresh token is also expired
console.log(
'Note: If this persists, you may need to re-authenticate with Qwen OAuth',
);
return null;
}
if (isRateLimitError) {
console.warn('Qwen API rate limit encountered:', errorMessage);
// For rate limiting, we don't need to do anything special
// The retry mechanism will handle the backoff
return null;
}
// For other errors, don't handle them specially
return null;
}
}

View File

@@ -46,6 +46,7 @@ export enum AuthType {
USE_VERTEX_AI = 'vertex-ai',
CLOUD_SHELL = 'cloud-shell',
USE_OPENAI = 'openai',
QWEN_OAUTH = 'qwen-oauth',
}
export type ContentGeneratorConfig = {
@@ -131,6 +132,15 @@ export function createContentGeneratorConfig(
return contentGeneratorConfig;
}
if (authType === AuthType.QWEN_OAUTH) {
// For Qwen OAuth, we'll handle the API key dynamically in createContentGenerator
// Set a special marker to indicate this is Qwen OAuth
contentGeneratorConfig.apiKey = 'QWEN_OAUTH_DYNAMIC_TOKEN';
contentGeneratorConfig.model = config.getModel() || DEFAULT_GEMINI_MODEL;
return contentGeneratorConfig;
}
return contentGeneratorConfig;
}
@@ -184,6 +194,30 @@ export async function createContentGenerator(
return new OpenAIContentGenerator(config.apiKey, config.model, gcConfig);
}
if (config.authType === AuthType.QWEN_OAUTH) {
if (config.apiKey !== 'QWEN_OAUTH_DYNAMIC_TOKEN') {
throw new Error('Invalid Qwen OAuth configuration');
}
// Import required classes dynamically
const { getQwenOAuthClient: getQwenOauthClient } = await import(
'../code_assist/qwenOAuth2.js'
);
const { QwenContentGenerator } = await import('./qwenContentGenerator.js');
try {
// Get the Qwen OAuth client (now includes integrated token management)
const qwenClient = await getQwenOauthClient(gcConfig);
// Create the content generator with dynamic token management
return new QwenContentGenerator(qwenClient, config.model, gcConfig);
} catch (error) {
throw new Error(
`Failed to initialize Qwen: ${error instanceof Error ? error.message : String(error)}`,
);
}
}
throw new Error(
`Error creating contentGenerator: Unsupported authType: ${config.authType}`,
);

View File

@@ -201,6 +201,11 @@ export class GeminiChat {
authType?: string,
error?: unknown,
): Promise<string | null> {
// Handle different auth types
if (authType === AuthType.QWEN_OAUTH) {
return this.handleQwenOAuthError(error);
}
// Only handle fallback for OAuth users
if (authType !== AuthType.LOGIN_WITH_GOOGLE) {
return null;
@@ -674,4 +679,59 @@ export class GeminiChat {
content.parts[0].thought === true
);
}
/**
* Handles Qwen OAuth authentication errors and rate limiting
*/
private async handleQwenOAuthError(error?: unknown): Promise<string | null> {
if (!error) {
return null;
}
const errorMessage =
error instanceof Error
? error.message.toLowerCase()
: String(error).toLowerCase();
const errorCode =
(error as { status?: number; code?: number })?.status ||
(error as { status?: number; code?: number })?.code;
// Check if this is an authentication/authorization error
const isAuthError =
errorCode === 401 ||
errorCode === 403 ||
errorMessage.includes('unauthorized') ||
errorMessage.includes('forbidden') ||
errorMessage.includes('invalid api key') ||
errorMessage.includes('authentication') ||
errorMessage.includes('access denied') ||
(errorMessage.includes('token') && errorMessage.includes('expired'));
// Check if this is a rate limiting error
const isRateLimitError =
errorCode === 429 ||
errorMessage.includes('429') ||
errorMessage.includes('rate limit') ||
errorMessage.includes('too many requests');
if (isAuthError) {
console.warn('Qwen OAuth authentication error detected:', errorMessage);
// The QwenContentGenerator should automatically handle token refresh
// If it still fails, it likely means the refresh token is also expired
console.log(
'Note: If this persists, you may need to re-authenticate with Qwen OAuth',
);
return null;
}
if (isRateLimitError) {
console.warn('Qwen API rate limit encountered:', errorMessage);
// For rate limiting, we don't need to do anything special
// The retry mechanism will handle the backoff
return null;
}
// For other errors, don't handle them specially
return null;
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,6 +1,6 @@
/**
* @license
* Copyright 2025 Google LLC
* Copyright 2025 Qwen
* SPDX-License-Identifier: Apache-2.0
*/
@@ -78,7 +78,7 @@ interface OpenAIResponseFormat {
}
export class OpenAIContentGenerator implements ContentGenerator {
private client: OpenAI;
protected client: OpenAI;
private model: string;
private config: Config;
private streamingToolCalls: Map<
@@ -114,14 +114,21 @@ export class OpenAIContentGenerator implements ContentGenerator {
timeoutConfig.maxRetries = contentGeneratorConfig.maxRetries;
}
// Set up User-Agent header (same format as contentGenerator.ts)
const version = process.env.CLI_VERSION || process.version;
const userAgent = `QwenCode/${version} (${process.platform}; ${process.arch})`;
// Check if using OpenRouter and add required headers
const isOpenRouter = baseURL.includes('openrouter.ai');
const defaultHeaders = isOpenRouter
? {
'HTTP-Referer': 'https://github.com/QwenLM/qwen-code.git',
'X-Title': 'Qwen Code',
}
: undefined;
const defaultHeaders = {
'User-Agent': userAgent,
...(isOpenRouter
? {
'HTTP-Referer': 'https://github.com/QwenLM/qwen-code.git',
'X-Title': 'Qwen Code',
}
: {}),
};
this.client = new OpenAI({
apiKey,
@@ -132,6 +139,19 @@ export class OpenAIContentGenerator implements ContentGenerator {
});
}
/**
* Hook for subclasses to customize error handling behavior
* @param error The error that occurred
* @param request The original request
* @returns true if error logging should be suppressed, false otherwise
*/
protected shouldSuppressErrorLogging(
_error: unknown,
_request: GenerateContentParameters,
): boolean {
return false; // Default behavior: never suppress error logging
}
/**
* Check if an error is a timeout error
*/
@@ -275,7 +295,10 @@ export class OpenAIContentGenerator implements ContentGenerator {
);
}
console.error('OpenAI API Error:', errorMessage);
// Allow subclasses to suppress error logging for specific scenarios
if (!this.shouldSuppressErrorLogging(error, request)) {
console.error('OpenAI API Error:', errorMessage);
}
// Provide helpful timeout-specific error message
if (isTimeoutError) {
@@ -288,7 +311,7 @@ export class OpenAIContentGenerator implements ContentGenerator {
);
}
throw new Error(`OpenAI API error: ${errorMessage}`);
throw error;
}
}
@@ -486,7 +509,10 @@ export class OpenAIContentGenerator implements ContentGenerator {
);
logApiResponse(this.config, errorEvent);
console.error('OpenAI API Streaming Error:', errorMessage);
// Allow subclasses to suppress error logging for specific scenarios
if (!this.shouldSuppressErrorLogging(error, request)) {
console.error('OpenAI API Streaming Error:', errorMessage);
}
// Provide helpful timeout-specific error message for streaming setup
if (isTimeoutError) {
@@ -499,7 +525,7 @@ export class OpenAIContentGenerator implements ContentGenerator {
);
}
throw new Error(`OpenAI API error: ${errorMessage}`);
throw error;
}
}

View File

@@ -0,0 +1,794 @@
/**
* @license
* Copyright 2025 Qwen
* SPDX-License-Identifier: Apache-2.0
*/
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
import {
IQwenOAuth2Client,
QwenCredentials,
ErrorData,
} from '../code_assist/qwenOAuth2.js';
import {
GenerateContentParameters,
GenerateContentResponse,
CountTokensParameters,
CountTokensResponse,
EmbedContentParameters,
EmbedContentResponse,
FinishReason,
} from '@google/genai';
import { QwenContentGenerator } from './qwenContentGenerator.js';
import { Config } from '../config/config.js';
// Mock the OpenAIContentGenerator parent class
vi.mock('./openaiContentGenerator.js', () => ({
OpenAIContentGenerator: class {
client: {
apiKey: string;
baseURL: string;
};
constructor(apiKey: string, _model: string, _config: Config) {
this.client = {
apiKey,
baseURL: 'https://api.openai.com/v1',
};
}
async generateContent(
_request: GenerateContentParameters,
): Promise<GenerateContentResponse> {
return createMockResponse('Generated content');
}
async generateContentStream(
_request: GenerateContentParameters,
): Promise<AsyncGenerator<GenerateContentResponse>> {
return (async function* () {
yield createMockResponse('Stream chunk 1');
yield createMockResponse('Stream chunk 2');
})();
}
async countTokens(
_request: CountTokensParameters,
): Promise<CountTokensResponse> {
return { totalTokens: 10 };
}
async embedContent(
_request: EmbedContentParameters,
): Promise<EmbedContentResponse> {
return { embeddings: [{ values: [0.1, 0.2, 0.3] }] };
}
protected shouldSuppressErrorLogging(
_error: unknown,
_request: GenerateContentParameters,
): boolean {
return false;
}
},
}));
const createMockResponse = (text: string): GenerateContentResponse =>
({
candidates: [
{
content: { role: 'model', parts: [{ text }] },
finishReason: FinishReason.STOP,
index: 0,
safetyRatings: [],
},
],
promptFeedback: { safetyRatings: [] },
text,
data: undefined,
functionCalls: [],
executableCode: '',
codeExecutionResult: '',
}) as GenerateContentResponse;
describe('QwenContentGenerator', () => {
let mockQwenClient: IQwenOAuth2Client;
let qwenContentGenerator: QwenContentGenerator;
let mockConfig: Config;
const mockCredentials: QwenCredentials = {
access_token: 'test-access-token',
refresh_token: 'test-refresh-token',
resource_url: 'https://test-endpoint.com/v1',
};
beforeEach(() => {
vi.clearAllMocks();
// Mock Config
mockConfig = {} as Config;
// Mock QwenOAuth2Client
mockQwenClient = {
getAccessToken: vi.fn(),
getCredentials: vi.fn(),
setCredentials: vi.fn(),
refreshAccessToken: vi.fn(),
requestDeviceAuthorization: vi.fn(),
pollDeviceToken: vi.fn(),
};
// Create QwenContentGenerator instance
qwenContentGenerator = new QwenContentGenerator(
mockQwenClient,
'qwen-turbo',
mockConfig,
);
});
afterEach(() => {
vi.restoreAllMocks();
});
describe('Core Content Generation Methods', () => {
it('should generate content with valid token', async () => {
vi.mocked(mockQwenClient.getAccessToken).mockResolvedValue({
token: 'valid-token',
});
vi.mocked(mockQwenClient.getCredentials).mockReturnValue(mockCredentials);
const request: GenerateContentParameters = {
model: 'qwen-turbo',
contents: [{ role: 'user', parts: [{ text: 'Hello' }] }],
};
const result = await qwenContentGenerator.generateContent(request);
expect(result.text).toBe('Generated content');
expect(mockQwenClient.getAccessToken).toHaveBeenCalled();
});
it('should generate content stream with valid token', async () => {
vi.mocked(mockQwenClient.getAccessToken).mockResolvedValue({
token: 'valid-token',
});
vi.mocked(mockQwenClient.getCredentials).mockReturnValue(mockCredentials);
const request: GenerateContentParameters = {
model: 'qwen-turbo',
contents: [{ role: 'user', parts: [{ text: 'Hello stream' }] }],
};
const stream = await qwenContentGenerator.generateContentStream(request);
const chunks: string[] = [];
for await (const chunk of stream) {
chunks.push(chunk.text || '');
}
expect(chunks).toEqual(['Stream chunk 1', 'Stream chunk 2']);
expect(mockQwenClient.getAccessToken).toHaveBeenCalled();
});
it('should count tokens with valid token', async () => {
vi.mocked(mockQwenClient.getAccessToken).mockResolvedValue({
token: 'valid-token',
});
vi.mocked(mockQwenClient.getCredentials).mockReturnValue(mockCredentials);
const request: CountTokensParameters = {
model: 'qwen-turbo',
contents: [{ role: 'user', parts: [{ text: 'Count me' }] }],
};
const result = await qwenContentGenerator.countTokens(request);
expect(result.totalTokens).toBe(10);
expect(mockQwenClient.getAccessToken).toHaveBeenCalled();
});
it('should embed content with valid token', async () => {
vi.mocked(mockQwenClient.getAccessToken).mockResolvedValue({
token: 'valid-token',
});
vi.mocked(mockQwenClient.getCredentials).mockReturnValue(mockCredentials);
const request: EmbedContentParameters = {
model: 'qwen-turbo',
contents: [{ parts: [{ text: 'Embed me' }] }],
};
const result = await qwenContentGenerator.embedContent(request);
expect(result.embeddings).toHaveLength(1);
expect(result.embeddings?.[0]?.values).toEqual([0.1, 0.2, 0.3]);
expect(mockQwenClient.getAccessToken).toHaveBeenCalled();
});
});
describe('Token Management and Refresh Logic', () => {
it('should refresh token on auth error and retry', async () => {
const authError = { status: 401, message: 'Unauthorized' };
// First call fails with auth error
vi.mocked(mockQwenClient.getAccessToken).mockRejectedValueOnce(authError);
// Refresh succeeds
vi.mocked(mockQwenClient.refreshAccessToken).mockResolvedValue({
access_token: 'refreshed-token',
token_type: 'Bearer',
expires_in: 3600,
resource_url: 'https://refreshed-endpoint.com',
});
const request: GenerateContentParameters = {
model: 'qwen-turbo',
contents: [{ role: 'user', parts: [{ text: 'Hello' }] }],
};
const result = await qwenContentGenerator.generateContent(request);
expect(result.text).toBe('Generated content');
expect(mockQwenClient.refreshAccessToken).toHaveBeenCalled();
});
it('should handle token refresh failure', async () => {
vi.mocked(mockQwenClient.getAccessToken).mockRejectedValue(
new Error('Token expired'),
);
vi.mocked(mockQwenClient.refreshAccessToken).mockRejectedValue(
new Error('Refresh failed'),
);
const request: GenerateContentParameters = {
model: 'qwen-turbo',
contents: [{ role: 'user', parts: [{ text: 'Hello' }] }],
};
await expect(
qwenContentGenerator.generateContent(request),
).rejects.toThrow(
'Failed to obtain valid Qwen access token. Please re-authenticate.',
);
});
it('should update endpoint when token is refreshed', async () => {
vi.mocked(mockQwenClient.getAccessToken).mockResolvedValue({
token: 'valid-token',
});
vi.mocked(mockQwenClient.getCredentials).mockReturnValue({
...mockCredentials,
resource_url: 'https://new-endpoint.com',
});
const request: GenerateContentParameters = {
model: 'qwen-turbo',
contents: [{ role: 'user', parts: [{ text: 'Hello' }] }],
};
await qwenContentGenerator.generateContent(request);
expect(mockQwenClient.getCredentials).toHaveBeenCalled();
});
});
describe('Endpoint URL Normalization', () => {
it('should use default endpoint when no custom endpoint provided', async () => {
let capturedBaseURL = '';
vi.mocked(mockQwenClient.getAccessToken).mockResolvedValue({
token: 'valid-token',
});
vi.mocked(mockQwenClient.getCredentials).mockReturnValue({
access_token: 'test-token',
refresh_token: 'test-refresh',
// No resource_url provided
});
// Mock the parent's generateContent to capture the baseURL during the call
const parentPrototype = Object.getPrototypeOf(
Object.getPrototypeOf(qwenContentGenerator),
);
const originalGenerateContent = parentPrototype.generateContent;
parentPrototype.generateContent = vi.fn().mockImplementation(function (
this: QwenContentGenerator,
) {
capturedBaseURL = (this as unknown as { client: { baseURL: string } })
.client.baseURL;
return createMockResponse('Generated content');
});
const request: GenerateContentParameters = {
model: 'qwen-turbo',
contents: [{ role: 'user', parts: [{ text: 'Hello' }] }],
};
await qwenContentGenerator.generateContent(request);
// Should use default endpoint with /v1 suffix
expect(capturedBaseURL).toBe(
'https://dashscope.aliyuncs.com/compatible-mode/v1',
);
// Restore original method
parentPrototype.generateContent = originalGenerateContent;
});
it('should normalize hostname-only endpoints by adding https protocol', async () => {
let capturedBaseURL = '';
vi.mocked(mockQwenClient.getAccessToken).mockResolvedValue({
token: 'valid-token',
});
vi.mocked(mockQwenClient.getCredentials).mockReturnValue({
...mockCredentials,
resource_url: 'custom-endpoint.com',
});
// Mock the parent's generateContent to capture the baseURL during the call
const parentPrototype = Object.getPrototypeOf(
Object.getPrototypeOf(qwenContentGenerator),
);
const originalGenerateContent = parentPrototype.generateContent;
parentPrototype.generateContent = vi.fn().mockImplementation(function (
this: QwenContentGenerator,
) {
capturedBaseURL = (this as unknown as { client: { baseURL: string } })
.client.baseURL;
return createMockResponse('Generated content');
});
const request: GenerateContentParameters = {
model: 'qwen-turbo',
contents: [{ role: 'user', parts: [{ text: 'Hello' }] }],
};
await qwenContentGenerator.generateContent(request);
// Should add https:// and /v1
expect(capturedBaseURL).toBe('https://custom-endpoint.com/v1');
// Restore original method
parentPrototype.generateContent = originalGenerateContent;
});
it('should preserve existing protocol in endpoint URLs', async () => {
let capturedBaseURL = '';
vi.mocked(mockQwenClient.getAccessToken).mockResolvedValue({
token: 'valid-token',
});
vi.mocked(mockQwenClient.getCredentials).mockReturnValue({
...mockCredentials,
resource_url: 'https://custom-endpoint.com',
});
// Mock the parent's generateContent to capture the baseURL during the call
const parentPrototype = Object.getPrototypeOf(
Object.getPrototypeOf(qwenContentGenerator),
);
const originalGenerateContent = parentPrototype.generateContent;
parentPrototype.generateContent = vi.fn().mockImplementation(function (
this: QwenContentGenerator,
) {
capturedBaseURL = (this as unknown as { client: { baseURL: string } })
.client.baseURL;
return createMockResponse('Generated content');
});
const request: GenerateContentParameters = {
model: 'qwen-turbo',
contents: [{ role: 'user', parts: [{ text: 'Hello' }] }],
};
await qwenContentGenerator.generateContent(request);
// Should preserve https:// and add /v1
expect(capturedBaseURL).toBe('https://custom-endpoint.com/v1');
// Restore original method
parentPrototype.generateContent = originalGenerateContent;
});
it('should not duplicate /v1 suffix if already present', async () => {
let capturedBaseURL = '';
vi.mocked(mockQwenClient.getAccessToken).mockResolvedValue({
token: 'valid-token',
});
vi.mocked(mockQwenClient.getCredentials).mockReturnValue({
...mockCredentials,
resource_url: 'https://custom-endpoint.com/v1',
});
// Mock the parent's generateContent to capture the baseURL during the call
const parentPrototype = Object.getPrototypeOf(
Object.getPrototypeOf(qwenContentGenerator),
);
const originalGenerateContent = parentPrototype.generateContent;
parentPrototype.generateContent = vi.fn().mockImplementation(function (
this: QwenContentGenerator,
) {
capturedBaseURL = (this as unknown as { client: { baseURL: string } })
.client.baseURL;
return createMockResponse('Generated content');
});
const request: GenerateContentParameters = {
model: 'qwen-turbo',
contents: [{ role: 'user', parts: [{ text: 'Hello' }] }],
};
await qwenContentGenerator.generateContent(request);
// Should not duplicate /v1
expect(capturedBaseURL).toBe('https://custom-endpoint.com/v1');
// Restore original method
parentPrototype.generateContent = originalGenerateContent;
});
});
describe('Client State Management', () => {
it('should restore original client credentials after operations', async () => {
const client = (
qwenContentGenerator as unknown as {
client: { apiKey: string; baseURL: string };
}
).client;
const originalApiKey = client.apiKey;
const originalBaseURL = client.baseURL;
vi.mocked(mockQwenClient.getAccessToken).mockResolvedValue({
token: 'temp-token',
});
vi.mocked(mockQwenClient.getCredentials).mockReturnValue({
...mockCredentials,
resource_url: 'https://temp-endpoint.com',
});
const request: GenerateContentParameters = {
model: 'qwen-turbo',
contents: [{ role: 'user', parts: [{ text: 'Hello' }] }],
};
await qwenContentGenerator.generateContent(request);
// Should restore original values after operation
expect(client.apiKey).toBe(originalApiKey);
expect(client.baseURL).toBe(originalBaseURL);
});
it('should restore credentials even when operation throws', async () => {
const client = (
qwenContentGenerator as unknown as {
client: { apiKey: string; baseURL: string };
}
).client;
const originalApiKey = client.apiKey;
const originalBaseURL = client.baseURL;
vi.mocked(mockQwenClient.getAccessToken).mockResolvedValue({
token: 'temp-token',
});
vi.mocked(mockQwenClient.getCredentials).mockReturnValue(mockCredentials);
// Mock the parent method to throw an error
const mockError = new Error('Network error');
const parentPrototype = Object.getPrototypeOf(
Object.getPrototypeOf(qwenContentGenerator),
);
const originalGenerateContent = parentPrototype.generateContent;
parentPrototype.generateContent = vi.fn().mockRejectedValue(mockError);
const request: GenerateContentParameters = {
model: 'qwen-turbo',
contents: [{ role: 'user', parts: [{ text: 'Hello' }] }],
};
try {
await qwenContentGenerator.generateContent(request);
} catch (error) {
expect(error).toBe(mockError);
}
// Credentials should still be restored
expect(client.apiKey).toBe(originalApiKey);
expect(client.baseURL).toBe(originalBaseURL);
// Restore original method
parentPrototype.generateContent = originalGenerateContent;
});
});
describe('Error Handling and Retry Logic', () => {
it('should retry once on authentication errors', async () => {
const authError = { status: 401, message: 'Unauthorized' };
// Mock first call to fail with auth error
const mockGenerateContent = vi
.fn()
.mockRejectedValueOnce(authError)
.mockResolvedValueOnce(createMockResponse('Success after retry'));
// Replace the parent method
const parentPrototype = Object.getPrototypeOf(
Object.getPrototypeOf(qwenContentGenerator),
);
const originalGenerateContent = parentPrototype.generateContent;
parentPrototype.generateContent = mockGenerateContent;
vi.mocked(mockQwenClient.getAccessToken).mockResolvedValue({
token: 'initial-token',
});
vi.mocked(mockQwenClient.getCredentials).mockReturnValue(mockCredentials);
vi.mocked(mockQwenClient.refreshAccessToken).mockResolvedValue({
access_token: 'refreshed-token',
token_type: 'Bearer',
expires_in: 3600,
});
const request: GenerateContentParameters = {
model: 'qwen-turbo',
contents: [{ role: 'user', parts: [{ text: 'Hello' }] }],
};
const result = await qwenContentGenerator.generateContent(request);
expect(result.text).toBe('Success after retry');
expect(mockGenerateContent).toHaveBeenCalledTimes(2);
expect(mockQwenClient.refreshAccessToken).toHaveBeenCalled();
// Restore original method
parentPrototype.generateContent = originalGenerateContent;
});
it('should not retry non-authentication errors', async () => {
const networkError = new Error('Network timeout');
const mockGenerateContent = vi.fn().mockRejectedValue(networkError);
const parentPrototype = Object.getPrototypeOf(
Object.getPrototypeOf(qwenContentGenerator),
);
const originalGenerateContent = parentPrototype.generateContent;
parentPrototype.generateContent = mockGenerateContent;
vi.mocked(mockQwenClient.getAccessToken).mockResolvedValue({
token: 'valid-token',
});
vi.mocked(mockQwenClient.getCredentials).mockReturnValue(mockCredentials);
const request: GenerateContentParameters = {
model: 'qwen-turbo',
contents: [{ role: 'user', parts: [{ text: 'Hello' }] }],
};
await expect(
qwenContentGenerator.generateContent(request),
).rejects.toThrow('Network timeout');
expect(mockGenerateContent).toHaveBeenCalledTimes(1);
expect(mockQwenClient.refreshAccessToken).not.toHaveBeenCalled();
// Restore original method
parentPrototype.generateContent = originalGenerateContent;
});
it('should handle error response from token refresh', async () => {
vi.mocked(mockQwenClient.getAccessToken).mockRejectedValue(
new Error('Token expired'),
);
vi.mocked(mockQwenClient.refreshAccessToken).mockResolvedValue({
error: 'invalid_grant',
error_description: 'Refresh token expired',
} as ErrorData);
const request: GenerateContentParameters = {
model: 'qwen-turbo',
contents: [{ role: 'user', parts: [{ text: 'Hello' }] }],
};
await expect(
qwenContentGenerator.generateContent(request),
).rejects.toThrow('Failed to obtain valid Qwen access token');
});
});
describe('Token State Management', () => {
it('should cache and return current token', () => {
expect(qwenContentGenerator.getCurrentToken()).toBeNull();
// Simulate setting a token internally
(
qwenContentGenerator as unknown as { currentToken: string }
).currentToken = 'cached-token';
expect(qwenContentGenerator.getCurrentToken()).toBe('cached-token');
});
it('should clear token and endpoint on clearToken()', () => {
// Simulate having cached values
const qwenInstance = qwenContentGenerator as unknown as {
currentToken: string;
currentEndpoint: string;
refreshPromise: Promise<string>;
};
qwenInstance.currentToken = 'cached-token';
qwenInstance.currentEndpoint = 'https://cached-endpoint.com';
qwenInstance.refreshPromise = Promise.resolve('token');
qwenContentGenerator.clearToken();
expect(qwenContentGenerator.getCurrentToken()).toBeNull();
expect(
(qwenContentGenerator as unknown as { currentEndpoint: string | null })
.currentEndpoint,
).toBeNull();
expect(
(
qwenContentGenerator as unknown as {
refreshPromise: Promise<string> | null;
}
).refreshPromise,
).toBeNull();
});
it('should handle concurrent token refresh requests', async () => {
let refreshCallCount = 0;
// Clear any existing cached token first
qwenContentGenerator.clearToken();
// Mock to simulate auth error on first parent call, which should trigger refresh
const authError = { status: 401, message: 'Unauthorized' };
let parentCallCount = 0;
vi.mocked(mockQwenClient.getAccessToken).mockResolvedValue({
token: 'initial-token',
});
vi.mocked(mockQwenClient.getCredentials).mockReturnValue(mockCredentials);
vi.mocked(mockQwenClient.refreshAccessToken).mockImplementation(
async () => {
refreshCallCount++;
await new Promise((resolve) => setTimeout(resolve, 50)); // Longer delay to ensure concurrency
return {
access_token: 'refreshed-token',
token_type: 'Bearer',
expires_in: 3600,
};
},
);
// Mock the parent method to fail first then succeed
const parentPrototype = Object.getPrototypeOf(
Object.getPrototypeOf(qwenContentGenerator),
);
const originalGenerateContent = parentPrototype.generateContent;
parentPrototype.generateContent = vi.fn().mockImplementation(async () => {
parentCallCount++;
if (parentCallCount === 1) {
throw authError; // First call triggers auth error
}
return createMockResponse('Generated content');
});
const request: GenerateContentParameters = {
model: 'qwen-turbo',
contents: [{ role: 'user', parts: [{ text: 'Hello' }] }],
};
// Make multiple concurrent requests - should all use the same refresh promise
const promises = [
qwenContentGenerator.generateContent(request),
qwenContentGenerator.generateContent(request),
qwenContentGenerator.generateContent(request),
];
const results = await Promise.all(promises);
// All should succeed
results.forEach((result) => {
expect(result.text).toBe('Generated content');
});
// The main test is that all requests succeed without crashing
expect(results).toHaveLength(3);
expect(refreshCallCount).toBeGreaterThanOrEqual(1);
// Restore original method
parentPrototype.generateContent = originalGenerateContent;
});
});
describe('Error Logging Suppression', () => {
it('should suppress logging for authentication errors', () => {
const authErrors = [
{ status: 401 },
{ code: 403 },
new Error('Unauthorized access'),
new Error('Token expired'),
new Error('Invalid API key'),
];
authErrors.forEach((error) => {
const shouldSuppress = (
qwenContentGenerator as unknown as {
shouldSuppressErrorLogging: (
error: unknown,
request: GenerateContentParameters,
) => boolean;
}
).shouldSuppressErrorLogging(error, {} as GenerateContentParameters);
expect(shouldSuppress).toBe(true);
});
});
it('should not suppress logging for non-auth errors', () => {
const nonAuthErrors = [
new Error('Network timeout'),
new Error('Rate limit exceeded'),
{ status: 500 },
new Error('Internal server error'),
];
nonAuthErrors.forEach((error) => {
const shouldSuppress = (
qwenContentGenerator as unknown as {
shouldSuppressErrorLogging: (
error: unknown,
request: GenerateContentParameters,
) => boolean;
}
).shouldSuppressErrorLogging(error, {} as GenerateContentParameters);
expect(shouldSuppress).toBe(false);
});
});
});
describe('Integration Tests', () => {
it('should handle complete workflow: get token, use it, refresh on auth error, retry', async () => {
const authError = { status: 401, message: 'Token expired' };
// Setup complex scenario
let callCount = 0;
const mockGenerateContent = vi.fn().mockImplementation(async () => {
callCount++;
if (callCount === 1) {
throw authError; // First call fails
}
return createMockResponse('Success after refresh'); // Second call succeeds
});
const parentPrototype = Object.getPrototypeOf(
Object.getPrototypeOf(qwenContentGenerator),
);
parentPrototype.generateContent = mockGenerateContent;
vi.mocked(mockQwenClient.getAccessToken).mockResolvedValue({
token: 'initial-token',
});
vi.mocked(mockQwenClient.getCredentials).mockReturnValue({
...mockCredentials,
resource_url: 'custom-endpoint.com',
});
vi.mocked(mockQwenClient.refreshAccessToken).mockResolvedValue({
access_token: 'new-token',
token_type: 'Bearer',
expires_in: 7200,
resource_url: 'https://new-endpoint.com',
});
const request: GenerateContentParameters = {
model: 'qwen-turbo',
contents: [{ role: 'user', parts: [{ text: 'Test message' }] }],
};
const result = await qwenContentGenerator.generateContent(request);
expect(result.text).toBe('Success after refresh');
expect(mockQwenClient.getAccessToken).toHaveBeenCalled();
expect(mockQwenClient.refreshAccessToken).toHaveBeenCalled();
expect(callCount).toBe(2); // Initial call + retry
});
});
});

View File

@@ -0,0 +1,356 @@
/**
* @license
* Copyright 2025 Qwen
* SPDX-License-Identifier: Apache-2.0
*/
import { OpenAIContentGenerator } from './openaiContentGenerator.js';
import {
IQwenOAuth2Client,
type TokenRefreshData,
type ErrorData,
isErrorResponse,
} from '../code_assist/qwenOAuth2.js';
import { Config } from '../config/config.js';
import {
GenerateContentParameters,
GenerateContentResponse,
CountTokensParameters,
CountTokensResponse,
EmbedContentParameters,
EmbedContentResponse,
} from '@google/genai';
// Default fallback base URL if no endpoint is provided
const DEFAULT_QWEN_BASE_URL =
'https://dashscope.aliyuncs.com/compatible-mode/v1';
/**
* Qwen Content Generator that uses Qwen OAuth tokens with automatic refresh
*/
export class QwenContentGenerator extends OpenAIContentGenerator {
private qwenClient: IQwenOAuth2Client;
// Token management (integrated from QwenTokenManager)
private currentToken: string | null = null;
private currentEndpoint: string | null = null;
private refreshPromise: Promise<string> | null = null;
constructor(qwenClient: IQwenOAuth2Client, model: string, config: Config) {
// Initialize with empty API key, we'll override it dynamically
super('', model, config);
this.qwenClient = qwenClient;
// Set default base URL, will be updated dynamically
this.client.baseURL = DEFAULT_QWEN_BASE_URL;
}
/**
* Get the current endpoint URL with proper protocol and /v1 suffix
*/
private getCurrentEndpoint(): string {
const baseEndpoint = this.currentEndpoint || DEFAULT_QWEN_BASE_URL;
const suffix = '/v1';
// Normalize the URL: add protocol if missing, ensure /v1 suffix
const normalizedUrl = baseEndpoint.startsWith('http')
? baseEndpoint
: `https://${baseEndpoint}`;
return normalizedUrl.endsWith(suffix)
? normalizedUrl
: `${normalizedUrl}${suffix}`;
}
/**
* Override error logging behavior to suppress auth errors during token refresh
*/
protected shouldSuppressErrorLogging(
error: unknown,
_request: GenerateContentParameters,
): boolean {
// Suppress logging for authentication errors that we handle with token refresh
return this.isAuthError(error);
}
/**
* Override to use dynamic token and endpoint
*/
async generateContent(
request: GenerateContentParameters,
): Promise<GenerateContentResponse> {
return this.withValidToken(async (token) => {
// Temporarily update the API key and base URL
const originalApiKey = this.client.apiKey;
const originalBaseURL = this.client.baseURL;
this.client.apiKey = token;
this.client.baseURL = this.getCurrentEndpoint();
try {
return await super.generateContent(request);
} finally {
// Restore original values
this.client.apiKey = originalApiKey;
this.client.baseURL = originalBaseURL;
}
});
}
/**
* Override to use dynamic token and endpoint
*/
async generateContentStream(
request: GenerateContentParameters,
): Promise<AsyncGenerator<GenerateContentResponse>> {
return this.withValidTokenForStream(async (token) => {
// Update the API key and base URL before streaming
const originalApiKey = this.client.apiKey;
const originalBaseURL = this.client.baseURL;
this.client.apiKey = token;
this.client.baseURL = this.getCurrentEndpoint();
try {
return await super.generateContentStream(request);
} catch (error) {
// Restore original values on error
this.client.apiKey = originalApiKey;
this.client.baseURL = originalBaseURL;
throw error;
}
// Note: We don't restore the values in finally for streaming because
// the generator may continue to be used after this method returns
});
}
/**
* Override to use dynamic token and endpoint
*/
async countTokens(
request: CountTokensParameters,
): Promise<CountTokensResponse> {
return this.withValidToken(async (token) => {
const originalApiKey = this.client.apiKey;
const originalBaseURL = this.client.baseURL;
this.client.apiKey = token;
this.client.baseURL = this.getCurrentEndpoint();
try {
return await super.countTokens(request);
} finally {
this.client.apiKey = originalApiKey;
this.client.baseURL = originalBaseURL;
}
});
}
/**
* Override to use dynamic token and endpoint
*/
async embedContent(
request: EmbedContentParameters,
): Promise<EmbedContentResponse> {
return this.withValidToken(async (token) => {
const originalApiKey = this.client.apiKey;
const originalBaseURL = this.client.baseURL;
this.client.apiKey = token;
this.client.baseURL = this.getCurrentEndpoint();
try {
return await super.embedContent(request);
} finally {
this.client.apiKey = originalApiKey;
this.client.baseURL = originalBaseURL;
}
});
}
/**
* Execute operation with a valid token, with retry on auth failure
*/
private async withValidToken<T>(
operation: (token: string) => Promise<T>,
): Promise<T> {
const token = await this.getTokenWithRetry();
try {
return await operation(token);
} catch (error) {
// Check if this is an authentication error
if (this.isAuthError(error)) {
// Refresh token and retry once silently
const newToken = await this.refreshToken();
return await operation(newToken);
}
throw error;
}
}
/**
* Execute operation with a valid token for streaming, with retry on auth failure
*/
private async withValidTokenForStream<T>(
operation: (token: string) => Promise<T>,
): Promise<T> {
const token = await this.getTokenWithRetry();
try {
return await operation(token);
} catch (error) {
// Check if this is an authentication error
if (this.isAuthError(error)) {
// Refresh token and retry once silently
const newToken = await this.refreshToken();
return await operation(newToken);
}
throw error;
}
}
/**
* Get token with retry logic
*/
private async getTokenWithRetry(): Promise<string> {
try {
return await this.getValidToken();
} catch (error) {
console.error('Failed to get valid token:', error);
throw new Error(
'Failed to obtain valid Qwen access token. Please re-authenticate.',
);
}
}
// Token management methods (integrated from QwenTokenManager)
/**
* Get a valid access token, refreshing if necessary
*/
private async getValidToken(): Promise<string> {
// If there's already a refresh in progress, wait for it
if (this.refreshPromise) {
return this.refreshPromise;
}
try {
const { token } = await this.qwenClient.getAccessToken();
if (token) {
this.currentToken = token;
// Also update endpoint from current credentials
const credentials = this.qwenClient.getCredentials();
if (credentials.resource_url) {
this.currentEndpoint = credentials.resource_url;
}
return token;
}
} catch (error) {
console.warn('Failed to get access token, attempting refresh:', error);
}
// Start a new refresh operation
this.refreshPromise = this.performTokenRefresh();
try {
const newToken = await this.refreshPromise;
return newToken;
} finally {
this.refreshPromise = null;
}
}
/**
* Force refresh the access token
*/
private async refreshToken(): Promise<string> {
this.refreshPromise = this.performTokenRefresh();
try {
const newToken = await this.refreshPromise;
return newToken;
} finally {
this.refreshPromise = null;
}
}
private async performTokenRefresh(): Promise<string> {
try {
const response = await this.qwenClient.refreshAccessToken();
if (isErrorResponse(response)) {
const errorData = response as ErrorData;
throw new Error(
`${errorData?.error || 'Unknown error'} - ${errorData?.error_description || 'No details provided'}`,
);
}
const tokenData = response as TokenRefreshData;
if (!tokenData.access_token) {
throw new Error('Failed to refresh access token: no token returned');
}
this.currentToken = tokenData.access_token;
// Update endpoint if provided
if (tokenData.resource_url) {
this.currentEndpoint = tokenData.resource_url;
}
return tokenData.access_token;
} catch (error) {
throw new Error(
`${error instanceof Error ? error.message : String(error)}`,
);
}
}
/**
* Check if an error is related to authentication/authorization
*/
private isAuthError(error: unknown): boolean {
if (!error) return false;
const errorMessage =
error instanceof Error
? error.message.toLowerCase()
: String(error).toLowerCase();
// Define a type for errors that might have status or code properties
const errorWithCode = error as {
status?: number | string;
code?: number | string;
};
const errorCode = errorWithCode?.status || errorWithCode?.code;
return (
errorCode === 400 ||
errorCode === 401 ||
errorCode === 403 ||
errorMessage.includes('unauthorized') ||
errorMessage.includes('forbidden') ||
errorMessage.includes('invalid api key') ||
errorMessage.includes('invalid access token') ||
errorMessage.includes('token expired') ||
errorMessage.includes('authentication') ||
errorMessage.includes('access denied') ||
(errorMessage.includes('token') && errorMessage.includes('expired'))
);
}
/**
* Get the current cached token (may be expired)
*/
getCurrentToken(): string | null {
return this.currentToken;
}
/**
* Clear the cached token and endpoint
*/
clearToken(): void {
this.currentToken = null;
this.currentEndpoint = null;
this.refreshPromise = null;
}
}

View File

@@ -21,6 +21,7 @@ export * from './core/nonInteractiveToolExecutor.js';
export * from './code_assist/codeAssist.js';
export * from './code_assist/oauth2.js';
export * from './code_assist/qwenOAuth2.js';
export * from './code_assist/server.js';
export * from './code_assist/types.js';

View File

@@ -0,0 +1,194 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { describe, it, expect } from 'vitest';
import {
isQwenQuotaExceededError,
isQwenThrottlingError,
isProQuotaExceededError,
isGenericQuotaExceededError,
isApiError,
isStructuredError,
type ApiError,
} from './quotaErrorDetection.js';
describe('quotaErrorDetection', () => {
describe('isQwenQuotaExceededError', () => {
it('should detect insufficient_quota error message', () => {
const error = new Error('insufficient_quota');
expect(isQwenQuotaExceededError(error)).toBe(true);
});
it('should detect free allocated quota exceeded error message', () => {
const error = new Error('Free allocated quota exceeded.');
expect(isQwenQuotaExceededError(error)).toBe(true);
});
it('should detect quota exceeded error message', () => {
const error = new Error('quota exceeded');
expect(isQwenQuotaExceededError(error)).toBe(true);
});
it('should detect quota exceeded in string error', () => {
const error = 'insufficient_quota';
expect(isQwenQuotaExceededError(error)).toBe(true);
});
it('should detect quota exceeded in structured error', () => {
const error = { message: 'Free allocated quota exceeded.', status: 429 };
expect(isQwenQuotaExceededError(error)).toBe(true);
});
it('should detect quota exceeded in API error', () => {
const error: ApiError = {
error: {
code: 429,
message: 'insufficient_quota',
status: 'RESOURCE_EXHAUSTED',
details: [],
},
};
expect(isQwenQuotaExceededError(error)).toBe(true);
});
it('should not detect throttling errors as quota exceeded', () => {
const error = new Error('requests throttling triggered');
expect(isQwenQuotaExceededError(error)).toBe(false);
});
it('should not detect unrelated errors', () => {
const error = new Error('Network error');
expect(isQwenQuotaExceededError(error)).toBe(false);
});
});
describe('isQwenThrottlingError', () => {
it('should detect throttling error with 429 status', () => {
const error = { message: 'throttling', status: 429 };
expect(isQwenThrottlingError(error)).toBe(true);
});
it('should detect requests throttling triggered with 429 status', () => {
const error = { message: 'requests throttling triggered', status: 429 };
expect(isQwenThrottlingError(error)).toBe(true);
});
it('should detect rate limit error with 429 status', () => {
const error = { message: 'rate limit exceeded', status: 429 };
expect(isQwenThrottlingError(error)).toBe(true);
});
it('should detect too many requests with 429 status', () => {
const error = { message: 'too many requests', status: 429 };
expect(isQwenThrottlingError(error)).toBe(true);
});
it('should detect throttling in string error', () => {
const error = 'throttling';
expect(isQwenThrottlingError(error)).toBe(true);
});
it('should detect throttling in structured error with 429', () => {
const error = { message: 'requests throttling triggered', status: 429 };
expect(isQwenThrottlingError(error)).toBe(true);
});
it('should detect throttling in API error with 429', () => {
const error: ApiError = {
error: {
code: 429,
message: 'throttling',
status: 'RESOURCE_EXHAUSTED',
details: [],
},
};
expect(isQwenThrottlingError(error)).toBe(true);
});
it('should not detect throttling without 429 status in structured error', () => {
const error = { message: 'throttling', status: 500 };
expect(isQwenThrottlingError(error)).toBe(false);
});
it('should not detect quota exceeded as throttling', () => {
const error = { message: 'insufficient_quota', status: 429 };
expect(isQwenThrottlingError(error)).toBe(false);
});
it('should not detect unrelated errors as throttling', () => {
const error = { message: 'Network error', status: 500 };
expect(isQwenThrottlingError(error)).toBe(false);
});
});
describe('isProQuotaExceededError', () => {
it('should detect Gemini Pro quota exceeded error', () => {
const error = new Error(
"Quota exceeded for quota metric 'Gemini 2.5 Pro Requests'",
);
expect(isProQuotaExceededError(error)).toBe(true);
});
it('should detect Gemini preview Pro quota exceeded error', () => {
const error = new Error(
"Quota exceeded for quota metric 'Gemini 2.5-preview Pro Requests'",
);
expect(isProQuotaExceededError(error)).toBe(true);
});
it('should not detect non-Pro quota errors', () => {
const error = new Error(
"Quota exceeded for quota metric 'Gemini 1.5 Flash Requests'",
);
expect(isProQuotaExceededError(error)).toBe(false);
});
});
describe('isGenericQuotaExceededError', () => {
it('should detect generic quota exceeded error', () => {
const error = new Error('Quota exceeded for quota metric');
expect(isGenericQuotaExceededError(error)).toBe(true);
});
it('should not detect non-quota errors', () => {
const error = new Error('Network error');
expect(isGenericQuotaExceededError(error)).toBe(false);
});
});
describe('type guards', () => {
describe('isApiError', () => {
it('should detect valid API error', () => {
const error: ApiError = {
error: {
code: 429,
message: 'test error',
status: 'RESOURCE_EXHAUSTED',
details: [],
},
};
expect(isApiError(error)).toBe(true);
});
it('should not detect invalid API error', () => {
const error = { message: 'test error' };
expect(isApiError(error)).toBe(false);
});
});
describe('isStructuredError', () => {
it('should detect valid structured error', () => {
const error = { message: 'test error', status: 429 };
expect(isStructuredError(error)).toBe(true);
});
it('should not detect invalid structured error', () => {
const error = { code: 429 };
expect(isStructuredError(error)).toBe(false);
});
});
});
});

View File

@@ -101,3 +101,70 @@ export function isGenericQuotaExceededError(error: unknown): boolean {
return false;
}
export function isQwenQuotaExceededError(error: unknown): boolean {
// Check for Qwen insufficient quota errors (should not retry)
const checkMessage = (message: string): boolean => {
const lowerMessage = message.toLowerCase();
return (
lowerMessage.includes('insufficient_quota') ||
lowerMessage.includes('free allocated quota exceeded') ||
(lowerMessage.includes('quota') && lowerMessage.includes('exceeded'))
);
};
if (typeof error === 'string') {
return checkMessage(error);
}
if (isStructuredError(error)) {
return checkMessage(error.message);
}
if (isApiError(error)) {
return checkMessage(error.error.message);
}
return false;
}
export function isQwenThrottlingError(error: unknown): boolean {
// Check for Qwen throttling errors (should retry)
const checkMessage = (message: string): boolean => {
const lowerMessage = message.toLowerCase();
return (
lowerMessage.includes('throttling') ||
lowerMessage.includes('requests throttling triggered') ||
lowerMessage.includes('rate limit') ||
lowerMessage.includes('too many requests')
);
};
// Check status code
const getStatusCode = (error: unknown): number | undefined => {
if (error && typeof error === 'object') {
const errorObj = error as { status?: number; code?: number };
return errorObj.status || errorObj.code;
}
return undefined;
};
const statusCode = getStatusCode(error);
if (typeof error === 'string') {
return (
(statusCode === 429 && checkMessage(error)) ||
error.includes('throttling')
);
}
if (isStructuredError(error)) {
return statusCode === 429 && checkMessage(error.message);
}
if (isApiError(error)) {
return error.error.code === 429 && checkMessage(error.error.message);
}
return false;
}

View File

@@ -8,6 +8,7 @@
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
import { retryWithBackoff, HttpError } from './retry.js';
import { setSimulate429 } from './testUtils.js';
import { AuthType } from '../core/contentGenerator.js';
// Helper to create a mock function that fails a certain number of times
const createFailingFunction = (
@@ -399,4 +400,173 @@ describe('retryWithBackoff', () => {
expect(fallbackCallback).toHaveBeenCalledWith('oauth-personal');
});
});
describe('Qwen OAuth 429 error handling', () => {
it('should retry for Qwen OAuth 429 errors that are throttling-related', async () => {
const errorWith429: HttpError = new Error('Rate limit exceeded');
errorWith429.status = 429;
const fn = vi
.fn()
.mockRejectedValueOnce(errorWith429)
.mockResolvedValue('success');
const promise = retryWithBackoff(fn, {
maxAttempts: 5,
initialDelayMs: 100,
maxDelayMs: 1000,
shouldRetry: () => true,
authType: AuthType.QWEN_OAUTH,
});
// Fast-forward time for delays
await vi.runAllTimersAsync();
await expect(promise).resolves.toBe('success');
// Should be called twice (1 failure + 1 success)
expect(fn).toHaveBeenCalledTimes(2);
});
it('should throw immediately for Qwen OAuth with insufficient_quota message', async () => {
const errorWithInsufficientQuota = new Error('insufficient_quota');
const fn = vi.fn().mockRejectedValue(errorWithInsufficientQuota);
const promise = retryWithBackoff(fn, {
maxAttempts: 5,
initialDelayMs: 1000,
maxDelayMs: 5000,
shouldRetry: () => true,
authType: AuthType.QWEN_OAUTH,
});
await expect(promise).rejects.toThrow(/Qwen API quota exceeded/);
// Should be called only once (no retries)
expect(fn).toHaveBeenCalledTimes(1);
});
it('should throw immediately for Qwen OAuth with free allocated quota exceeded message', async () => {
const errorWithQuotaExceeded = new Error(
'Free allocated quota exceeded.',
);
const fn = vi.fn().mockRejectedValue(errorWithQuotaExceeded);
const promise = retryWithBackoff(fn, {
maxAttempts: 5,
initialDelayMs: 1000,
maxDelayMs: 5000,
shouldRetry: () => true,
authType: AuthType.QWEN_OAUTH,
});
await expect(promise).rejects.toThrow(/Qwen API quota exceeded/);
// Should be called only once (no retries)
expect(fn).toHaveBeenCalledTimes(1);
});
it('should retry for Qwen OAuth with throttling message', async () => {
const throttlingError: HttpError = new Error(
'requests throttling triggered',
);
throttlingError.status = 429;
const fn = vi
.fn()
.mockRejectedValueOnce(throttlingError)
.mockRejectedValueOnce(throttlingError)
.mockResolvedValue('success');
const promise = retryWithBackoff(fn, {
maxAttempts: 5,
initialDelayMs: 100,
maxDelayMs: 1000,
shouldRetry: () => true,
authType: AuthType.QWEN_OAUTH,
});
// Fast-forward time for delays
await vi.runAllTimersAsync();
await expect(promise).resolves.toBe('success');
// Should be called 3 times (2 failures + 1 success)
expect(fn).toHaveBeenCalledTimes(3);
});
it('should retry for Qwen OAuth with throttling error', async () => {
const throttlingError: HttpError = new Error('throttling');
throttlingError.status = 429;
const fn = vi
.fn()
.mockRejectedValueOnce(throttlingError)
.mockResolvedValue('success');
const promise = retryWithBackoff(fn, {
maxAttempts: 5,
initialDelayMs: 100,
maxDelayMs: 1000,
shouldRetry: () => true,
authType: AuthType.QWEN_OAUTH,
});
// Fast-forward time for delays
await vi.runAllTimersAsync();
await expect(promise).resolves.toBe('success');
// Should be called 2 times (1 failure + 1 success)
expect(fn).toHaveBeenCalledTimes(2);
});
it('should throw immediately for Qwen OAuth with quota message', async () => {
const errorWithQuota = new Error('quota exceeded');
const fn = vi.fn().mockRejectedValue(errorWithQuota);
const promise = retryWithBackoff(fn, {
maxAttempts: 5,
initialDelayMs: 1000,
maxDelayMs: 5000,
shouldRetry: () => true,
authType: AuthType.QWEN_OAUTH,
});
await expect(promise).rejects.toThrow(/Qwen API quota exceeded/);
// Should be called only once (no retries)
expect(fn).toHaveBeenCalledTimes(1);
});
it('should retry normal errors for Qwen OAuth (not quota-related)', async () => {
const normalError: HttpError = new Error('Network error');
normalError.status = 500;
const fn = createFailingFunction(2, 'success');
// Replace the default 500 error with our normal error
fn.mockRejectedValueOnce(normalError)
.mockRejectedValueOnce(normalError)
.mockResolvedValue('success');
const promise = retryWithBackoff(fn, {
maxAttempts: 5,
initialDelayMs: 100,
maxDelayMs: 1000,
shouldRetry: () => true,
authType: AuthType.QWEN_OAUTH,
});
// Fast-forward time for delays
await vi.runAllTimersAsync();
await expect(promise).resolves.toBe('success');
// Should be called 3 times (2 failures + 1 success)
expect(fn).toHaveBeenCalledTimes(3);
});
});
});

View File

@@ -8,6 +8,8 @@ import { AuthType } from '../core/contentGenerator.js';
import {
isProQuotaExceededError,
isGenericQuotaExceededError,
isQwenQuotaExceededError,
isQwenThrottlingError,
} from './quotaErrorDetection.js';
export interface HttpError extends Error {
@@ -150,9 +152,23 @@ export async function retryWithBackoff<T>(
}
}
// Track consecutive 429 errors
// Check for Qwen OAuth quota exceeded error - throw immediately without retry
if (authType === AuthType.QWEN_OAUTH && isQwenQuotaExceededError(error)) {
throw new Error(
`Qwen API quota exceeded: Your Qwen API quota has been exhausted. Please wait for your quota to reset.`,
);
}
// Track consecutive 429 errors, but handle Qwen throttling differently
if (errorStatus === 429) {
consecutive429Count++;
// For Qwen throttling errors, we still want to track them for exponential backoff
// but not for quota fallback logic (since Qwen doesn't have model fallback)
if (authType === AuthType.QWEN_OAUTH && isQwenThrottlingError(error)) {
// Keep track of 429s but reset the consecutive count to avoid fallback logic
consecutive429Count = 0;
} else {
consecutive429Count++;
}
} else {
consecutive429Count = 0;
}