fix sampling params

This commit is contained in:
koalazf.99
2025-12-11 13:46:37 +08:00
parent 354c85bcff
commit 5854ac67c6

View File

@@ -5,6 +5,7 @@
*/ */
import type { import type {
ContentGeneratorConfig,
FileFilteringOptions, FileFilteringOptions,
MCPServerConfig, MCPServerConfig,
OutputFormat, OutputFormat,
@@ -123,6 +124,24 @@ export interface CliArgs {
outputFormat: string | undefined; outputFormat: string | undefined;
} }
type LegacySamplingSettings = {
sampling_params?: ContentGeneratorConfig['samplingParams'];
};
function getLegacySamplingParams(
settings: Settings,
): ContentGeneratorConfig['samplingParams'] | undefined {
if (
typeof settings !== 'object' ||
settings === null ||
!('sampling_params' in (settings as Record<string, unknown>))
) {
return undefined;
}
return (settings as Settings & LegacySamplingSettings).sampling_params;
}
export async function parseArguments(settings: Settings): Promise<CliArgs> { export async function parseArguments(settings: Settings): Promise<CliArgs> {
const rawArgv = hideBin(process.argv); const rawArgv = hideBin(process.argv);
const yargsInstance = yargs(rawArgv) const yargsInstance = yargs(rawArgv)
@@ -685,6 +704,7 @@ export async function loadCliConfig(
const vlmSwitchMode = const vlmSwitchMode =
argv.vlmSwitchMode || settings.experimental?.vlmSwitchMode; argv.vlmSwitchMode || settings.experimental?.vlmSwitchMode;
const legacySamplingParams = getLegacySamplingParams(settings);
return new Config({ return new Config({
sessionId, sessionId,
embeddingModel: DEFAULT_QWEN_EMBEDDING_MODEL, embeddingModel: DEFAULT_QWEN_EMBEDDING_MODEL,
@@ -745,6 +765,8 @@ export async function loadCliConfig(
(typeof argv.openaiLogging === 'undefined' (typeof argv.openaiLogging === 'undefined'
? settings.model?.enableOpenAILogging ? settings.model?.enableOpenAILogging
: argv.openaiLogging) ?? false, : argv.openaiLogging) ?? false,
// Include sampling_params from root level settings
...(legacySamplingParams ? { samplingParams: legacySamplingParams } : {}),
}, },
cliVersion: await getCliVersion(), cliVersion: await getCliVersion(),
tavilyApiKey: tavilyApiKey: