mirror of
https://github.com/QwenLM/qwen-code.git
synced 2025-12-20 16:57:46 +00:00
Sync upstream Gemini-CLI v0.8.2 (#838)
This commit is contained in:
@@ -12,7 +12,21 @@ vi.mock('../utils/secure-browser-launcher.js', () => ({
|
||||
openBrowserSecurely: mockOpenBrowserSecurely,
|
||||
}));
|
||||
vi.mock('node:crypto');
|
||||
vi.mock('./oauth-token-storage.js');
|
||||
vi.mock('./oauth-token-storage.js', () => {
|
||||
const mockSaveToken = vi.fn();
|
||||
const mockGetCredentials = vi.fn();
|
||||
const mockIsTokenExpired = vi.fn();
|
||||
const mockdeleteCredentials = vi.fn();
|
||||
|
||||
return {
|
||||
MCPOAuthTokenStorage: vi.fn(() => ({
|
||||
saveToken: mockSaveToken,
|
||||
getCredentials: mockGetCredentials,
|
||||
isTokenExpired: mockIsTokenExpired,
|
||||
deleteCredentials: mockdeleteCredentials,
|
||||
})),
|
||||
};
|
||||
});
|
||||
|
||||
import { describe, it, expect, beforeEach, afterEach } from 'vitest';
|
||||
import * as http from 'node:http';
|
||||
@@ -25,6 +39,10 @@ import type {
|
||||
import { MCPOAuthProvider } from './oauth-provider.js';
|
||||
import type { OAuthToken } from './token-storage/types.js';
|
||||
import { MCPOAuthTokenStorage } from './oauth-token-storage.js';
|
||||
import type {
|
||||
OAuthAuthorizationServerMetadata,
|
||||
OAuthProtectedResourceMetadata,
|
||||
} from './oauth-utils.js';
|
||||
|
||||
// Mock fetch globally
|
||||
const mockFetch = vi.fn();
|
||||
@@ -147,8 +165,9 @@ describe('MCPOAuthProvider', () => {
|
||||
});
|
||||
|
||||
// Mock token storage
|
||||
vi.mocked(MCPOAuthTokenStorage.saveToken).mockResolvedValue(undefined);
|
||||
vi.mocked(MCPOAuthTokenStorage.getToken).mockResolvedValue(null);
|
||||
const tokenStorage = new MCPOAuthTokenStorage();
|
||||
vi.mocked(tokenStorage.saveToken).mockResolvedValue(undefined);
|
||||
vi.mocked(tokenStorage.getCredentials).mockResolvedValue(null);
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
@@ -192,10 +211,8 @@ describe('MCPOAuthProvider', () => {
|
||||
}),
|
||||
);
|
||||
|
||||
const result = await MCPOAuthProvider.authenticate(
|
||||
'test-server',
|
||||
mockConfig,
|
||||
);
|
||||
const authProvider = new MCPOAuthProvider();
|
||||
const result = await authProvider.authenticate('test-server', mockConfig);
|
||||
|
||||
expect(result).toEqual({
|
||||
accessToken: 'access_token_123',
|
||||
@@ -208,7 +225,8 @@ describe('MCPOAuthProvider', () => {
|
||||
expect(mockOpenBrowserSecurely).toHaveBeenCalledWith(
|
||||
expect.stringContaining('authorize'),
|
||||
);
|
||||
expect(MCPOAuthTokenStorage.saveToken).toHaveBeenCalledWith(
|
||||
const tokenStorage = new MCPOAuthTokenStorage();
|
||||
expect(tokenStorage.saveToken).toHaveBeenCalledWith(
|
||||
'test-server',
|
||||
expect.objectContaining({ accessToken: 'access_token_123' }),
|
||||
'test-client-id',
|
||||
@@ -296,7 +314,8 @@ describe('MCPOAuthProvider', () => {
|
||||
}),
|
||||
);
|
||||
|
||||
const result = await MCPOAuthProvider.authenticate(
|
||||
const authProvider = new MCPOAuthProvider();
|
||||
const result = await authProvider.authenticate(
|
||||
'test-server',
|
||||
configWithoutAuth,
|
||||
'https://api.example.com',
|
||||
@@ -314,8 +333,11 @@ describe('MCPOAuthProvider', () => {
|
||||
);
|
||||
});
|
||||
|
||||
it('should perform dynamic client registration when no client ID provided', async () => {
|
||||
const configWithoutClient = { ...mockConfig };
|
||||
it('should perform dynamic client registration when no client ID is provided but registration URL is provided', async () => {
|
||||
const configWithoutClient: MCPOAuthConfig = {
|
||||
...mockConfig,
|
||||
registrationUrl: 'https://auth.example.com/register',
|
||||
};
|
||||
delete configWithoutClient.clientId;
|
||||
|
||||
const mockRegistrationResponse: OAuthClientRegistrationResponse = {
|
||||
@@ -327,7 +349,80 @@ describe('MCPOAuthProvider', () => {
|
||||
token_endpoint_auth_method: 'none',
|
||||
};
|
||||
|
||||
const mockAuthServerMetadata = {
|
||||
mockFetch.mockResolvedValueOnce(
|
||||
createMockResponse({
|
||||
ok: true,
|
||||
contentType: 'application/json',
|
||||
text: JSON.stringify(mockRegistrationResponse),
|
||||
json: mockRegistrationResponse,
|
||||
}),
|
||||
);
|
||||
|
||||
// Setup callback handler
|
||||
let callbackHandler: unknown;
|
||||
vi.mocked(http.createServer).mockImplementation((handler) => {
|
||||
callbackHandler = handler;
|
||||
return mockHttpServer as unknown as http.Server;
|
||||
});
|
||||
|
||||
mockHttpServer.listen.mockImplementation((port, callback) => {
|
||||
callback?.();
|
||||
setTimeout(() => {
|
||||
const mockReq = {
|
||||
url: '/oauth/callback?code=auth_code_123&state=bW9ja19zdGF0ZV8xNl9ieXRlcw',
|
||||
};
|
||||
const mockRes = {
|
||||
writeHead: vi.fn(),
|
||||
end: vi.fn(),
|
||||
};
|
||||
(callbackHandler as (req: unknown, res: unknown) => void)(
|
||||
mockReq,
|
||||
mockRes,
|
||||
);
|
||||
}, 10);
|
||||
});
|
||||
|
||||
// Mock token exchange
|
||||
mockFetch.mockResolvedValueOnce(
|
||||
createMockResponse({
|
||||
ok: true,
|
||||
contentType: 'application/json',
|
||||
text: JSON.stringify(mockTokenResponse),
|
||||
json: mockTokenResponse,
|
||||
}),
|
||||
);
|
||||
|
||||
const authProvider = new MCPOAuthProvider();
|
||||
const result = await authProvider.authenticate(
|
||||
'test-server',
|
||||
configWithoutClient,
|
||||
);
|
||||
|
||||
expect(result).toBeDefined();
|
||||
expect(mockFetch).toHaveBeenCalledWith(
|
||||
'https://auth.example.com/register',
|
||||
expect.objectContaining({
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
it('should perform OAuth discovery and dynamic client registration when no client ID or registration URL provided', async () => {
|
||||
const configWithoutClient: MCPOAuthConfig = { ...mockConfig };
|
||||
delete configWithoutClient.clientId;
|
||||
|
||||
const mockRegistrationResponse: OAuthClientRegistrationResponse = {
|
||||
client_id: 'dynamic_client_id',
|
||||
client_secret: 'dynamic_client_secret',
|
||||
redirect_uris: ['http://localhost:7777/oauth/callback'],
|
||||
grant_types: ['authorization_code', 'refresh_token'],
|
||||
response_types: ['code'],
|
||||
token_endpoint_auth_method: 'none',
|
||||
};
|
||||
|
||||
const mockAuthServerMetadata: OAuthAuthorizationServerMetadata = {
|
||||
issuer: 'https://auth.example.com',
|
||||
authorization_endpoint: 'https://auth.example.com/authorize',
|
||||
token_endpoint: 'https://auth.example.com/token',
|
||||
registration_endpoint: 'https://auth.example.com/register',
|
||||
@@ -385,7 +480,8 @@ describe('MCPOAuthProvider', () => {
|
||||
}),
|
||||
);
|
||||
|
||||
const result = await MCPOAuthProvider.authenticate(
|
||||
const authProvider = new MCPOAuthProvider();
|
||||
const result = await authProvider.authenticate(
|
||||
'test-server',
|
||||
configWithoutClient,
|
||||
);
|
||||
@@ -400,6 +496,117 @@ describe('MCPOAuthProvider', () => {
|
||||
);
|
||||
});
|
||||
|
||||
it('should perform OAuth discovery once and dynamic client registration when no client ID, authorization URL or registration URL provided', async () => {
|
||||
const configWithoutClientAndAuthorizationUrl: MCPOAuthConfig = {
|
||||
...mockConfig,
|
||||
};
|
||||
delete configWithoutClientAndAuthorizationUrl.clientId;
|
||||
delete configWithoutClientAndAuthorizationUrl.authorizationUrl;
|
||||
|
||||
const mockResourceMetadata: OAuthProtectedResourceMetadata = {
|
||||
resource: 'https://api.example.com',
|
||||
authorization_servers: ['https://auth.example.com'],
|
||||
};
|
||||
|
||||
const mockAuthServerMetadata: OAuthAuthorizationServerMetadata = {
|
||||
issuer: 'https://auth.example.com',
|
||||
authorization_endpoint: 'https://auth.example.com/authorize',
|
||||
token_endpoint: 'https://auth.example.com/token',
|
||||
registration_endpoint: 'https://auth.example.com/register',
|
||||
};
|
||||
|
||||
const mockRegistrationResponse: OAuthClientRegistrationResponse = {
|
||||
client_id: 'dynamic_client_id',
|
||||
client_secret: 'dynamic_client_secret',
|
||||
redirect_uris: ['http://localhost:7777/oauth/callback'],
|
||||
grant_types: ['authorization_code', 'refresh_token'],
|
||||
response_types: ['code'],
|
||||
token_endpoint_auth_method: 'none',
|
||||
};
|
||||
|
||||
mockFetch
|
||||
.mockResolvedValueOnce(
|
||||
createMockResponse({
|
||||
ok: true,
|
||||
status: 200,
|
||||
}),
|
||||
)
|
||||
.mockResolvedValueOnce(
|
||||
createMockResponse({
|
||||
ok: true,
|
||||
contentType: 'application/json',
|
||||
text: JSON.stringify(mockResourceMetadata),
|
||||
json: mockResourceMetadata,
|
||||
}),
|
||||
)
|
||||
.mockResolvedValueOnce(
|
||||
createMockResponse({
|
||||
ok: true,
|
||||
contentType: 'application/json',
|
||||
text: JSON.stringify(mockAuthServerMetadata),
|
||||
json: mockAuthServerMetadata,
|
||||
}),
|
||||
)
|
||||
.mockResolvedValueOnce(
|
||||
createMockResponse({
|
||||
ok: true,
|
||||
contentType: 'application/json',
|
||||
text: JSON.stringify(mockRegistrationResponse),
|
||||
json: mockRegistrationResponse,
|
||||
}),
|
||||
);
|
||||
|
||||
// Setup callback handler
|
||||
let callbackHandler: unknown;
|
||||
vi.mocked(http.createServer).mockImplementation((handler) => {
|
||||
callbackHandler = handler;
|
||||
return mockHttpServer as unknown as http.Server;
|
||||
});
|
||||
|
||||
mockHttpServer.listen.mockImplementation((port, callback) => {
|
||||
callback?.();
|
||||
setTimeout(() => {
|
||||
const mockReq = {
|
||||
url: '/oauth/callback?code=auth_code_123&state=bW9ja19zdGF0ZV8xNl9ieXRlcw',
|
||||
};
|
||||
const mockRes = {
|
||||
writeHead: vi.fn(),
|
||||
end: vi.fn(),
|
||||
};
|
||||
(callbackHandler as (req: unknown, res: unknown) => void)(
|
||||
mockReq,
|
||||
mockRes,
|
||||
);
|
||||
}, 10);
|
||||
});
|
||||
|
||||
// Mock token exchange
|
||||
mockFetch.mockResolvedValueOnce(
|
||||
createMockResponse({
|
||||
ok: true,
|
||||
contentType: 'application/json',
|
||||
text: JSON.stringify(mockTokenResponse),
|
||||
json: mockTokenResponse,
|
||||
}),
|
||||
);
|
||||
|
||||
const authProvider = new MCPOAuthProvider();
|
||||
const result = await authProvider.authenticate(
|
||||
'test-server',
|
||||
configWithoutClientAndAuthorizationUrl,
|
||||
'https://api.example.com',
|
||||
);
|
||||
|
||||
expect(result).toBeDefined();
|
||||
expect(mockFetch).toHaveBeenCalledWith(
|
||||
'https://auth.example.com/register',
|
||||
expect.objectContaining({
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle OAuth callback errors', async () => {
|
||||
let callbackHandler: unknown;
|
||||
vi.mocked(http.createServer).mockImplementation((handler) => {
|
||||
@@ -424,8 +631,9 @@ describe('MCPOAuthProvider', () => {
|
||||
}, 10);
|
||||
});
|
||||
|
||||
const authProvider = new MCPOAuthProvider();
|
||||
await expect(
|
||||
MCPOAuthProvider.authenticate('test-server', mockConfig),
|
||||
authProvider.authenticate('test-server', mockConfig),
|
||||
).rejects.toThrow('OAuth error: access_denied');
|
||||
});
|
||||
|
||||
@@ -453,8 +661,9 @@ describe('MCPOAuthProvider', () => {
|
||||
}, 10);
|
||||
});
|
||||
|
||||
const authProvider = new MCPOAuthProvider();
|
||||
await expect(
|
||||
MCPOAuthProvider.authenticate('test-server', mockConfig),
|
||||
authProvider.authenticate('test-server', mockConfig),
|
||||
).rejects.toThrow('State mismatch - possible CSRF attack');
|
||||
});
|
||||
|
||||
@@ -491,8 +700,9 @@ describe('MCPOAuthProvider', () => {
|
||||
}),
|
||||
);
|
||||
|
||||
const authProvider = new MCPOAuthProvider();
|
||||
await expect(
|
||||
MCPOAuthProvider.authenticate('test-server', mockConfig),
|
||||
authProvider.authenticate('test-server', mockConfig),
|
||||
).rejects.toThrow('Token exchange failed: invalid_grant - Invalid grant');
|
||||
});
|
||||
|
||||
@@ -516,8 +726,9 @@ describe('MCPOAuthProvider', () => {
|
||||
return originalSetTimeout(callback, 0);
|
||||
}) as unknown as typeof setTimeout;
|
||||
|
||||
const authProvider = new MCPOAuthProvider();
|
||||
await expect(
|
||||
MCPOAuthProvider.authenticate('test-server', mockConfig),
|
||||
authProvider.authenticate('test-server', mockConfig),
|
||||
).rejects.toThrow('OAuth callback timeout');
|
||||
|
||||
global.setTimeout = originalSetTimeout;
|
||||
@@ -542,7 +753,8 @@ describe('MCPOAuthProvider', () => {
|
||||
}),
|
||||
);
|
||||
|
||||
const result = await MCPOAuthProvider.refreshAccessToken(
|
||||
const authProvider = new MCPOAuthProvider();
|
||||
const result = await authProvider.refreshAccessToken(
|
||||
mockConfig,
|
||||
'old_refresh_token',
|
||||
'https://auth.example.com/token',
|
||||
@@ -572,7 +784,8 @@ describe('MCPOAuthProvider', () => {
|
||||
}),
|
||||
);
|
||||
|
||||
await MCPOAuthProvider.refreshAccessToken(
|
||||
const authProvider = new MCPOAuthProvider();
|
||||
await authProvider.refreshAccessToken(
|
||||
mockConfig,
|
||||
'refresh_token',
|
||||
'https://auth.example.com/token',
|
||||
@@ -592,8 +805,9 @@ describe('MCPOAuthProvider', () => {
|
||||
}),
|
||||
);
|
||||
|
||||
const authProvider = new MCPOAuthProvider();
|
||||
await expect(
|
||||
MCPOAuthProvider.refreshAccessToken(
|
||||
authProvider.refreshAccessToken(
|
||||
mockConfig,
|
||||
'invalid_refresh_token',
|
||||
'https://auth.example.com/token',
|
||||
@@ -614,12 +828,14 @@ describe('MCPOAuthProvider', () => {
|
||||
updatedAt: Date.now(),
|
||||
};
|
||||
|
||||
vi.mocked(MCPOAuthTokenStorage.getToken).mockResolvedValue(
|
||||
const tokenStorage = new MCPOAuthTokenStorage();
|
||||
vi.mocked(tokenStorage.getCredentials).mockResolvedValue(
|
||||
validCredentials,
|
||||
);
|
||||
vi.mocked(MCPOAuthTokenStorage.isTokenExpired).mockReturnValue(false);
|
||||
vi.mocked(tokenStorage.isTokenExpired).mockReturnValue(false);
|
||||
|
||||
const result = await MCPOAuthProvider.getValidToken(
|
||||
const authProvider = new MCPOAuthProvider();
|
||||
const result = await authProvider.getValidToken(
|
||||
'test-server',
|
||||
mockConfig,
|
||||
);
|
||||
@@ -636,10 +852,11 @@ describe('MCPOAuthProvider', () => {
|
||||
updatedAt: Date.now(),
|
||||
};
|
||||
|
||||
vi.mocked(MCPOAuthTokenStorage.getToken).mockResolvedValue(
|
||||
const tokenStorage = new MCPOAuthTokenStorage();
|
||||
vi.mocked(tokenStorage.getCredentials).mockResolvedValue(
|
||||
expiredCredentials,
|
||||
);
|
||||
vi.mocked(MCPOAuthTokenStorage.isTokenExpired).mockReturnValue(true);
|
||||
vi.mocked(tokenStorage.isTokenExpired).mockReturnValue(true);
|
||||
|
||||
const refreshResponse = {
|
||||
access_token: 'new_access_token',
|
||||
@@ -657,13 +874,14 @@ describe('MCPOAuthProvider', () => {
|
||||
}),
|
||||
);
|
||||
|
||||
const result = await MCPOAuthProvider.getValidToken(
|
||||
const authProvider = new MCPOAuthProvider();
|
||||
const result = await authProvider.getValidToken(
|
||||
'test-server',
|
||||
mockConfig,
|
||||
);
|
||||
|
||||
expect(result).toBe('new_access_token');
|
||||
expect(MCPOAuthTokenStorage.saveToken).toHaveBeenCalledWith(
|
||||
expect(tokenStorage.saveToken).toHaveBeenCalledWith(
|
||||
'test-server',
|
||||
expect.objectContaining({ accessToken: 'new_access_token' }),
|
||||
'test-client-id',
|
||||
@@ -673,9 +891,11 @@ describe('MCPOAuthProvider', () => {
|
||||
});
|
||||
|
||||
it('should return null when no credentials exist', async () => {
|
||||
vi.mocked(MCPOAuthTokenStorage.getToken).mockResolvedValue(null);
|
||||
const tokenStorage = new MCPOAuthTokenStorage();
|
||||
vi.mocked(tokenStorage.getCredentials).mockResolvedValue(null);
|
||||
|
||||
const result = await MCPOAuthProvider.getValidToken(
|
||||
const authProvider = new MCPOAuthProvider();
|
||||
const result = await authProvider.getValidToken(
|
||||
'test-server',
|
||||
mockConfig,
|
||||
);
|
||||
@@ -692,11 +912,12 @@ describe('MCPOAuthProvider', () => {
|
||||
updatedAt: Date.now(),
|
||||
};
|
||||
|
||||
vi.mocked(MCPOAuthTokenStorage.getToken).mockResolvedValue(
|
||||
const tokenStorage = new MCPOAuthTokenStorage();
|
||||
vi.mocked(tokenStorage.getCredentials).mockResolvedValue(
|
||||
expiredCredentials,
|
||||
);
|
||||
vi.mocked(MCPOAuthTokenStorage.isTokenExpired).mockReturnValue(true);
|
||||
vi.mocked(MCPOAuthTokenStorage.removeToken).mockResolvedValue(undefined);
|
||||
vi.mocked(tokenStorage.isTokenExpired).mockReturnValue(true);
|
||||
vi.mocked(tokenStorage.deleteCredentials).mockResolvedValue(undefined);
|
||||
|
||||
mockFetch.mockResolvedValueOnce(
|
||||
createMockResponse({
|
||||
@@ -707,13 +928,14 @@ describe('MCPOAuthProvider', () => {
|
||||
}),
|
||||
);
|
||||
|
||||
const result = await MCPOAuthProvider.getValidToken(
|
||||
const authProvider = new MCPOAuthProvider();
|
||||
const result = await authProvider.getValidToken(
|
||||
'test-server',
|
||||
mockConfig,
|
||||
);
|
||||
|
||||
expect(result).toBeNull();
|
||||
expect(MCPOAuthTokenStorage.removeToken).toHaveBeenCalledWith(
|
||||
expect(tokenStorage.deleteCredentials).toHaveBeenCalledWith(
|
||||
'test-server',
|
||||
);
|
||||
expect(console.error).toHaveBeenCalledWith(
|
||||
@@ -734,12 +956,14 @@ describe('MCPOAuthProvider', () => {
|
||||
updatedAt: Date.now(),
|
||||
};
|
||||
|
||||
vi.mocked(MCPOAuthTokenStorage.getToken).mockResolvedValue(
|
||||
const tokenStorage = new MCPOAuthTokenStorage();
|
||||
vi.mocked(tokenStorage.getCredentials).mockResolvedValue(
|
||||
tokenWithoutRefresh,
|
||||
);
|
||||
vi.mocked(MCPOAuthTokenStorage.isTokenExpired).mockReturnValue(true);
|
||||
vi.mocked(tokenStorage.isTokenExpired).mockReturnValue(true);
|
||||
|
||||
const result = await MCPOAuthProvider.getValidToken(
|
||||
const authProvider = new MCPOAuthProvider();
|
||||
const result = await authProvider.getValidToken(
|
||||
'test-server',
|
||||
mockConfig,
|
||||
);
|
||||
@@ -784,7 +1008,8 @@ describe('MCPOAuthProvider', () => {
|
||||
}),
|
||||
);
|
||||
|
||||
await MCPOAuthProvider.authenticate('test-server', mockConfig);
|
||||
const authProvider = new MCPOAuthProvider();
|
||||
await authProvider.authenticate('test-server', mockConfig);
|
||||
|
||||
expect(crypto.randomBytes).toHaveBeenCalledWith(32); // code verifier
|
||||
expect(crypto.randomBytes).toHaveBeenCalledWith(16); // state
|
||||
@@ -833,7 +1058,8 @@ describe('MCPOAuthProvider', () => {
|
||||
}),
|
||||
);
|
||||
|
||||
await MCPOAuthProvider.authenticate(
|
||||
const authProvider = new MCPOAuthProvider();
|
||||
await authProvider.authenticate(
|
||||
'test-server',
|
||||
mockConfig,
|
||||
'https://auth.example.com',
|
||||
@@ -894,7 +1120,8 @@ describe('MCPOAuthProvider', () => {
|
||||
authorizationUrl: 'https://auth.example.com/authorize?audience=1234',
|
||||
};
|
||||
|
||||
await MCPOAuthProvider.authenticate('test-server', configWithParamsInUrl);
|
||||
const authProvider = new MCPOAuthProvider();
|
||||
await authProvider.authenticate('test-server', configWithParamsInUrl);
|
||||
|
||||
const url = new URL(capturedUrl!);
|
||||
expect(url.searchParams.get('audience')).toBe('1234');
|
||||
@@ -947,7 +1174,8 @@ describe('MCPOAuthProvider', () => {
|
||||
authorizationUrl: 'https://auth.example.com/authorize#login',
|
||||
};
|
||||
|
||||
await MCPOAuthProvider.authenticate('test-server', configWithFragment);
|
||||
const authProvider = new MCPOAuthProvider();
|
||||
await authProvider.authenticate('test-server', configWithFragment);
|
||||
|
||||
const url = new URL(capturedUrl!);
|
||||
expect(url.searchParams.get('client_id')).toBe('test-client-id');
|
||||
|
||||
@@ -7,12 +7,15 @@
|
||||
import * as http from 'node:http';
|
||||
import * as crypto from 'node:crypto';
|
||||
import { URL } from 'node:url';
|
||||
import type { EventEmitter } from 'node:events';
|
||||
import { openBrowserSecurely } from '../utils/secure-browser-launcher.js';
|
||||
import type { OAuthToken } from './token-storage/types.js';
|
||||
import { MCPOAuthTokenStorage } from './oauth-token-storage.js';
|
||||
import { getErrorMessage } from '../utils/errors.js';
|
||||
import { OAuthUtils } from './oauth-utils.js';
|
||||
|
||||
export const OAUTH_DISPLAY_MESSAGE_EVENT = 'oauth-display-message' as const;
|
||||
|
||||
/**
|
||||
* OAuth configuration for an MCP server.
|
||||
*/
|
||||
@@ -26,6 +29,7 @@ export interface MCPOAuthConfig {
|
||||
audiences?: string[];
|
||||
redirectUri?: string;
|
||||
tokenParamName?: string; // For SSE connections, specifies the query parameter name for the token
|
||||
registrationUrl?: string;
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -85,13 +89,19 @@ interface PKCEParams {
|
||||
state: string;
|
||||
}
|
||||
|
||||
const REDIRECT_PORT = 7777;
|
||||
const REDIRECT_PATH = '/oauth/callback';
|
||||
const HTTP_OK = 200;
|
||||
|
||||
/**
|
||||
* Provider for handling OAuth authentication for MCP servers.
|
||||
*/
|
||||
export class MCPOAuthProvider {
|
||||
private static readonly REDIRECT_PORT = 7777;
|
||||
private static readonly REDIRECT_PATH = '/oauth/callback';
|
||||
private static readonly HTTP_OK = 200;
|
||||
private readonly tokenStorage: MCPOAuthTokenStorage;
|
||||
|
||||
constructor(tokenStorage: MCPOAuthTokenStorage = new MCPOAuthTokenStorage()) {
|
||||
this.tokenStorage = tokenStorage;
|
||||
}
|
||||
|
||||
/**
|
||||
* Register a client dynamically with the OAuth server.
|
||||
@@ -100,13 +110,12 @@ export class MCPOAuthProvider {
|
||||
* @param config OAuth configuration
|
||||
* @returns The registered client information
|
||||
*/
|
||||
private static async registerClient(
|
||||
private async registerClient(
|
||||
registrationUrl: string,
|
||||
config: MCPOAuthConfig,
|
||||
): Promise<OAuthClientRegistrationResponse> {
|
||||
const redirectUri =
|
||||
config.redirectUri ||
|
||||
`http://localhost:${this.REDIRECT_PORT}${this.REDIRECT_PATH}`;
|
||||
config.redirectUri || `http://localhost:${REDIRECT_PORT}${REDIRECT_PATH}`;
|
||||
|
||||
const registrationRequest: OAuthClientRegistrationRequest = {
|
||||
client_name: 'Gemini CLI (Google ADC)',
|
||||
@@ -142,7 +151,7 @@ export class MCPOAuthProvider {
|
||||
* @param mcpServerUrl The MCP server URL
|
||||
* @returns OAuth configuration if discovered, null otherwise
|
||||
*/
|
||||
private static async discoverOAuthFromMCPServer(
|
||||
private async discoverOAuthFromMCPServer(
|
||||
mcpServerUrl: string,
|
||||
): Promise<MCPOAuthConfig | null> {
|
||||
// Use the full URL with path preserved for OAuth discovery
|
||||
@@ -154,7 +163,7 @@ export class MCPOAuthProvider {
|
||||
*
|
||||
* @returns PKCE parameters including code verifier, challenge, and state
|
||||
*/
|
||||
private static generatePKCEParams(): PKCEParams {
|
||||
private generatePKCEParams(): PKCEParams {
|
||||
// Generate code verifier (43-128 characters)
|
||||
const codeVerifier = crypto.randomBytes(32).toString('base64url');
|
||||
|
||||
@@ -176,19 +185,16 @@ export class MCPOAuthProvider {
|
||||
* @param expectedState The state parameter to validate
|
||||
* @returns Promise that resolves with the authorization code
|
||||
*/
|
||||
private static async startCallbackServer(
|
||||
private async startCallbackServer(
|
||||
expectedState: string,
|
||||
): Promise<OAuthAuthorizationResponse> {
|
||||
return new Promise((resolve, reject) => {
|
||||
const server = http.createServer(
|
||||
async (req: http.IncomingMessage, res: http.ServerResponse) => {
|
||||
try {
|
||||
const url = new URL(
|
||||
req.url!,
|
||||
`http://localhost:${this.REDIRECT_PORT}`,
|
||||
);
|
||||
const url = new URL(req.url!, `http://localhost:${REDIRECT_PORT}`);
|
||||
|
||||
if (url.pathname !== this.REDIRECT_PATH) {
|
||||
if (url.pathname !== REDIRECT_PATH) {
|
||||
res.writeHead(404);
|
||||
res.end('Not found');
|
||||
return;
|
||||
@@ -199,7 +205,7 @@ export class MCPOAuthProvider {
|
||||
const error = url.searchParams.get('error');
|
||||
|
||||
if (error) {
|
||||
res.writeHead(this.HTTP_OK, { 'Content-Type': 'text/html' });
|
||||
res.writeHead(HTTP_OK, { 'Content-Type': 'text/html' });
|
||||
res.end(`
|
||||
<html>
|
||||
<body>
|
||||
@@ -230,7 +236,7 @@ export class MCPOAuthProvider {
|
||||
}
|
||||
|
||||
// Send success response to browser
|
||||
res.writeHead(this.HTTP_OK, { 'Content-Type': 'text/html' });
|
||||
res.writeHead(HTTP_OK, { 'Content-Type': 'text/html' });
|
||||
res.end(`
|
||||
<html>
|
||||
<body>
|
||||
@@ -251,10 +257,8 @@ export class MCPOAuthProvider {
|
||||
);
|
||||
|
||||
server.on('error', reject);
|
||||
server.listen(this.REDIRECT_PORT, () => {
|
||||
console.log(
|
||||
`OAuth callback server listening on port ${this.REDIRECT_PORT}`,
|
||||
);
|
||||
server.listen(REDIRECT_PORT, () => {
|
||||
console.log(`OAuth callback server listening on port ${REDIRECT_PORT}`);
|
||||
});
|
||||
|
||||
// Timeout after 5 minutes
|
||||
@@ -276,14 +280,13 @@ export class MCPOAuthProvider {
|
||||
* @param mcpServerUrl The MCP server URL to use as the resource parameter
|
||||
* @returns The authorization URL
|
||||
*/
|
||||
private static buildAuthorizationUrl(
|
||||
private buildAuthorizationUrl(
|
||||
config: MCPOAuthConfig,
|
||||
pkceParams: PKCEParams,
|
||||
mcpServerUrl?: string,
|
||||
): string {
|
||||
const redirectUri =
|
||||
config.redirectUri ||
|
||||
`http://localhost:${this.REDIRECT_PORT}${this.REDIRECT_PATH}`;
|
||||
config.redirectUri || `http://localhost:${REDIRECT_PORT}${REDIRECT_PATH}`;
|
||||
|
||||
const params = new URLSearchParams({
|
||||
client_id: config.clientId!,
|
||||
@@ -333,15 +336,14 @@ export class MCPOAuthProvider {
|
||||
* @param mcpServerUrl The MCP server URL to use as the resource parameter
|
||||
* @returns The token response
|
||||
*/
|
||||
private static async exchangeCodeForToken(
|
||||
private async exchangeCodeForToken(
|
||||
config: MCPOAuthConfig,
|
||||
code: string,
|
||||
codeVerifier: string,
|
||||
mcpServerUrl?: string,
|
||||
): Promise<OAuthTokenResponse> {
|
||||
const redirectUri =
|
||||
config.redirectUri ||
|
||||
`http://localhost:${this.REDIRECT_PORT}${this.REDIRECT_PATH}`;
|
||||
config.redirectUri || `http://localhost:${REDIRECT_PORT}${REDIRECT_PATH}`;
|
||||
|
||||
const params = new URLSearchParams({
|
||||
grant_type: 'authorization_code',
|
||||
@@ -458,7 +460,7 @@ export class MCPOAuthProvider {
|
||||
* @param mcpServerUrl The MCP server URL to use as the resource parameter
|
||||
* @returns The new token response
|
||||
*/
|
||||
static async refreshAccessToken(
|
||||
async refreshAccessToken(
|
||||
config: MCPOAuthConfig,
|
||||
refreshToken: string,
|
||||
tokenUrl: string,
|
||||
@@ -577,18 +579,28 @@ export class MCPOAuthProvider {
|
||||
* @param serverName The name of the MCP server
|
||||
* @param config OAuth configuration
|
||||
* @param mcpServerUrl Optional MCP server URL for OAuth discovery
|
||||
* @param messageHandler Optional handler for displaying user-facing messages
|
||||
* @returns The obtained OAuth token
|
||||
*/
|
||||
static async authenticate(
|
||||
async authenticate(
|
||||
serverName: string,
|
||||
config: MCPOAuthConfig,
|
||||
mcpServerUrl?: string,
|
||||
events?: EventEmitter,
|
||||
): Promise<OAuthToken> {
|
||||
// Helper function to display messages through handler or fallback to console.log
|
||||
const displayMessage = (message: string) => {
|
||||
if (events) {
|
||||
events.emit(OAUTH_DISPLAY_MESSAGE_EVENT, message);
|
||||
} else {
|
||||
console.log(message);
|
||||
}
|
||||
};
|
||||
|
||||
// If no authorization URL is provided, try to discover OAuth configuration
|
||||
if (!config.authorizationUrl && mcpServerUrl) {
|
||||
console.log(
|
||||
'No authorization URL provided, attempting OAuth discovery...',
|
||||
);
|
||||
console.debug(`Starting OAuth for MCP server "${serverName}"…
|
||||
✓ No authorization URL; using OAuth discovery`);
|
||||
|
||||
// First check if the server requires authentication via WWW-Authenticate header
|
||||
try {
|
||||
@@ -640,6 +652,7 @@ export class MCPOAuthProvider {
|
||||
authorizationUrl: discoveredConfig.authorizationUrl,
|
||||
tokenUrl: discoveredConfig.tokenUrl,
|
||||
scopes: discoveredConfig.scopes || config.scopes || [],
|
||||
registrationUrl: discoveredConfig.registrationUrl,
|
||||
// Preserve existing client credentials
|
||||
clientId: config.clientId,
|
||||
clientSecret: config.clientSecret,
|
||||
@@ -654,40 +667,38 @@ export class MCPOAuthProvider {
|
||||
|
||||
// If no client ID is provided, try dynamic client registration
|
||||
if (!config.clientId) {
|
||||
// Extract server URL from authorization URL
|
||||
if (!config.authorizationUrl) {
|
||||
throw new Error(
|
||||
'Cannot perform dynamic registration without authorization URL',
|
||||
);
|
||||
}
|
||||
let registrationUrl = config.registrationUrl;
|
||||
|
||||
const authUrl = new URL(config.authorizationUrl);
|
||||
const serverUrl = `${authUrl.protocol}//${authUrl.host}`;
|
||||
// If no registration URL was previously discovered, try to discover it
|
||||
if (!registrationUrl) {
|
||||
// Extract server URL from authorization URL
|
||||
if (!config.authorizationUrl) {
|
||||
throw new Error(
|
||||
'Cannot perform dynamic registration without authorization URL',
|
||||
);
|
||||
}
|
||||
|
||||
console.log(
|
||||
'No client ID provided, attempting dynamic client registration...',
|
||||
);
|
||||
const authUrl = new URL(config.authorizationUrl);
|
||||
const serverUrl = `${authUrl.protocol}//${authUrl.host}`;
|
||||
|
||||
// Get the authorization server metadata for registration
|
||||
const authServerMetadataUrl = new URL(
|
||||
'/.well-known/oauth-authorization-server',
|
||||
serverUrl,
|
||||
).toString();
|
||||
console.debug('→ Attempting dynamic client registration...');
|
||||
|
||||
const authServerMetadata =
|
||||
await OAuthUtils.fetchAuthorizationServerMetadata(
|
||||
authServerMetadataUrl,
|
||||
);
|
||||
if (!authServerMetadata) {
|
||||
throw new Error(
|
||||
'Failed to fetch authorization server metadata for client registration',
|
||||
);
|
||||
// Get the authorization server metadata for registration
|
||||
const authServerMetadata =
|
||||
await OAuthUtils.discoverAuthorizationServerMetadata(serverUrl);
|
||||
|
||||
if (!authServerMetadata) {
|
||||
throw new Error(
|
||||
'Failed to fetch authorization server metadata for client registration',
|
||||
);
|
||||
}
|
||||
registrationUrl = authServerMetadata.registration_endpoint;
|
||||
}
|
||||
|
||||
// Register client if registration endpoint is available
|
||||
if (authServerMetadata.registration_endpoint) {
|
||||
if (registrationUrl) {
|
||||
const clientRegistration = await this.registerClient(
|
||||
authServerMetadata.registration_endpoint,
|
||||
registrationUrl,
|
||||
config,
|
||||
);
|
||||
|
||||
@@ -696,7 +707,7 @@ export class MCPOAuthProvider {
|
||||
config.clientSecret = clientRegistration.client_secret;
|
||||
}
|
||||
|
||||
console.log('Dynamic client registration successful');
|
||||
console.debug('✓ Dynamic client registration successful');
|
||||
} else {
|
||||
throw new Error(
|
||||
'No client ID provided and dynamic registration not supported',
|
||||
@@ -721,30 +732,13 @@ export class MCPOAuthProvider {
|
||||
mcpServerUrl,
|
||||
);
|
||||
|
||||
console.log('\nOpening browser for OAuth authentication...');
|
||||
console.log('If the browser does not open, please visit:');
|
||||
console.log('');
|
||||
displayMessage(`→ Opening your browser for OAuth sign-in...
|
||||
|
||||
// Get terminal width or default to 80
|
||||
const terminalWidth = process.stdout.columns || 80;
|
||||
const separatorLength = Math.min(terminalWidth - 2, 80);
|
||||
const separator = '━'.repeat(separatorLength);
|
||||
If the browser does not open, copy and paste this URL into your browser:
|
||||
${authUrl}
|
||||
|
||||
console.log(separator);
|
||||
console.log(
|
||||
'COPY THE ENTIRE URL BELOW (select all text between the lines):',
|
||||
);
|
||||
console.log(separator);
|
||||
console.log(authUrl);
|
||||
console.log(separator);
|
||||
console.log('');
|
||||
console.log(
|
||||
'💡 TIP: Triple-click to select the entire URL, then copy and paste it into your browser.',
|
||||
);
|
||||
console.log(
|
||||
'⚠️ Make sure to copy the COMPLETE URL - it may wrap across multiple lines.',
|
||||
);
|
||||
console.log('');
|
||||
💡 TIP: Triple-click to select the entire URL, then copy and paste it into your browser.
|
||||
⚠️ Make sure to copy the COMPLETE URL - it may wrap across multiple lines.`);
|
||||
|
||||
// Start callback server
|
||||
const callbackPromise = this.startCallbackServer(pkceParams.state);
|
||||
@@ -762,7 +756,7 @@ export class MCPOAuthProvider {
|
||||
// Wait for callback
|
||||
const { code } = await callbackPromise;
|
||||
|
||||
console.log('\nAuthorization code received, exchanging for tokens...');
|
||||
console.debug('✓ Authorization code received, exchanging for tokens...');
|
||||
|
||||
// Exchange code for tokens
|
||||
const tokenResponse = await this.exchangeCodeForToken(
|
||||
@@ -790,23 +784,27 @@ export class MCPOAuthProvider {
|
||||
|
||||
// Save token
|
||||
try {
|
||||
await MCPOAuthTokenStorage.saveToken(
|
||||
await this.tokenStorage.saveToken(
|
||||
serverName,
|
||||
token,
|
||||
config.clientId,
|
||||
config.tokenUrl,
|
||||
mcpServerUrl,
|
||||
);
|
||||
console.log('Authentication successful! Token saved.');
|
||||
console.debug('✓ Authentication successful! Token saved.');
|
||||
|
||||
// Verify token was saved
|
||||
const savedToken = await MCPOAuthTokenStorage.getToken(serverName);
|
||||
const savedToken = await this.tokenStorage.getCredentials(serverName);
|
||||
if (savedToken && savedToken.token && savedToken.token.accessToken) {
|
||||
const tokenPreview =
|
||||
savedToken.token.accessToken.length > 20
|
||||
? `${savedToken.token.accessToken.substring(0, 20)}...`
|
||||
: '[token]';
|
||||
console.log(`Token verification successful: ${tokenPreview}`);
|
||||
// Avoid leaking token material; log a short SHA-256 fingerprint instead.
|
||||
const tokenFingerprint = crypto
|
||||
.createHash('sha256')
|
||||
.update(savedToken.token.accessToken)
|
||||
.digest('hex')
|
||||
.slice(0, 8);
|
||||
console.debug(
|
||||
`✓ Token verification successful (fingerprint: ${tokenFingerprint})`,
|
||||
);
|
||||
} else {
|
||||
console.error(
|
||||
'Token verification failed: token not found or invalid after save',
|
||||
@@ -827,12 +825,12 @@ export class MCPOAuthProvider {
|
||||
* @param config OAuth configuration
|
||||
* @returns A valid access token or null if not authenticated
|
||||
*/
|
||||
static async getValidToken(
|
||||
async getValidToken(
|
||||
serverName: string,
|
||||
config: MCPOAuthConfig,
|
||||
): Promise<string | null> {
|
||||
console.debug(`Getting valid token for server: ${serverName}`);
|
||||
const credentials = await MCPOAuthTokenStorage.getToken(serverName);
|
||||
const credentials = await this.tokenStorage.getCredentials(serverName);
|
||||
|
||||
if (!credentials) {
|
||||
console.debug(`No credentials found for server: ${serverName}`);
|
||||
@@ -841,11 +839,11 @@ export class MCPOAuthProvider {
|
||||
|
||||
const { token } = credentials;
|
||||
console.debug(
|
||||
`Found token for server: ${serverName}, expired: ${MCPOAuthTokenStorage.isTokenExpired(token)}`,
|
||||
`Found token for server: ${serverName}, expired: ${this.tokenStorage.isTokenExpired(token)}`,
|
||||
);
|
||||
|
||||
// Check if token is expired
|
||||
if (!MCPOAuthTokenStorage.isTokenExpired(token)) {
|
||||
if (!this.tokenStorage.isTokenExpired(token)) {
|
||||
console.debug(`Returning valid token for server: ${serverName}`);
|
||||
return token.accessToken;
|
||||
}
|
||||
@@ -874,7 +872,7 @@ export class MCPOAuthProvider {
|
||||
newToken.expiresAt = Date.now() + newTokenResponse.expires_in * 1000;
|
||||
}
|
||||
|
||||
await MCPOAuthTokenStorage.saveToken(
|
||||
await this.tokenStorage.saveToken(
|
||||
serverName,
|
||||
newToken,
|
||||
config.clientId,
|
||||
@@ -886,7 +884,7 @@ export class MCPOAuthProvider {
|
||||
} catch (error) {
|
||||
console.error(`Failed to refresh token: ${getErrorMessage(error)}`);
|
||||
// Remove invalid token
|
||||
await MCPOAuthTokenStorage.removeToken(serverName);
|
||||
await this.tokenStorage.deleteCredentials(serverName);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -4,13 +4,15 @@
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest';
|
||||
import { promises as fs } from 'node:fs';
|
||||
import * as path from 'node:path';
|
||||
import { MCPOAuthTokenStorage } from './oauth-token-storage.js';
|
||||
import type { OAuthToken, OAuthCredentials } from './token-storage/types.js';
|
||||
import { FORCE_ENCRYPTED_FILE_ENV_VAR } from './token-storage/index.js';
|
||||
import type { OAuthCredentials, OAuthToken } from './token-storage/types.js';
|
||||
import { QWEN_DIR } from '../utils/paths.js';
|
||||
|
||||
// Mock file system operations
|
||||
// Mock dependencies
|
||||
vi.mock('node:fs', () => ({
|
||||
promises: {
|
||||
readFile: vi.fn(),
|
||||
@@ -18,20 +20,42 @@ vi.mock('node:fs', () => ({
|
||||
mkdir: vi.fn(),
|
||||
unlink: vi.fn(),
|
||||
},
|
||||
mkdirSync: vi.fn(),
|
||||
}));
|
||||
|
||||
vi.mock('node:os', () => ({
|
||||
homedir: vi.fn(() => '/mock/home'),
|
||||
vi.mock('node:path', () => ({
|
||||
dirname: vi.fn(),
|
||||
join: vi.fn(),
|
||||
}));
|
||||
|
||||
vi.mock('../config/storage.js', () => ({
|
||||
Storage: {
|
||||
getMcpOAuthTokensPath: vi.fn(),
|
||||
},
|
||||
}));
|
||||
|
||||
const mockHybridTokenStorage = {
|
||||
listServers: vi.fn(),
|
||||
setCredentials: vi.fn(),
|
||||
getCredentials: vi.fn(),
|
||||
deleteCredentials: vi.fn(),
|
||||
clearAll: vi.fn(),
|
||||
getAllCredentials: vi.fn(),
|
||||
};
|
||||
vi.mock('./token-storage/hybrid-token-storage.js', () => ({
|
||||
HybridTokenStorage: vi.fn(() => mockHybridTokenStorage),
|
||||
}));
|
||||
|
||||
const ONE_HR_MS = 3600000;
|
||||
|
||||
describe('MCPOAuthTokenStorage', () => {
|
||||
let tokenStorage: MCPOAuthTokenStorage;
|
||||
|
||||
const mockToken: OAuthToken = {
|
||||
accessToken: 'access_token_123',
|
||||
refreshToken: 'refresh_token_456',
|
||||
tokenType: 'Bearer',
|
||||
scope: 'read write',
|
||||
expiresAt: Date.now() + 3600000, // 1 hour from now
|
||||
expiresAt: Date.now() + ONE_HR_MS,
|
||||
};
|
||||
|
||||
const mockCredentials: OAuthCredentials = {
|
||||
@@ -42,281 +66,385 @@ describe('MCPOAuthTokenStorage', () => {
|
||||
updatedAt: Date.now(),
|
||||
};
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
vi.spyOn(console, 'error').mockImplementation(() => {});
|
||||
});
|
||||
describe('with encrypted flag false', () => {
|
||||
beforeEach(() => {
|
||||
vi.stubEnv(FORCE_ENCRYPTED_FILE_ENV_VAR, 'false');
|
||||
tokenStorage = new MCPOAuthTokenStorage();
|
||||
|
||||
afterEach(() => {
|
||||
vi.restoreAllMocks();
|
||||
});
|
||||
|
||||
describe('loadTokens', () => {
|
||||
it('should return empty map when token file does not exist', async () => {
|
||||
vi.mocked(fs.readFile).mockRejectedValue({ code: 'ENOENT' });
|
||||
|
||||
const tokens = await MCPOAuthTokenStorage.loadTokens();
|
||||
|
||||
expect(tokens.size).toBe(0);
|
||||
expect(console.error).not.toHaveBeenCalled();
|
||||
vi.clearAllMocks();
|
||||
vi.spyOn(console, 'error');
|
||||
});
|
||||
|
||||
it('should load tokens from file successfully', async () => {
|
||||
const tokensArray = [mockCredentials];
|
||||
vi.mocked(fs.readFile).mockResolvedValue(JSON.stringify(tokensArray));
|
||||
|
||||
const tokens = await MCPOAuthTokenStorage.loadTokens();
|
||||
|
||||
expect(tokens.size).toBe(1);
|
||||
expect(tokens.get('test-server')).toEqual(mockCredentials);
|
||||
expect(fs.readFile).toHaveBeenCalledWith(
|
||||
path.join('/mock/home', '.qwen', 'mcp-oauth-tokens.json'),
|
||||
'utf-8',
|
||||
);
|
||||
afterEach(() => {
|
||||
vi.unstubAllEnvs();
|
||||
vi.restoreAllMocks();
|
||||
});
|
||||
|
||||
it('should handle corrupted token file gracefully', async () => {
|
||||
vi.mocked(fs.readFile).mockResolvedValue('invalid json');
|
||||
describe('getAllCredentials', () => {
|
||||
it('should return empty map when token file does not exist', async () => {
|
||||
vi.mocked(fs.readFile).mockRejectedValue({ code: 'ENOENT' });
|
||||
|
||||
const tokens = await MCPOAuthTokenStorage.loadTokens();
|
||||
const tokens = await tokenStorage.getAllCredentials();
|
||||
|
||||
expect(tokens.size).toBe(0);
|
||||
expect(console.error).toHaveBeenCalledWith(
|
||||
expect.stringContaining('Failed to load MCP OAuth tokens'),
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle file read errors other than ENOENT', async () => {
|
||||
const error = new Error('Permission denied');
|
||||
vi.mocked(fs.readFile).mockRejectedValue(error);
|
||||
|
||||
const tokens = await MCPOAuthTokenStorage.loadTokens();
|
||||
|
||||
expect(tokens.size).toBe(0);
|
||||
expect(console.error).toHaveBeenCalledWith(
|
||||
expect.stringContaining('Failed to load MCP OAuth tokens'),
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('saveToken', () => {
|
||||
it('should save token with restricted permissions', async () => {
|
||||
vi.mocked(fs.readFile).mockRejectedValue({ code: 'ENOENT' });
|
||||
vi.mocked(fs.mkdir).mockResolvedValue(undefined);
|
||||
vi.mocked(fs.writeFile).mockResolvedValue(undefined);
|
||||
|
||||
await MCPOAuthTokenStorage.saveToken(
|
||||
'test-server',
|
||||
mockToken,
|
||||
'client-id',
|
||||
'https://token.url',
|
||||
);
|
||||
|
||||
expect(fs.mkdir).toHaveBeenCalledWith(path.join('/mock/home', '.qwen'), {
|
||||
recursive: true,
|
||||
expect(tokens.size).toBe(0);
|
||||
expect(console.error).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should load tokens from file successfully', async () => {
|
||||
const tokensArray = [mockCredentials];
|
||||
vi.mocked(fs.readFile).mockResolvedValue(JSON.stringify(tokensArray));
|
||||
|
||||
const tokens = await tokenStorage.getAllCredentials();
|
||||
|
||||
expect(tokens.size).toBe(1);
|
||||
expect(tokens.get('test-server')).toEqual(mockCredentials);
|
||||
expect(fs.readFile).toHaveBeenCalledWith(
|
||||
path.join('/mock/home', QWEN_DIR, 'mcp-oauth-tokens.json'),
|
||||
'utf-8',
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle corrupted token file gracefully', async () => {
|
||||
vi.mocked(fs.readFile).mockResolvedValue('invalid json');
|
||||
|
||||
const tokens = await tokenStorage.getAllCredentials();
|
||||
|
||||
expect(tokens.size).toBe(0);
|
||||
expect(console.error).toHaveBeenCalledWith(
|
||||
expect.stringContaining('Failed to load MCP OAuth tokens'),
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle file read errors other than ENOENT', async () => {
|
||||
const error = new Error('Permission denied');
|
||||
vi.mocked(fs.readFile).mockRejectedValue(error);
|
||||
|
||||
const tokens = await tokenStorage.getAllCredentials();
|
||||
|
||||
expect(tokens.size).toBe(0);
|
||||
expect(console.error).toHaveBeenCalledWith(
|
||||
expect.stringContaining('Failed to load MCP OAuth tokens'),
|
||||
);
|
||||
});
|
||||
expect(fs.writeFile).toHaveBeenCalledWith(
|
||||
path.join('/mock/home', '.qwen', 'mcp-oauth-tokens.json'),
|
||||
expect.stringContaining('test-server'),
|
||||
{ mode: 0o600 },
|
||||
);
|
||||
});
|
||||
|
||||
it('should update existing token for same server', async () => {
|
||||
const existingCredentials = {
|
||||
...mockCredentials,
|
||||
serverName: 'existing-server',
|
||||
};
|
||||
vi.mocked(fs.readFile).mockResolvedValue(
|
||||
JSON.stringify([existingCredentials]),
|
||||
);
|
||||
vi.mocked(fs.writeFile).mockResolvedValue(undefined);
|
||||
describe('saveToken', () => {
|
||||
it('should save token with restricted permissions', async () => {
|
||||
vi.mocked(fs.readFile).mockRejectedValue({ code: 'ENOENT' });
|
||||
vi.mocked(fs.mkdir).mockResolvedValue(undefined);
|
||||
vi.mocked(fs.writeFile).mockResolvedValue(undefined);
|
||||
|
||||
const newToken = { ...mockToken, accessToken: 'new_access_token' };
|
||||
await MCPOAuthTokenStorage.saveToken('existing-server', newToken);
|
||||
await tokenStorage.saveToken(
|
||||
'test-server',
|
||||
mockToken,
|
||||
'client-id',
|
||||
'https://token.url',
|
||||
);
|
||||
|
||||
const writeCall = vi.mocked(fs.writeFile).mock.calls[0];
|
||||
const savedData = JSON.parse(writeCall[1] as string);
|
||||
expect(fs.mkdir).toHaveBeenCalledWith(
|
||||
path.join('/mock/home', QWEN_DIR),
|
||||
{ recursive: true },
|
||||
);
|
||||
expect(fs.writeFile).toHaveBeenCalledWith(
|
||||
path.join('/mock/home', QWEN_DIR, 'mcp-oauth-tokens.json'),
|
||||
expect.stringContaining('test-server'),
|
||||
{ mode: 0o600 },
|
||||
);
|
||||
});
|
||||
|
||||
expect(savedData).toHaveLength(1);
|
||||
expect(savedData[0].token.accessToken).toBe('new_access_token');
|
||||
expect(savedData[0].serverName).toBe('existing-server');
|
||||
it('should update existing token for same server', async () => {
|
||||
const existingCredentials: OAuthCredentials = {
|
||||
...mockCredentials,
|
||||
serverName: 'existing-server',
|
||||
};
|
||||
vi.mocked(fs.readFile).mockResolvedValue(
|
||||
JSON.stringify([existingCredentials]),
|
||||
);
|
||||
vi.mocked(fs.writeFile).mockResolvedValue(undefined);
|
||||
|
||||
const newToken: OAuthToken = {
|
||||
...mockToken,
|
||||
accessToken: 'new_access_token',
|
||||
};
|
||||
await tokenStorage.saveToken('existing-server', newToken);
|
||||
|
||||
const writeCall = vi.mocked(fs.writeFile).mock.calls[0];
|
||||
const savedData = JSON.parse(
|
||||
writeCall[1] as string,
|
||||
) as OAuthCredentials[];
|
||||
|
||||
expect(savedData).toHaveLength(1);
|
||||
expect(savedData[0].token.accessToken).toBe('new_access_token');
|
||||
expect(savedData[0].serverName).toBe('existing-server');
|
||||
});
|
||||
|
||||
it('should handle write errors gracefully', async () => {
|
||||
vi.mocked(fs.readFile).mockRejectedValue({ code: 'ENOENT' });
|
||||
vi.mocked(fs.mkdir).mockResolvedValue(undefined);
|
||||
const writeError = new Error('Disk full');
|
||||
vi.mocked(fs.writeFile).mockRejectedValue(writeError);
|
||||
|
||||
await expect(
|
||||
tokenStorage.saveToken('test-server', mockToken),
|
||||
).rejects.toThrow('Disk full');
|
||||
|
||||
expect(console.error).toHaveBeenCalledWith(
|
||||
expect.stringContaining('Failed to save MCP OAuth token'),
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
it('should handle write errors gracefully', async () => {
|
||||
vi.mocked(fs.readFile).mockRejectedValue({ code: 'ENOENT' });
|
||||
vi.mocked(fs.mkdir).mockResolvedValue(undefined);
|
||||
const writeError = new Error('Disk full');
|
||||
vi.mocked(fs.writeFile).mockRejectedValue(writeError);
|
||||
describe('getCredentials', () => {
|
||||
it('should return token for existing server', async () => {
|
||||
vi.mocked(fs.readFile).mockResolvedValue(
|
||||
JSON.stringify([mockCredentials]),
|
||||
);
|
||||
|
||||
await expect(
|
||||
MCPOAuthTokenStorage.saveToken('test-server', mockToken),
|
||||
).rejects.toThrow('Disk full');
|
||||
const result = await tokenStorage.getCredentials('test-server');
|
||||
|
||||
expect(console.error).toHaveBeenCalledWith(
|
||||
expect.stringContaining('Failed to save MCP OAuth token'),
|
||||
);
|
||||
expect(result).toEqual(mockCredentials);
|
||||
});
|
||||
|
||||
it('should return null for non-existent server', async () => {
|
||||
vi.mocked(fs.readFile).mockResolvedValue(
|
||||
JSON.stringify([mockCredentials]),
|
||||
);
|
||||
|
||||
const result = await tokenStorage.getCredentials('non-existent');
|
||||
|
||||
expect(result).toBeNull();
|
||||
});
|
||||
|
||||
it('should return null when no tokens file exists', async () => {
|
||||
vi.mocked(fs.readFile).mockRejectedValue({ code: 'ENOENT' });
|
||||
|
||||
const result = await tokenStorage.getCredentials('test-server');
|
||||
|
||||
expect(result).toBeNull();
|
||||
});
|
||||
});
|
||||
|
||||
describe('deleteCredentials', () => {
|
||||
it('should remove token for specific server', async () => {
|
||||
const credentials1: OAuthCredentials = {
|
||||
...mockCredentials,
|
||||
serverName: 'server1',
|
||||
};
|
||||
const credentials2: OAuthCredentials = {
|
||||
...mockCredentials,
|
||||
serverName: 'server2',
|
||||
};
|
||||
vi.mocked(fs.readFile).mockResolvedValue(
|
||||
JSON.stringify([credentials1, credentials2]),
|
||||
);
|
||||
vi.mocked(fs.writeFile).mockResolvedValue(undefined);
|
||||
|
||||
await tokenStorage.deleteCredentials('server1');
|
||||
|
||||
const writeCall = vi.mocked(fs.writeFile).mock.calls[0];
|
||||
const savedData = JSON.parse(writeCall[1] as string);
|
||||
|
||||
expect(savedData).toHaveLength(1);
|
||||
expect(savedData[0].serverName).toBe('server2');
|
||||
});
|
||||
|
||||
it('should remove token file when no tokens remain', async () => {
|
||||
vi.mocked(fs.readFile).mockResolvedValue(
|
||||
JSON.stringify([mockCredentials]),
|
||||
);
|
||||
vi.mocked(fs.unlink).mockResolvedValue(undefined);
|
||||
|
||||
await tokenStorage.deleteCredentials('test-server');
|
||||
|
||||
expect(fs.unlink).toHaveBeenCalledWith(
|
||||
path.join('/mock/home', QWEN_DIR, 'mcp-oauth-tokens.json'),
|
||||
);
|
||||
expect(fs.writeFile).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should handle removal of non-existent token gracefully', async () => {
|
||||
vi.mocked(fs.readFile).mockResolvedValue(
|
||||
JSON.stringify([mockCredentials]),
|
||||
);
|
||||
|
||||
await tokenStorage.deleteCredentials('non-existent');
|
||||
|
||||
expect(fs.writeFile).not.toHaveBeenCalled();
|
||||
expect(fs.unlink).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should handle file operation errors gracefully', async () => {
|
||||
vi.mocked(fs.readFile).mockResolvedValue(
|
||||
JSON.stringify([mockCredentials]),
|
||||
);
|
||||
vi.mocked(fs.unlink).mockRejectedValue(new Error('Permission denied'));
|
||||
|
||||
await tokenStorage.deleteCredentials('test-server');
|
||||
|
||||
expect(console.error).toHaveBeenCalledWith(
|
||||
expect.stringContaining('Failed to remove MCP OAuth token'),
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('isTokenExpired', () => {
|
||||
it('should return false for token without expiry', () => {
|
||||
const tokenWithoutExpiry: OAuthToken = { ...mockToken };
|
||||
delete tokenWithoutExpiry.expiresAt;
|
||||
|
||||
const result = tokenStorage.isTokenExpired(tokenWithoutExpiry);
|
||||
|
||||
expect(result).toBe(false);
|
||||
});
|
||||
|
||||
it('should return false for valid token', () => {
|
||||
const futureToken: OAuthToken = {
|
||||
...mockToken,
|
||||
expiresAt: Date.now() + ONE_HR_MS,
|
||||
};
|
||||
|
||||
const result = tokenStorage.isTokenExpired(futureToken);
|
||||
|
||||
expect(result).toBe(false);
|
||||
});
|
||||
|
||||
it('should return true for expired token', () => {
|
||||
const expiredToken: OAuthToken = {
|
||||
...mockToken,
|
||||
expiresAt: Date.now() - ONE_HR_MS,
|
||||
};
|
||||
|
||||
const result = tokenStorage.isTokenExpired(expiredToken);
|
||||
|
||||
expect(result).toBe(true);
|
||||
});
|
||||
|
||||
it('should return true for token expiring within buffer time', () => {
|
||||
const soonToExpireToken: OAuthToken = {
|
||||
...mockToken,
|
||||
expiresAt: Date.now() + 60000, // 1 minute from now (within 5-minute buffer)
|
||||
};
|
||||
|
||||
const result = tokenStorage.isTokenExpired(soonToExpireToken);
|
||||
|
||||
expect(result).toBe(true);
|
||||
});
|
||||
});
|
||||
|
||||
describe('clearAll', () => {
|
||||
it('should remove token file successfully', async () => {
|
||||
vi.mocked(fs.unlink).mockResolvedValue(undefined);
|
||||
|
||||
await tokenStorage.clearAll();
|
||||
|
||||
expect(fs.unlink).toHaveBeenCalledWith(
|
||||
path.join('/mock/home', QWEN_DIR, 'mcp-oauth-tokens.json'),
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle non-existent file gracefully', async () => {
|
||||
vi.mocked(fs.unlink).mockRejectedValue({ code: 'ENOENT' });
|
||||
|
||||
await tokenStorage.clearAll();
|
||||
|
||||
expect(console.error).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should handle other file errors gracefully', async () => {
|
||||
vi.mocked(fs.unlink).mockRejectedValue(new Error('Permission denied'));
|
||||
|
||||
await tokenStorage.clearAll();
|
||||
|
||||
expect(console.error).toHaveBeenCalledWith(
|
||||
expect.stringContaining('Failed to clear MCP OAuth tokens'),
|
||||
);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('getToken', () => {
|
||||
it('should return token for existing server', async () => {
|
||||
vi.mocked(fs.readFile).mockResolvedValue(
|
||||
JSON.stringify([mockCredentials]),
|
||||
);
|
||||
describe('with encrypted flag true', () => {
|
||||
beforeEach(() => {
|
||||
vi.stubEnv(FORCE_ENCRYPTED_FILE_ENV_VAR, 'true');
|
||||
tokenStorage = new MCPOAuthTokenStorage();
|
||||
|
||||
const result = await MCPOAuthTokenStorage.getToken('test-server');
|
||||
|
||||
expect(result).toEqual(mockCredentials);
|
||||
vi.clearAllMocks();
|
||||
vi.spyOn(console, 'error');
|
||||
});
|
||||
|
||||
it('should return null for non-existent server', async () => {
|
||||
vi.mocked(fs.readFile).mockResolvedValue(
|
||||
JSON.stringify([mockCredentials]),
|
||||
);
|
||||
|
||||
const result = await MCPOAuthTokenStorage.getToken('non-existent');
|
||||
|
||||
expect(result).toBeNull();
|
||||
afterEach(() => {
|
||||
vi.unstubAllEnvs();
|
||||
vi.restoreAllMocks();
|
||||
});
|
||||
|
||||
it('should return null when no tokens file exists', async () => {
|
||||
vi.mocked(fs.readFile).mockRejectedValue({ code: 'ENOENT' });
|
||||
|
||||
const result = await MCPOAuthTokenStorage.getToken('test-server');
|
||||
|
||||
expect(result).toBeNull();
|
||||
});
|
||||
});
|
||||
|
||||
describe('removeToken', () => {
|
||||
it('should remove token for specific server', async () => {
|
||||
const credentials1 = { ...mockCredentials, serverName: 'server1' };
|
||||
const credentials2 = { ...mockCredentials, serverName: 'server2' };
|
||||
vi.mocked(fs.readFile).mockResolvedValue(
|
||||
JSON.stringify([credentials1, credentials2]),
|
||||
);
|
||||
vi.mocked(fs.writeFile).mockResolvedValue(undefined);
|
||||
|
||||
await MCPOAuthTokenStorage.removeToken('server1');
|
||||
|
||||
const writeCall = vi.mocked(fs.writeFile).mock.calls[0];
|
||||
const savedData = JSON.parse(writeCall[1] as string);
|
||||
|
||||
expect(savedData).toHaveLength(1);
|
||||
expect(savedData[0].serverName).toBe('server2');
|
||||
it('should use HybridTokenStorage to list all credentials', async () => {
|
||||
mockHybridTokenStorage.getAllCredentials.mockResolvedValue(new Map());
|
||||
const servers = await tokenStorage.getAllCredentials();
|
||||
expect(mockHybridTokenStorage.getAllCredentials).toHaveBeenCalled();
|
||||
expect(servers).toEqual(new Map());
|
||||
});
|
||||
|
||||
it('should remove token file when no tokens remain', async () => {
|
||||
vi.mocked(fs.readFile).mockResolvedValue(
|
||||
JSON.stringify([mockCredentials]),
|
||||
);
|
||||
vi.mocked(fs.unlink).mockResolvedValue(undefined);
|
||||
|
||||
await MCPOAuthTokenStorage.removeToken('test-server');
|
||||
|
||||
expect(fs.unlink).toHaveBeenCalledWith(
|
||||
path.join('/mock/home', '.qwen', 'mcp-oauth-tokens.json'),
|
||||
);
|
||||
expect(fs.writeFile).not.toHaveBeenCalled();
|
||||
it('should use HybridTokenStorage to list servers', async () => {
|
||||
mockHybridTokenStorage.listServers.mockResolvedValue(['server1']);
|
||||
const servers = await tokenStorage.listServers();
|
||||
expect(mockHybridTokenStorage.listServers).toHaveBeenCalled();
|
||||
expect(servers).toEqual(['server1']);
|
||||
});
|
||||
|
||||
it('should handle removal of non-existent token gracefully', async () => {
|
||||
vi.mocked(fs.readFile).mockResolvedValue(
|
||||
JSON.stringify([mockCredentials]),
|
||||
);
|
||||
|
||||
await MCPOAuthTokenStorage.removeToken('non-existent');
|
||||
|
||||
expect(fs.writeFile).not.toHaveBeenCalled();
|
||||
expect(fs.unlink).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should handle file operation errors gracefully', async () => {
|
||||
vi.mocked(fs.readFile).mockResolvedValue(
|
||||
JSON.stringify([mockCredentials]),
|
||||
);
|
||||
vi.mocked(fs.unlink).mockRejectedValue(new Error('Permission denied'));
|
||||
|
||||
await MCPOAuthTokenStorage.removeToken('test-server');
|
||||
|
||||
expect(console.error).toHaveBeenCalledWith(
|
||||
expect.stringContaining('Failed to remove MCP OAuth token'),
|
||||
it('should use HybridTokenStorage to set credentials', async () => {
|
||||
await tokenStorage.setCredentials(mockCredentials);
|
||||
expect(mockHybridTokenStorage.setCredentials).toHaveBeenCalledWith(
|
||||
mockCredentials,
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('isTokenExpired', () => {
|
||||
it('should return false for token without expiry', () => {
|
||||
const tokenWithoutExpiry = { ...mockToken };
|
||||
delete tokenWithoutExpiry.expiresAt;
|
||||
it('should use HybridTokenStorage to save a token', async () => {
|
||||
const serverName = 'server1';
|
||||
const now = Date.now();
|
||||
vi.spyOn(Date, 'now').mockReturnValue(now);
|
||||
|
||||
const result = MCPOAuthTokenStorage.isTokenExpired(tokenWithoutExpiry);
|
||||
await tokenStorage.saveToken(
|
||||
serverName,
|
||||
mockToken,
|
||||
'clientId',
|
||||
'tokenUrl',
|
||||
'mcpUrl',
|
||||
);
|
||||
|
||||
expect(result).toBe(false);
|
||||
});
|
||||
|
||||
it('should return false for valid token', () => {
|
||||
const futureToken = {
|
||||
...mockToken,
|
||||
expiresAt: Date.now() + 3600000, // 1 hour from now
|
||||
const expectedCredential: OAuthCredentials = {
|
||||
serverName,
|
||||
token: mockToken,
|
||||
clientId: 'clientId',
|
||||
tokenUrl: 'tokenUrl',
|
||||
mcpServerUrl: 'mcpUrl',
|
||||
updatedAt: now,
|
||||
};
|
||||
|
||||
const result = MCPOAuthTokenStorage.isTokenExpired(futureToken);
|
||||
|
||||
expect(result).toBe(false);
|
||||
expect(mockHybridTokenStorage.setCredentials).toHaveBeenCalledWith(
|
||||
expectedCredential,
|
||||
);
|
||||
expect(path.dirname).toHaveBeenCalled();
|
||||
expect(fs.mkdir).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should return true for expired token', () => {
|
||||
const expiredToken = {
|
||||
...mockToken,
|
||||
expiresAt: Date.now() - 3600000, // 1 hour ago
|
||||
};
|
||||
|
||||
const result = MCPOAuthTokenStorage.isTokenExpired(expiredToken);
|
||||
|
||||
expect(result).toBe(true);
|
||||
it('should use HybridTokenStorage to get credentials', async () => {
|
||||
mockHybridTokenStorage.getCredentials.mockResolvedValue(mockCredentials);
|
||||
const result = await tokenStorage.getCredentials('server1');
|
||||
expect(mockHybridTokenStorage.getCredentials).toHaveBeenCalledWith(
|
||||
'server1',
|
||||
);
|
||||
expect(result).toBe(mockCredentials);
|
||||
});
|
||||
|
||||
it('should return true for token expiring within buffer time', () => {
|
||||
const soonToExpireToken = {
|
||||
...mockToken,
|
||||
expiresAt: Date.now() + 60000, // 1 minute from now (within 5-minute buffer)
|
||||
};
|
||||
|
||||
const result = MCPOAuthTokenStorage.isTokenExpired(soonToExpireToken);
|
||||
|
||||
expect(result).toBe(true);
|
||||
});
|
||||
});
|
||||
|
||||
describe('clearAllTokens', () => {
|
||||
it('should remove token file successfully', async () => {
|
||||
vi.mocked(fs.unlink).mockResolvedValue(undefined);
|
||||
|
||||
await MCPOAuthTokenStorage.clearAllTokens();
|
||||
|
||||
expect(fs.unlink).toHaveBeenCalledWith(
|
||||
path.join('/mock/home', '.qwen', 'mcp-oauth-tokens.json'),
|
||||
it('should use HybridTokenStorage to delete credentials', async () => {
|
||||
await tokenStorage.deleteCredentials('server1');
|
||||
expect(mockHybridTokenStorage.deleteCredentials).toHaveBeenCalledWith(
|
||||
'server1',
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle non-existent file gracefully', async () => {
|
||||
vi.mocked(fs.unlink).mockRejectedValue({ code: 'ENOENT' });
|
||||
|
||||
await MCPOAuthTokenStorage.clearAllTokens();
|
||||
|
||||
expect(console.error).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should handle other file errors gracefully', async () => {
|
||||
vi.mocked(fs.unlink).mockRejectedValue(new Error('Permission denied'));
|
||||
|
||||
await MCPOAuthTokenStorage.clearAllTokens();
|
||||
|
||||
expect(console.error).toHaveBeenCalledWith(
|
||||
expect.stringContaining('Failed to clear MCP OAuth tokens'),
|
||||
);
|
||||
it('should use HybridTokenStorage to clear all tokens', async () => {
|
||||
await tokenStorage.clearAll();
|
||||
expect(mockHybridTokenStorage.clearAll).toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -8,25 +8,40 @@ import { promises as fs } from 'node:fs';
|
||||
import * as path from 'node:path';
|
||||
import { Storage } from '../config/storage.js';
|
||||
import { getErrorMessage } from '../utils/errors.js';
|
||||
import type { OAuthToken, OAuthCredentials } from './token-storage/types.js';
|
||||
import type {
|
||||
OAuthToken,
|
||||
OAuthCredentials,
|
||||
TokenStorage,
|
||||
} from './token-storage/types.js';
|
||||
import { HybridTokenStorage } from './token-storage/hybrid-token-storage.js';
|
||||
import {
|
||||
DEFAULT_SERVICE_NAME,
|
||||
FORCE_ENCRYPTED_FILE_ENV_VAR,
|
||||
} from './token-storage/index.js';
|
||||
|
||||
/**
|
||||
* Class for managing MCP OAuth token storage and retrieval.
|
||||
*/
|
||||
export class MCPOAuthTokenStorage {
|
||||
export class MCPOAuthTokenStorage implements TokenStorage {
|
||||
private readonly hybridTokenStorage = new HybridTokenStorage(
|
||||
DEFAULT_SERVICE_NAME,
|
||||
);
|
||||
private readonly useEncryptedFile =
|
||||
process.env[FORCE_ENCRYPTED_FILE_ENV_VAR] === 'true';
|
||||
|
||||
/**
|
||||
* Get the path to the token storage file.
|
||||
*
|
||||
* @returns The full path to the token storage file
|
||||
*/
|
||||
private static getTokenFilePath(): string {
|
||||
private getTokenFilePath(): string {
|
||||
return Storage.getMcpOAuthTokensPath();
|
||||
}
|
||||
|
||||
/**
|
||||
* Ensure the config directory exists.
|
||||
*/
|
||||
private static async ensureConfigDir(): Promise<void> {
|
||||
private async ensureConfigDir(): Promise<void> {
|
||||
const configDir = path.dirname(this.getTokenFilePath());
|
||||
await fs.mkdir(configDir, { recursive: true });
|
||||
}
|
||||
@@ -36,7 +51,10 @@ export class MCPOAuthTokenStorage {
|
||||
*
|
||||
* @returns A map of server names to credentials
|
||||
*/
|
||||
static async loadTokens(): Promise<Map<string, OAuthCredentials>> {
|
||||
async getAllCredentials(): Promise<Map<string, OAuthCredentials>> {
|
||||
if (this.useEncryptedFile) {
|
||||
return this.hybridTokenStorage.getAllCredentials();
|
||||
}
|
||||
const tokenMap = new Map<string, OAuthCredentials>();
|
||||
|
||||
try {
|
||||
@@ -59,36 +77,20 @@ export class MCPOAuthTokenStorage {
|
||||
return tokenMap;
|
||||
}
|
||||
|
||||
/**
|
||||
* Save a token for a specific MCP server.
|
||||
*
|
||||
* @param serverName The name of the MCP server
|
||||
* @param token The OAuth token to save
|
||||
* @param clientId Optional client ID used for this token
|
||||
* @param tokenUrl Optional token URL used for this token
|
||||
* @param mcpServerUrl Optional MCP server URL
|
||||
*/
|
||||
static async saveToken(
|
||||
serverName: string,
|
||||
token: OAuthToken,
|
||||
clientId?: string,
|
||||
tokenUrl?: string,
|
||||
mcpServerUrl?: string,
|
||||
): Promise<void> {
|
||||
await this.ensureConfigDir();
|
||||
async listServers(): Promise<string[]> {
|
||||
if (this.useEncryptedFile) {
|
||||
return this.hybridTokenStorage.listServers();
|
||||
}
|
||||
const tokens = await this.getAllCredentials();
|
||||
return Array.from(tokens.keys());
|
||||
}
|
||||
|
||||
const tokens = await this.loadTokens();
|
||||
|
||||
const credential: OAuthCredentials = {
|
||||
serverName,
|
||||
token,
|
||||
clientId,
|
||||
tokenUrl,
|
||||
mcpServerUrl,
|
||||
updatedAt: Date.now(),
|
||||
};
|
||||
|
||||
tokens.set(serverName, credential);
|
||||
async setCredentials(credentials: OAuthCredentials): Promise<void> {
|
||||
if (this.useEncryptedFile) {
|
||||
return this.hybridTokenStorage.setCredentials(credentials);
|
||||
}
|
||||
const tokens = await this.getAllCredentials();
|
||||
tokens.set(credentials.serverName, credentials);
|
||||
|
||||
const tokenArray = Array.from(tokens.values());
|
||||
const tokenFile = this.getTokenFilePath();
|
||||
@@ -107,14 +109,50 @@ export class MCPOAuthTokenStorage {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Save a token for a specific MCP server.
|
||||
*
|
||||
* @param serverName The name of the MCP server
|
||||
* @param token The OAuth token to save
|
||||
* @param clientId Optional client ID used for this token
|
||||
* @param tokenUrl Optional token URL used for this token
|
||||
* @param mcpServerUrl Optional MCP server URL
|
||||
*/
|
||||
async saveToken(
|
||||
serverName: string,
|
||||
token: OAuthToken,
|
||||
clientId?: string,
|
||||
tokenUrl?: string,
|
||||
mcpServerUrl?: string,
|
||||
): Promise<void> {
|
||||
await this.ensureConfigDir();
|
||||
|
||||
const credential: OAuthCredentials = {
|
||||
serverName,
|
||||
token,
|
||||
clientId,
|
||||
tokenUrl,
|
||||
mcpServerUrl,
|
||||
updatedAt: Date.now(),
|
||||
};
|
||||
|
||||
if (this.useEncryptedFile) {
|
||||
return this.hybridTokenStorage.setCredentials(credential);
|
||||
}
|
||||
await this.setCredentials(credential);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get a token for a specific MCP server.
|
||||
*
|
||||
* @param serverName The name of the MCP server
|
||||
* @returns The stored credentials or null if not found
|
||||
*/
|
||||
static async getToken(serverName: string): Promise<OAuthCredentials | null> {
|
||||
const tokens = await this.loadTokens();
|
||||
async getCredentials(serverName: string): Promise<OAuthCredentials | null> {
|
||||
if (this.useEncryptedFile) {
|
||||
return this.hybridTokenStorage.getCredentials(serverName);
|
||||
}
|
||||
const tokens = await this.getAllCredentials();
|
||||
return tokens.get(serverName) || null;
|
||||
}
|
||||
|
||||
@@ -123,8 +161,11 @@ export class MCPOAuthTokenStorage {
|
||||
*
|
||||
* @param serverName The name of the MCP server
|
||||
*/
|
||||
static async removeToken(serverName: string): Promise<void> {
|
||||
const tokens = await this.loadTokens();
|
||||
async deleteCredentials(serverName: string): Promise<void> {
|
||||
if (this.useEncryptedFile) {
|
||||
return this.hybridTokenStorage.deleteCredentials(serverName);
|
||||
}
|
||||
const tokens = await this.getAllCredentials();
|
||||
|
||||
if (tokens.delete(serverName)) {
|
||||
const tokenArray = Array.from(tokens.values());
|
||||
@@ -153,7 +194,7 @@ export class MCPOAuthTokenStorage {
|
||||
* @param token The token to check
|
||||
* @returns True if the token is expired
|
||||
*/
|
||||
static isTokenExpired(token: OAuthToken): boolean {
|
||||
isTokenExpired(token: OAuthToken): boolean {
|
||||
if (!token.expiresAt) {
|
||||
return false; // No expiry, assume valid
|
||||
}
|
||||
@@ -166,7 +207,10 @@ export class MCPOAuthTokenStorage {
|
||||
/**
|
||||
* Clear all stored MCP OAuth tokens.
|
||||
*/
|
||||
static async clearAllTokens(): Promise<void> {
|
||||
async clearAll(): Promise<void> {
|
||||
if (this.useEncryptedFile) {
|
||||
return this.hybridTokenStorage.clearAll();
|
||||
}
|
||||
try {
|
||||
const tokenFile = this.getTokenFilePath();
|
||||
await fs.unlink(tokenFile);
|
||||
|
||||
@@ -137,6 +137,7 @@ export class OAuthUtils {
|
||||
authorizationUrl: metadata.authorization_endpoint,
|
||||
tokenUrl: metadata.token_endpoint,
|
||||
scopes: metadata.scopes_supported || [],
|
||||
registrationUrl: metadata.registration_endpoint,
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
153
packages/core/src/mcp/sa-impersonation-provider.test.ts
Normal file
153
packages/core/src/mcp/sa-impersonation-provider.test.ts
Normal file
@@ -0,0 +1,153 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { describe, it, expect, vi, beforeEach } from 'vitest';
|
||||
import { ServiceAccountImpersonationProvider } from './sa-impersonation-provider.js';
|
||||
import type { MCPServerConfig } from '../config/config.js';
|
||||
|
||||
const mockRequest = vi.fn();
|
||||
const mockGetClient = vi.fn(() => ({
|
||||
request: mockRequest,
|
||||
}));
|
||||
|
||||
// Mock the google-auth-library to use a shared mock function
|
||||
vi.mock('google-auth-library', async (importOriginal) => {
|
||||
const actual = await importOriginal<typeof import('google-auth-library')>();
|
||||
return {
|
||||
...actual,
|
||||
GoogleAuth: vi.fn().mockImplementation(() => ({
|
||||
getClient: mockGetClient,
|
||||
})),
|
||||
};
|
||||
});
|
||||
|
||||
const defaultSAConfig: MCPServerConfig = {
|
||||
url: 'https://my-iap-service.run.app',
|
||||
targetAudience: 'my-audience',
|
||||
targetServiceAccount: 'my-sa',
|
||||
};
|
||||
|
||||
describe('ServiceAccountImpersonationProvider', () => {
|
||||
beforeEach(() => {
|
||||
// Reset mocks before each test
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
it('should throw an error if no URL is provided', () => {
|
||||
const config: MCPServerConfig = {};
|
||||
expect(() => new ServiceAccountImpersonationProvider(config)).toThrow(
|
||||
'A url or httpUrl must be provided for the Service Account Impersonation provider',
|
||||
);
|
||||
});
|
||||
|
||||
it('should throw an error if no targetAudience is provided', () => {
|
||||
const config: MCPServerConfig = {
|
||||
url: 'https://my-iap-service.run.app',
|
||||
};
|
||||
expect(() => new ServiceAccountImpersonationProvider(config)).toThrow(
|
||||
'targetAudience must be provided for the Service Account Impersonation provider',
|
||||
);
|
||||
});
|
||||
|
||||
it('should throw an error if no targetSA is provided', () => {
|
||||
const config: MCPServerConfig = {
|
||||
url: 'https://my-iap-service.run.app',
|
||||
targetAudience: 'my-audience',
|
||||
};
|
||||
expect(() => new ServiceAccountImpersonationProvider(config)).toThrow(
|
||||
'targetServiceAccount must be provided for the Service Account Impersonation provider',
|
||||
);
|
||||
});
|
||||
|
||||
it('should correctly get tokens for a valid config', async () => {
|
||||
const mockToken = 'mock-id-token-123';
|
||||
mockRequest.mockResolvedValue({ data: { token: mockToken } });
|
||||
|
||||
const provider = new ServiceAccountImpersonationProvider(defaultSAConfig);
|
||||
const tokens = await provider.tokens();
|
||||
|
||||
expect(tokens).toBeDefined();
|
||||
expect(tokens?.access_token).toBe(mockToken);
|
||||
expect(tokens?.token_type).toBe('Bearer');
|
||||
});
|
||||
|
||||
it('should return undefined if token acquisition fails', async () => {
|
||||
mockRequest.mockResolvedValue({ data: { token: null } });
|
||||
|
||||
const provider = new ServiceAccountImpersonationProvider(defaultSAConfig);
|
||||
const tokens = await provider.tokens();
|
||||
|
||||
expect(tokens).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should make a request with the correct parameters', async () => {
|
||||
mockRequest.mockResolvedValue({ data: { token: 'test-token' } });
|
||||
|
||||
const provider = new ServiceAccountImpersonationProvider(defaultSAConfig);
|
||||
await provider.tokens();
|
||||
|
||||
expect(mockRequest).toHaveBeenCalledWith({
|
||||
url: 'https://iamcredentials.googleapis.com/v1/projects/-/serviceAccounts/my-sa:generateIdToken',
|
||||
method: 'POST',
|
||||
data: {
|
||||
audience: 'my-audience',
|
||||
includeEmail: true,
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
it('should return a cached token if it is not expired', async () => {
|
||||
const provider = new ServiceAccountImpersonationProvider(defaultSAConfig);
|
||||
vi.useFakeTimers();
|
||||
|
||||
// jwt payload with exp set to 1 hour from now
|
||||
const payload = { exp: Math.floor(Date.now() / 1000) + 3600 };
|
||||
const jwt = `header.${Buffer.from(JSON.stringify(payload)).toString('base64')}.signature`;
|
||||
mockRequest.mockResolvedValue({ data: { token: jwt } });
|
||||
|
||||
const firstTokens = await provider.tokens();
|
||||
expect(firstTokens?.access_token).toBe(jwt);
|
||||
expect(mockRequest).toHaveBeenCalledTimes(1);
|
||||
|
||||
// Advance time by 30 minutes
|
||||
vi.advanceTimersByTime(1800 * 1000);
|
||||
|
||||
// Seturn cached token
|
||||
const secondTokens = await provider.tokens();
|
||||
expect(secondTokens).toBe(firstTokens);
|
||||
expect(mockRequest).toHaveBeenCalledTimes(1);
|
||||
|
||||
vi.useRealTimers();
|
||||
});
|
||||
|
||||
it('should fetch a new token if the cached token is expired (using fake timers)', async () => {
|
||||
const provider = new ServiceAccountImpersonationProvider(defaultSAConfig);
|
||||
vi.useFakeTimers();
|
||||
|
||||
// Get and cache a token that expires in 1 second
|
||||
const expiredPayload = { exp: Math.floor(Date.now() / 1000) + 1 };
|
||||
const expiredJwt = `header.${Buffer.from(JSON.stringify(expiredPayload)).toString('base64')}.signature`;
|
||||
|
||||
mockRequest.mockResolvedValue({ data: { token: expiredJwt } });
|
||||
const firstTokens = await provider.tokens();
|
||||
expect(firstTokens?.access_token).toBe(expiredJwt);
|
||||
expect(mockRequest).toHaveBeenCalledTimes(1);
|
||||
|
||||
// Prepare the mock for the *next* call
|
||||
const newPayload = { exp: Math.floor(Date.now() / 1000) + 3600 };
|
||||
const newJwt = `header.${Buffer.from(JSON.stringify(newPayload)).toString('base64')}.signature`;
|
||||
mockRequest.mockResolvedValue({ data: { token: newJwt } });
|
||||
|
||||
vi.advanceTimersByTime(1001);
|
||||
|
||||
const newTokens = await provider.tokens();
|
||||
expect(newTokens?.access_token).toBe(newJwt);
|
||||
expect(newTokens?.access_token).not.toBe(expiredJwt);
|
||||
expect(mockRequest).toHaveBeenCalledTimes(2); // Confirms a new fetch
|
||||
|
||||
vi.useRealTimers();
|
||||
});
|
||||
});
|
||||
171
packages/core/src/mcp/sa-impersonation-provider.ts
Normal file
171
packages/core/src/mcp/sa-impersonation-provider.ts
Normal file
@@ -0,0 +1,171 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import type {
|
||||
OAuthClientInformation,
|
||||
OAuthClientInformationFull,
|
||||
OAuthClientMetadata,
|
||||
OAuthTokens,
|
||||
} from '@modelcontextprotocol/sdk/shared/auth.js';
|
||||
import { GoogleAuth } from 'google-auth-library';
|
||||
import type { MCPServerConfig } from '../config/config.js';
|
||||
import type { OAuthClientProvider } from '@modelcontextprotocol/sdk/client/auth.js';
|
||||
|
||||
const fiveMinBufferMs = 5 * 60 * 1000;
|
||||
|
||||
function createIamApiUrl(targetSA: string): string {
|
||||
return `https://iamcredentials.googleapis.com/v1/projects/-/serviceAccounts/${encodeURIComponent(targetSA)}:generateIdToken`;
|
||||
}
|
||||
|
||||
export class ServiceAccountImpersonationProvider
|
||||
implements OAuthClientProvider
|
||||
{
|
||||
private readonly targetServiceAccount: string;
|
||||
private readonly targetAudience: string; // OAuth Client Id
|
||||
private readonly auth: GoogleAuth;
|
||||
private cachedToken?: OAuthTokens;
|
||||
private tokenExpiryTime?: number;
|
||||
|
||||
// Properties required by OAuthClientProvider, with no-op values
|
||||
readonly redirectUrl = '';
|
||||
readonly clientMetadata: OAuthClientMetadata = {
|
||||
client_name: 'Gemini CLI (Service Account Impersonation)',
|
||||
redirect_uris: [],
|
||||
grant_types: [],
|
||||
response_types: [],
|
||||
token_endpoint_auth_method: 'none',
|
||||
};
|
||||
private _clientInformation?: OAuthClientInformationFull;
|
||||
|
||||
constructor(private readonly config: MCPServerConfig) {
|
||||
// This check is done in mcp-client.ts. This is just an additional check.
|
||||
if (!this.config.httpUrl && !this.config.url) {
|
||||
throw new Error(
|
||||
'A url or httpUrl must be provided for the Service Account Impersonation provider',
|
||||
);
|
||||
}
|
||||
|
||||
if (!config.targetAudience) {
|
||||
throw new Error(
|
||||
'targetAudience must be provided for the Service Account Impersonation provider',
|
||||
);
|
||||
}
|
||||
this.targetAudience = config.targetAudience;
|
||||
|
||||
if (!config.targetServiceAccount) {
|
||||
throw new Error(
|
||||
'targetServiceAccount must be provided for the Service Account Impersonation provider',
|
||||
);
|
||||
}
|
||||
this.targetServiceAccount = config.targetServiceAccount;
|
||||
|
||||
this.auth = new GoogleAuth();
|
||||
}
|
||||
|
||||
clientInformation(): OAuthClientInformation | undefined {
|
||||
return this._clientInformation;
|
||||
}
|
||||
|
||||
saveClientInformation(clientInformation: OAuthClientInformationFull): void {
|
||||
this._clientInformation = clientInformation;
|
||||
}
|
||||
|
||||
async tokens(): Promise<OAuthTokens | undefined> {
|
||||
// 1. Check if we have a valid, non-expired cached token.
|
||||
if (
|
||||
this.cachedToken &&
|
||||
this.tokenExpiryTime &&
|
||||
Date.now() < this.tokenExpiryTime - fiveMinBufferMs
|
||||
) {
|
||||
return this.cachedToken;
|
||||
}
|
||||
|
||||
// 2. Clear any invalid/expired cache.
|
||||
this.cachedToken = undefined;
|
||||
this.tokenExpiryTime = undefined;
|
||||
|
||||
// 3. Fetch a new ID token.
|
||||
const client = await this.auth.getClient();
|
||||
const url = createIamApiUrl(this.targetServiceAccount);
|
||||
|
||||
let idToken: string;
|
||||
try {
|
||||
const res = await client.request<{ token: string }>({
|
||||
url,
|
||||
method: 'POST',
|
||||
data: {
|
||||
audience: this.targetAudience,
|
||||
includeEmail: true,
|
||||
},
|
||||
});
|
||||
idToken = res.data.token;
|
||||
|
||||
if (!idToken || idToken.length === 0) {
|
||||
console.error('Failed to get ID token from Google');
|
||||
return undefined;
|
||||
}
|
||||
} catch (e) {
|
||||
console.error('Failed to fetch ID token from Google:', e);
|
||||
return undefined;
|
||||
}
|
||||
|
||||
const expiryTime = this.parseTokenExpiry(idToken);
|
||||
// Note: We are placing the OIDC ID Token into the `access_token` field.
|
||||
// This is because the CLI uses this field to construct the
|
||||
// `Authorization: Bearer <token>` header, which is the correct way to
|
||||
// present an ID token.
|
||||
const newTokens: OAuthTokens = {
|
||||
access_token: idToken,
|
||||
token_type: 'Bearer',
|
||||
};
|
||||
|
||||
if (expiryTime) {
|
||||
this.tokenExpiryTime = expiryTime;
|
||||
this.cachedToken = newTokens;
|
||||
}
|
||||
|
||||
return newTokens;
|
||||
}
|
||||
|
||||
saveTokens(_tokens: OAuthTokens): void {
|
||||
// No-op
|
||||
}
|
||||
|
||||
redirectToAuthorization(_authorizationUrl: URL): void {
|
||||
// No-op
|
||||
}
|
||||
|
||||
saveCodeVerifier(_codeVerifier: string): void {
|
||||
// No-op
|
||||
}
|
||||
|
||||
codeVerifier(): string {
|
||||
// No-op
|
||||
return '';
|
||||
}
|
||||
|
||||
/**
|
||||
* Parses a JWT string to extract its expiry time.
|
||||
* @param idToken The JWT ID token.
|
||||
* @returns The expiry time in **milliseconds**, or undefined if parsing fails.
|
||||
*/
|
||||
private parseTokenExpiry(idToken: string): number | undefined {
|
||||
try {
|
||||
const payload = JSON.parse(
|
||||
Buffer.from(idToken.split('.')[1], 'base64').toString(),
|
||||
);
|
||||
|
||||
if (payload && typeof payload.exp === 'number') {
|
||||
return payload.exp * 1000; // Convert seconds to milliseconds
|
||||
}
|
||||
} catch (e) {
|
||||
console.error('Failed to parse ID token for expiry time with error:', e);
|
||||
}
|
||||
|
||||
// Return undefined if try block fails or 'exp' is missing/invalid
|
||||
return undefined;
|
||||
}
|
||||
}
|
||||
@@ -53,7 +53,7 @@ describe('BaseTokenStorage', () => {
|
||||
let storage: TestTokenStorage;
|
||||
|
||||
beforeEach(() => {
|
||||
storage = new TestTokenStorage();
|
||||
storage = new TestTokenStorage('gemini-cli-mcp-oauth');
|
||||
});
|
||||
|
||||
describe('validateCredentials', () => {
|
||||
|
||||
@@ -9,7 +9,7 @@ import type { TokenStorage, OAuthCredentials } from './types.js';
|
||||
export abstract class BaseTokenStorage implements TokenStorage {
|
||||
protected readonly serviceName: string;
|
||||
|
||||
constructor(serviceName: string = 'gemini-cli-mcp-oauth') {
|
||||
constructor(serviceName: string) {
|
||||
this.serviceName = serviceName;
|
||||
}
|
||||
|
||||
|
||||
323
packages/core/src/mcp/token-storage/file-token-storage.test.ts
Normal file
323
packages/core/src/mcp/token-storage/file-token-storage.test.ts
Normal file
@@ -0,0 +1,323 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { describe, it, expect, beforeEach, afterEach, vi } from 'vitest';
|
||||
import { promises as fs } from 'node:fs';
|
||||
import * as path from 'node:path';
|
||||
import { FileTokenStorage } from './file-token-storage.js';
|
||||
import type { OAuthCredentials } from './types.js';
|
||||
|
||||
vi.mock('node:fs', () => ({
|
||||
promises: {
|
||||
readFile: vi.fn(),
|
||||
writeFile: vi.fn(),
|
||||
unlink: vi.fn(),
|
||||
mkdir: vi.fn(),
|
||||
},
|
||||
}));
|
||||
|
||||
vi.mock('node:os', () => ({
|
||||
default: {
|
||||
homedir: vi.fn(() => '/home/test'),
|
||||
hostname: vi.fn(() => 'test-host'),
|
||||
userInfo: vi.fn(() => ({ username: 'test-user' })),
|
||||
},
|
||||
homedir: vi.fn(() => '/home/test'),
|
||||
hostname: vi.fn(() => 'test-host'),
|
||||
userInfo: vi.fn(() => ({ username: 'test-user' })),
|
||||
}));
|
||||
|
||||
describe('FileTokenStorage', () => {
|
||||
let storage: FileTokenStorage;
|
||||
const mockFs = fs as unknown as {
|
||||
readFile: ReturnType<typeof vi.fn>;
|
||||
writeFile: ReturnType<typeof vi.fn>;
|
||||
unlink: ReturnType<typeof vi.fn>;
|
||||
mkdir: ReturnType<typeof vi.fn>;
|
||||
};
|
||||
const existingCredentials: OAuthCredentials = {
|
||||
serverName: 'existing-server',
|
||||
token: {
|
||||
accessToken: 'existing-token',
|
||||
tokenType: 'Bearer',
|
||||
},
|
||||
updatedAt: Date.now() - 10000,
|
||||
};
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
storage = new FileTokenStorage('test-storage');
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
describe('getCredentials', () => {
|
||||
it('should throw error when file does not exist', async () => {
|
||||
mockFs.readFile.mockRejectedValue({ code: 'ENOENT' });
|
||||
|
||||
await expect(storage.getCredentials('test-server')).rejects.toThrow(
|
||||
'Token file does not exist',
|
||||
);
|
||||
});
|
||||
|
||||
it('should return null for expired tokens', async () => {
|
||||
const credentials: OAuthCredentials = {
|
||||
serverName: 'test-server',
|
||||
token: {
|
||||
accessToken: 'access-token',
|
||||
tokenType: 'Bearer',
|
||||
expiresAt: Date.now() - 3600000,
|
||||
},
|
||||
updatedAt: Date.now(),
|
||||
};
|
||||
|
||||
const encryptedData = storage['encrypt'](
|
||||
JSON.stringify({ 'test-server': credentials }),
|
||||
);
|
||||
mockFs.readFile.mockResolvedValue(encryptedData);
|
||||
|
||||
const result = await storage.getCredentials('test-server');
|
||||
expect(result).toBeNull();
|
||||
});
|
||||
|
||||
it('should return credentials for valid tokens', async () => {
|
||||
const credentials: OAuthCredentials = {
|
||||
serverName: 'test-server',
|
||||
token: {
|
||||
accessToken: 'access-token',
|
||||
tokenType: 'Bearer',
|
||||
expiresAt: Date.now() + 3600000,
|
||||
},
|
||||
updatedAt: Date.now(),
|
||||
};
|
||||
|
||||
const encryptedData = storage['encrypt'](
|
||||
JSON.stringify({ 'test-server': credentials }),
|
||||
);
|
||||
mockFs.readFile.mockResolvedValue(encryptedData);
|
||||
|
||||
const result = await storage.getCredentials('test-server');
|
||||
expect(result).toEqual(credentials);
|
||||
});
|
||||
|
||||
it('should throw error for corrupted files', async () => {
|
||||
mockFs.readFile.mockResolvedValue('corrupted-data');
|
||||
|
||||
await expect(storage.getCredentials('test-server')).rejects.toThrow(
|
||||
'Token file corrupted',
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('setCredentials', () => {
|
||||
it('should save credentials with encryption', async () => {
|
||||
const encryptedData = storage['encrypt'](
|
||||
JSON.stringify({ 'existing-server': existingCredentials }),
|
||||
);
|
||||
mockFs.readFile.mockResolvedValue(encryptedData);
|
||||
mockFs.mkdir.mockResolvedValue(undefined);
|
||||
mockFs.writeFile.mockResolvedValue(undefined);
|
||||
|
||||
const credentials: OAuthCredentials = {
|
||||
serverName: 'test-server',
|
||||
token: {
|
||||
accessToken: 'access-token',
|
||||
tokenType: 'Bearer',
|
||||
},
|
||||
updatedAt: Date.now(),
|
||||
};
|
||||
|
||||
await storage.setCredentials(credentials);
|
||||
|
||||
expect(mockFs.mkdir).toHaveBeenCalledWith(
|
||||
path.join('/home/test', '.qwen'),
|
||||
{ recursive: true, mode: 0o700 },
|
||||
);
|
||||
expect(mockFs.writeFile).toHaveBeenCalled();
|
||||
|
||||
const writeCall = mockFs.writeFile.mock.calls[0];
|
||||
expect(writeCall[1]).toMatch(/^[0-9a-f]+:[0-9a-f]+:[0-9a-f]+$/);
|
||||
expect(writeCall[2]).toEqual({ mode: 0o600 });
|
||||
});
|
||||
|
||||
it('should update existing credentials', async () => {
|
||||
const encryptedData = storage['encrypt'](
|
||||
JSON.stringify({ 'existing-server': existingCredentials }),
|
||||
);
|
||||
mockFs.readFile.mockResolvedValue(encryptedData);
|
||||
mockFs.writeFile.mockResolvedValue(undefined);
|
||||
|
||||
const newCredentials: OAuthCredentials = {
|
||||
serverName: 'test-server',
|
||||
token: {
|
||||
accessToken: 'new-token',
|
||||
tokenType: 'Bearer',
|
||||
},
|
||||
updatedAt: Date.now(),
|
||||
};
|
||||
|
||||
await storage.setCredentials(newCredentials);
|
||||
|
||||
expect(mockFs.writeFile).toHaveBeenCalled();
|
||||
const writeCall = mockFs.writeFile.mock.calls[0];
|
||||
const decrypted = storage['decrypt'](writeCall[1]);
|
||||
const saved = JSON.parse(decrypted);
|
||||
|
||||
expect(saved['existing-server']).toEqual(existingCredentials);
|
||||
expect(saved['test-server'].token.accessToken).toBe('new-token');
|
||||
});
|
||||
});
|
||||
|
||||
describe('deleteCredentials', () => {
|
||||
it('should throw when credentials do not exist', async () => {
|
||||
mockFs.readFile.mockRejectedValue({ code: 'ENOENT' });
|
||||
|
||||
await expect(storage.deleteCredentials('test-server')).rejects.toThrow(
|
||||
'Token file does not exist',
|
||||
);
|
||||
});
|
||||
|
||||
it('should delete file when last credential is removed', async () => {
|
||||
const credentials: OAuthCredentials = {
|
||||
serverName: 'test-server',
|
||||
token: {
|
||||
accessToken: 'access-token',
|
||||
tokenType: 'Bearer',
|
||||
},
|
||||
updatedAt: Date.now(),
|
||||
};
|
||||
|
||||
const encryptedData = storage['encrypt'](
|
||||
JSON.stringify({ 'test-server': credentials }),
|
||||
);
|
||||
mockFs.readFile.mockResolvedValue(encryptedData);
|
||||
mockFs.unlink.mockResolvedValue(undefined);
|
||||
|
||||
await storage.deleteCredentials('test-server');
|
||||
|
||||
expect(mockFs.unlink).toHaveBeenCalledWith(
|
||||
path.join('/home/test', '.qwen', 'mcp-oauth-tokens-v2.json'),
|
||||
);
|
||||
});
|
||||
|
||||
it('should update file when other credentials remain', async () => {
|
||||
const credentials1: OAuthCredentials = {
|
||||
serverName: 'server1',
|
||||
token: {
|
||||
accessToken: 'token1',
|
||||
tokenType: 'Bearer',
|
||||
},
|
||||
updatedAt: Date.now(),
|
||||
};
|
||||
|
||||
const credentials2: OAuthCredentials = {
|
||||
serverName: 'server2',
|
||||
token: {
|
||||
accessToken: 'token2',
|
||||
tokenType: 'Bearer',
|
||||
},
|
||||
updatedAt: Date.now(),
|
||||
};
|
||||
|
||||
const encryptedData = storage['encrypt'](
|
||||
JSON.stringify({ server1: credentials1, server2: credentials2 }),
|
||||
);
|
||||
mockFs.readFile.mockResolvedValue(encryptedData);
|
||||
mockFs.writeFile.mockResolvedValue(undefined);
|
||||
|
||||
await storage.deleteCredentials('server1');
|
||||
|
||||
expect(mockFs.writeFile).toHaveBeenCalled();
|
||||
expect(mockFs.unlink).not.toHaveBeenCalled();
|
||||
|
||||
const writeCall = mockFs.writeFile.mock.calls[0];
|
||||
const decrypted = storage['decrypt'](writeCall[1]);
|
||||
const saved = JSON.parse(decrypted);
|
||||
|
||||
expect(saved['server1']).toBeUndefined();
|
||||
expect(saved['server2']).toEqual(credentials2);
|
||||
});
|
||||
});
|
||||
|
||||
describe('listServers', () => {
|
||||
it('should throw error when file does not exist', async () => {
|
||||
mockFs.readFile.mockRejectedValue({ code: 'ENOENT' });
|
||||
|
||||
await expect(storage.listServers()).rejects.toThrow(
|
||||
'Token file does not exist',
|
||||
);
|
||||
});
|
||||
|
||||
it('should return list of server names', async () => {
|
||||
const credentials: Record<string, OAuthCredentials> = {
|
||||
server1: {
|
||||
serverName: 'server1',
|
||||
token: { accessToken: 'token1', tokenType: 'Bearer' },
|
||||
updatedAt: Date.now(),
|
||||
},
|
||||
server2: {
|
||||
serverName: 'server2',
|
||||
token: { accessToken: 'token2', tokenType: 'Bearer' },
|
||||
updatedAt: Date.now(),
|
||||
},
|
||||
};
|
||||
|
||||
const encryptedData = storage['encrypt'](JSON.stringify(credentials));
|
||||
mockFs.readFile.mockResolvedValue(encryptedData);
|
||||
|
||||
const result = await storage.listServers();
|
||||
expect(result).toEqual(['server1', 'server2']);
|
||||
});
|
||||
});
|
||||
|
||||
describe('clearAll', () => {
|
||||
it('should delete the token file', async () => {
|
||||
mockFs.unlink.mockResolvedValue(undefined);
|
||||
|
||||
await storage.clearAll();
|
||||
|
||||
expect(mockFs.unlink).toHaveBeenCalledWith(
|
||||
path.join('/home/test', '.qwen', 'mcp-oauth-tokens-v2.json'),
|
||||
);
|
||||
});
|
||||
|
||||
it('should not throw when file does not exist', async () => {
|
||||
mockFs.unlink.mockRejectedValue({ code: 'ENOENT' });
|
||||
|
||||
await expect(storage.clearAll()).resolves.not.toThrow();
|
||||
});
|
||||
});
|
||||
|
||||
describe('encryption', () => {
|
||||
it('should encrypt and decrypt data correctly', () => {
|
||||
const original = 'test-data-123';
|
||||
const encrypted = storage['encrypt'](original);
|
||||
const decrypted = storage['decrypt'](encrypted);
|
||||
|
||||
expect(decrypted).toBe(original);
|
||||
expect(encrypted).not.toBe(original);
|
||||
expect(encrypted).toMatch(/^[0-9a-f]+:[0-9a-f]+:[0-9a-f]+$/);
|
||||
});
|
||||
|
||||
it('should produce different encrypted output each time', () => {
|
||||
const original = 'test-data';
|
||||
const encrypted1 = storage['encrypt'](original);
|
||||
const encrypted2 = storage['encrypt'](original);
|
||||
|
||||
expect(encrypted1).not.toBe(encrypted2);
|
||||
expect(storage['decrypt'](encrypted1)).toBe(original);
|
||||
expect(storage['decrypt'](encrypted2)).toBe(original);
|
||||
});
|
||||
|
||||
it('should throw on invalid encrypted data format', () => {
|
||||
expect(() => storage['decrypt']('invalid-data')).toThrow(
|
||||
'Invalid encrypted data format',
|
||||
);
|
||||
});
|
||||
});
|
||||
});
|
||||
184
packages/core/src/mcp/token-storage/file-token-storage.ts
Normal file
184
packages/core/src/mcp/token-storage/file-token-storage.ts
Normal file
@@ -0,0 +1,184 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { promises as fs } from 'node:fs';
|
||||
import * as path from 'node:path';
|
||||
import * as os from 'node:os';
|
||||
import * as crypto from 'node:crypto';
|
||||
import { BaseTokenStorage } from './base-token-storage.js';
|
||||
import type { OAuthCredentials } from './types.js';
|
||||
|
||||
export class FileTokenStorage extends BaseTokenStorage {
|
||||
private readonly tokenFilePath: string;
|
||||
private readonly encryptionKey: Buffer;
|
||||
|
||||
constructor(serviceName: string) {
|
||||
super(serviceName);
|
||||
const configDir = path.join(os.homedir(), '.qwen');
|
||||
this.tokenFilePath = path.join(configDir, 'mcp-oauth-tokens-v2.json');
|
||||
this.encryptionKey = this.deriveEncryptionKey();
|
||||
}
|
||||
|
||||
private deriveEncryptionKey(): Buffer {
|
||||
const salt = `${os.hostname()}-${os.userInfo().username}-qwen-code`;
|
||||
return crypto.scryptSync('qwen-code-oauth', salt, 32);
|
||||
}
|
||||
|
||||
private encrypt(text: string): string {
|
||||
const iv = crypto.randomBytes(16);
|
||||
const cipher = crypto.createCipheriv('aes-256-gcm', this.encryptionKey, iv);
|
||||
|
||||
let encrypted = cipher.update(text, 'utf8', 'hex');
|
||||
encrypted += cipher.final('hex');
|
||||
|
||||
const authTag = cipher.getAuthTag();
|
||||
|
||||
return iv.toString('hex') + ':' + authTag.toString('hex') + ':' + encrypted;
|
||||
}
|
||||
|
||||
private decrypt(encryptedData: string): string {
|
||||
const parts = encryptedData.split(':');
|
||||
if (parts.length !== 3) {
|
||||
throw new Error('Invalid encrypted data format');
|
||||
}
|
||||
|
||||
const iv = Buffer.from(parts[0], 'hex');
|
||||
const authTag = Buffer.from(parts[1], 'hex');
|
||||
const encrypted = parts[2];
|
||||
|
||||
const decipher = crypto.createDecipheriv(
|
||||
'aes-256-gcm',
|
||||
this.encryptionKey,
|
||||
iv,
|
||||
);
|
||||
decipher.setAuthTag(authTag);
|
||||
|
||||
let decrypted = decipher.update(encrypted, 'hex', 'utf8');
|
||||
decrypted += decipher.final('utf8');
|
||||
|
||||
return decrypted;
|
||||
}
|
||||
|
||||
private async ensureDirectoryExists(): Promise<void> {
|
||||
const dir = path.dirname(this.tokenFilePath);
|
||||
await fs.mkdir(dir, { recursive: true, mode: 0o700 });
|
||||
}
|
||||
|
||||
private async loadTokens(): Promise<Map<string, OAuthCredentials>> {
|
||||
try {
|
||||
const data = await fs.readFile(this.tokenFilePath, 'utf-8');
|
||||
const decrypted = this.decrypt(data);
|
||||
const tokens = JSON.parse(decrypted) as Record<string, OAuthCredentials>;
|
||||
return new Map(Object.entries(tokens));
|
||||
} catch (error: unknown) {
|
||||
const err = error as NodeJS.ErrnoException & { message?: string };
|
||||
if (err.code === 'ENOENT') {
|
||||
throw new Error('Token file does not exist');
|
||||
}
|
||||
if (
|
||||
err.message?.includes('Invalid encrypted data format') ||
|
||||
err.message?.includes(
|
||||
'Unsupported state or unable to authenticate data',
|
||||
)
|
||||
) {
|
||||
throw new Error('Token file corrupted');
|
||||
}
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
private async saveTokens(
|
||||
tokens: Map<string, OAuthCredentials>,
|
||||
): Promise<void> {
|
||||
await this.ensureDirectoryExists();
|
||||
|
||||
const data = Object.fromEntries(tokens);
|
||||
const json = JSON.stringify(data, null, 2);
|
||||
const encrypted = this.encrypt(json);
|
||||
|
||||
await fs.writeFile(this.tokenFilePath, encrypted, { mode: 0o600 });
|
||||
}
|
||||
|
||||
async getCredentials(serverName: string): Promise<OAuthCredentials | null> {
|
||||
const tokens = await this.loadTokens();
|
||||
const credentials = tokens.get(serverName);
|
||||
|
||||
if (!credentials) {
|
||||
return null;
|
||||
}
|
||||
|
||||
if (this.isTokenExpired(credentials)) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return credentials;
|
||||
}
|
||||
|
||||
async setCredentials(credentials: OAuthCredentials): Promise<void> {
|
||||
this.validateCredentials(credentials);
|
||||
|
||||
const tokens = await this.loadTokens();
|
||||
const updatedCredentials: OAuthCredentials = {
|
||||
...credentials,
|
||||
updatedAt: Date.now(),
|
||||
};
|
||||
|
||||
tokens.set(credentials.serverName, updatedCredentials);
|
||||
await this.saveTokens(tokens);
|
||||
}
|
||||
|
||||
async deleteCredentials(serverName: string): Promise<void> {
|
||||
const tokens = await this.loadTokens();
|
||||
|
||||
if (!tokens.has(serverName)) {
|
||||
throw new Error(`No credentials found for ${serverName}`);
|
||||
}
|
||||
|
||||
tokens.delete(serverName);
|
||||
|
||||
if (tokens.size === 0) {
|
||||
try {
|
||||
await fs.unlink(this.tokenFilePath);
|
||||
} catch (error: unknown) {
|
||||
const err = error as NodeJS.ErrnoException;
|
||||
if (err.code !== 'ENOENT') {
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
await this.saveTokens(tokens);
|
||||
}
|
||||
}
|
||||
|
||||
async listServers(): Promise<string[]> {
|
||||
const tokens = await this.loadTokens();
|
||||
return Array.from(tokens.keys());
|
||||
}
|
||||
|
||||
async getAllCredentials(): Promise<Map<string, OAuthCredentials>> {
|
||||
const tokens = await this.loadTokens();
|
||||
const result = new Map<string, OAuthCredentials>();
|
||||
|
||||
for (const [serverName, credentials] of tokens) {
|
||||
if (!this.isTokenExpired(credentials)) {
|
||||
result.set(serverName, credentials);
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
async clearAll(): Promise<void> {
|
||||
try {
|
||||
await fs.unlink(this.tokenFilePath);
|
||||
} catch (error: unknown) {
|
||||
const err = error as NodeJS.ErrnoException;
|
||||
if (err.code !== 'ENOENT') {
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
274
packages/core/src/mcp/token-storage/hybrid-token-storage.test.ts
Normal file
274
packages/core/src/mcp/token-storage/hybrid-token-storage.test.ts
Normal file
@@ -0,0 +1,274 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { describe, it, expect, beforeEach, afterEach, vi } from 'vitest';
|
||||
import { HybridTokenStorage } from './hybrid-token-storage.js';
|
||||
import { KeychainTokenStorage } from './keychain-token-storage.js';
|
||||
import { FileTokenStorage } from './file-token-storage.js';
|
||||
import { type OAuthCredentials, TokenStorageType } from './types.js';
|
||||
|
||||
vi.mock('./keychain-token-storage.js', () => ({
|
||||
KeychainTokenStorage: vi.fn().mockImplementation(() => ({
|
||||
isAvailable: vi.fn(),
|
||||
getCredentials: vi.fn(),
|
||||
setCredentials: vi.fn(),
|
||||
deleteCredentials: vi.fn(),
|
||||
listServers: vi.fn(),
|
||||
getAllCredentials: vi.fn(),
|
||||
clearAll: vi.fn(),
|
||||
})),
|
||||
}));
|
||||
|
||||
vi.mock('./file-token-storage.js', () => ({
|
||||
FileTokenStorage: vi.fn().mockImplementation(() => ({
|
||||
getCredentials: vi.fn(),
|
||||
setCredentials: vi.fn(),
|
||||
deleteCredentials: vi.fn(),
|
||||
listServers: vi.fn(),
|
||||
getAllCredentials: vi.fn(),
|
||||
clearAll: vi.fn(),
|
||||
})),
|
||||
}));
|
||||
|
||||
interface MockStorage {
|
||||
isAvailable?: ReturnType<typeof vi.fn>;
|
||||
getCredentials: ReturnType<typeof vi.fn>;
|
||||
setCredentials: ReturnType<typeof vi.fn>;
|
||||
deleteCredentials: ReturnType<typeof vi.fn>;
|
||||
listServers: ReturnType<typeof vi.fn>;
|
||||
getAllCredentials: ReturnType<typeof vi.fn>;
|
||||
clearAll: ReturnType<typeof vi.fn>;
|
||||
}
|
||||
|
||||
describe('HybridTokenStorage', () => {
|
||||
let storage: HybridTokenStorage;
|
||||
let mockKeychainStorage: MockStorage;
|
||||
let mockFileStorage: MockStorage;
|
||||
const originalEnv = process.env;
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
process.env = { ...originalEnv };
|
||||
|
||||
// Create mock instances before creating HybridTokenStorage
|
||||
mockKeychainStorage = {
|
||||
isAvailable: vi.fn(),
|
||||
getCredentials: vi.fn(),
|
||||
setCredentials: vi.fn(),
|
||||
deleteCredentials: vi.fn(),
|
||||
listServers: vi.fn(),
|
||||
getAllCredentials: vi.fn(),
|
||||
clearAll: vi.fn(),
|
||||
};
|
||||
|
||||
mockFileStorage = {
|
||||
getCredentials: vi.fn(),
|
||||
setCredentials: vi.fn(),
|
||||
deleteCredentials: vi.fn(),
|
||||
listServers: vi.fn(),
|
||||
getAllCredentials: vi.fn(),
|
||||
clearAll: vi.fn(),
|
||||
};
|
||||
|
||||
(
|
||||
KeychainTokenStorage as unknown as ReturnType<typeof vi.fn>
|
||||
).mockImplementation(() => mockKeychainStorage);
|
||||
(
|
||||
FileTokenStorage as unknown as ReturnType<typeof vi.fn>
|
||||
).mockImplementation(() => mockFileStorage);
|
||||
|
||||
storage = new HybridTokenStorage('test-service');
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
process.env = originalEnv;
|
||||
});
|
||||
|
||||
describe('storage selection', () => {
|
||||
it('should use keychain when available', async () => {
|
||||
mockKeychainStorage.isAvailable!.mockResolvedValue(true);
|
||||
mockKeychainStorage.getCredentials.mockResolvedValue(null);
|
||||
|
||||
await storage.getCredentials('test-server');
|
||||
|
||||
expect(mockKeychainStorage.isAvailable).toHaveBeenCalled();
|
||||
expect(mockKeychainStorage.getCredentials).toHaveBeenCalledWith(
|
||||
'test-server',
|
||||
);
|
||||
expect(await storage.getStorageType()).toBe(TokenStorageType.KEYCHAIN);
|
||||
});
|
||||
|
||||
it('should use file storage when GEMINI_FORCE_FILE_STORAGE is set', async () => {
|
||||
process.env['GEMINI_FORCE_FILE_STORAGE'] = 'true';
|
||||
mockFileStorage.getCredentials.mockResolvedValue(null);
|
||||
|
||||
await storage.getCredentials('test-server');
|
||||
|
||||
expect(mockKeychainStorage.isAvailable).not.toHaveBeenCalled();
|
||||
expect(mockFileStorage.getCredentials).toHaveBeenCalledWith(
|
||||
'test-server',
|
||||
);
|
||||
expect(await storage.getStorageType()).toBe(
|
||||
TokenStorageType.ENCRYPTED_FILE,
|
||||
);
|
||||
});
|
||||
|
||||
it('should fall back to file storage when keychain is unavailable', async () => {
|
||||
mockKeychainStorage.isAvailable!.mockResolvedValue(false);
|
||||
mockFileStorage.getCredentials.mockResolvedValue(null);
|
||||
|
||||
await storage.getCredentials('test-server');
|
||||
|
||||
expect(mockKeychainStorage.isAvailable).toHaveBeenCalled();
|
||||
expect(mockFileStorage.getCredentials).toHaveBeenCalledWith(
|
||||
'test-server',
|
||||
);
|
||||
expect(await storage.getStorageType()).toBe(
|
||||
TokenStorageType.ENCRYPTED_FILE,
|
||||
);
|
||||
});
|
||||
|
||||
it('should fall back to file storage when keychain throws error', async () => {
|
||||
mockKeychainStorage.isAvailable!.mockRejectedValue(
|
||||
new Error('Keychain error'),
|
||||
);
|
||||
mockFileStorage.getCredentials.mockResolvedValue(null);
|
||||
|
||||
await storage.getCredentials('test-server');
|
||||
|
||||
expect(mockKeychainStorage.isAvailable).toHaveBeenCalled();
|
||||
expect(mockFileStorage.getCredentials).toHaveBeenCalledWith(
|
||||
'test-server',
|
||||
);
|
||||
expect(await storage.getStorageType()).toBe(
|
||||
TokenStorageType.ENCRYPTED_FILE,
|
||||
);
|
||||
});
|
||||
|
||||
it('should cache storage selection', async () => {
|
||||
mockKeychainStorage.isAvailable!.mockResolvedValue(true);
|
||||
mockKeychainStorage.getCredentials.mockResolvedValue(null);
|
||||
|
||||
await storage.getCredentials('test-server');
|
||||
await storage.getCredentials('another-server');
|
||||
|
||||
expect(mockKeychainStorage.isAvailable).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
});
|
||||
|
||||
describe('getCredentials', () => {
|
||||
it('should delegate to selected storage', async () => {
|
||||
const credentials: OAuthCredentials = {
|
||||
serverName: 'test-server',
|
||||
token: {
|
||||
accessToken: 'access-token',
|
||||
tokenType: 'Bearer',
|
||||
},
|
||||
updatedAt: Date.now(),
|
||||
};
|
||||
|
||||
mockKeychainStorage.isAvailable!.mockResolvedValue(true);
|
||||
mockKeychainStorage.getCredentials.mockResolvedValue(credentials);
|
||||
|
||||
const result = await storage.getCredentials('test-server');
|
||||
|
||||
expect(result).toEqual(credentials);
|
||||
expect(mockKeychainStorage.getCredentials).toHaveBeenCalledWith(
|
||||
'test-server',
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('setCredentials', () => {
|
||||
it('should delegate to selected storage', async () => {
|
||||
const credentials: OAuthCredentials = {
|
||||
serverName: 'test-server',
|
||||
token: {
|
||||
accessToken: 'access-token',
|
||||
tokenType: 'Bearer',
|
||||
},
|
||||
updatedAt: Date.now(),
|
||||
};
|
||||
|
||||
mockKeychainStorage.isAvailable!.mockResolvedValue(true);
|
||||
mockKeychainStorage.setCredentials.mockResolvedValue(undefined);
|
||||
|
||||
await storage.setCredentials(credentials);
|
||||
|
||||
expect(mockKeychainStorage.setCredentials).toHaveBeenCalledWith(
|
||||
credentials,
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('deleteCredentials', () => {
|
||||
it('should delegate to selected storage', async () => {
|
||||
mockKeychainStorage.isAvailable!.mockResolvedValue(true);
|
||||
mockKeychainStorage.deleteCredentials.mockResolvedValue(undefined);
|
||||
|
||||
await storage.deleteCredentials('test-server');
|
||||
|
||||
expect(mockKeychainStorage.deleteCredentials).toHaveBeenCalledWith(
|
||||
'test-server',
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('listServers', () => {
|
||||
it('should delegate to selected storage', async () => {
|
||||
const servers = ['server1', 'server2'];
|
||||
mockKeychainStorage.isAvailable!.mockResolvedValue(true);
|
||||
mockKeychainStorage.listServers.mockResolvedValue(servers);
|
||||
|
||||
const result = await storage.listServers();
|
||||
|
||||
expect(result).toEqual(servers);
|
||||
expect(mockKeychainStorage.listServers).toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
describe('getAllCredentials', () => {
|
||||
it('should delegate to selected storage', async () => {
|
||||
const credentialsMap = new Map([
|
||||
[
|
||||
'server1',
|
||||
{
|
||||
serverName: 'server1',
|
||||
token: { accessToken: 'token1', tokenType: 'Bearer' },
|
||||
updatedAt: Date.now(),
|
||||
},
|
||||
],
|
||||
[
|
||||
'server2',
|
||||
{
|
||||
serverName: 'server2',
|
||||
token: { accessToken: 'token2', tokenType: 'Bearer' },
|
||||
updatedAt: Date.now(),
|
||||
},
|
||||
],
|
||||
]);
|
||||
|
||||
mockKeychainStorage.isAvailable!.mockResolvedValue(true);
|
||||
mockKeychainStorage.getAllCredentials.mockResolvedValue(credentialsMap);
|
||||
|
||||
const result = await storage.getAllCredentials();
|
||||
|
||||
expect(result).toEqual(credentialsMap);
|
||||
expect(mockKeychainStorage.getAllCredentials).toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
describe('clearAll', () => {
|
||||
it('should delegate to selected storage', async () => {
|
||||
mockKeychainStorage.isAvailable!.mockResolvedValue(true);
|
||||
mockKeychainStorage.clearAll.mockResolvedValue(undefined);
|
||||
|
||||
await storage.clearAll();
|
||||
|
||||
expect(mockKeychainStorage.clearAll).toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
});
|
||||
97
packages/core/src/mcp/token-storage/hybrid-token-storage.ts
Normal file
97
packages/core/src/mcp/token-storage/hybrid-token-storage.ts
Normal file
@@ -0,0 +1,97 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { BaseTokenStorage } from './base-token-storage.js';
|
||||
import { FileTokenStorage } from './file-token-storage.js';
|
||||
import type { TokenStorage, OAuthCredentials } from './types.js';
|
||||
import { TokenStorageType } from './types.js';
|
||||
|
||||
const FORCE_FILE_STORAGE_ENV_VAR = 'GEMINI_FORCE_FILE_STORAGE';
|
||||
|
||||
export class HybridTokenStorage extends BaseTokenStorage {
|
||||
private storage: TokenStorage | null = null;
|
||||
private storageType: TokenStorageType | null = null;
|
||||
private storageInitPromise: Promise<TokenStorage> | null = null;
|
||||
|
||||
constructor(serviceName: string) {
|
||||
super(serviceName);
|
||||
}
|
||||
|
||||
private async initializeStorage(): Promise<TokenStorage> {
|
||||
const forceFileStorage = process.env[FORCE_FILE_STORAGE_ENV_VAR] === 'true';
|
||||
|
||||
if (!forceFileStorage) {
|
||||
try {
|
||||
const { KeychainTokenStorage } = await import(
|
||||
'./keychain-token-storage.js'
|
||||
);
|
||||
const keychainStorage = new KeychainTokenStorage(this.serviceName);
|
||||
|
||||
const isAvailable = await keychainStorage.isAvailable();
|
||||
if (isAvailable) {
|
||||
this.storage = keychainStorage;
|
||||
this.storageType = TokenStorageType.KEYCHAIN;
|
||||
return this.storage;
|
||||
}
|
||||
} catch (_e) {
|
||||
// Fallback to file storage if keychain fails to initialize
|
||||
}
|
||||
}
|
||||
|
||||
this.storage = new FileTokenStorage(this.serviceName);
|
||||
this.storageType = TokenStorageType.ENCRYPTED_FILE;
|
||||
return this.storage;
|
||||
}
|
||||
|
||||
private async getStorage(): Promise<TokenStorage> {
|
||||
if (this.storage !== null) {
|
||||
return this.storage;
|
||||
}
|
||||
|
||||
// Use a single initialization promise to avoid race conditions
|
||||
if (!this.storageInitPromise) {
|
||||
this.storageInitPromise = this.initializeStorage();
|
||||
}
|
||||
|
||||
// Wait for initialization to complete
|
||||
return await this.storageInitPromise;
|
||||
}
|
||||
|
||||
async getCredentials(serverName: string): Promise<OAuthCredentials | null> {
|
||||
const storage = await this.getStorage();
|
||||
return storage.getCredentials(serverName);
|
||||
}
|
||||
|
||||
async setCredentials(credentials: OAuthCredentials): Promise<void> {
|
||||
const storage = await this.getStorage();
|
||||
await storage.setCredentials(credentials);
|
||||
}
|
||||
|
||||
async deleteCredentials(serverName: string): Promise<void> {
|
||||
const storage = await this.getStorage();
|
||||
await storage.deleteCredentials(serverName);
|
||||
}
|
||||
|
||||
async listServers(): Promise<string[]> {
|
||||
const storage = await this.getStorage();
|
||||
return storage.listServers();
|
||||
}
|
||||
|
||||
async getAllCredentials(): Promise<Map<string, OAuthCredentials>> {
|
||||
const storage = await this.getStorage();
|
||||
return storage.getAllCredentials();
|
||||
}
|
||||
|
||||
async clearAll(): Promise<void> {
|
||||
const storage = await this.getStorage();
|
||||
await storage.clearAll();
|
||||
}
|
||||
|
||||
async getStorageType(): Promise<TokenStorageType> {
|
||||
await this.getStorage();
|
||||
return this.storageType!;
|
||||
}
|
||||
}
|
||||
14
packages/core/src/mcp/token-storage/index.ts
Normal file
14
packages/core/src/mcp/token-storage/index.ts
Normal file
@@ -0,0 +1,14 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
export * from './types.js';
|
||||
export * from './base-token-storage.js';
|
||||
export * from './file-token-storage.js';
|
||||
export * from './hybrid-token-storage.js';
|
||||
|
||||
export const DEFAULT_SERVICE_NAME = 'qwen-code-oauth';
|
||||
export const FORCE_ENCRYPTED_FILE_ENV_VAR =
|
||||
'QWEN_CODE_FORCE_ENCRYPTED_FILE_STORAGE';
|
||||
@@ -0,0 +1,352 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { describe, it, expect, beforeEach, afterEach, vi } from 'vitest';
|
||||
import type { KeychainTokenStorage } from './keychain-token-storage.js';
|
||||
import type { OAuthCredentials } from './types.js';
|
||||
|
||||
// Hoist the mock to be available in the vi.mock factory
|
||||
const mockKeytar = vi.hoisted(() => ({
|
||||
getPassword: vi.fn(),
|
||||
setPassword: vi.fn(),
|
||||
deletePassword: vi.fn(),
|
||||
findCredentials: vi.fn(),
|
||||
}));
|
||||
|
||||
const mockServiceName = 'service-name';
|
||||
const mockCryptoRandomBytesString = 'random-string';
|
||||
|
||||
// Mock the dynamic import of 'keytar'
|
||||
vi.mock('keytar', () => ({
|
||||
default: mockKeytar,
|
||||
}));
|
||||
|
||||
vi.mock('node:crypto', () => ({
|
||||
randomBytes: vi.fn(() => ({
|
||||
toString: vi.fn(() => mockCryptoRandomBytesString),
|
||||
})),
|
||||
}));
|
||||
|
||||
describe('KeychainTokenStorage', () => {
|
||||
let storage: KeychainTokenStorage;
|
||||
|
||||
beforeEach(async () => {
|
||||
vi.resetAllMocks();
|
||||
// Reset the internal state of the keychain-token-storage module
|
||||
vi.resetModules();
|
||||
const { KeychainTokenStorage } = await import(
|
||||
'./keychain-token-storage.js'
|
||||
);
|
||||
storage = new KeychainTokenStorage(mockServiceName);
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.restoreAllMocks();
|
||||
vi.useRealTimers();
|
||||
});
|
||||
|
||||
const validCredentials = {
|
||||
serverName: 'test-server',
|
||||
token: {
|
||||
accessToken: 'access-token',
|
||||
tokenType: 'Bearer',
|
||||
expiresAt: Date.now() + 3600000,
|
||||
},
|
||||
updatedAt: Date.now(),
|
||||
} as OAuthCredentials;
|
||||
|
||||
describe('checkKeychainAvailability', () => {
|
||||
it('should return true if keytar is available and functional', async () => {
|
||||
mockKeytar.setPassword.mockResolvedValue(undefined);
|
||||
mockKeytar.getPassword.mockResolvedValue('test');
|
||||
mockKeytar.deletePassword.mockResolvedValue(true);
|
||||
|
||||
const isAvailable = await storage.checkKeychainAvailability();
|
||||
expect(isAvailable).toBe(true);
|
||||
expect(mockKeytar.setPassword).toHaveBeenCalledWith(
|
||||
mockServiceName,
|
||||
`__keychain_test__${mockCryptoRandomBytesString}`,
|
||||
'test',
|
||||
);
|
||||
expect(mockKeytar.getPassword).toHaveBeenCalledWith(
|
||||
mockServiceName,
|
||||
`__keychain_test__${mockCryptoRandomBytesString}`,
|
||||
);
|
||||
expect(mockKeytar.deletePassword).toHaveBeenCalledWith(
|
||||
mockServiceName,
|
||||
`__keychain_test__${mockCryptoRandomBytesString}`,
|
||||
);
|
||||
});
|
||||
|
||||
it('should return false if keytar fails to set password', async () => {
|
||||
mockKeytar.setPassword.mockRejectedValue(new Error('write error'));
|
||||
const isAvailable = await storage.checkKeychainAvailability();
|
||||
expect(isAvailable).toBe(false);
|
||||
});
|
||||
|
||||
it('should return false if retrieved password does not match', async () => {
|
||||
mockKeytar.setPassword.mockResolvedValue(undefined);
|
||||
mockKeytar.getPassword.mockResolvedValue('wrong-password');
|
||||
mockKeytar.deletePassword.mockResolvedValue(true);
|
||||
const isAvailable = await storage.checkKeychainAvailability();
|
||||
expect(isAvailable).toBe(false);
|
||||
});
|
||||
|
||||
it('should cache the availability result', async () => {
|
||||
mockKeytar.setPassword.mockResolvedValue(undefined);
|
||||
mockKeytar.getPassword.mockResolvedValue('test');
|
||||
mockKeytar.deletePassword.mockResolvedValue(true);
|
||||
|
||||
await storage.checkKeychainAvailability();
|
||||
await storage.checkKeychainAvailability();
|
||||
|
||||
expect(mockKeytar.setPassword).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
});
|
||||
|
||||
describe('with keychain unavailable', () => {
|
||||
beforeEach(async () => {
|
||||
// Force keychain to be unavailable
|
||||
mockKeytar.setPassword.mockRejectedValue(new Error('keychain error'));
|
||||
await storage.checkKeychainAvailability();
|
||||
});
|
||||
|
||||
it('getCredentials should throw', async () => {
|
||||
await expect(storage.getCredentials('server')).rejects.toThrow(
|
||||
'Keychain is not available',
|
||||
);
|
||||
});
|
||||
|
||||
it('setCredentials should throw', async () => {
|
||||
await expect(storage.setCredentials(validCredentials)).rejects.toThrow(
|
||||
'Keychain is not available',
|
||||
);
|
||||
});
|
||||
|
||||
it('deleteCredentials should throw', async () => {
|
||||
await expect(storage.deleteCredentials('server')).rejects.toThrow(
|
||||
'Keychain is not available',
|
||||
);
|
||||
});
|
||||
|
||||
it('listServers should throw', async () => {
|
||||
await expect(storage.listServers()).rejects.toThrow(
|
||||
'Keychain is not available',
|
||||
);
|
||||
});
|
||||
|
||||
it('getAllCredentials should throw', async () => {
|
||||
await expect(storage.getAllCredentials()).rejects.toThrow(
|
||||
'Keychain is not available',
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('with keychain available', () => {
|
||||
beforeEach(async () => {
|
||||
mockKeytar.setPassword.mockResolvedValue(undefined);
|
||||
mockKeytar.getPassword.mockResolvedValue('test');
|
||||
mockKeytar.deletePassword.mockResolvedValue(true);
|
||||
await storage.checkKeychainAvailability();
|
||||
// Reset mocks after availability check
|
||||
vi.resetAllMocks();
|
||||
});
|
||||
|
||||
describe('getCredentials', () => {
|
||||
it('should return null if no credentials are found', async () => {
|
||||
mockKeytar.getPassword.mockResolvedValue(null);
|
||||
const result = await storage.getCredentials('test-server');
|
||||
expect(result).toBeNull();
|
||||
expect(mockKeytar.getPassword).toHaveBeenCalledWith(
|
||||
mockServiceName,
|
||||
'test-server',
|
||||
);
|
||||
});
|
||||
|
||||
it('should return credentials if found and not expired', async () => {
|
||||
mockKeytar.getPassword.mockResolvedValue(
|
||||
JSON.stringify(validCredentials),
|
||||
);
|
||||
const result = await storage.getCredentials('test-server');
|
||||
expect(result).toEqual(validCredentials);
|
||||
});
|
||||
|
||||
it('should return null if credentials have expired', async () => {
|
||||
const expiredCreds = {
|
||||
...validCredentials,
|
||||
token: { ...validCredentials.token, expiresAt: Date.now() - 1000 },
|
||||
};
|
||||
mockKeytar.getPassword.mockResolvedValue(JSON.stringify(expiredCreds));
|
||||
const result = await storage.getCredentials('test-server');
|
||||
expect(result).toBeNull();
|
||||
});
|
||||
|
||||
it('should throw if stored data is corrupted JSON', async () => {
|
||||
mockKeytar.getPassword.mockResolvedValue('not-json');
|
||||
await expect(storage.getCredentials('test-server')).rejects.toThrow(
|
||||
'Failed to parse stored credentials for test-server',
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('setCredentials', () => {
|
||||
it('should save credentials to keychain', async () => {
|
||||
vi.useFakeTimers();
|
||||
mockKeytar.setPassword.mockResolvedValue(undefined);
|
||||
await storage.setCredentials(validCredentials);
|
||||
expect(mockKeytar.setPassword).toHaveBeenCalledWith(
|
||||
mockServiceName,
|
||||
'test-server',
|
||||
JSON.stringify({ ...validCredentials, updatedAt: Date.now() }),
|
||||
);
|
||||
});
|
||||
|
||||
it('should throw if saving to keychain fails', async () => {
|
||||
mockKeytar.setPassword.mockRejectedValue(
|
||||
new Error('keychain write error'),
|
||||
);
|
||||
await expect(storage.setCredentials(validCredentials)).rejects.toThrow(
|
||||
'keychain write error',
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('deleteCredentials', () => {
|
||||
it('should delete credentials from keychain', async () => {
|
||||
mockKeytar.deletePassword.mockResolvedValue(true);
|
||||
await storage.deleteCredentials('test-server');
|
||||
expect(mockKeytar.deletePassword).toHaveBeenCalledWith(
|
||||
mockServiceName,
|
||||
'test-server',
|
||||
);
|
||||
});
|
||||
|
||||
it('should throw if no credentials were found to delete', async () => {
|
||||
mockKeytar.deletePassword.mockResolvedValue(false);
|
||||
await expect(storage.deleteCredentials('test-server')).rejects.toThrow(
|
||||
'No credentials found for test-server',
|
||||
);
|
||||
});
|
||||
|
||||
it('should throw if deleting from keychain fails', async () => {
|
||||
mockKeytar.deletePassword.mockRejectedValue(
|
||||
new Error('keychain delete error'),
|
||||
);
|
||||
await expect(storage.deleteCredentials('test-server')).rejects.toThrow(
|
||||
'keychain delete error',
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('listServers', () => {
|
||||
it('should return a list of server names', async () => {
|
||||
mockKeytar.findCredentials.mockResolvedValue([
|
||||
{ account: 'server1', password: '' },
|
||||
{ account: 'server2', password: '' },
|
||||
]);
|
||||
const result = await storage.listServers();
|
||||
expect(result).toEqual(['server1', 'server2']);
|
||||
});
|
||||
|
||||
it('should not include internal test keys in the server list', async () => {
|
||||
mockKeytar.findCredentials.mockResolvedValue([
|
||||
{ account: 'server1', password: '' },
|
||||
{
|
||||
account: `__keychain_test__${mockCryptoRandomBytesString}`,
|
||||
password: '',
|
||||
},
|
||||
{ account: 'server2', password: '' },
|
||||
]);
|
||||
const result = await storage.listServers();
|
||||
expect(result).toEqual(['server1', 'server2']);
|
||||
});
|
||||
|
||||
it('should return an empty array on error', async () => {
|
||||
mockKeytar.findCredentials.mockRejectedValue(new Error('find error'));
|
||||
const result = await storage.listServers();
|
||||
expect(result).toEqual([]);
|
||||
});
|
||||
});
|
||||
|
||||
describe('getAllCredentials', () => {
|
||||
it('should return a map of all valid credentials', async () => {
|
||||
const creds2 = {
|
||||
...validCredentials,
|
||||
serverName: 'server2',
|
||||
};
|
||||
const expiredCreds = {
|
||||
...validCredentials,
|
||||
serverName: 'expired-server',
|
||||
token: { ...validCredentials.token, expiresAt: Date.now() - 1000 },
|
||||
};
|
||||
const structurallyInvalidCreds = {
|
||||
serverName: 'invalid-server',
|
||||
};
|
||||
|
||||
mockKeytar.findCredentials.mockResolvedValue([
|
||||
{
|
||||
account: 'test-server',
|
||||
password: JSON.stringify(validCredentials),
|
||||
},
|
||||
{ account: 'server2', password: JSON.stringify(creds2) },
|
||||
{
|
||||
account: 'expired-server',
|
||||
password: JSON.stringify(expiredCreds),
|
||||
},
|
||||
{ account: 'bad-server', password: 'not-json' },
|
||||
{
|
||||
account: 'invalid-server',
|
||||
password: JSON.stringify(structurallyInvalidCreds),
|
||||
},
|
||||
]);
|
||||
|
||||
const result = await storage.getAllCredentials();
|
||||
expect(result.size).toBe(2);
|
||||
expect(result.get('test-server')).toEqual(validCredentials);
|
||||
expect(result.get('server2')).toEqual(creds2);
|
||||
expect(result.has('expired-server')).toBe(false);
|
||||
expect(result.has('bad-server')).toBe(false);
|
||||
expect(result.has('invalid-server')).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe('clearAll', () => {
|
||||
it('should delete all credentials for the service', async () => {
|
||||
mockKeytar.findCredentials.mockResolvedValue([
|
||||
{ account: 'server1', password: '' },
|
||||
{ account: 'server2', password: '' },
|
||||
]);
|
||||
mockKeytar.deletePassword.mockResolvedValue(true);
|
||||
|
||||
await storage.clearAll();
|
||||
|
||||
expect(mockKeytar.deletePassword).toHaveBeenCalledTimes(2);
|
||||
expect(mockKeytar.deletePassword).toHaveBeenCalledWith(
|
||||
mockServiceName,
|
||||
'server1',
|
||||
);
|
||||
expect(mockKeytar.deletePassword).toHaveBeenCalledWith(
|
||||
mockServiceName,
|
||||
'server2',
|
||||
);
|
||||
});
|
||||
|
||||
it('should throw an aggregated error if deletions fail', async () => {
|
||||
mockKeytar.findCredentials.mockResolvedValue([
|
||||
{ account: 'server1', password: '' },
|
||||
{ account: 'server2', password: '' },
|
||||
]);
|
||||
mockKeytar.deletePassword
|
||||
.mockResolvedValueOnce(true)
|
||||
.mockRejectedValueOnce(new Error('delete failed'));
|
||||
|
||||
await expect(storage.clearAll()).rejects.toThrow(
|
||||
'Failed to clear some credentials: delete failed',
|
||||
);
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
251
packages/core/src/mcp/token-storage/keychain-token-storage.ts
Normal file
251
packages/core/src/mcp/token-storage/keychain-token-storage.ts
Normal file
@@ -0,0 +1,251 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import * as crypto from 'node:crypto';
|
||||
import { BaseTokenStorage } from './base-token-storage.js';
|
||||
import type { OAuthCredentials } from './types.js';
|
||||
|
||||
interface Keytar {
|
||||
getPassword(service: string, account: string): Promise<string | null>;
|
||||
setPassword(
|
||||
service: string,
|
||||
account: string,
|
||||
password: string,
|
||||
): Promise<void>;
|
||||
deletePassword(service: string, account: string): Promise<boolean>;
|
||||
findCredentials(
|
||||
service: string,
|
||||
): Promise<Array<{ account: string; password: string }>>;
|
||||
}
|
||||
|
||||
const KEYCHAIN_TEST_PREFIX = '__keychain_test__';
|
||||
|
||||
export class KeychainTokenStorage extends BaseTokenStorage {
|
||||
private keychainAvailable: boolean | null = null;
|
||||
private keytarModule: Keytar | null = null;
|
||||
private keytarLoadAttempted = false;
|
||||
|
||||
async getKeytar(): Promise<Keytar | null> {
|
||||
// If we've already tried loading (successfully or not), return the result
|
||||
if (this.keytarLoadAttempted) {
|
||||
return this.keytarModule;
|
||||
}
|
||||
|
||||
this.keytarLoadAttempted = true;
|
||||
|
||||
try {
|
||||
// Try to import keytar without any timeout - let the OS handle it
|
||||
const moduleName = 'keytar';
|
||||
const module = await import(moduleName);
|
||||
this.keytarModule = module.default || module;
|
||||
} catch (error) {
|
||||
console.error(error);
|
||||
}
|
||||
return this.keytarModule;
|
||||
}
|
||||
|
||||
async getCredentials(serverName: string): Promise<OAuthCredentials | null> {
|
||||
if (!(await this.checkKeychainAvailability())) {
|
||||
throw new Error('Keychain is not available');
|
||||
}
|
||||
|
||||
const keytar = await this.getKeytar();
|
||||
if (!keytar) {
|
||||
throw new Error('Keytar module not available');
|
||||
}
|
||||
|
||||
try {
|
||||
const sanitizedName = this.sanitizeServerName(serverName);
|
||||
const data = await keytar.getPassword(this.serviceName, sanitizedName);
|
||||
|
||||
if (!data) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const credentials = JSON.parse(data) as OAuthCredentials;
|
||||
|
||||
if (this.isTokenExpired(credentials)) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return credentials;
|
||||
} catch (error) {
|
||||
if (error instanceof SyntaxError) {
|
||||
throw new Error(`Failed to parse stored credentials for ${serverName}`);
|
||||
}
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
async setCredentials(credentials: OAuthCredentials): Promise<void> {
|
||||
if (!(await this.checkKeychainAvailability())) {
|
||||
throw new Error('Keychain is not available');
|
||||
}
|
||||
|
||||
const keytar = await this.getKeytar();
|
||||
if (!keytar) {
|
||||
throw new Error('Keytar module not available');
|
||||
}
|
||||
|
||||
this.validateCredentials(credentials);
|
||||
|
||||
const sanitizedName = this.sanitizeServerName(credentials.serverName);
|
||||
const updatedCredentials: OAuthCredentials = {
|
||||
...credentials,
|
||||
updatedAt: Date.now(),
|
||||
};
|
||||
|
||||
const data = JSON.stringify(updatedCredentials);
|
||||
await keytar.setPassword(this.serviceName, sanitizedName, data);
|
||||
}
|
||||
|
||||
async deleteCredentials(serverName: string): Promise<void> {
|
||||
if (!(await this.checkKeychainAvailability())) {
|
||||
throw new Error('Keychain is not available');
|
||||
}
|
||||
|
||||
const keytar = await this.getKeytar();
|
||||
if (!keytar) {
|
||||
throw new Error('Keytar module not available');
|
||||
}
|
||||
|
||||
const sanitizedName = this.sanitizeServerName(serverName);
|
||||
const deleted = await keytar.deletePassword(
|
||||
this.serviceName,
|
||||
sanitizedName,
|
||||
);
|
||||
|
||||
if (!deleted) {
|
||||
throw new Error(`No credentials found for ${serverName}`);
|
||||
}
|
||||
}
|
||||
|
||||
async listServers(): Promise<string[]> {
|
||||
if (!(await this.checkKeychainAvailability())) {
|
||||
throw new Error('Keychain is not available');
|
||||
}
|
||||
|
||||
const keytar = await this.getKeytar();
|
||||
if (!keytar) {
|
||||
throw new Error('Keytar module not available');
|
||||
}
|
||||
|
||||
try {
|
||||
const credentials = await keytar.findCredentials(this.serviceName);
|
||||
return credentials
|
||||
.filter((cred) => !cred.account.startsWith(KEYCHAIN_TEST_PREFIX))
|
||||
.map((cred: { account: string }) => cred.account);
|
||||
} catch (error) {
|
||||
console.error('Failed to list servers from keychain:', error);
|
||||
return [];
|
||||
}
|
||||
}
|
||||
|
||||
async getAllCredentials(): Promise<Map<string, OAuthCredentials>> {
|
||||
if (!(await this.checkKeychainAvailability())) {
|
||||
throw new Error('Keychain is not available');
|
||||
}
|
||||
|
||||
const keytar = await this.getKeytar();
|
||||
if (!keytar) {
|
||||
throw new Error('Keytar module not available');
|
||||
}
|
||||
|
||||
const result = new Map<string, OAuthCredentials>();
|
||||
try {
|
||||
const credentials = (
|
||||
await keytar.findCredentials(this.serviceName)
|
||||
).filter((c) => !c.account.startsWith(KEYCHAIN_TEST_PREFIX));
|
||||
|
||||
for (const cred of credentials) {
|
||||
try {
|
||||
const data = JSON.parse(cred.password) as OAuthCredentials;
|
||||
if (!this.isTokenExpired(data)) {
|
||||
result.set(cred.account, data);
|
||||
}
|
||||
} catch (error) {
|
||||
console.error(
|
||||
`Failed to parse credentials for ${cred.account}:`,
|
||||
error,
|
||||
);
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Failed to get all credentials from keychain:', error);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
async clearAll(): Promise<void> {
|
||||
if (!(await this.checkKeychainAvailability())) {
|
||||
throw new Error('Keychain is not available');
|
||||
}
|
||||
|
||||
const servers = this.keytarModule
|
||||
? await this.keytarModule
|
||||
.findCredentials(this.serviceName)
|
||||
.then((creds) => creds.map((c) => c.account))
|
||||
.catch((error: Error) => {
|
||||
throw new Error(
|
||||
`Failed to list servers for clearing: ${error.message}`,
|
||||
);
|
||||
})
|
||||
: [];
|
||||
const errors: Error[] = [];
|
||||
|
||||
for (const server of servers) {
|
||||
try {
|
||||
await this.deleteCredentials(server);
|
||||
} catch (error) {
|
||||
errors.push(error as Error);
|
||||
}
|
||||
}
|
||||
|
||||
if (errors.length > 0) {
|
||||
throw new Error(
|
||||
`Failed to clear some credentials: ${errors.map((e) => e.message).join(', ')}`,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// Checks whether or not a set-get-delete cycle with the keychain works.
|
||||
// Returns false if any operation fails.
|
||||
async checkKeychainAvailability(): Promise<boolean> {
|
||||
if (this.keychainAvailable !== null) {
|
||||
return this.keychainAvailable;
|
||||
}
|
||||
|
||||
try {
|
||||
const keytar = await this.getKeytar();
|
||||
if (!keytar) {
|
||||
this.keychainAvailable = false;
|
||||
return false;
|
||||
}
|
||||
|
||||
const testAccount = `${KEYCHAIN_TEST_PREFIX}${crypto.randomBytes(8).toString('hex')}`;
|
||||
const testPassword = 'test';
|
||||
|
||||
await keytar.setPassword(this.serviceName, testAccount, testPassword);
|
||||
const retrieved = await keytar.getPassword(this.serviceName, testAccount);
|
||||
const deleted = await keytar.deletePassword(
|
||||
this.serviceName,
|
||||
testAccount,
|
||||
);
|
||||
|
||||
const success = deleted && retrieved === testPassword;
|
||||
this.keychainAvailable = success;
|
||||
return success;
|
||||
} catch (_error) {
|
||||
this.keychainAvailable = false;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
async isAvailable(): Promise<boolean> {
|
||||
return this.checkKeychainAvailability();
|
||||
}
|
||||
}
|
||||
@@ -35,3 +35,8 @@ export interface TokenStorage {
|
||||
getAllCredentials(): Promise<Map<string, OAuthCredentials>>;
|
||||
clearAll(): Promise<void>;
|
||||
}
|
||||
|
||||
export enum TokenStorageType {
|
||||
KEYCHAIN = 'keychain',
|
||||
ENCRYPTED_FILE = 'encrypted_file',
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user