mirror of
https://github.com/QwenLM/qwen-code.git
synced 2025-12-20 16:57:46 +00:00
Load and use MCP server prompts as slash commands in the CLI (#4828)
Co-authored-by: harold <haroldmciver@google.com> Co-authored-by: N. Taylor Mullen <ntaylormullen@google.com>
This commit is contained in:
@@ -11,6 +11,7 @@ import {
|
||||
createTransport,
|
||||
isEnabled,
|
||||
discoverTools,
|
||||
discoverPrompts,
|
||||
} from './mcp-client.js';
|
||||
import { SSEClientTransport } from '@modelcontextprotocol/sdk/client/sse.js';
|
||||
import * as SdkClientStdioLib from '@modelcontextprotocol/sdk/client/stdio.js';
|
||||
@@ -18,6 +19,7 @@ import * as ClientLib from '@modelcontextprotocol/sdk/client/index.js';
|
||||
import * as GenAiLib from '@google/genai';
|
||||
import { GoogleCredentialProvider } from '../mcp/google-auth-provider.js';
|
||||
import { AuthProviderType } from '../config/config.js';
|
||||
import { PromptRegistry } from '../prompts/prompt-registry.js';
|
||||
|
||||
vi.mock('@modelcontextprotocol/sdk/client/stdio.js');
|
||||
vi.mock('@modelcontextprotocol/sdk/client/index.js');
|
||||
@@ -50,6 +52,77 @@ describe('mcp-client', () => {
|
||||
});
|
||||
});
|
||||
|
||||
describe('discoverPrompts', () => {
|
||||
const mockedPromptRegistry = {
|
||||
registerPrompt: vi.fn(),
|
||||
} as unknown as PromptRegistry;
|
||||
|
||||
it('should discover and log prompts', async () => {
|
||||
const mockRequest = vi.fn().mockResolvedValue({
|
||||
prompts: [
|
||||
{ name: 'prompt1', description: 'desc1' },
|
||||
{ name: 'prompt2' },
|
||||
],
|
||||
});
|
||||
const mockedClient = {
|
||||
request: mockRequest,
|
||||
} as unknown as ClientLib.Client;
|
||||
|
||||
await discoverPrompts('test-server', mockedClient, mockedPromptRegistry);
|
||||
|
||||
expect(mockRequest).toHaveBeenCalledWith(
|
||||
{ method: 'prompts/list', params: {} },
|
||||
expect.anything(),
|
||||
);
|
||||
});
|
||||
|
||||
it('should do nothing if no prompts are discovered', async () => {
|
||||
const mockRequest = vi.fn().mockResolvedValue({
|
||||
prompts: [],
|
||||
});
|
||||
const mockedClient = {
|
||||
request: mockRequest,
|
||||
} as unknown as ClientLib.Client;
|
||||
|
||||
const consoleLogSpy = vi
|
||||
.spyOn(console, 'debug')
|
||||
.mockImplementation(() => {
|
||||
// no-op
|
||||
});
|
||||
|
||||
await discoverPrompts('test-server', mockedClient, mockedPromptRegistry);
|
||||
|
||||
expect(mockRequest).toHaveBeenCalledOnce();
|
||||
expect(consoleLogSpy).not.toHaveBeenCalled();
|
||||
|
||||
consoleLogSpy.mockRestore();
|
||||
});
|
||||
|
||||
it('should log an error if discovery fails', async () => {
|
||||
const testError = new Error('test error');
|
||||
testError.message = 'test error';
|
||||
const mockRequest = vi.fn().mockRejectedValue(testError);
|
||||
const mockedClient = {
|
||||
request: mockRequest,
|
||||
} as unknown as ClientLib.Client;
|
||||
|
||||
const consoleErrorSpy = vi
|
||||
.spyOn(console, 'error')
|
||||
.mockImplementation(() => {
|
||||
// no-op
|
||||
});
|
||||
|
||||
await discoverPrompts('test-server', mockedClient, mockedPromptRegistry);
|
||||
|
||||
expect(mockRequest).toHaveBeenCalledOnce();
|
||||
expect(consoleErrorSpy).toHaveBeenCalledWith(
|
||||
`Error discovering prompts from test-server: ${testError.message}`,
|
||||
);
|
||||
|
||||
consoleErrorSpy.mockRestore();
|
||||
});
|
||||
});
|
||||
|
||||
describe('appendMcpServerCommand', () => {
|
||||
it('should do nothing if no MCP servers or command are configured', () => {
|
||||
const out = populateMcpServerCommand({}, undefined);
|
||||
|
||||
@@ -15,12 +15,20 @@ import {
|
||||
StreamableHTTPClientTransport,
|
||||
StreamableHTTPClientTransportOptions,
|
||||
} from '@modelcontextprotocol/sdk/client/streamableHttp.js';
|
||||
import {
|
||||
Prompt,
|
||||
ListPromptsResultSchema,
|
||||
GetPromptResult,
|
||||
GetPromptResultSchema,
|
||||
} from '@modelcontextprotocol/sdk/types.js';
|
||||
import { parse } from 'shell-quote';
|
||||
import { AuthProviderType, MCPServerConfig } from '../config/config.js';
|
||||
import { GoogleCredentialProvider } from '../mcp/google-auth-provider.js';
|
||||
import { DiscoveredMCPTool } from './mcp-tool.js';
|
||||
|
||||
import { FunctionDeclaration, mcpToTool } from '@google/genai';
|
||||
import { ToolRegistry } from './tool-registry.js';
|
||||
import { PromptRegistry } from '../prompts/prompt-registry.js';
|
||||
import { MCPOAuthProvider } from '../mcp/oauth-provider.js';
|
||||
import { OAuthUtils } from '../mcp/oauth-utils.js';
|
||||
import { MCPOAuthTokenStorage } from '../mcp/oauth-token-storage.js';
|
||||
@@ -28,6 +36,11 @@ import { getErrorMessage } from '../utils/errors.js';
|
||||
|
||||
export const MCP_DEFAULT_TIMEOUT_MSEC = 10 * 60 * 1000; // default to 10 minutes
|
||||
|
||||
export type DiscoveredMCPPrompt = Prompt & {
|
||||
serverName: string;
|
||||
invoke: (params: Record<string, unknown>) => Promise<GetPromptResult>;
|
||||
};
|
||||
|
||||
/**
|
||||
* Enum representing the connection status of an MCP server
|
||||
*/
|
||||
@@ -55,7 +68,7 @@ export enum MCPDiscoveryState {
|
||||
/**
|
||||
* Map to track the status of each MCP server within the core package
|
||||
*/
|
||||
const mcpServerStatusesInternal: Map<string, MCPServerStatus> = new Map();
|
||||
const serverStatuses: Map<string, MCPServerStatus> = new Map();
|
||||
|
||||
/**
|
||||
* Track the overall MCP discovery state
|
||||
@@ -104,7 +117,7 @@ function updateMCPServerStatus(
|
||||
serverName: string,
|
||||
status: MCPServerStatus,
|
||||
): void {
|
||||
mcpServerStatusesInternal.set(serverName, status);
|
||||
serverStatuses.set(serverName, status);
|
||||
// Notify all listeners
|
||||
for (const listener of statusChangeListeners) {
|
||||
listener(serverName, status);
|
||||
@@ -115,16 +128,14 @@ function updateMCPServerStatus(
|
||||
* Get the current status of an MCP server
|
||||
*/
|
||||
export function getMCPServerStatus(serverName: string): MCPServerStatus {
|
||||
return (
|
||||
mcpServerStatusesInternal.get(serverName) || MCPServerStatus.DISCONNECTED
|
||||
);
|
||||
return serverStatuses.get(serverName) || MCPServerStatus.DISCONNECTED;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get all MCP server statuses
|
||||
*/
|
||||
export function getAllMCPServerStatuses(): Map<string, MCPServerStatus> {
|
||||
return new Map(mcpServerStatusesInternal);
|
||||
return new Map(serverStatuses);
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -307,6 +318,7 @@ export async function discoverMcpTools(
|
||||
mcpServers: Record<string, MCPServerConfig>,
|
||||
mcpServerCommand: string | undefined,
|
||||
toolRegistry: ToolRegistry,
|
||||
promptRegistry: PromptRegistry,
|
||||
debugMode: boolean,
|
||||
): Promise<void> {
|
||||
mcpDiscoveryState = MCPDiscoveryState.IN_PROGRESS;
|
||||
@@ -319,6 +331,7 @@ export async function discoverMcpTools(
|
||||
mcpServerName,
|
||||
mcpServerConfig,
|
||||
toolRegistry,
|
||||
promptRegistry,
|
||||
debugMode,
|
||||
),
|
||||
);
|
||||
@@ -362,6 +375,7 @@ export async function connectAndDiscover(
|
||||
mcpServerName: string,
|
||||
mcpServerConfig: MCPServerConfig,
|
||||
toolRegistry: ToolRegistry,
|
||||
promptRegistry: PromptRegistry,
|
||||
debugMode: boolean,
|
||||
): Promise<void> {
|
||||
updateMCPServerStatus(mcpServerName, MCPServerStatus.CONNECTING);
|
||||
@@ -378,6 +392,7 @@ export async function connectAndDiscover(
|
||||
console.error(`MCP ERROR (${mcpServerName}):`, error.toString());
|
||||
updateMCPServerStatus(mcpServerName, MCPServerStatus.DISCONNECTED);
|
||||
};
|
||||
await discoverPrompts(mcpServerName, mcpClient, promptRegistry);
|
||||
|
||||
const tools = await discoverTools(
|
||||
mcpServerName,
|
||||
@@ -393,7 +408,9 @@ export async function connectAndDiscover(
|
||||
}
|
||||
} catch (error) {
|
||||
console.error(
|
||||
`Error connecting to MCP server '${mcpServerName}': ${getErrorMessage(error)}`,
|
||||
`Error connecting to MCP server '${mcpServerName}': ${getErrorMessage(
|
||||
error,
|
||||
)}`,
|
||||
);
|
||||
updateMCPServerStatus(mcpServerName, MCPServerStatus.DISCONNECTED);
|
||||
}
|
||||
@@ -441,15 +458,97 @@ export async function discoverTools(
|
||||
),
|
||||
);
|
||||
}
|
||||
if (discoveredTools.length === 0) {
|
||||
throw Error('No enabled tools found');
|
||||
}
|
||||
return discoveredTools;
|
||||
} catch (error) {
|
||||
throw new Error(`Error discovering tools: ${error}`);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Discovers and logs prompts from a connected MCP client.
|
||||
* It retrieves prompt declarations from the client and logs their names.
|
||||
*
|
||||
* @param mcpServerName The name of the MCP server.
|
||||
* @param mcpClient The active MCP client instance.
|
||||
*/
|
||||
export async function discoverPrompts(
|
||||
mcpServerName: string,
|
||||
mcpClient: Client,
|
||||
promptRegistry: PromptRegistry,
|
||||
): Promise<void> {
|
||||
try {
|
||||
const response = await mcpClient.request(
|
||||
{ method: 'prompts/list', params: {} },
|
||||
ListPromptsResultSchema,
|
||||
);
|
||||
|
||||
for (const prompt of response.prompts) {
|
||||
promptRegistry.registerPrompt({
|
||||
...prompt,
|
||||
serverName: mcpServerName,
|
||||
invoke: (params: Record<string, unknown>) =>
|
||||
invokeMcpPrompt(mcpServerName, mcpClient, prompt.name, params),
|
||||
});
|
||||
}
|
||||
} catch (error) {
|
||||
// It's okay if this fails, not all servers will have prompts.
|
||||
// Don't log an error if the method is not found, which is a common case.
|
||||
if (
|
||||
error instanceof Error &&
|
||||
!error.message?.includes('Method not found')
|
||||
) {
|
||||
console.error(
|
||||
`Error discovering prompts from ${mcpServerName}: ${getErrorMessage(
|
||||
error,
|
||||
)}`,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Invokes a prompt on a connected MCP client.
|
||||
*
|
||||
* @param mcpServerName The name of the MCP server.
|
||||
* @param mcpClient The active MCP client instance.
|
||||
* @param promptName The name of the prompt to invoke.
|
||||
* @param promptParams The parameters to pass to the prompt.
|
||||
* @returns A promise that resolves to the result of the prompt invocation.
|
||||
*/
|
||||
export async function invokeMcpPrompt(
|
||||
mcpServerName: string,
|
||||
mcpClient: Client,
|
||||
promptName: string,
|
||||
promptParams: Record<string, unknown>,
|
||||
): Promise<GetPromptResult> {
|
||||
try {
|
||||
const response = await mcpClient.request(
|
||||
{
|
||||
method: 'prompts/get',
|
||||
params: {
|
||||
name: promptName,
|
||||
arguments: promptParams,
|
||||
},
|
||||
},
|
||||
GetPromptResultSchema,
|
||||
);
|
||||
|
||||
return response;
|
||||
} catch (error) {
|
||||
if (
|
||||
error instanceof Error &&
|
||||
!error.message?.includes('Method not found')
|
||||
) {
|
||||
console.error(
|
||||
`Error invoking prompt '${promptName}' from ${mcpServerName} ${promptParams}: ${getErrorMessage(
|
||||
error,
|
||||
)}`,
|
||||
);
|
||||
}
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates and connects an MCP client to a server based on the provided configuration.
|
||||
* It determines the appropriate transport (Stdio, SSE, or Streamable HTTP) and
|
||||
|
||||
@@ -344,6 +344,7 @@ describe('ToolRegistry', () => {
|
||||
mcpServerConfigVal,
|
||||
undefined,
|
||||
toolRegistry,
|
||||
undefined,
|
||||
false,
|
||||
);
|
||||
});
|
||||
@@ -366,6 +367,7 @@ describe('ToolRegistry', () => {
|
||||
mcpServerConfigVal,
|
||||
undefined,
|
||||
toolRegistry,
|
||||
undefined,
|
||||
false,
|
||||
);
|
||||
});
|
||||
|
||||
@@ -170,6 +170,7 @@ export class ToolRegistry {
|
||||
this.config.getMcpServers() ?? {},
|
||||
this.config.getMcpServerCommand(),
|
||||
this,
|
||||
this.config.getPromptRegistry(),
|
||||
this.config.getDebugMode(),
|
||||
);
|
||||
}
|
||||
@@ -192,6 +193,7 @@ export class ToolRegistry {
|
||||
this.config.getMcpServers() ?? {},
|
||||
this.config.getMcpServerCommand(),
|
||||
this,
|
||||
this.config.getPromptRegistry(),
|
||||
this.config.getDebugMode(),
|
||||
);
|
||||
}
|
||||
@@ -215,6 +217,7 @@ export class ToolRegistry {
|
||||
{ [serverName]: serverConfig },
|
||||
undefined,
|
||||
this,
|
||||
this.config.getPromptRegistry(),
|
||||
this.config.getDebugMode(),
|
||||
);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user