fix(core): Sanitize tool parameters to fix 400 API errors (#3300)

This commit is contained in:
BigUncle
2025-07-06 05:58:51 +08:00
committed by GitHub
parent 5c9372372c
commit b564d4a088
8 changed files with 438 additions and 176 deletions

View File

@@ -4,12 +4,14 @@
* SPDX-License-Identifier: Apache-2.0
*/
import { FunctionDeclaration } from '@google/genai';
import { FunctionDeclaration, Schema, Type } from '@google/genai';
import { Tool, ToolResult, BaseTool } from './tools.js';
import { Config } from '../config/config.js';
import { spawn, execSync } from 'node:child_process';
import { spawn } from 'node:child_process';
import { StringDecoder } from 'node:string_decoder';
import { discoverMcpTools } from './mcp-client.js';
import { DiscoveredMCPTool } from './mcp-tool.js';
import { parse } from 'shell-quote';
type ToolParams = Record<string, unknown>;
@@ -157,32 +159,9 @@ export class ToolRegistry {
// Keep manually registered tools
}
}
// discover tools using discovery command, if configured
const discoveryCmd = this.config.getToolDiscoveryCommand();
if (discoveryCmd) {
// execute discovery command and extract function declarations (w/ or w/o "tool" wrappers)
const functions: FunctionDeclaration[] = [];
for (const tool of JSON.parse(execSync(discoveryCmd).toString().trim())) {
if (tool['function_declarations']) {
functions.push(...tool['function_declarations']);
} else if (tool['functionDeclarations']) {
functions.push(...tool['functionDeclarations']);
} else if (tool['name']) {
functions.push(tool);
}
}
// register each function as a tool
for (const func of functions) {
this.registerTool(
new DiscoveredTool(
this.config,
func.name!,
func.description!,
func.parameters! as Record<string, unknown>,
),
);
}
}
await this.discoverAndRegisterToolsFromCommand();
// discover tools using MCP servers, if configured
await discoverMcpTools(
this.config.getMcpServers() ?? {},
@@ -191,6 +170,128 @@ export class ToolRegistry {
);
}
private async discoverAndRegisterToolsFromCommand(): Promise<void> {
const discoveryCmd = this.config.getToolDiscoveryCommand();
if (!discoveryCmd) {
return;
}
try {
const cmdParts = parse(discoveryCmd);
if (cmdParts.length === 0) {
throw new Error(
'Tool discovery command is empty or contains only whitespace.',
);
}
const proc = spawn(cmdParts[0] as string, cmdParts.slice(1) as string[]);
let stdout = '';
const stdoutDecoder = new StringDecoder('utf8');
let stderr = '';
const stderrDecoder = new StringDecoder('utf8');
let sizeLimitExceeded = false;
const MAX_STDOUT_SIZE = 10 * 1024 * 1024; // 10MB limit
const MAX_STDERR_SIZE = 10 * 1024 * 1024; // 10MB limit
let stdoutByteLength = 0;
let stderrByteLength = 0;
proc.stdout.on('data', (data) => {
if (sizeLimitExceeded) return;
if (stdoutByteLength + data.length > MAX_STDOUT_SIZE) {
sizeLimitExceeded = true;
proc.kill();
return;
}
stdoutByteLength += data.length;
stdout += stdoutDecoder.write(data);
});
proc.stderr.on('data', (data) => {
if (sizeLimitExceeded) return;
if (stderrByteLength + data.length > MAX_STDERR_SIZE) {
sizeLimitExceeded = true;
proc.kill();
return;
}
stderrByteLength += data.length;
stderr += stderrDecoder.write(data);
});
await new Promise<void>((resolve, reject) => {
proc.on('error', reject);
proc.on('close', (code) => {
stdout += stdoutDecoder.end();
stderr += stderrDecoder.end();
if (sizeLimitExceeded) {
return reject(
new Error(
`Tool discovery command output exceeded size limit of ${MAX_STDOUT_SIZE} bytes.`,
),
);
}
if (code !== 0) {
console.error(`Command failed with code ${code}`);
console.error(stderr);
return reject(
new Error(`Tool discovery command failed with exit code ${code}`),
);
}
resolve();
});
});
// execute discovery command and extract function declarations (w/ or w/o "tool" wrappers)
const functions: FunctionDeclaration[] = [];
const discoveredItems = JSON.parse(stdout.trim());
if (!discoveredItems || !Array.isArray(discoveredItems)) {
throw new Error(
'Tool discovery command did not return a JSON array of tools.',
);
}
for (const tool of discoveredItems) {
if (tool && typeof tool === 'object') {
if (Array.isArray(tool['function_declarations'])) {
functions.push(...tool['function_declarations']);
} else if (Array.isArray(tool['functionDeclarations'])) {
functions.push(...tool['functionDeclarations']);
} else if (tool['name']) {
functions.push(tool as FunctionDeclaration);
}
}
}
// register each function as a tool
for (const func of functions) {
if (!func.name) {
console.warn('Discovered a tool with no name. Skipping.');
continue;
}
// Sanitize the parameters before registering the tool.
const parameters =
func.parameters &&
typeof func.parameters === 'object' &&
!Array.isArray(func.parameters)
? (func.parameters as Schema)
: {};
sanitizeParameters(parameters);
this.registerTool(
new DiscoveredTool(
this.config,
func.name,
func.description ?? '',
parameters as Record<string, unknown>,
),
);
}
} catch (e) {
console.error(`Tool discovery command "${discoveryCmd}" failed:`, e);
throw e;
}
}
/**
* Retrieves the list of tool schemas (FunctionDeclaration array).
* Extracts the declarations from the ToolListUnion structure.
@@ -232,3 +333,62 @@ export class ToolRegistry {
return this.tools.get(name);
}
}
/**
* Sanitizes a schema object in-place to ensure compatibility with the Gemini API.
*
* NOTE: This function mutates the passed schema object.
*
* It performs the following actions:
* - Removes the `default` property when `anyOf` is present.
* - Removes unsupported `format` values from string properties, keeping only 'enum' and 'date-time'.
* - Recursively sanitizes nested schemas within `anyOf`, `items`, and `properties`.
* - Handles circular references within the schema to prevent infinite loops.
*
* @param schema The schema object to sanitize. It will be modified directly.
*/
export function sanitizeParameters(schema?: Schema) {
_sanitizeParameters(schema, new Set<Schema>());
}
/**
* Internal recursive implementation for sanitizeParameters.
* @param schema The schema object to sanitize.
* @param visited A set used to track visited schema objects during recursion.
*/
function _sanitizeParameters(schema: Schema | undefined, visited: Set<Schema>) {
if (!schema || visited.has(schema)) {
return;
}
visited.add(schema);
if (schema.anyOf) {
// Vertex AI gets confused if both anyOf and default are set.
schema.default = undefined;
for (const item of schema.anyOf) {
if (typeof item !== 'boolean') {
_sanitizeParameters(item, visited);
}
}
}
if (schema.items && typeof schema.items !== 'boolean') {
_sanitizeParameters(schema.items, visited);
}
if (schema.properties) {
for (const item of Object.values(schema.properties)) {
if (typeof item !== 'boolean') {
_sanitizeParameters(item, visited);
}
}
}
// Vertex AI only supports 'enum' and 'date-time' for STRING format.
if (schema.type === Type.STRING) {
if (
schema.format &&
schema.format !== 'enum' &&
schema.format !== 'date-time'
) {
schema.format = undefined;
}
}
}