fix: use server-returned project for gca free tier auth (#6113)

This commit is contained in:
Gaurav
2025-08-13 14:04:58 -07:00
committed by GitHub
parent 2dbd5ecdc8
commit f9a1e8eb6f
2 changed files with 175 additions and 22 deletions

View File

@@ -16,9 +16,17 @@ const mockPaidTier: GeminiUserTier = {
id: UserTierId.STANDARD, id: UserTierId.STANDARD,
name: 'paid', name: 'paid',
description: 'Paid tier', description: 'Paid tier',
isDefault: true,
}; };
describe('setupUser', () => { const mockFreeTier: GeminiUserTier = {
id: UserTierId.FREE,
name: 'free',
description: 'Free tier',
isDefault: true,
};
describe('setupUser for existing user', () => {
let mockLoad: ReturnType<typeof vi.fn>; let mockLoad: ReturnType<typeof vi.fn>;
let mockOnboardUser: ReturnType<typeof vi.fn>; let mockOnboardUser: ReturnType<typeof vi.fn>;
@@ -42,7 +50,7 @@ describe('setupUser', () => {
); );
}); });
it('should use GOOGLE_CLOUD_PROJECT when set', async () => { it('should use GOOGLE_CLOUD_PROJECT when set and project from server is undefined', async () => {
process.env.GOOGLE_CLOUD_PROJECT = 'test-project'; process.env.GOOGLE_CLOUD_PROJECT = 'test-project';
mockLoad.mockResolvedValue({ mockLoad.mockResolvedValue({
currentTier: mockPaidTier, currentTier: mockPaidTier,
@@ -57,8 +65,8 @@ describe('setupUser', () => {
); );
}); });
it('should treat empty GOOGLE_CLOUD_PROJECT as undefined and use project from server', async () => { it('should ignore GOOGLE_CLOUD_PROJECT when project from server is set', async () => {
process.env.GOOGLE_CLOUD_PROJECT = ''; process.env.GOOGLE_CLOUD_PROJECT = 'test-project';
mockLoad.mockResolvedValue({ mockLoad.mockResolvedValue({
cloudaicompanionProject: 'server-project', cloudaicompanionProject: 'server-project',
currentTier: mockPaidTier, currentTier: mockPaidTier,
@@ -66,7 +74,7 @@ describe('setupUser', () => {
const projectId = await setupUser({} as OAuth2Client); const projectId = await setupUser({} as OAuth2Client);
expect(CodeAssistServer).toHaveBeenCalledWith( expect(CodeAssistServer).toHaveBeenCalledWith(
{}, {},
undefined, 'test-project',
{}, {},
'', '',
undefined, undefined,
@@ -89,3 +97,119 @@ describe('setupUser', () => {
); );
}); });
}); });
describe('setupUser for new user', () => {
let mockLoad: ReturnType<typeof vi.fn>;
let mockOnboardUser: ReturnType<typeof vi.fn>;
beforeEach(() => {
vi.resetAllMocks();
mockLoad = vi.fn();
mockOnboardUser = vi.fn().mockResolvedValue({
done: true,
response: {
cloudaicompanionProject: {
id: 'server-project',
},
},
});
vi.mocked(CodeAssistServer).mockImplementation(
() =>
({
loadCodeAssist: mockLoad,
onboardUser: mockOnboardUser,
}) as unknown as CodeAssistServer,
);
});
it('should use GOOGLE_CLOUD_PROJECT when set and onboard a new paid user', async () => {
process.env.GOOGLE_CLOUD_PROJECT = 'test-project';
mockLoad.mockResolvedValue({
allowedTiers: [mockPaidTier],
});
const userData = await setupUser({} as OAuth2Client);
expect(CodeAssistServer).toHaveBeenCalledWith(
{},
'test-project',
{},
'',
undefined,
);
expect(mockLoad).toHaveBeenCalled();
expect(mockOnboardUser).toHaveBeenCalledWith({
tierId: 'standard-tier',
cloudaicompanionProject: 'test-project',
metadata: {
ideType: 'IDE_UNSPECIFIED',
platform: 'PLATFORM_UNSPECIFIED',
pluginType: 'GEMINI',
duetProject: 'test-project',
},
});
expect(userData).toEqual({
projectId: 'server-project',
userTier: 'standard-tier',
});
});
it('should onboard a new free user when GOOGLE_CLOUD_PROJECT is not set', async () => {
delete process.env.GOOGLE_CLOUD_PROJECT;
mockLoad.mockResolvedValue({
allowedTiers: [mockFreeTier],
});
const userData = await setupUser({} as OAuth2Client);
expect(CodeAssistServer).toHaveBeenCalledWith(
{},
undefined,
{},
'',
undefined,
);
expect(mockLoad).toHaveBeenCalled();
expect(mockOnboardUser).toHaveBeenCalledWith({
tierId: 'free-tier',
cloudaicompanionProject: undefined,
metadata: {
ideType: 'IDE_UNSPECIFIED',
platform: 'PLATFORM_UNSPECIFIED',
pluginType: 'GEMINI',
},
});
expect(userData).toEqual({
projectId: 'server-project',
userTier: 'free-tier',
});
});
it('should use GOOGLE_CLOUD_PROJECT when onboard response has no project ID', async () => {
process.env.GOOGLE_CLOUD_PROJECT = 'test-project';
mockLoad.mockResolvedValue({
allowedTiers: [mockPaidTier],
});
mockOnboardUser.mockResolvedValue({
done: true,
response: {
cloudaicompanionProject: undefined,
},
});
const userData = await setupUser({} as OAuth2Client);
expect(userData).toEqual({
projectId: 'test-project',
userTier: 'standard-tier',
});
});
it('should throw ProjectIdRequiredError when no project ID is available', async () => {
delete process.env.GOOGLE_CLOUD_PROJECT;
mockLoad.mockResolvedValue({
allowedTiers: [mockPaidTier],
});
mockOnboardUser.mockResolvedValue({
done: true,
response: {},
});
await expect(setupUser({} as OAuth2Client)).rejects.toThrow(
ProjectIdRequiredError,
);
});
});

View File

@@ -33,32 +33,58 @@ export interface UserData {
* @returns the user's actual project id * @returns the user's actual project id
*/ */
export async function setupUser(client: OAuth2Client): Promise<UserData> { export async function setupUser(client: OAuth2Client): Promise<UserData> {
let projectId = process.env.GOOGLE_CLOUD_PROJECT || undefined; const projectId = process.env.GOOGLE_CLOUD_PROJECT || undefined;
const caServer = new CodeAssistServer(client, projectId, {}, '', undefined); const caServer = new CodeAssistServer(client, projectId, {}, '', undefined);
const coreClientMetadata: ClientMetadata = {
const clientMetadata: ClientMetadata = {
ideType: 'IDE_UNSPECIFIED', ideType: 'IDE_UNSPECIFIED',
platform: 'PLATFORM_UNSPECIFIED', platform: 'PLATFORM_UNSPECIFIED',
pluginType: 'GEMINI', pluginType: 'GEMINI',
duetProject: projectId,
}; };
const loadRes = await caServer.loadCodeAssist({ const loadRes = await caServer.loadCodeAssist({
cloudaicompanionProject: projectId, cloudaicompanionProject: projectId,
metadata: clientMetadata, metadata: {
...coreClientMetadata,
duetProject: projectId,
},
}); });
if (!projectId && loadRes.cloudaicompanionProject) { if (loadRes.currentTier) {
projectId = loadRes.cloudaicompanionProject; if (!loadRes.cloudaicompanionProject) {
if (projectId) {
return {
projectId,
userTier: loadRes.currentTier.id,
};
}
throw new ProjectIdRequiredError();
}
return {
projectId: loadRes.cloudaicompanionProject,
userTier: loadRes.currentTier.id,
};
} }
const tier = getOnboardTier(loadRes); const tier = getOnboardTier(loadRes);
const onboardReq: OnboardUserRequest = { let onboardReq: OnboardUserRequest;
tierId: tier.id, if (tier.id === UserTierId.FREE) {
cloudaicompanionProject: projectId, // The free tier uses a managed google cloud project. Setting a project in the `onboardUser` request causes a `Precondition Failed` error.
metadata: clientMetadata, onboardReq = {
}; tierId: tier.id,
cloudaicompanionProject: undefined,
metadata: coreClientMetadata,
};
} else {
onboardReq = {
tierId: tier.id,
cloudaicompanionProject: projectId,
metadata: {
...coreClientMetadata,
duetProject: projectId,
},
};
}
// Poll onboardUser until long running operation is complete. // Poll onboardUser until long running operation is complete.
let lroRes = await caServer.onboardUser(onboardReq); let lroRes = await caServer.onboardUser(onboardReq);
@@ -67,20 +93,23 @@ export async function setupUser(client: OAuth2Client): Promise<UserData> {
lroRes = await caServer.onboardUser(onboardReq); lroRes = await caServer.onboardUser(onboardReq);
} }
if (!lroRes.response?.cloudaicompanionProject?.id && !projectId) { if (!lroRes.response?.cloudaicompanionProject?.id) {
if (projectId) {
return {
projectId,
userTier: tier.id,
};
}
throw new ProjectIdRequiredError(); throw new ProjectIdRequiredError();
} }
return { return {
projectId: lroRes.response?.cloudaicompanionProject?.id || projectId!, projectId: lroRes.response.cloudaicompanionProject.id,
userTier: tier.id, userTier: tier.id,
}; };
} }
function getOnboardTier(res: LoadCodeAssistResponse): GeminiUserTier { function getOnboardTier(res: LoadCodeAssistResponse): GeminiUserTier {
if (res.currentTier) {
return res.currentTier;
}
for (const tier of res.allowedTiers || []) { for (const tier of res.allowedTiers || []) {
if (tier.isDefault) { if (tier.isDefault) {
return tier; return tier;