mirror of
https://github.com/QwenLM/qwen-code.git
synced 2025-12-21 01:07:46 +00:00
fix(core): Sanitize tool parameters to fix 400 API errors (#3300)
This commit is contained in:
@@ -14,22 +14,22 @@ import {
|
||||
afterEach,
|
||||
Mocked,
|
||||
} from 'vitest';
|
||||
import { ToolRegistry, DiscoveredTool } from './tool-registry.js';
|
||||
import { DiscoveredMCPTool } from './mcp-tool.js';
|
||||
import {
|
||||
Config,
|
||||
ConfigParameters,
|
||||
MCPServerConfig,
|
||||
ApprovalMode,
|
||||
} from '../config/config.js';
|
||||
ToolRegistry,
|
||||
DiscoveredTool,
|
||||
sanitizeParameters,
|
||||
} from './tool-registry.js';
|
||||
import { DiscoveredMCPTool } from './mcp-tool.js';
|
||||
import { Config, ConfigParameters, ApprovalMode } from '../config/config.js';
|
||||
import { BaseTool, ToolResult } from './tools.js';
|
||||
import {
|
||||
FunctionDeclaration,
|
||||
CallableTool,
|
||||
mcpToTool,
|
||||
Type,
|
||||
Schema,
|
||||
} from '@google/genai';
|
||||
import { execSync } from 'node:child_process';
|
||||
import { spawn } from 'node:child_process';
|
||||
|
||||
// Use vi.hoisted to define the mock function so it can be used in the vi.mock factory
|
||||
const mockDiscoverMcpTools = vi.hoisted(() => vi.fn());
|
||||
@@ -61,7 +61,6 @@ vi.mock('@modelcontextprotocol/sdk/client/index.js', () => {
|
||||
set onerror(handler: any) {
|
||||
mockMcpClientOnError(handler);
|
||||
},
|
||||
// listTools and callTool are no longer directly used by ToolRegistry/discoverMcpTools
|
||||
}));
|
||||
return { Client: MockClient };
|
||||
});
|
||||
@@ -90,7 +89,6 @@ vi.mock('@google/genai', async () => {
|
||||
return {
|
||||
...actualGenai,
|
||||
mcpToTool: vi.fn().mockImplementation(() => ({
|
||||
// Default mock implementation
|
||||
tool: vi.fn().mockResolvedValue({ functionDeclarations: [] }),
|
||||
callTool: vi.fn(),
|
||||
})),
|
||||
@@ -139,6 +137,7 @@ const baseConfigParams: ConfigParameters = {
|
||||
describe('ToolRegistry', () => {
|
||||
let config: Config;
|
||||
let toolRegistry: ToolRegistry;
|
||||
let mockConfigGetToolDiscoveryCommand: ReturnType<typeof vi.spyOn>;
|
||||
|
||||
beforeEach(() => {
|
||||
config = new Config(baseConfigParams);
|
||||
@@ -148,13 +147,19 @@ describe('ToolRegistry', () => {
|
||||
vi.spyOn(console, 'debug').mockImplementation(() => {});
|
||||
vi.spyOn(console, 'log').mockImplementation(() => {});
|
||||
|
||||
// Reset mocks for MCP parts
|
||||
mockMcpClientConnect.mockReset().mockResolvedValue(undefined); // Default connect success
|
||||
mockMcpClientConnect.mockReset().mockResolvedValue(undefined);
|
||||
mockStdioTransportClose.mockReset();
|
||||
mockSseTransportClose.mockReset();
|
||||
vi.mocked(mcpToTool).mockClear();
|
||||
// Default mcpToTool to return a callable tool that returns no functions
|
||||
vi.mocked(mcpToTool).mockReturnValue(createMockCallableTool([]));
|
||||
|
||||
mockConfigGetToolDiscoveryCommand = vi.spyOn(
|
||||
config,
|
||||
'getToolDiscoveryCommand',
|
||||
);
|
||||
vi.spyOn(config, 'getMcpServers');
|
||||
vi.spyOn(config, 'getMcpServerCommand');
|
||||
mockDiscoverMcpTools.mockReset().mockResolvedValue(undefined);
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
@@ -167,21 +172,18 @@ describe('ToolRegistry', () => {
|
||||
toolRegistry.registerTool(tool);
|
||||
expect(toolRegistry.getTool('mock-tool')).toBe(tool);
|
||||
});
|
||||
// ... other registerTool tests
|
||||
});
|
||||
|
||||
describe('getToolsByServer', () => {
|
||||
it('should return an empty array if no tools match the server name', () => {
|
||||
toolRegistry.registerTool(new MockTool()); // A non-MCP tool
|
||||
toolRegistry.registerTool(new MockTool());
|
||||
expect(toolRegistry.getToolsByServer('any-mcp-server')).toEqual([]);
|
||||
});
|
||||
|
||||
it('should return only tools matching the server name', async () => {
|
||||
const server1Name = 'mcp-server-uno';
|
||||
const server2Name = 'mcp-server-dos';
|
||||
|
||||
// Manually register mock MCP tools for this test
|
||||
const mockCallable = {} as CallableTool; // Minimal mock callable
|
||||
const mockCallable = {} as CallableTool;
|
||||
const mcpTool1 = new DiscoveredMCPTool(
|
||||
mockCallable,
|
||||
server1Name,
|
||||
@@ -207,73 +209,87 @@ describe('ToolRegistry', () => {
|
||||
const toolsFromServer1 = toolRegistry.getToolsByServer(server1Name);
|
||||
expect(toolsFromServer1).toHaveLength(1);
|
||||
expect(toolsFromServer1[0].name).toBe(mcpTool1.name);
|
||||
expect((toolsFromServer1[0] as DiscoveredMCPTool).serverName).toBe(
|
||||
server1Name,
|
||||
);
|
||||
|
||||
const toolsFromServer2 = toolRegistry.getToolsByServer(server2Name);
|
||||
expect(toolsFromServer2).toHaveLength(1);
|
||||
expect(toolsFromServer2[0].name).toBe(mcpTool2.name);
|
||||
expect((toolsFromServer2[0] as DiscoveredMCPTool).serverName).toBe(
|
||||
server2Name,
|
||||
);
|
||||
|
||||
expect(toolRegistry.getToolsByServer('non-existent-server')).toEqual([]);
|
||||
});
|
||||
});
|
||||
|
||||
describe('discoverTools', () => {
|
||||
let mockConfigGetToolDiscoveryCommand: ReturnType<typeof vi.spyOn>;
|
||||
let mockConfigGetMcpServers: ReturnType<typeof vi.spyOn>;
|
||||
let mockConfigGetMcpServerCommand: ReturnType<typeof vi.spyOn>;
|
||||
let mockExecSync: ReturnType<typeof vi.mocked<typeof execSync>>;
|
||||
|
||||
beforeEach(() => {
|
||||
mockConfigGetToolDiscoveryCommand = vi.spyOn(
|
||||
config,
|
||||
'getToolDiscoveryCommand',
|
||||
);
|
||||
mockConfigGetMcpServers = vi.spyOn(config, 'getMcpServers');
|
||||
mockConfigGetMcpServerCommand = vi.spyOn(config, 'getMcpServerCommand');
|
||||
mockExecSync = vi.mocked(execSync);
|
||||
toolRegistry = new ToolRegistry(config); // Reset registry
|
||||
// Reset the mock for discoverMcpTools before each test in this suite
|
||||
mockDiscoverMcpTools.mockReset().mockResolvedValue(undefined);
|
||||
});
|
||||
|
||||
it('should discover tools using discovery command', async () => {
|
||||
// ... this test remains largely the same
|
||||
it('should sanitize tool parameters during discovery from command', async () => {
|
||||
const discoveryCommand = 'my-discovery-command';
|
||||
mockConfigGetToolDiscoveryCommand.mockReturnValue(discoveryCommand);
|
||||
const mockToolDeclarations: FunctionDeclaration[] = [
|
||||
{
|
||||
name: 'discovered-tool-1',
|
||||
description: 'A discovered tool',
|
||||
parameters: { type: Type.OBJECT, properties: {} },
|
||||
|
||||
const unsanitizedToolDeclaration: FunctionDeclaration = {
|
||||
name: 'tool-with-bad-format',
|
||||
description: 'A tool with an invalid format property',
|
||||
parameters: {
|
||||
type: Type.OBJECT,
|
||||
properties: {
|
||||
some_string: {
|
||||
type: Type.STRING,
|
||||
format: 'uuid', // This is an unsupported format
|
||||
},
|
||||
},
|
||||
},
|
||||
];
|
||||
mockExecSync.mockReturnValue(
|
||||
Buffer.from(
|
||||
JSON.stringify([{ function_declarations: mockToolDeclarations }]),
|
||||
),
|
||||
);
|
||||
};
|
||||
|
||||
const mockSpawn = vi.mocked(spawn);
|
||||
const mockChildProcess = {
|
||||
stdout: { on: vi.fn() },
|
||||
stderr: { on: vi.fn() },
|
||||
on: vi.fn(),
|
||||
};
|
||||
mockSpawn.mockReturnValue(mockChildProcess as any);
|
||||
|
||||
// Simulate stdout data
|
||||
mockChildProcess.stdout.on.mockImplementation((event, callback) => {
|
||||
if (event === 'data') {
|
||||
callback(
|
||||
Buffer.from(
|
||||
JSON.stringify([
|
||||
{ function_declarations: [unsanitizedToolDeclaration] },
|
||||
]),
|
||||
),
|
||||
);
|
||||
}
|
||||
return mockChildProcess as any;
|
||||
});
|
||||
|
||||
// Simulate process close
|
||||
mockChildProcess.on.mockImplementation((event, callback) => {
|
||||
if (event === 'close') {
|
||||
callback(0);
|
||||
}
|
||||
return mockChildProcess as any;
|
||||
});
|
||||
|
||||
await toolRegistry.discoverTools();
|
||||
expect(execSync).toHaveBeenCalledWith(discoveryCommand);
|
||||
const discoveredTool = toolRegistry.getTool('discovered-tool-1');
|
||||
expect(discoveredTool).toBeInstanceOf(DiscoveredTool);
|
||||
|
||||
const discoveredTool = toolRegistry.getTool('tool-with-bad-format');
|
||||
expect(discoveredTool).toBeDefined();
|
||||
|
||||
const registeredParams = (discoveredTool as DiscoveredTool).schema
|
||||
.parameters as Schema;
|
||||
expect(registeredParams.properties?.['some_string']).toBeDefined();
|
||||
expect(registeredParams.properties?.['some_string']).toHaveProperty(
|
||||
'format',
|
||||
undefined,
|
||||
);
|
||||
});
|
||||
|
||||
it('should discover tools using MCP servers defined in getMcpServers', async () => {
|
||||
mockConfigGetToolDiscoveryCommand.mockReturnValue(undefined);
|
||||
mockConfigGetMcpServerCommand.mockReturnValue(undefined);
|
||||
vi.spyOn(config, 'getMcpServerCommand').mockReturnValue(undefined);
|
||||
const mcpServerConfigVal = {
|
||||
'my-mcp-server': {
|
||||
command: 'mcp-server-cmd',
|
||||
args: ['--port', '1234'],
|
||||
trust: true,
|
||||
} as MCPServerConfig,
|
||||
},
|
||||
};
|
||||
mockConfigGetMcpServers.mockReturnValue(mcpServerConfigVal);
|
||||
vi.spyOn(config, 'getMcpServers').mockReturnValue(mcpServerConfigVal);
|
||||
|
||||
await toolRegistry.discoverTools();
|
||||
|
||||
@@ -282,56 +298,166 @@ describe('ToolRegistry', () => {
|
||||
undefined,
|
||||
toolRegistry,
|
||||
);
|
||||
// We no longer check these as discoverMcpTools is mocked
|
||||
// expect(vi.mocked(mcpToTool)).toHaveBeenCalledTimes(1);
|
||||
// expect(Client).toHaveBeenCalledTimes(1);
|
||||
// expect(StdioClientTransport).toHaveBeenCalledWith({
|
||||
// command: 'mcp-server-cmd',
|
||||
// args: ['--port', '1234'],
|
||||
// env: expect.any(Object),
|
||||
// stderr: 'pipe',
|
||||
// });
|
||||
// expect(mockMcpClientConnect).toHaveBeenCalled();
|
||||
|
||||
// To verify that tools *would* have been registered, we'd need mockDiscoverMcpTools
|
||||
// to call toolRegistry.registerTool, or we test that separately.
|
||||
// For now, we just check that the delegation happened.
|
||||
});
|
||||
|
||||
it('should discover tools using MCP server command from getMcpServerCommand', async () => {
|
||||
it('should discover tools using MCP servers defined in getMcpServers', async () => {
|
||||
mockConfigGetToolDiscoveryCommand.mockReturnValue(undefined);
|
||||
mockConfigGetMcpServers.mockReturnValue({});
|
||||
mockConfigGetMcpServerCommand.mockReturnValue(
|
||||
'mcp-server-start-command --param',
|
||||
);
|
||||
|
||||
await toolRegistry.discoverTools();
|
||||
expect(mockDiscoverMcpTools).toHaveBeenCalledWith(
|
||||
{},
|
||||
'mcp-server-start-command --param',
|
||||
toolRegistry,
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle errors during MCP client connection gracefully and close transport', async () => {
|
||||
mockConfigGetToolDiscoveryCommand.mockReturnValue(undefined);
|
||||
mockConfigGetMcpServers.mockReturnValue({
|
||||
'failing-mcp': { command: 'fail-cmd' } as MCPServerConfig,
|
||||
});
|
||||
|
||||
mockMcpClientConnect.mockRejectedValue(new Error('Connection failed'));
|
||||
|
||||
await toolRegistry.discoverTools();
|
||||
expect(mockDiscoverMcpTools).toHaveBeenCalledWith(
|
||||
{
|
||||
'failing-mcp': { command: 'fail-cmd' },
|
||||
vi.spyOn(config, 'getMcpServerCommand').mockReturnValue(undefined);
|
||||
const mcpServerConfigVal = {
|
||||
'my-mcp-server': {
|
||||
command: 'mcp-server-cmd',
|
||||
args: ['--port', '1234'],
|
||||
trust: true,
|
||||
},
|
||||
};
|
||||
vi.spyOn(config, 'getMcpServers').mockReturnValue(mcpServerConfigVal);
|
||||
|
||||
await toolRegistry.discoverTools();
|
||||
|
||||
expect(mockDiscoverMcpTools).toHaveBeenCalledWith(
|
||||
mcpServerConfigVal,
|
||||
undefined,
|
||||
toolRegistry,
|
||||
);
|
||||
expect(toolRegistry.getAllTools()).toHaveLength(0);
|
||||
});
|
||||
});
|
||||
// Other tests for DiscoveredTool and DiscoveredMCPTool can be simplified or removed
|
||||
// if their core logic is now tested in their respective dedicated test files (mcp-tool.test.ts)
|
||||
});
|
||||
|
||||
describe('sanitizeParameters', () => {
|
||||
it('should remove unsupported format from a simple string property', () => {
|
||||
const schema: Schema = {
|
||||
type: Type.OBJECT,
|
||||
properties: {
|
||||
name: { type: Type.STRING },
|
||||
id: { type: Type.STRING, format: 'uuid' },
|
||||
},
|
||||
};
|
||||
sanitizeParameters(schema);
|
||||
expect(schema.properties?.['id']).toHaveProperty('format', undefined);
|
||||
expect(schema.properties?.['name']).not.toHaveProperty('format');
|
||||
});
|
||||
|
||||
it('should NOT remove supported format values', () => {
|
||||
const schema: Schema = {
|
||||
type: Type.OBJECT,
|
||||
properties: {
|
||||
date: { type: Type.STRING, format: 'date-time' },
|
||||
role: {
|
||||
type: Type.STRING,
|
||||
format: 'enum',
|
||||
enum: ['admin', 'user'],
|
||||
},
|
||||
},
|
||||
};
|
||||
const originalSchema = JSON.parse(JSON.stringify(schema));
|
||||
sanitizeParameters(schema);
|
||||
expect(schema).toEqual(originalSchema);
|
||||
});
|
||||
|
||||
it('should handle nested objects recursively', () => {
|
||||
const schema: Schema = {
|
||||
type: Type.OBJECT,
|
||||
properties: {
|
||||
user: {
|
||||
type: Type.OBJECT,
|
||||
properties: {
|
||||
email: { type: Type.STRING, format: 'email' },
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
sanitizeParameters(schema);
|
||||
expect(schema.properties?.['user']?.properties?.['email']).toHaveProperty(
|
||||
'format',
|
||||
undefined,
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle arrays of objects', () => {
|
||||
const schema: Schema = {
|
||||
type: Type.OBJECT,
|
||||
properties: {
|
||||
items: {
|
||||
type: Type.ARRAY,
|
||||
items: {
|
||||
type: Type.OBJECT,
|
||||
properties: {
|
||||
itemId: { type: Type.STRING, format: 'uuid' },
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
sanitizeParameters(schema);
|
||||
expect(
|
||||
(schema.properties?.['items']?.items as Schema)?.properties?.['itemId'],
|
||||
).toHaveProperty('format', undefined);
|
||||
});
|
||||
|
||||
it('should handle schemas with no properties to sanitize', () => {
|
||||
const schema: Schema = {
|
||||
type: Type.OBJECT,
|
||||
properties: {
|
||||
count: { type: Type.NUMBER },
|
||||
isActive: { type: Type.BOOLEAN },
|
||||
},
|
||||
};
|
||||
const originalSchema = JSON.parse(JSON.stringify(schema));
|
||||
sanitizeParameters(schema);
|
||||
expect(schema).toEqual(originalSchema);
|
||||
});
|
||||
|
||||
it('should not crash on an empty or undefined schema', () => {
|
||||
expect(() => sanitizeParameters({})).not.toThrow();
|
||||
expect(() => sanitizeParameters(undefined)).not.toThrow();
|
||||
});
|
||||
|
||||
it('should handle cyclic schemas without crashing', () => {
|
||||
const schema: any = {
|
||||
type: Type.OBJECT,
|
||||
properties: {
|
||||
name: { type: Type.STRING, format: 'hostname' },
|
||||
},
|
||||
};
|
||||
schema.properties.self = schema;
|
||||
|
||||
expect(() => sanitizeParameters(schema)).not.toThrow();
|
||||
expect(schema.properties.name).toHaveProperty('format', undefined);
|
||||
});
|
||||
|
||||
it('should handle complex nested schemas with cycles', () => {
|
||||
const userNode: any = {
|
||||
type: Type.OBJECT,
|
||||
properties: {
|
||||
id: { type: Type.STRING, format: 'uuid' },
|
||||
name: { type: Type.STRING },
|
||||
manager: {
|
||||
type: Type.OBJECT,
|
||||
properties: {
|
||||
id: { type: Type.STRING, format: 'uuid' },
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
userNode.properties.reports = {
|
||||
type: Type.ARRAY,
|
||||
items: userNode,
|
||||
};
|
||||
|
||||
const schema: Schema = {
|
||||
type: Type.OBJECT,
|
||||
properties: {
|
||||
ceo: userNode,
|
||||
},
|
||||
};
|
||||
|
||||
expect(() => sanitizeParameters(schema)).not.toThrow();
|
||||
expect(schema.properties?.['ceo']?.properties?.['id']).toHaveProperty(
|
||||
'format',
|
||||
undefined,
|
||||
);
|
||||
expect(
|
||||
schema.properties?.['ceo']?.properties?.['manager']?.properties?.['id'],
|
||||
).toHaveProperty('format', undefined);
|
||||
});
|
||||
});
|
||||
|
||||
Reference in New Issue
Block a user