fix(core): Sanitize tool parameters to fix 400 API errors (#3300)

This commit is contained in:
BigUncle
2025-07-06 05:58:51 +08:00
committed by GitHub
parent 5c9372372c
commit b564d4a088
8 changed files with 438 additions and 176 deletions

View File

@@ -14,22 +14,22 @@ import {
afterEach,
Mocked,
} from 'vitest';
import { ToolRegistry, DiscoveredTool } from './tool-registry.js';
import { DiscoveredMCPTool } from './mcp-tool.js';
import {
Config,
ConfigParameters,
MCPServerConfig,
ApprovalMode,
} from '../config/config.js';
ToolRegistry,
DiscoveredTool,
sanitizeParameters,
} from './tool-registry.js';
import { DiscoveredMCPTool } from './mcp-tool.js';
import { Config, ConfigParameters, ApprovalMode } from '../config/config.js';
import { BaseTool, ToolResult } from './tools.js';
import {
FunctionDeclaration,
CallableTool,
mcpToTool,
Type,
Schema,
} from '@google/genai';
import { execSync } from 'node:child_process';
import { spawn } from 'node:child_process';
// Use vi.hoisted to define the mock function so it can be used in the vi.mock factory
const mockDiscoverMcpTools = vi.hoisted(() => vi.fn());
@@ -61,7 +61,6 @@ vi.mock('@modelcontextprotocol/sdk/client/index.js', () => {
set onerror(handler: any) {
mockMcpClientOnError(handler);
},
// listTools and callTool are no longer directly used by ToolRegistry/discoverMcpTools
}));
return { Client: MockClient };
});
@@ -90,7 +89,6 @@ vi.mock('@google/genai', async () => {
return {
...actualGenai,
mcpToTool: vi.fn().mockImplementation(() => ({
// Default mock implementation
tool: vi.fn().mockResolvedValue({ functionDeclarations: [] }),
callTool: vi.fn(),
})),
@@ -139,6 +137,7 @@ const baseConfigParams: ConfigParameters = {
describe('ToolRegistry', () => {
let config: Config;
let toolRegistry: ToolRegistry;
let mockConfigGetToolDiscoveryCommand: ReturnType<typeof vi.spyOn>;
beforeEach(() => {
config = new Config(baseConfigParams);
@@ -148,13 +147,19 @@ describe('ToolRegistry', () => {
vi.spyOn(console, 'debug').mockImplementation(() => {});
vi.spyOn(console, 'log').mockImplementation(() => {});
// Reset mocks for MCP parts
mockMcpClientConnect.mockReset().mockResolvedValue(undefined); // Default connect success
mockMcpClientConnect.mockReset().mockResolvedValue(undefined);
mockStdioTransportClose.mockReset();
mockSseTransportClose.mockReset();
vi.mocked(mcpToTool).mockClear();
// Default mcpToTool to return a callable tool that returns no functions
vi.mocked(mcpToTool).mockReturnValue(createMockCallableTool([]));
mockConfigGetToolDiscoveryCommand = vi.spyOn(
config,
'getToolDiscoveryCommand',
);
vi.spyOn(config, 'getMcpServers');
vi.spyOn(config, 'getMcpServerCommand');
mockDiscoverMcpTools.mockReset().mockResolvedValue(undefined);
});
afterEach(() => {
@@ -167,21 +172,18 @@ describe('ToolRegistry', () => {
toolRegistry.registerTool(tool);
expect(toolRegistry.getTool('mock-tool')).toBe(tool);
});
// ... other registerTool tests
});
describe('getToolsByServer', () => {
it('should return an empty array if no tools match the server name', () => {
toolRegistry.registerTool(new MockTool()); // A non-MCP tool
toolRegistry.registerTool(new MockTool());
expect(toolRegistry.getToolsByServer('any-mcp-server')).toEqual([]);
});
it('should return only tools matching the server name', async () => {
const server1Name = 'mcp-server-uno';
const server2Name = 'mcp-server-dos';
// Manually register mock MCP tools for this test
const mockCallable = {} as CallableTool; // Minimal mock callable
const mockCallable = {} as CallableTool;
const mcpTool1 = new DiscoveredMCPTool(
mockCallable,
server1Name,
@@ -207,73 +209,87 @@ describe('ToolRegistry', () => {
const toolsFromServer1 = toolRegistry.getToolsByServer(server1Name);
expect(toolsFromServer1).toHaveLength(1);
expect(toolsFromServer1[0].name).toBe(mcpTool1.name);
expect((toolsFromServer1[0] as DiscoveredMCPTool).serverName).toBe(
server1Name,
);
const toolsFromServer2 = toolRegistry.getToolsByServer(server2Name);
expect(toolsFromServer2).toHaveLength(1);
expect(toolsFromServer2[0].name).toBe(mcpTool2.name);
expect((toolsFromServer2[0] as DiscoveredMCPTool).serverName).toBe(
server2Name,
);
expect(toolRegistry.getToolsByServer('non-existent-server')).toEqual([]);
});
});
describe('discoverTools', () => {
let mockConfigGetToolDiscoveryCommand: ReturnType<typeof vi.spyOn>;
let mockConfigGetMcpServers: ReturnType<typeof vi.spyOn>;
let mockConfigGetMcpServerCommand: ReturnType<typeof vi.spyOn>;
let mockExecSync: ReturnType<typeof vi.mocked<typeof execSync>>;
beforeEach(() => {
mockConfigGetToolDiscoveryCommand = vi.spyOn(
config,
'getToolDiscoveryCommand',
);
mockConfigGetMcpServers = vi.spyOn(config, 'getMcpServers');
mockConfigGetMcpServerCommand = vi.spyOn(config, 'getMcpServerCommand');
mockExecSync = vi.mocked(execSync);
toolRegistry = new ToolRegistry(config); // Reset registry
// Reset the mock for discoverMcpTools before each test in this suite
mockDiscoverMcpTools.mockReset().mockResolvedValue(undefined);
});
it('should discover tools using discovery command', async () => {
// ... this test remains largely the same
it('should sanitize tool parameters during discovery from command', async () => {
const discoveryCommand = 'my-discovery-command';
mockConfigGetToolDiscoveryCommand.mockReturnValue(discoveryCommand);
const mockToolDeclarations: FunctionDeclaration[] = [
{
name: 'discovered-tool-1',
description: 'A discovered tool',
parameters: { type: Type.OBJECT, properties: {} },
const unsanitizedToolDeclaration: FunctionDeclaration = {
name: 'tool-with-bad-format',
description: 'A tool with an invalid format property',
parameters: {
type: Type.OBJECT,
properties: {
some_string: {
type: Type.STRING,
format: 'uuid', // This is an unsupported format
},
},
},
];
mockExecSync.mockReturnValue(
Buffer.from(
JSON.stringify([{ function_declarations: mockToolDeclarations }]),
),
);
};
const mockSpawn = vi.mocked(spawn);
const mockChildProcess = {
stdout: { on: vi.fn() },
stderr: { on: vi.fn() },
on: vi.fn(),
};
mockSpawn.mockReturnValue(mockChildProcess as any);
// Simulate stdout data
mockChildProcess.stdout.on.mockImplementation((event, callback) => {
if (event === 'data') {
callback(
Buffer.from(
JSON.stringify([
{ function_declarations: [unsanitizedToolDeclaration] },
]),
),
);
}
return mockChildProcess as any;
});
// Simulate process close
mockChildProcess.on.mockImplementation((event, callback) => {
if (event === 'close') {
callback(0);
}
return mockChildProcess as any;
});
await toolRegistry.discoverTools();
expect(execSync).toHaveBeenCalledWith(discoveryCommand);
const discoveredTool = toolRegistry.getTool('discovered-tool-1');
expect(discoveredTool).toBeInstanceOf(DiscoveredTool);
const discoveredTool = toolRegistry.getTool('tool-with-bad-format');
expect(discoveredTool).toBeDefined();
const registeredParams = (discoveredTool as DiscoveredTool).schema
.parameters as Schema;
expect(registeredParams.properties?.['some_string']).toBeDefined();
expect(registeredParams.properties?.['some_string']).toHaveProperty(
'format',
undefined,
);
});
it('should discover tools using MCP servers defined in getMcpServers', async () => {
mockConfigGetToolDiscoveryCommand.mockReturnValue(undefined);
mockConfigGetMcpServerCommand.mockReturnValue(undefined);
vi.spyOn(config, 'getMcpServerCommand').mockReturnValue(undefined);
const mcpServerConfigVal = {
'my-mcp-server': {
command: 'mcp-server-cmd',
args: ['--port', '1234'],
trust: true,
} as MCPServerConfig,
},
};
mockConfigGetMcpServers.mockReturnValue(mcpServerConfigVal);
vi.spyOn(config, 'getMcpServers').mockReturnValue(mcpServerConfigVal);
await toolRegistry.discoverTools();
@@ -282,56 +298,166 @@ describe('ToolRegistry', () => {
undefined,
toolRegistry,
);
// We no longer check these as discoverMcpTools is mocked
// expect(vi.mocked(mcpToTool)).toHaveBeenCalledTimes(1);
// expect(Client).toHaveBeenCalledTimes(1);
// expect(StdioClientTransport).toHaveBeenCalledWith({
// command: 'mcp-server-cmd',
// args: ['--port', '1234'],
// env: expect.any(Object),
// stderr: 'pipe',
// });
// expect(mockMcpClientConnect).toHaveBeenCalled();
// To verify that tools *would* have been registered, we'd need mockDiscoverMcpTools
// to call toolRegistry.registerTool, or we test that separately.
// For now, we just check that the delegation happened.
});
it('should discover tools using MCP server command from getMcpServerCommand', async () => {
it('should discover tools using MCP servers defined in getMcpServers', async () => {
mockConfigGetToolDiscoveryCommand.mockReturnValue(undefined);
mockConfigGetMcpServers.mockReturnValue({});
mockConfigGetMcpServerCommand.mockReturnValue(
'mcp-server-start-command --param',
);
await toolRegistry.discoverTools();
expect(mockDiscoverMcpTools).toHaveBeenCalledWith(
{},
'mcp-server-start-command --param',
toolRegistry,
);
});
it('should handle errors during MCP client connection gracefully and close transport', async () => {
mockConfigGetToolDiscoveryCommand.mockReturnValue(undefined);
mockConfigGetMcpServers.mockReturnValue({
'failing-mcp': { command: 'fail-cmd' } as MCPServerConfig,
});
mockMcpClientConnect.mockRejectedValue(new Error('Connection failed'));
await toolRegistry.discoverTools();
expect(mockDiscoverMcpTools).toHaveBeenCalledWith(
{
'failing-mcp': { command: 'fail-cmd' },
vi.spyOn(config, 'getMcpServerCommand').mockReturnValue(undefined);
const mcpServerConfigVal = {
'my-mcp-server': {
command: 'mcp-server-cmd',
args: ['--port', '1234'],
trust: true,
},
};
vi.spyOn(config, 'getMcpServers').mockReturnValue(mcpServerConfigVal);
await toolRegistry.discoverTools();
expect(mockDiscoverMcpTools).toHaveBeenCalledWith(
mcpServerConfigVal,
undefined,
toolRegistry,
);
expect(toolRegistry.getAllTools()).toHaveLength(0);
});
});
// Other tests for DiscoveredTool and DiscoveredMCPTool can be simplified or removed
// if their core logic is now tested in their respective dedicated test files (mcp-tool.test.ts)
});
describe('sanitizeParameters', () => {
it('should remove unsupported format from a simple string property', () => {
const schema: Schema = {
type: Type.OBJECT,
properties: {
name: { type: Type.STRING },
id: { type: Type.STRING, format: 'uuid' },
},
};
sanitizeParameters(schema);
expect(schema.properties?.['id']).toHaveProperty('format', undefined);
expect(schema.properties?.['name']).not.toHaveProperty('format');
});
it('should NOT remove supported format values', () => {
const schema: Schema = {
type: Type.OBJECT,
properties: {
date: { type: Type.STRING, format: 'date-time' },
role: {
type: Type.STRING,
format: 'enum',
enum: ['admin', 'user'],
},
},
};
const originalSchema = JSON.parse(JSON.stringify(schema));
sanitizeParameters(schema);
expect(schema).toEqual(originalSchema);
});
it('should handle nested objects recursively', () => {
const schema: Schema = {
type: Type.OBJECT,
properties: {
user: {
type: Type.OBJECT,
properties: {
email: { type: Type.STRING, format: 'email' },
},
},
},
};
sanitizeParameters(schema);
expect(schema.properties?.['user']?.properties?.['email']).toHaveProperty(
'format',
undefined,
);
});
it('should handle arrays of objects', () => {
const schema: Schema = {
type: Type.OBJECT,
properties: {
items: {
type: Type.ARRAY,
items: {
type: Type.OBJECT,
properties: {
itemId: { type: Type.STRING, format: 'uuid' },
},
},
},
},
};
sanitizeParameters(schema);
expect(
(schema.properties?.['items']?.items as Schema)?.properties?.['itemId'],
).toHaveProperty('format', undefined);
});
it('should handle schemas with no properties to sanitize', () => {
const schema: Schema = {
type: Type.OBJECT,
properties: {
count: { type: Type.NUMBER },
isActive: { type: Type.BOOLEAN },
},
};
const originalSchema = JSON.parse(JSON.stringify(schema));
sanitizeParameters(schema);
expect(schema).toEqual(originalSchema);
});
it('should not crash on an empty or undefined schema', () => {
expect(() => sanitizeParameters({})).not.toThrow();
expect(() => sanitizeParameters(undefined)).not.toThrow();
});
it('should handle cyclic schemas without crashing', () => {
const schema: any = {
type: Type.OBJECT,
properties: {
name: { type: Type.STRING, format: 'hostname' },
},
};
schema.properties.self = schema;
expect(() => sanitizeParameters(schema)).not.toThrow();
expect(schema.properties.name).toHaveProperty('format', undefined);
});
it('should handle complex nested schemas with cycles', () => {
const userNode: any = {
type: Type.OBJECT,
properties: {
id: { type: Type.STRING, format: 'uuid' },
name: { type: Type.STRING },
manager: {
type: Type.OBJECT,
properties: {
id: { type: Type.STRING, format: 'uuid' },
},
},
},
};
userNode.properties.reports = {
type: Type.ARRAY,
items: userNode,
};
const schema: Schema = {
type: Type.OBJECT,
properties: {
ceo: userNode,
},
};
expect(() => sanitizeParameters(schema)).not.toThrow();
expect(schema.properties?.['ceo']?.properties?.['id']).toHaveProperty(
'format',
undefined,
);
expect(
schema.properties?.['ceo']?.properties?.['manager']?.properties?.['id'],
).toHaveProperty('format', undefined);
});
});