diff --git a/packages/core/src/core/baseLlmClient.ts b/packages/core/src/core/baseLlmClient.ts index b8ce2a68..f7a61c7f 100644 --- a/packages/core/src/core/baseLlmClient.ts +++ b/packages/core/src/core/baseLlmClient.ts @@ -65,10 +65,7 @@ export interface GenerateJsonOptions { */ export class BaseLlmClient { // Default configuration for utility tasks - private readonly defaultUtilityConfig: GenerateContentConfig = { - temperature: 0, - topP: 1, - }; + private readonly defaultUtilityConfig: GenerateContentConfig = {}; constructor( private readonly contentGenerator: ContentGenerator, diff --git a/packages/core/src/core/client.ts b/packages/core/src/core/client.ts index 5b35e88f..7c700cfb 100644 --- a/packages/core/src/core/client.ts +++ b/packages/core/src/core/client.ts @@ -149,10 +149,7 @@ const COMPRESSION_PRESERVE_THRESHOLD = 0.3; export class GeminiClient { private chat?: GeminiChat; - private readonly generateContentConfig: GenerateContentConfig = { - temperature: 0, - topP: 1, - }; + private readonly generateContentConfig: GenerateContentConfig; private sessionTurnCount = 0; private readonly loopDetector: LoopDetectionService; @@ -169,6 +166,44 @@ export class GeminiClient { constructor(private readonly config: Config) { this.loopDetector = new LoopDetectionService(config); this.lastPromptId = this.config.getSessionId(); + this.generateContentConfig = this.buildDefaultGenerateContentConfig(); + } + + private buildDefaultGenerateContentConfig(): GenerateContentConfig { + const samplingParams = + this.config.getContentGeneratorConfig()?.samplingParams; + + if (!samplingParams) { + return {}; + } + + const config: GenerateContentConfig = {}; + + if (samplingParams.temperature !== undefined) { + config.temperature = samplingParams.temperature; + } + + if (samplingParams.top_p !== undefined) { + config.topP = samplingParams.top_p; + } + + if (samplingParams.top_k !== undefined) { + config.topK = samplingParams.top_k; + } + + if (samplingParams.max_tokens !== undefined) { + config.maxOutputTokens = samplingParams.max_tokens; + } + + if (samplingParams.presence_penalty !== undefined) { + config.presencePenalty = samplingParams.presence_penalty; + } + + if (samplingParams.frequency_penalty !== undefined) { + config.frequencyPenalty = samplingParams.frequency_penalty; + } + + return config; } async initialize() {