From a7c8c4c2fa2583cba41d566c6fdfb97e9270f61d Mon Sep 17 00:00:00 2001 From: Mingholy Date: Wed, 20 Aug 2025 17:21:44 +0800 Subject: [PATCH 1/7] fix: allow query submission abort (#392) --- packages/cli/src/ui/hooks/useGeminiStream.ts | 1 + 1 file changed, 1 insertion(+) diff --git a/packages/cli/src/ui/hooks/useGeminiStream.ts b/packages/cli/src/ui/hooks/useGeminiStream.ts index 9972bbdc..51492909 100644 --- a/packages/cli/src/ui/hooks/useGeminiStream.ts +++ b/packages/cli/src/ui/hooks/useGeminiStream.ts @@ -192,6 +192,7 @@ export const useGeminiStream = ( return; } turnCancelledRef.current = true; + isSubmittingQueryRef.current = false; abortControllerRef.current?.abort(); if (pendingHistoryItemRef.current) { addItem(pendingHistoryItemRef.current, Date.now()); From 1cf23c34e7e0124a5096ef7b48682b83d7a9fa1a Mon Sep 17 00:00:00 2001 From: tanzhenxin Date: Wed, 20 Aug 2025 17:35:00 +0800 Subject: [PATCH 2/7] fix: workflow automated issue triage --- .../gemini-automated-issue-triage.yml | 155 +----------------- 1 file changed, 1 insertion(+), 154 deletions(-) diff --git a/.github/workflows/gemini-automated-issue-triage.yml b/.github/workflows/gemini-automated-issue-triage.yml index b9343033..bb22e65b 100644 --- a/.github/workflows/gemini-automated-issue-triage.yml +++ b/.github/workflows/gemini-automated-issue-triage.yml @@ -34,13 +34,6 @@ jobs: triage-issue: timeout-minutes: 5 if: ${{ github.repository == 'QwenLM/qwen-code' }} - permissions: - issues: write - contents: read - id-token: write - concurrency: - group: ${{ github.workflow }}-${{ github.event.issue.number }} - cancel-in-progress: true runs-on: ubuntu-latest steps: - name: Run Qwen Issue Triage @@ -177,151 +170,5 @@ jobs: owner: '${{ github.repository }}'.split('/')[0], repo: '${{ github.repository }}'.split('/')[1], issue_number: '${{ github.event.issue.number }}', - body: 'There is a problem with the Gemini CLI issue triaging. Please check the [action logs](${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }}) for details.' + body: 'There is a problem with the Qwen Code issue triaging. Please check the [action logs](${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }}) for details.' }) - - deduplicate-issues: - if: > - github.repository == 'google-gemini/gemini-cli' && - vars.TRIAGE_DEDUPLICATE_ISSUES != '' && - (github.event_name == 'issues' || - github.event_name == 'workflow_dispatch' || - (github.event_name == 'issue_comment' && - contains(github.event.comment.body, '@gemini-cli /deduplicate') && - (github.event.comment.author_association == 'OWNER' || - github.event.comment.author_association == 'MEMBER' || - github.event.comment.author_association == 'COLLABORATOR'))) - - timeout-minutes: 20 - runs-on: 'ubuntu-latest' - steps: - - name: 'Checkout repository' - uses: 'actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683' - - - name: 'Generate GitHub App Token' - id: 'generate_token' - uses: 'actions/create-github-app-token@df432ceedc7162793a195dd1713ff69aefc7379e' - with: - app-id: '${{ secrets.APP_ID }}' - private-key: '${{ secrets.PRIVATE_KEY }}' - - - name: Log in to GitHub Container Registry - uses: docker/login-action@v3 - with: - registry: ghcr.io - username: ${{ github.actor }} - password: ${{ secrets.GITHUB_TOKEN }} - - - name: 'Run Gemini Issue Deduplication' - uses: 'google-github-actions/run-gemini-cli@20351b5ea2b4179431f1ae8918a246a0808f8747' - id: 'gemini_issue_deduplication' - env: - GITHUB_TOKEN: '${{ steps.generate_token.outputs.token }}' - ISSUE_TITLE: '${{ github.event.issue.title }}' - ISSUE_BODY: '${{ github.event.issue.body }}' - ISSUE_NUMBER: '${{ github.event.issue.number }}' - REPOSITORY: '${{ github.repository }}' - FIRESTORE_PROJECT: '${{ vars.FIRESTORE_PROJECT }}' - with: - gcp_workload_identity_provider: '${{ vars.GCP_WIF_PROVIDER }}' - gcp_project_id: '${{ vars.GOOGLE_CLOUD_PROJECT }}' - gcp_location: '${{ vars.GOOGLE_CLOUD_LOCATION }}' - gcp_service_account: '${{ vars.SERVICE_ACCOUNT_EMAIL }}' - gemini_api_key: '${{ secrets.GEMINI_API_KEY }}' - use_vertex_ai: '${{ vars.GOOGLE_GENAI_USE_VERTEXAI }}' - use_gemini_code_assist: '${{ vars.GOOGLE_GENAI_USE_GCA }}' - settings: |- - { - "mcpServers": { - "issue_deduplication": { - "command": "docker", - "args": [ - "run", - "-i", - "--rm", - "--network", "host", - "-e", "GITHUB_TOKEN", - "-e", "GEMINI_API_KEY", - "-e", "DATABASE_TYPE", - "-e", "FIRESTORE_DATABASE_ID", - "-e", "GCP_PROJECT", - "-e", "GOOGLE_APPLICATION_CREDENTIALS=/app/gcp-credentials.json", - "-v", "${GOOGLE_APPLICATION_CREDENTIALS}:/app/gcp-credentials.json", - "ghcr.io/google-gemini/gemini-cli-issue-triage@sha256:e3de1523f6c83aabb3c54b76d08940a2bf42febcb789dd2da6f95169641f94d3" - ], - "env": { - "GITHUB_TOKEN": "${GITHUB_TOKEN}", - "GEMINI_API_KEY": "${{ secrets.GEMINI_API_KEY }}", - "DATABASE_TYPE":"firestore", - "GCP_PROJECT": "${FIRESTORE_PROJECT}", - "FIRESTORE_DATABASE_ID": "(default)", - "GOOGLE_APPLICATION_CREDENTIALS": "${GOOGLE_APPLICATION_CREDENTIALS}" - }, - "enabled": true, - "timeout": 600000 - } - }, - "maxSessionTurns": 25, - "coreTools": [ - "run_shell_command(echo)", - "run_shell_command(gh issue comment)", - "run_shell_command(gh issue view)", - "run_shell_command(gh issue edit)" - ], - "telemetry": { - "enabled": true, - "target": "gcp" - } - } - prompt: |- - ## Role - You are an issue de-duplication assistant. Your goal is to find - duplicate issues, label the current issue as a duplicate, and notify - the user by commenting on the current issue, while avoiding - duplicate comments. - ## Steps - 1. **Find Potential Duplicates:** - - The repository is ${{ github.repository }} and the issue number is ${{ github.event.issue.number }}. - - Use the `duplicates` tool with the `repo` and `issue_number` to find potential duplicates for the current issue. Do not use the `threshold` parameter. - - If no duplicates are found, you are done. - - Print the JSON output from the `duplicates` tool to the logs. - 2. **Refine Duplicates List (if necessary):** - - If the `duplicates` tool returns between 1 and 14 results, you must refine the list. - - For each potential duplicate issue, run `gh issue view --json title,body,comments` to fetch its content. - - Also fetch the content of the original issue: `gh issue view "${ISSUE_NUMBER}" --json title,body,comments`. - - Carefully analyze the content (title, body, comments) of the original issue and all potential duplicates. - - It is very important if the comments on either issue mention that they are not duplicates of each other, to treat them as not duplicates. - - Based on your analysis, create a final list containing only the issues you are highly confident are actual duplicates. - - If your final list is empty, you are done. - - Print to the logs if you omitted any potential duplicates based on your analysis. - - If the `duplicates` tool returned 15+ results, use the top 15 matches (based on descending similarity score value) to perform this step. - 3. **Format Final Duplicates List:** - Format the final list of duplicates into a markdown string. - The format should be: - "Found possible duplicate issues:\n\n- #${issue_number}\n\nIf you believe this is not a duplicate, please remove the `status/possible-duplicate` label." - Add an HTML comment to the end for identification: `` - 4. **Check for Existing Comment:** - - Run `gh issue view "${ISSUE_NUMBER}" --json comments` to get all - comments on the issue. - - Look for a comment made by a bot (the author's login often ends in `[bot]`) that contains ``. - - If you find such a comment, store its `id` and `body`. - 5. **Decide Action:** - - **If an existing comment is found:** - - Compare the new list of duplicate issues with the list from the existing comment's body. - - If they are the same, do nothing. - - If they are different, edit the existing comment. Use - `gh issue comment "${ISSUE_NUMBER}" --edit-comment --body "..."`. - The new body should be the new list of duplicates, but with the header "Found possible duplicate issues (updated):". - - **If no existing comment is found:** - - Create a new comment with the list of duplicates. - - Use `gh issue comment "${ISSUE_NUMBER}" --body "..."`. - 6. **Add Duplicate Label:** - - If you created or updated a comment in the previous step, add the `duplicate` label to the current issue. - - Use `gh issue edit "${ISSUE_NUMBER}" --add-label "status/possible-duplicate"`. - ## Guidelines - - Only use the `duplicates` and `run_shell_command` tools. - - The `run_shell_command` tool can be used with `gh issue view`, `gh issue comment`, and `gh issue edit`. - - Do not download or read media files like images, videos, or links. The `--json` flag for `gh issue view` will prevent this. - - Do not modify the issue content or status. - - Only comment on and label the current issue. - - Reference all shell variables as "${VAR}" (with quotes and braces). From 201b476a3e2e3163bd9ff314c344aa6295fa0293 Mon Sep 17 00:00:00 2001 From: tanzhenxin Date: Wed, 20 Aug 2025 17:52:25 +0800 Subject: [PATCH 3/7] fix: workflow automated issue triage --- .github/workflows/gemini-automated-issue-triage.yml | 3 ++- .github/workflows/gemini-scheduled-issue-triage.yml | 1 - 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/gemini-automated-issue-triage.yml b/.github/workflows/gemini-automated-issue-triage.yml index bb22e65b..3d711264 100644 --- a/.github/workflows/gemini-automated-issue-triage.yml +++ b/.github/workflows/gemini-automated-issue-triage.yml @@ -43,8 +43,9 @@ jobs: ISSUE_TITLE: ${{ github.event.issue.title }} ISSUE_BODY: ${{ github.event.issue.body }} with: - version: 0.0.7 OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} + OPENAI_BASE_URL: ${{ secrets.OPENAI_BASE_URL }} + OPENAI_MODEL: ${{ secrets.OPENAI_MODEL }} settings_json: | { "maxSessionTurns": 25, diff --git a/.github/workflows/gemini-scheduled-issue-triage.yml b/.github/workflows/gemini-scheduled-issue-triage.yml index ac698dd7..ca11b308 100644 --- a/.github/workflows/gemini-scheduled-issue-triage.yml +++ b/.github/workflows/gemini-scheduled-issue-triage.yml @@ -60,7 +60,6 @@ jobs: ISSUES_TO_TRIAGE: ${{ steps.find_issues.outputs.issues_to_triage }} REPOSITORY: ${{ github.repository }} with: - version: 0.0.7 OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} OPENAI_BASE_URL: ${{ secrets.OPENAI_BASE_URL }} OPENAI_MODEL: ${{ secrets.OPENAI_MODEL }} From 9a28f04acedc38dde6767f43a6c325d681507fee Mon Sep 17 00:00:00 2001 From: tanzhenxin Date: Wed, 20 Aug 2025 19:41:17 +0800 Subject: [PATCH 4/7] chore: update issue triage prompt --- .github/workflows/gemini-automated-issue-triage.yml | 2 +- .github/workflows/gemini-scheduled-issue-triage.yml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/gemini-automated-issue-triage.yml b/.github/workflows/gemini-automated-issue-triage.yml index 3d711264..cc73d9d0 100644 --- a/.github/workflows/gemini-automated-issue-triage.yml +++ b/.github/workflows/gemini-automated-issue-triage.yml @@ -66,7 +66,7 @@ jobs: ## Steps 1. Run: `gh label list --repo ${{ github.repository }} --limit 100` to get all available labels. - 2. Review the issue title and body provided in the environment variables: "${ISSUE_TITLE}" and "${ISSUE_BODY}". + 2. Use right tool to review the issue title and body provided in the environment variables: "${ISSUE_TITLE}" and "${ISSUE_BODY}". 3. Ignore any existing priorities or tags on the issue. Just report your findings. 4. Select the most relevant labels from the existing labels, focusing on kind/*, area/*, sub-area/* and priority/*. For area/* and kind/* limit yourself to only the single most applicable label in each case. 6. Apply the selected labels to this issue using: `gh issue edit ${{ github.event.issue.number }} --repo ${{ github.repository }} --add-label "label1,label2"`. diff --git a/.github/workflows/gemini-scheduled-issue-triage.yml b/.github/workflows/gemini-scheduled-issue-triage.yml index ca11b308..66c0167b 100644 --- a/.github/workflows/gemini-scheduled-issue-triage.yml +++ b/.github/workflows/gemini-scheduled-issue-triage.yml @@ -85,7 +85,7 @@ jobs: ## Steps 1. Run: `gh label list --repo ${{ github.repository }} --limit 100` to get all available labels. - 2. Check environment variable for issues to triage: $ISSUES_TO_TRIAGE (JSON array of issues) + 2. Use right tool to check environment variable for issues to triage: $ISSUES_TO_TRIAGE (JSON array of issues) 3. Review the issue title, body and any comments provided in the environment variables. 4. Ignore any existing priorities or tags on the issue. 5. Select the most relevant labels from the existing labels, focusing on kind/*, area/*, sub-area/* and priority/*. From 64ce8c1d1e3996c95af35a5dcf6bec6b57ca5da8 Mon Sep 17 00:00:00 2001 From: tanzhenxin Date: Wed, 20 Aug 2025 20:23:02 +0800 Subject: [PATCH 5/7] fix: revert trimEnd on LLM response content --- packages/core/src/core/openaiContentGenerator.ts | 16 ++++------------ 1 file changed, 4 insertions(+), 12 deletions(-) diff --git a/packages/core/src/core/openaiContentGenerator.ts b/packages/core/src/core/openaiContentGenerator.ts index eeba5db7..e24bd0c3 100644 --- a/packages/core/src/core/openaiContentGenerator.ts +++ b/packages/core/src/core/openaiContentGenerator.ts @@ -563,7 +563,7 @@ export class OpenAIContentGenerator implements ContentGenerator { // Add combined text if any if (combinedText) { - combinedParts.push({ text: combinedText.trimEnd() }); + combinedParts.push({ text: combinedText }); } // Add function calls @@ -1164,11 +1164,7 @@ export class OpenAIContentGenerator implements ContentGenerator { // Handle text content if (choice.message.content) { - if (typeof choice.message.content === 'string') { - parts.push({ text: choice.message.content.trimEnd() }); - } else { - parts.push({ text: choice.message.content }); - } + parts.push({ text: choice.message.content }); } // Handle tool calls @@ -1253,11 +1249,7 @@ export class OpenAIContentGenerator implements ContentGenerator { // Handle text content if (choice.delta?.content) { - if (typeof choice.delta.content === 'string') { - parts.push({ text: choice.delta.content.trimEnd() }); - } else { - parts.push({ text: choice.delta.content }); - } + parts.push({ text: choice.delta.content }); } // Handle tool calls - only accumulate during streaming, emit when complete @@ -1776,7 +1768,7 @@ export class OpenAIContentGenerator implements ContentGenerator { } } - messageContent = textParts.join('').trimEnd(); + messageContent = textParts.join(''); } const choice: OpenAIChoice = { From 742337c390e29f6cf1a96f7c23dc429ec700f11b Mon Sep 17 00:00:00 2001 From: tanzhenxin Date: Thu, 21 Aug 2025 18:33:13 +0800 Subject: [PATCH 6/7] feat: Add deterministic cache control (#411) * feat: add deterministic cache control --- package-lock.json | 1 + packages/cli/src/config/config.ts | 1 - packages/cli/src/config/settingsSchema.ts | 10 - packages/core/package.json | 1 + packages/core/src/config/config.ts | 17 +- .../__tests__/openaiTimeoutHandling.test.ts | 48 +- packages/core/src/core/client.ts | 10 +- .../core/src/core/contentGenerator.test.ts | 1 + packages/core/src/core/contentGenerator.ts | 16 +- .../src/core/openaiContentGenerator.test.ts | 515 ++++++++++++-- .../core/src/core/openaiContentGenerator.ts | 639 ++++++------------ .../src/qwen/qwenContentGenerator.test.ts | 16 +- .../core/src/qwen/qwenContentGenerator.ts | 9 +- 13 files changed, 757 insertions(+), 527 deletions(-) diff --git a/package-lock.json b/package-lock.json index ef109c01..4aabf3cf 100644 --- a/package-lock.json +++ b/package-lock.json @@ -11997,6 +11997,7 @@ "strip-ansi": "^7.1.0", "tiktoken": "^1.0.21", "undici": "^7.10.0", + "uuid": "^9.0.1", "ws": "^8.18.0" }, "devDependencies": { diff --git a/packages/cli/src/config/config.ts b/packages/cli/src/config/config.ts index 0ec6bd07..c141be39 100644 --- a/packages/cli/src/config/config.ts +++ b/packages/cli/src/config/config.ts @@ -517,7 +517,6 @@ export async function loadCliConfig( (typeof argv.openaiLogging === 'undefined' ? settings.enableOpenAILogging : argv.openaiLogging) ?? false, - sampling_params: settings.sampling_params, systemPromptMappings: (settings.systemPromptMappings ?? [ { baseUrls: [ diff --git a/packages/cli/src/config/settingsSchema.ts b/packages/cli/src/config/settingsSchema.ts index 73ffebdc..30f16bf7 100644 --- a/packages/cli/src/config/settingsSchema.ts +++ b/packages/cli/src/config/settingsSchema.ts @@ -503,7 +503,6 @@ export const SETTINGS_SCHEMA = { description: 'Show line numbers in the chat.', showInDialog: true, }, - contentGenerator: { type: 'object', label: 'Content Generator', @@ -513,15 +512,6 @@ export const SETTINGS_SCHEMA = { description: 'Content generator settings.', showInDialog: false, }, - sampling_params: { - type: 'object', - label: 'Sampling Params', - category: 'General', - requiresRestart: false, - default: undefined as Record | undefined, - description: 'Sampling parameters for the model.', - showInDialog: false, - }, enableOpenAILogging: { type: 'boolean', label: 'Enable OpenAI Logging', diff --git a/packages/core/package.json b/packages/core/package.json index 7b84fd01..0555bf99 100644 --- a/packages/core/package.json +++ b/packages/core/package.json @@ -52,6 +52,7 @@ "strip-ansi": "^7.1.0", "tiktoken": "^1.0.21", "undici": "^7.10.0", + "uuid": "^9.0.1", "ws": "^8.18.0" }, "devDependencies": { diff --git a/packages/core/src/config/config.ts b/packages/core/src/config/config.ts index b1a2a096..f474a2dc 100644 --- a/packages/core/src/config/config.ts +++ b/packages/core/src/config/config.ts @@ -204,7 +204,6 @@ export interface ConfigParameters { folderTrust?: boolean; ideMode?: boolean; enableOpenAILogging?: boolean; - sampling_params?: Record; systemPromptMappings?: Array<{ baseUrls: string[]; modelNames: string[]; @@ -213,6 +212,9 @@ export interface ConfigParameters { contentGenerator?: { timeout?: number; maxRetries?: number; + samplingParams?: { + [key: string]: unknown; + }; }; cliVersion?: string; loadMemoryFromIncludeDirectories?: boolean; @@ -289,10 +291,10 @@ export class Config { | undefined; private readonly experimentalAcp: boolean = false; private readonly enableOpenAILogging: boolean; - private readonly sampling_params?: Record; private readonly contentGenerator?: { timeout?: number; maxRetries?: number; + samplingParams?: Record; }; private readonly cliVersion?: string; private readonly loadMemoryFromIncludeDirectories: boolean = false; @@ -367,7 +369,6 @@ export class Config { this.ideClient = IdeClient.getInstance(); this.systemPromptMappings = params.systemPromptMappings; this.enableOpenAILogging = params.enableOpenAILogging ?? false; - this.sampling_params = params.sampling_params; this.contentGenerator = params.contentGenerator; this.cliVersion = params.cliVersion; @@ -757,10 +758,6 @@ export class Config { return this.enableOpenAILogging; } - getSamplingParams(): Record | undefined { - return this.sampling_params; - } - getContentGeneratorTimeout(): number | undefined { return this.contentGenerator?.timeout; } @@ -769,6 +766,12 @@ export class Config { return this.contentGenerator?.maxRetries; } + getContentGeneratorSamplingParams(): ContentGeneratorConfig['samplingParams'] { + return this.contentGenerator?.samplingParams as + | ContentGeneratorConfig['samplingParams'] + | undefined; + } + getCliVersion(): string | undefined { return this.cliVersion; } diff --git a/packages/core/src/core/__tests__/openaiTimeoutHandling.test.ts b/packages/core/src/core/__tests__/openaiTimeoutHandling.test.ts index c743c9b5..bb46b09b 100644 --- a/packages/core/src/core/__tests__/openaiTimeoutHandling.test.ts +++ b/packages/core/src/core/__tests__/openaiTimeoutHandling.test.ts @@ -7,6 +7,7 @@ import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest'; import { OpenAIContentGenerator } from '../openaiContentGenerator.js'; import { Config } from '../../config/config.js'; +import { AuthType } from '../contentGenerator.js'; import OpenAI from 'openai'; // Mock OpenAI @@ -41,9 +42,6 @@ describe('OpenAIContentGenerator Timeout Handling', () => { mockConfig = { getContentGeneratorConfig: vi.fn().mockReturnValue({ authType: 'openai', - enableOpenAILogging: false, - timeout: 120000, - maxRetries: 3, }), getCliVersion: vi.fn().mockReturnValue('1.0.0'), } as unknown as Config; @@ -60,7 +58,12 @@ describe('OpenAIContentGenerator Timeout Handling', () => { vi.mocked(OpenAI).mockImplementation(() => mockOpenAIClient); // Create generator instance - generator = new OpenAIContentGenerator('test-key', 'gpt-4', mockConfig); + const contentGeneratorConfig = { + model: 'gpt-4', + apiKey: 'test-key', + authType: AuthType.USE_OPENAI, + }; + generator = new OpenAIContentGenerator(contentGeneratorConfig, mockConfig); }); afterEach(() => { @@ -237,12 +240,18 @@ describe('OpenAIContentGenerator Timeout Handling', () => { describe('timeout configuration', () => { it('should use default timeout configuration', () => { - new OpenAIContentGenerator('test-key', 'gpt-4', mockConfig); + const contentGeneratorConfig = { + model: 'gpt-4', + apiKey: 'test-key', + authType: AuthType.USE_OPENAI, + baseUrl: 'http://localhost:8080', + }; + new OpenAIContentGenerator(contentGeneratorConfig, mockConfig); // Verify OpenAI client was created with timeout config expect(OpenAI).toHaveBeenCalledWith({ apiKey: 'test-key', - baseURL: '', + baseURL: 'http://localhost:8080', timeout: 120000, maxRetries: 3, defaultHeaders: { @@ -253,18 +262,23 @@ describe('OpenAIContentGenerator Timeout Handling', () => { it('should use custom timeout from config', () => { const customConfig = { - getContentGeneratorConfig: vi.fn().mockReturnValue({ - timeout: 300000, // 5 minutes - maxRetries: 5, - }), + getContentGeneratorConfig: vi.fn().mockReturnValue({}), getCliVersion: vi.fn().mockReturnValue('1.0.0'), } as unknown as Config; - new OpenAIContentGenerator('test-key', 'gpt-4', customConfig); + const contentGeneratorConfig = { + model: 'gpt-4', + apiKey: 'test-key', + baseUrl: 'http://localhost:8080', + authType: AuthType.USE_OPENAI, + timeout: 300000, + maxRetries: 5, + }; + new OpenAIContentGenerator(contentGeneratorConfig, customConfig); expect(OpenAI).toHaveBeenCalledWith({ apiKey: 'test-key', - baseURL: '', + baseURL: 'http://localhost:8080', timeout: 300000, maxRetries: 5, defaultHeaders: { @@ -279,11 +293,17 @@ describe('OpenAIContentGenerator Timeout Handling', () => { getCliVersion: vi.fn().mockReturnValue('1.0.0'), } as unknown as Config; - new OpenAIContentGenerator('test-key', 'gpt-4', noTimeoutConfig); + const contentGeneratorConfig = { + model: 'gpt-4', + apiKey: 'test-key', + authType: AuthType.USE_OPENAI, + baseUrl: 'http://localhost:8080', + }; + new OpenAIContentGenerator(contentGeneratorConfig, noTimeoutConfig); expect(OpenAI).toHaveBeenCalledWith({ apiKey: 'test-key', - baseURL: '', + baseURL: 'http://localhost:8080', timeout: 120000, // default maxRetries: 3, // default defaultHeaders: { diff --git a/packages/core/src/core/client.ts b/packages/core/src/core/client.ts index f190df10..876e9027 100644 --- a/packages/core/src/core/client.ts +++ b/packages/core/src/core/client.ts @@ -565,10 +565,7 @@ export class GeminiClient { model || this.config.getModel() || DEFAULT_GEMINI_FLASH_MODEL; try { const userMemory = this.config.getUserMemory(); - const systemPromptMappings = this.config.getSystemPromptMappings(); - const systemInstruction = getCoreSystemPrompt(userMemory, { - systemPromptMappings, - }); + const systemInstruction = getCoreSystemPrompt(userMemory); const requestConfig = { abortSignal, ...this.generateContentConfig, @@ -656,10 +653,7 @@ export class GeminiClient { try { const userMemory = this.config.getUserMemory(); - const systemPromptMappings = this.config.getSystemPromptMappings(); - const systemInstruction = getCoreSystemPrompt(userMemory, { - systemPromptMappings, - }); + const systemInstruction = getCoreSystemPrompt(userMemory); const requestConfig = { abortSignal, diff --git a/packages/core/src/core/contentGenerator.test.ts b/packages/core/src/core/contentGenerator.test.ts index 5d735beb..2761c0c5 100644 --- a/packages/core/src/core/contentGenerator.test.ts +++ b/packages/core/src/core/contentGenerator.test.ts @@ -84,6 +84,7 @@ describe('createContentGeneratorConfig', () => { getSamplingParams: vi.fn().mockReturnValue(undefined), getContentGeneratorTimeout: vi.fn().mockReturnValue(undefined), getContentGeneratorMaxRetries: vi.fn().mockReturnValue(undefined), + getContentGeneratorSamplingParams: vi.fn().mockReturnValue(undefined), getCliVersion: vi.fn().mockReturnValue('1.0.0'), } as unknown as Config; diff --git a/packages/core/src/core/contentGenerator.ts b/packages/core/src/core/contentGenerator.ts index 2c90e9c6..582ffbe4 100644 --- a/packages/core/src/core/contentGenerator.ts +++ b/packages/core/src/core/contentGenerator.ts @@ -53,6 +53,7 @@ export enum AuthType { export type ContentGeneratorConfig = { model: string; apiKey?: string; + baseUrl?: string; vertexai?: boolean; authType?: AuthType | undefined; enableOpenAILogging?: boolean; @@ -76,11 +77,16 @@ export function createContentGeneratorConfig( config: Config, authType: AuthType | undefined, ): ContentGeneratorConfig { + // google auth const geminiApiKey = process.env.GEMINI_API_KEY || undefined; const googleApiKey = process.env.GOOGLE_API_KEY || undefined; const googleCloudProject = process.env.GOOGLE_CLOUD_PROJECT || undefined; const googleCloudLocation = process.env.GOOGLE_CLOUD_LOCATION || undefined; + + // openai auth const openaiApiKey = process.env.OPENAI_API_KEY; + const openaiBaseUrl = process.env.OPENAI_BASE_URL || undefined; + const openaiModel = process.env.OPENAI_MODEL || undefined; // Use runtime model from config if available; otherwise, fall back to parameter or default const effectiveModel = config.getModel() || DEFAULT_GEMINI_MODEL; @@ -92,7 +98,7 @@ export function createContentGeneratorConfig( enableOpenAILogging: config.getEnableOpenAILogging(), timeout: config.getContentGeneratorTimeout(), maxRetries: config.getContentGeneratorMaxRetries(), - samplingParams: config.getSamplingParams(), + samplingParams: config.getContentGeneratorSamplingParams(), }; // If we are using Google auth or we are in Cloud Shell, there is nothing else to validate for now @@ -127,8 +133,8 @@ export function createContentGeneratorConfig( if (authType === AuthType.USE_OPENAI && openaiApiKey) { contentGeneratorConfig.apiKey = openaiApiKey; - contentGeneratorConfig.model = - process.env.OPENAI_MODEL || DEFAULT_GEMINI_MODEL; + contentGeneratorConfig.baseUrl = openaiBaseUrl; + contentGeneratorConfig.model = openaiModel || DEFAULT_QWEN_MODEL; return contentGeneratorConfig; } @@ -196,7 +202,7 @@ export async function createContentGenerator( ); // Always use OpenAIContentGenerator, logging is controlled by enableOpenAILogging flag - return new OpenAIContentGenerator(config.apiKey, config.model, gcConfig); + return new OpenAIContentGenerator(config, gcConfig); } if (config.authType === AuthType.QWEN_OAUTH) { @@ -217,7 +223,7 @@ export async function createContentGenerator( const qwenClient = await getQwenOauthClient(gcConfig); // Create the content generator with dynamic token management - return new QwenContentGenerator(qwenClient, config.model, gcConfig); + return new QwenContentGenerator(qwenClient, config, gcConfig); } catch (error) { throw new Error( `Failed to initialize Qwen: ${error instanceof Error ? error.message : String(error)}`, diff --git a/packages/core/src/core/openaiContentGenerator.test.ts b/packages/core/src/core/openaiContentGenerator.test.ts index 92de6235..ac255a42 100644 --- a/packages/core/src/core/openaiContentGenerator.test.ts +++ b/packages/core/src/core/openaiContentGenerator.test.ts @@ -7,6 +7,7 @@ import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest'; import { OpenAIContentGenerator } from './openaiContentGenerator.js'; import { Config } from '../config/config.js'; +import { AuthType } from './contentGenerator.js'; import OpenAI from 'openai'; import type { GenerateContentParameters, @@ -84,7 +85,20 @@ describe('OpenAIContentGenerator', () => { vi.mocked(OpenAI).mockImplementation(() => mockOpenAIClient); // Create generator instance - generator = new OpenAIContentGenerator('test-key', 'gpt-4', mockConfig); + const contentGeneratorConfig = { + model: 'gpt-4', + apiKey: 'test-key', + authType: AuthType.USE_OPENAI, + enableOpenAILogging: false, + timeout: 120000, + maxRetries: 3, + samplingParams: { + temperature: 0.7, + max_tokens: 1000, + top_p: 0.9, + }, + }; + generator = new OpenAIContentGenerator(contentGeneratorConfig, mockConfig); }); afterEach(() => { @@ -95,7 +109,7 @@ describe('OpenAIContentGenerator', () => { it('should initialize with basic configuration', () => { expect(OpenAI).toHaveBeenCalledWith({ apiKey: 'test-key', - baseURL: '', + baseURL: undefined, timeout: 120000, maxRetries: 3, defaultHeaders: { @@ -105,9 +119,16 @@ describe('OpenAIContentGenerator', () => { }); it('should handle custom base URL', () => { - vi.stubEnv('OPENAI_BASE_URL', 'https://api.custom.com'); - - new OpenAIContentGenerator('test-key', 'gpt-4', mockConfig); + const contentGeneratorConfig = { + model: 'gpt-4', + apiKey: 'test-key', + baseUrl: 'https://api.custom.com', + authType: AuthType.USE_OPENAI, + enableOpenAILogging: false, + timeout: 120000, + maxRetries: 3, + }; + new OpenAIContentGenerator(contentGeneratorConfig, mockConfig); expect(OpenAI).toHaveBeenCalledWith({ apiKey: 'test-key', @@ -121,9 +142,16 @@ describe('OpenAIContentGenerator', () => { }); it('should configure OpenRouter headers when using OpenRouter', () => { - vi.stubEnv('OPENAI_BASE_URL', 'https://openrouter.ai/api/v1'); - - new OpenAIContentGenerator('test-key', 'gpt-4', mockConfig); + const contentGeneratorConfig = { + model: 'gpt-4', + apiKey: 'test-key', + baseUrl: 'https://openrouter.ai/api/v1', + authType: AuthType.USE_OPENAI, + enableOpenAILogging: false, + timeout: 120000, + maxRetries: 3, + }; + new OpenAIContentGenerator(contentGeneratorConfig, mockConfig); expect(OpenAI).toHaveBeenCalledWith({ apiKey: 'test-key', @@ -147,11 +175,18 @@ describe('OpenAIContentGenerator', () => { getCliVersion: vi.fn().mockReturnValue('1.0.0'), } as unknown as Config; - new OpenAIContentGenerator('test-key', 'gpt-4', customConfig); + const contentGeneratorConfig = { + model: 'gpt-4', + apiKey: 'test-key', + authType: AuthType.USE_OPENAI, + timeout: 300000, + maxRetries: 5, + }; + new OpenAIContentGenerator(contentGeneratorConfig, customConfig); expect(OpenAI).toHaveBeenCalledWith({ apiKey: 'test-key', - baseURL: '', + baseURL: undefined, timeout: 300000, maxRetries: 5, defaultHeaders: { @@ -906,9 +941,14 @@ describe('OpenAIContentGenerator', () => { getCliVersion: vi.fn().mockReturnValue('1.0.0'), } as unknown as Config; + const contentGeneratorConfig = { + model: 'gpt-4', + apiKey: 'test-key', + authType: AuthType.USE_OPENAI, + enableOpenAILogging: true, + }; const loggingGenerator = new OpenAIContentGenerator( - 'test-key', - 'gpt-4', + contentGeneratorConfig, loggingConfig, ); @@ -1029,9 +1069,14 @@ describe('OpenAIContentGenerator', () => { getCliVersion: vi.fn().mockReturnValue('1.0.0'), } as unknown as Config; + const contentGeneratorConfig = { + model: 'gpt-4', + apiKey: 'test-key', + authType: AuthType.USE_OPENAI, + enableOpenAILogging: true, + }; const loggingGenerator = new OpenAIContentGenerator( - 'test-key', - 'gpt-4', + contentGeneratorConfig, loggingConfig, ); @@ -1587,7 +1632,23 @@ describe('OpenAIContentGenerator', () => { } } - const testGenerator = new TestGenerator('test-key', 'gpt-4', mockConfig); + const contentGeneratorConfig = { + model: 'gpt-4', + apiKey: 'test-key', + authType: AuthType.USE_OPENAI, + enableOpenAILogging: false, + timeout: 120000, + maxRetries: 3, + samplingParams: { + temperature: 0.7, + max_tokens: 1000, + top_p: 0.9, + }, + }; + const testGenerator = new TestGenerator( + contentGeneratorConfig, + mockConfig, + ); const consoleSpy = vi .spyOn(console, 'error') .mockImplementation(() => {}); @@ -1908,9 +1969,18 @@ describe('OpenAIContentGenerator', () => { getCliVersion: vi.fn().mockReturnValue('1.0.0'), } as unknown as Config; + const contentGeneratorConfig = { + model: 'gpt-4', + apiKey: 'test-key', + authType: AuthType.USE_OPENAI, + enableOpenAILogging: true, + samplingParams: { + temperature: 0.8, + max_tokens: 500, + }, + }; const loggingGenerator = new OpenAIContentGenerator( - 'test-key', - 'gpt-4', + contentGeneratorConfig, loggingConfig, ); @@ -2093,9 +2163,14 @@ describe('OpenAIContentGenerator', () => { getCliVersion: vi.fn().mockReturnValue('1.0.0'), } as unknown as Config; + const contentGeneratorConfig = { + model: 'gpt-4', + apiKey: 'test-key', + authType: AuthType.USE_OPENAI, + enableOpenAILogging: true, + }; const loggingGenerator = new OpenAIContentGenerator( - 'test-key', - 'gpt-4', + contentGeneratorConfig, loggingConfig, ); @@ -2350,9 +2425,18 @@ describe('OpenAIContentGenerator', () => { getCliVersion: vi.fn().mockReturnValue('1.0.0'), } as unknown as Config; + const contentGeneratorConfig = { + model: 'gpt-4', + apiKey: 'test-key', + authType: AuthType.USE_OPENAI, + samplingParams: { + temperature: undefined, + max_tokens: undefined, + top_p: undefined, + }, + }; const testGenerator = new OpenAIContentGenerator( - 'test-key', - 'gpt-4', + contentGeneratorConfig, configWithUndefined, ); @@ -2408,9 +2492,22 @@ describe('OpenAIContentGenerator', () => { getCliVersion: vi.fn().mockReturnValue('1.0.0'), } as unknown as Config; + const contentGeneratorConfig = { + model: 'gpt-4', + apiKey: 'test-key', + authType: AuthType.USE_OPENAI, + samplingParams: { + temperature: 0.8, + max_tokens: 1500, + top_p: 0.95, + top_k: 40, + repetition_penalty: 1.1, + presence_penalty: 0.5, + frequency_penalty: 0.3, + }, + }; const testGenerator = new OpenAIContentGenerator( - 'test-key', - 'gpt-4', + contentGeneratorConfig, fullSamplingConfig, ); @@ -2489,9 +2586,14 @@ describe('OpenAIContentGenerator', () => { getCliVersion: vi.fn().mockReturnValue('1.0.0'), } as unknown as Config; + const contentGeneratorConfig = { + model: 'qwen-turbo', + apiKey: 'test-key', + authType: AuthType.QWEN_OAUTH, + enableOpenAILogging: false, + }; const qwenGenerator = new OpenAIContentGenerator( - 'test-key', - 'qwen-turbo', + contentGeneratorConfig, qwenConfig, ); @@ -2528,12 +2630,6 @@ describe('OpenAIContentGenerator', () => { }); it('should include metadata when baseURL is dashscope openai compatible mode', async () => { - // Mock environment to set dashscope base URL BEFORE creating the generator - vi.stubEnv( - 'OPENAI_BASE_URL', - 'https://dashscope.aliyuncs.com/compatible-mode/v1', - ); - const dashscopeConfig = { getContentGeneratorConfig: vi.fn().mockReturnValue({ authType: 'openai', // Not QWEN_OAUTH @@ -2543,9 +2639,15 @@ describe('OpenAIContentGenerator', () => { getCliVersion: vi.fn().mockReturnValue('1.0.0'), } as unknown as Config; + const contentGeneratorConfig = { + model: 'qwen-turbo', + apiKey: 'test-key', + baseUrl: 'https://dashscope.aliyuncs.com/compatible-mode/v1', + authType: AuthType.USE_OPENAI, + enableOpenAILogging: false, + }; const dashscopeGenerator = new OpenAIContentGenerator( - 'test-key', - 'qwen-turbo', + contentGeneratorConfig, dashscopeConfig, ); @@ -2604,9 +2706,18 @@ describe('OpenAIContentGenerator', () => { getCliVersion: vi.fn().mockReturnValue('1.0.0'), } as unknown as Config; + const contentGeneratorConfig = { + model: 'gpt-4', + + apiKey: 'test-key', + + authType: AuthType.USE_OPENAI, + + enableOpenAILogging: false, + }; + const regularGenerator = new OpenAIContentGenerator( - 'test-key', - 'gpt-4', + contentGeneratorConfig, regularConfig, ); @@ -2650,9 +2761,18 @@ describe('OpenAIContentGenerator', () => { getCliVersion: vi.fn().mockReturnValue('1.0.0'), } as unknown as Config; + const contentGeneratorConfig = { + model: 'gpt-4', + + apiKey: 'test-key', + + authType: AuthType.USE_OPENAI, + + enableOpenAILogging: false, + }; + const otherGenerator = new OpenAIContentGenerator( - 'test-key', - 'gpt-4', + contentGeneratorConfig, otherAuthConfig, ); @@ -2699,9 +2819,18 @@ describe('OpenAIContentGenerator', () => { getCliVersion: vi.fn().mockReturnValue('1.0.0'), } as unknown as Config; + const contentGeneratorConfig = { + model: 'gpt-4', + + apiKey: 'test-key', + + authType: AuthType.USE_OPENAI, + + enableOpenAILogging: false, + }; + const otherBaseUrlGenerator = new OpenAIContentGenerator( - 'test-key', - 'gpt-4', + contentGeneratorConfig, otherBaseUrlConfig, ); @@ -2748,9 +2877,18 @@ describe('OpenAIContentGenerator', () => { getCliVersion: vi.fn().mockReturnValue('1.0.0'), } as unknown as Config; + const contentGeneratorConfig = { + model: 'qwen-turbo', + + apiKey: 'test-key', + + authType: AuthType.QWEN_OAUTH, + + enableOpenAILogging: false, + }; + const qwenGenerator = new OpenAIContentGenerator( - 'test-key', - 'qwen-turbo', + contentGeneratorConfig, qwenConfig, ); @@ -2804,8 +2942,6 @@ describe('OpenAIContentGenerator', () => { sessionId: 'streaming-session-id', promptId: 'streaming-prompt-id', }, - stream: true, - stream_options: { include_usage: true }, }), ); @@ -2827,9 +2963,18 @@ describe('OpenAIContentGenerator', () => { getCliVersion: vi.fn().mockReturnValue('1.0.0'), } as unknown as Config; + const contentGeneratorConfig = { + model: 'gpt-4', + + apiKey: 'test-key', + + authType: AuthType.USE_OPENAI, + + enableOpenAILogging: false, + }; + const regularGenerator = new OpenAIContentGenerator( - 'test-key', - 'gpt-4', + contentGeneratorConfig, regularConfig, ); @@ -2901,9 +3046,18 @@ describe('OpenAIContentGenerator', () => { getCliVersion: vi.fn().mockReturnValue('1.0.0'), } as unknown as Config; + const contentGeneratorConfig = { + model: 'qwen-turbo', + + apiKey: 'test-key', + + authType: AuthType.QWEN_OAUTH, + + enableOpenAILogging: false, + }; + const qwenGenerator = new OpenAIContentGenerator( - 'test-key', - 'qwen-turbo', + contentGeneratorConfig, qwenConfig, ); @@ -2955,9 +3109,18 @@ describe('OpenAIContentGenerator', () => { getCliVersion: vi.fn().mockReturnValue('1.0.0'), } as unknown as Config; + const contentGeneratorConfig = { + model: 'gpt-4', + + apiKey: 'test-key', + + authType: AuthType.USE_OPENAI, + + enableOpenAILogging: false, + }; + const noBaseUrlGenerator = new OpenAIContentGenerator( - 'test-key', - 'gpt-4', + contentGeneratorConfig, noBaseUrlConfig, ); @@ -3004,9 +3167,18 @@ describe('OpenAIContentGenerator', () => { getCliVersion: vi.fn().mockReturnValue('1.0.0'), } as unknown as Config; + const contentGeneratorConfig = { + model: 'gpt-4', + + apiKey: 'test-key', + + authType: AuthType.USE_OPENAI, + + enableOpenAILogging: false, + }; + const undefinedAuthGenerator = new OpenAIContentGenerator( - 'test-key', - 'gpt-4', + contentGeneratorConfig, undefinedAuthConfig, ); @@ -3050,9 +3222,18 @@ describe('OpenAIContentGenerator', () => { getCliVersion: vi.fn().mockReturnValue('1.0.0'), } as unknown as Config; + const contentGeneratorConfig = { + model: 'gpt-4', + + apiKey: 'test-key', + + authType: AuthType.USE_OPENAI, + + enableOpenAILogging: false, + }; + const undefinedConfigGenerator = new OpenAIContentGenerator( - 'test-key', - 'gpt-4', + contentGeneratorConfig, undefinedConfig, ); @@ -3089,4 +3270,232 @@ describe('OpenAIContentGenerator', () => { ); }); }); + + describe('cache control for DashScope', () => { + it('should add cache control to system message for DashScope providers', async () => { + // Mock environment to set dashscope base URL + vi.stubEnv( + 'OPENAI_BASE_URL', + 'https://dashscope.aliyuncs.com/compatible-mode/v1', + ); + + const dashscopeConfig = { + getContentGeneratorConfig: vi.fn().mockReturnValue({ + authType: 'openai', + enableOpenAILogging: false, + }), + getSessionId: vi.fn().mockReturnValue('dashscope-session-id'), + getCliVersion: vi.fn().mockReturnValue('1.0.0'), + } as unknown as Config; + + const contentGeneratorConfig = { + model: 'qwen-turbo', + + apiKey: 'test-key', + + authType: AuthType.QWEN_OAUTH, + + enableOpenAILogging: false, + }; + + const dashscopeGenerator = new OpenAIContentGenerator( + contentGeneratorConfig, + dashscopeConfig, + ); + + // Mock the client's baseURL property to return the expected value + Object.defineProperty(dashscopeGenerator['client'], 'baseURL', { + value: 'https://dashscope.aliyuncs.com/compatible-mode/v1', + writable: true, + }); + + const mockResponse = { + id: 'chatcmpl-123', + choices: [ + { + index: 0, + message: { role: 'assistant', content: 'Response' }, + finish_reason: 'stop', + }, + ], + created: 1677652288, + model: 'qwen-turbo', + }; + + mockOpenAIClient.chat.completions.create.mockResolvedValue(mockResponse); + + const request: GenerateContentParameters = { + contents: [{ role: 'user', parts: [{ text: 'Hello' }] }], + config: { + systemInstruction: 'You are a helpful assistant.', + }, + model: 'qwen-turbo', + }; + + await dashscopeGenerator.generateContent(request, 'dashscope-prompt-id'); + + // Should include cache control in system message + expect(mockOpenAIClient.chat.completions.create).toHaveBeenCalledWith( + expect.objectContaining({ + messages: expect.arrayContaining([ + expect.objectContaining({ + role: 'system', + content: expect.arrayContaining([ + expect.objectContaining({ + type: 'text', + text: 'You are a helpful assistant.', + cache_control: { type: 'ephemeral' }, + }), + ]), + }), + ]), + }), + ); + }); + + it('should add cache control to last message for DashScope providers', async () => { + // Mock environment to set dashscope base URL + vi.stubEnv( + 'OPENAI_BASE_URL', + 'https://dashscope.aliyuncs.com/compatible-mode/v1', + ); + + const dashscopeConfig = { + getContentGeneratorConfig: vi.fn().mockReturnValue({ + authType: 'openai', + enableOpenAILogging: false, + }), + getSessionId: vi.fn().mockReturnValue('dashscope-session-id'), + getCliVersion: vi.fn().mockReturnValue('1.0.0'), + } as unknown as Config; + + const contentGeneratorConfig = { + model: 'qwen-turbo', + + apiKey: 'test-key', + + authType: AuthType.QWEN_OAUTH, + + enableOpenAILogging: false, + }; + + const dashscopeGenerator = new OpenAIContentGenerator( + contentGeneratorConfig, + dashscopeConfig, + ); + + // Mock the client's baseURL property to return the expected value + Object.defineProperty(dashscopeGenerator['client'], 'baseURL', { + value: 'https://dashscope.aliyuncs.com/compatible-mode/v1', + writable: true, + }); + + const mockResponse = { + id: 'chatcmpl-123', + choices: [ + { + index: 0, + message: { role: 'assistant', content: 'Response' }, + finish_reason: 'stop', + }, + ], + created: 1677652288, + model: 'qwen-turbo', + }; + + mockOpenAIClient.chat.completions.create.mockResolvedValue(mockResponse); + + const request: GenerateContentParameters = { + contents: [{ role: 'user', parts: [{ text: 'Hello, how are you?' }] }], + model: 'qwen-turbo', + }; + + await dashscopeGenerator.generateContent(request, 'dashscope-prompt-id'); + + // Should include cache control in last message + expect(mockOpenAIClient.chat.completions.create).toHaveBeenCalledWith( + expect.objectContaining({ + messages: expect.arrayContaining([ + expect.objectContaining({ + role: 'user', + content: expect.arrayContaining([ + expect.objectContaining({ + type: 'text', + text: 'Hello, how are you?', + cache_control: { type: 'ephemeral' }, + }), + ]), + }), + ]), + }), + ); + }); + + it('should NOT add cache control for non-DashScope providers', async () => { + const regularConfig = { + getContentGeneratorConfig: vi.fn().mockReturnValue({ + authType: 'openai', + enableOpenAILogging: false, + }), + getSessionId: vi.fn().mockReturnValue('regular-session-id'), + getCliVersion: vi.fn().mockReturnValue('1.0.0'), + } as unknown as Config; + + const contentGeneratorConfig = { + model: 'gpt-4', + + apiKey: 'test-key', + + authType: AuthType.USE_OPENAI, + + enableOpenAILogging: false, + }; + + const regularGenerator = new OpenAIContentGenerator( + contentGeneratorConfig, + regularConfig, + ); + + const mockResponse = { + id: 'chatcmpl-123', + choices: [ + { + index: 0, + message: { role: 'assistant', content: 'Response' }, + finish_reason: 'stop', + }, + ], + created: 1677652288, + model: 'gpt-4', + }; + + mockOpenAIClient.chat.completions.create.mockResolvedValue(mockResponse); + + const request: GenerateContentParameters = { + contents: [{ role: 'user', parts: [{ text: 'Hello' }] }], + config: { + systemInstruction: 'You are a helpful assistant.', + }, + model: 'gpt-4', + }; + + await regularGenerator.generateContent(request, 'regular-prompt-id'); + + // Should NOT include cache control (messages should be strings, not arrays) + expect(mockOpenAIClient.chat.completions.create).toHaveBeenCalledWith( + expect.objectContaining({ + messages: expect.arrayContaining([ + expect.objectContaining({ + role: 'system', + content: 'You are a helpful assistant.', + }), + expect.objectContaining({ + role: 'user', + content: 'Hello', + }), + ]), + }), + ); + }); + }); }); diff --git a/packages/core/src/core/openaiContentGenerator.ts b/packages/core/src/core/openaiContentGenerator.ts index e24bd0c3..3f223f3e 100644 --- a/packages/core/src/core/openaiContentGenerator.ts +++ b/packages/core/src/core/openaiContentGenerator.ts @@ -20,7 +20,11 @@ import { FunctionCall, FunctionResponse, } from '@google/genai'; -import { AuthType, ContentGenerator } from './contentGenerator.js'; +import { + AuthType, + ContentGenerator, + ContentGeneratorConfig, +} from './contentGenerator.js'; import OpenAI from 'openai'; import { logApiError, logApiResponse } from '../telemetry/loggers.js'; import { ApiErrorEvent, ApiResponseEvent } from '../telemetry/types.js'; @@ -28,6 +32,17 @@ import { Config } from '../config/config.js'; import { openaiLogger } from '../utils/openaiLogger.js'; import { safeJsonParse } from '../utils/safeJsonParse.js'; +// Extended types to support cache_control +interface ChatCompletionContentPartTextWithCache + extends OpenAI.Chat.ChatCompletionContentPartText { + cache_control?: { type: 'ephemeral' }; +} + +type ChatCompletionContentPartWithCache = + | ChatCompletionContentPartTextWithCache + | OpenAI.Chat.ChatCompletionContentPartImage + | OpenAI.Chat.ChatCompletionContentPartRefusal; + // OpenAI API type definitions for logging interface OpenAIToolCall { id: string; @@ -38,9 +53,15 @@ interface OpenAIToolCall { }; } +interface OpenAIContentItem { + type: 'text'; + text: string; + cache_control?: { type: 'ephemeral' }; +} + interface OpenAIMessage { role: 'system' | 'user' | 'assistant' | 'tool'; - content: string | null; + content: string | null | OpenAIContentItem[]; tool_calls?: OpenAIToolCall[]; tool_call_id?: string; } @@ -60,15 +81,6 @@ interface OpenAIChoice { finish_reason: string; } -interface OpenAIRequestFormat { - model: string; - messages: OpenAIMessage[]; - temperature?: number; - max_tokens?: number; - top_p?: number; - tools?: unknown[]; -} - interface OpenAIResponseFormat { id: string; object: string; @@ -81,6 +93,7 @@ interface OpenAIResponseFormat { export class OpenAIContentGenerator implements ContentGenerator { protected client: OpenAI; private model: string; + private contentGeneratorConfig: ContentGeneratorConfig; private config: Config; private streamingToolCalls: Map< number, @@ -91,50 +104,40 @@ export class OpenAIContentGenerator implements ContentGenerator { } > = new Map(); - constructor(apiKey: string, model: string, config: Config) { - this.model = model; - this.config = config; - const baseURL = process.env.OPENAI_BASE_URL || ''; + constructor( + contentGeneratorConfig: ContentGeneratorConfig, + gcConfig: Config, + ) { + this.model = contentGeneratorConfig.model; + this.contentGeneratorConfig = contentGeneratorConfig; + this.config = gcConfig; - // Configure timeout settings - using progressive timeouts - const timeoutConfig = { - // Base timeout for most requests (2 minutes) - timeout: 120000, - // Maximum retries for failed requests - maxRetries: 3, - // HTTP client options - httpAgent: undefined, // Let the client use default agent - }; - - // Allow config to override timeout settings - const contentGeneratorConfig = this.config.getContentGeneratorConfig(); - if (contentGeneratorConfig?.timeout) { - timeoutConfig.timeout = contentGeneratorConfig.timeout; - } - if (contentGeneratorConfig?.maxRetries !== undefined) { - timeoutConfig.maxRetries = contentGeneratorConfig.maxRetries; - } - - const version = config.getCliVersion() || 'unknown'; + const version = gcConfig.getCliVersion() || 'unknown'; const userAgent = `QwenCode/${version} (${process.platform}; ${process.arch})`; // Check if using OpenRouter and add required headers - const isOpenRouter = baseURL.includes('openrouter.ai'); + const isOpenRouterProvider = this.isOpenRouterProvider(); + const isDashScopeProvider = this.isDashScopeProvider(); + const defaultHeaders = { 'User-Agent': userAgent, - ...(isOpenRouter + ...(isOpenRouterProvider ? { 'HTTP-Referer': 'https://github.com/QwenLM/qwen-code.git', 'X-Title': 'Qwen Code', } - : {}), + : isDashScopeProvider + ? { + 'X-DashScope-CacheControl': 'enable', + } + : {}), }; this.client = new OpenAI({ - apiKey, - baseURL, - timeout: timeoutConfig.timeout, - maxRetries: timeoutConfig.maxRetries, + apiKey: contentGeneratorConfig.apiKey, + baseURL: contentGeneratorConfig.baseUrl, + timeout: contentGeneratorConfig.timeout ?? 120000, + maxRetries: contentGeneratorConfig.maxRetries ?? 3, defaultHeaders, }); } @@ -185,22 +188,25 @@ export class OpenAIContentGenerator implements ContentGenerator { ); } + private isOpenRouterProvider(): boolean { + const baseURL = this.contentGeneratorConfig.baseUrl || ''; + return baseURL.includes('openrouter.ai'); + } + /** - * Determine if metadata should be included in the request. - * Only include the `metadata` field if the provider is QWEN_OAUTH - * or the baseUrl is 'https://dashscope.aliyuncs.com/compatible-mode/v1'. - * This is because some models/providers do not support metadata or need extra configuration. + * Determine if this is a DashScope provider. + * DashScope providers include QWEN_OAUTH auth type or specific DashScope base URLs. * - * @returns true if metadata should be included, false otherwise + * @returns true if this is a DashScope provider, false otherwise */ - private shouldIncludeMetadata(): boolean { - const authType = this.config.getContentGeneratorConfig?.()?.authType; - // baseUrl may be undefined; default to empty string if so - const baseUrl = this.client?.baseURL || ''; + private isDashScopeProvider(): boolean { + const authType = this.contentGeneratorConfig.authType; + const baseUrl = this.contentGeneratorConfig.baseUrl; return ( authType === AuthType.QWEN_OAUTH || - baseUrl === 'https://dashscope.aliyuncs.com/compatible-mode/v1' + baseUrl === 'https://dashscope.aliyuncs.com/compatible-mode/v1' || + baseUrl === 'https://dashscope-intl.aliyuncs.com/compatible-mode/v1' ); } @@ -213,7 +219,7 @@ export class OpenAIContentGenerator implements ContentGenerator { private buildMetadata( userPromptId: string, ): { metadata: { sessionId?: string; promptId: string } } | undefined { - if (!this.shouldIncludeMetadata()) { + if (!this.isDashScopeProvider()) { return undefined; } @@ -225,35 +231,44 @@ export class OpenAIContentGenerator implements ContentGenerator { }; } + private async buildCreateParams( + request: GenerateContentParameters, + userPromptId: string, + ): Promise[0]> { + const messages = this.convertToOpenAIFormat(request); + + // Build sampling parameters with clear priority: + // 1. Request-level parameters (highest priority) + // 2. Config-level sampling parameters (medium priority) + // 3. Default values (lowest priority) + const samplingParams = this.buildSamplingParameters(request); + + const createParams: Parameters< + typeof this.client.chat.completions.create + >[0] = { + model: this.model, + messages, + ...samplingParams, + ...(this.buildMetadata(userPromptId) || {}), + }; + + if (request.config?.tools) { + createParams.tools = await this.convertGeminiToolsToOpenAI( + request.config.tools, + ); + } + + return createParams; + } + async generateContent( request: GenerateContentParameters, userPromptId: string, ): Promise { const startTime = Date.now(); - const messages = this.convertToOpenAIFormat(request); + const createParams = await this.buildCreateParams(request, userPromptId); try { - // Build sampling parameters with clear priority: - // 1. Request-level parameters (highest priority) - // 2. Config-level sampling parameters (medium priority) - // 3. Default values (lowest priority) - const samplingParams = this.buildSamplingParameters(request); - - const createParams: Parameters< - typeof this.client.chat.completions.create - >[0] = { - model: this.model, - messages, - ...samplingParams, - ...(this.buildMetadata(userPromptId) || {}), - }; - - if (request.config?.tools) { - createParams.tools = await this.convertGeminiToolsToOpenAI( - request.config.tools, - ); - } - // console.log('createParams', createParams); const completion = (await this.client.chat.completions.create( createParams, )) as OpenAI.Chat.ChatCompletion; @@ -267,15 +282,15 @@ export class OpenAIContentGenerator implements ContentGenerator { this.model, durationMs, userPromptId, - this.config.getContentGeneratorConfig()?.authType, + this.contentGeneratorConfig.authType, response.usageMetadata, ); logApiResponse(this.config, responseEvent); // Log interaction if enabled - if (this.config.getContentGeneratorConfig()?.enableOpenAILogging) { - const openaiRequest = await this.convertGeminiRequestToOpenAI(request); + if (this.contentGeneratorConfig.enableOpenAILogging) { + const openaiRequest = createParams; const openaiResponse = this.convertGeminiResponseToOpenAI(response); await openaiLogger.logInteraction(openaiRequest, openaiResponse); } @@ -300,7 +315,7 @@ export class OpenAIContentGenerator implements ContentGenerator { errorMessage, durationMs, userPromptId, - this.config.getContentGeneratorConfig()?.authType, + this.contentGeneratorConfig.authType, // eslint-disable-next-line @typescript-eslint/no-explicit-any (error as any).type, // eslint-disable-next-line @typescript-eslint/no-explicit-any @@ -309,10 +324,9 @@ export class OpenAIContentGenerator implements ContentGenerator { logApiError(this.config, errorEvent); // Log error interaction if enabled - if (this.config.getContentGeneratorConfig()?.enableOpenAILogging) { - const openaiRequest = await this.convertGeminiRequestToOpenAI(request); + if (this.contentGeneratorConfig.enableOpenAILogging) { await openaiLogger.logInteraction( - openaiRequest, + createParams, undefined, error as Error, ); @@ -343,29 +357,12 @@ export class OpenAIContentGenerator implements ContentGenerator { userPromptId: string, ): Promise> { const startTime = Date.now(); - const messages = this.convertToOpenAIFormat(request); + const createParams = await this.buildCreateParams(request, userPromptId); + + createParams.stream = true; + createParams.stream_options = { include_usage: true }; try { - // Build sampling parameters with clear priority - const samplingParams = this.buildSamplingParameters(request); - - const createParams: Parameters< - typeof this.client.chat.completions.create - >[0] = { - model: this.model, - messages, - ...samplingParams, - stream: true, - stream_options: { include_usage: true }, - ...(this.buildMetadata(userPromptId) || {}), - }; - - if (request.config?.tools) { - createParams.tools = await this.convertGeminiToolsToOpenAI( - request.config.tools, - ); - } - const stream = (await this.client.chat.completions.create( createParams, )) as AsyncIterable; @@ -397,16 +394,15 @@ export class OpenAIContentGenerator implements ContentGenerator { this.model, durationMs, userPromptId, - this.config.getContentGeneratorConfig()?.authType, + this.contentGeneratorConfig.authType, finalUsageMetadata, ); logApiResponse(this.config, responseEvent); // Log interaction if enabled (same as generateContent method) - if (this.config.getContentGeneratorConfig()?.enableOpenAILogging) { - const openaiRequest = - await this.convertGeminiRequestToOpenAI(request); + if (this.contentGeneratorConfig.enableOpenAILogging) { + const openaiRequest = createParams; // For streaming, we combine all responses into a single response for logging const combinedResponse = this.combineStreamResponsesForLogging(responses); @@ -433,7 +429,7 @@ export class OpenAIContentGenerator implements ContentGenerator { errorMessage, durationMs, userPromptId, - this.config.getContentGeneratorConfig()?.authType, + this.contentGeneratorConfig.authType, // eslint-disable-next-line @typescript-eslint/no-explicit-any (error as any).type, // eslint-disable-next-line @typescript-eslint/no-explicit-any @@ -442,11 +438,9 @@ export class OpenAIContentGenerator implements ContentGenerator { logApiError(this.config, errorEvent); // Log error interaction if enabled - if (this.config.getContentGeneratorConfig()?.enableOpenAILogging) { - const openaiRequest = - await this.convertGeminiRequestToOpenAI(request); + if (this.contentGeneratorConfig.enableOpenAILogging) { await openaiLogger.logInteraction( - openaiRequest, + createParams, undefined, error as Error, ); @@ -487,7 +481,7 @@ export class OpenAIContentGenerator implements ContentGenerator { errorMessage, durationMs, userPromptId, - this.config.getContentGeneratorConfig()?.authType, + this.contentGeneratorConfig.authType, // eslint-disable-next-line @typescript-eslint/no-explicit-any (error as any).type, // eslint-disable-next-line @typescript-eslint/no-explicit-any @@ -944,7 +938,114 @@ export class OpenAIContentGenerator implements ContentGenerator { // Clean up orphaned tool calls and merge consecutive assistant messages const cleanedMessages = this.cleanOrphanedToolCalls(messages); - return this.mergeConsecutiveAssistantMessages(cleanedMessages); + const mergedMessages = + this.mergeConsecutiveAssistantMessages(cleanedMessages); + + // Add cache control to system and last messages for DashScope providers + return this.addCacheControlFlag(mergedMessages, 'both'); + } + + /** + * Add cache control flag to specified message(s) for DashScope providers + */ + private addCacheControlFlag( + messages: OpenAI.Chat.ChatCompletionMessageParam[], + target: 'system' | 'last' | 'both' = 'both', + ): OpenAI.Chat.ChatCompletionMessageParam[] { + if (!this.isDashScopeProvider() || messages.length === 0) { + return messages; + } + + let updatedMessages = [...messages]; + + // Add cache control to system message if requested + if (target === 'system' || target === 'both') { + updatedMessages = this.addCacheControlToMessage( + updatedMessages, + 'system', + ); + } + + // Add cache control to last message if requested + if (target === 'last' || target === 'both') { + updatedMessages = this.addCacheControlToMessage(updatedMessages, 'last'); + } + + return updatedMessages; + } + + /** + * Helper method to add cache control to a specific message + */ + private addCacheControlToMessage( + messages: OpenAI.Chat.ChatCompletionMessageParam[], + target: 'system' | 'last', + ): OpenAI.Chat.ChatCompletionMessageParam[] { + const updatedMessages = [...messages]; + let messageIndex: number; + + if (target === 'system') { + // Find the first system message + messageIndex = messages.findIndex((msg) => msg.role === 'system'); + if (messageIndex === -1) { + return updatedMessages; + } + } else { + // Get the last message + messageIndex = messages.length - 1; + } + + const message = updatedMessages[messageIndex]; + + // Only process messages that have content + if ('content' in message && message.content !== null) { + if (typeof message.content === 'string') { + // Convert string content to array format with cache control + const messageWithArrayContent = { + ...message, + content: [ + { + type: 'text', + text: message.content, + cache_control: { type: 'ephemeral' }, + } as ChatCompletionContentPartTextWithCache, + ], + }; + updatedMessages[messageIndex] = + messageWithArrayContent as OpenAI.Chat.ChatCompletionMessageParam; + } else if (Array.isArray(message.content)) { + // If content is already an array, add cache_control to the last item + const contentArray = [ + ...message.content, + ] as ChatCompletionContentPartWithCache[]; + if (contentArray.length > 0) { + const lastItem = contentArray[contentArray.length - 1]; + if (lastItem.type === 'text') { + // Add cache_control to the last text item + contentArray[contentArray.length - 1] = { + ...lastItem, + cache_control: { type: 'ephemeral' }, + } as ChatCompletionContentPartTextWithCache; + } else { + // If the last item is not text, add a new text item with cache_control + contentArray.push({ + type: 'text', + text: '', + cache_control: { type: 'ephemeral' }, + } as ChatCompletionContentPartTextWithCache); + } + + const messageWithCache = { + ...message, + content: contentArray, + }; + updatedMessages[messageIndex] = + messageWithCache as OpenAI.Chat.ChatCompletionMessageParam; + } + } + } + + return updatedMessages; } /** @@ -1368,8 +1469,7 @@ export class OpenAIContentGenerator implements ContentGenerator { private buildSamplingParameters( request: GenerateContentParameters, ): Record { - const configSamplingParams = - this.config.getContentGeneratorConfig()?.samplingParams; + const configSamplingParams = this.contentGeneratorConfig.samplingParams; const params = { // Temperature: config > request > default @@ -1431,313 +1531,6 @@ export class OpenAIContentGenerator implements ContentGenerator { return mapping[openaiReason] || FinishReason.FINISH_REASON_UNSPECIFIED; } - /** - * Convert Gemini request format to OpenAI chat completion format for logging - */ - private async convertGeminiRequestToOpenAI( - request: GenerateContentParameters, - ): Promise { - const messages: OpenAIMessage[] = []; - - // Handle system instruction - if (request.config?.systemInstruction) { - const systemInstruction = request.config.systemInstruction; - let systemText = ''; - - if (Array.isArray(systemInstruction)) { - systemText = systemInstruction - .map((content) => { - if (typeof content === 'string') return content; - if ('parts' in content) { - const contentObj = content as Content; - return ( - contentObj.parts - ?.map((p: Part) => - typeof p === 'string' ? p : 'text' in p ? p.text : '', - ) - .join('\n') || '' - ); - } - return ''; - }) - .join('\n'); - } else if (typeof systemInstruction === 'string') { - systemText = systemInstruction; - } else if ( - typeof systemInstruction === 'object' && - 'parts' in systemInstruction - ) { - const systemContent = systemInstruction as Content; - systemText = - systemContent.parts - ?.map((p: Part) => - typeof p === 'string' ? p : 'text' in p ? p.text : '', - ) - .join('\n') || ''; - } - - if (systemText) { - messages.push({ - role: 'system', - content: systemText, - }); - } - } - - // Handle contents - if (Array.isArray(request.contents)) { - for (const content of request.contents) { - if (typeof content === 'string') { - messages.push({ role: 'user', content }); - } else if ('role' in content && 'parts' in content) { - const functionCalls: FunctionCall[] = []; - const functionResponses: FunctionResponse[] = []; - const textParts: string[] = []; - - for (const part of content.parts || []) { - if (typeof part === 'string') { - textParts.push(part); - } else if ('text' in part && part.text) { - textParts.push(part.text); - } else if ('functionCall' in part && part.functionCall) { - functionCalls.push(part.functionCall); - } else if ('functionResponse' in part && part.functionResponse) { - functionResponses.push(part.functionResponse); - } - } - - // Handle function responses (tool results) - if (functionResponses.length > 0) { - for (const funcResponse of functionResponses) { - messages.push({ - role: 'tool', - tool_call_id: funcResponse.id || '', - content: - typeof funcResponse.response === 'string' - ? funcResponse.response - : JSON.stringify(funcResponse.response), - }); - } - } - // Handle model messages with function calls - else if (content.role === 'model' && functionCalls.length > 0) { - const toolCalls = functionCalls.map((fc, index) => ({ - id: fc.id || `call_${index}`, - type: 'function' as const, - function: { - name: fc.name || '', - arguments: JSON.stringify(fc.args || {}), - }, - })); - - messages.push({ - role: 'assistant', - content: textParts.join('\n') || null, - tool_calls: toolCalls, - }); - } - // Handle regular text messages - else { - const role = content.role === 'model' ? 'assistant' : 'user'; - const text = textParts.join('\n'); - if (text) { - messages.push({ role, content: text }); - } - } - } - } - } else if (request.contents) { - if (typeof request.contents === 'string') { - messages.push({ role: 'user', content: request.contents }); - } else if ('role' in request.contents && 'parts' in request.contents) { - const content = request.contents; - const role = content.role === 'model' ? 'assistant' : 'user'; - const text = - content.parts - ?.map((p: Part) => - typeof p === 'string' ? p : 'text' in p ? p.text : '', - ) - .join('\n') || ''; - messages.push({ role, content: text }); - } - } - - // Clean up orphaned tool calls and merge consecutive assistant messages - const cleanedMessages = this.cleanOrphanedToolCallsForLogging(messages); - const mergedMessages = - this.mergeConsecutiveAssistantMessagesForLogging(cleanedMessages); - - const openaiRequest: OpenAIRequestFormat = { - model: this.model, - messages: mergedMessages, - }; - - // Add sampling parameters using the same logic as actual API calls - const samplingParams = this.buildSamplingParameters(request); - Object.assign(openaiRequest, samplingParams); - - // Convert tools if present - if (request.config?.tools) { - openaiRequest.tools = await this.convertGeminiToolsToOpenAI( - request.config.tools, - ); - } - - return openaiRequest; - } - - /** - * Clean up orphaned tool calls for logging purposes - */ - private cleanOrphanedToolCallsForLogging( - messages: OpenAIMessage[], - ): OpenAIMessage[] { - const cleaned: OpenAIMessage[] = []; - const toolCallIds = new Set(); - const toolResponseIds = new Set(); - - // First pass: collect all tool call IDs and tool response IDs - for (const message of messages) { - if (message.role === 'assistant' && message.tool_calls) { - for (const toolCall of message.tool_calls) { - if (toolCall.id) { - toolCallIds.add(toolCall.id); - } - } - } else if (message.role === 'tool' && message.tool_call_id) { - toolResponseIds.add(message.tool_call_id); - } - } - - // Second pass: filter out orphaned messages - for (const message of messages) { - if (message.role === 'assistant' && message.tool_calls) { - // Filter out tool calls that don't have corresponding responses - const validToolCalls = message.tool_calls.filter( - (toolCall) => toolCall.id && toolResponseIds.has(toolCall.id), - ); - - if (validToolCalls.length > 0) { - // Keep the message but only with valid tool calls - const cleanedMessage = { ...message }; - cleanedMessage.tool_calls = validToolCalls; - cleaned.push(cleanedMessage); - } else if ( - typeof message.content === 'string' && - message.content.trim() - ) { - // Keep the message if it has text content, but remove tool calls - const cleanedMessage = { ...message }; - delete cleanedMessage.tool_calls; - cleaned.push(cleanedMessage); - } - // If no valid tool calls and no content, skip the message entirely - } else if (message.role === 'tool' && message.tool_call_id) { - // Only keep tool responses that have corresponding tool calls - if (toolCallIds.has(message.tool_call_id)) { - cleaned.push(message); - } - } else { - // Keep all other messages as-is - cleaned.push(message); - } - } - - // Final validation: ensure every assistant message with tool_calls has corresponding tool responses - const finalCleaned: OpenAIMessage[] = []; - const finalToolCallIds = new Set(); - - // Collect all remaining tool call IDs - for (const message of cleaned) { - if (message.role === 'assistant' && message.tool_calls) { - for (const toolCall of message.tool_calls) { - if (toolCall.id) { - finalToolCallIds.add(toolCall.id); - } - } - } - } - - // Verify all tool calls have responses - const finalToolResponseIds = new Set(); - for (const message of cleaned) { - if (message.role === 'tool' && message.tool_call_id) { - finalToolResponseIds.add(message.tool_call_id); - } - } - - // Remove any remaining orphaned tool calls - for (const message of cleaned) { - if (message.role === 'assistant' && message.tool_calls) { - const finalValidToolCalls = message.tool_calls.filter( - (toolCall) => toolCall.id && finalToolResponseIds.has(toolCall.id), - ); - - if (finalValidToolCalls.length > 0) { - const cleanedMessage = { ...message }; - cleanedMessage.tool_calls = finalValidToolCalls; - finalCleaned.push(cleanedMessage); - } else if ( - typeof message.content === 'string' && - message.content.trim() - ) { - const cleanedMessage = { ...message }; - delete cleanedMessage.tool_calls; - finalCleaned.push(cleanedMessage); - } - } else { - finalCleaned.push(message); - } - } - - return finalCleaned; - } - - /** - * Merge consecutive assistant messages to combine split text and tool calls for logging - */ - private mergeConsecutiveAssistantMessagesForLogging( - messages: OpenAIMessage[], - ): OpenAIMessage[] { - const merged: OpenAIMessage[] = []; - - for (const message of messages) { - if (message.role === 'assistant' && merged.length > 0) { - const lastMessage = merged[merged.length - 1]; - - // If the last message is also an assistant message, merge them - if (lastMessage.role === 'assistant') { - // Combine content - const combinedContent = [ - lastMessage.content || '', - message.content || '', - ] - .filter(Boolean) - .join(''); - - // Combine tool calls - const combinedToolCalls = [ - ...(lastMessage.tool_calls || []), - ...(message.tool_calls || []), - ]; - - // Update the last message with combined data - lastMessage.content = combinedContent || null; - if (combinedToolCalls.length > 0) { - lastMessage.tool_calls = combinedToolCalls; - } - - continue; // Skip adding the current message since it's been merged - } - } - - // Add the message as-is if no merging is needed - merged.push(message); - } - - return merged; - } - /** * Convert Gemini response format to OpenAI chat completion format for logging */ diff --git a/packages/core/src/qwen/qwenContentGenerator.test.ts b/packages/core/src/qwen/qwenContentGenerator.test.ts index d2878f93..a56aed81 100644 --- a/packages/core/src/qwen/qwenContentGenerator.test.ts +++ b/packages/core/src/qwen/qwenContentGenerator.test.ts @@ -21,6 +21,7 @@ import { } from '@google/genai'; import { QwenContentGenerator } from './qwenContentGenerator.js'; import { Config } from '../config/config.js'; +import { AuthType, ContentGeneratorConfig } from '../core/contentGenerator.js'; // Mock the OpenAIContentGenerator parent class vi.mock('../core/openaiContentGenerator.js', () => ({ @@ -30,10 +31,13 @@ vi.mock('../core/openaiContentGenerator.js', () => ({ baseURL: string; }; - constructor(apiKey: string, _model: string, _config: Config) { + constructor( + contentGeneratorConfig: ContentGeneratorConfig, + _config: Config, + ) { this.client = { - apiKey, - baseURL: 'https://api.openai.com/v1', + apiKey: contentGeneratorConfig.apiKey || 'test-key', + baseURL: contentGeneratorConfig.baseUrl || 'https://api.openai.com/v1', }; } @@ -131,9 +135,13 @@ describe('QwenContentGenerator', () => { }; // Create QwenContentGenerator instance + const contentGeneratorConfig = { + model: 'qwen-turbo', + authType: AuthType.QWEN_OAUTH, + }; qwenContentGenerator = new QwenContentGenerator( mockQwenClient, - 'qwen-turbo', + contentGeneratorConfig, mockConfig, ); }); diff --git a/packages/core/src/qwen/qwenContentGenerator.ts b/packages/core/src/qwen/qwenContentGenerator.ts index 1158d547..2a9468bd 100644 --- a/packages/core/src/qwen/qwenContentGenerator.ts +++ b/packages/core/src/qwen/qwenContentGenerator.ts @@ -20,6 +20,7 @@ import { EmbedContentParameters, EmbedContentResponse, } from '@google/genai'; +import { ContentGeneratorConfig } from '../core/contentGenerator.js'; // Default fallback base URL if no endpoint is provided const DEFAULT_QWEN_BASE_URL = @@ -36,9 +37,13 @@ export class QwenContentGenerator extends OpenAIContentGenerator { private currentEndpoint: string | null = null; private refreshPromise: Promise | null = null; - constructor(qwenClient: IQwenOAuth2Client, model: string, config: Config) { + constructor( + qwenClient: IQwenOAuth2Client, + contentGeneratorConfig: ContentGeneratorConfig, + config: Config, + ) { // Initialize with empty API key, we'll override it dynamically - super('', model, config); + super(contentGeneratorConfig, config); this.qwenClient = qwenClient; // Set default base URL, will be updated dynamically From ed5a2d0fa49adbe5354c74a05a199431b8212d36 Mon Sep 17 00:00:00 2001 From: ajiwo <1269872+ajiwo@users.noreply.github.com> Date: Thu, 21 Aug 2025 17:35:30 +0700 Subject: [PATCH 7/7] Limit grep result (#407) * feat: implement result limiting for GrepTool to prevent context overflow --- docs/tools/file-system.md | 30 +++++++++++ packages/core/src/tools/grep.test.ts | 80 ++++++++++++++++++++++++++++ packages/core/src/tools/grep.ts | 69 ++++++++++++++++++++++-- 3 files changed, 176 insertions(+), 3 deletions(-) diff --git a/docs/tools/file-system.md b/docs/tools/file-system.md index 965ff9af..45c1eaa7 100644 --- a/docs/tools/file-system.md +++ b/docs/tools/file-system.md @@ -89,10 +89,13 @@ Qwen Code provides a comprehensive suite of tools for interacting with the local - `pattern` (string, required): The regular expression (regex) to search for (e.g., `"function\s+myFunction"`). - `path` (string, optional): The absolute path to the directory to search within. Defaults to the current working directory. - `include` (string, optional): A glob pattern to filter which files are searched (e.g., `"*.js"`, `"src/**/*.{ts,tsx}"`). If omitted, searches most files (respecting common ignores). + - `maxResults` (number, optional): Maximum number of matches to return to prevent context overflow (default: 20, max: 100). Use lower values for broad searches, higher for specific searches. - **Behavior:** - Uses `git grep` if available in a Git repository for speed; otherwise, falls back to system `grep` or a JavaScript-based search. - Returns a list of matching lines, each prefixed with its file path (relative to the search directory) and line number. + - Limits results to a maximum of 20 matches by default to prevent context overflow. When results are truncated, shows a clear warning with guidance on refining searches. - **Output (`llmContent`):** A formatted string of matches, e.g.: + ``` Found 3 matches for pattern "myFunction" in path "." (filter: "*.ts"): --- @@ -103,9 +106,36 @@ Qwen Code provides a comprehensive suite of tools for interacting with the local File: src/index.ts L5: import { myFunction } from './utils'; --- + + WARNING: Results truncated to prevent context overflow. To see more results: + - Use a more specific pattern to reduce matches + - Add file filters with the 'include' parameter (e.g., "*.js", "src/**") + - Specify a narrower 'path' to search in a subdirectory + - Increase 'maxResults' parameter if you need more matches (current: 20) ``` + - **Confirmation:** No. +### `search_file_content` examples + +Search for a pattern with default result limiting: + +``` +search_file_content(pattern="function\s+myFunction", path="src") +``` + +Search for a pattern with custom result limiting: + +``` +search_file_content(pattern="function", path="src", maxResults=50) +``` + +Search for a pattern with file filtering and custom result limiting: + +``` +search_file_content(pattern="function", include="*.js", maxResults=10) +``` + ## 6. `replace` (Edit) `replace` replaces text within a file. By default, replaces a single occurrence, but can replace multiple occurrences when `expected_replacements` is specified. This tool is designed for precise, targeted changes and requires significant context around the `old_string` to ensure it modifies the correct location. diff --git a/packages/core/src/tools/grep.test.ts b/packages/core/src/tools/grep.test.ts index dd567170..d77da29c 100644 --- a/packages/core/src/tools/grep.test.ts +++ b/packages/core/src/tools/grep.test.ts @@ -439,4 +439,84 @@ describe('GrepTool', () => { expect(invocation.getDescription()).toBe("'testPattern' within ./"); }); }); + + describe('Result limiting', () => { + beforeEach(async () => { + // Create many test files with matches to test limiting + for (let i = 1; i <= 30; i++) { + const fileName = `test${i}.txt`; + const content = `This is test file ${i} with the pattern testword in it.`; + await fs.writeFile(path.join(tempRootDir, fileName), content); + } + }); + + it('should limit results to default 20 matches', async () => { + const params: GrepToolParams = { pattern: 'testword' }; + const invocation = grepTool.build(params); + const result = await invocation.execute(abortSignal); + + expect(result.llmContent).toContain('Found 20 matches'); + expect(result.llmContent).toContain( + 'showing first 20 of 30+ total matches', + ); + expect(result.llmContent).toContain('WARNING: Results truncated'); + expect(result.returnDisplay).toContain( + 'Found 20 matches (truncated from 30+)', + ); + }); + + it('should respect custom maxResults parameter', async () => { + const params: GrepToolParams = { pattern: 'testword', maxResults: 5 }; + const invocation = grepTool.build(params); + const result = await invocation.execute(abortSignal); + + expect(result.llmContent).toContain('Found 5 matches'); + expect(result.llmContent).toContain( + 'showing first 5 of 30+ total matches', + ); + expect(result.llmContent).toContain('current: 5'); + expect(result.returnDisplay).toContain( + 'Found 5 matches (truncated from 30+)', + ); + }); + + it('should not show truncation warning when all results fit', async () => { + const params: GrepToolParams = { pattern: 'testword', maxResults: 50 }; + const invocation = grepTool.build(params); + const result = await invocation.execute(abortSignal); + + expect(result.llmContent).toContain('Found 30 matches'); + expect(result.llmContent).not.toContain('WARNING: Results truncated'); + expect(result.llmContent).not.toContain('showing first'); + expect(result.returnDisplay).toBe('Found 30 matches'); + }); + + it('should validate maxResults parameter', () => { + const invalidParams = [ + { pattern: 'test', maxResults: 0 }, + { pattern: 'test', maxResults: 101 }, + { pattern: 'test', maxResults: -1 }, + { pattern: 'test', maxResults: 1.5 }, + ]; + + invalidParams.forEach((params) => { + const error = grepTool.validateToolParams(params as GrepToolParams); + expect(error).toBeTruthy(); // Just check that validation fails + expect(error).toMatch(/maxResults|must be/); // Check it's about maxResults validation + }); + }); + + it('should accept valid maxResults parameter', () => { + const validParams = [ + { pattern: 'test', maxResults: 1 }, + { pattern: 'test', maxResults: 50 }, + { pattern: 'test', maxResults: 100 }, + ]; + + validParams.forEach((params) => { + const error = grepTool.validateToolParams(params); + expect(error).toBeNull(); + }); + }); + }); }); diff --git a/packages/core/src/tools/grep.ts b/packages/core/src/tools/grep.ts index 41e77c0f..cd14fbe5 100644 --- a/packages/core/src/tools/grep.ts +++ b/packages/core/src/tools/grep.ts @@ -43,6 +43,11 @@ export interface GrepToolParams { * File pattern to include in the search (e.g. "*.js", "*.{ts,tsx}") */ include?: string; + + /** + * Maximum number of matches to return (optional, defaults to 20) + */ + maxResults?: number; } /** @@ -124,6 +129,10 @@ class GrepToolInvocation extends BaseToolInvocation< // Collect matches from all search directories let allMatches: GrepMatch[] = []; + const maxResults = this.params.maxResults ?? 20; // Default to 20 results + let totalMatchesFound = 0; + let searchTruncated = false; + for (const searchDir of searchDirectories) { const matches = await this.performGrepSearch({ pattern: this.params.pattern, @@ -132,6 +141,8 @@ class GrepToolInvocation extends BaseToolInvocation< signal, }); + totalMatchesFound += matches.length; + // Add directory prefix if searching multiple directories if (searchDirectories.length > 1) { const dirName = path.basename(searchDir); @@ -140,7 +151,20 @@ class GrepToolInvocation extends BaseToolInvocation< }); } - allMatches = allMatches.concat(matches); + // Apply result limiting + const remainingSlots = maxResults - allMatches.length; + if (remainingSlots <= 0) { + searchTruncated = true; + break; + } + + if (matches.length > remainingSlots) { + allMatches = allMatches.concat(matches.slice(0, remainingSlots)); + searchTruncated = true; + break; + } else { + allMatches = allMatches.concat(matches); + } } let searchLocationDescription: string; @@ -176,7 +200,14 @@ class GrepToolInvocation extends BaseToolInvocation< const matchCount = allMatches.length; const matchTerm = matchCount === 1 ? 'match' : 'matches'; - let llmContent = `Found ${matchCount} ${matchTerm} for pattern "${this.params.pattern}" ${searchLocationDescription}${this.params.include ? ` (filter: "${this.params.include}")` : ''}: + // Build the header with truncation info if needed + let headerText = `Found ${matchCount} ${matchTerm} for pattern "${this.params.pattern}" ${searchLocationDescription}${this.params.include ? ` (filter: "${this.params.include}")` : ''}`; + + if (searchTruncated) { + headerText += ` (showing first ${matchCount} of ${totalMatchesFound}+ total matches)`; + } + + let llmContent = `${headerText}: --- `; @@ -189,9 +220,23 @@ class GrepToolInvocation extends BaseToolInvocation< llmContent += '---\n'; } + // Add truncation guidance if results were limited + if (searchTruncated) { + llmContent += `\nWARNING: Results truncated to prevent context overflow. To see more results: +- Use a more specific pattern to reduce matches +- Add file filters with the 'include' parameter (e.g., "*.js", "src/**") +- Specify a narrower 'path' to search in a subdirectory +- Increase 'maxResults' parameter if you need more matches (current: ${maxResults})`; + } + + let displayText = `Found ${matchCount} ${matchTerm}`; + if (searchTruncated) { + displayText += ` (truncated from ${totalMatchesFound}+)`; + } + return { llmContent: llmContent.trim(), - returnDisplay: `Found ${matchCount} ${matchTerm}`, + returnDisplay: displayText, }; } catch (error) { console.error(`Error during GrepLogic execution: ${error}`); @@ -567,6 +612,13 @@ export class GrepTool extends BaseDeclarativeTool { "Optional: A glob pattern to filter which files are searched (e.g., '*.js', '*.{ts,tsx}', 'src/**'). If omitted, searches all files (respecting potential global ignores).", type: 'string', }, + maxResults: { + description: + 'Optional: Maximum number of matches to return to prevent context overflow (default: 20, max: 100). Use lower values for broad searches, higher for specific searches.', + type: 'number', + minimum: 1, + maximum: 100, + }, }, required: ['pattern'], type: 'object', @@ -635,6 +687,17 @@ export class GrepTool extends BaseDeclarativeTool { return `Invalid regular expression pattern provided: ${params.pattern}. Error: ${getErrorMessage(error)}`; } + // Validate maxResults if provided + if (params.maxResults !== undefined) { + if ( + !Number.isInteger(params.maxResults) || + params.maxResults < 1 || + params.maxResults > 100 + ) { + return `maxResults must be an integer between 1 and 100, got: ${params.maxResults}`; + } + } + // Only validate path if one is provided if (params.path) { try {