mirror of
https://github.com/QwenLM/qwen-code.git
synced 2025-12-19 09:33:53 +00:00
Fix oauth credential caching. (#2709)
This commit is contained in:
committed by
GitHub
parent
f3849627fc
commit
5c4c833ddd
@@ -64,6 +64,7 @@ describe('oauth2', () => {
|
|||||||
setCredentials: mockSetCredentials,
|
setCredentials: mockSetCredentials,
|
||||||
getAccessToken: mockGetAccessToken,
|
getAccessToken: mockGetAccessToken,
|
||||||
credentials: mockTokens,
|
credentials: mockTokens,
|
||||||
|
on: vi.fn(),
|
||||||
} as unknown as OAuth2Client;
|
} as unknown as OAuth2Client;
|
||||||
vi.mocked(OAuth2Client).mockImplementation(() => mockOAuth2Client);
|
vi.mocked(OAuth2Client).mockImplementation(() => mockOAuth2Client);
|
||||||
|
|
||||||
@@ -136,10 +137,6 @@ describe('oauth2', () => {
|
|||||||
});
|
});
|
||||||
expect(mockSetCredentials).toHaveBeenCalledWith(mockTokens);
|
expect(mockSetCredentials).toHaveBeenCalledWith(mockTokens);
|
||||||
|
|
||||||
const tokenPath = path.join(tempHomeDir, '.gemini', 'oauth_creds.json');
|
|
||||||
const tokenData = JSON.parse(fs.readFileSync(tokenPath, 'utf-8'));
|
|
||||||
expect(tokenData).toEqual(mockTokens);
|
|
||||||
|
|
||||||
// Verify Google Account ID was cached
|
// Verify Google Account ID was cached
|
||||||
const googleAccountIdPath = path.join(
|
const googleAccountIdPath = path.join(
|
||||||
tempHomeDir,
|
tempHomeDir,
|
||||||
|
|||||||
@@ -58,6 +58,9 @@ export async function getOauthClient(): Promise<OAuth2Client> {
|
|||||||
clientId: OAUTH_CLIENT_ID,
|
clientId: OAUTH_CLIENT_ID,
|
||||||
clientSecret: OAUTH_CLIENT_SECRET,
|
clientSecret: OAUTH_CLIENT_SECRET,
|
||||||
});
|
});
|
||||||
|
client.on('tokens', async (tokens: Credentials) => {
|
||||||
|
await cacheCredentials(tokens);
|
||||||
|
});
|
||||||
|
|
||||||
if (await loadCachedCredentials(client)) {
|
if (await loadCachedCredentials(client)) {
|
||||||
// Found valid cached credentials.
|
// Found valid cached credentials.
|
||||||
@@ -130,8 +133,6 @@ async function authWithWeb(client: OAuth2Client): Promise<OauthWebLogin> {
|
|||||||
redirect_uri: redirectUri,
|
redirect_uri: redirectUri,
|
||||||
});
|
});
|
||||||
client.setCredentials(tokens);
|
client.setCredentials(tokens);
|
||||||
await cacheCredentials(client.credentials);
|
|
||||||
|
|
||||||
// Retrieve and cache Google Account ID during authentication
|
// Retrieve and cache Google Account ID during authentication
|
||||||
try {
|
try {
|
||||||
const googleAccountId = await getGoogleAccountId(client);
|
const googleAccountId = await getGoogleAccountId(client);
|
||||||
|
|||||||
@@ -18,8 +18,8 @@ describe('CodeAssistServer', () => {
|
|||||||
});
|
});
|
||||||
|
|
||||||
it('should call the generateContent endpoint', async () => {
|
it('should call the generateContent endpoint', async () => {
|
||||||
const auth = new OAuth2Client();
|
const client = new OAuth2Client();
|
||||||
const server = new CodeAssistServer(auth, 'test-project');
|
const server = new CodeAssistServer(client, 'test-project');
|
||||||
const mockResponse = {
|
const mockResponse = {
|
||||||
response: {
|
response: {
|
||||||
candidates: [
|
candidates: [
|
||||||
@@ -53,8 +53,8 @@ describe('CodeAssistServer', () => {
|
|||||||
});
|
});
|
||||||
|
|
||||||
it('should call the generateContentStream endpoint', async () => {
|
it('should call the generateContentStream endpoint', async () => {
|
||||||
const auth = new OAuth2Client();
|
const client = new OAuth2Client();
|
||||||
const server = new CodeAssistServer(auth, 'test-project');
|
const server = new CodeAssistServer(client, 'test-project');
|
||||||
const mockResponse = (async function* () {
|
const mockResponse = (async function* () {
|
||||||
yield {
|
yield {
|
||||||
response: {
|
response: {
|
||||||
@@ -90,8 +90,8 @@ describe('CodeAssistServer', () => {
|
|||||||
});
|
});
|
||||||
|
|
||||||
it('should call the onboardUser endpoint', async () => {
|
it('should call the onboardUser endpoint', async () => {
|
||||||
const auth = new OAuth2Client();
|
const client = new OAuth2Client();
|
||||||
const server = new CodeAssistServer(auth, 'test-project');
|
const server = new CodeAssistServer(client, 'test-project');
|
||||||
const mockResponse = {
|
const mockResponse = {
|
||||||
name: 'operations/123',
|
name: 'operations/123',
|
||||||
done: true,
|
done: true,
|
||||||
@@ -112,8 +112,8 @@ describe('CodeAssistServer', () => {
|
|||||||
});
|
});
|
||||||
|
|
||||||
it('should call the loadCodeAssist endpoint', async () => {
|
it('should call the loadCodeAssist endpoint', async () => {
|
||||||
const auth = new OAuth2Client();
|
const client = new OAuth2Client();
|
||||||
const server = new CodeAssistServer(auth, 'test-project');
|
const server = new CodeAssistServer(client, 'test-project');
|
||||||
const mockResponse = {
|
const mockResponse = {
|
||||||
// TODO: Add mock response
|
// TODO: Add mock response
|
||||||
};
|
};
|
||||||
@@ -131,8 +131,8 @@ describe('CodeAssistServer', () => {
|
|||||||
});
|
});
|
||||||
|
|
||||||
it('should return 0 for countTokens', async () => {
|
it('should return 0 for countTokens', async () => {
|
||||||
const auth = new OAuth2Client();
|
const client = new OAuth2Client();
|
||||||
const server = new CodeAssistServer(auth, 'test-project');
|
const server = new CodeAssistServer(client, 'test-project');
|
||||||
const mockResponse = {
|
const mockResponse = {
|
||||||
totalTokens: 100,
|
totalTokens: 100,
|
||||||
};
|
};
|
||||||
@@ -146,8 +146,8 @@ describe('CodeAssistServer', () => {
|
|||||||
});
|
});
|
||||||
|
|
||||||
it('should throw an error for embedContent', async () => {
|
it('should throw an error for embedContent', async () => {
|
||||||
const auth = new OAuth2Client();
|
const client = new OAuth2Client();
|
||||||
const server = new CodeAssistServer(auth, 'test-project');
|
const server = new CodeAssistServer(client, 'test-project');
|
||||||
await expect(
|
await expect(
|
||||||
server.embedContent({
|
server.embedContent({
|
||||||
model: 'test-model',
|
model: 'test-model',
|
||||||
|
|||||||
@@ -4,7 +4,7 @@
|
|||||||
* SPDX-License-Identifier: Apache-2.0
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
*/
|
*/
|
||||||
|
|
||||||
import { AuthClient } from 'google-auth-library';
|
import { OAuth2Client } from 'google-auth-library';
|
||||||
import {
|
import {
|
||||||
CodeAssistGlobalUserSettingResponse,
|
CodeAssistGlobalUserSettingResponse,
|
||||||
LoadCodeAssistRequest,
|
LoadCodeAssistRequest,
|
||||||
@@ -46,7 +46,7 @@ export const CODE_ASSIST_API_VERSION = 'v1internal';
|
|||||||
|
|
||||||
export class CodeAssistServer implements ContentGenerator {
|
export class CodeAssistServer implements ContentGenerator {
|
||||||
constructor(
|
constructor(
|
||||||
readonly auth: AuthClient,
|
readonly client: OAuth2Client,
|
||||||
readonly projectId?: string,
|
readonly projectId?: string,
|
||||||
readonly httpOptions: HttpOptions = {},
|
readonly httpOptions: HttpOptions = {},
|
||||||
) {}
|
) {}
|
||||||
@@ -129,7 +129,7 @@ export class CodeAssistServer implements ContentGenerator {
|
|||||||
req: object,
|
req: object,
|
||||||
signal?: AbortSignal,
|
signal?: AbortSignal,
|
||||||
): Promise<T> {
|
): Promise<T> {
|
||||||
const res = await this.auth.request({
|
const res = await this.client.request({
|
||||||
url: `${CODE_ASSIST_ENDPOINT}/${CODE_ASSIST_API_VERSION}:${method}`,
|
url: `${CODE_ASSIST_ENDPOINT}/${CODE_ASSIST_API_VERSION}:${method}`,
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
headers: {
|
headers: {
|
||||||
@@ -144,7 +144,7 @@ export class CodeAssistServer implements ContentGenerator {
|
|||||||
}
|
}
|
||||||
|
|
||||||
async getEndpoint<T>(method: string, signal?: AbortSignal): Promise<T> {
|
async getEndpoint<T>(method: string, signal?: AbortSignal): Promise<T> {
|
||||||
const res = await this.auth.request({
|
const res = await this.client.request({
|
||||||
url: `${CODE_ASSIST_ENDPOINT}/${CODE_ASSIST_API_VERSION}:${method}`,
|
url: `${CODE_ASSIST_ENDPOINT}/${CODE_ASSIST_API_VERSION}:${method}`,
|
||||||
method: 'GET',
|
method: 'GET',
|
||||||
headers: {
|
headers: {
|
||||||
@@ -162,7 +162,7 @@ export class CodeAssistServer implements ContentGenerator {
|
|||||||
req: object,
|
req: object,
|
||||||
signal?: AbortSignal,
|
signal?: AbortSignal,
|
||||||
): Promise<AsyncGenerator<T>> {
|
): Promise<AsyncGenerator<T>> {
|
||||||
const res = await this.auth.request({
|
const res = await this.client.request({
|
||||||
url: `${CODE_ASSIST_ENDPOINT}/${CODE_ASSIST_API_VERSION}:${method}`,
|
url: `${CODE_ASSIST_ENDPOINT}/${CODE_ASSIST_API_VERSION}:${method}`,
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
params: {
|
params: {
|
||||||
|
|||||||
@@ -27,9 +27,9 @@ export class ProjectIdRequiredError extends Error {
|
|||||||
* @param projectId the user's project id, if any
|
* @param projectId the user's project id, if any
|
||||||
* @returns the user's actual project id
|
* @returns the user's actual project id
|
||||||
*/
|
*/
|
||||||
export async function setupUser(authClient: OAuth2Client): Promise<string> {
|
export async function setupUser(client: OAuth2Client): Promise<string> {
|
||||||
let projectId = process.env.GOOGLE_CLOUD_PROJECT;
|
let projectId = process.env.GOOGLE_CLOUD_PROJECT;
|
||||||
const caServer = new CodeAssistServer(authClient, projectId);
|
const caServer = new CodeAssistServer(client, projectId);
|
||||||
|
|
||||||
const clientMetadata: ClientMetadata = {
|
const clientMetadata: ClientMetadata = {
|
||||||
ideType: 'IDE_UNSPECIFIED',
|
ideType: 'IDE_UNSPECIFIED',
|
||||||
|
|||||||
Reference in New Issue
Block a user