Load and use MCP server prompts as slash commands in the CLI (#4828)

Co-authored-by: harold <haroldmciver@google.com>
Co-authored-by: N. Taylor Mullen <ntaylormullen@google.com>
This commit is contained in:
christine betts
2025-07-25 20:56:33 +00:00
committed by GitHub
parent de96887789
commit eb65034117
19 changed files with 761 additions and 100 deletions

View File

@@ -15,12 +15,20 @@ import {
StreamableHTTPClientTransport,
StreamableHTTPClientTransportOptions,
} from '@modelcontextprotocol/sdk/client/streamableHttp.js';
import {
Prompt,
ListPromptsResultSchema,
GetPromptResult,
GetPromptResultSchema,
} from '@modelcontextprotocol/sdk/types.js';
import { parse } from 'shell-quote';
import { AuthProviderType, MCPServerConfig } from '../config/config.js';
import { GoogleCredentialProvider } from '../mcp/google-auth-provider.js';
import { DiscoveredMCPTool } from './mcp-tool.js';
import { FunctionDeclaration, mcpToTool } from '@google/genai';
import { ToolRegistry } from './tool-registry.js';
import { PromptRegistry } from '../prompts/prompt-registry.js';
import { MCPOAuthProvider } from '../mcp/oauth-provider.js';
import { OAuthUtils } from '../mcp/oauth-utils.js';
import { MCPOAuthTokenStorage } from '../mcp/oauth-token-storage.js';
@@ -28,6 +36,11 @@ import { getErrorMessage } from '../utils/errors.js';
export const MCP_DEFAULT_TIMEOUT_MSEC = 10 * 60 * 1000; // default to 10 minutes
export type DiscoveredMCPPrompt = Prompt & {
serverName: string;
invoke: (params: Record<string, unknown>) => Promise<GetPromptResult>;
};
/**
* Enum representing the connection status of an MCP server
*/
@@ -55,7 +68,7 @@ export enum MCPDiscoveryState {
/**
* Map to track the status of each MCP server within the core package
*/
const mcpServerStatusesInternal: Map<string, MCPServerStatus> = new Map();
const serverStatuses: Map<string, MCPServerStatus> = new Map();
/**
* Track the overall MCP discovery state
@@ -104,7 +117,7 @@ function updateMCPServerStatus(
serverName: string,
status: MCPServerStatus,
): void {
mcpServerStatusesInternal.set(serverName, status);
serverStatuses.set(serverName, status);
// Notify all listeners
for (const listener of statusChangeListeners) {
listener(serverName, status);
@@ -115,16 +128,14 @@ function updateMCPServerStatus(
* Get the current status of an MCP server
*/
export function getMCPServerStatus(serverName: string): MCPServerStatus {
return (
mcpServerStatusesInternal.get(serverName) || MCPServerStatus.DISCONNECTED
);
return serverStatuses.get(serverName) || MCPServerStatus.DISCONNECTED;
}
/**
* Get all MCP server statuses
*/
export function getAllMCPServerStatuses(): Map<string, MCPServerStatus> {
return new Map(mcpServerStatusesInternal);
return new Map(serverStatuses);
}
/**
@@ -307,6 +318,7 @@ export async function discoverMcpTools(
mcpServers: Record<string, MCPServerConfig>,
mcpServerCommand: string | undefined,
toolRegistry: ToolRegistry,
promptRegistry: PromptRegistry,
debugMode: boolean,
): Promise<void> {
mcpDiscoveryState = MCPDiscoveryState.IN_PROGRESS;
@@ -319,6 +331,7 @@ export async function discoverMcpTools(
mcpServerName,
mcpServerConfig,
toolRegistry,
promptRegistry,
debugMode,
),
);
@@ -362,6 +375,7 @@ export async function connectAndDiscover(
mcpServerName: string,
mcpServerConfig: MCPServerConfig,
toolRegistry: ToolRegistry,
promptRegistry: PromptRegistry,
debugMode: boolean,
): Promise<void> {
updateMCPServerStatus(mcpServerName, MCPServerStatus.CONNECTING);
@@ -378,6 +392,7 @@ export async function connectAndDiscover(
console.error(`MCP ERROR (${mcpServerName}):`, error.toString());
updateMCPServerStatus(mcpServerName, MCPServerStatus.DISCONNECTED);
};
await discoverPrompts(mcpServerName, mcpClient, promptRegistry);
const tools = await discoverTools(
mcpServerName,
@@ -393,7 +408,9 @@ export async function connectAndDiscover(
}
} catch (error) {
console.error(
`Error connecting to MCP server '${mcpServerName}': ${getErrorMessage(error)}`,
`Error connecting to MCP server '${mcpServerName}': ${getErrorMessage(
error,
)}`,
);
updateMCPServerStatus(mcpServerName, MCPServerStatus.DISCONNECTED);
}
@@ -441,15 +458,97 @@ export async function discoverTools(
),
);
}
if (discoveredTools.length === 0) {
throw Error('No enabled tools found');
}
return discoveredTools;
} catch (error) {
throw new Error(`Error discovering tools: ${error}`);
}
}
/**
* Discovers and logs prompts from a connected MCP client.
* It retrieves prompt declarations from the client and logs their names.
*
* @param mcpServerName The name of the MCP server.
* @param mcpClient The active MCP client instance.
*/
export async function discoverPrompts(
mcpServerName: string,
mcpClient: Client,
promptRegistry: PromptRegistry,
): Promise<void> {
try {
const response = await mcpClient.request(
{ method: 'prompts/list', params: {} },
ListPromptsResultSchema,
);
for (const prompt of response.prompts) {
promptRegistry.registerPrompt({
...prompt,
serverName: mcpServerName,
invoke: (params: Record<string, unknown>) =>
invokeMcpPrompt(mcpServerName, mcpClient, prompt.name, params),
});
}
} catch (error) {
// It's okay if this fails, not all servers will have prompts.
// Don't log an error if the method is not found, which is a common case.
if (
error instanceof Error &&
!error.message?.includes('Method not found')
) {
console.error(
`Error discovering prompts from ${mcpServerName}: ${getErrorMessage(
error,
)}`,
);
}
}
}
/**
* Invokes a prompt on a connected MCP client.
*
* @param mcpServerName The name of the MCP server.
* @param mcpClient The active MCP client instance.
* @param promptName The name of the prompt to invoke.
* @param promptParams The parameters to pass to the prompt.
* @returns A promise that resolves to the result of the prompt invocation.
*/
export async function invokeMcpPrompt(
mcpServerName: string,
mcpClient: Client,
promptName: string,
promptParams: Record<string, unknown>,
): Promise<GetPromptResult> {
try {
const response = await mcpClient.request(
{
method: 'prompts/get',
params: {
name: promptName,
arguments: promptParams,
},
},
GetPromptResultSchema,
);
return response;
} catch (error) {
if (
error instanceof Error &&
!error.message?.includes('Method not found')
) {
console.error(
`Error invoking prompt '${promptName}' from ${mcpServerName} ${promptParams}: ${getErrorMessage(
error,
)}`,
);
}
throw error;
}
}
/**
* Creates and connects an MCP client to a server based on the provided configuration.
* It determines the appropriate transport (Stdio, SSE, or Streamable HTTP) and