mirror of
https://github.com/QwenLM/qwen-code.git
synced 2026-01-09 18:39:12 +00:00
Compare commits
13 Commits
feat/gemin
...
feat/skill
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
251031cfc5 | ||
|
|
77c257d9d0 | ||
|
|
4311af96eb | ||
|
|
b49c11e9a2 | ||
|
|
9cdd85c62a | ||
|
|
00547ba439 | ||
|
|
fc1dac9dc7 | ||
|
|
338eb9038d | ||
|
|
e0b9044833 | ||
|
|
f33f43e2f7 | ||
|
|
177fc42f04 | ||
|
|
2560c2d1a2 | ||
|
|
bd6e16d41b |
9
.github/workflows/e2e.yml
vendored
9
.github/workflows/e2e.yml
vendored
@@ -18,8 +18,6 @@ jobs:
|
||||
- 'sandbox:docker'
|
||||
node-version:
|
||||
- '20.x'
|
||||
- '22.x'
|
||||
- '24.x'
|
||||
steps:
|
||||
- name: 'Checkout'
|
||||
uses: 'actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8' # ratchet:actions/checkout@v5
|
||||
@@ -67,10 +65,13 @@ jobs:
|
||||
OPENAI_BASE_URL: '${{ secrets.OPENAI_BASE_URL }}'
|
||||
OPENAI_MODEL: '${{ secrets.OPENAI_MODEL }}'
|
||||
KEEP_OUTPUT: 'true'
|
||||
SANDBOX: '${{ matrix.sandbox }}'
|
||||
VERBOSE: 'true'
|
||||
run: |-
|
||||
npm run "test:integration:${SANDBOX}"
|
||||
if [[ "${{ matrix.sandbox }}" == "sandbox:docker" ]]; then
|
||||
npm run test:integration:sandbox:docker
|
||||
else
|
||||
npm run test:integration:sandbox:none
|
||||
fi
|
||||
|
||||
e2e-test-macos:
|
||||
name: 'E2E Test - macOS'
|
||||
|
||||
@@ -43,6 +43,7 @@ Qwen Code uses JSON settings files for persistent configuration. There are four
|
||||
In addition to a project settings file, a project's `.qwen` directory can contain other project-specific files related to Qwen Code's operation, such as:
|
||||
|
||||
- [Custom sandbox profiles](../features/sandbox) (e.g. `.qwen/sandbox-macos-custom.sb`, `.qwen/sandbox.Dockerfile`).
|
||||
- [Agent Skills](../features/skills) (experimental) under `.qwen/skills/` (each Skill is a directory containing a `SKILL.md`).
|
||||
|
||||
### Available settings in `settings.json`
|
||||
|
||||
@@ -380,6 +381,8 @@ Arguments passed directly when running the CLI can override other configurations
|
||||
| `--telemetry-otlp-protocol` | | Sets the OTLP protocol for telemetry (`grpc` or `http`). | | Defaults to `grpc`. See [telemetry](../../developers/development/telemetry) for more information. |
|
||||
| `--telemetry-log-prompts` | | Enables logging of prompts for telemetry. | | See [telemetry](../../developers/development/telemetry) for more information. |
|
||||
| `--checkpointing` | | Enables [checkpointing](../features/checkpointing). | | |
|
||||
| `--experimental-acp` | | Enables ACP mode (Agent Control Protocol). Useful for IDE/editor integrations like [Zed](../integration-zed). | | Experimental. |
|
||||
| `--experimental-skills` | | Enables experimental [Agent Skills](../features/skills) (registers the `skill` tool and loads Skills from `.qwen/skills/` and `~/.qwen/skills/`). | | Experimental. |
|
||||
| `--extensions` | `-e` | Specifies a list of extensions to use for the session. | Extension names | If not provided, all available extensions are used. Use the special term `qwen -e none` to disable all extensions. Example: `qwen -e my-extension -e my-other-extension` |
|
||||
| `--list-extensions` | `-l` | Lists all available extensions and exits. | | |
|
||||
| `--proxy` | | Sets the proxy for the CLI. | Proxy URL | Example: `--proxy http://localhost:7890`. |
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
export default {
|
||||
commands: 'Commands',
|
||||
'sub-agents': 'SubAgents',
|
||||
skills: 'Skills (Experimental)',
|
||||
headless: 'Headless Mode',
|
||||
checkpointing: {
|
||||
display: 'hidden',
|
||||
|
||||
@@ -189,19 +189,20 @@ qwen -p "Write code" --output-format stream-json --include-partial-messages | jq
|
||||
|
||||
Key command-line options for headless usage:
|
||||
|
||||
| Option | Description | Example |
|
||||
| ---------------------------- | --------------------------------------------------- | ------------------------------------------------------------------------ |
|
||||
| `--prompt`, `-p` | Run in headless mode | `qwen -p "query"` |
|
||||
| `--output-format`, `-o` | Specify output format (text, json, stream-json) | `qwen -p "query" --output-format json` |
|
||||
| `--input-format` | Specify input format (text, stream-json) | `qwen --input-format text --output-format stream-json` |
|
||||
| `--include-partial-messages` | Include partial messages in stream-json output | `qwen -p "query" --output-format stream-json --include-partial-messages` |
|
||||
| `--debug`, `-d` | Enable debug mode | `qwen -p "query" --debug` |
|
||||
| `--all-files`, `-a` | Include all files in context | `qwen -p "query" --all-files` |
|
||||
| `--include-directories` | Include additional directories | `qwen -p "query" --include-directories src,docs` |
|
||||
| `--yolo`, `-y` | Auto-approve all actions | `qwen -p "query" --yolo` |
|
||||
| `--approval-mode` | Set approval mode | `qwen -p "query" --approval-mode auto_edit` |
|
||||
| `--continue` | Resume the most recent session for this project | `qwen --continue -p "Pick up where we left off"` |
|
||||
| `--resume [sessionId]` | Resume a specific session (or choose interactively) | `qwen --resume 123e... -p "Finish the refactor"` |
|
||||
| Option | Description | Example |
|
||||
| ---------------------------- | ------------------------------------------------------- | ------------------------------------------------------------------------ |
|
||||
| `--prompt`, `-p` | Run in headless mode | `qwen -p "query"` |
|
||||
| `--output-format`, `-o` | Specify output format (text, json, stream-json) | `qwen -p "query" --output-format json` |
|
||||
| `--input-format` | Specify input format (text, stream-json) | `qwen --input-format text --output-format stream-json` |
|
||||
| `--include-partial-messages` | Include partial messages in stream-json output | `qwen -p "query" --output-format stream-json --include-partial-messages` |
|
||||
| `--debug`, `-d` | Enable debug mode | `qwen -p "query" --debug` |
|
||||
| `--all-files`, `-a` | Include all files in context | `qwen -p "query" --all-files` |
|
||||
| `--include-directories` | Include additional directories | `qwen -p "query" --include-directories src,docs` |
|
||||
| `--yolo`, `-y` | Auto-approve all actions | `qwen -p "query" --yolo` |
|
||||
| `--approval-mode` | Set approval mode | `qwen -p "query" --approval-mode auto_edit` |
|
||||
| `--continue` | Resume the most recent session for this project | `qwen --continue -p "Pick up where we left off"` |
|
||||
| `--resume [sessionId]` | Resume a specific session (or choose interactively) | `qwen --resume 123e... -p "Finish the refactor"` |
|
||||
| `--experimental-skills` | Enable experimental Skills (registers the `skill` tool) | `qwen --experimental-skills -p "What Skills are available?"` |
|
||||
|
||||
For complete details on all available configuration options, settings files, and environment variables, see the [Configuration Guide](../configuration/settings).
|
||||
|
||||
|
||||
282
docs/users/features/skills.md
Normal file
282
docs/users/features/skills.md
Normal file
@@ -0,0 +1,282 @@
|
||||
# Agent Skills (Experimental)
|
||||
|
||||
> Create, manage, and share Skills to extend Qwen Code’s capabilities.
|
||||
|
||||
This guide shows you how to create, use, and manage Agent Skills in **Qwen Code**. Skills are modular capabilities that extend the model’s effectiveness through organized folders containing instructions (and optionally scripts/resources).
|
||||
|
||||
> [!note]
|
||||
>
|
||||
> Skills are currently **experimental** and must be enabled with `--experimental-skills`.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
- Qwen Code (recent version)
|
||||
- Run with the experimental flag enabled:
|
||||
|
||||
```bash
|
||||
qwen --experimental-skills
|
||||
```
|
||||
|
||||
- Basic familiarity with Qwen Code ([Quickstart](../quickstart.md))
|
||||
|
||||
## What are Agent Skills?
|
||||
|
||||
Agent Skills package expertise into discoverable capabilities. Each Skill consists of a `SKILL.md` file with instructions that the model can load when relevant, plus optional supporting files like scripts and templates.
|
||||
|
||||
### How Skills are invoked
|
||||
|
||||
Skills are **model-invoked** — the model autonomously decides when to use them based on your request and the Skill’s description. This is different from slash commands, which are **user-invoked** (you explicitly type `/command`).
|
||||
|
||||
### Benefits
|
||||
|
||||
- Extend Qwen Code for your workflows
|
||||
- Share expertise across your team via git
|
||||
- Reduce repetitive prompting
|
||||
- Compose multiple Skills for complex tasks
|
||||
|
||||
## Create a Skill
|
||||
|
||||
Skills are stored as directories containing a `SKILL.md` file.
|
||||
|
||||
### Personal Skills
|
||||
|
||||
Personal Skills are available across all your projects. Store them in `~/.qwen/skills/`:
|
||||
|
||||
```bash
|
||||
mkdir -p ~/.qwen/skills/my-skill-name
|
||||
```
|
||||
|
||||
Use personal Skills for:
|
||||
|
||||
- Your individual workflows and preferences
|
||||
- Experimental Skills you’re developing
|
||||
- Personal productivity helpers
|
||||
|
||||
### Project Skills
|
||||
|
||||
Project Skills are shared with your team. Store them in `.qwen/skills/` within your project:
|
||||
|
||||
```bash
|
||||
mkdir -p .qwen/skills/my-skill-name
|
||||
```
|
||||
|
||||
Use project Skills for:
|
||||
|
||||
- Team workflows and conventions
|
||||
- Project-specific expertise
|
||||
- Shared utilities and scripts
|
||||
|
||||
Project Skills can be checked into git and automatically become available to teammates.
|
||||
|
||||
## Write `SKILL.md`
|
||||
|
||||
Create a `SKILL.md` file with YAML frontmatter and Markdown content:
|
||||
|
||||
```yaml
|
||||
---
|
||||
name: your-skill-name
|
||||
description: Brief description of what this Skill does and when to use it
|
||||
---
|
||||
|
||||
# Your Skill Name
|
||||
|
||||
## Instructions
|
||||
Provide clear, step-by-step guidance for Qwen Code.
|
||||
|
||||
## Examples
|
||||
Show concrete examples of using this Skill.
|
||||
```
|
||||
|
||||
### Field requirements
|
||||
|
||||
Qwen Code currently validates that:
|
||||
|
||||
- `name` is a non-empty string
|
||||
- `description` is a non-empty string
|
||||
|
||||
Recommended conventions (not strictly enforced yet):
|
||||
|
||||
- Use lowercase letters, numbers, and hyphens in `name`
|
||||
- Make `description` specific: include both **what** the Skill does and **when** to use it (key words users will naturally mention)
|
||||
|
||||
## Add supporting files
|
||||
|
||||
Create additional files alongside `SKILL.md`:
|
||||
|
||||
```text
|
||||
my-skill/
|
||||
├── SKILL.md (required)
|
||||
├── reference.md (optional documentation)
|
||||
├── examples.md (optional examples)
|
||||
├── scripts/
|
||||
│ └── helper.py (optional utility)
|
||||
└── templates/
|
||||
└── template.txt (optional template)
|
||||
```
|
||||
|
||||
Reference these files from `SKILL.md`:
|
||||
|
||||
````markdown
|
||||
For advanced usage, see [reference.md](reference.md).
|
||||
|
||||
Run the helper script:
|
||||
|
||||
```bash
|
||||
python scripts/helper.py input.txt
|
||||
```
|
||||
````
|
||||
|
||||
## View available Skills
|
||||
|
||||
When `--experimental-skills` is enabled, Qwen Code discovers Skills from:
|
||||
|
||||
- Personal Skills: `~/.qwen/skills/`
|
||||
- Project Skills: `.qwen/skills/`
|
||||
|
||||
To view available Skills, ask Qwen Code directly:
|
||||
|
||||
```text
|
||||
What Skills are available?
|
||||
```
|
||||
|
||||
Or inspect the filesystem:
|
||||
|
||||
```bash
|
||||
# List personal Skills
|
||||
ls ~/.qwen/skills/
|
||||
|
||||
# List project Skills (if in a project directory)
|
||||
ls .qwen/skills/
|
||||
|
||||
# View a specific Skill’s content
|
||||
cat ~/.qwen/skills/my-skill/SKILL.md
|
||||
```
|
||||
|
||||
## Test a Skill
|
||||
|
||||
After creating a Skill, test it by asking questions that match your description.
|
||||
|
||||
Example: if your description mentions “PDF files”:
|
||||
|
||||
```text
|
||||
Can you help me extract text from this PDF?
|
||||
```
|
||||
|
||||
The model autonomously decides to use your Skill if it matches the request — you don’t need to explicitly invoke it.
|
||||
|
||||
## Debug a Skill
|
||||
|
||||
If Qwen Code doesn’t use your Skill, check these common issues:
|
||||
|
||||
### Make the description specific
|
||||
|
||||
Too vague:
|
||||
|
||||
```yaml
|
||||
description: Helps with documents
|
||||
```
|
||||
|
||||
Specific:
|
||||
|
||||
```yaml
|
||||
description: Extract text and tables from PDF files, fill forms, merge documents. Use when working with PDFs, forms, or document extraction.
|
||||
```
|
||||
|
||||
### Verify file path
|
||||
|
||||
- Personal Skills: `~/.qwen/skills/<skill-name>/SKILL.md`
|
||||
- Project Skills: `.qwen/skills/<skill-name>/SKILL.md`
|
||||
|
||||
```bash
|
||||
# Personal
|
||||
ls ~/.qwen/skills/my-skill/SKILL.md
|
||||
|
||||
# Project
|
||||
ls .qwen/skills/my-skill/SKILL.md
|
||||
```
|
||||
|
||||
### Check YAML syntax
|
||||
|
||||
Invalid YAML prevents the Skill metadata from loading correctly.
|
||||
|
||||
```bash
|
||||
cat SKILL.md | head -n 15
|
||||
```
|
||||
|
||||
Ensure:
|
||||
|
||||
- Opening `---` on line 1
|
||||
- Closing `---` before Markdown content
|
||||
- Valid YAML syntax (no tabs, correct indentation)
|
||||
|
||||
### View errors
|
||||
|
||||
Run Qwen Code with debug mode to see Skill loading errors:
|
||||
|
||||
```bash
|
||||
qwen --experimental-skills --debug
|
||||
```
|
||||
|
||||
## Share Skills with your team
|
||||
|
||||
You can share Skills through project repositories:
|
||||
|
||||
1. Add the Skill under `.qwen/skills/`
|
||||
2. Commit and push
|
||||
3. Teammates pull the changes and run with `--experimental-skills`
|
||||
|
||||
```bash
|
||||
git add .qwen/skills/
|
||||
git commit -m "Add team Skill for PDF processing"
|
||||
git push
|
||||
```
|
||||
|
||||
## Update a Skill
|
||||
|
||||
Edit `SKILL.md` directly:
|
||||
|
||||
```bash
|
||||
# Personal Skill
|
||||
code ~/.qwen/skills/my-skill/SKILL.md
|
||||
|
||||
# Project Skill
|
||||
code .qwen/skills/my-skill/SKILL.md
|
||||
```
|
||||
|
||||
Changes take effect the next time you start Qwen Code. If Qwen Code is already running, restart it to load the updates.
|
||||
|
||||
## Remove a Skill
|
||||
|
||||
Delete the Skill directory:
|
||||
|
||||
```bash
|
||||
# Personal
|
||||
rm -rf ~/.qwen/skills/my-skill
|
||||
|
||||
# Project
|
||||
rm -rf .qwen/skills/my-skill
|
||||
git commit -m "Remove unused Skill"
|
||||
```
|
||||
|
||||
## Best practices
|
||||
|
||||
### Keep Skills focused
|
||||
|
||||
One Skill should address one capability:
|
||||
|
||||
- Focused: “PDF form filling”, “Excel analysis”, “Git commit messages”
|
||||
- Too broad: “Document processing” (split into smaller Skills)
|
||||
|
||||
### Write clear descriptions
|
||||
|
||||
Help the model discover when to use Skills by including specific triggers:
|
||||
|
||||
```yaml
|
||||
description: Analyze Excel spreadsheets, create pivot tables, and generate charts. Use when working with Excel files, spreadsheets, or .xlsx data.
|
||||
```
|
||||
|
||||
### Test with your team
|
||||
|
||||
- Does the Skill activate when expected?
|
||||
- Are the instructions clear?
|
||||
- Are there missing examples or edge cases?
|
||||
2007
package-lock.json
generated
2007
package-lock.json
generated
File diff suppressed because it is too large
Load Diff
@@ -18,6 +18,9 @@
|
||||
"scripts": {
|
||||
"start": "cross-env node scripts/start.js",
|
||||
"debug": "cross-env DEBUG=1 node --inspect-brk scripts/start.js",
|
||||
"auth:npm": "npx google-artifactregistry-auth",
|
||||
"auth:docker": "gcloud auth configure-docker us-west1-docker.pkg.dev",
|
||||
"auth": "npm run auth:npm && npm run auth:docker",
|
||||
"generate": "node scripts/generate-git-commit-info.js",
|
||||
"build": "node scripts/build.js",
|
||||
"build-and-start": "npm run build && npm run start",
|
||||
@@ -92,6 +95,7 @@
|
||||
"eslint-plugin-react-hooks": "^5.2.0",
|
||||
"glob": "^10.5.0",
|
||||
"globals": "^16.0.0",
|
||||
"google-artifactregistry-auth": "^3.4.0",
|
||||
"husky": "^9.1.7",
|
||||
"json": "^11.0.0",
|
||||
"lint-staged": "^16.1.6",
|
||||
|
||||
@@ -36,10 +36,10 @@
|
||||
"sandboxImageUri": "ghcr.io/qwenlm/qwen-code:0.6.0"
|
||||
},
|
||||
"dependencies": {
|
||||
"@google/genai": "1.30.0",
|
||||
"@google/genai": "1.16.0",
|
||||
"@iarna/toml": "^2.2.5",
|
||||
"@qwen-code/qwen-code-core": "file:../core",
|
||||
"@modelcontextprotocol/sdk": "^1.25.1",
|
||||
"@modelcontextprotocol/sdk": "^1.15.1",
|
||||
"@types/update-notifier": "^6.0.8",
|
||||
"ansi-regex": "^6.2.2",
|
||||
"command-exists": "^1.2.9",
|
||||
|
||||
@@ -26,23 +26,5 @@ export function validateAuthMethod(authMethod: string): string | null {
|
||||
return null;
|
||||
}
|
||||
|
||||
if (authMethod === AuthType.USE_GEMINI) {
|
||||
const hasApiKey = process.env['GEMINI_API_KEY'];
|
||||
if (!hasApiKey) {
|
||||
return 'GEMINI_API_KEY environment variable not found. Please set it in your .env file or environment variables.';
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
if (authMethod === AuthType.USE_VERTEX_AI) {
|
||||
const hasApiKey = process.env['GOOGLE_API_KEY'];
|
||||
if (!hasApiKey) {
|
||||
return 'GOOGLE_API_KEY environment variable not found. Please set it in your .env file or environment variables.';
|
||||
}
|
||||
|
||||
process.env['GOOGLE_GENAI_USE_VERTEXAI'] = 'true';
|
||||
return null;
|
||||
}
|
||||
|
||||
return 'Invalid auth method selected.';
|
||||
}
|
||||
|
||||
@@ -112,6 +112,7 @@ export interface CliArgs {
|
||||
allowedMcpServerNames: string[] | undefined;
|
||||
allowedTools: string[] | undefined;
|
||||
experimentalAcp: boolean | undefined;
|
||||
experimentalSkills: boolean | undefined;
|
||||
extensions: string[] | undefined;
|
||||
listExtensions: boolean | undefined;
|
||||
openaiLogging: boolean | undefined;
|
||||
@@ -307,6 +308,11 @@ export async function parseArguments(settings: Settings): Promise<CliArgs> {
|
||||
type: 'boolean',
|
||||
description: 'Starts the agent in ACP mode',
|
||||
})
|
||||
.option('experimental-skills', {
|
||||
type: 'boolean',
|
||||
description: 'Enable experimental Skills feature',
|
||||
default: false,
|
||||
})
|
||||
.option('channel', {
|
||||
type: 'string',
|
||||
choices: ['VSCode', 'ACP', 'SDK', 'CI'],
|
||||
@@ -460,12 +466,7 @@ export async function parseArguments(settings: Settings): Promise<CliArgs> {
|
||||
})
|
||||
.option('auth-type', {
|
||||
type: 'string',
|
||||
choices: [
|
||||
AuthType.USE_OPENAI,
|
||||
AuthType.QWEN_OAUTH,
|
||||
AuthType.USE_GEMINI,
|
||||
AuthType.USE_VERTEX_AI,
|
||||
],
|
||||
choices: [AuthType.USE_OPENAI, AuthType.QWEN_OAUTH],
|
||||
description: 'Authentication type',
|
||||
})
|
||||
.deprecateOption(
|
||||
@@ -956,6 +957,7 @@ export async function loadCliConfig(
|
||||
maxSessionTurns:
|
||||
argv.maxSessionTurns ?? settings.model?.maxSessionTurns ?? -1,
|
||||
experimentalZedIntegration: argv.experimentalAcp || false,
|
||||
experimentalSkills: argv.experimentalSkills || false,
|
||||
listExtensions: argv.listExtensions || false,
|
||||
extensions: allExtensions,
|
||||
blockedMcpServers,
|
||||
|
||||
@@ -461,6 +461,7 @@ describe('gemini.tsx main function kitty protocol', () => {
|
||||
allowedMcpServerNames: undefined,
|
||||
allowedTools: undefined,
|
||||
experimentalAcp: undefined,
|
||||
experimentalSkills: undefined,
|
||||
extensions: undefined,
|
||||
listExtensions: undefined,
|
||||
openaiLogging: undefined,
|
||||
|
||||
@@ -4,8 +4,13 @@
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import type { Config, AuthType } from '@qwen-code/qwen-code-core';
|
||||
import { InputFormat, logUserPrompt } from '@qwen-code/qwen-code-core';
|
||||
import type { Config } from '@qwen-code/qwen-code-core';
|
||||
import {
|
||||
AuthType,
|
||||
getOauthClient,
|
||||
InputFormat,
|
||||
logUserPrompt,
|
||||
} from '@qwen-code/qwen-code-core';
|
||||
import { render } from 'ink';
|
||||
import dns from 'node:dns';
|
||||
import os from 'node:os';
|
||||
@@ -394,6 +399,15 @@ export async function main() {
|
||||
initializationResult = await initializeApp(config, settings);
|
||||
}
|
||||
|
||||
if (
|
||||
settings.merged.security?.auth?.selectedType ===
|
||||
AuthType.LOGIN_WITH_GOOGLE &&
|
||||
config.isBrowserLaunchSuppressed()
|
||||
) {
|
||||
// Do oauth before app renders to make copying the link possible.
|
||||
await getOauthClient(settings.merged.security.auth.selectedType, config);
|
||||
}
|
||||
|
||||
if (config.getExperimentalZedIntegration()) {
|
||||
return runAcpAgent(config, settings, extensions, argv);
|
||||
}
|
||||
|
||||
@@ -610,6 +610,8 @@ export abstract class BaseJsonOutputAdapter {
|
||||
const errorText = parseAndFormatApiError(
|
||||
event.value.error,
|
||||
this.config.getContentGeneratorConfig()?.authType,
|
||||
undefined,
|
||||
this.config.getModel(),
|
||||
);
|
||||
this.appendText(state, errorText, null);
|
||||
break;
|
||||
|
||||
@@ -221,6 +221,8 @@ export async function runNonInteractive(
|
||||
const errorText = parseAndFormatApiError(
|
||||
event.value.error,
|
||||
config.getContentGeneratorConfig()?.authType,
|
||||
undefined,
|
||||
config.getModel(),
|
||||
);
|
||||
process.stderr.write(`${errorText}\n`);
|
||||
}
|
||||
|
||||
@@ -28,7 +28,7 @@ const mockPrompt = {
|
||||
{ name: 'trail', required: false, description: "The animal's trail." },
|
||||
],
|
||||
invoke: vi.fn().mockResolvedValue({
|
||||
messages: [{ content: { type: 'text', text: 'Hello, world!' } }],
|
||||
messages: [{ content: { text: 'Hello, world!' } }],
|
||||
}),
|
||||
};
|
||||
|
||||
|
||||
@@ -123,10 +123,7 @@ export class McpPromptLoader implements ICommandLoader {
|
||||
};
|
||||
}
|
||||
|
||||
const firstMessage = result.messages?.[0];
|
||||
const content = firstMessage?.content;
|
||||
|
||||
if (content?.type !== 'text') {
|
||||
if (!result.messages?.[0]?.content?.['text']) {
|
||||
return {
|
||||
type: 'message',
|
||||
messageType: 'error',
|
||||
@@ -137,7 +134,7 @@ export class McpPromptLoader implements ICommandLoader {
|
||||
|
||||
return {
|
||||
type: 'submit_prompt',
|
||||
content: JSON.stringify(content.text),
|
||||
content: JSON.stringify(result.messages[0].content.text),
|
||||
};
|
||||
} catch (error) {
|
||||
return {
|
||||
|
||||
@@ -23,6 +23,7 @@ import {
|
||||
} from '@qwen-code/qwen-code-core';
|
||||
import type { LoadedSettings } from '../config/settings.js';
|
||||
import type { InitializationResult } from '../core/initializer.js';
|
||||
import { useQuotaAndFallback } from './hooks/useQuotaAndFallback.js';
|
||||
import { UIStateContext, type UIState } from './contexts/UIStateContext.js';
|
||||
import {
|
||||
UIActionsContext,
|
||||
@@ -55,6 +56,7 @@ vi.mock('./App.js', () => ({
|
||||
App: TestContextConsumer,
|
||||
}));
|
||||
|
||||
vi.mock('./hooks/useQuotaAndFallback.js');
|
||||
vi.mock('./hooks/useHistoryManager.js');
|
||||
vi.mock('./hooks/useThemeCommand.js');
|
||||
vi.mock('./auth/useAuth.js');
|
||||
@@ -120,6 +122,7 @@ describe('AppContainer State Management', () => {
|
||||
let mockInitResult: InitializationResult;
|
||||
|
||||
// Create typed mocks for all hooks
|
||||
const mockedUseQuotaAndFallback = useQuotaAndFallback as Mock;
|
||||
const mockedUseHistory = useHistory as Mock;
|
||||
const mockedUseThemeCommand = useThemeCommand as Mock;
|
||||
const mockedUseAuthCommand = useAuthCommand as Mock;
|
||||
@@ -161,6 +164,10 @@ describe('AppContainer State Management', () => {
|
||||
capturedUIActions = null!;
|
||||
|
||||
// **Provide a default return value for EVERY mocked hook.**
|
||||
mockedUseQuotaAndFallback.mockReturnValue({
|
||||
proQuotaRequest: null,
|
||||
handleProQuotaChoice: vi.fn(),
|
||||
});
|
||||
mockedUseHistory.mockReturnValue({
|
||||
history: [],
|
||||
addItem: vi.fn(),
|
||||
@@ -560,6 +567,75 @@ describe('AppContainer State Management', () => {
|
||||
});
|
||||
});
|
||||
|
||||
describe('Quota and Fallback Integration', () => {
|
||||
it('passes a null proQuotaRequest to UIStateContext by default', () => {
|
||||
// The default mock from beforeEach already sets proQuotaRequest to null
|
||||
render(
|
||||
<AppContainer
|
||||
config={mockConfig}
|
||||
settings={mockSettings}
|
||||
version="1.0.0"
|
||||
initializationResult={mockInitResult}
|
||||
/>,
|
||||
);
|
||||
|
||||
// Assert that the context value is as expected
|
||||
expect(capturedUIState.proQuotaRequest).toBeNull();
|
||||
});
|
||||
|
||||
it('passes a valid proQuotaRequest to UIStateContext when provided by the hook', () => {
|
||||
// Arrange: Create a mock request object that a UI dialog would receive
|
||||
const mockRequest = {
|
||||
failedModel: 'gemini-pro',
|
||||
fallbackModel: 'gemini-flash',
|
||||
resolve: vi.fn(),
|
||||
};
|
||||
mockedUseQuotaAndFallback.mockReturnValue({
|
||||
proQuotaRequest: mockRequest,
|
||||
handleProQuotaChoice: vi.fn(),
|
||||
});
|
||||
|
||||
// Act: Render the container
|
||||
render(
|
||||
<AppContainer
|
||||
config={mockConfig}
|
||||
settings={mockSettings}
|
||||
version="1.0.0"
|
||||
initializationResult={mockInitResult}
|
||||
/>,
|
||||
);
|
||||
|
||||
// Assert: The mock request is correctly passed through the context
|
||||
expect(capturedUIState.proQuotaRequest).toEqual(mockRequest);
|
||||
});
|
||||
|
||||
it('passes the handleProQuotaChoice function to UIActionsContext', () => {
|
||||
// Arrange: Create a mock handler function
|
||||
const mockHandler = vi.fn();
|
||||
mockedUseQuotaAndFallback.mockReturnValue({
|
||||
proQuotaRequest: null,
|
||||
handleProQuotaChoice: mockHandler,
|
||||
});
|
||||
|
||||
// Act: Render the container
|
||||
render(
|
||||
<AppContainer
|
||||
config={mockConfig}
|
||||
settings={mockSettings}
|
||||
version="1.0.0"
|
||||
initializationResult={mockInitResult}
|
||||
/>,
|
||||
);
|
||||
|
||||
// Assert: The action in the context is the mock handler we provided
|
||||
expect(capturedUIActions.handleProQuotaChoice).toBe(mockHandler);
|
||||
|
||||
// You can even verify that the plumbed function is callable
|
||||
capturedUIActions.handleProQuotaChoice('auth');
|
||||
expect(mockHandler).toHaveBeenCalledWith('auth');
|
||||
});
|
||||
});
|
||||
|
||||
describe('Terminal Title Update Feature', () => {
|
||||
beforeEach(() => {
|
||||
// Reset mock stdout for each test
|
||||
|
||||
@@ -32,6 +32,7 @@ import {
|
||||
type Config,
|
||||
type IdeInfo,
|
||||
type IdeContext,
|
||||
type UserTierId,
|
||||
DEFAULT_GEMINI_FLASH_MODEL,
|
||||
IdeClient,
|
||||
ideContextStore,
|
||||
@@ -47,6 +48,7 @@ import { useHistory } from './hooks/useHistoryManager.js';
|
||||
import { useMemoryMonitor } from './hooks/useMemoryMonitor.js';
|
||||
import { useThemeCommand } from './hooks/useThemeCommand.js';
|
||||
import { useAuthCommand } from './auth/useAuth.js';
|
||||
import { useQuotaAndFallback } from './hooks/useQuotaAndFallback.js';
|
||||
import { useEditorSettings } from './hooks/useEditorSettings.js';
|
||||
import { useSettingsCommand } from './hooks/useSettingsCommand.js';
|
||||
import { useModelCommand } from './hooks/useModelCommand.js';
|
||||
@@ -190,6 +192,8 @@ export const AppContainer = (props: AppContainerProps) => {
|
||||
|
||||
const [currentModel, setCurrentModel] = useState(getEffectiveModel());
|
||||
|
||||
const [userTier] = useState<UserTierId | undefined>(undefined);
|
||||
|
||||
const [isConfigInitialized, setConfigInitialized] = useState(false);
|
||||
|
||||
const [userMessages, setUserMessages] = useState<string[]>([]);
|
||||
@@ -363,6 +367,14 @@ export const AppContainer = (props: AppContainerProps) => {
|
||||
cancelAuthentication,
|
||||
} = useAuthCommand(settings, config, historyManager.addItem);
|
||||
|
||||
const { proQuotaRequest, handleProQuotaChoice } = useQuotaAndFallback({
|
||||
config,
|
||||
historyManager,
|
||||
userTier,
|
||||
setAuthState,
|
||||
setModelSwitchedFromQuotaError,
|
||||
});
|
||||
|
||||
useInitializationAuthError(initializationResult.authError, onAuthError);
|
||||
|
||||
// Sync user tier from config when authentication changes
|
||||
@@ -740,7 +752,8 @@ export const AppContainer = (props: AppContainerProps) => {
|
||||
!initError &&
|
||||
!isProcessing &&
|
||||
(streamingState === StreamingState.Idle ||
|
||||
streamingState === StreamingState.Responding);
|
||||
streamingState === StreamingState.Responding) &&
|
||||
!proQuotaRequest;
|
||||
|
||||
const [controlsHeight, setControlsHeight] = useState(0);
|
||||
|
||||
@@ -1193,6 +1206,7 @@ export const AppContainer = (props: AppContainerProps) => {
|
||||
isAuthenticating ||
|
||||
isEditorDialogOpen ||
|
||||
showIdeRestartPrompt ||
|
||||
!!proQuotaRequest ||
|
||||
isSubagentCreateDialogOpen ||
|
||||
isAgentsManagerDialogOpen ||
|
||||
isApprovalModeDialogOpen ||
|
||||
@@ -1263,6 +1277,8 @@ export const AppContainer = (props: AppContainerProps) => {
|
||||
showWorkspaceMigrationDialog,
|
||||
workspaceExtensions,
|
||||
currentModel,
|
||||
userTier,
|
||||
proQuotaRequest,
|
||||
contextFileNames,
|
||||
errorCount,
|
||||
availableTerminalHeight,
|
||||
@@ -1351,6 +1367,8 @@ export const AppContainer = (props: AppContainerProps) => {
|
||||
showAutoAcceptIndicator,
|
||||
showWorkspaceMigrationDialog,
|
||||
workspaceExtensions,
|
||||
userTier,
|
||||
proQuotaRequest,
|
||||
contextFileNames,
|
||||
errorCount,
|
||||
availableTerminalHeight,
|
||||
@@ -1412,6 +1430,7 @@ export const AppContainer = (props: AppContainerProps) => {
|
||||
handleClearScreen,
|
||||
onWorkspaceMigrationDialogOpen,
|
||||
onWorkspaceMigrationDialogClose,
|
||||
handleProQuotaChoice,
|
||||
// Vision switch dialog
|
||||
handleVisionSwitchSelect,
|
||||
// Welcome back dialog
|
||||
@@ -1449,6 +1468,7 @@ export const AppContainer = (props: AppContainerProps) => {
|
||||
handleClearScreen,
|
||||
onWorkspaceMigrationDialogOpen,
|
||||
onWorkspaceMigrationDialogClose,
|
||||
handleProQuotaChoice,
|
||||
handleVisionSwitchSelect,
|
||||
handleWelcomeBackSelection,
|
||||
handleWelcomeBackClose,
|
||||
|
||||
@@ -168,7 +168,7 @@ describe('AuthDialog', () => {
|
||||
|
||||
it('should not show the GEMINI_API_KEY message if QWEN_DEFAULT_AUTH_TYPE is set to something else', () => {
|
||||
process.env['GEMINI_API_KEY'] = 'foobar';
|
||||
process.env['QWEN_DEFAULT_AUTH_TYPE'] = AuthType.USE_OPENAI;
|
||||
process.env['QWEN_DEFAULT_AUTH_TYPE'] = AuthType.LOGIN_WITH_GOOGLE;
|
||||
|
||||
const settings: LoadedSettings = new LoadedSettings(
|
||||
{
|
||||
@@ -212,7 +212,7 @@ describe('AuthDialog', () => {
|
||||
|
||||
it('should show the GEMINI_API_KEY message if QWEN_DEFAULT_AUTH_TYPE is set to use api key', () => {
|
||||
process.env['GEMINI_API_KEY'] = 'foobar';
|
||||
process.env['QWEN_DEFAULT_AUTH_TYPE'] = AuthType.USE_OPENAI;
|
||||
process.env['QWEN_DEFAULT_AUTH_TYPE'] = AuthType.USE_GEMINI;
|
||||
|
||||
const settings: LoadedSettings = new LoadedSettings(
|
||||
{
|
||||
@@ -504,12 +504,12 @@ describe('AuthDialog', () => {
|
||||
},
|
||||
{
|
||||
settings: {
|
||||
security: { auth: { selectedType: AuthType.USE_OPENAI } },
|
||||
security: { auth: { selectedType: AuthType.LOGIN_WITH_GOOGLE } },
|
||||
ui: { customThemes: {} },
|
||||
mcpServers: {},
|
||||
},
|
||||
originalSettings: {
|
||||
security: { auth: { selectedType: AuthType.USE_OPENAI } },
|
||||
security: { auth: { selectedType: AuthType.LOGIN_WITH_GOOGLE } },
|
||||
ui: { customThemes: {} },
|
||||
mcpServers: {},
|
||||
},
|
||||
|
||||
@@ -225,24 +225,16 @@ export const useAuthCommand = (
|
||||
const defaultAuthType = process.env['QWEN_DEFAULT_AUTH_TYPE'];
|
||||
if (
|
||||
defaultAuthType &&
|
||||
![
|
||||
AuthType.QWEN_OAUTH,
|
||||
AuthType.USE_OPENAI,
|
||||
AuthType.USE_GEMINI,
|
||||
AuthType.USE_VERTEX_AI,
|
||||
].includes(defaultAuthType as AuthType)
|
||||
![AuthType.QWEN_OAUTH, AuthType.USE_OPENAI].includes(
|
||||
defaultAuthType as AuthType,
|
||||
)
|
||||
) {
|
||||
onAuthError(
|
||||
t(
|
||||
'Invalid QWEN_DEFAULT_AUTH_TYPE value: "{{value}}". Valid values are: {{validValues}}',
|
||||
{
|
||||
value: defaultAuthType,
|
||||
validValues: [
|
||||
AuthType.QWEN_OAUTH,
|
||||
AuthType.USE_OPENAI,
|
||||
AuthType.USE_GEMINI,
|
||||
AuthType.USE_VERTEX_AI,
|
||||
].join(', '),
|
||||
validValues: [AuthType.QWEN_OAUTH, AuthType.USE_OPENAI].join(', '),
|
||||
},
|
||||
),
|
||||
);
|
||||
|
||||
@@ -15,6 +15,7 @@ vi.mock('@qwen-code/qwen-code-core', async (importOriginal) => {
|
||||
const original = await importOriginal<typeof core>();
|
||||
return {
|
||||
...original,
|
||||
getOauthClient: vi.fn(original.getOauthClient),
|
||||
getIdeInstaller: vi.fn(original.getIdeInstaller),
|
||||
IdeClient: {
|
||||
getInstance: vi.fn(),
|
||||
|
||||
@@ -17,6 +17,7 @@ import { AuthDialog } from '../auth/AuthDialog.js';
|
||||
import { OpenAIKeyPrompt } from './OpenAIKeyPrompt.js';
|
||||
import { EditorSettingsDialog } from './EditorSettingsDialog.js';
|
||||
import { WorkspaceMigrationDialog } from './WorkspaceMigrationDialog.js';
|
||||
import { ProQuotaDialog } from './ProQuotaDialog.js';
|
||||
import { PermissionsModifyTrustDialog } from './PermissionsModifyTrustDialog.js';
|
||||
import { ModelDialog } from './ModelDialog.js';
|
||||
import { ApprovalModeDialog } from './ApprovalModeDialog.js';
|
||||
@@ -86,6 +87,15 @@ export const DialogManager = ({
|
||||
/>
|
||||
);
|
||||
}
|
||||
if (uiState.proQuotaRequest) {
|
||||
return (
|
||||
<ProQuotaDialog
|
||||
failedModel={uiState.proQuotaRequest.failedModel}
|
||||
fallbackModel={uiState.proQuotaRequest.fallbackModel}
|
||||
onChoice={uiActions.handleProQuotaChoice}
|
||||
/>
|
||||
);
|
||||
}
|
||||
if (uiState.shouldShowIdePrompt) {
|
||||
return (
|
||||
<IdeIntegrationNudge
|
||||
|
||||
91
packages/cli/src/ui/components/ProQuotaDialog.test.tsx
Normal file
91
packages/cli/src/ui/components/ProQuotaDialog.test.tsx
Normal file
@@ -0,0 +1,91 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { render } from 'ink-testing-library';
|
||||
import { describe, it, expect, vi, beforeEach, type Mock } from 'vitest';
|
||||
import { ProQuotaDialog } from './ProQuotaDialog.js';
|
||||
import { RadioButtonSelect } from './shared/RadioButtonSelect.js';
|
||||
|
||||
// Mock the child component to make it easier to test the parent
|
||||
vi.mock('./shared/RadioButtonSelect.js', () => ({
|
||||
RadioButtonSelect: vi.fn(),
|
||||
}));
|
||||
|
||||
describe('ProQuotaDialog', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
it('should render with correct title and options', () => {
|
||||
const { lastFrame } = render(
|
||||
<ProQuotaDialog
|
||||
failedModel="gemini-2.5-pro"
|
||||
fallbackModel="gemini-2.5-flash"
|
||||
onChoice={() => {}}
|
||||
/>,
|
||||
);
|
||||
|
||||
const output = lastFrame();
|
||||
expect(output).toContain('Pro quota limit reached for gemini-2.5-pro.');
|
||||
|
||||
// Check that RadioButtonSelect was called with the correct items
|
||||
expect(RadioButtonSelect).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
items: [
|
||||
{
|
||||
label: 'Change auth (executes the /auth command)',
|
||||
value: 'auth',
|
||||
key: 'auth',
|
||||
},
|
||||
{
|
||||
label: `Continue with gemini-2.5-flash`,
|
||||
value: 'continue',
|
||||
key: 'continue',
|
||||
},
|
||||
],
|
||||
}),
|
||||
undefined,
|
||||
);
|
||||
});
|
||||
|
||||
it('should call onChoice with "auth" when "Change auth" is selected', () => {
|
||||
const mockOnChoice = vi.fn();
|
||||
render(
|
||||
<ProQuotaDialog
|
||||
failedModel="gemini-2.5-pro"
|
||||
fallbackModel="gemini-2.5-flash"
|
||||
onChoice={mockOnChoice}
|
||||
/>,
|
||||
);
|
||||
|
||||
// Get the onSelect function passed to RadioButtonSelect
|
||||
const onSelect = (RadioButtonSelect as Mock).mock.calls[0][0].onSelect;
|
||||
|
||||
// Simulate the selection
|
||||
onSelect('auth');
|
||||
|
||||
expect(mockOnChoice).toHaveBeenCalledWith('auth');
|
||||
});
|
||||
|
||||
it('should call onChoice with "continue" when "Continue with flash" is selected', () => {
|
||||
const mockOnChoice = vi.fn();
|
||||
render(
|
||||
<ProQuotaDialog
|
||||
failedModel="gemini-2.5-pro"
|
||||
fallbackModel="gemini-2.5-flash"
|
||||
onChoice={mockOnChoice}
|
||||
/>,
|
||||
);
|
||||
|
||||
// Get the onSelect function passed to RadioButtonSelect
|
||||
const onSelect = (RadioButtonSelect as Mock).mock.calls[0][0].onSelect;
|
||||
|
||||
// Simulate the selection
|
||||
onSelect('continue');
|
||||
|
||||
expect(mockOnChoice).toHaveBeenCalledWith('continue');
|
||||
});
|
||||
});
|
||||
55
packages/cli/src/ui/components/ProQuotaDialog.tsx
Normal file
55
packages/cli/src/ui/components/ProQuotaDialog.tsx
Normal file
@@ -0,0 +1,55 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import type React from 'react';
|
||||
import { Box, Text } from 'ink';
|
||||
import { RadioButtonSelect } from './shared/RadioButtonSelect.js';
|
||||
import { theme } from '../semantic-colors.js';
|
||||
import { t } from '../../i18n/index.js';
|
||||
|
||||
interface ProQuotaDialogProps {
|
||||
failedModel: string;
|
||||
fallbackModel: string;
|
||||
onChoice: (choice: 'auth' | 'continue') => void;
|
||||
}
|
||||
|
||||
export function ProQuotaDialog({
|
||||
failedModel,
|
||||
fallbackModel,
|
||||
onChoice,
|
||||
}: ProQuotaDialogProps): React.JSX.Element {
|
||||
const items = [
|
||||
{
|
||||
label: t('Change auth (executes the /auth command)'),
|
||||
value: 'auth' as const,
|
||||
key: 'auth',
|
||||
},
|
||||
{
|
||||
label: t('Continue with {{model}}', { model: fallbackModel }),
|
||||
value: 'continue' as const,
|
||||
key: 'continue',
|
||||
},
|
||||
];
|
||||
|
||||
const handleSelect = (choice: 'auth' | 'continue') => {
|
||||
onChoice(choice);
|
||||
};
|
||||
|
||||
return (
|
||||
<Box borderStyle="round" flexDirection="column" paddingX={1}>
|
||||
<Text bold color={theme.status.warning}>
|
||||
{t('Pro quota limit reached for {{model}}.', { model: failedModel })}
|
||||
</Text>
|
||||
<Box marginTop={1}>
|
||||
<RadioButtonSelect
|
||||
items={items}
|
||||
initialIndex={1}
|
||||
onSelect={handleSelect}
|
||||
/>
|
||||
</Box>
|
||||
</Box>
|
||||
);
|
||||
}
|
||||
@@ -55,6 +55,7 @@ export interface UIActions {
|
||||
handleClearScreen: () => void;
|
||||
onWorkspaceMigrationDialogOpen: () => void;
|
||||
onWorkspaceMigrationDialogClose: () => void;
|
||||
handleProQuotaChoice: (choice: 'auth' | 'continue') => void;
|
||||
// Vision switch dialog
|
||||
handleVisionSwitchSelect: (outcome: VisionSwitchOutcome) => void;
|
||||
// Welcome back dialog
|
||||
|
||||
@@ -22,13 +22,21 @@ import type {
|
||||
AuthType,
|
||||
IdeContext,
|
||||
ApprovalMode,
|
||||
UserTierId,
|
||||
IdeInfo,
|
||||
FallbackIntent,
|
||||
} from '@qwen-code/qwen-code-core';
|
||||
import type { DOMElement } from 'ink';
|
||||
import type { SessionStatsState } from '../contexts/SessionContext.js';
|
||||
import type { ExtensionUpdateState } from '../state/extensions.js';
|
||||
import type { UpdateObject } from '../utils/updateCheck.js';
|
||||
|
||||
export interface ProQuotaDialogRequest {
|
||||
failedModel: string;
|
||||
fallbackModel: string;
|
||||
resolve: (intent: FallbackIntent) => void;
|
||||
}
|
||||
|
||||
import { type UseHistoryManagerReturn } from '../hooks/useHistoryManager.js';
|
||||
import { type RestartReason } from '../hooks/useIdeTrustListener.js';
|
||||
|
||||
@@ -91,6 +99,8 @@ export interface UIState {
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
workspaceExtensions: any[]; // Extension[]
|
||||
// Quota-related state
|
||||
userTier: UserTierId | undefined;
|
||||
proQuotaRequest: ProQuotaDialogRequest | null;
|
||||
currentModel: string;
|
||||
contextFileNames: string[];
|
||||
errorCount: number;
|
||||
|
||||
@@ -1323,7 +1323,7 @@ describe('useGeminiStream', () => {
|
||||
it('should call parseAndFormatApiError with the correct authType on stream initialization failure', async () => {
|
||||
// 1. Setup
|
||||
const mockError = new Error('Rate limit exceeded');
|
||||
const mockAuthType = AuthType.USE_VERTEX_AI;
|
||||
const mockAuthType = AuthType.LOGIN_WITH_GOOGLE;
|
||||
mockParseAndFormatApiError.mockClear();
|
||||
mockSendMessageStream.mockReturnValue(
|
||||
(async function* () {
|
||||
@@ -1374,6 +1374,9 @@ describe('useGeminiStream', () => {
|
||||
expect(mockParseAndFormatApiError).toHaveBeenCalledWith(
|
||||
'Rate limit exceeded',
|
||||
mockAuthType,
|
||||
undefined,
|
||||
'gemini-2.5-pro',
|
||||
'gemini-2.5-flash',
|
||||
);
|
||||
});
|
||||
});
|
||||
@@ -2490,6 +2493,9 @@ describe('useGeminiStream', () => {
|
||||
expect(mockParseAndFormatApiError).toHaveBeenCalledWith(
|
||||
{ message: 'Test error' },
|
||||
expect.any(String),
|
||||
undefined,
|
||||
'gemini-2.5-pro',
|
||||
'gemini-2.5-flash',
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -26,6 +26,7 @@ import {
|
||||
GitService,
|
||||
UnauthorizedError,
|
||||
UserPromptEvent,
|
||||
DEFAULT_GEMINI_FLASH_MODEL,
|
||||
logConversationFinishedEvent,
|
||||
ConversationFinishedEvent,
|
||||
ApprovalMode,
|
||||
@@ -599,6 +600,9 @@ export const useGeminiStream = (
|
||||
text: parseAndFormatApiError(
|
||||
eventValue.error,
|
||||
config.getContentGeneratorConfig()?.authType,
|
||||
undefined,
|
||||
config.getModel(),
|
||||
DEFAULT_GEMINI_FLASH_MODEL,
|
||||
),
|
||||
},
|
||||
userMessageTimestamp,
|
||||
@@ -650,9 +654,6 @@ export const useGeminiStream = (
|
||||
'Response stopped due to image safety violations.',
|
||||
[FinishReason.UNEXPECTED_TOOL_CALL]:
|
||||
'Response stopped due to unexpected tool call.',
|
||||
[FinishReason.IMAGE_PROHIBITED_CONTENT]:
|
||||
'Response stopped due to image prohibited content.',
|
||||
[FinishReason.NO_IMAGE]: 'Response stopped due to no image.',
|
||||
};
|
||||
|
||||
const message = finishReasonMessages[finishReason];
|
||||
@@ -769,17 +770,11 @@ export const useGeminiStream = (
|
||||
for await (const event of stream) {
|
||||
switch (event.type) {
|
||||
case ServerGeminiEventType.Thought:
|
||||
// If the thought has a subject, it's a discrete status update rather than
|
||||
// a streamed textual thought, so we update the thought state directly.
|
||||
if (event.value.subject) {
|
||||
setThought(event.value);
|
||||
} else {
|
||||
thoughtBuffer = handleThoughtEvent(
|
||||
event.value,
|
||||
thoughtBuffer,
|
||||
userMessageTimestamp,
|
||||
);
|
||||
}
|
||||
thoughtBuffer = handleThoughtEvent(
|
||||
event.value,
|
||||
thoughtBuffer,
|
||||
userMessageTimestamp,
|
||||
);
|
||||
break;
|
||||
case ServerGeminiEventType.Content:
|
||||
geminiMessageBuffer = handleContentEvent(
|
||||
@@ -850,7 +845,6 @@ export const useGeminiStream = (
|
||||
handleMaxSessionTurnsEvent,
|
||||
handleSessionTokenLimitExceededEvent,
|
||||
handleCitationEvent,
|
||||
setThought,
|
||||
],
|
||||
);
|
||||
|
||||
@@ -993,6 +987,9 @@ export const useGeminiStream = (
|
||||
text: parseAndFormatApiError(
|
||||
getErrorMessage(error) || 'Unknown error',
|
||||
config.getContentGeneratorConfig()?.authType,
|
||||
undefined,
|
||||
config.getModel(),
|
||||
DEFAULT_GEMINI_FLASH_MODEL,
|
||||
),
|
||||
},
|
||||
userMessageTimestamp,
|
||||
|
||||
391
packages/cli/src/ui/hooks/useQuotaAndFallback.test.ts
Normal file
391
packages/cli/src/ui/hooks/useQuotaAndFallback.test.ts
Normal file
@@ -0,0 +1,391 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import {
|
||||
vi,
|
||||
describe,
|
||||
it,
|
||||
expect,
|
||||
beforeEach,
|
||||
afterEach,
|
||||
type Mock,
|
||||
} from 'vitest';
|
||||
import { act, renderHook } from '@testing-library/react';
|
||||
import {
|
||||
type Config,
|
||||
type FallbackModelHandler,
|
||||
UserTierId,
|
||||
AuthType,
|
||||
isGenericQuotaExceededError,
|
||||
isProQuotaExceededError,
|
||||
makeFakeConfig,
|
||||
} from '@qwen-code/qwen-code-core';
|
||||
import { useQuotaAndFallback } from './useQuotaAndFallback.js';
|
||||
import type { UseHistoryManagerReturn } from './useHistoryManager.js';
|
||||
import { AuthState, MessageType } from '../types.js';
|
||||
|
||||
// Mock the error checking functions from the core package to control test scenarios
|
||||
vi.mock('@qwen-code/qwen-code-core', async (importOriginal) => {
|
||||
const original =
|
||||
await importOriginal<typeof import('@qwen-code/qwen-code-core')>();
|
||||
return {
|
||||
...original,
|
||||
isGenericQuotaExceededError: vi.fn(),
|
||||
isProQuotaExceededError: vi.fn(),
|
||||
};
|
||||
});
|
||||
|
||||
// Use a type alias for SpyInstance as it's not directly exported
|
||||
type SpyInstance = ReturnType<typeof vi.spyOn>;
|
||||
|
||||
describe('useQuotaAndFallback', () => {
|
||||
let mockConfig: Config;
|
||||
let mockHistoryManager: UseHistoryManagerReturn;
|
||||
let mockSetAuthState: Mock;
|
||||
let mockSetModelSwitchedFromQuotaError: Mock;
|
||||
let setFallbackHandlerSpy: SpyInstance;
|
||||
|
||||
const mockedIsGenericQuotaExceededError = isGenericQuotaExceededError as Mock;
|
||||
const mockedIsProQuotaExceededError = isProQuotaExceededError as Mock;
|
||||
|
||||
beforeEach(() => {
|
||||
mockConfig = makeFakeConfig();
|
||||
|
||||
// Spy on the method that requires the private field and mock its return.
|
||||
// This is cleaner than modifying the config class for tests.
|
||||
vi.spyOn(mockConfig, 'getContentGeneratorConfig').mockReturnValue({
|
||||
model: 'test-model',
|
||||
authType: AuthType.LOGIN_WITH_GOOGLE,
|
||||
});
|
||||
|
||||
mockHistoryManager = {
|
||||
addItem: vi.fn(),
|
||||
history: [],
|
||||
updateItem: vi.fn(),
|
||||
clearItems: vi.fn(),
|
||||
loadHistory: vi.fn(),
|
||||
};
|
||||
mockSetAuthState = vi.fn();
|
||||
mockSetModelSwitchedFromQuotaError = vi.fn();
|
||||
|
||||
setFallbackHandlerSpy = vi.spyOn(mockConfig, 'setFallbackModelHandler');
|
||||
vi.spyOn(mockConfig, 'setQuotaErrorOccurred');
|
||||
|
||||
mockedIsGenericQuotaExceededError.mockReturnValue(false);
|
||||
mockedIsProQuotaExceededError.mockReturnValue(false);
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
it('should register a fallback handler on initialization', () => {
|
||||
renderHook(() =>
|
||||
useQuotaAndFallback({
|
||||
config: mockConfig,
|
||||
historyManager: mockHistoryManager,
|
||||
userTier: UserTierId.FREE,
|
||||
setAuthState: mockSetAuthState,
|
||||
setModelSwitchedFromQuotaError: mockSetModelSwitchedFromQuotaError,
|
||||
}),
|
||||
);
|
||||
|
||||
expect(setFallbackHandlerSpy).toHaveBeenCalledTimes(1);
|
||||
expect(setFallbackHandlerSpy.mock.calls[0][0]).toBeInstanceOf(Function);
|
||||
});
|
||||
|
||||
describe('Fallback Handler Logic', () => {
|
||||
// Helper function to render the hook and extract the registered handler
|
||||
const getRegisteredHandler = (
|
||||
userTier: UserTierId = UserTierId.FREE,
|
||||
): FallbackModelHandler => {
|
||||
renderHook(
|
||||
(props) =>
|
||||
useQuotaAndFallback({
|
||||
config: mockConfig,
|
||||
historyManager: mockHistoryManager,
|
||||
userTier: props.userTier,
|
||||
setAuthState: mockSetAuthState,
|
||||
setModelSwitchedFromQuotaError: mockSetModelSwitchedFromQuotaError,
|
||||
}),
|
||||
{ initialProps: { userTier } },
|
||||
);
|
||||
return setFallbackHandlerSpy.mock.calls[0][0] as FallbackModelHandler;
|
||||
};
|
||||
|
||||
it('should return null and take no action if already in fallback mode', async () => {
|
||||
vi.spyOn(mockConfig, 'isInFallbackMode').mockReturnValue(true);
|
||||
const handler = getRegisteredHandler();
|
||||
const result = await handler('gemini-pro', 'gemini-flash', new Error());
|
||||
|
||||
expect(result).toBeNull();
|
||||
expect(mockHistoryManager.addItem).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should return null and take no action if authType is not LOGIN_WITH_GOOGLE', async () => {
|
||||
// Override the default mock from beforeEach for this specific test
|
||||
vi.spyOn(mockConfig, 'getContentGeneratorConfig').mockReturnValue({
|
||||
model: 'test-model',
|
||||
authType: AuthType.USE_GEMINI,
|
||||
});
|
||||
|
||||
const handler = getRegisteredHandler();
|
||||
const result = await handler('gemini-pro', 'gemini-flash', new Error());
|
||||
|
||||
expect(result).toBeNull();
|
||||
expect(mockHistoryManager.addItem).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
describe('Automatic Fallback Scenarios', () => {
|
||||
const testCases = [
|
||||
{
|
||||
errorType: 'generic',
|
||||
tier: UserTierId.FREE,
|
||||
expectedMessageSnippets: [
|
||||
'Automatically switching from model-A to model-B',
|
||||
'upgrade to a Gemini Code Assist Standard or Enterprise plan',
|
||||
],
|
||||
},
|
||||
{
|
||||
errorType: 'generic',
|
||||
tier: UserTierId.STANDARD, // Paid tier
|
||||
expectedMessageSnippets: [
|
||||
'Automatically switching from model-A to model-B',
|
||||
'switch to using a paid API key from AI Studio',
|
||||
],
|
||||
},
|
||||
{
|
||||
errorType: 'other',
|
||||
tier: UserTierId.FREE,
|
||||
expectedMessageSnippets: [
|
||||
'Automatically switching from model-A to model-B for faster responses',
|
||||
'upgrade to a Gemini Code Assist Standard or Enterprise plan',
|
||||
],
|
||||
},
|
||||
{
|
||||
errorType: 'other',
|
||||
tier: UserTierId.LEGACY, // Paid tier
|
||||
expectedMessageSnippets: [
|
||||
'Automatically switching from model-A to model-B for faster responses',
|
||||
'switch to using a paid API key from AI Studio',
|
||||
],
|
||||
},
|
||||
];
|
||||
|
||||
for (const { errorType, tier, expectedMessageSnippets } of testCases) {
|
||||
it(`should handle ${errorType} error for ${tier} tier correctly`, async () => {
|
||||
mockedIsGenericQuotaExceededError.mockReturnValue(
|
||||
errorType === 'generic',
|
||||
);
|
||||
|
||||
const handler = getRegisteredHandler(tier);
|
||||
const result = await handler(
|
||||
'model-A',
|
||||
'model-B',
|
||||
new Error('quota exceeded'),
|
||||
);
|
||||
|
||||
// Automatic fallbacks should return 'stop'
|
||||
expect(result).toBe('stop');
|
||||
|
||||
expect(mockHistoryManager.addItem).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ type: MessageType.INFO }),
|
||||
expect.any(Number),
|
||||
);
|
||||
|
||||
const message = (mockHistoryManager.addItem as Mock).mock.calls[0][0]
|
||||
.text;
|
||||
for (const snippet of expectedMessageSnippets) {
|
||||
expect(message).toContain(snippet);
|
||||
}
|
||||
|
||||
expect(mockSetModelSwitchedFromQuotaError).toHaveBeenCalledWith(true);
|
||||
expect(mockConfig.setQuotaErrorOccurred).toHaveBeenCalledWith(true);
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
describe('Interactive Fallback (Pro Quota Error)', () => {
|
||||
beforeEach(() => {
|
||||
mockedIsProQuotaExceededError.mockReturnValue(true);
|
||||
});
|
||||
|
||||
it('should set an interactive request and wait for user choice', async () => {
|
||||
const { result } = renderHook(() =>
|
||||
useQuotaAndFallback({
|
||||
config: mockConfig,
|
||||
historyManager: mockHistoryManager,
|
||||
userTier: UserTierId.FREE,
|
||||
setAuthState: mockSetAuthState,
|
||||
setModelSwitchedFromQuotaError: mockSetModelSwitchedFromQuotaError,
|
||||
}),
|
||||
);
|
||||
|
||||
const handler = setFallbackHandlerSpy.mock
|
||||
.calls[0][0] as FallbackModelHandler;
|
||||
|
||||
// Call the handler but do not await it, to check the intermediate state
|
||||
const promise = handler(
|
||||
'gemini-pro',
|
||||
'gemini-flash',
|
||||
new Error('pro quota'),
|
||||
);
|
||||
|
||||
await act(async () => {});
|
||||
|
||||
// The hook should now have a pending request for the UI to handle
|
||||
expect(result.current.proQuotaRequest).not.toBeNull();
|
||||
expect(result.current.proQuotaRequest?.failedModel).toBe('gemini-pro');
|
||||
|
||||
// Simulate the user choosing to continue with the fallback model
|
||||
act(() => {
|
||||
result.current.handleProQuotaChoice('continue');
|
||||
});
|
||||
|
||||
// The original promise from the handler should now resolve
|
||||
const intent = await promise;
|
||||
expect(intent).toBe('retry');
|
||||
|
||||
// The pending request should be cleared from the state
|
||||
expect(result.current.proQuotaRequest).toBeNull();
|
||||
});
|
||||
|
||||
it('should handle race conditions by stopping subsequent requests', async () => {
|
||||
const { result } = renderHook(() =>
|
||||
useQuotaAndFallback({
|
||||
config: mockConfig,
|
||||
historyManager: mockHistoryManager,
|
||||
userTier: UserTierId.FREE,
|
||||
setAuthState: mockSetAuthState,
|
||||
setModelSwitchedFromQuotaError: mockSetModelSwitchedFromQuotaError,
|
||||
}),
|
||||
);
|
||||
|
||||
const handler = setFallbackHandlerSpy.mock
|
||||
.calls[0][0] as FallbackModelHandler;
|
||||
|
||||
const promise1 = handler(
|
||||
'gemini-pro',
|
||||
'gemini-flash',
|
||||
new Error('pro quota 1'),
|
||||
);
|
||||
await act(async () => {});
|
||||
|
||||
const firstRequest = result.current.proQuotaRequest;
|
||||
expect(firstRequest).not.toBeNull();
|
||||
|
||||
const result2 = await handler(
|
||||
'gemini-pro',
|
||||
'gemini-flash',
|
||||
new Error('pro quota 2'),
|
||||
);
|
||||
|
||||
// The lock should have stopped the second request
|
||||
expect(result2).toBe('stop');
|
||||
expect(result.current.proQuotaRequest).toBe(firstRequest);
|
||||
|
||||
act(() => {
|
||||
result.current.handleProQuotaChoice('continue');
|
||||
});
|
||||
|
||||
const intent1 = await promise1;
|
||||
expect(intent1).toBe('retry');
|
||||
expect(result.current.proQuotaRequest).toBeNull();
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('handleProQuotaChoice', () => {
|
||||
beforeEach(() => {
|
||||
mockedIsProQuotaExceededError.mockReturnValue(true);
|
||||
});
|
||||
|
||||
it('should do nothing if there is no pending pro quota request', () => {
|
||||
const { result } = renderHook(() =>
|
||||
useQuotaAndFallback({
|
||||
config: mockConfig,
|
||||
historyManager: mockHistoryManager,
|
||||
userTier: UserTierId.FREE,
|
||||
setAuthState: mockSetAuthState,
|
||||
setModelSwitchedFromQuotaError: mockSetModelSwitchedFromQuotaError,
|
||||
}),
|
||||
);
|
||||
|
||||
act(() => {
|
||||
result.current.handleProQuotaChoice('auth');
|
||||
});
|
||||
|
||||
expect(mockSetAuthState).not.toHaveBeenCalled();
|
||||
expect(mockHistoryManager.addItem).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should resolve intent to "auth" and trigger auth state update', async () => {
|
||||
const { result } = renderHook(() =>
|
||||
useQuotaAndFallback({
|
||||
config: mockConfig,
|
||||
historyManager: mockHistoryManager,
|
||||
userTier: UserTierId.FREE,
|
||||
setAuthState: mockSetAuthState,
|
||||
setModelSwitchedFromQuotaError: mockSetModelSwitchedFromQuotaError,
|
||||
}),
|
||||
);
|
||||
|
||||
const handler = setFallbackHandlerSpy.mock
|
||||
.calls[0][0] as FallbackModelHandler;
|
||||
const promise = handler(
|
||||
'gemini-pro',
|
||||
'gemini-flash',
|
||||
new Error('pro quota'),
|
||||
);
|
||||
await act(async () => {}); // Allow state to update
|
||||
|
||||
act(() => {
|
||||
result.current.handleProQuotaChoice('auth');
|
||||
});
|
||||
|
||||
const intent = await promise;
|
||||
expect(intent).toBe('auth');
|
||||
expect(mockSetAuthState).toHaveBeenCalledWith(AuthState.Updating);
|
||||
expect(result.current.proQuotaRequest).toBeNull();
|
||||
});
|
||||
|
||||
it('should resolve intent to "retry" and add info message on continue', async () => {
|
||||
const { result } = renderHook(() =>
|
||||
useQuotaAndFallback({
|
||||
config: mockConfig,
|
||||
historyManager: mockHistoryManager,
|
||||
userTier: UserTierId.FREE,
|
||||
setAuthState: mockSetAuthState,
|
||||
setModelSwitchedFromQuotaError: mockSetModelSwitchedFromQuotaError,
|
||||
}),
|
||||
);
|
||||
|
||||
const handler = setFallbackHandlerSpy.mock
|
||||
.calls[0][0] as FallbackModelHandler;
|
||||
// The first `addItem` call is for the initial quota error message
|
||||
const promise = handler(
|
||||
'gemini-pro',
|
||||
'gemini-flash',
|
||||
new Error('pro quota'),
|
||||
);
|
||||
await act(async () => {}); // Allow state to update
|
||||
|
||||
act(() => {
|
||||
result.current.handleProQuotaChoice('continue');
|
||||
});
|
||||
|
||||
const intent = await promise;
|
||||
expect(intent).toBe('retry');
|
||||
expect(result.current.proQuotaRequest).toBeNull();
|
||||
|
||||
// Check for the second "Switched to fallback model" message
|
||||
expect(mockHistoryManager.addItem).toHaveBeenCalledTimes(2);
|
||||
const lastCall = (mockHistoryManager.addItem as Mock).mock.calls[1][0];
|
||||
expect(lastCall.type).toBe(MessageType.INFO);
|
||||
expect(lastCall.text).toContain('Switched to fallback model.');
|
||||
});
|
||||
});
|
||||
});
|
||||
175
packages/cli/src/ui/hooks/useQuotaAndFallback.ts
Normal file
175
packages/cli/src/ui/hooks/useQuotaAndFallback.ts
Normal file
@@ -0,0 +1,175 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import {
|
||||
AuthType,
|
||||
type Config,
|
||||
type FallbackModelHandler,
|
||||
type FallbackIntent,
|
||||
isGenericQuotaExceededError,
|
||||
isProQuotaExceededError,
|
||||
UserTierId,
|
||||
} from '@qwen-code/qwen-code-core';
|
||||
import { useCallback, useEffect, useRef, useState } from 'react';
|
||||
import { type UseHistoryManagerReturn } from './useHistoryManager.js';
|
||||
import { AuthState, MessageType } from '../types.js';
|
||||
import { type ProQuotaDialogRequest } from '../contexts/UIStateContext.js';
|
||||
|
||||
interface UseQuotaAndFallbackArgs {
|
||||
config: Config;
|
||||
historyManager: UseHistoryManagerReturn;
|
||||
userTier: UserTierId | undefined;
|
||||
setAuthState: (state: AuthState) => void;
|
||||
setModelSwitchedFromQuotaError: (value: boolean) => void;
|
||||
}
|
||||
|
||||
export function useQuotaAndFallback({
|
||||
config,
|
||||
historyManager,
|
||||
userTier,
|
||||
setAuthState,
|
||||
setModelSwitchedFromQuotaError,
|
||||
}: UseQuotaAndFallbackArgs) {
|
||||
const [proQuotaRequest, setProQuotaRequest] =
|
||||
useState<ProQuotaDialogRequest | null>(null);
|
||||
const isDialogPending = useRef(false);
|
||||
|
||||
// Set up Flash fallback handler
|
||||
useEffect(() => {
|
||||
const fallbackHandler: FallbackModelHandler = async (
|
||||
failedModel,
|
||||
fallbackModel,
|
||||
error,
|
||||
): Promise<FallbackIntent | null> => {
|
||||
if (config.isInFallbackMode()) {
|
||||
return null;
|
||||
}
|
||||
|
||||
// Fallbacks are currently only handled for OAuth users.
|
||||
const contentGeneratorConfig = config.getContentGeneratorConfig();
|
||||
if (
|
||||
!contentGeneratorConfig ||
|
||||
contentGeneratorConfig.authType !== AuthType.LOGIN_WITH_GOOGLE
|
||||
) {
|
||||
return null;
|
||||
}
|
||||
|
||||
// Use actual user tier if available; otherwise, default to FREE tier behavior (safe default)
|
||||
const isPaidTier =
|
||||
userTier === UserTierId.LEGACY || userTier === UserTierId.STANDARD;
|
||||
|
||||
let message: string;
|
||||
|
||||
if (error && isProQuotaExceededError(error)) {
|
||||
// Pro Quota specific messages (Interactive)
|
||||
if (isPaidTier) {
|
||||
message = `⚡ You have reached your daily ${failedModel} quota limit.
|
||||
⚡ You can choose to authenticate with a paid API key or continue with the fallback model.
|
||||
⚡ To continue accessing the ${failedModel} model today, consider using /auth to switch to using a paid API key from AI Studio at https://aistudio.google.com/apikey`;
|
||||
} else {
|
||||
message = `⚡ You have reached your daily ${failedModel} quota limit.
|
||||
⚡ You can choose to authenticate with a paid API key or continue with the fallback model.
|
||||
⚡ To increase your limits, upgrade to a Gemini Code Assist Standard or Enterprise plan with higher limits at https://goo.gle/set-up-gemini-code-assist
|
||||
⚡ Or you can utilize a Gemini API Key. See: https://goo.gle/gemini-cli-docs-auth#gemini-api-key
|
||||
⚡ You can switch authentication methods by typing /auth`;
|
||||
}
|
||||
} else if (error && isGenericQuotaExceededError(error)) {
|
||||
// Generic Quota (Automatic fallback)
|
||||
const actionMessage = `⚡ You have reached your daily quota limit.\n⚡ Automatically switching from ${failedModel} to ${fallbackModel} for the remainder of this session.`;
|
||||
|
||||
if (isPaidTier) {
|
||||
message = `${actionMessage}
|
||||
⚡ To continue accessing the ${failedModel} model today, consider using /auth to switch to using a paid API key from AI Studio at https://aistudio.google.com/apikey`;
|
||||
} else {
|
||||
message = `${actionMessage}
|
||||
⚡ To increase your limits, upgrade to a Gemini Code Assist Standard or Enterprise plan with higher limits at https://goo.gle/set-up-gemini-code-assist
|
||||
⚡ Or you can utilize a Gemini API Key. See: https://goo.gle/gemini-cli-docs-auth#gemini-api-key
|
||||
⚡ You can switch authentication methods by typing /auth`;
|
||||
}
|
||||
} else {
|
||||
// Consecutive 429s or other errors (Automatic fallback)
|
||||
const actionMessage = `⚡ Automatically switching from ${failedModel} to ${fallbackModel} for faster responses for the remainder of this session.`;
|
||||
|
||||
if (isPaidTier) {
|
||||
message = `${actionMessage}
|
||||
⚡ Possible reasons for this are that you have received multiple consecutive capacity errors or you have reached your daily ${failedModel} quota limit
|
||||
⚡ To continue accessing the ${failedModel} model today, consider using /auth to switch to using a paid API key from AI Studio at https://aistudio.google.com/apikey`;
|
||||
} else {
|
||||
message = `${actionMessage}
|
||||
⚡ Possible reasons for this are that you have received multiple consecutive capacity errors or you have reached your daily ${failedModel} quota limit
|
||||
⚡ To increase your limits, upgrade to a Gemini Code Assist Standard or Enterprise plan with higher limits at https://goo.gle/set-up-gemini-code-assist
|
||||
⚡ Or you can utilize a Gemini API Key. See: https://goo.gle/gemini-cli-docs-auth#gemini-api-key
|
||||
⚡ You can switch authentication methods by typing /auth`;
|
||||
}
|
||||
}
|
||||
|
||||
// Add message to UI history
|
||||
historyManager.addItem(
|
||||
{
|
||||
type: MessageType.INFO,
|
||||
text: message,
|
||||
},
|
||||
Date.now(),
|
||||
);
|
||||
|
||||
setModelSwitchedFromQuotaError(true);
|
||||
config.setQuotaErrorOccurred(true);
|
||||
|
||||
// Interactive Fallback for Pro quota
|
||||
if (error && isProQuotaExceededError(error)) {
|
||||
if (isDialogPending.current) {
|
||||
return 'stop'; // A dialog is already active, so just stop this request.
|
||||
}
|
||||
isDialogPending.current = true;
|
||||
|
||||
const intent: FallbackIntent = await new Promise<FallbackIntent>(
|
||||
(resolve) => {
|
||||
setProQuotaRequest({
|
||||
failedModel,
|
||||
fallbackModel,
|
||||
resolve,
|
||||
});
|
||||
},
|
||||
);
|
||||
|
||||
return intent;
|
||||
}
|
||||
|
||||
return 'stop';
|
||||
};
|
||||
|
||||
config.setFallbackModelHandler(fallbackHandler);
|
||||
}, [config, historyManager, userTier, setModelSwitchedFromQuotaError]);
|
||||
|
||||
const handleProQuotaChoice = useCallback(
|
||||
(choice: 'auth' | 'continue') => {
|
||||
if (!proQuotaRequest) return;
|
||||
|
||||
const intent: FallbackIntent = choice === 'auth' ? 'auth' : 'retry';
|
||||
proQuotaRequest.resolve(intent);
|
||||
setProQuotaRequest(null);
|
||||
isDialogPending.current = false; // Reset the flag here
|
||||
|
||||
if (choice === 'auth') {
|
||||
setAuthState(AuthState.Updating);
|
||||
} else {
|
||||
historyManager.addItem(
|
||||
{
|
||||
type: MessageType.INFO,
|
||||
text: 'Switched to fallback model. Tip: Press Ctrl+P (or Up Arrow) to recall your previous prompt and submit it again if you wish.',
|
||||
},
|
||||
Date.now(),
|
||||
);
|
||||
}
|
||||
},
|
||||
[proQuotaRequest, setAuthState, historyManager],
|
||||
);
|
||||
|
||||
return {
|
||||
proQuotaRequest,
|
||||
handleProQuotaChoice,
|
||||
};
|
||||
}
|
||||
@@ -411,7 +411,7 @@ describe('useQwenAuth', () => {
|
||||
expect(geminiResult.current.qwenAuthState.authStatus).toBe('idle');
|
||||
|
||||
const { result: oauthResult } = renderHook(() =>
|
||||
useQwenAuth(AuthType.USE_OPENAI, true),
|
||||
useQwenAuth(AuthType.LOGIN_WITH_GOOGLE, true),
|
||||
);
|
||||
expect(oauthResult.current.qwenAuthState.authStatus).toBe('idle');
|
||||
});
|
||||
|
||||
@@ -62,7 +62,7 @@ const mockConfig = {
|
||||
getAllowedTools: vi.fn(() => []),
|
||||
getContentGeneratorConfig: () => ({
|
||||
model: 'test-model',
|
||||
authType: 'gemini-api-key',
|
||||
authType: 'oauth-personal',
|
||||
}),
|
||||
getUseSmartEdit: () => false,
|
||||
getUseModelRouter: () => false,
|
||||
|
||||
@@ -21,13 +21,6 @@ function getAuthTypeFromEnv(): AuthType | undefined {
|
||||
return AuthType.QWEN_OAUTH;
|
||||
}
|
||||
|
||||
if (process.env['GEMINI_API_KEY']) {
|
||||
return AuthType.USE_GEMINI;
|
||||
}
|
||||
if (process.env['GOOGLE_API_KEY']) {
|
||||
return AuthType.USE_VERTEX_AI;
|
||||
}
|
||||
|
||||
return undefined;
|
||||
}
|
||||
|
||||
|
||||
@@ -23,8 +23,8 @@
|
||||
"scripts/postinstall.js"
|
||||
],
|
||||
"dependencies": {
|
||||
"@google/genai": "1.30.0",
|
||||
"@modelcontextprotocol/sdk": "^1.25.1",
|
||||
"@google/genai": "1.16.0",
|
||||
"@modelcontextprotocol/sdk": "^1.11.0",
|
||||
"@opentelemetry/api": "^1.9.0",
|
||||
"async-mutex": "^0.5.0",
|
||||
"@opentelemetry/exporter-logs-otlp-grpc": "^0.203.0",
|
||||
@@ -34,6 +34,7 @@
|
||||
"@opentelemetry/exporter-trace-otlp-grpc": "^0.203.0",
|
||||
"@opentelemetry/exporter-trace-otlp-http": "^0.203.0",
|
||||
"@opentelemetry/instrumentation-http": "^0.203.0",
|
||||
"@opentelemetry/resource-detector-gcp": "^0.40.0",
|
||||
"@opentelemetry/sdk-node": "^0.203.0",
|
||||
"@types/html-to-text": "^9.0.4",
|
||||
"@xterm/headless": "5.5.0",
|
||||
@@ -47,7 +48,7 @@
|
||||
"fdir": "^6.4.6",
|
||||
"fzf": "^0.5.2",
|
||||
"glob": "^10.5.0",
|
||||
"google-auth-library": "^10.5.0",
|
||||
"google-auth-library": "^9.11.0",
|
||||
"html-to-text": "^9.0.5",
|
||||
"https-proxy-agent": "^7.0.6",
|
||||
"ignore": "^7.0.0",
|
||||
|
||||
54
packages/core/src/code_assist/codeAssist.ts
Normal file
54
packages/core/src/code_assist/codeAssist.ts
Normal file
@@ -0,0 +1,54 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import type { ContentGenerator } from '../core/contentGenerator.js';
|
||||
import { AuthType } from '../core/contentGenerator.js';
|
||||
import { getOauthClient } from './oauth2.js';
|
||||
import { setupUser } from './setup.js';
|
||||
import type { HttpOptions } from './server.js';
|
||||
import { CodeAssistServer } from './server.js';
|
||||
import type { Config } from '../config/config.js';
|
||||
import { LoggingContentGenerator } from '../core/loggingContentGenerator.js';
|
||||
|
||||
export async function createCodeAssistContentGenerator(
|
||||
httpOptions: HttpOptions,
|
||||
authType: AuthType,
|
||||
config: Config,
|
||||
sessionId?: string,
|
||||
): Promise<ContentGenerator> {
|
||||
if (
|
||||
authType === AuthType.LOGIN_WITH_GOOGLE ||
|
||||
authType === AuthType.CLOUD_SHELL
|
||||
) {
|
||||
const authClient = await getOauthClient(authType, config);
|
||||
const userData = await setupUser(authClient);
|
||||
return new CodeAssistServer(
|
||||
authClient,
|
||||
userData.projectId,
|
||||
httpOptions,
|
||||
sessionId,
|
||||
userData.userTier,
|
||||
);
|
||||
}
|
||||
|
||||
throw new Error(`Unsupported authType: ${authType}`);
|
||||
}
|
||||
|
||||
export function getCodeAssistServer(
|
||||
config: Config,
|
||||
): CodeAssistServer | undefined {
|
||||
let server = config.getContentGenerator();
|
||||
|
||||
// Unwrap LoggingContentGenerator if present
|
||||
if (server instanceof LoggingContentGenerator) {
|
||||
server = server.getWrapped();
|
||||
}
|
||||
|
||||
if (!(server instanceof CodeAssistServer)) {
|
||||
return undefined;
|
||||
}
|
||||
return server;
|
||||
}
|
||||
456
packages/core/src/code_assist/converter.test.ts
Normal file
456
packages/core/src/code_assist/converter.test.ts
Normal file
@@ -0,0 +1,456 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { describe, it, expect } from 'vitest';
|
||||
import type { CaGenerateContentResponse } from './converter.js';
|
||||
import {
|
||||
toGenerateContentRequest,
|
||||
fromGenerateContentResponse,
|
||||
toContents,
|
||||
} from './converter.js';
|
||||
import type {
|
||||
ContentListUnion,
|
||||
GenerateContentParameters,
|
||||
} from '@google/genai';
|
||||
import {
|
||||
GenerateContentResponse,
|
||||
FinishReason,
|
||||
BlockedReason,
|
||||
type Part,
|
||||
} from '@google/genai';
|
||||
|
||||
describe('converter', () => {
|
||||
describe('toCodeAssistRequest', () => {
|
||||
it('should convert a simple request with project', () => {
|
||||
const genaiReq: GenerateContentParameters = {
|
||||
model: 'gemini-pro',
|
||||
contents: [{ role: 'user', parts: [{ text: 'Hello' }] }],
|
||||
};
|
||||
const codeAssistReq = toGenerateContentRequest(
|
||||
genaiReq,
|
||||
'my-prompt',
|
||||
'my-project',
|
||||
'my-session',
|
||||
);
|
||||
expect(codeAssistReq).toEqual({
|
||||
model: 'gemini-pro',
|
||||
project: 'my-project',
|
||||
request: {
|
||||
contents: [{ role: 'user', parts: [{ text: 'Hello' }] }],
|
||||
systemInstruction: undefined,
|
||||
cachedContent: undefined,
|
||||
tools: undefined,
|
||||
toolConfig: undefined,
|
||||
labels: undefined,
|
||||
safetySettings: undefined,
|
||||
generationConfig: undefined,
|
||||
session_id: 'my-session',
|
||||
},
|
||||
user_prompt_id: 'my-prompt',
|
||||
});
|
||||
});
|
||||
|
||||
it('should convert a request without a project', () => {
|
||||
const genaiReq: GenerateContentParameters = {
|
||||
model: 'gemini-pro',
|
||||
contents: [{ role: 'user', parts: [{ text: 'Hello' }] }],
|
||||
};
|
||||
const codeAssistReq = toGenerateContentRequest(
|
||||
genaiReq,
|
||||
'my-prompt',
|
||||
undefined,
|
||||
'my-session',
|
||||
);
|
||||
expect(codeAssistReq).toEqual({
|
||||
model: 'gemini-pro',
|
||||
project: undefined,
|
||||
request: {
|
||||
contents: [{ role: 'user', parts: [{ text: 'Hello' }] }],
|
||||
systemInstruction: undefined,
|
||||
cachedContent: undefined,
|
||||
tools: undefined,
|
||||
toolConfig: undefined,
|
||||
labels: undefined,
|
||||
safetySettings: undefined,
|
||||
generationConfig: undefined,
|
||||
session_id: 'my-session',
|
||||
},
|
||||
user_prompt_id: 'my-prompt',
|
||||
});
|
||||
});
|
||||
|
||||
it('should convert a request with sessionId', () => {
|
||||
const genaiReq: GenerateContentParameters = {
|
||||
model: 'gemini-pro',
|
||||
contents: [{ role: 'user', parts: [{ text: 'Hello' }] }],
|
||||
};
|
||||
const codeAssistReq = toGenerateContentRequest(
|
||||
genaiReq,
|
||||
'my-prompt',
|
||||
'my-project',
|
||||
'session-123',
|
||||
);
|
||||
expect(codeAssistReq).toEqual({
|
||||
model: 'gemini-pro',
|
||||
project: 'my-project',
|
||||
request: {
|
||||
contents: [{ role: 'user', parts: [{ text: 'Hello' }] }],
|
||||
systemInstruction: undefined,
|
||||
cachedContent: undefined,
|
||||
tools: undefined,
|
||||
toolConfig: undefined,
|
||||
labels: undefined,
|
||||
safetySettings: undefined,
|
||||
generationConfig: undefined,
|
||||
session_id: 'session-123',
|
||||
},
|
||||
user_prompt_id: 'my-prompt',
|
||||
});
|
||||
});
|
||||
|
||||
it('should handle string content', () => {
|
||||
const genaiReq: GenerateContentParameters = {
|
||||
model: 'gemini-pro',
|
||||
contents: 'Hello',
|
||||
};
|
||||
const codeAssistReq = toGenerateContentRequest(
|
||||
genaiReq,
|
||||
'my-prompt',
|
||||
'my-project',
|
||||
'my-session',
|
||||
);
|
||||
expect(codeAssistReq.request.contents).toEqual([
|
||||
{ role: 'user', parts: [{ text: 'Hello' }] },
|
||||
]);
|
||||
});
|
||||
|
||||
it('should handle Part[] content', () => {
|
||||
const genaiReq: GenerateContentParameters = {
|
||||
model: 'gemini-pro',
|
||||
contents: [{ text: 'Hello' }, { text: 'World' }],
|
||||
};
|
||||
const codeAssistReq = toGenerateContentRequest(
|
||||
genaiReq,
|
||||
'my-prompt',
|
||||
'my-project',
|
||||
'my-session',
|
||||
);
|
||||
expect(codeAssistReq.request.contents).toEqual([
|
||||
{ role: 'user', parts: [{ text: 'Hello' }] },
|
||||
{ role: 'user', parts: [{ text: 'World' }] },
|
||||
]);
|
||||
});
|
||||
|
||||
it('should handle system instructions', () => {
|
||||
const genaiReq: GenerateContentParameters = {
|
||||
model: 'gemini-pro',
|
||||
contents: 'Hello',
|
||||
config: {
|
||||
systemInstruction: 'You are a helpful assistant.',
|
||||
},
|
||||
};
|
||||
const codeAssistReq = toGenerateContentRequest(
|
||||
genaiReq,
|
||||
'my-prompt',
|
||||
'my-project',
|
||||
'my-session',
|
||||
);
|
||||
expect(codeAssistReq.request.systemInstruction).toEqual({
|
||||
role: 'user',
|
||||
parts: [{ text: 'You are a helpful assistant.' }],
|
||||
});
|
||||
});
|
||||
|
||||
it('should handle generation config', () => {
|
||||
const genaiReq: GenerateContentParameters = {
|
||||
model: 'gemini-pro',
|
||||
contents: 'Hello',
|
||||
config: {
|
||||
temperature: 0.8,
|
||||
topK: 40,
|
||||
},
|
||||
};
|
||||
const codeAssistReq = toGenerateContentRequest(
|
||||
genaiReq,
|
||||
'my-prompt',
|
||||
'my-project',
|
||||
'my-session',
|
||||
);
|
||||
expect(codeAssistReq.request.generationConfig).toEqual({
|
||||
temperature: 0.8,
|
||||
topK: 40,
|
||||
});
|
||||
});
|
||||
|
||||
it('should handle all generation config fields', () => {
|
||||
const genaiReq: GenerateContentParameters = {
|
||||
model: 'gemini-pro',
|
||||
contents: 'Hello',
|
||||
config: {
|
||||
temperature: 0.1,
|
||||
topP: 0.2,
|
||||
topK: 3,
|
||||
candidateCount: 4,
|
||||
maxOutputTokens: 5,
|
||||
stopSequences: ['a'],
|
||||
responseLogprobs: true,
|
||||
logprobs: 6,
|
||||
presencePenalty: 0.7,
|
||||
frequencyPenalty: 0.8,
|
||||
seed: 9,
|
||||
responseMimeType: 'application/json',
|
||||
},
|
||||
};
|
||||
const codeAssistReq = toGenerateContentRequest(
|
||||
genaiReq,
|
||||
'my-prompt',
|
||||
'my-project',
|
||||
'my-session',
|
||||
);
|
||||
expect(codeAssistReq.request.generationConfig).toEqual({
|
||||
temperature: 0.1,
|
||||
topP: 0.2,
|
||||
topK: 3,
|
||||
candidateCount: 4,
|
||||
maxOutputTokens: 5,
|
||||
stopSequences: ['a'],
|
||||
responseLogprobs: true,
|
||||
logprobs: 6,
|
||||
presencePenalty: 0.7,
|
||||
frequencyPenalty: 0.8,
|
||||
seed: 9,
|
||||
responseMimeType: 'application/json',
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('fromCodeAssistResponse', () => {
|
||||
it('should convert a simple response', () => {
|
||||
const codeAssistRes: CaGenerateContentResponse = {
|
||||
response: {
|
||||
candidates: [
|
||||
{
|
||||
index: 0,
|
||||
content: {
|
||||
role: 'model',
|
||||
parts: [{ text: 'Hi there!' }],
|
||||
},
|
||||
finishReason: FinishReason.STOP,
|
||||
safetyRatings: [],
|
||||
},
|
||||
],
|
||||
},
|
||||
};
|
||||
const genaiRes = fromGenerateContentResponse(codeAssistRes);
|
||||
expect(genaiRes).toBeInstanceOf(GenerateContentResponse);
|
||||
expect(genaiRes.candidates).toEqual(codeAssistRes.response.candidates);
|
||||
});
|
||||
|
||||
it('should handle prompt feedback and usage metadata', () => {
|
||||
const codeAssistRes: CaGenerateContentResponse = {
|
||||
response: {
|
||||
candidates: [],
|
||||
promptFeedback: {
|
||||
blockReason: BlockedReason.SAFETY,
|
||||
safetyRatings: [],
|
||||
},
|
||||
usageMetadata: {
|
||||
promptTokenCount: 10,
|
||||
candidatesTokenCount: 20,
|
||||
totalTokenCount: 30,
|
||||
},
|
||||
},
|
||||
};
|
||||
const genaiRes = fromGenerateContentResponse(codeAssistRes);
|
||||
expect(genaiRes.promptFeedback).toEqual(
|
||||
codeAssistRes.response.promptFeedback,
|
||||
);
|
||||
expect(genaiRes.usageMetadata).toEqual(
|
||||
codeAssistRes.response.usageMetadata,
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle automatic function calling history', () => {
|
||||
const codeAssistRes: CaGenerateContentResponse = {
|
||||
response: {
|
||||
candidates: [],
|
||||
automaticFunctionCallingHistory: [
|
||||
{
|
||||
role: 'model',
|
||||
parts: [
|
||||
{
|
||||
functionCall: {
|
||||
name: 'test_function',
|
||||
args: {
|
||||
foo: 'bar',
|
||||
},
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
],
|
||||
},
|
||||
};
|
||||
const genaiRes = fromGenerateContentResponse(codeAssistRes);
|
||||
expect(genaiRes.automaticFunctionCallingHistory).toEqual(
|
||||
codeAssistRes.response.automaticFunctionCallingHistory,
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle modelVersion', () => {
|
||||
const codeAssistRes: CaGenerateContentResponse = {
|
||||
response: {
|
||||
candidates: [],
|
||||
modelVersion: 'qwen3-coder-plus',
|
||||
},
|
||||
};
|
||||
const genaiRes = fromGenerateContentResponse(codeAssistRes);
|
||||
expect(genaiRes.modelVersion).toEqual('qwen3-coder-plus');
|
||||
});
|
||||
});
|
||||
|
||||
describe('toContents', () => {
|
||||
it('should handle Content', () => {
|
||||
const content: ContentListUnion = {
|
||||
role: 'user',
|
||||
parts: [{ text: 'hello' }],
|
||||
};
|
||||
expect(toContents(content)).toEqual([
|
||||
{ role: 'user', parts: [{ text: 'hello' }] },
|
||||
]);
|
||||
});
|
||||
|
||||
it('should handle array of Contents', () => {
|
||||
const contents: ContentListUnion = [
|
||||
{ role: 'user', parts: [{ text: 'hello' }] },
|
||||
{ role: 'model', parts: [{ text: 'hi' }] },
|
||||
];
|
||||
expect(toContents(contents)).toEqual([
|
||||
{ role: 'user', parts: [{ text: 'hello' }] },
|
||||
{ role: 'model', parts: [{ text: 'hi' }] },
|
||||
]);
|
||||
});
|
||||
|
||||
it('should handle Part', () => {
|
||||
const part: ContentListUnion = { text: 'a part' };
|
||||
expect(toContents(part)).toEqual([
|
||||
{ role: 'user', parts: [{ text: 'a part' }] },
|
||||
]);
|
||||
});
|
||||
|
||||
it('should handle array of Parts', () => {
|
||||
const parts = [{ text: 'part 1' }, 'part 2'];
|
||||
expect(toContents(parts)).toEqual([
|
||||
{ role: 'user', parts: [{ text: 'part 1' }] },
|
||||
{ role: 'user', parts: [{ text: 'part 2' }] },
|
||||
]);
|
||||
});
|
||||
|
||||
it('should handle string', () => {
|
||||
const str: ContentListUnion = 'a string';
|
||||
expect(toContents(str)).toEqual([
|
||||
{ role: 'user', parts: [{ text: 'a string' }] },
|
||||
]);
|
||||
});
|
||||
|
||||
it('should handle array of strings', () => {
|
||||
const strings: ContentListUnion = ['string 1', 'string 2'];
|
||||
expect(toContents(strings)).toEqual([
|
||||
{ role: 'user', parts: [{ text: 'string 1' }] },
|
||||
{ role: 'user', parts: [{ text: 'string 2' }] },
|
||||
]);
|
||||
});
|
||||
|
||||
it('should convert thought parts to text parts for API compatibility', () => {
|
||||
const contentWithThought: ContentListUnion = {
|
||||
role: 'model',
|
||||
parts: [
|
||||
{ text: 'regular text' },
|
||||
{ thought: 'thinking about the problem' } as Part & {
|
||||
thought: string;
|
||||
},
|
||||
{ text: 'more text' },
|
||||
],
|
||||
};
|
||||
expect(toContents(contentWithThought)).toEqual([
|
||||
{
|
||||
role: 'model',
|
||||
parts: [
|
||||
{ text: 'regular text' },
|
||||
{ text: '[Thought: thinking about the problem]' },
|
||||
{ text: 'more text' },
|
||||
],
|
||||
},
|
||||
]);
|
||||
});
|
||||
|
||||
it('should combine text and thought for text parts with thoughts', () => {
|
||||
const contentWithTextAndThought: ContentListUnion = {
|
||||
role: 'model',
|
||||
parts: [
|
||||
{
|
||||
text: 'Here is my response',
|
||||
thought: 'I need to be careful here',
|
||||
} as Part & { thought: string },
|
||||
],
|
||||
};
|
||||
expect(toContents(contentWithTextAndThought)).toEqual([
|
||||
{
|
||||
role: 'model',
|
||||
parts: [
|
||||
{
|
||||
text: 'Here is my response\n[Thought: I need to be careful here]',
|
||||
},
|
||||
],
|
||||
},
|
||||
]);
|
||||
});
|
||||
|
||||
it('should preserve non-thought properties while removing thought', () => {
|
||||
const contentWithComplexPart: ContentListUnion = {
|
||||
role: 'model',
|
||||
parts: [
|
||||
{
|
||||
functionCall: { name: 'calculate', args: { x: 5, y: 10 } },
|
||||
thought: 'Performing calculation',
|
||||
} as Part & { thought: string },
|
||||
],
|
||||
};
|
||||
expect(toContents(contentWithComplexPart)).toEqual([
|
||||
{
|
||||
role: 'model',
|
||||
parts: [
|
||||
{
|
||||
functionCall: { name: 'calculate', args: { x: 5, y: 10 } },
|
||||
},
|
||||
],
|
||||
},
|
||||
]);
|
||||
});
|
||||
|
||||
it('should convert invalid text content to valid text part with thought', () => {
|
||||
const contentWithInvalidText: ContentListUnion = {
|
||||
role: 'model',
|
||||
parts: [
|
||||
{
|
||||
text: 123, // Invalid - should be string
|
||||
thought: 'Processing number',
|
||||
} as Part & { thought: string; text: number },
|
||||
],
|
||||
};
|
||||
expect(toContents(contentWithInvalidText)).toEqual([
|
||||
{
|
||||
role: 'model',
|
||||
parts: [
|
||||
{
|
||||
text: '123\n[Thought: Processing number]',
|
||||
},
|
||||
],
|
||||
},
|
||||
]);
|
||||
});
|
||||
});
|
||||
});
|
||||
285
packages/core/src/code_assist/converter.ts
Normal file
285
packages/core/src/code_assist/converter.ts
Normal file
@@ -0,0 +1,285 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import type {
|
||||
Content,
|
||||
ContentListUnion,
|
||||
ContentUnion,
|
||||
GenerateContentConfig,
|
||||
GenerateContentParameters,
|
||||
CountTokensParameters,
|
||||
CountTokensResponse,
|
||||
GenerationConfigRoutingConfig,
|
||||
MediaResolution,
|
||||
Candidate,
|
||||
ModelSelectionConfig,
|
||||
GenerateContentResponsePromptFeedback,
|
||||
GenerateContentResponseUsageMetadata,
|
||||
Part,
|
||||
SafetySetting,
|
||||
PartUnion,
|
||||
SpeechConfigUnion,
|
||||
ThinkingConfig,
|
||||
ToolListUnion,
|
||||
ToolConfig,
|
||||
} from '@google/genai';
|
||||
import { GenerateContentResponse } from '@google/genai';
|
||||
|
||||
export interface CAGenerateContentRequest {
|
||||
model: string;
|
||||
project?: string;
|
||||
user_prompt_id?: string;
|
||||
request: VertexGenerateContentRequest;
|
||||
}
|
||||
|
||||
interface VertexGenerateContentRequest {
|
||||
contents: Content[];
|
||||
systemInstruction?: Content;
|
||||
cachedContent?: string;
|
||||
tools?: ToolListUnion;
|
||||
toolConfig?: ToolConfig;
|
||||
labels?: Record<string, string>;
|
||||
safetySettings?: SafetySetting[];
|
||||
generationConfig?: VertexGenerationConfig;
|
||||
session_id?: string;
|
||||
}
|
||||
|
||||
interface VertexGenerationConfig {
|
||||
temperature?: number;
|
||||
topP?: number;
|
||||
topK?: number;
|
||||
candidateCount?: number;
|
||||
maxOutputTokens?: number;
|
||||
stopSequences?: string[];
|
||||
responseLogprobs?: boolean;
|
||||
logprobs?: number;
|
||||
presencePenalty?: number;
|
||||
frequencyPenalty?: number;
|
||||
seed?: number;
|
||||
responseMimeType?: string;
|
||||
responseJsonSchema?: unknown;
|
||||
responseSchema?: unknown;
|
||||
routingConfig?: GenerationConfigRoutingConfig;
|
||||
modelSelectionConfig?: ModelSelectionConfig;
|
||||
responseModalities?: string[];
|
||||
mediaResolution?: MediaResolution;
|
||||
speechConfig?: SpeechConfigUnion;
|
||||
audioTimestamp?: boolean;
|
||||
thinkingConfig?: ThinkingConfig;
|
||||
}
|
||||
|
||||
export interface CaGenerateContentResponse {
|
||||
response: VertexGenerateContentResponse;
|
||||
}
|
||||
|
||||
interface VertexGenerateContentResponse {
|
||||
candidates: Candidate[];
|
||||
automaticFunctionCallingHistory?: Content[];
|
||||
promptFeedback?: GenerateContentResponsePromptFeedback;
|
||||
usageMetadata?: GenerateContentResponseUsageMetadata;
|
||||
modelVersion?: string;
|
||||
}
|
||||
|
||||
export interface CaCountTokenRequest {
|
||||
request: VertexCountTokenRequest;
|
||||
}
|
||||
|
||||
interface VertexCountTokenRequest {
|
||||
model: string;
|
||||
contents: Content[];
|
||||
}
|
||||
|
||||
export interface CaCountTokenResponse {
|
||||
totalTokens: number;
|
||||
}
|
||||
|
||||
export function toCountTokenRequest(
|
||||
req: CountTokensParameters,
|
||||
): CaCountTokenRequest {
|
||||
return {
|
||||
request: {
|
||||
model: 'models/' + req.model,
|
||||
contents: toContents(req.contents),
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
export function fromCountTokenResponse(
|
||||
res: CaCountTokenResponse,
|
||||
): CountTokensResponse {
|
||||
return {
|
||||
totalTokens: res.totalTokens,
|
||||
};
|
||||
}
|
||||
|
||||
export function toGenerateContentRequest(
|
||||
req: GenerateContentParameters,
|
||||
userPromptId: string,
|
||||
project?: string,
|
||||
sessionId?: string,
|
||||
): CAGenerateContentRequest {
|
||||
return {
|
||||
model: req.model,
|
||||
project,
|
||||
user_prompt_id: userPromptId,
|
||||
request: toVertexGenerateContentRequest(req, sessionId),
|
||||
};
|
||||
}
|
||||
|
||||
export function fromGenerateContentResponse(
|
||||
res: CaGenerateContentResponse,
|
||||
): GenerateContentResponse {
|
||||
const inres = res.response;
|
||||
const out = new GenerateContentResponse();
|
||||
out.candidates = inres.candidates;
|
||||
out.automaticFunctionCallingHistory = inres.automaticFunctionCallingHistory;
|
||||
out.promptFeedback = inres.promptFeedback;
|
||||
out.usageMetadata = inres.usageMetadata;
|
||||
out.modelVersion = inres.modelVersion;
|
||||
return out;
|
||||
}
|
||||
|
||||
function toVertexGenerateContentRequest(
|
||||
req: GenerateContentParameters,
|
||||
sessionId?: string,
|
||||
): VertexGenerateContentRequest {
|
||||
return {
|
||||
contents: toContents(req.contents),
|
||||
systemInstruction: maybeToContent(req.config?.systemInstruction),
|
||||
cachedContent: req.config?.cachedContent,
|
||||
tools: req.config?.tools,
|
||||
toolConfig: req.config?.toolConfig,
|
||||
labels: req.config?.labels,
|
||||
safetySettings: req.config?.safetySettings,
|
||||
generationConfig: toVertexGenerationConfig(req.config),
|
||||
session_id: sessionId,
|
||||
};
|
||||
}
|
||||
|
||||
export function toContents(contents: ContentListUnion): Content[] {
|
||||
if (Array.isArray(contents)) {
|
||||
// it's a Content[] or a PartsUnion[]
|
||||
return contents.map(toContent);
|
||||
}
|
||||
// it's a Content or a PartsUnion
|
||||
return [toContent(contents)];
|
||||
}
|
||||
|
||||
function maybeToContent(content?: ContentUnion): Content | undefined {
|
||||
if (!content) {
|
||||
return undefined;
|
||||
}
|
||||
return toContent(content);
|
||||
}
|
||||
|
||||
function toContent(content: ContentUnion): Content {
|
||||
if (Array.isArray(content)) {
|
||||
// it's a PartsUnion[]
|
||||
return {
|
||||
role: 'user',
|
||||
parts: toParts(content),
|
||||
};
|
||||
}
|
||||
if (typeof content === 'string') {
|
||||
// it's a string
|
||||
return {
|
||||
role: 'user',
|
||||
parts: [{ text: content }],
|
||||
};
|
||||
}
|
||||
if ('parts' in content) {
|
||||
// it's a Content - process parts to handle thought filtering
|
||||
return {
|
||||
...content,
|
||||
parts: content.parts
|
||||
? toParts(content.parts.filter((p) => p != null))
|
||||
: [],
|
||||
};
|
||||
}
|
||||
// it's a Part
|
||||
return {
|
||||
role: 'user',
|
||||
parts: [toPart(content as Part)],
|
||||
};
|
||||
}
|
||||
|
||||
export function toParts(parts: PartUnion[]): Part[] {
|
||||
return parts.map(toPart);
|
||||
}
|
||||
|
||||
function toPart(part: PartUnion): Part {
|
||||
if (typeof part === 'string') {
|
||||
// it's a string
|
||||
return { text: part };
|
||||
}
|
||||
|
||||
// Handle thought parts for CountToken API compatibility
|
||||
// The CountToken API expects parts to have certain required "oneof" fields initialized,
|
||||
// but thought parts don't conform to this schema and cause API failures
|
||||
if ('thought' in part && part.thought) {
|
||||
const thoughtText = `[Thought: ${part.thought}]`;
|
||||
|
||||
const newPart = { ...part };
|
||||
delete (newPart as Record<string, unknown>)['thought'];
|
||||
|
||||
const hasApiContent =
|
||||
'functionCall' in newPart ||
|
||||
'functionResponse' in newPart ||
|
||||
'inlineData' in newPart ||
|
||||
'fileData' in newPart;
|
||||
|
||||
if (hasApiContent) {
|
||||
// It's a functionCall or other non-text part. Just strip the thought.
|
||||
return newPart;
|
||||
}
|
||||
|
||||
// If no other valid API content, this must be a text part.
|
||||
// Combine existing text (if any) with the thought, preserving other properties.
|
||||
const text = (newPart as { text?: unknown }).text;
|
||||
const existingText = text ? String(text) : '';
|
||||
const combinedText = existingText
|
||||
? `${existingText}\n${thoughtText}`
|
||||
: thoughtText;
|
||||
|
||||
return {
|
||||
...newPart,
|
||||
text: combinedText,
|
||||
};
|
||||
}
|
||||
|
||||
return part;
|
||||
}
|
||||
|
||||
function toVertexGenerationConfig(
|
||||
config?: GenerateContentConfig,
|
||||
): VertexGenerationConfig | undefined {
|
||||
if (!config) {
|
||||
return undefined;
|
||||
}
|
||||
return {
|
||||
temperature: config.temperature,
|
||||
topP: config.topP,
|
||||
topK: config.topK,
|
||||
candidateCount: config.candidateCount,
|
||||
maxOutputTokens: config.maxOutputTokens,
|
||||
stopSequences: config.stopSequences,
|
||||
responseLogprobs: config.responseLogprobs,
|
||||
logprobs: config.logprobs,
|
||||
presencePenalty: config.presencePenalty,
|
||||
frequencyPenalty: config.frequencyPenalty,
|
||||
seed: config.seed,
|
||||
responseMimeType: config.responseMimeType,
|
||||
responseSchema: config.responseSchema,
|
||||
responseJsonSchema: config.responseJsonSchema,
|
||||
routingConfig: config.routingConfig,
|
||||
modelSelectionConfig: config.modelSelectionConfig,
|
||||
responseModalities: config.responseModalities,
|
||||
mediaResolution: config.mediaResolution,
|
||||
speechConfig: config.speechConfig,
|
||||
audioTimestamp: config.audioTimestamp,
|
||||
thinkingConfig: config.thinkingConfig,
|
||||
};
|
||||
}
|
||||
217
packages/core/src/code_assist/oauth-credential-storage.test.ts
Normal file
217
packages/core/src/code_assist/oauth-credential-storage.test.ts
Normal file
@@ -0,0 +1,217 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { type Credentials } from 'google-auth-library';
|
||||
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
|
||||
import { OAuthCredentialStorage } from './oauth-credential-storage.js';
|
||||
import type { OAuthCredentials } from '../mcp/token-storage/types.js';
|
||||
|
||||
import * as path from 'node:path';
|
||||
import * as os from 'node:os';
|
||||
import { promises as fs } from 'node:fs';
|
||||
|
||||
// Mock external dependencies
|
||||
const mockHybridTokenStorage = vi.hoisted(() => ({
|
||||
getCredentials: vi.fn(),
|
||||
setCredentials: vi.fn(),
|
||||
deleteCredentials: vi.fn(),
|
||||
}));
|
||||
vi.mock('../mcp/token-storage/hybrid-token-storage.js', () => ({
|
||||
HybridTokenStorage: vi.fn(() => mockHybridTokenStorage),
|
||||
}));
|
||||
vi.mock('node:fs', () => ({
|
||||
promises: {
|
||||
readFile: vi.fn(),
|
||||
rm: vi.fn(),
|
||||
},
|
||||
}));
|
||||
vi.mock('node:os');
|
||||
vi.mock('node:path');
|
||||
|
||||
describe('OAuthCredentialStorage', () => {
|
||||
const mockCredentials: Credentials = {
|
||||
access_token: 'mock_access_token',
|
||||
refresh_token: 'mock_refresh_token',
|
||||
expiry_date: Date.now() + 3600 * 1000,
|
||||
token_type: 'Bearer',
|
||||
scope: 'email profile',
|
||||
};
|
||||
|
||||
const mockMcpCredentials: OAuthCredentials = {
|
||||
serverName: 'main-account',
|
||||
token: {
|
||||
accessToken: 'mock_access_token',
|
||||
refreshToken: 'mock_refresh_token',
|
||||
tokenType: 'Bearer',
|
||||
scope: 'email profile',
|
||||
expiresAt: mockCredentials.expiry_date!,
|
||||
},
|
||||
updatedAt: expect.any(Number),
|
||||
};
|
||||
|
||||
const oldFilePath = '/mock/home/.qwen/oauth.json';
|
||||
|
||||
beforeEach(() => {
|
||||
vi.spyOn(mockHybridTokenStorage, 'getCredentials').mockResolvedValue(null);
|
||||
vi.spyOn(mockHybridTokenStorage, 'setCredentials').mockResolvedValue(
|
||||
undefined,
|
||||
);
|
||||
vi.spyOn(mockHybridTokenStorage, 'deleteCredentials').mockResolvedValue(
|
||||
undefined,
|
||||
);
|
||||
|
||||
vi.spyOn(fs, 'readFile').mockRejectedValue(new Error('File not found'));
|
||||
vi.spyOn(fs, 'rm').mockResolvedValue(undefined);
|
||||
|
||||
vi.spyOn(os, 'homedir').mockReturnValue('/mock/home');
|
||||
vi.spyOn(path, 'join').mockReturnValue(oldFilePath);
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.restoreAllMocks();
|
||||
});
|
||||
|
||||
describe('loadCredentials', () => {
|
||||
it('should load credentials from HybridTokenStorage if available', async () => {
|
||||
vi.spyOn(mockHybridTokenStorage, 'getCredentials').mockResolvedValue(
|
||||
mockMcpCredentials,
|
||||
);
|
||||
|
||||
const result = await OAuthCredentialStorage.loadCredentials();
|
||||
|
||||
expect(mockHybridTokenStorage.getCredentials).toHaveBeenCalledWith(
|
||||
'main-account',
|
||||
);
|
||||
expect(result).toEqual(mockCredentials);
|
||||
});
|
||||
|
||||
it('should fallback to migrateFromFileStorage if no credentials in HybridTokenStorage', async () => {
|
||||
vi.spyOn(mockHybridTokenStorage, 'getCredentials').mockResolvedValue(
|
||||
null,
|
||||
);
|
||||
vi.spyOn(fs, 'readFile').mockResolvedValue(
|
||||
JSON.stringify(mockCredentials),
|
||||
);
|
||||
|
||||
const result = await OAuthCredentialStorage.loadCredentials();
|
||||
|
||||
expect(mockHybridTokenStorage.getCredentials).toHaveBeenCalledWith(
|
||||
'main-account',
|
||||
);
|
||||
expect(fs.readFile).toHaveBeenCalledWith(oldFilePath, 'utf-8');
|
||||
expect(mockHybridTokenStorage.setCredentials).toHaveBeenCalled(); // Verify credentials were saved
|
||||
expect(fs.rm).toHaveBeenCalledWith(oldFilePath, { force: true }); // Verify old file was removed
|
||||
expect(result).toEqual(mockCredentials);
|
||||
});
|
||||
|
||||
it('should return null if no credentials found and no old file to migrate', async () => {
|
||||
vi.spyOn(fs, 'readFile').mockRejectedValue({
|
||||
message: 'File not found',
|
||||
code: 'ENOENT',
|
||||
});
|
||||
|
||||
const result = await OAuthCredentialStorage.loadCredentials();
|
||||
|
||||
expect(result).toBeNull();
|
||||
});
|
||||
|
||||
it('should throw an error if loading fails', async () => {
|
||||
vi.spyOn(mockHybridTokenStorage, 'getCredentials').mockRejectedValue(
|
||||
new Error('Loading error'),
|
||||
);
|
||||
|
||||
await expect(OAuthCredentialStorage.loadCredentials()).rejects.toThrow(
|
||||
'Failed to load OAuth credentials',
|
||||
);
|
||||
});
|
||||
|
||||
it('should throw an error if read file fails', async () => {
|
||||
vi.spyOn(mockHybridTokenStorage, 'getCredentials').mockResolvedValue(
|
||||
null,
|
||||
);
|
||||
vi.spyOn(fs, 'readFile').mockRejectedValue(
|
||||
new Error('Permission denied'),
|
||||
);
|
||||
|
||||
await expect(OAuthCredentialStorage.loadCredentials()).rejects.toThrow(
|
||||
'Failed to load OAuth credentials',
|
||||
);
|
||||
});
|
||||
|
||||
it('should not throw error if migration file removal failed', async () => {
|
||||
vi.spyOn(mockHybridTokenStorage, 'getCredentials').mockResolvedValue(
|
||||
null,
|
||||
);
|
||||
vi.spyOn(fs, 'readFile').mockResolvedValue(
|
||||
JSON.stringify(mockCredentials),
|
||||
);
|
||||
vi.spyOn(OAuthCredentialStorage, 'saveCredentials').mockResolvedValue(
|
||||
undefined,
|
||||
);
|
||||
vi.spyOn(fs, 'rm').mockRejectedValue(new Error('Deletion failed'));
|
||||
|
||||
const result = await OAuthCredentialStorage.loadCredentials();
|
||||
|
||||
expect(result).toEqual(mockCredentials);
|
||||
});
|
||||
});
|
||||
|
||||
describe('saveCredentials', () => {
|
||||
it('should save credentials to HybridTokenStorage', async () => {
|
||||
await OAuthCredentialStorage.saveCredentials(mockCredentials);
|
||||
|
||||
expect(mockHybridTokenStorage.setCredentials).toHaveBeenCalledWith(
|
||||
mockMcpCredentials,
|
||||
);
|
||||
});
|
||||
|
||||
it('should throw an error if access_token is missing', async () => {
|
||||
const invalidCredentials: Credentials = {
|
||||
...mockCredentials,
|
||||
access_token: undefined,
|
||||
};
|
||||
await expect(
|
||||
OAuthCredentialStorage.saveCredentials(invalidCredentials),
|
||||
).rejects.toThrow(
|
||||
'Attempted to save credentials without an access token.',
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('clearCredentials', () => {
|
||||
it('should delete credentials from HybridTokenStorage', async () => {
|
||||
await OAuthCredentialStorage.clearCredentials();
|
||||
|
||||
expect(mockHybridTokenStorage.deleteCredentials).toHaveBeenCalledWith(
|
||||
'main-account',
|
||||
);
|
||||
});
|
||||
|
||||
it('should attempt to remove the old file-based storage', async () => {
|
||||
await OAuthCredentialStorage.clearCredentials();
|
||||
|
||||
expect(fs.rm).toHaveBeenCalledWith(oldFilePath, { force: true });
|
||||
});
|
||||
|
||||
it('should not throw an error if deleting old file fails', async () => {
|
||||
vi.spyOn(fs, 'rm').mockRejectedValue(new Error('File deletion failed'));
|
||||
|
||||
await expect(
|
||||
OAuthCredentialStorage.clearCredentials(),
|
||||
).resolves.toBeUndefined();
|
||||
});
|
||||
|
||||
it('should throw an error if clearing from HybridTokenStorage fails', async () => {
|
||||
vi.spyOn(mockHybridTokenStorage, 'deleteCredentials').mockRejectedValue(
|
||||
new Error('Deletion error'),
|
||||
);
|
||||
|
||||
await expect(OAuthCredentialStorage.clearCredentials()).rejects.toThrow(
|
||||
'Failed to clear OAuth credentials',
|
||||
);
|
||||
});
|
||||
});
|
||||
});
|
||||
130
packages/core/src/code_assist/oauth-credential-storage.ts
Normal file
130
packages/core/src/code_assist/oauth-credential-storage.ts
Normal file
@@ -0,0 +1,130 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { type Credentials } from 'google-auth-library';
|
||||
import { HybridTokenStorage } from '../mcp/token-storage/hybrid-token-storage.js';
|
||||
import { OAUTH_FILE } from '../config/storage.js';
|
||||
import type { OAuthCredentials } from '../mcp/token-storage/types.js';
|
||||
import * as path from 'node:path';
|
||||
import * as os from 'node:os';
|
||||
import { promises as fs } from 'node:fs';
|
||||
|
||||
const QWEN_DIR = '.qwen';
|
||||
const KEYCHAIN_SERVICE_NAME = 'qwen-code-oauth';
|
||||
const MAIN_ACCOUNT_KEY = 'main-account';
|
||||
|
||||
export class OAuthCredentialStorage {
|
||||
private static storage: HybridTokenStorage = new HybridTokenStorage(
|
||||
KEYCHAIN_SERVICE_NAME,
|
||||
);
|
||||
|
||||
/**
|
||||
* Load cached OAuth credentials
|
||||
*/
|
||||
static async loadCredentials(): Promise<Credentials | null> {
|
||||
try {
|
||||
const credentials = await this.storage.getCredentials(MAIN_ACCOUNT_KEY);
|
||||
|
||||
if (credentials?.token) {
|
||||
const { accessToken, refreshToken, expiresAt, tokenType, scope } =
|
||||
credentials.token;
|
||||
// Convert from OAuthCredentials format to Google Credentials format
|
||||
const googleCreds: Credentials = {
|
||||
access_token: accessToken,
|
||||
refresh_token: refreshToken || undefined,
|
||||
token_type: tokenType || undefined,
|
||||
scope: scope || undefined,
|
||||
};
|
||||
|
||||
if (expiresAt) {
|
||||
googleCreds.expiry_date = expiresAt;
|
||||
}
|
||||
|
||||
return googleCreds;
|
||||
}
|
||||
|
||||
// Fallback: Try to migrate from old file-based storage
|
||||
return await this.migrateFromFileStorage();
|
||||
} catch (error: unknown) {
|
||||
console.error(error);
|
||||
throw new Error('Failed to load OAuth credentials');
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Save OAuth credentials
|
||||
*/
|
||||
static async saveCredentials(credentials: Credentials): Promise<void> {
|
||||
if (!credentials.access_token) {
|
||||
throw new Error('Attempted to save credentials without an access token.');
|
||||
}
|
||||
|
||||
// Convert Google Credentials to OAuthCredentials format
|
||||
const mcpCredentials: OAuthCredentials = {
|
||||
serverName: MAIN_ACCOUNT_KEY,
|
||||
token: {
|
||||
accessToken: credentials.access_token,
|
||||
refreshToken: credentials.refresh_token || undefined,
|
||||
tokenType: credentials.token_type || 'Bearer',
|
||||
scope: credentials.scope || undefined,
|
||||
expiresAt: credentials.expiry_date || undefined,
|
||||
},
|
||||
updatedAt: Date.now(),
|
||||
};
|
||||
|
||||
await this.storage.setCredentials(mcpCredentials);
|
||||
}
|
||||
|
||||
/**
|
||||
* Clear cached OAuth credentials
|
||||
*/
|
||||
static async clearCredentials(): Promise<void> {
|
||||
try {
|
||||
await this.storage.deleteCredentials(MAIN_ACCOUNT_KEY);
|
||||
|
||||
// Also try to remove the old file if it exists
|
||||
const oldFilePath = path.join(os.homedir(), QWEN_DIR, OAUTH_FILE);
|
||||
await fs.rm(oldFilePath, { force: true }).catch(() => {});
|
||||
} catch (error: unknown) {
|
||||
console.error(error);
|
||||
throw new Error('Failed to clear OAuth credentials');
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Migrate credentials from old file-based storage to keychain
|
||||
*/
|
||||
private static async migrateFromFileStorage(): Promise<Credentials | null> {
|
||||
const oldFilePath = path.join(os.homedir(), QWEN_DIR, OAUTH_FILE);
|
||||
|
||||
let credsJson: string;
|
||||
try {
|
||||
credsJson = await fs.readFile(oldFilePath, 'utf-8');
|
||||
} catch (error: unknown) {
|
||||
if (
|
||||
typeof error === 'object' &&
|
||||
error !== null &&
|
||||
'code' in error &&
|
||||
error.code === 'ENOENT'
|
||||
) {
|
||||
// File doesn't exist, so no migration.
|
||||
return null;
|
||||
}
|
||||
// Other read errors should propagate.
|
||||
throw error;
|
||||
}
|
||||
|
||||
const credentials = JSON.parse(credsJson) as Credentials;
|
||||
|
||||
// Save to new storage
|
||||
await this.saveCredentials(credentials);
|
||||
|
||||
// Remove old file after successful migration
|
||||
await fs.rm(oldFilePath, { force: true }).catch(() => {});
|
||||
|
||||
return credentials;
|
||||
}
|
||||
}
|
||||
1166
packages/core/src/code_assist/oauth2.test.ts
Normal file
1166
packages/core/src/code_assist/oauth2.test.ts
Normal file
File diff suppressed because it is too large
Load Diff
563
packages/core/src/code_assist/oauth2.ts
Normal file
563
packages/core/src/code_assist/oauth2.ts
Normal file
@@ -0,0 +1,563 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import type { Credentials } from 'google-auth-library';
|
||||
import {
|
||||
CodeChallengeMethod,
|
||||
Compute,
|
||||
OAuth2Client,
|
||||
} from 'google-auth-library';
|
||||
import crypto from 'node:crypto';
|
||||
import { promises as fs } from 'node:fs';
|
||||
import * as http from 'node:http';
|
||||
import * as net from 'node:net';
|
||||
import path from 'node:path';
|
||||
import readline from 'node:readline';
|
||||
import url from 'node:url';
|
||||
import open from 'open';
|
||||
import type { Config } from '../config/config.js';
|
||||
import { Storage } from '../config/storage.js';
|
||||
import { AuthType } from '../core/contentGenerator.js';
|
||||
import { FatalAuthenticationError, getErrorMessage } from '../utils/errors.js';
|
||||
import { UserAccountManager } from '../utils/userAccountManager.js';
|
||||
import { OAuthCredentialStorage } from './oauth-credential-storage.js';
|
||||
import { FORCE_ENCRYPTED_FILE_ENV_VAR } from '../mcp/token-storage/index.js';
|
||||
|
||||
const userAccountManager = new UserAccountManager();
|
||||
|
||||
// OAuth Client ID used to initiate OAuth2Client class.
|
||||
const OAUTH_CLIENT_ID =
|
||||
'681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com';
|
||||
|
||||
// OAuth Secret value used to initiate OAuth2Client class.
|
||||
// Note: It's ok to save this in git because this is an installed application
|
||||
// as described here: https://developers.google.com/identity/protocols/oauth2#installed
|
||||
// "The process results in a client ID and, in some cases, a client secret,
|
||||
// which you embed in the source code of your application. (In this context,
|
||||
// the client secret is obviously not treated as a secret.)"
|
||||
const OAUTH_CLIENT_SECRET = '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';
|
||||
|
||||
// OAuth Scopes for Cloud Code authorization.
|
||||
const OAUTH_SCOPE = [
|
||||
'https://www.googleapis.com/auth/cloud-platform',
|
||||
'https://www.googleapis.com/auth/userinfo.email',
|
||||
'https://www.googleapis.com/auth/userinfo.profile',
|
||||
];
|
||||
|
||||
const HTTP_REDIRECT = 301;
|
||||
const SIGN_IN_SUCCESS_URL =
|
||||
'https://developers.google.com/gemini-code-assist/auth_success_gemini';
|
||||
const SIGN_IN_FAILURE_URL =
|
||||
'https://developers.google.com/gemini-code-assist/auth_failure_gemini';
|
||||
|
||||
/**
|
||||
* An Authentication URL for updating the credentials of a Oauth2Client
|
||||
* as well as a promise that will resolve when the credentials have
|
||||
* been refreshed (or which throws error when refreshing credentials failed).
|
||||
*/
|
||||
export interface OauthWebLogin {
|
||||
authUrl: string;
|
||||
loginCompletePromise: Promise<void>;
|
||||
}
|
||||
|
||||
const oauthClientPromises = new Map<AuthType, Promise<OAuth2Client>>();
|
||||
|
||||
function getUseEncryptedStorageFlag() {
|
||||
return process.env[FORCE_ENCRYPTED_FILE_ENV_VAR] === 'true';
|
||||
}
|
||||
|
||||
async function initOauthClient(
|
||||
authType: AuthType,
|
||||
config: Config,
|
||||
): Promise<OAuth2Client> {
|
||||
const client = new OAuth2Client({
|
||||
clientId: OAUTH_CLIENT_ID,
|
||||
clientSecret: OAUTH_CLIENT_SECRET,
|
||||
transporterOptions: {
|
||||
proxy: config.getProxy(),
|
||||
},
|
||||
});
|
||||
const useEncryptedStorage = getUseEncryptedStorageFlag();
|
||||
|
||||
if (
|
||||
process.env['GOOGLE_GENAI_USE_GCA'] &&
|
||||
process.env['GOOGLE_CLOUD_ACCESS_TOKEN']
|
||||
) {
|
||||
client.setCredentials({
|
||||
access_token: process.env['GOOGLE_CLOUD_ACCESS_TOKEN'],
|
||||
});
|
||||
await fetchAndCacheUserInfo(client);
|
||||
return client;
|
||||
}
|
||||
|
||||
client.on('tokens', async (tokens: Credentials) => {
|
||||
if (useEncryptedStorage) {
|
||||
await OAuthCredentialStorage.saveCredentials(tokens);
|
||||
} else {
|
||||
await cacheCredentials(tokens);
|
||||
}
|
||||
});
|
||||
|
||||
// If there are cached creds on disk, they always take precedence
|
||||
if (await loadCachedCredentials(client)) {
|
||||
// Found valid cached credentials.
|
||||
// Check if we need to retrieve Google Account ID or Email
|
||||
if (!userAccountManager.getCachedGoogleAccount()) {
|
||||
try {
|
||||
await fetchAndCacheUserInfo(client);
|
||||
} catch (error) {
|
||||
// Non-fatal, continue with existing auth.
|
||||
console.warn('Failed to fetch user info:', getErrorMessage(error));
|
||||
}
|
||||
}
|
||||
console.log('Loaded cached credentials.');
|
||||
return client;
|
||||
}
|
||||
|
||||
// In Google Cloud Shell, we can use Application Default Credentials (ADC)
|
||||
// provided via its metadata server to authenticate non-interactively using
|
||||
// the identity of the user logged into Cloud Shell.
|
||||
if (authType === AuthType.CLOUD_SHELL) {
|
||||
try {
|
||||
console.log("Attempting to authenticate via Cloud Shell VM's ADC.");
|
||||
const computeClient = new Compute({
|
||||
// We can leave this empty, since the metadata server will provide
|
||||
// the service account email.
|
||||
});
|
||||
await computeClient.getAccessToken();
|
||||
console.log('Authentication successful.');
|
||||
|
||||
// Do not cache creds in this case; note that Compute client will handle its own refresh
|
||||
return computeClient;
|
||||
} catch (e) {
|
||||
throw new Error(
|
||||
`Could not authenticate using Cloud Shell credentials. Please select a different authentication method or ensure you are in a properly configured environment. Error: ${getErrorMessage(
|
||||
e,
|
||||
)}`,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
if (config.isBrowserLaunchSuppressed()) {
|
||||
let success = false;
|
||||
const maxRetries = 2;
|
||||
for (let i = 0; !success && i < maxRetries; i++) {
|
||||
success = await authWithUserCode(client);
|
||||
if (!success) {
|
||||
console.error(
|
||||
'\nFailed to authenticate with user code.',
|
||||
i === maxRetries - 1 ? '' : 'Retrying...\n',
|
||||
);
|
||||
}
|
||||
}
|
||||
if (!success) {
|
||||
throw new FatalAuthenticationError(
|
||||
'Failed to authenticate with user code.',
|
||||
);
|
||||
}
|
||||
} else {
|
||||
const webLogin = await authWithWeb(client);
|
||||
|
||||
console.log(
|
||||
`\n\nCode Assist login required.\n` +
|
||||
`Attempting to open authentication page in your browser.\n` +
|
||||
`Otherwise navigate to:\n\n${webLogin.authUrl}\n\n`,
|
||||
);
|
||||
try {
|
||||
// Attempt to open the authentication URL in the default browser.
|
||||
// We do not use the `wait` option here because the main script's execution
|
||||
// is already paused by `loginCompletePromise`, which awaits the server callback.
|
||||
const childProcess = await open(webLogin.authUrl);
|
||||
|
||||
// IMPORTANT: Attach an error handler to the returned child process.
|
||||
// Without this, if `open` fails to spawn a process (e.g., `xdg-open` is not found
|
||||
// in a minimal Docker container), it will emit an unhandled 'error' event,
|
||||
// causing the entire Node.js process to crash.
|
||||
childProcess.on('error', (error) => {
|
||||
console.error(
|
||||
'Failed to open browser automatically. Please try running again with NO_BROWSER=true set.',
|
||||
);
|
||||
console.error('Browser error details:', getErrorMessage(error));
|
||||
});
|
||||
} catch (err) {
|
||||
console.error(
|
||||
'An unexpected error occurred while trying to open the browser:',
|
||||
getErrorMessage(err),
|
||||
'\nThis might be due to browser compatibility issues or system configuration.',
|
||||
'\nPlease try running again with NO_BROWSER=true set for manual authentication.',
|
||||
);
|
||||
throw new FatalAuthenticationError(
|
||||
`Failed to open browser: ${getErrorMessage(err)}`,
|
||||
);
|
||||
}
|
||||
console.log('Waiting for authentication...');
|
||||
|
||||
// Add timeout to prevent infinite waiting when browser tab gets stuck
|
||||
const authTimeout = 5 * 60 * 1000; // 5 minutes timeout
|
||||
const timeoutPromise = new Promise<never>((_, reject) => {
|
||||
setTimeout(() => {
|
||||
reject(
|
||||
new FatalAuthenticationError(
|
||||
'Authentication timed out after 5 minutes. The browser tab may have gotten stuck in a loading state. ' +
|
||||
'Please try again or use NO_BROWSER=true for manual authentication.',
|
||||
),
|
||||
);
|
||||
}, authTimeout);
|
||||
});
|
||||
|
||||
await Promise.race([webLogin.loginCompletePromise, timeoutPromise]);
|
||||
}
|
||||
|
||||
return client;
|
||||
}
|
||||
|
||||
export async function getOauthClient(
|
||||
authType: AuthType,
|
||||
config: Config,
|
||||
): Promise<OAuth2Client> {
|
||||
if (!oauthClientPromises.has(authType)) {
|
||||
oauthClientPromises.set(authType, initOauthClient(authType, config));
|
||||
}
|
||||
return oauthClientPromises.get(authType)!;
|
||||
}
|
||||
|
||||
async function authWithUserCode(client: OAuth2Client): Promise<boolean> {
|
||||
const redirectUri = 'https://codeassist.google.com/authcode';
|
||||
const codeVerifier = await client.generateCodeVerifierAsync();
|
||||
const state = crypto.randomBytes(32).toString('hex');
|
||||
const authUrl: string = client.generateAuthUrl({
|
||||
redirect_uri: redirectUri,
|
||||
access_type: 'offline',
|
||||
scope: OAUTH_SCOPE,
|
||||
code_challenge_method: CodeChallengeMethod.S256,
|
||||
code_challenge: codeVerifier.codeChallenge,
|
||||
state,
|
||||
});
|
||||
console.log('Please visit the following URL to authorize the application:');
|
||||
console.log('');
|
||||
console.log(authUrl);
|
||||
console.log('');
|
||||
|
||||
const code = await new Promise<string>((resolve) => {
|
||||
const rl = readline.createInterface({
|
||||
input: process.stdin,
|
||||
output: process.stdout,
|
||||
});
|
||||
rl.question('Enter the authorization code: ', (code) => {
|
||||
rl.close();
|
||||
resolve(code.trim());
|
||||
});
|
||||
});
|
||||
|
||||
if (!code) {
|
||||
console.error('Authorization code is required.');
|
||||
return false;
|
||||
}
|
||||
|
||||
try {
|
||||
const { tokens } = await client.getToken({
|
||||
code,
|
||||
codeVerifier: codeVerifier.codeVerifier,
|
||||
redirect_uri: redirectUri,
|
||||
});
|
||||
client.setCredentials(tokens);
|
||||
} catch (error) {
|
||||
console.error(
|
||||
'Failed to authenticate with authorization code:',
|
||||
getErrorMessage(error),
|
||||
);
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
async function authWithWeb(client: OAuth2Client): Promise<OauthWebLogin> {
|
||||
const port = await getAvailablePort();
|
||||
// The hostname used for the HTTP server binding (e.g., '0.0.0.0' in Docker).
|
||||
const host = process.env['OAUTH_CALLBACK_HOST'] || 'localhost';
|
||||
// The `redirectUri` sent to Google's authorization server MUST use a loopback IP literal
|
||||
// (i.e., 'localhost' or '127.0.0.1'). This is a strict security policy for credentials of
|
||||
// type 'Desktop app' or 'Web application' (when using loopback flow) to mitigate
|
||||
// authorization code interception attacks.
|
||||
const redirectUri = `http://localhost:${port}/oauth2callback`;
|
||||
const state = crypto.randomBytes(32).toString('hex');
|
||||
const authUrl = client.generateAuthUrl({
|
||||
redirect_uri: redirectUri,
|
||||
access_type: 'offline',
|
||||
scope: OAUTH_SCOPE,
|
||||
state,
|
||||
});
|
||||
|
||||
const loginCompletePromise = new Promise<void>((resolve, reject) => {
|
||||
const server = http.createServer(async (req, res) => {
|
||||
try {
|
||||
if (req.url!.indexOf('/oauth2callback') === -1) {
|
||||
res.writeHead(HTTP_REDIRECT, { Location: SIGN_IN_FAILURE_URL });
|
||||
res.end();
|
||||
reject(
|
||||
new FatalAuthenticationError(
|
||||
'OAuth callback not received. Unexpected request: ' + req.url,
|
||||
),
|
||||
);
|
||||
}
|
||||
// acquire the code from the querystring, and close the web server.
|
||||
const qs = new url.URL(req.url!, 'http://localhost:3000').searchParams;
|
||||
if (qs.get('error')) {
|
||||
res.writeHead(HTTP_REDIRECT, { Location: SIGN_IN_FAILURE_URL });
|
||||
res.end();
|
||||
|
||||
const errorCode = qs.get('error');
|
||||
const errorDescription =
|
||||
qs.get('error_description') || 'No additional details provided';
|
||||
reject(
|
||||
new FatalAuthenticationError(
|
||||
`Google OAuth error: ${errorCode}. ${errorDescription}`,
|
||||
),
|
||||
);
|
||||
} else if (qs.get('state') !== state) {
|
||||
res.end('State mismatch. Possible CSRF attack');
|
||||
|
||||
reject(
|
||||
new FatalAuthenticationError(
|
||||
'OAuth state mismatch. Possible CSRF attack or browser session issue.',
|
||||
),
|
||||
);
|
||||
} else if (qs.get('code')) {
|
||||
try {
|
||||
const { tokens } = await client.getToken({
|
||||
code: qs.get('code')!,
|
||||
redirect_uri: redirectUri,
|
||||
});
|
||||
client.setCredentials(tokens);
|
||||
|
||||
// Retrieve and cache Google Account ID during authentication
|
||||
try {
|
||||
await fetchAndCacheUserInfo(client);
|
||||
} catch (error) {
|
||||
console.warn(
|
||||
'Failed to retrieve Google Account ID during authentication:',
|
||||
getErrorMessage(error),
|
||||
);
|
||||
// Don't fail the auth flow if Google Account ID retrieval fails
|
||||
}
|
||||
|
||||
res.writeHead(HTTP_REDIRECT, { Location: SIGN_IN_SUCCESS_URL });
|
||||
res.end();
|
||||
resolve();
|
||||
} catch (error) {
|
||||
res.writeHead(HTTP_REDIRECT, { Location: SIGN_IN_FAILURE_URL });
|
||||
res.end();
|
||||
reject(
|
||||
new FatalAuthenticationError(
|
||||
`Failed to exchange authorization code for tokens: ${getErrorMessage(error)}`,
|
||||
),
|
||||
);
|
||||
}
|
||||
} else {
|
||||
reject(
|
||||
new FatalAuthenticationError(
|
||||
'No authorization code received from Google OAuth. Please try authenticating again.',
|
||||
),
|
||||
);
|
||||
}
|
||||
} catch (e) {
|
||||
// Provide more specific error message for unexpected errors during OAuth flow
|
||||
if (e instanceof FatalAuthenticationError) {
|
||||
reject(e);
|
||||
} else {
|
||||
reject(
|
||||
new FatalAuthenticationError(
|
||||
`Unexpected error during OAuth authentication: ${getErrorMessage(e)}`,
|
||||
),
|
||||
);
|
||||
}
|
||||
} finally {
|
||||
server.close();
|
||||
}
|
||||
});
|
||||
|
||||
server.listen(port, host, () => {
|
||||
// Server started successfully
|
||||
});
|
||||
|
||||
server.on('error', (err) => {
|
||||
reject(
|
||||
new FatalAuthenticationError(
|
||||
`OAuth callback server error: ${getErrorMessage(err)}`,
|
||||
),
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
return {
|
||||
authUrl,
|
||||
loginCompletePromise,
|
||||
};
|
||||
}
|
||||
|
||||
export function getAvailablePort(): Promise<number> {
|
||||
return new Promise((resolve, reject) => {
|
||||
let port = 0;
|
||||
try {
|
||||
const portStr = process.env['OAUTH_CALLBACK_PORT'];
|
||||
if (portStr) {
|
||||
port = parseInt(portStr, 10);
|
||||
if (isNaN(port) || port <= 0 || port > 65535) {
|
||||
return reject(
|
||||
new Error(`Invalid value for OAUTH_CALLBACK_PORT: "${portStr}"`),
|
||||
);
|
||||
}
|
||||
return resolve(port);
|
||||
}
|
||||
const server = net.createServer();
|
||||
server.listen(0, () => {
|
||||
const address = server.address()! as net.AddressInfo;
|
||||
port = address.port;
|
||||
});
|
||||
server.on('listening', () => {
|
||||
server.close();
|
||||
server.unref();
|
||||
});
|
||||
server.on('error', (e) => reject(e));
|
||||
server.on('close', () => resolve(port));
|
||||
} catch (e) {
|
||||
reject(e);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
async function loadCachedCredentials(client: OAuth2Client): Promise<boolean> {
|
||||
const useEncryptedStorage = getUseEncryptedStorageFlag();
|
||||
if (useEncryptedStorage) {
|
||||
const credentials = await OAuthCredentialStorage.loadCredentials();
|
||||
if (credentials) {
|
||||
client.setCredentials(credentials);
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
const pathsToTry = [
|
||||
Storage.getOAuthCredsPath(),
|
||||
process.env['GOOGLE_APPLICATION_CREDENTIALS'],
|
||||
].filter((p): p is string => !!p);
|
||||
|
||||
for (const keyFile of pathsToTry) {
|
||||
try {
|
||||
const creds = await fs.readFile(keyFile, 'utf-8');
|
||||
client.setCredentials(JSON.parse(creds));
|
||||
|
||||
// This will verify locally that the credentials look good.
|
||||
const { token } = await client.getAccessToken();
|
||||
if (!token) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// This will check with the server to see if it hasn't been revoked.
|
||||
await client.getTokenInfo(token);
|
||||
|
||||
return true;
|
||||
} catch (error) {
|
||||
// Log specific error for debugging, but continue trying other paths
|
||||
console.debug(
|
||||
`Failed to load credentials from ${keyFile}:`,
|
||||
getErrorMessage(error),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
async function cacheCredentials(credentials: Credentials) {
|
||||
const filePath = Storage.getOAuthCredsPath();
|
||||
await fs.mkdir(path.dirname(filePath), { recursive: true });
|
||||
|
||||
const credString = JSON.stringify(credentials, null, 2);
|
||||
await fs.writeFile(filePath, credString, { mode: 0o600 });
|
||||
try {
|
||||
await fs.chmod(filePath, 0o600);
|
||||
} catch {
|
||||
/* empty */
|
||||
}
|
||||
}
|
||||
|
||||
export function clearOauthClientCache() {
|
||||
oauthClientPromises.clear();
|
||||
}
|
||||
|
||||
export async function clearCachedCredentialFile() {
|
||||
try {
|
||||
const useEncryptedStorage = getUseEncryptedStorageFlag();
|
||||
if (useEncryptedStorage) {
|
||||
await OAuthCredentialStorage.clearCredentials();
|
||||
} else {
|
||||
await fs.rm(Storage.getOAuthCredsPath(), { force: true });
|
||||
}
|
||||
// Clear the Google Account ID cache when credentials are cleared
|
||||
await userAccountManager.clearCachedGoogleAccount();
|
||||
// Clear the in-memory OAuth client cache to force re-authentication
|
||||
clearOauthClientCache();
|
||||
|
||||
/**
|
||||
* Also clear Qwen SharedTokenManager cache and credentials file to prevent stale credentials
|
||||
* when switching between auth types
|
||||
* TODO: We do not depend on code_assist, we'll have to build an independent auth-cleaning procedure.
|
||||
*/
|
||||
try {
|
||||
const { SharedTokenManager } = await import(
|
||||
'../qwen/sharedTokenManager.js'
|
||||
);
|
||||
const { clearQwenCredentials } = await import('../qwen/qwenOAuth2.js');
|
||||
|
||||
const sharedManager = SharedTokenManager.getInstance();
|
||||
sharedManager.clearCache();
|
||||
|
||||
await clearQwenCredentials();
|
||||
} catch (qwenError) {
|
||||
console.debug('Could not clear Qwen credentials:', qwenError);
|
||||
}
|
||||
} catch (e) {
|
||||
console.error('Failed to clear cached credentials:', e);
|
||||
}
|
||||
}
|
||||
|
||||
async function fetchAndCacheUserInfo(client: OAuth2Client): Promise<void> {
|
||||
try {
|
||||
const { token } = await client.getAccessToken();
|
||||
if (!token) {
|
||||
return;
|
||||
}
|
||||
|
||||
const response = await fetch(
|
||||
'https://www.googleapis.com/oauth2/v2/userinfo',
|
||||
{
|
||||
headers: {
|
||||
Authorization: `Bearer ${token}`,
|
||||
},
|
||||
},
|
||||
);
|
||||
|
||||
if (!response.ok) {
|
||||
console.error(
|
||||
'Failed to fetch user info:',
|
||||
response.status,
|
||||
response.statusText,
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
const userInfo = await response.json();
|
||||
await userAccountManager.cacheGoogleAccount(userInfo.email);
|
||||
} catch (error) {
|
||||
console.error('Error retrieving user info:', error);
|
||||
}
|
||||
}
|
||||
|
||||
// Helper to ensure test isolation
|
||||
export function resetOauthClientForTesting() {
|
||||
oauthClientPromises.clear();
|
||||
}
|
||||
255
packages/core/src/code_assist/server.test.ts
Normal file
255
packages/core/src/code_assist/server.test.ts
Normal file
@@ -0,0 +1,255 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { beforeEach, describe, it, expect, vi } from 'vitest';
|
||||
import { CodeAssistServer } from './server.js';
|
||||
import { OAuth2Client } from 'google-auth-library';
|
||||
import { UserTierId } from './types.js';
|
||||
|
||||
vi.mock('google-auth-library');
|
||||
|
||||
describe('CodeAssistServer', () => {
|
||||
beforeEach(() => {
|
||||
vi.resetAllMocks();
|
||||
});
|
||||
|
||||
it('should be able to be constructed', () => {
|
||||
const auth = new OAuth2Client();
|
||||
const server = new CodeAssistServer(
|
||||
auth,
|
||||
'test-project',
|
||||
{},
|
||||
'test-session',
|
||||
UserTierId.FREE,
|
||||
);
|
||||
expect(server).toBeInstanceOf(CodeAssistServer);
|
||||
});
|
||||
|
||||
it('should call the generateContent endpoint', async () => {
|
||||
const client = new OAuth2Client();
|
||||
const server = new CodeAssistServer(
|
||||
client,
|
||||
'test-project',
|
||||
{},
|
||||
'test-session',
|
||||
UserTierId.FREE,
|
||||
);
|
||||
const mockResponse = {
|
||||
response: {
|
||||
candidates: [
|
||||
{
|
||||
index: 0,
|
||||
content: {
|
||||
role: 'model',
|
||||
parts: [{ text: 'response' }],
|
||||
},
|
||||
finishReason: 'STOP',
|
||||
safetyRatings: [],
|
||||
},
|
||||
],
|
||||
},
|
||||
};
|
||||
vi.spyOn(server, 'requestPost').mockResolvedValue(mockResponse);
|
||||
|
||||
const response = await server.generateContent(
|
||||
{
|
||||
model: 'test-model',
|
||||
contents: [{ role: 'user', parts: [{ text: 'request' }] }],
|
||||
},
|
||||
'user-prompt-id',
|
||||
);
|
||||
|
||||
expect(server.requestPost).toHaveBeenCalledWith(
|
||||
'generateContent',
|
||||
expect.any(Object),
|
||||
undefined,
|
||||
);
|
||||
expect(response.candidates?.[0]?.content?.parts?.[0]?.text).toBe(
|
||||
'response',
|
||||
);
|
||||
});
|
||||
|
||||
it('should call the generateContentStream endpoint', async () => {
|
||||
const client = new OAuth2Client();
|
||||
const server = new CodeAssistServer(
|
||||
client,
|
||||
'test-project',
|
||||
{},
|
||||
'test-session',
|
||||
UserTierId.FREE,
|
||||
);
|
||||
const mockResponse = (async function* () {
|
||||
yield {
|
||||
response: {
|
||||
candidates: [
|
||||
{
|
||||
index: 0,
|
||||
content: {
|
||||
role: 'model',
|
||||
parts: [{ text: 'response' }],
|
||||
},
|
||||
finishReason: 'STOP',
|
||||
safetyRatings: [],
|
||||
},
|
||||
],
|
||||
},
|
||||
};
|
||||
})();
|
||||
vi.spyOn(server, 'requestStreamingPost').mockResolvedValue(mockResponse);
|
||||
|
||||
const stream = await server.generateContentStream(
|
||||
{
|
||||
model: 'test-model',
|
||||
contents: [{ role: 'user', parts: [{ text: 'request' }] }],
|
||||
},
|
||||
'user-prompt-id',
|
||||
);
|
||||
|
||||
for await (const res of stream) {
|
||||
expect(server.requestStreamingPost).toHaveBeenCalledWith(
|
||||
'streamGenerateContent',
|
||||
expect.any(Object),
|
||||
undefined,
|
||||
);
|
||||
expect(res.candidates?.[0]?.content?.parts?.[0]?.text).toBe('response');
|
||||
}
|
||||
});
|
||||
|
||||
it('should call the onboardUser endpoint', async () => {
|
||||
const client = new OAuth2Client();
|
||||
const server = new CodeAssistServer(
|
||||
client,
|
||||
'test-project',
|
||||
{},
|
||||
'test-session',
|
||||
UserTierId.FREE,
|
||||
);
|
||||
const mockResponse = {
|
||||
name: 'operations/123',
|
||||
done: true,
|
||||
};
|
||||
vi.spyOn(server, 'requestPost').mockResolvedValue(mockResponse);
|
||||
|
||||
const response = await server.onboardUser({
|
||||
tierId: 'test-tier',
|
||||
cloudaicompanionProject: 'test-project',
|
||||
metadata: {},
|
||||
});
|
||||
|
||||
expect(server.requestPost).toHaveBeenCalledWith(
|
||||
'onboardUser',
|
||||
expect.any(Object),
|
||||
);
|
||||
expect(response.name).toBe('operations/123');
|
||||
});
|
||||
|
||||
it('should call the loadCodeAssist endpoint', async () => {
|
||||
const client = new OAuth2Client();
|
||||
const server = new CodeAssistServer(
|
||||
client,
|
||||
'test-project',
|
||||
{},
|
||||
'test-session',
|
||||
UserTierId.FREE,
|
||||
);
|
||||
const mockResponse = {
|
||||
currentTier: {
|
||||
id: UserTierId.FREE,
|
||||
name: 'Free',
|
||||
description: 'free tier',
|
||||
},
|
||||
allowedTiers: [],
|
||||
ineligibleTiers: [],
|
||||
cloudaicompanionProject: 'projects/test',
|
||||
};
|
||||
vi.spyOn(server, 'requestPost').mockResolvedValue(mockResponse);
|
||||
|
||||
const response = await server.loadCodeAssist({
|
||||
metadata: {},
|
||||
});
|
||||
|
||||
expect(server.requestPost).toHaveBeenCalledWith(
|
||||
'loadCodeAssist',
|
||||
expect.any(Object),
|
||||
);
|
||||
expect(response).toEqual(mockResponse);
|
||||
});
|
||||
|
||||
it('should return 0 for countTokens', async () => {
|
||||
const client = new OAuth2Client();
|
||||
const server = new CodeAssistServer(
|
||||
client,
|
||||
'test-project',
|
||||
{},
|
||||
'test-session',
|
||||
UserTierId.FREE,
|
||||
);
|
||||
const mockResponse = {
|
||||
totalTokens: 100,
|
||||
};
|
||||
vi.spyOn(server, 'requestPost').mockResolvedValue(mockResponse);
|
||||
|
||||
const response = await server.countTokens({
|
||||
model: 'test-model',
|
||||
contents: [{ role: 'user', parts: [{ text: 'request' }] }],
|
||||
});
|
||||
expect(response.totalTokens).toBe(100);
|
||||
});
|
||||
|
||||
it('should throw an error for embedContent', async () => {
|
||||
const client = new OAuth2Client();
|
||||
const server = new CodeAssistServer(
|
||||
client,
|
||||
'test-project',
|
||||
{},
|
||||
'test-session',
|
||||
UserTierId.FREE,
|
||||
);
|
||||
await expect(
|
||||
server.embedContent({
|
||||
model: 'test-model',
|
||||
contents: [{ role: 'user', parts: [{ text: 'request' }] }],
|
||||
}),
|
||||
).rejects.toThrow();
|
||||
});
|
||||
|
||||
it('should handle VPC-SC errors when calling loadCodeAssist', async () => {
|
||||
const client = new OAuth2Client();
|
||||
const server = new CodeAssistServer(
|
||||
client,
|
||||
'test-project',
|
||||
{},
|
||||
'test-session',
|
||||
UserTierId.FREE,
|
||||
);
|
||||
const mockVpcScError = {
|
||||
response: {
|
||||
data: {
|
||||
error: {
|
||||
details: [
|
||||
{
|
||||
reason: 'SECURITY_POLICY_VIOLATED',
|
||||
},
|
||||
],
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
vi.spyOn(server, 'requestPost').mockRejectedValue(mockVpcScError);
|
||||
|
||||
const response = await server.loadCodeAssist({
|
||||
metadata: {},
|
||||
});
|
||||
|
||||
expect(server.requestPost).toHaveBeenCalledWith(
|
||||
'loadCodeAssist',
|
||||
expect.any(Object),
|
||||
);
|
||||
expect(response).toEqual({
|
||||
currentTier: { id: UserTierId.STANDARD },
|
||||
});
|
||||
});
|
||||
});
|
||||
253
packages/core/src/code_assist/server.ts
Normal file
253
packages/core/src/code_assist/server.ts
Normal file
@@ -0,0 +1,253 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import type { OAuth2Client } from 'google-auth-library';
|
||||
import type {
|
||||
CodeAssistGlobalUserSettingResponse,
|
||||
GoogleRpcResponse,
|
||||
LoadCodeAssistRequest,
|
||||
LoadCodeAssistResponse,
|
||||
LongRunningOperationResponse,
|
||||
OnboardUserRequest,
|
||||
SetCodeAssistGlobalUserSettingRequest,
|
||||
} from './types.js';
|
||||
import type {
|
||||
CountTokensParameters,
|
||||
CountTokensResponse,
|
||||
EmbedContentParameters,
|
||||
EmbedContentResponse,
|
||||
GenerateContentParameters,
|
||||
GenerateContentResponse,
|
||||
} from '@google/genai';
|
||||
import * as readline from 'node:readline';
|
||||
import type { ContentGenerator } from '../core/contentGenerator.js';
|
||||
import { UserTierId } from './types.js';
|
||||
import type {
|
||||
CaCountTokenResponse,
|
||||
CaGenerateContentResponse,
|
||||
} from './converter.js';
|
||||
import {
|
||||
fromCountTokenResponse,
|
||||
fromGenerateContentResponse,
|
||||
toCountTokenRequest,
|
||||
toGenerateContentRequest,
|
||||
} from './converter.js';
|
||||
|
||||
/** HTTP options to be used in each of the requests. */
|
||||
export interface HttpOptions {
|
||||
/** Additional HTTP headers to be sent with the request. */
|
||||
headers?: Record<string, string>;
|
||||
}
|
||||
|
||||
export const CODE_ASSIST_ENDPOINT = 'https://localhost:0'; // Disable Google Code Assist API Request
|
||||
export const CODE_ASSIST_API_VERSION = 'v1internal';
|
||||
|
||||
export class CodeAssistServer implements ContentGenerator {
|
||||
constructor(
|
||||
readonly client: OAuth2Client,
|
||||
readonly projectId?: string,
|
||||
readonly httpOptions: HttpOptions = {},
|
||||
readonly sessionId?: string,
|
||||
readonly userTier?: UserTierId,
|
||||
) {}
|
||||
|
||||
async generateContentStream(
|
||||
req: GenerateContentParameters,
|
||||
userPromptId: string,
|
||||
): Promise<AsyncGenerator<GenerateContentResponse>> {
|
||||
const resps = await this.requestStreamingPost<CaGenerateContentResponse>(
|
||||
'streamGenerateContent',
|
||||
toGenerateContentRequest(
|
||||
req,
|
||||
userPromptId,
|
||||
this.projectId,
|
||||
this.sessionId,
|
||||
),
|
||||
req.config?.abortSignal,
|
||||
);
|
||||
return (async function* (): AsyncGenerator<GenerateContentResponse> {
|
||||
for await (const resp of resps) {
|
||||
yield fromGenerateContentResponse(resp);
|
||||
}
|
||||
})();
|
||||
}
|
||||
|
||||
async generateContent(
|
||||
req: GenerateContentParameters,
|
||||
userPromptId: string,
|
||||
): Promise<GenerateContentResponse> {
|
||||
const resp = await this.requestPost<CaGenerateContentResponse>(
|
||||
'generateContent',
|
||||
toGenerateContentRequest(
|
||||
req,
|
||||
userPromptId,
|
||||
this.projectId,
|
||||
this.sessionId,
|
||||
),
|
||||
req.config?.abortSignal,
|
||||
);
|
||||
return fromGenerateContentResponse(resp);
|
||||
}
|
||||
|
||||
async onboardUser(
|
||||
req: OnboardUserRequest,
|
||||
): Promise<LongRunningOperationResponse> {
|
||||
return await this.requestPost<LongRunningOperationResponse>(
|
||||
'onboardUser',
|
||||
req,
|
||||
);
|
||||
}
|
||||
|
||||
async loadCodeAssist(
|
||||
req: LoadCodeAssistRequest,
|
||||
): Promise<LoadCodeAssistResponse> {
|
||||
try {
|
||||
return await this.requestPost<LoadCodeAssistResponse>(
|
||||
'loadCodeAssist',
|
||||
req,
|
||||
);
|
||||
} catch (e) {
|
||||
if (isVpcScAffectedUser(e)) {
|
||||
return {
|
||||
currentTier: { id: UserTierId.STANDARD },
|
||||
};
|
||||
} else {
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async getCodeAssistGlobalUserSetting(): Promise<CodeAssistGlobalUserSettingResponse> {
|
||||
return await this.requestGet<CodeAssistGlobalUserSettingResponse>(
|
||||
'getCodeAssistGlobalUserSetting',
|
||||
);
|
||||
}
|
||||
|
||||
async setCodeAssistGlobalUserSetting(
|
||||
req: SetCodeAssistGlobalUserSettingRequest,
|
||||
): Promise<CodeAssistGlobalUserSettingResponse> {
|
||||
return await this.requestPost<CodeAssistGlobalUserSettingResponse>(
|
||||
'setCodeAssistGlobalUserSetting',
|
||||
req,
|
||||
);
|
||||
}
|
||||
|
||||
async countTokens(req: CountTokensParameters): Promise<CountTokensResponse> {
|
||||
const resp = await this.requestPost<CaCountTokenResponse>(
|
||||
'countTokens',
|
||||
toCountTokenRequest(req),
|
||||
);
|
||||
return fromCountTokenResponse(resp);
|
||||
}
|
||||
|
||||
async embedContent(
|
||||
_req: EmbedContentParameters,
|
||||
): Promise<EmbedContentResponse> {
|
||||
throw Error();
|
||||
}
|
||||
|
||||
async requestPost<T>(
|
||||
method: string,
|
||||
req: object,
|
||||
signal?: AbortSignal,
|
||||
): Promise<T> {
|
||||
const res = await this.client.request({
|
||||
url: this.getMethodUrl(method),
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
...this.httpOptions.headers,
|
||||
},
|
||||
responseType: 'json',
|
||||
body: JSON.stringify(req),
|
||||
signal,
|
||||
});
|
||||
return res.data as T;
|
||||
}
|
||||
|
||||
async requestGet<T>(method: string, signal?: AbortSignal): Promise<T> {
|
||||
const res = await this.client.request({
|
||||
url: this.getMethodUrl(method),
|
||||
method: 'GET',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
...this.httpOptions.headers,
|
||||
},
|
||||
responseType: 'json',
|
||||
signal,
|
||||
});
|
||||
return res.data as T;
|
||||
}
|
||||
|
||||
async requestStreamingPost<T>(
|
||||
method: string,
|
||||
req: object,
|
||||
signal?: AbortSignal,
|
||||
): Promise<AsyncGenerator<T>> {
|
||||
const res = await this.client.request({
|
||||
url: this.getMethodUrl(method),
|
||||
method: 'POST',
|
||||
params: {
|
||||
alt: 'sse',
|
||||
},
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
...this.httpOptions.headers,
|
||||
},
|
||||
responseType: 'stream',
|
||||
body: JSON.stringify(req),
|
||||
signal,
|
||||
});
|
||||
|
||||
return (async function* (): AsyncGenerator<T> {
|
||||
const rl = readline.createInterface({
|
||||
input: res.data as NodeJS.ReadableStream,
|
||||
crlfDelay: Infinity, // Recognizes '\r\n' and '\n' as line breaks
|
||||
});
|
||||
|
||||
let bufferedLines: string[] = [];
|
||||
for await (const line of rl) {
|
||||
// blank lines are used to separate JSON objects in the stream
|
||||
if (line === '') {
|
||||
if (bufferedLines.length === 0) {
|
||||
continue; // no data to yield
|
||||
}
|
||||
yield JSON.parse(bufferedLines.join('\n')) as T;
|
||||
bufferedLines = []; // Reset the buffer after yielding
|
||||
} else if (line.startsWith('data: ')) {
|
||||
bufferedLines.push(line.slice(6).trim());
|
||||
} else {
|
||||
throw new Error(`Unexpected line format in response: ${line}`);
|
||||
}
|
||||
}
|
||||
})();
|
||||
}
|
||||
|
||||
getMethodUrl(method: string): string {
|
||||
const endpoint =
|
||||
process.env['CODE_ASSIST_ENDPOINT'] ?? CODE_ASSIST_ENDPOINT;
|
||||
return `${endpoint}/${CODE_ASSIST_API_VERSION}:${method}`;
|
||||
}
|
||||
}
|
||||
|
||||
function isVpcScAffectedUser(error: unknown): boolean {
|
||||
if (error && typeof error === 'object' && 'response' in error) {
|
||||
const gaxiosError = error as {
|
||||
response?: {
|
||||
data?: unknown;
|
||||
};
|
||||
};
|
||||
const response = gaxiosError.response?.data as
|
||||
| GoogleRpcResponse
|
||||
| undefined;
|
||||
if (Array.isArray(response?.error?.details)) {
|
||||
return response.error.details.some(
|
||||
(detail) => detail.reason === 'SECURITY_POLICY_VIOLATED',
|
||||
);
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
224
packages/core/src/code_assist/setup.test.ts
Normal file
224
packages/core/src/code_assist/setup.test.ts
Normal file
@@ -0,0 +1,224 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
|
||||
import { setupUser, ProjectIdRequiredError } from './setup.js';
|
||||
import { CodeAssistServer } from '../code_assist/server.js';
|
||||
import type { OAuth2Client } from 'google-auth-library';
|
||||
import type { GeminiUserTier } from './types.js';
|
||||
import { UserTierId } from './types.js';
|
||||
|
||||
vi.mock('../code_assist/server.js');
|
||||
|
||||
const mockPaidTier: GeminiUserTier = {
|
||||
id: UserTierId.STANDARD,
|
||||
name: 'paid',
|
||||
description: 'Paid tier',
|
||||
isDefault: true,
|
||||
};
|
||||
|
||||
const mockFreeTier: GeminiUserTier = {
|
||||
id: UserTierId.FREE,
|
||||
name: 'free',
|
||||
description: 'Free tier',
|
||||
isDefault: true,
|
||||
};
|
||||
|
||||
describe('setupUser for existing user', () => {
|
||||
let mockLoad: ReturnType<typeof vi.fn>;
|
||||
let mockOnboardUser: ReturnType<typeof vi.fn>;
|
||||
|
||||
beforeEach(() => {
|
||||
vi.resetAllMocks();
|
||||
mockLoad = vi.fn();
|
||||
mockOnboardUser = vi.fn().mockResolvedValue({
|
||||
done: true,
|
||||
response: {
|
||||
cloudaicompanionProject: {
|
||||
id: 'server-project',
|
||||
},
|
||||
},
|
||||
});
|
||||
vi.mocked(CodeAssistServer).mockImplementation(
|
||||
() =>
|
||||
({
|
||||
loadCodeAssist: mockLoad,
|
||||
onboardUser: mockOnboardUser,
|
||||
}) as unknown as CodeAssistServer,
|
||||
);
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.unstubAllEnvs();
|
||||
});
|
||||
|
||||
it('should use GOOGLE_CLOUD_PROJECT when set and project from server is undefined', async () => {
|
||||
vi.stubEnv('GOOGLE_CLOUD_PROJECT', 'test-project');
|
||||
mockLoad.mockResolvedValue({
|
||||
currentTier: mockPaidTier,
|
||||
});
|
||||
await setupUser({} as OAuth2Client);
|
||||
expect(CodeAssistServer).toHaveBeenCalledWith(
|
||||
{},
|
||||
'test-project',
|
||||
{},
|
||||
'',
|
||||
undefined,
|
||||
);
|
||||
});
|
||||
|
||||
it('should ignore GOOGLE_CLOUD_PROJECT when project from server is set', async () => {
|
||||
vi.stubEnv('GOOGLE_CLOUD_PROJECT', 'test-project');
|
||||
mockLoad.mockResolvedValue({
|
||||
cloudaicompanionProject: 'server-project',
|
||||
currentTier: mockPaidTier,
|
||||
});
|
||||
const projectId = await setupUser({} as OAuth2Client);
|
||||
expect(CodeAssistServer).toHaveBeenCalledWith(
|
||||
{},
|
||||
'test-project',
|
||||
{},
|
||||
'',
|
||||
undefined,
|
||||
);
|
||||
expect(projectId).toEqual({
|
||||
projectId: 'server-project',
|
||||
userTier: 'standard-tier',
|
||||
});
|
||||
});
|
||||
|
||||
it('should throw ProjectIdRequiredError when no project ID is available', async () => {
|
||||
vi.stubEnv('GOOGLE_CLOUD_PROJECT', '');
|
||||
// And the server itself requires a project ID internally
|
||||
vi.mocked(CodeAssistServer).mockImplementation(() => {
|
||||
throw new ProjectIdRequiredError();
|
||||
});
|
||||
|
||||
await expect(setupUser({} as OAuth2Client)).rejects.toThrow(
|
||||
ProjectIdRequiredError,
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('setupUser for new user', () => {
|
||||
let mockLoad: ReturnType<typeof vi.fn>;
|
||||
let mockOnboardUser: ReturnType<typeof vi.fn>;
|
||||
|
||||
beforeEach(() => {
|
||||
vi.resetAllMocks();
|
||||
mockLoad = vi.fn();
|
||||
mockOnboardUser = vi.fn().mockResolvedValue({
|
||||
done: true,
|
||||
response: {
|
||||
cloudaicompanionProject: {
|
||||
id: 'server-project',
|
||||
},
|
||||
},
|
||||
});
|
||||
vi.mocked(CodeAssistServer).mockImplementation(
|
||||
() =>
|
||||
({
|
||||
loadCodeAssist: mockLoad,
|
||||
onboardUser: mockOnboardUser,
|
||||
}) as unknown as CodeAssistServer,
|
||||
);
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.unstubAllEnvs();
|
||||
});
|
||||
|
||||
it('should use GOOGLE_CLOUD_PROJECT when set and onboard a new paid user', async () => {
|
||||
vi.stubEnv('GOOGLE_CLOUD_PROJECT', 'test-project');
|
||||
mockLoad.mockResolvedValue({
|
||||
allowedTiers: [mockPaidTier],
|
||||
});
|
||||
const userData = await setupUser({} as OAuth2Client);
|
||||
expect(CodeAssistServer).toHaveBeenCalledWith(
|
||||
{},
|
||||
'test-project',
|
||||
{},
|
||||
'',
|
||||
undefined,
|
||||
);
|
||||
expect(mockLoad).toHaveBeenCalled();
|
||||
expect(mockOnboardUser).toHaveBeenCalledWith({
|
||||
tierId: 'standard-tier',
|
||||
cloudaicompanionProject: 'test-project',
|
||||
metadata: {
|
||||
ideType: 'IDE_UNSPECIFIED',
|
||||
platform: 'PLATFORM_UNSPECIFIED',
|
||||
pluginType: 'GEMINI',
|
||||
duetProject: 'test-project',
|
||||
},
|
||||
});
|
||||
expect(userData).toEqual({
|
||||
projectId: 'server-project',
|
||||
userTier: 'standard-tier',
|
||||
});
|
||||
});
|
||||
|
||||
it('should onboard a new free user when GOOGLE_CLOUD_PROJECT is not set', async () => {
|
||||
vi.stubEnv('GOOGLE_CLOUD_PROJECT', '');
|
||||
mockLoad.mockResolvedValue({
|
||||
allowedTiers: [mockFreeTier],
|
||||
});
|
||||
const userData = await setupUser({} as OAuth2Client);
|
||||
expect(CodeAssistServer).toHaveBeenCalledWith(
|
||||
{},
|
||||
undefined,
|
||||
{},
|
||||
'',
|
||||
undefined,
|
||||
);
|
||||
expect(mockLoad).toHaveBeenCalled();
|
||||
expect(mockOnboardUser).toHaveBeenCalledWith({
|
||||
tierId: 'free-tier',
|
||||
cloudaicompanionProject: undefined,
|
||||
metadata: {
|
||||
ideType: 'IDE_UNSPECIFIED',
|
||||
platform: 'PLATFORM_UNSPECIFIED',
|
||||
pluginType: 'GEMINI',
|
||||
},
|
||||
});
|
||||
expect(userData).toEqual({
|
||||
projectId: 'server-project',
|
||||
userTier: 'free-tier',
|
||||
});
|
||||
});
|
||||
|
||||
it('should use GOOGLE_CLOUD_PROJECT when onboard response has no project ID', async () => {
|
||||
vi.stubEnv('GOOGLE_CLOUD_PROJECT', 'test-project');
|
||||
mockLoad.mockResolvedValue({
|
||||
allowedTiers: [mockPaidTier],
|
||||
});
|
||||
mockOnboardUser.mockResolvedValue({
|
||||
done: true,
|
||||
response: {
|
||||
cloudaicompanionProject: undefined,
|
||||
},
|
||||
});
|
||||
const userData = await setupUser({} as OAuth2Client);
|
||||
expect(userData).toEqual({
|
||||
projectId: 'test-project',
|
||||
userTier: 'standard-tier',
|
||||
});
|
||||
});
|
||||
|
||||
it('should throw ProjectIdRequiredError when no project ID is available', async () => {
|
||||
vi.stubEnv('GOOGLE_CLOUD_PROJECT', '');
|
||||
mockLoad.mockResolvedValue({
|
||||
allowedTiers: [mockPaidTier],
|
||||
});
|
||||
mockOnboardUser.mockResolvedValue({
|
||||
done: true,
|
||||
response: {},
|
||||
});
|
||||
await expect(setupUser({} as OAuth2Client)).rejects.toThrow(
|
||||
ProjectIdRequiredError,
|
||||
);
|
||||
});
|
||||
});
|
||||
124
packages/core/src/code_assist/setup.ts
Normal file
124
packages/core/src/code_assist/setup.ts
Normal file
@@ -0,0 +1,124 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import type {
|
||||
ClientMetadata,
|
||||
GeminiUserTier,
|
||||
LoadCodeAssistResponse,
|
||||
OnboardUserRequest,
|
||||
} from './types.js';
|
||||
import { UserTierId } from './types.js';
|
||||
import { CodeAssistServer } from './server.js';
|
||||
import type { OAuth2Client } from 'google-auth-library';
|
||||
|
||||
export class ProjectIdRequiredError extends Error {
|
||||
constructor() {
|
||||
super(
|
||||
'This account requires setting the GOOGLE_CLOUD_PROJECT env var. See https://goo.gle/gemini-cli-auth-docs#workspace-gca',
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
export interface UserData {
|
||||
projectId: string;
|
||||
userTier: UserTierId;
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param projectId the user's project id, if any
|
||||
* @returns the user's actual project id
|
||||
*/
|
||||
export async function setupUser(client: OAuth2Client): Promise<UserData> {
|
||||
const projectId = process.env['GOOGLE_CLOUD_PROJECT'] || undefined;
|
||||
const caServer = new CodeAssistServer(client, projectId, {}, '', undefined);
|
||||
const coreClientMetadata: ClientMetadata = {
|
||||
ideType: 'IDE_UNSPECIFIED',
|
||||
platform: 'PLATFORM_UNSPECIFIED',
|
||||
pluginType: 'GEMINI',
|
||||
};
|
||||
|
||||
const loadRes = await caServer.loadCodeAssist({
|
||||
cloudaicompanionProject: projectId,
|
||||
metadata: {
|
||||
...coreClientMetadata,
|
||||
duetProject: projectId,
|
||||
},
|
||||
});
|
||||
|
||||
if (loadRes.currentTier) {
|
||||
if (!loadRes.cloudaicompanionProject) {
|
||||
if (projectId) {
|
||||
return {
|
||||
projectId,
|
||||
userTier: loadRes.currentTier.id,
|
||||
};
|
||||
}
|
||||
throw new ProjectIdRequiredError();
|
||||
}
|
||||
return {
|
||||
projectId: loadRes.cloudaicompanionProject,
|
||||
userTier: loadRes.currentTier.id,
|
||||
};
|
||||
}
|
||||
|
||||
const tier = getOnboardTier(loadRes);
|
||||
|
||||
let onboardReq: OnboardUserRequest;
|
||||
if (tier.id === UserTierId.FREE) {
|
||||
// The free tier uses a managed google cloud project. Setting a project in the `onboardUser` request causes a `Precondition Failed` error.
|
||||
onboardReq = {
|
||||
tierId: tier.id,
|
||||
cloudaicompanionProject: undefined,
|
||||
metadata: coreClientMetadata,
|
||||
};
|
||||
} else {
|
||||
onboardReq = {
|
||||
tierId: tier.id,
|
||||
cloudaicompanionProject: projectId,
|
||||
metadata: {
|
||||
...coreClientMetadata,
|
||||
duetProject: projectId,
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
// Poll onboardUser until long running operation is complete.
|
||||
let lroRes = await caServer.onboardUser(onboardReq);
|
||||
while (!lroRes.done) {
|
||||
await new Promise((f) => setTimeout(f, 5000));
|
||||
lroRes = await caServer.onboardUser(onboardReq);
|
||||
}
|
||||
|
||||
if (!lroRes.response?.cloudaicompanionProject?.id) {
|
||||
if (projectId) {
|
||||
return {
|
||||
projectId,
|
||||
userTier: tier.id,
|
||||
};
|
||||
}
|
||||
throw new ProjectIdRequiredError();
|
||||
}
|
||||
|
||||
return {
|
||||
projectId: lroRes.response.cloudaicompanionProject.id,
|
||||
userTier: tier.id,
|
||||
};
|
||||
}
|
||||
|
||||
function getOnboardTier(res: LoadCodeAssistResponse): GeminiUserTier {
|
||||
for (const tier of res.allowedTiers || []) {
|
||||
if (tier.isDefault) {
|
||||
return tier;
|
||||
}
|
||||
}
|
||||
return {
|
||||
name: '',
|
||||
description: '',
|
||||
id: UserTierId.LEGACY,
|
||||
userDefinedCloudaicompanionProject: true,
|
||||
};
|
||||
}
|
||||
201
packages/core/src/code_assist/types.ts
Normal file
201
packages/core/src/code_assist/types.ts
Normal file
@@ -0,0 +1,201 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
export interface ClientMetadata {
|
||||
ideType?: ClientMetadataIdeType;
|
||||
ideVersion?: string;
|
||||
pluginVersion?: string;
|
||||
platform?: ClientMetadataPlatform;
|
||||
updateChannel?: string;
|
||||
duetProject?: string;
|
||||
pluginType?: ClientMetadataPluginType;
|
||||
ideName?: string;
|
||||
}
|
||||
|
||||
export type ClientMetadataIdeType =
|
||||
| 'IDE_UNSPECIFIED'
|
||||
| 'VSCODE'
|
||||
| 'INTELLIJ'
|
||||
| 'VSCODE_CLOUD_WORKSTATION'
|
||||
| 'INTELLIJ_CLOUD_WORKSTATION'
|
||||
| 'CLOUD_SHELL';
|
||||
export type ClientMetadataPlatform =
|
||||
| 'PLATFORM_UNSPECIFIED'
|
||||
| 'DARWIN_AMD64'
|
||||
| 'DARWIN_ARM64'
|
||||
| 'LINUX_AMD64'
|
||||
| 'LINUX_ARM64'
|
||||
| 'WINDOWS_AMD64';
|
||||
export type ClientMetadataPluginType =
|
||||
| 'PLUGIN_UNSPECIFIED'
|
||||
| 'CLOUD_CODE'
|
||||
| 'GEMINI'
|
||||
| 'AIPLUGIN_INTELLIJ'
|
||||
| 'AIPLUGIN_STUDIO';
|
||||
|
||||
export interface LoadCodeAssistRequest {
|
||||
cloudaicompanionProject?: string;
|
||||
metadata: ClientMetadata;
|
||||
}
|
||||
|
||||
/**
|
||||
* Represents LoadCodeAssistResponse proto json field
|
||||
* http://google3/google/internal/cloud/code/v1internal/cloudcode.proto;l=224
|
||||
*/
|
||||
export interface LoadCodeAssistResponse {
|
||||
currentTier?: GeminiUserTier | null;
|
||||
allowedTiers?: GeminiUserTier[] | null;
|
||||
ineligibleTiers?: IneligibleTier[] | null;
|
||||
cloudaicompanionProject?: string | null;
|
||||
}
|
||||
|
||||
/**
|
||||
* GeminiUserTier reflects the structure received from the CodeAssist when calling LoadCodeAssist.
|
||||
*/
|
||||
export interface GeminiUserTier {
|
||||
id: UserTierId;
|
||||
name?: string;
|
||||
description?: string;
|
||||
// This value is used to declare whether a given tier requires the user to configure the project setting on the IDE settings or not.
|
||||
userDefinedCloudaicompanionProject?: boolean | null;
|
||||
isDefault?: boolean;
|
||||
privacyNotice?: PrivacyNotice;
|
||||
hasAcceptedTos?: boolean;
|
||||
hasOnboardedPreviously?: boolean;
|
||||
}
|
||||
|
||||
/**
|
||||
* Includes information specifying the reasons for a user's ineligibility for a specific tier.
|
||||
* @param reasonCode mnemonic code representing the reason for in-eligibility.
|
||||
* @param reasonMessage message to display to the user.
|
||||
* @param tierId id of the tier.
|
||||
* @param tierName name of the tier.
|
||||
*/
|
||||
export interface IneligibleTier {
|
||||
reasonCode: IneligibleTierReasonCode;
|
||||
reasonMessage: string;
|
||||
tierId: UserTierId;
|
||||
tierName: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* List of predefined reason codes when a tier is blocked from a specific tier.
|
||||
* https://source.corp.google.com/piper///depot/google3/google/internal/cloud/code/v1internal/cloudcode.proto;l=378
|
||||
*/
|
||||
export enum IneligibleTierReasonCode {
|
||||
// go/keep-sorted start
|
||||
DASHER_USER = 'DASHER_USER',
|
||||
INELIGIBLE_ACCOUNT = 'INELIGIBLE_ACCOUNT',
|
||||
NON_USER_ACCOUNT = 'NON_USER_ACCOUNT',
|
||||
RESTRICTED_AGE = 'RESTRICTED_AGE',
|
||||
RESTRICTED_NETWORK = 'RESTRICTED_NETWORK',
|
||||
UNKNOWN = 'UNKNOWN',
|
||||
UNKNOWN_LOCATION = 'UNKNOWN_LOCATION',
|
||||
UNSUPPORTED_LOCATION = 'UNSUPPORTED_LOCATION',
|
||||
// go/keep-sorted end
|
||||
}
|
||||
/**
|
||||
* UserTierId represents IDs returned from the Cloud Code Private API representing a user's tier
|
||||
*
|
||||
* //depot/google3/cloud/developer_experience/cloudcode/pa/service/usertier.go;l=16
|
||||
*/
|
||||
export enum UserTierId {
|
||||
FREE = 'free-tier',
|
||||
LEGACY = 'legacy-tier',
|
||||
STANDARD = 'standard-tier',
|
||||
}
|
||||
|
||||
/**
|
||||
* PrivacyNotice reflects the structure received from the CodeAssist in regards to a tier
|
||||
* privacy notice.
|
||||
*/
|
||||
export interface PrivacyNotice {
|
||||
showNotice: boolean;
|
||||
noticeText?: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* Proto signature of OnboardUserRequest as payload to OnboardUser call
|
||||
*/
|
||||
export interface OnboardUserRequest {
|
||||
tierId: string | undefined;
|
||||
cloudaicompanionProject: string | undefined;
|
||||
metadata: ClientMetadata | undefined;
|
||||
}
|
||||
|
||||
/**
|
||||
* Represents LongRunningOperation proto
|
||||
* http://google3/google/longrunning/operations.proto;rcl=698857719;l=107
|
||||
*/
|
||||
export interface LongRunningOperationResponse {
|
||||
name: string;
|
||||
done?: boolean;
|
||||
response?: OnboardUserResponse;
|
||||
}
|
||||
|
||||
/**
|
||||
* Represents OnboardUserResponse proto
|
||||
* http://google3/google/internal/cloud/code/v1internal/cloudcode.proto;l=215
|
||||
*/
|
||||
export interface OnboardUserResponse {
|
||||
// tslint:disable-next-line:enforce-name-casing This is the name of the field in the proto.
|
||||
cloudaicompanionProject?: {
|
||||
id: string;
|
||||
name: string;
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Status code of user license status
|
||||
* it does not strictly correspond to the proto
|
||||
* Error value is an additional value assigned to error responses from OnboardUser
|
||||
*/
|
||||
export enum OnboardUserStatusCode {
|
||||
Default = 'DEFAULT',
|
||||
Notice = 'NOTICE',
|
||||
Warning = 'WARNING',
|
||||
Error = 'ERROR',
|
||||
}
|
||||
|
||||
/**
|
||||
* Status of user onboarded to gemini
|
||||
*/
|
||||
export interface OnboardUserStatus {
|
||||
statusCode: OnboardUserStatusCode;
|
||||
displayMessage: string;
|
||||
helpLink: HelpLinkUrl | undefined;
|
||||
}
|
||||
|
||||
export interface HelpLinkUrl {
|
||||
description: string;
|
||||
url: string;
|
||||
}
|
||||
|
||||
export interface SetCodeAssistGlobalUserSettingRequest {
|
||||
cloudaicompanionProject?: string;
|
||||
freeTierDataCollectionOptin: boolean;
|
||||
}
|
||||
|
||||
export interface CodeAssistGlobalUserSettingResponse {
|
||||
cloudaicompanionProject?: string;
|
||||
freeTierDataCollectionOptin: boolean;
|
||||
}
|
||||
|
||||
/**
|
||||
* Relevant fields that can be returned from a Google RPC response
|
||||
*/
|
||||
export interface GoogleRpcResponse {
|
||||
error?: {
|
||||
details?: GoogleRpcErrorInfo[];
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Relevant fields that can be returned in the details of an error returned from GoogleRPCs
|
||||
*/
|
||||
interface GoogleRpcErrorInfo {
|
||||
reason?: string;
|
||||
}
|
||||
@@ -283,6 +283,23 @@ describe('Server Config (config.ts)', () => {
|
||||
expect(config.isInFallbackMode()).toBe(false);
|
||||
});
|
||||
|
||||
it('should strip thoughts when switching from GenAI to Vertex', async () => {
|
||||
const config = new Config(baseParams);
|
||||
|
||||
vi.mocked(createContentGeneratorConfig).mockImplementation(
|
||||
(_: Config, authType: AuthType | undefined) =>
|
||||
({ authType }) as unknown as ContentGeneratorConfig,
|
||||
);
|
||||
|
||||
await config.refreshAuth(AuthType.USE_GEMINI);
|
||||
|
||||
await config.refreshAuth(AuthType.LOGIN_WITH_GOOGLE);
|
||||
|
||||
expect(
|
||||
config.getGeminiClient().stripThoughtsFromHistory,
|
||||
).toHaveBeenCalledWith();
|
||||
});
|
||||
|
||||
it('should not strip thoughts when switching from Vertex to GenAI', async () => {
|
||||
const config = new Config(baseParams);
|
||||
|
||||
|
||||
@@ -16,7 +16,6 @@ import { ProxyAgent, setGlobalDispatcher } from 'undici';
|
||||
import type {
|
||||
ContentGenerator,
|
||||
ContentGeneratorConfig,
|
||||
AuthType,
|
||||
} from '../core/contentGenerator.js';
|
||||
import type { FallbackModelHandler } from '../fallback/types.js';
|
||||
import type { MCPOAuthConfig } from '../mcp/oauth-provider.js';
|
||||
@@ -27,6 +26,7 @@ import type { AnyToolInvocation } from '../tools/tools.js';
|
||||
import { BaseLlmClient } from '../core/baseLlmClient.js';
|
||||
import { GeminiClient } from '../core/client.js';
|
||||
import {
|
||||
AuthType,
|
||||
createContentGenerator,
|
||||
createContentGeneratorConfig,
|
||||
} from '../core/contentGenerator.js';
|
||||
@@ -54,6 +54,7 @@ import { canUseRipgrep } from '../utils/ripgrepUtils.js';
|
||||
import { RipGrepTool } from '../tools/ripGrep.js';
|
||||
import { ShellTool } from '../tools/shell.js';
|
||||
import { SmartEditTool } from '../tools/smart-edit.js';
|
||||
import { SkillTool } from '../tools/skill.js';
|
||||
import { TaskTool } from '../tools/task.js';
|
||||
import { TodoWriteTool } from '../tools/todoWrite.js';
|
||||
import { ToolRegistry } from '../tools/tool-registry.js';
|
||||
@@ -65,6 +66,7 @@ import { WriteFileTool } from '../tools/write-file.js';
|
||||
import { ideContextStore } from '../ide/ideContext.js';
|
||||
import { InputFormat, OutputFormat } from '../output/types.js';
|
||||
import { PromptRegistry } from '../prompts/prompt-registry.js';
|
||||
import { SkillManager } from '../skills/skill-manager.js';
|
||||
import { SubagentManager } from '../subagents/subagent-manager.js';
|
||||
import type { SubagentConfig } from '../subagents/types.js';
|
||||
import {
|
||||
@@ -305,6 +307,7 @@ export interface ConfigParameters {
|
||||
extensionContextFilePaths?: string[];
|
||||
maxSessionTurns?: number;
|
||||
sessionTokenLimit?: number;
|
||||
experimentalSkills?: boolean;
|
||||
experimentalZedIntegration?: boolean;
|
||||
listExtensions?: boolean;
|
||||
extensions?: GeminiCLIExtension[];
|
||||
@@ -389,6 +392,7 @@ export class Config {
|
||||
private toolRegistry!: ToolRegistry;
|
||||
private promptRegistry!: PromptRegistry;
|
||||
private subagentManager!: SubagentManager;
|
||||
private skillManager!: SkillManager;
|
||||
private fileSystemService: FileSystemService;
|
||||
private contentGeneratorConfig!: ContentGeneratorConfig;
|
||||
private contentGenerator!: ContentGenerator;
|
||||
@@ -458,6 +462,7 @@ export class Config {
|
||||
| undefined;
|
||||
private readonly cliVersion?: string;
|
||||
private readonly experimentalZedIntegration: boolean = false;
|
||||
private readonly experimentalSkills: boolean = false;
|
||||
private readonly chatRecordingEnabled: boolean;
|
||||
private readonly loadMemoryFromIncludeDirectories: boolean = false;
|
||||
private readonly webSearch?: {
|
||||
@@ -557,6 +562,7 @@ export class Config {
|
||||
this.sessionTokenLimit = params.sessionTokenLimit ?? -1;
|
||||
this.experimentalZedIntegration =
|
||||
params.experimentalZedIntegration ?? false;
|
||||
this.experimentalSkills = params.experimentalSkills ?? false;
|
||||
this.listExtensions = params.listExtensions ?? false;
|
||||
this._extensions = params.extensions ?? [];
|
||||
this._blockedMcpServers = params.blockedMcpServers ?? [];
|
||||
@@ -644,6 +650,7 @@ export class Config {
|
||||
}
|
||||
this.promptRegistry = new PromptRegistry();
|
||||
this.subagentManager = new SubagentManager(this);
|
||||
this.skillManager = new SkillManager(this);
|
||||
|
||||
// Load session subagents if they were provided before initialization
|
||||
if (this.sessionSubagents.length > 0) {
|
||||
@@ -684,6 +691,16 @@ export class Config {
|
||||
}
|
||||
|
||||
async refreshAuth(authMethod: AuthType, isInitialAuth?: boolean) {
|
||||
// Vertex and Genai have incompatible encryption and sending history with
|
||||
// throughtSignature from Genai to Vertex will fail, we need to strip them
|
||||
if (
|
||||
this.contentGeneratorConfig?.authType === AuthType.USE_GEMINI &&
|
||||
authMethod === AuthType.LOGIN_WITH_GOOGLE
|
||||
) {
|
||||
// Restore the conversation history to the new client
|
||||
this.geminiClient.stripThoughtsFromHistory();
|
||||
}
|
||||
|
||||
const newContentGeneratorConfig = createContentGeneratorConfig(
|
||||
this,
|
||||
authMethod,
|
||||
@@ -1066,6 +1083,10 @@ export class Config {
|
||||
return this.experimentalZedIntegration;
|
||||
}
|
||||
|
||||
getExperimentalSkills(): boolean {
|
||||
return this.experimentalSkills;
|
||||
}
|
||||
|
||||
getListExtensions(): boolean {
|
||||
return this.listExtensions;
|
||||
}
|
||||
@@ -1296,6 +1317,10 @@ export class Config {
|
||||
return this.subagentManager;
|
||||
}
|
||||
|
||||
getSkillManager(): SkillManager {
|
||||
return this.skillManager;
|
||||
}
|
||||
|
||||
async createToolRegistry(
|
||||
sendSdkMcpMessage?: SendSdkMcpMessage,
|
||||
): Promise<ToolRegistry> {
|
||||
@@ -1338,6 +1363,9 @@ export class Config {
|
||||
};
|
||||
|
||||
registerCoreTool(TaskTool, this);
|
||||
if (this.getExperimentalSkills()) {
|
||||
registerCoreTool(SkillTool, this);
|
||||
}
|
||||
registerCoreTool(LSTool, this);
|
||||
registerCoreTool(ReadFileTool, this);
|
||||
|
||||
|
||||
@@ -31,7 +31,7 @@ describe('Flash Model Fallback Configuration', () => {
|
||||
config as unknown as { contentGeneratorConfig: unknown }
|
||||
).contentGeneratorConfig = {
|
||||
model: DEFAULT_GEMINI_MODEL,
|
||||
authType: 'gemini-api-key',
|
||||
authType: 'oauth-personal',
|
||||
};
|
||||
});
|
||||
|
||||
|
||||
@@ -126,6 +126,10 @@ export class Storage {
|
||||
return path.join(this.getExtensionsDir(), 'qwen-extension.json');
|
||||
}
|
||||
|
||||
getUserSkillsDir(): string {
|
||||
return path.join(Storage.getGlobalQwenDir(), 'skills');
|
||||
}
|
||||
|
||||
getHistoryFilePath(): string {
|
||||
return path.join(this.getProjectTempDir(), 'shell_history');
|
||||
}
|
||||
|
||||
@@ -73,7 +73,6 @@ describe('OpenAIContentGenerator Timeout Handling', () => {
|
||||
}),
|
||||
buildClient: vi.fn().mockReturnValue(mockOpenAIClient),
|
||||
buildRequest: vi.fn().mockImplementation((req) => req),
|
||||
getDefaultGenerationConfig: vi.fn().mockReturnValue({}),
|
||||
};
|
||||
|
||||
// Create generator instance
|
||||
@@ -300,7 +299,6 @@ describe('OpenAIContentGenerator Timeout Handling', () => {
|
||||
}),
|
||||
buildClient: vi.fn().mockReturnValue(mockOpenAIClient),
|
||||
buildRequest: vi.fn().mockImplementation((req) => req),
|
||||
getDefaultGenerationConfig: vi.fn().mockReturnValue({}),
|
||||
};
|
||||
|
||||
new OpenAIContentGenerator(
|
||||
@@ -335,7 +333,6 @@ describe('OpenAIContentGenerator Timeout Handling', () => {
|
||||
}),
|
||||
buildClient: vi.fn().mockReturnValue(mockOpenAIClient),
|
||||
buildRequest: vi.fn().mockImplementation((req) => req),
|
||||
getDefaultGenerationConfig: vi.fn().mockReturnValue({}),
|
||||
};
|
||||
|
||||
new OpenAIContentGenerator(
|
||||
|
||||
@@ -146,11 +146,12 @@ describe('BaseLlmClient', () => {
|
||||
// Validate the parameters passed to the underlying generator
|
||||
expect(mockGenerateContent).toHaveBeenCalledTimes(1);
|
||||
expect(mockGenerateContent).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
{
|
||||
model: 'test-model',
|
||||
contents: defaultOptions.contents,
|
||||
config: expect.objectContaining({
|
||||
config: {
|
||||
abortSignal: defaultOptions.abortSignal,
|
||||
topP: 0.8,
|
||||
tools: [
|
||||
{
|
||||
functionDeclarations: [
|
||||
@@ -162,8 +163,9 @@ describe('BaseLlmClient', () => {
|
||||
],
|
||||
},
|
||||
],
|
||||
}),
|
||||
}),
|
||||
// Crucial: systemInstruction should NOT be in the config object if not provided
|
||||
},
|
||||
},
|
||||
'test-prompt-id',
|
||||
);
|
||||
});
|
||||
@@ -186,6 +188,7 @@ describe('BaseLlmClient', () => {
|
||||
expect.objectContaining({
|
||||
config: expect.objectContaining({
|
||||
temperature: 0.8,
|
||||
topP: 0.8, // Default should remain if not overridden
|
||||
topK: 10,
|
||||
tools: expect.any(Array),
|
||||
}),
|
||||
|
||||
@@ -64,6 +64,11 @@ export interface GenerateJsonOptions {
|
||||
* A client dedicated to stateless, utility-focused LLM calls.
|
||||
*/
|
||||
export class BaseLlmClient {
|
||||
// Default configuration for utility tasks
|
||||
private readonly defaultUtilityConfig: GenerateContentConfig = {
|
||||
topP: 0.8,
|
||||
};
|
||||
|
||||
constructor(
|
||||
private readonly contentGenerator: ContentGenerator,
|
||||
private readonly config: Config,
|
||||
@@ -84,6 +89,7 @@ export class BaseLlmClient {
|
||||
|
||||
const requestConfig: GenerateContentConfig = {
|
||||
abortSignal,
|
||||
...this.defaultUtilityConfig,
|
||||
...options.config,
|
||||
...(systemInstruction && { systemInstruction }),
|
||||
};
|
||||
|
||||
@@ -15,7 +15,11 @@ import {
|
||||
} from 'vitest';
|
||||
|
||||
import type { Content, GenerateContentResponse, Part } from '@google/genai';
|
||||
import { GeminiClient } from './client.js';
|
||||
import {
|
||||
isThinkingDefault,
|
||||
isThinkingSupported,
|
||||
GeminiClient,
|
||||
} from './client.js';
|
||||
import { findCompressSplitPoint } from '../services/chatCompressionService.js';
|
||||
import {
|
||||
AuthType,
|
||||
@@ -243,6 +247,40 @@ describe('findCompressSplitPoint', () => {
|
||||
});
|
||||
});
|
||||
|
||||
describe('isThinkingSupported', () => {
|
||||
it('should return true for gemini-2.5', () => {
|
||||
expect(isThinkingSupported('gemini-2.5')).toBe(true);
|
||||
});
|
||||
|
||||
it('should return true for gemini-2.5-pro', () => {
|
||||
expect(isThinkingSupported('gemini-2.5-pro')).toBe(true);
|
||||
});
|
||||
|
||||
it('should return false for other models', () => {
|
||||
expect(isThinkingSupported('gemini-1.5-flash')).toBe(false);
|
||||
expect(isThinkingSupported('some-other-model')).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe('isThinkingDefault', () => {
|
||||
it('should return false for gemini-2.5-flash-lite', () => {
|
||||
expect(isThinkingDefault('gemini-2.5-flash-lite')).toBe(false);
|
||||
});
|
||||
|
||||
it('should return true for gemini-2.5', () => {
|
||||
expect(isThinkingDefault('gemini-2.5')).toBe(true);
|
||||
});
|
||||
|
||||
it('should return true for gemini-2.5-pro', () => {
|
||||
expect(isThinkingDefault('gemini-2.5-pro')).toBe(true);
|
||||
});
|
||||
|
||||
it('should return false for other models', () => {
|
||||
expect(isThinkingDefault('gemini-1.5-flash')).toBe(false);
|
||||
expect(isThinkingDefault('some-other-model')).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Gemini Client (client.ts)', () => {
|
||||
let mockContentGenerator: ContentGenerator;
|
||||
let mockConfig: Config;
|
||||
@@ -2266,15 +2304,16 @@ ${JSON.stringify(
|
||||
);
|
||||
|
||||
expect(mockContentGenerator.generateContent).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
{
|
||||
model: DEFAULT_GEMINI_FLASH_MODEL,
|
||||
config: expect.objectContaining({
|
||||
config: {
|
||||
abortSignal,
|
||||
systemInstruction: getCoreSystemPrompt(''),
|
||||
temperature: 0.5,
|
||||
}),
|
||||
topP: 0.8,
|
||||
},
|
||||
contents,
|
||||
}),
|
||||
},
|
||||
'test-session-id',
|
||||
);
|
||||
});
|
||||
|
||||
@@ -15,7 +15,11 @@ import type {
|
||||
|
||||
// Config
|
||||
import { ApprovalMode, type Config } from '../config/config.js';
|
||||
import { DEFAULT_GEMINI_FLASH_MODEL } from '../config/models.js';
|
||||
import {
|
||||
DEFAULT_GEMINI_FLASH_MODEL,
|
||||
DEFAULT_GEMINI_MODEL_AUTO,
|
||||
DEFAULT_THINKING_MODE,
|
||||
} from '../config/models.js';
|
||||
|
||||
// Core modules
|
||||
import type { ContentGenerator } from './contentGenerator.js';
|
||||
@@ -74,10 +78,24 @@ import { type File, type IdeContext } from '../ide/types.js';
|
||||
// Fallback handling
|
||||
import { handleFallback } from '../fallback/handler.js';
|
||||
|
||||
export function isThinkingSupported(model: string) {
|
||||
return model.startsWith('gemini-2.5') || model === DEFAULT_GEMINI_MODEL_AUTO;
|
||||
}
|
||||
|
||||
export function isThinkingDefault(model: string) {
|
||||
if (model.startsWith('gemini-2.5-flash-lite')) {
|
||||
return false;
|
||||
}
|
||||
return model.startsWith('gemini-2.5') || model === DEFAULT_GEMINI_MODEL_AUTO;
|
||||
}
|
||||
|
||||
const MAX_TURNS = 100;
|
||||
|
||||
export class GeminiClient {
|
||||
private chat?: GeminiChat;
|
||||
private readonly generateContentConfig: GenerateContentConfig = {
|
||||
topP: 0.8,
|
||||
};
|
||||
private sessionTurnCount = 0;
|
||||
|
||||
private readonly loopDetector: LoopDetectionService;
|
||||
@@ -189,10 +207,20 @@ export class GeminiClient {
|
||||
const model = this.config.getModel();
|
||||
const systemInstruction = getCoreSystemPrompt(userMemory, model);
|
||||
|
||||
const config: GenerateContentConfig = { ...this.generateContentConfig };
|
||||
|
||||
if (isThinkingSupported(model)) {
|
||||
config.thinkingConfig = {
|
||||
includeThoughts: true,
|
||||
thinkingBudget: DEFAULT_THINKING_MODE,
|
||||
};
|
||||
}
|
||||
|
||||
return new GeminiChat(
|
||||
this.config,
|
||||
{
|
||||
systemInstruction,
|
||||
...config,
|
||||
tools,
|
||||
},
|
||||
history,
|
||||
@@ -589,6 +617,11 @@ export class GeminiClient {
|
||||
): Promise<GenerateContentResponse> {
|
||||
let currentAttemptModel: string = model;
|
||||
|
||||
const configToUse: GenerateContentConfig = {
|
||||
...this.generateContentConfig,
|
||||
...generationConfig,
|
||||
};
|
||||
|
||||
try {
|
||||
const userMemory = this.config.getUserMemory();
|
||||
const finalSystemInstruction = generationConfig.systemInstruction
|
||||
@@ -597,7 +630,7 @@ export class GeminiClient {
|
||||
|
||||
const requestConfig: GenerateContentConfig = {
|
||||
abortSignal,
|
||||
...generationConfig,
|
||||
...configToUse,
|
||||
systemInstruction: finalSystemInstruction,
|
||||
};
|
||||
|
||||
@@ -638,7 +671,7 @@ export class GeminiClient {
|
||||
`Error generating content via API with model ${currentAttemptModel}.`,
|
||||
{
|
||||
requestContents: contents,
|
||||
requestConfig: generationConfig,
|
||||
requestConfig: configToUse,
|
||||
},
|
||||
'generateContent-api',
|
||||
);
|
||||
|
||||
@@ -5,19 +5,42 @@
|
||||
*/
|
||||
|
||||
import { describe, it, expect, vi } from 'vitest';
|
||||
import type { ContentGenerator } from './contentGenerator.js';
|
||||
import { createContentGenerator, AuthType } from './contentGenerator.js';
|
||||
import { createCodeAssistContentGenerator } from '../code_assist/codeAssist.js';
|
||||
import { GoogleGenAI } from '@google/genai';
|
||||
import type { Config } from '../config/config.js';
|
||||
import { LoggingContentGenerator } from './geminiContentGenerator/loggingContentGenerator.js';
|
||||
import { LoggingContentGenerator } from './loggingContentGenerator.js';
|
||||
|
||||
vi.mock('../code_assist/codeAssist.js');
|
||||
vi.mock('@google/genai');
|
||||
|
||||
const mockConfig = {
|
||||
getCliVersion: vi.fn().mockReturnValue('1.0.0'),
|
||||
} as unknown as Config;
|
||||
|
||||
describe('createContentGenerator', () => {
|
||||
it('should create a Gemini content generator', async () => {
|
||||
it('should create a CodeAssistContentGenerator', async () => {
|
||||
const mockGenerator = {} as unknown as ContentGenerator;
|
||||
vi.mocked(createCodeAssistContentGenerator).mockResolvedValue(
|
||||
mockGenerator as never,
|
||||
);
|
||||
const generator = await createContentGenerator(
|
||||
{
|
||||
model: 'test-model',
|
||||
authType: AuthType.LOGIN_WITH_GOOGLE,
|
||||
},
|
||||
mockConfig,
|
||||
);
|
||||
expect(createCodeAssistContentGenerator).toHaveBeenCalled();
|
||||
expect(generator).toEqual(
|
||||
new LoggingContentGenerator(mockGenerator, mockConfig),
|
||||
);
|
||||
});
|
||||
|
||||
it('should create a GoogleGenAI content generator', async () => {
|
||||
const mockConfig = {
|
||||
getUsageStatisticsEnabled: () => true,
|
||||
getContentGeneratorConfig: () => ({}),
|
||||
getCliVersion: () => '1.0.0',
|
||||
} as unknown as Config;
|
||||
|
||||
const mockGenerator = {
|
||||
@@ -42,17 +65,17 @@ describe('createContentGenerator', () => {
|
||||
},
|
||||
},
|
||||
});
|
||||
// We expect it to be a LoggingContentGenerator wrapping a GeminiContentGenerator
|
||||
expect(generator).toBeInstanceOf(LoggingContentGenerator);
|
||||
const wrapped = (generator as LoggingContentGenerator).getWrapped();
|
||||
expect(wrapped).toBeDefined();
|
||||
expect(generator).toEqual(
|
||||
new LoggingContentGenerator(
|
||||
(mockGenerator as GoogleGenAI).models,
|
||||
mockConfig,
|
||||
),
|
||||
);
|
||||
});
|
||||
|
||||
it('should create a Gemini content generator with client install id logging disabled', async () => {
|
||||
it('should create a GoogleGenAI content generator with client install id logging disabled', async () => {
|
||||
const mockConfig = {
|
||||
getUsageStatisticsEnabled: () => false,
|
||||
getContentGeneratorConfig: () => ({}),
|
||||
getCliVersion: () => '1.0.0',
|
||||
} as unknown as Config;
|
||||
const mockGenerator = {
|
||||
models: {},
|
||||
@@ -75,6 +98,11 @@ describe('createContentGenerator', () => {
|
||||
},
|
||||
},
|
||||
});
|
||||
expect(generator).toBeInstanceOf(LoggingContentGenerator);
|
||||
expect(generator).toEqual(
|
||||
new LoggingContentGenerator(
|
||||
(mockGenerator as GoogleGenAI).models,
|
||||
mockConfig,
|
||||
),
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -12,9 +12,15 @@ import type {
|
||||
GenerateContentParameters,
|
||||
GenerateContentResponse,
|
||||
} from '@google/genai';
|
||||
import { GoogleGenAI } from '@google/genai';
|
||||
import { createCodeAssistContentGenerator } from '../code_assist/codeAssist.js';
|
||||
import { DEFAULT_QWEN_MODEL } from '../config/models.js';
|
||||
import type { Config } from '../config/config.js';
|
||||
|
||||
import type { UserTierId } from '../code_assist/types.js';
|
||||
import { InstallationManager } from '../utils/installationManager.js';
|
||||
import { LoggingContentGenerator } from './loggingContentGenerator.js';
|
||||
|
||||
/**
|
||||
* Interface abstracting the core functionalities for generating content and counting tokens.
|
||||
*/
|
||||
@@ -33,12 +39,14 @@ export interface ContentGenerator {
|
||||
|
||||
embedContent(request: EmbedContentParameters): Promise<EmbedContentResponse>;
|
||||
|
||||
useSummarizedThinking(): boolean;
|
||||
userTier?: UserTierId;
|
||||
}
|
||||
|
||||
export enum AuthType {
|
||||
LOGIN_WITH_GOOGLE = 'oauth-personal',
|
||||
USE_GEMINI = 'gemini-api-key',
|
||||
USE_VERTEX_AI = 'vertex-ai',
|
||||
CLOUD_SHELL = 'cloud-shell',
|
||||
USE_OPENAI = 'openai',
|
||||
QWEN_OAUTH = 'qwen-oauth',
|
||||
}
|
||||
@@ -51,9 +59,12 @@ export type ContentGeneratorConfig = {
|
||||
authType?: AuthType | undefined;
|
||||
enableOpenAILogging?: boolean;
|
||||
openAILoggingDir?: string;
|
||||
timeout?: number; // Timeout configuration in milliseconds
|
||||
maxRetries?: number; // Maximum retries for failed requests
|
||||
disableCacheControl?: boolean; // Disable cache control for DashScope providers
|
||||
// Timeout configuration in milliseconds
|
||||
timeout?: number;
|
||||
// Maximum retries for failed requests
|
||||
maxRetries?: number;
|
||||
// Disable cache control for DashScope providers
|
||||
disableCacheControl?: boolean;
|
||||
samplingParams?: {
|
||||
top_p?: number;
|
||||
top_k?: number;
|
||||
@@ -63,9 +74,6 @@ export type ContentGeneratorConfig = {
|
||||
temperature?: number;
|
||||
max_tokens?: number;
|
||||
};
|
||||
reasoning?: {
|
||||
effort?: 'low' | 'medium' | 'high';
|
||||
};
|
||||
proxy?: string | undefined;
|
||||
userAgent?: string;
|
||||
// Schema compliance mode for tool definitions
|
||||
@@ -115,14 +123,48 @@ export async function createContentGenerator(
|
||||
gcConfig: Config,
|
||||
isInitialAuth?: boolean,
|
||||
): Promise<ContentGenerator> {
|
||||
const version = process.env['CLI_VERSION'] || process.version;
|
||||
const userAgent = `QwenCode/${version} (${process.platform}; ${process.arch})`;
|
||||
const baseHeaders: Record<string, string> = {
|
||||
'User-Agent': userAgent,
|
||||
};
|
||||
|
||||
if (
|
||||
config.authType === AuthType.LOGIN_WITH_GOOGLE ||
|
||||
config.authType === AuthType.CLOUD_SHELL
|
||||
) {
|
||||
const httpOptions = { headers: baseHeaders };
|
||||
return new LoggingContentGenerator(
|
||||
await createCodeAssistContentGenerator(
|
||||
httpOptions,
|
||||
config.authType,
|
||||
gcConfig,
|
||||
),
|
||||
gcConfig,
|
||||
);
|
||||
}
|
||||
|
||||
if (
|
||||
config.authType === AuthType.USE_GEMINI ||
|
||||
config.authType === AuthType.USE_VERTEX_AI
|
||||
) {
|
||||
const { createGeminiContentGenerator } = await import(
|
||||
'./geminiContentGenerator/index.js'
|
||||
);
|
||||
return createGeminiContentGenerator(config, gcConfig);
|
||||
let headers: Record<string, string> = { ...baseHeaders };
|
||||
if (gcConfig?.getUsageStatisticsEnabled()) {
|
||||
const installationManager = new InstallationManager();
|
||||
const installationId = installationManager.getInstallationId();
|
||||
headers = {
|
||||
...headers,
|
||||
'x-gemini-api-privileged-user-id': `${installationId}`,
|
||||
};
|
||||
}
|
||||
const httpOptions = { headers };
|
||||
|
||||
const googleGenAI = new GoogleGenAI({
|
||||
apiKey: config.apiKey === '' ? undefined : config.apiKey,
|
||||
vertexai: config.vertexai,
|
||||
httpOptions,
|
||||
});
|
||||
return new LoggingContentGenerator(googleGenAI.models, gcConfig);
|
||||
}
|
||||
|
||||
if (config.authType === AuthType.USE_OPENAI) {
|
||||
|
||||
@@ -240,7 +240,7 @@ describe('CoreToolScheduler', () => {
|
||||
getAllowedTools: () => [],
|
||||
getContentGeneratorConfig: () => ({
|
||||
model: 'test-model',
|
||||
authType: 'gemini-api-key',
|
||||
authType: 'oauth-personal',
|
||||
}),
|
||||
getShellExecutionConfig: () => ({
|
||||
terminalWidth: 90,
|
||||
@@ -318,7 +318,7 @@ describe('CoreToolScheduler', () => {
|
||||
getAllowedTools: () => [],
|
||||
getContentGeneratorConfig: () => ({
|
||||
model: 'test-model',
|
||||
authType: 'gemini-api-key',
|
||||
authType: 'oauth-personal',
|
||||
}),
|
||||
getShellExecutionConfig: () => ({
|
||||
terminalWidth: 90,
|
||||
@@ -497,7 +497,7 @@ describe('CoreToolScheduler', () => {
|
||||
getExcludeTools: () => ['write_file', 'edit', 'run_shell_command'],
|
||||
getContentGeneratorConfig: () => ({
|
||||
model: 'test-model',
|
||||
authType: 'gemini-api-key',
|
||||
authType: 'oauth-personal',
|
||||
}),
|
||||
getShellExecutionConfig: () => ({
|
||||
terminalWidth: 90,
|
||||
@@ -584,7 +584,7 @@ describe('CoreToolScheduler', () => {
|
||||
getExcludeTools: () => ['write_file', 'edit'], // Different excluded tools
|
||||
getContentGeneratorConfig: () => ({
|
||||
model: 'test-model',
|
||||
authType: 'gemini-api-key',
|
||||
authType: 'oauth-personal',
|
||||
}),
|
||||
getShellExecutionConfig: () => ({
|
||||
terminalWidth: 90,
|
||||
@@ -674,7 +674,7 @@ describe('CoreToolScheduler with payload', () => {
|
||||
getAllowedTools: () => [],
|
||||
getContentGeneratorConfig: () => ({
|
||||
model: 'test-model',
|
||||
authType: 'gemini-api-key',
|
||||
authType: 'oauth-personal',
|
||||
}),
|
||||
getShellExecutionConfig: () => ({
|
||||
terminalWidth: 90,
|
||||
@@ -1001,7 +1001,7 @@ describe('CoreToolScheduler edit cancellation', () => {
|
||||
getAllowedTools: () => [],
|
||||
getContentGeneratorConfig: () => ({
|
||||
model: 'test-model',
|
||||
authType: 'gemini-api-key',
|
||||
authType: 'oauth-personal',
|
||||
}),
|
||||
getShellExecutionConfig: () => ({
|
||||
terminalWidth: 90,
|
||||
@@ -1108,7 +1108,7 @@ describe('CoreToolScheduler YOLO mode', () => {
|
||||
getAllowedTools: () => [],
|
||||
getContentGeneratorConfig: () => ({
|
||||
model: 'test-model',
|
||||
authType: 'gemini-api-key',
|
||||
authType: 'oauth-personal',
|
||||
}),
|
||||
getShellExecutionConfig: () => ({
|
||||
terminalWidth: 90,
|
||||
@@ -1258,7 +1258,7 @@ describe('CoreToolScheduler cancellation during executing with live output', ()
|
||||
getApprovalMode: () => ApprovalMode.DEFAULT,
|
||||
getContentGeneratorConfig: () => ({
|
||||
model: 'test-model',
|
||||
authType: 'gemini-api-key',
|
||||
authType: 'oauth-personal',
|
||||
}),
|
||||
getToolRegistry: () => mockToolRegistry,
|
||||
getShellExecutionConfig: () => ({
|
||||
@@ -1350,7 +1350,7 @@ describe('CoreToolScheduler request queueing', () => {
|
||||
getAllowedTools: () => [],
|
||||
getContentGeneratorConfig: () => ({
|
||||
model: 'test-model',
|
||||
authType: 'gemini-api-key',
|
||||
authType: 'oauth-personal',
|
||||
}),
|
||||
getShellExecutionConfig: () => ({
|
||||
terminalWidth: 90,
|
||||
@@ -1482,7 +1482,7 @@ describe('CoreToolScheduler request queueing', () => {
|
||||
getToolRegistry: () => toolRegistry,
|
||||
getContentGeneratorConfig: () => ({
|
||||
model: 'test-model',
|
||||
authType: 'gemini-api-key',
|
||||
authType: 'oauth-personal',
|
||||
}),
|
||||
getShellExecutionConfig: () => ({
|
||||
terminalWidth: 80,
|
||||
@@ -1586,7 +1586,7 @@ describe('CoreToolScheduler request queueing', () => {
|
||||
getAllowedTools: () => [],
|
||||
getContentGeneratorConfig: () => ({
|
||||
model: 'test-model',
|
||||
authType: 'gemini-api-key',
|
||||
authType: 'oauth-personal',
|
||||
}),
|
||||
getShellExecutionConfig: () => ({
|
||||
terminalWidth: 90,
|
||||
@@ -1854,7 +1854,7 @@ describe('CoreToolScheduler Sequential Execution', () => {
|
||||
getAllowedTools: () => [],
|
||||
getContentGeneratorConfig: () => ({
|
||||
model: 'test-model',
|
||||
authType: 'gemini-api-key',
|
||||
authType: 'oauth-personal',
|
||||
}),
|
||||
getShellExecutionConfig: () => ({
|
||||
terminalWidth: 90,
|
||||
@@ -1975,7 +1975,7 @@ describe('CoreToolScheduler Sequential Execution', () => {
|
||||
getAllowedTools: () => [],
|
||||
getContentGeneratorConfig: () => ({
|
||||
model: 'test-model',
|
||||
authType: 'gemini-api-key',
|
||||
authType: 'oauth-personal',
|
||||
}),
|
||||
getShellExecutionConfig: () => ({
|
||||
terminalWidth: 90,
|
||||
|
||||
@@ -100,7 +100,6 @@ describe('GeminiChat', () => {
|
||||
countTokens: vi.fn(),
|
||||
embedContent: vi.fn(),
|
||||
batchEmbedContents: vi.fn(),
|
||||
useSummarizedThinking: vi.fn().mockReturnValue(false),
|
||||
} as unknown as ContentGenerator;
|
||||
|
||||
mockHandleFallback.mockClear();
|
||||
@@ -112,7 +111,7 @@ describe('GeminiChat', () => {
|
||||
getUsageStatisticsEnabled: () => true,
|
||||
getDebugMode: () => false,
|
||||
getContentGeneratorConfig: vi.fn().mockReturnValue({
|
||||
authType: 'gemini-api-key', // Ensure this is set for fallback tests
|
||||
authType: 'oauth-personal', // Ensure this is set for fallback tests
|
||||
model: 'test-model',
|
||||
}),
|
||||
getModel: vi.fn().mockReturnValue('gemini-pro'),
|
||||
@@ -719,99 +718,6 @@ describe('GeminiChat', () => {
|
||||
1,
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle summarized thinking by conditionally including thoughts in history', async () => {
|
||||
// Case 1: useSummarizedThinking is true -> thoughts NOT in history
|
||||
vi.mocked(mockContentGenerator.useSummarizedThinking).mockReturnValue(
|
||||
true,
|
||||
);
|
||||
const stream1 = (async function* () {
|
||||
yield {
|
||||
candidates: [
|
||||
{
|
||||
content: {
|
||||
role: 'model',
|
||||
parts: [{ thought: true, text: 'T1' }, { text: 'A1' }],
|
||||
},
|
||||
finishReason: 'STOP',
|
||||
},
|
||||
],
|
||||
} as unknown as GenerateContentResponse;
|
||||
})();
|
||||
vi.mocked(mockContentGenerator.generateContentStream).mockResolvedValue(
|
||||
stream1,
|
||||
);
|
||||
|
||||
const res1 = await chat.sendMessageStream('m1', { message: 'h1' }, 'p1');
|
||||
for await (const _ of res1);
|
||||
|
||||
const history1 = chat.getHistory();
|
||||
expect(history1[1].parts).toEqual([{ text: 'A1' }]);
|
||||
|
||||
// Case 2: useSummarizedThinking is false -> thoughts ARE in history
|
||||
chat.clearHistory();
|
||||
vi.mocked(mockContentGenerator.useSummarizedThinking).mockReturnValue(
|
||||
false,
|
||||
);
|
||||
const stream2 = (async function* () {
|
||||
yield {
|
||||
candidates: [
|
||||
{
|
||||
content: {
|
||||
role: 'model',
|
||||
parts: [{ thought: true, text: 'T2' }, { text: 'A2' }],
|
||||
},
|
||||
finishReason: 'STOP',
|
||||
},
|
||||
],
|
||||
} as unknown as GenerateContentResponse;
|
||||
})();
|
||||
vi.mocked(mockContentGenerator.generateContentStream).mockResolvedValue(
|
||||
stream2,
|
||||
);
|
||||
|
||||
const res2 = await chat.sendMessageStream('m1', { message: 'h1' }, 'p2');
|
||||
for await (const _ of res2);
|
||||
|
||||
const history2 = chat.getHistory();
|
||||
expect(history2[1].parts).toEqual([
|
||||
{ text: 'T2', thought: true },
|
||||
{ text: 'A2' },
|
||||
]);
|
||||
});
|
||||
|
||||
it('should keep parts with thoughtSignature when consolidating history', async () => {
|
||||
const stream = (async function* () {
|
||||
yield {
|
||||
candidates: [
|
||||
{
|
||||
content: {
|
||||
role: 'model',
|
||||
parts: [
|
||||
{
|
||||
text: 'p1',
|
||||
thoughtSignature: 's1',
|
||||
} as unknown as { text: string; thoughtSignature: string },
|
||||
],
|
||||
},
|
||||
finishReason: 'STOP',
|
||||
},
|
||||
],
|
||||
} as unknown as GenerateContentResponse;
|
||||
})();
|
||||
vi.mocked(mockContentGenerator.generateContentStream).mockResolvedValue(
|
||||
stream,
|
||||
);
|
||||
|
||||
const res = await chat.sendMessageStream('m1', { message: 'h1' }, 'p1');
|
||||
for await (const _ of res);
|
||||
|
||||
const history = chat.getHistory();
|
||||
expect(history[1].parts![0]).toEqual({
|
||||
text: 'p1',
|
||||
thoughtSignature: 's1',
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('addHistory', () => {
|
||||
@@ -1476,7 +1382,7 @@ describe('GeminiChat', () => {
|
||||
});
|
||||
|
||||
it('should call handleFallback with the specific failed model and retry if handler returns true', async () => {
|
||||
const authType = AuthType.USE_GEMINI;
|
||||
const authType = AuthType.LOGIN_WITH_GOOGLE;
|
||||
vi.mocked(mockConfig.getContentGeneratorConfig).mockReturnValue({
|
||||
model: 'test-model',
|
||||
authType,
|
||||
@@ -1626,7 +1532,7 @@ describe('GeminiChat', () => {
|
||||
});
|
||||
|
||||
describe('stripThoughtsFromHistory', () => {
|
||||
it('should strip thoughts and thought signatures, and remove empty content objects', () => {
|
||||
it('should strip thought signatures', () => {
|
||||
chat.setHistory([
|
||||
{
|
||||
role: 'user',
|
||||
@@ -1638,15 +1544,10 @@ describe('GeminiChat', () => {
|
||||
{ text: 'thinking...', thought: true },
|
||||
{ text: 'hi' },
|
||||
{
|
||||
text: 'hidden metadata',
|
||||
thoughtSignature: 'abc',
|
||||
} as unknown as { text: string; thoughtSignature: string },
|
||||
functionCall: { name: 'test', args: {} },
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
role: 'model',
|
||||
parts: [{ text: 'only thinking', thought: true }],
|
||||
},
|
||||
]);
|
||||
|
||||
chat.stripThoughtsFromHistory();
|
||||
@@ -1658,7 +1559,7 @@ describe('GeminiChat', () => {
|
||||
},
|
||||
{
|
||||
role: 'model',
|
||||
parts: [{ text: 'hi' }, { text: 'hidden metadata' }],
|
||||
parts: [{ text: 'hi' }, { functionCall: { name: 'test', args: {} } }],
|
||||
},
|
||||
]);
|
||||
});
|
||||
|
||||
@@ -92,7 +92,6 @@ export function isValidNonThoughtTextPart(part: Part): boolean {
|
||||
return (
|
||||
typeof part.text === 'string' &&
|
||||
!part.thought &&
|
||||
!part.thoughtSignature &&
|
||||
// Technically, the model should never generate parts that have text and
|
||||
// any of these but we don't trust them so check anyways.
|
||||
!part.functionCall &&
|
||||
@@ -110,24 +109,18 @@ function isValidContent(content: Content): boolean {
|
||||
if (part === undefined || Object.keys(part).length === 0) {
|
||||
return false;
|
||||
}
|
||||
if (!isValidContentPart(part)) {
|
||||
if (
|
||||
!part.thought &&
|
||||
part.text !== undefined &&
|
||||
part.text === '' &&
|
||||
part.functionCall === undefined
|
||||
) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
function isValidContentPart(part: Part): boolean {
|
||||
const isInvalid =
|
||||
!part.thought &&
|
||||
!part.thoughtSignature &&
|
||||
part.text !== undefined &&
|
||||
part.text === '' &&
|
||||
part.functionCall === undefined;
|
||||
|
||||
return !isInvalid;
|
||||
}
|
||||
|
||||
/**
|
||||
* Validates the history contains the correct roles.
|
||||
*
|
||||
@@ -455,29 +448,15 @@ export class GeminiChat {
|
||||
if (!content.parts) return content;
|
||||
|
||||
// Filter out thought parts entirely
|
||||
const filteredParts = content.parts
|
||||
.filter(
|
||||
(part) =>
|
||||
!(
|
||||
part &&
|
||||
typeof part === 'object' &&
|
||||
'thought' in part &&
|
||||
part.thought
|
||||
),
|
||||
)
|
||||
.map((part) => {
|
||||
if (
|
||||
const filteredParts = content.parts.filter(
|
||||
(part) =>
|
||||
!(
|
||||
part &&
|
||||
typeof part === 'object' &&
|
||||
'thoughtSignature' in part
|
||||
) {
|
||||
const newPart = { ...part };
|
||||
delete (newPart as { thoughtSignature?: string })
|
||||
.thoughtSignature;
|
||||
return newPart;
|
||||
}
|
||||
return part;
|
||||
});
|
||||
'thought' in part &&
|
||||
part.thought
|
||||
),
|
||||
);
|
||||
|
||||
return {
|
||||
...content,
|
||||
@@ -559,15 +538,11 @@ export class GeminiChat {
|
||||
yield chunk; // Yield every chunk to the UI immediately.
|
||||
}
|
||||
|
||||
let thoughtText = '';
|
||||
// Only include thoughts if not using summarized thinking.
|
||||
if (!this.config.getContentGenerator().useSummarizedThinking()) {
|
||||
thoughtText = allModelParts
|
||||
.filter((part) => part.thought)
|
||||
.map((part) => part.text)
|
||||
.join('')
|
||||
.trim();
|
||||
}
|
||||
const thoughtParts = allModelParts.filter((part) => part.thought);
|
||||
const thoughtText = thoughtParts
|
||||
.map((part) => part.text)
|
||||
.join('')
|
||||
.trim();
|
||||
|
||||
const contentParts = allModelParts.filter((part) => !part.thought);
|
||||
const consolidatedHistoryParts: Part[] = [];
|
||||
@@ -580,7 +555,7 @@ export class GeminiChat {
|
||||
isValidNonThoughtTextPart(part)
|
||||
) {
|
||||
lastPart.text += part.text;
|
||||
} else if (isValidContentPart(part)) {
|
||||
} else {
|
||||
consolidatedHistoryParts.push(part);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,173 +0,0 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { describe, it, expect, vi, beforeEach } from 'vitest';
|
||||
import { GeminiContentGenerator } from './geminiContentGenerator.js';
|
||||
import { GoogleGenAI } from '@google/genai';
|
||||
|
||||
vi.mock('@google/genai', () => {
|
||||
const mockGenerateContent = vi.fn();
|
||||
const mockGenerateContentStream = vi.fn();
|
||||
const mockCountTokens = vi.fn();
|
||||
const mockEmbedContent = vi.fn();
|
||||
|
||||
return {
|
||||
GoogleGenAI: vi.fn().mockImplementation(() => ({
|
||||
models: {
|
||||
generateContent: mockGenerateContent,
|
||||
generateContentStream: mockGenerateContentStream,
|
||||
countTokens: mockCountTokens,
|
||||
embedContent: mockEmbedContent,
|
||||
},
|
||||
})),
|
||||
};
|
||||
});
|
||||
|
||||
describe('GeminiContentGenerator', () => {
|
||||
let generator: GeminiContentGenerator;
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
let mockGoogleGenAI: any;
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
generator = new GeminiContentGenerator({
|
||||
apiKey: 'test-api-key',
|
||||
});
|
||||
mockGoogleGenAI = vi.mocked(GoogleGenAI).mock.results[0].value;
|
||||
});
|
||||
|
||||
it('should call generateContent on the underlying model', async () => {
|
||||
const request = { model: 'gemini-1.5-flash', contents: [] };
|
||||
const expectedResponse = { responseId: 'test-id' };
|
||||
mockGoogleGenAI.models.generateContent.mockResolvedValue(expectedResponse);
|
||||
|
||||
const response = await generator.generateContent(request, 'prompt-id');
|
||||
|
||||
expect(mockGoogleGenAI.models.generateContent).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
...request,
|
||||
config: expect.objectContaining({
|
||||
temperature: 1,
|
||||
topP: 0.95,
|
||||
thinkingConfig: {
|
||||
includeThoughts: true,
|
||||
thinkingLevel: 'THINKING_LEVEL_UNSPECIFIED',
|
||||
},
|
||||
}),
|
||||
}),
|
||||
);
|
||||
expect(response).toBe(expectedResponse);
|
||||
});
|
||||
|
||||
it('should call generateContentStream on the underlying model', async () => {
|
||||
const request = { model: 'gemini-1.5-flash', contents: [] };
|
||||
const mockStream = (async function* () {
|
||||
yield { responseId: '1' };
|
||||
})();
|
||||
mockGoogleGenAI.models.generateContentStream.mockResolvedValue(mockStream);
|
||||
|
||||
const stream = await generator.generateContentStream(request, 'prompt-id');
|
||||
|
||||
expect(mockGoogleGenAI.models.generateContentStream).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
...request,
|
||||
config: expect.objectContaining({
|
||||
temperature: 1,
|
||||
topP: 0.95,
|
||||
thinkingConfig: {
|
||||
includeThoughts: true,
|
||||
thinkingLevel: 'THINKING_LEVEL_UNSPECIFIED',
|
||||
},
|
||||
}),
|
||||
}),
|
||||
);
|
||||
expect(stream).toBe(mockStream);
|
||||
});
|
||||
|
||||
it('should call countTokens on the underlying model', async () => {
|
||||
const request = { model: 'gemini-1.5-flash', contents: [] };
|
||||
const expectedResponse = { totalTokens: 10 };
|
||||
mockGoogleGenAI.models.countTokens.mockResolvedValue(expectedResponse);
|
||||
|
||||
const response = await generator.countTokens(request);
|
||||
|
||||
expect(mockGoogleGenAI.models.countTokens).toHaveBeenCalledWith(request);
|
||||
expect(response).toBe(expectedResponse);
|
||||
});
|
||||
|
||||
it('should call embedContent on the underlying model', async () => {
|
||||
const request = { model: 'embedding-model', contents: [] };
|
||||
const expectedResponse = { embeddings: [] };
|
||||
mockGoogleGenAI.models.embedContent.mockResolvedValue(expectedResponse);
|
||||
|
||||
const response = await generator.embedContent(request);
|
||||
|
||||
expect(mockGoogleGenAI.models.embedContent).toHaveBeenCalledWith(request);
|
||||
expect(response).toBe(expectedResponse);
|
||||
});
|
||||
|
||||
it('should prioritize contentGeneratorConfig samplingParams over request config', async () => {
|
||||
const generatorWithParams = new GeminiContentGenerator({ apiKey: 'test' }, {
|
||||
model: 'gemini-1.5-flash',
|
||||
samplingParams: {
|
||||
temperature: 0.1,
|
||||
top_p: 0.2,
|
||||
},
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
} as any);
|
||||
|
||||
const request = {
|
||||
model: 'gemini-1.5-flash',
|
||||
contents: [],
|
||||
config: {
|
||||
temperature: 0.9,
|
||||
topP: 0.9,
|
||||
},
|
||||
};
|
||||
|
||||
await generatorWithParams.generateContent(request, 'prompt-id');
|
||||
|
||||
expect(mockGoogleGenAI.models.generateContent).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
config: expect.objectContaining({
|
||||
temperature: 0.1,
|
||||
topP: 0.2,
|
||||
}),
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
it('should map reasoning effort to thinkingConfig', async () => {
|
||||
const generatorWithReasoning = new GeminiContentGenerator(
|
||||
{ apiKey: 'test' },
|
||||
{
|
||||
model: 'gemini-2.5-pro',
|
||||
reasoning: {
|
||||
effort: 'high',
|
||||
},
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
} as any,
|
||||
);
|
||||
|
||||
const request = {
|
||||
model: 'gemini-2.5-pro',
|
||||
contents: [],
|
||||
};
|
||||
|
||||
await generatorWithReasoning.generateContent(request, 'prompt-id');
|
||||
|
||||
expect(mockGoogleGenAI.models.generateContent).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
config: expect.objectContaining({
|
||||
thinkingConfig: {
|
||||
includeThoughts: true,
|
||||
thinkingLevel: 'HIGH',
|
||||
},
|
||||
}),
|
||||
}),
|
||||
);
|
||||
});
|
||||
});
|
||||
@@ -1,144 +0,0 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import type {
|
||||
CountTokensParameters,
|
||||
CountTokensResponse,
|
||||
EmbedContentParameters,
|
||||
EmbedContentResponse,
|
||||
GenerateContentParameters,
|
||||
GenerateContentResponse,
|
||||
GenerateContentConfig,
|
||||
ThinkingLevel,
|
||||
} from '@google/genai';
|
||||
import { GoogleGenAI } from '@google/genai';
|
||||
import type {
|
||||
ContentGenerator,
|
||||
ContentGeneratorConfig,
|
||||
} from '../contentGenerator.js';
|
||||
|
||||
/**
|
||||
* A wrapper for GoogleGenAI that implements the ContentGenerator interface.
|
||||
*/
|
||||
export class GeminiContentGenerator implements ContentGenerator {
|
||||
private readonly googleGenAI: GoogleGenAI;
|
||||
private readonly contentGeneratorConfig?: ContentGeneratorConfig;
|
||||
|
||||
constructor(
|
||||
options: {
|
||||
apiKey?: string;
|
||||
vertexai?: boolean;
|
||||
httpOptions?: { headers: Record<string, string> };
|
||||
},
|
||||
contentGeneratorConfig?: ContentGeneratorConfig,
|
||||
) {
|
||||
this.googleGenAI = new GoogleGenAI(options);
|
||||
this.contentGeneratorConfig = contentGeneratorConfig;
|
||||
}
|
||||
|
||||
private buildSamplingParameters(
|
||||
request: GenerateContentParameters,
|
||||
): GenerateContentConfig {
|
||||
const configSamplingParams = this.contentGeneratorConfig?.samplingParams;
|
||||
const requestConfig = request.config || {};
|
||||
|
||||
// Helper function to get parameter value with priority: config > request > default
|
||||
const getParameterValue = <T>(
|
||||
configValue: T | undefined,
|
||||
requestKey: keyof GenerateContentConfig,
|
||||
defaultValue?: T,
|
||||
): T | undefined => {
|
||||
const requestValue = requestConfig[requestKey] as T | undefined;
|
||||
|
||||
if (configValue !== undefined) return configValue;
|
||||
if (requestValue !== undefined) return requestValue;
|
||||
return defaultValue;
|
||||
};
|
||||
|
||||
return {
|
||||
...requestConfig,
|
||||
temperature: getParameterValue<number>(
|
||||
configSamplingParams?.temperature,
|
||||
'temperature',
|
||||
1,
|
||||
),
|
||||
topP: getParameterValue<number>(
|
||||
configSamplingParams?.top_p,
|
||||
'topP',
|
||||
0.95,
|
||||
),
|
||||
topK: getParameterValue<number>(configSamplingParams?.top_k, 'topK', 64),
|
||||
maxOutputTokens: getParameterValue<number>(
|
||||
configSamplingParams?.max_tokens,
|
||||
'maxOutputTokens',
|
||||
),
|
||||
presencePenalty: getParameterValue<number>(
|
||||
configSamplingParams?.presence_penalty,
|
||||
'presencePenalty',
|
||||
),
|
||||
frequencyPenalty: getParameterValue<number>(
|
||||
configSamplingParams?.frequency_penalty,
|
||||
'frequencyPenalty',
|
||||
),
|
||||
thinkingConfig: getParameterValue(
|
||||
this.contentGeneratorConfig?.reasoning
|
||||
? {
|
||||
includeThoughts: true,
|
||||
thinkingLevel: (this.contentGeneratorConfig.reasoning.effort ===
|
||||
'low'
|
||||
? 'LOW'
|
||||
: this.contentGeneratorConfig.reasoning.effort === 'high'
|
||||
? 'HIGH'
|
||||
: 'THINKING_LEVEL_UNSPECIFIED') as ThinkingLevel,
|
||||
}
|
||||
: undefined,
|
||||
'thinkingConfig',
|
||||
{
|
||||
includeThoughts: true,
|
||||
thinkingLevel: 'THINKING_LEVEL_UNSPECIFIED' as ThinkingLevel,
|
||||
},
|
||||
),
|
||||
};
|
||||
}
|
||||
|
||||
async generateContent(
|
||||
request: GenerateContentParameters,
|
||||
_userPromptId: string,
|
||||
): Promise<GenerateContentResponse> {
|
||||
const finalRequest = {
|
||||
...request,
|
||||
config: this.buildSamplingParameters(request),
|
||||
};
|
||||
return this.googleGenAI.models.generateContent(finalRequest);
|
||||
}
|
||||
|
||||
async generateContentStream(
|
||||
request: GenerateContentParameters,
|
||||
_userPromptId: string,
|
||||
): Promise<AsyncGenerator<GenerateContentResponse>> {
|
||||
const finalRequest = {
|
||||
...request,
|
||||
config: this.buildSamplingParameters(request),
|
||||
};
|
||||
return this.googleGenAI.models.generateContentStream(finalRequest);
|
||||
}
|
||||
|
||||
async countTokens(
|
||||
request: CountTokensParameters,
|
||||
): Promise<CountTokensResponse> {
|
||||
return this.googleGenAI.models.countTokens(request);
|
||||
}
|
||||
|
||||
async embedContent(
|
||||
request: EmbedContentParameters,
|
||||
): Promise<EmbedContentResponse> {
|
||||
return this.googleGenAI.models.embedContent(request);
|
||||
}
|
||||
|
||||
useSummarizedThinking(): boolean {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
@@ -1,47 +0,0 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { describe, it, expect, vi, beforeEach } from 'vitest';
|
||||
import { createGeminiContentGenerator } from './index.js';
|
||||
import { GeminiContentGenerator } from './geminiContentGenerator.js';
|
||||
import { LoggingContentGenerator } from './loggingContentGenerator.js';
|
||||
import type { Config } from '../../config/config.js';
|
||||
import { AuthType } from '../contentGenerator.js';
|
||||
|
||||
vi.mock('./geminiContentGenerator.js', () => ({
|
||||
GeminiContentGenerator: vi.fn().mockImplementation(() => ({})),
|
||||
}));
|
||||
|
||||
vi.mock('./loggingContentGenerator.js', () => ({
|
||||
LoggingContentGenerator: vi.fn().mockImplementation((wrapped) => wrapped),
|
||||
}));
|
||||
|
||||
describe('createGeminiContentGenerator', () => {
|
||||
let mockConfig: Config;
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
mockConfig = {
|
||||
getUsageStatisticsEnabled: vi.fn().mockReturnValue(false),
|
||||
getContentGeneratorConfig: vi.fn().mockReturnValue({}),
|
||||
getCliVersion: vi.fn().mockReturnValue('1.0.0'),
|
||||
} as unknown as Config;
|
||||
});
|
||||
|
||||
it('should create a GeminiContentGenerator wrapped in LoggingContentGenerator', () => {
|
||||
const config = {
|
||||
model: 'gemini-1.5-flash',
|
||||
apiKey: 'test-key',
|
||||
authType: AuthType.USE_GEMINI,
|
||||
};
|
||||
|
||||
const generator = createGeminiContentGenerator(config, mockConfig);
|
||||
|
||||
expect(GeminiContentGenerator).toHaveBeenCalled();
|
||||
expect(LoggingContentGenerator).toHaveBeenCalled();
|
||||
expect(generator).toBeDefined();
|
||||
});
|
||||
});
|
||||
@@ -1,55 +0,0 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { GeminiContentGenerator } from './geminiContentGenerator.js';
|
||||
import type {
|
||||
ContentGenerator,
|
||||
ContentGeneratorConfig,
|
||||
} from '../contentGenerator.js';
|
||||
import type { Config } from '../../config/config.js';
|
||||
import { InstallationManager } from '../../utils/installationManager.js';
|
||||
import { LoggingContentGenerator } from './loggingContentGenerator.js';
|
||||
|
||||
export { GeminiContentGenerator } from './geminiContentGenerator.js';
|
||||
export { LoggingContentGenerator } from './loggingContentGenerator.js';
|
||||
|
||||
/**
|
||||
* Create a Gemini content generator.
|
||||
*/
|
||||
export function createGeminiContentGenerator(
|
||||
config: ContentGeneratorConfig,
|
||||
gcConfig: Config,
|
||||
): ContentGenerator {
|
||||
const version = process.env['CLI_VERSION'] || process.version;
|
||||
const userAgent =
|
||||
config.userAgent ||
|
||||
`QwenCode/${version} (${process.platform}; ${process.arch})`;
|
||||
const baseHeaders: Record<string, string> = {
|
||||
'User-Agent': userAgent,
|
||||
};
|
||||
|
||||
let headers: Record<string, string> = { ...baseHeaders };
|
||||
if (gcConfig?.getUsageStatisticsEnabled()) {
|
||||
const installationManager = new InstallationManager();
|
||||
const installationId = installationManager.getInstallationId();
|
||||
headers = {
|
||||
...headers,
|
||||
'x-gemini-api-privileged-user-id': `${installationId}`,
|
||||
};
|
||||
}
|
||||
const httpOptions = { headers };
|
||||
|
||||
const geminiContentGenerator = new GeminiContentGenerator(
|
||||
{
|
||||
apiKey: config.apiKey === '' ? undefined : config.apiKey,
|
||||
vertexai: config.vertexai,
|
||||
httpOptions,
|
||||
},
|
||||
config,
|
||||
);
|
||||
|
||||
return new LoggingContentGenerator(geminiContentGenerator, gcConfig);
|
||||
}
|
||||
@@ -13,24 +13,21 @@ import type {
|
||||
GenerateContentParameters,
|
||||
GenerateContentResponseUsageMetadata,
|
||||
GenerateContentResponse,
|
||||
ContentListUnion,
|
||||
ContentUnion,
|
||||
Part,
|
||||
PartUnion,
|
||||
} from '@google/genai';
|
||||
import {
|
||||
ApiRequestEvent,
|
||||
ApiResponseEvent,
|
||||
ApiErrorEvent,
|
||||
} from '../../telemetry/types.js';
|
||||
import type { Config } from '../../config/config.js';
|
||||
} from '../telemetry/types.js';
|
||||
import type { Config } from '../config/config.js';
|
||||
import {
|
||||
logApiError,
|
||||
logApiRequest,
|
||||
logApiResponse,
|
||||
} from '../../telemetry/loggers.js';
|
||||
import type { ContentGenerator } from '../contentGenerator.js';
|
||||
import { isStructuredError } from '../../utils/quotaErrorDetection.js';
|
||||
} from '../telemetry/loggers.js';
|
||||
import type { ContentGenerator } from './contentGenerator.js';
|
||||
import { toContents } from '../code_assist/converter.js';
|
||||
import { isStructuredError } from '../utils/quotaErrorDetection.js';
|
||||
|
||||
interface StructuredError {
|
||||
status: number;
|
||||
@@ -115,7 +112,7 @@ export class LoggingContentGenerator implements ContentGenerator {
|
||||
userPromptId: string,
|
||||
): Promise<GenerateContentResponse> {
|
||||
const startTime = Date.now();
|
||||
this.logApiRequest(this.toContents(req.contents), req.model, userPromptId);
|
||||
this.logApiRequest(toContents(req.contents), req.model, userPromptId);
|
||||
try {
|
||||
const response = await this.wrapped.generateContent(req, userPromptId);
|
||||
const durationMs = Date.now() - startTime;
|
||||
@@ -140,7 +137,7 @@ export class LoggingContentGenerator implements ContentGenerator {
|
||||
userPromptId: string,
|
||||
): Promise<AsyncGenerator<GenerateContentResponse>> {
|
||||
const startTime = Date.now();
|
||||
this.logApiRequest(this.toContents(req.contents), req.model, userPromptId);
|
||||
this.logApiRequest(toContents(req.contents), req.model, userPromptId);
|
||||
|
||||
let stream: AsyncGenerator<GenerateContentResponse>;
|
||||
try {
|
||||
@@ -208,95 +205,4 @@ export class LoggingContentGenerator implements ContentGenerator {
|
||||
): Promise<EmbedContentResponse> {
|
||||
return this.wrapped.embedContent(req);
|
||||
}
|
||||
|
||||
useSummarizedThinking(): boolean {
|
||||
return this.wrapped.useSummarizedThinking();
|
||||
}
|
||||
|
||||
private toContents(contents: ContentListUnion): Content[] {
|
||||
if (Array.isArray(contents)) {
|
||||
// it's a Content[] or a PartsUnion[]
|
||||
return contents.map((c) => this.toContent(c));
|
||||
}
|
||||
// it's a Content or a PartsUnion
|
||||
return [this.toContent(contents)];
|
||||
}
|
||||
|
||||
private toContent(content: ContentUnion): Content {
|
||||
if (Array.isArray(content)) {
|
||||
// it's a PartsUnion[]
|
||||
return {
|
||||
role: 'user',
|
||||
parts: this.toParts(content),
|
||||
};
|
||||
}
|
||||
if (typeof content === 'string') {
|
||||
// it's a string
|
||||
return {
|
||||
role: 'user',
|
||||
parts: [{ text: content }],
|
||||
};
|
||||
}
|
||||
if ('parts' in content) {
|
||||
// it's a Content - process parts to handle thought filtering
|
||||
return {
|
||||
...content,
|
||||
parts: content.parts
|
||||
? this.toParts(content.parts.filter((p) => p != null))
|
||||
: [],
|
||||
};
|
||||
}
|
||||
// it's a Part
|
||||
return {
|
||||
role: 'user',
|
||||
parts: [this.toPart(content as Part)],
|
||||
};
|
||||
}
|
||||
|
||||
private toParts(parts: PartUnion[]): Part[] {
|
||||
return parts.map((p) => this.toPart(p));
|
||||
}
|
||||
|
||||
private toPart(part: PartUnion): Part {
|
||||
if (typeof part === 'string') {
|
||||
// it's a string
|
||||
return { text: part };
|
||||
}
|
||||
|
||||
// Handle thought parts for CountToken API compatibility
|
||||
// The CountToken API expects parts to have certain required "oneof" fields initialized,
|
||||
// but thought parts don't conform to this schema and cause API failures
|
||||
if ('thought' in part && part.thought) {
|
||||
const thoughtText = `[Thought: ${part.thought}]`;
|
||||
|
||||
const newPart = { ...part };
|
||||
delete (newPart as Record<string, unknown>)['thought'];
|
||||
|
||||
const hasApiContent =
|
||||
'functionCall' in newPart ||
|
||||
'functionResponse' in newPart ||
|
||||
'inlineData' in newPart ||
|
||||
'fileData' in newPart;
|
||||
|
||||
if (hasApiContent) {
|
||||
// It's a functionCall or other non-text part. Just strip the thought.
|
||||
return newPart;
|
||||
}
|
||||
|
||||
// If no other valid API content, this must be a text part.
|
||||
// Combine existing text (if any) with the thought, preserving other properties.
|
||||
const text = (newPart as { text?: unknown }).text;
|
||||
const existingText = text ? String(text) : '';
|
||||
const combinedText = existingText
|
||||
? `${existingText}\n${thoughtText}`
|
||||
: thoughtText;
|
||||
|
||||
return {
|
||||
...newPart,
|
||||
text: combinedText,
|
||||
};
|
||||
}
|
||||
|
||||
return part;
|
||||
}
|
||||
}
|
||||
@@ -47,7 +47,7 @@ describe('executeToolCall', () => {
|
||||
getDebugMode: () => false,
|
||||
getContentGeneratorConfig: () => ({
|
||||
model: 'test-model',
|
||||
authType: 'gemini-api-key',
|
||||
authType: 'oauth-personal',
|
||||
}),
|
||||
getShellExecutionConfig: () => ({
|
||||
terminalWidth: 90,
|
||||
|
||||
@@ -99,7 +99,6 @@ describe('OpenAIContentGenerator (Refactored)', () => {
|
||||
},
|
||||
} as unknown as OpenAI),
|
||||
buildRequest: vi.fn().mockImplementation((req) => req),
|
||||
getDefaultGenerationConfig: vi.fn().mockReturnValue({}),
|
||||
};
|
||||
|
||||
generator = new OpenAIContentGenerator(
|
||||
@@ -212,7 +211,6 @@ describe('OpenAIContentGenerator (Refactored)', () => {
|
||||
},
|
||||
} as unknown as OpenAI),
|
||||
buildRequest: vi.fn().mockImplementation((req) => req),
|
||||
getDefaultGenerationConfig: vi.fn().mockReturnValue({}),
|
||||
};
|
||||
|
||||
const testGenerator = new TestGenerator(
|
||||
@@ -279,7 +277,6 @@ describe('OpenAIContentGenerator (Refactored)', () => {
|
||||
},
|
||||
} as unknown as OpenAI),
|
||||
buildRequest: vi.fn().mockImplementation((req) => req),
|
||||
getDefaultGenerationConfig: vi.fn().mockReturnValue({}),
|
||||
};
|
||||
|
||||
const testGenerator = new TestGenerator(
|
||||
|
||||
@@ -154,8 +154,4 @@ export class OpenAIContentGenerator implements ContentGenerator {
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
useSummarizedThinking(): boolean {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -60,7 +60,6 @@ describe('ContentGenerationPipeline', () => {
|
||||
buildClient: vi.fn().mockReturnValue(mockClient),
|
||||
buildRequest: vi.fn().mockImplementation((req) => req),
|
||||
buildHeaders: vi.fn().mockReturnValue({}),
|
||||
getDefaultGenerationConfig: vi.fn().mockReturnValue({}),
|
||||
};
|
||||
|
||||
// Mock telemetry service
|
||||
|
||||
@@ -283,22 +283,16 @@ export class ContentGenerationPipeline {
|
||||
private buildSamplingParameters(
|
||||
request: GenerateContentParameters,
|
||||
): Record<string, unknown> {
|
||||
const defaultSamplingParams =
|
||||
this.config.provider.getDefaultGenerationConfig();
|
||||
const configSamplingParams = this.contentGeneratorConfig.samplingParams;
|
||||
|
||||
// Helper function to get parameter value with priority: config > request > default
|
||||
const getParameterValue = <T>(
|
||||
configKey: keyof NonNullable<typeof configSamplingParams>,
|
||||
requestKey?: keyof NonNullable<typeof request.config>,
|
||||
requestKey: keyof NonNullable<typeof request.config>,
|
||||
defaultValue?: T,
|
||||
): T | undefined => {
|
||||
const configValue = configSamplingParams?.[configKey] as T | undefined;
|
||||
const requestValue = requestKey
|
||||
? (request.config?.[requestKey] as T | undefined)
|
||||
: undefined;
|
||||
const defaultValue = requestKey
|
||||
? (defaultSamplingParams[requestKey] as T)
|
||||
: undefined;
|
||||
const requestValue = request.config?.[requestKey] as T | undefined;
|
||||
|
||||
if (configValue !== undefined) return configValue;
|
||||
if (requestValue !== undefined) return requestValue;
|
||||
@@ -310,8 +304,12 @@ export class ContentGenerationPipeline {
|
||||
key: string,
|
||||
configKey: keyof NonNullable<typeof configSamplingParams>,
|
||||
requestKey?: keyof NonNullable<typeof request.config>,
|
||||
): Record<string, T | undefined> => {
|
||||
const value = getParameterValue<T>(configKey, requestKey);
|
||||
defaultValue?: T,
|
||||
): Record<string, T> | Record<string, never> => {
|
||||
const value = requestKey
|
||||
? getParameterValue(configKey, requestKey, defaultValue)
|
||||
: ((configSamplingParams?.[configKey] as T | undefined) ??
|
||||
defaultValue);
|
||||
|
||||
return value !== undefined ? { [key]: value } : {};
|
||||
};
|
||||
@@ -325,18 +323,10 @@ export class ContentGenerationPipeline {
|
||||
...addParameterIfDefined('max_tokens', 'max_tokens', 'maxOutputTokens'),
|
||||
|
||||
// Config-only parameters (no request fallback)
|
||||
...addParameterIfDefined('top_k', 'top_k', 'topK'),
|
||||
...addParameterIfDefined('top_k', 'top_k'),
|
||||
...addParameterIfDefined('repetition_penalty', 'repetition_penalty'),
|
||||
...addParameterIfDefined(
|
||||
'presence_penalty',
|
||||
'presence_penalty',
|
||||
'presencePenalty',
|
||||
),
|
||||
...addParameterIfDefined(
|
||||
'frequency_penalty',
|
||||
'frequency_penalty',
|
||||
'frequencyPenalty',
|
||||
),
|
||||
...addParameterIfDefined('presence_penalty', 'presence_penalty'),
|
||||
...addParameterIfDefined('frequency_penalty', 'frequency_penalty'),
|
||||
};
|
||||
|
||||
return params;
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import OpenAI from 'openai';
|
||||
import type { GenerateContentConfig } from '@google/genai';
|
||||
import type { Config } from '../../../config/config.js';
|
||||
import type { ContentGeneratorConfig } from '../../contentGenerator.js';
|
||||
import { AuthType } from '../../contentGenerator.js';
|
||||
@@ -142,14 +141,6 @@ export class DashScopeOpenAICompatibleProvider
|
||||
};
|
||||
}
|
||||
|
||||
getDefaultGenerationConfig(): GenerateContentConfig {
|
||||
return {
|
||||
temperature: 0.7,
|
||||
topP: 0.8,
|
||||
topK: 20,
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Add cache control flag to specified message(s) for DashScope providers
|
||||
*/
|
||||
|
||||
@@ -8,7 +8,6 @@ import type OpenAI from 'openai';
|
||||
import type { Config } from '../../../config/config.js';
|
||||
import type { ContentGeneratorConfig } from '../../contentGenerator.js';
|
||||
import { DefaultOpenAICompatibleProvider } from './default.js';
|
||||
import type { GenerateContentConfig } from '@google/genai';
|
||||
|
||||
export class DeepSeekOpenAICompatibleProvider extends DefaultOpenAICompatibleProvider {
|
||||
constructor(
|
||||
@@ -77,10 +76,4 @@ export class DeepSeekOpenAICompatibleProvider extends DefaultOpenAICompatiblePro
|
||||
messages,
|
||||
};
|
||||
}
|
||||
|
||||
override getDefaultGenerationConfig(): GenerateContentConfig {
|
||||
return {
|
||||
temperature: 0,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import OpenAI from 'openai';
|
||||
import type { GenerateContentConfig } from '@google/genai';
|
||||
import type { Config } from '../../../config/config.js';
|
||||
import type { ContentGeneratorConfig } from '../../contentGenerator.js';
|
||||
import { DEFAULT_TIMEOUT, DEFAULT_MAX_RETRIES } from '../constants.js';
|
||||
@@ -56,10 +55,4 @@ export class DefaultOpenAICompatibleProvider
|
||||
...request, // Preserve all original parameters including sampling params
|
||||
};
|
||||
}
|
||||
|
||||
getDefaultGenerationConfig(): GenerateContentConfig {
|
||||
return {
|
||||
topP: 0.95,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import type { GenerateContentConfig } from '@google/genai';
|
||||
import type OpenAI from 'openai';
|
||||
|
||||
// Extended types to support cache_control for DashScope
|
||||
@@ -23,7 +22,6 @@ export interface OpenAICompatibleProvider {
|
||||
request: OpenAI.Chat.ChatCompletionCreateParams,
|
||||
userPromptId: string,
|
||||
): OpenAI.Chat.ChatCompletionCreateParams;
|
||||
getDefaultGenerationConfig(): GenerateContentConfig;
|
||||
}
|
||||
|
||||
export type DashScopeRequestMetadata = {
|
||||
|
||||
@@ -36,6 +36,13 @@ vi.mock('../utils/errorReporting', () => ({
|
||||
reportError: vi.fn(),
|
||||
}));
|
||||
|
||||
// Use the actual implementation from partUtils now that it's provided.
|
||||
vi.mock('../utils/generateContentResponseUtilities', () => ({
|
||||
getResponseText: (resp: GenerateContentResponse) =>
|
||||
resp.candidates?.[0]?.content?.parts?.map((part) => part.text).join('') ||
|
||||
undefined,
|
||||
}));
|
||||
|
||||
describe('Turn', () => {
|
||||
let turn: Turn;
|
||||
// Define a type for the mocked Chat instance for clarity
|
||||
@@ -149,7 +156,6 @@ describe('Turn', () => {
|
||||
type: GeminiEventType.Thought,
|
||||
value: { subject: '', description: 'reasoning...' },
|
||||
},
|
||||
{ type: GeminiEventType.Content, value: 'final answer' },
|
||||
]);
|
||||
});
|
||||
|
||||
|
||||
@@ -27,11 +27,7 @@ import {
|
||||
toFriendlyError,
|
||||
} from '../utils/errors.js';
|
||||
import type { GeminiChat } from './geminiChat.js';
|
||||
import {
|
||||
getThoughtText,
|
||||
parseThought,
|
||||
type ThoughtSummary,
|
||||
} from '../utils/thoughtUtils.js';
|
||||
import { getThoughtText, type ThoughtSummary } from '../utils/thoughtUtils.js';
|
||||
|
||||
// Define a structure for tools passed to the server
|
||||
export interface ServerTool {
|
||||
@@ -270,12 +266,13 @@ export class Turn {
|
||||
this.currentResponseId = resp.responseId;
|
||||
}
|
||||
|
||||
const thoughtText = getThoughtText(resp);
|
||||
if (thoughtText) {
|
||||
const thoughtPart = getThoughtText(resp);
|
||||
if (thoughtPart) {
|
||||
yield {
|
||||
type: GeminiEventType.Thought,
|
||||
value: parseThought(thoughtText),
|
||||
value: { subject: '', description: thoughtPart },
|
||||
};
|
||||
continue;
|
||||
}
|
||||
|
||||
const text = getResponseText(resp);
|
||||
|
||||
@@ -4,10 +4,36 @@
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { describe, it, expect, vi, beforeEach } from 'vitest';
|
||||
import {
|
||||
describe,
|
||||
it,
|
||||
expect,
|
||||
vi,
|
||||
beforeEach,
|
||||
type Mock,
|
||||
type MockInstance,
|
||||
afterEach,
|
||||
} from 'vitest';
|
||||
import { handleFallback } from './handler.js';
|
||||
import type { Config } from '../config/config.js';
|
||||
import { AuthType } from '../core/contentGenerator.js';
|
||||
import {
|
||||
DEFAULT_GEMINI_FLASH_MODEL,
|
||||
DEFAULT_GEMINI_MODEL,
|
||||
} from '../config/models.js';
|
||||
import { logFlashFallback } from '../telemetry/index.js';
|
||||
import type { FallbackModelHandler } from './types.js';
|
||||
|
||||
// Mock the telemetry logger and event class
|
||||
vi.mock('../telemetry/index.js', () => ({
|
||||
logFlashFallback: vi.fn(),
|
||||
FlashFallbackEvent: class {},
|
||||
}));
|
||||
|
||||
const MOCK_PRO_MODEL = DEFAULT_GEMINI_MODEL;
|
||||
const FALLBACK_MODEL = DEFAULT_GEMINI_FLASH_MODEL;
|
||||
const AUTH_OAUTH = AuthType.LOGIN_WITH_GOOGLE;
|
||||
const AUTH_API_KEY = AuthType.USE_GEMINI;
|
||||
|
||||
const createMockConfig = (overrides: Partial<Config> = {}): Config =>
|
||||
({
|
||||
@@ -19,28 +45,174 @@ const createMockConfig = (overrides: Partial<Config> = {}): Config =>
|
||||
|
||||
describe('handleFallback', () => {
|
||||
let mockConfig: Config;
|
||||
let mockHandler: Mock<FallbackModelHandler>;
|
||||
let consoleErrorSpy: MockInstance;
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
mockConfig = createMockConfig();
|
||||
mockHandler = vi.fn();
|
||||
// Default setup: OAuth user, Pro model failed, handler injected
|
||||
mockConfig = createMockConfig({
|
||||
fallbackModelHandler: mockHandler,
|
||||
});
|
||||
consoleErrorSpy = vi.spyOn(console, 'error').mockImplementation(() => {});
|
||||
});
|
||||
|
||||
it('should return null for unknown auth types', async () => {
|
||||
afterEach(() => {
|
||||
consoleErrorSpy.mockRestore();
|
||||
});
|
||||
|
||||
it('should return null immediately if authType is not OAuth', async () => {
|
||||
const result = await handleFallback(
|
||||
mockConfig,
|
||||
'test-model',
|
||||
'unknown-auth',
|
||||
MOCK_PRO_MODEL,
|
||||
AUTH_API_KEY,
|
||||
);
|
||||
expect(result).toBeNull();
|
||||
expect(mockHandler).not.toHaveBeenCalled();
|
||||
expect(mockConfig.setFallbackMode).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should return null if the failed model is already the fallback model', async () => {
|
||||
const result = await handleFallback(
|
||||
mockConfig,
|
||||
FALLBACK_MODEL, // Failed model is Flash
|
||||
AUTH_OAUTH,
|
||||
);
|
||||
expect(result).toBeNull();
|
||||
expect(mockHandler).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should return null if no fallbackHandler is injected in config', async () => {
|
||||
const configWithoutHandler = createMockConfig({
|
||||
fallbackModelHandler: undefined,
|
||||
});
|
||||
const result = await handleFallback(
|
||||
configWithoutHandler,
|
||||
MOCK_PRO_MODEL,
|
||||
AUTH_OAUTH,
|
||||
);
|
||||
expect(result).toBeNull();
|
||||
});
|
||||
|
||||
it('should handle Qwen OAuth error', async () => {
|
||||
const result = await handleFallback(
|
||||
mockConfig,
|
||||
'test-model',
|
||||
AuthType.QWEN_OAUTH,
|
||||
new Error('unauthorized'),
|
||||
describe('when handler returns "retry"', () => {
|
||||
it('should activate fallback mode, log telemetry, and return true', async () => {
|
||||
mockHandler.mockResolvedValue('retry');
|
||||
|
||||
const result = await handleFallback(
|
||||
mockConfig,
|
||||
MOCK_PRO_MODEL,
|
||||
AUTH_OAUTH,
|
||||
);
|
||||
|
||||
expect(result).toBe(true);
|
||||
expect(mockConfig.setFallbackMode).toHaveBeenCalledWith(true);
|
||||
expect(logFlashFallback).toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
describe('when handler returns "stop"', () => {
|
||||
it('should activate fallback mode, log telemetry, and return false', async () => {
|
||||
mockHandler.mockResolvedValue('stop');
|
||||
|
||||
const result = await handleFallback(
|
||||
mockConfig,
|
||||
MOCK_PRO_MODEL,
|
||||
AUTH_OAUTH,
|
||||
);
|
||||
|
||||
expect(result).toBe(false);
|
||||
expect(mockConfig.setFallbackMode).toHaveBeenCalledWith(true);
|
||||
expect(logFlashFallback).toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
describe('when handler returns "auth"', () => {
|
||||
it('should NOT activate fallback mode and return false', async () => {
|
||||
mockHandler.mockResolvedValue('auth');
|
||||
|
||||
const result = await handleFallback(
|
||||
mockConfig,
|
||||
MOCK_PRO_MODEL,
|
||||
AUTH_OAUTH,
|
||||
);
|
||||
|
||||
expect(result).toBe(false);
|
||||
expect(mockConfig.setFallbackMode).not.toHaveBeenCalled();
|
||||
expect(logFlashFallback).not.toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
describe('when handler returns an unexpected value', () => {
|
||||
it('should log an error and return null', async () => {
|
||||
mockHandler.mockResolvedValue(null);
|
||||
|
||||
const result = await handleFallback(
|
||||
mockConfig,
|
||||
MOCK_PRO_MODEL,
|
||||
AUTH_OAUTH,
|
||||
);
|
||||
|
||||
expect(result).toBeNull();
|
||||
expect(consoleErrorSpy).toHaveBeenCalledWith(
|
||||
'Fallback UI handler failed:',
|
||||
new Error(
|
||||
'Unexpected fallback intent received from fallbackModelHandler: "null"',
|
||||
),
|
||||
);
|
||||
expect(mockConfig.setFallbackMode).not.toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
it('should pass the correct context (failedModel, fallbackModel, error) to the handler', async () => {
|
||||
const mockError = new Error('Quota Exceeded');
|
||||
mockHandler.mockResolvedValue('retry');
|
||||
|
||||
await handleFallback(mockConfig, MOCK_PRO_MODEL, AUTH_OAUTH, mockError);
|
||||
|
||||
expect(mockHandler).toHaveBeenCalledWith(
|
||||
MOCK_PRO_MODEL,
|
||||
FALLBACK_MODEL,
|
||||
mockError,
|
||||
);
|
||||
});
|
||||
|
||||
it('should not call setFallbackMode or log telemetry if already in fallback mode', async () => {
|
||||
// Setup config where fallback mode is already active
|
||||
const activeFallbackConfig = createMockConfig({
|
||||
fallbackModelHandler: mockHandler,
|
||||
isInFallbackMode: vi.fn(() => true), // Already active
|
||||
setFallbackMode: vi.fn(),
|
||||
});
|
||||
|
||||
mockHandler.mockResolvedValue('retry');
|
||||
|
||||
const result = await handleFallback(
|
||||
activeFallbackConfig,
|
||||
MOCK_PRO_MODEL,
|
||||
AUTH_OAUTH,
|
||||
);
|
||||
|
||||
// Should still return true to allow the retry (which will use the active fallback mode)
|
||||
expect(result).toBe(true);
|
||||
// Should still consult the handler
|
||||
expect(mockHandler).toHaveBeenCalled();
|
||||
// But should not mutate state or log telemetry again
|
||||
expect(activeFallbackConfig.setFallbackMode).not.toHaveBeenCalled();
|
||||
expect(logFlashFallback).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should catch errors from the handler, log an error, and return null', async () => {
|
||||
const handlerError = new Error('UI interaction failed');
|
||||
mockHandler.mockRejectedValue(handlerError);
|
||||
|
||||
const result = await handleFallback(mockConfig, MOCK_PRO_MODEL, AUTH_OAUTH);
|
||||
|
||||
expect(result).toBeNull();
|
||||
expect(consoleErrorSpy).toHaveBeenCalledWith(
|
||||
'Fallback UI handler failed:',
|
||||
handlerError,
|
||||
);
|
||||
expect(mockConfig.setFallbackMode).not.toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
@@ -6,6 +6,8 @@
|
||||
|
||||
import type { Config } from '../config/config.js';
|
||||
import { AuthType } from '../core/contentGenerator.js';
|
||||
import { DEFAULT_GEMINI_FLASH_MODEL } from '../config/models.js';
|
||||
import { logFlashFallback, FlashFallbackEvent } from '../telemetry/index.js';
|
||||
|
||||
export async function handleFallback(
|
||||
config: Config,
|
||||
@@ -18,7 +20,48 @@ export async function handleFallback(
|
||||
return handleQwenOAuthError(error);
|
||||
}
|
||||
|
||||
return null;
|
||||
// Applicability Checks
|
||||
if (authType !== AuthType.LOGIN_WITH_GOOGLE) return null;
|
||||
|
||||
const fallbackModel = DEFAULT_GEMINI_FLASH_MODEL;
|
||||
|
||||
if (failedModel === fallbackModel) return null;
|
||||
|
||||
// Consult UI Handler for Intent
|
||||
const fallbackModelHandler = config.fallbackModelHandler;
|
||||
if (typeof fallbackModelHandler !== 'function') return null;
|
||||
|
||||
try {
|
||||
// Pass the specific failed model to the UI handler.
|
||||
const intent = await fallbackModelHandler(
|
||||
failedModel,
|
||||
fallbackModel,
|
||||
error,
|
||||
);
|
||||
|
||||
// Process Intent and Update State
|
||||
switch (intent) {
|
||||
case 'retry':
|
||||
// Activate fallback mode. The NEXT retry attempt will pick this up.
|
||||
activateFallbackMode(config, authType);
|
||||
return true; // Signal retryWithBackoff to continue.
|
||||
|
||||
case 'stop':
|
||||
activateFallbackMode(config, authType);
|
||||
return false;
|
||||
|
||||
case 'auth':
|
||||
return false;
|
||||
|
||||
default:
|
||||
throw new Error(
|
||||
`Unexpected fallback intent received from fallbackModelHandler: "${intent}"`,
|
||||
);
|
||||
}
|
||||
} catch (handlerError) {
|
||||
console.error('Fallback UI handler failed:', handlerError);
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -75,3 +118,12 @@ async function handleQwenOAuthError(error?: unknown): Promise<string | null> {
|
||||
// For other errors, don't handle them specially
|
||||
return null;
|
||||
}
|
||||
|
||||
function activateFallbackMode(config: Config, authType: string | undefined) {
|
||||
if (!config.isInFallbackMode()) {
|
||||
config.setFallbackMode(true);
|
||||
if (authType) {
|
||||
logFlashFallback(config, new FlashFallbackEvent(authType));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -20,11 +20,26 @@ async function getProcessInfo(pid: number): Promise<{
|
||||
command: string;
|
||||
}> {
|
||||
// Only used for Unix systems (macOS and Linux)
|
||||
const { stdout } = await execAsync(`ps -p ${pid} -o ppid=,comm=`);
|
||||
const [ppidStr, ...commandParts] = stdout.trim().split(/\s+/);
|
||||
const parentPid = parseInt(ppidStr, 10);
|
||||
const command = commandParts.join(' ');
|
||||
return { parentPid, name: path.basename(command), command };
|
||||
try {
|
||||
const command = `ps -o ppid=,command= -p ${pid}`;
|
||||
const { stdout } = await execAsync(command);
|
||||
const trimmedStdout = stdout.trim();
|
||||
if (!trimmedStdout) {
|
||||
return { parentPid: 0, name: '', command: '' };
|
||||
}
|
||||
const parts = trimmedStdout.split(/\s+/);
|
||||
const ppidString = parts[0];
|
||||
const parentPid = parseInt(ppidString, 10);
|
||||
const fullCommand = trimmedStdout.substring(ppidString.length).trim();
|
||||
const processName = path.basename(fullCommand.split(' ')[0]);
|
||||
return {
|
||||
parentPid: isNaN(parentPid) ? 1 : parentPid,
|
||||
name: processName,
|
||||
command: fullCommand,
|
||||
};
|
||||
} catch (_e) {
|
||||
return { parentPid: 0, name: '', command: '' };
|
||||
}
|
||||
}
|
||||
/**
|
||||
* Finds the IDE process info on Unix-like systems.
|
||||
|
||||
@@ -12,6 +12,7 @@ export * from './output/json-formatter.js';
|
||||
// Export Core Logic
|
||||
export * from './core/client.js';
|
||||
export * from './core/contentGenerator.js';
|
||||
export * from './core/loggingContentGenerator.js';
|
||||
export * from './core/geminiChat.js';
|
||||
export * from './core/logger.js';
|
||||
export * from './core/prompts.js';
|
||||
@@ -23,7 +24,11 @@ export * from './core/nonInteractiveToolExecutor.js';
|
||||
|
||||
export * from './fallback/types.js';
|
||||
|
||||
export * from './code_assist/codeAssist.js';
|
||||
export * from './code_assist/oauth2.js';
|
||||
export * from './qwen/qwenOAuth2.js';
|
||||
export * from './code_assist/server.js';
|
||||
export * from './code_assist/types.js';
|
||||
|
||||
// Export utilities
|
||||
export * from './utils/paths.js';
|
||||
@@ -80,6 +85,9 @@ export * from './tools/tool-registry.js';
|
||||
// Export subagents (Phase 1)
|
||||
export * from './subagents/index.js';
|
||||
|
||||
// Export skills
|
||||
export * from './skills/index.js';
|
||||
|
||||
// Export prompt logic
|
||||
export * from './prompts/mcp-prompts.js';
|
||||
|
||||
@@ -101,6 +109,7 @@ export * from './tools/mcp-client-manager.js';
|
||||
export * from './tools/mcp-tool.js';
|
||||
export * from './tools/sdk-control-client-transport.js';
|
||||
export * from './tools/task.js';
|
||||
export * from './tools/skill.js';
|
||||
export * from './tools/todoWrite.js';
|
||||
export * from './tools/exitPlanMode.js';
|
||||
|
||||
|
||||
@@ -907,5 +907,3 @@ export async function clearQwenCredentials(): Promise<void> {
|
||||
function getQwenCachedCredentialPath(): string {
|
||||
return path.join(os.homedir(), QWEN_DIR, QWEN_CREDENTIAL_FILENAME);
|
||||
}
|
||||
|
||||
export const clearCachedCredentialFile = clearQwenCredentials;
|
||||
|
||||
31
packages/core/src/skills/index.ts
Normal file
31
packages/core/src/skills/index.ts
Normal file
@@ -0,0 +1,31 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Qwen
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
/**
|
||||
* @fileoverview Skills feature implementation
|
||||
*
|
||||
* This module provides the foundation for the skills feature, which allows
|
||||
* users to define reusable skill configurations that can be loaded by the
|
||||
* model via a dedicated Skills tool.
|
||||
*
|
||||
* Skills are stored as directories in `.qwen/skills/` (project-level) or
|
||||
* `~/.qwen/skills/` (user-level), with each directory containing a SKILL.md
|
||||
* file with YAML frontmatter for metadata.
|
||||
*/
|
||||
|
||||
// Core types and interfaces
|
||||
export type {
|
||||
SkillConfig,
|
||||
SkillLevel,
|
||||
SkillValidationResult,
|
||||
ListSkillsOptions,
|
||||
SkillErrorCode,
|
||||
} from './types.js';
|
||||
|
||||
export { SkillError } from './types.js';
|
||||
|
||||
// Main management class
|
||||
export { SkillManager } from './skill-manager.js';
|
||||
463
packages/core/src/skills/skill-manager.test.ts
Normal file
463
packages/core/src/skills/skill-manager.test.ts
Normal file
@@ -0,0 +1,463 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Qwen
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { vi, describe, it, expect, beforeEach, afterEach } from 'vitest';
|
||||
import * as fs from 'fs/promises';
|
||||
import * as path from 'path';
|
||||
import * as os from 'os';
|
||||
import { SkillManager } from './skill-manager.js';
|
||||
import { type SkillConfig, SkillError } from './types.js';
|
||||
import type { Config } from '../config/config.js';
|
||||
import { makeFakeConfig } from '../test-utils/config.js';
|
||||
|
||||
// Mock file system operations
|
||||
vi.mock('fs/promises');
|
||||
vi.mock('os');
|
||||
|
||||
// Mock yaml parser - use vi.hoisted for proper hoisting
|
||||
const mockParseYaml = vi.hoisted(() => vi.fn());
|
||||
|
||||
vi.mock('../utils/yaml-parser.js', () => ({
|
||||
parse: mockParseYaml,
|
||||
stringify: vi.fn(),
|
||||
}));
|
||||
|
||||
describe('SkillManager', () => {
|
||||
let manager: SkillManager;
|
||||
let mockConfig: Config;
|
||||
|
||||
beforeEach(() => {
|
||||
// Create mock Config object using test utility
|
||||
mockConfig = makeFakeConfig({});
|
||||
|
||||
// Mock the project root method
|
||||
vi.spyOn(mockConfig, 'getProjectRoot').mockReturnValue('/test/project');
|
||||
|
||||
// Mock os.homedir
|
||||
vi.mocked(os.homedir).mockReturnValue('/home/user');
|
||||
|
||||
// Reset and setup mocks
|
||||
vi.clearAllMocks();
|
||||
|
||||
// Setup yaml parser mocks with sophisticated behavior
|
||||
mockParseYaml.mockImplementation((yamlString: string) => {
|
||||
// Handle different test cases based on YAML content
|
||||
if (yamlString.includes('allowedTools:')) {
|
||||
return {
|
||||
name: 'test-skill',
|
||||
description: 'A test skill',
|
||||
allowedTools: ['read_file', 'write_file'],
|
||||
};
|
||||
}
|
||||
if (yamlString.includes('name: skill1')) {
|
||||
return { name: 'skill1', description: 'First skill' };
|
||||
}
|
||||
if (yamlString.includes('name: skill2')) {
|
||||
return { name: 'skill2', description: 'Second skill' };
|
||||
}
|
||||
if (yamlString.includes('name: skill3')) {
|
||||
return { name: 'skill3', description: 'Third skill' };
|
||||
}
|
||||
if (!yamlString.includes('name:')) {
|
||||
return { description: 'A test skill' }; // Missing name case
|
||||
}
|
||||
if (!yamlString.includes('description:')) {
|
||||
return { name: 'test-skill' }; // Missing description case
|
||||
}
|
||||
// Default case
|
||||
return {
|
||||
name: 'test-skill',
|
||||
description: 'A test skill',
|
||||
};
|
||||
});
|
||||
|
||||
manager = new SkillManager(mockConfig);
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.restoreAllMocks();
|
||||
});
|
||||
|
||||
const validSkillConfig: SkillConfig = {
|
||||
name: 'test-skill',
|
||||
description: 'A test skill',
|
||||
level: 'project',
|
||||
filePath: '/test/project/.qwen/skills/test-skill/SKILL.md',
|
||||
body: 'You are a helpful assistant with this skill.',
|
||||
};
|
||||
|
||||
const validMarkdown = `---
|
||||
name: test-skill
|
||||
description: A test skill
|
||||
---
|
||||
|
||||
You are a helpful assistant with this skill.
|
||||
`;
|
||||
|
||||
describe('parseSkillContent', () => {
|
||||
it('should parse valid markdown content', () => {
|
||||
const config = manager.parseSkillContent(
|
||||
validMarkdown,
|
||||
validSkillConfig.filePath,
|
||||
'project',
|
||||
);
|
||||
|
||||
expect(config.name).toBe('test-skill');
|
||||
expect(config.description).toBe('A test skill');
|
||||
expect(config.body).toBe('You are a helpful assistant with this skill.');
|
||||
expect(config.level).toBe('project');
|
||||
expect(config.filePath).toBe(validSkillConfig.filePath);
|
||||
});
|
||||
|
||||
it('should parse content with allowedTools', () => {
|
||||
const markdownWithTools = `---
|
||||
name: test-skill
|
||||
description: A test skill
|
||||
allowedTools:
|
||||
- read_file
|
||||
- write_file
|
||||
---
|
||||
|
||||
You are a helpful assistant with this skill.
|
||||
`;
|
||||
|
||||
const config = manager.parseSkillContent(
|
||||
markdownWithTools,
|
||||
validSkillConfig.filePath,
|
||||
'project',
|
||||
);
|
||||
|
||||
expect(config.allowedTools).toEqual(['read_file', 'write_file']);
|
||||
});
|
||||
|
||||
it('should determine level from file path', () => {
|
||||
const projectPath = '/test/project/.qwen/skills/test-skill/SKILL.md';
|
||||
const userPath = '/home/user/.qwen/skills/test-skill/SKILL.md';
|
||||
|
||||
const projectConfig = manager.parseSkillContent(
|
||||
validMarkdown,
|
||||
projectPath,
|
||||
'project',
|
||||
);
|
||||
const userConfig = manager.parseSkillContent(
|
||||
validMarkdown,
|
||||
userPath,
|
||||
'user',
|
||||
);
|
||||
|
||||
expect(projectConfig.level).toBe('project');
|
||||
expect(userConfig.level).toBe('user');
|
||||
});
|
||||
|
||||
it('should throw error for invalid frontmatter format', () => {
|
||||
const invalidMarkdown = `No frontmatter here
|
||||
Just content`;
|
||||
|
||||
expect(() =>
|
||||
manager.parseSkillContent(
|
||||
invalidMarkdown,
|
||||
validSkillConfig.filePath,
|
||||
'project',
|
||||
),
|
||||
).toThrow(SkillError);
|
||||
});
|
||||
|
||||
it('should throw error for missing name', () => {
|
||||
const markdownWithoutName = `---
|
||||
description: A test skill
|
||||
---
|
||||
|
||||
You are a helpful assistant.
|
||||
`;
|
||||
|
||||
expect(() =>
|
||||
manager.parseSkillContent(
|
||||
markdownWithoutName,
|
||||
validSkillConfig.filePath,
|
||||
'project',
|
||||
),
|
||||
).toThrow(SkillError);
|
||||
});
|
||||
|
||||
it('should throw error for missing description', () => {
|
||||
const markdownWithoutDescription = `---
|
||||
name: test-skill
|
||||
---
|
||||
|
||||
You are a helpful assistant.
|
||||
`;
|
||||
|
||||
expect(() =>
|
||||
manager.parseSkillContent(
|
||||
markdownWithoutDescription,
|
||||
validSkillConfig.filePath,
|
||||
'project',
|
||||
),
|
||||
).toThrow(SkillError);
|
||||
});
|
||||
});
|
||||
|
||||
describe('validateConfig', () => {
|
||||
it('should validate valid configuration', () => {
|
||||
const result = manager.validateConfig(validSkillConfig);
|
||||
|
||||
expect(result.isValid).toBe(true);
|
||||
expect(result.errors).toHaveLength(0);
|
||||
});
|
||||
|
||||
it('should report error for missing name', () => {
|
||||
const invalidConfig = { ...validSkillConfig, name: '' };
|
||||
const result = manager.validateConfig(invalidConfig);
|
||||
|
||||
expect(result.isValid).toBe(false);
|
||||
expect(result.errors).toContain('"name" cannot be empty');
|
||||
});
|
||||
|
||||
it('should report error for missing description', () => {
|
||||
const invalidConfig = { ...validSkillConfig, description: '' };
|
||||
const result = manager.validateConfig(invalidConfig);
|
||||
|
||||
expect(result.isValid).toBe(false);
|
||||
expect(result.errors).toContain('"description" cannot be empty');
|
||||
});
|
||||
|
||||
it('should report error for invalid allowedTools type', () => {
|
||||
const invalidConfig = {
|
||||
...validSkillConfig,
|
||||
allowedTools: 'not-an-array' as unknown as string[],
|
||||
};
|
||||
const result = manager.validateConfig(invalidConfig);
|
||||
|
||||
expect(result.isValid).toBe(false);
|
||||
expect(result.errors).toContain('"allowedTools" must be an array');
|
||||
});
|
||||
|
||||
it('should warn for empty body', () => {
|
||||
const configWithEmptyBody = { ...validSkillConfig, body: '' };
|
||||
const result = manager.validateConfig(configWithEmptyBody);
|
||||
|
||||
expect(result.isValid).toBe(true); // Still valid
|
||||
expect(result.warnings).toContain('Skill body is empty');
|
||||
});
|
||||
});
|
||||
|
||||
describe('loadSkill', () => {
|
||||
it('should load skill from project level first', async () => {
|
||||
vi.mocked(fs.readdir).mockResolvedValue([
|
||||
{ name: 'test-skill', isDirectory: () => true, isFile: () => false },
|
||||
] as unknown as Awaited<ReturnType<typeof fs.readdir>>);
|
||||
vi.mocked(fs.access).mockResolvedValue(undefined);
|
||||
vi.mocked(fs.readFile).mockResolvedValue(validMarkdown);
|
||||
|
||||
const config = await manager.loadSkill('test-skill');
|
||||
|
||||
expect(config).toBeDefined();
|
||||
expect(config!.name).toBe('test-skill');
|
||||
});
|
||||
|
||||
it('should fall back to user level if project level fails', async () => {
|
||||
vi.mocked(fs.readdir)
|
||||
.mockRejectedValueOnce(new Error('Project dir not found')) // project level fails
|
||||
.mockResolvedValueOnce([
|
||||
{ name: 'test-skill', isDirectory: () => true, isFile: () => false },
|
||||
] as unknown as Awaited<ReturnType<typeof fs.readdir>>); // user level succeeds
|
||||
vi.mocked(fs.access).mockResolvedValue(undefined);
|
||||
vi.mocked(fs.readFile).mockResolvedValue(validMarkdown);
|
||||
|
||||
const config = await manager.loadSkill('test-skill');
|
||||
|
||||
expect(config).toBeDefined();
|
||||
expect(config!.name).toBe('test-skill');
|
||||
});
|
||||
|
||||
it('should return null if not found at either level', async () => {
|
||||
vi.mocked(fs.readdir).mockRejectedValue(new Error('Directory not found'));
|
||||
|
||||
const config = await manager.loadSkill('nonexistent');
|
||||
|
||||
expect(config).toBeNull();
|
||||
});
|
||||
});
|
||||
|
||||
describe('loadSkillForRuntime', () => {
|
||||
it('should load skill for runtime', async () => {
|
||||
vi.mocked(fs.readdir).mockResolvedValueOnce([
|
||||
{ name: 'test-skill', isDirectory: () => true, isFile: () => false },
|
||||
] as unknown as Awaited<ReturnType<typeof fs.readdir>>);
|
||||
|
||||
vi.mocked(fs.access).mockResolvedValue(undefined);
|
||||
vi.mocked(fs.readFile).mockResolvedValue(validMarkdown); // SKILL.md
|
||||
|
||||
const config = await manager.loadSkillForRuntime('test-skill');
|
||||
|
||||
expect(config).toBeDefined();
|
||||
expect(config!.name).toBe('test-skill');
|
||||
});
|
||||
|
||||
it('should return null if skill not found', async () => {
|
||||
vi.mocked(fs.readdir).mockRejectedValue(new Error('Directory not found'));
|
||||
|
||||
const config = await manager.loadSkillForRuntime('nonexistent');
|
||||
|
||||
expect(config).toBeNull();
|
||||
});
|
||||
});
|
||||
|
||||
describe('listSkills', () => {
|
||||
beforeEach(() => {
|
||||
// Mock directory listing for skills directories (with Dirent objects)
|
||||
vi.mocked(fs.readdir)
|
||||
.mockResolvedValueOnce([
|
||||
{ name: 'skill1', isDirectory: () => true, isFile: () => false },
|
||||
{ name: 'skill2', isDirectory: () => true, isFile: () => false },
|
||||
{
|
||||
name: 'not-a-dir.txt',
|
||||
isDirectory: () => false,
|
||||
isFile: () => true,
|
||||
},
|
||||
] as unknown as Awaited<ReturnType<typeof fs.readdir>>)
|
||||
.mockResolvedValueOnce([
|
||||
{ name: 'skill3', isDirectory: () => true, isFile: () => false },
|
||||
{ name: 'skill1', isDirectory: () => true, isFile: () => false },
|
||||
] as unknown as Awaited<ReturnType<typeof fs.readdir>>);
|
||||
|
||||
vi.mocked(fs.access).mockResolvedValue(undefined);
|
||||
|
||||
// Mock file reading for valid skills
|
||||
vi.mocked(fs.readFile).mockImplementation((filePath) => {
|
||||
const pathStr = String(filePath);
|
||||
if (pathStr.includes('skill1')) {
|
||||
return Promise.resolve(`---
|
||||
name: skill1
|
||||
description: First skill
|
||||
---
|
||||
Skill 1 content`);
|
||||
} else if (pathStr.includes('skill2')) {
|
||||
return Promise.resolve(`---
|
||||
name: skill2
|
||||
description: Second skill
|
||||
---
|
||||
Skill 2 content`);
|
||||
} else if (pathStr.includes('skill3')) {
|
||||
return Promise.resolve(`---
|
||||
name: skill3
|
||||
description: Third skill
|
||||
---
|
||||
Skill 3 content`);
|
||||
}
|
||||
return Promise.reject(new Error('File not found'));
|
||||
});
|
||||
});
|
||||
|
||||
it('should list skills from both levels', async () => {
|
||||
const skills = await manager.listSkills();
|
||||
|
||||
expect(skills).toHaveLength(3); // skill1 (project takes precedence), skill2, skill3
|
||||
expect(skills.map((s) => s.name).sort()).toEqual([
|
||||
'skill1',
|
||||
'skill2',
|
||||
'skill3',
|
||||
]);
|
||||
});
|
||||
|
||||
it('should prioritize project level over user level', async () => {
|
||||
const skills = await manager.listSkills();
|
||||
const skill1 = skills.find((s) => s.name === 'skill1');
|
||||
|
||||
expect(skill1!.level).toBe('project');
|
||||
});
|
||||
|
||||
it('should filter by level', async () => {
|
||||
const projectSkills = await manager.listSkills({
|
||||
level: 'project',
|
||||
});
|
||||
|
||||
expect(projectSkills).toHaveLength(2); // skill1, skill2
|
||||
expect(projectSkills.every((s) => s.level === 'project')).toBe(true);
|
||||
});
|
||||
|
||||
it('should handle empty directories', async () => {
|
||||
vi.mocked(fs.readdir).mockReset();
|
||||
vi.mocked(fs.readdir).mockResolvedValue(
|
||||
[] as unknown as Awaited<ReturnType<typeof fs.readdir>>,
|
||||
);
|
||||
|
||||
const skills = await manager.listSkills({ force: true });
|
||||
|
||||
expect(skills).toHaveLength(0);
|
||||
});
|
||||
|
||||
it('should handle directory read errors', async () => {
|
||||
vi.mocked(fs.readdir).mockReset();
|
||||
vi.mocked(fs.readdir).mockRejectedValue(new Error('Directory not found'));
|
||||
|
||||
const skills = await manager.listSkills({ force: true });
|
||||
|
||||
expect(skills).toHaveLength(0);
|
||||
});
|
||||
});
|
||||
|
||||
describe('getSkillsBaseDir', () => {
|
||||
it('should return project-level base dir', () => {
|
||||
const baseDir = manager.getSkillsBaseDir('project');
|
||||
|
||||
expect(baseDir).toBe(path.join('/test/project', '.qwen', 'skills'));
|
||||
});
|
||||
|
||||
it('should return user-level base dir', () => {
|
||||
const baseDir = manager.getSkillsBaseDir('user');
|
||||
|
||||
expect(baseDir).toBe(path.join('/home/user', '.qwen', 'skills'));
|
||||
});
|
||||
});
|
||||
|
||||
describe('change listeners', () => {
|
||||
it('should notify listeners when cache is refreshed', async () => {
|
||||
const listener = vi.fn();
|
||||
manager.addChangeListener(listener);
|
||||
|
||||
vi.mocked(fs.readdir).mockResolvedValue(
|
||||
[] as unknown as Awaited<ReturnType<typeof fs.readdir>>,
|
||||
);
|
||||
|
||||
await manager.refreshCache();
|
||||
|
||||
expect(listener).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should remove listener when cleanup function is called', async () => {
|
||||
const listener = vi.fn();
|
||||
const removeListener = manager.addChangeListener(listener);
|
||||
|
||||
removeListener();
|
||||
|
||||
vi.mocked(fs.readdir).mockResolvedValue(
|
||||
[] as unknown as Awaited<ReturnType<typeof fs.readdir>>,
|
||||
);
|
||||
|
||||
await manager.refreshCache();
|
||||
|
||||
expect(listener).not.toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
describe('parse errors', () => {
|
||||
it('should track parse errors', async () => {
|
||||
vi.mocked(fs.readdir).mockResolvedValue([
|
||||
{ name: 'bad-skill', isDirectory: () => true, isFile: () => false },
|
||||
] as unknown as Awaited<ReturnType<typeof fs.readdir>>);
|
||||
vi.mocked(fs.access).mockResolvedValue(undefined);
|
||||
vi.mocked(fs.readFile).mockResolvedValue(
|
||||
'invalid content without frontmatter',
|
||||
);
|
||||
|
||||
await manager.listSkills({ force: true });
|
||||
|
||||
const errors = manager.getParseErrors();
|
||||
expect(errors.size).toBeGreaterThan(0);
|
||||
});
|
||||
});
|
||||
});
|
||||
452
packages/core/src/skills/skill-manager.ts
Normal file
452
packages/core/src/skills/skill-manager.ts
Normal file
@@ -0,0 +1,452 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Qwen
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import * as fs from 'fs/promises';
|
||||
import * as path from 'path';
|
||||
import * as os from 'os';
|
||||
import { parse as parseYaml } from '../utils/yaml-parser.js';
|
||||
import type {
|
||||
SkillConfig,
|
||||
SkillLevel,
|
||||
ListSkillsOptions,
|
||||
SkillValidationResult,
|
||||
} from './types.js';
|
||||
import { SkillError, SkillErrorCode } from './types.js';
|
||||
import type { Config } from '../config/config.js';
|
||||
|
||||
const QWEN_CONFIG_DIR = '.qwen';
|
||||
const SKILLS_CONFIG_DIR = 'skills';
|
||||
const SKILL_MANIFEST_FILE = 'SKILL.md';
|
||||
|
||||
/**
|
||||
* Manages skill configurations stored as directories containing SKILL.md files.
|
||||
* Provides discovery, parsing, validation, and caching for skills.
|
||||
*/
|
||||
export class SkillManager {
|
||||
private skillsCache: Map<SkillLevel, SkillConfig[]> | null = null;
|
||||
private readonly changeListeners: Set<() => void> = new Set();
|
||||
private parseErrors: Map<string, SkillError> = new Map();
|
||||
|
||||
constructor(private readonly config: Config) {}
|
||||
|
||||
/**
|
||||
* Adds a listener that will be called when skills change.
|
||||
* @returns A function to remove the listener.
|
||||
*/
|
||||
addChangeListener(listener: () => void): () => void {
|
||||
this.changeListeners.add(listener);
|
||||
return () => {
|
||||
this.changeListeners.delete(listener);
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Notifies all registered change listeners.
|
||||
*/
|
||||
private notifyChangeListeners(): void {
|
||||
for (const listener of this.changeListeners) {
|
||||
try {
|
||||
listener();
|
||||
} catch (error) {
|
||||
console.warn('Skill change listener threw an error:', error);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets any parse errors that occurred during skill loading.
|
||||
* @returns Map of skill paths to their parse errors.
|
||||
*/
|
||||
getParseErrors(): Map<string, SkillError> {
|
||||
return new Map(this.parseErrors);
|
||||
}
|
||||
|
||||
/**
|
||||
* Lists all available skills.
|
||||
*
|
||||
* @param options - Filtering options
|
||||
* @returns Array of skill configurations
|
||||
*/
|
||||
async listSkills(options: ListSkillsOptions = {}): Promise<SkillConfig[]> {
|
||||
const skills: SkillConfig[] = [];
|
||||
const seenNames = new Set<string>();
|
||||
|
||||
const levelsToCheck: SkillLevel[] = options.level
|
||||
? [options.level]
|
||||
: ['project', 'user'];
|
||||
|
||||
// Check if we should use cache or force refresh
|
||||
const shouldUseCache = !options.force && this.skillsCache !== null;
|
||||
|
||||
// Initialize cache if it doesn't exist or we're forcing a refresh
|
||||
if (!shouldUseCache) {
|
||||
await this.refreshCache();
|
||||
}
|
||||
|
||||
// Collect skills from each level (project takes precedence over user)
|
||||
for (const level of levelsToCheck) {
|
||||
const levelSkills = this.skillsCache?.get(level) || [];
|
||||
|
||||
for (const skill of levelSkills) {
|
||||
// Skip if we've already seen this name (precedence: project > user)
|
||||
if (seenNames.has(skill.name)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
skills.push(skill);
|
||||
seenNames.add(skill.name);
|
||||
}
|
||||
}
|
||||
|
||||
// Sort by name for consistent ordering
|
||||
skills.sort((a, b) => a.name.localeCompare(b.name));
|
||||
|
||||
return skills;
|
||||
}
|
||||
|
||||
/**
|
||||
* Loads a skill configuration by name.
|
||||
* If level is specified, only searches that level.
|
||||
* If level is omitted, searches project-level first, then user-level.
|
||||
*
|
||||
* @param name - Name of the skill to load
|
||||
* @param level - Optional level to limit search to
|
||||
* @returns SkillConfig or null if not found
|
||||
*/
|
||||
async loadSkill(
|
||||
name: string,
|
||||
level?: SkillLevel,
|
||||
): Promise<SkillConfig | null> {
|
||||
if (level) {
|
||||
return this.findSkillByNameAtLevel(name, level);
|
||||
}
|
||||
|
||||
// Try project level first
|
||||
const projectSkill = await this.findSkillByNameAtLevel(name, 'project');
|
||||
if (projectSkill) {
|
||||
return projectSkill;
|
||||
}
|
||||
|
||||
// Try user level
|
||||
return this.findSkillByNameAtLevel(name, 'user');
|
||||
}
|
||||
|
||||
/**
|
||||
* Loads a skill with its full content, ready for runtime use.
|
||||
* This includes loading additional files from the skill directory.
|
||||
*
|
||||
* @param name - Name of the skill to load
|
||||
* @param level - Optional level to limit search to
|
||||
* @returns SkillConfig or null if not found
|
||||
*/
|
||||
async loadSkillForRuntime(
|
||||
name: string,
|
||||
level?: SkillLevel,
|
||||
): Promise<SkillConfig | null> {
|
||||
const skill = await this.loadSkill(name, level);
|
||||
if (!skill) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return skill;
|
||||
}
|
||||
|
||||
/**
|
||||
* Validates a skill configuration.
|
||||
*
|
||||
* @param config - Configuration to validate
|
||||
* @returns Validation result
|
||||
*/
|
||||
validateConfig(config: Partial<SkillConfig>): SkillValidationResult {
|
||||
const errors: string[] = [];
|
||||
const warnings: string[] = [];
|
||||
|
||||
// Check required fields
|
||||
if (typeof config.name !== 'string') {
|
||||
errors.push('Missing or invalid "name" field');
|
||||
} else if (config.name.trim() === '') {
|
||||
errors.push('"name" cannot be empty');
|
||||
}
|
||||
|
||||
if (typeof config.description !== 'string') {
|
||||
errors.push('Missing or invalid "description" field');
|
||||
} else if (config.description.trim() === '') {
|
||||
errors.push('"description" cannot be empty');
|
||||
}
|
||||
|
||||
// Validate allowedTools if present
|
||||
if (config.allowedTools !== undefined) {
|
||||
if (!Array.isArray(config.allowedTools)) {
|
||||
errors.push('"allowedTools" must be an array');
|
||||
} else {
|
||||
for (const tool of config.allowedTools) {
|
||||
if (typeof tool !== 'string') {
|
||||
errors.push('"allowedTools" must contain only strings');
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Warn if body is empty
|
||||
if (!config.body || config.body.trim() === '') {
|
||||
warnings.push('Skill body is empty');
|
||||
}
|
||||
|
||||
return {
|
||||
isValid: errors.length === 0,
|
||||
errors,
|
||||
warnings,
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Refreshes the skills cache by loading all skills from disk.
|
||||
*/
|
||||
async refreshCache(): Promise<void> {
|
||||
const skillsCache = new Map<SkillLevel, SkillConfig[]>();
|
||||
this.parseErrors.clear();
|
||||
|
||||
const levels: SkillLevel[] = ['project', 'user'];
|
||||
|
||||
for (const level of levels) {
|
||||
const levelSkills = await this.listSkillsAtLevel(level);
|
||||
skillsCache.set(level, levelSkills);
|
||||
}
|
||||
|
||||
this.skillsCache = skillsCache;
|
||||
this.notifyChangeListeners();
|
||||
}
|
||||
|
||||
/**
|
||||
* Parses a SKILL.md file and returns the configuration.
|
||||
*
|
||||
* @param filePath - Path to the SKILL.md file
|
||||
* @param level - Storage level
|
||||
* @returns SkillConfig
|
||||
* @throws SkillError if parsing fails
|
||||
*/
|
||||
parseSkillFile(filePath: string, level: SkillLevel): Promise<SkillConfig> {
|
||||
return this.parseSkillFileInternal(filePath, level);
|
||||
}
|
||||
|
||||
/**
|
||||
* Internal implementation of skill file parsing.
|
||||
*/
|
||||
private async parseSkillFileInternal(
|
||||
filePath: string,
|
||||
level: SkillLevel,
|
||||
): Promise<SkillConfig> {
|
||||
let content: string;
|
||||
|
||||
try {
|
||||
content = await fs.readFile(filePath, 'utf8');
|
||||
} catch (error) {
|
||||
const skillError = new SkillError(
|
||||
`Failed to read skill file: ${error instanceof Error ? error.message : 'Unknown error'}`,
|
||||
SkillErrorCode.FILE_ERROR,
|
||||
);
|
||||
this.parseErrors.set(filePath, skillError);
|
||||
throw skillError;
|
||||
}
|
||||
|
||||
return this.parseSkillContent(content, filePath, level);
|
||||
}
|
||||
|
||||
/**
|
||||
* Parses skill content from a string.
|
||||
*
|
||||
* @param content - File content
|
||||
* @param filePath - File path for error reporting
|
||||
* @param level - Storage level
|
||||
* @returns SkillConfig
|
||||
* @throws SkillError if parsing fails
|
||||
*/
|
||||
parseSkillContent(
|
||||
content: string,
|
||||
filePath: string,
|
||||
level: SkillLevel,
|
||||
): SkillConfig {
|
||||
try {
|
||||
// Split frontmatter and content
|
||||
const frontmatterRegex = /^---\n([\s\S]*?)\n---\n([\s\S]*)$/;
|
||||
const match = content.match(frontmatterRegex);
|
||||
|
||||
if (!match) {
|
||||
throw new Error('Invalid format: missing YAML frontmatter');
|
||||
}
|
||||
|
||||
const [, frontmatterYaml, body] = match;
|
||||
|
||||
// Parse YAML frontmatter
|
||||
const frontmatter = parseYaml(frontmatterYaml) as Record<string, unknown>;
|
||||
|
||||
// Extract required fields
|
||||
const nameRaw = frontmatter['name'];
|
||||
const descriptionRaw = frontmatter['description'];
|
||||
|
||||
if (nameRaw == null || nameRaw === '') {
|
||||
throw new Error('Missing "name" in frontmatter');
|
||||
}
|
||||
|
||||
if (descriptionRaw == null || descriptionRaw === '') {
|
||||
throw new Error('Missing "description" in frontmatter');
|
||||
}
|
||||
|
||||
// Convert to strings
|
||||
const name = String(nameRaw);
|
||||
const description = String(descriptionRaw);
|
||||
|
||||
// Extract optional fields
|
||||
const allowedToolsRaw = frontmatter['allowedTools'] as
|
||||
| unknown[]
|
||||
| undefined;
|
||||
let allowedTools: string[] | undefined;
|
||||
|
||||
if (allowedToolsRaw !== undefined) {
|
||||
if (Array.isArray(allowedToolsRaw)) {
|
||||
allowedTools = allowedToolsRaw.map(String);
|
||||
} else {
|
||||
throw new Error('"allowedTools" must be an array');
|
||||
}
|
||||
}
|
||||
|
||||
const config: SkillConfig = {
|
||||
name,
|
||||
description,
|
||||
allowedTools,
|
||||
level,
|
||||
filePath,
|
||||
body: body.trim(),
|
||||
};
|
||||
|
||||
// Validate the parsed configuration
|
||||
const validation = this.validateConfig(config);
|
||||
if (!validation.isValid) {
|
||||
throw new Error(`Validation failed: ${validation.errors.join(', ')}`);
|
||||
}
|
||||
|
||||
return config;
|
||||
} catch (error) {
|
||||
const skillError = new SkillError(
|
||||
`Failed to parse skill file: ${error instanceof Error ? error.message : 'Unknown error'}`,
|
||||
SkillErrorCode.PARSE_ERROR,
|
||||
);
|
||||
this.parseErrors.set(filePath, skillError);
|
||||
throw skillError;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the base directory for skills at a specific level.
|
||||
*
|
||||
* @param level - Storage level
|
||||
* @returns Absolute directory path
|
||||
*/
|
||||
getSkillsBaseDir(level: SkillLevel): string {
|
||||
const baseDir =
|
||||
level === 'project'
|
||||
? path.join(
|
||||
this.config.getProjectRoot(),
|
||||
QWEN_CONFIG_DIR,
|
||||
SKILLS_CONFIG_DIR,
|
||||
)
|
||||
: path.join(os.homedir(), QWEN_CONFIG_DIR, SKILLS_CONFIG_DIR);
|
||||
|
||||
return baseDir;
|
||||
}
|
||||
|
||||
/**
|
||||
* Lists skills at a specific level.
|
||||
*
|
||||
* @param level - Storage level to scan
|
||||
* @returns Array of skill configurations
|
||||
*/
|
||||
private async listSkillsAtLevel(level: SkillLevel): Promise<SkillConfig[]> {
|
||||
const projectRoot = this.config.getProjectRoot();
|
||||
const homeDir = os.homedir();
|
||||
const isHomeDirectory = path.resolve(projectRoot) === path.resolve(homeDir);
|
||||
|
||||
// If project level is requested but project root is same as home directory,
|
||||
// return empty array to avoid conflicts between project and global skills
|
||||
if (level === 'project' && isHomeDirectory) {
|
||||
return [];
|
||||
}
|
||||
|
||||
const baseDir = this.getSkillsBaseDir(level);
|
||||
|
||||
try {
|
||||
const entries = await fs.readdir(baseDir, { withFileTypes: true });
|
||||
const skills: SkillConfig[] = [];
|
||||
|
||||
for (const entry of entries) {
|
||||
// Only process directories (each skill is a directory)
|
||||
if (!entry.isDirectory()) continue;
|
||||
|
||||
const skillDir = path.join(baseDir, entry.name);
|
||||
const skillManifest = path.join(skillDir, SKILL_MANIFEST_FILE);
|
||||
|
||||
try {
|
||||
// Check if SKILL.md exists
|
||||
await fs.access(skillManifest);
|
||||
|
||||
const config = await this.parseSkillFileInternal(
|
||||
skillManifest,
|
||||
level,
|
||||
);
|
||||
skills.push(config);
|
||||
} catch (error) {
|
||||
// Skip directories without valid SKILL.md
|
||||
if (error instanceof SkillError) {
|
||||
// Parse error was already recorded
|
||||
console.warn(
|
||||
`Failed to parse skill at ${skillDir}: ${error.message}`,
|
||||
);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
return skills;
|
||||
} catch (_error) {
|
||||
// Directory doesn't exist or can't be read
|
||||
return [];
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Finds a skill by name at a specific level.
|
||||
*
|
||||
* @param name - Name of the skill to find
|
||||
* @param level - Storage level to search
|
||||
* @returns SkillConfig or null if not found
|
||||
*/
|
||||
private async findSkillByNameAtLevel(
|
||||
name: string,
|
||||
level: SkillLevel,
|
||||
): Promise<SkillConfig | null> {
|
||||
await this.ensureLevelCache(level);
|
||||
|
||||
const levelSkills = this.skillsCache?.get(level) || [];
|
||||
|
||||
// Find the skill with matching name
|
||||
return levelSkills.find((skill) => skill.name === name) || null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Ensures the cache is populated for a specific level without loading other levels.
|
||||
*/
|
||||
private async ensureLevelCache(level: SkillLevel): Promise<void> {
|
||||
if (!this.skillsCache) {
|
||||
this.skillsCache = new Map<SkillLevel, SkillConfig[]>();
|
||||
}
|
||||
|
||||
if (!this.skillsCache.has(level)) {
|
||||
const levelSkills = await this.listSkillsAtLevel(level);
|
||||
this.skillsCache.set(level, levelSkills);
|
||||
}
|
||||
}
|
||||
}
|
||||
105
packages/core/src/skills/types.ts
Normal file
105
packages/core/src/skills/types.ts
Normal file
@@ -0,0 +1,105 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Qwen
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
/**
|
||||
* Represents the storage level for a skill configuration.
|
||||
* - 'project': Stored in `.qwen/skills/` within the project directory
|
||||
* - 'user': Stored in `~/.qwen/skills/` in the user's home directory
|
||||
*/
|
||||
export type SkillLevel = 'project' | 'user';
|
||||
|
||||
/**
|
||||
* Core configuration for a skill as stored in SKILL.md files.
|
||||
* Each skill directory contains a SKILL.md file with YAML frontmatter
|
||||
* containing metadata, followed by markdown content describing the skill.
|
||||
*/
|
||||
export interface SkillConfig {
|
||||
/** Unique name identifier for the skill */
|
||||
name: string;
|
||||
|
||||
/** Human-readable description of what this skill provides */
|
||||
description: string;
|
||||
|
||||
/**
|
||||
* Optional list of tool names that this skill is allowed to use.
|
||||
* For v1, this is informational only (no gating).
|
||||
*/
|
||||
allowedTools?: string[];
|
||||
|
||||
/**
|
||||
* Storage level - determines where the configuration file is stored
|
||||
*/
|
||||
level: SkillLevel;
|
||||
|
||||
/**
|
||||
* Absolute path to the skill directory containing SKILL.md
|
||||
*/
|
||||
filePath: string;
|
||||
|
||||
/**
|
||||
* The markdown body content from SKILL.md (after the frontmatter)
|
||||
*/
|
||||
body: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* Runtime configuration for a skill when it's being actively used.
|
||||
* Extends SkillConfig with additional runtime-specific fields.
|
||||
*/
|
||||
export type SkillRuntimeConfig = SkillConfig;
|
||||
|
||||
/**
|
||||
* Result of a validation operation on a skill configuration.
|
||||
*/
|
||||
export interface SkillValidationResult {
|
||||
/** Whether the configuration is valid */
|
||||
isValid: boolean;
|
||||
|
||||
/** Array of error messages if validation failed */
|
||||
errors: string[];
|
||||
|
||||
/** Array of warning messages (non-blocking issues) */
|
||||
warnings: string[];
|
||||
}
|
||||
|
||||
/**
|
||||
* Options for listing skills.
|
||||
*/
|
||||
export interface ListSkillsOptions {
|
||||
/** Filter by storage level */
|
||||
level?: SkillLevel;
|
||||
|
||||
/** Force refresh from disk, bypassing cache. Defaults to false. */
|
||||
force?: boolean;
|
||||
}
|
||||
|
||||
/**
|
||||
* Error thrown when a skill operation fails.
|
||||
*/
|
||||
export class SkillError extends Error {
|
||||
constructor(
|
||||
message: string,
|
||||
readonly code: SkillErrorCode,
|
||||
readonly skillName?: string,
|
||||
) {
|
||||
super(message);
|
||||
this.name = 'SkillError';
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Error codes for skill operations.
|
||||
*/
|
||||
export const SkillErrorCode = {
|
||||
NOT_FOUND: 'NOT_FOUND',
|
||||
INVALID_CONFIG: 'INVALID_CONFIG',
|
||||
INVALID_NAME: 'INVALID_NAME',
|
||||
FILE_ERROR: 'FILE_ERROR',
|
||||
PARSE_ERROR: 'PARSE_ERROR',
|
||||
} as const;
|
||||
|
||||
export type SkillErrorCode =
|
||||
(typeof SkillErrorCode)[keyof typeof SkillErrorCode];
|
||||
@@ -30,6 +30,7 @@ import {
|
||||
ToolCallEvent,
|
||||
} from '../types.js';
|
||||
import { GIT_COMMIT_INFO, CLI_VERSION } from '../../generated/git-commit.js';
|
||||
import { UserAccountManager } from '../../utils/userAccountManager.js';
|
||||
import { InstallationManager } from '../../utils/installationManager.js';
|
||||
import { safeJsonStringify } from '../../utils/safeJsonStringify.js';
|
||||
|
||||
@@ -89,8 +90,10 @@ expect.extend({
|
||||
},
|
||||
});
|
||||
|
||||
vi.mock('../../utils/userAccountManager.js');
|
||||
vi.mock('../../utils/installationManager.js');
|
||||
|
||||
const mockUserAccount = vi.mocked(UserAccountManager.prototype);
|
||||
const mockInstallMgr = vi.mocked(InstallationManager.prototype);
|
||||
|
||||
// TODO(richieforeman): Consider moving this to test setup globally.
|
||||
@@ -125,7 +128,11 @@ describe('ClearcutLogger', () => {
|
||||
vi.unstubAllEnvs();
|
||||
});
|
||||
|
||||
function setup({ config = {} as Partial<ConfigParameters> } = {}) {
|
||||
function setup({
|
||||
config = {} as Partial<ConfigParameters>,
|
||||
lifetimeGoogleAccounts = 1,
|
||||
cachedGoogleAccount = 'test@google.com',
|
||||
} = {}) {
|
||||
server.resetHandlers(
|
||||
http.post(CLEARCUT_URL, () => HttpResponse.text(EXAMPLE_RESPONSE)),
|
||||
);
|
||||
@@ -139,6 +146,10 @@ describe('ClearcutLogger', () => {
|
||||
});
|
||||
ClearcutLogger.clearInstance();
|
||||
|
||||
mockUserAccount.getCachedGoogleAccount.mockReturnValue(cachedGoogleAccount);
|
||||
mockUserAccount.getLifetimeGoogleAccounts.mockReturnValue(
|
||||
lifetimeGoogleAccounts,
|
||||
);
|
||||
mockInstallMgr.getInstallationId = vi
|
||||
.fn()
|
||||
.mockReturnValue('test-installation-id');
|
||||
@@ -184,6 +195,19 @@ describe('ClearcutLogger', () => {
|
||||
});
|
||||
|
||||
describe('createLogEvent', () => {
|
||||
it('logs the total number of google accounts', () => {
|
||||
const { logger } = setup({
|
||||
lifetimeGoogleAccounts: 9001,
|
||||
});
|
||||
|
||||
const event = logger?.createLogEvent(EventNames.API_ERROR, []);
|
||||
|
||||
expect(event?.event_metadata[0]).toContainEqual({
|
||||
gemini_cli_key: EventMetadataKey.GEMINI_CLI_GOOGLE_ACCOUNTS_COUNT,
|
||||
value: '9001',
|
||||
});
|
||||
});
|
||||
|
||||
it('logs the current surface from a github action', () => {
|
||||
const { logger } = setup({});
|
||||
|
||||
@@ -227,6 +251,7 @@ describe('ClearcutLogger', () => {
|
||||
// Define expected values
|
||||
const session_id = 'test-session-id';
|
||||
const auth_type = AuthType.USE_GEMINI;
|
||||
const google_accounts = 123;
|
||||
const surface = 'ide-1234';
|
||||
const cli_version = CLI_VERSION;
|
||||
const git_commit_hash = GIT_COMMIT_INFO;
|
||||
@@ -235,6 +260,7 @@ describe('ClearcutLogger', () => {
|
||||
|
||||
// Setup logger with expected values
|
||||
const { logger, loggerConfig } = setup({
|
||||
lifetimeGoogleAccounts: google_accounts,
|
||||
config: {},
|
||||
});
|
||||
vi.spyOn(loggerConfig, 'getContentGeneratorConfig').mockReturnValue({
|
||||
@@ -257,6 +283,10 @@ describe('ClearcutLogger', () => {
|
||||
gemini_cli_key: EventMetadataKey.GEMINI_CLI_AUTH_TYPE,
|
||||
value: JSON.stringify(auth_type),
|
||||
},
|
||||
{
|
||||
gemini_cli_key: EventMetadataKey.GEMINI_CLI_GOOGLE_ACCOUNTS_COUNT,
|
||||
value: `${google_accounts}`,
|
||||
},
|
||||
{
|
||||
gemini_cli_key: EventMetadataKey.GEMINI_CLI_SURFACE,
|
||||
value: surface,
|
||||
@@ -374,14 +404,10 @@ describe('ClearcutLogger', () => {
|
||||
vi.stubEnv(key, value);
|
||||
}
|
||||
const event = logger?.createLogEvent(EventNames.API_ERROR, []);
|
||||
expect(event?.event_metadata[0]).toEqual(
|
||||
expect.arrayContaining([
|
||||
{
|
||||
gemini_cli_key: EventMetadataKey.GEMINI_CLI_SURFACE,
|
||||
value: expectedValue,
|
||||
},
|
||||
]),
|
||||
);
|
||||
expect(event?.event_metadata[0][3]).toEqual({
|
||||
gemini_cli_key: EventMetadataKey.GEMINI_CLI_SURFACE,
|
||||
value: expectedValue,
|
||||
});
|
||||
},
|
||||
);
|
||||
});
|
||||
|
||||
@@ -34,6 +34,7 @@ import type {
|
||||
import { EventMetadataKey } from './event-metadata-key.js';
|
||||
import type { Config } from '../../config/config.js';
|
||||
import { InstallationManager } from '../../utils/installationManager.js';
|
||||
import { UserAccountManager } from '../../utils/userAccountManager.js';
|
||||
import { safeJsonStringify } from '../../utils/safeJsonStringify.js';
|
||||
import { FixedDeque } from 'mnemonist';
|
||||
import { GIT_COMMIT_INFO, CLI_VERSION } from '../../generated/git-commit.js';
|
||||
@@ -156,6 +157,7 @@ export class ClearcutLogger {
|
||||
private sessionData: EventValue[] = [];
|
||||
private promptId: string = '';
|
||||
private readonly installationManager: InstallationManager;
|
||||
private readonly userAccountManager: UserAccountManager;
|
||||
|
||||
/**
|
||||
* Queue of pending events that need to be flushed to the server. New events
|
||||
@@ -184,6 +186,7 @@ export class ClearcutLogger {
|
||||
this.events = new FixedDeque<LogEventEntry[]>(Array, MAX_EVENTS);
|
||||
this.promptId = config?.getSessionId() ?? '';
|
||||
this.installationManager = new InstallationManager();
|
||||
this.userAccountManager = new UserAccountManager();
|
||||
}
|
||||
|
||||
static getInstance(config?: Config): ClearcutLogger | undefined {
|
||||
@@ -230,11 +233,14 @@ export class ClearcutLogger {
|
||||
}
|
||||
|
||||
createLogEvent(eventName: EventNames, data: EventValue[] = []): LogEvent {
|
||||
const email = this.userAccountManager.getCachedGoogleAccount();
|
||||
|
||||
if (eventName !== EventNames.START_SESSION) {
|
||||
data.push(...this.sessionData);
|
||||
}
|
||||
const totalAccounts = this.userAccountManager.getLifetimeGoogleAccounts();
|
||||
|
||||
data = this.addDefaultFields(data);
|
||||
data = this.addDefaultFields(data, totalAccounts);
|
||||
|
||||
const logEvent: LogEvent = {
|
||||
console_type: 'GEMINI_CLI',
|
||||
@@ -243,7 +249,12 @@ export class ClearcutLogger {
|
||||
event_metadata: [data],
|
||||
};
|
||||
|
||||
logEvent.client_install_id = this.installationManager.getInstallationId();
|
||||
// Should log either email or install ID, not both. See go/cloudmill-1p-oss-instrumentation#define-sessionable-id
|
||||
if (email) {
|
||||
logEvent.client_email = email;
|
||||
} else {
|
||||
logEvent.client_install_id = this.installationManager.getInstallationId();
|
||||
}
|
||||
|
||||
return logEvent;
|
||||
}
|
||||
@@ -1007,7 +1018,7 @@ export class ClearcutLogger {
|
||||
* Adds default fields to data, and returns a new data array. This fields
|
||||
* should exist on all log events.
|
||||
*/
|
||||
addDefaultFields(data: EventValue[]): EventValue[] {
|
||||
addDefaultFields(data: EventValue[], totalAccounts: number): EventValue[] {
|
||||
const surface = determineSurface();
|
||||
|
||||
const defaultLogMetadata: EventValue[] = [
|
||||
@@ -1021,6 +1032,10 @@ export class ClearcutLogger {
|
||||
this.config?.getContentGeneratorConfig()?.authType,
|
||||
),
|
||||
},
|
||||
{
|
||||
gemini_cli_key: EventMetadataKey.GEMINI_CLI_GOOGLE_ACCOUNTS_COUNT,
|
||||
value: `${totalAccounts}`,
|
||||
},
|
||||
{
|
||||
gemini_cli_key: EventMetadataKey.GEMINI_CLI_SURFACE,
|
||||
value: surface,
|
||||
|
||||
@@ -33,6 +33,7 @@ export const EVENT_MALFORMED_JSON_RESPONSE =
|
||||
export const EVENT_FILE_OPERATION = 'qwen-code.file_operation';
|
||||
export const EVENT_MODEL_SLASH_COMMAND = 'qwen-code.slash_command.model';
|
||||
export const EVENT_SUBAGENT_EXECUTION = 'qwen-code.subagent_execution';
|
||||
export const EVENT_SKILL_LAUNCH = 'qwen-code.skill_launch';
|
||||
export const EVENT_AUTH = 'qwen-code.auth';
|
||||
|
||||
// Performance Events
|
||||
|
||||
@@ -44,6 +44,7 @@ export {
|
||||
logRipgrepFallback,
|
||||
logNextSpeakerCheck,
|
||||
logAuth,
|
||||
logSkillLaunch,
|
||||
} from './loggers.js';
|
||||
export type { SlashCommandEvent, ChatCompressionEvent } from './types.js';
|
||||
export {
|
||||
@@ -63,6 +64,7 @@ export {
|
||||
RipgrepFallbackEvent,
|
||||
NextSpeakerCheckEvent,
|
||||
AuthEvent,
|
||||
SkillLaunchEvent,
|
||||
} from './types.js';
|
||||
export { makeSlashCommandEvent, makeChatCompressionEvent } from './types.js';
|
||||
export type { TelemetryEvent } from './types.js';
|
||||
|
||||
@@ -83,6 +83,7 @@ import type {
|
||||
} from '@google/genai';
|
||||
import { DiscoveredMCPTool } from '../tools/mcp-tool.js';
|
||||
import * as uiTelemetry from './uiTelemetry.js';
|
||||
import { UserAccountManager } from '../utils/userAccountManager.js';
|
||||
import { makeFakeConfig } from '../test-utils/config.js';
|
||||
|
||||
describe('loggers', () => {
|
||||
@@ -100,6 +101,10 @@ describe('loggers', () => {
|
||||
vi.spyOn(uiTelemetry.uiTelemetryService, 'addEvent').mockImplementation(
|
||||
mockUiEvent.addEvent,
|
||||
);
|
||||
vi.spyOn(
|
||||
UserAccountManager.prototype,
|
||||
'getCachedGoogleAccount',
|
||||
).mockReturnValue('test-user@example.com');
|
||||
vi.useFakeTimers();
|
||||
vi.setSystemTime(new Date('2025-01-01T00:00:00.000Z'));
|
||||
});
|
||||
@@ -183,6 +188,7 @@ describe('loggers', () => {
|
||||
body: 'CLI configuration loaded.',
|
||||
attributes: {
|
||||
'session.id': 'test-session-id',
|
||||
'user.email': 'test-user@example.com',
|
||||
'event.name': EVENT_CLI_CONFIG,
|
||||
'event.timestamp': '2025-01-01T00:00:00.000Z',
|
||||
model: 'test-model',
|
||||
@@ -200,6 +206,8 @@ describe('loggers', () => {
|
||||
mcp_tools: undefined,
|
||||
mcp_tools_count: undefined,
|
||||
output_format: 'json',
|
||||
skills: undefined,
|
||||
subagents: undefined,
|
||||
},
|
||||
});
|
||||
});
|
||||
@@ -227,6 +235,7 @@ describe('loggers', () => {
|
||||
body: 'User prompt. Length: 11.',
|
||||
attributes: {
|
||||
'session.id': 'test-session-id',
|
||||
'user.email': 'test-user@example.com',
|
||||
'event.name': EVENT_USER_PROMPT,
|
||||
'event.timestamp': '2025-01-01T00:00:00.000Z',
|
||||
prompt_length: 11,
|
||||
@@ -248,7 +257,7 @@ describe('loggers', () => {
|
||||
const event = new UserPromptEvent(
|
||||
11,
|
||||
'prompt-id-9',
|
||||
AuthType.USE_GEMINI,
|
||||
AuthType.CLOUD_SHELL,
|
||||
'test-prompt',
|
||||
);
|
||||
|
||||
@@ -258,11 +267,12 @@ describe('loggers', () => {
|
||||
body: 'User prompt. Length: 11.',
|
||||
attributes: {
|
||||
'session.id': 'test-session-id',
|
||||
'user.email': 'test-user@example.com',
|
||||
'event.name': EVENT_USER_PROMPT,
|
||||
'event.timestamp': '2025-01-01T00:00:00.000Z',
|
||||
prompt_length: 11,
|
||||
prompt_id: 'prompt-id-9',
|
||||
auth_type: 'gemini-api-key',
|
||||
auth_type: 'cloud-shell',
|
||||
},
|
||||
});
|
||||
});
|
||||
@@ -305,7 +315,7 @@ describe('loggers', () => {
|
||||
'test-model',
|
||||
100,
|
||||
'prompt-id-1',
|
||||
AuthType.USE_GEMINI,
|
||||
AuthType.LOGIN_WITH_GOOGLE,
|
||||
usageData,
|
||||
'test-response',
|
||||
);
|
||||
@@ -316,6 +326,7 @@ describe('loggers', () => {
|
||||
body: 'API response from test-model. Status: 200. Duration: 100ms.',
|
||||
attributes: {
|
||||
'session.id': 'test-session-id',
|
||||
'user.email': 'test-user@example.com',
|
||||
'event.name': EVENT_API_RESPONSE,
|
||||
'event.timestamp': '2025-01-01T00:00:00.000Z',
|
||||
[SemanticAttributes.HTTP_STATUS_CODE]: 200,
|
||||
@@ -331,7 +342,7 @@ describe('loggers', () => {
|
||||
total_token_count: 0,
|
||||
response_text: 'test-response',
|
||||
prompt_id: 'prompt-id-1',
|
||||
auth_type: 'gemini-api-key',
|
||||
auth_type: 'oauth-personal',
|
||||
},
|
||||
});
|
||||
|
||||
@@ -377,6 +388,7 @@ describe('loggers', () => {
|
||||
body: 'API request to test-model.',
|
||||
attributes: {
|
||||
'session.id': 'test-session-id',
|
||||
'user.email': 'test-user@example.com',
|
||||
'event.name': EVENT_API_REQUEST,
|
||||
'event.timestamp': '2025-01-01T00:00:00.000Z',
|
||||
model: 'test-model',
|
||||
@@ -395,6 +407,7 @@ describe('loggers', () => {
|
||||
body: 'API request to test-model.',
|
||||
attributes: {
|
||||
'session.id': 'test-session-id',
|
||||
'user.email': 'test-user@example.com',
|
||||
'event.name': EVENT_API_REQUEST,
|
||||
'event.timestamp': '2025-01-01T00:00:00.000Z',
|
||||
model: 'test-model',
|
||||
@@ -419,6 +432,7 @@ describe('loggers', () => {
|
||||
body: 'Switching to flash as Fallback.',
|
||||
attributes: {
|
||||
'session.id': 'test-session-id',
|
||||
'user.email': 'test-user@example.com',
|
||||
'event.name': EVENT_FLASH_FALLBACK,
|
||||
'event.timestamp': '2025-01-01T00:00:00.000Z',
|
||||
auth_type: 'vertex-ai',
|
||||
@@ -453,6 +467,7 @@ describe('loggers', () => {
|
||||
expect(emittedEvent.attributes).toEqual(
|
||||
expect.objectContaining({
|
||||
'session.id': 'test-session-id',
|
||||
'user.email': 'test-user@example.com',
|
||||
'event.name': EVENT_RIPGREP_FALLBACK,
|
||||
error: 'ripgrep is not available',
|
||||
}),
|
||||
@@ -471,6 +486,7 @@ describe('loggers', () => {
|
||||
expect(emittedEvent.attributes).toEqual(
|
||||
expect.objectContaining({
|
||||
'session.id': 'test-session-id',
|
||||
'user.email': 'test-user@example.com',
|
||||
'event.name': EVENT_RIPGREP_FALLBACK,
|
||||
error: 'rg not found',
|
||||
}),
|
||||
@@ -584,6 +600,7 @@ describe('loggers', () => {
|
||||
body: 'Tool call: test-function. Decision: accept. Success: true. Duration: 100ms.',
|
||||
attributes: {
|
||||
'session.id': 'test-session-id',
|
||||
'user.email': 'test-user@example.com',
|
||||
'event.name': EVENT_TOOL_CALL,
|
||||
'event.timestamp': '2025-01-01T00:00:00.000Z',
|
||||
function_name: 'test-function',
|
||||
@@ -667,6 +684,7 @@ describe('loggers', () => {
|
||||
body: 'Tool call: test-function. Decision: reject. Success: false. Duration: 100ms.',
|
||||
attributes: {
|
||||
'session.id': 'test-session-id',
|
||||
'user.email': 'test-user@example.com',
|
||||
'event.name': EVENT_TOOL_CALL,
|
||||
'event.timestamp': '2025-01-01T00:00:00.000Z',
|
||||
function_name: 'test-function',
|
||||
@@ -743,6 +761,7 @@ describe('loggers', () => {
|
||||
body: 'Tool call: test-function. Decision: modify. Success: true. Duration: 100ms.',
|
||||
attributes: {
|
||||
'session.id': 'test-session-id',
|
||||
'user.email': 'test-user@example.com',
|
||||
'event.name': EVENT_TOOL_CALL,
|
||||
'event.timestamp': '2025-01-01T00:00:00.000Z',
|
||||
function_name: 'test-function',
|
||||
@@ -818,6 +837,7 @@ describe('loggers', () => {
|
||||
body: 'Tool call: test-function. Success: true. Duration: 100ms.',
|
||||
attributes: {
|
||||
'session.id': 'test-session-id',
|
||||
'user.email': 'test-user@example.com',
|
||||
'event.name': EVENT_TOOL_CALL,
|
||||
'event.timestamp': '2025-01-01T00:00:00.000Z',
|
||||
function_name: 'test-function',
|
||||
@@ -892,6 +912,7 @@ describe('loggers', () => {
|
||||
body: 'Tool call: test-function. Success: false. Duration: 100ms.',
|
||||
attributes: {
|
||||
'session.id': 'test-session-id',
|
||||
'user.email': 'test-user@example.com',
|
||||
'event.name': EVENT_TOOL_CALL,
|
||||
'event.timestamp': '2025-01-01T00:00:00.000Z',
|
||||
function_name: 'test-function',
|
||||
@@ -980,6 +1001,7 @@ describe('loggers', () => {
|
||||
body: 'Tool call: mock_mcp_tool. Success: true. Duration: 100ms.',
|
||||
attributes: {
|
||||
'session.id': 'test-session-id',
|
||||
'user.email': 'test-user@example.com',
|
||||
'event.name': EVENT_TOOL_CALL,
|
||||
'event.timestamp': '2025-01-01T00:00:00.000Z',
|
||||
function_name: 'mock_mcp_tool',
|
||||
@@ -1027,6 +1049,7 @@ describe('loggers', () => {
|
||||
body: 'Malformed JSON response from test-model.',
|
||||
attributes: {
|
||||
'session.id': 'test-session-id',
|
||||
'user.email': 'test-user@example.com',
|
||||
'event.name': EVENT_MALFORMED_JSON_RESPONSE,
|
||||
'event.timestamp': '2025-01-01T00:00:00.000Z',
|
||||
model: 'test-model',
|
||||
@@ -1070,6 +1093,7 @@ describe('loggers', () => {
|
||||
body: 'File operation: read. Lines: 10.',
|
||||
attributes: {
|
||||
'session.id': 'test-session-id',
|
||||
'user.email': 'test-user@example.com',
|
||||
'event.name': EVENT_FILE_OPERATION,
|
||||
'event.timestamp': '2025-01-01T00:00:00.000Z',
|
||||
tool_name: 'test-tool',
|
||||
@@ -1115,6 +1139,7 @@ describe('loggers', () => {
|
||||
body: 'Tool output truncated for test-tool.',
|
||||
attributes: {
|
||||
'session.id': 'test-session-id',
|
||||
'user.email': 'test-user@example.com',
|
||||
'event.name': 'tool_output_truncated',
|
||||
'event.timestamp': '2025-01-01T00:00:00.000Z',
|
||||
eventName: 'tool_output_truncated',
|
||||
@@ -1161,6 +1186,7 @@ describe('loggers', () => {
|
||||
body: 'Installed extension vscode',
|
||||
attributes: {
|
||||
'session.id': 'test-session-id',
|
||||
'user.email': 'test-user@example.com',
|
||||
'event.name': EVENT_EXTENSION_INSTALL,
|
||||
'event.timestamp': '2025-01-01T00:00:00.000Z',
|
||||
extension_name: 'vscode',
|
||||
@@ -1199,6 +1225,7 @@ describe('loggers', () => {
|
||||
body: 'Uninstalled extension vscode',
|
||||
attributes: {
|
||||
'session.id': 'test-session-id',
|
||||
'user.email': 'test-user@example.com',
|
||||
'event.name': EVENT_EXTENSION_UNINSTALL,
|
||||
'event.timestamp': '2025-01-01T00:00:00.000Z',
|
||||
extension_name: 'vscode',
|
||||
@@ -1235,6 +1262,7 @@ describe('loggers', () => {
|
||||
body: 'Enabled extension vscode',
|
||||
attributes: {
|
||||
'session.id': 'test-session-id',
|
||||
'user.email': 'test-user@example.com',
|
||||
'event.name': EVENT_EXTENSION_ENABLE,
|
||||
'event.timestamp': '2025-01-01T00:00:00.000Z',
|
||||
extension_name: 'vscode',
|
||||
@@ -1271,6 +1299,7 @@ describe('loggers', () => {
|
||||
body: 'Disabled extension vscode',
|
||||
attributes: {
|
||||
'session.id': 'test-session-id',
|
||||
'user.email': 'test-user@example.com',
|
||||
'event.name': EVENT_EXTENSION_DISABLE,
|
||||
'event.timestamp': '2025-01-01T00:00:00.000Z',
|
||||
extension_name: 'vscode',
|
||||
|
||||
@@ -9,6 +9,7 @@ import { logs } from '@opentelemetry/api-logs';
|
||||
import { SemanticAttributes } from '@opentelemetry/semantic-conventions';
|
||||
import type { Config } from '../config/config.js';
|
||||
import { safeJsonStringify } from '../utils/safeJsonStringify.js';
|
||||
import { UserAccountManager } from '../utils/userAccountManager.js';
|
||||
import {
|
||||
EVENT_API_ERROR,
|
||||
EVENT_API_CANCEL,
|
||||
@@ -37,6 +38,7 @@ import {
|
||||
EVENT_MALFORMED_JSON_RESPONSE,
|
||||
EVENT_INVALID_CHUNK,
|
||||
EVENT_AUTH,
|
||||
EVENT_SKILL_LAUNCH,
|
||||
} from './constants.js';
|
||||
import {
|
||||
recordApiErrorMetrics,
|
||||
@@ -84,6 +86,7 @@ import type {
|
||||
MalformedJsonResponseEvent,
|
||||
InvalidChunkEvent,
|
||||
AuthEvent,
|
||||
SkillLaunchEvent,
|
||||
} from './types.js';
|
||||
import type { UiEvent } from './uiTelemetry.js';
|
||||
import { uiTelemetryService } from './uiTelemetry.js';
|
||||
@@ -92,8 +95,11 @@ const shouldLogUserPrompts = (config: Config): boolean =>
|
||||
config.getTelemetryLogPromptsEnabled();
|
||||
|
||||
function getCommonAttributes(config: Config): LogAttributes {
|
||||
const userAccountManager = new UserAccountManager();
|
||||
const email = userAccountManager.getCachedGoogleAccount();
|
||||
return {
|
||||
'session.id': config.getSessionId(),
|
||||
...(email && { 'user.email': email }),
|
||||
};
|
||||
}
|
||||
|
||||
@@ -123,6 +129,8 @@ export function logStartSession(
|
||||
mcp_tools: event.mcp_tools,
|
||||
mcp_tools_count: event.mcp_tools_count,
|
||||
output_format: event.output_format,
|
||||
skills: event.skills,
|
||||
subagents: event.subagents,
|
||||
};
|
||||
|
||||
const logger = logs.getLogger(SERVICE_NAME);
|
||||
@@ -865,3 +873,21 @@ export function logAuth(config: Config, event: AuthEvent): void {
|
||||
};
|
||||
logger.emit(logRecord);
|
||||
}
|
||||
|
||||
export function logSkillLaunch(config: Config, event: SkillLaunchEvent): void {
|
||||
if (!isTelemetrySdkInitialized()) return;
|
||||
|
||||
const attributes: LogAttributes = {
|
||||
...getCommonAttributes(config),
|
||||
...event,
|
||||
'event.name': EVENT_SKILL_LAUNCH,
|
||||
'event.timestamp': new Date().toISOString(),
|
||||
};
|
||||
|
||||
const logger = logs.getLogger(SERVICE_NAME);
|
||||
const logRecord: LogRecord = {
|
||||
body: `Skill launch: ${event.skill_name}. Success: ${event.success}.`,
|
||||
attributes,
|
||||
};
|
||||
logger.emit(logRecord);
|
||||
}
|
||||
|
||||
@@ -38,6 +38,7 @@ import type {
|
||||
ModelSlashCommandEvent,
|
||||
ExtensionDisableEvent,
|
||||
AuthEvent,
|
||||
SkillLaunchEvent,
|
||||
RipgrepFallbackEvent,
|
||||
EndSessionEvent,
|
||||
} from '../types.js';
|
||||
@@ -391,6 +392,8 @@ export class QwenLogger {
|
||||
telemetry_enabled: event.telemetry_enabled,
|
||||
telemetry_log_user_prompts_enabled:
|
||||
event.telemetry_log_user_prompts_enabled,
|
||||
skills: event.skills,
|
||||
subagents: event.subagents,
|
||||
},
|
||||
});
|
||||
|
||||
@@ -827,6 +830,18 @@ export class QwenLogger {
|
||||
this.flushIfNeeded();
|
||||
}
|
||||
|
||||
logSkillLaunchEvent(event: SkillLaunchEvent): void {
|
||||
const rumEvent = this.createActionEvent('misc', 'skill_launch', {
|
||||
properties: {
|
||||
skill_name: event.skill_name,
|
||||
success: event.success ? 1 : 0,
|
||||
},
|
||||
});
|
||||
|
||||
this.enqueueLogEvent(rumEvent);
|
||||
this.flushIfNeeded();
|
||||
}
|
||||
|
||||
logChatCompressionEvent(event: ChatCompressionEvent): void {
|
||||
const rumEvent = this.createActionEvent('misc', 'chat_compression', {
|
||||
properties: {
|
||||
|
||||
@@ -18,6 +18,9 @@ import {
|
||||
import type { FileOperation } from './metrics.js';
|
||||
export { ToolCallDecision };
|
||||
import type { OutputFormat } from '../output/types.js';
|
||||
import { ToolNames } from '../tools/tool-names.js';
|
||||
import type { SkillTool } from '../tools/skill.js';
|
||||
import type { TaskTool } from '../tools/task.js';
|
||||
|
||||
export interface BaseTelemetryEvent {
|
||||
'event.name': string;
|
||||
@@ -47,6 +50,8 @@ export class StartSessionEvent implements BaseTelemetryEvent {
|
||||
mcp_tools_count?: number;
|
||||
mcp_tools?: string;
|
||||
output_format: OutputFormat;
|
||||
skills?: string;
|
||||
subagents?: string;
|
||||
|
||||
constructor(config: Config) {
|
||||
const generatorConfig = config.getContentGeneratorConfig();
|
||||
@@ -79,6 +84,7 @@ export class StartSessionEvent implements BaseTelemetryEvent {
|
||||
config.getFileFilteringRespectGitIgnore();
|
||||
this.mcp_servers_count = mcpServers ? Object.keys(mcpServers).length : 0;
|
||||
this.output_format = config.getOutputFormat();
|
||||
|
||||
if (toolRegistry) {
|
||||
const mcpTools = toolRegistry
|
||||
.getAllTools()
|
||||
@@ -87,6 +93,22 @@ export class StartSessionEvent implements BaseTelemetryEvent {
|
||||
this.mcp_tools = mcpTools
|
||||
.map((tool) => (tool as DiscoveredMCPTool).name)
|
||||
.join(',');
|
||||
|
||||
const skillTool = toolRegistry.getTool(ToolNames.SKILL) as
|
||||
| SkillTool
|
||||
| undefined;
|
||||
const skillNames = skillTool?.getAvailableSkillNames?.();
|
||||
if (skillNames && skillNames.length > 0) {
|
||||
this.skills = skillNames.join(',');
|
||||
}
|
||||
|
||||
const taskTool = toolRegistry.getTool(ToolNames.TASK) as
|
||||
| TaskTool
|
||||
| undefined;
|
||||
const subagentNames = taskTool?.getAvailableSubagentNames?.();
|
||||
if (subagentNames && subagentNames.length > 0) {
|
||||
this.subagents = subagentNames.join(',');
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -721,6 +743,20 @@ export class AuthEvent implements BaseTelemetryEvent {
|
||||
}
|
||||
}
|
||||
|
||||
export class SkillLaunchEvent implements BaseTelemetryEvent {
|
||||
'event.name': 'skill_launch';
|
||||
'event.timestamp': string;
|
||||
skill_name: string;
|
||||
success: boolean;
|
||||
|
||||
constructor(skill_name: string, success: boolean) {
|
||||
this['event.name'] = 'skill_launch';
|
||||
this['event.timestamp'] = new Date().toISOString();
|
||||
this.skill_name = skill_name;
|
||||
this.success = success;
|
||||
}
|
||||
}
|
||||
|
||||
export type TelemetryEvent =
|
||||
| StartSessionEvent
|
||||
| EndSessionEvent
|
||||
@@ -749,7 +785,8 @@ export type TelemetryEvent =
|
||||
| ExtensionUninstallEvent
|
||||
| ToolOutputTruncatedEvent
|
||||
| ModelSlashCommandEvent
|
||||
| AuthEvent;
|
||||
| AuthEvent
|
||||
| SkillLaunchEvent;
|
||||
|
||||
export class ExtensionDisableEvent implements BaseTelemetryEvent {
|
||||
'event.name': 'extension_disable';
|
||||
|
||||
@@ -31,6 +31,8 @@ describe('LSTool', () => {
|
||||
tempSecondaryDir,
|
||||
]);
|
||||
|
||||
const userSkillsBase = path.join(os.homedir(), '.qwen', 'skills');
|
||||
|
||||
mockConfig = {
|
||||
getTargetDir: () => tempRootDir,
|
||||
getWorkspaceContext: () => mockWorkspaceContext,
|
||||
@@ -39,6 +41,9 @@ describe('LSTool', () => {
|
||||
respectGitIgnore: true,
|
||||
respectQwenIgnore: true,
|
||||
}),
|
||||
storage: {
|
||||
getUserSkillsDir: () => userSkillsBase,
|
||||
},
|
||||
} as unknown as Config;
|
||||
|
||||
lsTool = new LSTool(mockConfig);
|
||||
@@ -288,7 +293,7 @@ describe('LSTool', () => {
|
||||
};
|
||||
const invocation = lsTool.build(params);
|
||||
const description = invocation.getDescription();
|
||||
const expected = path.relative(tempRootDir, params.path);
|
||||
const expected = path.resolve(params.path);
|
||||
expect(description).toBe(expected);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -9,6 +9,7 @@ import path from 'node:path';
|
||||
import type { ToolInvocation, ToolResult } from './tools.js';
|
||||
import { BaseDeclarativeTool, BaseToolInvocation, Kind } from './tools.js';
|
||||
import { makeRelative, shortenPath } from '../utils/paths.js';
|
||||
import { isSubpath } from '../utils/paths.js';
|
||||
import type { Config } from '../config/config.js';
|
||||
import { DEFAULT_FILE_FILTERING_OPTIONS } from '../config/constants.js';
|
||||
import { ToolErrorType } from './tool-error.js';
|
||||
@@ -311,8 +312,14 @@ export class LSTool extends BaseDeclarativeTool<LSToolParams, ToolResult> {
|
||||
return `Path must be absolute: ${params.path}`;
|
||||
}
|
||||
|
||||
const userSkillsBase = this.config.storage.getUserSkillsDir();
|
||||
const isUnderUserSkills = isSubpath(userSkillsBase, params.path);
|
||||
|
||||
const workspaceContext = this.config.getWorkspaceContext();
|
||||
if (!workspaceContext.isPathWithinWorkspace(params.path)) {
|
||||
if (
|
||||
!workspaceContext.isPathWithinWorkspace(params.path) &&
|
||||
!isUnderUserSkills
|
||||
) {
|
||||
const directories = workspaceContext.getDirectories();
|
||||
return `Path must be within one of the workspace directories: ${directories.join(
|
||||
', ',
|
||||
|
||||
@@ -217,9 +217,9 @@ describe('mcp-client', () => {
|
||||
false,
|
||||
);
|
||||
|
||||
expect(transport).toBeInstanceOf(StreamableHTTPClientTransport);
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
expect((transport as any)._url).toEqual(new URL('http://test-server'));
|
||||
expect(transport).toEqual(
|
||||
new StreamableHTTPClientTransport(new URL('http://test-server'), {}),
|
||||
);
|
||||
});
|
||||
|
||||
it('with headers', async () => {
|
||||
@@ -232,13 +232,13 @@ describe('mcp-client', () => {
|
||||
false,
|
||||
);
|
||||
|
||||
expect(transport).toBeInstanceOf(StreamableHTTPClientTransport);
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
expect((transport as any)._url).toEqual(new URL('http://test-server'));
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
expect((transport as any)._requestInit?.headers).toEqual({
|
||||
Authorization: 'derp',
|
||||
});
|
||||
expect(transport).toEqual(
|
||||
new StreamableHTTPClientTransport(new URL('http://test-server'), {
|
||||
requestInit: {
|
||||
headers: { Authorization: 'derp' },
|
||||
},
|
||||
}),
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -251,9 +251,9 @@ describe('mcp-client', () => {
|
||||
},
|
||||
false,
|
||||
);
|
||||
expect(transport).toBeInstanceOf(SSEClientTransport);
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
expect((transport as any)._url).toEqual(new URL('http://test-server'));
|
||||
expect(transport).toEqual(
|
||||
new SSEClientTransport(new URL('http://test-server'), {}),
|
||||
);
|
||||
});
|
||||
|
||||
it('with headers', async () => {
|
||||
@@ -266,13 +266,13 @@ describe('mcp-client', () => {
|
||||
false,
|
||||
);
|
||||
|
||||
expect(transport).toBeInstanceOf(SSEClientTransport);
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
expect((transport as any)._url).toEqual(new URL('http://test-server'));
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
expect((transport as any)._requestInit?.headers).toEqual({
|
||||
Authorization: 'derp',
|
||||
});
|
||||
expect(transport).toEqual(
|
||||
new SSEClientTransport(new URL('http://test-server'), {
|
||||
requestInit: {
|
||||
headers: { Authorization: 'derp' },
|
||||
},
|
||||
}),
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
|
||||
@@ -40,6 +40,7 @@ describe('ReadFileTool', () => {
|
||||
getWorkspaceContext: () => createMockWorkspaceContext(tempRootDir),
|
||||
storage: {
|
||||
getProjectTempDir: () => path.join(tempRootDir, '.temp'),
|
||||
getUserSkillsDir: () => path.join(os.homedir(), '.qwen', 'skills'),
|
||||
},
|
||||
getTruncateToolOutputThreshold: () => 2500,
|
||||
getTruncateToolOutputLines: () => 500,
|
||||
|
||||
@@ -20,6 +20,7 @@ import { FileOperation } from '../telemetry/metrics.js';
|
||||
import { getProgrammingLanguage } from '../telemetry/telemetry-utils.js';
|
||||
import { logFileOperation } from '../telemetry/loggers.js';
|
||||
import { FileOperationEvent } from '../telemetry/types.js';
|
||||
import { isSubpath } from '../utils/paths.js';
|
||||
|
||||
/**
|
||||
* Parameters for the ReadFile tool
|
||||
@@ -183,15 +184,20 @@ export class ReadFileTool extends BaseDeclarativeTool<
|
||||
|
||||
const workspaceContext = this.config.getWorkspaceContext();
|
||||
const projectTempDir = this.config.storage.getProjectTempDir();
|
||||
const userSkillsDir = this.config.storage.getUserSkillsDir();
|
||||
const resolvedFilePath = path.resolve(filePath);
|
||||
const resolvedProjectTempDir = path.resolve(projectTempDir);
|
||||
const isWithinTempDir =
|
||||
resolvedFilePath.startsWith(resolvedProjectTempDir + path.sep) ||
|
||||
resolvedFilePath === resolvedProjectTempDir;
|
||||
const isWithinTempDir = isSubpath(projectTempDir, resolvedFilePath);
|
||||
const isWithinUserSkills = isSubpath(userSkillsDir, resolvedFilePath);
|
||||
|
||||
if (!workspaceContext.isPathWithinWorkspace(filePath) && !isWithinTempDir) {
|
||||
if (
|
||||
!workspaceContext.isPathWithinWorkspace(filePath) &&
|
||||
!isWithinTempDir &&
|
||||
!isWithinUserSkills
|
||||
) {
|
||||
const directories = workspaceContext.getDirectories();
|
||||
return `File path must be within one of the workspace directories: ${directories.join(', ')} or within the project temp directory: ${projectTempDir}`;
|
||||
return `File path must be within one of the workspace directories: ${directories.join(
|
||||
', ',
|
||||
)} or within the project temp directory: ${projectTempDir}`;
|
||||
}
|
||||
if (params.offset !== undefined && params.offset < 0) {
|
||||
return 'Offset must be a non-negative number';
|
||||
|
||||
442
packages/core/src/tools/skill.test.ts
Normal file
442
packages/core/src/tools/skill.test.ts
Normal file
@@ -0,0 +1,442 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Qwen
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { describe, it, expect, beforeEach, afterEach, vi } from 'vitest';
|
||||
import { SkillTool, type SkillParams } from './skill.js';
|
||||
import type { PartListUnion } from '@google/genai';
|
||||
import type { ToolResultDisplay } from './tools.js';
|
||||
import type { Config } from '../config/config.js';
|
||||
import { SkillManager } from '../skills/skill-manager.js';
|
||||
import type { SkillConfig } from '../skills/types.js';
|
||||
import { partToString } from '../utils/partUtils.js';
|
||||
|
||||
// Type for accessing protected methods in tests
|
||||
type SkillToolWithProtectedMethods = SkillTool & {
|
||||
createInvocation: (params: SkillParams) => {
|
||||
execute: (
|
||||
signal?: AbortSignal,
|
||||
updateOutput?: (output: ToolResultDisplay) => void,
|
||||
) => Promise<{
|
||||
llmContent: PartListUnion;
|
||||
returnDisplay: ToolResultDisplay;
|
||||
}>;
|
||||
getDescription: () => string;
|
||||
shouldConfirmExecute: () => Promise<boolean>;
|
||||
};
|
||||
};
|
||||
|
||||
// Mock dependencies
|
||||
vi.mock('../skills/skill-manager.js');
|
||||
vi.mock('../telemetry/index.js', () => ({
|
||||
logSkillLaunch: vi.fn(),
|
||||
SkillLaunchEvent: class {
|
||||
constructor(
|
||||
public skill_name: string,
|
||||
public success: boolean,
|
||||
) {}
|
||||
},
|
||||
}));
|
||||
|
||||
const MockedSkillManager = vi.mocked(SkillManager);
|
||||
|
||||
describe('SkillTool', () => {
|
||||
let config: Config;
|
||||
let skillTool: SkillTool;
|
||||
let mockSkillManager: SkillManager;
|
||||
let changeListeners: Array<() => void>;
|
||||
|
||||
const mockSkills: SkillConfig[] = [
|
||||
{
|
||||
name: 'code-review',
|
||||
description: 'Specialized skill for reviewing code quality',
|
||||
level: 'project',
|
||||
filePath: '/project/.qwen/skills/code-review/SKILL.md',
|
||||
body: 'Review code for quality and best practices.',
|
||||
},
|
||||
{
|
||||
name: 'testing',
|
||||
description: 'Skill for writing and running tests',
|
||||
level: 'user',
|
||||
filePath: '/home/user/.qwen/skills/testing/SKILL.md',
|
||||
body: 'Help write comprehensive tests.',
|
||||
allowedTools: ['read_file', 'write_file', 'shell'],
|
||||
},
|
||||
];
|
||||
|
||||
beforeEach(async () => {
|
||||
// Setup fake timers
|
||||
vi.useFakeTimers();
|
||||
|
||||
// Create mock config
|
||||
config = {
|
||||
getProjectRoot: vi.fn().mockReturnValue('/test/project'),
|
||||
getSessionId: vi.fn().mockReturnValue('test-session-id'),
|
||||
getSkillManager: vi.fn(),
|
||||
getGeminiClient: vi.fn().mockReturnValue(undefined),
|
||||
} as unknown as Config;
|
||||
|
||||
changeListeners = [];
|
||||
|
||||
// Setup SkillManager mock
|
||||
mockSkillManager = {
|
||||
listSkills: vi.fn().mockResolvedValue(mockSkills),
|
||||
loadSkill: vi.fn(),
|
||||
loadSkillForRuntime: vi.fn(),
|
||||
addChangeListener: vi.fn((listener: () => void) => {
|
||||
changeListeners.push(listener);
|
||||
return () => {
|
||||
const index = changeListeners.indexOf(listener);
|
||||
if (index >= 0) {
|
||||
changeListeners.splice(index, 1);
|
||||
}
|
||||
};
|
||||
}),
|
||||
getParseErrors: vi.fn().mockReturnValue(new Map()),
|
||||
} as unknown as SkillManager;
|
||||
|
||||
MockedSkillManager.mockImplementation(() => mockSkillManager);
|
||||
|
||||
// Make config return the mock SkillManager
|
||||
vi.mocked(config.getSkillManager).mockReturnValue(mockSkillManager);
|
||||
|
||||
// Create SkillTool instance
|
||||
skillTool = new SkillTool(config);
|
||||
|
||||
// Allow async initialization to complete
|
||||
await vi.runAllTimersAsync();
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.useRealTimers();
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
describe('initialization', () => {
|
||||
it('should initialize with correct name and properties', () => {
|
||||
expect(skillTool.name).toBe('skill');
|
||||
expect(skillTool.displayName).toBe('Skill');
|
||||
expect(skillTool.kind).toBe('read');
|
||||
});
|
||||
|
||||
it('should load available skills during initialization', () => {
|
||||
expect(mockSkillManager.listSkills).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should subscribe to skill manager changes', () => {
|
||||
expect(mockSkillManager.addChangeListener).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
|
||||
it('should update description with available skills', () => {
|
||||
expect(skillTool.description).toContain('code-review');
|
||||
expect(skillTool.description).toContain(
|
||||
'Specialized skill for reviewing code quality',
|
||||
);
|
||||
expect(skillTool.description).toContain('testing');
|
||||
expect(skillTool.description).toContain(
|
||||
'Skill for writing and running tests',
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle empty skills list gracefully', async () => {
|
||||
vi.mocked(mockSkillManager.listSkills).mockResolvedValue([]);
|
||||
|
||||
const emptySkillTool = new SkillTool(config);
|
||||
await vi.runAllTimersAsync();
|
||||
|
||||
expect(emptySkillTool.description).toContain(
|
||||
'No skills are currently configured',
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle skill loading errors gracefully', async () => {
|
||||
vi.mocked(mockSkillManager.listSkills).mockRejectedValue(
|
||||
new Error('Loading failed'),
|
||||
);
|
||||
|
||||
const consoleSpy = vi.spyOn(console, 'warn').mockImplementation(() => {});
|
||||
|
||||
new SkillTool(config);
|
||||
await vi.runAllTimersAsync();
|
||||
|
||||
expect(consoleSpy).toHaveBeenCalledWith(
|
||||
'Failed to load skills for Skills tool:',
|
||||
expect.any(Error),
|
||||
);
|
||||
consoleSpy.mockRestore();
|
||||
});
|
||||
});
|
||||
|
||||
describe('schema generation', () => {
|
||||
it('should expose static schema without dynamic enums', () => {
|
||||
const schema = skillTool.schema;
|
||||
const properties = schema.parametersJsonSchema as {
|
||||
properties: {
|
||||
skill: {
|
||||
type: string;
|
||||
description: string;
|
||||
enum?: string[];
|
||||
};
|
||||
};
|
||||
};
|
||||
expect(properties.properties.skill.type).toBe('string');
|
||||
expect(properties.properties.skill.description).toBe(
|
||||
'The skill name (no arguments). E.g., "pdf" or "xlsx"',
|
||||
);
|
||||
expect(properties.properties.skill.enum).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should keep schema static even when no skills available', async () => {
|
||||
vi.mocked(mockSkillManager.listSkills).mockResolvedValue([]);
|
||||
|
||||
const emptySkillTool = new SkillTool(config);
|
||||
await vi.runAllTimersAsync();
|
||||
|
||||
const schema = emptySkillTool.schema;
|
||||
const properties = schema.parametersJsonSchema as {
|
||||
properties: {
|
||||
skill: {
|
||||
type: string;
|
||||
description: string;
|
||||
enum?: string[];
|
||||
};
|
||||
};
|
||||
};
|
||||
expect(properties.properties.skill.type).toBe('string');
|
||||
expect(properties.properties.skill.description).toBe(
|
||||
'The skill name (no arguments). E.g., "pdf" or "xlsx"',
|
||||
);
|
||||
expect(properties.properties.skill.enum).toBeUndefined();
|
||||
});
|
||||
});
|
||||
|
||||
describe('validateToolParams', () => {
|
||||
it('should validate valid parameters', () => {
|
||||
const result = skillTool.validateToolParams({ skill: 'code-review' });
|
||||
expect(result).toBeNull();
|
||||
});
|
||||
|
||||
it('should reject empty skill', () => {
|
||||
const result = skillTool.validateToolParams({ skill: '' });
|
||||
expect(result).toBe('Parameter "skill" must be a non-empty string.');
|
||||
});
|
||||
|
||||
it('should reject non-existent skill', () => {
|
||||
const result = skillTool.validateToolParams({
|
||||
skill: 'non-existent',
|
||||
});
|
||||
expect(result).toBe(
|
||||
'Skill "non-existent" not found. Available skills: code-review, testing',
|
||||
);
|
||||
});
|
||||
|
||||
it('should show appropriate message when no skills available', async () => {
|
||||
vi.mocked(mockSkillManager.listSkills).mockResolvedValue([]);
|
||||
|
||||
const emptySkillTool = new SkillTool(config);
|
||||
await vi.runAllTimersAsync();
|
||||
|
||||
const result = emptySkillTool.validateToolParams({
|
||||
skill: 'non-existent',
|
||||
});
|
||||
expect(result).toBe(
|
||||
'Skill "non-existent" not found. No skills are currently available.',
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('refreshSkills', () => {
|
||||
it('should refresh when change listener fires', async () => {
|
||||
const newSkills: SkillConfig[] = [
|
||||
{
|
||||
name: 'new-skill',
|
||||
description: 'A brand new skill',
|
||||
level: 'project',
|
||||
filePath: '/project/.qwen/skills/new-skill/SKILL.md',
|
||||
body: 'New skill content.',
|
||||
},
|
||||
];
|
||||
|
||||
vi.mocked(mockSkillManager.listSkills).mockResolvedValueOnce(newSkills);
|
||||
|
||||
const listener = changeListeners[0];
|
||||
expect(listener).toBeDefined();
|
||||
|
||||
listener?.();
|
||||
await vi.runAllTimersAsync();
|
||||
|
||||
expect(skillTool.description).toContain('new-skill');
|
||||
expect(skillTool.description).toContain('A brand new skill');
|
||||
});
|
||||
|
||||
it('should refresh available skills and update description', async () => {
|
||||
const newSkills: SkillConfig[] = [
|
||||
{
|
||||
name: 'test-skill',
|
||||
description: 'A test skill',
|
||||
level: 'project',
|
||||
filePath: '/project/.qwen/skills/test-skill/SKILL.md',
|
||||
body: 'Test content.',
|
||||
},
|
||||
];
|
||||
|
||||
vi.mocked(mockSkillManager.listSkills).mockResolvedValue(newSkills);
|
||||
|
||||
await skillTool.refreshSkills();
|
||||
|
||||
expect(skillTool.description).toContain('test-skill');
|
||||
expect(skillTool.description).toContain('A test skill');
|
||||
});
|
||||
});
|
||||
|
||||
describe('SkillToolInvocation', () => {
|
||||
const mockRuntimeConfig: SkillConfig = {
|
||||
...mockSkills[0],
|
||||
};
|
||||
|
||||
beforeEach(() => {
|
||||
vi.mocked(mockSkillManager.loadSkillForRuntime).mockResolvedValue(
|
||||
mockRuntimeConfig,
|
||||
);
|
||||
});
|
||||
|
||||
it('should execute skill load successfully', async () => {
|
||||
const params: SkillParams = {
|
||||
skill: 'code-review',
|
||||
};
|
||||
|
||||
const invocation = (
|
||||
skillTool as SkillToolWithProtectedMethods
|
||||
).createInvocation(params);
|
||||
const result = await invocation.execute();
|
||||
|
||||
expect(mockSkillManager.loadSkillForRuntime).toHaveBeenCalledWith(
|
||||
'code-review',
|
||||
);
|
||||
|
||||
const llmText = partToString(result.llmContent);
|
||||
expect(llmText).toContain(
|
||||
'Base directory for this skill: /project/.qwen/skills/code-review',
|
||||
);
|
||||
expect(llmText.trim()).toContain(
|
||||
'Review code for quality and best practices.',
|
||||
);
|
||||
|
||||
expect(result.returnDisplay).toBe('Launching skill: code-review');
|
||||
});
|
||||
|
||||
it('should include allowedTools in result when present', async () => {
|
||||
const skillWithTools: SkillConfig = {
|
||||
...mockSkills[1],
|
||||
};
|
||||
vi.mocked(mockSkillManager.loadSkillForRuntime).mockResolvedValue(
|
||||
skillWithTools,
|
||||
);
|
||||
|
||||
const params: SkillParams = {
|
||||
skill: 'testing',
|
||||
};
|
||||
|
||||
const invocation = (
|
||||
skillTool as SkillToolWithProtectedMethods
|
||||
).createInvocation(params);
|
||||
const result = await invocation.execute();
|
||||
|
||||
const llmText = partToString(result.llmContent);
|
||||
expect(llmText).toContain('testing');
|
||||
// Base description is omitted from llmContent; ensure body is present.
|
||||
expect(llmText).toContain('Help write comprehensive tests.');
|
||||
|
||||
expect(result.returnDisplay).toBe('Launching skill: testing');
|
||||
});
|
||||
|
||||
it('should handle skill not found error', async () => {
|
||||
vi.mocked(mockSkillManager.loadSkillForRuntime).mockResolvedValue(null);
|
||||
|
||||
const params: SkillParams = {
|
||||
skill: 'non-existent',
|
||||
};
|
||||
|
||||
const invocation = (
|
||||
skillTool as SkillToolWithProtectedMethods
|
||||
).createInvocation(params);
|
||||
const result = await invocation.execute();
|
||||
|
||||
const llmText = partToString(result.llmContent);
|
||||
expect(llmText).toContain('Skill "non-existent" not found');
|
||||
});
|
||||
|
||||
it('should handle execution errors gracefully', async () => {
|
||||
vi.mocked(mockSkillManager.loadSkillForRuntime).mockRejectedValue(
|
||||
new Error('Loading failed'),
|
||||
);
|
||||
|
||||
const consoleSpy = vi
|
||||
.spyOn(console, 'error')
|
||||
.mockImplementation(() => {});
|
||||
|
||||
const params: SkillParams = {
|
||||
skill: 'code-review',
|
||||
};
|
||||
|
||||
const invocation = (
|
||||
skillTool as SkillToolWithProtectedMethods
|
||||
).createInvocation(params);
|
||||
const result = await invocation.execute();
|
||||
|
||||
const llmText = partToString(result.llmContent);
|
||||
expect(llmText).toContain('Failed to load skill');
|
||||
expect(llmText).toContain('Loading failed');
|
||||
|
||||
consoleSpy.mockRestore();
|
||||
});
|
||||
|
||||
it('should not require confirmation', async () => {
|
||||
const params: SkillParams = {
|
||||
skill: 'code-review',
|
||||
};
|
||||
|
||||
const invocation = (
|
||||
skillTool as SkillToolWithProtectedMethods
|
||||
).createInvocation(params);
|
||||
const shouldConfirm = await invocation.shouldConfirmExecute();
|
||||
|
||||
expect(shouldConfirm).toBe(false);
|
||||
});
|
||||
|
||||
it('should provide correct description', () => {
|
||||
const params: SkillParams = {
|
||||
skill: 'code-review',
|
||||
};
|
||||
|
||||
const invocation = (
|
||||
skillTool as SkillToolWithProtectedMethods
|
||||
).createInvocation(params);
|
||||
const description = invocation.getDescription();
|
||||
|
||||
expect(description).toBe('Launching skill: "code-review"');
|
||||
});
|
||||
|
||||
it('should handle skill without additional files', async () => {
|
||||
vi.mocked(mockSkillManager.loadSkillForRuntime).mockResolvedValue(
|
||||
mockSkills[0],
|
||||
);
|
||||
|
||||
const params: SkillParams = {
|
||||
skill: 'code-review',
|
||||
};
|
||||
|
||||
const invocation = (
|
||||
skillTool as SkillToolWithProtectedMethods
|
||||
).createInvocation(params);
|
||||
const result = await invocation.execute();
|
||||
|
||||
const llmText = partToString(result.llmContent);
|
||||
expect(llmText).not.toContain('## Additional Files');
|
||||
|
||||
expect(result.returnDisplay).toBe('Launching skill: code-review');
|
||||
});
|
||||
});
|
||||
});
|
||||
264
packages/core/src/tools/skill.ts
Normal file
264
packages/core/src/tools/skill.ts
Normal file
@@ -0,0 +1,264 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Qwen
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { BaseDeclarativeTool, BaseToolInvocation, Kind } from './tools.js';
|
||||
import { ToolNames, ToolDisplayNames } from './tool-names.js';
|
||||
import type { ToolResult, ToolResultDisplay } from './tools.js';
|
||||
import type { Config } from '../config/config.js';
|
||||
import type { SkillManager } from '../skills/skill-manager.js';
|
||||
import type { SkillConfig } from '../skills/types.js';
|
||||
import { logSkillLaunch, SkillLaunchEvent } from '../telemetry/index.js';
|
||||
import path from 'path';
|
||||
|
||||
export interface SkillParams {
|
||||
skill: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* Skill tool that enables the model to access skill definitions.
|
||||
* The tool dynamically loads available skills and includes them in its description
|
||||
* for the model to choose from.
|
||||
*/
|
||||
export class SkillTool extends BaseDeclarativeTool<SkillParams, ToolResult> {
|
||||
static readonly Name: string = ToolNames.SKILL;
|
||||
|
||||
private skillManager: SkillManager;
|
||||
private availableSkills: SkillConfig[] = [];
|
||||
|
||||
constructor(private readonly config: Config) {
|
||||
// Initialize with a basic schema first
|
||||
const initialSchema = {
|
||||
type: 'object',
|
||||
properties: {
|
||||
skill: {
|
||||
type: 'string',
|
||||
description: 'The skill name (no arguments). E.g., "pdf" or "xlsx"',
|
||||
},
|
||||
},
|
||||
required: ['skill'],
|
||||
additionalProperties: false,
|
||||
$schema: 'http://json-schema.org/draft-07/schema#',
|
||||
};
|
||||
|
||||
super(
|
||||
SkillTool.Name,
|
||||
ToolDisplayNames.SKILL,
|
||||
'Execute a skill within the main conversation. Loading available skills...', // Initial description
|
||||
Kind.Read,
|
||||
initialSchema,
|
||||
true, // isOutputMarkdown
|
||||
false, // canUpdateOutput
|
||||
);
|
||||
|
||||
this.skillManager = config.getSkillManager();
|
||||
this.skillManager.addChangeListener(() => {
|
||||
void this.refreshSkills();
|
||||
});
|
||||
|
||||
// Initialize the tool asynchronously
|
||||
this.refreshSkills();
|
||||
}
|
||||
|
||||
/**
|
||||
* Asynchronously initializes the tool by loading available skills
|
||||
* and updating the description and schema.
|
||||
*/
|
||||
async refreshSkills(): Promise<void> {
|
||||
try {
|
||||
this.availableSkills = await this.skillManager.listSkills();
|
||||
this.updateDescriptionAndSchema();
|
||||
} catch (error) {
|
||||
console.warn('Failed to load skills for Skills tool:', error);
|
||||
this.availableSkills = [];
|
||||
this.updateDescriptionAndSchema();
|
||||
} finally {
|
||||
// Update the client with the new tools
|
||||
const geminiClient = this.config.getGeminiClient();
|
||||
if (geminiClient && geminiClient.isInitialized()) {
|
||||
await geminiClient.setTools();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Updates the tool's description and schema based on available skills.
|
||||
*/
|
||||
private updateDescriptionAndSchema(): void {
|
||||
let skillDescriptions = '';
|
||||
if (this.availableSkills.length === 0) {
|
||||
skillDescriptions =
|
||||
'No skills are currently configured. Skills can be created by adding directories with SKILL.md files to .qwen/skills/ or ~/.qwen/skills/.';
|
||||
} else {
|
||||
skillDescriptions = this.availableSkills
|
||||
.map(
|
||||
(skill) => `<skill>
|
||||
<name>
|
||||
${skill.name}
|
||||
</name>
|
||||
<description>
|
||||
${skill.description} (${skill.level})
|
||||
</description>
|
||||
<location>
|
||||
${skill.level}
|
||||
</location>
|
||||
</skill>`,
|
||||
)
|
||||
.join('\n');
|
||||
}
|
||||
|
||||
const baseDescription = `Execute a skill within the main conversation
|
||||
|
||||
<skills_instructions>
|
||||
When users ask you to perform tasks, check if any of the available skills below can help complete the task more effectively. Skills provide specialized capabilities and domain knowledge.
|
||||
|
||||
How to invoke:
|
||||
- Use this tool with the skill name only (no arguments)
|
||||
- Examples:
|
||||
- \`skill: "pdf"\` - invoke the pdf skill
|
||||
- \`skill: "xlsx"\` - invoke the xlsx skill
|
||||
- \`skill: "ms-office-suite:pdf"\` - invoke using fully qualified name
|
||||
|
||||
Important:
|
||||
- When a skill is relevant, you must invoke this tool IMMEDIATELY as your first action
|
||||
- NEVER just announce or mention a skill in your text response without actually calling this tool
|
||||
- This is a BLOCKING REQUIREMENT: invoke the relevant Skill tool BEFORE generating any other response about the task
|
||||
- Only use skills listed in <available_skills> below
|
||||
- Do not invoke a skill that is already running
|
||||
- Do not use this tool for built-in CLI commands (like /help, /clear, etc.)
|
||||
</skills_instructions>
|
||||
|
||||
<available_skills>
|
||||
${skillDescriptions}
|
||||
</available_skills>
|
||||
`;
|
||||
// Update description using object property assignment
|
||||
(this as { description: string }).description = baseDescription;
|
||||
}
|
||||
|
||||
override validateToolParams(params: SkillParams): string | null {
|
||||
// Validate required fields
|
||||
if (
|
||||
!params.skill ||
|
||||
typeof params.skill !== 'string' ||
|
||||
params.skill.trim() === ''
|
||||
) {
|
||||
return 'Parameter "skill" must be a non-empty string.';
|
||||
}
|
||||
|
||||
// Validate that the skill exists
|
||||
const skillExists = this.availableSkills.some(
|
||||
(skill) => skill.name === params.skill,
|
||||
);
|
||||
|
||||
if (!skillExists) {
|
||||
const availableNames = this.availableSkills.map((s) => s.name);
|
||||
if (availableNames.length === 0) {
|
||||
return `Skill "${params.skill}" not found. No skills are currently available.`;
|
||||
}
|
||||
return `Skill "${params.skill}" not found. Available skills: ${availableNames.join(', ')}`;
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
protected createInvocation(params: SkillParams) {
|
||||
return new SkillToolInvocation(this.config, this.skillManager, params);
|
||||
}
|
||||
|
||||
getAvailableSkillNames(): string[] {
|
||||
return this.availableSkills.map((skill) => skill.name);
|
||||
}
|
||||
}
|
||||
|
||||
class SkillToolInvocation extends BaseToolInvocation<SkillParams, ToolResult> {
|
||||
constructor(
|
||||
private readonly config: Config,
|
||||
private readonly skillManager: SkillManager,
|
||||
params: SkillParams,
|
||||
) {
|
||||
super(params);
|
||||
}
|
||||
|
||||
getDescription(): string {
|
||||
return `Launching skill: "${this.params.skill}"`;
|
||||
}
|
||||
|
||||
override async shouldConfirmExecute(): Promise<false> {
|
||||
// Skill loading is a read-only operation, no confirmation needed
|
||||
return false;
|
||||
}
|
||||
|
||||
async execute(
|
||||
_signal?: AbortSignal,
|
||||
_updateOutput?: (output: ToolResultDisplay) => void,
|
||||
): Promise<ToolResult> {
|
||||
try {
|
||||
// Load the skill with runtime config (includes additional files)
|
||||
const skill = await this.skillManager.loadSkillForRuntime(
|
||||
this.params.skill,
|
||||
);
|
||||
|
||||
if (!skill) {
|
||||
// Log failed skill launch
|
||||
logSkillLaunch(
|
||||
this.config,
|
||||
new SkillLaunchEvent(this.params.skill, false),
|
||||
);
|
||||
|
||||
// Get parse errors if any
|
||||
const parseErrors = this.skillManager.getParseErrors();
|
||||
const errorMessages: string[] = [];
|
||||
|
||||
for (const [filePath, error] of parseErrors) {
|
||||
if (filePath.includes(this.params.skill)) {
|
||||
errorMessages.push(`Parse error at ${filePath}: ${error.message}`);
|
||||
}
|
||||
}
|
||||
|
||||
const errorDetail =
|
||||
errorMessages.length > 0
|
||||
? `\nErrors:\n${errorMessages.join('\n')}`
|
||||
: '';
|
||||
|
||||
return {
|
||||
llmContent: `Skill "${this.params.skill}" not found.${errorDetail}`,
|
||||
returnDisplay: `Skill "${this.params.skill}" not found.${errorDetail}`,
|
||||
};
|
||||
}
|
||||
|
||||
// Log successful skill launch
|
||||
logSkillLaunch(
|
||||
this.config,
|
||||
new SkillLaunchEvent(this.params.skill, true),
|
||||
);
|
||||
|
||||
const baseDir = path.dirname(skill.filePath);
|
||||
|
||||
// Build markdown content for LLM (show base dir, then body)
|
||||
const llmContent = `Base directory for this skill: ${baseDir}\n\n${skill.body}\n`;
|
||||
|
||||
return {
|
||||
llmContent: [{ text: llmContent }],
|
||||
returnDisplay: `Launching skill: ${skill.name}`,
|
||||
};
|
||||
} catch (error) {
|
||||
const errorMessage =
|
||||
error instanceof Error ? error.message : String(error);
|
||||
console.error(`[SkillsTool] Error launching skill: ${errorMessage}`);
|
||||
|
||||
// Log failed skill launch
|
||||
logSkillLaunch(
|
||||
this.config,
|
||||
new SkillLaunchEvent(this.params.skill, false),
|
||||
);
|
||||
|
||||
return {
|
||||
llmContent: `Failed to load skill "${this.params.skill}": ${errorMessage}`,
|
||||
returnDisplay: `Failed to load skill "${this.params.skill}": ${errorMessage}`,
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user