mirror of
https://github.com/QwenLM/qwen-code.git
synced 2025-12-20 16:57:46 +00:00
Load and use MCP server prompts as slash commands in the CLI (#4828)
Co-authored-by: harold <haroldmciver@google.com> Co-authored-by: N. Taylor Mullen <ntaylormullen@google.com>
This commit is contained in:
@@ -15,12 +15,20 @@ import {
|
||||
StreamableHTTPClientTransport,
|
||||
StreamableHTTPClientTransportOptions,
|
||||
} from '@modelcontextprotocol/sdk/client/streamableHttp.js';
|
||||
import {
|
||||
Prompt,
|
||||
ListPromptsResultSchema,
|
||||
GetPromptResult,
|
||||
GetPromptResultSchema,
|
||||
} from '@modelcontextprotocol/sdk/types.js';
|
||||
import { parse } from 'shell-quote';
|
||||
import { AuthProviderType, MCPServerConfig } from '../config/config.js';
|
||||
import { GoogleCredentialProvider } from '../mcp/google-auth-provider.js';
|
||||
import { DiscoveredMCPTool } from './mcp-tool.js';
|
||||
|
||||
import { FunctionDeclaration, mcpToTool } from '@google/genai';
|
||||
import { ToolRegistry } from './tool-registry.js';
|
||||
import { PromptRegistry } from '../prompts/prompt-registry.js';
|
||||
import { MCPOAuthProvider } from '../mcp/oauth-provider.js';
|
||||
import { OAuthUtils } from '../mcp/oauth-utils.js';
|
||||
import { MCPOAuthTokenStorage } from '../mcp/oauth-token-storage.js';
|
||||
@@ -28,6 +36,11 @@ import { getErrorMessage } from '../utils/errors.js';
|
||||
|
||||
export const MCP_DEFAULT_TIMEOUT_MSEC = 10 * 60 * 1000; // default to 10 minutes
|
||||
|
||||
export type DiscoveredMCPPrompt = Prompt & {
|
||||
serverName: string;
|
||||
invoke: (params: Record<string, unknown>) => Promise<GetPromptResult>;
|
||||
};
|
||||
|
||||
/**
|
||||
* Enum representing the connection status of an MCP server
|
||||
*/
|
||||
@@ -55,7 +68,7 @@ export enum MCPDiscoveryState {
|
||||
/**
|
||||
* Map to track the status of each MCP server within the core package
|
||||
*/
|
||||
const mcpServerStatusesInternal: Map<string, MCPServerStatus> = new Map();
|
||||
const serverStatuses: Map<string, MCPServerStatus> = new Map();
|
||||
|
||||
/**
|
||||
* Track the overall MCP discovery state
|
||||
@@ -104,7 +117,7 @@ function updateMCPServerStatus(
|
||||
serverName: string,
|
||||
status: MCPServerStatus,
|
||||
): void {
|
||||
mcpServerStatusesInternal.set(serverName, status);
|
||||
serverStatuses.set(serverName, status);
|
||||
// Notify all listeners
|
||||
for (const listener of statusChangeListeners) {
|
||||
listener(serverName, status);
|
||||
@@ -115,16 +128,14 @@ function updateMCPServerStatus(
|
||||
* Get the current status of an MCP server
|
||||
*/
|
||||
export function getMCPServerStatus(serverName: string): MCPServerStatus {
|
||||
return (
|
||||
mcpServerStatusesInternal.get(serverName) || MCPServerStatus.DISCONNECTED
|
||||
);
|
||||
return serverStatuses.get(serverName) || MCPServerStatus.DISCONNECTED;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get all MCP server statuses
|
||||
*/
|
||||
export function getAllMCPServerStatuses(): Map<string, MCPServerStatus> {
|
||||
return new Map(mcpServerStatusesInternal);
|
||||
return new Map(serverStatuses);
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -307,6 +318,7 @@ export async function discoverMcpTools(
|
||||
mcpServers: Record<string, MCPServerConfig>,
|
||||
mcpServerCommand: string | undefined,
|
||||
toolRegistry: ToolRegistry,
|
||||
promptRegistry: PromptRegistry,
|
||||
debugMode: boolean,
|
||||
): Promise<void> {
|
||||
mcpDiscoveryState = MCPDiscoveryState.IN_PROGRESS;
|
||||
@@ -319,6 +331,7 @@ export async function discoverMcpTools(
|
||||
mcpServerName,
|
||||
mcpServerConfig,
|
||||
toolRegistry,
|
||||
promptRegistry,
|
||||
debugMode,
|
||||
),
|
||||
);
|
||||
@@ -362,6 +375,7 @@ export async function connectAndDiscover(
|
||||
mcpServerName: string,
|
||||
mcpServerConfig: MCPServerConfig,
|
||||
toolRegistry: ToolRegistry,
|
||||
promptRegistry: PromptRegistry,
|
||||
debugMode: boolean,
|
||||
): Promise<void> {
|
||||
updateMCPServerStatus(mcpServerName, MCPServerStatus.CONNECTING);
|
||||
@@ -378,6 +392,7 @@ export async function connectAndDiscover(
|
||||
console.error(`MCP ERROR (${mcpServerName}):`, error.toString());
|
||||
updateMCPServerStatus(mcpServerName, MCPServerStatus.DISCONNECTED);
|
||||
};
|
||||
await discoverPrompts(mcpServerName, mcpClient, promptRegistry);
|
||||
|
||||
const tools = await discoverTools(
|
||||
mcpServerName,
|
||||
@@ -393,7 +408,9 @@ export async function connectAndDiscover(
|
||||
}
|
||||
} catch (error) {
|
||||
console.error(
|
||||
`Error connecting to MCP server '${mcpServerName}': ${getErrorMessage(error)}`,
|
||||
`Error connecting to MCP server '${mcpServerName}': ${getErrorMessage(
|
||||
error,
|
||||
)}`,
|
||||
);
|
||||
updateMCPServerStatus(mcpServerName, MCPServerStatus.DISCONNECTED);
|
||||
}
|
||||
@@ -441,15 +458,97 @@ export async function discoverTools(
|
||||
),
|
||||
);
|
||||
}
|
||||
if (discoveredTools.length === 0) {
|
||||
throw Error('No enabled tools found');
|
||||
}
|
||||
return discoveredTools;
|
||||
} catch (error) {
|
||||
throw new Error(`Error discovering tools: ${error}`);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Discovers and logs prompts from a connected MCP client.
|
||||
* It retrieves prompt declarations from the client and logs their names.
|
||||
*
|
||||
* @param mcpServerName The name of the MCP server.
|
||||
* @param mcpClient The active MCP client instance.
|
||||
*/
|
||||
export async function discoverPrompts(
|
||||
mcpServerName: string,
|
||||
mcpClient: Client,
|
||||
promptRegistry: PromptRegistry,
|
||||
): Promise<void> {
|
||||
try {
|
||||
const response = await mcpClient.request(
|
||||
{ method: 'prompts/list', params: {} },
|
||||
ListPromptsResultSchema,
|
||||
);
|
||||
|
||||
for (const prompt of response.prompts) {
|
||||
promptRegistry.registerPrompt({
|
||||
...prompt,
|
||||
serverName: mcpServerName,
|
||||
invoke: (params: Record<string, unknown>) =>
|
||||
invokeMcpPrompt(mcpServerName, mcpClient, prompt.name, params),
|
||||
});
|
||||
}
|
||||
} catch (error) {
|
||||
// It's okay if this fails, not all servers will have prompts.
|
||||
// Don't log an error if the method is not found, which is a common case.
|
||||
if (
|
||||
error instanceof Error &&
|
||||
!error.message?.includes('Method not found')
|
||||
) {
|
||||
console.error(
|
||||
`Error discovering prompts from ${mcpServerName}: ${getErrorMessage(
|
||||
error,
|
||||
)}`,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Invokes a prompt on a connected MCP client.
|
||||
*
|
||||
* @param mcpServerName The name of the MCP server.
|
||||
* @param mcpClient The active MCP client instance.
|
||||
* @param promptName The name of the prompt to invoke.
|
||||
* @param promptParams The parameters to pass to the prompt.
|
||||
* @returns A promise that resolves to the result of the prompt invocation.
|
||||
*/
|
||||
export async function invokeMcpPrompt(
|
||||
mcpServerName: string,
|
||||
mcpClient: Client,
|
||||
promptName: string,
|
||||
promptParams: Record<string, unknown>,
|
||||
): Promise<GetPromptResult> {
|
||||
try {
|
||||
const response = await mcpClient.request(
|
||||
{
|
||||
method: 'prompts/get',
|
||||
params: {
|
||||
name: promptName,
|
||||
arguments: promptParams,
|
||||
},
|
||||
},
|
||||
GetPromptResultSchema,
|
||||
);
|
||||
|
||||
return response;
|
||||
} catch (error) {
|
||||
if (
|
||||
error instanceof Error &&
|
||||
!error.message?.includes('Method not found')
|
||||
) {
|
||||
console.error(
|
||||
`Error invoking prompt '${promptName}' from ${mcpServerName} ${promptParams}: ${getErrorMessage(
|
||||
error,
|
||||
)}`,
|
||||
);
|
||||
}
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates and connects an MCP client to a server based on the provided configuration.
|
||||
* It determines the appropriate transport (Stdio, SSE, or Streamable HTTP) and
|
||||
|
||||
Reference in New Issue
Block a user