refactor: Update MCP tool discovery to use @google/genai - Also fixes JSON schema issues. (#682)

This commit is contained in:
N. Taylor Mullen
2025-06-02 13:39:25 -07:00
committed by GitHub
parent 0795e55f0e
commit 58597c29d3
7 changed files with 744 additions and 812 deletions

View File

@@ -16,7 +16,6 @@ import {
} from 'vitest';
import { discoverMcpTools } from './mcp-client.js';
import { Config, MCPServerConfig } from '../config/config.js';
import { ToolRegistry } from './tool-registry.js';
import { DiscoveredMCPTool } from './mcp-tool.js';
import { Client } from '@modelcontextprotocol/sdk/client/index.js';
import { StdioClientTransport } from '@modelcontextprotocol/sdk/client/stdio.js';
@@ -51,33 +50,56 @@ vi.mock('@modelcontextprotocol/sdk/client/stdio.js', () => {
// Always return a new object with a fresh reference to the global mock for .on
this.options = options;
this.stderr = { on: mockGlobalStdioStderrOn };
this.close = vi.fn().mockResolvedValue(undefined); // Add mock close method
return this;
});
return { StdioClientTransport: MockedStdioTransport };
});
vi.mock('@modelcontextprotocol/sdk/client/sse.js', () => {
const MockedSSETransport = vi.fn();
const MockedSSETransport = vi.fn().mockImplementation(function (this: any) {
this.close = vi.fn().mockResolvedValue(undefined); // Add mock close method
return this;
});
return { SSEClientTransport: MockedSSETransport };
});
vi.mock('./tool-registry.js');
const mockToolRegistryInstance = {
registerTool: vi.fn(),
getToolsByServer: vi.fn().mockReturnValue([]), // Default to empty array
// Add other methods if they are called by the code under test, with default mocks
getTool: vi.fn(),
getAllTools: vi.fn().mockReturnValue([]),
getFunctionDeclarations: vi.fn().mockReturnValue([]),
discoverTools: vi.fn().mockResolvedValue(undefined),
};
vi.mock('./tool-registry.js', () => ({
ToolRegistry: vi.fn(() => mockToolRegistryInstance),
}));
describe('discoverMcpTools', () => {
let mockConfig: Mocked<Config>;
let mockToolRegistry: Mocked<ToolRegistry>;
// Use the instance from the module mock
let mockToolRegistry: typeof mockToolRegistryInstance;
beforeEach(() => {
// Assign the shared mock instance to the test-scoped variable
mockToolRegistry = mockToolRegistryInstance;
// Reset individual spies on the shared instance before each test
mockToolRegistry.registerTool.mockClear();
mockToolRegistry.getToolsByServer.mockClear().mockReturnValue([]); // Reset to default
mockToolRegistry.getTool.mockClear().mockReturnValue(undefined); // Default to no existing tool
mockToolRegistry.getAllTools.mockClear().mockReturnValue([]);
mockToolRegistry.getFunctionDeclarations.mockClear().mockReturnValue([]);
mockToolRegistry.discoverTools.mockClear().mockResolvedValue(undefined);
mockConfig = {
getMcpServers: vi.fn().mockReturnValue({}),
getMcpServerCommand: vi.fn().mockReturnValue(undefined),
// getToolRegistry should now return the same shared mock instance
getToolRegistry: vi.fn(() => mockToolRegistry),
} as any;
mockToolRegistry = new (ToolRegistry as any)(
mockConfig,
) as Mocked<ToolRegistry>;
mockToolRegistry.registerTool = vi.fn();
vi.mocked(parse).mockClear();
vi.mocked(Client).mockClear();
vi.mocked(Client.prototype.connect)
@@ -88,9 +110,24 @@ describe('discoverMcpTools', () => {
.mockResolvedValue({ tools: [] });
vi.mocked(StdioClientTransport).mockClear();
// Ensure the StdioClientTransport mock constructor returns an object with a close method
vi.mocked(StdioClientTransport).mockImplementation(function (
this: any,
options: any,
) {
this.options = options;
this.stderr = { on: mockGlobalStdioStderrOn };
this.close = vi.fn().mockResolvedValue(undefined);
return this;
});
mockGlobalStdioStderrOn.mockClear(); // Clear the global mock in beforeEach
vi.mocked(SSEClientTransport).mockClear();
// Ensure the SSEClientTransport mock constructor returns an object with a close method
vi.mocked(SSEClientTransport).mockImplementation(function (this: any) {
this.close = vi.fn().mockResolvedValue(undefined);
return this;
});
});
afterEach(() => {
@@ -98,7 +135,7 @@ describe('discoverMcpTools', () => {
});
it('should do nothing if no MCP servers or command are configured', async () => {
await discoverMcpTools(mockConfig, mockToolRegistry);
await discoverMcpTools(mockConfig);
expect(mockConfig.getMcpServers).toHaveBeenCalledTimes(1);
expect(mockConfig.getMcpServerCommand).toHaveBeenCalledTimes(1);
expect(Client).not.toHaveBeenCalled();
@@ -120,7 +157,11 @@ describe('discoverMcpTools', () => {
tools: [mockTool],
});
await discoverMcpTools(mockConfig, mockToolRegistry);
// PRE-MOCK getToolsByServer for the expected server name
// In this case, listTools fails, so no tools are registered.
// The default mock `mockReturnValue([])` from beforeEach should apply.
await discoverMcpTools(mockConfig);
expect(parse).toHaveBeenCalledWith(commandString, process.env);
expect(StdioClientTransport).toHaveBeenCalledWith({
@@ -158,7 +199,12 @@ describe('discoverMcpTools', () => {
tools: [mockTool],
});
await discoverMcpTools(mockConfig, mockToolRegistry);
// PRE-MOCK getToolsByServer for the expected server name
mockToolRegistry.getToolsByServer.mockReturnValueOnce([
expect.any(DiscoveredMCPTool),
]);
await discoverMcpTools(mockConfig);
expect(StdioClientTransport).toHaveBeenCalledWith({
command: serverConfig.command,
@@ -188,7 +234,12 @@ describe('discoverMcpTools', () => {
tools: [mockTool],
});
await discoverMcpTools(mockConfig, mockToolRegistry);
// PRE-MOCK getToolsByServer for the expected server name
mockToolRegistry.getToolsByServer.mockReturnValueOnce([
expect.any(DiscoveredMCPTool),
]);
await discoverMcpTools(mockConfig);
expect(SSEClientTransport).toHaveBeenCalledWith(new URL(serverConfig.url!));
expect(mockToolRegistry.registerTool).toHaveBeenCalledWith(
@@ -208,32 +259,96 @@ describe('discoverMcpTools', () => {
});
const mockTool1 = {
name: 'toolA',
name: 'toolA', // Same original name
description: 'd1',
inputSchema: { type: 'object' as const, properties: {} },
};
const mockTool2 = {
name: 'toolB',
name: 'toolA', // Same original name
description: 'd2',
inputSchema: { type: 'object' as const, properties: {} },
};
const mockToolB = {
name: 'toolB',
description: 'dB',
inputSchema: { type: 'object' as const, properties: {} },
};
vi.mocked(Client.prototype.listTools)
.mockResolvedValueOnce({ tools: [mockTool1] })
.mockResolvedValueOnce({ tools: [mockTool2] });
.mockResolvedValueOnce({ tools: [mockTool1, mockToolB] }) // Tools for server1
.mockResolvedValueOnce({ tools: [mockTool2] }); // Tool for server2 (toolA)
await discoverMcpTools(mockConfig, mockToolRegistry);
const effectivelyRegisteredTools = new Map<string, any>();
expect(mockToolRegistry.registerTool).toHaveBeenCalledTimes(2);
const registeredTool1 = mockToolRegistry.registerTool.mock
.calls[0][0] as DiscoveredMCPTool;
const registeredTool2 = mockToolRegistry.registerTool.mock
.calls[1][0] as DiscoveredMCPTool;
mockToolRegistry.getTool.mockImplementation((toolName: string) =>
effectivelyRegisteredTools.get(toolName),
);
expect(registeredTool1.name).toBe('server1__toolA');
expect(registeredTool1.serverToolName).toBe('toolA');
expect(registeredTool2.name).toBe('server2__toolB');
expect(registeredTool2.serverToolName).toBe('toolB');
// Store the original spy implementation if needed, or just let the new one be the behavior.
// The mockToolRegistry.registerTool is already a vi.fn() from mockToolRegistryInstance.
// We are setting its behavior for this test.
mockToolRegistry.registerTool.mockImplementation((toolToRegister: any) => {
// Simulate the actual registration name being stored for getTool to find
effectivelyRegisteredTools.set(toolToRegister.name, toolToRegister);
// If it's the first time toolA is registered (from server1, not prefixed),
// also make it findable by its original name for the prefixing check of server2/toolA.
if (
toolToRegister.serverName === 'server1' &&
toolToRegister.serverToolName === 'toolA' &&
toolToRegister.name === 'toolA'
) {
effectivelyRegisteredTools.set('toolA', toolToRegister);
}
// The spy call count is inherently tracked by mockToolRegistry.registerTool itself.
});
// PRE-MOCK getToolsByServer for the expected server names
// This is for the final check in connectAndDiscover to see if any tools were registered *from that server*
mockToolRegistry.getToolsByServer.mockImplementation(
(serverName: string) => {
if (serverName === 'server1')
return [
expect.objectContaining({ name: 'toolA' }),
expect.objectContaining({ name: 'toolB' }),
];
if (serverName === 'server2')
return [expect.objectContaining({ name: 'server2__toolA' })];
return [];
},
);
await discoverMcpTools(mockConfig);
expect(mockToolRegistry.registerTool).toHaveBeenCalledTimes(3);
const registeredArgs = mockToolRegistry.registerTool.mock.calls.map(
(call) => call[0],
) as DiscoveredMCPTool[];
// The order of server processing by Promise.all is not guaranteed.
// One 'toolA' will be unprefixed, the other will be prefixed.
const toolA_from_server1 = registeredArgs.find(
(t) => t.serverToolName === 'toolA' && t.serverName === 'server1',
);
const toolA_from_server2 = registeredArgs.find(
(t) => t.serverToolName === 'toolA' && t.serverName === 'server2',
);
const toolB_from_server1 = registeredArgs.find(
(t) => t.serverToolName === 'toolB' && t.serverName === 'server1',
);
expect(toolA_from_server1).toBeDefined();
expect(toolA_from_server2).toBeDefined();
expect(toolB_from_server1).toBeDefined();
expect(toolB_from_server1?.name).toBe('toolB'); // toolB is unique
// Check that one of toolA is prefixed and the other is not, and the prefixed one is correct.
if (toolA_from_server1?.name === 'toolA') {
expect(toolA_from_server2?.name).toBe('server2__toolA');
} else {
expect(toolA_from_server1?.name).toBe('server1__toolA');
expect(toolA_from_server2?.name).toBe('toolA');
}
});
it('should clean schema properties ($schema, additionalProperties)', async () => {
@@ -261,8 +376,12 @@ describe('discoverMcpTools', () => {
vi.mocked(Client.prototype.listTools).mockResolvedValue({
tools: [mockTool],
});
// PRE-MOCK getToolsByServer for the expected server name
mockToolRegistry.getToolsByServer.mockReturnValueOnce([
expect.any(DiscoveredMCPTool),
]);
await discoverMcpTools(mockConfig, mockToolRegistry);
await discoverMcpTools(mockConfig);
expect(mockToolRegistry.registerTool).toHaveBeenCalledTimes(1);
const registeredTool = mockToolRegistry.registerTool.mock
@@ -291,9 +410,9 @@ describe('discoverMcpTools', () => {
});
vi.spyOn(console, 'error').mockImplementation(() => {});
await expect(
discoverMcpTools(mockConfig, mockToolRegistry),
).rejects.toThrow('Parsing failed');
await expect(discoverMcpTools(mockConfig)).rejects.toThrow(
'Parsing failed',
);
expect(mockToolRegistry.registerTool).not.toHaveBeenCalled();
expect(console.error).not.toHaveBeenCalled();
});
@@ -302,7 +421,7 @@ describe('discoverMcpTools', () => {
mockConfig.getMcpServers.mockReturnValue({ 'bad-server': {} as any });
vi.spyOn(console, 'error').mockImplementation(() => {});
await discoverMcpTools(mockConfig, mockToolRegistry);
await discoverMcpTools(mockConfig);
expect(console.error).toHaveBeenCalledWith(
expect.stringContaining(
@@ -323,7 +442,7 @@ describe('discoverMcpTools', () => {
);
vi.spyOn(console, 'error').mockImplementation(() => {});
await discoverMcpTools(mockConfig, mockToolRegistry);
await discoverMcpTools(mockConfig);
expect(console.error).toHaveBeenCalledWith(
expect.stringContaining(
@@ -344,7 +463,7 @@ describe('discoverMcpTools', () => {
);
vi.spyOn(console, 'error').mockImplementation(() => {});
await discoverMcpTools(mockConfig, mockToolRegistry);
await discoverMcpTools(mockConfig);
expect(console.error).toHaveBeenCalledWith(
expect.stringContaining(
@@ -359,8 +478,12 @@ describe('discoverMcpTools', () => {
mockConfig.getMcpServers.mockReturnValue({
'onerror-server': serverConfig,
});
// PRE-MOCK getToolsByServer for the expected server name
mockToolRegistry.getToolsByServer.mockReturnValueOnce([
expect.any(DiscoveredMCPTool),
]);
await discoverMcpTools(mockConfig, mockToolRegistry);
await discoverMcpTools(mockConfig);
const clientInstances = vi.mocked(Client).mock.results;
expect(clientInstances.length).toBeGreaterThan(0);