Sync upstream Gemini-CLI v0.8.2 (#838)

This commit is contained in:
tanzhenxin
2025-10-23 09:27:04 +08:00
committed by GitHub
parent 096fabb5d6
commit eb95c131be
644 changed files with 70389 additions and 23709 deletions

View File

@@ -12,7 +12,21 @@ vi.mock('../utils/secure-browser-launcher.js', () => ({
openBrowserSecurely: mockOpenBrowserSecurely,
}));
vi.mock('node:crypto');
vi.mock('./oauth-token-storage.js');
vi.mock('./oauth-token-storage.js', () => {
const mockSaveToken = vi.fn();
const mockGetCredentials = vi.fn();
const mockIsTokenExpired = vi.fn();
const mockdeleteCredentials = vi.fn();
return {
MCPOAuthTokenStorage: vi.fn(() => ({
saveToken: mockSaveToken,
getCredentials: mockGetCredentials,
isTokenExpired: mockIsTokenExpired,
deleteCredentials: mockdeleteCredentials,
})),
};
});
import { describe, it, expect, beforeEach, afterEach } from 'vitest';
import * as http from 'node:http';
@@ -25,6 +39,10 @@ import type {
import { MCPOAuthProvider } from './oauth-provider.js';
import type { OAuthToken } from './token-storage/types.js';
import { MCPOAuthTokenStorage } from './oauth-token-storage.js';
import type {
OAuthAuthorizationServerMetadata,
OAuthProtectedResourceMetadata,
} from './oauth-utils.js';
// Mock fetch globally
const mockFetch = vi.fn();
@@ -147,8 +165,9 @@ describe('MCPOAuthProvider', () => {
});
// Mock token storage
vi.mocked(MCPOAuthTokenStorage.saveToken).mockResolvedValue(undefined);
vi.mocked(MCPOAuthTokenStorage.getToken).mockResolvedValue(null);
const tokenStorage = new MCPOAuthTokenStorage();
vi.mocked(tokenStorage.saveToken).mockResolvedValue(undefined);
vi.mocked(tokenStorage.getCredentials).mockResolvedValue(null);
});
afterEach(() => {
@@ -192,10 +211,8 @@ describe('MCPOAuthProvider', () => {
}),
);
const result = await MCPOAuthProvider.authenticate(
'test-server',
mockConfig,
);
const authProvider = new MCPOAuthProvider();
const result = await authProvider.authenticate('test-server', mockConfig);
expect(result).toEqual({
accessToken: 'access_token_123',
@@ -208,7 +225,8 @@ describe('MCPOAuthProvider', () => {
expect(mockOpenBrowserSecurely).toHaveBeenCalledWith(
expect.stringContaining('authorize'),
);
expect(MCPOAuthTokenStorage.saveToken).toHaveBeenCalledWith(
const tokenStorage = new MCPOAuthTokenStorage();
expect(tokenStorage.saveToken).toHaveBeenCalledWith(
'test-server',
expect.objectContaining({ accessToken: 'access_token_123' }),
'test-client-id',
@@ -296,7 +314,8 @@ describe('MCPOAuthProvider', () => {
}),
);
const result = await MCPOAuthProvider.authenticate(
const authProvider = new MCPOAuthProvider();
const result = await authProvider.authenticate(
'test-server',
configWithoutAuth,
'https://api.example.com',
@@ -314,8 +333,11 @@ describe('MCPOAuthProvider', () => {
);
});
it('should perform dynamic client registration when no client ID provided', async () => {
const configWithoutClient = { ...mockConfig };
it('should perform dynamic client registration when no client ID is provided but registration URL is provided', async () => {
const configWithoutClient: MCPOAuthConfig = {
...mockConfig,
registrationUrl: 'https://auth.example.com/register',
};
delete configWithoutClient.clientId;
const mockRegistrationResponse: OAuthClientRegistrationResponse = {
@@ -327,7 +349,80 @@ describe('MCPOAuthProvider', () => {
token_endpoint_auth_method: 'none',
};
const mockAuthServerMetadata = {
mockFetch.mockResolvedValueOnce(
createMockResponse({
ok: true,
contentType: 'application/json',
text: JSON.stringify(mockRegistrationResponse),
json: mockRegistrationResponse,
}),
);
// Setup callback handler
let callbackHandler: unknown;
vi.mocked(http.createServer).mockImplementation((handler) => {
callbackHandler = handler;
return mockHttpServer as unknown as http.Server;
});
mockHttpServer.listen.mockImplementation((port, callback) => {
callback?.();
setTimeout(() => {
const mockReq = {
url: '/oauth/callback?code=auth_code_123&state=bW9ja19zdGF0ZV8xNl9ieXRlcw',
};
const mockRes = {
writeHead: vi.fn(),
end: vi.fn(),
};
(callbackHandler as (req: unknown, res: unknown) => void)(
mockReq,
mockRes,
);
}, 10);
});
// Mock token exchange
mockFetch.mockResolvedValueOnce(
createMockResponse({
ok: true,
contentType: 'application/json',
text: JSON.stringify(mockTokenResponse),
json: mockTokenResponse,
}),
);
const authProvider = new MCPOAuthProvider();
const result = await authProvider.authenticate(
'test-server',
configWithoutClient,
);
expect(result).toBeDefined();
expect(mockFetch).toHaveBeenCalledWith(
'https://auth.example.com/register',
expect.objectContaining({
method: 'POST',
headers: { 'Content-Type': 'application/json' },
}),
);
});
it('should perform OAuth discovery and dynamic client registration when no client ID or registration URL provided', async () => {
const configWithoutClient: MCPOAuthConfig = { ...mockConfig };
delete configWithoutClient.clientId;
const mockRegistrationResponse: OAuthClientRegistrationResponse = {
client_id: 'dynamic_client_id',
client_secret: 'dynamic_client_secret',
redirect_uris: ['http://localhost:7777/oauth/callback'],
grant_types: ['authorization_code', 'refresh_token'],
response_types: ['code'],
token_endpoint_auth_method: 'none',
};
const mockAuthServerMetadata: OAuthAuthorizationServerMetadata = {
issuer: 'https://auth.example.com',
authorization_endpoint: 'https://auth.example.com/authorize',
token_endpoint: 'https://auth.example.com/token',
registration_endpoint: 'https://auth.example.com/register',
@@ -385,7 +480,8 @@ describe('MCPOAuthProvider', () => {
}),
);
const result = await MCPOAuthProvider.authenticate(
const authProvider = new MCPOAuthProvider();
const result = await authProvider.authenticate(
'test-server',
configWithoutClient,
);
@@ -400,6 +496,117 @@ describe('MCPOAuthProvider', () => {
);
});
it('should perform OAuth discovery once and dynamic client registration when no client ID, authorization URL or registration URL provided', async () => {
const configWithoutClientAndAuthorizationUrl: MCPOAuthConfig = {
...mockConfig,
};
delete configWithoutClientAndAuthorizationUrl.clientId;
delete configWithoutClientAndAuthorizationUrl.authorizationUrl;
const mockResourceMetadata: OAuthProtectedResourceMetadata = {
resource: 'https://api.example.com',
authorization_servers: ['https://auth.example.com'],
};
const mockAuthServerMetadata: OAuthAuthorizationServerMetadata = {
issuer: 'https://auth.example.com',
authorization_endpoint: 'https://auth.example.com/authorize',
token_endpoint: 'https://auth.example.com/token',
registration_endpoint: 'https://auth.example.com/register',
};
const mockRegistrationResponse: OAuthClientRegistrationResponse = {
client_id: 'dynamic_client_id',
client_secret: 'dynamic_client_secret',
redirect_uris: ['http://localhost:7777/oauth/callback'],
grant_types: ['authorization_code', 'refresh_token'],
response_types: ['code'],
token_endpoint_auth_method: 'none',
};
mockFetch
.mockResolvedValueOnce(
createMockResponse({
ok: true,
status: 200,
}),
)
.mockResolvedValueOnce(
createMockResponse({
ok: true,
contentType: 'application/json',
text: JSON.stringify(mockResourceMetadata),
json: mockResourceMetadata,
}),
)
.mockResolvedValueOnce(
createMockResponse({
ok: true,
contentType: 'application/json',
text: JSON.stringify(mockAuthServerMetadata),
json: mockAuthServerMetadata,
}),
)
.mockResolvedValueOnce(
createMockResponse({
ok: true,
contentType: 'application/json',
text: JSON.stringify(mockRegistrationResponse),
json: mockRegistrationResponse,
}),
);
// Setup callback handler
let callbackHandler: unknown;
vi.mocked(http.createServer).mockImplementation((handler) => {
callbackHandler = handler;
return mockHttpServer as unknown as http.Server;
});
mockHttpServer.listen.mockImplementation((port, callback) => {
callback?.();
setTimeout(() => {
const mockReq = {
url: '/oauth/callback?code=auth_code_123&state=bW9ja19zdGF0ZV8xNl9ieXRlcw',
};
const mockRes = {
writeHead: vi.fn(),
end: vi.fn(),
};
(callbackHandler as (req: unknown, res: unknown) => void)(
mockReq,
mockRes,
);
}, 10);
});
// Mock token exchange
mockFetch.mockResolvedValueOnce(
createMockResponse({
ok: true,
contentType: 'application/json',
text: JSON.stringify(mockTokenResponse),
json: mockTokenResponse,
}),
);
const authProvider = new MCPOAuthProvider();
const result = await authProvider.authenticate(
'test-server',
configWithoutClientAndAuthorizationUrl,
'https://api.example.com',
);
expect(result).toBeDefined();
expect(mockFetch).toHaveBeenCalledWith(
'https://auth.example.com/register',
expect.objectContaining({
method: 'POST',
headers: { 'Content-Type': 'application/json' },
}),
);
});
it('should handle OAuth callback errors', async () => {
let callbackHandler: unknown;
vi.mocked(http.createServer).mockImplementation((handler) => {
@@ -424,8 +631,9 @@ describe('MCPOAuthProvider', () => {
}, 10);
});
const authProvider = new MCPOAuthProvider();
await expect(
MCPOAuthProvider.authenticate('test-server', mockConfig),
authProvider.authenticate('test-server', mockConfig),
).rejects.toThrow('OAuth error: access_denied');
});
@@ -453,8 +661,9 @@ describe('MCPOAuthProvider', () => {
}, 10);
});
const authProvider = new MCPOAuthProvider();
await expect(
MCPOAuthProvider.authenticate('test-server', mockConfig),
authProvider.authenticate('test-server', mockConfig),
).rejects.toThrow('State mismatch - possible CSRF attack');
});
@@ -491,8 +700,9 @@ describe('MCPOAuthProvider', () => {
}),
);
const authProvider = new MCPOAuthProvider();
await expect(
MCPOAuthProvider.authenticate('test-server', mockConfig),
authProvider.authenticate('test-server', mockConfig),
).rejects.toThrow('Token exchange failed: invalid_grant - Invalid grant');
});
@@ -516,8 +726,9 @@ describe('MCPOAuthProvider', () => {
return originalSetTimeout(callback, 0);
}) as unknown as typeof setTimeout;
const authProvider = new MCPOAuthProvider();
await expect(
MCPOAuthProvider.authenticate('test-server', mockConfig),
authProvider.authenticate('test-server', mockConfig),
).rejects.toThrow('OAuth callback timeout');
global.setTimeout = originalSetTimeout;
@@ -542,7 +753,8 @@ describe('MCPOAuthProvider', () => {
}),
);
const result = await MCPOAuthProvider.refreshAccessToken(
const authProvider = new MCPOAuthProvider();
const result = await authProvider.refreshAccessToken(
mockConfig,
'old_refresh_token',
'https://auth.example.com/token',
@@ -572,7 +784,8 @@ describe('MCPOAuthProvider', () => {
}),
);
await MCPOAuthProvider.refreshAccessToken(
const authProvider = new MCPOAuthProvider();
await authProvider.refreshAccessToken(
mockConfig,
'refresh_token',
'https://auth.example.com/token',
@@ -592,8 +805,9 @@ describe('MCPOAuthProvider', () => {
}),
);
const authProvider = new MCPOAuthProvider();
await expect(
MCPOAuthProvider.refreshAccessToken(
authProvider.refreshAccessToken(
mockConfig,
'invalid_refresh_token',
'https://auth.example.com/token',
@@ -614,12 +828,14 @@ describe('MCPOAuthProvider', () => {
updatedAt: Date.now(),
};
vi.mocked(MCPOAuthTokenStorage.getToken).mockResolvedValue(
const tokenStorage = new MCPOAuthTokenStorage();
vi.mocked(tokenStorage.getCredentials).mockResolvedValue(
validCredentials,
);
vi.mocked(MCPOAuthTokenStorage.isTokenExpired).mockReturnValue(false);
vi.mocked(tokenStorage.isTokenExpired).mockReturnValue(false);
const result = await MCPOAuthProvider.getValidToken(
const authProvider = new MCPOAuthProvider();
const result = await authProvider.getValidToken(
'test-server',
mockConfig,
);
@@ -636,10 +852,11 @@ describe('MCPOAuthProvider', () => {
updatedAt: Date.now(),
};
vi.mocked(MCPOAuthTokenStorage.getToken).mockResolvedValue(
const tokenStorage = new MCPOAuthTokenStorage();
vi.mocked(tokenStorage.getCredentials).mockResolvedValue(
expiredCredentials,
);
vi.mocked(MCPOAuthTokenStorage.isTokenExpired).mockReturnValue(true);
vi.mocked(tokenStorage.isTokenExpired).mockReturnValue(true);
const refreshResponse = {
access_token: 'new_access_token',
@@ -657,13 +874,14 @@ describe('MCPOAuthProvider', () => {
}),
);
const result = await MCPOAuthProvider.getValidToken(
const authProvider = new MCPOAuthProvider();
const result = await authProvider.getValidToken(
'test-server',
mockConfig,
);
expect(result).toBe('new_access_token');
expect(MCPOAuthTokenStorage.saveToken).toHaveBeenCalledWith(
expect(tokenStorage.saveToken).toHaveBeenCalledWith(
'test-server',
expect.objectContaining({ accessToken: 'new_access_token' }),
'test-client-id',
@@ -673,9 +891,11 @@ describe('MCPOAuthProvider', () => {
});
it('should return null when no credentials exist', async () => {
vi.mocked(MCPOAuthTokenStorage.getToken).mockResolvedValue(null);
const tokenStorage = new MCPOAuthTokenStorage();
vi.mocked(tokenStorage.getCredentials).mockResolvedValue(null);
const result = await MCPOAuthProvider.getValidToken(
const authProvider = new MCPOAuthProvider();
const result = await authProvider.getValidToken(
'test-server',
mockConfig,
);
@@ -692,11 +912,12 @@ describe('MCPOAuthProvider', () => {
updatedAt: Date.now(),
};
vi.mocked(MCPOAuthTokenStorage.getToken).mockResolvedValue(
const tokenStorage = new MCPOAuthTokenStorage();
vi.mocked(tokenStorage.getCredentials).mockResolvedValue(
expiredCredentials,
);
vi.mocked(MCPOAuthTokenStorage.isTokenExpired).mockReturnValue(true);
vi.mocked(MCPOAuthTokenStorage.removeToken).mockResolvedValue(undefined);
vi.mocked(tokenStorage.isTokenExpired).mockReturnValue(true);
vi.mocked(tokenStorage.deleteCredentials).mockResolvedValue(undefined);
mockFetch.mockResolvedValueOnce(
createMockResponse({
@@ -707,13 +928,14 @@ describe('MCPOAuthProvider', () => {
}),
);
const result = await MCPOAuthProvider.getValidToken(
const authProvider = new MCPOAuthProvider();
const result = await authProvider.getValidToken(
'test-server',
mockConfig,
);
expect(result).toBeNull();
expect(MCPOAuthTokenStorage.removeToken).toHaveBeenCalledWith(
expect(tokenStorage.deleteCredentials).toHaveBeenCalledWith(
'test-server',
);
expect(console.error).toHaveBeenCalledWith(
@@ -734,12 +956,14 @@ describe('MCPOAuthProvider', () => {
updatedAt: Date.now(),
};
vi.mocked(MCPOAuthTokenStorage.getToken).mockResolvedValue(
const tokenStorage = new MCPOAuthTokenStorage();
vi.mocked(tokenStorage.getCredentials).mockResolvedValue(
tokenWithoutRefresh,
);
vi.mocked(MCPOAuthTokenStorage.isTokenExpired).mockReturnValue(true);
vi.mocked(tokenStorage.isTokenExpired).mockReturnValue(true);
const result = await MCPOAuthProvider.getValidToken(
const authProvider = new MCPOAuthProvider();
const result = await authProvider.getValidToken(
'test-server',
mockConfig,
);
@@ -784,7 +1008,8 @@ describe('MCPOAuthProvider', () => {
}),
);
await MCPOAuthProvider.authenticate('test-server', mockConfig);
const authProvider = new MCPOAuthProvider();
await authProvider.authenticate('test-server', mockConfig);
expect(crypto.randomBytes).toHaveBeenCalledWith(32); // code verifier
expect(crypto.randomBytes).toHaveBeenCalledWith(16); // state
@@ -833,7 +1058,8 @@ describe('MCPOAuthProvider', () => {
}),
);
await MCPOAuthProvider.authenticate(
const authProvider = new MCPOAuthProvider();
await authProvider.authenticate(
'test-server',
mockConfig,
'https://auth.example.com',
@@ -894,7 +1120,8 @@ describe('MCPOAuthProvider', () => {
authorizationUrl: 'https://auth.example.com/authorize?audience=1234',
};
await MCPOAuthProvider.authenticate('test-server', configWithParamsInUrl);
const authProvider = new MCPOAuthProvider();
await authProvider.authenticate('test-server', configWithParamsInUrl);
const url = new URL(capturedUrl!);
expect(url.searchParams.get('audience')).toBe('1234');
@@ -947,7 +1174,8 @@ describe('MCPOAuthProvider', () => {
authorizationUrl: 'https://auth.example.com/authorize#login',
};
await MCPOAuthProvider.authenticate('test-server', configWithFragment);
const authProvider = new MCPOAuthProvider();
await authProvider.authenticate('test-server', configWithFragment);
const url = new URL(capturedUrl!);
expect(url.searchParams.get('client_id')).toBe('test-client-id');

View File

@@ -7,12 +7,15 @@
import * as http from 'node:http';
import * as crypto from 'node:crypto';
import { URL } from 'node:url';
import type { EventEmitter } from 'node:events';
import { openBrowserSecurely } from '../utils/secure-browser-launcher.js';
import type { OAuthToken } from './token-storage/types.js';
import { MCPOAuthTokenStorage } from './oauth-token-storage.js';
import { getErrorMessage } from '../utils/errors.js';
import { OAuthUtils } from './oauth-utils.js';
export const OAUTH_DISPLAY_MESSAGE_EVENT = 'oauth-display-message' as const;
/**
* OAuth configuration for an MCP server.
*/
@@ -26,6 +29,7 @@ export interface MCPOAuthConfig {
audiences?: string[];
redirectUri?: string;
tokenParamName?: string; // For SSE connections, specifies the query parameter name for the token
registrationUrl?: string;
}
/**
@@ -85,13 +89,19 @@ interface PKCEParams {
state: string;
}
const REDIRECT_PORT = 7777;
const REDIRECT_PATH = '/oauth/callback';
const HTTP_OK = 200;
/**
* Provider for handling OAuth authentication for MCP servers.
*/
export class MCPOAuthProvider {
private static readonly REDIRECT_PORT = 7777;
private static readonly REDIRECT_PATH = '/oauth/callback';
private static readonly HTTP_OK = 200;
private readonly tokenStorage: MCPOAuthTokenStorage;
constructor(tokenStorage: MCPOAuthTokenStorage = new MCPOAuthTokenStorage()) {
this.tokenStorage = tokenStorage;
}
/**
* Register a client dynamically with the OAuth server.
@@ -100,13 +110,12 @@ export class MCPOAuthProvider {
* @param config OAuth configuration
* @returns The registered client information
*/
private static async registerClient(
private async registerClient(
registrationUrl: string,
config: MCPOAuthConfig,
): Promise<OAuthClientRegistrationResponse> {
const redirectUri =
config.redirectUri ||
`http://localhost:${this.REDIRECT_PORT}${this.REDIRECT_PATH}`;
config.redirectUri || `http://localhost:${REDIRECT_PORT}${REDIRECT_PATH}`;
const registrationRequest: OAuthClientRegistrationRequest = {
client_name: 'Gemini CLI (Google ADC)',
@@ -142,7 +151,7 @@ export class MCPOAuthProvider {
* @param mcpServerUrl The MCP server URL
* @returns OAuth configuration if discovered, null otherwise
*/
private static async discoverOAuthFromMCPServer(
private async discoverOAuthFromMCPServer(
mcpServerUrl: string,
): Promise<MCPOAuthConfig | null> {
// Use the full URL with path preserved for OAuth discovery
@@ -154,7 +163,7 @@ export class MCPOAuthProvider {
*
* @returns PKCE parameters including code verifier, challenge, and state
*/
private static generatePKCEParams(): PKCEParams {
private generatePKCEParams(): PKCEParams {
// Generate code verifier (43-128 characters)
const codeVerifier = crypto.randomBytes(32).toString('base64url');
@@ -176,19 +185,16 @@ export class MCPOAuthProvider {
* @param expectedState The state parameter to validate
* @returns Promise that resolves with the authorization code
*/
private static async startCallbackServer(
private async startCallbackServer(
expectedState: string,
): Promise<OAuthAuthorizationResponse> {
return new Promise((resolve, reject) => {
const server = http.createServer(
async (req: http.IncomingMessage, res: http.ServerResponse) => {
try {
const url = new URL(
req.url!,
`http://localhost:${this.REDIRECT_PORT}`,
);
const url = new URL(req.url!, `http://localhost:${REDIRECT_PORT}`);
if (url.pathname !== this.REDIRECT_PATH) {
if (url.pathname !== REDIRECT_PATH) {
res.writeHead(404);
res.end('Not found');
return;
@@ -199,7 +205,7 @@ export class MCPOAuthProvider {
const error = url.searchParams.get('error');
if (error) {
res.writeHead(this.HTTP_OK, { 'Content-Type': 'text/html' });
res.writeHead(HTTP_OK, { 'Content-Type': 'text/html' });
res.end(`
<html>
<body>
@@ -230,7 +236,7 @@ export class MCPOAuthProvider {
}
// Send success response to browser
res.writeHead(this.HTTP_OK, { 'Content-Type': 'text/html' });
res.writeHead(HTTP_OK, { 'Content-Type': 'text/html' });
res.end(`
<html>
<body>
@@ -251,10 +257,8 @@ export class MCPOAuthProvider {
);
server.on('error', reject);
server.listen(this.REDIRECT_PORT, () => {
console.log(
`OAuth callback server listening on port ${this.REDIRECT_PORT}`,
);
server.listen(REDIRECT_PORT, () => {
console.log(`OAuth callback server listening on port ${REDIRECT_PORT}`);
});
// Timeout after 5 minutes
@@ -276,14 +280,13 @@ export class MCPOAuthProvider {
* @param mcpServerUrl The MCP server URL to use as the resource parameter
* @returns The authorization URL
*/
private static buildAuthorizationUrl(
private buildAuthorizationUrl(
config: MCPOAuthConfig,
pkceParams: PKCEParams,
mcpServerUrl?: string,
): string {
const redirectUri =
config.redirectUri ||
`http://localhost:${this.REDIRECT_PORT}${this.REDIRECT_PATH}`;
config.redirectUri || `http://localhost:${REDIRECT_PORT}${REDIRECT_PATH}`;
const params = new URLSearchParams({
client_id: config.clientId!,
@@ -333,15 +336,14 @@ export class MCPOAuthProvider {
* @param mcpServerUrl The MCP server URL to use as the resource parameter
* @returns The token response
*/
private static async exchangeCodeForToken(
private async exchangeCodeForToken(
config: MCPOAuthConfig,
code: string,
codeVerifier: string,
mcpServerUrl?: string,
): Promise<OAuthTokenResponse> {
const redirectUri =
config.redirectUri ||
`http://localhost:${this.REDIRECT_PORT}${this.REDIRECT_PATH}`;
config.redirectUri || `http://localhost:${REDIRECT_PORT}${REDIRECT_PATH}`;
const params = new URLSearchParams({
grant_type: 'authorization_code',
@@ -458,7 +460,7 @@ export class MCPOAuthProvider {
* @param mcpServerUrl The MCP server URL to use as the resource parameter
* @returns The new token response
*/
static async refreshAccessToken(
async refreshAccessToken(
config: MCPOAuthConfig,
refreshToken: string,
tokenUrl: string,
@@ -577,18 +579,28 @@ export class MCPOAuthProvider {
* @param serverName The name of the MCP server
* @param config OAuth configuration
* @param mcpServerUrl Optional MCP server URL for OAuth discovery
* @param messageHandler Optional handler for displaying user-facing messages
* @returns The obtained OAuth token
*/
static async authenticate(
async authenticate(
serverName: string,
config: MCPOAuthConfig,
mcpServerUrl?: string,
events?: EventEmitter,
): Promise<OAuthToken> {
// Helper function to display messages through handler or fallback to console.log
const displayMessage = (message: string) => {
if (events) {
events.emit(OAUTH_DISPLAY_MESSAGE_EVENT, message);
} else {
console.log(message);
}
};
// If no authorization URL is provided, try to discover OAuth configuration
if (!config.authorizationUrl && mcpServerUrl) {
console.log(
'No authorization URL provided, attempting OAuth discovery...',
);
console.debug(`Starting OAuth for MCP server "${serverName}"…
No authorization URL; using OAuth discovery`);
// First check if the server requires authentication via WWW-Authenticate header
try {
@@ -640,6 +652,7 @@ export class MCPOAuthProvider {
authorizationUrl: discoveredConfig.authorizationUrl,
tokenUrl: discoveredConfig.tokenUrl,
scopes: discoveredConfig.scopes || config.scopes || [],
registrationUrl: discoveredConfig.registrationUrl,
// Preserve existing client credentials
clientId: config.clientId,
clientSecret: config.clientSecret,
@@ -654,40 +667,38 @@ export class MCPOAuthProvider {
// If no client ID is provided, try dynamic client registration
if (!config.clientId) {
// Extract server URL from authorization URL
if (!config.authorizationUrl) {
throw new Error(
'Cannot perform dynamic registration without authorization URL',
);
}
let registrationUrl = config.registrationUrl;
const authUrl = new URL(config.authorizationUrl);
const serverUrl = `${authUrl.protocol}//${authUrl.host}`;
// If no registration URL was previously discovered, try to discover it
if (!registrationUrl) {
// Extract server URL from authorization URL
if (!config.authorizationUrl) {
throw new Error(
'Cannot perform dynamic registration without authorization URL',
);
}
console.log(
'No client ID provided, attempting dynamic client registration...',
);
const authUrl = new URL(config.authorizationUrl);
const serverUrl = `${authUrl.protocol}//${authUrl.host}`;
// Get the authorization server metadata for registration
const authServerMetadataUrl = new URL(
'/.well-known/oauth-authorization-server',
serverUrl,
).toString();
console.debug('→ Attempting dynamic client registration...');
const authServerMetadata =
await OAuthUtils.fetchAuthorizationServerMetadata(
authServerMetadataUrl,
);
if (!authServerMetadata) {
throw new Error(
'Failed to fetch authorization server metadata for client registration',
);
// Get the authorization server metadata for registration
const authServerMetadata =
await OAuthUtils.discoverAuthorizationServerMetadata(serverUrl);
if (!authServerMetadata) {
throw new Error(
'Failed to fetch authorization server metadata for client registration',
);
}
registrationUrl = authServerMetadata.registration_endpoint;
}
// Register client if registration endpoint is available
if (authServerMetadata.registration_endpoint) {
if (registrationUrl) {
const clientRegistration = await this.registerClient(
authServerMetadata.registration_endpoint,
registrationUrl,
config,
);
@@ -696,7 +707,7 @@ export class MCPOAuthProvider {
config.clientSecret = clientRegistration.client_secret;
}
console.log('Dynamic client registration successful');
console.debug('Dynamic client registration successful');
} else {
throw new Error(
'No client ID provided and dynamic registration not supported',
@@ -721,30 +732,13 @@ export class MCPOAuthProvider {
mcpServerUrl,
);
console.log('\nOpening browser for OAuth authentication...');
console.log('If the browser does not open, please visit:');
console.log('');
displayMessage(`Opening your browser for OAuth sign-in...
// Get terminal width or default to 80
const terminalWidth = process.stdout.columns || 80;
const separatorLength = Math.min(terminalWidth - 2, 80);
const separator = '━'.repeat(separatorLength);
If the browser does not open, copy and paste this URL into your browser:
${authUrl}
console.log(separator);
console.log(
'COPY THE ENTIRE URL BELOW (select all text between the lines):',
);
console.log(separator);
console.log(authUrl);
console.log(separator);
console.log('');
console.log(
'💡 TIP: Triple-click to select the entire URL, then copy and paste it into your browser.',
);
console.log(
'⚠️ Make sure to copy the COMPLETE URL - it may wrap across multiple lines.',
);
console.log('');
💡 TIP: Triple-click to select the entire URL, then copy and paste it into your browser.
⚠️ Make sure to copy the COMPLETE URL - it may wrap across multiple lines.`);
// Start callback server
const callbackPromise = this.startCallbackServer(pkceParams.state);
@@ -762,7 +756,7 @@ export class MCPOAuthProvider {
// Wait for callback
const { code } = await callbackPromise;
console.log('\nAuthorization code received, exchanging for tokens...');
console.debug('Authorization code received, exchanging for tokens...');
// Exchange code for tokens
const tokenResponse = await this.exchangeCodeForToken(
@@ -790,23 +784,27 @@ export class MCPOAuthProvider {
// Save token
try {
await MCPOAuthTokenStorage.saveToken(
await this.tokenStorage.saveToken(
serverName,
token,
config.clientId,
config.tokenUrl,
mcpServerUrl,
);
console.log('Authentication successful! Token saved.');
console.debug('Authentication successful! Token saved.');
// Verify token was saved
const savedToken = await MCPOAuthTokenStorage.getToken(serverName);
const savedToken = await this.tokenStorage.getCredentials(serverName);
if (savedToken && savedToken.token && savedToken.token.accessToken) {
const tokenPreview =
savedToken.token.accessToken.length > 20
? `${savedToken.token.accessToken.substring(0, 20)}...`
: '[token]';
console.log(`Token verification successful: ${tokenPreview}`);
// Avoid leaking token material; log a short SHA-256 fingerprint instead.
const tokenFingerprint = crypto
.createHash('sha256')
.update(savedToken.token.accessToken)
.digest('hex')
.slice(0, 8);
console.debug(
`✓ Token verification successful (fingerprint: ${tokenFingerprint})`,
);
} else {
console.error(
'Token verification failed: token not found or invalid after save',
@@ -827,12 +825,12 @@ export class MCPOAuthProvider {
* @param config OAuth configuration
* @returns A valid access token or null if not authenticated
*/
static async getValidToken(
async getValidToken(
serverName: string,
config: MCPOAuthConfig,
): Promise<string | null> {
console.debug(`Getting valid token for server: ${serverName}`);
const credentials = await MCPOAuthTokenStorage.getToken(serverName);
const credentials = await this.tokenStorage.getCredentials(serverName);
if (!credentials) {
console.debug(`No credentials found for server: ${serverName}`);
@@ -841,11 +839,11 @@ export class MCPOAuthProvider {
const { token } = credentials;
console.debug(
`Found token for server: ${serverName}, expired: ${MCPOAuthTokenStorage.isTokenExpired(token)}`,
`Found token for server: ${serverName}, expired: ${this.tokenStorage.isTokenExpired(token)}`,
);
// Check if token is expired
if (!MCPOAuthTokenStorage.isTokenExpired(token)) {
if (!this.tokenStorage.isTokenExpired(token)) {
console.debug(`Returning valid token for server: ${serverName}`);
return token.accessToken;
}
@@ -874,7 +872,7 @@ export class MCPOAuthProvider {
newToken.expiresAt = Date.now() + newTokenResponse.expires_in * 1000;
}
await MCPOAuthTokenStorage.saveToken(
await this.tokenStorage.saveToken(
serverName,
newToken,
config.clientId,
@@ -886,7 +884,7 @@ export class MCPOAuthProvider {
} catch (error) {
console.error(`Failed to refresh token: ${getErrorMessage(error)}`);
// Remove invalid token
await MCPOAuthTokenStorage.removeToken(serverName);
await this.tokenStorage.deleteCredentials(serverName);
}
}

View File

@@ -4,13 +4,15 @@
* SPDX-License-Identifier: Apache-2.0
*/
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest';
import { promises as fs } from 'node:fs';
import * as path from 'node:path';
import { MCPOAuthTokenStorage } from './oauth-token-storage.js';
import type { OAuthToken, OAuthCredentials } from './token-storage/types.js';
import { FORCE_ENCRYPTED_FILE_ENV_VAR } from './token-storage/index.js';
import type { OAuthCredentials, OAuthToken } from './token-storage/types.js';
import { QWEN_DIR } from '../utils/paths.js';
// Mock file system operations
// Mock dependencies
vi.mock('node:fs', () => ({
promises: {
readFile: vi.fn(),
@@ -18,20 +20,42 @@ vi.mock('node:fs', () => ({
mkdir: vi.fn(),
unlink: vi.fn(),
},
mkdirSync: vi.fn(),
}));
vi.mock('node:os', () => ({
homedir: vi.fn(() => '/mock/home'),
vi.mock('node:path', () => ({
dirname: vi.fn(),
join: vi.fn(),
}));
vi.mock('../config/storage.js', () => ({
Storage: {
getMcpOAuthTokensPath: vi.fn(),
},
}));
const mockHybridTokenStorage = {
listServers: vi.fn(),
setCredentials: vi.fn(),
getCredentials: vi.fn(),
deleteCredentials: vi.fn(),
clearAll: vi.fn(),
getAllCredentials: vi.fn(),
};
vi.mock('./token-storage/hybrid-token-storage.js', () => ({
HybridTokenStorage: vi.fn(() => mockHybridTokenStorage),
}));
const ONE_HR_MS = 3600000;
describe('MCPOAuthTokenStorage', () => {
let tokenStorage: MCPOAuthTokenStorage;
const mockToken: OAuthToken = {
accessToken: 'access_token_123',
refreshToken: 'refresh_token_456',
tokenType: 'Bearer',
scope: 'read write',
expiresAt: Date.now() + 3600000, // 1 hour from now
expiresAt: Date.now() + ONE_HR_MS,
};
const mockCredentials: OAuthCredentials = {
@@ -42,281 +66,385 @@ describe('MCPOAuthTokenStorage', () => {
updatedAt: Date.now(),
};
beforeEach(() => {
vi.clearAllMocks();
vi.spyOn(console, 'error').mockImplementation(() => {});
});
describe('with encrypted flag false', () => {
beforeEach(() => {
vi.stubEnv(FORCE_ENCRYPTED_FILE_ENV_VAR, 'false');
tokenStorage = new MCPOAuthTokenStorage();
afterEach(() => {
vi.restoreAllMocks();
});
describe('loadTokens', () => {
it('should return empty map when token file does not exist', async () => {
vi.mocked(fs.readFile).mockRejectedValue({ code: 'ENOENT' });
const tokens = await MCPOAuthTokenStorage.loadTokens();
expect(tokens.size).toBe(0);
expect(console.error).not.toHaveBeenCalled();
vi.clearAllMocks();
vi.spyOn(console, 'error');
});
it('should load tokens from file successfully', async () => {
const tokensArray = [mockCredentials];
vi.mocked(fs.readFile).mockResolvedValue(JSON.stringify(tokensArray));
const tokens = await MCPOAuthTokenStorage.loadTokens();
expect(tokens.size).toBe(1);
expect(tokens.get('test-server')).toEqual(mockCredentials);
expect(fs.readFile).toHaveBeenCalledWith(
path.join('/mock/home', '.qwen', 'mcp-oauth-tokens.json'),
'utf-8',
);
afterEach(() => {
vi.unstubAllEnvs();
vi.restoreAllMocks();
});
it('should handle corrupted token file gracefully', async () => {
vi.mocked(fs.readFile).mockResolvedValue('invalid json');
describe('getAllCredentials', () => {
it('should return empty map when token file does not exist', async () => {
vi.mocked(fs.readFile).mockRejectedValue({ code: 'ENOENT' });
const tokens = await MCPOAuthTokenStorage.loadTokens();
const tokens = await tokenStorage.getAllCredentials();
expect(tokens.size).toBe(0);
expect(console.error).toHaveBeenCalledWith(
expect.stringContaining('Failed to load MCP OAuth tokens'),
);
});
it('should handle file read errors other than ENOENT', async () => {
const error = new Error('Permission denied');
vi.mocked(fs.readFile).mockRejectedValue(error);
const tokens = await MCPOAuthTokenStorage.loadTokens();
expect(tokens.size).toBe(0);
expect(console.error).toHaveBeenCalledWith(
expect.stringContaining('Failed to load MCP OAuth tokens'),
);
});
});
describe('saveToken', () => {
it('should save token with restricted permissions', async () => {
vi.mocked(fs.readFile).mockRejectedValue({ code: 'ENOENT' });
vi.mocked(fs.mkdir).mockResolvedValue(undefined);
vi.mocked(fs.writeFile).mockResolvedValue(undefined);
await MCPOAuthTokenStorage.saveToken(
'test-server',
mockToken,
'client-id',
'https://token.url',
);
expect(fs.mkdir).toHaveBeenCalledWith(path.join('/mock/home', '.qwen'), {
recursive: true,
expect(tokens.size).toBe(0);
expect(console.error).not.toHaveBeenCalled();
});
it('should load tokens from file successfully', async () => {
const tokensArray = [mockCredentials];
vi.mocked(fs.readFile).mockResolvedValue(JSON.stringify(tokensArray));
const tokens = await tokenStorage.getAllCredentials();
expect(tokens.size).toBe(1);
expect(tokens.get('test-server')).toEqual(mockCredentials);
expect(fs.readFile).toHaveBeenCalledWith(
path.join('/mock/home', QWEN_DIR, 'mcp-oauth-tokens.json'),
'utf-8',
);
});
it('should handle corrupted token file gracefully', async () => {
vi.mocked(fs.readFile).mockResolvedValue('invalid json');
const tokens = await tokenStorage.getAllCredentials();
expect(tokens.size).toBe(0);
expect(console.error).toHaveBeenCalledWith(
expect.stringContaining('Failed to load MCP OAuth tokens'),
);
});
it('should handle file read errors other than ENOENT', async () => {
const error = new Error('Permission denied');
vi.mocked(fs.readFile).mockRejectedValue(error);
const tokens = await tokenStorage.getAllCredentials();
expect(tokens.size).toBe(0);
expect(console.error).toHaveBeenCalledWith(
expect.stringContaining('Failed to load MCP OAuth tokens'),
);
});
expect(fs.writeFile).toHaveBeenCalledWith(
path.join('/mock/home', '.qwen', 'mcp-oauth-tokens.json'),
expect.stringContaining('test-server'),
{ mode: 0o600 },
);
});
it('should update existing token for same server', async () => {
const existingCredentials = {
...mockCredentials,
serverName: 'existing-server',
};
vi.mocked(fs.readFile).mockResolvedValue(
JSON.stringify([existingCredentials]),
);
vi.mocked(fs.writeFile).mockResolvedValue(undefined);
describe('saveToken', () => {
it('should save token with restricted permissions', async () => {
vi.mocked(fs.readFile).mockRejectedValue({ code: 'ENOENT' });
vi.mocked(fs.mkdir).mockResolvedValue(undefined);
vi.mocked(fs.writeFile).mockResolvedValue(undefined);
const newToken = { ...mockToken, accessToken: 'new_access_token' };
await MCPOAuthTokenStorage.saveToken('existing-server', newToken);
await tokenStorage.saveToken(
'test-server',
mockToken,
'client-id',
'https://token.url',
);
const writeCall = vi.mocked(fs.writeFile).mock.calls[0];
const savedData = JSON.parse(writeCall[1] as string);
expect(fs.mkdir).toHaveBeenCalledWith(
path.join('/mock/home', QWEN_DIR),
{ recursive: true },
);
expect(fs.writeFile).toHaveBeenCalledWith(
path.join('/mock/home', QWEN_DIR, 'mcp-oauth-tokens.json'),
expect.stringContaining('test-server'),
{ mode: 0o600 },
);
});
expect(savedData).toHaveLength(1);
expect(savedData[0].token.accessToken).toBe('new_access_token');
expect(savedData[0].serverName).toBe('existing-server');
it('should update existing token for same server', async () => {
const existingCredentials: OAuthCredentials = {
...mockCredentials,
serverName: 'existing-server',
};
vi.mocked(fs.readFile).mockResolvedValue(
JSON.stringify([existingCredentials]),
);
vi.mocked(fs.writeFile).mockResolvedValue(undefined);
const newToken: OAuthToken = {
...mockToken,
accessToken: 'new_access_token',
};
await tokenStorage.saveToken('existing-server', newToken);
const writeCall = vi.mocked(fs.writeFile).mock.calls[0];
const savedData = JSON.parse(
writeCall[1] as string,
) as OAuthCredentials[];
expect(savedData).toHaveLength(1);
expect(savedData[0].token.accessToken).toBe('new_access_token');
expect(savedData[0].serverName).toBe('existing-server');
});
it('should handle write errors gracefully', async () => {
vi.mocked(fs.readFile).mockRejectedValue({ code: 'ENOENT' });
vi.mocked(fs.mkdir).mockResolvedValue(undefined);
const writeError = new Error('Disk full');
vi.mocked(fs.writeFile).mockRejectedValue(writeError);
await expect(
tokenStorage.saveToken('test-server', mockToken),
).rejects.toThrow('Disk full');
expect(console.error).toHaveBeenCalledWith(
expect.stringContaining('Failed to save MCP OAuth token'),
);
});
});
it('should handle write errors gracefully', async () => {
vi.mocked(fs.readFile).mockRejectedValue({ code: 'ENOENT' });
vi.mocked(fs.mkdir).mockResolvedValue(undefined);
const writeError = new Error('Disk full');
vi.mocked(fs.writeFile).mockRejectedValue(writeError);
describe('getCredentials', () => {
it('should return token for existing server', async () => {
vi.mocked(fs.readFile).mockResolvedValue(
JSON.stringify([mockCredentials]),
);
await expect(
MCPOAuthTokenStorage.saveToken('test-server', mockToken),
).rejects.toThrow('Disk full');
const result = await tokenStorage.getCredentials('test-server');
expect(console.error).toHaveBeenCalledWith(
expect.stringContaining('Failed to save MCP OAuth token'),
);
expect(result).toEqual(mockCredentials);
});
it('should return null for non-existent server', async () => {
vi.mocked(fs.readFile).mockResolvedValue(
JSON.stringify([mockCredentials]),
);
const result = await tokenStorage.getCredentials('non-existent');
expect(result).toBeNull();
});
it('should return null when no tokens file exists', async () => {
vi.mocked(fs.readFile).mockRejectedValue({ code: 'ENOENT' });
const result = await tokenStorage.getCredentials('test-server');
expect(result).toBeNull();
});
});
describe('deleteCredentials', () => {
it('should remove token for specific server', async () => {
const credentials1: OAuthCredentials = {
...mockCredentials,
serverName: 'server1',
};
const credentials2: OAuthCredentials = {
...mockCredentials,
serverName: 'server2',
};
vi.mocked(fs.readFile).mockResolvedValue(
JSON.stringify([credentials1, credentials2]),
);
vi.mocked(fs.writeFile).mockResolvedValue(undefined);
await tokenStorage.deleteCredentials('server1');
const writeCall = vi.mocked(fs.writeFile).mock.calls[0];
const savedData = JSON.parse(writeCall[1] as string);
expect(savedData).toHaveLength(1);
expect(savedData[0].serverName).toBe('server2');
});
it('should remove token file when no tokens remain', async () => {
vi.mocked(fs.readFile).mockResolvedValue(
JSON.stringify([mockCredentials]),
);
vi.mocked(fs.unlink).mockResolvedValue(undefined);
await tokenStorage.deleteCredentials('test-server');
expect(fs.unlink).toHaveBeenCalledWith(
path.join('/mock/home', QWEN_DIR, 'mcp-oauth-tokens.json'),
);
expect(fs.writeFile).not.toHaveBeenCalled();
});
it('should handle removal of non-existent token gracefully', async () => {
vi.mocked(fs.readFile).mockResolvedValue(
JSON.stringify([mockCredentials]),
);
await tokenStorage.deleteCredentials('non-existent');
expect(fs.writeFile).not.toHaveBeenCalled();
expect(fs.unlink).not.toHaveBeenCalled();
});
it('should handle file operation errors gracefully', async () => {
vi.mocked(fs.readFile).mockResolvedValue(
JSON.stringify([mockCredentials]),
);
vi.mocked(fs.unlink).mockRejectedValue(new Error('Permission denied'));
await tokenStorage.deleteCredentials('test-server');
expect(console.error).toHaveBeenCalledWith(
expect.stringContaining('Failed to remove MCP OAuth token'),
);
});
});
describe('isTokenExpired', () => {
it('should return false for token without expiry', () => {
const tokenWithoutExpiry: OAuthToken = { ...mockToken };
delete tokenWithoutExpiry.expiresAt;
const result = tokenStorage.isTokenExpired(tokenWithoutExpiry);
expect(result).toBe(false);
});
it('should return false for valid token', () => {
const futureToken: OAuthToken = {
...mockToken,
expiresAt: Date.now() + ONE_HR_MS,
};
const result = tokenStorage.isTokenExpired(futureToken);
expect(result).toBe(false);
});
it('should return true for expired token', () => {
const expiredToken: OAuthToken = {
...mockToken,
expiresAt: Date.now() - ONE_HR_MS,
};
const result = tokenStorage.isTokenExpired(expiredToken);
expect(result).toBe(true);
});
it('should return true for token expiring within buffer time', () => {
const soonToExpireToken: OAuthToken = {
...mockToken,
expiresAt: Date.now() + 60000, // 1 minute from now (within 5-minute buffer)
};
const result = tokenStorage.isTokenExpired(soonToExpireToken);
expect(result).toBe(true);
});
});
describe('clearAll', () => {
it('should remove token file successfully', async () => {
vi.mocked(fs.unlink).mockResolvedValue(undefined);
await tokenStorage.clearAll();
expect(fs.unlink).toHaveBeenCalledWith(
path.join('/mock/home', QWEN_DIR, 'mcp-oauth-tokens.json'),
);
});
it('should handle non-existent file gracefully', async () => {
vi.mocked(fs.unlink).mockRejectedValue({ code: 'ENOENT' });
await tokenStorage.clearAll();
expect(console.error).not.toHaveBeenCalled();
});
it('should handle other file errors gracefully', async () => {
vi.mocked(fs.unlink).mockRejectedValue(new Error('Permission denied'));
await tokenStorage.clearAll();
expect(console.error).toHaveBeenCalledWith(
expect.stringContaining('Failed to clear MCP OAuth tokens'),
);
});
});
});
describe('getToken', () => {
it('should return token for existing server', async () => {
vi.mocked(fs.readFile).mockResolvedValue(
JSON.stringify([mockCredentials]),
);
describe('with encrypted flag true', () => {
beforeEach(() => {
vi.stubEnv(FORCE_ENCRYPTED_FILE_ENV_VAR, 'true');
tokenStorage = new MCPOAuthTokenStorage();
const result = await MCPOAuthTokenStorage.getToken('test-server');
expect(result).toEqual(mockCredentials);
vi.clearAllMocks();
vi.spyOn(console, 'error');
});
it('should return null for non-existent server', async () => {
vi.mocked(fs.readFile).mockResolvedValue(
JSON.stringify([mockCredentials]),
);
const result = await MCPOAuthTokenStorage.getToken('non-existent');
expect(result).toBeNull();
afterEach(() => {
vi.unstubAllEnvs();
vi.restoreAllMocks();
});
it('should return null when no tokens file exists', async () => {
vi.mocked(fs.readFile).mockRejectedValue({ code: 'ENOENT' });
const result = await MCPOAuthTokenStorage.getToken('test-server');
expect(result).toBeNull();
});
});
describe('removeToken', () => {
it('should remove token for specific server', async () => {
const credentials1 = { ...mockCredentials, serverName: 'server1' };
const credentials2 = { ...mockCredentials, serverName: 'server2' };
vi.mocked(fs.readFile).mockResolvedValue(
JSON.stringify([credentials1, credentials2]),
);
vi.mocked(fs.writeFile).mockResolvedValue(undefined);
await MCPOAuthTokenStorage.removeToken('server1');
const writeCall = vi.mocked(fs.writeFile).mock.calls[0];
const savedData = JSON.parse(writeCall[1] as string);
expect(savedData).toHaveLength(1);
expect(savedData[0].serverName).toBe('server2');
it('should use HybridTokenStorage to list all credentials', async () => {
mockHybridTokenStorage.getAllCredentials.mockResolvedValue(new Map());
const servers = await tokenStorage.getAllCredentials();
expect(mockHybridTokenStorage.getAllCredentials).toHaveBeenCalled();
expect(servers).toEqual(new Map());
});
it('should remove token file when no tokens remain', async () => {
vi.mocked(fs.readFile).mockResolvedValue(
JSON.stringify([mockCredentials]),
);
vi.mocked(fs.unlink).mockResolvedValue(undefined);
await MCPOAuthTokenStorage.removeToken('test-server');
expect(fs.unlink).toHaveBeenCalledWith(
path.join('/mock/home', '.qwen', 'mcp-oauth-tokens.json'),
);
expect(fs.writeFile).not.toHaveBeenCalled();
it('should use HybridTokenStorage to list servers', async () => {
mockHybridTokenStorage.listServers.mockResolvedValue(['server1']);
const servers = await tokenStorage.listServers();
expect(mockHybridTokenStorage.listServers).toHaveBeenCalled();
expect(servers).toEqual(['server1']);
});
it('should handle removal of non-existent token gracefully', async () => {
vi.mocked(fs.readFile).mockResolvedValue(
JSON.stringify([mockCredentials]),
);
await MCPOAuthTokenStorage.removeToken('non-existent');
expect(fs.writeFile).not.toHaveBeenCalled();
expect(fs.unlink).not.toHaveBeenCalled();
});
it('should handle file operation errors gracefully', async () => {
vi.mocked(fs.readFile).mockResolvedValue(
JSON.stringify([mockCredentials]),
);
vi.mocked(fs.unlink).mockRejectedValue(new Error('Permission denied'));
await MCPOAuthTokenStorage.removeToken('test-server');
expect(console.error).toHaveBeenCalledWith(
expect.stringContaining('Failed to remove MCP OAuth token'),
it('should use HybridTokenStorage to set credentials', async () => {
await tokenStorage.setCredentials(mockCredentials);
expect(mockHybridTokenStorage.setCredentials).toHaveBeenCalledWith(
mockCredentials,
);
});
});
describe('isTokenExpired', () => {
it('should return false for token without expiry', () => {
const tokenWithoutExpiry = { ...mockToken };
delete tokenWithoutExpiry.expiresAt;
it('should use HybridTokenStorage to save a token', async () => {
const serverName = 'server1';
const now = Date.now();
vi.spyOn(Date, 'now').mockReturnValue(now);
const result = MCPOAuthTokenStorage.isTokenExpired(tokenWithoutExpiry);
await tokenStorage.saveToken(
serverName,
mockToken,
'clientId',
'tokenUrl',
'mcpUrl',
);
expect(result).toBe(false);
});
it('should return false for valid token', () => {
const futureToken = {
...mockToken,
expiresAt: Date.now() + 3600000, // 1 hour from now
const expectedCredential: OAuthCredentials = {
serverName,
token: mockToken,
clientId: 'clientId',
tokenUrl: 'tokenUrl',
mcpServerUrl: 'mcpUrl',
updatedAt: now,
};
const result = MCPOAuthTokenStorage.isTokenExpired(futureToken);
expect(result).toBe(false);
expect(mockHybridTokenStorage.setCredentials).toHaveBeenCalledWith(
expectedCredential,
);
expect(path.dirname).toHaveBeenCalled();
expect(fs.mkdir).toHaveBeenCalled();
});
it('should return true for expired token', () => {
const expiredToken = {
...mockToken,
expiresAt: Date.now() - 3600000, // 1 hour ago
};
const result = MCPOAuthTokenStorage.isTokenExpired(expiredToken);
expect(result).toBe(true);
it('should use HybridTokenStorage to get credentials', async () => {
mockHybridTokenStorage.getCredentials.mockResolvedValue(mockCredentials);
const result = await tokenStorage.getCredentials('server1');
expect(mockHybridTokenStorage.getCredentials).toHaveBeenCalledWith(
'server1',
);
expect(result).toBe(mockCredentials);
});
it('should return true for token expiring within buffer time', () => {
const soonToExpireToken = {
...mockToken,
expiresAt: Date.now() + 60000, // 1 minute from now (within 5-minute buffer)
};
const result = MCPOAuthTokenStorage.isTokenExpired(soonToExpireToken);
expect(result).toBe(true);
});
});
describe('clearAllTokens', () => {
it('should remove token file successfully', async () => {
vi.mocked(fs.unlink).mockResolvedValue(undefined);
await MCPOAuthTokenStorage.clearAllTokens();
expect(fs.unlink).toHaveBeenCalledWith(
path.join('/mock/home', '.qwen', 'mcp-oauth-tokens.json'),
it('should use HybridTokenStorage to delete credentials', async () => {
await tokenStorage.deleteCredentials('server1');
expect(mockHybridTokenStorage.deleteCredentials).toHaveBeenCalledWith(
'server1',
);
});
it('should handle non-existent file gracefully', async () => {
vi.mocked(fs.unlink).mockRejectedValue({ code: 'ENOENT' });
await MCPOAuthTokenStorage.clearAllTokens();
expect(console.error).not.toHaveBeenCalled();
});
it('should handle other file errors gracefully', async () => {
vi.mocked(fs.unlink).mockRejectedValue(new Error('Permission denied'));
await MCPOAuthTokenStorage.clearAllTokens();
expect(console.error).toHaveBeenCalledWith(
expect.stringContaining('Failed to clear MCP OAuth tokens'),
);
it('should use HybridTokenStorage to clear all tokens', async () => {
await tokenStorage.clearAll();
expect(mockHybridTokenStorage.clearAll).toHaveBeenCalled();
});
});
});

View File

@@ -8,25 +8,40 @@ import { promises as fs } from 'node:fs';
import * as path from 'node:path';
import { Storage } from '../config/storage.js';
import { getErrorMessage } from '../utils/errors.js';
import type { OAuthToken, OAuthCredentials } from './token-storage/types.js';
import type {
OAuthToken,
OAuthCredentials,
TokenStorage,
} from './token-storage/types.js';
import { HybridTokenStorage } from './token-storage/hybrid-token-storage.js';
import {
DEFAULT_SERVICE_NAME,
FORCE_ENCRYPTED_FILE_ENV_VAR,
} from './token-storage/index.js';
/**
* Class for managing MCP OAuth token storage and retrieval.
*/
export class MCPOAuthTokenStorage {
export class MCPOAuthTokenStorage implements TokenStorage {
private readonly hybridTokenStorage = new HybridTokenStorage(
DEFAULT_SERVICE_NAME,
);
private readonly useEncryptedFile =
process.env[FORCE_ENCRYPTED_FILE_ENV_VAR] === 'true';
/**
* Get the path to the token storage file.
*
* @returns The full path to the token storage file
*/
private static getTokenFilePath(): string {
private getTokenFilePath(): string {
return Storage.getMcpOAuthTokensPath();
}
/**
* Ensure the config directory exists.
*/
private static async ensureConfigDir(): Promise<void> {
private async ensureConfigDir(): Promise<void> {
const configDir = path.dirname(this.getTokenFilePath());
await fs.mkdir(configDir, { recursive: true });
}
@@ -36,7 +51,10 @@ export class MCPOAuthTokenStorage {
*
* @returns A map of server names to credentials
*/
static async loadTokens(): Promise<Map<string, OAuthCredentials>> {
async getAllCredentials(): Promise<Map<string, OAuthCredentials>> {
if (this.useEncryptedFile) {
return this.hybridTokenStorage.getAllCredentials();
}
const tokenMap = new Map<string, OAuthCredentials>();
try {
@@ -59,36 +77,20 @@ export class MCPOAuthTokenStorage {
return tokenMap;
}
/**
* Save a token for a specific MCP server.
*
* @param serverName The name of the MCP server
* @param token The OAuth token to save
* @param clientId Optional client ID used for this token
* @param tokenUrl Optional token URL used for this token
* @param mcpServerUrl Optional MCP server URL
*/
static async saveToken(
serverName: string,
token: OAuthToken,
clientId?: string,
tokenUrl?: string,
mcpServerUrl?: string,
): Promise<void> {
await this.ensureConfigDir();
async listServers(): Promise<string[]> {
if (this.useEncryptedFile) {
return this.hybridTokenStorage.listServers();
}
const tokens = await this.getAllCredentials();
return Array.from(tokens.keys());
}
const tokens = await this.loadTokens();
const credential: OAuthCredentials = {
serverName,
token,
clientId,
tokenUrl,
mcpServerUrl,
updatedAt: Date.now(),
};
tokens.set(serverName, credential);
async setCredentials(credentials: OAuthCredentials): Promise<void> {
if (this.useEncryptedFile) {
return this.hybridTokenStorage.setCredentials(credentials);
}
const tokens = await this.getAllCredentials();
tokens.set(credentials.serverName, credentials);
const tokenArray = Array.from(tokens.values());
const tokenFile = this.getTokenFilePath();
@@ -107,14 +109,50 @@ export class MCPOAuthTokenStorage {
}
}
/**
* Save a token for a specific MCP server.
*
* @param serverName The name of the MCP server
* @param token The OAuth token to save
* @param clientId Optional client ID used for this token
* @param tokenUrl Optional token URL used for this token
* @param mcpServerUrl Optional MCP server URL
*/
async saveToken(
serverName: string,
token: OAuthToken,
clientId?: string,
tokenUrl?: string,
mcpServerUrl?: string,
): Promise<void> {
await this.ensureConfigDir();
const credential: OAuthCredentials = {
serverName,
token,
clientId,
tokenUrl,
mcpServerUrl,
updatedAt: Date.now(),
};
if (this.useEncryptedFile) {
return this.hybridTokenStorage.setCredentials(credential);
}
await this.setCredentials(credential);
}
/**
* Get a token for a specific MCP server.
*
* @param serverName The name of the MCP server
* @returns The stored credentials or null if not found
*/
static async getToken(serverName: string): Promise<OAuthCredentials | null> {
const tokens = await this.loadTokens();
async getCredentials(serverName: string): Promise<OAuthCredentials | null> {
if (this.useEncryptedFile) {
return this.hybridTokenStorage.getCredentials(serverName);
}
const tokens = await this.getAllCredentials();
return tokens.get(serverName) || null;
}
@@ -123,8 +161,11 @@ export class MCPOAuthTokenStorage {
*
* @param serverName The name of the MCP server
*/
static async removeToken(serverName: string): Promise<void> {
const tokens = await this.loadTokens();
async deleteCredentials(serverName: string): Promise<void> {
if (this.useEncryptedFile) {
return this.hybridTokenStorage.deleteCredentials(serverName);
}
const tokens = await this.getAllCredentials();
if (tokens.delete(serverName)) {
const tokenArray = Array.from(tokens.values());
@@ -153,7 +194,7 @@ export class MCPOAuthTokenStorage {
* @param token The token to check
* @returns True if the token is expired
*/
static isTokenExpired(token: OAuthToken): boolean {
isTokenExpired(token: OAuthToken): boolean {
if (!token.expiresAt) {
return false; // No expiry, assume valid
}
@@ -166,7 +207,10 @@ export class MCPOAuthTokenStorage {
/**
* Clear all stored MCP OAuth tokens.
*/
static async clearAllTokens(): Promise<void> {
async clearAll(): Promise<void> {
if (this.useEncryptedFile) {
return this.hybridTokenStorage.clearAll();
}
try {
const tokenFile = this.getTokenFilePath();
await fs.unlink(tokenFile);

View File

@@ -137,6 +137,7 @@ export class OAuthUtils {
authorizationUrl: metadata.authorization_endpoint,
tokenUrl: metadata.token_endpoint,
scopes: metadata.scopes_supported || [],
registrationUrl: metadata.registration_endpoint,
};
}

View File

@@ -0,0 +1,153 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { describe, it, expect, vi, beforeEach } from 'vitest';
import { ServiceAccountImpersonationProvider } from './sa-impersonation-provider.js';
import type { MCPServerConfig } from '../config/config.js';
const mockRequest = vi.fn();
const mockGetClient = vi.fn(() => ({
request: mockRequest,
}));
// Mock the google-auth-library to use a shared mock function
vi.mock('google-auth-library', async (importOriginal) => {
const actual = await importOriginal<typeof import('google-auth-library')>();
return {
...actual,
GoogleAuth: vi.fn().mockImplementation(() => ({
getClient: mockGetClient,
})),
};
});
const defaultSAConfig: MCPServerConfig = {
url: 'https://my-iap-service.run.app',
targetAudience: 'my-audience',
targetServiceAccount: 'my-sa',
};
describe('ServiceAccountImpersonationProvider', () => {
beforeEach(() => {
// Reset mocks before each test
vi.clearAllMocks();
});
it('should throw an error if no URL is provided', () => {
const config: MCPServerConfig = {};
expect(() => new ServiceAccountImpersonationProvider(config)).toThrow(
'A url or httpUrl must be provided for the Service Account Impersonation provider',
);
});
it('should throw an error if no targetAudience is provided', () => {
const config: MCPServerConfig = {
url: 'https://my-iap-service.run.app',
};
expect(() => new ServiceAccountImpersonationProvider(config)).toThrow(
'targetAudience must be provided for the Service Account Impersonation provider',
);
});
it('should throw an error if no targetSA is provided', () => {
const config: MCPServerConfig = {
url: 'https://my-iap-service.run.app',
targetAudience: 'my-audience',
};
expect(() => new ServiceAccountImpersonationProvider(config)).toThrow(
'targetServiceAccount must be provided for the Service Account Impersonation provider',
);
});
it('should correctly get tokens for a valid config', async () => {
const mockToken = 'mock-id-token-123';
mockRequest.mockResolvedValue({ data: { token: mockToken } });
const provider = new ServiceAccountImpersonationProvider(defaultSAConfig);
const tokens = await provider.tokens();
expect(tokens).toBeDefined();
expect(tokens?.access_token).toBe(mockToken);
expect(tokens?.token_type).toBe('Bearer');
});
it('should return undefined if token acquisition fails', async () => {
mockRequest.mockResolvedValue({ data: { token: null } });
const provider = new ServiceAccountImpersonationProvider(defaultSAConfig);
const tokens = await provider.tokens();
expect(tokens).toBeUndefined();
});
it('should make a request with the correct parameters', async () => {
mockRequest.mockResolvedValue({ data: { token: 'test-token' } });
const provider = new ServiceAccountImpersonationProvider(defaultSAConfig);
await provider.tokens();
expect(mockRequest).toHaveBeenCalledWith({
url: 'https://iamcredentials.googleapis.com/v1/projects/-/serviceAccounts/my-sa:generateIdToken',
method: 'POST',
data: {
audience: 'my-audience',
includeEmail: true,
},
});
});
it('should return a cached token if it is not expired', async () => {
const provider = new ServiceAccountImpersonationProvider(defaultSAConfig);
vi.useFakeTimers();
// jwt payload with exp set to 1 hour from now
const payload = { exp: Math.floor(Date.now() / 1000) + 3600 };
const jwt = `header.${Buffer.from(JSON.stringify(payload)).toString('base64')}.signature`;
mockRequest.mockResolvedValue({ data: { token: jwt } });
const firstTokens = await provider.tokens();
expect(firstTokens?.access_token).toBe(jwt);
expect(mockRequest).toHaveBeenCalledTimes(1);
// Advance time by 30 minutes
vi.advanceTimersByTime(1800 * 1000);
// Seturn cached token
const secondTokens = await provider.tokens();
expect(secondTokens).toBe(firstTokens);
expect(mockRequest).toHaveBeenCalledTimes(1);
vi.useRealTimers();
});
it('should fetch a new token if the cached token is expired (using fake timers)', async () => {
const provider = new ServiceAccountImpersonationProvider(defaultSAConfig);
vi.useFakeTimers();
// Get and cache a token that expires in 1 second
const expiredPayload = { exp: Math.floor(Date.now() / 1000) + 1 };
const expiredJwt = `header.${Buffer.from(JSON.stringify(expiredPayload)).toString('base64')}.signature`;
mockRequest.mockResolvedValue({ data: { token: expiredJwt } });
const firstTokens = await provider.tokens();
expect(firstTokens?.access_token).toBe(expiredJwt);
expect(mockRequest).toHaveBeenCalledTimes(1);
// Prepare the mock for the *next* call
const newPayload = { exp: Math.floor(Date.now() / 1000) + 3600 };
const newJwt = `header.${Buffer.from(JSON.stringify(newPayload)).toString('base64')}.signature`;
mockRequest.mockResolvedValue({ data: { token: newJwt } });
vi.advanceTimersByTime(1001);
const newTokens = await provider.tokens();
expect(newTokens?.access_token).toBe(newJwt);
expect(newTokens?.access_token).not.toBe(expiredJwt);
expect(mockRequest).toHaveBeenCalledTimes(2); // Confirms a new fetch
vi.useRealTimers();
});
});

View File

@@ -0,0 +1,171 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import type {
OAuthClientInformation,
OAuthClientInformationFull,
OAuthClientMetadata,
OAuthTokens,
} from '@modelcontextprotocol/sdk/shared/auth.js';
import { GoogleAuth } from 'google-auth-library';
import type { MCPServerConfig } from '../config/config.js';
import type { OAuthClientProvider } from '@modelcontextprotocol/sdk/client/auth.js';
const fiveMinBufferMs = 5 * 60 * 1000;
function createIamApiUrl(targetSA: string): string {
return `https://iamcredentials.googleapis.com/v1/projects/-/serviceAccounts/${encodeURIComponent(targetSA)}:generateIdToken`;
}
export class ServiceAccountImpersonationProvider
implements OAuthClientProvider
{
private readonly targetServiceAccount: string;
private readonly targetAudience: string; // OAuth Client Id
private readonly auth: GoogleAuth;
private cachedToken?: OAuthTokens;
private tokenExpiryTime?: number;
// Properties required by OAuthClientProvider, with no-op values
readonly redirectUrl = '';
readonly clientMetadata: OAuthClientMetadata = {
client_name: 'Gemini CLI (Service Account Impersonation)',
redirect_uris: [],
grant_types: [],
response_types: [],
token_endpoint_auth_method: 'none',
};
private _clientInformation?: OAuthClientInformationFull;
constructor(private readonly config: MCPServerConfig) {
// This check is done in mcp-client.ts. This is just an additional check.
if (!this.config.httpUrl && !this.config.url) {
throw new Error(
'A url or httpUrl must be provided for the Service Account Impersonation provider',
);
}
if (!config.targetAudience) {
throw new Error(
'targetAudience must be provided for the Service Account Impersonation provider',
);
}
this.targetAudience = config.targetAudience;
if (!config.targetServiceAccount) {
throw new Error(
'targetServiceAccount must be provided for the Service Account Impersonation provider',
);
}
this.targetServiceAccount = config.targetServiceAccount;
this.auth = new GoogleAuth();
}
clientInformation(): OAuthClientInformation | undefined {
return this._clientInformation;
}
saveClientInformation(clientInformation: OAuthClientInformationFull): void {
this._clientInformation = clientInformation;
}
async tokens(): Promise<OAuthTokens | undefined> {
// 1. Check if we have a valid, non-expired cached token.
if (
this.cachedToken &&
this.tokenExpiryTime &&
Date.now() < this.tokenExpiryTime - fiveMinBufferMs
) {
return this.cachedToken;
}
// 2. Clear any invalid/expired cache.
this.cachedToken = undefined;
this.tokenExpiryTime = undefined;
// 3. Fetch a new ID token.
const client = await this.auth.getClient();
const url = createIamApiUrl(this.targetServiceAccount);
let idToken: string;
try {
const res = await client.request<{ token: string }>({
url,
method: 'POST',
data: {
audience: this.targetAudience,
includeEmail: true,
},
});
idToken = res.data.token;
if (!idToken || idToken.length === 0) {
console.error('Failed to get ID token from Google');
return undefined;
}
} catch (e) {
console.error('Failed to fetch ID token from Google:', e);
return undefined;
}
const expiryTime = this.parseTokenExpiry(idToken);
// Note: We are placing the OIDC ID Token into the `access_token` field.
// This is because the CLI uses this field to construct the
// `Authorization: Bearer <token>` header, which is the correct way to
// present an ID token.
const newTokens: OAuthTokens = {
access_token: idToken,
token_type: 'Bearer',
};
if (expiryTime) {
this.tokenExpiryTime = expiryTime;
this.cachedToken = newTokens;
}
return newTokens;
}
saveTokens(_tokens: OAuthTokens): void {
// No-op
}
redirectToAuthorization(_authorizationUrl: URL): void {
// No-op
}
saveCodeVerifier(_codeVerifier: string): void {
// No-op
}
codeVerifier(): string {
// No-op
return '';
}
/**
* Parses a JWT string to extract its expiry time.
* @param idToken The JWT ID token.
* @returns The expiry time in **milliseconds**, or undefined if parsing fails.
*/
private parseTokenExpiry(idToken: string): number | undefined {
try {
const payload = JSON.parse(
Buffer.from(idToken.split('.')[1], 'base64').toString(),
);
if (payload && typeof payload.exp === 'number') {
return payload.exp * 1000; // Convert seconds to milliseconds
}
} catch (e) {
console.error('Failed to parse ID token for expiry time with error:', e);
}
// Return undefined if try block fails or 'exp' is missing/invalid
return undefined;
}
}

View File

@@ -53,7 +53,7 @@ describe('BaseTokenStorage', () => {
let storage: TestTokenStorage;
beforeEach(() => {
storage = new TestTokenStorage();
storage = new TestTokenStorage('gemini-cli-mcp-oauth');
});
describe('validateCredentials', () => {

View File

@@ -9,7 +9,7 @@ import type { TokenStorage, OAuthCredentials } from './types.js';
export abstract class BaseTokenStorage implements TokenStorage {
protected readonly serviceName: string;
constructor(serviceName: string = 'gemini-cli-mcp-oauth') {
constructor(serviceName: string) {
this.serviceName = serviceName;
}

View File

@@ -0,0 +1,323 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { describe, it, expect, beforeEach, afterEach, vi } from 'vitest';
import { promises as fs } from 'node:fs';
import * as path from 'node:path';
import { FileTokenStorage } from './file-token-storage.js';
import type { OAuthCredentials } from './types.js';
vi.mock('node:fs', () => ({
promises: {
readFile: vi.fn(),
writeFile: vi.fn(),
unlink: vi.fn(),
mkdir: vi.fn(),
},
}));
vi.mock('node:os', () => ({
default: {
homedir: vi.fn(() => '/home/test'),
hostname: vi.fn(() => 'test-host'),
userInfo: vi.fn(() => ({ username: 'test-user' })),
},
homedir: vi.fn(() => '/home/test'),
hostname: vi.fn(() => 'test-host'),
userInfo: vi.fn(() => ({ username: 'test-user' })),
}));
describe('FileTokenStorage', () => {
let storage: FileTokenStorage;
const mockFs = fs as unknown as {
readFile: ReturnType<typeof vi.fn>;
writeFile: ReturnType<typeof vi.fn>;
unlink: ReturnType<typeof vi.fn>;
mkdir: ReturnType<typeof vi.fn>;
};
const existingCredentials: OAuthCredentials = {
serverName: 'existing-server',
token: {
accessToken: 'existing-token',
tokenType: 'Bearer',
},
updatedAt: Date.now() - 10000,
};
beforeEach(() => {
vi.clearAllMocks();
storage = new FileTokenStorage('test-storage');
});
afterEach(() => {
vi.clearAllMocks();
});
describe('getCredentials', () => {
it('should throw error when file does not exist', async () => {
mockFs.readFile.mockRejectedValue({ code: 'ENOENT' });
await expect(storage.getCredentials('test-server')).rejects.toThrow(
'Token file does not exist',
);
});
it('should return null for expired tokens', async () => {
const credentials: OAuthCredentials = {
serverName: 'test-server',
token: {
accessToken: 'access-token',
tokenType: 'Bearer',
expiresAt: Date.now() - 3600000,
},
updatedAt: Date.now(),
};
const encryptedData = storage['encrypt'](
JSON.stringify({ 'test-server': credentials }),
);
mockFs.readFile.mockResolvedValue(encryptedData);
const result = await storage.getCredentials('test-server');
expect(result).toBeNull();
});
it('should return credentials for valid tokens', async () => {
const credentials: OAuthCredentials = {
serverName: 'test-server',
token: {
accessToken: 'access-token',
tokenType: 'Bearer',
expiresAt: Date.now() + 3600000,
},
updatedAt: Date.now(),
};
const encryptedData = storage['encrypt'](
JSON.stringify({ 'test-server': credentials }),
);
mockFs.readFile.mockResolvedValue(encryptedData);
const result = await storage.getCredentials('test-server');
expect(result).toEqual(credentials);
});
it('should throw error for corrupted files', async () => {
mockFs.readFile.mockResolvedValue('corrupted-data');
await expect(storage.getCredentials('test-server')).rejects.toThrow(
'Token file corrupted',
);
});
});
describe('setCredentials', () => {
it('should save credentials with encryption', async () => {
const encryptedData = storage['encrypt'](
JSON.stringify({ 'existing-server': existingCredentials }),
);
mockFs.readFile.mockResolvedValue(encryptedData);
mockFs.mkdir.mockResolvedValue(undefined);
mockFs.writeFile.mockResolvedValue(undefined);
const credentials: OAuthCredentials = {
serverName: 'test-server',
token: {
accessToken: 'access-token',
tokenType: 'Bearer',
},
updatedAt: Date.now(),
};
await storage.setCredentials(credentials);
expect(mockFs.mkdir).toHaveBeenCalledWith(
path.join('/home/test', '.qwen'),
{ recursive: true, mode: 0o700 },
);
expect(mockFs.writeFile).toHaveBeenCalled();
const writeCall = mockFs.writeFile.mock.calls[0];
expect(writeCall[1]).toMatch(/^[0-9a-f]+:[0-9a-f]+:[0-9a-f]+$/);
expect(writeCall[2]).toEqual({ mode: 0o600 });
});
it('should update existing credentials', async () => {
const encryptedData = storage['encrypt'](
JSON.stringify({ 'existing-server': existingCredentials }),
);
mockFs.readFile.mockResolvedValue(encryptedData);
mockFs.writeFile.mockResolvedValue(undefined);
const newCredentials: OAuthCredentials = {
serverName: 'test-server',
token: {
accessToken: 'new-token',
tokenType: 'Bearer',
},
updatedAt: Date.now(),
};
await storage.setCredentials(newCredentials);
expect(mockFs.writeFile).toHaveBeenCalled();
const writeCall = mockFs.writeFile.mock.calls[0];
const decrypted = storage['decrypt'](writeCall[1]);
const saved = JSON.parse(decrypted);
expect(saved['existing-server']).toEqual(existingCredentials);
expect(saved['test-server'].token.accessToken).toBe('new-token');
});
});
describe('deleteCredentials', () => {
it('should throw when credentials do not exist', async () => {
mockFs.readFile.mockRejectedValue({ code: 'ENOENT' });
await expect(storage.deleteCredentials('test-server')).rejects.toThrow(
'Token file does not exist',
);
});
it('should delete file when last credential is removed', async () => {
const credentials: OAuthCredentials = {
serverName: 'test-server',
token: {
accessToken: 'access-token',
tokenType: 'Bearer',
},
updatedAt: Date.now(),
};
const encryptedData = storage['encrypt'](
JSON.stringify({ 'test-server': credentials }),
);
mockFs.readFile.mockResolvedValue(encryptedData);
mockFs.unlink.mockResolvedValue(undefined);
await storage.deleteCredentials('test-server');
expect(mockFs.unlink).toHaveBeenCalledWith(
path.join('/home/test', '.qwen', 'mcp-oauth-tokens-v2.json'),
);
});
it('should update file when other credentials remain', async () => {
const credentials1: OAuthCredentials = {
serverName: 'server1',
token: {
accessToken: 'token1',
tokenType: 'Bearer',
},
updatedAt: Date.now(),
};
const credentials2: OAuthCredentials = {
serverName: 'server2',
token: {
accessToken: 'token2',
tokenType: 'Bearer',
},
updatedAt: Date.now(),
};
const encryptedData = storage['encrypt'](
JSON.stringify({ server1: credentials1, server2: credentials2 }),
);
mockFs.readFile.mockResolvedValue(encryptedData);
mockFs.writeFile.mockResolvedValue(undefined);
await storage.deleteCredentials('server1');
expect(mockFs.writeFile).toHaveBeenCalled();
expect(mockFs.unlink).not.toHaveBeenCalled();
const writeCall = mockFs.writeFile.mock.calls[0];
const decrypted = storage['decrypt'](writeCall[1]);
const saved = JSON.parse(decrypted);
expect(saved['server1']).toBeUndefined();
expect(saved['server2']).toEqual(credentials2);
});
});
describe('listServers', () => {
it('should throw error when file does not exist', async () => {
mockFs.readFile.mockRejectedValue({ code: 'ENOENT' });
await expect(storage.listServers()).rejects.toThrow(
'Token file does not exist',
);
});
it('should return list of server names', async () => {
const credentials: Record<string, OAuthCredentials> = {
server1: {
serverName: 'server1',
token: { accessToken: 'token1', tokenType: 'Bearer' },
updatedAt: Date.now(),
},
server2: {
serverName: 'server2',
token: { accessToken: 'token2', tokenType: 'Bearer' },
updatedAt: Date.now(),
},
};
const encryptedData = storage['encrypt'](JSON.stringify(credentials));
mockFs.readFile.mockResolvedValue(encryptedData);
const result = await storage.listServers();
expect(result).toEqual(['server1', 'server2']);
});
});
describe('clearAll', () => {
it('should delete the token file', async () => {
mockFs.unlink.mockResolvedValue(undefined);
await storage.clearAll();
expect(mockFs.unlink).toHaveBeenCalledWith(
path.join('/home/test', '.qwen', 'mcp-oauth-tokens-v2.json'),
);
});
it('should not throw when file does not exist', async () => {
mockFs.unlink.mockRejectedValue({ code: 'ENOENT' });
await expect(storage.clearAll()).resolves.not.toThrow();
});
});
describe('encryption', () => {
it('should encrypt and decrypt data correctly', () => {
const original = 'test-data-123';
const encrypted = storage['encrypt'](original);
const decrypted = storage['decrypt'](encrypted);
expect(decrypted).toBe(original);
expect(encrypted).not.toBe(original);
expect(encrypted).toMatch(/^[0-9a-f]+:[0-9a-f]+:[0-9a-f]+$/);
});
it('should produce different encrypted output each time', () => {
const original = 'test-data';
const encrypted1 = storage['encrypt'](original);
const encrypted2 = storage['encrypt'](original);
expect(encrypted1).not.toBe(encrypted2);
expect(storage['decrypt'](encrypted1)).toBe(original);
expect(storage['decrypt'](encrypted2)).toBe(original);
});
it('should throw on invalid encrypted data format', () => {
expect(() => storage['decrypt']('invalid-data')).toThrow(
'Invalid encrypted data format',
);
});
});
});

View File

@@ -0,0 +1,184 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { promises as fs } from 'node:fs';
import * as path from 'node:path';
import * as os from 'node:os';
import * as crypto from 'node:crypto';
import { BaseTokenStorage } from './base-token-storage.js';
import type { OAuthCredentials } from './types.js';
export class FileTokenStorage extends BaseTokenStorage {
private readonly tokenFilePath: string;
private readonly encryptionKey: Buffer;
constructor(serviceName: string) {
super(serviceName);
const configDir = path.join(os.homedir(), '.qwen');
this.tokenFilePath = path.join(configDir, 'mcp-oauth-tokens-v2.json');
this.encryptionKey = this.deriveEncryptionKey();
}
private deriveEncryptionKey(): Buffer {
const salt = `${os.hostname()}-${os.userInfo().username}-qwen-code`;
return crypto.scryptSync('qwen-code-oauth', salt, 32);
}
private encrypt(text: string): string {
const iv = crypto.randomBytes(16);
const cipher = crypto.createCipheriv('aes-256-gcm', this.encryptionKey, iv);
let encrypted = cipher.update(text, 'utf8', 'hex');
encrypted += cipher.final('hex');
const authTag = cipher.getAuthTag();
return iv.toString('hex') + ':' + authTag.toString('hex') + ':' + encrypted;
}
private decrypt(encryptedData: string): string {
const parts = encryptedData.split(':');
if (parts.length !== 3) {
throw new Error('Invalid encrypted data format');
}
const iv = Buffer.from(parts[0], 'hex');
const authTag = Buffer.from(parts[1], 'hex');
const encrypted = parts[2];
const decipher = crypto.createDecipheriv(
'aes-256-gcm',
this.encryptionKey,
iv,
);
decipher.setAuthTag(authTag);
let decrypted = decipher.update(encrypted, 'hex', 'utf8');
decrypted += decipher.final('utf8');
return decrypted;
}
private async ensureDirectoryExists(): Promise<void> {
const dir = path.dirname(this.tokenFilePath);
await fs.mkdir(dir, { recursive: true, mode: 0o700 });
}
private async loadTokens(): Promise<Map<string, OAuthCredentials>> {
try {
const data = await fs.readFile(this.tokenFilePath, 'utf-8');
const decrypted = this.decrypt(data);
const tokens = JSON.parse(decrypted) as Record<string, OAuthCredentials>;
return new Map(Object.entries(tokens));
} catch (error: unknown) {
const err = error as NodeJS.ErrnoException & { message?: string };
if (err.code === 'ENOENT') {
throw new Error('Token file does not exist');
}
if (
err.message?.includes('Invalid encrypted data format') ||
err.message?.includes(
'Unsupported state or unable to authenticate data',
)
) {
throw new Error('Token file corrupted');
}
throw error;
}
}
private async saveTokens(
tokens: Map<string, OAuthCredentials>,
): Promise<void> {
await this.ensureDirectoryExists();
const data = Object.fromEntries(tokens);
const json = JSON.stringify(data, null, 2);
const encrypted = this.encrypt(json);
await fs.writeFile(this.tokenFilePath, encrypted, { mode: 0o600 });
}
async getCredentials(serverName: string): Promise<OAuthCredentials | null> {
const tokens = await this.loadTokens();
const credentials = tokens.get(serverName);
if (!credentials) {
return null;
}
if (this.isTokenExpired(credentials)) {
return null;
}
return credentials;
}
async setCredentials(credentials: OAuthCredentials): Promise<void> {
this.validateCredentials(credentials);
const tokens = await this.loadTokens();
const updatedCredentials: OAuthCredentials = {
...credentials,
updatedAt: Date.now(),
};
tokens.set(credentials.serverName, updatedCredentials);
await this.saveTokens(tokens);
}
async deleteCredentials(serverName: string): Promise<void> {
const tokens = await this.loadTokens();
if (!tokens.has(serverName)) {
throw new Error(`No credentials found for ${serverName}`);
}
tokens.delete(serverName);
if (tokens.size === 0) {
try {
await fs.unlink(this.tokenFilePath);
} catch (error: unknown) {
const err = error as NodeJS.ErrnoException;
if (err.code !== 'ENOENT') {
throw error;
}
}
} else {
await this.saveTokens(tokens);
}
}
async listServers(): Promise<string[]> {
const tokens = await this.loadTokens();
return Array.from(tokens.keys());
}
async getAllCredentials(): Promise<Map<string, OAuthCredentials>> {
const tokens = await this.loadTokens();
const result = new Map<string, OAuthCredentials>();
for (const [serverName, credentials] of tokens) {
if (!this.isTokenExpired(credentials)) {
result.set(serverName, credentials);
}
}
return result;
}
async clearAll(): Promise<void> {
try {
await fs.unlink(this.tokenFilePath);
} catch (error: unknown) {
const err = error as NodeJS.ErrnoException;
if (err.code !== 'ENOENT') {
throw error;
}
}
}
}

View File

@@ -0,0 +1,274 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { describe, it, expect, beforeEach, afterEach, vi } from 'vitest';
import { HybridTokenStorage } from './hybrid-token-storage.js';
import { KeychainTokenStorage } from './keychain-token-storage.js';
import { FileTokenStorage } from './file-token-storage.js';
import { type OAuthCredentials, TokenStorageType } from './types.js';
vi.mock('./keychain-token-storage.js', () => ({
KeychainTokenStorage: vi.fn().mockImplementation(() => ({
isAvailable: vi.fn(),
getCredentials: vi.fn(),
setCredentials: vi.fn(),
deleteCredentials: vi.fn(),
listServers: vi.fn(),
getAllCredentials: vi.fn(),
clearAll: vi.fn(),
})),
}));
vi.mock('./file-token-storage.js', () => ({
FileTokenStorage: vi.fn().mockImplementation(() => ({
getCredentials: vi.fn(),
setCredentials: vi.fn(),
deleteCredentials: vi.fn(),
listServers: vi.fn(),
getAllCredentials: vi.fn(),
clearAll: vi.fn(),
})),
}));
interface MockStorage {
isAvailable?: ReturnType<typeof vi.fn>;
getCredentials: ReturnType<typeof vi.fn>;
setCredentials: ReturnType<typeof vi.fn>;
deleteCredentials: ReturnType<typeof vi.fn>;
listServers: ReturnType<typeof vi.fn>;
getAllCredentials: ReturnType<typeof vi.fn>;
clearAll: ReturnType<typeof vi.fn>;
}
describe('HybridTokenStorage', () => {
let storage: HybridTokenStorage;
let mockKeychainStorage: MockStorage;
let mockFileStorage: MockStorage;
const originalEnv = process.env;
beforeEach(() => {
vi.clearAllMocks();
process.env = { ...originalEnv };
// Create mock instances before creating HybridTokenStorage
mockKeychainStorage = {
isAvailable: vi.fn(),
getCredentials: vi.fn(),
setCredentials: vi.fn(),
deleteCredentials: vi.fn(),
listServers: vi.fn(),
getAllCredentials: vi.fn(),
clearAll: vi.fn(),
};
mockFileStorage = {
getCredentials: vi.fn(),
setCredentials: vi.fn(),
deleteCredentials: vi.fn(),
listServers: vi.fn(),
getAllCredentials: vi.fn(),
clearAll: vi.fn(),
};
(
KeychainTokenStorage as unknown as ReturnType<typeof vi.fn>
).mockImplementation(() => mockKeychainStorage);
(
FileTokenStorage as unknown as ReturnType<typeof vi.fn>
).mockImplementation(() => mockFileStorage);
storage = new HybridTokenStorage('test-service');
});
afterEach(() => {
process.env = originalEnv;
});
describe('storage selection', () => {
it('should use keychain when available', async () => {
mockKeychainStorage.isAvailable!.mockResolvedValue(true);
mockKeychainStorage.getCredentials.mockResolvedValue(null);
await storage.getCredentials('test-server');
expect(mockKeychainStorage.isAvailable).toHaveBeenCalled();
expect(mockKeychainStorage.getCredentials).toHaveBeenCalledWith(
'test-server',
);
expect(await storage.getStorageType()).toBe(TokenStorageType.KEYCHAIN);
});
it('should use file storage when GEMINI_FORCE_FILE_STORAGE is set', async () => {
process.env['GEMINI_FORCE_FILE_STORAGE'] = 'true';
mockFileStorage.getCredentials.mockResolvedValue(null);
await storage.getCredentials('test-server');
expect(mockKeychainStorage.isAvailable).not.toHaveBeenCalled();
expect(mockFileStorage.getCredentials).toHaveBeenCalledWith(
'test-server',
);
expect(await storage.getStorageType()).toBe(
TokenStorageType.ENCRYPTED_FILE,
);
});
it('should fall back to file storage when keychain is unavailable', async () => {
mockKeychainStorage.isAvailable!.mockResolvedValue(false);
mockFileStorage.getCredentials.mockResolvedValue(null);
await storage.getCredentials('test-server');
expect(mockKeychainStorage.isAvailable).toHaveBeenCalled();
expect(mockFileStorage.getCredentials).toHaveBeenCalledWith(
'test-server',
);
expect(await storage.getStorageType()).toBe(
TokenStorageType.ENCRYPTED_FILE,
);
});
it('should fall back to file storage when keychain throws error', async () => {
mockKeychainStorage.isAvailable!.mockRejectedValue(
new Error('Keychain error'),
);
mockFileStorage.getCredentials.mockResolvedValue(null);
await storage.getCredentials('test-server');
expect(mockKeychainStorage.isAvailable).toHaveBeenCalled();
expect(mockFileStorage.getCredentials).toHaveBeenCalledWith(
'test-server',
);
expect(await storage.getStorageType()).toBe(
TokenStorageType.ENCRYPTED_FILE,
);
});
it('should cache storage selection', async () => {
mockKeychainStorage.isAvailable!.mockResolvedValue(true);
mockKeychainStorage.getCredentials.mockResolvedValue(null);
await storage.getCredentials('test-server');
await storage.getCredentials('another-server');
expect(mockKeychainStorage.isAvailable).toHaveBeenCalledTimes(1);
});
});
describe('getCredentials', () => {
it('should delegate to selected storage', async () => {
const credentials: OAuthCredentials = {
serverName: 'test-server',
token: {
accessToken: 'access-token',
tokenType: 'Bearer',
},
updatedAt: Date.now(),
};
mockKeychainStorage.isAvailable!.mockResolvedValue(true);
mockKeychainStorage.getCredentials.mockResolvedValue(credentials);
const result = await storage.getCredentials('test-server');
expect(result).toEqual(credentials);
expect(mockKeychainStorage.getCredentials).toHaveBeenCalledWith(
'test-server',
);
});
});
describe('setCredentials', () => {
it('should delegate to selected storage', async () => {
const credentials: OAuthCredentials = {
serverName: 'test-server',
token: {
accessToken: 'access-token',
tokenType: 'Bearer',
},
updatedAt: Date.now(),
};
mockKeychainStorage.isAvailable!.mockResolvedValue(true);
mockKeychainStorage.setCredentials.mockResolvedValue(undefined);
await storage.setCredentials(credentials);
expect(mockKeychainStorage.setCredentials).toHaveBeenCalledWith(
credentials,
);
});
});
describe('deleteCredentials', () => {
it('should delegate to selected storage', async () => {
mockKeychainStorage.isAvailable!.mockResolvedValue(true);
mockKeychainStorage.deleteCredentials.mockResolvedValue(undefined);
await storage.deleteCredentials('test-server');
expect(mockKeychainStorage.deleteCredentials).toHaveBeenCalledWith(
'test-server',
);
});
});
describe('listServers', () => {
it('should delegate to selected storage', async () => {
const servers = ['server1', 'server2'];
mockKeychainStorage.isAvailable!.mockResolvedValue(true);
mockKeychainStorage.listServers.mockResolvedValue(servers);
const result = await storage.listServers();
expect(result).toEqual(servers);
expect(mockKeychainStorage.listServers).toHaveBeenCalled();
});
});
describe('getAllCredentials', () => {
it('should delegate to selected storage', async () => {
const credentialsMap = new Map([
[
'server1',
{
serverName: 'server1',
token: { accessToken: 'token1', tokenType: 'Bearer' },
updatedAt: Date.now(),
},
],
[
'server2',
{
serverName: 'server2',
token: { accessToken: 'token2', tokenType: 'Bearer' },
updatedAt: Date.now(),
},
],
]);
mockKeychainStorage.isAvailable!.mockResolvedValue(true);
mockKeychainStorage.getAllCredentials.mockResolvedValue(credentialsMap);
const result = await storage.getAllCredentials();
expect(result).toEqual(credentialsMap);
expect(mockKeychainStorage.getAllCredentials).toHaveBeenCalled();
});
});
describe('clearAll', () => {
it('should delegate to selected storage', async () => {
mockKeychainStorage.isAvailable!.mockResolvedValue(true);
mockKeychainStorage.clearAll.mockResolvedValue(undefined);
await storage.clearAll();
expect(mockKeychainStorage.clearAll).toHaveBeenCalled();
});
});
});

View File

@@ -0,0 +1,97 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { BaseTokenStorage } from './base-token-storage.js';
import { FileTokenStorage } from './file-token-storage.js';
import type { TokenStorage, OAuthCredentials } from './types.js';
import { TokenStorageType } from './types.js';
const FORCE_FILE_STORAGE_ENV_VAR = 'GEMINI_FORCE_FILE_STORAGE';
export class HybridTokenStorage extends BaseTokenStorage {
private storage: TokenStorage | null = null;
private storageType: TokenStorageType | null = null;
private storageInitPromise: Promise<TokenStorage> | null = null;
constructor(serviceName: string) {
super(serviceName);
}
private async initializeStorage(): Promise<TokenStorage> {
const forceFileStorage = process.env[FORCE_FILE_STORAGE_ENV_VAR] === 'true';
if (!forceFileStorage) {
try {
const { KeychainTokenStorage } = await import(
'./keychain-token-storage.js'
);
const keychainStorage = new KeychainTokenStorage(this.serviceName);
const isAvailable = await keychainStorage.isAvailable();
if (isAvailable) {
this.storage = keychainStorage;
this.storageType = TokenStorageType.KEYCHAIN;
return this.storage;
}
} catch (_e) {
// Fallback to file storage if keychain fails to initialize
}
}
this.storage = new FileTokenStorage(this.serviceName);
this.storageType = TokenStorageType.ENCRYPTED_FILE;
return this.storage;
}
private async getStorage(): Promise<TokenStorage> {
if (this.storage !== null) {
return this.storage;
}
// Use a single initialization promise to avoid race conditions
if (!this.storageInitPromise) {
this.storageInitPromise = this.initializeStorage();
}
// Wait for initialization to complete
return await this.storageInitPromise;
}
async getCredentials(serverName: string): Promise<OAuthCredentials | null> {
const storage = await this.getStorage();
return storage.getCredentials(serverName);
}
async setCredentials(credentials: OAuthCredentials): Promise<void> {
const storage = await this.getStorage();
await storage.setCredentials(credentials);
}
async deleteCredentials(serverName: string): Promise<void> {
const storage = await this.getStorage();
await storage.deleteCredentials(serverName);
}
async listServers(): Promise<string[]> {
const storage = await this.getStorage();
return storage.listServers();
}
async getAllCredentials(): Promise<Map<string, OAuthCredentials>> {
const storage = await this.getStorage();
return storage.getAllCredentials();
}
async clearAll(): Promise<void> {
const storage = await this.getStorage();
await storage.clearAll();
}
async getStorageType(): Promise<TokenStorageType> {
await this.getStorage();
return this.storageType!;
}
}

View File

@@ -0,0 +1,14 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
export * from './types.js';
export * from './base-token-storage.js';
export * from './file-token-storage.js';
export * from './hybrid-token-storage.js';
export const DEFAULT_SERVICE_NAME = 'qwen-code-oauth';
export const FORCE_ENCRYPTED_FILE_ENV_VAR =
'QWEN_CODE_FORCE_ENCRYPTED_FILE_STORAGE';

View File

@@ -0,0 +1,352 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { describe, it, expect, beforeEach, afterEach, vi } from 'vitest';
import type { KeychainTokenStorage } from './keychain-token-storage.js';
import type { OAuthCredentials } from './types.js';
// Hoist the mock to be available in the vi.mock factory
const mockKeytar = vi.hoisted(() => ({
getPassword: vi.fn(),
setPassword: vi.fn(),
deletePassword: vi.fn(),
findCredentials: vi.fn(),
}));
const mockServiceName = 'service-name';
const mockCryptoRandomBytesString = 'random-string';
// Mock the dynamic import of 'keytar'
vi.mock('keytar', () => ({
default: mockKeytar,
}));
vi.mock('node:crypto', () => ({
randomBytes: vi.fn(() => ({
toString: vi.fn(() => mockCryptoRandomBytesString),
})),
}));
describe('KeychainTokenStorage', () => {
let storage: KeychainTokenStorage;
beforeEach(async () => {
vi.resetAllMocks();
// Reset the internal state of the keychain-token-storage module
vi.resetModules();
const { KeychainTokenStorage } = await import(
'./keychain-token-storage.js'
);
storage = new KeychainTokenStorage(mockServiceName);
});
afterEach(() => {
vi.restoreAllMocks();
vi.useRealTimers();
});
const validCredentials = {
serverName: 'test-server',
token: {
accessToken: 'access-token',
tokenType: 'Bearer',
expiresAt: Date.now() + 3600000,
},
updatedAt: Date.now(),
} as OAuthCredentials;
describe('checkKeychainAvailability', () => {
it('should return true if keytar is available and functional', async () => {
mockKeytar.setPassword.mockResolvedValue(undefined);
mockKeytar.getPassword.mockResolvedValue('test');
mockKeytar.deletePassword.mockResolvedValue(true);
const isAvailable = await storage.checkKeychainAvailability();
expect(isAvailable).toBe(true);
expect(mockKeytar.setPassword).toHaveBeenCalledWith(
mockServiceName,
`__keychain_test__${mockCryptoRandomBytesString}`,
'test',
);
expect(mockKeytar.getPassword).toHaveBeenCalledWith(
mockServiceName,
`__keychain_test__${mockCryptoRandomBytesString}`,
);
expect(mockKeytar.deletePassword).toHaveBeenCalledWith(
mockServiceName,
`__keychain_test__${mockCryptoRandomBytesString}`,
);
});
it('should return false if keytar fails to set password', async () => {
mockKeytar.setPassword.mockRejectedValue(new Error('write error'));
const isAvailable = await storage.checkKeychainAvailability();
expect(isAvailable).toBe(false);
});
it('should return false if retrieved password does not match', async () => {
mockKeytar.setPassword.mockResolvedValue(undefined);
mockKeytar.getPassword.mockResolvedValue('wrong-password');
mockKeytar.deletePassword.mockResolvedValue(true);
const isAvailable = await storage.checkKeychainAvailability();
expect(isAvailable).toBe(false);
});
it('should cache the availability result', async () => {
mockKeytar.setPassword.mockResolvedValue(undefined);
mockKeytar.getPassword.mockResolvedValue('test');
mockKeytar.deletePassword.mockResolvedValue(true);
await storage.checkKeychainAvailability();
await storage.checkKeychainAvailability();
expect(mockKeytar.setPassword).toHaveBeenCalledTimes(1);
});
});
describe('with keychain unavailable', () => {
beforeEach(async () => {
// Force keychain to be unavailable
mockKeytar.setPassword.mockRejectedValue(new Error('keychain error'));
await storage.checkKeychainAvailability();
});
it('getCredentials should throw', async () => {
await expect(storage.getCredentials('server')).rejects.toThrow(
'Keychain is not available',
);
});
it('setCredentials should throw', async () => {
await expect(storage.setCredentials(validCredentials)).rejects.toThrow(
'Keychain is not available',
);
});
it('deleteCredentials should throw', async () => {
await expect(storage.deleteCredentials('server')).rejects.toThrow(
'Keychain is not available',
);
});
it('listServers should throw', async () => {
await expect(storage.listServers()).rejects.toThrow(
'Keychain is not available',
);
});
it('getAllCredentials should throw', async () => {
await expect(storage.getAllCredentials()).rejects.toThrow(
'Keychain is not available',
);
});
});
describe('with keychain available', () => {
beforeEach(async () => {
mockKeytar.setPassword.mockResolvedValue(undefined);
mockKeytar.getPassword.mockResolvedValue('test');
mockKeytar.deletePassword.mockResolvedValue(true);
await storage.checkKeychainAvailability();
// Reset mocks after availability check
vi.resetAllMocks();
});
describe('getCredentials', () => {
it('should return null if no credentials are found', async () => {
mockKeytar.getPassword.mockResolvedValue(null);
const result = await storage.getCredentials('test-server');
expect(result).toBeNull();
expect(mockKeytar.getPassword).toHaveBeenCalledWith(
mockServiceName,
'test-server',
);
});
it('should return credentials if found and not expired', async () => {
mockKeytar.getPassword.mockResolvedValue(
JSON.stringify(validCredentials),
);
const result = await storage.getCredentials('test-server');
expect(result).toEqual(validCredentials);
});
it('should return null if credentials have expired', async () => {
const expiredCreds = {
...validCredentials,
token: { ...validCredentials.token, expiresAt: Date.now() - 1000 },
};
mockKeytar.getPassword.mockResolvedValue(JSON.stringify(expiredCreds));
const result = await storage.getCredentials('test-server');
expect(result).toBeNull();
});
it('should throw if stored data is corrupted JSON', async () => {
mockKeytar.getPassword.mockResolvedValue('not-json');
await expect(storage.getCredentials('test-server')).rejects.toThrow(
'Failed to parse stored credentials for test-server',
);
});
});
describe('setCredentials', () => {
it('should save credentials to keychain', async () => {
vi.useFakeTimers();
mockKeytar.setPassword.mockResolvedValue(undefined);
await storage.setCredentials(validCredentials);
expect(mockKeytar.setPassword).toHaveBeenCalledWith(
mockServiceName,
'test-server',
JSON.stringify({ ...validCredentials, updatedAt: Date.now() }),
);
});
it('should throw if saving to keychain fails', async () => {
mockKeytar.setPassword.mockRejectedValue(
new Error('keychain write error'),
);
await expect(storage.setCredentials(validCredentials)).rejects.toThrow(
'keychain write error',
);
});
});
describe('deleteCredentials', () => {
it('should delete credentials from keychain', async () => {
mockKeytar.deletePassword.mockResolvedValue(true);
await storage.deleteCredentials('test-server');
expect(mockKeytar.deletePassword).toHaveBeenCalledWith(
mockServiceName,
'test-server',
);
});
it('should throw if no credentials were found to delete', async () => {
mockKeytar.deletePassword.mockResolvedValue(false);
await expect(storage.deleteCredentials('test-server')).rejects.toThrow(
'No credentials found for test-server',
);
});
it('should throw if deleting from keychain fails', async () => {
mockKeytar.deletePassword.mockRejectedValue(
new Error('keychain delete error'),
);
await expect(storage.deleteCredentials('test-server')).rejects.toThrow(
'keychain delete error',
);
});
});
describe('listServers', () => {
it('should return a list of server names', async () => {
mockKeytar.findCredentials.mockResolvedValue([
{ account: 'server1', password: '' },
{ account: 'server2', password: '' },
]);
const result = await storage.listServers();
expect(result).toEqual(['server1', 'server2']);
});
it('should not include internal test keys in the server list', async () => {
mockKeytar.findCredentials.mockResolvedValue([
{ account: 'server1', password: '' },
{
account: `__keychain_test__${mockCryptoRandomBytesString}`,
password: '',
},
{ account: 'server2', password: '' },
]);
const result = await storage.listServers();
expect(result).toEqual(['server1', 'server2']);
});
it('should return an empty array on error', async () => {
mockKeytar.findCredentials.mockRejectedValue(new Error('find error'));
const result = await storage.listServers();
expect(result).toEqual([]);
});
});
describe('getAllCredentials', () => {
it('should return a map of all valid credentials', async () => {
const creds2 = {
...validCredentials,
serverName: 'server2',
};
const expiredCreds = {
...validCredentials,
serverName: 'expired-server',
token: { ...validCredentials.token, expiresAt: Date.now() - 1000 },
};
const structurallyInvalidCreds = {
serverName: 'invalid-server',
};
mockKeytar.findCredentials.mockResolvedValue([
{
account: 'test-server',
password: JSON.stringify(validCredentials),
},
{ account: 'server2', password: JSON.stringify(creds2) },
{
account: 'expired-server',
password: JSON.stringify(expiredCreds),
},
{ account: 'bad-server', password: 'not-json' },
{
account: 'invalid-server',
password: JSON.stringify(structurallyInvalidCreds),
},
]);
const result = await storage.getAllCredentials();
expect(result.size).toBe(2);
expect(result.get('test-server')).toEqual(validCredentials);
expect(result.get('server2')).toEqual(creds2);
expect(result.has('expired-server')).toBe(false);
expect(result.has('bad-server')).toBe(false);
expect(result.has('invalid-server')).toBe(false);
});
});
describe('clearAll', () => {
it('should delete all credentials for the service', async () => {
mockKeytar.findCredentials.mockResolvedValue([
{ account: 'server1', password: '' },
{ account: 'server2', password: '' },
]);
mockKeytar.deletePassword.mockResolvedValue(true);
await storage.clearAll();
expect(mockKeytar.deletePassword).toHaveBeenCalledTimes(2);
expect(mockKeytar.deletePassword).toHaveBeenCalledWith(
mockServiceName,
'server1',
);
expect(mockKeytar.deletePassword).toHaveBeenCalledWith(
mockServiceName,
'server2',
);
});
it('should throw an aggregated error if deletions fail', async () => {
mockKeytar.findCredentials.mockResolvedValue([
{ account: 'server1', password: '' },
{ account: 'server2', password: '' },
]);
mockKeytar.deletePassword
.mockResolvedValueOnce(true)
.mockRejectedValueOnce(new Error('delete failed'));
await expect(storage.clearAll()).rejects.toThrow(
'Failed to clear some credentials: delete failed',
);
});
});
});
});

View File

@@ -0,0 +1,251 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import * as crypto from 'node:crypto';
import { BaseTokenStorage } from './base-token-storage.js';
import type { OAuthCredentials } from './types.js';
interface Keytar {
getPassword(service: string, account: string): Promise<string | null>;
setPassword(
service: string,
account: string,
password: string,
): Promise<void>;
deletePassword(service: string, account: string): Promise<boolean>;
findCredentials(
service: string,
): Promise<Array<{ account: string; password: string }>>;
}
const KEYCHAIN_TEST_PREFIX = '__keychain_test__';
export class KeychainTokenStorage extends BaseTokenStorage {
private keychainAvailable: boolean | null = null;
private keytarModule: Keytar | null = null;
private keytarLoadAttempted = false;
async getKeytar(): Promise<Keytar | null> {
// If we've already tried loading (successfully or not), return the result
if (this.keytarLoadAttempted) {
return this.keytarModule;
}
this.keytarLoadAttempted = true;
try {
// Try to import keytar without any timeout - let the OS handle it
const moduleName = 'keytar';
const module = await import(moduleName);
this.keytarModule = module.default || module;
} catch (error) {
console.error(error);
}
return this.keytarModule;
}
async getCredentials(serverName: string): Promise<OAuthCredentials | null> {
if (!(await this.checkKeychainAvailability())) {
throw new Error('Keychain is not available');
}
const keytar = await this.getKeytar();
if (!keytar) {
throw new Error('Keytar module not available');
}
try {
const sanitizedName = this.sanitizeServerName(serverName);
const data = await keytar.getPassword(this.serviceName, sanitizedName);
if (!data) {
return null;
}
const credentials = JSON.parse(data) as OAuthCredentials;
if (this.isTokenExpired(credentials)) {
return null;
}
return credentials;
} catch (error) {
if (error instanceof SyntaxError) {
throw new Error(`Failed to parse stored credentials for ${serverName}`);
}
throw error;
}
}
async setCredentials(credentials: OAuthCredentials): Promise<void> {
if (!(await this.checkKeychainAvailability())) {
throw new Error('Keychain is not available');
}
const keytar = await this.getKeytar();
if (!keytar) {
throw new Error('Keytar module not available');
}
this.validateCredentials(credentials);
const sanitizedName = this.sanitizeServerName(credentials.serverName);
const updatedCredentials: OAuthCredentials = {
...credentials,
updatedAt: Date.now(),
};
const data = JSON.stringify(updatedCredentials);
await keytar.setPassword(this.serviceName, sanitizedName, data);
}
async deleteCredentials(serverName: string): Promise<void> {
if (!(await this.checkKeychainAvailability())) {
throw new Error('Keychain is not available');
}
const keytar = await this.getKeytar();
if (!keytar) {
throw new Error('Keytar module not available');
}
const sanitizedName = this.sanitizeServerName(serverName);
const deleted = await keytar.deletePassword(
this.serviceName,
sanitizedName,
);
if (!deleted) {
throw new Error(`No credentials found for ${serverName}`);
}
}
async listServers(): Promise<string[]> {
if (!(await this.checkKeychainAvailability())) {
throw new Error('Keychain is not available');
}
const keytar = await this.getKeytar();
if (!keytar) {
throw new Error('Keytar module not available');
}
try {
const credentials = await keytar.findCredentials(this.serviceName);
return credentials
.filter((cred) => !cred.account.startsWith(KEYCHAIN_TEST_PREFIX))
.map((cred: { account: string }) => cred.account);
} catch (error) {
console.error('Failed to list servers from keychain:', error);
return [];
}
}
async getAllCredentials(): Promise<Map<string, OAuthCredentials>> {
if (!(await this.checkKeychainAvailability())) {
throw new Error('Keychain is not available');
}
const keytar = await this.getKeytar();
if (!keytar) {
throw new Error('Keytar module not available');
}
const result = new Map<string, OAuthCredentials>();
try {
const credentials = (
await keytar.findCredentials(this.serviceName)
).filter((c) => !c.account.startsWith(KEYCHAIN_TEST_PREFIX));
for (const cred of credentials) {
try {
const data = JSON.parse(cred.password) as OAuthCredentials;
if (!this.isTokenExpired(data)) {
result.set(cred.account, data);
}
} catch (error) {
console.error(
`Failed to parse credentials for ${cred.account}:`,
error,
);
}
}
} catch (error) {
console.error('Failed to get all credentials from keychain:', error);
}
return result;
}
async clearAll(): Promise<void> {
if (!(await this.checkKeychainAvailability())) {
throw new Error('Keychain is not available');
}
const servers = this.keytarModule
? await this.keytarModule
.findCredentials(this.serviceName)
.then((creds) => creds.map((c) => c.account))
.catch((error: Error) => {
throw new Error(
`Failed to list servers for clearing: ${error.message}`,
);
})
: [];
const errors: Error[] = [];
for (const server of servers) {
try {
await this.deleteCredentials(server);
} catch (error) {
errors.push(error as Error);
}
}
if (errors.length > 0) {
throw new Error(
`Failed to clear some credentials: ${errors.map((e) => e.message).join(', ')}`,
);
}
}
// Checks whether or not a set-get-delete cycle with the keychain works.
// Returns false if any operation fails.
async checkKeychainAvailability(): Promise<boolean> {
if (this.keychainAvailable !== null) {
return this.keychainAvailable;
}
try {
const keytar = await this.getKeytar();
if (!keytar) {
this.keychainAvailable = false;
return false;
}
const testAccount = `${KEYCHAIN_TEST_PREFIX}${crypto.randomBytes(8).toString('hex')}`;
const testPassword = 'test';
await keytar.setPassword(this.serviceName, testAccount, testPassword);
const retrieved = await keytar.getPassword(this.serviceName, testAccount);
const deleted = await keytar.deletePassword(
this.serviceName,
testAccount,
);
const success = deleted && retrieved === testPassword;
this.keychainAvailable = success;
return success;
} catch (_error) {
this.keychainAvailable = false;
return false;
}
}
async isAvailable(): Promise<boolean> {
return this.checkKeychainAvailability();
}
}

View File

@@ -35,3 +35,8 @@ export interface TokenStorage {
getAllCredentials(): Promise<Map<string, OAuthCredentials>>;
clearAll(): Promise<void>;
}
export enum TokenStorageType {
KEYCHAIN = 'keychain',
ENCRYPTED_FILE = 'encrypted_file',
}