diff --git a/.github/workflows/gemini-automated-issue-triage.yml b/.github/workflows/gemini-automated-issue-triage.yml index 48404294..e645ce70 100644 --- a/.github/workflows/gemini-automated-issue-triage.yml +++ b/.github/workflows/gemini-automated-issue-triage.yml @@ -34,13 +34,6 @@ jobs: triage-issue: timeout-minutes: 5 if: ${{ github.repository == 'QwenLM/qwen-code' }} - permissions: - issues: write - contents: read - id-token: write - concurrency: - group: ${{ github.workflow }}-${{ github.event.issue.number }} - cancel-in-progress: true runs-on: ubuntu-latest steps: - name: Run Qwen Issue Triage @@ -50,8 +43,9 @@ jobs: ISSUE_TITLE: ${{ github.event.issue.title }} ISSUE_BODY: ${{ github.event.issue.body }} with: - version: 0.0.7 OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} + OPENAI_BASE_URL: ${{ secrets.OPENAI_BASE_URL }} + OPENAI_MODEL: ${{ secrets.OPENAI_MODEL }} settings_json: | { "maxSessionTurns": 25, @@ -72,7 +66,7 @@ jobs: ## Steps 1. Run: `gh label list --repo "${REPOSITORY}" --limit 100` to get all available labels. - 2. Review the issue title and body provided in the environment variables: "${ISSUE_TITLE}" and "${ISSUE_BODY}". + 2. Use right tool to review the issue title and body provided in the environment variables: "${ISSUE_TITLE}" and "${ISSUE_BODY}". 3. Ignore any existing priorities or tags on the issue. Just report your findings. 4. Select the most relevant labels from the existing labels, focusing on kind/*, area/*, sub-area/* and priority/*. For area/* and kind/* limit yourself to only the single most applicable label in each case. 6. Apply the selected labels to this issue using: `gh issue edit "${ISSUE_NUMBER}" --repo "${REPOSITORY}" --add-label "label1,label2"`. @@ -180,151 +174,5 @@ jobs: owner: process.env.REPOSITORY.split('/')[0], repo: process.env.REPOSITORY.split('/')[1], issue_number: '${{ github.event.issue.number }}', - body: `There is a problem with the Gemini CLI issue triaging. Please check the [action logs](${process.env.RUN_URL}) for details.` + body: 'There is a problem with the Qwen Code issue triaging. Please check the [action logs](${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }}) for details.' }) - - deduplicate-issues: - if: |- - github.repository == 'google-gemini/gemini-cli' && - vars.TRIAGE_DEDUPLICATE_ISSUES != '' && - (github.event_name == 'issues' || - github.event_name == 'workflow_dispatch' || - (github.event_name == 'issue_comment' && - contains(github.event.comment.body, '@gemini-cli /deduplicate') && - (github.event.comment.author_association == 'OWNER' || - github.event.comment.author_association == 'MEMBER' || - github.event.comment.author_association == 'COLLABORATOR'))) - - timeout-minutes: 20 - runs-on: 'ubuntu-latest' - steps: - - name: 'Checkout' - uses: 'actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8' # ratchet:actions/checkout@v5 - - - name: 'Generate GitHub App Token' - id: 'generate_token' - uses: 'actions/create-github-app-token@a8d616148505b5069dccd32f177bb87d7f39123b' # ratchet:actions/create-github-app-token@v2 - with: - app-id: '${{ secrets.APP_ID }}' - private-key: '${{ secrets.PRIVATE_KEY }}' - - - name: 'Log in to GitHub Container Registry' - uses: 'docker/login-action@184bdaa0721073962dff0199f1fb9940f07167d1' # ratchet:docker/login-action@v3 - with: - registry: 'ghcr.io' - username: '${{ github.actor }}' - password: '${{ secrets.GITHUB_TOKEN }}' - - - name: 'Run Gemini Issue Deduplication' - uses: 'google-github-actions/run-gemini-cli@06123c6a203eb7a964ce3be7c48479cc66059f23' # ratchet:google-github-actions/run-gemini-cli@v0 - id: 'gemini_issue_deduplication' - env: - GITHUB_TOKEN: '${{ steps.generate_token.outputs.token }}' - ISSUE_TITLE: '${{ github.event.issue.title }}' - ISSUE_BODY: '${{ github.event.issue.body }}' - ISSUE_NUMBER: '${{ github.event.issue.number }}' - REPOSITORY: '${{ github.repository }}' - FIRESTORE_PROJECT: '${{ vars.FIRESTORE_PROJECT }}' - with: - gcp_workload_identity_provider: '${{ vars.GCP_WIF_PROVIDER }}' - gcp_project_id: '${{ vars.GOOGLE_CLOUD_PROJECT }}' - gcp_location: '${{ vars.GOOGLE_CLOUD_LOCATION }}' - gcp_service_account: '${{ vars.SERVICE_ACCOUNT_EMAIL }}' - gemini_api_key: '${{ secrets.GEMINI_API_KEY }}' - use_vertex_ai: '${{ vars.GOOGLE_GENAI_USE_VERTEXAI }}' - use_gemini_code_assist: '${{ vars.GOOGLE_GENAI_USE_GCA }}' - settings: |- - { - "mcpServers": { - "issue_deduplication": { - "command": "docker", - "args": [ - "run", - "-i", - "--rm", - "--network", "host", - "-e", "GITHUB_TOKEN", - "-e", "GEMINI_API_KEY", - "-e", "DATABASE_TYPE", - "-e", "FIRESTORE_DATABASE_ID", - "-e", "GCP_PROJECT", - "-e", "GOOGLE_APPLICATION_CREDENTIALS=/app/gcp-credentials.json", - "-v", "${GOOGLE_APPLICATION_CREDENTIALS}:/app/gcp-credentials.json", - "ghcr.io/google-gemini/gemini-cli-issue-triage@sha256:e3de1523f6c83aabb3c54b76d08940a2bf42febcb789dd2da6f95169641f94d3" - ], - "env": { - "GITHUB_TOKEN": "${GITHUB_TOKEN}", - "GEMINI_API_KEY": "${{ secrets.GEMINI_API_KEY }}", - "DATABASE_TYPE":"firestore", - "GCP_PROJECT": "${FIRESTORE_PROJECT}", - "FIRESTORE_DATABASE_ID": "(default)", - "GOOGLE_APPLICATION_CREDENTIALS": "${GOOGLE_APPLICATION_CREDENTIALS}" - }, - "enabled": true, - "timeout": 600000 - } - }, - "maxSessionTurns": 25, - "coreTools": [ - "run_shell_command(echo)", - "run_shell_command(gh issue comment)", - "run_shell_command(gh issue view)", - "run_shell_command(gh issue edit)" - ], - "telemetry": { - "enabled": true, - "target": "gcp" - } - } - prompt: |- - ## Role - You are an issue de-duplication assistant. Your goal is to find - duplicate issues, label the current issue as a duplicate, and notify - the user by commenting on the current issue, while avoiding - duplicate comments. - ## Steps - 1. **Find Potential Duplicates:** - - The repository is ${{ github.repository }} and the issue number is ${{ github.event.issue.number }}. - - Use the `duplicates` tool with the `repo` and `issue_number` to find potential duplicates for the current issue. Do not use the `threshold` parameter. - - If no duplicates are found, you are done. - - Print the JSON output from the `duplicates` tool to the logs. - 2. **Refine Duplicates List (if necessary):** - - If the `duplicates` tool returns between 1 and 14 results, you must refine the list. - - For each potential duplicate issue, run `gh issue view --json title,body,comments` to fetch its content. - - Also fetch the content of the original issue: `gh issue view "${ISSUE_NUMBER}" --json title,body,comments`. - - Carefully analyze the content (title, body, comments) of the original issue and all potential duplicates. - - It is very important if the comments on either issue mention that they are not duplicates of each other, to treat them as not duplicates. - - Based on your analysis, create a final list containing only the issues you are highly confident are actual duplicates. - - If your final list is empty, you are done. - - Print to the logs if you omitted any potential duplicates based on your analysis. - - If the `duplicates` tool returned 15+ results, use the top 15 matches (based on descending similarity score value) to perform this step. - 3. **Format Final Duplicates List:** - Format the final list of duplicates into a markdown string. - The format should be: - "Found possible duplicate issues:\n\n- #${issue_number}\n\nIf you believe this is not a duplicate, please remove the `status/possible-duplicate` label." - Add an HTML comment to the end for identification: `` - 4. **Check for Existing Comment:** - - Run `gh issue view "${ISSUE_NUMBER}" --json comments` to get all - comments on the issue. - - Look for a comment made by a bot (the author's login often ends in `[bot]`) that contains ``. - - If you find such a comment, store its `id` and `body`. - 5. **Decide Action:** - - **If an existing comment is found:** - - Compare the new list of duplicate issues with the list from the existing comment's body. - - If they are the same, do nothing. - - If they are different, edit the existing comment. Use - `gh issue comment "${ISSUE_NUMBER}" --edit-comment --body "..."`. - The new body should be the new list of duplicates, but with the header "Found possible duplicate issues (updated):". - - **If no existing comment is found:** - - Create a new comment with the list of duplicates. - - Use `gh issue comment "${ISSUE_NUMBER}" --body "..."`. - 6. **Add Duplicate Label:** - - If you created or updated a comment in the previous step, add the `duplicate` label to the current issue. - - Use `gh issue edit "${ISSUE_NUMBER}" --add-label "status/possible-duplicate"`. - ## Guidelines - - Only use the `duplicates` and `run_shell_command` tools. - - The `run_shell_command` tool can be used with `gh issue view`, `gh issue comment`, and `gh issue edit`. - - Do not download or read media files like images, videos, or links. The `--json` flag for `gh issue view` will prevent this. - - Do not modify the issue content or status. - - Only comment on and label the current issue. - - Reference all shell variables as "${VAR}" (with quotes and braces). diff --git a/.github/workflows/gemini-scheduled-issue-triage.yml b/.github/workflows/gemini-scheduled-issue-triage.yml index d6780d8d..9eba47d4 100644 --- a/.github/workflows/gemini-scheduled-issue-triage.yml +++ b/.github/workflows/gemini-scheduled-issue-triage.yml @@ -60,7 +60,6 @@ jobs: ISSUES_TO_TRIAGE: ${{ steps.find_issues.outputs.issues_to_triage }} REPOSITORY: ${{ github.repository }} with: - version: 0.0.7 OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} OPENAI_BASE_URL: ${{ secrets.OPENAI_BASE_URL }} OPENAI_MODEL: ${{ secrets.OPENAI_MODEL }} @@ -86,7 +85,7 @@ jobs: ## Steps 1. Run: `gh label list --repo "${REPOSITORY}" --limit 100` to get all available labels. - 2. Check environment variable for issues to triage: $ISSUES_TO_TRIAGE (JSON array of issues) + 2. Use right tool to check environment variable for issues to triage: $ISSUES_TO_TRIAGE (JSON array of issues) 3. Review the issue title, body and any comments provided in the environment variables. 4. Ignore any existing priorities or tags on the issue. 5. Select the most relevant labels from the existing labels, focusing on kind/*, area/*, sub-area/* and priority/*. diff --git a/docs/tools/file-system.md b/docs/tools/file-system.md index 965ff9af..45c1eaa7 100644 --- a/docs/tools/file-system.md +++ b/docs/tools/file-system.md @@ -89,10 +89,13 @@ Qwen Code provides a comprehensive suite of tools for interacting with the local - `pattern` (string, required): The regular expression (regex) to search for (e.g., `"function\s+myFunction"`). - `path` (string, optional): The absolute path to the directory to search within. Defaults to the current working directory. - `include` (string, optional): A glob pattern to filter which files are searched (e.g., `"*.js"`, `"src/**/*.{ts,tsx}"`). If omitted, searches most files (respecting common ignores). + - `maxResults` (number, optional): Maximum number of matches to return to prevent context overflow (default: 20, max: 100). Use lower values for broad searches, higher for specific searches. - **Behavior:** - Uses `git grep` if available in a Git repository for speed; otherwise, falls back to system `grep` or a JavaScript-based search. - Returns a list of matching lines, each prefixed with its file path (relative to the search directory) and line number. + - Limits results to a maximum of 20 matches by default to prevent context overflow. When results are truncated, shows a clear warning with guidance on refining searches. - **Output (`llmContent`):** A formatted string of matches, e.g.: + ``` Found 3 matches for pattern "myFunction" in path "." (filter: "*.ts"): --- @@ -103,9 +106,36 @@ Qwen Code provides a comprehensive suite of tools for interacting with the local File: src/index.ts L5: import { myFunction } from './utils'; --- + + WARNING: Results truncated to prevent context overflow. To see more results: + - Use a more specific pattern to reduce matches + - Add file filters with the 'include' parameter (e.g., "*.js", "src/**") + - Specify a narrower 'path' to search in a subdirectory + - Increase 'maxResults' parameter if you need more matches (current: 20) ``` + - **Confirmation:** No. +### `search_file_content` examples + +Search for a pattern with default result limiting: + +``` +search_file_content(pattern="function\s+myFunction", path="src") +``` + +Search for a pattern with custom result limiting: + +``` +search_file_content(pattern="function", path="src", maxResults=50) +``` + +Search for a pattern with file filtering and custom result limiting: + +``` +search_file_content(pattern="function", include="*.js", maxResults=10) +``` + ## 6. `replace` (Edit) `replace` replaces text within a file. By default, replaces a single occurrence, but can replace multiple occurrences when `expected_replacements` is specified. This tool is designed for precise, targeted changes and requires significant context around the `old_string` to ensure it modifies the correct location. diff --git a/package-lock.json b/package-lock.json index a9d095d3..07f74e88 100644 --- a/package-lock.json +++ b/package-lock.json @@ -12554,6 +12554,7 @@ "strip-ansi": "^7.1.0", "tiktoken": "^1.0.21", "undici": "^7.10.0", + "uuid": "^9.0.1", "ws": "^8.18.0" }, "devDependencies": { diff --git a/packages/cli/src/config/config.test.ts b/packages/cli/src/config/config.test.ts index 0538c1f9..4342ee0e 100644 --- a/packages/cli/src/config/config.test.ts +++ b/packages/cli/src/config/config.test.ts @@ -242,7 +242,7 @@ describe('parseArguments', () => { await expect(parseArguments()).rejects.toThrow('process.exit called'); expect(mockConsoleError).toHaveBeenCalledWith( - expect.stringContaining('无效的选项值:'), + expect.stringContaining('Invalid values:'), ); mockExit.mockRestore(); diff --git a/packages/cli/src/config/config.ts b/packages/cli/src/config/config.ts index aa45f1b2..59501c88 100644 --- a/packages/cli/src/config/config.ts +++ b/packages/cli/src/config/config.ts @@ -564,7 +564,6 @@ export async function loadCliConfig( (typeof argv.openaiLogging === 'undefined' ? settings.enableOpenAILogging : argv.openaiLogging) ?? false, - sampling_params: settings.sampling_params, systemPromptMappings: (settings.systemPromptMappings ?? [ { baseUrls: [ diff --git a/packages/cli/src/config/settingsSchema.ts b/packages/cli/src/config/settingsSchema.ts index 6e472067..4a21ebe5 100644 --- a/packages/cli/src/config/settingsSchema.ts +++ b/packages/cli/src/config/settingsSchema.ts @@ -495,7 +495,6 @@ export const SETTINGS_SCHEMA = { description: 'Show line numbers in the chat.', showInDialog: true, }, - contentGenerator: { type: 'object', label: 'Content Generator', @@ -505,15 +504,6 @@ export const SETTINGS_SCHEMA = { description: 'Content generator settings.', showInDialog: false, }, - sampling_params: { - type: 'object', - label: 'Sampling Params', - category: 'General', - requiresRestart: false, - default: undefined as Record | undefined, - description: 'Sampling parameters for the model.', - showInDialog: false, - }, enableOpenAILogging: { type: 'boolean', label: 'Enable OpenAI Logging', diff --git a/packages/cli/src/ui/hooks/useGeminiStream.ts b/packages/cli/src/ui/hooks/useGeminiStream.ts index 5b46f15b..2207b9c9 100644 --- a/packages/cli/src/ui/hooks/useGeminiStream.ts +++ b/packages/cli/src/ui/hooks/useGeminiStream.ts @@ -192,6 +192,7 @@ export const useGeminiStream = ( return; } turnCancelledRef.current = true; + isSubmittingQueryRef.current = false; abortControllerRef.current?.abort(); if (pendingHistoryItemRef.current) { addItem(pendingHistoryItemRef.current, Date.now()); diff --git a/packages/core/package.json b/packages/core/package.json index 3363a669..8e511758 100644 --- a/packages/core/package.json +++ b/packages/core/package.json @@ -52,6 +52,7 @@ "strip-ansi": "^7.1.0", "tiktoken": "^1.0.21", "undici": "^7.10.0", + "uuid": "^9.0.1", "ws": "^8.18.0" }, "devDependencies": { diff --git a/packages/core/src/config/config.ts b/packages/core/src/config/config.ts index 621a1769..d09c24e6 100644 --- a/packages/core/src/config/config.ts +++ b/packages/core/src/config/config.ts @@ -203,7 +203,6 @@ export interface ConfigParameters { folderTrust?: boolean; ideMode?: boolean; enableOpenAILogging?: boolean; - sampling_params?: Record; systemPromptMappings?: Array<{ baseUrls: string[]; modelNames: string[]; @@ -212,6 +211,9 @@ export interface ConfigParameters { contentGenerator?: { timeout?: number; maxRetries?: number; + samplingParams?: { + [key: string]: unknown; + }; }; cliVersion?: string; loadMemoryFromIncludeDirectories?: boolean; @@ -287,10 +289,10 @@ export class Config { | Record | undefined; private readonly enableOpenAILogging: boolean; - private readonly sampling_params?: Record; private readonly contentGenerator?: { timeout?: number; maxRetries?: number; + samplingParams?: Record; }; private readonly cliVersion?: string; private readonly experimentalZedIntegration: boolean = false; @@ -367,7 +369,6 @@ export class Config { this.ideClient = IdeClient.getInstance(); this.systemPromptMappings = params.systemPromptMappings; this.enableOpenAILogging = params.enableOpenAILogging ?? false; - this.sampling_params = params.sampling_params; this.contentGenerator = params.contentGenerator; this.cliVersion = params.cliVersion; @@ -766,10 +767,6 @@ export class Config { return this.enableOpenAILogging; } - getSamplingParams(): Record | undefined { - return this.sampling_params; - } - getContentGeneratorTimeout(): number | undefined { return this.contentGenerator?.timeout; } @@ -778,6 +775,12 @@ export class Config { return this.contentGenerator?.maxRetries; } + getContentGeneratorSamplingParams(): ContentGeneratorConfig['samplingParams'] { + return this.contentGenerator?.samplingParams as + | ContentGeneratorConfig['samplingParams'] + | undefined; + } + getCliVersion(): string | undefined { return this.cliVersion; } diff --git a/packages/core/src/core/__tests__/openaiTimeoutHandling.test.ts b/packages/core/src/core/__tests__/openaiTimeoutHandling.test.ts index c743c9b5..bb46b09b 100644 --- a/packages/core/src/core/__tests__/openaiTimeoutHandling.test.ts +++ b/packages/core/src/core/__tests__/openaiTimeoutHandling.test.ts @@ -7,6 +7,7 @@ import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest'; import { OpenAIContentGenerator } from '../openaiContentGenerator.js'; import { Config } from '../../config/config.js'; +import { AuthType } from '../contentGenerator.js'; import OpenAI from 'openai'; // Mock OpenAI @@ -41,9 +42,6 @@ describe('OpenAIContentGenerator Timeout Handling', () => { mockConfig = { getContentGeneratorConfig: vi.fn().mockReturnValue({ authType: 'openai', - enableOpenAILogging: false, - timeout: 120000, - maxRetries: 3, }), getCliVersion: vi.fn().mockReturnValue('1.0.0'), } as unknown as Config; @@ -60,7 +58,12 @@ describe('OpenAIContentGenerator Timeout Handling', () => { vi.mocked(OpenAI).mockImplementation(() => mockOpenAIClient); // Create generator instance - generator = new OpenAIContentGenerator('test-key', 'gpt-4', mockConfig); + const contentGeneratorConfig = { + model: 'gpt-4', + apiKey: 'test-key', + authType: AuthType.USE_OPENAI, + }; + generator = new OpenAIContentGenerator(contentGeneratorConfig, mockConfig); }); afterEach(() => { @@ -237,12 +240,18 @@ describe('OpenAIContentGenerator Timeout Handling', () => { describe('timeout configuration', () => { it('should use default timeout configuration', () => { - new OpenAIContentGenerator('test-key', 'gpt-4', mockConfig); + const contentGeneratorConfig = { + model: 'gpt-4', + apiKey: 'test-key', + authType: AuthType.USE_OPENAI, + baseUrl: 'http://localhost:8080', + }; + new OpenAIContentGenerator(contentGeneratorConfig, mockConfig); // Verify OpenAI client was created with timeout config expect(OpenAI).toHaveBeenCalledWith({ apiKey: 'test-key', - baseURL: '', + baseURL: 'http://localhost:8080', timeout: 120000, maxRetries: 3, defaultHeaders: { @@ -253,18 +262,23 @@ describe('OpenAIContentGenerator Timeout Handling', () => { it('should use custom timeout from config', () => { const customConfig = { - getContentGeneratorConfig: vi.fn().mockReturnValue({ - timeout: 300000, // 5 minutes - maxRetries: 5, - }), + getContentGeneratorConfig: vi.fn().mockReturnValue({}), getCliVersion: vi.fn().mockReturnValue('1.0.0'), } as unknown as Config; - new OpenAIContentGenerator('test-key', 'gpt-4', customConfig); + const contentGeneratorConfig = { + model: 'gpt-4', + apiKey: 'test-key', + baseUrl: 'http://localhost:8080', + authType: AuthType.USE_OPENAI, + timeout: 300000, + maxRetries: 5, + }; + new OpenAIContentGenerator(contentGeneratorConfig, customConfig); expect(OpenAI).toHaveBeenCalledWith({ apiKey: 'test-key', - baseURL: '', + baseURL: 'http://localhost:8080', timeout: 300000, maxRetries: 5, defaultHeaders: { @@ -279,11 +293,17 @@ describe('OpenAIContentGenerator Timeout Handling', () => { getCliVersion: vi.fn().mockReturnValue('1.0.0'), } as unknown as Config; - new OpenAIContentGenerator('test-key', 'gpt-4', noTimeoutConfig); + const contentGeneratorConfig = { + model: 'gpt-4', + apiKey: 'test-key', + authType: AuthType.USE_OPENAI, + baseUrl: 'http://localhost:8080', + }; + new OpenAIContentGenerator(contentGeneratorConfig, noTimeoutConfig); expect(OpenAI).toHaveBeenCalledWith({ apiKey: 'test-key', - baseURL: '', + baseURL: 'http://localhost:8080', timeout: 120000, // default maxRetries: 3, // default defaultHeaders: { diff --git a/packages/core/src/core/client.ts b/packages/core/src/core/client.ts index 33235a67..0815ffcd 100644 --- a/packages/core/src/core/client.ts +++ b/packages/core/src/core/client.ts @@ -589,10 +589,7 @@ export class GeminiClient { model || this.config.getModel() || DEFAULT_GEMINI_FLASH_MODEL; try { const userMemory = this.config.getUserMemory(); - const systemPromptMappings = this.config.getSystemPromptMappings(); - const systemInstruction = getCoreSystemPrompt(userMemory, { - systemPromptMappings, - }); + const systemInstruction = getCoreSystemPrompt(userMemory); const requestConfig = { abortSignal, ...this.generateContentConfig, @@ -680,10 +677,7 @@ export class GeminiClient { try { const userMemory = this.config.getUserMemory(); - const systemPromptMappings = this.config.getSystemPromptMappings(); - const systemInstruction = getCoreSystemPrompt(userMemory, { - systemPromptMappings, - }); + const systemInstruction = getCoreSystemPrompt(userMemory); const requestConfig = { abortSignal, diff --git a/packages/core/src/core/contentGenerator.test.ts b/packages/core/src/core/contentGenerator.test.ts index 5d735beb..2761c0c5 100644 --- a/packages/core/src/core/contentGenerator.test.ts +++ b/packages/core/src/core/contentGenerator.test.ts @@ -84,6 +84,7 @@ describe('createContentGeneratorConfig', () => { getSamplingParams: vi.fn().mockReturnValue(undefined), getContentGeneratorTimeout: vi.fn().mockReturnValue(undefined), getContentGeneratorMaxRetries: vi.fn().mockReturnValue(undefined), + getContentGeneratorSamplingParams: vi.fn().mockReturnValue(undefined), getCliVersion: vi.fn().mockReturnValue('1.0.0'), } as unknown as Config; diff --git a/packages/core/src/core/contentGenerator.ts b/packages/core/src/core/contentGenerator.ts index e58c70b1..72552c90 100644 --- a/packages/core/src/core/contentGenerator.ts +++ b/packages/core/src/core/contentGenerator.ts @@ -53,6 +53,7 @@ export enum AuthType { export type ContentGeneratorConfig = { model: string; apiKey?: string; + baseUrl?: string; vertexai?: boolean; authType?: AuthType | undefined; enableOpenAILogging?: boolean; @@ -77,11 +78,16 @@ export function createContentGeneratorConfig( config: Config, authType: AuthType | undefined, ): ContentGeneratorConfig { + // google auth const geminiApiKey = process.env.GEMINI_API_KEY || undefined; const googleApiKey = process.env.GOOGLE_API_KEY || undefined; const googleCloudProject = process.env.GOOGLE_CLOUD_PROJECT || undefined; const googleCloudLocation = process.env.GOOGLE_CLOUD_LOCATION || undefined; + + // openai auth const openaiApiKey = process.env.OPENAI_API_KEY; + const openaiBaseUrl = process.env.OPENAI_BASE_URL || undefined; + const openaiModel = process.env.OPENAI_MODEL || undefined; // Use runtime model from config if available; otherwise, fall back to parameter or default const effectiveModel = config.getModel() || DEFAULT_GEMINI_MODEL; @@ -93,7 +99,7 @@ export function createContentGeneratorConfig( enableOpenAILogging: config.getEnableOpenAILogging(), timeout: config.getContentGeneratorTimeout(), maxRetries: config.getContentGeneratorMaxRetries(), - samplingParams: config.getSamplingParams(), + samplingParams: config.getContentGeneratorSamplingParams(), }; // If we are using Google auth or we are in Cloud Shell, there is nothing else to validate for now @@ -123,8 +129,8 @@ export function createContentGeneratorConfig( if (authType === AuthType.USE_OPENAI && openaiApiKey) { contentGeneratorConfig.apiKey = openaiApiKey; - contentGeneratorConfig.model = - process.env.OPENAI_MODEL || DEFAULT_GEMINI_MODEL; + contentGeneratorConfig.baseUrl = openaiBaseUrl; + contentGeneratorConfig.model = openaiModel || DEFAULT_QWEN_MODEL; return contentGeneratorConfig; } @@ -192,7 +198,7 @@ export async function createContentGenerator( ); // Always use OpenAIContentGenerator, logging is controlled by enableOpenAILogging flag - return new OpenAIContentGenerator(config.apiKey, config.model, gcConfig); + return new OpenAIContentGenerator(config, gcConfig); } if (config.authType === AuthType.QWEN_OAUTH) { @@ -213,7 +219,7 @@ export async function createContentGenerator( const qwenClient = await getQwenOauthClient(gcConfig); // Create the content generator with dynamic token management - return new QwenContentGenerator(qwenClient, config.model, gcConfig); + return new QwenContentGenerator(qwenClient, config, gcConfig); } catch (error) { throw new Error( `Failed to initialize Qwen: ${error instanceof Error ? error.message : String(error)}`, diff --git a/packages/core/src/core/openaiContentGenerator.test.ts b/packages/core/src/core/openaiContentGenerator.test.ts index 46c8267e..8d03f0ae 100644 --- a/packages/core/src/core/openaiContentGenerator.test.ts +++ b/packages/core/src/core/openaiContentGenerator.test.ts @@ -7,6 +7,7 @@ import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest'; import { OpenAIContentGenerator } from './openaiContentGenerator.js'; import { Config } from '../config/config.js'; +import { AuthType } from './contentGenerator.js'; import OpenAI from 'openai'; import type { GenerateContentParameters, @@ -84,7 +85,20 @@ describe('OpenAIContentGenerator', () => { vi.mocked(OpenAI).mockImplementation(() => mockOpenAIClient); // Create generator instance - generator = new OpenAIContentGenerator('test-key', 'gpt-4', mockConfig); + const contentGeneratorConfig = { + model: 'gpt-4', + apiKey: 'test-key', + authType: AuthType.USE_OPENAI, + enableOpenAILogging: false, + timeout: 120000, + maxRetries: 3, + samplingParams: { + temperature: 0.7, + max_tokens: 1000, + top_p: 0.9, + }, + }; + generator = new OpenAIContentGenerator(contentGeneratorConfig, mockConfig); }); afterEach(() => { @@ -95,7 +109,7 @@ describe('OpenAIContentGenerator', () => { it('should initialize with basic configuration', () => { expect(OpenAI).toHaveBeenCalledWith({ apiKey: 'test-key', - baseURL: '', + baseURL: undefined, timeout: 120000, maxRetries: 3, defaultHeaders: { @@ -105,9 +119,16 @@ describe('OpenAIContentGenerator', () => { }); it('should handle custom base URL', () => { - vi.stubEnv('OPENAI_BASE_URL', 'https://api.custom.com'); - - new OpenAIContentGenerator('test-key', 'gpt-4', mockConfig); + const contentGeneratorConfig = { + model: 'gpt-4', + apiKey: 'test-key', + baseUrl: 'https://api.custom.com', + authType: AuthType.USE_OPENAI, + enableOpenAILogging: false, + timeout: 120000, + maxRetries: 3, + }; + new OpenAIContentGenerator(contentGeneratorConfig, mockConfig); expect(OpenAI).toHaveBeenCalledWith({ apiKey: 'test-key', @@ -121,9 +142,16 @@ describe('OpenAIContentGenerator', () => { }); it('should configure OpenRouter headers when using OpenRouter', () => { - vi.stubEnv('OPENAI_BASE_URL', 'https://openrouter.ai/api/v1'); - - new OpenAIContentGenerator('test-key', 'gpt-4', mockConfig); + const contentGeneratorConfig = { + model: 'gpt-4', + apiKey: 'test-key', + baseUrl: 'https://openrouter.ai/api/v1', + authType: AuthType.USE_OPENAI, + enableOpenAILogging: false, + timeout: 120000, + maxRetries: 3, + }; + new OpenAIContentGenerator(contentGeneratorConfig, mockConfig); expect(OpenAI).toHaveBeenCalledWith({ apiKey: 'test-key', @@ -147,11 +175,18 @@ describe('OpenAIContentGenerator', () => { getCliVersion: vi.fn().mockReturnValue('1.0.0'), } as unknown as Config; - new OpenAIContentGenerator('test-key', 'gpt-4', customConfig); + const contentGeneratorConfig = { + model: 'gpt-4', + apiKey: 'test-key', + authType: AuthType.USE_OPENAI, + timeout: 300000, + maxRetries: 5, + }; + new OpenAIContentGenerator(contentGeneratorConfig, customConfig); expect(OpenAI).toHaveBeenCalledWith({ apiKey: 'test-key', - baseURL: '', + baseURL: undefined, timeout: 300000, maxRetries: 5, defaultHeaders: { @@ -906,9 +941,14 @@ describe('OpenAIContentGenerator', () => { getCliVersion: vi.fn().mockReturnValue('1.0.0'), } as unknown as Config; + const contentGeneratorConfig = { + model: 'gpt-4', + apiKey: 'test-key', + authType: AuthType.USE_OPENAI, + enableOpenAILogging: true, + }; const loggingGenerator = new OpenAIContentGenerator( - 'test-key', - 'gpt-4', + contentGeneratorConfig, loggingConfig, ); @@ -1029,9 +1069,14 @@ describe('OpenAIContentGenerator', () => { getCliVersion: vi.fn().mockReturnValue('1.0.0'), } as unknown as Config; + const contentGeneratorConfig = { + model: 'gpt-4', + apiKey: 'test-key', + authType: AuthType.USE_OPENAI, + enableOpenAILogging: true, + }; const loggingGenerator = new OpenAIContentGenerator( - 'test-key', - 'gpt-4', + contentGeneratorConfig, loggingConfig, ); @@ -1587,7 +1632,23 @@ describe('OpenAIContentGenerator', () => { } } - const testGenerator = new TestGenerator('test-key', 'gpt-4', mockConfig); + const contentGeneratorConfig = { + model: 'gpt-4', + apiKey: 'test-key', + authType: AuthType.USE_OPENAI, + enableOpenAILogging: false, + timeout: 120000, + maxRetries: 3, + samplingParams: { + temperature: 0.7, + max_tokens: 1000, + top_p: 0.9, + }, + }; + const testGenerator = new TestGenerator( + contentGeneratorConfig, + mockConfig, + ); const consoleSpy = vi .spyOn(console, 'error') .mockImplementation(() => {}); @@ -1908,9 +1969,18 @@ describe('OpenAIContentGenerator', () => { getCliVersion: vi.fn().mockReturnValue('1.0.0'), } as unknown as Config; + const contentGeneratorConfig = { + model: 'gpt-4', + apiKey: 'test-key', + authType: AuthType.USE_OPENAI, + enableOpenAILogging: true, + samplingParams: { + temperature: 0.8, + max_tokens: 500, + }, + }; const loggingGenerator = new OpenAIContentGenerator( - 'test-key', - 'gpt-4', + contentGeneratorConfig, loggingConfig, ); @@ -2093,9 +2163,14 @@ describe('OpenAIContentGenerator', () => { getCliVersion: vi.fn().mockReturnValue('1.0.0'), } as unknown as Config; + const contentGeneratorConfig = { + model: 'gpt-4', + apiKey: 'test-key', + authType: AuthType.USE_OPENAI, + enableOpenAILogging: true, + }; const loggingGenerator = new OpenAIContentGenerator( - 'test-key', - 'gpt-4', + contentGeneratorConfig, loggingConfig, ); @@ -2350,9 +2425,18 @@ describe('OpenAIContentGenerator', () => { getCliVersion: vi.fn().mockReturnValue('1.0.0'), } as unknown as Config; + const contentGeneratorConfig = { + model: 'gpt-4', + apiKey: 'test-key', + authType: AuthType.USE_OPENAI, + samplingParams: { + temperature: undefined, + max_tokens: undefined, + top_p: undefined, + }, + }; const testGenerator = new OpenAIContentGenerator( - 'test-key', - 'gpt-4', + contentGeneratorConfig, configWithUndefined, ); @@ -2408,9 +2492,22 @@ describe('OpenAIContentGenerator', () => { getCliVersion: vi.fn().mockReturnValue('1.0.0'), } as unknown as Config; + const contentGeneratorConfig = { + model: 'gpt-4', + apiKey: 'test-key', + authType: AuthType.USE_OPENAI, + samplingParams: { + temperature: 0.8, + max_tokens: 1500, + top_p: 0.95, + top_k: 40, + repetition_penalty: 1.1, + presence_penalty: 0.5, + frequency_penalty: 0.3, + }, + }; const testGenerator = new OpenAIContentGenerator( - 'test-key', - 'gpt-4', + contentGeneratorConfig, fullSamplingConfig, ); @@ -2489,9 +2586,14 @@ describe('OpenAIContentGenerator', () => { getCliVersion: vi.fn().mockReturnValue('1.0.0'), } as unknown as Config; + const contentGeneratorConfig = { + model: 'qwen-turbo', + apiKey: 'test-key', + authType: AuthType.QWEN_OAUTH, + enableOpenAILogging: false, + }; const qwenGenerator = new OpenAIContentGenerator( - 'test-key', - 'qwen-turbo', + contentGeneratorConfig, qwenConfig, ); @@ -2528,12 +2630,6 @@ describe('OpenAIContentGenerator', () => { }); it('should include metadata when baseURL is dashscope openai compatible mode', async () => { - // Mock environment to set dashscope base URL BEFORE creating the generator - vi.stubEnv( - 'OPENAI_BASE_URL', - 'https://dashscope.aliyuncs.com/compatible-mode/v1', - ); - const dashscopeConfig = { getContentGeneratorConfig: vi.fn().mockReturnValue({ authType: 'openai', // Not QWEN_OAUTH @@ -2543,9 +2639,15 @@ describe('OpenAIContentGenerator', () => { getCliVersion: vi.fn().mockReturnValue('1.0.0'), } as unknown as Config; + const contentGeneratorConfig = { + model: 'qwen-turbo', + apiKey: 'test-key', + baseUrl: 'https://dashscope.aliyuncs.com/compatible-mode/v1', + authType: AuthType.USE_OPENAI, + enableOpenAILogging: false, + }; const dashscopeGenerator = new OpenAIContentGenerator( - 'test-key', - 'qwen-turbo', + contentGeneratorConfig, dashscopeConfig, ); @@ -2604,9 +2706,18 @@ describe('OpenAIContentGenerator', () => { getCliVersion: vi.fn().mockReturnValue('1.0.0'), } as unknown as Config; + const contentGeneratorConfig = { + model: 'gpt-4', + + apiKey: 'test-key', + + authType: AuthType.USE_OPENAI, + + enableOpenAILogging: false, + }; + const regularGenerator = new OpenAIContentGenerator( - 'test-key', - 'gpt-4', + contentGeneratorConfig, regularConfig, ); @@ -2650,9 +2761,18 @@ describe('OpenAIContentGenerator', () => { getCliVersion: vi.fn().mockReturnValue('1.0.0'), } as unknown as Config; + const contentGeneratorConfig = { + model: 'gpt-4', + + apiKey: 'test-key', + + authType: AuthType.USE_OPENAI, + + enableOpenAILogging: false, + }; + const otherGenerator = new OpenAIContentGenerator( - 'test-key', - 'gpt-4', + contentGeneratorConfig, otherAuthConfig, ); @@ -2699,9 +2819,18 @@ describe('OpenAIContentGenerator', () => { getCliVersion: vi.fn().mockReturnValue('1.0.0'), } as unknown as Config; + const contentGeneratorConfig = { + model: 'gpt-4', + + apiKey: 'test-key', + + authType: AuthType.USE_OPENAI, + + enableOpenAILogging: false, + }; + const otherBaseUrlGenerator = new OpenAIContentGenerator( - 'test-key', - 'gpt-4', + contentGeneratorConfig, otherBaseUrlConfig, ); @@ -2748,9 +2877,18 @@ describe('OpenAIContentGenerator', () => { getCliVersion: vi.fn().mockReturnValue('1.0.0'), } as unknown as Config; + const contentGeneratorConfig = { + model: 'qwen-turbo', + + apiKey: 'test-key', + + authType: AuthType.QWEN_OAUTH, + + enableOpenAILogging: false, + }; + const qwenGenerator = new OpenAIContentGenerator( - 'test-key', - 'qwen-turbo', + contentGeneratorConfig, qwenConfig, ); @@ -2804,8 +2942,6 @@ describe('OpenAIContentGenerator', () => { sessionId: 'streaming-session-id', promptId: 'streaming-prompt-id', }, - stream: true, - stream_options: { include_usage: true }, }), ); @@ -2827,9 +2963,18 @@ describe('OpenAIContentGenerator', () => { getCliVersion: vi.fn().mockReturnValue('1.0.0'), } as unknown as Config; + const contentGeneratorConfig = { + model: 'gpt-4', + + apiKey: 'test-key', + + authType: AuthType.USE_OPENAI, + + enableOpenAILogging: false, + }; + const regularGenerator = new OpenAIContentGenerator( - 'test-key', - 'gpt-4', + contentGeneratorConfig, regularConfig, ); @@ -2901,9 +3046,18 @@ describe('OpenAIContentGenerator', () => { getCliVersion: vi.fn().mockReturnValue('1.0.0'), } as unknown as Config; + const contentGeneratorConfig = { + model: 'qwen-turbo', + + apiKey: 'test-key', + + authType: AuthType.QWEN_OAUTH, + + enableOpenAILogging: false, + }; + const qwenGenerator = new OpenAIContentGenerator( - 'test-key', - 'qwen-turbo', + contentGeneratorConfig, qwenConfig, ); @@ -2955,9 +3109,18 @@ describe('OpenAIContentGenerator', () => { getCliVersion: vi.fn().mockReturnValue('1.0.0'), } as unknown as Config; + const contentGeneratorConfig = { + model: 'gpt-4', + + apiKey: 'test-key', + + authType: AuthType.USE_OPENAI, + + enableOpenAILogging: false, + }; + const noBaseUrlGenerator = new OpenAIContentGenerator( - 'test-key', - 'gpt-4', + contentGeneratorConfig, noBaseUrlConfig, ); @@ -3004,9 +3167,18 @@ describe('OpenAIContentGenerator', () => { getCliVersion: vi.fn().mockReturnValue('1.0.0'), } as unknown as Config; + const contentGeneratorConfig = { + model: 'gpt-4', + + apiKey: 'test-key', + + authType: AuthType.USE_OPENAI, + + enableOpenAILogging: false, + }; + const undefinedAuthGenerator = new OpenAIContentGenerator( - 'test-key', - 'gpt-4', + contentGeneratorConfig, undefinedAuthConfig, ); @@ -3050,9 +3222,18 @@ describe('OpenAIContentGenerator', () => { getCliVersion: vi.fn().mockReturnValue('1.0.0'), } as unknown as Config; + const contentGeneratorConfig = { + model: 'gpt-4', + + apiKey: 'test-key', + + authType: AuthType.USE_OPENAI, + + enableOpenAILogging: false, + }; + const undefinedConfigGenerator = new OpenAIContentGenerator( - 'test-key', - 'gpt-4', + contentGeneratorConfig, undefinedConfig, ); @@ -3089,4 +3270,232 @@ describe('OpenAIContentGenerator', () => { ); }); }); + + describe('cache control for DashScope', () => { + it('should add cache control to system message for DashScope providers', async () => { + // Mock environment to set dashscope base URL + vi.stubEnv( + 'OPENAI_BASE_URL', + 'https://dashscope.aliyuncs.com/compatible-mode/v1', + ); + + const dashscopeConfig = { + getContentGeneratorConfig: vi.fn().mockReturnValue({ + authType: 'openai', + enableOpenAILogging: false, + }), + getSessionId: vi.fn().mockReturnValue('dashscope-session-id'), + getCliVersion: vi.fn().mockReturnValue('1.0.0'), + } as unknown as Config; + + const contentGeneratorConfig = { + model: 'qwen-turbo', + + apiKey: 'test-key', + + authType: AuthType.QWEN_OAUTH, + + enableOpenAILogging: false, + }; + + const dashscopeGenerator = new OpenAIContentGenerator( + contentGeneratorConfig, + dashscopeConfig, + ); + + // Mock the client's baseURL property to return the expected value + Object.defineProperty(dashscopeGenerator['client'], 'baseURL', { + value: 'https://dashscope.aliyuncs.com/compatible-mode/v1', + writable: true, + }); + + const mockResponse = { + id: 'chatcmpl-123', + choices: [ + { + index: 0, + message: { role: 'assistant', content: 'Response' }, + finish_reason: 'stop', + }, + ], + created: 1677652288, + model: 'qwen-turbo', + }; + + mockOpenAIClient.chat.completions.create.mockResolvedValue(mockResponse); + + const request: GenerateContentParameters = { + contents: [{ role: 'user', parts: [{ text: 'Hello' }] }], + config: { + systemInstruction: 'You are a helpful assistant.', + }, + model: 'qwen-turbo', + }; + + await dashscopeGenerator.generateContent(request, 'dashscope-prompt-id'); + + // Should include cache control in system message + expect(mockOpenAIClient.chat.completions.create).toHaveBeenCalledWith( + expect.objectContaining({ + messages: expect.arrayContaining([ + expect.objectContaining({ + role: 'system', + content: expect.arrayContaining([ + expect.objectContaining({ + type: 'text', + text: 'You are a helpful assistant.', + cache_control: { type: 'ephemeral' }, + }), + ]), + }), + ]), + }), + ); + }); + + it('should add cache control to last message for DashScope providers', async () => { + // Mock environment to set dashscope base URL + vi.stubEnv( + 'OPENAI_BASE_URL', + 'https://dashscope.aliyuncs.com/compatible-mode/v1', + ); + + const dashscopeConfig = { + getContentGeneratorConfig: vi.fn().mockReturnValue({ + authType: 'openai', + enableOpenAILogging: false, + }), + getSessionId: vi.fn().mockReturnValue('dashscope-session-id'), + getCliVersion: vi.fn().mockReturnValue('1.0.0'), + } as unknown as Config; + + const contentGeneratorConfig = { + model: 'qwen-turbo', + + apiKey: 'test-key', + + authType: AuthType.QWEN_OAUTH, + + enableOpenAILogging: false, + }; + + const dashscopeGenerator = new OpenAIContentGenerator( + contentGeneratorConfig, + dashscopeConfig, + ); + + // Mock the client's baseURL property to return the expected value + Object.defineProperty(dashscopeGenerator['client'], 'baseURL', { + value: 'https://dashscope.aliyuncs.com/compatible-mode/v1', + writable: true, + }); + + const mockResponse = { + id: 'chatcmpl-123', + choices: [ + { + index: 0, + message: { role: 'assistant', content: 'Response' }, + finish_reason: 'stop', + }, + ], + created: 1677652288, + model: 'qwen-turbo', + }; + + mockOpenAIClient.chat.completions.create.mockResolvedValue(mockResponse); + + const request: GenerateContentParameters = { + contents: [{ role: 'user', parts: [{ text: 'Hello, how are you?' }] }], + model: 'qwen-turbo', + }; + + await dashscopeGenerator.generateContent(request, 'dashscope-prompt-id'); + + // Should include cache control in last message + expect(mockOpenAIClient.chat.completions.create).toHaveBeenCalledWith( + expect.objectContaining({ + messages: expect.arrayContaining([ + expect.objectContaining({ + role: 'user', + content: expect.arrayContaining([ + expect.objectContaining({ + type: 'text', + text: 'Hello, how are you?', + cache_control: { type: 'ephemeral' }, + }), + ]), + }), + ]), + }), + ); + }); + + it('should NOT add cache control for non-DashScope providers', async () => { + const regularConfig = { + getContentGeneratorConfig: vi.fn().mockReturnValue({ + authType: 'openai', + enableOpenAILogging: false, + }), + getSessionId: vi.fn().mockReturnValue('regular-session-id'), + getCliVersion: vi.fn().mockReturnValue('1.0.0'), + } as unknown as Config; + + const contentGeneratorConfig = { + model: 'gpt-4', + + apiKey: 'test-key', + + authType: AuthType.USE_OPENAI, + + enableOpenAILogging: false, + }; + + const regularGenerator = new OpenAIContentGenerator( + contentGeneratorConfig, + regularConfig, + ); + + const mockResponse = { + id: 'chatcmpl-123', + choices: [ + { + index: 0, + message: { role: 'assistant', content: 'Response' }, + finish_reason: 'stop', + }, + ], + created: 1677652288, + model: 'gpt-4', + }; + + mockOpenAIClient.chat.completions.create.mockResolvedValue(mockResponse); + + const request: GenerateContentParameters = { + contents: [{ role: 'user', parts: [{ text: 'Hello' }] }], + config: { + systemInstruction: 'You are a helpful assistant.', + }, + model: 'gpt-4', + }; + + await regularGenerator.generateContent(request, 'regular-prompt-id'); + + // Should NOT include cache control (messages should be strings, not arrays) + expect(mockOpenAIClient.chat.completions.create).toHaveBeenCalledWith( + expect.objectContaining({ + messages: expect.arrayContaining([ + expect.objectContaining({ + role: 'system', + content: 'You are a helpful assistant.', + }), + expect.objectContaining({ + role: 'user', + content: 'Hello', + }), + ]), + }), + ); + }); + }); }); diff --git a/packages/core/src/core/openaiContentGenerator.ts b/packages/core/src/core/openaiContentGenerator.ts index eeba5db7..3f223f3e 100644 --- a/packages/core/src/core/openaiContentGenerator.ts +++ b/packages/core/src/core/openaiContentGenerator.ts @@ -20,7 +20,11 @@ import { FunctionCall, FunctionResponse, } from '@google/genai'; -import { AuthType, ContentGenerator } from './contentGenerator.js'; +import { + AuthType, + ContentGenerator, + ContentGeneratorConfig, +} from './contentGenerator.js'; import OpenAI from 'openai'; import { logApiError, logApiResponse } from '../telemetry/loggers.js'; import { ApiErrorEvent, ApiResponseEvent } from '../telemetry/types.js'; @@ -28,6 +32,17 @@ import { Config } from '../config/config.js'; import { openaiLogger } from '../utils/openaiLogger.js'; import { safeJsonParse } from '../utils/safeJsonParse.js'; +// Extended types to support cache_control +interface ChatCompletionContentPartTextWithCache + extends OpenAI.Chat.ChatCompletionContentPartText { + cache_control?: { type: 'ephemeral' }; +} + +type ChatCompletionContentPartWithCache = + | ChatCompletionContentPartTextWithCache + | OpenAI.Chat.ChatCompletionContentPartImage + | OpenAI.Chat.ChatCompletionContentPartRefusal; + // OpenAI API type definitions for logging interface OpenAIToolCall { id: string; @@ -38,9 +53,15 @@ interface OpenAIToolCall { }; } +interface OpenAIContentItem { + type: 'text'; + text: string; + cache_control?: { type: 'ephemeral' }; +} + interface OpenAIMessage { role: 'system' | 'user' | 'assistant' | 'tool'; - content: string | null; + content: string | null | OpenAIContentItem[]; tool_calls?: OpenAIToolCall[]; tool_call_id?: string; } @@ -60,15 +81,6 @@ interface OpenAIChoice { finish_reason: string; } -interface OpenAIRequestFormat { - model: string; - messages: OpenAIMessage[]; - temperature?: number; - max_tokens?: number; - top_p?: number; - tools?: unknown[]; -} - interface OpenAIResponseFormat { id: string; object: string; @@ -81,6 +93,7 @@ interface OpenAIResponseFormat { export class OpenAIContentGenerator implements ContentGenerator { protected client: OpenAI; private model: string; + private contentGeneratorConfig: ContentGeneratorConfig; private config: Config; private streamingToolCalls: Map< number, @@ -91,50 +104,40 @@ export class OpenAIContentGenerator implements ContentGenerator { } > = new Map(); - constructor(apiKey: string, model: string, config: Config) { - this.model = model; - this.config = config; - const baseURL = process.env.OPENAI_BASE_URL || ''; + constructor( + contentGeneratorConfig: ContentGeneratorConfig, + gcConfig: Config, + ) { + this.model = contentGeneratorConfig.model; + this.contentGeneratorConfig = contentGeneratorConfig; + this.config = gcConfig; - // Configure timeout settings - using progressive timeouts - const timeoutConfig = { - // Base timeout for most requests (2 minutes) - timeout: 120000, - // Maximum retries for failed requests - maxRetries: 3, - // HTTP client options - httpAgent: undefined, // Let the client use default agent - }; - - // Allow config to override timeout settings - const contentGeneratorConfig = this.config.getContentGeneratorConfig(); - if (contentGeneratorConfig?.timeout) { - timeoutConfig.timeout = contentGeneratorConfig.timeout; - } - if (contentGeneratorConfig?.maxRetries !== undefined) { - timeoutConfig.maxRetries = contentGeneratorConfig.maxRetries; - } - - const version = config.getCliVersion() || 'unknown'; + const version = gcConfig.getCliVersion() || 'unknown'; const userAgent = `QwenCode/${version} (${process.platform}; ${process.arch})`; // Check if using OpenRouter and add required headers - const isOpenRouter = baseURL.includes('openrouter.ai'); + const isOpenRouterProvider = this.isOpenRouterProvider(); + const isDashScopeProvider = this.isDashScopeProvider(); + const defaultHeaders = { 'User-Agent': userAgent, - ...(isOpenRouter + ...(isOpenRouterProvider ? { 'HTTP-Referer': 'https://github.com/QwenLM/qwen-code.git', 'X-Title': 'Qwen Code', } - : {}), + : isDashScopeProvider + ? { + 'X-DashScope-CacheControl': 'enable', + } + : {}), }; this.client = new OpenAI({ - apiKey, - baseURL, - timeout: timeoutConfig.timeout, - maxRetries: timeoutConfig.maxRetries, + apiKey: contentGeneratorConfig.apiKey, + baseURL: contentGeneratorConfig.baseUrl, + timeout: contentGeneratorConfig.timeout ?? 120000, + maxRetries: contentGeneratorConfig.maxRetries ?? 3, defaultHeaders, }); } @@ -185,22 +188,25 @@ export class OpenAIContentGenerator implements ContentGenerator { ); } + private isOpenRouterProvider(): boolean { + const baseURL = this.contentGeneratorConfig.baseUrl || ''; + return baseURL.includes('openrouter.ai'); + } + /** - * Determine if metadata should be included in the request. - * Only include the `metadata` field if the provider is QWEN_OAUTH - * or the baseUrl is 'https://dashscope.aliyuncs.com/compatible-mode/v1'. - * This is because some models/providers do not support metadata or need extra configuration. + * Determine if this is a DashScope provider. + * DashScope providers include QWEN_OAUTH auth type or specific DashScope base URLs. * - * @returns true if metadata should be included, false otherwise + * @returns true if this is a DashScope provider, false otherwise */ - private shouldIncludeMetadata(): boolean { - const authType = this.config.getContentGeneratorConfig?.()?.authType; - // baseUrl may be undefined; default to empty string if so - const baseUrl = this.client?.baseURL || ''; + private isDashScopeProvider(): boolean { + const authType = this.contentGeneratorConfig.authType; + const baseUrl = this.contentGeneratorConfig.baseUrl; return ( authType === AuthType.QWEN_OAUTH || - baseUrl === 'https://dashscope.aliyuncs.com/compatible-mode/v1' + baseUrl === 'https://dashscope.aliyuncs.com/compatible-mode/v1' || + baseUrl === 'https://dashscope-intl.aliyuncs.com/compatible-mode/v1' ); } @@ -213,7 +219,7 @@ export class OpenAIContentGenerator implements ContentGenerator { private buildMetadata( userPromptId: string, ): { metadata: { sessionId?: string; promptId: string } } | undefined { - if (!this.shouldIncludeMetadata()) { + if (!this.isDashScopeProvider()) { return undefined; } @@ -225,35 +231,44 @@ export class OpenAIContentGenerator implements ContentGenerator { }; } + private async buildCreateParams( + request: GenerateContentParameters, + userPromptId: string, + ): Promise[0]> { + const messages = this.convertToOpenAIFormat(request); + + // Build sampling parameters with clear priority: + // 1. Request-level parameters (highest priority) + // 2. Config-level sampling parameters (medium priority) + // 3. Default values (lowest priority) + const samplingParams = this.buildSamplingParameters(request); + + const createParams: Parameters< + typeof this.client.chat.completions.create + >[0] = { + model: this.model, + messages, + ...samplingParams, + ...(this.buildMetadata(userPromptId) || {}), + }; + + if (request.config?.tools) { + createParams.tools = await this.convertGeminiToolsToOpenAI( + request.config.tools, + ); + } + + return createParams; + } + async generateContent( request: GenerateContentParameters, userPromptId: string, ): Promise { const startTime = Date.now(); - const messages = this.convertToOpenAIFormat(request); + const createParams = await this.buildCreateParams(request, userPromptId); try { - // Build sampling parameters with clear priority: - // 1. Request-level parameters (highest priority) - // 2. Config-level sampling parameters (medium priority) - // 3. Default values (lowest priority) - const samplingParams = this.buildSamplingParameters(request); - - const createParams: Parameters< - typeof this.client.chat.completions.create - >[0] = { - model: this.model, - messages, - ...samplingParams, - ...(this.buildMetadata(userPromptId) || {}), - }; - - if (request.config?.tools) { - createParams.tools = await this.convertGeminiToolsToOpenAI( - request.config.tools, - ); - } - // console.log('createParams', createParams); const completion = (await this.client.chat.completions.create( createParams, )) as OpenAI.Chat.ChatCompletion; @@ -267,15 +282,15 @@ export class OpenAIContentGenerator implements ContentGenerator { this.model, durationMs, userPromptId, - this.config.getContentGeneratorConfig()?.authType, + this.contentGeneratorConfig.authType, response.usageMetadata, ); logApiResponse(this.config, responseEvent); // Log interaction if enabled - if (this.config.getContentGeneratorConfig()?.enableOpenAILogging) { - const openaiRequest = await this.convertGeminiRequestToOpenAI(request); + if (this.contentGeneratorConfig.enableOpenAILogging) { + const openaiRequest = createParams; const openaiResponse = this.convertGeminiResponseToOpenAI(response); await openaiLogger.logInteraction(openaiRequest, openaiResponse); } @@ -300,7 +315,7 @@ export class OpenAIContentGenerator implements ContentGenerator { errorMessage, durationMs, userPromptId, - this.config.getContentGeneratorConfig()?.authType, + this.contentGeneratorConfig.authType, // eslint-disable-next-line @typescript-eslint/no-explicit-any (error as any).type, // eslint-disable-next-line @typescript-eslint/no-explicit-any @@ -309,10 +324,9 @@ export class OpenAIContentGenerator implements ContentGenerator { logApiError(this.config, errorEvent); // Log error interaction if enabled - if (this.config.getContentGeneratorConfig()?.enableOpenAILogging) { - const openaiRequest = await this.convertGeminiRequestToOpenAI(request); + if (this.contentGeneratorConfig.enableOpenAILogging) { await openaiLogger.logInteraction( - openaiRequest, + createParams, undefined, error as Error, ); @@ -343,29 +357,12 @@ export class OpenAIContentGenerator implements ContentGenerator { userPromptId: string, ): Promise> { const startTime = Date.now(); - const messages = this.convertToOpenAIFormat(request); + const createParams = await this.buildCreateParams(request, userPromptId); + + createParams.stream = true; + createParams.stream_options = { include_usage: true }; try { - // Build sampling parameters with clear priority - const samplingParams = this.buildSamplingParameters(request); - - const createParams: Parameters< - typeof this.client.chat.completions.create - >[0] = { - model: this.model, - messages, - ...samplingParams, - stream: true, - stream_options: { include_usage: true }, - ...(this.buildMetadata(userPromptId) || {}), - }; - - if (request.config?.tools) { - createParams.tools = await this.convertGeminiToolsToOpenAI( - request.config.tools, - ); - } - const stream = (await this.client.chat.completions.create( createParams, )) as AsyncIterable; @@ -397,16 +394,15 @@ export class OpenAIContentGenerator implements ContentGenerator { this.model, durationMs, userPromptId, - this.config.getContentGeneratorConfig()?.authType, + this.contentGeneratorConfig.authType, finalUsageMetadata, ); logApiResponse(this.config, responseEvent); // Log interaction if enabled (same as generateContent method) - if (this.config.getContentGeneratorConfig()?.enableOpenAILogging) { - const openaiRequest = - await this.convertGeminiRequestToOpenAI(request); + if (this.contentGeneratorConfig.enableOpenAILogging) { + const openaiRequest = createParams; // For streaming, we combine all responses into a single response for logging const combinedResponse = this.combineStreamResponsesForLogging(responses); @@ -433,7 +429,7 @@ export class OpenAIContentGenerator implements ContentGenerator { errorMessage, durationMs, userPromptId, - this.config.getContentGeneratorConfig()?.authType, + this.contentGeneratorConfig.authType, // eslint-disable-next-line @typescript-eslint/no-explicit-any (error as any).type, // eslint-disable-next-line @typescript-eslint/no-explicit-any @@ -442,11 +438,9 @@ export class OpenAIContentGenerator implements ContentGenerator { logApiError(this.config, errorEvent); // Log error interaction if enabled - if (this.config.getContentGeneratorConfig()?.enableOpenAILogging) { - const openaiRequest = - await this.convertGeminiRequestToOpenAI(request); + if (this.contentGeneratorConfig.enableOpenAILogging) { await openaiLogger.logInteraction( - openaiRequest, + createParams, undefined, error as Error, ); @@ -487,7 +481,7 @@ export class OpenAIContentGenerator implements ContentGenerator { errorMessage, durationMs, userPromptId, - this.config.getContentGeneratorConfig()?.authType, + this.contentGeneratorConfig.authType, // eslint-disable-next-line @typescript-eslint/no-explicit-any (error as any).type, // eslint-disable-next-line @typescript-eslint/no-explicit-any @@ -563,7 +557,7 @@ export class OpenAIContentGenerator implements ContentGenerator { // Add combined text if any if (combinedText) { - combinedParts.push({ text: combinedText.trimEnd() }); + combinedParts.push({ text: combinedText }); } // Add function calls @@ -944,7 +938,114 @@ export class OpenAIContentGenerator implements ContentGenerator { // Clean up orphaned tool calls and merge consecutive assistant messages const cleanedMessages = this.cleanOrphanedToolCalls(messages); - return this.mergeConsecutiveAssistantMessages(cleanedMessages); + const mergedMessages = + this.mergeConsecutiveAssistantMessages(cleanedMessages); + + // Add cache control to system and last messages for DashScope providers + return this.addCacheControlFlag(mergedMessages, 'both'); + } + + /** + * Add cache control flag to specified message(s) for DashScope providers + */ + private addCacheControlFlag( + messages: OpenAI.Chat.ChatCompletionMessageParam[], + target: 'system' | 'last' | 'both' = 'both', + ): OpenAI.Chat.ChatCompletionMessageParam[] { + if (!this.isDashScopeProvider() || messages.length === 0) { + return messages; + } + + let updatedMessages = [...messages]; + + // Add cache control to system message if requested + if (target === 'system' || target === 'both') { + updatedMessages = this.addCacheControlToMessage( + updatedMessages, + 'system', + ); + } + + // Add cache control to last message if requested + if (target === 'last' || target === 'both') { + updatedMessages = this.addCacheControlToMessage(updatedMessages, 'last'); + } + + return updatedMessages; + } + + /** + * Helper method to add cache control to a specific message + */ + private addCacheControlToMessage( + messages: OpenAI.Chat.ChatCompletionMessageParam[], + target: 'system' | 'last', + ): OpenAI.Chat.ChatCompletionMessageParam[] { + const updatedMessages = [...messages]; + let messageIndex: number; + + if (target === 'system') { + // Find the first system message + messageIndex = messages.findIndex((msg) => msg.role === 'system'); + if (messageIndex === -1) { + return updatedMessages; + } + } else { + // Get the last message + messageIndex = messages.length - 1; + } + + const message = updatedMessages[messageIndex]; + + // Only process messages that have content + if ('content' in message && message.content !== null) { + if (typeof message.content === 'string') { + // Convert string content to array format with cache control + const messageWithArrayContent = { + ...message, + content: [ + { + type: 'text', + text: message.content, + cache_control: { type: 'ephemeral' }, + } as ChatCompletionContentPartTextWithCache, + ], + }; + updatedMessages[messageIndex] = + messageWithArrayContent as OpenAI.Chat.ChatCompletionMessageParam; + } else if (Array.isArray(message.content)) { + // If content is already an array, add cache_control to the last item + const contentArray = [ + ...message.content, + ] as ChatCompletionContentPartWithCache[]; + if (contentArray.length > 0) { + const lastItem = contentArray[contentArray.length - 1]; + if (lastItem.type === 'text') { + // Add cache_control to the last text item + contentArray[contentArray.length - 1] = { + ...lastItem, + cache_control: { type: 'ephemeral' }, + } as ChatCompletionContentPartTextWithCache; + } else { + // If the last item is not text, add a new text item with cache_control + contentArray.push({ + type: 'text', + text: '', + cache_control: { type: 'ephemeral' }, + } as ChatCompletionContentPartTextWithCache); + } + + const messageWithCache = { + ...message, + content: contentArray, + }; + updatedMessages[messageIndex] = + messageWithCache as OpenAI.Chat.ChatCompletionMessageParam; + } + } + } + + return updatedMessages; } /** @@ -1164,11 +1265,7 @@ export class OpenAIContentGenerator implements ContentGenerator { // Handle text content if (choice.message.content) { - if (typeof choice.message.content === 'string') { - parts.push({ text: choice.message.content.trimEnd() }); - } else { - parts.push({ text: choice.message.content }); - } + parts.push({ text: choice.message.content }); } // Handle tool calls @@ -1253,11 +1350,7 @@ export class OpenAIContentGenerator implements ContentGenerator { // Handle text content if (choice.delta?.content) { - if (typeof choice.delta.content === 'string') { - parts.push({ text: choice.delta.content.trimEnd() }); - } else { - parts.push({ text: choice.delta.content }); - } + parts.push({ text: choice.delta.content }); } // Handle tool calls - only accumulate during streaming, emit when complete @@ -1376,8 +1469,7 @@ export class OpenAIContentGenerator implements ContentGenerator { private buildSamplingParameters( request: GenerateContentParameters, ): Record { - const configSamplingParams = - this.config.getContentGeneratorConfig()?.samplingParams; + const configSamplingParams = this.contentGeneratorConfig.samplingParams; const params = { // Temperature: config > request > default @@ -1439,313 +1531,6 @@ export class OpenAIContentGenerator implements ContentGenerator { return mapping[openaiReason] || FinishReason.FINISH_REASON_UNSPECIFIED; } - /** - * Convert Gemini request format to OpenAI chat completion format for logging - */ - private async convertGeminiRequestToOpenAI( - request: GenerateContentParameters, - ): Promise { - const messages: OpenAIMessage[] = []; - - // Handle system instruction - if (request.config?.systemInstruction) { - const systemInstruction = request.config.systemInstruction; - let systemText = ''; - - if (Array.isArray(systemInstruction)) { - systemText = systemInstruction - .map((content) => { - if (typeof content === 'string') return content; - if ('parts' in content) { - const contentObj = content as Content; - return ( - contentObj.parts - ?.map((p: Part) => - typeof p === 'string' ? p : 'text' in p ? p.text : '', - ) - .join('\n') || '' - ); - } - return ''; - }) - .join('\n'); - } else if (typeof systemInstruction === 'string') { - systemText = systemInstruction; - } else if ( - typeof systemInstruction === 'object' && - 'parts' in systemInstruction - ) { - const systemContent = systemInstruction as Content; - systemText = - systemContent.parts - ?.map((p: Part) => - typeof p === 'string' ? p : 'text' in p ? p.text : '', - ) - .join('\n') || ''; - } - - if (systemText) { - messages.push({ - role: 'system', - content: systemText, - }); - } - } - - // Handle contents - if (Array.isArray(request.contents)) { - for (const content of request.contents) { - if (typeof content === 'string') { - messages.push({ role: 'user', content }); - } else if ('role' in content && 'parts' in content) { - const functionCalls: FunctionCall[] = []; - const functionResponses: FunctionResponse[] = []; - const textParts: string[] = []; - - for (const part of content.parts || []) { - if (typeof part === 'string') { - textParts.push(part); - } else if ('text' in part && part.text) { - textParts.push(part.text); - } else if ('functionCall' in part && part.functionCall) { - functionCalls.push(part.functionCall); - } else if ('functionResponse' in part && part.functionResponse) { - functionResponses.push(part.functionResponse); - } - } - - // Handle function responses (tool results) - if (functionResponses.length > 0) { - for (const funcResponse of functionResponses) { - messages.push({ - role: 'tool', - tool_call_id: funcResponse.id || '', - content: - typeof funcResponse.response === 'string' - ? funcResponse.response - : JSON.stringify(funcResponse.response), - }); - } - } - // Handle model messages with function calls - else if (content.role === 'model' && functionCalls.length > 0) { - const toolCalls = functionCalls.map((fc, index) => ({ - id: fc.id || `call_${index}`, - type: 'function' as const, - function: { - name: fc.name || '', - arguments: JSON.stringify(fc.args || {}), - }, - })); - - messages.push({ - role: 'assistant', - content: textParts.join('\n') || null, - tool_calls: toolCalls, - }); - } - // Handle regular text messages - else { - const role = content.role === 'model' ? 'assistant' : 'user'; - const text = textParts.join('\n'); - if (text) { - messages.push({ role, content: text }); - } - } - } - } - } else if (request.contents) { - if (typeof request.contents === 'string') { - messages.push({ role: 'user', content: request.contents }); - } else if ('role' in request.contents && 'parts' in request.contents) { - const content = request.contents; - const role = content.role === 'model' ? 'assistant' : 'user'; - const text = - content.parts - ?.map((p: Part) => - typeof p === 'string' ? p : 'text' in p ? p.text : '', - ) - .join('\n') || ''; - messages.push({ role, content: text }); - } - } - - // Clean up orphaned tool calls and merge consecutive assistant messages - const cleanedMessages = this.cleanOrphanedToolCallsForLogging(messages); - const mergedMessages = - this.mergeConsecutiveAssistantMessagesForLogging(cleanedMessages); - - const openaiRequest: OpenAIRequestFormat = { - model: this.model, - messages: mergedMessages, - }; - - // Add sampling parameters using the same logic as actual API calls - const samplingParams = this.buildSamplingParameters(request); - Object.assign(openaiRequest, samplingParams); - - // Convert tools if present - if (request.config?.tools) { - openaiRequest.tools = await this.convertGeminiToolsToOpenAI( - request.config.tools, - ); - } - - return openaiRequest; - } - - /** - * Clean up orphaned tool calls for logging purposes - */ - private cleanOrphanedToolCallsForLogging( - messages: OpenAIMessage[], - ): OpenAIMessage[] { - const cleaned: OpenAIMessage[] = []; - const toolCallIds = new Set(); - const toolResponseIds = new Set(); - - // First pass: collect all tool call IDs and tool response IDs - for (const message of messages) { - if (message.role === 'assistant' && message.tool_calls) { - for (const toolCall of message.tool_calls) { - if (toolCall.id) { - toolCallIds.add(toolCall.id); - } - } - } else if (message.role === 'tool' && message.tool_call_id) { - toolResponseIds.add(message.tool_call_id); - } - } - - // Second pass: filter out orphaned messages - for (const message of messages) { - if (message.role === 'assistant' && message.tool_calls) { - // Filter out tool calls that don't have corresponding responses - const validToolCalls = message.tool_calls.filter( - (toolCall) => toolCall.id && toolResponseIds.has(toolCall.id), - ); - - if (validToolCalls.length > 0) { - // Keep the message but only with valid tool calls - const cleanedMessage = { ...message }; - cleanedMessage.tool_calls = validToolCalls; - cleaned.push(cleanedMessage); - } else if ( - typeof message.content === 'string' && - message.content.trim() - ) { - // Keep the message if it has text content, but remove tool calls - const cleanedMessage = { ...message }; - delete cleanedMessage.tool_calls; - cleaned.push(cleanedMessage); - } - // If no valid tool calls and no content, skip the message entirely - } else if (message.role === 'tool' && message.tool_call_id) { - // Only keep tool responses that have corresponding tool calls - if (toolCallIds.has(message.tool_call_id)) { - cleaned.push(message); - } - } else { - // Keep all other messages as-is - cleaned.push(message); - } - } - - // Final validation: ensure every assistant message with tool_calls has corresponding tool responses - const finalCleaned: OpenAIMessage[] = []; - const finalToolCallIds = new Set(); - - // Collect all remaining tool call IDs - for (const message of cleaned) { - if (message.role === 'assistant' && message.tool_calls) { - for (const toolCall of message.tool_calls) { - if (toolCall.id) { - finalToolCallIds.add(toolCall.id); - } - } - } - } - - // Verify all tool calls have responses - const finalToolResponseIds = new Set(); - for (const message of cleaned) { - if (message.role === 'tool' && message.tool_call_id) { - finalToolResponseIds.add(message.tool_call_id); - } - } - - // Remove any remaining orphaned tool calls - for (const message of cleaned) { - if (message.role === 'assistant' && message.tool_calls) { - const finalValidToolCalls = message.tool_calls.filter( - (toolCall) => toolCall.id && finalToolResponseIds.has(toolCall.id), - ); - - if (finalValidToolCalls.length > 0) { - const cleanedMessage = { ...message }; - cleanedMessage.tool_calls = finalValidToolCalls; - finalCleaned.push(cleanedMessage); - } else if ( - typeof message.content === 'string' && - message.content.trim() - ) { - const cleanedMessage = { ...message }; - delete cleanedMessage.tool_calls; - finalCleaned.push(cleanedMessage); - } - } else { - finalCleaned.push(message); - } - } - - return finalCleaned; - } - - /** - * Merge consecutive assistant messages to combine split text and tool calls for logging - */ - private mergeConsecutiveAssistantMessagesForLogging( - messages: OpenAIMessage[], - ): OpenAIMessage[] { - const merged: OpenAIMessage[] = []; - - for (const message of messages) { - if (message.role === 'assistant' && merged.length > 0) { - const lastMessage = merged[merged.length - 1]; - - // If the last message is also an assistant message, merge them - if (lastMessage.role === 'assistant') { - // Combine content - const combinedContent = [ - lastMessage.content || '', - message.content || '', - ] - .filter(Boolean) - .join(''); - - // Combine tool calls - const combinedToolCalls = [ - ...(lastMessage.tool_calls || []), - ...(message.tool_calls || []), - ]; - - // Update the last message with combined data - lastMessage.content = combinedContent || null; - if (combinedToolCalls.length > 0) { - lastMessage.tool_calls = combinedToolCalls; - } - - continue; // Skip adding the current message since it's been merged - } - } - - // Add the message as-is if no merging is needed - merged.push(message); - } - - return merged; - } - /** * Convert Gemini response format to OpenAI chat completion format for logging */ @@ -1776,7 +1561,7 @@ export class OpenAIContentGenerator implements ContentGenerator { } } - messageContent = textParts.join('').trimEnd(); + messageContent = textParts.join(''); } const choice: OpenAIChoice = { diff --git a/packages/core/src/ide/detect-ide.test.ts b/packages/core/src/ide/detect-ide.test.ts index 85249ad6..41a713e6 100644 --- a/packages/core/src/ide/detect-ide.test.ts +++ b/packages/core/src/ide/detect-ide.test.ts @@ -54,15 +54,39 @@ describe('detectIde', () => { expected: DetectedIde.FirebaseStudio, }, ])('detects the IDE for $expected', ({ env, expected }) => { + // Clear all environment variables first + vi.unstubAllEnvs(); + + // Set TERM_PROGRAM to vscode (required for all IDE detection) vi.stubEnv('TERM_PROGRAM', 'vscode'); + + // Explicitly stub all environment variables that detectIde() checks to undefined + // This ensures no real environment variables interfere with the tests + vi.stubEnv('__COG_BASHRC_SOURCED', undefined); + vi.stubEnv('REPLIT_USER', undefined); + vi.stubEnv('CURSOR_TRACE_ID', undefined); + vi.stubEnv('CODESPACES', undefined); + vi.stubEnv('EDITOR_IN_CLOUD_SHELL', undefined); + vi.stubEnv('CLOUD_SHELL', undefined); + vi.stubEnv('TERM_PRODUCT', undefined); + vi.stubEnv('FIREBASE_DEPLOY_AGENT', undefined); + vi.stubEnv('MONOSPACE_ENV', undefined); + + // Set only the specific environment variables for this test case for (const [key, value] of Object.entries(env)) { vi.stubEnv(key, value); } + expect(detectIde()).toBe(expected); }); it('returns undefined for non-vscode', () => { + // Clear all environment variables first + vi.unstubAllEnvs(); + + // Set TERM_PROGRAM to something other than vscode vi.stubEnv('TERM_PROGRAM', 'definitely-not-vscode'); + expect(detectIde()).toBeUndefined(); }); }); diff --git a/packages/core/src/qwen/qwenContentGenerator.test.ts b/packages/core/src/qwen/qwenContentGenerator.test.ts index d2878f93..a56aed81 100644 --- a/packages/core/src/qwen/qwenContentGenerator.test.ts +++ b/packages/core/src/qwen/qwenContentGenerator.test.ts @@ -21,6 +21,7 @@ import { } from '@google/genai'; import { QwenContentGenerator } from './qwenContentGenerator.js'; import { Config } from '../config/config.js'; +import { AuthType, ContentGeneratorConfig } from '../core/contentGenerator.js'; // Mock the OpenAIContentGenerator parent class vi.mock('../core/openaiContentGenerator.js', () => ({ @@ -30,10 +31,13 @@ vi.mock('../core/openaiContentGenerator.js', () => ({ baseURL: string; }; - constructor(apiKey: string, _model: string, _config: Config) { + constructor( + contentGeneratorConfig: ContentGeneratorConfig, + _config: Config, + ) { this.client = { - apiKey, - baseURL: 'https://api.openai.com/v1', + apiKey: contentGeneratorConfig.apiKey || 'test-key', + baseURL: contentGeneratorConfig.baseUrl || 'https://api.openai.com/v1', }; } @@ -131,9 +135,13 @@ describe('QwenContentGenerator', () => { }; // Create QwenContentGenerator instance + const contentGeneratorConfig = { + model: 'qwen-turbo', + authType: AuthType.QWEN_OAUTH, + }; qwenContentGenerator = new QwenContentGenerator( mockQwenClient, - 'qwen-turbo', + contentGeneratorConfig, mockConfig, ); }); diff --git a/packages/core/src/qwen/qwenContentGenerator.ts b/packages/core/src/qwen/qwenContentGenerator.ts index f9daf5c0..4180efa2 100644 --- a/packages/core/src/qwen/qwenContentGenerator.ts +++ b/packages/core/src/qwen/qwenContentGenerator.ts @@ -20,6 +20,7 @@ import { EmbedContentParameters, EmbedContentResponse, } from '@google/genai'; +import { ContentGeneratorConfig } from '../core/contentGenerator.js'; // Default fallback base URL if no endpoint is provided const DEFAULT_QWEN_BASE_URL = @@ -36,9 +37,13 @@ export class QwenContentGenerator extends OpenAIContentGenerator { private currentEndpoint: string | null = null; private refreshPromise: Promise | null = null; - constructor(qwenClient: IQwenOAuth2Client, model: string, config: Config) { + constructor( + qwenClient: IQwenOAuth2Client, + contentGeneratorConfig: ContentGeneratorConfig, + config: Config, + ) { // Initialize with empty API key, we'll override it dynamically - super('', model, config); + super(contentGeneratorConfig, config); this.qwenClient = qwenClient; // Set default base URL, will be updated dynamically diff --git a/packages/core/src/telemetry/clearcut-logger/clearcut-logger.test.ts b/packages/core/src/telemetry/clearcut-logger/clearcut-logger.test.ts index f2ce4d19..a63a803d 100644 --- a/packages/core/src/telemetry/clearcut-logger/clearcut-logger.test.ts +++ b/packages/core/src/telemetry/clearcut-logger/clearcut-logger.test.ts @@ -205,10 +205,26 @@ describe('ClearcutLogger', () => { 'logs the current surface for as $expectedValue, preempting vscode detection', ({ env, expectedValue }) => { const { logger } = setup({}); + + // Clear all environment variables that could interfere with surface detection + vi.stubEnv('SURFACE', undefined); + vi.stubEnv('GITHUB_SHA', undefined); + vi.stubEnv('CURSOR_TRACE_ID', undefined); + vi.stubEnv('__COG_BASHRC_SOURCED', undefined); + vi.stubEnv('REPLIT_USER', undefined); + vi.stubEnv('CODESPACES', undefined); + vi.stubEnv('EDITOR_IN_CLOUD_SHELL', undefined); + vi.stubEnv('CLOUD_SHELL', undefined); + vi.stubEnv('TERM_PRODUCT', undefined); + vi.stubEnv('FIREBASE_DEPLOY_AGENT', undefined); + vi.stubEnv('MONOSPACE_ENV', undefined); + + // Set the specific environment variables for this test case for (const [key, value] of Object.entries(env)) { vi.stubEnv(key, value); } vi.stubEnv('TERM_PROGRAM', 'vscode'); + const event = logger?.createLogEvent('abc', []); expect(event?.event_metadata[0][1]).toEqual({ gemini_cli_key: EventMetadataKey.GEMINI_CLI_SURFACE, diff --git a/packages/core/src/tools/grep.test.ts b/packages/core/src/tools/grep.test.ts index dd567170..d77da29c 100644 --- a/packages/core/src/tools/grep.test.ts +++ b/packages/core/src/tools/grep.test.ts @@ -439,4 +439,84 @@ describe('GrepTool', () => { expect(invocation.getDescription()).toBe("'testPattern' within ./"); }); }); + + describe('Result limiting', () => { + beforeEach(async () => { + // Create many test files with matches to test limiting + for (let i = 1; i <= 30; i++) { + const fileName = `test${i}.txt`; + const content = `This is test file ${i} with the pattern testword in it.`; + await fs.writeFile(path.join(tempRootDir, fileName), content); + } + }); + + it('should limit results to default 20 matches', async () => { + const params: GrepToolParams = { pattern: 'testword' }; + const invocation = grepTool.build(params); + const result = await invocation.execute(abortSignal); + + expect(result.llmContent).toContain('Found 20 matches'); + expect(result.llmContent).toContain( + 'showing first 20 of 30+ total matches', + ); + expect(result.llmContent).toContain('WARNING: Results truncated'); + expect(result.returnDisplay).toContain( + 'Found 20 matches (truncated from 30+)', + ); + }); + + it('should respect custom maxResults parameter', async () => { + const params: GrepToolParams = { pattern: 'testword', maxResults: 5 }; + const invocation = grepTool.build(params); + const result = await invocation.execute(abortSignal); + + expect(result.llmContent).toContain('Found 5 matches'); + expect(result.llmContent).toContain( + 'showing first 5 of 30+ total matches', + ); + expect(result.llmContent).toContain('current: 5'); + expect(result.returnDisplay).toContain( + 'Found 5 matches (truncated from 30+)', + ); + }); + + it('should not show truncation warning when all results fit', async () => { + const params: GrepToolParams = { pattern: 'testword', maxResults: 50 }; + const invocation = grepTool.build(params); + const result = await invocation.execute(abortSignal); + + expect(result.llmContent).toContain('Found 30 matches'); + expect(result.llmContent).not.toContain('WARNING: Results truncated'); + expect(result.llmContent).not.toContain('showing first'); + expect(result.returnDisplay).toBe('Found 30 matches'); + }); + + it('should validate maxResults parameter', () => { + const invalidParams = [ + { pattern: 'test', maxResults: 0 }, + { pattern: 'test', maxResults: 101 }, + { pattern: 'test', maxResults: -1 }, + { pattern: 'test', maxResults: 1.5 }, + ]; + + invalidParams.forEach((params) => { + const error = grepTool.validateToolParams(params as GrepToolParams); + expect(error).toBeTruthy(); // Just check that validation fails + expect(error).toMatch(/maxResults|must be/); // Check it's about maxResults validation + }); + }); + + it('should accept valid maxResults parameter', () => { + const validParams = [ + { pattern: 'test', maxResults: 1 }, + { pattern: 'test', maxResults: 50 }, + { pattern: 'test', maxResults: 100 }, + ]; + + validParams.forEach((params) => { + const error = grepTool.validateToolParams(params); + expect(error).toBeNull(); + }); + }); + }); }); diff --git a/packages/core/src/tools/grep.ts b/packages/core/src/tools/grep.ts index e5b834d4..d78cb04a 100644 --- a/packages/core/src/tools/grep.ts +++ b/packages/core/src/tools/grep.ts @@ -43,6 +43,11 @@ export interface GrepToolParams { * File pattern to include in the search (e.g. "*.js", "*.{ts,tsx}") */ include?: string; + + /** + * Maximum number of matches to return (optional, defaults to 20) + */ + maxResults?: number; } /** @@ -124,6 +129,10 @@ class GrepToolInvocation extends BaseToolInvocation< // Collect matches from all search directories let allMatches: GrepMatch[] = []; + const maxResults = this.params.maxResults ?? 20; // Default to 20 results + let totalMatchesFound = 0; + let searchTruncated = false; + for (const searchDir of searchDirectories) { const matches = await this.performGrepSearch({ pattern: this.params.pattern, @@ -132,6 +141,8 @@ class GrepToolInvocation extends BaseToolInvocation< signal, }); + totalMatchesFound += matches.length; + // Add directory prefix if searching multiple directories if (searchDirectories.length > 1) { const dirName = path.basename(searchDir); @@ -140,7 +151,20 @@ class GrepToolInvocation extends BaseToolInvocation< }); } - allMatches = allMatches.concat(matches); + // Apply result limiting + const remainingSlots = maxResults - allMatches.length; + if (remainingSlots <= 0) { + searchTruncated = true; + break; + } + + if (matches.length > remainingSlots) { + allMatches = allMatches.concat(matches.slice(0, remainingSlots)); + searchTruncated = true; + break; + } else { + allMatches = allMatches.concat(matches); + } } let searchLocationDescription: string; @@ -176,7 +200,14 @@ class GrepToolInvocation extends BaseToolInvocation< const matchCount = allMatches.length; const matchTerm = matchCount === 1 ? 'match' : 'matches'; - let llmContent = `Found ${matchCount} ${matchTerm} for pattern "${this.params.pattern}" ${searchLocationDescription}${this.params.include ? ` (filter: "${this.params.include}")` : ''}: + // Build the header with truncation info if needed + let headerText = `Found ${matchCount} ${matchTerm} for pattern "${this.params.pattern}" ${searchLocationDescription}${this.params.include ? ` (filter: "${this.params.include}")` : ''}`; + + if (searchTruncated) { + headerText += ` (showing first ${matchCount} of ${totalMatchesFound}+ total matches)`; + } + + let llmContent = `${headerText}: --- `; @@ -189,9 +220,23 @@ class GrepToolInvocation extends BaseToolInvocation< llmContent += '---\n'; } + // Add truncation guidance if results were limited + if (searchTruncated) { + llmContent += `\nWARNING: Results truncated to prevent context overflow. To see more results: +- Use a more specific pattern to reduce matches +- Add file filters with the 'include' parameter (e.g., "*.js", "src/**") +- Specify a narrower 'path' to search in a subdirectory +- Increase 'maxResults' parameter if you need more matches (current: ${maxResults})`; + } + + let displayText = `Found ${matchCount} ${matchTerm}`; + if (searchTruncated) { + displayText += ` (truncated from ${totalMatchesFound}+)`; + } + return { llmContent: llmContent.trim(), - returnDisplay: `Found ${matchCount} ${matchTerm}`, + returnDisplay: displayText, }; } catch (error) { console.error(`Error during GrepLogic execution: ${error}`); @@ -567,6 +612,13 @@ export class GrepTool extends BaseDeclarativeTool { "Optional: A glob pattern to filter which files are searched (e.g., '*.js', '*.{ts,tsx}', 'src/**'). If omitted, searches all files (respecting potential global ignores).", type: 'string', }, + maxResults: { + description: + 'Optional: Maximum number of matches to return to prevent context overflow (default: 20, max: 100). Use lower values for broad searches, higher for specific searches.', + type: 'number', + minimum: 1, + maximum: 100, + }, }, required: ['pattern'], type: 'object', @@ -635,6 +687,17 @@ export class GrepTool extends BaseDeclarativeTool { return `Invalid regular expression pattern provided: ${params.pattern}. Error: ${getErrorMessage(error)}`; } + // Validate maxResults if provided + if (params.maxResults !== undefined) { + if ( + !Number.isInteger(params.maxResults) || + params.maxResults < 1 || + params.maxResults > 100 + ) { + return `maxResults must be an integer between 1 and 100, got: ${params.maxResults}`; + } + } + // Only validate path if one is provided if (params.path) { try {