mirror of
https://github.com/QwenLM/qwen-code.git
synced 2025-12-20 16:57:46 +00:00
refactor: Update MCP tool discovery to use @google/genai - Also fixes JSON schema issues. (#682)
This commit is contained in:
@@ -16,7 +16,6 @@ import {
|
||||
} from 'vitest';
|
||||
import { discoverMcpTools } from './mcp-client.js';
|
||||
import { Config, MCPServerConfig } from '../config/config.js';
|
||||
import { ToolRegistry } from './tool-registry.js';
|
||||
import { DiscoveredMCPTool } from './mcp-tool.js';
|
||||
import { Client } from '@modelcontextprotocol/sdk/client/index.js';
|
||||
import { StdioClientTransport } from '@modelcontextprotocol/sdk/client/stdio.js';
|
||||
@@ -51,33 +50,56 @@ vi.mock('@modelcontextprotocol/sdk/client/stdio.js', () => {
|
||||
// Always return a new object with a fresh reference to the global mock for .on
|
||||
this.options = options;
|
||||
this.stderr = { on: mockGlobalStdioStderrOn };
|
||||
this.close = vi.fn().mockResolvedValue(undefined); // Add mock close method
|
||||
return this;
|
||||
});
|
||||
return { StdioClientTransport: MockedStdioTransport };
|
||||
});
|
||||
|
||||
vi.mock('@modelcontextprotocol/sdk/client/sse.js', () => {
|
||||
const MockedSSETransport = vi.fn();
|
||||
const MockedSSETransport = vi.fn().mockImplementation(function (this: any) {
|
||||
this.close = vi.fn().mockResolvedValue(undefined); // Add mock close method
|
||||
return this;
|
||||
});
|
||||
return { SSEClientTransport: MockedSSETransport };
|
||||
});
|
||||
|
||||
vi.mock('./tool-registry.js');
|
||||
const mockToolRegistryInstance = {
|
||||
registerTool: vi.fn(),
|
||||
getToolsByServer: vi.fn().mockReturnValue([]), // Default to empty array
|
||||
// Add other methods if they are called by the code under test, with default mocks
|
||||
getTool: vi.fn(),
|
||||
getAllTools: vi.fn().mockReturnValue([]),
|
||||
getFunctionDeclarations: vi.fn().mockReturnValue([]),
|
||||
discoverTools: vi.fn().mockResolvedValue(undefined),
|
||||
};
|
||||
vi.mock('./tool-registry.js', () => ({
|
||||
ToolRegistry: vi.fn(() => mockToolRegistryInstance),
|
||||
}));
|
||||
|
||||
describe('discoverMcpTools', () => {
|
||||
let mockConfig: Mocked<Config>;
|
||||
let mockToolRegistry: Mocked<ToolRegistry>;
|
||||
// Use the instance from the module mock
|
||||
let mockToolRegistry: typeof mockToolRegistryInstance;
|
||||
|
||||
beforeEach(() => {
|
||||
// Assign the shared mock instance to the test-scoped variable
|
||||
mockToolRegistry = mockToolRegistryInstance;
|
||||
// Reset individual spies on the shared instance before each test
|
||||
mockToolRegistry.registerTool.mockClear();
|
||||
mockToolRegistry.getToolsByServer.mockClear().mockReturnValue([]); // Reset to default
|
||||
mockToolRegistry.getTool.mockClear().mockReturnValue(undefined); // Default to no existing tool
|
||||
mockToolRegistry.getAllTools.mockClear().mockReturnValue([]);
|
||||
mockToolRegistry.getFunctionDeclarations.mockClear().mockReturnValue([]);
|
||||
mockToolRegistry.discoverTools.mockClear().mockResolvedValue(undefined);
|
||||
|
||||
mockConfig = {
|
||||
getMcpServers: vi.fn().mockReturnValue({}),
|
||||
getMcpServerCommand: vi.fn().mockReturnValue(undefined),
|
||||
// getToolRegistry should now return the same shared mock instance
|
||||
getToolRegistry: vi.fn(() => mockToolRegistry),
|
||||
} as any;
|
||||
|
||||
mockToolRegistry = new (ToolRegistry as any)(
|
||||
mockConfig,
|
||||
) as Mocked<ToolRegistry>;
|
||||
mockToolRegistry.registerTool = vi.fn();
|
||||
|
||||
vi.mocked(parse).mockClear();
|
||||
vi.mocked(Client).mockClear();
|
||||
vi.mocked(Client.prototype.connect)
|
||||
@@ -88,9 +110,24 @@ describe('discoverMcpTools', () => {
|
||||
.mockResolvedValue({ tools: [] });
|
||||
|
||||
vi.mocked(StdioClientTransport).mockClear();
|
||||
// Ensure the StdioClientTransport mock constructor returns an object with a close method
|
||||
vi.mocked(StdioClientTransport).mockImplementation(function (
|
||||
this: any,
|
||||
options: any,
|
||||
) {
|
||||
this.options = options;
|
||||
this.stderr = { on: mockGlobalStdioStderrOn };
|
||||
this.close = vi.fn().mockResolvedValue(undefined);
|
||||
return this;
|
||||
});
|
||||
mockGlobalStdioStderrOn.mockClear(); // Clear the global mock in beforeEach
|
||||
|
||||
vi.mocked(SSEClientTransport).mockClear();
|
||||
// Ensure the SSEClientTransport mock constructor returns an object with a close method
|
||||
vi.mocked(SSEClientTransport).mockImplementation(function (this: any) {
|
||||
this.close = vi.fn().mockResolvedValue(undefined);
|
||||
return this;
|
||||
});
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
@@ -98,7 +135,7 @@ describe('discoverMcpTools', () => {
|
||||
});
|
||||
|
||||
it('should do nothing if no MCP servers or command are configured', async () => {
|
||||
await discoverMcpTools(mockConfig, mockToolRegistry);
|
||||
await discoverMcpTools(mockConfig);
|
||||
expect(mockConfig.getMcpServers).toHaveBeenCalledTimes(1);
|
||||
expect(mockConfig.getMcpServerCommand).toHaveBeenCalledTimes(1);
|
||||
expect(Client).not.toHaveBeenCalled();
|
||||
@@ -120,7 +157,11 @@ describe('discoverMcpTools', () => {
|
||||
tools: [mockTool],
|
||||
});
|
||||
|
||||
await discoverMcpTools(mockConfig, mockToolRegistry);
|
||||
// PRE-MOCK getToolsByServer for the expected server name
|
||||
// In this case, listTools fails, so no tools are registered.
|
||||
// The default mock `mockReturnValue([])` from beforeEach should apply.
|
||||
|
||||
await discoverMcpTools(mockConfig);
|
||||
|
||||
expect(parse).toHaveBeenCalledWith(commandString, process.env);
|
||||
expect(StdioClientTransport).toHaveBeenCalledWith({
|
||||
@@ -158,7 +199,12 @@ describe('discoverMcpTools', () => {
|
||||
tools: [mockTool],
|
||||
});
|
||||
|
||||
await discoverMcpTools(mockConfig, mockToolRegistry);
|
||||
// PRE-MOCK getToolsByServer for the expected server name
|
||||
mockToolRegistry.getToolsByServer.mockReturnValueOnce([
|
||||
expect.any(DiscoveredMCPTool),
|
||||
]);
|
||||
|
||||
await discoverMcpTools(mockConfig);
|
||||
|
||||
expect(StdioClientTransport).toHaveBeenCalledWith({
|
||||
command: serverConfig.command,
|
||||
@@ -188,7 +234,12 @@ describe('discoverMcpTools', () => {
|
||||
tools: [mockTool],
|
||||
});
|
||||
|
||||
await discoverMcpTools(mockConfig, mockToolRegistry);
|
||||
// PRE-MOCK getToolsByServer for the expected server name
|
||||
mockToolRegistry.getToolsByServer.mockReturnValueOnce([
|
||||
expect.any(DiscoveredMCPTool),
|
||||
]);
|
||||
|
||||
await discoverMcpTools(mockConfig);
|
||||
|
||||
expect(SSEClientTransport).toHaveBeenCalledWith(new URL(serverConfig.url!));
|
||||
expect(mockToolRegistry.registerTool).toHaveBeenCalledWith(
|
||||
@@ -208,32 +259,96 @@ describe('discoverMcpTools', () => {
|
||||
});
|
||||
|
||||
const mockTool1 = {
|
||||
name: 'toolA',
|
||||
name: 'toolA', // Same original name
|
||||
description: 'd1',
|
||||
inputSchema: { type: 'object' as const, properties: {} },
|
||||
};
|
||||
const mockTool2 = {
|
||||
name: 'toolB',
|
||||
name: 'toolA', // Same original name
|
||||
description: 'd2',
|
||||
inputSchema: { type: 'object' as const, properties: {} },
|
||||
};
|
||||
const mockToolB = {
|
||||
name: 'toolB',
|
||||
description: 'dB',
|
||||
inputSchema: { type: 'object' as const, properties: {} },
|
||||
};
|
||||
|
||||
vi.mocked(Client.prototype.listTools)
|
||||
.mockResolvedValueOnce({ tools: [mockTool1] })
|
||||
.mockResolvedValueOnce({ tools: [mockTool2] });
|
||||
.mockResolvedValueOnce({ tools: [mockTool1, mockToolB] }) // Tools for server1
|
||||
.mockResolvedValueOnce({ tools: [mockTool2] }); // Tool for server2 (toolA)
|
||||
|
||||
await discoverMcpTools(mockConfig, mockToolRegistry);
|
||||
const effectivelyRegisteredTools = new Map<string, any>();
|
||||
|
||||
expect(mockToolRegistry.registerTool).toHaveBeenCalledTimes(2);
|
||||
const registeredTool1 = mockToolRegistry.registerTool.mock
|
||||
.calls[0][0] as DiscoveredMCPTool;
|
||||
const registeredTool2 = mockToolRegistry.registerTool.mock
|
||||
.calls[1][0] as DiscoveredMCPTool;
|
||||
mockToolRegistry.getTool.mockImplementation((toolName: string) =>
|
||||
effectivelyRegisteredTools.get(toolName),
|
||||
);
|
||||
|
||||
expect(registeredTool1.name).toBe('server1__toolA');
|
||||
expect(registeredTool1.serverToolName).toBe('toolA');
|
||||
expect(registeredTool2.name).toBe('server2__toolB');
|
||||
expect(registeredTool2.serverToolName).toBe('toolB');
|
||||
// Store the original spy implementation if needed, or just let the new one be the behavior.
|
||||
// The mockToolRegistry.registerTool is already a vi.fn() from mockToolRegistryInstance.
|
||||
// We are setting its behavior for this test.
|
||||
mockToolRegistry.registerTool.mockImplementation((toolToRegister: any) => {
|
||||
// Simulate the actual registration name being stored for getTool to find
|
||||
effectivelyRegisteredTools.set(toolToRegister.name, toolToRegister);
|
||||
// If it's the first time toolA is registered (from server1, not prefixed),
|
||||
// also make it findable by its original name for the prefixing check of server2/toolA.
|
||||
if (
|
||||
toolToRegister.serverName === 'server1' &&
|
||||
toolToRegister.serverToolName === 'toolA' &&
|
||||
toolToRegister.name === 'toolA'
|
||||
) {
|
||||
effectivelyRegisteredTools.set('toolA', toolToRegister);
|
||||
}
|
||||
// The spy call count is inherently tracked by mockToolRegistry.registerTool itself.
|
||||
});
|
||||
|
||||
// PRE-MOCK getToolsByServer for the expected server names
|
||||
// This is for the final check in connectAndDiscover to see if any tools were registered *from that server*
|
||||
mockToolRegistry.getToolsByServer.mockImplementation(
|
||||
(serverName: string) => {
|
||||
if (serverName === 'server1')
|
||||
return [
|
||||
expect.objectContaining({ name: 'toolA' }),
|
||||
expect.objectContaining({ name: 'toolB' }),
|
||||
];
|
||||
if (serverName === 'server2')
|
||||
return [expect.objectContaining({ name: 'server2__toolA' })];
|
||||
return [];
|
||||
},
|
||||
);
|
||||
|
||||
await discoverMcpTools(mockConfig);
|
||||
|
||||
expect(mockToolRegistry.registerTool).toHaveBeenCalledTimes(3);
|
||||
const registeredArgs = mockToolRegistry.registerTool.mock.calls.map(
|
||||
(call) => call[0],
|
||||
) as DiscoveredMCPTool[];
|
||||
|
||||
// The order of server processing by Promise.all is not guaranteed.
|
||||
// One 'toolA' will be unprefixed, the other will be prefixed.
|
||||
const toolA_from_server1 = registeredArgs.find(
|
||||
(t) => t.serverToolName === 'toolA' && t.serverName === 'server1',
|
||||
);
|
||||
const toolA_from_server2 = registeredArgs.find(
|
||||
(t) => t.serverToolName === 'toolA' && t.serverName === 'server2',
|
||||
);
|
||||
const toolB_from_server1 = registeredArgs.find(
|
||||
(t) => t.serverToolName === 'toolB' && t.serverName === 'server1',
|
||||
);
|
||||
|
||||
expect(toolA_from_server1).toBeDefined();
|
||||
expect(toolA_from_server2).toBeDefined();
|
||||
expect(toolB_from_server1).toBeDefined();
|
||||
|
||||
expect(toolB_from_server1?.name).toBe('toolB'); // toolB is unique
|
||||
|
||||
// Check that one of toolA is prefixed and the other is not, and the prefixed one is correct.
|
||||
if (toolA_from_server1?.name === 'toolA') {
|
||||
expect(toolA_from_server2?.name).toBe('server2__toolA');
|
||||
} else {
|
||||
expect(toolA_from_server1?.name).toBe('server1__toolA');
|
||||
expect(toolA_from_server2?.name).toBe('toolA');
|
||||
}
|
||||
});
|
||||
|
||||
it('should clean schema properties ($schema, additionalProperties)', async () => {
|
||||
@@ -261,8 +376,12 @@ describe('discoverMcpTools', () => {
|
||||
vi.mocked(Client.prototype.listTools).mockResolvedValue({
|
||||
tools: [mockTool],
|
||||
});
|
||||
// PRE-MOCK getToolsByServer for the expected server name
|
||||
mockToolRegistry.getToolsByServer.mockReturnValueOnce([
|
||||
expect.any(DiscoveredMCPTool),
|
||||
]);
|
||||
|
||||
await discoverMcpTools(mockConfig, mockToolRegistry);
|
||||
await discoverMcpTools(mockConfig);
|
||||
|
||||
expect(mockToolRegistry.registerTool).toHaveBeenCalledTimes(1);
|
||||
const registeredTool = mockToolRegistry.registerTool.mock
|
||||
@@ -291,9 +410,9 @@ describe('discoverMcpTools', () => {
|
||||
});
|
||||
vi.spyOn(console, 'error').mockImplementation(() => {});
|
||||
|
||||
await expect(
|
||||
discoverMcpTools(mockConfig, mockToolRegistry),
|
||||
).rejects.toThrow('Parsing failed');
|
||||
await expect(discoverMcpTools(mockConfig)).rejects.toThrow(
|
||||
'Parsing failed',
|
||||
);
|
||||
expect(mockToolRegistry.registerTool).not.toHaveBeenCalled();
|
||||
expect(console.error).not.toHaveBeenCalled();
|
||||
});
|
||||
@@ -302,7 +421,7 @@ describe('discoverMcpTools', () => {
|
||||
mockConfig.getMcpServers.mockReturnValue({ 'bad-server': {} as any });
|
||||
vi.spyOn(console, 'error').mockImplementation(() => {});
|
||||
|
||||
await discoverMcpTools(mockConfig, mockToolRegistry);
|
||||
await discoverMcpTools(mockConfig);
|
||||
|
||||
expect(console.error).toHaveBeenCalledWith(
|
||||
expect.stringContaining(
|
||||
@@ -323,7 +442,7 @@ describe('discoverMcpTools', () => {
|
||||
);
|
||||
vi.spyOn(console, 'error').mockImplementation(() => {});
|
||||
|
||||
await discoverMcpTools(mockConfig, mockToolRegistry);
|
||||
await discoverMcpTools(mockConfig);
|
||||
|
||||
expect(console.error).toHaveBeenCalledWith(
|
||||
expect.stringContaining(
|
||||
@@ -344,7 +463,7 @@ describe('discoverMcpTools', () => {
|
||||
);
|
||||
vi.spyOn(console, 'error').mockImplementation(() => {});
|
||||
|
||||
await discoverMcpTools(mockConfig, mockToolRegistry);
|
||||
await discoverMcpTools(mockConfig);
|
||||
|
||||
expect(console.error).toHaveBeenCalledWith(
|
||||
expect.stringContaining(
|
||||
@@ -359,8 +478,12 @@ describe('discoverMcpTools', () => {
|
||||
mockConfig.getMcpServers.mockReturnValue({
|
||||
'onerror-server': serverConfig,
|
||||
});
|
||||
// PRE-MOCK getToolsByServer for the expected server name
|
||||
mockToolRegistry.getToolsByServer.mockReturnValueOnce([
|
||||
expect.any(DiscoveredMCPTool),
|
||||
]);
|
||||
|
||||
await discoverMcpTools(mockConfig, mockToolRegistry);
|
||||
await discoverMcpTools(mockConfig);
|
||||
|
||||
const clientInstances = vi.mocked(Client).mock.results;
|
||||
expect(clientInstances.length).toBeGreaterThan(0);
|
||||
|
||||
Reference in New Issue
Block a user