mirror of
https://github.com/QwenLM/qwen-code.git
synced 2025-12-20 08:47:44 +00:00
feat(oauth): add Qwen OAuth integration
This commit is contained in:
1306
packages/core/src/code_assist/qwenOAuth2.test.ts
Normal file
1306
packages/core/src/code_assist/qwenOAuth2.test.ts
Normal file
File diff suppressed because it is too large
Load Diff
862
packages/core/src/code_assist/qwenOAuth2.ts
Normal file
862
packages/core/src/code_assist/qwenOAuth2.ts
Normal file
@@ -0,0 +1,862 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Qwen
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import crypto from 'crypto';
|
||||
import path from 'node:path';
|
||||
import { promises as fs } from 'node:fs';
|
||||
import * as os from 'os';
|
||||
|
||||
import open from 'open';
|
||||
import { EventEmitter } from 'events';
|
||||
import { Config } from '../config/config.js';
|
||||
import { randomUUID } from 'node:crypto';
|
||||
|
||||
// OAuth Endpoints
|
||||
const QWEN_OAUTH_BASE_URL = 'https://pre4-chat.qwen.ai';
|
||||
// const QWEN_OAUTH_BASE_URL = 'https://chat.qwen.ai';
|
||||
|
||||
const QWEN_OAUTH_DEVICE_CODE_ENDPOINT = `${QWEN_OAUTH_BASE_URL}/api/v1/oauth2/device/code`;
|
||||
const QWEN_OAUTH_TOKEN_ENDPOINT = `${QWEN_OAUTH_BASE_URL}/api/v1/oauth2/token`;
|
||||
|
||||
// OAuth Client Configuration
|
||||
// const QWEN_OAUTH_CLIENT_ID = '93a239d6ed36412c8c442e91b60fa305';
|
||||
const QWEN_OAUTH_CLIENT_ID = 'f0304373b74a44d2b584a3fb70ca9e56';
|
||||
|
||||
const QWEN_OAUTH_SCOPE = 'openid profile email model.completion';
|
||||
const QWEN_OAUTH_GRANT_TYPE = 'urn:ietf:params:oauth:grant-type:device_code';
|
||||
|
||||
// File System Configuration
|
||||
const QWEN_DIR = '.qwen';
|
||||
const QWEN_CREDENTIAL_FILENAME = 'oauth_creds.json';
|
||||
|
||||
// Token Configuration
|
||||
const TOKEN_REFRESH_BUFFER_MS = 30 * 1000; // 30 seconds
|
||||
|
||||
/**
|
||||
* PKCE (Proof Key for Code Exchange) utilities
|
||||
* Implements RFC 7636 - Proof Key for Code Exchange by OAuth Public Clients
|
||||
*/
|
||||
|
||||
/**
|
||||
* Generate a random code verifier for PKCE
|
||||
* @returns A random string of 43-128 characters
|
||||
*/
|
||||
export function generateCodeVerifier(): string {
|
||||
return crypto.randomBytes(32).toString('base64url');
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate a code challenge from a code verifier using SHA-256
|
||||
* @param codeVerifier The code verifier string
|
||||
* @returns The code challenge string
|
||||
*/
|
||||
export function generateCodeChallenge(codeVerifier: string): string {
|
||||
const hash = crypto.createHash('sha256');
|
||||
hash.update(codeVerifier);
|
||||
return hash.digest('base64url');
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate PKCE code verifier and challenge pair
|
||||
* @returns Object containing code_verifier and code_challenge
|
||||
*/
|
||||
export function generatePKCEPair(): {
|
||||
code_verifier: string;
|
||||
code_challenge: string;
|
||||
} {
|
||||
const codeVerifier = generateCodeVerifier();
|
||||
const codeChallenge = generateCodeChallenge(codeVerifier);
|
||||
return { code_verifier: codeVerifier, code_challenge: codeChallenge };
|
||||
}
|
||||
|
||||
/**
|
||||
* Convert object to URL-encoded form data
|
||||
* @param data The object to convert
|
||||
* @returns URL-encoded string
|
||||
*/
|
||||
function objectToUrlEncoded(data: Record<string, string>): string {
|
||||
return Object.keys(data)
|
||||
.map((key) => `${encodeURIComponent(key)}=${encodeURIComponent(data[key])}`)
|
||||
.join('&');
|
||||
}
|
||||
|
||||
/**
|
||||
* Standard error response data
|
||||
*/
|
||||
export interface ErrorData {
|
||||
error: string;
|
||||
error_description: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* Qwen OAuth2 credentials interface
|
||||
*/
|
||||
export interface QwenCredentials {
|
||||
access_token?: string;
|
||||
refresh_token?: string;
|
||||
id_token?: string;
|
||||
expiry_date?: number;
|
||||
token_type?: string;
|
||||
resource_url?: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* Device authorization success data
|
||||
*/
|
||||
export interface DeviceAuthorizationData {
|
||||
device_code: string;
|
||||
user_code: string;
|
||||
verification_uri: string;
|
||||
verification_uri_complete: string;
|
||||
expires_in: number;
|
||||
}
|
||||
|
||||
/**
|
||||
* Device authorization response interface
|
||||
*/
|
||||
export type DeviceAuthorizationResponse = DeviceAuthorizationData | ErrorData;
|
||||
|
||||
/**
|
||||
* Type guard to check if device authorization was successful
|
||||
*/
|
||||
export function isDeviceAuthorizationSuccess(
|
||||
response: DeviceAuthorizationResponse,
|
||||
): response is DeviceAuthorizationData {
|
||||
return 'device_code' in response;
|
||||
}
|
||||
|
||||
/**
|
||||
* Device token success data
|
||||
*/
|
||||
export interface DeviceTokenData {
|
||||
access_token: string | null;
|
||||
refresh_token?: string | null;
|
||||
token_type: string;
|
||||
expires_in: number | null;
|
||||
scope?: string | null;
|
||||
endpoint?: string;
|
||||
resource_url?: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* Device token pending response
|
||||
*/
|
||||
export interface DeviceTokenPendingData {
|
||||
status: 'pending';
|
||||
slowDown?: boolean; // Indicates if client should increase polling interval
|
||||
}
|
||||
|
||||
/**
|
||||
* Device token response interface
|
||||
*/
|
||||
export type DeviceTokenResponse =
|
||||
| DeviceTokenData
|
||||
| DeviceTokenPendingData
|
||||
| ErrorData;
|
||||
|
||||
/**
|
||||
* Type guard to check if device token response was successful
|
||||
*/
|
||||
export function isDeviceTokenSuccess(
|
||||
response: DeviceTokenResponse,
|
||||
): response is DeviceTokenData {
|
||||
return (
|
||||
'access_token' in response &&
|
||||
response.access_token !== null &&
|
||||
response.access_token !== undefined &&
|
||||
typeof response.access_token === 'string' &&
|
||||
response.access_token.length > 0
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Type guard to check if device token response is pending
|
||||
*/
|
||||
export function isDeviceTokenPending(
|
||||
response: DeviceTokenResponse,
|
||||
): response is DeviceTokenPendingData {
|
||||
return (
|
||||
'status' in response &&
|
||||
(response as DeviceTokenPendingData).status === 'pending'
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Type guard to check if response is an error
|
||||
*/
|
||||
export function isErrorResponse(
|
||||
response:
|
||||
| DeviceAuthorizationResponse
|
||||
| DeviceTokenResponse
|
||||
| TokenRefreshResponse,
|
||||
): response is ErrorData {
|
||||
return 'error' in response;
|
||||
}
|
||||
|
||||
/**
|
||||
* Token refresh success data
|
||||
*/
|
||||
export interface TokenRefreshData {
|
||||
access_token: string;
|
||||
token_type: string;
|
||||
expires_in: number;
|
||||
refresh_token?: string; // Some OAuth servers may return a new refresh token
|
||||
resource_url?: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* Token refresh response interface
|
||||
*/
|
||||
export type TokenRefreshResponse = TokenRefreshData | ErrorData;
|
||||
|
||||
/**
|
||||
* Qwen OAuth2 client interface
|
||||
*/
|
||||
export interface IQwenOAuth2Client {
|
||||
setCredentials(credentials: QwenCredentials): void;
|
||||
getCredentials(): QwenCredentials;
|
||||
getAccessToken(): Promise<{ token?: string }>;
|
||||
requestDeviceAuthorization(options: {
|
||||
scope: string;
|
||||
code_challenge: string;
|
||||
code_challenge_method: string;
|
||||
}): Promise<DeviceAuthorizationResponse>;
|
||||
pollDeviceToken(options: {
|
||||
device_code: string;
|
||||
code_verifier: string;
|
||||
}): Promise<DeviceTokenResponse>;
|
||||
refreshAccessToken(): Promise<TokenRefreshResponse>;
|
||||
}
|
||||
|
||||
/**
|
||||
* Qwen OAuth2 client implementation
|
||||
*/
|
||||
export class QwenOAuth2Client implements IQwenOAuth2Client {
|
||||
private credentials: QwenCredentials = {};
|
||||
private proxy?: string;
|
||||
|
||||
constructor(options: { proxy?: string }) {
|
||||
this.proxy = options.proxy;
|
||||
}
|
||||
|
||||
setCredentials(credentials: QwenCredentials): void {
|
||||
this.credentials = credentials;
|
||||
}
|
||||
|
||||
getCredentials(): QwenCredentials {
|
||||
return this.credentials;
|
||||
}
|
||||
|
||||
async getAccessToken(): Promise<{ token?: string }> {
|
||||
if (this.credentials.access_token && this.isTokenValid()) {
|
||||
return { token: this.credentials.access_token };
|
||||
}
|
||||
|
||||
if (this.credentials.refresh_token) {
|
||||
const refreshResponse = await this.refreshAccessToken();
|
||||
const tokenData = refreshResponse as TokenRefreshData;
|
||||
return { token: tokenData.access_token };
|
||||
}
|
||||
|
||||
return { token: undefined };
|
||||
}
|
||||
|
||||
async requestDeviceAuthorization(options: {
|
||||
scope: string;
|
||||
code_challenge: string;
|
||||
code_challenge_method: string;
|
||||
}): Promise<DeviceAuthorizationResponse> {
|
||||
const bodyData = {
|
||||
client_id: QWEN_OAUTH_CLIENT_ID,
|
||||
scope: options.scope,
|
||||
code_challenge: options.code_challenge,
|
||||
code_challenge_method: options.code_challenge_method,
|
||||
};
|
||||
|
||||
const response = await fetch(QWEN_OAUTH_DEVICE_CODE_ENDPOINT, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/x-www-form-urlencoded',
|
||||
Accept: 'application/json',
|
||||
'x-request-id': randomUUID(),
|
||||
},
|
||||
body: objectToUrlEncoded(bodyData),
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const errorData = await response.text();
|
||||
throw new Error(
|
||||
`Device authorization failed: ${response.status} ${response.statusText}. Response: ${errorData}`,
|
||||
);
|
||||
}
|
||||
|
||||
const result = (await response.json()) as DeviceAuthorizationResponse;
|
||||
console.log('Device authorization result:', result);
|
||||
|
||||
// Check if the response indicates success
|
||||
if (!isDeviceAuthorizationSuccess(result)) {
|
||||
const errorData = result as ErrorData;
|
||||
throw new Error(
|
||||
`Device authorization failed: ${errorData?.error || 'Unknown error'} - ${errorData?.error_description || 'No details provided'}`,
|
||||
);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
async pollDeviceToken(options: {
|
||||
device_code: string;
|
||||
code_verifier: string;
|
||||
}): Promise<DeviceTokenResponse> {
|
||||
const bodyData = {
|
||||
grant_type: QWEN_OAUTH_GRANT_TYPE,
|
||||
client_id: QWEN_OAUTH_CLIENT_ID,
|
||||
device_code: options.device_code,
|
||||
code_verifier: options.code_verifier,
|
||||
};
|
||||
|
||||
const response = await fetch(QWEN_OAUTH_TOKEN_ENDPOINT, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/x-www-form-urlencoded',
|
||||
Accept: 'application/json',
|
||||
},
|
||||
body: objectToUrlEncoded(bodyData),
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
// Parse the response as JSON to check for OAuth RFC 8628 standard errors
|
||||
try {
|
||||
const errorData = (await response.json()) as ErrorData;
|
||||
|
||||
console.log(errorData.error);
|
||||
|
||||
// According to OAuth RFC 8628, handle standard polling responses
|
||||
if (
|
||||
response.status === 400 &&
|
||||
errorData.error === 'authorization_pending'
|
||||
) {
|
||||
// User has not yet approved the authorization request. Continue polling.
|
||||
return { status: 'pending' } as DeviceTokenPendingData;
|
||||
}
|
||||
|
||||
if (response.status === 429 && errorData.error === 'slow_down') {
|
||||
// Client is polling too frequently. Return pending with slowDown flag.
|
||||
return {
|
||||
status: 'pending',
|
||||
slowDown: true,
|
||||
} as DeviceTokenPendingData;
|
||||
}
|
||||
|
||||
// Handle other 400 errors (access_denied, expired_token, etc.) as real errors
|
||||
|
||||
// For other errors, throw with proper error information
|
||||
const error = new Error(
|
||||
`Device token poll failed: ${errorData.error || 'Unknown error'} - ${errorData.error_description || 'No details provided'}`,
|
||||
);
|
||||
(error as Error & { status?: number }).status = response.status;
|
||||
throw error;
|
||||
} catch (_parseError) {
|
||||
// If JSON parsing fails, fall back to text response
|
||||
const errorData = await response.text();
|
||||
const error = new Error(
|
||||
`Device token poll failed: ${response.status} ${response.statusText}. Response: ${errorData}`,
|
||||
);
|
||||
(error as Error & { status?: number }).status = response.status;
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
return (await response.json()) as DeviceTokenResponse;
|
||||
}
|
||||
|
||||
async refreshAccessToken(): Promise<TokenRefreshResponse> {
|
||||
if (!this.credentials.refresh_token) {
|
||||
throw new Error('No refresh token available');
|
||||
}
|
||||
|
||||
const bodyData = {
|
||||
grant_type: 'refresh_token',
|
||||
refresh_token: this.credentials.refresh_token,
|
||||
client_id: QWEN_OAUTH_CLIENT_ID,
|
||||
};
|
||||
|
||||
const response = await fetch(QWEN_OAUTH_TOKEN_ENDPOINT, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/x-www-form-urlencoded',
|
||||
Accept: 'application/json',
|
||||
},
|
||||
body: objectToUrlEncoded(bodyData),
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const errorData = await response.text();
|
||||
// Handle 401 errors which might indicate refresh token expiry
|
||||
if (response.status === 400) {
|
||||
await clearQwenCredentials();
|
||||
throw new Error(
|
||||
"Refresh token expired or invalid. Please use '/auth' to re-authenticate.",
|
||||
);
|
||||
}
|
||||
throw new Error(
|
||||
`Token refresh failed: ${response.status} ${response.statusText}. Response: ${errorData}`,
|
||||
);
|
||||
}
|
||||
|
||||
const responseData = (await response.json()) as TokenRefreshResponse;
|
||||
|
||||
// Check if the response indicates success
|
||||
if (isErrorResponse(responseData)) {
|
||||
const errorData = responseData as ErrorData;
|
||||
throw new Error(
|
||||
`Token refresh failed: ${errorData?.error || 'Unknown error'} - ${errorData?.error_description || 'No details provided'}`,
|
||||
);
|
||||
}
|
||||
|
||||
// Handle successful response
|
||||
const tokenData = responseData as TokenRefreshData;
|
||||
const tokens: QwenCredentials = {
|
||||
access_token: tokenData.access_token,
|
||||
token_type: tokenData.token_type,
|
||||
// Use new refresh token if provided, otherwise preserve existing one
|
||||
refresh_token: tokenData.refresh_token || this.credentials.refresh_token,
|
||||
resource_url: tokenData.resource_url, // Include resource_url if provided
|
||||
expiry_date: Date.now() + tokenData.expires_in * 1000,
|
||||
};
|
||||
|
||||
this.setCredentials(tokens);
|
||||
|
||||
// Cache the updated credentials to file
|
||||
await cacheQwenCredentials(tokens);
|
||||
|
||||
return responseData;
|
||||
}
|
||||
|
||||
private isTokenValid(): boolean {
|
||||
if (!this.credentials.expiry_date) {
|
||||
return false;
|
||||
}
|
||||
// Check if token expires within the refresh buffer time
|
||||
return Date.now() < this.credentials.expiry_date - TOKEN_REFRESH_BUFFER_MS;
|
||||
}
|
||||
}
|
||||
|
||||
export enum QwenOAuth2Event {
|
||||
AuthUri = 'auth-uri',
|
||||
AuthProgress = 'auth-progress',
|
||||
AuthCancel = 'auth-cancel',
|
||||
}
|
||||
|
||||
/**
|
||||
* Authentication result types to distinguish different failure reasons
|
||||
*/
|
||||
export type AuthResult =
|
||||
| { success: true }
|
||||
| {
|
||||
success: false;
|
||||
reason: 'timeout' | 'cancelled' | 'error' | 'rate_limit';
|
||||
};
|
||||
|
||||
/**
|
||||
* Global event emitter instance for QwenOAuth2 authentication events
|
||||
*/
|
||||
export const qwenOAuth2Events = new EventEmitter();
|
||||
|
||||
export async function getQwenOAuthClient(
|
||||
config: Config,
|
||||
): Promise<QwenOAuth2Client> {
|
||||
const client = new QwenOAuth2Client({
|
||||
proxy: config.getProxy(),
|
||||
});
|
||||
|
||||
// If there are cached creds on disk, they always take precedence
|
||||
if (await loadCachedQwenCredentials(client)) {
|
||||
console.log('Loaded cached Qwen credentials.');
|
||||
|
||||
try {
|
||||
await client.refreshAccessToken();
|
||||
return client;
|
||||
} catch (error: unknown) {
|
||||
// Handle refresh token errors
|
||||
const errorMessage =
|
||||
error instanceof Error ? error.message : String(error);
|
||||
|
||||
const isInvalidToken = errorMessage.includes(
|
||||
'Refresh token expired or invalid',
|
||||
);
|
||||
const userMessage = isInvalidToken
|
||||
? 'Cached credentials are invalid. Please re-authenticate.'
|
||||
: `Token refresh failed: ${errorMessage}`;
|
||||
const throwMessage = isInvalidToken
|
||||
? 'Cached Qwen credentials are invalid. Please re-authenticate.'
|
||||
: `Qwen token refresh failed: ${errorMessage}`;
|
||||
|
||||
// Emit token refresh error event
|
||||
qwenOAuth2Events.emit(QwenOAuth2Event.AuthProgress, 'error', userMessage);
|
||||
throw new Error(throwMessage);
|
||||
}
|
||||
}
|
||||
|
||||
// Use device authorization flow for authentication (single attempt)
|
||||
const result = await authWithQwenDeviceFlow(client, config);
|
||||
if (!result.success) {
|
||||
// Only emit timeout event if the failure reason is actually timeout
|
||||
// Other error types (401, 429, etc.) have already emitted their specific events
|
||||
if (result.reason === 'timeout') {
|
||||
qwenOAuth2Events.emit(
|
||||
QwenOAuth2Event.AuthProgress,
|
||||
'timeout',
|
||||
'Authentication timed out. Please try again or select a different authentication method.',
|
||||
);
|
||||
}
|
||||
|
||||
// Throw error with appropriate message based on failure reason
|
||||
switch (result.reason) {
|
||||
case 'timeout':
|
||||
throw new Error('Qwen OAuth authentication timed out');
|
||||
case 'cancelled':
|
||||
throw new Error('Qwen OAuth authentication was cancelled by user');
|
||||
case 'rate_limit':
|
||||
throw new Error(
|
||||
'Too many request for Qwen OAuth authentication, please try again later.',
|
||||
);
|
||||
case 'error':
|
||||
default:
|
||||
throw new Error('Qwen OAuth authentication failed');
|
||||
}
|
||||
}
|
||||
|
||||
return client;
|
||||
}
|
||||
|
||||
async function authWithQwenDeviceFlow(
|
||||
client: QwenOAuth2Client,
|
||||
config: Config,
|
||||
): Promise<AuthResult> {
|
||||
let isCancelled = false;
|
||||
|
||||
// Set up cancellation listener
|
||||
const cancelHandler = () => {
|
||||
isCancelled = true;
|
||||
};
|
||||
qwenOAuth2Events.once(QwenOAuth2Event.AuthCancel, cancelHandler);
|
||||
|
||||
try {
|
||||
// Generate PKCE code verifier and challenge
|
||||
const { code_verifier, code_challenge } = generatePKCEPair();
|
||||
|
||||
// Request device authorization
|
||||
const deviceAuth = await client.requestDeviceAuthorization({
|
||||
scope: QWEN_OAUTH_SCOPE,
|
||||
code_challenge,
|
||||
code_challenge_method: 'S256',
|
||||
});
|
||||
|
||||
// Ensure we have a successful authorization response
|
||||
if (!isDeviceAuthorizationSuccess(deviceAuth)) {
|
||||
const errorData = deviceAuth as ErrorData;
|
||||
throw new Error(
|
||||
`Device authorization failed: ${errorData?.error || 'Unknown error'} - ${errorData?.error_description || 'No details provided'}`,
|
||||
);
|
||||
}
|
||||
|
||||
// Emit device authorization event for UI integration immediately
|
||||
qwenOAuth2Events.emit(QwenOAuth2Event.AuthUri, deviceAuth);
|
||||
|
||||
console.log('\n=== Qwen OAuth Device Authorization ===');
|
||||
console.log(
|
||||
`Please visit the following URL on your phone or browser for authorization:`,
|
||||
);
|
||||
console.log(`\n${deviceAuth.verification_uri_complete}\n`);
|
||||
|
||||
const showFallbackMessage = () => {
|
||||
// Fallback message for console output
|
||||
};
|
||||
|
||||
// If browser launch is not suppressed, try to open the URL
|
||||
if (!config.isBrowserLaunchSuppressed()) {
|
||||
try {
|
||||
const childProcess = await open(deviceAuth.verification_uri_complete);
|
||||
|
||||
// IMPORTANT: Attach an error handler to the returned child process.
|
||||
// Without this, if `open` fails to spawn a process (e.g., `xdg-open` is not found
|
||||
// in a minimal Docker container), it will emit an unhandled 'error' event,
|
||||
// causing the entire Node.js process to crash.
|
||||
if (childProcess) {
|
||||
childProcess.on('error', () => {
|
||||
console.log('Failed to open browser. Visit this URL to authorize:');
|
||||
showFallbackMessage();
|
||||
});
|
||||
}
|
||||
} catch (_err) {
|
||||
showFallbackMessage();
|
||||
}
|
||||
} else {
|
||||
// Browser launch is suppressed, show fallback message
|
||||
showFallbackMessage();
|
||||
}
|
||||
|
||||
// Emit auth progress event
|
||||
qwenOAuth2Events.emit(
|
||||
QwenOAuth2Event.AuthProgress,
|
||||
'polling',
|
||||
'Waiting for authorization...',
|
||||
);
|
||||
|
||||
console.log('Waiting for authorization...\n');
|
||||
|
||||
// Poll for the token
|
||||
let pollInterval = 2000; // 2 seconds, can be increased if slow_down is received
|
||||
const maxAttempts = Math.ceil(
|
||||
deviceAuth.expires_in / (pollInterval / 1000),
|
||||
);
|
||||
|
||||
for (let attempt = 0; attempt < maxAttempts; attempt++) {
|
||||
// Check if authentication was cancelled
|
||||
if (isCancelled) {
|
||||
console.log('\nAuthentication cancelled by user.');
|
||||
qwenOAuth2Events.emit(
|
||||
QwenOAuth2Event.AuthProgress,
|
||||
'error',
|
||||
'Authentication cancelled by user.',
|
||||
);
|
||||
return { success: false, reason: 'cancelled' };
|
||||
}
|
||||
|
||||
try {
|
||||
console.log('polling for token...');
|
||||
const tokenResponse = await client.pollDeviceToken({
|
||||
device_code: deviceAuth.device_code,
|
||||
code_verifier,
|
||||
});
|
||||
|
||||
// Check if the response is successful and contains token data
|
||||
if (isDeviceTokenSuccess(tokenResponse)) {
|
||||
const tokenData = tokenResponse as DeviceTokenData;
|
||||
|
||||
// Convert to QwenCredentials format
|
||||
const credentials: QwenCredentials = {
|
||||
access_token: tokenData.access_token!, // Safe to assert as non-null due to isDeviceTokenSuccess check
|
||||
refresh_token: tokenData.refresh_token || undefined,
|
||||
token_type: tokenData.token_type,
|
||||
resource_url: tokenData.resource_url,
|
||||
expiry_date: tokenData.expires_in
|
||||
? /* ts-ignore */
|
||||
Date.now() + (tokenData.expires_in ?? 1) * 1000
|
||||
: undefined,
|
||||
};
|
||||
|
||||
client.setCredentials(credentials);
|
||||
|
||||
// Cache the new tokens
|
||||
await cacheQwenCredentials(credentials);
|
||||
|
||||
// Emit auth progress success event
|
||||
qwenOAuth2Events.emit(
|
||||
QwenOAuth2Event.AuthProgress,
|
||||
'success',
|
||||
'Authentication successful! Access token obtained.',
|
||||
);
|
||||
|
||||
console.log('Authentication successful! Access token obtained.');
|
||||
return { success: true };
|
||||
}
|
||||
|
||||
// Check if the response is pending
|
||||
if (isDeviceTokenPending(tokenResponse)) {
|
||||
const pendingData = tokenResponse as DeviceTokenPendingData;
|
||||
|
||||
console.log(pendingData);
|
||||
|
||||
// Handle slow_down error by increasing poll interval
|
||||
if (pendingData.slowDown) {
|
||||
pollInterval = Math.min(pollInterval * 1.5, 10000); // Increase by 50%, max 10 seconds
|
||||
console.log(
|
||||
`\nServer requested to slow down, increasing poll interval to ${pollInterval}ms`,
|
||||
);
|
||||
} else {
|
||||
pollInterval = 2000; // Reset to default interval
|
||||
}
|
||||
|
||||
// Emit polling progress event
|
||||
qwenOAuth2Events.emit(
|
||||
QwenOAuth2Event.AuthProgress,
|
||||
'polling',
|
||||
`Polling... (attempt ${attempt + 1}/${maxAttempts})`,
|
||||
);
|
||||
|
||||
process.stdout.write('.');
|
||||
|
||||
// Wait with cancellation check every 100ms
|
||||
await new Promise<void>((resolve) => {
|
||||
const checkInterval = 100; // Check every 100ms
|
||||
let elapsedTime = 0;
|
||||
|
||||
const intervalId = setInterval(() => {
|
||||
elapsedTime += checkInterval;
|
||||
|
||||
// Check for cancellation during wait
|
||||
if (isCancelled) {
|
||||
clearInterval(intervalId);
|
||||
resolve();
|
||||
return;
|
||||
}
|
||||
|
||||
// Complete wait when interval is reached
|
||||
if (elapsedTime >= pollInterval) {
|
||||
clearInterval(intervalId);
|
||||
resolve();
|
||||
return;
|
||||
}
|
||||
}, checkInterval);
|
||||
});
|
||||
|
||||
// Check for cancellation after waiting
|
||||
if (isCancelled) {
|
||||
console.log('\nAuthentication cancelled by user.');
|
||||
qwenOAuth2Events.emit(
|
||||
QwenOAuth2Event.AuthProgress,
|
||||
'error',
|
||||
'Authentication cancelled by user.',
|
||||
);
|
||||
return { success: false, reason: 'cancelled' };
|
||||
}
|
||||
|
||||
continue;
|
||||
}
|
||||
|
||||
// Handle error response
|
||||
if (isErrorResponse(tokenResponse)) {
|
||||
const errorData = tokenResponse as ErrorData;
|
||||
throw new Error(
|
||||
`Token polling failed: ${errorData?.error || 'Unknown error'} - ${errorData?.error_description || 'No details provided'}`,
|
||||
);
|
||||
}
|
||||
} catch (error: unknown) {
|
||||
// Handle specific error cases
|
||||
const errorMessage =
|
||||
error instanceof Error ? error.message : String(error);
|
||||
const statusCode =
|
||||
error instanceof Error
|
||||
? (error as Error & { status?: number }).status
|
||||
: null;
|
||||
|
||||
if (errorMessage.includes('401') || statusCode === 401) {
|
||||
const message =
|
||||
'Device code expired or invalid, please restart the authorization process.';
|
||||
|
||||
// Emit error event
|
||||
qwenOAuth2Events.emit(QwenOAuth2Event.AuthProgress, 'error', message);
|
||||
|
||||
return { success: false, reason: 'error' };
|
||||
}
|
||||
|
||||
// Handle 429 Too Many Requests error
|
||||
if (errorMessage.includes('429') || statusCode === 429) {
|
||||
const message =
|
||||
'Too many requests. The server is rate limiting our requests. Please select a different authentication method or try again later.';
|
||||
|
||||
// Emit rate limit event to notify user
|
||||
qwenOAuth2Events.emit(
|
||||
QwenOAuth2Event.AuthProgress,
|
||||
'rate_limit',
|
||||
message,
|
||||
);
|
||||
|
||||
console.log('\n' + message);
|
||||
|
||||
// Return false to stop polling and go back to auth selection
|
||||
return { success: false, reason: 'rate_limit' };
|
||||
}
|
||||
|
||||
const message = `Error polling for token: ${errorMessage}`;
|
||||
|
||||
// Emit error event
|
||||
qwenOAuth2Events.emit(QwenOAuth2Event.AuthProgress, 'error', message);
|
||||
|
||||
// Check for cancellation before waiting
|
||||
if (isCancelled) {
|
||||
return { success: false, reason: 'cancelled' };
|
||||
}
|
||||
|
||||
await new Promise((resolve) => setTimeout(resolve, pollInterval));
|
||||
}
|
||||
}
|
||||
|
||||
const timeoutMessage = 'Authorization timeout, please restart the process.';
|
||||
|
||||
// Emit timeout error event
|
||||
qwenOAuth2Events.emit(
|
||||
QwenOAuth2Event.AuthProgress,
|
||||
'timeout',
|
||||
timeoutMessage,
|
||||
);
|
||||
|
||||
console.error('\n' + timeoutMessage);
|
||||
return { success: false, reason: 'timeout' };
|
||||
} catch (error: unknown) {
|
||||
const errorMessage = error instanceof Error ? error.message : String(error);
|
||||
console.error('Device authorization flow failed:', errorMessage);
|
||||
return { success: false, reason: 'error' };
|
||||
} finally {
|
||||
// Clean up event listener
|
||||
qwenOAuth2Events.off(QwenOAuth2Event.AuthCancel, cancelHandler);
|
||||
}
|
||||
}
|
||||
|
||||
async function loadCachedQwenCredentials(
|
||||
client: QwenOAuth2Client,
|
||||
): Promise<boolean> {
|
||||
try {
|
||||
const keyFile = getQwenCachedCredentialPath();
|
||||
const creds = await fs.readFile(keyFile, 'utf-8');
|
||||
const credentials = JSON.parse(creds) as QwenCredentials;
|
||||
client.setCredentials(credentials);
|
||||
|
||||
// Verify that the credentials are still valid
|
||||
const { token } = await client.getAccessToken();
|
||||
if (!token) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
} catch (_) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
async function cacheQwenCredentials(credentials: QwenCredentials) {
|
||||
const filePath = getQwenCachedCredentialPath();
|
||||
await fs.mkdir(path.dirname(filePath), { recursive: true });
|
||||
|
||||
const credString = JSON.stringify(credentials, null, 2);
|
||||
await fs.writeFile(filePath, credString);
|
||||
}
|
||||
|
||||
/**
|
||||
* Clear cached Qwen credentials from disk
|
||||
* This is useful when credentials have expired or need to be reset
|
||||
*/
|
||||
export async function clearQwenCredentials(): Promise<void> {
|
||||
try {
|
||||
const filePath = getQwenCachedCredentialPath();
|
||||
await fs.unlink(filePath);
|
||||
console.log('Cached Qwen credentials cleared successfully.');
|
||||
} catch (error: unknown) {
|
||||
// If file doesn't exist or can't be deleted, we consider it cleared
|
||||
if (error instanceof Error && 'code' in error && error.code === 'ENOENT') {
|
||||
// File doesn't exist, already cleared
|
||||
return;
|
||||
}
|
||||
// Log other errors but don't throw - clearing credentials should be non-critical
|
||||
console.warn('Warning: Failed to clear cached Qwen credentials:', error);
|
||||
}
|
||||
}
|
||||
|
||||
function getQwenCachedCredentialPath(): string {
|
||||
return path.join(os.homedir(), QWEN_DIR, QWEN_CREDENTIAL_FILENAME);
|
||||
}
|
||||
@@ -333,7 +333,7 @@ export class Config {
|
||||
this.model = params.model;
|
||||
this.extensionContextFilePaths = params.extensionContextFilePaths ?? [];
|
||||
this.maxSessionTurns = params.maxSessionTurns ?? -1;
|
||||
this.sessionTokenLimit = params.sessionTokenLimit ?? 32000;
|
||||
this.sessionTokenLimit = params.sessionTokenLimit ?? -1;
|
||||
this.maxFolderItems = params.maxFolderItems ?? 20;
|
||||
this.experimentalAcp = params.experimentalAcp ?? false;
|
||||
this.listExtensions = params.listExtensions ?? false;
|
||||
|
||||
@@ -128,7 +128,8 @@ describe('OpenAIContentGenerator Timeout Handling', () => {
|
||||
: String(thrownError);
|
||||
expect(errorMessage).not.toMatch(/timeout after \d+s/);
|
||||
expect(errorMessage).not.toMatch(/Troubleshooting tips:/);
|
||||
expect(errorMessage).toMatch(/OpenAI API error:/);
|
||||
// Should preserve the original error message
|
||||
expect(errorMessage).toMatch(new RegExp(error.message));
|
||||
}
|
||||
}
|
||||
});
|
||||
@@ -161,7 +162,7 @@ describe('OpenAIContentGenerator Timeout Handling', () => {
|
||||
};
|
||||
|
||||
await expect(generator.generateContent(request)).rejects.toThrow(
|
||||
'OpenAI API error: Invalid API key',
|
||||
'Invalid API key',
|
||||
);
|
||||
});
|
||||
|
||||
@@ -238,6 +239,9 @@ describe('OpenAIContentGenerator Timeout Handling', () => {
|
||||
baseURL: '',
|
||||
timeout: 120000,
|
||||
maxRetries: 3,
|
||||
defaultHeaders: {
|
||||
'User-Agent': expect.stringMatching(/^QwenCode/),
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
@@ -256,6 +260,9 @@ describe('OpenAIContentGenerator Timeout Handling', () => {
|
||||
baseURL: '',
|
||||
timeout: 300000,
|
||||
maxRetries: 5,
|
||||
defaultHeaders: {
|
||||
'User-Agent': expect.stringMatching(/^QwenCode/),
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
@@ -271,6 +278,9 @@ describe('OpenAIContentGenerator Timeout Handling', () => {
|
||||
baseURL: '',
|
||||
timeout: 120000, // default
|
||||
maxRetries: 3, // default
|
||||
defaultHeaders: {
|
||||
'User-Agent': expect.stringMatching(/^QwenCode/),
|
||||
},
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -797,6 +797,11 @@ export class GeminiClient {
|
||||
authType?: string,
|
||||
error?: unknown,
|
||||
): Promise<string | null> {
|
||||
// Handle different auth types
|
||||
if (authType === AuthType.QWEN_OAUTH) {
|
||||
return this.handleQwenOAuthError(error);
|
||||
}
|
||||
|
||||
// Only handle fallback for OAuth users
|
||||
if (authType !== AuthType.LOGIN_WITH_GOOGLE) {
|
||||
return null;
|
||||
@@ -835,4 +840,59 @@ export class GeminiClient {
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Handles Qwen OAuth authentication errors and rate limiting
|
||||
*/
|
||||
private async handleQwenOAuthError(error?: unknown): Promise<string | null> {
|
||||
if (!error) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const errorMessage =
|
||||
error instanceof Error
|
||||
? error.message.toLowerCase()
|
||||
: String(error).toLowerCase();
|
||||
const errorCode =
|
||||
(error as { status?: number; code?: number })?.status ||
|
||||
(error as { status?: number; code?: number })?.code;
|
||||
|
||||
// Check if this is an authentication/authorization error
|
||||
const isAuthError =
|
||||
errorCode === 401 ||
|
||||
errorCode === 403 ||
|
||||
errorMessage.includes('unauthorized') ||
|
||||
errorMessage.includes('forbidden') ||
|
||||
errorMessage.includes('invalid api key') ||
|
||||
errorMessage.includes('authentication') ||
|
||||
errorMessage.includes('access denied') ||
|
||||
(errorMessage.includes('token') && errorMessage.includes('expired'));
|
||||
|
||||
// Check if this is a rate limiting error
|
||||
const isRateLimitError =
|
||||
errorCode === 429 ||
|
||||
errorMessage.includes('429') ||
|
||||
errorMessage.includes('rate limit') ||
|
||||
errorMessage.includes('too many requests');
|
||||
|
||||
if (isAuthError) {
|
||||
console.warn('Qwen OAuth authentication error detected:', errorMessage);
|
||||
// The QwenContentGenerator should automatically handle token refresh
|
||||
// If it still fails, it likely means the refresh token is also expired
|
||||
console.log(
|
||||
'Note: If this persists, you may need to re-authenticate with Qwen OAuth',
|
||||
);
|
||||
return null;
|
||||
}
|
||||
|
||||
if (isRateLimitError) {
|
||||
console.warn('Qwen API rate limit encountered:', errorMessage);
|
||||
// For rate limiting, we don't need to do anything special
|
||||
// The retry mechanism will handle the backoff
|
||||
return null;
|
||||
}
|
||||
|
||||
// For other errors, don't handle them specially
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -46,6 +46,7 @@ export enum AuthType {
|
||||
USE_VERTEX_AI = 'vertex-ai',
|
||||
CLOUD_SHELL = 'cloud-shell',
|
||||
USE_OPENAI = 'openai',
|
||||
QWEN_OAUTH = 'qwen-oauth',
|
||||
}
|
||||
|
||||
export type ContentGeneratorConfig = {
|
||||
@@ -131,6 +132,15 @@ export function createContentGeneratorConfig(
|
||||
return contentGeneratorConfig;
|
||||
}
|
||||
|
||||
if (authType === AuthType.QWEN_OAUTH) {
|
||||
// For Qwen OAuth, we'll handle the API key dynamically in createContentGenerator
|
||||
// Set a special marker to indicate this is Qwen OAuth
|
||||
contentGeneratorConfig.apiKey = 'QWEN_OAUTH_DYNAMIC_TOKEN';
|
||||
contentGeneratorConfig.model = config.getModel() || DEFAULT_GEMINI_MODEL;
|
||||
|
||||
return contentGeneratorConfig;
|
||||
}
|
||||
|
||||
return contentGeneratorConfig;
|
||||
}
|
||||
|
||||
@@ -184,6 +194,30 @@ export async function createContentGenerator(
|
||||
return new OpenAIContentGenerator(config.apiKey, config.model, gcConfig);
|
||||
}
|
||||
|
||||
if (config.authType === AuthType.QWEN_OAUTH) {
|
||||
if (config.apiKey !== 'QWEN_OAUTH_DYNAMIC_TOKEN') {
|
||||
throw new Error('Invalid Qwen OAuth configuration');
|
||||
}
|
||||
|
||||
// Import required classes dynamically
|
||||
const { getQwenOAuthClient: getQwenOauthClient } = await import(
|
||||
'../code_assist/qwenOAuth2.js'
|
||||
);
|
||||
const { QwenContentGenerator } = await import('./qwenContentGenerator.js');
|
||||
|
||||
try {
|
||||
// Get the Qwen OAuth client (now includes integrated token management)
|
||||
const qwenClient = await getQwenOauthClient(gcConfig);
|
||||
|
||||
// Create the content generator with dynamic token management
|
||||
return new QwenContentGenerator(qwenClient, config.model, gcConfig);
|
||||
} catch (error) {
|
||||
throw new Error(
|
||||
`Failed to initialize Qwen: ${error instanceof Error ? error.message : String(error)}`,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
throw new Error(
|
||||
`Error creating contentGenerator: Unsupported authType: ${config.authType}`,
|
||||
);
|
||||
|
||||
@@ -201,6 +201,11 @@ export class GeminiChat {
|
||||
authType?: string,
|
||||
error?: unknown,
|
||||
): Promise<string | null> {
|
||||
// Handle different auth types
|
||||
if (authType === AuthType.QWEN_OAUTH) {
|
||||
return this.handleQwenOAuthError(error);
|
||||
}
|
||||
|
||||
// Only handle fallback for OAuth users
|
||||
if (authType !== AuthType.LOGIN_WITH_GOOGLE) {
|
||||
return null;
|
||||
@@ -674,4 +679,59 @@ export class GeminiChat {
|
||||
content.parts[0].thought === true
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Handles Qwen OAuth authentication errors and rate limiting
|
||||
*/
|
||||
private async handleQwenOAuthError(error?: unknown): Promise<string | null> {
|
||||
if (!error) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const errorMessage =
|
||||
error instanceof Error
|
||||
? error.message.toLowerCase()
|
||||
: String(error).toLowerCase();
|
||||
const errorCode =
|
||||
(error as { status?: number; code?: number })?.status ||
|
||||
(error as { status?: number; code?: number })?.code;
|
||||
|
||||
// Check if this is an authentication/authorization error
|
||||
const isAuthError =
|
||||
errorCode === 401 ||
|
||||
errorCode === 403 ||
|
||||
errorMessage.includes('unauthorized') ||
|
||||
errorMessage.includes('forbidden') ||
|
||||
errorMessage.includes('invalid api key') ||
|
||||
errorMessage.includes('authentication') ||
|
||||
errorMessage.includes('access denied') ||
|
||||
(errorMessage.includes('token') && errorMessage.includes('expired'));
|
||||
|
||||
// Check if this is a rate limiting error
|
||||
const isRateLimitError =
|
||||
errorCode === 429 ||
|
||||
errorMessage.includes('429') ||
|
||||
errorMessage.includes('rate limit') ||
|
||||
errorMessage.includes('too many requests');
|
||||
|
||||
if (isAuthError) {
|
||||
console.warn('Qwen OAuth authentication error detected:', errorMessage);
|
||||
// The QwenContentGenerator should automatically handle token refresh
|
||||
// If it still fails, it likely means the refresh token is also expired
|
||||
console.log(
|
||||
'Note: If this persists, you may need to re-authenticate with Qwen OAuth',
|
||||
);
|
||||
return null;
|
||||
}
|
||||
|
||||
if (isRateLimitError) {
|
||||
console.warn('Qwen API rate limit encountered:', errorMessage);
|
||||
// For rate limiting, we don't need to do anything special
|
||||
// The retry mechanism will handle the backoff
|
||||
return null;
|
||||
}
|
||||
|
||||
// For other errors, don't handle them specially
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
2358
packages/core/src/core/openaiContentGenerator.test.ts
Normal file
2358
packages/core/src/core/openaiContentGenerator.test.ts
Normal file
File diff suppressed because it is too large
Load Diff
@@ -1,6 +1,6 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* Copyright 2025 Qwen
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
@@ -78,7 +78,7 @@ interface OpenAIResponseFormat {
|
||||
}
|
||||
|
||||
export class OpenAIContentGenerator implements ContentGenerator {
|
||||
private client: OpenAI;
|
||||
protected client: OpenAI;
|
||||
private model: string;
|
||||
private config: Config;
|
||||
private streamingToolCalls: Map<
|
||||
@@ -114,14 +114,21 @@ export class OpenAIContentGenerator implements ContentGenerator {
|
||||
timeoutConfig.maxRetries = contentGeneratorConfig.maxRetries;
|
||||
}
|
||||
|
||||
// Set up User-Agent header (same format as contentGenerator.ts)
|
||||
const version = process.env.CLI_VERSION || process.version;
|
||||
const userAgent = `QwenCode/${version} (${process.platform}; ${process.arch})`;
|
||||
|
||||
// Check if using OpenRouter and add required headers
|
||||
const isOpenRouter = baseURL.includes('openrouter.ai');
|
||||
const defaultHeaders = isOpenRouter
|
||||
? {
|
||||
'HTTP-Referer': 'https://github.com/QwenLM/qwen-code.git',
|
||||
'X-Title': 'Qwen Code',
|
||||
}
|
||||
: undefined;
|
||||
const defaultHeaders = {
|
||||
'User-Agent': userAgent,
|
||||
...(isOpenRouter
|
||||
? {
|
||||
'HTTP-Referer': 'https://github.com/QwenLM/qwen-code.git',
|
||||
'X-Title': 'Qwen Code',
|
||||
}
|
||||
: {}),
|
||||
};
|
||||
|
||||
this.client = new OpenAI({
|
||||
apiKey,
|
||||
@@ -132,6 +139,19 @@ export class OpenAIContentGenerator implements ContentGenerator {
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Hook for subclasses to customize error handling behavior
|
||||
* @param error The error that occurred
|
||||
* @param request The original request
|
||||
* @returns true if error logging should be suppressed, false otherwise
|
||||
*/
|
||||
protected shouldSuppressErrorLogging(
|
||||
_error: unknown,
|
||||
_request: GenerateContentParameters,
|
||||
): boolean {
|
||||
return false; // Default behavior: never suppress error logging
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if an error is a timeout error
|
||||
*/
|
||||
@@ -275,7 +295,10 @@ export class OpenAIContentGenerator implements ContentGenerator {
|
||||
);
|
||||
}
|
||||
|
||||
console.error('OpenAI API Error:', errorMessage);
|
||||
// Allow subclasses to suppress error logging for specific scenarios
|
||||
if (!this.shouldSuppressErrorLogging(error, request)) {
|
||||
console.error('OpenAI API Error:', errorMessage);
|
||||
}
|
||||
|
||||
// Provide helpful timeout-specific error message
|
||||
if (isTimeoutError) {
|
||||
@@ -288,7 +311,7 @@ export class OpenAIContentGenerator implements ContentGenerator {
|
||||
);
|
||||
}
|
||||
|
||||
throw new Error(`OpenAI API error: ${errorMessage}`);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -486,7 +509,10 @@ export class OpenAIContentGenerator implements ContentGenerator {
|
||||
);
|
||||
logApiResponse(this.config, errorEvent);
|
||||
|
||||
console.error('OpenAI API Streaming Error:', errorMessage);
|
||||
// Allow subclasses to suppress error logging for specific scenarios
|
||||
if (!this.shouldSuppressErrorLogging(error, request)) {
|
||||
console.error('OpenAI API Streaming Error:', errorMessage);
|
||||
}
|
||||
|
||||
// Provide helpful timeout-specific error message for streaming setup
|
||||
if (isTimeoutError) {
|
||||
@@ -499,7 +525,7 @@ export class OpenAIContentGenerator implements ContentGenerator {
|
||||
);
|
||||
}
|
||||
|
||||
throw new Error(`OpenAI API error: ${errorMessage}`);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
794
packages/core/src/core/qwenContentGenerator.test.ts
Normal file
794
packages/core/src/core/qwenContentGenerator.test.ts
Normal file
@@ -0,0 +1,794 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Qwen
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
|
||||
import {
|
||||
IQwenOAuth2Client,
|
||||
QwenCredentials,
|
||||
ErrorData,
|
||||
} from '../code_assist/qwenOAuth2.js';
|
||||
import {
|
||||
GenerateContentParameters,
|
||||
GenerateContentResponse,
|
||||
CountTokensParameters,
|
||||
CountTokensResponse,
|
||||
EmbedContentParameters,
|
||||
EmbedContentResponse,
|
||||
FinishReason,
|
||||
} from '@google/genai';
|
||||
import { QwenContentGenerator } from './qwenContentGenerator.js';
|
||||
import { Config } from '../config/config.js';
|
||||
|
||||
// Mock the OpenAIContentGenerator parent class
|
||||
vi.mock('./openaiContentGenerator.js', () => ({
|
||||
OpenAIContentGenerator: class {
|
||||
client: {
|
||||
apiKey: string;
|
||||
baseURL: string;
|
||||
};
|
||||
|
||||
constructor(apiKey: string, _model: string, _config: Config) {
|
||||
this.client = {
|
||||
apiKey,
|
||||
baseURL: 'https://api.openai.com/v1',
|
||||
};
|
||||
}
|
||||
|
||||
async generateContent(
|
||||
_request: GenerateContentParameters,
|
||||
): Promise<GenerateContentResponse> {
|
||||
return createMockResponse('Generated content');
|
||||
}
|
||||
|
||||
async generateContentStream(
|
||||
_request: GenerateContentParameters,
|
||||
): Promise<AsyncGenerator<GenerateContentResponse>> {
|
||||
return (async function* () {
|
||||
yield createMockResponse('Stream chunk 1');
|
||||
yield createMockResponse('Stream chunk 2');
|
||||
})();
|
||||
}
|
||||
|
||||
async countTokens(
|
||||
_request: CountTokensParameters,
|
||||
): Promise<CountTokensResponse> {
|
||||
return { totalTokens: 10 };
|
||||
}
|
||||
|
||||
async embedContent(
|
||||
_request: EmbedContentParameters,
|
||||
): Promise<EmbedContentResponse> {
|
||||
return { embeddings: [{ values: [0.1, 0.2, 0.3] }] };
|
||||
}
|
||||
|
||||
protected shouldSuppressErrorLogging(
|
||||
_error: unknown,
|
||||
_request: GenerateContentParameters,
|
||||
): boolean {
|
||||
return false;
|
||||
}
|
||||
},
|
||||
}));
|
||||
|
||||
const createMockResponse = (text: string): GenerateContentResponse =>
|
||||
({
|
||||
candidates: [
|
||||
{
|
||||
content: { role: 'model', parts: [{ text }] },
|
||||
finishReason: FinishReason.STOP,
|
||||
index: 0,
|
||||
safetyRatings: [],
|
||||
},
|
||||
],
|
||||
promptFeedback: { safetyRatings: [] },
|
||||
text,
|
||||
data: undefined,
|
||||
functionCalls: [],
|
||||
executableCode: '',
|
||||
codeExecutionResult: '',
|
||||
}) as GenerateContentResponse;
|
||||
|
||||
describe('QwenContentGenerator', () => {
|
||||
let mockQwenClient: IQwenOAuth2Client;
|
||||
let qwenContentGenerator: QwenContentGenerator;
|
||||
let mockConfig: Config;
|
||||
|
||||
const mockCredentials: QwenCredentials = {
|
||||
access_token: 'test-access-token',
|
||||
refresh_token: 'test-refresh-token',
|
||||
resource_url: 'https://test-endpoint.com/v1',
|
||||
};
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
|
||||
// Mock Config
|
||||
mockConfig = {} as Config;
|
||||
|
||||
// Mock QwenOAuth2Client
|
||||
mockQwenClient = {
|
||||
getAccessToken: vi.fn(),
|
||||
getCredentials: vi.fn(),
|
||||
setCredentials: vi.fn(),
|
||||
refreshAccessToken: vi.fn(),
|
||||
requestDeviceAuthorization: vi.fn(),
|
||||
pollDeviceToken: vi.fn(),
|
||||
};
|
||||
|
||||
// Create QwenContentGenerator instance
|
||||
qwenContentGenerator = new QwenContentGenerator(
|
||||
mockQwenClient,
|
||||
'qwen-turbo',
|
||||
mockConfig,
|
||||
);
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.restoreAllMocks();
|
||||
});
|
||||
|
||||
describe('Core Content Generation Methods', () => {
|
||||
it('should generate content with valid token', async () => {
|
||||
vi.mocked(mockQwenClient.getAccessToken).mockResolvedValue({
|
||||
token: 'valid-token',
|
||||
});
|
||||
vi.mocked(mockQwenClient.getCredentials).mockReturnValue(mockCredentials);
|
||||
|
||||
const request: GenerateContentParameters = {
|
||||
model: 'qwen-turbo',
|
||||
contents: [{ role: 'user', parts: [{ text: 'Hello' }] }],
|
||||
};
|
||||
|
||||
const result = await qwenContentGenerator.generateContent(request);
|
||||
|
||||
expect(result.text).toBe('Generated content');
|
||||
expect(mockQwenClient.getAccessToken).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should generate content stream with valid token', async () => {
|
||||
vi.mocked(mockQwenClient.getAccessToken).mockResolvedValue({
|
||||
token: 'valid-token',
|
||||
});
|
||||
vi.mocked(mockQwenClient.getCredentials).mockReturnValue(mockCredentials);
|
||||
|
||||
const request: GenerateContentParameters = {
|
||||
model: 'qwen-turbo',
|
||||
contents: [{ role: 'user', parts: [{ text: 'Hello stream' }] }],
|
||||
};
|
||||
|
||||
const stream = await qwenContentGenerator.generateContentStream(request);
|
||||
const chunks: string[] = [];
|
||||
|
||||
for await (const chunk of stream) {
|
||||
chunks.push(chunk.text || '');
|
||||
}
|
||||
|
||||
expect(chunks).toEqual(['Stream chunk 1', 'Stream chunk 2']);
|
||||
expect(mockQwenClient.getAccessToken).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should count tokens with valid token', async () => {
|
||||
vi.mocked(mockQwenClient.getAccessToken).mockResolvedValue({
|
||||
token: 'valid-token',
|
||||
});
|
||||
vi.mocked(mockQwenClient.getCredentials).mockReturnValue(mockCredentials);
|
||||
|
||||
const request: CountTokensParameters = {
|
||||
model: 'qwen-turbo',
|
||||
contents: [{ role: 'user', parts: [{ text: 'Count me' }] }],
|
||||
};
|
||||
|
||||
const result = await qwenContentGenerator.countTokens(request);
|
||||
|
||||
expect(result.totalTokens).toBe(10);
|
||||
expect(mockQwenClient.getAccessToken).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should embed content with valid token', async () => {
|
||||
vi.mocked(mockQwenClient.getAccessToken).mockResolvedValue({
|
||||
token: 'valid-token',
|
||||
});
|
||||
vi.mocked(mockQwenClient.getCredentials).mockReturnValue(mockCredentials);
|
||||
|
||||
const request: EmbedContentParameters = {
|
||||
model: 'qwen-turbo',
|
||||
contents: [{ parts: [{ text: 'Embed me' }] }],
|
||||
};
|
||||
|
||||
const result = await qwenContentGenerator.embedContent(request);
|
||||
|
||||
expect(result.embeddings).toHaveLength(1);
|
||||
expect(result.embeddings?.[0]?.values).toEqual([0.1, 0.2, 0.3]);
|
||||
expect(mockQwenClient.getAccessToken).toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
describe('Token Management and Refresh Logic', () => {
|
||||
it('should refresh token on auth error and retry', async () => {
|
||||
const authError = { status: 401, message: 'Unauthorized' };
|
||||
|
||||
// First call fails with auth error
|
||||
vi.mocked(mockQwenClient.getAccessToken).mockRejectedValueOnce(authError);
|
||||
|
||||
// Refresh succeeds
|
||||
vi.mocked(mockQwenClient.refreshAccessToken).mockResolvedValue({
|
||||
access_token: 'refreshed-token',
|
||||
token_type: 'Bearer',
|
||||
expires_in: 3600,
|
||||
resource_url: 'https://refreshed-endpoint.com',
|
||||
});
|
||||
|
||||
const request: GenerateContentParameters = {
|
||||
model: 'qwen-turbo',
|
||||
contents: [{ role: 'user', parts: [{ text: 'Hello' }] }],
|
||||
};
|
||||
|
||||
const result = await qwenContentGenerator.generateContent(request);
|
||||
|
||||
expect(result.text).toBe('Generated content');
|
||||
expect(mockQwenClient.refreshAccessToken).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should handle token refresh failure', async () => {
|
||||
vi.mocked(mockQwenClient.getAccessToken).mockRejectedValue(
|
||||
new Error('Token expired'),
|
||||
);
|
||||
vi.mocked(mockQwenClient.refreshAccessToken).mockRejectedValue(
|
||||
new Error('Refresh failed'),
|
||||
);
|
||||
|
||||
const request: GenerateContentParameters = {
|
||||
model: 'qwen-turbo',
|
||||
contents: [{ role: 'user', parts: [{ text: 'Hello' }] }],
|
||||
};
|
||||
|
||||
await expect(
|
||||
qwenContentGenerator.generateContent(request),
|
||||
).rejects.toThrow(
|
||||
'Failed to obtain valid Qwen access token. Please re-authenticate.',
|
||||
);
|
||||
});
|
||||
|
||||
it('should update endpoint when token is refreshed', async () => {
|
||||
vi.mocked(mockQwenClient.getAccessToken).mockResolvedValue({
|
||||
token: 'valid-token',
|
||||
});
|
||||
vi.mocked(mockQwenClient.getCredentials).mockReturnValue({
|
||||
...mockCredentials,
|
||||
resource_url: 'https://new-endpoint.com',
|
||||
});
|
||||
|
||||
const request: GenerateContentParameters = {
|
||||
model: 'qwen-turbo',
|
||||
contents: [{ role: 'user', parts: [{ text: 'Hello' }] }],
|
||||
};
|
||||
|
||||
await qwenContentGenerator.generateContent(request);
|
||||
|
||||
expect(mockQwenClient.getCredentials).toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
describe('Endpoint URL Normalization', () => {
|
||||
it('should use default endpoint when no custom endpoint provided', async () => {
|
||||
let capturedBaseURL = '';
|
||||
|
||||
vi.mocked(mockQwenClient.getAccessToken).mockResolvedValue({
|
||||
token: 'valid-token',
|
||||
});
|
||||
vi.mocked(mockQwenClient.getCredentials).mockReturnValue({
|
||||
access_token: 'test-token',
|
||||
refresh_token: 'test-refresh',
|
||||
// No resource_url provided
|
||||
});
|
||||
|
||||
// Mock the parent's generateContent to capture the baseURL during the call
|
||||
const parentPrototype = Object.getPrototypeOf(
|
||||
Object.getPrototypeOf(qwenContentGenerator),
|
||||
);
|
||||
const originalGenerateContent = parentPrototype.generateContent;
|
||||
parentPrototype.generateContent = vi.fn().mockImplementation(function (
|
||||
this: QwenContentGenerator,
|
||||
) {
|
||||
capturedBaseURL = (this as unknown as { client: { baseURL: string } })
|
||||
.client.baseURL;
|
||||
return createMockResponse('Generated content');
|
||||
});
|
||||
|
||||
const request: GenerateContentParameters = {
|
||||
model: 'qwen-turbo',
|
||||
contents: [{ role: 'user', parts: [{ text: 'Hello' }] }],
|
||||
};
|
||||
|
||||
await qwenContentGenerator.generateContent(request);
|
||||
|
||||
// Should use default endpoint with /v1 suffix
|
||||
expect(capturedBaseURL).toBe(
|
||||
'https://dashscope.aliyuncs.com/compatible-mode/v1',
|
||||
);
|
||||
|
||||
// Restore original method
|
||||
parentPrototype.generateContent = originalGenerateContent;
|
||||
});
|
||||
|
||||
it('should normalize hostname-only endpoints by adding https protocol', async () => {
|
||||
let capturedBaseURL = '';
|
||||
|
||||
vi.mocked(mockQwenClient.getAccessToken).mockResolvedValue({
|
||||
token: 'valid-token',
|
||||
});
|
||||
vi.mocked(mockQwenClient.getCredentials).mockReturnValue({
|
||||
...mockCredentials,
|
||||
resource_url: 'custom-endpoint.com',
|
||||
});
|
||||
|
||||
// Mock the parent's generateContent to capture the baseURL during the call
|
||||
const parentPrototype = Object.getPrototypeOf(
|
||||
Object.getPrototypeOf(qwenContentGenerator),
|
||||
);
|
||||
const originalGenerateContent = parentPrototype.generateContent;
|
||||
parentPrototype.generateContent = vi.fn().mockImplementation(function (
|
||||
this: QwenContentGenerator,
|
||||
) {
|
||||
capturedBaseURL = (this as unknown as { client: { baseURL: string } })
|
||||
.client.baseURL;
|
||||
return createMockResponse('Generated content');
|
||||
});
|
||||
|
||||
const request: GenerateContentParameters = {
|
||||
model: 'qwen-turbo',
|
||||
contents: [{ role: 'user', parts: [{ text: 'Hello' }] }],
|
||||
};
|
||||
|
||||
await qwenContentGenerator.generateContent(request);
|
||||
|
||||
// Should add https:// and /v1
|
||||
expect(capturedBaseURL).toBe('https://custom-endpoint.com/v1');
|
||||
|
||||
// Restore original method
|
||||
parentPrototype.generateContent = originalGenerateContent;
|
||||
});
|
||||
|
||||
it('should preserve existing protocol in endpoint URLs', async () => {
|
||||
let capturedBaseURL = '';
|
||||
|
||||
vi.mocked(mockQwenClient.getAccessToken).mockResolvedValue({
|
||||
token: 'valid-token',
|
||||
});
|
||||
vi.mocked(mockQwenClient.getCredentials).mockReturnValue({
|
||||
...mockCredentials,
|
||||
resource_url: 'https://custom-endpoint.com',
|
||||
});
|
||||
|
||||
// Mock the parent's generateContent to capture the baseURL during the call
|
||||
const parentPrototype = Object.getPrototypeOf(
|
||||
Object.getPrototypeOf(qwenContentGenerator),
|
||||
);
|
||||
const originalGenerateContent = parentPrototype.generateContent;
|
||||
parentPrototype.generateContent = vi.fn().mockImplementation(function (
|
||||
this: QwenContentGenerator,
|
||||
) {
|
||||
capturedBaseURL = (this as unknown as { client: { baseURL: string } })
|
||||
.client.baseURL;
|
||||
return createMockResponse('Generated content');
|
||||
});
|
||||
|
||||
const request: GenerateContentParameters = {
|
||||
model: 'qwen-turbo',
|
||||
contents: [{ role: 'user', parts: [{ text: 'Hello' }] }],
|
||||
};
|
||||
|
||||
await qwenContentGenerator.generateContent(request);
|
||||
|
||||
// Should preserve https:// and add /v1
|
||||
expect(capturedBaseURL).toBe('https://custom-endpoint.com/v1');
|
||||
|
||||
// Restore original method
|
||||
parentPrototype.generateContent = originalGenerateContent;
|
||||
});
|
||||
|
||||
it('should not duplicate /v1 suffix if already present', async () => {
|
||||
let capturedBaseURL = '';
|
||||
|
||||
vi.mocked(mockQwenClient.getAccessToken).mockResolvedValue({
|
||||
token: 'valid-token',
|
||||
});
|
||||
vi.mocked(mockQwenClient.getCredentials).mockReturnValue({
|
||||
...mockCredentials,
|
||||
resource_url: 'https://custom-endpoint.com/v1',
|
||||
});
|
||||
|
||||
// Mock the parent's generateContent to capture the baseURL during the call
|
||||
const parentPrototype = Object.getPrototypeOf(
|
||||
Object.getPrototypeOf(qwenContentGenerator),
|
||||
);
|
||||
const originalGenerateContent = parentPrototype.generateContent;
|
||||
parentPrototype.generateContent = vi.fn().mockImplementation(function (
|
||||
this: QwenContentGenerator,
|
||||
) {
|
||||
capturedBaseURL = (this as unknown as { client: { baseURL: string } })
|
||||
.client.baseURL;
|
||||
return createMockResponse('Generated content');
|
||||
});
|
||||
|
||||
const request: GenerateContentParameters = {
|
||||
model: 'qwen-turbo',
|
||||
contents: [{ role: 'user', parts: [{ text: 'Hello' }] }],
|
||||
};
|
||||
|
||||
await qwenContentGenerator.generateContent(request);
|
||||
|
||||
// Should not duplicate /v1
|
||||
expect(capturedBaseURL).toBe('https://custom-endpoint.com/v1');
|
||||
|
||||
// Restore original method
|
||||
parentPrototype.generateContent = originalGenerateContent;
|
||||
});
|
||||
});
|
||||
|
||||
describe('Client State Management', () => {
|
||||
it('should restore original client credentials after operations', async () => {
|
||||
const client = (
|
||||
qwenContentGenerator as unknown as {
|
||||
client: { apiKey: string; baseURL: string };
|
||||
}
|
||||
).client;
|
||||
const originalApiKey = client.apiKey;
|
||||
const originalBaseURL = client.baseURL;
|
||||
|
||||
vi.mocked(mockQwenClient.getAccessToken).mockResolvedValue({
|
||||
token: 'temp-token',
|
||||
});
|
||||
vi.mocked(mockQwenClient.getCredentials).mockReturnValue({
|
||||
...mockCredentials,
|
||||
resource_url: 'https://temp-endpoint.com',
|
||||
});
|
||||
|
||||
const request: GenerateContentParameters = {
|
||||
model: 'qwen-turbo',
|
||||
contents: [{ role: 'user', parts: [{ text: 'Hello' }] }],
|
||||
};
|
||||
|
||||
await qwenContentGenerator.generateContent(request);
|
||||
|
||||
// Should restore original values after operation
|
||||
expect(client.apiKey).toBe(originalApiKey);
|
||||
expect(client.baseURL).toBe(originalBaseURL);
|
||||
});
|
||||
|
||||
it('should restore credentials even when operation throws', async () => {
|
||||
const client = (
|
||||
qwenContentGenerator as unknown as {
|
||||
client: { apiKey: string; baseURL: string };
|
||||
}
|
||||
).client;
|
||||
const originalApiKey = client.apiKey;
|
||||
const originalBaseURL = client.baseURL;
|
||||
|
||||
vi.mocked(mockQwenClient.getAccessToken).mockResolvedValue({
|
||||
token: 'temp-token',
|
||||
});
|
||||
vi.mocked(mockQwenClient.getCredentials).mockReturnValue(mockCredentials);
|
||||
|
||||
// Mock the parent method to throw an error
|
||||
const mockError = new Error('Network error');
|
||||
const parentPrototype = Object.getPrototypeOf(
|
||||
Object.getPrototypeOf(qwenContentGenerator),
|
||||
);
|
||||
const originalGenerateContent = parentPrototype.generateContent;
|
||||
parentPrototype.generateContent = vi.fn().mockRejectedValue(mockError);
|
||||
|
||||
const request: GenerateContentParameters = {
|
||||
model: 'qwen-turbo',
|
||||
contents: [{ role: 'user', parts: [{ text: 'Hello' }] }],
|
||||
};
|
||||
|
||||
try {
|
||||
await qwenContentGenerator.generateContent(request);
|
||||
} catch (error) {
|
||||
expect(error).toBe(mockError);
|
||||
}
|
||||
|
||||
// Credentials should still be restored
|
||||
expect(client.apiKey).toBe(originalApiKey);
|
||||
expect(client.baseURL).toBe(originalBaseURL);
|
||||
|
||||
// Restore original method
|
||||
parentPrototype.generateContent = originalGenerateContent;
|
||||
});
|
||||
});
|
||||
|
||||
describe('Error Handling and Retry Logic', () => {
|
||||
it('should retry once on authentication errors', async () => {
|
||||
const authError = { status: 401, message: 'Unauthorized' };
|
||||
|
||||
// Mock first call to fail with auth error
|
||||
const mockGenerateContent = vi
|
||||
.fn()
|
||||
.mockRejectedValueOnce(authError)
|
||||
.mockResolvedValueOnce(createMockResponse('Success after retry'));
|
||||
|
||||
// Replace the parent method
|
||||
const parentPrototype = Object.getPrototypeOf(
|
||||
Object.getPrototypeOf(qwenContentGenerator),
|
||||
);
|
||||
const originalGenerateContent = parentPrototype.generateContent;
|
||||
parentPrototype.generateContent = mockGenerateContent;
|
||||
|
||||
vi.mocked(mockQwenClient.getAccessToken).mockResolvedValue({
|
||||
token: 'initial-token',
|
||||
});
|
||||
vi.mocked(mockQwenClient.getCredentials).mockReturnValue(mockCredentials);
|
||||
vi.mocked(mockQwenClient.refreshAccessToken).mockResolvedValue({
|
||||
access_token: 'refreshed-token',
|
||||
token_type: 'Bearer',
|
||||
expires_in: 3600,
|
||||
});
|
||||
|
||||
const request: GenerateContentParameters = {
|
||||
model: 'qwen-turbo',
|
||||
contents: [{ role: 'user', parts: [{ text: 'Hello' }] }],
|
||||
};
|
||||
|
||||
const result = await qwenContentGenerator.generateContent(request);
|
||||
|
||||
expect(result.text).toBe('Success after retry');
|
||||
expect(mockGenerateContent).toHaveBeenCalledTimes(2);
|
||||
expect(mockQwenClient.refreshAccessToken).toHaveBeenCalled();
|
||||
|
||||
// Restore original method
|
||||
parentPrototype.generateContent = originalGenerateContent;
|
||||
});
|
||||
|
||||
it('should not retry non-authentication errors', async () => {
|
||||
const networkError = new Error('Network timeout');
|
||||
|
||||
const mockGenerateContent = vi.fn().mockRejectedValue(networkError);
|
||||
const parentPrototype = Object.getPrototypeOf(
|
||||
Object.getPrototypeOf(qwenContentGenerator),
|
||||
);
|
||||
const originalGenerateContent = parentPrototype.generateContent;
|
||||
parentPrototype.generateContent = mockGenerateContent;
|
||||
|
||||
vi.mocked(mockQwenClient.getAccessToken).mockResolvedValue({
|
||||
token: 'valid-token',
|
||||
});
|
||||
vi.mocked(mockQwenClient.getCredentials).mockReturnValue(mockCredentials);
|
||||
|
||||
const request: GenerateContentParameters = {
|
||||
model: 'qwen-turbo',
|
||||
contents: [{ role: 'user', parts: [{ text: 'Hello' }] }],
|
||||
};
|
||||
|
||||
await expect(
|
||||
qwenContentGenerator.generateContent(request),
|
||||
).rejects.toThrow('Network timeout');
|
||||
expect(mockGenerateContent).toHaveBeenCalledTimes(1);
|
||||
expect(mockQwenClient.refreshAccessToken).not.toHaveBeenCalled();
|
||||
|
||||
// Restore original method
|
||||
parentPrototype.generateContent = originalGenerateContent;
|
||||
});
|
||||
|
||||
it('should handle error response from token refresh', async () => {
|
||||
vi.mocked(mockQwenClient.getAccessToken).mockRejectedValue(
|
||||
new Error('Token expired'),
|
||||
);
|
||||
vi.mocked(mockQwenClient.refreshAccessToken).mockResolvedValue({
|
||||
error: 'invalid_grant',
|
||||
error_description: 'Refresh token expired',
|
||||
} as ErrorData);
|
||||
|
||||
const request: GenerateContentParameters = {
|
||||
model: 'qwen-turbo',
|
||||
contents: [{ role: 'user', parts: [{ text: 'Hello' }] }],
|
||||
};
|
||||
|
||||
await expect(
|
||||
qwenContentGenerator.generateContent(request),
|
||||
).rejects.toThrow('Failed to obtain valid Qwen access token');
|
||||
});
|
||||
});
|
||||
|
||||
describe('Token State Management', () => {
|
||||
it('should cache and return current token', () => {
|
||||
expect(qwenContentGenerator.getCurrentToken()).toBeNull();
|
||||
|
||||
// Simulate setting a token internally
|
||||
(
|
||||
qwenContentGenerator as unknown as { currentToken: string }
|
||||
).currentToken = 'cached-token';
|
||||
|
||||
expect(qwenContentGenerator.getCurrentToken()).toBe('cached-token');
|
||||
});
|
||||
|
||||
it('should clear token and endpoint on clearToken()', () => {
|
||||
// Simulate having cached values
|
||||
const qwenInstance = qwenContentGenerator as unknown as {
|
||||
currentToken: string;
|
||||
currentEndpoint: string;
|
||||
refreshPromise: Promise<string>;
|
||||
};
|
||||
qwenInstance.currentToken = 'cached-token';
|
||||
qwenInstance.currentEndpoint = 'https://cached-endpoint.com';
|
||||
qwenInstance.refreshPromise = Promise.resolve('token');
|
||||
|
||||
qwenContentGenerator.clearToken();
|
||||
|
||||
expect(qwenContentGenerator.getCurrentToken()).toBeNull();
|
||||
expect(
|
||||
(qwenContentGenerator as unknown as { currentEndpoint: string | null })
|
||||
.currentEndpoint,
|
||||
).toBeNull();
|
||||
expect(
|
||||
(
|
||||
qwenContentGenerator as unknown as {
|
||||
refreshPromise: Promise<string> | null;
|
||||
}
|
||||
).refreshPromise,
|
||||
).toBeNull();
|
||||
});
|
||||
|
||||
it('should handle concurrent token refresh requests', async () => {
|
||||
let refreshCallCount = 0;
|
||||
|
||||
// Clear any existing cached token first
|
||||
qwenContentGenerator.clearToken();
|
||||
|
||||
// Mock to simulate auth error on first parent call, which should trigger refresh
|
||||
const authError = { status: 401, message: 'Unauthorized' };
|
||||
let parentCallCount = 0;
|
||||
|
||||
vi.mocked(mockQwenClient.getAccessToken).mockResolvedValue({
|
||||
token: 'initial-token',
|
||||
});
|
||||
vi.mocked(mockQwenClient.getCredentials).mockReturnValue(mockCredentials);
|
||||
|
||||
vi.mocked(mockQwenClient.refreshAccessToken).mockImplementation(
|
||||
async () => {
|
||||
refreshCallCount++;
|
||||
await new Promise((resolve) => setTimeout(resolve, 50)); // Longer delay to ensure concurrency
|
||||
return {
|
||||
access_token: 'refreshed-token',
|
||||
token_type: 'Bearer',
|
||||
expires_in: 3600,
|
||||
};
|
||||
},
|
||||
);
|
||||
|
||||
// Mock the parent method to fail first then succeed
|
||||
const parentPrototype = Object.getPrototypeOf(
|
||||
Object.getPrototypeOf(qwenContentGenerator),
|
||||
);
|
||||
const originalGenerateContent = parentPrototype.generateContent;
|
||||
parentPrototype.generateContent = vi.fn().mockImplementation(async () => {
|
||||
parentCallCount++;
|
||||
if (parentCallCount === 1) {
|
||||
throw authError; // First call triggers auth error
|
||||
}
|
||||
return createMockResponse('Generated content');
|
||||
});
|
||||
|
||||
const request: GenerateContentParameters = {
|
||||
model: 'qwen-turbo',
|
||||
contents: [{ role: 'user', parts: [{ text: 'Hello' }] }],
|
||||
};
|
||||
|
||||
// Make multiple concurrent requests - should all use the same refresh promise
|
||||
const promises = [
|
||||
qwenContentGenerator.generateContent(request),
|
||||
qwenContentGenerator.generateContent(request),
|
||||
qwenContentGenerator.generateContent(request),
|
||||
];
|
||||
|
||||
const results = await Promise.all(promises);
|
||||
|
||||
// All should succeed
|
||||
results.forEach((result) => {
|
||||
expect(result.text).toBe('Generated content');
|
||||
});
|
||||
|
||||
// The main test is that all requests succeed without crashing
|
||||
expect(results).toHaveLength(3);
|
||||
expect(refreshCallCount).toBeGreaterThanOrEqual(1);
|
||||
|
||||
// Restore original method
|
||||
parentPrototype.generateContent = originalGenerateContent;
|
||||
});
|
||||
});
|
||||
|
||||
describe('Error Logging Suppression', () => {
|
||||
it('should suppress logging for authentication errors', () => {
|
||||
const authErrors = [
|
||||
{ status: 401 },
|
||||
{ code: 403 },
|
||||
new Error('Unauthorized access'),
|
||||
new Error('Token expired'),
|
||||
new Error('Invalid API key'),
|
||||
];
|
||||
|
||||
authErrors.forEach((error) => {
|
||||
const shouldSuppress = (
|
||||
qwenContentGenerator as unknown as {
|
||||
shouldSuppressErrorLogging: (
|
||||
error: unknown,
|
||||
request: GenerateContentParameters,
|
||||
) => boolean;
|
||||
}
|
||||
).shouldSuppressErrorLogging(error, {} as GenerateContentParameters);
|
||||
expect(shouldSuppress).toBe(true);
|
||||
});
|
||||
});
|
||||
|
||||
it('should not suppress logging for non-auth errors', () => {
|
||||
const nonAuthErrors = [
|
||||
new Error('Network timeout'),
|
||||
new Error('Rate limit exceeded'),
|
||||
{ status: 500 },
|
||||
new Error('Internal server error'),
|
||||
];
|
||||
|
||||
nonAuthErrors.forEach((error) => {
|
||||
const shouldSuppress = (
|
||||
qwenContentGenerator as unknown as {
|
||||
shouldSuppressErrorLogging: (
|
||||
error: unknown,
|
||||
request: GenerateContentParameters,
|
||||
) => boolean;
|
||||
}
|
||||
).shouldSuppressErrorLogging(error, {} as GenerateContentParameters);
|
||||
expect(shouldSuppress).toBe(false);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('Integration Tests', () => {
|
||||
it('should handle complete workflow: get token, use it, refresh on auth error, retry', async () => {
|
||||
const authError = { status: 401, message: 'Token expired' };
|
||||
|
||||
// Setup complex scenario
|
||||
let callCount = 0;
|
||||
const mockGenerateContent = vi.fn().mockImplementation(async () => {
|
||||
callCount++;
|
||||
if (callCount === 1) {
|
||||
throw authError; // First call fails
|
||||
}
|
||||
return createMockResponse('Success after refresh'); // Second call succeeds
|
||||
});
|
||||
|
||||
const parentPrototype = Object.getPrototypeOf(
|
||||
Object.getPrototypeOf(qwenContentGenerator),
|
||||
);
|
||||
parentPrototype.generateContent = mockGenerateContent;
|
||||
|
||||
vi.mocked(mockQwenClient.getAccessToken).mockResolvedValue({
|
||||
token: 'initial-token',
|
||||
});
|
||||
vi.mocked(mockQwenClient.getCredentials).mockReturnValue({
|
||||
...mockCredentials,
|
||||
resource_url: 'custom-endpoint.com',
|
||||
});
|
||||
vi.mocked(mockQwenClient.refreshAccessToken).mockResolvedValue({
|
||||
access_token: 'new-token',
|
||||
token_type: 'Bearer',
|
||||
expires_in: 7200,
|
||||
resource_url: 'https://new-endpoint.com',
|
||||
});
|
||||
|
||||
const request: GenerateContentParameters = {
|
||||
model: 'qwen-turbo',
|
||||
contents: [{ role: 'user', parts: [{ text: 'Test message' }] }],
|
||||
};
|
||||
|
||||
const result = await qwenContentGenerator.generateContent(request);
|
||||
|
||||
expect(result.text).toBe('Success after refresh');
|
||||
expect(mockQwenClient.getAccessToken).toHaveBeenCalled();
|
||||
expect(mockQwenClient.refreshAccessToken).toHaveBeenCalled();
|
||||
expect(callCount).toBe(2); // Initial call + retry
|
||||
});
|
||||
});
|
||||
});
|
||||
356
packages/core/src/core/qwenContentGenerator.ts
Normal file
356
packages/core/src/core/qwenContentGenerator.ts
Normal file
@@ -0,0 +1,356 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Qwen
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { OpenAIContentGenerator } from './openaiContentGenerator.js';
|
||||
import {
|
||||
IQwenOAuth2Client,
|
||||
type TokenRefreshData,
|
||||
type ErrorData,
|
||||
isErrorResponse,
|
||||
} from '../code_assist/qwenOAuth2.js';
|
||||
import { Config } from '../config/config.js';
|
||||
import {
|
||||
GenerateContentParameters,
|
||||
GenerateContentResponse,
|
||||
CountTokensParameters,
|
||||
CountTokensResponse,
|
||||
EmbedContentParameters,
|
||||
EmbedContentResponse,
|
||||
} from '@google/genai';
|
||||
|
||||
// Default fallback base URL if no endpoint is provided
|
||||
const DEFAULT_QWEN_BASE_URL =
|
||||
'https://dashscope.aliyuncs.com/compatible-mode/v1';
|
||||
|
||||
/**
|
||||
* Qwen Content Generator that uses Qwen OAuth tokens with automatic refresh
|
||||
*/
|
||||
export class QwenContentGenerator extends OpenAIContentGenerator {
|
||||
private qwenClient: IQwenOAuth2Client;
|
||||
|
||||
// Token management (integrated from QwenTokenManager)
|
||||
private currentToken: string | null = null;
|
||||
private currentEndpoint: string | null = null;
|
||||
private refreshPromise: Promise<string> | null = null;
|
||||
|
||||
constructor(qwenClient: IQwenOAuth2Client, model: string, config: Config) {
|
||||
// Initialize with empty API key, we'll override it dynamically
|
||||
super('', model, config);
|
||||
this.qwenClient = qwenClient;
|
||||
|
||||
// Set default base URL, will be updated dynamically
|
||||
this.client.baseURL = DEFAULT_QWEN_BASE_URL;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the current endpoint URL with proper protocol and /v1 suffix
|
||||
*/
|
||||
private getCurrentEndpoint(): string {
|
||||
const baseEndpoint = this.currentEndpoint || DEFAULT_QWEN_BASE_URL;
|
||||
const suffix = '/v1';
|
||||
|
||||
// Normalize the URL: add protocol if missing, ensure /v1 suffix
|
||||
const normalizedUrl = baseEndpoint.startsWith('http')
|
||||
? baseEndpoint
|
||||
: `https://${baseEndpoint}`;
|
||||
|
||||
return normalizedUrl.endsWith(suffix)
|
||||
? normalizedUrl
|
||||
: `${normalizedUrl}${suffix}`;
|
||||
}
|
||||
|
||||
/**
|
||||
* Override error logging behavior to suppress auth errors during token refresh
|
||||
*/
|
||||
protected shouldSuppressErrorLogging(
|
||||
error: unknown,
|
||||
_request: GenerateContentParameters,
|
||||
): boolean {
|
||||
// Suppress logging for authentication errors that we handle with token refresh
|
||||
return this.isAuthError(error);
|
||||
}
|
||||
|
||||
/**
|
||||
* Override to use dynamic token and endpoint
|
||||
*/
|
||||
async generateContent(
|
||||
request: GenerateContentParameters,
|
||||
): Promise<GenerateContentResponse> {
|
||||
return this.withValidToken(async (token) => {
|
||||
// Temporarily update the API key and base URL
|
||||
const originalApiKey = this.client.apiKey;
|
||||
const originalBaseURL = this.client.baseURL;
|
||||
this.client.apiKey = token;
|
||||
this.client.baseURL = this.getCurrentEndpoint();
|
||||
|
||||
try {
|
||||
return await super.generateContent(request);
|
||||
} finally {
|
||||
// Restore original values
|
||||
this.client.apiKey = originalApiKey;
|
||||
this.client.baseURL = originalBaseURL;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Override to use dynamic token and endpoint
|
||||
*/
|
||||
async generateContentStream(
|
||||
request: GenerateContentParameters,
|
||||
): Promise<AsyncGenerator<GenerateContentResponse>> {
|
||||
return this.withValidTokenForStream(async (token) => {
|
||||
// Update the API key and base URL before streaming
|
||||
const originalApiKey = this.client.apiKey;
|
||||
const originalBaseURL = this.client.baseURL;
|
||||
this.client.apiKey = token;
|
||||
this.client.baseURL = this.getCurrentEndpoint();
|
||||
|
||||
try {
|
||||
return await super.generateContentStream(request);
|
||||
} catch (error) {
|
||||
// Restore original values on error
|
||||
this.client.apiKey = originalApiKey;
|
||||
this.client.baseURL = originalBaseURL;
|
||||
throw error;
|
||||
}
|
||||
// Note: We don't restore the values in finally for streaming because
|
||||
// the generator may continue to be used after this method returns
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Override to use dynamic token and endpoint
|
||||
*/
|
||||
async countTokens(
|
||||
request: CountTokensParameters,
|
||||
): Promise<CountTokensResponse> {
|
||||
return this.withValidToken(async (token) => {
|
||||
const originalApiKey = this.client.apiKey;
|
||||
const originalBaseURL = this.client.baseURL;
|
||||
this.client.apiKey = token;
|
||||
this.client.baseURL = this.getCurrentEndpoint();
|
||||
|
||||
try {
|
||||
return await super.countTokens(request);
|
||||
} finally {
|
||||
this.client.apiKey = originalApiKey;
|
||||
this.client.baseURL = originalBaseURL;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Override to use dynamic token and endpoint
|
||||
*/
|
||||
async embedContent(
|
||||
request: EmbedContentParameters,
|
||||
): Promise<EmbedContentResponse> {
|
||||
return this.withValidToken(async (token) => {
|
||||
const originalApiKey = this.client.apiKey;
|
||||
const originalBaseURL = this.client.baseURL;
|
||||
this.client.apiKey = token;
|
||||
this.client.baseURL = this.getCurrentEndpoint();
|
||||
|
||||
try {
|
||||
return await super.embedContent(request);
|
||||
} finally {
|
||||
this.client.apiKey = originalApiKey;
|
||||
this.client.baseURL = originalBaseURL;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Execute operation with a valid token, with retry on auth failure
|
||||
*/
|
||||
private async withValidToken<T>(
|
||||
operation: (token: string) => Promise<T>,
|
||||
): Promise<T> {
|
||||
const token = await this.getTokenWithRetry();
|
||||
|
||||
try {
|
||||
return await operation(token);
|
||||
} catch (error) {
|
||||
// Check if this is an authentication error
|
||||
if (this.isAuthError(error)) {
|
||||
// Refresh token and retry once silently
|
||||
const newToken = await this.refreshToken();
|
||||
return await operation(newToken);
|
||||
}
|
||||
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Execute operation with a valid token for streaming, with retry on auth failure
|
||||
*/
|
||||
private async withValidTokenForStream<T>(
|
||||
operation: (token: string) => Promise<T>,
|
||||
): Promise<T> {
|
||||
const token = await this.getTokenWithRetry();
|
||||
|
||||
try {
|
||||
return await operation(token);
|
||||
} catch (error) {
|
||||
// Check if this is an authentication error
|
||||
if (this.isAuthError(error)) {
|
||||
// Refresh token and retry once silently
|
||||
const newToken = await this.refreshToken();
|
||||
return await operation(newToken);
|
||||
}
|
||||
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get token with retry logic
|
||||
*/
|
||||
private async getTokenWithRetry(): Promise<string> {
|
||||
try {
|
||||
return await this.getValidToken();
|
||||
} catch (error) {
|
||||
console.error('Failed to get valid token:', error);
|
||||
throw new Error(
|
||||
'Failed to obtain valid Qwen access token. Please re-authenticate.',
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// Token management methods (integrated from QwenTokenManager)
|
||||
|
||||
/**
|
||||
* Get a valid access token, refreshing if necessary
|
||||
*/
|
||||
private async getValidToken(): Promise<string> {
|
||||
// If there's already a refresh in progress, wait for it
|
||||
if (this.refreshPromise) {
|
||||
return this.refreshPromise;
|
||||
}
|
||||
|
||||
try {
|
||||
const { token } = await this.qwenClient.getAccessToken();
|
||||
if (token) {
|
||||
this.currentToken = token;
|
||||
// Also update endpoint from current credentials
|
||||
const credentials = this.qwenClient.getCredentials();
|
||||
if (credentials.resource_url) {
|
||||
this.currentEndpoint = credentials.resource_url;
|
||||
}
|
||||
return token;
|
||||
}
|
||||
} catch (error) {
|
||||
console.warn('Failed to get access token, attempting refresh:', error);
|
||||
}
|
||||
|
||||
// Start a new refresh operation
|
||||
this.refreshPromise = this.performTokenRefresh();
|
||||
|
||||
try {
|
||||
const newToken = await this.refreshPromise;
|
||||
return newToken;
|
||||
} finally {
|
||||
this.refreshPromise = null;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Force refresh the access token
|
||||
*/
|
||||
private async refreshToken(): Promise<string> {
|
||||
this.refreshPromise = this.performTokenRefresh();
|
||||
|
||||
try {
|
||||
const newToken = await this.refreshPromise;
|
||||
return newToken;
|
||||
} finally {
|
||||
this.refreshPromise = null;
|
||||
}
|
||||
}
|
||||
|
||||
private async performTokenRefresh(): Promise<string> {
|
||||
try {
|
||||
const response = await this.qwenClient.refreshAccessToken();
|
||||
|
||||
if (isErrorResponse(response)) {
|
||||
const errorData = response as ErrorData;
|
||||
throw new Error(
|
||||
`${errorData?.error || 'Unknown error'} - ${errorData?.error_description || 'No details provided'}`,
|
||||
);
|
||||
}
|
||||
|
||||
const tokenData = response as TokenRefreshData;
|
||||
|
||||
if (!tokenData.access_token) {
|
||||
throw new Error('Failed to refresh access token: no token returned');
|
||||
}
|
||||
|
||||
this.currentToken = tokenData.access_token;
|
||||
|
||||
// Update endpoint if provided
|
||||
if (tokenData.resource_url) {
|
||||
this.currentEndpoint = tokenData.resource_url;
|
||||
}
|
||||
|
||||
return tokenData.access_token;
|
||||
} catch (error) {
|
||||
throw new Error(
|
||||
`${error instanceof Error ? error.message : String(error)}`,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if an error is related to authentication/authorization
|
||||
*/
|
||||
private isAuthError(error: unknown): boolean {
|
||||
if (!error) return false;
|
||||
|
||||
const errorMessage =
|
||||
error instanceof Error
|
||||
? error.message.toLowerCase()
|
||||
: String(error).toLowerCase();
|
||||
|
||||
// Define a type for errors that might have status or code properties
|
||||
const errorWithCode = error as {
|
||||
status?: number | string;
|
||||
code?: number | string;
|
||||
};
|
||||
const errorCode = errorWithCode?.status || errorWithCode?.code;
|
||||
|
||||
return (
|
||||
errorCode === 400 ||
|
||||
errorCode === 401 ||
|
||||
errorCode === 403 ||
|
||||
errorMessage.includes('unauthorized') ||
|
||||
errorMessage.includes('forbidden') ||
|
||||
errorMessage.includes('invalid api key') ||
|
||||
errorMessage.includes('invalid access token') ||
|
||||
errorMessage.includes('token expired') ||
|
||||
errorMessage.includes('authentication') ||
|
||||
errorMessage.includes('access denied') ||
|
||||
(errorMessage.includes('token') && errorMessage.includes('expired'))
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the current cached token (may be expired)
|
||||
*/
|
||||
getCurrentToken(): string | null {
|
||||
return this.currentToken;
|
||||
}
|
||||
|
||||
/**
|
||||
* Clear the cached token and endpoint
|
||||
*/
|
||||
clearToken(): void {
|
||||
this.currentToken = null;
|
||||
this.currentEndpoint = null;
|
||||
this.refreshPromise = null;
|
||||
}
|
||||
}
|
||||
@@ -21,6 +21,7 @@ export * from './core/nonInteractiveToolExecutor.js';
|
||||
|
||||
export * from './code_assist/codeAssist.js';
|
||||
export * from './code_assist/oauth2.js';
|
||||
export * from './code_assist/qwenOAuth2.js';
|
||||
export * from './code_assist/server.js';
|
||||
export * from './code_assist/types.js';
|
||||
|
||||
|
||||
194
packages/core/src/utils/quotaErrorDetection.test.ts
Normal file
194
packages/core/src/utils/quotaErrorDetection.test.ts
Normal file
@@ -0,0 +1,194 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { describe, it, expect } from 'vitest';
|
||||
import {
|
||||
isQwenQuotaExceededError,
|
||||
isQwenThrottlingError,
|
||||
isProQuotaExceededError,
|
||||
isGenericQuotaExceededError,
|
||||
isApiError,
|
||||
isStructuredError,
|
||||
type ApiError,
|
||||
} from './quotaErrorDetection.js';
|
||||
|
||||
describe('quotaErrorDetection', () => {
|
||||
describe('isQwenQuotaExceededError', () => {
|
||||
it('should detect insufficient_quota error message', () => {
|
||||
const error = new Error('insufficient_quota');
|
||||
expect(isQwenQuotaExceededError(error)).toBe(true);
|
||||
});
|
||||
|
||||
it('should detect free allocated quota exceeded error message', () => {
|
||||
const error = new Error('Free allocated quota exceeded.');
|
||||
expect(isQwenQuotaExceededError(error)).toBe(true);
|
||||
});
|
||||
|
||||
it('should detect quota exceeded error message', () => {
|
||||
const error = new Error('quota exceeded');
|
||||
expect(isQwenQuotaExceededError(error)).toBe(true);
|
||||
});
|
||||
|
||||
it('should detect quota exceeded in string error', () => {
|
||||
const error = 'insufficient_quota';
|
||||
expect(isQwenQuotaExceededError(error)).toBe(true);
|
||||
});
|
||||
|
||||
it('should detect quota exceeded in structured error', () => {
|
||||
const error = { message: 'Free allocated quota exceeded.', status: 429 };
|
||||
expect(isQwenQuotaExceededError(error)).toBe(true);
|
||||
});
|
||||
|
||||
it('should detect quota exceeded in API error', () => {
|
||||
const error: ApiError = {
|
||||
error: {
|
||||
code: 429,
|
||||
message: 'insufficient_quota',
|
||||
status: 'RESOURCE_EXHAUSTED',
|
||||
details: [],
|
||||
},
|
||||
};
|
||||
expect(isQwenQuotaExceededError(error)).toBe(true);
|
||||
});
|
||||
|
||||
it('should not detect throttling errors as quota exceeded', () => {
|
||||
const error = new Error('requests throttling triggered');
|
||||
expect(isQwenQuotaExceededError(error)).toBe(false);
|
||||
});
|
||||
|
||||
it('should not detect unrelated errors', () => {
|
||||
const error = new Error('Network error');
|
||||
expect(isQwenQuotaExceededError(error)).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe('isQwenThrottlingError', () => {
|
||||
it('should detect throttling error with 429 status', () => {
|
||||
const error = { message: 'throttling', status: 429 };
|
||||
expect(isQwenThrottlingError(error)).toBe(true);
|
||||
});
|
||||
|
||||
it('should detect requests throttling triggered with 429 status', () => {
|
||||
const error = { message: 'requests throttling triggered', status: 429 };
|
||||
expect(isQwenThrottlingError(error)).toBe(true);
|
||||
});
|
||||
|
||||
it('should detect rate limit error with 429 status', () => {
|
||||
const error = { message: 'rate limit exceeded', status: 429 };
|
||||
expect(isQwenThrottlingError(error)).toBe(true);
|
||||
});
|
||||
|
||||
it('should detect too many requests with 429 status', () => {
|
||||
const error = { message: 'too many requests', status: 429 };
|
||||
expect(isQwenThrottlingError(error)).toBe(true);
|
||||
});
|
||||
|
||||
it('should detect throttling in string error', () => {
|
||||
const error = 'throttling';
|
||||
expect(isQwenThrottlingError(error)).toBe(true);
|
||||
});
|
||||
|
||||
it('should detect throttling in structured error with 429', () => {
|
||||
const error = { message: 'requests throttling triggered', status: 429 };
|
||||
expect(isQwenThrottlingError(error)).toBe(true);
|
||||
});
|
||||
|
||||
it('should detect throttling in API error with 429', () => {
|
||||
const error: ApiError = {
|
||||
error: {
|
||||
code: 429,
|
||||
message: 'throttling',
|
||||
status: 'RESOURCE_EXHAUSTED',
|
||||
details: [],
|
||||
},
|
||||
};
|
||||
expect(isQwenThrottlingError(error)).toBe(true);
|
||||
});
|
||||
|
||||
it('should not detect throttling without 429 status in structured error', () => {
|
||||
const error = { message: 'throttling', status: 500 };
|
||||
expect(isQwenThrottlingError(error)).toBe(false);
|
||||
});
|
||||
|
||||
it('should not detect quota exceeded as throttling', () => {
|
||||
const error = { message: 'insufficient_quota', status: 429 };
|
||||
expect(isQwenThrottlingError(error)).toBe(false);
|
||||
});
|
||||
|
||||
it('should not detect unrelated errors as throttling', () => {
|
||||
const error = { message: 'Network error', status: 500 };
|
||||
expect(isQwenThrottlingError(error)).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe('isProQuotaExceededError', () => {
|
||||
it('should detect Gemini Pro quota exceeded error', () => {
|
||||
const error = new Error(
|
||||
"Quota exceeded for quota metric 'Gemini 2.5 Pro Requests'",
|
||||
);
|
||||
expect(isProQuotaExceededError(error)).toBe(true);
|
||||
});
|
||||
|
||||
it('should detect Gemini preview Pro quota exceeded error', () => {
|
||||
const error = new Error(
|
||||
"Quota exceeded for quota metric 'Gemini 2.5-preview Pro Requests'",
|
||||
);
|
||||
expect(isProQuotaExceededError(error)).toBe(true);
|
||||
});
|
||||
|
||||
it('should not detect non-Pro quota errors', () => {
|
||||
const error = new Error(
|
||||
"Quota exceeded for quota metric 'Gemini 1.5 Flash Requests'",
|
||||
);
|
||||
expect(isProQuotaExceededError(error)).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe('isGenericQuotaExceededError', () => {
|
||||
it('should detect generic quota exceeded error', () => {
|
||||
const error = new Error('Quota exceeded for quota metric');
|
||||
expect(isGenericQuotaExceededError(error)).toBe(true);
|
||||
});
|
||||
|
||||
it('should not detect non-quota errors', () => {
|
||||
const error = new Error('Network error');
|
||||
expect(isGenericQuotaExceededError(error)).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe('type guards', () => {
|
||||
describe('isApiError', () => {
|
||||
it('should detect valid API error', () => {
|
||||
const error: ApiError = {
|
||||
error: {
|
||||
code: 429,
|
||||
message: 'test error',
|
||||
status: 'RESOURCE_EXHAUSTED',
|
||||
details: [],
|
||||
},
|
||||
};
|
||||
expect(isApiError(error)).toBe(true);
|
||||
});
|
||||
|
||||
it('should not detect invalid API error', () => {
|
||||
const error = { message: 'test error' };
|
||||
expect(isApiError(error)).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe('isStructuredError', () => {
|
||||
it('should detect valid structured error', () => {
|
||||
const error = { message: 'test error', status: 429 };
|
||||
expect(isStructuredError(error)).toBe(true);
|
||||
});
|
||||
|
||||
it('should not detect invalid structured error', () => {
|
||||
const error = { code: 429 };
|
||||
expect(isStructuredError(error)).toBe(false);
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -101,3 +101,70 @@ export function isGenericQuotaExceededError(error: unknown): boolean {
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
export function isQwenQuotaExceededError(error: unknown): boolean {
|
||||
// Check for Qwen insufficient quota errors (should not retry)
|
||||
const checkMessage = (message: string): boolean => {
|
||||
const lowerMessage = message.toLowerCase();
|
||||
return (
|
||||
lowerMessage.includes('insufficient_quota') ||
|
||||
lowerMessage.includes('free allocated quota exceeded') ||
|
||||
(lowerMessage.includes('quota') && lowerMessage.includes('exceeded'))
|
||||
);
|
||||
};
|
||||
|
||||
if (typeof error === 'string') {
|
||||
return checkMessage(error);
|
||||
}
|
||||
|
||||
if (isStructuredError(error)) {
|
||||
return checkMessage(error.message);
|
||||
}
|
||||
|
||||
if (isApiError(error)) {
|
||||
return checkMessage(error.error.message);
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
export function isQwenThrottlingError(error: unknown): boolean {
|
||||
// Check for Qwen throttling errors (should retry)
|
||||
const checkMessage = (message: string): boolean => {
|
||||
const lowerMessage = message.toLowerCase();
|
||||
return (
|
||||
lowerMessage.includes('throttling') ||
|
||||
lowerMessage.includes('requests throttling triggered') ||
|
||||
lowerMessage.includes('rate limit') ||
|
||||
lowerMessage.includes('too many requests')
|
||||
);
|
||||
};
|
||||
|
||||
// Check status code
|
||||
const getStatusCode = (error: unknown): number | undefined => {
|
||||
if (error && typeof error === 'object') {
|
||||
const errorObj = error as { status?: number; code?: number };
|
||||
return errorObj.status || errorObj.code;
|
||||
}
|
||||
return undefined;
|
||||
};
|
||||
|
||||
const statusCode = getStatusCode(error);
|
||||
|
||||
if (typeof error === 'string') {
|
||||
return (
|
||||
(statusCode === 429 && checkMessage(error)) ||
|
||||
error.includes('throttling')
|
||||
);
|
||||
}
|
||||
|
||||
if (isStructuredError(error)) {
|
||||
return statusCode === 429 && checkMessage(error.message);
|
||||
}
|
||||
|
||||
if (isApiError(error)) {
|
||||
return error.error.code === 429 && checkMessage(error.error.message);
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
|
||||
import { retryWithBackoff, HttpError } from './retry.js';
|
||||
import { setSimulate429 } from './testUtils.js';
|
||||
import { AuthType } from '../core/contentGenerator.js';
|
||||
|
||||
// Helper to create a mock function that fails a certain number of times
|
||||
const createFailingFunction = (
|
||||
@@ -399,4 +400,173 @@ describe('retryWithBackoff', () => {
|
||||
expect(fallbackCallback).toHaveBeenCalledWith('oauth-personal');
|
||||
});
|
||||
});
|
||||
|
||||
describe('Qwen OAuth 429 error handling', () => {
|
||||
it('should retry for Qwen OAuth 429 errors that are throttling-related', async () => {
|
||||
const errorWith429: HttpError = new Error('Rate limit exceeded');
|
||||
errorWith429.status = 429;
|
||||
|
||||
const fn = vi
|
||||
.fn()
|
||||
.mockRejectedValueOnce(errorWith429)
|
||||
.mockResolvedValue('success');
|
||||
|
||||
const promise = retryWithBackoff(fn, {
|
||||
maxAttempts: 5,
|
||||
initialDelayMs: 100,
|
||||
maxDelayMs: 1000,
|
||||
shouldRetry: () => true,
|
||||
authType: AuthType.QWEN_OAUTH,
|
||||
});
|
||||
|
||||
// Fast-forward time for delays
|
||||
await vi.runAllTimersAsync();
|
||||
|
||||
await expect(promise).resolves.toBe('success');
|
||||
|
||||
// Should be called twice (1 failure + 1 success)
|
||||
expect(fn).toHaveBeenCalledTimes(2);
|
||||
});
|
||||
|
||||
it('should throw immediately for Qwen OAuth with insufficient_quota message', async () => {
|
||||
const errorWithInsufficientQuota = new Error('insufficient_quota');
|
||||
|
||||
const fn = vi.fn().mockRejectedValue(errorWithInsufficientQuota);
|
||||
|
||||
const promise = retryWithBackoff(fn, {
|
||||
maxAttempts: 5,
|
||||
initialDelayMs: 1000,
|
||||
maxDelayMs: 5000,
|
||||
shouldRetry: () => true,
|
||||
authType: AuthType.QWEN_OAUTH,
|
||||
});
|
||||
|
||||
await expect(promise).rejects.toThrow(/Qwen API quota exceeded/);
|
||||
|
||||
// Should be called only once (no retries)
|
||||
expect(fn).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
|
||||
it('should throw immediately for Qwen OAuth with free allocated quota exceeded message', async () => {
|
||||
const errorWithQuotaExceeded = new Error(
|
||||
'Free allocated quota exceeded.',
|
||||
);
|
||||
|
||||
const fn = vi.fn().mockRejectedValue(errorWithQuotaExceeded);
|
||||
|
||||
const promise = retryWithBackoff(fn, {
|
||||
maxAttempts: 5,
|
||||
initialDelayMs: 1000,
|
||||
maxDelayMs: 5000,
|
||||
shouldRetry: () => true,
|
||||
authType: AuthType.QWEN_OAUTH,
|
||||
});
|
||||
|
||||
await expect(promise).rejects.toThrow(/Qwen API quota exceeded/);
|
||||
|
||||
// Should be called only once (no retries)
|
||||
expect(fn).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
|
||||
it('should retry for Qwen OAuth with throttling message', async () => {
|
||||
const throttlingError: HttpError = new Error(
|
||||
'requests throttling triggered',
|
||||
);
|
||||
throttlingError.status = 429;
|
||||
|
||||
const fn = vi
|
||||
.fn()
|
||||
.mockRejectedValueOnce(throttlingError)
|
||||
.mockRejectedValueOnce(throttlingError)
|
||||
.mockResolvedValue('success');
|
||||
|
||||
const promise = retryWithBackoff(fn, {
|
||||
maxAttempts: 5,
|
||||
initialDelayMs: 100,
|
||||
maxDelayMs: 1000,
|
||||
shouldRetry: () => true,
|
||||
authType: AuthType.QWEN_OAUTH,
|
||||
});
|
||||
|
||||
// Fast-forward time for delays
|
||||
await vi.runAllTimersAsync();
|
||||
|
||||
await expect(promise).resolves.toBe('success');
|
||||
|
||||
// Should be called 3 times (2 failures + 1 success)
|
||||
expect(fn).toHaveBeenCalledTimes(3);
|
||||
});
|
||||
|
||||
it('should retry for Qwen OAuth with throttling error', async () => {
|
||||
const throttlingError: HttpError = new Error('throttling');
|
||||
throttlingError.status = 429;
|
||||
|
||||
const fn = vi
|
||||
.fn()
|
||||
.mockRejectedValueOnce(throttlingError)
|
||||
.mockResolvedValue('success');
|
||||
|
||||
const promise = retryWithBackoff(fn, {
|
||||
maxAttempts: 5,
|
||||
initialDelayMs: 100,
|
||||
maxDelayMs: 1000,
|
||||
shouldRetry: () => true,
|
||||
authType: AuthType.QWEN_OAUTH,
|
||||
});
|
||||
|
||||
// Fast-forward time for delays
|
||||
await vi.runAllTimersAsync();
|
||||
|
||||
await expect(promise).resolves.toBe('success');
|
||||
|
||||
// Should be called 2 times (1 failure + 1 success)
|
||||
expect(fn).toHaveBeenCalledTimes(2);
|
||||
});
|
||||
|
||||
it('should throw immediately for Qwen OAuth with quota message', async () => {
|
||||
const errorWithQuota = new Error('quota exceeded');
|
||||
|
||||
const fn = vi.fn().mockRejectedValue(errorWithQuota);
|
||||
|
||||
const promise = retryWithBackoff(fn, {
|
||||
maxAttempts: 5,
|
||||
initialDelayMs: 1000,
|
||||
maxDelayMs: 5000,
|
||||
shouldRetry: () => true,
|
||||
authType: AuthType.QWEN_OAUTH,
|
||||
});
|
||||
|
||||
await expect(promise).rejects.toThrow(/Qwen API quota exceeded/);
|
||||
|
||||
// Should be called only once (no retries)
|
||||
expect(fn).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
|
||||
it('should retry normal errors for Qwen OAuth (not quota-related)', async () => {
|
||||
const normalError: HttpError = new Error('Network error');
|
||||
normalError.status = 500;
|
||||
|
||||
const fn = createFailingFunction(2, 'success');
|
||||
// Replace the default 500 error with our normal error
|
||||
fn.mockRejectedValueOnce(normalError)
|
||||
.mockRejectedValueOnce(normalError)
|
||||
.mockResolvedValue('success');
|
||||
|
||||
const promise = retryWithBackoff(fn, {
|
||||
maxAttempts: 5,
|
||||
initialDelayMs: 100,
|
||||
maxDelayMs: 1000,
|
||||
shouldRetry: () => true,
|
||||
authType: AuthType.QWEN_OAUTH,
|
||||
});
|
||||
|
||||
// Fast-forward time for delays
|
||||
await vi.runAllTimersAsync();
|
||||
|
||||
await expect(promise).resolves.toBe('success');
|
||||
|
||||
// Should be called 3 times (2 failures + 1 success)
|
||||
expect(fn).toHaveBeenCalledTimes(3);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -8,6 +8,8 @@ import { AuthType } from '../core/contentGenerator.js';
|
||||
import {
|
||||
isProQuotaExceededError,
|
||||
isGenericQuotaExceededError,
|
||||
isQwenQuotaExceededError,
|
||||
isQwenThrottlingError,
|
||||
} from './quotaErrorDetection.js';
|
||||
|
||||
export interface HttpError extends Error {
|
||||
@@ -150,9 +152,23 @@ export async function retryWithBackoff<T>(
|
||||
}
|
||||
}
|
||||
|
||||
// Track consecutive 429 errors
|
||||
// Check for Qwen OAuth quota exceeded error - throw immediately without retry
|
||||
if (authType === AuthType.QWEN_OAUTH && isQwenQuotaExceededError(error)) {
|
||||
throw new Error(
|
||||
`Qwen API quota exceeded: Your Qwen API quota has been exhausted. Please wait for your quota to reset.`,
|
||||
);
|
||||
}
|
||||
|
||||
// Track consecutive 429 errors, but handle Qwen throttling differently
|
||||
if (errorStatus === 429) {
|
||||
consecutive429Count++;
|
||||
// For Qwen throttling errors, we still want to track them for exponential backoff
|
||||
// but not for quota fallback logic (since Qwen doesn't have model fallback)
|
||||
if (authType === AuthType.QWEN_OAUTH && isQwenThrottlingError(error)) {
|
||||
// Keep track of 429s but reset the consecutive count to avoid fallback logic
|
||||
consecutive429Count = 0;
|
||||
} else {
|
||||
consecutive429Count++;
|
||||
}
|
||||
} else {
|
||||
consecutive429Count = 0;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user