mirror of
https://github.com/QwenLM/qwen-code.git
synced 2025-12-21 01:07:46 +00:00
Sync upstream Gemini-CLI v0.8.2 (#838)
This commit is contained in:
@@ -5,41 +5,42 @@
|
||||
*/
|
||||
|
||||
import { Client } from '@modelcontextprotocol/sdk/client/index.js';
|
||||
import type { Transport } from '@modelcontextprotocol/sdk/shared/transport.js';
|
||||
import { StdioClientTransport } from '@modelcontextprotocol/sdk/client/stdio.js';
|
||||
import type { SSEClientTransportOptions } from '@modelcontextprotocol/sdk/client/sse.js';
|
||||
import { SSEClientTransport } from '@modelcontextprotocol/sdk/client/sse.js';
|
||||
import { StdioClientTransport } from '@modelcontextprotocol/sdk/client/stdio.js';
|
||||
import type { StreamableHTTPClientTransportOptions } from '@modelcontextprotocol/sdk/client/streamableHttp.js';
|
||||
import { StreamableHTTPClientTransport } from '@modelcontextprotocol/sdk/client/streamableHttp.js';
|
||||
import type { Transport } from '@modelcontextprotocol/sdk/shared/transport.js';
|
||||
import type {
|
||||
Prompt,
|
||||
GetPromptResult,
|
||||
Prompt,
|
||||
} from '@modelcontextprotocol/sdk/types.js';
|
||||
import {
|
||||
ListPromptsResultSchema,
|
||||
GetPromptResultSchema,
|
||||
ListPromptsResultSchema,
|
||||
ListRootsRequestSchema,
|
||||
} from '@modelcontextprotocol/sdk/types.js';
|
||||
import { parse } from 'shell-quote';
|
||||
import type { MCPServerConfig } from '../config/config.js';
|
||||
import type { Config, MCPServerConfig } from '../config/config.js';
|
||||
import { AuthProviderType } from '../config/config.js';
|
||||
import { GoogleCredentialProvider } from '../mcp/google-auth-provider.js';
|
||||
import { ServiceAccountImpersonationProvider } from '../mcp/sa-impersonation-provider.js';
|
||||
import { DiscoveredMCPTool } from './mcp-tool.js';
|
||||
|
||||
import type { FunctionDeclaration } from '@google/genai';
|
||||
import { mcpToTool } from '@google/genai';
|
||||
import type { ToolRegistry } from './tool-registry.js';
|
||||
import type { PromptRegistry } from '../prompts/prompt-registry.js';
|
||||
import { MCPOAuthProvider } from '../mcp/oauth-provider.js';
|
||||
import { OAuthUtils } from '../mcp/oauth-utils.js';
|
||||
import { MCPOAuthTokenStorage } from '../mcp/oauth-token-storage.js';
|
||||
import { getErrorMessage } from '../utils/errors.js';
|
||||
import { basename } from 'node:path';
|
||||
import { pathToFileURL } from 'node:url';
|
||||
import { MCPOAuthProvider } from '../mcp/oauth-provider.js';
|
||||
import { MCPOAuthTokenStorage } from '../mcp/oauth-token-storage.js';
|
||||
import { OAuthUtils } from '../mcp/oauth-utils.js';
|
||||
import type { PromptRegistry } from '../prompts/prompt-registry.js';
|
||||
import { getErrorMessage } from '../utils/errors.js';
|
||||
import type {
|
||||
Unsubscribe,
|
||||
WorkspaceContext,
|
||||
} from '../utils/workspaceContext.js';
|
||||
import type { ToolRegistry } from './tool-registry.js';
|
||||
|
||||
export const MCP_DEFAULT_TIMEOUT_MSEC = 10 * 60 * 1000; // default to 10 minutes
|
||||
|
||||
@@ -93,7 +94,7 @@ export class McpClient {
|
||||
private readonly debugMode: boolean,
|
||||
) {
|
||||
this.client = new Client({
|
||||
name: `gemini-cli-mcp-client-${this.serverName}`,
|
||||
name: `qwen-cli-mcp-client-${this.serverName}`,
|
||||
version: '0.0.1',
|
||||
});
|
||||
}
|
||||
@@ -146,13 +147,13 @@ export class McpClient {
|
||||
/**
|
||||
* Discovers tools and prompts from the MCP server.
|
||||
*/
|
||||
async discover(): Promise<void> {
|
||||
async discover(cliConfig: Config): Promise<void> {
|
||||
if (this.status !== MCPServerStatus.CONNECTED) {
|
||||
throw new Error('Client is not connected.');
|
||||
}
|
||||
|
||||
const prompts = await this.discoverPrompts();
|
||||
const tools = await this.discoverTools();
|
||||
const tools = await this.discoverTools(cliConfig);
|
||||
|
||||
if (prompts.length === 0 && tools.length === 0) {
|
||||
throw new Error('No prompts or tools found on the server.');
|
||||
@@ -191,8 +192,13 @@ export class McpClient {
|
||||
return createTransport(this.serverName, this.serverConfig, this.debugMode);
|
||||
}
|
||||
|
||||
private async discoverTools(): Promise<DiscoveredMCPTool[]> {
|
||||
return discoverTools(this.serverName, this.serverConfig, this.client);
|
||||
private async discoverTools(cliConfig: Config): Promise<DiscoveredMCPTool[]> {
|
||||
return discoverTools(
|
||||
this.serverName,
|
||||
this.serverConfig,
|
||||
this.client,
|
||||
cliConfig,
|
||||
);
|
||||
}
|
||||
|
||||
private async discoverPrompts(): Promise<Prompt[]> {
|
||||
@@ -360,11 +366,8 @@ async function handleAutomaticOAuth(
|
||||
console.log(
|
||||
`Starting OAuth authentication for server '${mcpServerName}'...`,
|
||||
);
|
||||
await MCPOAuthProvider.authenticate(
|
||||
mcpServerName,
|
||||
oauthAuthConfig,
|
||||
serverUrl,
|
||||
);
|
||||
const authProvider = new MCPOAuthProvider(new MCPOAuthTokenStorage());
|
||||
await authProvider.authenticate(mcpServerName, oauthAuthConfig, serverUrl);
|
||||
|
||||
console.log(
|
||||
`OAuth authentication successful for server '${mcpServerName}'`,
|
||||
@@ -438,6 +441,7 @@ async function createTransportWithOAuth(
|
||||
* @param toolRegistry The central registry where discovered tools will be registered.
|
||||
* @returns A promise that resolves when the discovery process has been attempted for all servers.
|
||||
*/
|
||||
|
||||
export async function discoverMcpTools(
|
||||
mcpServers: Record<string, MCPServerConfig>,
|
||||
mcpServerCommand: string | undefined,
|
||||
@@ -445,6 +449,7 @@ export async function discoverMcpTools(
|
||||
promptRegistry: PromptRegistry,
|
||||
debugMode: boolean,
|
||||
workspaceContext: WorkspaceContext,
|
||||
cliConfig: Config,
|
||||
): Promise<void> {
|
||||
mcpDiscoveryState = MCPDiscoveryState.IN_PROGRESS;
|
||||
try {
|
||||
@@ -459,6 +464,7 @@ export async function discoverMcpTools(
|
||||
promptRegistry,
|
||||
debugMode,
|
||||
workspaceContext,
|
||||
cliConfig,
|
||||
),
|
||||
);
|
||||
await Promise.all(discoveryPromises);
|
||||
@@ -504,6 +510,7 @@ export async function connectAndDiscover(
|
||||
promptRegistry: PromptRegistry,
|
||||
debugMode: boolean,
|
||||
workspaceContext: WorkspaceContext,
|
||||
cliConfig: Config,
|
||||
): Promise<void> {
|
||||
updateMCPServerStatus(mcpServerName, MCPServerStatus.CONNECTING);
|
||||
|
||||
@@ -531,6 +538,7 @@ export async function connectAndDiscover(
|
||||
mcpServerName,
|
||||
mcpServerConfig,
|
||||
mcpClient,
|
||||
cliConfig,
|
||||
);
|
||||
|
||||
// If we have neither prompts nor tools, it's a failed discovery
|
||||
@@ -558,65 +566,6 @@ export async function connectAndDiscover(
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Recursively validates that a JSON schema and all its nested properties and
|
||||
* items have a `type` defined.
|
||||
*
|
||||
* @param schema The JSON schema to validate.
|
||||
* @returns `true` if the schema is valid, `false` otherwise.
|
||||
*
|
||||
* @visiblefortesting
|
||||
*/
|
||||
export function hasValidTypes(schema: unknown): boolean {
|
||||
if (typeof schema !== 'object' || schema === null) {
|
||||
// Not a schema object we can validate, or not a schema at all.
|
||||
// Treat as valid as it has no properties to be invalid.
|
||||
return true;
|
||||
}
|
||||
|
||||
const s = schema as Record<string, unknown>;
|
||||
|
||||
if (!s['type']) {
|
||||
// These keywords contain an array of schemas that should be validated.
|
||||
//
|
||||
// If no top level type was given, then they must each have a type.
|
||||
let hasSubSchema = false;
|
||||
const schemaArrayKeywords = ['anyOf', 'allOf', 'oneOf'];
|
||||
for (const keyword of schemaArrayKeywords) {
|
||||
const subSchemas = s[keyword];
|
||||
if (Array.isArray(subSchemas)) {
|
||||
hasSubSchema = true;
|
||||
for (const subSchema of subSchemas) {
|
||||
if (!hasValidTypes(subSchema)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If the node itself is missing a type and had no subschemas, then it isn't valid.
|
||||
if (!hasSubSchema) return false;
|
||||
}
|
||||
|
||||
if (s['type'] === 'object' && s['properties']) {
|
||||
if (typeof s['properties'] === 'object' && s['properties'] !== null) {
|
||||
for (const prop of Object.values(s['properties'])) {
|
||||
if (!hasValidTypes(prop)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (s['type'] === 'array' && s['items']) {
|
||||
if (!hasValidTypes(s['items'])) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
* Discovers and sanitizes tools from a connected MCP client.
|
||||
* It retrieves function declarations from the client, filters out disabled tools,
|
||||
@@ -632,9 +581,12 @@ export async function discoverTools(
|
||||
mcpServerName: string,
|
||||
mcpServerConfig: MCPServerConfig,
|
||||
mcpClient: Client,
|
||||
cliConfig: Config,
|
||||
): Promise<DiscoveredMCPTool[]> {
|
||||
try {
|
||||
const mcpCallableTool = mcpToTool(mcpClient);
|
||||
const mcpCallableTool = mcpToTool(mcpClient, {
|
||||
timeout: mcpServerConfig.timeout ?? MCP_DEFAULT_TIMEOUT_MSEC,
|
||||
});
|
||||
const tool = await mcpCallableTool.tool();
|
||||
|
||||
if (!Array.isArray(tool.functionDeclarations)) {
|
||||
@@ -649,15 +601,6 @@ export async function discoverTools(
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!hasValidTypes(funcDecl.parametersJsonSchema)) {
|
||||
console.warn(
|
||||
`Skipping tool '${funcDecl.name}' from MCP server '${mcpServerName}' ` +
|
||||
`because it has missing types in its parameter schema. Please file an ` +
|
||||
`issue with the owner of the MCP server.`,
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
discoveredTools.push(
|
||||
new DiscoveredMCPTool(
|
||||
mcpCallableTool,
|
||||
@@ -665,8 +608,9 @@ export async function discoverTools(
|
||||
funcDecl.name!,
|
||||
funcDecl.description ?? '',
|
||||
funcDecl.parametersJsonSchema ?? { type: 'object', properties: {} },
|
||||
mcpServerConfig.timeout ?? MCP_DEFAULT_TIMEOUT_MSEC,
|
||||
mcpServerConfig.trust,
|
||||
undefined,
|
||||
cliConfig,
|
||||
),
|
||||
);
|
||||
} catch (error) {
|
||||
@@ -859,18 +803,6 @@ export async function connectToMcpServer(
|
||||
unlistenDirectories = undefined;
|
||||
};
|
||||
|
||||
// patch Client.callTool to use request timeout as genai McpCallTool.callTool does not do it
|
||||
// TODO: remove this hack once GenAI SDK does callTool with request options
|
||||
if ('callTool' in mcpClient) {
|
||||
const origCallTool = mcpClient.callTool.bind(mcpClient);
|
||||
mcpClient.callTool = function (params, resultSchema, options) {
|
||||
return origCallTool(params, resultSchema, {
|
||||
...options,
|
||||
timeout: mcpServerConfig.timeout ?? MCP_DEFAULT_TIMEOUT_MSEC,
|
||||
});
|
||||
};
|
||||
}
|
||||
|
||||
try {
|
||||
const transport = await createTransport(
|
||||
mcpServerName,
|
||||
@@ -898,9 +830,11 @@ export async function connectToMcpServer(
|
||||
|
||||
if (!shouldTriggerOAuth) {
|
||||
// For SSE servers without explicit OAuth config, if a token was found but rejected, report it accurately.
|
||||
const credentials = await MCPOAuthTokenStorage.getToken(mcpServerName);
|
||||
const tokenStorage = new MCPOAuthTokenStorage();
|
||||
const credentials = await tokenStorage.getCredentials(mcpServerName);
|
||||
if (credentials) {
|
||||
const hasStoredTokens = await MCPOAuthProvider.getValidToken(
|
||||
const authProvider = new MCPOAuthProvider(tokenStorage);
|
||||
const hasStoredTokens = await authProvider.getValidToken(
|
||||
mcpServerName,
|
||||
{
|
||||
// Pass client ID if available
|
||||
@@ -981,10 +915,11 @@ export async function connectToMcpServer(
|
||||
|
||||
// Get the valid token - we need to create a proper OAuth config
|
||||
// The token should already be available from the authentication process
|
||||
const credentials =
|
||||
await MCPOAuthTokenStorage.getToken(mcpServerName);
|
||||
const tokenStorage = new MCPOAuthTokenStorage();
|
||||
const credentials = await tokenStorage.getCredentials(mcpServerName);
|
||||
if (credentials) {
|
||||
const accessToken = await MCPOAuthProvider.getValidToken(
|
||||
const authProvider = new MCPOAuthProvider(tokenStorage);
|
||||
const accessToken = await authProvider.getValidToken(
|
||||
mcpServerName,
|
||||
{
|
||||
// Pass client ID if available
|
||||
@@ -1055,10 +990,11 @@ export async function connectToMcpServer(
|
||||
mcpServerConfig.httpUrl || mcpServerConfig.oauth?.enabled;
|
||||
|
||||
if (!shouldTryDiscovery) {
|
||||
const credentials =
|
||||
await MCPOAuthTokenStorage.getToken(mcpServerName);
|
||||
const tokenStorage = new MCPOAuthTokenStorage();
|
||||
const credentials = await tokenStorage.getCredentials(mcpServerName);
|
||||
if (credentials) {
|
||||
const hasStoredTokens = await MCPOAuthProvider.getValidToken(
|
||||
const authProvider = new MCPOAuthProvider(tokenStorage);
|
||||
const hasStoredTokens = await authProvider.getValidToken(
|
||||
mcpServerName,
|
||||
{
|
||||
// Pass client ID if available
|
||||
@@ -1115,17 +1051,22 @@ export async function connectToMcpServer(
|
||||
console.log(
|
||||
`Starting OAuth authentication for server '${mcpServerName}'...`,
|
||||
);
|
||||
await MCPOAuthProvider.authenticate(
|
||||
const authProvider = new MCPOAuthProvider(
|
||||
new MCPOAuthTokenStorage(),
|
||||
);
|
||||
await authProvider.authenticate(
|
||||
mcpServerName,
|
||||
oauthAuthConfig,
|
||||
authServerUrl,
|
||||
);
|
||||
|
||||
// Retry connection with OAuth token
|
||||
const tokenStorage = new MCPOAuthTokenStorage();
|
||||
const credentials =
|
||||
await MCPOAuthTokenStorage.getToken(mcpServerName);
|
||||
await tokenStorage.getCredentials(mcpServerName);
|
||||
if (credentials) {
|
||||
const accessToken = await MCPOAuthProvider.getValidToken(
|
||||
const authProvider = new MCPOAuthProvider(tokenStorage);
|
||||
const accessToken = await authProvider.getValidToken(
|
||||
mcpServerName,
|
||||
{
|
||||
// Pass client ID if available
|
||||
@@ -1232,6 +1173,34 @@ export async function createTransport(
|
||||
mcpServerConfig: MCPServerConfig,
|
||||
debugMode: boolean,
|
||||
): Promise<Transport> {
|
||||
if (
|
||||
mcpServerConfig.authProviderType ===
|
||||
AuthProviderType.SERVICE_ACCOUNT_IMPERSONATION
|
||||
) {
|
||||
const provider = new ServiceAccountImpersonationProvider(mcpServerConfig);
|
||||
const transportOptions:
|
||||
| StreamableHTTPClientTransportOptions
|
||||
| SSEClientTransportOptions = {
|
||||
authProvider: provider,
|
||||
};
|
||||
|
||||
if (mcpServerConfig.httpUrl) {
|
||||
return new StreamableHTTPClientTransport(
|
||||
new URL(mcpServerConfig.httpUrl),
|
||||
transportOptions,
|
||||
);
|
||||
} else if (mcpServerConfig.url) {
|
||||
// Default to SSE if only url is provided
|
||||
return new SSEClientTransport(
|
||||
new URL(mcpServerConfig.url),
|
||||
transportOptions,
|
||||
);
|
||||
}
|
||||
throw new Error(
|
||||
'No URL configured for ServiceAccountImpersonation MCP Server',
|
||||
);
|
||||
}
|
||||
|
||||
if (
|
||||
mcpServerConfig.authProviderType === AuthProviderType.GOOGLE_CREDENTIALS
|
||||
) {
|
||||
@@ -1260,7 +1229,9 @@ export async function createTransport(
|
||||
let hasOAuthConfig = mcpServerConfig.oauth?.enabled;
|
||||
|
||||
if (hasOAuthConfig && mcpServerConfig.oauth) {
|
||||
accessToken = await MCPOAuthProvider.getValidToken(
|
||||
const tokenStorage = new MCPOAuthTokenStorage();
|
||||
const authProvider = new MCPOAuthProvider(tokenStorage);
|
||||
accessToken = await authProvider.getValidToken(
|
||||
mcpServerName,
|
||||
mcpServerConfig.oauth,
|
||||
);
|
||||
@@ -1277,9 +1248,11 @@ export async function createTransport(
|
||||
}
|
||||
} else {
|
||||
// Check if we have stored OAuth tokens for this server (from previous authentication)
|
||||
const credentials = await MCPOAuthTokenStorage.getToken(mcpServerName);
|
||||
const tokenStorage = new MCPOAuthTokenStorage();
|
||||
const credentials = await tokenStorage.getCredentials(mcpServerName);
|
||||
if (credentials) {
|
||||
accessToken = await MCPOAuthProvider.getValidToken(mcpServerName, {
|
||||
const authProvider = new MCPOAuthProvider(tokenStorage);
|
||||
accessToken = await authProvider.getValidToken(mcpServerName, {
|
||||
// Pass client ID if available
|
||||
clientId: credentials.clientId,
|
||||
});
|
||||
|
||||
Reference in New Issue
Block a user