feat: restart MCP servers on /mcp refresh (#5479)

Co-authored-by: Brian Ray <62354532+emeryray2002@users.noreply.github.com>
Co-authored-by: N. Taylor Mullen <ntaylormullen@google.com>
This commit is contained in:
Ramón Medrano Llamas
2025-08-19 21:03:19 +02:00
committed by GitHub
parent 4828e4daf1
commit b24c5887c4
9 changed files with 447 additions and 467 deletions

View File

@@ -4,16 +4,14 @@
* SPDX-License-Identifier: Apache-2.0
*/
import { afterEach, describe, expect, it, vi, beforeEach } from 'vitest';
import { afterEach, describe, expect, it, vi } from 'vitest';
import { StreamableHTTPClientTransport } from '@modelcontextprotocol/sdk/client/streamableHttp.js';
import {
populateMcpServerCommand,
createTransport,
isEnabled,
discoverTools,
discoverPrompts,
hasValidTypes,
connectToMcpServer,
McpClient,
} from './mcp-client.js';
import { SSEClientTransport } from '@modelcontextprotocol/sdk/client/sse.js';
import * as SdkClientStdioLib from '@modelcontextprotocol/sdk/client/stdio.js';
@@ -22,26 +20,36 @@ import * as GenAiLib from '@google/genai';
import { GoogleCredentialProvider } from '../mcp/google-auth-provider.js';
import { AuthProviderType } from '../config/config.js';
import { PromptRegistry } from '../prompts/prompt-registry.js';
import { DiscoveredMCPTool } from './mcp-tool.js';
import { ToolRegistry } from './tool-registry.js';
import { WorkspaceContext } from '../utils/workspaceContext.js';
import { pathToFileURL } from 'node:url';
vi.mock('@modelcontextprotocol/sdk/client/stdio.js');
vi.mock('@modelcontextprotocol/sdk/client/index.js');
vi.mock('@google/genai');
vi.mock('../mcp/oauth-provider.js');
vi.mock('../mcp/oauth-token-storage.js');
vi.mock('./mcp-tool.js');
describe('mcp-client', () => {
afterEach(() => {
vi.restoreAllMocks();
});
describe('discoverTools', () => {
describe('McpClient', () => {
it('should discover tools', async () => {
const mockedClient = {} as unknown as ClientLib.Client;
const mockedClient = {
connect: vi.fn(),
discover: vi.fn(),
disconnect: vi.fn(),
getStatus: vi.fn(),
registerCapabilities: vi.fn(),
setRequestHandler: vi.fn(),
};
vi.mocked(ClientLib.Client).mockReturnValue(
mockedClient as unknown as ClientLib.Client,
);
vi.spyOn(SdkClientStdioLib, 'StdioClientTransport').mockReturnValue(
{} as SdkClientStdioLib.StdioClientTransport,
);
const mockedMcpToTool = vi.mocked(GenAiLib.mcpToTool).mockReturnValue({
tool: () => ({
functionDeclarations: [
@@ -51,62 +59,43 @@ describe('mcp-client', () => {
],
}),
} as unknown as GenAiLib.CallableTool);
const tools = await discoverTools('test-server', {}, mockedClient);
expect(tools.length).toBe(1);
const mockedToolRegistry = {
registerTool: vi.fn(),
} as unknown as ToolRegistry;
const client = new McpClient(
'test-server',
{
command: 'test-command',
},
mockedToolRegistry,
{} as PromptRegistry,
{} as WorkspaceContext,
false,
);
await client.connect();
await client.discover();
expect(mockedMcpToTool).toHaveBeenCalledOnce();
});
it('should log an error if there is an error discovering a tool', async () => {
const mockedClient = {} as unknown as ClientLib.Client;
const consoleErrorSpy = vi
.spyOn(console, 'error')
.mockImplementation(() => {});
const testError = new Error('Invalid tool name');
vi.mocked(DiscoveredMCPTool).mockImplementation(
(
_mcpCallableTool: GenAiLib.CallableTool,
_serverName: string,
name: string,
) => {
if (name === 'invalid tool name') {
throw testError;
}
return { name: 'validTool' } as DiscoveredMCPTool;
},
);
vi.mocked(GenAiLib.mcpToTool).mockReturnValue({
tool: () =>
Promise.resolve({
functionDeclarations: [
{
name: 'validTool',
},
{
name: 'invalid tool name', // this will fail validation
},
],
}),
} as unknown as GenAiLib.CallableTool);
const tools = await discoverTools('test-server', {}, mockedClient);
expect(tools.length).toBe(1);
expect(tools[0].name).toBe('validTool');
expect(consoleErrorSpy).toHaveBeenCalledOnce();
expect(consoleErrorSpy).toHaveBeenCalledWith(
`Error discovering tool: 'invalid tool name' from MCP server 'test-server': ${testError.message}`,
);
});
it('should skip tools if a parameter is missing a type', async () => {
const mockedClient = {} as unknown as ClientLib.Client;
const consoleWarnSpy = vi
.spyOn(console, 'warn')
.mockImplementation(() => {});
const mockedClient = {
connect: vi.fn(),
discover: vi.fn(),
disconnect: vi.fn(),
getStatus: vi.fn(),
registerCapabilities: vi.fn(),
setRequestHandler: vi.fn(),
tool: vi.fn(),
};
vi.mocked(ClientLib.Client).mockReturnValue(
mockedClient as unknown as ClientLib.Client,
);
vi.spyOn(SdkClientStdioLib, 'StdioClientTransport').mockReturnValue(
{} as SdkClientStdioLib.StdioClientTransport,
);
vi.mocked(GenAiLib.mcpToTool).mockReturnValue({
tool: () =>
Promise.resolve({
@@ -132,352 +121,73 @@ describe('mcp-client', () => {
],
}),
} as unknown as GenAiLib.CallableTool);
const tools = await discoverTools('test-server', {}, mockedClient);
expect(tools.length).toBe(1);
expect(vi.mocked(DiscoveredMCPTool).mock.calls[0][2]).toBe('validTool');
expect(consoleWarnSpy).toHaveBeenCalledOnce();
expect(consoleWarnSpy).toHaveBeenCalledWith(
`Skipping tool 'invalidTool' from MCP server 'test-server' because it has ` +
`missing types in its parameter schema. Please file an issue with the owner of the MCP server.`,
);
consoleWarnSpy.mockRestore();
});
it('should skip tools if a nested parameter is missing a type', async () => {
const mockedClient = {} as unknown as ClientLib.Client;
const consoleWarnSpy = vi
.spyOn(console, 'warn')
.mockImplementation(() => {});
vi.mocked(GenAiLib.mcpToTool).mockReturnValue({
tool: () =>
Promise.resolve({
functionDeclarations: [
{
name: 'invalidTool',
parametersJsonSchema: {
type: 'object',
properties: {
param1: {
type: 'object',
properties: {
nestedParam: {
description: 'a nested param with no type',
},
},
},
},
},
},
],
}),
} as unknown as GenAiLib.CallableTool);
const tools = await discoverTools('test-server', {}, mockedClient);
expect(tools.length).toBe(0);
expect(consoleWarnSpy).toHaveBeenCalledOnce();
expect(consoleWarnSpy).toHaveBeenCalledWith(
`Skipping tool 'invalidTool' from MCP server 'test-server' because it has ` +
`missing types in its parameter schema. Please file an issue with the owner of the MCP server.`,
);
consoleWarnSpy.mockRestore();
});
it('should skip tool if an array item is missing a type', async () => {
const mockedClient = {} as unknown as ClientLib.Client;
const consoleWarnSpy = vi
.spyOn(console, 'warn')
.mockImplementation(() => {});
vi.mocked(GenAiLib.mcpToTool).mockReturnValue({
tool: () =>
Promise.resolve({
functionDeclarations: [
{
name: 'invalidTool',
parametersJsonSchema: {
type: 'object',
properties: {
param1: {
type: 'array',
items: {
description: 'an array item with no type',
},
},
},
},
},
],
}),
} as unknown as GenAiLib.CallableTool);
const tools = await discoverTools('test-server', {}, mockedClient);
expect(tools.length).toBe(0);
expect(consoleWarnSpy).toHaveBeenCalledOnce();
expect(consoleWarnSpy).toHaveBeenCalledWith(
`Skipping tool 'invalidTool' from MCP server 'test-server' because it has ` +
`missing types in its parameter schema. Please file an issue with the owner of the MCP server.`,
);
consoleWarnSpy.mockRestore();
});
it('should discover tool with no properties in schema', async () => {
const mockedClient = {} as unknown as ClientLib.Client;
const consoleWarnSpy = vi
.spyOn(console, 'warn')
.mockImplementation(() => {});
vi.mocked(GenAiLib.mcpToTool).mockReturnValue({
tool: () =>
Promise.resolve({
functionDeclarations: [
{
name: 'validTool',
parametersJsonSchema: {
type: 'object',
},
},
],
}),
} as unknown as GenAiLib.CallableTool);
const tools = await discoverTools('test-server', {}, mockedClient);
expect(tools.length).toBe(1);
expect(vi.mocked(DiscoveredMCPTool).mock.calls[0][2]).toBe('validTool');
expect(consoleWarnSpy).not.toHaveBeenCalled();
consoleWarnSpy.mockRestore();
});
it('should discover tool with empty properties object in schema', async () => {
const mockedClient = {} as unknown as ClientLib.Client;
const consoleWarnSpy = vi
.spyOn(console, 'warn')
.mockImplementation(() => {});
vi.mocked(GenAiLib.mcpToTool).mockReturnValue({
tool: () =>
Promise.resolve({
functionDeclarations: [
{
name: 'validTool',
parametersJsonSchema: {
type: 'object',
properties: {},
},
},
],
}),
} as unknown as GenAiLib.CallableTool);
const tools = await discoverTools('test-server', {}, mockedClient);
expect(tools.length).toBe(1);
expect(vi.mocked(DiscoveredMCPTool).mock.calls[0][2]).toBe('validTool');
expect(consoleWarnSpy).not.toHaveBeenCalled();
consoleWarnSpy.mockRestore();
});
});
describe('connectToMcpServer', () => {
it('should send a notification when directories change', async () => {
const mockedClient = {
registerCapabilities: vi.fn(),
setRequestHandler: vi.fn(),
notification: vi.fn(),
callTool: vi.fn(),
connect: vi.fn(),
};
vi.mocked(ClientLib.Client).mockReturnValue(
mockedClient as unknown as ClientLib.Client,
);
vi.spyOn(SdkClientStdioLib, 'StdioClientTransport').mockReturnValue(
{} as SdkClientStdioLib.StdioClientTransport,
);
let onDirectoriesChangedCallback: () => void = () => {};
const mockWorkspaceContext = {
getDirectories: vi
.fn()
.mockReturnValue(['/test/dir', '/another/project']),
onDirectoriesChanged: vi.fn().mockImplementation((callback) => {
onDirectoriesChangedCallback = callback;
}),
} as unknown as WorkspaceContext;
await connectToMcpServer(
const mockedToolRegistry = {
registerTool: vi.fn(),
} as unknown as ToolRegistry;
const client = new McpClient(
'test-server',
{
command: 'test-command',
},
mockedToolRegistry,
{} as PromptRegistry,
{} as WorkspaceContext,
false,
mockWorkspaceContext,
);
onDirectoriesChangedCallback();
expect(mockedClient.notification).toHaveBeenCalledWith({
method: 'notifications/roots/list_changed',
});
await client.connect();
await client.discover();
expect(mockedToolRegistry.registerTool).toHaveBeenCalledOnce();
expect(consoleWarnSpy).toHaveBeenCalledOnce();
expect(consoleWarnSpy).toHaveBeenCalledWith(
`Skipping tool 'invalidTool' from MCP server 'test-server' because it has ` +
`missing types in its parameter schema. Please file an issue with the owner of the MCP server.`,
);
consoleWarnSpy.mockRestore();
});
it('should register a roots/list handler', async () => {
const mockedClient = {
registerCapabilities: vi.fn(),
setRequestHandler: vi.fn(),
callTool: vi.fn(),
connect: vi.fn(),
};
vi.mocked(ClientLib.Client).mockReturnValue(
mockedClient as unknown as ClientLib.Client,
);
vi.spyOn(SdkClientStdioLib, 'StdioClientTransport').mockReturnValue(
{} as SdkClientStdioLib.StdioClientTransport,
);
const mockWorkspaceContext = {
getDirectories: vi
.fn()
.mockReturnValue(['/test/dir', '/another/project']),
onDirectoriesChanged: vi.fn(),
} as unknown as WorkspaceContext;
await connectToMcpServer(
'test-server',
{
command: 'test-command',
},
false,
mockWorkspaceContext,
);
expect(mockedClient.registerCapabilities).toHaveBeenCalledWith({
roots: {
listChanged: true,
},
});
expect(mockedClient.setRequestHandler).toHaveBeenCalledOnce();
const handler = mockedClient.setRequestHandler.mock.calls[0][1];
const roots = await handler();
expect(roots).toEqual({
roots: [
{
uri: pathToFileURL('/test/dir').toString(),
name: 'dir',
},
{
uri: pathToFileURL('/another/project').toString(),
name: 'project',
},
],
});
});
});
describe('discoverPrompts', () => {
const mockedPromptRegistry = {
registerPrompt: vi.fn(),
} as unknown as PromptRegistry;
it('should discover and log prompts', async () => {
const mockRequest = vi.fn().mockResolvedValue({
prompts: [
{ name: 'prompt1', description: 'desc1' },
{ name: 'prompt2' },
],
});
const mockGetServerCapabilities = vi.fn().mockReturnValue({
prompts: {},
});
const mockedClient = {
getServerCapabilities: mockGetServerCapabilities,
request: mockRequest,
} as unknown as ClientLib.Client;
await discoverPrompts('test-server', mockedClient, mockedPromptRegistry);
expect(mockGetServerCapabilities).toHaveBeenCalledOnce();
expect(mockRequest).toHaveBeenCalledWith(
{ method: 'prompts/list', params: {} },
expect.anything(),
);
});
it('should do nothing if no prompts are discovered', async () => {
const mockRequest = vi.fn().mockResolvedValue({
prompts: [],
});
const mockGetServerCapabilities = vi.fn().mockReturnValue({
prompts: {},
});
const mockedClient = {
getServerCapabilities: mockGetServerCapabilities,
request: mockRequest,
} as unknown as ClientLib.Client;
const consoleLogSpy = vi
.spyOn(console, 'debug')
.mockImplementation(() => {});
await discoverPrompts('test-server', mockedClient, mockedPromptRegistry);
expect(mockGetServerCapabilities).toHaveBeenCalledOnce();
expect(mockRequest).toHaveBeenCalledOnce();
expect(consoleLogSpy).not.toHaveBeenCalled();
consoleLogSpy.mockRestore();
});
it('should do nothing if the server has no prompt support', async () => {
const mockRequest = vi.fn().mockResolvedValue({
prompts: [],
});
const mockGetServerCapabilities = vi.fn().mockReturnValue({});
const mockedClient = {
getServerCapabilities: mockGetServerCapabilities,
request: mockRequest,
} as unknown as ClientLib.Client;
const consoleLogSpy = vi
.spyOn(console, 'debug')
.mockImplementation(() => {});
await discoverPrompts('test-server', mockedClient, mockedPromptRegistry);
expect(mockGetServerCapabilities).toHaveBeenCalledOnce();
expect(mockRequest).not.toHaveBeenCalled();
expect(consoleLogSpy).not.toHaveBeenCalled();
consoleLogSpy.mockRestore();
});
it('should log an error if discovery fails', async () => {
const testError = new Error('test error');
testError.message = 'test error';
const mockRequest = vi.fn().mockRejectedValue(testError);
const mockGetServerCapabilities = vi.fn().mockReturnValue({
prompts: {},
});
const mockedClient = {
getServerCapabilities: mockGetServerCapabilities,
request: mockRequest,
} as unknown as ClientLib.Client;
it('should handle errors when discovering prompts', async () => {
const consoleErrorSpy = vi
.spyOn(console, 'error')
.mockImplementation(() => {});
await discoverPrompts('test-server', mockedClient, mockedPromptRegistry);
expect(mockRequest).toHaveBeenCalledOnce();
expect(consoleErrorSpy).toHaveBeenCalledWith(
`Error discovering prompts from test-server: ${testError.message}`,
const mockedClient = {
connect: vi.fn(),
discover: vi.fn(),
disconnect: vi.fn(),
getStatus: vi.fn(),
registerCapabilities: vi.fn(),
setRequestHandler: vi.fn(),
getServerCapabilities: vi.fn().mockReturnValue({ prompts: {} }),
request: vi.fn().mockRejectedValue(new Error('Test error')),
};
vi.mocked(ClientLib.Client).mockReturnValue(
mockedClient as unknown as ClientLib.Client,
);
vi.spyOn(SdkClientStdioLib, 'StdioClientTransport').mockReturnValue(
{} as SdkClientStdioLib.StdioClientTransport,
);
vi.mocked(GenAiLib.mcpToTool).mockReturnValue({
tool: () => Promise.resolve({ functionDeclarations: [] }),
} as unknown as GenAiLib.CallableTool);
const client = new McpClient(
'test-server',
{
command: 'test-command',
},
{} as ToolRegistry,
{} as PromptRegistry,
{} as WorkspaceContext,
false,
);
await client.connect();
await expect(client.discover()).rejects.toThrow(
'No prompts or tools found on the server.',
);
expect(consoleErrorSpy).toHaveBeenCalledWith(
`Error discovering prompts from test-server: Test error`,
);
consoleErrorSpy.mockRestore();
});
});
describe('appendMcpServerCommand', () => {
it('should do nothing if no MCP servers or command are configured', () => {
const out = populateMcpServerCommand({}, undefined);
@@ -501,17 +211,6 @@ describe('mcp-client', () => {
});
describe('createTransport', () => {
const originalEnv = process.env;
beforeEach(() => {
vi.resetModules();
process.env = {};
});
afterEach(() => {
process.env = originalEnv;
});
describe('should connect via httpUrl', () => {
it('without headers', async () => {
const transport = await createTransport(
@@ -601,7 +300,7 @@ describe('mcp-client', () => {
command: 'test-command',
args: ['--foo', 'bar'],
cwd: 'test/cwd',
env: { FOO: 'bar' },
env: { ...process.env, FOO: 'bar' },
stderr: 'pipe',
});
});