Refactor: Centralize GeminiClient in Config (#693)

Co-authored-by: N. Taylor Mullen <ntaylormullen@google.com>
This commit is contained in:
Scott Densmore
2025-06-02 14:55:51 -07:00
committed by GitHub
parent 1dcf0a4cbd
commit e428707e07
4 changed files with 79 additions and 50 deletions

View File

@@ -4,13 +4,12 @@
* SPDX-License-Identifier: Apache-2.0
*/
import { GoogleGenAI, GroundingMetadata } from '@google/genai';
import { GroundingMetadata } from '@google/genai';
import { SchemaValidator } from '../utils/schemaValidator.js';
import { BaseTool, ToolResult } from './tools.js';
import { getErrorMessage } from '../utils/errors.js';
import { Config } from '../config/config.js';
import { getResponseText } from '../utils/generateContentResponseUtilities.js';
import { retryWithBackoff } from '../utils/retry.js';
// Interfaces for grounding metadata (similar to web-search.ts)
interface GroundingChunkWeb {
@@ -49,9 +48,6 @@ export interface WebFetchToolParams {
export class WebFetchTool extends BaseTool<WebFetchToolParams, ToolResult> {
static readonly Name: string = 'web_fetch';
private ai: GoogleGenAI;
private modelName: string;
constructor(private readonly config: Config) {
super(
WebFetchTool.Name,
@@ -69,12 +65,6 @@ export class WebFetchTool extends BaseTool<WebFetchToolParams, ToolResult> {
type: 'object',
},
);
const apiKeyFromConfig = this.config.getApiKey();
this.ai = new GoogleGenAI({
apiKey: apiKeyFromConfig === '' ? undefined : apiKeyFromConfig,
});
this.modelName = this.config.getModel();
}
validateParams(params: WebFetchToolParams): string | null {
@@ -109,7 +99,7 @@ export class WebFetchTool extends BaseTool<WebFetchToolParams, ToolResult> {
async execute(
params: WebFetchToolParams,
_signal: AbortSignal,
signal: AbortSignal,
): Promise<ToolResult> {
const validationError = this.validateParams(params);
if (validationError) {
@@ -120,23 +110,14 @@ export class WebFetchTool extends BaseTool<WebFetchToolParams, ToolResult> {
}
const userPrompt = params.prompt;
const geminiClient = this.config.getGeminiClient();
try {
const apiCall = () =>
this.ai.models.generateContent({
model: this.modelName,
contents: [
{
role: 'user',
parts: [{ text: userPrompt }],
},
],
config: {
tools: [{ urlContext: {} }],
},
});
const response = await retryWithBackoff(apiCall);
const response = await geminiClient.generateContent(
[{ role: 'user', parts: [{ text: userPrompt }] }],
{ tools: [{ urlContext: {} }] },
signal, // Pass signal
);
console.debug(
`[WebFetchTool] Full response for prompt "${userPrompt.substring(0, 50)}...":`,

View File

@@ -4,14 +4,13 @@
* SPDX-License-Identifier: Apache-2.0
*/
import { GoogleGenAI, GroundingMetadata } from '@google/genai';
import { GroundingMetadata } from '@google/genai';
import { BaseTool, ToolResult } from './tools.js';
import { SchemaValidator } from '../utils/schemaValidator.js';
import { getErrorMessage } from '../utils/errors.js';
import { Config } from '../config/config.js';
import { getResponseText } from '../utils/generateContentResponseUtilities.js';
import { retryWithBackoff } from '../utils/retry.js';
interface GroundingChunkWeb {
uri?: string;
@@ -64,9 +63,6 @@ export class WebSearchTool extends BaseTool<
> {
static readonly Name: string = 'google_web_search';
private ai: GoogleGenAI;
private modelName: string;
constructor(private readonly config: Config) {
super(
WebSearchTool.Name,
@@ -83,13 +79,6 @@ export class WebSearchTool extends BaseTool<
required: ['query'],
},
);
const apiKeyFromConfig = this.config.getApiKey();
// Initialize GoogleGenAI, allowing fallback to environment variables for API key
this.ai = new GoogleGenAI({
apiKey: apiKeyFromConfig === '' ? undefined : apiKeyFromConfig,
});
this.modelName = this.config.getModel();
}
validateParams(params: WebSearchToolParams): string | null {
@@ -112,7 +101,10 @@ export class WebSearchTool extends BaseTool<
return `Searching the web for: "${params.query}"`;
}
async execute(params: WebSearchToolParams): Promise<WebSearchToolResult> {
async execute(
params: WebSearchToolParams,
signal: AbortSignal,
): Promise<WebSearchToolResult> {
const validationError = this.validateParams(params);
if (validationError) {
return {
@@ -120,18 +112,14 @@ export class WebSearchTool extends BaseTool<
returnDisplay: validationError,
};
}
const geminiClient = this.config.getGeminiClient();
try {
const apiCall = () =>
this.ai.models.generateContent({
model: this.modelName,
contents: [{ role: 'user', parts: [{ text: params.query }] }],
config: {
tools: [{ googleSearch: {} }],
},
});
const response = await retryWithBackoff(apiCall);
const response = await geminiClient.generateContent(
[{ role: 'user', parts: [{ text: params.query }] }],
{ tools: [{ googleSearch: {} }] },
signal,
);
const responseText = getResponseText(response);
const groundingMetadata = response.candidates?.[0]?.groundingMetadata;