Compare commits

..

13 Commits

Author SHA1 Message Date
tanzhenxin
251031cfc5 fix a link in skills.md 2025-12-23 15:09:23 +08:00
tanzhenxin
77c257d9d0 fix flaky tests 2025-12-23 14:50:47 +08:00
tanzhenxin
4311af96eb add docs 2025-12-23 10:53:09 +08:00
tanzhenxin
b49c11e9a2 add experimental-skills flag to enable skills feature 2025-12-23 10:24:57 +08:00
tanzhenxin
9cdd85c62a Merge branch 'main' into feat/skills 2025-12-22 16:00:57 +08:00
tanzhenxin
00547ba439 Merge pull request #1311 from QwenLM/fix/e2e
fix e2e workflow
2025-12-22 14:54:07 +08:00
tanzhenxin
fc1dac9dc7 update 2025-12-22 14:32:51 +08:00
tanzhenxin
338eb9038d fix e2e workflow 2025-12-22 14:28:36 +08:00
tanzhenxin
e0b9044833 Merge pull request #1310 from QwenLM/fix/process-info-robust-20251222
Improve robustness of getProcessInfo with try-catch and empty output fallback
2025-12-22 14:02:51 +08:00
xuewenjie
f33f43e2f7 feat: improve getProcessInfo robustness with try-catch and empty output fallback 2025-12-22 11:38:38 +08:00
tanzhenxin
177fc42f04 Merge branch 'main' into feat/skills 2025-12-15 14:25:56 +08:00
tanzhenxin
2560c2d1a2 Merge branch 'main' into feat/skills 2025-12-11 14:50:07 +08:00
tanzhenxin
bd6e16d41b draft version of skill tool feature 2025-12-10 17:18:44 +08:00
118 changed files with 9758 additions and 2609 deletions

View File

@@ -18,8 +18,6 @@ jobs:
- 'sandbox:docker'
node-version:
- '20.x'
- '22.x'
- '24.x'
steps:
- name: 'Checkout'
uses: 'actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8' # ratchet:actions/checkout@v5
@@ -67,10 +65,13 @@ jobs:
OPENAI_BASE_URL: '${{ secrets.OPENAI_BASE_URL }}'
OPENAI_MODEL: '${{ secrets.OPENAI_MODEL }}'
KEEP_OUTPUT: 'true'
SANDBOX: '${{ matrix.sandbox }}'
VERBOSE: 'true'
run: |-
npm run "test:integration:${SANDBOX}"
if [[ "${{ matrix.sandbox }}" == "sandbox:docker" ]]; then
npm run test:integration:sandbox:docker
else
npm run test:integration:sandbox:none
fi
e2e-test-macos:
name: 'E2E Test - macOS'

View File

@@ -43,6 +43,7 @@ Qwen Code uses JSON settings files for persistent configuration. There are four
In addition to a project settings file, a project's `.qwen` directory can contain other project-specific files related to Qwen Code's operation, such as:
- [Custom sandbox profiles](../features/sandbox) (e.g. `.qwen/sandbox-macos-custom.sb`, `.qwen/sandbox.Dockerfile`).
- [Agent Skills](../features/skills) (experimental) under `.qwen/skills/` (each Skill is a directory containing a `SKILL.md`).
### Available settings in `settings.json`
@@ -380,6 +381,8 @@ Arguments passed directly when running the CLI can override other configurations
| `--telemetry-otlp-protocol` | | Sets the OTLP protocol for telemetry (`grpc` or `http`). | | Defaults to `grpc`. See [telemetry](../../developers/development/telemetry) for more information. |
| `--telemetry-log-prompts` | | Enables logging of prompts for telemetry. | | See [telemetry](../../developers/development/telemetry) for more information. |
| `--checkpointing` | | Enables [checkpointing](../features/checkpointing). | | |
| `--experimental-acp` | | Enables ACP mode (Agent Control Protocol). Useful for IDE/editor integrations like [Zed](../integration-zed). | | Experimental. |
| `--experimental-skills` | | Enables experimental [Agent Skills](../features/skills) (registers the `skill` tool and loads Skills from `.qwen/skills/` and `~/.qwen/skills/`). | | Experimental. |
| `--extensions` | `-e` | Specifies a list of extensions to use for the session. | Extension names | If not provided, all available extensions are used. Use the special term `qwen -e none` to disable all extensions. Example: `qwen -e my-extension -e my-other-extension` |
| `--list-extensions` | `-l` | Lists all available extensions and exits. | | |
| `--proxy` | | Sets the proxy for the CLI. | Proxy URL | Example: `--proxy http://localhost:7890`. |

View File

@@ -1,6 +1,7 @@
export default {
commands: 'Commands',
'sub-agents': 'SubAgents',
skills: 'Skills (Experimental)',
headless: 'Headless Mode',
checkpointing: {
display: 'hidden',

View File

@@ -189,19 +189,20 @@ qwen -p "Write code" --output-format stream-json --include-partial-messages | jq
Key command-line options for headless usage:
| Option | Description | Example |
| ---------------------------- | --------------------------------------------------- | ------------------------------------------------------------------------ |
| `--prompt`, `-p` | Run in headless mode | `qwen -p "query"` |
| `--output-format`, `-o` | Specify output format (text, json, stream-json) | `qwen -p "query" --output-format json` |
| `--input-format` | Specify input format (text, stream-json) | `qwen --input-format text --output-format stream-json` |
| `--include-partial-messages` | Include partial messages in stream-json output | `qwen -p "query" --output-format stream-json --include-partial-messages` |
| `--debug`, `-d` | Enable debug mode | `qwen -p "query" --debug` |
| `--all-files`, `-a` | Include all files in context | `qwen -p "query" --all-files` |
| `--include-directories` | Include additional directories | `qwen -p "query" --include-directories src,docs` |
| `--yolo`, `-y` | Auto-approve all actions | `qwen -p "query" --yolo` |
| `--approval-mode` | Set approval mode | `qwen -p "query" --approval-mode auto_edit` |
| `--continue` | Resume the most recent session for this project | `qwen --continue -p "Pick up where we left off"` |
| `--resume [sessionId]` | Resume a specific session (or choose interactively) | `qwen --resume 123e... -p "Finish the refactor"` |
| Option | Description | Example |
| ---------------------------- | ------------------------------------------------------- | ------------------------------------------------------------------------ |
| `--prompt`, `-p` | Run in headless mode | `qwen -p "query"` |
| `--output-format`, `-o` | Specify output format (text, json, stream-json) | `qwen -p "query" --output-format json` |
| `--input-format` | Specify input format (text, stream-json) | `qwen --input-format text --output-format stream-json` |
| `--include-partial-messages` | Include partial messages in stream-json output | `qwen -p "query" --output-format stream-json --include-partial-messages` |
| `--debug`, `-d` | Enable debug mode | `qwen -p "query" --debug` |
| `--all-files`, `-a` | Include all files in context | `qwen -p "query" --all-files` |
| `--include-directories` | Include additional directories | `qwen -p "query" --include-directories src,docs` |
| `--yolo`, `-y` | Auto-approve all actions | `qwen -p "query" --yolo` |
| `--approval-mode` | Set approval mode | `qwen -p "query" --approval-mode auto_edit` |
| `--continue` | Resume the most recent session for this project | `qwen --continue -p "Pick up where we left off"` |
| `--resume [sessionId]` | Resume a specific session (or choose interactively) | `qwen --resume 123e... -p "Finish the refactor"` |
| `--experimental-skills` | Enable experimental Skills (registers the `skill` tool) | `qwen --experimental-skills -p "What Skills are available?"` |
For complete details on all available configuration options, settings files, and environment variables, see the [Configuration Guide](../configuration/settings).

View File

@@ -0,0 +1,282 @@
# Agent Skills (Experimental)
> Create, manage, and share Skills to extend Qwen Codes capabilities.
This guide shows you how to create, use, and manage Agent Skills in **Qwen Code**. Skills are modular capabilities that extend the models effectiveness through organized folders containing instructions (and optionally scripts/resources).
> [!note]
>
> Skills are currently **experimental** and must be enabled with `--experimental-skills`.
## Prerequisites
- Qwen Code (recent version)
- Run with the experimental flag enabled:
```bash
qwen --experimental-skills
```
- Basic familiarity with Qwen Code ([Quickstart](../quickstart.md))
## What are Agent Skills?
Agent Skills package expertise into discoverable capabilities. Each Skill consists of a `SKILL.md` file with instructions that the model can load when relevant, plus optional supporting files like scripts and templates.
### How Skills are invoked
Skills are **model-invoked** — the model autonomously decides when to use them based on your request and the Skills description. This is different from slash commands, which are **user-invoked** (you explicitly type `/command`).
### Benefits
- Extend Qwen Code for your workflows
- Share expertise across your team via git
- Reduce repetitive prompting
- Compose multiple Skills for complex tasks
## Create a Skill
Skills are stored as directories containing a `SKILL.md` file.
### Personal Skills
Personal Skills are available across all your projects. Store them in `~/.qwen/skills/`:
```bash
mkdir -p ~/.qwen/skills/my-skill-name
```
Use personal Skills for:
- Your individual workflows and preferences
- Experimental Skills youre developing
- Personal productivity helpers
### Project Skills
Project Skills are shared with your team. Store them in `.qwen/skills/` within your project:
```bash
mkdir -p .qwen/skills/my-skill-name
```
Use project Skills for:
- Team workflows and conventions
- Project-specific expertise
- Shared utilities and scripts
Project Skills can be checked into git and automatically become available to teammates.
## Write `SKILL.md`
Create a `SKILL.md` file with YAML frontmatter and Markdown content:
```yaml
---
name: your-skill-name
description: Brief description of what this Skill does and when to use it
---
# Your Skill Name
## Instructions
Provide clear, step-by-step guidance for Qwen Code.
## Examples
Show concrete examples of using this Skill.
```
### Field requirements
Qwen Code currently validates that:
- `name` is a non-empty string
- `description` is a non-empty string
Recommended conventions (not strictly enforced yet):
- Use lowercase letters, numbers, and hyphens in `name`
- Make `description` specific: include both **what** the Skill does and **when** to use it (key words users will naturally mention)
## Add supporting files
Create additional files alongside `SKILL.md`:
```text
my-skill/
├── SKILL.md (required)
├── reference.md (optional documentation)
├── examples.md (optional examples)
├── scripts/
│ └── helper.py (optional utility)
└── templates/
└── template.txt (optional template)
```
Reference these files from `SKILL.md`:
````markdown
For advanced usage, see [reference.md](reference.md).
Run the helper script:
```bash
python scripts/helper.py input.txt
```
````
## View available Skills
When `--experimental-skills` is enabled, Qwen Code discovers Skills from:
- Personal Skills: `~/.qwen/skills/`
- Project Skills: `.qwen/skills/`
To view available Skills, ask Qwen Code directly:
```text
What Skills are available?
```
Or inspect the filesystem:
```bash
# List personal Skills
ls ~/.qwen/skills/
# List project Skills (if in a project directory)
ls .qwen/skills/
# View a specific Skills content
cat ~/.qwen/skills/my-skill/SKILL.md
```
## Test a Skill
After creating a Skill, test it by asking questions that match your description.
Example: if your description mentions “PDF files”:
```text
Can you help me extract text from this PDF?
```
The model autonomously decides to use your Skill if it matches the request — you dont need to explicitly invoke it.
## Debug a Skill
If Qwen Code doesnt use your Skill, check these common issues:
### Make the description specific
Too vague:
```yaml
description: Helps with documents
```
Specific:
```yaml
description: Extract text and tables from PDF files, fill forms, merge documents. Use when working with PDFs, forms, or document extraction.
```
### Verify file path
- Personal Skills: `~/.qwen/skills/<skill-name>/SKILL.md`
- Project Skills: `.qwen/skills/<skill-name>/SKILL.md`
```bash
# Personal
ls ~/.qwen/skills/my-skill/SKILL.md
# Project
ls .qwen/skills/my-skill/SKILL.md
```
### Check YAML syntax
Invalid YAML prevents the Skill metadata from loading correctly.
```bash
cat SKILL.md | head -n 15
```
Ensure:
- Opening `---` on line 1
- Closing `---` before Markdown content
- Valid YAML syntax (no tabs, correct indentation)
### View errors
Run Qwen Code with debug mode to see Skill loading errors:
```bash
qwen --experimental-skills --debug
```
## Share Skills with your team
You can share Skills through project repositories:
1. Add the Skill under `.qwen/skills/`
2. Commit and push
3. Teammates pull the changes and run with `--experimental-skills`
```bash
git add .qwen/skills/
git commit -m "Add team Skill for PDF processing"
git push
```
## Update a Skill
Edit `SKILL.md` directly:
```bash
# Personal Skill
code ~/.qwen/skills/my-skill/SKILL.md
# Project Skill
code .qwen/skills/my-skill/SKILL.md
```
Changes take effect the next time you start Qwen Code. If Qwen Code is already running, restart it to load the updates.
## Remove a Skill
Delete the Skill directory:
```bash
# Personal
rm -rf ~/.qwen/skills/my-skill
# Project
rm -rf .qwen/skills/my-skill
git commit -m "Remove unused Skill"
```
## Best practices
### Keep Skills focused
One Skill should address one capability:
- Focused: “PDF form filling”, “Excel analysis”, “Git commit messages”
- Too broad: “Document processing” (split into smaller Skills)
### Write clear descriptions
Help the model discover when to use Skills by including specific triggers:
```yaml
description: Analyze Excel spreadsheets, create pivot tables, and generate charts. Use when working with Excel files, spreadsheets, or .xlsx data.
```
### Test with your team
- Does the Skill activate when expected?
- Are the instructions clear?
- Are there missing examples or edge cases?

2007
package-lock.json generated

File diff suppressed because it is too large Load Diff

View File

@@ -18,6 +18,9 @@
"scripts": {
"start": "cross-env node scripts/start.js",
"debug": "cross-env DEBUG=1 node --inspect-brk scripts/start.js",
"auth:npm": "npx google-artifactregistry-auth",
"auth:docker": "gcloud auth configure-docker us-west1-docker.pkg.dev",
"auth": "npm run auth:npm && npm run auth:docker",
"generate": "node scripts/generate-git-commit-info.js",
"build": "node scripts/build.js",
"build-and-start": "npm run build && npm run start",
@@ -92,6 +95,7 @@
"eslint-plugin-react-hooks": "^5.2.0",
"glob": "^10.5.0",
"globals": "^16.0.0",
"google-artifactregistry-auth": "^3.4.0",
"husky": "^9.1.7",
"json": "^11.0.0",
"lint-staged": "^16.1.6",

View File

@@ -36,10 +36,10 @@
"sandboxImageUri": "ghcr.io/qwenlm/qwen-code:0.6.0"
},
"dependencies": {
"@google/genai": "1.30.0",
"@google/genai": "1.16.0",
"@iarna/toml": "^2.2.5",
"@qwen-code/qwen-code-core": "file:../core",
"@modelcontextprotocol/sdk": "^1.25.1",
"@modelcontextprotocol/sdk": "^1.15.1",
"@types/update-notifier": "^6.0.8",
"ansi-regex": "^6.2.2",
"command-exists": "^1.2.9",

View File

@@ -26,23 +26,5 @@ export function validateAuthMethod(authMethod: string): string | null {
return null;
}
if (authMethod === AuthType.USE_GEMINI) {
const hasApiKey = process.env['GEMINI_API_KEY'];
if (!hasApiKey) {
return 'GEMINI_API_KEY environment variable not found. Please set it in your .env file or environment variables.';
}
return null;
}
if (authMethod === AuthType.USE_VERTEX_AI) {
const hasApiKey = process.env['GOOGLE_API_KEY'];
if (!hasApiKey) {
return 'GOOGLE_API_KEY environment variable not found. Please set it in your .env file or environment variables.';
}
process.env['GOOGLE_GENAI_USE_VERTEXAI'] = 'true';
return null;
}
return 'Invalid auth method selected.';
}

View File

@@ -112,6 +112,7 @@ export interface CliArgs {
allowedMcpServerNames: string[] | undefined;
allowedTools: string[] | undefined;
experimentalAcp: boolean | undefined;
experimentalSkills: boolean | undefined;
extensions: string[] | undefined;
listExtensions: boolean | undefined;
openaiLogging: boolean | undefined;
@@ -307,6 +308,11 @@ export async function parseArguments(settings: Settings): Promise<CliArgs> {
type: 'boolean',
description: 'Starts the agent in ACP mode',
})
.option('experimental-skills', {
type: 'boolean',
description: 'Enable experimental Skills feature',
default: false,
})
.option('channel', {
type: 'string',
choices: ['VSCode', 'ACP', 'SDK', 'CI'],
@@ -460,12 +466,7 @@ export async function parseArguments(settings: Settings): Promise<CliArgs> {
})
.option('auth-type', {
type: 'string',
choices: [
AuthType.USE_OPENAI,
AuthType.QWEN_OAUTH,
AuthType.USE_GEMINI,
AuthType.USE_VERTEX_AI,
],
choices: [AuthType.USE_OPENAI, AuthType.QWEN_OAUTH],
description: 'Authentication type',
})
.deprecateOption(
@@ -956,6 +957,7 @@ export async function loadCliConfig(
maxSessionTurns:
argv.maxSessionTurns ?? settings.model?.maxSessionTurns ?? -1,
experimentalZedIntegration: argv.experimentalAcp || false,
experimentalSkills: argv.experimentalSkills || false,
listExtensions: argv.listExtensions || false,
extensions: allExtensions,
blockedMcpServers,

View File

@@ -461,6 +461,7 @@ describe('gemini.tsx main function kitty protocol', () => {
allowedMcpServerNames: undefined,
allowedTools: undefined,
experimentalAcp: undefined,
experimentalSkills: undefined,
extensions: undefined,
listExtensions: undefined,
openaiLogging: undefined,

View File

@@ -4,8 +4,13 @@
* SPDX-License-Identifier: Apache-2.0
*/
import type { Config, AuthType } from '@qwen-code/qwen-code-core';
import { InputFormat, logUserPrompt } from '@qwen-code/qwen-code-core';
import type { Config } from '@qwen-code/qwen-code-core';
import {
AuthType,
getOauthClient,
InputFormat,
logUserPrompt,
} from '@qwen-code/qwen-code-core';
import { render } from 'ink';
import dns from 'node:dns';
import os from 'node:os';
@@ -394,6 +399,15 @@ export async function main() {
initializationResult = await initializeApp(config, settings);
}
if (
settings.merged.security?.auth?.selectedType ===
AuthType.LOGIN_WITH_GOOGLE &&
config.isBrowserLaunchSuppressed()
) {
// Do oauth before app renders to make copying the link possible.
await getOauthClient(settings.merged.security.auth.selectedType, config);
}
if (config.getExperimentalZedIntegration()) {
return runAcpAgent(config, settings, extensions, argv);
}

View File

@@ -610,6 +610,8 @@ export abstract class BaseJsonOutputAdapter {
const errorText = parseAndFormatApiError(
event.value.error,
this.config.getContentGeneratorConfig()?.authType,
undefined,
this.config.getModel(),
);
this.appendText(state, errorText, null);
break;

View File

@@ -221,6 +221,8 @@ export async function runNonInteractive(
const errorText = parseAndFormatApiError(
event.value.error,
config.getContentGeneratorConfig()?.authType,
undefined,
config.getModel(),
);
process.stderr.write(`${errorText}\n`);
}

View File

@@ -28,7 +28,7 @@ const mockPrompt = {
{ name: 'trail', required: false, description: "The animal's trail." },
],
invoke: vi.fn().mockResolvedValue({
messages: [{ content: { type: 'text', text: 'Hello, world!' } }],
messages: [{ content: { text: 'Hello, world!' } }],
}),
};

View File

@@ -123,10 +123,7 @@ export class McpPromptLoader implements ICommandLoader {
};
}
const firstMessage = result.messages?.[0];
const content = firstMessage?.content;
if (content?.type !== 'text') {
if (!result.messages?.[0]?.content?.['text']) {
return {
type: 'message',
messageType: 'error',
@@ -137,7 +134,7 @@ export class McpPromptLoader implements ICommandLoader {
return {
type: 'submit_prompt',
content: JSON.stringify(content.text),
content: JSON.stringify(result.messages[0].content.text),
};
} catch (error) {
return {

View File

@@ -23,6 +23,7 @@ import {
} from '@qwen-code/qwen-code-core';
import type { LoadedSettings } from '../config/settings.js';
import type { InitializationResult } from '../core/initializer.js';
import { useQuotaAndFallback } from './hooks/useQuotaAndFallback.js';
import { UIStateContext, type UIState } from './contexts/UIStateContext.js';
import {
UIActionsContext,
@@ -55,6 +56,7 @@ vi.mock('./App.js', () => ({
App: TestContextConsumer,
}));
vi.mock('./hooks/useQuotaAndFallback.js');
vi.mock('./hooks/useHistoryManager.js');
vi.mock('./hooks/useThemeCommand.js');
vi.mock('./auth/useAuth.js');
@@ -120,6 +122,7 @@ describe('AppContainer State Management', () => {
let mockInitResult: InitializationResult;
// Create typed mocks for all hooks
const mockedUseQuotaAndFallback = useQuotaAndFallback as Mock;
const mockedUseHistory = useHistory as Mock;
const mockedUseThemeCommand = useThemeCommand as Mock;
const mockedUseAuthCommand = useAuthCommand as Mock;
@@ -161,6 +164,10 @@ describe('AppContainer State Management', () => {
capturedUIActions = null!;
// **Provide a default return value for EVERY mocked hook.**
mockedUseQuotaAndFallback.mockReturnValue({
proQuotaRequest: null,
handleProQuotaChoice: vi.fn(),
});
mockedUseHistory.mockReturnValue({
history: [],
addItem: vi.fn(),
@@ -560,6 +567,75 @@ describe('AppContainer State Management', () => {
});
});
describe('Quota and Fallback Integration', () => {
it('passes a null proQuotaRequest to UIStateContext by default', () => {
// The default mock from beforeEach already sets proQuotaRequest to null
render(
<AppContainer
config={mockConfig}
settings={mockSettings}
version="1.0.0"
initializationResult={mockInitResult}
/>,
);
// Assert that the context value is as expected
expect(capturedUIState.proQuotaRequest).toBeNull();
});
it('passes a valid proQuotaRequest to UIStateContext when provided by the hook', () => {
// Arrange: Create a mock request object that a UI dialog would receive
const mockRequest = {
failedModel: 'gemini-pro',
fallbackModel: 'gemini-flash',
resolve: vi.fn(),
};
mockedUseQuotaAndFallback.mockReturnValue({
proQuotaRequest: mockRequest,
handleProQuotaChoice: vi.fn(),
});
// Act: Render the container
render(
<AppContainer
config={mockConfig}
settings={mockSettings}
version="1.0.0"
initializationResult={mockInitResult}
/>,
);
// Assert: The mock request is correctly passed through the context
expect(capturedUIState.proQuotaRequest).toEqual(mockRequest);
});
it('passes the handleProQuotaChoice function to UIActionsContext', () => {
// Arrange: Create a mock handler function
const mockHandler = vi.fn();
mockedUseQuotaAndFallback.mockReturnValue({
proQuotaRequest: null,
handleProQuotaChoice: mockHandler,
});
// Act: Render the container
render(
<AppContainer
config={mockConfig}
settings={mockSettings}
version="1.0.0"
initializationResult={mockInitResult}
/>,
);
// Assert: The action in the context is the mock handler we provided
expect(capturedUIActions.handleProQuotaChoice).toBe(mockHandler);
// You can even verify that the plumbed function is callable
capturedUIActions.handleProQuotaChoice('auth');
expect(mockHandler).toHaveBeenCalledWith('auth');
});
});
describe('Terminal Title Update Feature', () => {
beforeEach(() => {
// Reset mock stdout for each test

View File

@@ -32,6 +32,7 @@ import {
type Config,
type IdeInfo,
type IdeContext,
type UserTierId,
DEFAULT_GEMINI_FLASH_MODEL,
IdeClient,
ideContextStore,
@@ -47,6 +48,7 @@ import { useHistory } from './hooks/useHistoryManager.js';
import { useMemoryMonitor } from './hooks/useMemoryMonitor.js';
import { useThemeCommand } from './hooks/useThemeCommand.js';
import { useAuthCommand } from './auth/useAuth.js';
import { useQuotaAndFallback } from './hooks/useQuotaAndFallback.js';
import { useEditorSettings } from './hooks/useEditorSettings.js';
import { useSettingsCommand } from './hooks/useSettingsCommand.js';
import { useModelCommand } from './hooks/useModelCommand.js';
@@ -190,6 +192,8 @@ export const AppContainer = (props: AppContainerProps) => {
const [currentModel, setCurrentModel] = useState(getEffectiveModel());
const [userTier] = useState<UserTierId | undefined>(undefined);
const [isConfigInitialized, setConfigInitialized] = useState(false);
const [userMessages, setUserMessages] = useState<string[]>([]);
@@ -363,6 +367,14 @@ export const AppContainer = (props: AppContainerProps) => {
cancelAuthentication,
} = useAuthCommand(settings, config, historyManager.addItem);
const { proQuotaRequest, handleProQuotaChoice } = useQuotaAndFallback({
config,
historyManager,
userTier,
setAuthState,
setModelSwitchedFromQuotaError,
});
useInitializationAuthError(initializationResult.authError, onAuthError);
// Sync user tier from config when authentication changes
@@ -740,7 +752,8 @@ export const AppContainer = (props: AppContainerProps) => {
!initError &&
!isProcessing &&
(streamingState === StreamingState.Idle ||
streamingState === StreamingState.Responding);
streamingState === StreamingState.Responding) &&
!proQuotaRequest;
const [controlsHeight, setControlsHeight] = useState(0);
@@ -1193,6 +1206,7 @@ export const AppContainer = (props: AppContainerProps) => {
isAuthenticating ||
isEditorDialogOpen ||
showIdeRestartPrompt ||
!!proQuotaRequest ||
isSubagentCreateDialogOpen ||
isAgentsManagerDialogOpen ||
isApprovalModeDialogOpen ||
@@ -1263,6 +1277,8 @@ export const AppContainer = (props: AppContainerProps) => {
showWorkspaceMigrationDialog,
workspaceExtensions,
currentModel,
userTier,
proQuotaRequest,
contextFileNames,
errorCount,
availableTerminalHeight,
@@ -1351,6 +1367,8 @@ export const AppContainer = (props: AppContainerProps) => {
showAutoAcceptIndicator,
showWorkspaceMigrationDialog,
workspaceExtensions,
userTier,
proQuotaRequest,
contextFileNames,
errorCount,
availableTerminalHeight,
@@ -1412,6 +1430,7 @@ export const AppContainer = (props: AppContainerProps) => {
handleClearScreen,
onWorkspaceMigrationDialogOpen,
onWorkspaceMigrationDialogClose,
handleProQuotaChoice,
// Vision switch dialog
handleVisionSwitchSelect,
// Welcome back dialog
@@ -1449,6 +1468,7 @@ export const AppContainer = (props: AppContainerProps) => {
handleClearScreen,
onWorkspaceMigrationDialogOpen,
onWorkspaceMigrationDialogClose,
handleProQuotaChoice,
handleVisionSwitchSelect,
handleWelcomeBackSelection,
handleWelcomeBackClose,

View File

@@ -168,7 +168,7 @@ describe('AuthDialog', () => {
it('should not show the GEMINI_API_KEY message if QWEN_DEFAULT_AUTH_TYPE is set to something else', () => {
process.env['GEMINI_API_KEY'] = 'foobar';
process.env['QWEN_DEFAULT_AUTH_TYPE'] = AuthType.USE_OPENAI;
process.env['QWEN_DEFAULT_AUTH_TYPE'] = AuthType.LOGIN_WITH_GOOGLE;
const settings: LoadedSettings = new LoadedSettings(
{
@@ -212,7 +212,7 @@ describe('AuthDialog', () => {
it('should show the GEMINI_API_KEY message if QWEN_DEFAULT_AUTH_TYPE is set to use api key', () => {
process.env['GEMINI_API_KEY'] = 'foobar';
process.env['QWEN_DEFAULT_AUTH_TYPE'] = AuthType.USE_OPENAI;
process.env['QWEN_DEFAULT_AUTH_TYPE'] = AuthType.USE_GEMINI;
const settings: LoadedSettings = new LoadedSettings(
{
@@ -504,12 +504,12 @@ describe('AuthDialog', () => {
},
{
settings: {
security: { auth: { selectedType: AuthType.USE_OPENAI } },
security: { auth: { selectedType: AuthType.LOGIN_WITH_GOOGLE } },
ui: { customThemes: {} },
mcpServers: {},
},
originalSettings: {
security: { auth: { selectedType: AuthType.USE_OPENAI } },
security: { auth: { selectedType: AuthType.LOGIN_WITH_GOOGLE } },
ui: { customThemes: {} },
mcpServers: {},
},

View File

@@ -225,24 +225,16 @@ export const useAuthCommand = (
const defaultAuthType = process.env['QWEN_DEFAULT_AUTH_TYPE'];
if (
defaultAuthType &&
![
AuthType.QWEN_OAUTH,
AuthType.USE_OPENAI,
AuthType.USE_GEMINI,
AuthType.USE_VERTEX_AI,
].includes(defaultAuthType as AuthType)
![AuthType.QWEN_OAUTH, AuthType.USE_OPENAI].includes(
defaultAuthType as AuthType,
)
) {
onAuthError(
t(
'Invalid QWEN_DEFAULT_AUTH_TYPE value: "{{value}}". Valid values are: {{validValues}}',
{
value: defaultAuthType,
validValues: [
AuthType.QWEN_OAUTH,
AuthType.USE_OPENAI,
AuthType.USE_GEMINI,
AuthType.USE_VERTEX_AI,
].join(', '),
validValues: [AuthType.QWEN_OAUTH, AuthType.USE_OPENAI].join(', '),
},
),
);

View File

@@ -15,6 +15,7 @@ vi.mock('@qwen-code/qwen-code-core', async (importOriginal) => {
const original = await importOriginal<typeof core>();
return {
...original,
getOauthClient: vi.fn(original.getOauthClient),
getIdeInstaller: vi.fn(original.getIdeInstaller),
IdeClient: {
getInstance: vi.fn(),

View File

@@ -17,6 +17,7 @@ import { AuthDialog } from '../auth/AuthDialog.js';
import { OpenAIKeyPrompt } from './OpenAIKeyPrompt.js';
import { EditorSettingsDialog } from './EditorSettingsDialog.js';
import { WorkspaceMigrationDialog } from './WorkspaceMigrationDialog.js';
import { ProQuotaDialog } from './ProQuotaDialog.js';
import { PermissionsModifyTrustDialog } from './PermissionsModifyTrustDialog.js';
import { ModelDialog } from './ModelDialog.js';
import { ApprovalModeDialog } from './ApprovalModeDialog.js';
@@ -86,6 +87,15 @@ export const DialogManager = ({
/>
);
}
if (uiState.proQuotaRequest) {
return (
<ProQuotaDialog
failedModel={uiState.proQuotaRequest.failedModel}
fallbackModel={uiState.proQuotaRequest.fallbackModel}
onChoice={uiActions.handleProQuotaChoice}
/>
);
}
if (uiState.shouldShowIdePrompt) {
return (
<IdeIntegrationNudge

View File

@@ -0,0 +1,91 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { render } from 'ink-testing-library';
import { describe, it, expect, vi, beforeEach, type Mock } from 'vitest';
import { ProQuotaDialog } from './ProQuotaDialog.js';
import { RadioButtonSelect } from './shared/RadioButtonSelect.js';
// Mock the child component to make it easier to test the parent
vi.mock('./shared/RadioButtonSelect.js', () => ({
RadioButtonSelect: vi.fn(),
}));
describe('ProQuotaDialog', () => {
beforeEach(() => {
vi.clearAllMocks();
});
it('should render with correct title and options', () => {
const { lastFrame } = render(
<ProQuotaDialog
failedModel="gemini-2.5-pro"
fallbackModel="gemini-2.5-flash"
onChoice={() => {}}
/>,
);
const output = lastFrame();
expect(output).toContain('Pro quota limit reached for gemini-2.5-pro.');
// Check that RadioButtonSelect was called with the correct items
expect(RadioButtonSelect).toHaveBeenCalledWith(
expect.objectContaining({
items: [
{
label: 'Change auth (executes the /auth command)',
value: 'auth',
key: 'auth',
},
{
label: `Continue with gemini-2.5-flash`,
value: 'continue',
key: 'continue',
},
],
}),
undefined,
);
});
it('should call onChoice with "auth" when "Change auth" is selected', () => {
const mockOnChoice = vi.fn();
render(
<ProQuotaDialog
failedModel="gemini-2.5-pro"
fallbackModel="gemini-2.5-flash"
onChoice={mockOnChoice}
/>,
);
// Get the onSelect function passed to RadioButtonSelect
const onSelect = (RadioButtonSelect as Mock).mock.calls[0][0].onSelect;
// Simulate the selection
onSelect('auth');
expect(mockOnChoice).toHaveBeenCalledWith('auth');
});
it('should call onChoice with "continue" when "Continue with flash" is selected', () => {
const mockOnChoice = vi.fn();
render(
<ProQuotaDialog
failedModel="gemini-2.5-pro"
fallbackModel="gemini-2.5-flash"
onChoice={mockOnChoice}
/>,
);
// Get the onSelect function passed to RadioButtonSelect
const onSelect = (RadioButtonSelect as Mock).mock.calls[0][0].onSelect;
// Simulate the selection
onSelect('continue');
expect(mockOnChoice).toHaveBeenCalledWith('continue');
});
});

View File

@@ -0,0 +1,55 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import type React from 'react';
import { Box, Text } from 'ink';
import { RadioButtonSelect } from './shared/RadioButtonSelect.js';
import { theme } from '../semantic-colors.js';
import { t } from '../../i18n/index.js';
interface ProQuotaDialogProps {
failedModel: string;
fallbackModel: string;
onChoice: (choice: 'auth' | 'continue') => void;
}
export function ProQuotaDialog({
failedModel,
fallbackModel,
onChoice,
}: ProQuotaDialogProps): React.JSX.Element {
const items = [
{
label: t('Change auth (executes the /auth command)'),
value: 'auth' as const,
key: 'auth',
},
{
label: t('Continue with {{model}}', { model: fallbackModel }),
value: 'continue' as const,
key: 'continue',
},
];
const handleSelect = (choice: 'auth' | 'continue') => {
onChoice(choice);
};
return (
<Box borderStyle="round" flexDirection="column" paddingX={1}>
<Text bold color={theme.status.warning}>
{t('Pro quota limit reached for {{model}}.', { model: failedModel })}
</Text>
<Box marginTop={1}>
<RadioButtonSelect
items={items}
initialIndex={1}
onSelect={handleSelect}
/>
</Box>
</Box>
);
}

View File

@@ -55,6 +55,7 @@ export interface UIActions {
handleClearScreen: () => void;
onWorkspaceMigrationDialogOpen: () => void;
onWorkspaceMigrationDialogClose: () => void;
handleProQuotaChoice: (choice: 'auth' | 'continue') => void;
// Vision switch dialog
handleVisionSwitchSelect: (outcome: VisionSwitchOutcome) => void;
// Welcome back dialog

View File

@@ -22,13 +22,21 @@ import type {
AuthType,
IdeContext,
ApprovalMode,
UserTierId,
IdeInfo,
FallbackIntent,
} from '@qwen-code/qwen-code-core';
import type { DOMElement } from 'ink';
import type { SessionStatsState } from '../contexts/SessionContext.js';
import type { ExtensionUpdateState } from '../state/extensions.js';
import type { UpdateObject } from '../utils/updateCheck.js';
export interface ProQuotaDialogRequest {
failedModel: string;
fallbackModel: string;
resolve: (intent: FallbackIntent) => void;
}
import { type UseHistoryManagerReturn } from '../hooks/useHistoryManager.js';
import { type RestartReason } from '../hooks/useIdeTrustListener.js';
@@ -91,6 +99,8 @@ export interface UIState {
// eslint-disable-next-line @typescript-eslint/no-explicit-any
workspaceExtensions: any[]; // Extension[]
// Quota-related state
userTier: UserTierId | undefined;
proQuotaRequest: ProQuotaDialogRequest | null;
currentModel: string;
contextFileNames: string[];
errorCount: number;

View File

@@ -1323,7 +1323,7 @@ describe('useGeminiStream', () => {
it('should call parseAndFormatApiError with the correct authType on stream initialization failure', async () => {
// 1. Setup
const mockError = new Error('Rate limit exceeded');
const mockAuthType = AuthType.USE_VERTEX_AI;
const mockAuthType = AuthType.LOGIN_WITH_GOOGLE;
mockParseAndFormatApiError.mockClear();
mockSendMessageStream.mockReturnValue(
(async function* () {
@@ -1374,6 +1374,9 @@ describe('useGeminiStream', () => {
expect(mockParseAndFormatApiError).toHaveBeenCalledWith(
'Rate limit exceeded',
mockAuthType,
undefined,
'gemini-2.5-pro',
'gemini-2.5-flash',
);
});
});
@@ -2490,6 +2493,9 @@ describe('useGeminiStream', () => {
expect(mockParseAndFormatApiError).toHaveBeenCalledWith(
{ message: 'Test error' },
expect.any(String),
undefined,
'gemini-2.5-pro',
'gemini-2.5-flash',
);
});
});

View File

@@ -26,6 +26,7 @@ import {
GitService,
UnauthorizedError,
UserPromptEvent,
DEFAULT_GEMINI_FLASH_MODEL,
logConversationFinishedEvent,
ConversationFinishedEvent,
ApprovalMode,
@@ -599,6 +600,9 @@ export const useGeminiStream = (
text: parseAndFormatApiError(
eventValue.error,
config.getContentGeneratorConfig()?.authType,
undefined,
config.getModel(),
DEFAULT_GEMINI_FLASH_MODEL,
),
},
userMessageTimestamp,
@@ -650,9 +654,6 @@ export const useGeminiStream = (
'Response stopped due to image safety violations.',
[FinishReason.UNEXPECTED_TOOL_CALL]:
'Response stopped due to unexpected tool call.',
[FinishReason.IMAGE_PROHIBITED_CONTENT]:
'Response stopped due to image prohibited content.',
[FinishReason.NO_IMAGE]: 'Response stopped due to no image.',
};
const message = finishReasonMessages[finishReason];
@@ -769,17 +770,11 @@ export const useGeminiStream = (
for await (const event of stream) {
switch (event.type) {
case ServerGeminiEventType.Thought:
// If the thought has a subject, it's a discrete status update rather than
// a streamed textual thought, so we update the thought state directly.
if (event.value.subject) {
setThought(event.value);
} else {
thoughtBuffer = handleThoughtEvent(
event.value,
thoughtBuffer,
userMessageTimestamp,
);
}
thoughtBuffer = handleThoughtEvent(
event.value,
thoughtBuffer,
userMessageTimestamp,
);
break;
case ServerGeminiEventType.Content:
geminiMessageBuffer = handleContentEvent(
@@ -850,7 +845,6 @@ export const useGeminiStream = (
handleMaxSessionTurnsEvent,
handleSessionTokenLimitExceededEvent,
handleCitationEvent,
setThought,
],
);
@@ -993,6 +987,9 @@ export const useGeminiStream = (
text: parseAndFormatApiError(
getErrorMessage(error) || 'Unknown error',
config.getContentGeneratorConfig()?.authType,
undefined,
config.getModel(),
DEFAULT_GEMINI_FLASH_MODEL,
),
},
userMessageTimestamp,

View File

@@ -0,0 +1,391 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import {
vi,
describe,
it,
expect,
beforeEach,
afterEach,
type Mock,
} from 'vitest';
import { act, renderHook } from '@testing-library/react';
import {
type Config,
type FallbackModelHandler,
UserTierId,
AuthType,
isGenericQuotaExceededError,
isProQuotaExceededError,
makeFakeConfig,
} from '@qwen-code/qwen-code-core';
import { useQuotaAndFallback } from './useQuotaAndFallback.js';
import type { UseHistoryManagerReturn } from './useHistoryManager.js';
import { AuthState, MessageType } from '../types.js';
// Mock the error checking functions from the core package to control test scenarios
vi.mock('@qwen-code/qwen-code-core', async (importOriginal) => {
const original =
await importOriginal<typeof import('@qwen-code/qwen-code-core')>();
return {
...original,
isGenericQuotaExceededError: vi.fn(),
isProQuotaExceededError: vi.fn(),
};
});
// Use a type alias for SpyInstance as it's not directly exported
type SpyInstance = ReturnType<typeof vi.spyOn>;
describe('useQuotaAndFallback', () => {
let mockConfig: Config;
let mockHistoryManager: UseHistoryManagerReturn;
let mockSetAuthState: Mock;
let mockSetModelSwitchedFromQuotaError: Mock;
let setFallbackHandlerSpy: SpyInstance;
const mockedIsGenericQuotaExceededError = isGenericQuotaExceededError as Mock;
const mockedIsProQuotaExceededError = isProQuotaExceededError as Mock;
beforeEach(() => {
mockConfig = makeFakeConfig();
// Spy on the method that requires the private field and mock its return.
// This is cleaner than modifying the config class for tests.
vi.spyOn(mockConfig, 'getContentGeneratorConfig').mockReturnValue({
model: 'test-model',
authType: AuthType.LOGIN_WITH_GOOGLE,
});
mockHistoryManager = {
addItem: vi.fn(),
history: [],
updateItem: vi.fn(),
clearItems: vi.fn(),
loadHistory: vi.fn(),
};
mockSetAuthState = vi.fn();
mockSetModelSwitchedFromQuotaError = vi.fn();
setFallbackHandlerSpy = vi.spyOn(mockConfig, 'setFallbackModelHandler');
vi.spyOn(mockConfig, 'setQuotaErrorOccurred');
mockedIsGenericQuotaExceededError.mockReturnValue(false);
mockedIsProQuotaExceededError.mockReturnValue(false);
});
afterEach(() => {
vi.clearAllMocks();
});
it('should register a fallback handler on initialization', () => {
renderHook(() =>
useQuotaAndFallback({
config: mockConfig,
historyManager: mockHistoryManager,
userTier: UserTierId.FREE,
setAuthState: mockSetAuthState,
setModelSwitchedFromQuotaError: mockSetModelSwitchedFromQuotaError,
}),
);
expect(setFallbackHandlerSpy).toHaveBeenCalledTimes(1);
expect(setFallbackHandlerSpy.mock.calls[0][0]).toBeInstanceOf(Function);
});
describe('Fallback Handler Logic', () => {
// Helper function to render the hook and extract the registered handler
const getRegisteredHandler = (
userTier: UserTierId = UserTierId.FREE,
): FallbackModelHandler => {
renderHook(
(props) =>
useQuotaAndFallback({
config: mockConfig,
historyManager: mockHistoryManager,
userTier: props.userTier,
setAuthState: mockSetAuthState,
setModelSwitchedFromQuotaError: mockSetModelSwitchedFromQuotaError,
}),
{ initialProps: { userTier } },
);
return setFallbackHandlerSpy.mock.calls[0][0] as FallbackModelHandler;
};
it('should return null and take no action if already in fallback mode', async () => {
vi.spyOn(mockConfig, 'isInFallbackMode').mockReturnValue(true);
const handler = getRegisteredHandler();
const result = await handler('gemini-pro', 'gemini-flash', new Error());
expect(result).toBeNull();
expect(mockHistoryManager.addItem).not.toHaveBeenCalled();
});
it('should return null and take no action if authType is not LOGIN_WITH_GOOGLE', async () => {
// Override the default mock from beforeEach for this specific test
vi.spyOn(mockConfig, 'getContentGeneratorConfig').mockReturnValue({
model: 'test-model',
authType: AuthType.USE_GEMINI,
});
const handler = getRegisteredHandler();
const result = await handler('gemini-pro', 'gemini-flash', new Error());
expect(result).toBeNull();
expect(mockHistoryManager.addItem).not.toHaveBeenCalled();
});
describe('Automatic Fallback Scenarios', () => {
const testCases = [
{
errorType: 'generic',
tier: UserTierId.FREE,
expectedMessageSnippets: [
'Automatically switching from model-A to model-B',
'upgrade to a Gemini Code Assist Standard or Enterprise plan',
],
},
{
errorType: 'generic',
tier: UserTierId.STANDARD, // Paid tier
expectedMessageSnippets: [
'Automatically switching from model-A to model-B',
'switch to using a paid API key from AI Studio',
],
},
{
errorType: 'other',
tier: UserTierId.FREE,
expectedMessageSnippets: [
'Automatically switching from model-A to model-B for faster responses',
'upgrade to a Gemini Code Assist Standard or Enterprise plan',
],
},
{
errorType: 'other',
tier: UserTierId.LEGACY, // Paid tier
expectedMessageSnippets: [
'Automatically switching from model-A to model-B for faster responses',
'switch to using a paid API key from AI Studio',
],
},
];
for (const { errorType, tier, expectedMessageSnippets } of testCases) {
it(`should handle ${errorType} error for ${tier} tier correctly`, async () => {
mockedIsGenericQuotaExceededError.mockReturnValue(
errorType === 'generic',
);
const handler = getRegisteredHandler(tier);
const result = await handler(
'model-A',
'model-B',
new Error('quota exceeded'),
);
// Automatic fallbacks should return 'stop'
expect(result).toBe('stop');
expect(mockHistoryManager.addItem).toHaveBeenCalledWith(
expect.objectContaining({ type: MessageType.INFO }),
expect.any(Number),
);
const message = (mockHistoryManager.addItem as Mock).mock.calls[0][0]
.text;
for (const snippet of expectedMessageSnippets) {
expect(message).toContain(snippet);
}
expect(mockSetModelSwitchedFromQuotaError).toHaveBeenCalledWith(true);
expect(mockConfig.setQuotaErrorOccurred).toHaveBeenCalledWith(true);
});
}
});
describe('Interactive Fallback (Pro Quota Error)', () => {
beforeEach(() => {
mockedIsProQuotaExceededError.mockReturnValue(true);
});
it('should set an interactive request and wait for user choice', async () => {
const { result } = renderHook(() =>
useQuotaAndFallback({
config: mockConfig,
historyManager: mockHistoryManager,
userTier: UserTierId.FREE,
setAuthState: mockSetAuthState,
setModelSwitchedFromQuotaError: mockSetModelSwitchedFromQuotaError,
}),
);
const handler = setFallbackHandlerSpy.mock
.calls[0][0] as FallbackModelHandler;
// Call the handler but do not await it, to check the intermediate state
const promise = handler(
'gemini-pro',
'gemini-flash',
new Error('pro quota'),
);
await act(async () => {});
// The hook should now have a pending request for the UI to handle
expect(result.current.proQuotaRequest).not.toBeNull();
expect(result.current.proQuotaRequest?.failedModel).toBe('gemini-pro');
// Simulate the user choosing to continue with the fallback model
act(() => {
result.current.handleProQuotaChoice('continue');
});
// The original promise from the handler should now resolve
const intent = await promise;
expect(intent).toBe('retry');
// The pending request should be cleared from the state
expect(result.current.proQuotaRequest).toBeNull();
});
it('should handle race conditions by stopping subsequent requests', async () => {
const { result } = renderHook(() =>
useQuotaAndFallback({
config: mockConfig,
historyManager: mockHistoryManager,
userTier: UserTierId.FREE,
setAuthState: mockSetAuthState,
setModelSwitchedFromQuotaError: mockSetModelSwitchedFromQuotaError,
}),
);
const handler = setFallbackHandlerSpy.mock
.calls[0][0] as FallbackModelHandler;
const promise1 = handler(
'gemini-pro',
'gemini-flash',
new Error('pro quota 1'),
);
await act(async () => {});
const firstRequest = result.current.proQuotaRequest;
expect(firstRequest).not.toBeNull();
const result2 = await handler(
'gemini-pro',
'gemini-flash',
new Error('pro quota 2'),
);
// The lock should have stopped the second request
expect(result2).toBe('stop');
expect(result.current.proQuotaRequest).toBe(firstRequest);
act(() => {
result.current.handleProQuotaChoice('continue');
});
const intent1 = await promise1;
expect(intent1).toBe('retry');
expect(result.current.proQuotaRequest).toBeNull();
});
});
});
describe('handleProQuotaChoice', () => {
beforeEach(() => {
mockedIsProQuotaExceededError.mockReturnValue(true);
});
it('should do nothing if there is no pending pro quota request', () => {
const { result } = renderHook(() =>
useQuotaAndFallback({
config: mockConfig,
historyManager: mockHistoryManager,
userTier: UserTierId.FREE,
setAuthState: mockSetAuthState,
setModelSwitchedFromQuotaError: mockSetModelSwitchedFromQuotaError,
}),
);
act(() => {
result.current.handleProQuotaChoice('auth');
});
expect(mockSetAuthState).not.toHaveBeenCalled();
expect(mockHistoryManager.addItem).not.toHaveBeenCalled();
});
it('should resolve intent to "auth" and trigger auth state update', async () => {
const { result } = renderHook(() =>
useQuotaAndFallback({
config: mockConfig,
historyManager: mockHistoryManager,
userTier: UserTierId.FREE,
setAuthState: mockSetAuthState,
setModelSwitchedFromQuotaError: mockSetModelSwitchedFromQuotaError,
}),
);
const handler = setFallbackHandlerSpy.mock
.calls[0][0] as FallbackModelHandler;
const promise = handler(
'gemini-pro',
'gemini-flash',
new Error('pro quota'),
);
await act(async () => {}); // Allow state to update
act(() => {
result.current.handleProQuotaChoice('auth');
});
const intent = await promise;
expect(intent).toBe('auth');
expect(mockSetAuthState).toHaveBeenCalledWith(AuthState.Updating);
expect(result.current.proQuotaRequest).toBeNull();
});
it('should resolve intent to "retry" and add info message on continue', async () => {
const { result } = renderHook(() =>
useQuotaAndFallback({
config: mockConfig,
historyManager: mockHistoryManager,
userTier: UserTierId.FREE,
setAuthState: mockSetAuthState,
setModelSwitchedFromQuotaError: mockSetModelSwitchedFromQuotaError,
}),
);
const handler = setFallbackHandlerSpy.mock
.calls[0][0] as FallbackModelHandler;
// The first `addItem` call is for the initial quota error message
const promise = handler(
'gemini-pro',
'gemini-flash',
new Error('pro quota'),
);
await act(async () => {}); // Allow state to update
act(() => {
result.current.handleProQuotaChoice('continue');
});
const intent = await promise;
expect(intent).toBe('retry');
expect(result.current.proQuotaRequest).toBeNull();
// Check for the second "Switched to fallback model" message
expect(mockHistoryManager.addItem).toHaveBeenCalledTimes(2);
const lastCall = (mockHistoryManager.addItem as Mock).mock.calls[1][0];
expect(lastCall.type).toBe(MessageType.INFO);
expect(lastCall.text).toContain('Switched to fallback model.');
});
});
});

View File

@@ -0,0 +1,175 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import {
AuthType,
type Config,
type FallbackModelHandler,
type FallbackIntent,
isGenericQuotaExceededError,
isProQuotaExceededError,
UserTierId,
} from '@qwen-code/qwen-code-core';
import { useCallback, useEffect, useRef, useState } from 'react';
import { type UseHistoryManagerReturn } from './useHistoryManager.js';
import { AuthState, MessageType } from '../types.js';
import { type ProQuotaDialogRequest } from '../contexts/UIStateContext.js';
interface UseQuotaAndFallbackArgs {
config: Config;
historyManager: UseHistoryManagerReturn;
userTier: UserTierId | undefined;
setAuthState: (state: AuthState) => void;
setModelSwitchedFromQuotaError: (value: boolean) => void;
}
export function useQuotaAndFallback({
config,
historyManager,
userTier,
setAuthState,
setModelSwitchedFromQuotaError,
}: UseQuotaAndFallbackArgs) {
const [proQuotaRequest, setProQuotaRequest] =
useState<ProQuotaDialogRequest | null>(null);
const isDialogPending = useRef(false);
// Set up Flash fallback handler
useEffect(() => {
const fallbackHandler: FallbackModelHandler = async (
failedModel,
fallbackModel,
error,
): Promise<FallbackIntent | null> => {
if (config.isInFallbackMode()) {
return null;
}
// Fallbacks are currently only handled for OAuth users.
const contentGeneratorConfig = config.getContentGeneratorConfig();
if (
!contentGeneratorConfig ||
contentGeneratorConfig.authType !== AuthType.LOGIN_WITH_GOOGLE
) {
return null;
}
// Use actual user tier if available; otherwise, default to FREE tier behavior (safe default)
const isPaidTier =
userTier === UserTierId.LEGACY || userTier === UserTierId.STANDARD;
let message: string;
if (error && isProQuotaExceededError(error)) {
// Pro Quota specific messages (Interactive)
if (isPaidTier) {
message = `⚡ You have reached your daily ${failedModel} quota limit.
⚡ You can choose to authenticate with a paid API key or continue with the fallback model.
⚡ To continue accessing the ${failedModel} model today, consider using /auth to switch to using a paid API key from AI Studio at https://aistudio.google.com/apikey`;
} else {
message = `⚡ You have reached your daily ${failedModel} quota limit.
⚡ You can choose to authenticate with a paid API key or continue with the fallback model.
⚡ To increase your limits, upgrade to a Gemini Code Assist Standard or Enterprise plan with higher limits at https://goo.gle/set-up-gemini-code-assist
⚡ Or you can utilize a Gemini API Key. See: https://goo.gle/gemini-cli-docs-auth#gemini-api-key
⚡ You can switch authentication methods by typing /auth`;
}
} else if (error && isGenericQuotaExceededError(error)) {
// Generic Quota (Automatic fallback)
const actionMessage = `⚡ You have reached your daily quota limit.\n⚡ Automatically switching from ${failedModel} to ${fallbackModel} for the remainder of this session.`;
if (isPaidTier) {
message = `${actionMessage}
⚡ To continue accessing the ${failedModel} model today, consider using /auth to switch to using a paid API key from AI Studio at https://aistudio.google.com/apikey`;
} else {
message = `${actionMessage}
⚡ To increase your limits, upgrade to a Gemini Code Assist Standard or Enterprise plan with higher limits at https://goo.gle/set-up-gemini-code-assist
⚡ Or you can utilize a Gemini API Key. See: https://goo.gle/gemini-cli-docs-auth#gemini-api-key
⚡ You can switch authentication methods by typing /auth`;
}
} else {
// Consecutive 429s or other errors (Automatic fallback)
const actionMessage = `⚡ Automatically switching from ${failedModel} to ${fallbackModel} for faster responses for the remainder of this session.`;
if (isPaidTier) {
message = `${actionMessage}
⚡ Possible reasons for this are that you have received multiple consecutive capacity errors or you have reached your daily ${failedModel} quota limit
⚡ To continue accessing the ${failedModel} model today, consider using /auth to switch to using a paid API key from AI Studio at https://aistudio.google.com/apikey`;
} else {
message = `${actionMessage}
⚡ Possible reasons for this are that you have received multiple consecutive capacity errors or you have reached your daily ${failedModel} quota limit
⚡ To increase your limits, upgrade to a Gemini Code Assist Standard or Enterprise plan with higher limits at https://goo.gle/set-up-gemini-code-assist
⚡ Or you can utilize a Gemini API Key. See: https://goo.gle/gemini-cli-docs-auth#gemini-api-key
⚡ You can switch authentication methods by typing /auth`;
}
}
// Add message to UI history
historyManager.addItem(
{
type: MessageType.INFO,
text: message,
},
Date.now(),
);
setModelSwitchedFromQuotaError(true);
config.setQuotaErrorOccurred(true);
// Interactive Fallback for Pro quota
if (error && isProQuotaExceededError(error)) {
if (isDialogPending.current) {
return 'stop'; // A dialog is already active, so just stop this request.
}
isDialogPending.current = true;
const intent: FallbackIntent = await new Promise<FallbackIntent>(
(resolve) => {
setProQuotaRequest({
failedModel,
fallbackModel,
resolve,
});
},
);
return intent;
}
return 'stop';
};
config.setFallbackModelHandler(fallbackHandler);
}, [config, historyManager, userTier, setModelSwitchedFromQuotaError]);
const handleProQuotaChoice = useCallback(
(choice: 'auth' | 'continue') => {
if (!proQuotaRequest) return;
const intent: FallbackIntent = choice === 'auth' ? 'auth' : 'retry';
proQuotaRequest.resolve(intent);
setProQuotaRequest(null);
isDialogPending.current = false; // Reset the flag here
if (choice === 'auth') {
setAuthState(AuthState.Updating);
} else {
historyManager.addItem(
{
type: MessageType.INFO,
text: 'Switched to fallback model. Tip: Press Ctrl+P (or Up Arrow) to recall your previous prompt and submit it again if you wish.',
},
Date.now(),
);
}
},
[proQuotaRequest, setAuthState, historyManager],
);
return {
proQuotaRequest,
handleProQuotaChoice,
};
}

View File

@@ -411,7 +411,7 @@ describe('useQwenAuth', () => {
expect(geminiResult.current.qwenAuthState.authStatus).toBe('idle');
const { result: oauthResult } = renderHook(() =>
useQwenAuth(AuthType.USE_OPENAI, true),
useQwenAuth(AuthType.LOGIN_WITH_GOOGLE, true),
);
expect(oauthResult.current.qwenAuthState.authStatus).toBe('idle');
});

View File

@@ -62,7 +62,7 @@ const mockConfig = {
getAllowedTools: vi.fn(() => []),
getContentGeneratorConfig: () => ({
model: 'test-model',
authType: 'gemini-api-key',
authType: 'oauth-personal',
}),
getUseSmartEdit: () => false,
getUseModelRouter: () => false,

View File

@@ -21,13 +21,6 @@ function getAuthTypeFromEnv(): AuthType | undefined {
return AuthType.QWEN_OAUTH;
}
if (process.env['GEMINI_API_KEY']) {
return AuthType.USE_GEMINI;
}
if (process.env['GOOGLE_API_KEY']) {
return AuthType.USE_VERTEX_AI;
}
return undefined;
}

View File

@@ -23,8 +23,8 @@
"scripts/postinstall.js"
],
"dependencies": {
"@google/genai": "1.30.0",
"@modelcontextprotocol/sdk": "^1.25.1",
"@google/genai": "1.16.0",
"@modelcontextprotocol/sdk": "^1.11.0",
"@opentelemetry/api": "^1.9.0",
"async-mutex": "^0.5.0",
"@opentelemetry/exporter-logs-otlp-grpc": "^0.203.0",
@@ -34,6 +34,7 @@
"@opentelemetry/exporter-trace-otlp-grpc": "^0.203.0",
"@opentelemetry/exporter-trace-otlp-http": "^0.203.0",
"@opentelemetry/instrumentation-http": "^0.203.0",
"@opentelemetry/resource-detector-gcp": "^0.40.0",
"@opentelemetry/sdk-node": "^0.203.0",
"@types/html-to-text": "^9.0.4",
"@xterm/headless": "5.5.0",
@@ -47,7 +48,7 @@
"fdir": "^6.4.6",
"fzf": "^0.5.2",
"glob": "^10.5.0",
"google-auth-library": "^10.5.0",
"google-auth-library": "^9.11.0",
"html-to-text": "^9.0.5",
"https-proxy-agent": "^7.0.6",
"ignore": "^7.0.0",

View File

@@ -0,0 +1,54 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import type { ContentGenerator } from '../core/contentGenerator.js';
import { AuthType } from '../core/contentGenerator.js';
import { getOauthClient } from './oauth2.js';
import { setupUser } from './setup.js';
import type { HttpOptions } from './server.js';
import { CodeAssistServer } from './server.js';
import type { Config } from '../config/config.js';
import { LoggingContentGenerator } from '../core/loggingContentGenerator.js';
export async function createCodeAssistContentGenerator(
httpOptions: HttpOptions,
authType: AuthType,
config: Config,
sessionId?: string,
): Promise<ContentGenerator> {
if (
authType === AuthType.LOGIN_WITH_GOOGLE ||
authType === AuthType.CLOUD_SHELL
) {
const authClient = await getOauthClient(authType, config);
const userData = await setupUser(authClient);
return new CodeAssistServer(
authClient,
userData.projectId,
httpOptions,
sessionId,
userData.userTier,
);
}
throw new Error(`Unsupported authType: ${authType}`);
}
export function getCodeAssistServer(
config: Config,
): CodeAssistServer | undefined {
let server = config.getContentGenerator();
// Unwrap LoggingContentGenerator if present
if (server instanceof LoggingContentGenerator) {
server = server.getWrapped();
}
if (!(server instanceof CodeAssistServer)) {
return undefined;
}
return server;
}

View File

@@ -0,0 +1,456 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { describe, it, expect } from 'vitest';
import type { CaGenerateContentResponse } from './converter.js';
import {
toGenerateContentRequest,
fromGenerateContentResponse,
toContents,
} from './converter.js';
import type {
ContentListUnion,
GenerateContentParameters,
} from '@google/genai';
import {
GenerateContentResponse,
FinishReason,
BlockedReason,
type Part,
} from '@google/genai';
describe('converter', () => {
describe('toCodeAssistRequest', () => {
it('should convert a simple request with project', () => {
const genaiReq: GenerateContentParameters = {
model: 'gemini-pro',
contents: [{ role: 'user', parts: [{ text: 'Hello' }] }],
};
const codeAssistReq = toGenerateContentRequest(
genaiReq,
'my-prompt',
'my-project',
'my-session',
);
expect(codeAssistReq).toEqual({
model: 'gemini-pro',
project: 'my-project',
request: {
contents: [{ role: 'user', parts: [{ text: 'Hello' }] }],
systemInstruction: undefined,
cachedContent: undefined,
tools: undefined,
toolConfig: undefined,
labels: undefined,
safetySettings: undefined,
generationConfig: undefined,
session_id: 'my-session',
},
user_prompt_id: 'my-prompt',
});
});
it('should convert a request without a project', () => {
const genaiReq: GenerateContentParameters = {
model: 'gemini-pro',
contents: [{ role: 'user', parts: [{ text: 'Hello' }] }],
};
const codeAssistReq = toGenerateContentRequest(
genaiReq,
'my-prompt',
undefined,
'my-session',
);
expect(codeAssistReq).toEqual({
model: 'gemini-pro',
project: undefined,
request: {
contents: [{ role: 'user', parts: [{ text: 'Hello' }] }],
systemInstruction: undefined,
cachedContent: undefined,
tools: undefined,
toolConfig: undefined,
labels: undefined,
safetySettings: undefined,
generationConfig: undefined,
session_id: 'my-session',
},
user_prompt_id: 'my-prompt',
});
});
it('should convert a request with sessionId', () => {
const genaiReq: GenerateContentParameters = {
model: 'gemini-pro',
contents: [{ role: 'user', parts: [{ text: 'Hello' }] }],
};
const codeAssistReq = toGenerateContentRequest(
genaiReq,
'my-prompt',
'my-project',
'session-123',
);
expect(codeAssistReq).toEqual({
model: 'gemini-pro',
project: 'my-project',
request: {
contents: [{ role: 'user', parts: [{ text: 'Hello' }] }],
systemInstruction: undefined,
cachedContent: undefined,
tools: undefined,
toolConfig: undefined,
labels: undefined,
safetySettings: undefined,
generationConfig: undefined,
session_id: 'session-123',
},
user_prompt_id: 'my-prompt',
});
});
it('should handle string content', () => {
const genaiReq: GenerateContentParameters = {
model: 'gemini-pro',
contents: 'Hello',
};
const codeAssistReq = toGenerateContentRequest(
genaiReq,
'my-prompt',
'my-project',
'my-session',
);
expect(codeAssistReq.request.contents).toEqual([
{ role: 'user', parts: [{ text: 'Hello' }] },
]);
});
it('should handle Part[] content', () => {
const genaiReq: GenerateContentParameters = {
model: 'gemini-pro',
contents: [{ text: 'Hello' }, { text: 'World' }],
};
const codeAssistReq = toGenerateContentRequest(
genaiReq,
'my-prompt',
'my-project',
'my-session',
);
expect(codeAssistReq.request.contents).toEqual([
{ role: 'user', parts: [{ text: 'Hello' }] },
{ role: 'user', parts: [{ text: 'World' }] },
]);
});
it('should handle system instructions', () => {
const genaiReq: GenerateContentParameters = {
model: 'gemini-pro',
contents: 'Hello',
config: {
systemInstruction: 'You are a helpful assistant.',
},
};
const codeAssistReq = toGenerateContentRequest(
genaiReq,
'my-prompt',
'my-project',
'my-session',
);
expect(codeAssistReq.request.systemInstruction).toEqual({
role: 'user',
parts: [{ text: 'You are a helpful assistant.' }],
});
});
it('should handle generation config', () => {
const genaiReq: GenerateContentParameters = {
model: 'gemini-pro',
contents: 'Hello',
config: {
temperature: 0.8,
topK: 40,
},
};
const codeAssistReq = toGenerateContentRequest(
genaiReq,
'my-prompt',
'my-project',
'my-session',
);
expect(codeAssistReq.request.generationConfig).toEqual({
temperature: 0.8,
topK: 40,
});
});
it('should handle all generation config fields', () => {
const genaiReq: GenerateContentParameters = {
model: 'gemini-pro',
contents: 'Hello',
config: {
temperature: 0.1,
topP: 0.2,
topK: 3,
candidateCount: 4,
maxOutputTokens: 5,
stopSequences: ['a'],
responseLogprobs: true,
logprobs: 6,
presencePenalty: 0.7,
frequencyPenalty: 0.8,
seed: 9,
responseMimeType: 'application/json',
},
};
const codeAssistReq = toGenerateContentRequest(
genaiReq,
'my-prompt',
'my-project',
'my-session',
);
expect(codeAssistReq.request.generationConfig).toEqual({
temperature: 0.1,
topP: 0.2,
topK: 3,
candidateCount: 4,
maxOutputTokens: 5,
stopSequences: ['a'],
responseLogprobs: true,
logprobs: 6,
presencePenalty: 0.7,
frequencyPenalty: 0.8,
seed: 9,
responseMimeType: 'application/json',
});
});
});
describe('fromCodeAssistResponse', () => {
it('should convert a simple response', () => {
const codeAssistRes: CaGenerateContentResponse = {
response: {
candidates: [
{
index: 0,
content: {
role: 'model',
parts: [{ text: 'Hi there!' }],
},
finishReason: FinishReason.STOP,
safetyRatings: [],
},
],
},
};
const genaiRes = fromGenerateContentResponse(codeAssistRes);
expect(genaiRes).toBeInstanceOf(GenerateContentResponse);
expect(genaiRes.candidates).toEqual(codeAssistRes.response.candidates);
});
it('should handle prompt feedback and usage metadata', () => {
const codeAssistRes: CaGenerateContentResponse = {
response: {
candidates: [],
promptFeedback: {
blockReason: BlockedReason.SAFETY,
safetyRatings: [],
},
usageMetadata: {
promptTokenCount: 10,
candidatesTokenCount: 20,
totalTokenCount: 30,
},
},
};
const genaiRes = fromGenerateContentResponse(codeAssistRes);
expect(genaiRes.promptFeedback).toEqual(
codeAssistRes.response.promptFeedback,
);
expect(genaiRes.usageMetadata).toEqual(
codeAssistRes.response.usageMetadata,
);
});
it('should handle automatic function calling history', () => {
const codeAssistRes: CaGenerateContentResponse = {
response: {
candidates: [],
automaticFunctionCallingHistory: [
{
role: 'model',
parts: [
{
functionCall: {
name: 'test_function',
args: {
foo: 'bar',
},
},
},
],
},
],
},
};
const genaiRes = fromGenerateContentResponse(codeAssistRes);
expect(genaiRes.automaticFunctionCallingHistory).toEqual(
codeAssistRes.response.automaticFunctionCallingHistory,
);
});
it('should handle modelVersion', () => {
const codeAssistRes: CaGenerateContentResponse = {
response: {
candidates: [],
modelVersion: 'qwen3-coder-plus',
},
};
const genaiRes = fromGenerateContentResponse(codeAssistRes);
expect(genaiRes.modelVersion).toEqual('qwen3-coder-plus');
});
});
describe('toContents', () => {
it('should handle Content', () => {
const content: ContentListUnion = {
role: 'user',
parts: [{ text: 'hello' }],
};
expect(toContents(content)).toEqual([
{ role: 'user', parts: [{ text: 'hello' }] },
]);
});
it('should handle array of Contents', () => {
const contents: ContentListUnion = [
{ role: 'user', parts: [{ text: 'hello' }] },
{ role: 'model', parts: [{ text: 'hi' }] },
];
expect(toContents(contents)).toEqual([
{ role: 'user', parts: [{ text: 'hello' }] },
{ role: 'model', parts: [{ text: 'hi' }] },
]);
});
it('should handle Part', () => {
const part: ContentListUnion = { text: 'a part' };
expect(toContents(part)).toEqual([
{ role: 'user', parts: [{ text: 'a part' }] },
]);
});
it('should handle array of Parts', () => {
const parts = [{ text: 'part 1' }, 'part 2'];
expect(toContents(parts)).toEqual([
{ role: 'user', parts: [{ text: 'part 1' }] },
{ role: 'user', parts: [{ text: 'part 2' }] },
]);
});
it('should handle string', () => {
const str: ContentListUnion = 'a string';
expect(toContents(str)).toEqual([
{ role: 'user', parts: [{ text: 'a string' }] },
]);
});
it('should handle array of strings', () => {
const strings: ContentListUnion = ['string 1', 'string 2'];
expect(toContents(strings)).toEqual([
{ role: 'user', parts: [{ text: 'string 1' }] },
{ role: 'user', parts: [{ text: 'string 2' }] },
]);
});
it('should convert thought parts to text parts for API compatibility', () => {
const contentWithThought: ContentListUnion = {
role: 'model',
parts: [
{ text: 'regular text' },
{ thought: 'thinking about the problem' } as Part & {
thought: string;
},
{ text: 'more text' },
],
};
expect(toContents(contentWithThought)).toEqual([
{
role: 'model',
parts: [
{ text: 'regular text' },
{ text: '[Thought: thinking about the problem]' },
{ text: 'more text' },
],
},
]);
});
it('should combine text and thought for text parts with thoughts', () => {
const contentWithTextAndThought: ContentListUnion = {
role: 'model',
parts: [
{
text: 'Here is my response',
thought: 'I need to be careful here',
} as Part & { thought: string },
],
};
expect(toContents(contentWithTextAndThought)).toEqual([
{
role: 'model',
parts: [
{
text: 'Here is my response\n[Thought: I need to be careful here]',
},
],
},
]);
});
it('should preserve non-thought properties while removing thought', () => {
const contentWithComplexPart: ContentListUnion = {
role: 'model',
parts: [
{
functionCall: { name: 'calculate', args: { x: 5, y: 10 } },
thought: 'Performing calculation',
} as Part & { thought: string },
],
};
expect(toContents(contentWithComplexPart)).toEqual([
{
role: 'model',
parts: [
{
functionCall: { name: 'calculate', args: { x: 5, y: 10 } },
},
],
},
]);
});
it('should convert invalid text content to valid text part with thought', () => {
const contentWithInvalidText: ContentListUnion = {
role: 'model',
parts: [
{
text: 123, // Invalid - should be string
thought: 'Processing number',
} as Part & { thought: string; text: number },
],
};
expect(toContents(contentWithInvalidText)).toEqual([
{
role: 'model',
parts: [
{
text: '123\n[Thought: Processing number]',
},
],
},
]);
});
});
});

View File

@@ -0,0 +1,285 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import type {
Content,
ContentListUnion,
ContentUnion,
GenerateContentConfig,
GenerateContentParameters,
CountTokensParameters,
CountTokensResponse,
GenerationConfigRoutingConfig,
MediaResolution,
Candidate,
ModelSelectionConfig,
GenerateContentResponsePromptFeedback,
GenerateContentResponseUsageMetadata,
Part,
SafetySetting,
PartUnion,
SpeechConfigUnion,
ThinkingConfig,
ToolListUnion,
ToolConfig,
} from '@google/genai';
import { GenerateContentResponse } from '@google/genai';
export interface CAGenerateContentRequest {
model: string;
project?: string;
user_prompt_id?: string;
request: VertexGenerateContentRequest;
}
interface VertexGenerateContentRequest {
contents: Content[];
systemInstruction?: Content;
cachedContent?: string;
tools?: ToolListUnion;
toolConfig?: ToolConfig;
labels?: Record<string, string>;
safetySettings?: SafetySetting[];
generationConfig?: VertexGenerationConfig;
session_id?: string;
}
interface VertexGenerationConfig {
temperature?: number;
topP?: number;
topK?: number;
candidateCount?: number;
maxOutputTokens?: number;
stopSequences?: string[];
responseLogprobs?: boolean;
logprobs?: number;
presencePenalty?: number;
frequencyPenalty?: number;
seed?: number;
responseMimeType?: string;
responseJsonSchema?: unknown;
responseSchema?: unknown;
routingConfig?: GenerationConfigRoutingConfig;
modelSelectionConfig?: ModelSelectionConfig;
responseModalities?: string[];
mediaResolution?: MediaResolution;
speechConfig?: SpeechConfigUnion;
audioTimestamp?: boolean;
thinkingConfig?: ThinkingConfig;
}
export interface CaGenerateContentResponse {
response: VertexGenerateContentResponse;
}
interface VertexGenerateContentResponse {
candidates: Candidate[];
automaticFunctionCallingHistory?: Content[];
promptFeedback?: GenerateContentResponsePromptFeedback;
usageMetadata?: GenerateContentResponseUsageMetadata;
modelVersion?: string;
}
export interface CaCountTokenRequest {
request: VertexCountTokenRequest;
}
interface VertexCountTokenRequest {
model: string;
contents: Content[];
}
export interface CaCountTokenResponse {
totalTokens: number;
}
export function toCountTokenRequest(
req: CountTokensParameters,
): CaCountTokenRequest {
return {
request: {
model: 'models/' + req.model,
contents: toContents(req.contents),
},
};
}
export function fromCountTokenResponse(
res: CaCountTokenResponse,
): CountTokensResponse {
return {
totalTokens: res.totalTokens,
};
}
export function toGenerateContentRequest(
req: GenerateContentParameters,
userPromptId: string,
project?: string,
sessionId?: string,
): CAGenerateContentRequest {
return {
model: req.model,
project,
user_prompt_id: userPromptId,
request: toVertexGenerateContentRequest(req, sessionId),
};
}
export function fromGenerateContentResponse(
res: CaGenerateContentResponse,
): GenerateContentResponse {
const inres = res.response;
const out = new GenerateContentResponse();
out.candidates = inres.candidates;
out.automaticFunctionCallingHistory = inres.automaticFunctionCallingHistory;
out.promptFeedback = inres.promptFeedback;
out.usageMetadata = inres.usageMetadata;
out.modelVersion = inres.modelVersion;
return out;
}
function toVertexGenerateContentRequest(
req: GenerateContentParameters,
sessionId?: string,
): VertexGenerateContentRequest {
return {
contents: toContents(req.contents),
systemInstruction: maybeToContent(req.config?.systemInstruction),
cachedContent: req.config?.cachedContent,
tools: req.config?.tools,
toolConfig: req.config?.toolConfig,
labels: req.config?.labels,
safetySettings: req.config?.safetySettings,
generationConfig: toVertexGenerationConfig(req.config),
session_id: sessionId,
};
}
export function toContents(contents: ContentListUnion): Content[] {
if (Array.isArray(contents)) {
// it's a Content[] or a PartsUnion[]
return contents.map(toContent);
}
// it's a Content or a PartsUnion
return [toContent(contents)];
}
function maybeToContent(content?: ContentUnion): Content | undefined {
if (!content) {
return undefined;
}
return toContent(content);
}
function toContent(content: ContentUnion): Content {
if (Array.isArray(content)) {
// it's a PartsUnion[]
return {
role: 'user',
parts: toParts(content),
};
}
if (typeof content === 'string') {
// it's a string
return {
role: 'user',
parts: [{ text: content }],
};
}
if ('parts' in content) {
// it's a Content - process parts to handle thought filtering
return {
...content,
parts: content.parts
? toParts(content.parts.filter((p) => p != null))
: [],
};
}
// it's a Part
return {
role: 'user',
parts: [toPart(content as Part)],
};
}
export function toParts(parts: PartUnion[]): Part[] {
return parts.map(toPart);
}
function toPart(part: PartUnion): Part {
if (typeof part === 'string') {
// it's a string
return { text: part };
}
// Handle thought parts for CountToken API compatibility
// The CountToken API expects parts to have certain required "oneof" fields initialized,
// but thought parts don't conform to this schema and cause API failures
if ('thought' in part && part.thought) {
const thoughtText = `[Thought: ${part.thought}]`;
const newPart = { ...part };
delete (newPart as Record<string, unknown>)['thought'];
const hasApiContent =
'functionCall' in newPart ||
'functionResponse' in newPart ||
'inlineData' in newPart ||
'fileData' in newPart;
if (hasApiContent) {
// It's a functionCall or other non-text part. Just strip the thought.
return newPart;
}
// If no other valid API content, this must be a text part.
// Combine existing text (if any) with the thought, preserving other properties.
const text = (newPart as { text?: unknown }).text;
const existingText = text ? String(text) : '';
const combinedText = existingText
? `${existingText}\n${thoughtText}`
: thoughtText;
return {
...newPart,
text: combinedText,
};
}
return part;
}
function toVertexGenerationConfig(
config?: GenerateContentConfig,
): VertexGenerationConfig | undefined {
if (!config) {
return undefined;
}
return {
temperature: config.temperature,
topP: config.topP,
topK: config.topK,
candidateCount: config.candidateCount,
maxOutputTokens: config.maxOutputTokens,
stopSequences: config.stopSequences,
responseLogprobs: config.responseLogprobs,
logprobs: config.logprobs,
presencePenalty: config.presencePenalty,
frequencyPenalty: config.frequencyPenalty,
seed: config.seed,
responseMimeType: config.responseMimeType,
responseSchema: config.responseSchema,
responseJsonSchema: config.responseJsonSchema,
routingConfig: config.routingConfig,
modelSelectionConfig: config.modelSelectionConfig,
responseModalities: config.responseModalities,
mediaResolution: config.mediaResolution,
speechConfig: config.speechConfig,
audioTimestamp: config.audioTimestamp,
thinkingConfig: config.thinkingConfig,
};
}

View File

@@ -0,0 +1,217 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { type Credentials } from 'google-auth-library';
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
import { OAuthCredentialStorage } from './oauth-credential-storage.js';
import type { OAuthCredentials } from '../mcp/token-storage/types.js';
import * as path from 'node:path';
import * as os from 'node:os';
import { promises as fs } from 'node:fs';
// Mock external dependencies
const mockHybridTokenStorage = vi.hoisted(() => ({
getCredentials: vi.fn(),
setCredentials: vi.fn(),
deleteCredentials: vi.fn(),
}));
vi.mock('../mcp/token-storage/hybrid-token-storage.js', () => ({
HybridTokenStorage: vi.fn(() => mockHybridTokenStorage),
}));
vi.mock('node:fs', () => ({
promises: {
readFile: vi.fn(),
rm: vi.fn(),
},
}));
vi.mock('node:os');
vi.mock('node:path');
describe('OAuthCredentialStorage', () => {
const mockCredentials: Credentials = {
access_token: 'mock_access_token',
refresh_token: 'mock_refresh_token',
expiry_date: Date.now() + 3600 * 1000,
token_type: 'Bearer',
scope: 'email profile',
};
const mockMcpCredentials: OAuthCredentials = {
serverName: 'main-account',
token: {
accessToken: 'mock_access_token',
refreshToken: 'mock_refresh_token',
tokenType: 'Bearer',
scope: 'email profile',
expiresAt: mockCredentials.expiry_date!,
},
updatedAt: expect.any(Number),
};
const oldFilePath = '/mock/home/.qwen/oauth.json';
beforeEach(() => {
vi.spyOn(mockHybridTokenStorage, 'getCredentials').mockResolvedValue(null);
vi.spyOn(mockHybridTokenStorage, 'setCredentials').mockResolvedValue(
undefined,
);
vi.spyOn(mockHybridTokenStorage, 'deleteCredentials').mockResolvedValue(
undefined,
);
vi.spyOn(fs, 'readFile').mockRejectedValue(new Error('File not found'));
vi.spyOn(fs, 'rm').mockResolvedValue(undefined);
vi.spyOn(os, 'homedir').mockReturnValue('/mock/home');
vi.spyOn(path, 'join').mockReturnValue(oldFilePath);
});
afterEach(() => {
vi.restoreAllMocks();
});
describe('loadCredentials', () => {
it('should load credentials from HybridTokenStorage if available', async () => {
vi.spyOn(mockHybridTokenStorage, 'getCredentials').mockResolvedValue(
mockMcpCredentials,
);
const result = await OAuthCredentialStorage.loadCredentials();
expect(mockHybridTokenStorage.getCredentials).toHaveBeenCalledWith(
'main-account',
);
expect(result).toEqual(mockCredentials);
});
it('should fallback to migrateFromFileStorage if no credentials in HybridTokenStorage', async () => {
vi.spyOn(mockHybridTokenStorage, 'getCredentials').mockResolvedValue(
null,
);
vi.spyOn(fs, 'readFile').mockResolvedValue(
JSON.stringify(mockCredentials),
);
const result = await OAuthCredentialStorage.loadCredentials();
expect(mockHybridTokenStorage.getCredentials).toHaveBeenCalledWith(
'main-account',
);
expect(fs.readFile).toHaveBeenCalledWith(oldFilePath, 'utf-8');
expect(mockHybridTokenStorage.setCredentials).toHaveBeenCalled(); // Verify credentials were saved
expect(fs.rm).toHaveBeenCalledWith(oldFilePath, { force: true }); // Verify old file was removed
expect(result).toEqual(mockCredentials);
});
it('should return null if no credentials found and no old file to migrate', async () => {
vi.spyOn(fs, 'readFile').mockRejectedValue({
message: 'File not found',
code: 'ENOENT',
});
const result = await OAuthCredentialStorage.loadCredentials();
expect(result).toBeNull();
});
it('should throw an error if loading fails', async () => {
vi.spyOn(mockHybridTokenStorage, 'getCredentials').mockRejectedValue(
new Error('Loading error'),
);
await expect(OAuthCredentialStorage.loadCredentials()).rejects.toThrow(
'Failed to load OAuth credentials',
);
});
it('should throw an error if read file fails', async () => {
vi.spyOn(mockHybridTokenStorage, 'getCredentials').mockResolvedValue(
null,
);
vi.spyOn(fs, 'readFile').mockRejectedValue(
new Error('Permission denied'),
);
await expect(OAuthCredentialStorage.loadCredentials()).rejects.toThrow(
'Failed to load OAuth credentials',
);
});
it('should not throw error if migration file removal failed', async () => {
vi.spyOn(mockHybridTokenStorage, 'getCredentials').mockResolvedValue(
null,
);
vi.spyOn(fs, 'readFile').mockResolvedValue(
JSON.stringify(mockCredentials),
);
vi.spyOn(OAuthCredentialStorage, 'saveCredentials').mockResolvedValue(
undefined,
);
vi.spyOn(fs, 'rm').mockRejectedValue(new Error('Deletion failed'));
const result = await OAuthCredentialStorage.loadCredentials();
expect(result).toEqual(mockCredentials);
});
});
describe('saveCredentials', () => {
it('should save credentials to HybridTokenStorage', async () => {
await OAuthCredentialStorage.saveCredentials(mockCredentials);
expect(mockHybridTokenStorage.setCredentials).toHaveBeenCalledWith(
mockMcpCredentials,
);
});
it('should throw an error if access_token is missing', async () => {
const invalidCredentials: Credentials = {
...mockCredentials,
access_token: undefined,
};
await expect(
OAuthCredentialStorage.saveCredentials(invalidCredentials),
).rejects.toThrow(
'Attempted to save credentials without an access token.',
);
});
});
describe('clearCredentials', () => {
it('should delete credentials from HybridTokenStorage', async () => {
await OAuthCredentialStorage.clearCredentials();
expect(mockHybridTokenStorage.deleteCredentials).toHaveBeenCalledWith(
'main-account',
);
});
it('should attempt to remove the old file-based storage', async () => {
await OAuthCredentialStorage.clearCredentials();
expect(fs.rm).toHaveBeenCalledWith(oldFilePath, { force: true });
});
it('should not throw an error if deleting old file fails', async () => {
vi.spyOn(fs, 'rm').mockRejectedValue(new Error('File deletion failed'));
await expect(
OAuthCredentialStorage.clearCredentials(),
).resolves.toBeUndefined();
});
it('should throw an error if clearing from HybridTokenStorage fails', async () => {
vi.spyOn(mockHybridTokenStorage, 'deleteCredentials').mockRejectedValue(
new Error('Deletion error'),
);
await expect(OAuthCredentialStorage.clearCredentials()).rejects.toThrow(
'Failed to clear OAuth credentials',
);
});
});
});

View File

@@ -0,0 +1,130 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { type Credentials } from 'google-auth-library';
import { HybridTokenStorage } from '../mcp/token-storage/hybrid-token-storage.js';
import { OAUTH_FILE } from '../config/storage.js';
import type { OAuthCredentials } from '../mcp/token-storage/types.js';
import * as path from 'node:path';
import * as os from 'node:os';
import { promises as fs } from 'node:fs';
const QWEN_DIR = '.qwen';
const KEYCHAIN_SERVICE_NAME = 'qwen-code-oauth';
const MAIN_ACCOUNT_KEY = 'main-account';
export class OAuthCredentialStorage {
private static storage: HybridTokenStorage = new HybridTokenStorage(
KEYCHAIN_SERVICE_NAME,
);
/**
* Load cached OAuth credentials
*/
static async loadCredentials(): Promise<Credentials | null> {
try {
const credentials = await this.storage.getCredentials(MAIN_ACCOUNT_KEY);
if (credentials?.token) {
const { accessToken, refreshToken, expiresAt, tokenType, scope } =
credentials.token;
// Convert from OAuthCredentials format to Google Credentials format
const googleCreds: Credentials = {
access_token: accessToken,
refresh_token: refreshToken || undefined,
token_type: tokenType || undefined,
scope: scope || undefined,
};
if (expiresAt) {
googleCreds.expiry_date = expiresAt;
}
return googleCreds;
}
// Fallback: Try to migrate from old file-based storage
return await this.migrateFromFileStorage();
} catch (error: unknown) {
console.error(error);
throw new Error('Failed to load OAuth credentials');
}
}
/**
* Save OAuth credentials
*/
static async saveCredentials(credentials: Credentials): Promise<void> {
if (!credentials.access_token) {
throw new Error('Attempted to save credentials without an access token.');
}
// Convert Google Credentials to OAuthCredentials format
const mcpCredentials: OAuthCredentials = {
serverName: MAIN_ACCOUNT_KEY,
token: {
accessToken: credentials.access_token,
refreshToken: credentials.refresh_token || undefined,
tokenType: credentials.token_type || 'Bearer',
scope: credentials.scope || undefined,
expiresAt: credentials.expiry_date || undefined,
},
updatedAt: Date.now(),
};
await this.storage.setCredentials(mcpCredentials);
}
/**
* Clear cached OAuth credentials
*/
static async clearCredentials(): Promise<void> {
try {
await this.storage.deleteCredentials(MAIN_ACCOUNT_KEY);
// Also try to remove the old file if it exists
const oldFilePath = path.join(os.homedir(), QWEN_DIR, OAUTH_FILE);
await fs.rm(oldFilePath, { force: true }).catch(() => {});
} catch (error: unknown) {
console.error(error);
throw new Error('Failed to clear OAuth credentials');
}
}
/**
* Migrate credentials from old file-based storage to keychain
*/
private static async migrateFromFileStorage(): Promise<Credentials | null> {
const oldFilePath = path.join(os.homedir(), QWEN_DIR, OAUTH_FILE);
let credsJson: string;
try {
credsJson = await fs.readFile(oldFilePath, 'utf-8');
} catch (error: unknown) {
if (
typeof error === 'object' &&
error !== null &&
'code' in error &&
error.code === 'ENOENT'
) {
// File doesn't exist, so no migration.
return null;
}
// Other read errors should propagate.
throw error;
}
const credentials = JSON.parse(credsJson) as Credentials;
// Save to new storage
await this.saveCredentials(credentials);
// Remove old file after successful migration
await fs.rm(oldFilePath, { force: true }).catch(() => {});
return credentials;
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,563 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import type { Credentials } from 'google-auth-library';
import {
CodeChallengeMethod,
Compute,
OAuth2Client,
} from 'google-auth-library';
import crypto from 'node:crypto';
import { promises as fs } from 'node:fs';
import * as http from 'node:http';
import * as net from 'node:net';
import path from 'node:path';
import readline from 'node:readline';
import url from 'node:url';
import open from 'open';
import type { Config } from '../config/config.js';
import { Storage } from '../config/storage.js';
import { AuthType } from '../core/contentGenerator.js';
import { FatalAuthenticationError, getErrorMessage } from '../utils/errors.js';
import { UserAccountManager } from '../utils/userAccountManager.js';
import { OAuthCredentialStorage } from './oauth-credential-storage.js';
import { FORCE_ENCRYPTED_FILE_ENV_VAR } from '../mcp/token-storage/index.js';
const userAccountManager = new UserAccountManager();
// OAuth Client ID used to initiate OAuth2Client class.
const OAUTH_CLIENT_ID =
'681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com';
// OAuth Secret value used to initiate OAuth2Client class.
// Note: It's ok to save this in git because this is an installed application
// as described here: https://developers.google.com/identity/protocols/oauth2#installed
// "The process results in a client ID and, in some cases, a client secret,
// which you embed in the source code of your application. (In this context,
// the client secret is obviously not treated as a secret.)"
const OAUTH_CLIENT_SECRET = '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';
// OAuth Scopes for Cloud Code authorization.
const OAUTH_SCOPE = [
'https://www.googleapis.com/auth/cloud-platform',
'https://www.googleapis.com/auth/userinfo.email',
'https://www.googleapis.com/auth/userinfo.profile',
];
const HTTP_REDIRECT = 301;
const SIGN_IN_SUCCESS_URL =
'https://developers.google.com/gemini-code-assist/auth_success_gemini';
const SIGN_IN_FAILURE_URL =
'https://developers.google.com/gemini-code-assist/auth_failure_gemini';
/**
* An Authentication URL for updating the credentials of a Oauth2Client
* as well as a promise that will resolve when the credentials have
* been refreshed (or which throws error when refreshing credentials failed).
*/
export interface OauthWebLogin {
authUrl: string;
loginCompletePromise: Promise<void>;
}
const oauthClientPromises = new Map<AuthType, Promise<OAuth2Client>>();
function getUseEncryptedStorageFlag() {
return process.env[FORCE_ENCRYPTED_FILE_ENV_VAR] === 'true';
}
async function initOauthClient(
authType: AuthType,
config: Config,
): Promise<OAuth2Client> {
const client = new OAuth2Client({
clientId: OAUTH_CLIENT_ID,
clientSecret: OAUTH_CLIENT_SECRET,
transporterOptions: {
proxy: config.getProxy(),
},
});
const useEncryptedStorage = getUseEncryptedStorageFlag();
if (
process.env['GOOGLE_GENAI_USE_GCA'] &&
process.env['GOOGLE_CLOUD_ACCESS_TOKEN']
) {
client.setCredentials({
access_token: process.env['GOOGLE_CLOUD_ACCESS_TOKEN'],
});
await fetchAndCacheUserInfo(client);
return client;
}
client.on('tokens', async (tokens: Credentials) => {
if (useEncryptedStorage) {
await OAuthCredentialStorage.saveCredentials(tokens);
} else {
await cacheCredentials(tokens);
}
});
// If there are cached creds on disk, they always take precedence
if (await loadCachedCredentials(client)) {
// Found valid cached credentials.
// Check if we need to retrieve Google Account ID or Email
if (!userAccountManager.getCachedGoogleAccount()) {
try {
await fetchAndCacheUserInfo(client);
} catch (error) {
// Non-fatal, continue with existing auth.
console.warn('Failed to fetch user info:', getErrorMessage(error));
}
}
console.log('Loaded cached credentials.');
return client;
}
// In Google Cloud Shell, we can use Application Default Credentials (ADC)
// provided via its metadata server to authenticate non-interactively using
// the identity of the user logged into Cloud Shell.
if (authType === AuthType.CLOUD_SHELL) {
try {
console.log("Attempting to authenticate via Cloud Shell VM's ADC.");
const computeClient = new Compute({
// We can leave this empty, since the metadata server will provide
// the service account email.
});
await computeClient.getAccessToken();
console.log('Authentication successful.');
// Do not cache creds in this case; note that Compute client will handle its own refresh
return computeClient;
} catch (e) {
throw new Error(
`Could not authenticate using Cloud Shell credentials. Please select a different authentication method or ensure you are in a properly configured environment. Error: ${getErrorMessage(
e,
)}`,
);
}
}
if (config.isBrowserLaunchSuppressed()) {
let success = false;
const maxRetries = 2;
for (let i = 0; !success && i < maxRetries; i++) {
success = await authWithUserCode(client);
if (!success) {
console.error(
'\nFailed to authenticate with user code.',
i === maxRetries - 1 ? '' : 'Retrying...\n',
);
}
}
if (!success) {
throw new FatalAuthenticationError(
'Failed to authenticate with user code.',
);
}
} else {
const webLogin = await authWithWeb(client);
console.log(
`\n\nCode Assist login required.\n` +
`Attempting to open authentication page in your browser.\n` +
`Otherwise navigate to:\n\n${webLogin.authUrl}\n\n`,
);
try {
// Attempt to open the authentication URL in the default browser.
// We do not use the `wait` option here because the main script's execution
// is already paused by `loginCompletePromise`, which awaits the server callback.
const childProcess = await open(webLogin.authUrl);
// IMPORTANT: Attach an error handler to the returned child process.
// Without this, if `open` fails to spawn a process (e.g., `xdg-open` is not found
// in a minimal Docker container), it will emit an unhandled 'error' event,
// causing the entire Node.js process to crash.
childProcess.on('error', (error) => {
console.error(
'Failed to open browser automatically. Please try running again with NO_BROWSER=true set.',
);
console.error('Browser error details:', getErrorMessage(error));
});
} catch (err) {
console.error(
'An unexpected error occurred while trying to open the browser:',
getErrorMessage(err),
'\nThis might be due to browser compatibility issues or system configuration.',
'\nPlease try running again with NO_BROWSER=true set for manual authentication.',
);
throw new FatalAuthenticationError(
`Failed to open browser: ${getErrorMessage(err)}`,
);
}
console.log('Waiting for authentication...');
// Add timeout to prevent infinite waiting when browser tab gets stuck
const authTimeout = 5 * 60 * 1000; // 5 minutes timeout
const timeoutPromise = new Promise<never>((_, reject) => {
setTimeout(() => {
reject(
new FatalAuthenticationError(
'Authentication timed out after 5 minutes. The browser tab may have gotten stuck in a loading state. ' +
'Please try again or use NO_BROWSER=true for manual authentication.',
),
);
}, authTimeout);
});
await Promise.race([webLogin.loginCompletePromise, timeoutPromise]);
}
return client;
}
export async function getOauthClient(
authType: AuthType,
config: Config,
): Promise<OAuth2Client> {
if (!oauthClientPromises.has(authType)) {
oauthClientPromises.set(authType, initOauthClient(authType, config));
}
return oauthClientPromises.get(authType)!;
}
async function authWithUserCode(client: OAuth2Client): Promise<boolean> {
const redirectUri = 'https://codeassist.google.com/authcode';
const codeVerifier = await client.generateCodeVerifierAsync();
const state = crypto.randomBytes(32).toString('hex');
const authUrl: string = client.generateAuthUrl({
redirect_uri: redirectUri,
access_type: 'offline',
scope: OAUTH_SCOPE,
code_challenge_method: CodeChallengeMethod.S256,
code_challenge: codeVerifier.codeChallenge,
state,
});
console.log('Please visit the following URL to authorize the application:');
console.log('');
console.log(authUrl);
console.log('');
const code = await new Promise<string>((resolve) => {
const rl = readline.createInterface({
input: process.stdin,
output: process.stdout,
});
rl.question('Enter the authorization code: ', (code) => {
rl.close();
resolve(code.trim());
});
});
if (!code) {
console.error('Authorization code is required.');
return false;
}
try {
const { tokens } = await client.getToken({
code,
codeVerifier: codeVerifier.codeVerifier,
redirect_uri: redirectUri,
});
client.setCredentials(tokens);
} catch (error) {
console.error(
'Failed to authenticate with authorization code:',
getErrorMessage(error),
);
return false;
}
return true;
}
async function authWithWeb(client: OAuth2Client): Promise<OauthWebLogin> {
const port = await getAvailablePort();
// The hostname used for the HTTP server binding (e.g., '0.0.0.0' in Docker).
const host = process.env['OAUTH_CALLBACK_HOST'] || 'localhost';
// The `redirectUri` sent to Google's authorization server MUST use a loopback IP literal
// (i.e., 'localhost' or '127.0.0.1'). This is a strict security policy for credentials of
// type 'Desktop app' or 'Web application' (when using loopback flow) to mitigate
// authorization code interception attacks.
const redirectUri = `http://localhost:${port}/oauth2callback`;
const state = crypto.randomBytes(32).toString('hex');
const authUrl = client.generateAuthUrl({
redirect_uri: redirectUri,
access_type: 'offline',
scope: OAUTH_SCOPE,
state,
});
const loginCompletePromise = new Promise<void>((resolve, reject) => {
const server = http.createServer(async (req, res) => {
try {
if (req.url!.indexOf('/oauth2callback') === -1) {
res.writeHead(HTTP_REDIRECT, { Location: SIGN_IN_FAILURE_URL });
res.end();
reject(
new FatalAuthenticationError(
'OAuth callback not received. Unexpected request: ' + req.url,
),
);
}
// acquire the code from the querystring, and close the web server.
const qs = new url.URL(req.url!, 'http://localhost:3000').searchParams;
if (qs.get('error')) {
res.writeHead(HTTP_REDIRECT, { Location: SIGN_IN_FAILURE_URL });
res.end();
const errorCode = qs.get('error');
const errorDescription =
qs.get('error_description') || 'No additional details provided';
reject(
new FatalAuthenticationError(
`Google OAuth error: ${errorCode}. ${errorDescription}`,
),
);
} else if (qs.get('state') !== state) {
res.end('State mismatch. Possible CSRF attack');
reject(
new FatalAuthenticationError(
'OAuth state mismatch. Possible CSRF attack or browser session issue.',
),
);
} else if (qs.get('code')) {
try {
const { tokens } = await client.getToken({
code: qs.get('code')!,
redirect_uri: redirectUri,
});
client.setCredentials(tokens);
// Retrieve and cache Google Account ID during authentication
try {
await fetchAndCacheUserInfo(client);
} catch (error) {
console.warn(
'Failed to retrieve Google Account ID during authentication:',
getErrorMessage(error),
);
// Don't fail the auth flow if Google Account ID retrieval fails
}
res.writeHead(HTTP_REDIRECT, { Location: SIGN_IN_SUCCESS_URL });
res.end();
resolve();
} catch (error) {
res.writeHead(HTTP_REDIRECT, { Location: SIGN_IN_FAILURE_URL });
res.end();
reject(
new FatalAuthenticationError(
`Failed to exchange authorization code for tokens: ${getErrorMessage(error)}`,
),
);
}
} else {
reject(
new FatalAuthenticationError(
'No authorization code received from Google OAuth. Please try authenticating again.',
),
);
}
} catch (e) {
// Provide more specific error message for unexpected errors during OAuth flow
if (e instanceof FatalAuthenticationError) {
reject(e);
} else {
reject(
new FatalAuthenticationError(
`Unexpected error during OAuth authentication: ${getErrorMessage(e)}`,
),
);
}
} finally {
server.close();
}
});
server.listen(port, host, () => {
// Server started successfully
});
server.on('error', (err) => {
reject(
new FatalAuthenticationError(
`OAuth callback server error: ${getErrorMessage(err)}`,
),
);
});
});
return {
authUrl,
loginCompletePromise,
};
}
export function getAvailablePort(): Promise<number> {
return new Promise((resolve, reject) => {
let port = 0;
try {
const portStr = process.env['OAUTH_CALLBACK_PORT'];
if (portStr) {
port = parseInt(portStr, 10);
if (isNaN(port) || port <= 0 || port > 65535) {
return reject(
new Error(`Invalid value for OAUTH_CALLBACK_PORT: "${portStr}"`),
);
}
return resolve(port);
}
const server = net.createServer();
server.listen(0, () => {
const address = server.address()! as net.AddressInfo;
port = address.port;
});
server.on('listening', () => {
server.close();
server.unref();
});
server.on('error', (e) => reject(e));
server.on('close', () => resolve(port));
} catch (e) {
reject(e);
}
});
}
async function loadCachedCredentials(client: OAuth2Client): Promise<boolean> {
const useEncryptedStorage = getUseEncryptedStorageFlag();
if (useEncryptedStorage) {
const credentials = await OAuthCredentialStorage.loadCredentials();
if (credentials) {
client.setCredentials(credentials);
return true;
}
return false;
}
const pathsToTry = [
Storage.getOAuthCredsPath(),
process.env['GOOGLE_APPLICATION_CREDENTIALS'],
].filter((p): p is string => !!p);
for (const keyFile of pathsToTry) {
try {
const creds = await fs.readFile(keyFile, 'utf-8');
client.setCredentials(JSON.parse(creds));
// This will verify locally that the credentials look good.
const { token } = await client.getAccessToken();
if (!token) {
continue;
}
// This will check with the server to see if it hasn't been revoked.
await client.getTokenInfo(token);
return true;
} catch (error) {
// Log specific error for debugging, but continue trying other paths
console.debug(
`Failed to load credentials from ${keyFile}:`,
getErrorMessage(error),
);
}
}
return false;
}
async function cacheCredentials(credentials: Credentials) {
const filePath = Storage.getOAuthCredsPath();
await fs.mkdir(path.dirname(filePath), { recursive: true });
const credString = JSON.stringify(credentials, null, 2);
await fs.writeFile(filePath, credString, { mode: 0o600 });
try {
await fs.chmod(filePath, 0o600);
} catch {
/* empty */
}
}
export function clearOauthClientCache() {
oauthClientPromises.clear();
}
export async function clearCachedCredentialFile() {
try {
const useEncryptedStorage = getUseEncryptedStorageFlag();
if (useEncryptedStorage) {
await OAuthCredentialStorage.clearCredentials();
} else {
await fs.rm(Storage.getOAuthCredsPath(), { force: true });
}
// Clear the Google Account ID cache when credentials are cleared
await userAccountManager.clearCachedGoogleAccount();
// Clear the in-memory OAuth client cache to force re-authentication
clearOauthClientCache();
/**
* Also clear Qwen SharedTokenManager cache and credentials file to prevent stale credentials
* when switching between auth types
* TODO: We do not depend on code_assist, we'll have to build an independent auth-cleaning procedure.
*/
try {
const { SharedTokenManager } = await import(
'../qwen/sharedTokenManager.js'
);
const { clearQwenCredentials } = await import('../qwen/qwenOAuth2.js');
const sharedManager = SharedTokenManager.getInstance();
sharedManager.clearCache();
await clearQwenCredentials();
} catch (qwenError) {
console.debug('Could not clear Qwen credentials:', qwenError);
}
} catch (e) {
console.error('Failed to clear cached credentials:', e);
}
}
async function fetchAndCacheUserInfo(client: OAuth2Client): Promise<void> {
try {
const { token } = await client.getAccessToken();
if (!token) {
return;
}
const response = await fetch(
'https://www.googleapis.com/oauth2/v2/userinfo',
{
headers: {
Authorization: `Bearer ${token}`,
},
},
);
if (!response.ok) {
console.error(
'Failed to fetch user info:',
response.status,
response.statusText,
);
return;
}
const userInfo = await response.json();
await userAccountManager.cacheGoogleAccount(userInfo.email);
} catch (error) {
console.error('Error retrieving user info:', error);
}
}
// Helper to ensure test isolation
export function resetOauthClientForTesting() {
oauthClientPromises.clear();
}

View File

@@ -0,0 +1,255 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { beforeEach, describe, it, expect, vi } from 'vitest';
import { CodeAssistServer } from './server.js';
import { OAuth2Client } from 'google-auth-library';
import { UserTierId } from './types.js';
vi.mock('google-auth-library');
describe('CodeAssistServer', () => {
beforeEach(() => {
vi.resetAllMocks();
});
it('should be able to be constructed', () => {
const auth = new OAuth2Client();
const server = new CodeAssistServer(
auth,
'test-project',
{},
'test-session',
UserTierId.FREE,
);
expect(server).toBeInstanceOf(CodeAssistServer);
});
it('should call the generateContent endpoint', async () => {
const client = new OAuth2Client();
const server = new CodeAssistServer(
client,
'test-project',
{},
'test-session',
UserTierId.FREE,
);
const mockResponse = {
response: {
candidates: [
{
index: 0,
content: {
role: 'model',
parts: [{ text: 'response' }],
},
finishReason: 'STOP',
safetyRatings: [],
},
],
},
};
vi.spyOn(server, 'requestPost').mockResolvedValue(mockResponse);
const response = await server.generateContent(
{
model: 'test-model',
contents: [{ role: 'user', parts: [{ text: 'request' }] }],
},
'user-prompt-id',
);
expect(server.requestPost).toHaveBeenCalledWith(
'generateContent',
expect.any(Object),
undefined,
);
expect(response.candidates?.[0]?.content?.parts?.[0]?.text).toBe(
'response',
);
});
it('should call the generateContentStream endpoint', async () => {
const client = new OAuth2Client();
const server = new CodeAssistServer(
client,
'test-project',
{},
'test-session',
UserTierId.FREE,
);
const mockResponse = (async function* () {
yield {
response: {
candidates: [
{
index: 0,
content: {
role: 'model',
parts: [{ text: 'response' }],
},
finishReason: 'STOP',
safetyRatings: [],
},
],
},
};
})();
vi.spyOn(server, 'requestStreamingPost').mockResolvedValue(mockResponse);
const stream = await server.generateContentStream(
{
model: 'test-model',
contents: [{ role: 'user', parts: [{ text: 'request' }] }],
},
'user-prompt-id',
);
for await (const res of stream) {
expect(server.requestStreamingPost).toHaveBeenCalledWith(
'streamGenerateContent',
expect.any(Object),
undefined,
);
expect(res.candidates?.[0]?.content?.parts?.[0]?.text).toBe('response');
}
});
it('should call the onboardUser endpoint', async () => {
const client = new OAuth2Client();
const server = new CodeAssistServer(
client,
'test-project',
{},
'test-session',
UserTierId.FREE,
);
const mockResponse = {
name: 'operations/123',
done: true,
};
vi.spyOn(server, 'requestPost').mockResolvedValue(mockResponse);
const response = await server.onboardUser({
tierId: 'test-tier',
cloudaicompanionProject: 'test-project',
metadata: {},
});
expect(server.requestPost).toHaveBeenCalledWith(
'onboardUser',
expect.any(Object),
);
expect(response.name).toBe('operations/123');
});
it('should call the loadCodeAssist endpoint', async () => {
const client = new OAuth2Client();
const server = new CodeAssistServer(
client,
'test-project',
{},
'test-session',
UserTierId.FREE,
);
const mockResponse = {
currentTier: {
id: UserTierId.FREE,
name: 'Free',
description: 'free tier',
},
allowedTiers: [],
ineligibleTiers: [],
cloudaicompanionProject: 'projects/test',
};
vi.spyOn(server, 'requestPost').mockResolvedValue(mockResponse);
const response = await server.loadCodeAssist({
metadata: {},
});
expect(server.requestPost).toHaveBeenCalledWith(
'loadCodeAssist',
expect.any(Object),
);
expect(response).toEqual(mockResponse);
});
it('should return 0 for countTokens', async () => {
const client = new OAuth2Client();
const server = new CodeAssistServer(
client,
'test-project',
{},
'test-session',
UserTierId.FREE,
);
const mockResponse = {
totalTokens: 100,
};
vi.spyOn(server, 'requestPost').mockResolvedValue(mockResponse);
const response = await server.countTokens({
model: 'test-model',
contents: [{ role: 'user', parts: [{ text: 'request' }] }],
});
expect(response.totalTokens).toBe(100);
});
it('should throw an error for embedContent', async () => {
const client = new OAuth2Client();
const server = new CodeAssistServer(
client,
'test-project',
{},
'test-session',
UserTierId.FREE,
);
await expect(
server.embedContent({
model: 'test-model',
contents: [{ role: 'user', parts: [{ text: 'request' }] }],
}),
).rejects.toThrow();
});
it('should handle VPC-SC errors when calling loadCodeAssist', async () => {
const client = new OAuth2Client();
const server = new CodeAssistServer(
client,
'test-project',
{},
'test-session',
UserTierId.FREE,
);
const mockVpcScError = {
response: {
data: {
error: {
details: [
{
reason: 'SECURITY_POLICY_VIOLATED',
},
],
},
},
},
};
vi.spyOn(server, 'requestPost').mockRejectedValue(mockVpcScError);
const response = await server.loadCodeAssist({
metadata: {},
});
expect(server.requestPost).toHaveBeenCalledWith(
'loadCodeAssist',
expect.any(Object),
);
expect(response).toEqual({
currentTier: { id: UserTierId.STANDARD },
});
});
});

View File

@@ -0,0 +1,253 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import type { OAuth2Client } from 'google-auth-library';
import type {
CodeAssistGlobalUserSettingResponse,
GoogleRpcResponse,
LoadCodeAssistRequest,
LoadCodeAssistResponse,
LongRunningOperationResponse,
OnboardUserRequest,
SetCodeAssistGlobalUserSettingRequest,
} from './types.js';
import type {
CountTokensParameters,
CountTokensResponse,
EmbedContentParameters,
EmbedContentResponse,
GenerateContentParameters,
GenerateContentResponse,
} from '@google/genai';
import * as readline from 'node:readline';
import type { ContentGenerator } from '../core/contentGenerator.js';
import { UserTierId } from './types.js';
import type {
CaCountTokenResponse,
CaGenerateContentResponse,
} from './converter.js';
import {
fromCountTokenResponse,
fromGenerateContentResponse,
toCountTokenRequest,
toGenerateContentRequest,
} from './converter.js';
/** HTTP options to be used in each of the requests. */
export interface HttpOptions {
/** Additional HTTP headers to be sent with the request. */
headers?: Record<string, string>;
}
export const CODE_ASSIST_ENDPOINT = 'https://localhost:0'; // Disable Google Code Assist API Request
export const CODE_ASSIST_API_VERSION = 'v1internal';
export class CodeAssistServer implements ContentGenerator {
constructor(
readonly client: OAuth2Client,
readonly projectId?: string,
readonly httpOptions: HttpOptions = {},
readonly sessionId?: string,
readonly userTier?: UserTierId,
) {}
async generateContentStream(
req: GenerateContentParameters,
userPromptId: string,
): Promise<AsyncGenerator<GenerateContentResponse>> {
const resps = await this.requestStreamingPost<CaGenerateContentResponse>(
'streamGenerateContent',
toGenerateContentRequest(
req,
userPromptId,
this.projectId,
this.sessionId,
),
req.config?.abortSignal,
);
return (async function* (): AsyncGenerator<GenerateContentResponse> {
for await (const resp of resps) {
yield fromGenerateContentResponse(resp);
}
})();
}
async generateContent(
req: GenerateContentParameters,
userPromptId: string,
): Promise<GenerateContentResponse> {
const resp = await this.requestPost<CaGenerateContentResponse>(
'generateContent',
toGenerateContentRequest(
req,
userPromptId,
this.projectId,
this.sessionId,
),
req.config?.abortSignal,
);
return fromGenerateContentResponse(resp);
}
async onboardUser(
req: OnboardUserRequest,
): Promise<LongRunningOperationResponse> {
return await this.requestPost<LongRunningOperationResponse>(
'onboardUser',
req,
);
}
async loadCodeAssist(
req: LoadCodeAssistRequest,
): Promise<LoadCodeAssistResponse> {
try {
return await this.requestPost<LoadCodeAssistResponse>(
'loadCodeAssist',
req,
);
} catch (e) {
if (isVpcScAffectedUser(e)) {
return {
currentTier: { id: UserTierId.STANDARD },
};
} else {
throw e;
}
}
}
async getCodeAssistGlobalUserSetting(): Promise<CodeAssistGlobalUserSettingResponse> {
return await this.requestGet<CodeAssistGlobalUserSettingResponse>(
'getCodeAssistGlobalUserSetting',
);
}
async setCodeAssistGlobalUserSetting(
req: SetCodeAssistGlobalUserSettingRequest,
): Promise<CodeAssistGlobalUserSettingResponse> {
return await this.requestPost<CodeAssistGlobalUserSettingResponse>(
'setCodeAssistGlobalUserSetting',
req,
);
}
async countTokens(req: CountTokensParameters): Promise<CountTokensResponse> {
const resp = await this.requestPost<CaCountTokenResponse>(
'countTokens',
toCountTokenRequest(req),
);
return fromCountTokenResponse(resp);
}
async embedContent(
_req: EmbedContentParameters,
): Promise<EmbedContentResponse> {
throw Error();
}
async requestPost<T>(
method: string,
req: object,
signal?: AbortSignal,
): Promise<T> {
const res = await this.client.request({
url: this.getMethodUrl(method),
method: 'POST',
headers: {
'Content-Type': 'application/json',
...this.httpOptions.headers,
},
responseType: 'json',
body: JSON.stringify(req),
signal,
});
return res.data as T;
}
async requestGet<T>(method: string, signal?: AbortSignal): Promise<T> {
const res = await this.client.request({
url: this.getMethodUrl(method),
method: 'GET',
headers: {
'Content-Type': 'application/json',
...this.httpOptions.headers,
},
responseType: 'json',
signal,
});
return res.data as T;
}
async requestStreamingPost<T>(
method: string,
req: object,
signal?: AbortSignal,
): Promise<AsyncGenerator<T>> {
const res = await this.client.request({
url: this.getMethodUrl(method),
method: 'POST',
params: {
alt: 'sse',
},
headers: {
'Content-Type': 'application/json',
...this.httpOptions.headers,
},
responseType: 'stream',
body: JSON.stringify(req),
signal,
});
return (async function* (): AsyncGenerator<T> {
const rl = readline.createInterface({
input: res.data as NodeJS.ReadableStream,
crlfDelay: Infinity, // Recognizes '\r\n' and '\n' as line breaks
});
let bufferedLines: string[] = [];
for await (const line of rl) {
// blank lines are used to separate JSON objects in the stream
if (line === '') {
if (bufferedLines.length === 0) {
continue; // no data to yield
}
yield JSON.parse(bufferedLines.join('\n')) as T;
bufferedLines = []; // Reset the buffer after yielding
} else if (line.startsWith('data: ')) {
bufferedLines.push(line.slice(6).trim());
} else {
throw new Error(`Unexpected line format in response: ${line}`);
}
}
})();
}
getMethodUrl(method: string): string {
const endpoint =
process.env['CODE_ASSIST_ENDPOINT'] ?? CODE_ASSIST_ENDPOINT;
return `${endpoint}/${CODE_ASSIST_API_VERSION}:${method}`;
}
}
function isVpcScAffectedUser(error: unknown): boolean {
if (error && typeof error === 'object' && 'response' in error) {
const gaxiosError = error as {
response?: {
data?: unknown;
};
};
const response = gaxiosError.response?.data as
| GoogleRpcResponse
| undefined;
if (Array.isArray(response?.error?.details)) {
return response.error.details.some(
(detail) => detail.reason === 'SECURITY_POLICY_VIOLATED',
);
}
}
return false;
}

View File

@@ -0,0 +1,224 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
import { setupUser, ProjectIdRequiredError } from './setup.js';
import { CodeAssistServer } from '../code_assist/server.js';
import type { OAuth2Client } from 'google-auth-library';
import type { GeminiUserTier } from './types.js';
import { UserTierId } from './types.js';
vi.mock('../code_assist/server.js');
const mockPaidTier: GeminiUserTier = {
id: UserTierId.STANDARD,
name: 'paid',
description: 'Paid tier',
isDefault: true,
};
const mockFreeTier: GeminiUserTier = {
id: UserTierId.FREE,
name: 'free',
description: 'Free tier',
isDefault: true,
};
describe('setupUser for existing user', () => {
let mockLoad: ReturnType<typeof vi.fn>;
let mockOnboardUser: ReturnType<typeof vi.fn>;
beforeEach(() => {
vi.resetAllMocks();
mockLoad = vi.fn();
mockOnboardUser = vi.fn().mockResolvedValue({
done: true,
response: {
cloudaicompanionProject: {
id: 'server-project',
},
},
});
vi.mocked(CodeAssistServer).mockImplementation(
() =>
({
loadCodeAssist: mockLoad,
onboardUser: mockOnboardUser,
}) as unknown as CodeAssistServer,
);
});
afterEach(() => {
vi.unstubAllEnvs();
});
it('should use GOOGLE_CLOUD_PROJECT when set and project from server is undefined', async () => {
vi.stubEnv('GOOGLE_CLOUD_PROJECT', 'test-project');
mockLoad.mockResolvedValue({
currentTier: mockPaidTier,
});
await setupUser({} as OAuth2Client);
expect(CodeAssistServer).toHaveBeenCalledWith(
{},
'test-project',
{},
'',
undefined,
);
});
it('should ignore GOOGLE_CLOUD_PROJECT when project from server is set', async () => {
vi.stubEnv('GOOGLE_CLOUD_PROJECT', 'test-project');
mockLoad.mockResolvedValue({
cloudaicompanionProject: 'server-project',
currentTier: mockPaidTier,
});
const projectId = await setupUser({} as OAuth2Client);
expect(CodeAssistServer).toHaveBeenCalledWith(
{},
'test-project',
{},
'',
undefined,
);
expect(projectId).toEqual({
projectId: 'server-project',
userTier: 'standard-tier',
});
});
it('should throw ProjectIdRequiredError when no project ID is available', async () => {
vi.stubEnv('GOOGLE_CLOUD_PROJECT', '');
// And the server itself requires a project ID internally
vi.mocked(CodeAssistServer).mockImplementation(() => {
throw new ProjectIdRequiredError();
});
await expect(setupUser({} as OAuth2Client)).rejects.toThrow(
ProjectIdRequiredError,
);
});
});
describe('setupUser for new user', () => {
let mockLoad: ReturnType<typeof vi.fn>;
let mockOnboardUser: ReturnType<typeof vi.fn>;
beforeEach(() => {
vi.resetAllMocks();
mockLoad = vi.fn();
mockOnboardUser = vi.fn().mockResolvedValue({
done: true,
response: {
cloudaicompanionProject: {
id: 'server-project',
},
},
});
vi.mocked(CodeAssistServer).mockImplementation(
() =>
({
loadCodeAssist: mockLoad,
onboardUser: mockOnboardUser,
}) as unknown as CodeAssistServer,
);
});
afterEach(() => {
vi.unstubAllEnvs();
});
it('should use GOOGLE_CLOUD_PROJECT when set and onboard a new paid user', async () => {
vi.stubEnv('GOOGLE_CLOUD_PROJECT', 'test-project');
mockLoad.mockResolvedValue({
allowedTiers: [mockPaidTier],
});
const userData = await setupUser({} as OAuth2Client);
expect(CodeAssistServer).toHaveBeenCalledWith(
{},
'test-project',
{},
'',
undefined,
);
expect(mockLoad).toHaveBeenCalled();
expect(mockOnboardUser).toHaveBeenCalledWith({
tierId: 'standard-tier',
cloudaicompanionProject: 'test-project',
metadata: {
ideType: 'IDE_UNSPECIFIED',
platform: 'PLATFORM_UNSPECIFIED',
pluginType: 'GEMINI',
duetProject: 'test-project',
},
});
expect(userData).toEqual({
projectId: 'server-project',
userTier: 'standard-tier',
});
});
it('should onboard a new free user when GOOGLE_CLOUD_PROJECT is not set', async () => {
vi.stubEnv('GOOGLE_CLOUD_PROJECT', '');
mockLoad.mockResolvedValue({
allowedTiers: [mockFreeTier],
});
const userData = await setupUser({} as OAuth2Client);
expect(CodeAssistServer).toHaveBeenCalledWith(
{},
undefined,
{},
'',
undefined,
);
expect(mockLoad).toHaveBeenCalled();
expect(mockOnboardUser).toHaveBeenCalledWith({
tierId: 'free-tier',
cloudaicompanionProject: undefined,
metadata: {
ideType: 'IDE_UNSPECIFIED',
platform: 'PLATFORM_UNSPECIFIED',
pluginType: 'GEMINI',
},
});
expect(userData).toEqual({
projectId: 'server-project',
userTier: 'free-tier',
});
});
it('should use GOOGLE_CLOUD_PROJECT when onboard response has no project ID', async () => {
vi.stubEnv('GOOGLE_CLOUD_PROJECT', 'test-project');
mockLoad.mockResolvedValue({
allowedTiers: [mockPaidTier],
});
mockOnboardUser.mockResolvedValue({
done: true,
response: {
cloudaicompanionProject: undefined,
},
});
const userData = await setupUser({} as OAuth2Client);
expect(userData).toEqual({
projectId: 'test-project',
userTier: 'standard-tier',
});
});
it('should throw ProjectIdRequiredError when no project ID is available', async () => {
vi.stubEnv('GOOGLE_CLOUD_PROJECT', '');
mockLoad.mockResolvedValue({
allowedTiers: [mockPaidTier],
});
mockOnboardUser.mockResolvedValue({
done: true,
response: {},
});
await expect(setupUser({} as OAuth2Client)).rejects.toThrow(
ProjectIdRequiredError,
);
});
});

View File

@@ -0,0 +1,124 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import type {
ClientMetadata,
GeminiUserTier,
LoadCodeAssistResponse,
OnboardUserRequest,
} from './types.js';
import { UserTierId } from './types.js';
import { CodeAssistServer } from './server.js';
import type { OAuth2Client } from 'google-auth-library';
export class ProjectIdRequiredError extends Error {
constructor() {
super(
'This account requires setting the GOOGLE_CLOUD_PROJECT env var. See https://goo.gle/gemini-cli-auth-docs#workspace-gca',
);
}
}
export interface UserData {
projectId: string;
userTier: UserTierId;
}
/**
*
* @param projectId the user's project id, if any
* @returns the user's actual project id
*/
export async function setupUser(client: OAuth2Client): Promise<UserData> {
const projectId = process.env['GOOGLE_CLOUD_PROJECT'] || undefined;
const caServer = new CodeAssistServer(client, projectId, {}, '', undefined);
const coreClientMetadata: ClientMetadata = {
ideType: 'IDE_UNSPECIFIED',
platform: 'PLATFORM_UNSPECIFIED',
pluginType: 'GEMINI',
};
const loadRes = await caServer.loadCodeAssist({
cloudaicompanionProject: projectId,
metadata: {
...coreClientMetadata,
duetProject: projectId,
},
});
if (loadRes.currentTier) {
if (!loadRes.cloudaicompanionProject) {
if (projectId) {
return {
projectId,
userTier: loadRes.currentTier.id,
};
}
throw new ProjectIdRequiredError();
}
return {
projectId: loadRes.cloudaicompanionProject,
userTier: loadRes.currentTier.id,
};
}
const tier = getOnboardTier(loadRes);
let onboardReq: OnboardUserRequest;
if (tier.id === UserTierId.FREE) {
// The free tier uses a managed google cloud project. Setting a project in the `onboardUser` request causes a `Precondition Failed` error.
onboardReq = {
tierId: tier.id,
cloudaicompanionProject: undefined,
metadata: coreClientMetadata,
};
} else {
onboardReq = {
tierId: tier.id,
cloudaicompanionProject: projectId,
metadata: {
...coreClientMetadata,
duetProject: projectId,
},
};
}
// Poll onboardUser until long running operation is complete.
let lroRes = await caServer.onboardUser(onboardReq);
while (!lroRes.done) {
await new Promise((f) => setTimeout(f, 5000));
lroRes = await caServer.onboardUser(onboardReq);
}
if (!lroRes.response?.cloudaicompanionProject?.id) {
if (projectId) {
return {
projectId,
userTier: tier.id,
};
}
throw new ProjectIdRequiredError();
}
return {
projectId: lroRes.response.cloudaicompanionProject.id,
userTier: tier.id,
};
}
function getOnboardTier(res: LoadCodeAssistResponse): GeminiUserTier {
for (const tier of res.allowedTiers || []) {
if (tier.isDefault) {
return tier;
}
}
return {
name: '',
description: '',
id: UserTierId.LEGACY,
userDefinedCloudaicompanionProject: true,
};
}

View File

@@ -0,0 +1,201 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
export interface ClientMetadata {
ideType?: ClientMetadataIdeType;
ideVersion?: string;
pluginVersion?: string;
platform?: ClientMetadataPlatform;
updateChannel?: string;
duetProject?: string;
pluginType?: ClientMetadataPluginType;
ideName?: string;
}
export type ClientMetadataIdeType =
| 'IDE_UNSPECIFIED'
| 'VSCODE'
| 'INTELLIJ'
| 'VSCODE_CLOUD_WORKSTATION'
| 'INTELLIJ_CLOUD_WORKSTATION'
| 'CLOUD_SHELL';
export type ClientMetadataPlatform =
| 'PLATFORM_UNSPECIFIED'
| 'DARWIN_AMD64'
| 'DARWIN_ARM64'
| 'LINUX_AMD64'
| 'LINUX_ARM64'
| 'WINDOWS_AMD64';
export type ClientMetadataPluginType =
| 'PLUGIN_UNSPECIFIED'
| 'CLOUD_CODE'
| 'GEMINI'
| 'AIPLUGIN_INTELLIJ'
| 'AIPLUGIN_STUDIO';
export interface LoadCodeAssistRequest {
cloudaicompanionProject?: string;
metadata: ClientMetadata;
}
/**
* Represents LoadCodeAssistResponse proto json field
* http://google3/google/internal/cloud/code/v1internal/cloudcode.proto;l=224
*/
export interface LoadCodeAssistResponse {
currentTier?: GeminiUserTier | null;
allowedTiers?: GeminiUserTier[] | null;
ineligibleTiers?: IneligibleTier[] | null;
cloudaicompanionProject?: string | null;
}
/**
* GeminiUserTier reflects the structure received from the CodeAssist when calling LoadCodeAssist.
*/
export interface GeminiUserTier {
id: UserTierId;
name?: string;
description?: string;
// This value is used to declare whether a given tier requires the user to configure the project setting on the IDE settings or not.
userDefinedCloudaicompanionProject?: boolean | null;
isDefault?: boolean;
privacyNotice?: PrivacyNotice;
hasAcceptedTos?: boolean;
hasOnboardedPreviously?: boolean;
}
/**
* Includes information specifying the reasons for a user's ineligibility for a specific tier.
* @param reasonCode mnemonic code representing the reason for in-eligibility.
* @param reasonMessage message to display to the user.
* @param tierId id of the tier.
* @param tierName name of the tier.
*/
export interface IneligibleTier {
reasonCode: IneligibleTierReasonCode;
reasonMessage: string;
tierId: UserTierId;
tierName: string;
}
/**
* List of predefined reason codes when a tier is blocked from a specific tier.
* https://source.corp.google.com/piper///depot/google3/google/internal/cloud/code/v1internal/cloudcode.proto;l=378
*/
export enum IneligibleTierReasonCode {
// go/keep-sorted start
DASHER_USER = 'DASHER_USER',
INELIGIBLE_ACCOUNT = 'INELIGIBLE_ACCOUNT',
NON_USER_ACCOUNT = 'NON_USER_ACCOUNT',
RESTRICTED_AGE = 'RESTRICTED_AGE',
RESTRICTED_NETWORK = 'RESTRICTED_NETWORK',
UNKNOWN = 'UNKNOWN',
UNKNOWN_LOCATION = 'UNKNOWN_LOCATION',
UNSUPPORTED_LOCATION = 'UNSUPPORTED_LOCATION',
// go/keep-sorted end
}
/**
* UserTierId represents IDs returned from the Cloud Code Private API representing a user's tier
*
* //depot/google3/cloud/developer_experience/cloudcode/pa/service/usertier.go;l=16
*/
export enum UserTierId {
FREE = 'free-tier',
LEGACY = 'legacy-tier',
STANDARD = 'standard-tier',
}
/**
* PrivacyNotice reflects the structure received from the CodeAssist in regards to a tier
* privacy notice.
*/
export interface PrivacyNotice {
showNotice: boolean;
noticeText?: string;
}
/**
* Proto signature of OnboardUserRequest as payload to OnboardUser call
*/
export interface OnboardUserRequest {
tierId: string | undefined;
cloudaicompanionProject: string | undefined;
metadata: ClientMetadata | undefined;
}
/**
* Represents LongRunningOperation proto
* http://google3/google/longrunning/operations.proto;rcl=698857719;l=107
*/
export interface LongRunningOperationResponse {
name: string;
done?: boolean;
response?: OnboardUserResponse;
}
/**
* Represents OnboardUserResponse proto
* http://google3/google/internal/cloud/code/v1internal/cloudcode.proto;l=215
*/
export interface OnboardUserResponse {
// tslint:disable-next-line:enforce-name-casing This is the name of the field in the proto.
cloudaicompanionProject?: {
id: string;
name: string;
};
}
/**
* Status code of user license status
* it does not strictly correspond to the proto
* Error value is an additional value assigned to error responses from OnboardUser
*/
export enum OnboardUserStatusCode {
Default = 'DEFAULT',
Notice = 'NOTICE',
Warning = 'WARNING',
Error = 'ERROR',
}
/**
* Status of user onboarded to gemini
*/
export interface OnboardUserStatus {
statusCode: OnboardUserStatusCode;
displayMessage: string;
helpLink: HelpLinkUrl | undefined;
}
export interface HelpLinkUrl {
description: string;
url: string;
}
export interface SetCodeAssistGlobalUserSettingRequest {
cloudaicompanionProject?: string;
freeTierDataCollectionOptin: boolean;
}
export interface CodeAssistGlobalUserSettingResponse {
cloudaicompanionProject?: string;
freeTierDataCollectionOptin: boolean;
}
/**
* Relevant fields that can be returned from a Google RPC response
*/
export interface GoogleRpcResponse {
error?: {
details?: GoogleRpcErrorInfo[];
};
}
/**
* Relevant fields that can be returned in the details of an error returned from GoogleRPCs
*/
interface GoogleRpcErrorInfo {
reason?: string;
}

View File

@@ -283,6 +283,23 @@ describe('Server Config (config.ts)', () => {
expect(config.isInFallbackMode()).toBe(false);
});
it('should strip thoughts when switching from GenAI to Vertex', async () => {
const config = new Config(baseParams);
vi.mocked(createContentGeneratorConfig).mockImplementation(
(_: Config, authType: AuthType | undefined) =>
({ authType }) as unknown as ContentGeneratorConfig,
);
await config.refreshAuth(AuthType.USE_GEMINI);
await config.refreshAuth(AuthType.LOGIN_WITH_GOOGLE);
expect(
config.getGeminiClient().stripThoughtsFromHistory,
).toHaveBeenCalledWith();
});
it('should not strip thoughts when switching from Vertex to GenAI', async () => {
const config = new Config(baseParams);

View File

@@ -16,7 +16,6 @@ import { ProxyAgent, setGlobalDispatcher } from 'undici';
import type {
ContentGenerator,
ContentGeneratorConfig,
AuthType,
} from '../core/contentGenerator.js';
import type { FallbackModelHandler } from '../fallback/types.js';
import type { MCPOAuthConfig } from '../mcp/oauth-provider.js';
@@ -27,6 +26,7 @@ import type { AnyToolInvocation } from '../tools/tools.js';
import { BaseLlmClient } from '../core/baseLlmClient.js';
import { GeminiClient } from '../core/client.js';
import {
AuthType,
createContentGenerator,
createContentGeneratorConfig,
} from '../core/contentGenerator.js';
@@ -54,6 +54,7 @@ import { canUseRipgrep } from '../utils/ripgrepUtils.js';
import { RipGrepTool } from '../tools/ripGrep.js';
import { ShellTool } from '../tools/shell.js';
import { SmartEditTool } from '../tools/smart-edit.js';
import { SkillTool } from '../tools/skill.js';
import { TaskTool } from '../tools/task.js';
import { TodoWriteTool } from '../tools/todoWrite.js';
import { ToolRegistry } from '../tools/tool-registry.js';
@@ -65,6 +66,7 @@ import { WriteFileTool } from '../tools/write-file.js';
import { ideContextStore } from '../ide/ideContext.js';
import { InputFormat, OutputFormat } from '../output/types.js';
import { PromptRegistry } from '../prompts/prompt-registry.js';
import { SkillManager } from '../skills/skill-manager.js';
import { SubagentManager } from '../subagents/subagent-manager.js';
import type { SubagentConfig } from '../subagents/types.js';
import {
@@ -305,6 +307,7 @@ export interface ConfigParameters {
extensionContextFilePaths?: string[];
maxSessionTurns?: number;
sessionTokenLimit?: number;
experimentalSkills?: boolean;
experimentalZedIntegration?: boolean;
listExtensions?: boolean;
extensions?: GeminiCLIExtension[];
@@ -389,6 +392,7 @@ export class Config {
private toolRegistry!: ToolRegistry;
private promptRegistry!: PromptRegistry;
private subagentManager!: SubagentManager;
private skillManager!: SkillManager;
private fileSystemService: FileSystemService;
private contentGeneratorConfig!: ContentGeneratorConfig;
private contentGenerator!: ContentGenerator;
@@ -458,6 +462,7 @@ export class Config {
| undefined;
private readonly cliVersion?: string;
private readonly experimentalZedIntegration: boolean = false;
private readonly experimentalSkills: boolean = false;
private readonly chatRecordingEnabled: boolean;
private readonly loadMemoryFromIncludeDirectories: boolean = false;
private readonly webSearch?: {
@@ -557,6 +562,7 @@ export class Config {
this.sessionTokenLimit = params.sessionTokenLimit ?? -1;
this.experimentalZedIntegration =
params.experimentalZedIntegration ?? false;
this.experimentalSkills = params.experimentalSkills ?? false;
this.listExtensions = params.listExtensions ?? false;
this._extensions = params.extensions ?? [];
this._blockedMcpServers = params.blockedMcpServers ?? [];
@@ -644,6 +650,7 @@ export class Config {
}
this.promptRegistry = new PromptRegistry();
this.subagentManager = new SubagentManager(this);
this.skillManager = new SkillManager(this);
// Load session subagents if they were provided before initialization
if (this.sessionSubagents.length > 0) {
@@ -684,6 +691,16 @@ export class Config {
}
async refreshAuth(authMethod: AuthType, isInitialAuth?: boolean) {
// Vertex and Genai have incompatible encryption and sending history with
// throughtSignature from Genai to Vertex will fail, we need to strip them
if (
this.contentGeneratorConfig?.authType === AuthType.USE_GEMINI &&
authMethod === AuthType.LOGIN_WITH_GOOGLE
) {
// Restore the conversation history to the new client
this.geminiClient.stripThoughtsFromHistory();
}
const newContentGeneratorConfig = createContentGeneratorConfig(
this,
authMethod,
@@ -1066,6 +1083,10 @@ export class Config {
return this.experimentalZedIntegration;
}
getExperimentalSkills(): boolean {
return this.experimentalSkills;
}
getListExtensions(): boolean {
return this.listExtensions;
}
@@ -1296,6 +1317,10 @@ export class Config {
return this.subagentManager;
}
getSkillManager(): SkillManager {
return this.skillManager;
}
async createToolRegistry(
sendSdkMcpMessage?: SendSdkMcpMessage,
): Promise<ToolRegistry> {
@@ -1338,6 +1363,9 @@ export class Config {
};
registerCoreTool(TaskTool, this);
if (this.getExperimentalSkills()) {
registerCoreTool(SkillTool, this);
}
registerCoreTool(LSTool, this);
registerCoreTool(ReadFileTool, this);

View File

@@ -31,7 +31,7 @@ describe('Flash Model Fallback Configuration', () => {
config as unknown as { contentGeneratorConfig: unknown }
).contentGeneratorConfig = {
model: DEFAULT_GEMINI_MODEL,
authType: 'gemini-api-key',
authType: 'oauth-personal',
};
});

View File

@@ -126,6 +126,10 @@ export class Storage {
return path.join(this.getExtensionsDir(), 'qwen-extension.json');
}
getUserSkillsDir(): string {
return path.join(Storage.getGlobalQwenDir(), 'skills');
}
getHistoryFilePath(): string {
return path.join(this.getProjectTempDir(), 'shell_history');
}

View File

@@ -73,7 +73,6 @@ describe('OpenAIContentGenerator Timeout Handling', () => {
}),
buildClient: vi.fn().mockReturnValue(mockOpenAIClient),
buildRequest: vi.fn().mockImplementation((req) => req),
getDefaultGenerationConfig: vi.fn().mockReturnValue({}),
};
// Create generator instance
@@ -300,7 +299,6 @@ describe('OpenAIContentGenerator Timeout Handling', () => {
}),
buildClient: vi.fn().mockReturnValue(mockOpenAIClient),
buildRequest: vi.fn().mockImplementation((req) => req),
getDefaultGenerationConfig: vi.fn().mockReturnValue({}),
};
new OpenAIContentGenerator(
@@ -335,7 +333,6 @@ describe('OpenAIContentGenerator Timeout Handling', () => {
}),
buildClient: vi.fn().mockReturnValue(mockOpenAIClient),
buildRequest: vi.fn().mockImplementation((req) => req),
getDefaultGenerationConfig: vi.fn().mockReturnValue({}),
};
new OpenAIContentGenerator(

View File

@@ -146,11 +146,12 @@ describe('BaseLlmClient', () => {
// Validate the parameters passed to the underlying generator
expect(mockGenerateContent).toHaveBeenCalledTimes(1);
expect(mockGenerateContent).toHaveBeenCalledWith(
expect.objectContaining({
{
model: 'test-model',
contents: defaultOptions.contents,
config: expect.objectContaining({
config: {
abortSignal: defaultOptions.abortSignal,
topP: 0.8,
tools: [
{
functionDeclarations: [
@@ -162,8 +163,9 @@ describe('BaseLlmClient', () => {
],
},
],
}),
}),
// Crucial: systemInstruction should NOT be in the config object if not provided
},
},
'test-prompt-id',
);
});
@@ -186,6 +188,7 @@ describe('BaseLlmClient', () => {
expect.objectContaining({
config: expect.objectContaining({
temperature: 0.8,
topP: 0.8, // Default should remain if not overridden
topK: 10,
tools: expect.any(Array),
}),

View File

@@ -64,6 +64,11 @@ export interface GenerateJsonOptions {
* A client dedicated to stateless, utility-focused LLM calls.
*/
export class BaseLlmClient {
// Default configuration for utility tasks
private readonly defaultUtilityConfig: GenerateContentConfig = {
topP: 0.8,
};
constructor(
private readonly contentGenerator: ContentGenerator,
private readonly config: Config,
@@ -84,6 +89,7 @@ export class BaseLlmClient {
const requestConfig: GenerateContentConfig = {
abortSignal,
...this.defaultUtilityConfig,
...options.config,
...(systemInstruction && { systemInstruction }),
};

View File

@@ -15,7 +15,11 @@ import {
} from 'vitest';
import type { Content, GenerateContentResponse, Part } from '@google/genai';
import { GeminiClient } from './client.js';
import {
isThinkingDefault,
isThinkingSupported,
GeminiClient,
} from './client.js';
import { findCompressSplitPoint } from '../services/chatCompressionService.js';
import {
AuthType,
@@ -243,6 +247,40 @@ describe('findCompressSplitPoint', () => {
});
});
describe('isThinkingSupported', () => {
it('should return true for gemini-2.5', () => {
expect(isThinkingSupported('gemini-2.5')).toBe(true);
});
it('should return true for gemini-2.5-pro', () => {
expect(isThinkingSupported('gemini-2.5-pro')).toBe(true);
});
it('should return false for other models', () => {
expect(isThinkingSupported('gemini-1.5-flash')).toBe(false);
expect(isThinkingSupported('some-other-model')).toBe(false);
});
});
describe('isThinkingDefault', () => {
it('should return false for gemini-2.5-flash-lite', () => {
expect(isThinkingDefault('gemini-2.5-flash-lite')).toBe(false);
});
it('should return true for gemini-2.5', () => {
expect(isThinkingDefault('gemini-2.5')).toBe(true);
});
it('should return true for gemini-2.5-pro', () => {
expect(isThinkingDefault('gemini-2.5-pro')).toBe(true);
});
it('should return false for other models', () => {
expect(isThinkingDefault('gemini-1.5-flash')).toBe(false);
expect(isThinkingDefault('some-other-model')).toBe(false);
});
});
describe('Gemini Client (client.ts)', () => {
let mockContentGenerator: ContentGenerator;
let mockConfig: Config;
@@ -2266,15 +2304,16 @@ ${JSON.stringify(
);
expect(mockContentGenerator.generateContent).toHaveBeenCalledWith(
expect.objectContaining({
{
model: DEFAULT_GEMINI_FLASH_MODEL,
config: expect.objectContaining({
config: {
abortSignal,
systemInstruction: getCoreSystemPrompt(''),
temperature: 0.5,
}),
topP: 0.8,
},
contents,
}),
},
'test-session-id',
);
});

View File

@@ -15,7 +15,11 @@ import type {
// Config
import { ApprovalMode, type Config } from '../config/config.js';
import { DEFAULT_GEMINI_FLASH_MODEL } from '../config/models.js';
import {
DEFAULT_GEMINI_FLASH_MODEL,
DEFAULT_GEMINI_MODEL_AUTO,
DEFAULT_THINKING_MODE,
} from '../config/models.js';
// Core modules
import type { ContentGenerator } from './contentGenerator.js';
@@ -74,10 +78,24 @@ import { type File, type IdeContext } from '../ide/types.js';
// Fallback handling
import { handleFallback } from '../fallback/handler.js';
export function isThinkingSupported(model: string) {
return model.startsWith('gemini-2.5') || model === DEFAULT_GEMINI_MODEL_AUTO;
}
export function isThinkingDefault(model: string) {
if (model.startsWith('gemini-2.5-flash-lite')) {
return false;
}
return model.startsWith('gemini-2.5') || model === DEFAULT_GEMINI_MODEL_AUTO;
}
const MAX_TURNS = 100;
export class GeminiClient {
private chat?: GeminiChat;
private readonly generateContentConfig: GenerateContentConfig = {
topP: 0.8,
};
private sessionTurnCount = 0;
private readonly loopDetector: LoopDetectionService;
@@ -189,10 +207,20 @@ export class GeminiClient {
const model = this.config.getModel();
const systemInstruction = getCoreSystemPrompt(userMemory, model);
const config: GenerateContentConfig = { ...this.generateContentConfig };
if (isThinkingSupported(model)) {
config.thinkingConfig = {
includeThoughts: true,
thinkingBudget: DEFAULT_THINKING_MODE,
};
}
return new GeminiChat(
this.config,
{
systemInstruction,
...config,
tools,
},
history,
@@ -589,6 +617,11 @@ export class GeminiClient {
): Promise<GenerateContentResponse> {
let currentAttemptModel: string = model;
const configToUse: GenerateContentConfig = {
...this.generateContentConfig,
...generationConfig,
};
try {
const userMemory = this.config.getUserMemory();
const finalSystemInstruction = generationConfig.systemInstruction
@@ -597,7 +630,7 @@ export class GeminiClient {
const requestConfig: GenerateContentConfig = {
abortSignal,
...generationConfig,
...configToUse,
systemInstruction: finalSystemInstruction,
};
@@ -638,7 +671,7 @@ export class GeminiClient {
`Error generating content via API with model ${currentAttemptModel}.`,
{
requestContents: contents,
requestConfig: generationConfig,
requestConfig: configToUse,
},
'generateContent-api',
);

View File

@@ -5,19 +5,42 @@
*/
import { describe, it, expect, vi } from 'vitest';
import type { ContentGenerator } from './contentGenerator.js';
import { createContentGenerator, AuthType } from './contentGenerator.js';
import { createCodeAssistContentGenerator } from '../code_assist/codeAssist.js';
import { GoogleGenAI } from '@google/genai';
import type { Config } from '../config/config.js';
import { LoggingContentGenerator } from './geminiContentGenerator/loggingContentGenerator.js';
import { LoggingContentGenerator } from './loggingContentGenerator.js';
vi.mock('../code_assist/codeAssist.js');
vi.mock('@google/genai');
const mockConfig = {
getCliVersion: vi.fn().mockReturnValue('1.0.0'),
} as unknown as Config;
describe('createContentGenerator', () => {
it('should create a Gemini content generator', async () => {
it('should create a CodeAssistContentGenerator', async () => {
const mockGenerator = {} as unknown as ContentGenerator;
vi.mocked(createCodeAssistContentGenerator).mockResolvedValue(
mockGenerator as never,
);
const generator = await createContentGenerator(
{
model: 'test-model',
authType: AuthType.LOGIN_WITH_GOOGLE,
},
mockConfig,
);
expect(createCodeAssistContentGenerator).toHaveBeenCalled();
expect(generator).toEqual(
new LoggingContentGenerator(mockGenerator, mockConfig),
);
});
it('should create a GoogleGenAI content generator', async () => {
const mockConfig = {
getUsageStatisticsEnabled: () => true,
getContentGeneratorConfig: () => ({}),
getCliVersion: () => '1.0.0',
} as unknown as Config;
const mockGenerator = {
@@ -42,17 +65,17 @@ describe('createContentGenerator', () => {
},
},
});
// We expect it to be a LoggingContentGenerator wrapping a GeminiContentGenerator
expect(generator).toBeInstanceOf(LoggingContentGenerator);
const wrapped = (generator as LoggingContentGenerator).getWrapped();
expect(wrapped).toBeDefined();
expect(generator).toEqual(
new LoggingContentGenerator(
(mockGenerator as GoogleGenAI).models,
mockConfig,
),
);
});
it('should create a Gemini content generator with client install id logging disabled', async () => {
it('should create a GoogleGenAI content generator with client install id logging disabled', async () => {
const mockConfig = {
getUsageStatisticsEnabled: () => false,
getContentGeneratorConfig: () => ({}),
getCliVersion: () => '1.0.0',
} as unknown as Config;
const mockGenerator = {
models: {},
@@ -75,6 +98,11 @@ describe('createContentGenerator', () => {
},
},
});
expect(generator).toBeInstanceOf(LoggingContentGenerator);
expect(generator).toEqual(
new LoggingContentGenerator(
(mockGenerator as GoogleGenAI).models,
mockConfig,
),
);
});
});

View File

@@ -12,9 +12,15 @@ import type {
GenerateContentParameters,
GenerateContentResponse,
} from '@google/genai';
import { GoogleGenAI } from '@google/genai';
import { createCodeAssistContentGenerator } from '../code_assist/codeAssist.js';
import { DEFAULT_QWEN_MODEL } from '../config/models.js';
import type { Config } from '../config/config.js';
import type { UserTierId } from '../code_assist/types.js';
import { InstallationManager } from '../utils/installationManager.js';
import { LoggingContentGenerator } from './loggingContentGenerator.js';
/**
* Interface abstracting the core functionalities for generating content and counting tokens.
*/
@@ -33,12 +39,14 @@ export interface ContentGenerator {
embedContent(request: EmbedContentParameters): Promise<EmbedContentResponse>;
useSummarizedThinking(): boolean;
userTier?: UserTierId;
}
export enum AuthType {
LOGIN_WITH_GOOGLE = 'oauth-personal',
USE_GEMINI = 'gemini-api-key',
USE_VERTEX_AI = 'vertex-ai',
CLOUD_SHELL = 'cloud-shell',
USE_OPENAI = 'openai',
QWEN_OAUTH = 'qwen-oauth',
}
@@ -51,9 +59,12 @@ export type ContentGeneratorConfig = {
authType?: AuthType | undefined;
enableOpenAILogging?: boolean;
openAILoggingDir?: string;
timeout?: number; // Timeout configuration in milliseconds
maxRetries?: number; // Maximum retries for failed requests
disableCacheControl?: boolean; // Disable cache control for DashScope providers
// Timeout configuration in milliseconds
timeout?: number;
// Maximum retries for failed requests
maxRetries?: number;
// Disable cache control for DashScope providers
disableCacheControl?: boolean;
samplingParams?: {
top_p?: number;
top_k?: number;
@@ -63,9 +74,6 @@ export type ContentGeneratorConfig = {
temperature?: number;
max_tokens?: number;
};
reasoning?: {
effort?: 'low' | 'medium' | 'high';
};
proxy?: string | undefined;
userAgent?: string;
// Schema compliance mode for tool definitions
@@ -115,14 +123,48 @@ export async function createContentGenerator(
gcConfig: Config,
isInitialAuth?: boolean,
): Promise<ContentGenerator> {
const version = process.env['CLI_VERSION'] || process.version;
const userAgent = `QwenCode/${version} (${process.platform}; ${process.arch})`;
const baseHeaders: Record<string, string> = {
'User-Agent': userAgent,
};
if (
config.authType === AuthType.LOGIN_WITH_GOOGLE ||
config.authType === AuthType.CLOUD_SHELL
) {
const httpOptions = { headers: baseHeaders };
return new LoggingContentGenerator(
await createCodeAssistContentGenerator(
httpOptions,
config.authType,
gcConfig,
),
gcConfig,
);
}
if (
config.authType === AuthType.USE_GEMINI ||
config.authType === AuthType.USE_VERTEX_AI
) {
const { createGeminiContentGenerator } = await import(
'./geminiContentGenerator/index.js'
);
return createGeminiContentGenerator(config, gcConfig);
let headers: Record<string, string> = { ...baseHeaders };
if (gcConfig?.getUsageStatisticsEnabled()) {
const installationManager = new InstallationManager();
const installationId = installationManager.getInstallationId();
headers = {
...headers,
'x-gemini-api-privileged-user-id': `${installationId}`,
};
}
const httpOptions = { headers };
const googleGenAI = new GoogleGenAI({
apiKey: config.apiKey === '' ? undefined : config.apiKey,
vertexai: config.vertexai,
httpOptions,
});
return new LoggingContentGenerator(googleGenAI.models, gcConfig);
}
if (config.authType === AuthType.USE_OPENAI) {

View File

@@ -240,7 +240,7 @@ describe('CoreToolScheduler', () => {
getAllowedTools: () => [],
getContentGeneratorConfig: () => ({
model: 'test-model',
authType: 'gemini-api-key',
authType: 'oauth-personal',
}),
getShellExecutionConfig: () => ({
terminalWidth: 90,
@@ -318,7 +318,7 @@ describe('CoreToolScheduler', () => {
getAllowedTools: () => [],
getContentGeneratorConfig: () => ({
model: 'test-model',
authType: 'gemini-api-key',
authType: 'oauth-personal',
}),
getShellExecutionConfig: () => ({
terminalWidth: 90,
@@ -497,7 +497,7 @@ describe('CoreToolScheduler', () => {
getExcludeTools: () => ['write_file', 'edit', 'run_shell_command'],
getContentGeneratorConfig: () => ({
model: 'test-model',
authType: 'gemini-api-key',
authType: 'oauth-personal',
}),
getShellExecutionConfig: () => ({
terminalWidth: 90,
@@ -584,7 +584,7 @@ describe('CoreToolScheduler', () => {
getExcludeTools: () => ['write_file', 'edit'], // Different excluded tools
getContentGeneratorConfig: () => ({
model: 'test-model',
authType: 'gemini-api-key',
authType: 'oauth-personal',
}),
getShellExecutionConfig: () => ({
terminalWidth: 90,
@@ -674,7 +674,7 @@ describe('CoreToolScheduler with payload', () => {
getAllowedTools: () => [],
getContentGeneratorConfig: () => ({
model: 'test-model',
authType: 'gemini-api-key',
authType: 'oauth-personal',
}),
getShellExecutionConfig: () => ({
terminalWidth: 90,
@@ -1001,7 +1001,7 @@ describe('CoreToolScheduler edit cancellation', () => {
getAllowedTools: () => [],
getContentGeneratorConfig: () => ({
model: 'test-model',
authType: 'gemini-api-key',
authType: 'oauth-personal',
}),
getShellExecutionConfig: () => ({
terminalWidth: 90,
@@ -1108,7 +1108,7 @@ describe('CoreToolScheduler YOLO mode', () => {
getAllowedTools: () => [],
getContentGeneratorConfig: () => ({
model: 'test-model',
authType: 'gemini-api-key',
authType: 'oauth-personal',
}),
getShellExecutionConfig: () => ({
terminalWidth: 90,
@@ -1258,7 +1258,7 @@ describe('CoreToolScheduler cancellation during executing with live output', ()
getApprovalMode: () => ApprovalMode.DEFAULT,
getContentGeneratorConfig: () => ({
model: 'test-model',
authType: 'gemini-api-key',
authType: 'oauth-personal',
}),
getToolRegistry: () => mockToolRegistry,
getShellExecutionConfig: () => ({
@@ -1350,7 +1350,7 @@ describe('CoreToolScheduler request queueing', () => {
getAllowedTools: () => [],
getContentGeneratorConfig: () => ({
model: 'test-model',
authType: 'gemini-api-key',
authType: 'oauth-personal',
}),
getShellExecutionConfig: () => ({
terminalWidth: 90,
@@ -1482,7 +1482,7 @@ describe('CoreToolScheduler request queueing', () => {
getToolRegistry: () => toolRegistry,
getContentGeneratorConfig: () => ({
model: 'test-model',
authType: 'gemini-api-key',
authType: 'oauth-personal',
}),
getShellExecutionConfig: () => ({
terminalWidth: 80,
@@ -1586,7 +1586,7 @@ describe('CoreToolScheduler request queueing', () => {
getAllowedTools: () => [],
getContentGeneratorConfig: () => ({
model: 'test-model',
authType: 'gemini-api-key',
authType: 'oauth-personal',
}),
getShellExecutionConfig: () => ({
terminalWidth: 90,
@@ -1854,7 +1854,7 @@ describe('CoreToolScheduler Sequential Execution', () => {
getAllowedTools: () => [],
getContentGeneratorConfig: () => ({
model: 'test-model',
authType: 'gemini-api-key',
authType: 'oauth-personal',
}),
getShellExecutionConfig: () => ({
terminalWidth: 90,
@@ -1975,7 +1975,7 @@ describe('CoreToolScheduler Sequential Execution', () => {
getAllowedTools: () => [],
getContentGeneratorConfig: () => ({
model: 'test-model',
authType: 'gemini-api-key',
authType: 'oauth-personal',
}),
getShellExecutionConfig: () => ({
terminalWidth: 90,

View File

@@ -100,7 +100,6 @@ describe('GeminiChat', () => {
countTokens: vi.fn(),
embedContent: vi.fn(),
batchEmbedContents: vi.fn(),
useSummarizedThinking: vi.fn().mockReturnValue(false),
} as unknown as ContentGenerator;
mockHandleFallback.mockClear();
@@ -112,7 +111,7 @@ describe('GeminiChat', () => {
getUsageStatisticsEnabled: () => true,
getDebugMode: () => false,
getContentGeneratorConfig: vi.fn().mockReturnValue({
authType: 'gemini-api-key', // Ensure this is set for fallback tests
authType: 'oauth-personal', // Ensure this is set for fallback tests
model: 'test-model',
}),
getModel: vi.fn().mockReturnValue('gemini-pro'),
@@ -719,99 +718,6 @@ describe('GeminiChat', () => {
1,
);
});
it('should handle summarized thinking by conditionally including thoughts in history', async () => {
// Case 1: useSummarizedThinking is true -> thoughts NOT in history
vi.mocked(mockContentGenerator.useSummarizedThinking).mockReturnValue(
true,
);
const stream1 = (async function* () {
yield {
candidates: [
{
content: {
role: 'model',
parts: [{ thought: true, text: 'T1' }, { text: 'A1' }],
},
finishReason: 'STOP',
},
],
} as unknown as GenerateContentResponse;
})();
vi.mocked(mockContentGenerator.generateContentStream).mockResolvedValue(
stream1,
);
const res1 = await chat.sendMessageStream('m1', { message: 'h1' }, 'p1');
for await (const _ of res1);
const history1 = chat.getHistory();
expect(history1[1].parts).toEqual([{ text: 'A1' }]);
// Case 2: useSummarizedThinking is false -> thoughts ARE in history
chat.clearHistory();
vi.mocked(mockContentGenerator.useSummarizedThinking).mockReturnValue(
false,
);
const stream2 = (async function* () {
yield {
candidates: [
{
content: {
role: 'model',
parts: [{ thought: true, text: 'T2' }, { text: 'A2' }],
},
finishReason: 'STOP',
},
],
} as unknown as GenerateContentResponse;
})();
vi.mocked(mockContentGenerator.generateContentStream).mockResolvedValue(
stream2,
);
const res2 = await chat.sendMessageStream('m1', { message: 'h1' }, 'p2');
for await (const _ of res2);
const history2 = chat.getHistory();
expect(history2[1].parts).toEqual([
{ text: 'T2', thought: true },
{ text: 'A2' },
]);
});
it('should keep parts with thoughtSignature when consolidating history', async () => {
const stream = (async function* () {
yield {
candidates: [
{
content: {
role: 'model',
parts: [
{
text: 'p1',
thoughtSignature: 's1',
} as unknown as { text: string; thoughtSignature: string },
],
},
finishReason: 'STOP',
},
],
} as unknown as GenerateContentResponse;
})();
vi.mocked(mockContentGenerator.generateContentStream).mockResolvedValue(
stream,
);
const res = await chat.sendMessageStream('m1', { message: 'h1' }, 'p1');
for await (const _ of res);
const history = chat.getHistory();
expect(history[1].parts![0]).toEqual({
text: 'p1',
thoughtSignature: 's1',
});
});
});
describe('addHistory', () => {
@@ -1476,7 +1382,7 @@ describe('GeminiChat', () => {
});
it('should call handleFallback with the specific failed model and retry if handler returns true', async () => {
const authType = AuthType.USE_GEMINI;
const authType = AuthType.LOGIN_WITH_GOOGLE;
vi.mocked(mockConfig.getContentGeneratorConfig).mockReturnValue({
model: 'test-model',
authType,
@@ -1626,7 +1532,7 @@ describe('GeminiChat', () => {
});
describe('stripThoughtsFromHistory', () => {
it('should strip thoughts and thought signatures, and remove empty content objects', () => {
it('should strip thought signatures', () => {
chat.setHistory([
{
role: 'user',
@@ -1638,15 +1544,10 @@ describe('GeminiChat', () => {
{ text: 'thinking...', thought: true },
{ text: 'hi' },
{
text: 'hidden metadata',
thoughtSignature: 'abc',
} as unknown as { text: string; thoughtSignature: string },
functionCall: { name: 'test', args: {} },
},
],
},
{
role: 'model',
parts: [{ text: 'only thinking', thought: true }],
},
]);
chat.stripThoughtsFromHistory();
@@ -1658,7 +1559,7 @@ describe('GeminiChat', () => {
},
{
role: 'model',
parts: [{ text: 'hi' }, { text: 'hidden metadata' }],
parts: [{ text: 'hi' }, { functionCall: { name: 'test', args: {} } }],
},
]);
});

View File

@@ -92,7 +92,6 @@ export function isValidNonThoughtTextPart(part: Part): boolean {
return (
typeof part.text === 'string' &&
!part.thought &&
!part.thoughtSignature &&
// Technically, the model should never generate parts that have text and
// any of these but we don't trust them so check anyways.
!part.functionCall &&
@@ -110,24 +109,18 @@ function isValidContent(content: Content): boolean {
if (part === undefined || Object.keys(part).length === 0) {
return false;
}
if (!isValidContentPart(part)) {
if (
!part.thought &&
part.text !== undefined &&
part.text === '' &&
part.functionCall === undefined
) {
return false;
}
}
return true;
}
function isValidContentPart(part: Part): boolean {
const isInvalid =
!part.thought &&
!part.thoughtSignature &&
part.text !== undefined &&
part.text === '' &&
part.functionCall === undefined;
return !isInvalid;
}
/**
* Validates the history contains the correct roles.
*
@@ -455,29 +448,15 @@ export class GeminiChat {
if (!content.parts) return content;
// Filter out thought parts entirely
const filteredParts = content.parts
.filter(
(part) =>
!(
part &&
typeof part === 'object' &&
'thought' in part &&
part.thought
),
)
.map((part) => {
if (
const filteredParts = content.parts.filter(
(part) =>
!(
part &&
typeof part === 'object' &&
'thoughtSignature' in part
) {
const newPart = { ...part };
delete (newPart as { thoughtSignature?: string })
.thoughtSignature;
return newPart;
}
return part;
});
'thought' in part &&
part.thought
),
);
return {
...content,
@@ -559,15 +538,11 @@ export class GeminiChat {
yield chunk; // Yield every chunk to the UI immediately.
}
let thoughtText = '';
// Only include thoughts if not using summarized thinking.
if (!this.config.getContentGenerator().useSummarizedThinking()) {
thoughtText = allModelParts
.filter((part) => part.thought)
.map((part) => part.text)
.join('')
.trim();
}
const thoughtParts = allModelParts.filter((part) => part.thought);
const thoughtText = thoughtParts
.map((part) => part.text)
.join('')
.trim();
const contentParts = allModelParts.filter((part) => !part.thought);
const consolidatedHistoryParts: Part[] = [];
@@ -580,7 +555,7 @@ export class GeminiChat {
isValidNonThoughtTextPart(part)
) {
lastPart.text += part.text;
} else if (isValidContentPart(part)) {
} else {
consolidatedHistoryParts.push(part);
}
}

View File

@@ -1,173 +0,0 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { describe, it, expect, vi, beforeEach } from 'vitest';
import { GeminiContentGenerator } from './geminiContentGenerator.js';
import { GoogleGenAI } from '@google/genai';
vi.mock('@google/genai', () => {
const mockGenerateContent = vi.fn();
const mockGenerateContentStream = vi.fn();
const mockCountTokens = vi.fn();
const mockEmbedContent = vi.fn();
return {
GoogleGenAI: vi.fn().mockImplementation(() => ({
models: {
generateContent: mockGenerateContent,
generateContentStream: mockGenerateContentStream,
countTokens: mockCountTokens,
embedContent: mockEmbedContent,
},
})),
};
});
describe('GeminiContentGenerator', () => {
let generator: GeminiContentGenerator;
// eslint-disable-next-line @typescript-eslint/no-explicit-any
let mockGoogleGenAI: any;
beforeEach(() => {
vi.clearAllMocks();
generator = new GeminiContentGenerator({
apiKey: 'test-api-key',
});
mockGoogleGenAI = vi.mocked(GoogleGenAI).mock.results[0].value;
});
it('should call generateContent on the underlying model', async () => {
const request = { model: 'gemini-1.5-flash', contents: [] };
const expectedResponse = { responseId: 'test-id' };
mockGoogleGenAI.models.generateContent.mockResolvedValue(expectedResponse);
const response = await generator.generateContent(request, 'prompt-id');
expect(mockGoogleGenAI.models.generateContent).toHaveBeenCalledWith(
expect.objectContaining({
...request,
config: expect.objectContaining({
temperature: 1,
topP: 0.95,
thinkingConfig: {
includeThoughts: true,
thinkingLevel: 'THINKING_LEVEL_UNSPECIFIED',
},
}),
}),
);
expect(response).toBe(expectedResponse);
});
it('should call generateContentStream on the underlying model', async () => {
const request = { model: 'gemini-1.5-flash', contents: [] };
const mockStream = (async function* () {
yield { responseId: '1' };
})();
mockGoogleGenAI.models.generateContentStream.mockResolvedValue(mockStream);
const stream = await generator.generateContentStream(request, 'prompt-id');
expect(mockGoogleGenAI.models.generateContentStream).toHaveBeenCalledWith(
expect.objectContaining({
...request,
config: expect.objectContaining({
temperature: 1,
topP: 0.95,
thinkingConfig: {
includeThoughts: true,
thinkingLevel: 'THINKING_LEVEL_UNSPECIFIED',
},
}),
}),
);
expect(stream).toBe(mockStream);
});
it('should call countTokens on the underlying model', async () => {
const request = { model: 'gemini-1.5-flash', contents: [] };
const expectedResponse = { totalTokens: 10 };
mockGoogleGenAI.models.countTokens.mockResolvedValue(expectedResponse);
const response = await generator.countTokens(request);
expect(mockGoogleGenAI.models.countTokens).toHaveBeenCalledWith(request);
expect(response).toBe(expectedResponse);
});
it('should call embedContent on the underlying model', async () => {
const request = { model: 'embedding-model', contents: [] };
const expectedResponse = { embeddings: [] };
mockGoogleGenAI.models.embedContent.mockResolvedValue(expectedResponse);
const response = await generator.embedContent(request);
expect(mockGoogleGenAI.models.embedContent).toHaveBeenCalledWith(request);
expect(response).toBe(expectedResponse);
});
it('should prioritize contentGeneratorConfig samplingParams over request config', async () => {
const generatorWithParams = new GeminiContentGenerator({ apiKey: 'test' }, {
model: 'gemini-1.5-flash',
samplingParams: {
temperature: 0.1,
top_p: 0.2,
},
// eslint-disable-next-line @typescript-eslint/no-explicit-any
} as any);
const request = {
model: 'gemini-1.5-flash',
contents: [],
config: {
temperature: 0.9,
topP: 0.9,
},
};
await generatorWithParams.generateContent(request, 'prompt-id');
expect(mockGoogleGenAI.models.generateContent).toHaveBeenCalledWith(
expect.objectContaining({
config: expect.objectContaining({
temperature: 0.1,
topP: 0.2,
}),
}),
);
});
it('should map reasoning effort to thinkingConfig', async () => {
const generatorWithReasoning = new GeminiContentGenerator(
{ apiKey: 'test' },
{
model: 'gemini-2.5-pro',
reasoning: {
effort: 'high',
},
// eslint-disable-next-line @typescript-eslint/no-explicit-any
} as any,
);
const request = {
model: 'gemini-2.5-pro',
contents: [],
};
await generatorWithReasoning.generateContent(request, 'prompt-id');
expect(mockGoogleGenAI.models.generateContent).toHaveBeenCalledWith(
expect.objectContaining({
config: expect.objectContaining({
thinkingConfig: {
includeThoughts: true,
thinkingLevel: 'HIGH',
},
}),
}),
);
});
});

View File

@@ -1,144 +0,0 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import type {
CountTokensParameters,
CountTokensResponse,
EmbedContentParameters,
EmbedContentResponse,
GenerateContentParameters,
GenerateContentResponse,
GenerateContentConfig,
ThinkingLevel,
} from '@google/genai';
import { GoogleGenAI } from '@google/genai';
import type {
ContentGenerator,
ContentGeneratorConfig,
} from '../contentGenerator.js';
/**
* A wrapper for GoogleGenAI that implements the ContentGenerator interface.
*/
export class GeminiContentGenerator implements ContentGenerator {
private readonly googleGenAI: GoogleGenAI;
private readonly contentGeneratorConfig?: ContentGeneratorConfig;
constructor(
options: {
apiKey?: string;
vertexai?: boolean;
httpOptions?: { headers: Record<string, string> };
},
contentGeneratorConfig?: ContentGeneratorConfig,
) {
this.googleGenAI = new GoogleGenAI(options);
this.contentGeneratorConfig = contentGeneratorConfig;
}
private buildSamplingParameters(
request: GenerateContentParameters,
): GenerateContentConfig {
const configSamplingParams = this.contentGeneratorConfig?.samplingParams;
const requestConfig = request.config || {};
// Helper function to get parameter value with priority: config > request > default
const getParameterValue = <T>(
configValue: T | undefined,
requestKey: keyof GenerateContentConfig,
defaultValue?: T,
): T | undefined => {
const requestValue = requestConfig[requestKey] as T | undefined;
if (configValue !== undefined) return configValue;
if (requestValue !== undefined) return requestValue;
return defaultValue;
};
return {
...requestConfig,
temperature: getParameterValue<number>(
configSamplingParams?.temperature,
'temperature',
1,
),
topP: getParameterValue<number>(
configSamplingParams?.top_p,
'topP',
0.95,
),
topK: getParameterValue<number>(configSamplingParams?.top_k, 'topK', 64),
maxOutputTokens: getParameterValue<number>(
configSamplingParams?.max_tokens,
'maxOutputTokens',
),
presencePenalty: getParameterValue<number>(
configSamplingParams?.presence_penalty,
'presencePenalty',
),
frequencyPenalty: getParameterValue<number>(
configSamplingParams?.frequency_penalty,
'frequencyPenalty',
),
thinkingConfig: getParameterValue(
this.contentGeneratorConfig?.reasoning
? {
includeThoughts: true,
thinkingLevel: (this.contentGeneratorConfig.reasoning.effort ===
'low'
? 'LOW'
: this.contentGeneratorConfig.reasoning.effort === 'high'
? 'HIGH'
: 'THINKING_LEVEL_UNSPECIFIED') as ThinkingLevel,
}
: undefined,
'thinkingConfig',
{
includeThoughts: true,
thinkingLevel: 'THINKING_LEVEL_UNSPECIFIED' as ThinkingLevel,
},
),
};
}
async generateContent(
request: GenerateContentParameters,
_userPromptId: string,
): Promise<GenerateContentResponse> {
const finalRequest = {
...request,
config: this.buildSamplingParameters(request),
};
return this.googleGenAI.models.generateContent(finalRequest);
}
async generateContentStream(
request: GenerateContentParameters,
_userPromptId: string,
): Promise<AsyncGenerator<GenerateContentResponse>> {
const finalRequest = {
...request,
config: this.buildSamplingParameters(request),
};
return this.googleGenAI.models.generateContentStream(finalRequest);
}
async countTokens(
request: CountTokensParameters,
): Promise<CountTokensResponse> {
return this.googleGenAI.models.countTokens(request);
}
async embedContent(
request: EmbedContentParameters,
): Promise<EmbedContentResponse> {
return this.googleGenAI.models.embedContent(request);
}
useSummarizedThinking(): boolean {
return true;
}
}

View File

@@ -1,47 +0,0 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { describe, it, expect, vi, beforeEach } from 'vitest';
import { createGeminiContentGenerator } from './index.js';
import { GeminiContentGenerator } from './geminiContentGenerator.js';
import { LoggingContentGenerator } from './loggingContentGenerator.js';
import type { Config } from '../../config/config.js';
import { AuthType } from '../contentGenerator.js';
vi.mock('./geminiContentGenerator.js', () => ({
GeminiContentGenerator: vi.fn().mockImplementation(() => ({})),
}));
vi.mock('./loggingContentGenerator.js', () => ({
LoggingContentGenerator: vi.fn().mockImplementation((wrapped) => wrapped),
}));
describe('createGeminiContentGenerator', () => {
let mockConfig: Config;
beforeEach(() => {
vi.clearAllMocks();
mockConfig = {
getUsageStatisticsEnabled: vi.fn().mockReturnValue(false),
getContentGeneratorConfig: vi.fn().mockReturnValue({}),
getCliVersion: vi.fn().mockReturnValue('1.0.0'),
} as unknown as Config;
});
it('should create a GeminiContentGenerator wrapped in LoggingContentGenerator', () => {
const config = {
model: 'gemini-1.5-flash',
apiKey: 'test-key',
authType: AuthType.USE_GEMINI,
};
const generator = createGeminiContentGenerator(config, mockConfig);
expect(GeminiContentGenerator).toHaveBeenCalled();
expect(LoggingContentGenerator).toHaveBeenCalled();
expect(generator).toBeDefined();
});
});

View File

@@ -1,55 +0,0 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { GeminiContentGenerator } from './geminiContentGenerator.js';
import type {
ContentGenerator,
ContentGeneratorConfig,
} from '../contentGenerator.js';
import type { Config } from '../../config/config.js';
import { InstallationManager } from '../../utils/installationManager.js';
import { LoggingContentGenerator } from './loggingContentGenerator.js';
export { GeminiContentGenerator } from './geminiContentGenerator.js';
export { LoggingContentGenerator } from './loggingContentGenerator.js';
/**
* Create a Gemini content generator.
*/
export function createGeminiContentGenerator(
config: ContentGeneratorConfig,
gcConfig: Config,
): ContentGenerator {
const version = process.env['CLI_VERSION'] || process.version;
const userAgent =
config.userAgent ||
`QwenCode/${version} (${process.platform}; ${process.arch})`;
const baseHeaders: Record<string, string> = {
'User-Agent': userAgent,
};
let headers: Record<string, string> = { ...baseHeaders };
if (gcConfig?.getUsageStatisticsEnabled()) {
const installationManager = new InstallationManager();
const installationId = installationManager.getInstallationId();
headers = {
...headers,
'x-gemini-api-privileged-user-id': `${installationId}`,
};
}
const httpOptions = { headers };
const geminiContentGenerator = new GeminiContentGenerator(
{
apiKey: config.apiKey === '' ? undefined : config.apiKey,
vertexai: config.vertexai,
httpOptions,
},
config,
);
return new LoggingContentGenerator(geminiContentGenerator, gcConfig);
}

View File

@@ -13,24 +13,21 @@ import type {
GenerateContentParameters,
GenerateContentResponseUsageMetadata,
GenerateContentResponse,
ContentListUnion,
ContentUnion,
Part,
PartUnion,
} from '@google/genai';
import {
ApiRequestEvent,
ApiResponseEvent,
ApiErrorEvent,
} from '../../telemetry/types.js';
import type { Config } from '../../config/config.js';
} from '../telemetry/types.js';
import type { Config } from '../config/config.js';
import {
logApiError,
logApiRequest,
logApiResponse,
} from '../../telemetry/loggers.js';
import type { ContentGenerator } from '../contentGenerator.js';
import { isStructuredError } from '../../utils/quotaErrorDetection.js';
} from '../telemetry/loggers.js';
import type { ContentGenerator } from './contentGenerator.js';
import { toContents } from '../code_assist/converter.js';
import { isStructuredError } from '../utils/quotaErrorDetection.js';
interface StructuredError {
status: number;
@@ -115,7 +112,7 @@ export class LoggingContentGenerator implements ContentGenerator {
userPromptId: string,
): Promise<GenerateContentResponse> {
const startTime = Date.now();
this.logApiRequest(this.toContents(req.contents), req.model, userPromptId);
this.logApiRequest(toContents(req.contents), req.model, userPromptId);
try {
const response = await this.wrapped.generateContent(req, userPromptId);
const durationMs = Date.now() - startTime;
@@ -140,7 +137,7 @@ export class LoggingContentGenerator implements ContentGenerator {
userPromptId: string,
): Promise<AsyncGenerator<GenerateContentResponse>> {
const startTime = Date.now();
this.logApiRequest(this.toContents(req.contents), req.model, userPromptId);
this.logApiRequest(toContents(req.contents), req.model, userPromptId);
let stream: AsyncGenerator<GenerateContentResponse>;
try {
@@ -208,95 +205,4 @@ export class LoggingContentGenerator implements ContentGenerator {
): Promise<EmbedContentResponse> {
return this.wrapped.embedContent(req);
}
useSummarizedThinking(): boolean {
return this.wrapped.useSummarizedThinking();
}
private toContents(contents: ContentListUnion): Content[] {
if (Array.isArray(contents)) {
// it's a Content[] or a PartsUnion[]
return contents.map((c) => this.toContent(c));
}
// it's a Content or a PartsUnion
return [this.toContent(contents)];
}
private toContent(content: ContentUnion): Content {
if (Array.isArray(content)) {
// it's a PartsUnion[]
return {
role: 'user',
parts: this.toParts(content),
};
}
if (typeof content === 'string') {
// it's a string
return {
role: 'user',
parts: [{ text: content }],
};
}
if ('parts' in content) {
// it's a Content - process parts to handle thought filtering
return {
...content,
parts: content.parts
? this.toParts(content.parts.filter((p) => p != null))
: [],
};
}
// it's a Part
return {
role: 'user',
parts: [this.toPart(content as Part)],
};
}
private toParts(parts: PartUnion[]): Part[] {
return parts.map((p) => this.toPart(p));
}
private toPart(part: PartUnion): Part {
if (typeof part === 'string') {
// it's a string
return { text: part };
}
// Handle thought parts for CountToken API compatibility
// The CountToken API expects parts to have certain required "oneof" fields initialized,
// but thought parts don't conform to this schema and cause API failures
if ('thought' in part && part.thought) {
const thoughtText = `[Thought: ${part.thought}]`;
const newPart = { ...part };
delete (newPart as Record<string, unknown>)['thought'];
const hasApiContent =
'functionCall' in newPart ||
'functionResponse' in newPart ||
'inlineData' in newPart ||
'fileData' in newPart;
if (hasApiContent) {
// It's a functionCall or other non-text part. Just strip the thought.
return newPart;
}
// If no other valid API content, this must be a text part.
// Combine existing text (if any) with the thought, preserving other properties.
const text = (newPart as { text?: unknown }).text;
const existingText = text ? String(text) : '';
const combinedText = existingText
? `${existingText}\n${thoughtText}`
: thoughtText;
return {
...newPart,
text: combinedText,
};
}
return part;
}
}

View File

@@ -47,7 +47,7 @@ describe('executeToolCall', () => {
getDebugMode: () => false,
getContentGeneratorConfig: () => ({
model: 'test-model',
authType: 'gemini-api-key',
authType: 'oauth-personal',
}),
getShellExecutionConfig: () => ({
terminalWidth: 90,

View File

@@ -99,7 +99,6 @@ describe('OpenAIContentGenerator (Refactored)', () => {
},
} as unknown as OpenAI),
buildRequest: vi.fn().mockImplementation((req) => req),
getDefaultGenerationConfig: vi.fn().mockReturnValue({}),
};
generator = new OpenAIContentGenerator(
@@ -212,7 +211,6 @@ describe('OpenAIContentGenerator (Refactored)', () => {
},
} as unknown as OpenAI),
buildRequest: vi.fn().mockImplementation((req) => req),
getDefaultGenerationConfig: vi.fn().mockReturnValue({}),
};
const testGenerator = new TestGenerator(
@@ -279,7 +277,6 @@ describe('OpenAIContentGenerator (Refactored)', () => {
},
} as unknown as OpenAI),
buildRequest: vi.fn().mockImplementation((req) => req),
getDefaultGenerationConfig: vi.fn().mockReturnValue({}),
};
const testGenerator = new TestGenerator(

View File

@@ -154,8 +154,4 @@ export class OpenAIContentGenerator implements ContentGenerator {
);
}
}
useSummarizedThinking(): boolean {
return false;
}
}

View File

@@ -60,7 +60,6 @@ describe('ContentGenerationPipeline', () => {
buildClient: vi.fn().mockReturnValue(mockClient),
buildRequest: vi.fn().mockImplementation((req) => req),
buildHeaders: vi.fn().mockReturnValue({}),
getDefaultGenerationConfig: vi.fn().mockReturnValue({}),
};
// Mock telemetry service

View File

@@ -283,22 +283,16 @@ export class ContentGenerationPipeline {
private buildSamplingParameters(
request: GenerateContentParameters,
): Record<string, unknown> {
const defaultSamplingParams =
this.config.provider.getDefaultGenerationConfig();
const configSamplingParams = this.contentGeneratorConfig.samplingParams;
// Helper function to get parameter value with priority: config > request > default
const getParameterValue = <T>(
configKey: keyof NonNullable<typeof configSamplingParams>,
requestKey?: keyof NonNullable<typeof request.config>,
requestKey: keyof NonNullable<typeof request.config>,
defaultValue?: T,
): T | undefined => {
const configValue = configSamplingParams?.[configKey] as T | undefined;
const requestValue = requestKey
? (request.config?.[requestKey] as T | undefined)
: undefined;
const defaultValue = requestKey
? (defaultSamplingParams[requestKey] as T)
: undefined;
const requestValue = request.config?.[requestKey] as T | undefined;
if (configValue !== undefined) return configValue;
if (requestValue !== undefined) return requestValue;
@@ -310,8 +304,12 @@ export class ContentGenerationPipeline {
key: string,
configKey: keyof NonNullable<typeof configSamplingParams>,
requestKey?: keyof NonNullable<typeof request.config>,
): Record<string, T | undefined> => {
const value = getParameterValue<T>(configKey, requestKey);
defaultValue?: T,
): Record<string, T> | Record<string, never> => {
const value = requestKey
? getParameterValue(configKey, requestKey, defaultValue)
: ((configSamplingParams?.[configKey] as T | undefined) ??
defaultValue);
return value !== undefined ? { [key]: value } : {};
};
@@ -325,18 +323,10 @@ export class ContentGenerationPipeline {
...addParameterIfDefined('max_tokens', 'max_tokens', 'maxOutputTokens'),
// Config-only parameters (no request fallback)
...addParameterIfDefined('top_k', 'top_k', 'topK'),
...addParameterIfDefined('top_k', 'top_k'),
...addParameterIfDefined('repetition_penalty', 'repetition_penalty'),
...addParameterIfDefined(
'presence_penalty',
'presence_penalty',
'presencePenalty',
),
...addParameterIfDefined(
'frequency_penalty',
'frequency_penalty',
'frequencyPenalty',
),
...addParameterIfDefined('presence_penalty', 'presence_penalty'),
...addParameterIfDefined('frequency_penalty', 'frequency_penalty'),
};
return params;

View File

@@ -1,5 +1,4 @@
import OpenAI from 'openai';
import type { GenerateContentConfig } from '@google/genai';
import type { Config } from '../../../config/config.js';
import type { ContentGeneratorConfig } from '../../contentGenerator.js';
import { AuthType } from '../../contentGenerator.js';
@@ -142,14 +141,6 @@ export class DashScopeOpenAICompatibleProvider
};
}
getDefaultGenerationConfig(): GenerateContentConfig {
return {
temperature: 0.7,
topP: 0.8,
topK: 20,
};
}
/**
* Add cache control flag to specified message(s) for DashScope providers
*/

View File

@@ -8,7 +8,6 @@ import type OpenAI from 'openai';
import type { Config } from '../../../config/config.js';
import type { ContentGeneratorConfig } from '../../contentGenerator.js';
import { DefaultOpenAICompatibleProvider } from './default.js';
import type { GenerateContentConfig } from '@google/genai';
export class DeepSeekOpenAICompatibleProvider extends DefaultOpenAICompatibleProvider {
constructor(
@@ -77,10 +76,4 @@ export class DeepSeekOpenAICompatibleProvider extends DefaultOpenAICompatiblePro
messages,
};
}
override getDefaultGenerationConfig(): GenerateContentConfig {
return {
temperature: 0,
};
}
}

View File

@@ -1,5 +1,4 @@
import OpenAI from 'openai';
import type { GenerateContentConfig } from '@google/genai';
import type { Config } from '../../../config/config.js';
import type { ContentGeneratorConfig } from '../../contentGenerator.js';
import { DEFAULT_TIMEOUT, DEFAULT_MAX_RETRIES } from '../constants.js';
@@ -56,10 +55,4 @@ export class DefaultOpenAICompatibleProvider
...request, // Preserve all original parameters including sampling params
};
}
getDefaultGenerationConfig(): GenerateContentConfig {
return {
topP: 0.95,
};
}
}

View File

@@ -1,4 +1,3 @@
import type { GenerateContentConfig } from '@google/genai';
import type OpenAI from 'openai';
// Extended types to support cache_control for DashScope
@@ -23,7 +22,6 @@ export interface OpenAICompatibleProvider {
request: OpenAI.Chat.ChatCompletionCreateParams,
userPromptId: string,
): OpenAI.Chat.ChatCompletionCreateParams;
getDefaultGenerationConfig(): GenerateContentConfig;
}
export type DashScopeRequestMetadata = {

View File

@@ -36,6 +36,13 @@ vi.mock('../utils/errorReporting', () => ({
reportError: vi.fn(),
}));
// Use the actual implementation from partUtils now that it's provided.
vi.mock('../utils/generateContentResponseUtilities', () => ({
getResponseText: (resp: GenerateContentResponse) =>
resp.candidates?.[0]?.content?.parts?.map((part) => part.text).join('') ||
undefined,
}));
describe('Turn', () => {
let turn: Turn;
// Define a type for the mocked Chat instance for clarity
@@ -149,7 +156,6 @@ describe('Turn', () => {
type: GeminiEventType.Thought,
value: { subject: '', description: 'reasoning...' },
},
{ type: GeminiEventType.Content, value: 'final answer' },
]);
});

View File

@@ -27,11 +27,7 @@ import {
toFriendlyError,
} from '../utils/errors.js';
import type { GeminiChat } from './geminiChat.js';
import {
getThoughtText,
parseThought,
type ThoughtSummary,
} from '../utils/thoughtUtils.js';
import { getThoughtText, type ThoughtSummary } from '../utils/thoughtUtils.js';
// Define a structure for tools passed to the server
export interface ServerTool {
@@ -270,12 +266,13 @@ export class Turn {
this.currentResponseId = resp.responseId;
}
const thoughtText = getThoughtText(resp);
if (thoughtText) {
const thoughtPart = getThoughtText(resp);
if (thoughtPart) {
yield {
type: GeminiEventType.Thought,
value: parseThought(thoughtText),
value: { subject: '', description: thoughtPart },
};
continue;
}
const text = getResponseText(resp);

View File

@@ -4,10 +4,36 @@
* SPDX-License-Identifier: Apache-2.0
*/
import { describe, it, expect, vi, beforeEach } from 'vitest';
import {
describe,
it,
expect,
vi,
beforeEach,
type Mock,
type MockInstance,
afterEach,
} from 'vitest';
import { handleFallback } from './handler.js';
import type { Config } from '../config/config.js';
import { AuthType } from '../core/contentGenerator.js';
import {
DEFAULT_GEMINI_FLASH_MODEL,
DEFAULT_GEMINI_MODEL,
} from '../config/models.js';
import { logFlashFallback } from '../telemetry/index.js';
import type { FallbackModelHandler } from './types.js';
// Mock the telemetry logger and event class
vi.mock('../telemetry/index.js', () => ({
logFlashFallback: vi.fn(),
FlashFallbackEvent: class {},
}));
const MOCK_PRO_MODEL = DEFAULT_GEMINI_MODEL;
const FALLBACK_MODEL = DEFAULT_GEMINI_FLASH_MODEL;
const AUTH_OAUTH = AuthType.LOGIN_WITH_GOOGLE;
const AUTH_API_KEY = AuthType.USE_GEMINI;
const createMockConfig = (overrides: Partial<Config> = {}): Config =>
({
@@ -19,28 +45,174 @@ const createMockConfig = (overrides: Partial<Config> = {}): Config =>
describe('handleFallback', () => {
let mockConfig: Config;
let mockHandler: Mock<FallbackModelHandler>;
let consoleErrorSpy: MockInstance;
beforeEach(() => {
vi.clearAllMocks();
mockConfig = createMockConfig();
mockHandler = vi.fn();
// Default setup: OAuth user, Pro model failed, handler injected
mockConfig = createMockConfig({
fallbackModelHandler: mockHandler,
});
consoleErrorSpy = vi.spyOn(console, 'error').mockImplementation(() => {});
});
it('should return null for unknown auth types', async () => {
afterEach(() => {
consoleErrorSpy.mockRestore();
});
it('should return null immediately if authType is not OAuth', async () => {
const result = await handleFallback(
mockConfig,
'test-model',
'unknown-auth',
MOCK_PRO_MODEL,
AUTH_API_KEY,
);
expect(result).toBeNull();
expect(mockHandler).not.toHaveBeenCalled();
expect(mockConfig.setFallbackMode).not.toHaveBeenCalled();
});
it('should return null if the failed model is already the fallback model', async () => {
const result = await handleFallback(
mockConfig,
FALLBACK_MODEL, // Failed model is Flash
AUTH_OAUTH,
);
expect(result).toBeNull();
expect(mockHandler).not.toHaveBeenCalled();
});
it('should return null if no fallbackHandler is injected in config', async () => {
const configWithoutHandler = createMockConfig({
fallbackModelHandler: undefined,
});
const result = await handleFallback(
configWithoutHandler,
MOCK_PRO_MODEL,
AUTH_OAUTH,
);
expect(result).toBeNull();
});
it('should handle Qwen OAuth error', async () => {
const result = await handleFallback(
mockConfig,
'test-model',
AuthType.QWEN_OAUTH,
new Error('unauthorized'),
describe('when handler returns "retry"', () => {
it('should activate fallback mode, log telemetry, and return true', async () => {
mockHandler.mockResolvedValue('retry');
const result = await handleFallback(
mockConfig,
MOCK_PRO_MODEL,
AUTH_OAUTH,
);
expect(result).toBe(true);
expect(mockConfig.setFallbackMode).toHaveBeenCalledWith(true);
expect(logFlashFallback).toHaveBeenCalled();
});
});
describe('when handler returns "stop"', () => {
it('should activate fallback mode, log telemetry, and return false', async () => {
mockHandler.mockResolvedValue('stop');
const result = await handleFallback(
mockConfig,
MOCK_PRO_MODEL,
AUTH_OAUTH,
);
expect(result).toBe(false);
expect(mockConfig.setFallbackMode).toHaveBeenCalledWith(true);
expect(logFlashFallback).toHaveBeenCalled();
});
});
describe('when handler returns "auth"', () => {
it('should NOT activate fallback mode and return false', async () => {
mockHandler.mockResolvedValue('auth');
const result = await handleFallback(
mockConfig,
MOCK_PRO_MODEL,
AUTH_OAUTH,
);
expect(result).toBe(false);
expect(mockConfig.setFallbackMode).not.toHaveBeenCalled();
expect(logFlashFallback).not.toHaveBeenCalled();
});
});
describe('when handler returns an unexpected value', () => {
it('should log an error and return null', async () => {
mockHandler.mockResolvedValue(null);
const result = await handleFallback(
mockConfig,
MOCK_PRO_MODEL,
AUTH_OAUTH,
);
expect(result).toBeNull();
expect(consoleErrorSpy).toHaveBeenCalledWith(
'Fallback UI handler failed:',
new Error(
'Unexpected fallback intent received from fallbackModelHandler: "null"',
),
);
expect(mockConfig.setFallbackMode).not.toHaveBeenCalled();
});
});
it('should pass the correct context (failedModel, fallbackModel, error) to the handler', async () => {
const mockError = new Error('Quota Exceeded');
mockHandler.mockResolvedValue('retry');
await handleFallback(mockConfig, MOCK_PRO_MODEL, AUTH_OAUTH, mockError);
expect(mockHandler).toHaveBeenCalledWith(
MOCK_PRO_MODEL,
FALLBACK_MODEL,
mockError,
);
});
it('should not call setFallbackMode or log telemetry if already in fallback mode', async () => {
// Setup config where fallback mode is already active
const activeFallbackConfig = createMockConfig({
fallbackModelHandler: mockHandler,
isInFallbackMode: vi.fn(() => true), // Already active
setFallbackMode: vi.fn(),
});
mockHandler.mockResolvedValue('retry');
const result = await handleFallback(
activeFallbackConfig,
MOCK_PRO_MODEL,
AUTH_OAUTH,
);
// Should still return true to allow the retry (which will use the active fallback mode)
expect(result).toBe(true);
// Should still consult the handler
expect(mockHandler).toHaveBeenCalled();
// But should not mutate state or log telemetry again
expect(activeFallbackConfig.setFallbackMode).not.toHaveBeenCalled();
expect(logFlashFallback).not.toHaveBeenCalled();
});
it('should catch errors from the handler, log an error, and return null', async () => {
const handlerError = new Error('UI interaction failed');
mockHandler.mockRejectedValue(handlerError);
const result = await handleFallback(mockConfig, MOCK_PRO_MODEL, AUTH_OAUTH);
expect(result).toBeNull();
expect(consoleErrorSpy).toHaveBeenCalledWith(
'Fallback UI handler failed:',
handlerError,
);
expect(mockConfig.setFallbackMode).not.toHaveBeenCalled();
});
});

View File

@@ -6,6 +6,8 @@
import type { Config } from '../config/config.js';
import { AuthType } from '../core/contentGenerator.js';
import { DEFAULT_GEMINI_FLASH_MODEL } from '../config/models.js';
import { logFlashFallback, FlashFallbackEvent } from '../telemetry/index.js';
export async function handleFallback(
config: Config,
@@ -18,7 +20,48 @@ export async function handleFallback(
return handleQwenOAuthError(error);
}
return null;
// Applicability Checks
if (authType !== AuthType.LOGIN_WITH_GOOGLE) return null;
const fallbackModel = DEFAULT_GEMINI_FLASH_MODEL;
if (failedModel === fallbackModel) return null;
// Consult UI Handler for Intent
const fallbackModelHandler = config.fallbackModelHandler;
if (typeof fallbackModelHandler !== 'function') return null;
try {
// Pass the specific failed model to the UI handler.
const intent = await fallbackModelHandler(
failedModel,
fallbackModel,
error,
);
// Process Intent and Update State
switch (intent) {
case 'retry':
// Activate fallback mode. The NEXT retry attempt will pick this up.
activateFallbackMode(config, authType);
return true; // Signal retryWithBackoff to continue.
case 'stop':
activateFallbackMode(config, authType);
return false;
case 'auth':
return false;
default:
throw new Error(
`Unexpected fallback intent received from fallbackModelHandler: "${intent}"`,
);
}
} catch (handlerError) {
console.error('Fallback UI handler failed:', handlerError);
return null;
}
}
/**
@@ -75,3 +118,12 @@ async function handleQwenOAuthError(error?: unknown): Promise<string | null> {
// For other errors, don't handle them specially
return null;
}
function activateFallbackMode(config: Config, authType: string | undefined) {
if (!config.isInFallbackMode()) {
config.setFallbackMode(true);
if (authType) {
logFlashFallback(config, new FlashFallbackEvent(authType));
}
}
}

View File

@@ -20,11 +20,26 @@ async function getProcessInfo(pid: number): Promise<{
command: string;
}> {
// Only used for Unix systems (macOS and Linux)
const { stdout } = await execAsync(`ps -p ${pid} -o ppid=,comm=`);
const [ppidStr, ...commandParts] = stdout.trim().split(/\s+/);
const parentPid = parseInt(ppidStr, 10);
const command = commandParts.join(' ');
return { parentPid, name: path.basename(command), command };
try {
const command = `ps -o ppid=,command= -p ${pid}`;
const { stdout } = await execAsync(command);
const trimmedStdout = stdout.trim();
if (!trimmedStdout) {
return { parentPid: 0, name: '', command: '' };
}
const parts = trimmedStdout.split(/\s+/);
const ppidString = parts[0];
const parentPid = parseInt(ppidString, 10);
const fullCommand = trimmedStdout.substring(ppidString.length).trim();
const processName = path.basename(fullCommand.split(' ')[0]);
return {
parentPid: isNaN(parentPid) ? 1 : parentPid,
name: processName,
command: fullCommand,
};
} catch (_e) {
return { parentPid: 0, name: '', command: '' };
}
}
/**
* Finds the IDE process info on Unix-like systems.

View File

@@ -12,6 +12,7 @@ export * from './output/json-formatter.js';
// Export Core Logic
export * from './core/client.js';
export * from './core/contentGenerator.js';
export * from './core/loggingContentGenerator.js';
export * from './core/geminiChat.js';
export * from './core/logger.js';
export * from './core/prompts.js';
@@ -23,7 +24,11 @@ export * from './core/nonInteractiveToolExecutor.js';
export * from './fallback/types.js';
export * from './code_assist/codeAssist.js';
export * from './code_assist/oauth2.js';
export * from './qwen/qwenOAuth2.js';
export * from './code_assist/server.js';
export * from './code_assist/types.js';
// Export utilities
export * from './utils/paths.js';
@@ -80,6 +85,9 @@ export * from './tools/tool-registry.js';
// Export subagents (Phase 1)
export * from './subagents/index.js';
// Export skills
export * from './skills/index.js';
// Export prompt logic
export * from './prompts/mcp-prompts.js';
@@ -101,6 +109,7 @@ export * from './tools/mcp-client-manager.js';
export * from './tools/mcp-tool.js';
export * from './tools/sdk-control-client-transport.js';
export * from './tools/task.js';
export * from './tools/skill.js';
export * from './tools/todoWrite.js';
export * from './tools/exitPlanMode.js';

View File

@@ -907,5 +907,3 @@ export async function clearQwenCredentials(): Promise<void> {
function getQwenCachedCredentialPath(): string {
return path.join(os.homedir(), QWEN_DIR, QWEN_CREDENTIAL_FILENAME);
}
export const clearCachedCredentialFile = clearQwenCredentials;

View File

@@ -0,0 +1,31 @@
/**
* @license
* Copyright 2025 Qwen
* SPDX-License-Identifier: Apache-2.0
*/
/**
* @fileoverview Skills feature implementation
*
* This module provides the foundation for the skills feature, which allows
* users to define reusable skill configurations that can be loaded by the
* model via a dedicated Skills tool.
*
* Skills are stored as directories in `.qwen/skills/` (project-level) or
* `~/.qwen/skills/` (user-level), with each directory containing a SKILL.md
* file with YAML frontmatter for metadata.
*/
// Core types and interfaces
export type {
SkillConfig,
SkillLevel,
SkillValidationResult,
ListSkillsOptions,
SkillErrorCode,
} from './types.js';
export { SkillError } from './types.js';
// Main management class
export { SkillManager } from './skill-manager.js';

View File

@@ -0,0 +1,463 @@
/**
* @license
* Copyright 2025 Qwen
* SPDX-License-Identifier: Apache-2.0
*/
import { vi, describe, it, expect, beforeEach, afterEach } from 'vitest';
import * as fs from 'fs/promises';
import * as path from 'path';
import * as os from 'os';
import { SkillManager } from './skill-manager.js';
import { type SkillConfig, SkillError } from './types.js';
import type { Config } from '../config/config.js';
import { makeFakeConfig } from '../test-utils/config.js';
// Mock file system operations
vi.mock('fs/promises');
vi.mock('os');
// Mock yaml parser - use vi.hoisted for proper hoisting
const mockParseYaml = vi.hoisted(() => vi.fn());
vi.mock('../utils/yaml-parser.js', () => ({
parse: mockParseYaml,
stringify: vi.fn(),
}));
describe('SkillManager', () => {
let manager: SkillManager;
let mockConfig: Config;
beforeEach(() => {
// Create mock Config object using test utility
mockConfig = makeFakeConfig({});
// Mock the project root method
vi.spyOn(mockConfig, 'getProjectRoot').mockReturnValue('/test/project');
// Mock os.homedir
vi.mocked(os.homedir).mockReturnValue('/home/user');
// Reset and setup mocks
vi.clearAllMocks();
// Setup yaml parser mocks with sophisticated behavior
mockParseYaml.mockImplementation((yamlString: string) => {
// Handle different test cases based on YAML content
if (yamlString.includes('allowedTools:')) {
return {
name: 'test-skill',
description: 'A test skill',
allowedTools: ['read_file', 'write_file'],
};
}
if (yamlString.includes('name: skill1')) {
return { name: 'skill1', description: 'First skill' };
}
if (yamlString.includes('name: skill2')) {
return { name: 'skill2', description: 'Second skill' };
}
if (yamlString.includes('name: skill3')) {
return { name: 'skill3', description: 'Third skill' };
}
if (!yamlString.includes('name:')) {
return { description: 'A test skill' }; // Missing name case
}
if (!yamlString.includes('description:')) {
return { name: 'test-skill' }; // Missing description case
}
// Default case
return {
name: 'test-skill',
description: 'A test skill',
};
});
manager = new SkillManager(mockConfig);
});
afterEach(() => {
vi.restoreAllMocks();
});
const validSkillConfig: SkillConfig = {
name: 'test-skill',
description: 'A test skill',
level: 'project',
filePath: '/test/project/.qwen/skills/test-skill/SKILL.md',
body: 'You are a helpful assistant with this skill.',
};
const validMarkdown = `---
name: test-skill
description: A test skill
---
You are a helpful assistant with this skill.
`;
describe('parseSkillContent', () => {
it('should parse valid markdown content', () => {
const config = manager.parseSkillContent(
validMarkdown,
validSkillConfig.filePath,
'project',
);
expect(config.name).toBe('test-skill');
expect(config.description).toBe('A test skill');
expect(config.body).toBe('You are a helpful assistant with this skill.');
expect(config.level).toBe('project');
expect(config.filePath).toBe(validSkillConfig.filePath);
});
it('should parse content with allowedTools', () => {
const markdownWithTools = `---
name: test-skill
description: A test skill
allowedTools:
- read_file
- write_file
---
You are a helpful assistant with this skill.
`;
const config = manager.parseSkillContent(
markdownWithTools,
validSkillConfig.filePath,
'project',
);
expect(config.allowedTools).toEqual(['read_file', 'write_file']);
});
it('should determine level from file path', () => {
const projectPath = '/test/project/.qwen/skills/test-skill/SKILL.md';
const userPath = '/home/user/.qwen/skills/test-skill/SKILL.md';
const projectConfig = manager.parseSkillContent(
validMarkdown,
projectPath,
'project',
);
const userConfig = manager.parseSkillContent(
validMarkdown,
userPath,
'user',
);
expect(projectConfig.level).toBe('project');
expect(userConfig.level).toBe('user');
});
it('should throw error for invalid frontmatter format', () => {
const invalidMarkdown = `No frontmatter here
Just content`;
expect(() =>
manager.parseSkillContent(
invalidMarkdown,
validSkillConfig.filePath,
'project',
),
).toThrow(SkillError);
});
it('should throw error for missing name', () => {
const markdownWithoutName = `---
description: A test skill
---
You are a helpful assistant.
`;
expect(() =>
manager.parseSkillContent(
markdownWithoutName,
validSkillConfig.filePath,
'project',
),
).toThrow(SkillError);
});
it('should throw error for missing description', () => {
const markdownWithoutDescription = `---
name: test-skill
---
You are a helpful assistant.
`;
expect(() =>
manager.parseSkillContent(
markdownWithoutDescription,
validSkillConfig.filePath,
'project',
),
).toThrow(SkillError);
});
});
describe('validateConfig', () => {
it('should validate valid configuration', () => {
const result = manager.validateConfig(validSkillConfig);
expect(result.isValid).toBe(true);
expect(result.errors).toHaveLength(0);
});
it('should report error for missing name', () => {
const invalidConfig = { ...validSkillConfig, name: '' };
const result = manager.validateConfig(invalidConfig);
expect(result.isValid).toBe(false);
expect(result.errors).toContain('"name" cannot be empty');
});
it('should report error for missing description', () => {
const invalidConfig = { ...validSkillConfig, description: '' };
const result = manager.validateConfig(invalidConfig);
expect(result.isValid).toBe(false);
expect(result.errors).toContain('"description" cannot be empty');
});
it('should report error for invalid allowedTools type', () => {
const invalidConfig = {
...validSkillConfig,
allowedTools: 'not-an-array' as unknown as string[],
};
const result = manager.validateConfig(invalidConfig);
expect(result.isValid).toBe(false);
expect(result.errors).toContain('"allowedTools" must be an array');
});
it('should warn for empty body', () => {
const configWithEmptyBody = { ...validSkillConfig, body: '' };
const result = manager.validateConfig(configWithEmptyBody);
expect(result.isValid).toBe(true); // Still valid
expect(result.warnings).toContain('Skill body is empty');
});
});
describe('loadSkill', () => {
it('should load skill from project level first', async () => {
vi.mocked(fs.readdir).mockResolvedValue([
{ name: 'test-skill', isDirectory: () => true, isFile: () => false },
] as unknown as Awaited<ReturnType<typeof fs.readdir>>);
vi.mocked(fs.access).mockResolvedValue(undefined);
vi.mocked(fs.readFile).mockResolvedValue(validMarkdown);
const config = await manager.loadSkill('test-skill');
expect(config).toBeDefined();
expect(config!.name).toBe('test-skill');
});
it('should fall back to user level if project level fails', async () => {
vi.mocked(fs.readdir)
.mockRejectedValueOnce(new Error('Project dir not found')) // project level fails
.mockResolvedValueOnce([
{ name: 'test-skill', isDirectory: () => true, isFile: () => false },
] as unknown as Awaited<ReturnType<typeof fs.readdir>>); // user level succeeds
vi.mocked(fs.access).mockResolvedValue(undefined);
vi.mocked(fs.readFile).mockResolvedValue(validMarkdown);
const config = await manager.loadSkill('test-skill');
expect(config).toBeDefined();
expect(config!.name).toBe('test-skill');
});
it('should return null if not found at either level', async () => {
vi.mocked(fs.readdir).mockRejectedValue(new Error('Directory not found'));
const config = await manager.loadSkill('nonexistent');
expect(config).toBeNull();
});
});
describe('loadSkillForRuntime', () => {
it('should load skill for runtime', async () => {
vi.mocked(fs.readdir).mockResolvedValueOnce([
{ name: 'test-skill', isDirectory: () => true, isFile: () => false },
] as unknown as Awaited<ReturnType<typeof fs.readdir>>);
vi.mocked(fs.access).mockResolvedValue(undefined);
vi.mocked(fs.readFile).mockResolvedValue(validMarkdown); // SKILL.md
const config = await manager.loadSkillForRuntime('test-skill');
expect(config).toBeDefined();
expect(config!.name).toBe('test-skill');
});
it('should return null if skill not found', async () => {
vi.mocked(fs.readdir).mockRejectedValue(new Error('Directory not found'));
const config = await manager.loadSkillForRuntime('nonexistent');
expect(config).toBeNull();
});
});
describe('listSkills', () => {
beforeEach(() => {
// Mock directory listing for skills directories (with Dirent objects)
vi.mocked(fs.readdir)
.mockResolvedValueOnce([
{ name: 'skill1', isDirectory: () => true, isFile: () => false },
{ name: 'skill2', isDirectory: () => true, isFile: () => false },
{
name: 'not-a-dir.txt',
isDirectory: () => false,
isFile: () => true,
},
] as unknown as Awaited<ReturnType<typeof fs.readdir>>)
.mockResolvedValueOnce([
{ name: 'skill3', isDirectory: () => true, isFile: () => false },
{ name: 'skill1', isDirectory: () => true, isFile: () => false },
] as unknown as Awaited<ReturnType<typeof fs.readdir>>);
vi.mocked(fs.access).mockResolvedValue(undefined);
// Mock file reading for valid skills
vi.mocked(fs.readFile).mockImplementation((filePath) => {
const pathStr = String(filePath);
if (pathStr.includes('skill1')) {
return Promise.resolve(`---
name: skill1
description: First skill
---
Skill 1 content`);
} else if (pathStr.includes('skill2')) {
return Promise.resolve(`---
name: skill2
description: Second skill
---
Skill 2 content`);
} else if (pathStr.includes('skill3')) {
return Promise.resolve(`---
name: skill3
description: Third skill
---
Skill 3 content`);
}
return Promise.reject(new Error('File not found'));
});
});
it('should list skills from both levels', async () => {
const skills = await manager.listSkills();
expect(skills).toHaveLength(3); // skill1 (project takes precedence), skill2, skill3
expect(skills.map((s) => s.name).sort()).toEqual([
'skill1',
'skill2',
'skill3',
]);
});
it('should prioritize project level over user level', async () => {
const skills = await manager.listSkills();
const skill1 = skills.find((s) => s.name === 'skill1');
expect(skill1!.level).toBe('project');
});
it('should filter by level', async () => {
const projectSkills = await manager.listSkills({
level: 'project',
});
expect(projectSkills).toHaveLength(2); // skill1, skill2
expect(projectSkills.every((s) => s.level === 'project')).toBe(true);
});
it('should handle empty directories', async () => {
vi.mocked(fs.readdir).mockReset();
vi.mocked(fs.readdir).mockResolvedValue(
[] as unknown as Awaited<ReturnType<typeof fs.readdir>>,
);
const skills = await manager.listSkills({ force: true });
expect(skills).toHaveLength(0);
});
it('should handle directory read errors', async () => {
vi.mocked(fs.readdir).mockReset();
vi.mocked(fs.readdir).mockRejectedValue(new Error('Directory not found'));
const skills = await manager.listSkills({ force: true });
expect(skills).toHaveLength(0);
});
});
describe('getSkillsBaseDir', () => {
it('should return project-level base dir', () => {
const baseDir = manager.getSkillsBaseDir('project');
expect(baseDir).toBe(path.join('/test/project', '.qwen', 'skills'));
});
it('should return user-level base dir', () => {
const baseDir = manager.getSkillsBaseDir('user');
expect(baseDir).toBe(path.join('/home/user', '.qwen', 'skills'));
});
});
describe('change listeners', () => {
it('should notify listeners when cache is refreshed', async () => {
const listener = vi.fn();
manager.addChangeListener(listener);
vi.mocked(fs.readdir).mockResolvedValue(
[] as unknown as Awaited<ReturnType<typeof fs.readdir>>,
);
await manager.refreshCache();
expect(listener).toHaveBeenCalled();
});
it('should remove listener when cleanup function is called', async () => {
const listener = vi.fn();
const removeListener = manager.addChangeListener(listener);
removeListener();
vi.mocked(fs.readdir).mockResolvedValue(
[] as unknown as Awaited<ReturnType<typeof fs.readdir>>,
);
await manager.refreshCache();
expect(listener).not.toHaveBeenCalled();
});
});
describe('parse errors', () => {
it('should track parse errors', async () => {
vi.mocked(fs.readdir).mockResolvedValue([
{ name: 'bad-skill', isDirectory: () => true, isFile: () => false },
] as unknown as Awaited<ReturnType<typeof fs.readdir>>);
vi.mocked(fs.access).mockResolvedValue(undefined);
vi.mocked(fs.readFile).mockResolvedValue(
'invalid content without frontmatter',
);
await manager.listSkills({ force: true });
const errors = manager.getParseErrors();
expect(errors.size).toBeGreaterThan(0);
});
});
});

View File

@@ -0,0 +1,452 @@
/**
* @license
* Copyright 2025 Qwen
* SPDX-License-Identifier: Apache-2.0
*/
import * as fs from 'fs/promises';
import * as path from 'path';
import * as os from 'os';
import { parse as parseYaml } from '../utils/yaml-parser.js';
import type {
SkillConfig,
SkillLevel,
ListSkillsOptions,
SkillValidationResult,
} from './types.js';
import { SkillError, SkillErrorCode } from './types.js';
import type { Config } from '../config/config.js';
const QWEN_CONFIG_DIR = '.qwen';
const SKILLS_CONFIG_DIR = 'skills';
const SKILL_MANIFEST_FILE = 'SKILL.md';
/**
* Manages skill configurations stored as directories containing SKILL.md files.
* Provides discovery, parsing, validation, and caching for skills.
*/
export class SkillManager {
private skillsCache: Map<SkillLevel, SkillConfig[]> | null = null;
private readonly changeListeners: Set<() => void> = new Set();
private parseErrors: Map<string, SkillError> = new Map();
constructor(private readonly config: Config) {}
/**
* Adds a listener that will be called when skills change.
* @returns A function to remove the listener.
*/
addChangeListener(listener: () => void): () => void {
this.changeListeners.add(listener);
return () => {
this.changeListeners.delete(listener);
};
}
/**
* Notifies all registered change listeners.
*/
private notifyChangeListeners(): void {
for (const listener of this.changeListeners) {
try {
listener();
} catch (error) {
console.warn('Skill change listener threw an error:', error);
}
}
}
/**
* Gets any parse errors that occurred during skill loading.
* @returns Map of skill paths to their parse errors.
*/
getParseErrors(): Map<string, SkillError> {
return new Map(this.parseErrors);
}
/**
* Lists all available skills.
*
* @param options - Filtering options
* @returns Array of skill configurations
*/
async listSkills(options: ListSkillsOptions = {}): Promise<SkillConfig[]> {
const skills: SkillConfig[] = [];
const seenNames = new Set<string>();
const levelsToCheck: SkillLevel[] = options.level
? [options.level]
: ['project', 'user'];
// Check if we should use cache or force refresh
const shouldUseCache = !options.force && this.skillsCache !== null;
// Initialize cache if it doesn't exist or we're forcing a refresh
if (!shouldUseCache) {
await this.refreshCache();
}
// Collect skills from each level (project takes precedence over user)
for (const level of levelsToCheck) {
const levelSkills = this.skillsCache?.get(level) || [];
for (const skill of levelSkills) {
// Skip if we've already seen this name (precedence: project > user)
if (seenNames.has(skill.name)) {
continue;
}
skills.push(skill);
seenNames.add(skill.name);
}
}
// Sort by name for consistent ordering
skills.sort((a, b) => a.name.localeCompare(b.name));
return skills;
}
/**
* Loads a skill configuration by name.
* If level is specified, only searches that level.
* If level is omitted, searches project-level first, then user-level.
*
* @param name - Name of the skill to load
* @param level - Optional level to limit search to
* @returns SkillConfig or null if not found
*/
async loadSkill(
name: string,
level?: SkillLevel,
): Promise<SkillConfig | null> {
if (level) {
return this.findSkillByNameAtLevel(name, level);
}
// Try project level first
const projectSkill = await this.findSkillByNameAtLevel(name, 'project');
if (projectSkill) {
return projectSkill;
}
// Try user level
return this.findSkillByNameAtLevel(name, 'user');
}
/**
* Loads a skill with its full content, ready for runtime use.
* This includes loading additional files from the skill directory.
*
* @param name - Name of the skill to load
* @param level - Optional level to limit search to
* @returns SkillConfig or null if not found
*/
async loadSkillForRuntime(
name: string,
level?: SkillLevel,
): Promise<SkillConfig | null> {
const skill = await this.loadSkill(name, level);
if (!skill) {
return null;
}
return skill;
}
/**
* Validates a skill configuration.
*
* @param config - Configuration to validate
* @returns Validation result
*/
validateConfig(config: Partial<SkillConfig>): SkillValidationResult {
const errors: string[] = [];
const warnings: string[] = [];
// Check required fields
if (typeof config.name !== 'string') {
errors.push('Missing or invalid "name" field');
} else if (config.name.trim() === '') {
errors.push('"name" cannot be empty');
}
if (typeof config.description !== 'string') {
errors.push('Missing or invalid "description" field');
} else if (config.description.trim() === '') {
errors.push('"description" cannot be empty');
}
// Validate allowedTools if present
if (config.allowedTools !== undefined) {
if (!Array.isArray(config.allowedTools)) {
errors.push('"allowedTools" must be an array');
} else {
for (const tool of config.allowedTools) {
if (typeof tool !== 'string') {
errors.push('"allowedTools" must contain only strings');
break;
}
}
}
}
// Warn if body is empty
if (!config.body || config.body.trim() === '') {
warnings.push('Skill body is empty');
}
return {
isValid: errors.length === 0,
errors,
warnings,
};
}
/**
* Refreshes the skills cache by loading all skills from disk.
*/
async refreshCache(): Promise<void> {
const skillsCache = new Map<SkillLevel, SkillConfig[]>();
this.parseErrors.clear();
const levels: SkillLevel[] = ['project', 'user'];
for (const level of levels) {
const levelSkills = await this.listSkillsAtLevel(level);
skillsCache.set(level, levelSkills);
}
this.skillsCache = skillsCache;
this.notifyChangeListeners();
}
/**
* Parses a SKILL.md file and returns the configuration.
*
* @param filePath - Path to the SKILL.md file
* @param level - Storage level
* @returns SkillConfig
* @throws SkillError if parsing fails
*/
parseSkillFile(filePath: string, level: SkillLevel): Promise<SkillConfig> {
return this.parseSkillFileInternal(filePath, level);
}
/**
* Internal implementation of skill file parsing.
*/
private async parseSkillFileInternal(
filePath: string,
level: SkillLevel,
): Promise<SkillConfig> {
let content: string;
try {
content = await fs.readFile(filePath, 'utf8');
} catch (error) {
const skillError = new SkillError(
`Failed to read skill file: ${error instanceof Error ? error.message : 'Unknown error'}`,
SkillErrorCode.FILE_ERROR,
);
this.parseErrors.set(filePath, skillError);
throw skillError;
}
return this.parseSkillContent(content, filePath, level);
}
/**
* Parses skill content from a string.
*
* @param content - File content
* @param filePath - File path for error reporting
* @param level - Storage level
* @returns SkillConfig
* @throws SkillError if parsing fails
*/
parseSkillContent(
content: string,
filePath: string,
level: SkillLevel,
): SkillConfig {
try {
// Split frontmatter and content
const frontmatterRegex = /^---\n([\s\S]*?)\n---\n([\s\S]*)$/;
const match = content.match(frontmatterRegex);
if (!match) {
throw new Error('Invalid format: missing YAML frontmatter');
}
const [, frontmatterYaml, body] = match;
// Parse YAML frontmatter
const frontmatter = parseYaml(frontmatterYaml) as Record<string, unknown>;
// Extract required fields
const nameRaw = frontmatter['name'];
const descriptionRaw = frontmatter['description'];
if (nameRaw == null || nameRaw === '') {
throw new Error('Missing "name" in frontmatter');
}
if (descriptionRaw == null || descriptionRaw === '') {
throw new Error('Missing "description" in frontmatter');
}
// Convert to strings
const name = String(nameRaw);
const description = String(descriptionRaw);
// Extract optional fields
const allowedToolsRaw = frontmatter['allowedTools'] as
| unknown[]
| undefined;
let allowedTools: string[] | undefined;
if (allowedToolsRaw !== undefined) {
if (Array.isArray(allowedToolsRaw)) {
allowedTools = allowedToolsRaw.map(String);
} else {
throw new Error('"allowedTools" must be an array');
}
}
const config: SkillConfig = {
name,
description,
allowedTools,
level,
filePath,
body: body.trim(),
};
// Validate the parsed configuration
const validation = this.validateConfig(config);
if (!validation.isValid) {
throw new Error(`Validation failed: ${validation.errors.join(', ')}`);
}
return config;
} catch (error) {
const skillError = new SkillError(
`Failed to parse skill file: ${error instanceof Error ? error.message : 'Unknown error'}`,
SkillErrorCode.PARSE_ERROR,
);
this.parseErrors.set(filePath, skillError);
throw skillError;
}
}
/**
* Gets the base directory for skills at a specific level.
*
* @param level - Storage level
* @returns Absolute directory path
*/
getSkillsBaseDir(level: SkillLevel): string {
const baseDir =
level === 'project'
? path.join(
this.config.getProjectRoot(),
QWEN_CONFIG_DIR,
SKILLS_CONFIG_DIR,
)
: path.join(os.homedir(), QWEN_CONFIG_DIR, SKILLS_CONFIG_DIR);
return baseDir;
}
/**
* Lists skills at a specific level.
*
* @param level - Storage level to scan
* @returns Array of skill configurations
*/
private async listSkillsAtLevel(level: SkillLevel): Promise<SkillConfig[]> {
const projectRoot = this.config.getProjectRoot();
const homeDir = os.homedir();
const isHomeDirectory = path.resolve(projectRoot) === path.resolve(homeDir);
// If project level is requested but project root is same as home directory,
// return empty array to avoid conflicts between project and global skills
if (level === 'project' && isHomeDirectory) {
return [];
}
const baseDir = this.getSkillsBaseDir(level);
try {
const entries = await fs.readdir(baseDir, { withFileTypes: true });
const skills: SkillConfig[] = [];
for (const entry of entries) {
// Only process directories (each skill is a directory)
if (!entry.isDirectory()) continue;
const skillDir = path.join(baseDir, entry.name);
const skillManifest = path.join(skillDir, SKILL_MANIFEST_FILE);
try {
// Check if SKILL.md exists
await fs.access(skillManifest);
const config = await this.parseSkillFileInternal(
skillManifest,
level,
);
skills.push(config);
} catch (error) {
// Skip directories without valid SKILL.md
if (error instanceof SkillError) {
// Parse error was already recorded
console.warn(
`Failed to parse skill at ${skillDir}: ${error.message}`,
);
}
continue;
}
}
return skills;
} catch (_error) {
// Directory doesn't exist or can't be read
return [];
}
}
/**
* Finds a skill by name at a specific level.
*
* @param name - Name of the skill to find
* @param level - Storage level to search
* @returns SkillConfig or null if not found
*/
private async findSkillByNameAtLevel(
name: string,
level: SkillLevel,
): Promise<SkillConfig | null> {
await this.ensureLevelCache(level);
const levelSkills = this.skillsCache?.get(level) || [];
// Find the skill with matching name
return levelSkills.find((skill) => skill.name === name) || null;
}
/**
* Ensures the cache is populated for a specific level without loading other levels.
*/
private async ensureLevelCache(level: SkillLevel): Promise<void> {
if (!this.skillsCache) {
this.skillsCache = new Map<SkillLevel, SkillConfig[]>();
}
if (!this.skillsCache.has(level)) {
const levelSkills = await this.listSkillsAtLevel(level);
this.skillsCache.set(level, levelSkills);
}
}
}

View File

@@ -0,0 +1,105 @@
/**
* @license
* Copyright 2025 Qwen
* SPDX-License-Identifier: Apache-2.0
*/
/**
* Represents the storage level for a skill configuration.
* - 'project': Stored in `.qwen/skills/` within the project directory
* - 'user': Stored in `~/.qwen/skills/` in the user's home directory
*/
export type SkillLevel = 'project' | 'user';
/**
* Core configuration for a skill as stored in SKILL.md files.
* Each skill directory contains a SKILL.md file with YAML frontmatter
* containing metadata, followed by markdown content describing the skill.
*/
export interface SkillConfig {
/** Unique name identifier for the skill */
name: string;
/** Human-readable description of what this skill provides */
description: string;
/**
* Optional list of tool names that this skill is allowed to use.
* For v1, this is informational only (no gating).
*/
allowedTools?: string[];
/**
* Storage level - determines where the configuration file is stored
*/
level: SkillLevel;
/**
* Absolute path to the skill directory containing SKILL.md
*/
filePath: string;
/**
* The markdown body content from SKILL.md (after the frontmatter)
*/
body: string;
}
/**
* Runtime configuration for a skill when it's being actively used.
* Extends SkillConfig with additional runtime-specific fields.
*/
export type SkillRuntimeConfig = SkillConfig;
/**
* Result of a validation operation on a skill configuration.
*/
export interface SkillValidationResult {
/** Whether the configuration is valid */
isValid: boolean;
/** Array of error messages if validation failed */
errors: string[];
/** Array of warning messages (non-blocking issues) */
warnings: string[];
}
/**
* Options for listing skills.
*/
export interface ListSkillsOptions {
/** Filter by storage level */
level?: SkillLevel;
/** Force refresh from disk, bypassing cache. Defaults to false. */
force?: boolean;
}
/**
* Error thrown when a skill operation fails.
*/
export class SkillError extends Error {
constructor(
message: string,
readonly code: SkillErrorCode,
readonly skillName?: string,
) {
super(message);
this.name = 'SkillError';
}
}
/**
* Error codes for skill operations.
*/
export const SkillErrorCode = {
NOT_FOUND: 'NOT_FOUND',
INVALID_CONFIG: 'INVALID_CONFIG',
INVALID_NAME: 'INVALID_NAME',
FILE_ERROR: 'FILE_ERROR',
PARSE_ERROR: 'PARSE_ERROR',
} as const;
export type SkillErrorCode =
(typeof SkillErrorCode)[keyof typeof SkillErrorCode];

View File

@@ -30,6 +30,7 @@ import {
ToolCallEvent,
} from '../types.js';
import { GIT_COMMIT_INFO, CLI_VERSION } from '../../generated/git-commit.js';
import { UserAccountManager } from '../../utils/userAccountManager.js';
import { InstallationManager } from '../../utils/installationManager.js';
import { safeJsonStringify } from '../../utils/safeJsonStringify.js';
@@ -89,8 +90,10 @@ expect.extend({
},
});
vi.mock('../../utils/userAccountManager.js');
vi.mock('../../utils/installationManager.js');
const mockUserAccount = vi.mocked(UserAccountManager.prototype);
const mockInstallMgr = vi.mocked(InstallationManager.prototype);
// TODO(richieforeman): Consider moving this to test setup globally.
@@ -125,7 +128,11 @@ describe('ClearcutLogger', () => {
vi.unstubAllEnvs();
});
function setup({ config = {} as Partial<ConfigParameters> } = {}) {
function setup({
config = {} as Partial<ConfigParameters>,
lifetimeGoogleAccounts = 1,
cachedGoogleAccount = 'test@google.com',
} = {}) {
server.resetHandlers(
http.post(CLEARCUT_URL, () => HttpResponse.text(EXAMPLE_RESPONSE)),
);
@@ -139,6 +146,10 @@ describe('ClearcutLogger', () => {
});
ClearcutLogger.clearInstance();
mockUserAccount.getCachedGoogleAccount.mockReturnValue(cachedGoogleAccount);
mockUserAccount.getLifetimeGoogleAccounts.mockReturnValue(
lifetimeGoogleAccounts,
);
mockInstallMgr.getInstallationId = vi
.fn()
.mockReturnValue('test-installation-id');
@@ -184,6 +195,19 @@ describe('ClearcutLogger', () => {
});
describe('createLogEvent', () => {
it('logs the total number of google accounts', () => {
const { logger } = setup({
lifetimeGoogleAccounts: 9001,
});
const event = logger?.createLogEvent(EventNames.API_ERROR, []);
expect(event?.event_metadata[0]).toContainEqual({
gemini_cli_key: EventMetadataKey.GEMINI_CLI_GOOGLE_ACCOUNTS_COUNT,
value: '9001',
});
});
it('logs the current surface from a github action', () => {
const { logger } = setup({});
@@ -227,6 +251,7 @@ describe('ClearcutLogger', () => {
// Define expected values
const session_id = 'test-session-id';
const auth_type = AuthType.USE_GEMINI;
const google_accounts = 123;
const surface = 'ide-1234';
const cli_version = CLI_VERSION;
const git_commit_hash = GIT_COMMIT_INFO;
@@ -235,6 +260,7 @@ describe('ClearcutLogger', () => {
// Setup logger with expected values
const { logger, loggerConfig } = setup({
lifetimeGoogleAccounts: google_accounts,
config: {},
});
vi.spyOn(loggerConfig, 'getContentGeneratorConfig').mockReturnValue({
@@ -257,6 +283,10 @@ describe('ClearcutLogger', () => {
gemini_cli_key: EventMetadataKey.GEMINI_CLI_AUTH_TYPE,
value: JSON.stringify(auth_type),
},
{
gemini_cli_key: EventMetadataKey.GEMINI_CLI_GOOGLE_ACCOUNTS_COUNT,
value: `${google_accounts}`,
},
{
gemini_cli_key: EventMetadataKey.GEMINI_CLI_SURFACE,
value: surface,
@@ -374,14 +404,10 @@ describe('ClearcutLogger', () => {
vi.stubEnv(key, value);
}
const event = logger?.createLogEvent(EventNames.API_ERROR, []);
expect(event?.event_metadata[0]).toEqual(
expect.arrayContaining([
{
gemini_cli_key: EventMetadataKey.GEMINI_CLI_SURFACE,
value: expectedValue,
},
]),
);
expect(event?.event_metadata[0][3]).toEqual({
gemini_cli_key: EventMetadataKey.GEMINI_CLI_SURFACE,
value: expectedValue,
});
},
);
});

View File

@@ -34,6 +34,7 @@ import type {
import { EventMetadataKey } from './event-metadata-key.js';
import type { Config } from '../../config/config.js';
import { InstallationManager } from '../../utils/installationManager.js';
import { UserAccountManager } from '../../utils/userAccountManager.js';
import { safeJsonStringify } from '../../utils/safeJsonStringify.js';
import { FixedDeque } from 'mnemonist';
import { GIT_COMMIT_INFO, CLI_VERSION } from '../../generated/git-commit.js';
@@ -156,6 +157,7 @@ export class ClearcutLogger {
private sessionData: EventValue[] = [];
private promptId: string = '';
private readonly installationManager: InstallationManager;
private readonly userAccountManager: UserAccountManager;
/**
* Queue of pending events that need to be flushed to the server. New events
@@ -184,6 +186,7 @@ export class ClearcutLogger {
this.events = new FixedDeque<LogEventEntry[]>(Array, MAX_EVENTS);
this.promptId = config?.getSessionId() ?? '';
this.installationManager = new InstallationManager();
this.userAccountManager = new UserAccountManager();
}
static getInstance(config?: Config): ClearcutLogger | undefined {
@@ -230,11 +233,14 @@ export class ClearcutLogger {
}
createLogEvent(eventName: EventNames, data: EventValue[] = []): LogEvent {
const email = this.userAccountManager.getCachedGoogleAccount();
if (eventName !== EventNames.START_SESSION) {
data.push(...this.sessionData);
}
const totalAccounts = this.userAccountManager.getLifetimeGoogleAccounts();
data = this.addDefaultFields(data);
data = this.addDefaultFields(data, totalAccounts);
const logEvent: LogEvent = {
console_type: 'GEMINI_CLI',
@@ -243,7 +249,12 @@ export class ClearcutLogger {
event_metadata: [data],
};
logEvent.client_install_id = this.installationManager.getInstallationId();
// Should log either email or install ID, not both. See go/cloudmill-1p-oss-instrumentation#define-sessionable-id
if (email) {
logEvent.client_email = email;
} else {
logEvent.client_install_id = this.installationManager.getInstallationId();
}
return logEvent;
}
@@ -1007,7 +1018,7 @@ export class ClearcutLogger {
* Adds default fields to data, and returns a new data array. This fields
* should exist on all log events.
*/
addDefaultFields(data: EventValue[]): EventValue[] {
addDefaultFields(data: EventValue[], totalAccounts: number): EventValue[] {
const surface = determineSurface();
const defaultLogMetadata: EventValue[] = [
@@ -1021,6 +1032,10 @@ export class ClearcutLogger {
this.config?.getContentGeneratorConfig()?.authType,
),
},
{
gemini_cli_key: EventMetadataKey.GEMINI_CLI_GOOGLE_ACCOUNTS_COUNT,
value: `${totalAccounts}`,
},
{
gemini_cli_key: EventMetadataKey.GEMINI_CLI_SURFACE,
value: surface,

View File

@@ -33,6 +33,7 @@ export const EVENT_MALFORMED_JSON_RESPONSE =
export const EVENT_FILE_OPERATION = 'qwen-code.file_operation';
export const EVENT_MODEL_SLASH_COMMAND = 'qwen-code.slash_command.model';
export const EVENT_SUBAGENT_EXECUTION = 'qwen-code.subagent_execution';
export const EVENT_SKILL_LAUNCH = 'qwen-code.skill_launch';
export const EVENT_AUTH = 'qwen-code.auth';
// Performance Events

View File

@@ -44,6 +44,7 @@ export {
logRipgrepFallback,
logNextSpeakerCheck,
logAuth,
logSkillLaunch,
} from './loggers.js';
export type { SlashCommandEvent, ChatCompressionEvent } from './types.js';
export {
@@ -63,6 +64,7 @@ export {
RipgrepFallbackEvent,
NextSpeakerCheckEvent,
AuthEvent,
SkillLaunchEvent,
} from './types.js';
export { makeSlashCommandEvent, makeChatCompressionEvent } from './types.js';
export type { TelemetryEvent } from './types.js';

View File

@@ -83,6 +83,7 @@ import type {
} from '@google/genai';
import { DiscoveredMCPTool } from '../tools/mcp-tool.js';
import * as uiTelemetry from './uiTelemetry.js';
import { UserAccountManager } from '../utils/userAccountManager.js';
import { makeFakeConfig } from '../test-utils/config.js';
describe('loggers', () => {
@@ -100,6 +101,10 @@ describe('loggers', () => {
vi.spyOn(uiTelemetry.uiTelemetryService, 'addEvent').mockImplementation(
mockUiEvent.addEvent,
);
vi.spyOn(
UserAccountManager.prototype,
'getCachedGoogleAccount',
).mockReturnValue('test-user@example.com');
vi.useFakeTimers();
vi.setSystemTime(new Date('2025-01-01T00:00:00.000Z'));
});
@@ -183,6 +188,7 @@ describe('loggers', () => {
body: 'CLI configuration loaded.',
attributes: {
'session.id': 'test-session-id',
'user.email': 'test-user@example.com',
'event.name': EVENT_CLI_CONFIG,
'event.timestamp': '2025-01-01T00:00:00.000Z',
model: 'test-model',
@@ -200,6 +206,8 @@ describe('loggers', () => {
mcp_tools: undefined,
mcp_tools_count: undefined,
output_format: 'json',
skills: undefined,
subagents: undefined,
},
});
});
@@ -227,6 +235,7 @@ describe('loggers', () => {
body: 'User prompt. Length: 11.',
attributes: {
'session.id': 'test-session-id',
'user.email': 'test-user@example.com',
'event.name': EVENT_USER_PROMPT,
'event.timestamp': '2025-01-01T00:00:00.000Z',
prompt_length: 11,
@@ -248,7 +257,7 @@ describe('loggers', () => {
const event = new UserPromptEvent(
11,
'prompt-id-9',
AuthType.USE_GEMINI,
AuthType.CLOUD_SHELL,
'test-prompt',
);
@@ -258,11 +267,12 @@ describe('loggers', () => {
body: 'User prompt. Length: 11.',
attributes: {
'session.id': 'test-session-id',
'user.email': 'test-user@example.com',
'event.name': EVENT_USER_PROMPT,
'event.timestamp': '2025-01-01T00:00:00.000Z',
prompt_length: 11,
prompt_id: 'prompt-id-9',
auth_type: 'gemini-api-key',
auth_type: 'cloud-shell',
},
});
});
@@ -305,7 +315,7 @@ describe('loggers', () => {
'test-model',
100,
'prompt-id-1',
AuthType.USE_GEMINI,
AuthType.LOGIN_WITH_GOOGLE,
usageData,
'test-response',
);
@@ -316,6 +326,7 @@ describe('loggers', () => {
body: 'API response from test-model. Status: 200. Duration: 100ms.',
attributes: {
'session.id': 'test-session-id',
'user.email': 'test-user@example.com',
'event.name': EVENT_API_RESPONSE,
'event.timestamp': '2025-01-01T00:00:00.000Z',
[SemanticAttributes.HTTP_STATUS_CODE]: 200,
@@ -331,7 +342,7 @@ describe('loggers', () => {
total_token_count: 0,
response_text: 'test-response',
prompt_id: 'prompt-id-1',
auth_type: 'gemini-api-key',
auth_type: 'oauth-personal',
},
});
@@ -377,6 +388,7 @@ describe('loggers', () => {
body: 'API request to test-model.',
attributes: {
'session.id': 'test-session-id',
'user.email': 'test-user@example.com',
'event.name': EVENT_API_REQUEST,
'event.timestamp': '2025-01-01T00:00:00.000Z',
model: 'test-model',
@@ -395,6 +407,7 @@ describe('loggers', () => {
body: 'API request to test-model.',
attributes: {
'session.id': 'test-session-id',
'user.email': 'test-user@example.com',
'event.name': EVENT_API_REQUEST,
'event.timestamp': '2025-01-01T00:00:00.000Z',
model: 'test-model',
@@ -419,6 +432,7 @@ describe('loggers', () => {
body: 'Switching to flash as Fallback.',
attributes: {
'session.id': 'test-session-id',
'user.email': 'test-user@example.com',
'event.name': EVENT_FLASH_FALLBACK,
'event.timestamp': '2025-01-01T00:00:00.000Z',
auth_type: 'vertex-ai',
@@ -453,6 +467,7 @@ describe('loggers', () => {
expect(emittedEvent.attributes).toEqual(
expect.objectContaining({
'session.id': 'test-session-id',
'user.email': 'test-user@example.com',
'event.name': EVENT_RIPGREP_FALLBACK,
error: 'ripgrep is not available',
}),
@@ -471,6 +486,7 @@ describe('loggers', () => {
expect(emittedEvent.attributes).toEqual(
expect.objectContaining({
'session.id': 'test-session-id',
'user.email': 'test-user@example.com',
'event.name': EVENT_RIPGREP_FALLBACK,
error: 'rg not found',
}),
@@ -584,6 +600,7 @@ describe('loggers', () => {
body: 'Tool call: test-function. Decision: accept. Success: true. Duration: 100ms.',
attributes: {
'session.id': 'test-session-id',
'user.email': 'test-user@example.com',
'event.name': EVENT_TOOL_CALL,
'event.timestamp': '2025-01-01T00:00:00.000Z',
function_name: 'test-function',
@@ -667,6 +684,7 @@ describe('loggers', () => {
body: 'Tool call: test-function. Decision: reject. Success: false. Duration: 100ms.',
attributes: {
'session.id': 'test-session-id',
'user.email': 'test-user@example.com',
'event.name': EVENT_TOOL_CALL,
'event.timestamp': '2025-01-01T00:00:00.000Z',
function_name: 'test-function',
@@ -743,6 +761,7 @@ describe('loggers', () => {
body: 'Tool call: test-function. Decision: modify. Success: true. Duration: 100ms.',
attributes: {
'session.id': 'test-session-id',
'user.email': 'test-user@example.com',
'event.name': EVENT_TOOL_CALL,
'event.timestamp': '2025-01-01T00:00:00.000Z',
function_name: 'test-function',
@@ -818,6 +837,7 @@ describe('loggers', () => {
body: 'Tool call: test-function. Success: true. Duration: 100ms.',
attributes: {
'session.id': 'test-session-id',
'user.email': 'test-user@example.com',
'event.name': EVENT_TOOL_CALL,
'event.timestamp': '2025-01-01T00:00:00.000Z',
function_name: 'test-function',
@@ -892,6 +912,7 @@ describe('loggers', () => {
body: 'Tool call: test-function. Success: false. Duration: 100ms.',
attributes: {
'session.id': 'test-session-id',
'user.email': 'test-user@example.com',
'event.name': EVENT_TOOL_CALL,
'event.timestamp': '2025-01-01T00:00:00.000Z',
function_name: 'test-function',
@@ -980,6 +1001,7 @@ describe('loggers', () => {
body: 'Tool call: mock_mcp_tool. Success: true. Duration: 100ms.',
attributes: {
'session.id': 'test-session-id',
'user.email': 'test-user@example.com',
'event.name': EVENT_TOOL_CALL,
'event.timestamp': '2025-01-01T00:00:00.000Z',
function_name: 'mock_mcp_tool',
@@ -1027,6 +1049,7 @@ describe('loggers', () => {
body: 'Malformed JSON response from test-model.',
attributes: {
'session.id': 'test-session-id',
'user.email': 'test-user@example.com',
'event.name': EVENT_MALFORMED_JSON_RESPONSE,
'event.timestamp': '2025-01-01T00:00:00.000Z',
model: 'test-model',
@@ -1070,6 +1093,7 @@ describe('loggers', () => {
body: 'File operation: read. Lines: 10.',
attributes: {
'session.id': 'test-session-id',
'user.email': 'test-user@example.com',
'event.name': EVENT_FILE_OPERATION,
'event.timestamp': '2025-01-01T00:00:00.000Z',
tool_name: 'test-tool',
@@ -1115,6 +1139,7 @@ describe('loggers', () => {
body: 'Tool output truncated for test-tool.',
attributes: {
'session.id': 'test-session-id',
'user.email': 'test-user@example.com',
'event.name': 'tool_output_truncated',
'event.timestamp': '2025-01-01T00:00:00.000Z',
eventName: 'tool_output_truncated',
@@ -1161,6 +1186,7 @@ describe('loggers', () => {
body: 'Installed extension vscode',
attributes: {
'session.id': 'test-session-id',
'user.email': 'test-user@example.com',
'event.name': EVENT_EXTENSION_INSTALL,
'event.timestamp': '2025-01-01T00:00:00.000Z',
extension_name: 'vscode',
@@ -1199,6 +1225,7 @@ describe('loggers', () => {
body: 'Uninstalled extension vscode',
attributes: {
'session.id': 'test-session-id',
'user.email': 'test-user@example.com',
'event.name': EVENT_EXTENSION_UNINSTALL,
'event.timestamp': '2025-01-01T00:00:00.000Z',
extension_name: 'vscode',
@@ -1235,6 +1262,7 @@ describe('loggers', () => {
body: 'Enabled extension vscode',
attributes: {
'session.id': 'test-session-id',
'user.email': 'test-user@example.com',
'event.name': EVENT_EXTENSION_ENABLE,
'event.timestamp': '2025-01-01T00:00:00.000Z',
extension_name: 'vscode',
@@ -1271,6 +1299,7 @@ describe('loggers', () => {
body: 'Disabled extension vscode',
attributes: {
'session.id': 'test-session-id',
'user.email': 'test-user@example.com',
'event.name': EVENT_EXTENSION_DISABLE,
'event.timestamp': '2025-01-01T00:00:00.000Z',
extension_name: 'vscode',

View File

@@ -9,6 +9,7 @@ import { logs } from '@opentelemetry/api-logs';
import { SemanticAttributes } from '@opentelemetry/semantic-conventions';
import type { Config } from '../config/config.js';
import { safeJsonStringify } from '../utils/safeJsonStringify.js';
import { UserAccountManager } from '../utils/userAccountManager.js';
import {
EVENT_API_ERROR,
EVENT_API_CANCEL,
@@ -37,6 +38,7 @@ import {
EVENT_MALFORMED_JSON_RESPONSE,
EVENT_INVALID_CHUNK,
EVENT_AUTH,
EVENT_SKILL_LAUNCH,
} from './constants.js';
import {
recordApiErrorMetrics,
@@ -84,6 +86,7 @@ import type {
MalformedJsonResponseEvent,
InvalidChunkEvent,
AuthEvent,
SkillLaunchEvent,
} from './types.js';
import type { UiEvent } from './uiTelemetry.js';
import { uiTelemetryService } from './uiTelemetry.js';
@@ -92,8 +95,11 @@ const shouldLogUserPrompts = (config: Config): boolean =>
config.getTelemetryLogPromptsEnabled();
function getCommonAttributes(config: Config): LogAttributes {
const userAccountManager = new UserAccountManager();
const email = userAccountManager.getCachedGoogleAccount();
return {
'session.id': config.getSessionId(),
...(email && { 'user.email': email }),
};
}
@@ -123,6 +129,8 @@ export function logStartSession(
mcp_tools: event.mcp_tools,
mcp_tools_count: event.mcp_tools_count,
output_format: event.output_format,
skills: event.skills,
subagents: event.subagents,
};
const logger = logs.getLogger(SERVICE_NAME);
@@ -865,3 +873,21 @@ export function logAuth(config: Config, event: AuthEvent): void {
};
logger.emit(logRecord);
}
export function logSkillLaunch(config: Config, event: SkillLaunchEvent): void {
if (!isTelemetrySdkInitialized()) return;
const attributes: LogAttributes = {
...getCommonAttributes(config),
...event,
'event.name': EVENT_SKILL_LAUNCH,
'event.timestamp': new Date().toISOString(),
};
const logger = logs.getLogger(SERVICE_NAME);
const logRecord: LogRecord = {
body: `Skill launch: ${event.skill_name}. Success: ${event.success}.`,
attributes,
};
logger.emit(logRecord);
}

View File

@@ -38,6 +38,7 @@ import type {
ModelSlashCommandEvent,
ExtensionDisableEvent,
AuthEvent,
SkillLaunchEvent,
RipgrepFallbackEvent,
EndSessionEvent,
} from '../types.js';
@@ -391,6 +392,8 @@ export class QwenLogger {
telemetry_enabled: event.telemetry_enabled,
telemetry_log_user_prompts_enabled:
event.telemetry_log_user_prompts_enabled,
skills: event.skills,
subagents: event.subagents,
},
});
@@ -827,6 +830,18 @@ export class QwenLogger {
this.flushIfNeeded();
}
logSkillLaunchEvent(event: SkillLaunchEvent): void {
const rumEvent = this.createActionEvent('misc', 'skill_launch', {
properties: {
skill_name: event.skill_name,
success: event.success ? 1 : 0,
},
});
this.enqueueLogEvent(rumEvent);
this.flushIfNeeded();
}
logChatCompressionEvent(event: ChatCompressionEvent): void {
const rumEvent = this.createActionEvent('misc', 'chat_compression', {
properties: {

View File

@@ -18,6 +18,9 @@ import {
import type { FileOperation } from './metrics.js';
export { ToolCallDecision };
import type { OutputFormat } from '../output/types.js';
import { ToolNames } from '../tools/tool-names.js';
import type { SkillTool } from '../tools/skill.js';
import type { TaskTool } from '../tools/task.js';
export interface BaseTelemetryEvent {
'event.name': string;
@@ -47,6 +50,8 @@ export class StartSessionEvent implements BaseTelemetryEvent {
mcp_tools_count?: number;
mcp_tools?: string;
output_format: OutputFormat;
skills?: string;
subagents?: string;
constructor(config: Config) {
const generatorConfig = config.getContentGeneratorConfig();
@@ -79,6 +84,7 @@ export class StartSessionEvent implements BaseTelemetryEvent {
config.getFileFilteringRespectGitIgnore();
this.mcp_servers_count = mcpServers ? Object.keys(mcpServers).length : 0;
this.output_format = config.getOutputFormat();
if (toolRegistry) {
const mcpTools = toolRegistry
.getAllTools()
@@ -87,6 +93,22 @@ export class StartSessionEvent implements BaseTelemetryEvent {
this.mcp_tools = mcpTools
.map((tool) => (tool as DiscoveredMCPTool).name)
.join(',');
const skillTool = toolRegistry.getTool(ToolNames.SKILL) as
| SkillTool
| undefined;
const skillNames = skillTool?.getAvailableSkillNames?.();
if (skillNames && skillNames.length > 0) {
this.skills = skillNames.join(',');
}
const taskTool = toolRegistry.getTool(ToolNames.TASK) as
| TaskTool
| undefined;
const subagentNames = taskTool?.getAvailableSubagentNames?.();
if (subagentNames && subagentNames.length > 0) {
this.subagents = subagentNames.join(',');
}
}
}
}
@@ -721,6 +743,20 @@ export class AuthEvent implements BaseTelemetryEvent {
}
}
export class SkillLaunchEvent implements BaseTelemetryEvent {
'event.name': 'skill_launch';
'event.timestamp': string;
skill_name: string;
success: boolean;
constructor(skill_name: string, success: boolean) {
this['event.name'] = 'skill_launch';
this['event.timestamp'] = new Date().toISOString();
this.skill_name = skill_name;
this.success = success;
}
}
export type TelemetryEvent =
| StartSessionEvent
| EndSessionEvent
@@ -749,7 +785,8 @@ export type TelemetryEvent =
| ExtensionUninstallEvent
| ToolOutputTruncatedEvent
| ModelSlashCommandEvent
| AuthEvent;
| AuthEvent
| SkillLaunchEvent;
export class ExtensionDisableEvent implements BaseTelemetryEvent {
'event.name': 'extension_disable';

View File

@@ -31,6 +31,8 @@ describe('LSTool', () => {
tempSecondaryDir,
]);
const userSkillsBase = path.join(os.homedir(), '.qwen', 'skills');
mockConfig = {
getTargetDir: () => tempRootDir,
getWorkspaceContext: () => mockWorkspaceContext,
@@ -39,6 +41,9 @@ describe('LSTool', () => {
respectGitIgnore: true,
respectQwenIgnore: true,
}),
storage: {
getUserSkillsDir: () => userSkillsBase,
},
} as unknown as Config;
lsTool = new LSTool(mockConfig);
@@ -288,7 +293,7 @@ describe('LSTool', () => {
};
const invocation = lsTool.build(params);
const description = invocation.getDescription();
const expected = path.relative(tempRootDir, params.path);
const expected = path.resolve(params.path);
expect(description).toBe(expected);
});
});

View File

@@ -9,6 +9,7 @@ import path from 'node:path';
import type { ToolInvocation, ToolResult } from './tools.js';
import { BaseDeclarativeTool, BaseToolInvocation, Kind } from './tools.js';
import { makeRelative, shortenPath } from '../utils/paths.js';
import { isSubpath } from '../utils/paths.js';
import type { Config } from '../config/config.js';
import { DEFAULT_FILE_FILTERING_OPTIONS } from '../config/constants.js';
import { ToolErrorType } from './tool-error.js';
@@ -311,8 +312,14 @@ export class LSTool extends BaseDeclarativeTool<LSToolParams, ToolResult> {
return `Path must be absolute: ${params.path}`;
}
const userSkillsBase = this.config.storage.getUserSkillsDir();
const isUnderUserSkills = isSubpath(userSkillsBase, params.path);
const workspaceContext = this.config.getWorkspaceContext();
if (!workspaceContext.isPathWithinWorkspace(params.path)) {
if (
!workspaceContext.isPathWithinWorkspace(params.path) &&
!isUnderUserSkills
) {
const directories = workspaceContext.getDirectories();
return `Path must be within one of the workspace directories: ${directories.join(
', ',

View File

@@ -217,9 +217,9 @@ describe('mcp-client', () => {
false,
);
expect(transport).toBeInstanceOf(StreamableHTTPClientTransport);
// eslint-disable-next-line @typescript-eslint/no-explicit-any
expect((transport as any)._url).toEqual(new URL('http://test-server'));
expect(transport).toEqual(
new StreamableHTTPClientTransport(new URL('http://test-server'), {}),
);
});
it('with headers', async () => {
@@ -232,13 +232,13 @@ describe('mcp-client', () => {
false,
);
expect(transport).toBeInstanceOf(StreamableHTTPClientTransport);
// eslint-disable-next-line @typescript-eslint/no-explicit-any
expect((transport as any)._url).toEqual(new URL('http://test-server'));
// eslint-disable-next-line @typescript-eslint/no-explicit-any
expect((transport as any)._requestInit?.headers).toEqual({
Authorization: 'derp',
});
expect(transport).toEqual(
new StreamableHTTPClientTransport(new URL('http://test-server'), {
requestInit: {
headers: { Authorization: 'derp' },
},
}),
);
});
});
@@ -251,9 +251,9 @@ describe('mcp-client', () => {
},
false,
);
expect(transport).toBeInstanceOf(SSEClientTransport);
// eslint-disable-next-line @typescript-eslint/no-explicit-any
expect((transport as any)._url).toEqual(new URL('http://test-server'));
expect(transport).toEqual(
new SSEClientTransport(new URL('http://test-server'), {}),
);
});
it('with headers', async () => {
@@ -266,13 +266,13 @@ describe('mcp-client', () => {
false,
);
expect(transport).toBeInstanceOf(SSEClientTransport);
// eslint-disable-next-line @typescript-eslint/no-explicit-any
expect((transport as any)._url).toEqual(new URL('http://test-server'));
// eslint-disable-next-line @typescript-eslint/no-explicit-any
expect((transport as any)._requestInit?.headers).toEqual({
Authorization: 'derp',
});
expect(transport).toEqual(
new SSEClientTransport(new URL('http://test-server'), {
requestInit: {
headers: { Authorization: 'derp' },
},
}),
);
});
});

View File

@@ -40,6 +40,7 @@ describe('ReadFileTool', () => {
getWorkspaceContext: () => createMockWorkspaceContext(tempRootDir),
storage: {
getProjectTempDir: () => path.join(tempRootDir, '.temp'),
getUserSkillsDir: () => path.join(os.homedir(), '.qwen', 'skills'),
},
getTruncateToolOutputThreshold: () => 2500,
getTruncateToolOutputLines: () => 500,

View File

@@ -20,6 +20,7 @@ import { FileOperation } from '../telemetry/metrics.js';
import { getProgrammingLanguage } from '../telemetry/telemetry-utils.js';
import { logFileOperation } from '../telemetry/loggers.js';
import { FileOperationEvent } from '../telemetry/types.js';
import { isSubpath } from '../utils/paths.js';
/**
* Parameters for the ReadFile tool
@@ -183,15 +184,20 @@ export class ReadFileTool extends BaseDeclarativeTool<
const workspaceContext = this.config.getWorkspaceContext();
const projectTempDir = this.config.storage.getProjectTempDir();
const userSkillsDir = this.config.storage.getUserSkillsDir();
const resolvedFilePath = path.resolve(filePath);
const resolvedProjectTempDir = path.resolve(projectTempDir);
const isWithinTempDir =
resolvedFilePath.startsWith(resolvedProjectTempDir + path.sep) ||
resolvedFilePath === resolvedProjectTempDir;
const isWithinTempDir = isSubpath(projectTempDir, resolvedFilePath);
const isWithinUserSkills = isSubpath(userSkillsDir, resolvedFilePath);
if (!workspaceContext.isPathWithinWorkspace(filePath) && !isWithinTempDir) {
if (
!workspaceContext.isPathWithinWorkspace(filePath) &&
!isWithinTempDir &&
!isWithinUserSkills
) {
const directories = workspaceContext.getDirectories();
return `File path must be within one of the workspace directories: ${directories.join(', ')} or within the project temp directory: ${projectTempDir}`;
return `File path must be within one of the workspace directories: ${directories.join(
', ',
)} or within the project temp directory: ${projectTempDir}`;
}
if (params.offset !== undefined && params.offset < 0) {
return 'Offset must be a non-negative number';

View File

@@ -0,0 +1,442 @@
/**
* @license
* Copyright 2025 Qwen
* SPDX-License-Identifier: Apache-2.0
*/
import { describe, it, expect, beforeEach, afterEach, vi } from 'vitest';
import { SkillTool, type SkillParams } from './skill.js';
import type { PartListUnion } from '@google/genai';
import type { ToolResultDisplay } from './tools.js';
import type { Config } from '../config/config.js';
import { SkillManager } from '../skills/skill-manager.js';
import type { SkillConfig } from '../skills/types.js';
import { partToString } from '../utils/partUtils.js';
// Type for accessing protected methods in tests
type SkillToolWithProtectedMethods = SkillTool & {
createInvocation: (params: SkillParams) => {
execute: (
signal?: AbortSignal,
updateOutput?: (output: ToolResultDisplay) => void,
) => Promise<{
llmContent: PartListUnion;
returnDisplay: ToolResultDisplay;
}>;
getDescription: () => string;
shouldConfirmExecute: () => Promise<boolean>;
};
};
// Mock dependencies
vi.mock('../skills/skill-manager.js');
vi.mock('../telemetry/index.js', () => ({
logSkillLaunch: vi.fn(),
SkillLaunchEvent: class {
constructor(
public skill_name: string,
public success: boolean,
) {}
},
}));
const MockedSkillManager = vi.mocked(SkillManager);
describe('SkillTool', () => {
let config: Config;
let skillTool: SkillTool;
let mockSkillManager: SkillManager;
let changeListeners: Array<() => void>;
const mockSkills: SkillConfig[] = [
{
name: 'code-review',
description: 'Specialized skill for reviewing code quality',
level: 'project',
filePath: '/project/.qwen/skills/code-review/SKILL.md',
body: 'Review code for quality and best practices.',
},
{
name: 'testing',
description: 'Skill for writing and running tests',
level: 'user',
filePath: '/home/user/.qwen/skills/testing/SKILL.md',
body: 'Help write comprehensive tests.',
allowedTools: ['read_file', 'write_file', 'shell'],
},
];
beforeEach(async () => {
// Setup fake timers
vi.useFakeTimers();
// Create mock config
config = {
getProjectRoot: vi.fn().mockReturnValue('/test/project'),
getSessionId: vi.fn().mockReturnValue('test-session-id'),
getSkillManager: vi.fn(),
getGeminiClient: vi.fn().mockReturnValue(undefined),
} as unknown as Config;
changeListeners = [];
// Setup SkillManager mock
mockSkillManager = {
listSkills: vi.fn().mockResolvedValue(mockSkills),
loadSkill: vi.fn(),
loadSkillForRuntime: vi.fn(),
addChangeListener: vi.fn((listener: () => void) => {
changeListeners.push(listener);
return () => {
const index = changeListeners.indexOf(listener);
if (index >= 0) {
changeListeners.splice(index, 1);
}
};
}),
getParseErrors: vi.fn().mockReturnValue(new Map()),
} as unknown as SkillManager;
MockedSkillManager.mockImplementation(() => mockSkillManager);
// Make config return the mock SkillManager
vi.mocked(config.getSkillManager).mockReturnValue(mockSkillManager);
// Create SkillTool instance
skillTool = new SkillTool(config);
// Allow async initialization to complete
await vi.runAllTimersAsync();
});
afterEach(() => {
vi.useRealTimers();
vi.clearAllMocks();
});
describe('initialization', () => {
it('should initialize with correct name and properties', () => {
expect(skillTool.name).toBe('skill');
expect(skillTool.displayName).toBe('Skill');
expect(skillTool.kind).toBe('read');
});
it('should load available skills during initialization', () => {
expect(mockSkillManager.listSkills).toHaveBeenCalled();
});
it('should subscribe to skill manager changes', () => {
expect(mockSkillManager.addChangeListener).toHaveBeenCalledTimes(1);
});
it('should update description with available skills', () => {
expect(skillTool.description).toContain('code-review');
expect(skillTool.description).toContain(
'Specialized skill for reviewing code quality',
);
expect(skillTool.description).toContain('testing');
expect(skillTool.description).toContain(
'Skill for writing and running tests',
);
});
it('should handle empty skills list gracefully', async () => {
vi.mocked(mockSkillManager.listSkills).mockResolvedValue([]);
const emptySkillTool = new SkillTool(config);
await vi.runAllTimersAsync();
expect(emptySkillTool.description).toContain(
'No skills are currently configured',
);
});
it('should handle skill loading errors gracefully', async () => {
vi.mocked(mockSkillManager.listSkills).mockRejectedValue(
new Error('Loading failed'),
);
const consoleSpy = vi.spyOn(console, 'warn').mockImplementation(() => {});
new SkillTool(config);
await vi.runAllTimersAsync();
expect(consoleSpy).toHaveBeenCalledWith(
'Failed to load skills for Skills tool:',
expect.any(Error),
);
consoleSpy.mockRestore();
});
});
describe('schema generation', () => {
it('should expose static schema without dynamic enums', () => {
const schema = skillTool.schema;
const properties = schema.parametersJsonSchema as {
properties: {
skill: {
type: string;
description: string;
enum?: string[];
};
};
};
expect(properties.properties.skill.type).toBe('string');
expect(properties.properties.skill.description).toBe(
'The skill name (no arguments). E.g., "pdf" or "xlsx"',
);
expect(properties.properties.skill.enum).toBeUndefined();
});
it('should keep schema static even when no skills available', async () => {
vi.mocked(mockSkillManager.listSkills).mockResolvedValue([]);
const emptySkillTool = new SkillTool(config);
await vi.runAllTimersAsync();
const schema = emptySkillTool.schema;
const properties = schema.parametersJsonSchema as {
properties: {
skill: {
type: string;
description: string;
enum?: string[];
};
};
};
expect(properties.properties.skill.type).toBe('string');
expect(properties.properties.skill.description).toBe(
'The skill name (no arguments). E.g., "pdf" or "xlsx"',
);
expect(properties.properties.skill.enum).toBeUndefined();
});
});
describe('validateToolParams', () => {
it('should validate valid parameters', () => {
const result = skillTool.validateToolParams({ skill: 'code-review' });
expect(result).toBeNull();
});
it('should reject empty skill', () => {
const result = skillTool.validateToolParams({ skill: '' });
expect(result).toBe('Parameter "skill" must be a non-empty string.');
});
it('should reject non-existent skill', () => {
const result = skillTool.validateToolParams({
skill: 'non-existent',
});
expect(result).toBe(
'Skill "non-existent" not found. Available skills: code-review, testing',
);
});
it('should show appropriate message when no skills available', async () => {
vi.mocked(mockSkillManager.listSkills).mockResolvedValue([]);
const emptySkillTool = new SkillTool(config);
await vi.runAllTimersAsync();
const result = emptySkillTool.validateToolParams({
skill: 'non-existent',
});
expect(result).toBe(
'Skill "non-existent" not found. No skills are currently available.',
);
});
});
describe('refreshSkills', () => {
it('should refresh when change listener fires', async () => {
const newSkills: SkillConfig[] = [
{
name: 'new-skill',
description: 'A brand new skill',
level: 'project',
filePath: '/project/.qwen/skills/new-skill/SKILL.md',
body: 'New skill content.',
},
];
vi.mocked(mockSkillManager.listSkills).mockResolvedValueOnce(newSkills);
const listener = changeListeners[0];
expect(listener).toBeDefined();
listener?.();
await vi.runAllTimersAsync();
expect(skillTool.description).toContain('new-skill');
expect(skillTool.description).toContain('A brand new skill');
});
it('should refresh available skills and update description', async () => {
const newSkills: SkillConfig[] = [
{
name: 'test-skill',
description: 'A test skill',
level: 'project',
filePath: '/project/.qwen/skills/test-skill/SKILL.md',
body: 'Test content.',
},
];
vi.mocked(mockSkillManager.listSkills).mockResolvedValue(newSkills);
await skillTool.refreshSkills();
expect(skillTool.description).toContain('test-skill');
expect(skillTool.description).toContain('A test skill');
});
});
describe('SkillToolInvocation', () => {
const mockRuntimeConfig: SkillConfig = {
...mockSkills[0],
};
beforeEach(() => {
vi.mocked(mockSkillManager.loadSkillForRuntime).mockResolvedValue(
mockRuntimeConfig,
);
});
it('should execute skill load successfully', async () => {
const params: SkillParams = {
skill: 'code-review',
};
const invocation = (
skillTool as SkillToolWithProtectedMethods
).createInvocation(params);
const result = await invocation.execute();
expect(mockSkillManager.loadSkillForRuntime).toHaveBeenCalledWith(
'code-review',
);
const llmText = partToString(result.llmContent);
expect(llmText).toContain(
'Base directory for this skill: /project/.qwen/skills/code-review',
);
expect(llmText.trim()).toContain(
'Review code for quality and best practices.',
);
expect(result.returnDisplay).toBe('Launching skill: code-review');
});
it('should include allowedTools in result when present', async () => {
const skillWithTools: SkillConfig = {
...mockSkills[1],
};
vi.mocked(mockSkillManager.loadSkillForRuntime).mockResolvedValue(
skillWithTools,
);
const params: SkillParams = {
skill: 'testing',
};
const invocation = (
skillTool as SkillToolWithProtectedMethods
).createInvocation(params);
const result = await invocation.execute();
const llmText = partToString(result.llmContent);
expect(llmText).toContain('testing');
// Base description is omitted from llmContent; ensure body is present.
expect(llmText).toContain('Help write comprehensive tests.');
expect(result.returnDisplay).toBe('Launching skill: testing');
});
it('should handle skill not found error', async () => {
vi.mocked(mockSkillManager.loadSkillForRuntime).mockResolvedValue(null);
const params: SkillParams = {
skill: 'non-existent',
};
const invocation = (
skillTool as SkillToolWithProtectedMethods
).createInvocation(params);
const result = await invocation.execute();
const llmText = partToString(result.llmContent);
expect(llmText).toContain('Skill "non-existent" not found');
});
it('should handle execution errors gracefully', async () => {
vi.mocked(mockSkillManager.loadSkillForRuntime).mockRejectedValue(
new Error('Loading failed'),
);
const consoleSpy = vi
.spyOn(console, 'error')
.mockImplementation(() => {});
const params: SkillParams = {
skill: 'code-review',
};
const invocation = (
skillTool as SkillToolWithProtectedMethods
).createInvocation(params);
const result = await invocation.execute();
const llmText = partToString(result.llmContent);
expect(llmText).toContain('Failed to load skill');
expect(llmText).toContain('Loading failed');
consoleSpy.mockRestore();
});
it('should not require confirmation', async () => {
const params: SkillParams = {
skill: 'code-review',
};
const invocation = (
skillTool as SkillToolWithProtectedMethods
).createInvocation(params);
const shouldConfirm = await invocation.shouldConfirmExecute();
expect(shouldConfirm).toBe(false);
});
it('should provide correct description', () => {
const params: SkillParams = {
skill: 'code-review',
};
const invocation = (
skillTool as SkillToolWithProtectedMethods
).createInvocation(params);
const description = invocation.getDescription();
expect(description).toBe('Launching skill: "code-review"');
});
it('should handle skill without additional files', async () => {
vi.mocked(mockSkillManager.loadSkillForRuntime).mockResolvedValue(
mockSkills[0],
);
const params: SkillParams = {
skill: 'code-review',
};
const invocation = (
skillTool as SkillToolWithProtectedMethods
).createInvocation(params);
const result = await invocation.execute();
const llmText = partToString(result.llmContent);
expect(llmText).not.toContain('## Additional Files');
expect(result.returnDisplay).toBe('Launching skill: code-review');
});
});
});

View File

@@ -0,0 +1,264 @@
/**
* @license
* Copyright 2025 Qwen
* SPDX-License-Identifier: Apache-2.0
*/
import { BaseDeclarativeTool, BaseToolInvocation, Kind } from './tools.js';
import { ToolNames, ToolDisplayNames } from './tool-names.js';
import type { ToolResult, ToolResultDisplay } from './tools.js';
import type { Config } from '../config/config.js';
import type { SkillManager } from '../skills/skill-manager.js';
import type { SkillConfig } from '../skills/types.js';
import { logSkillLaunch, SkillLaunchEvent } from '../telemetry/index.js';
import path from 'path';
export interface SkillParams {
skill: string;
}
/**
* Skill tool that enables the model to access skill definitions.
* The tool dynamically loads available skills and includes them in its description
* for the model to choose from.
*/
export class SkillTool extends BaseDeclarativeTool<SkillParams, ToolResult> {
static readonly Name: string = ToolNames.SKILL;
private skillManager: SkillManager;
private availableSkills: SkillConfig[] = [];
constructor(private readonly config: Config) {
// Initialize with a basic schema first
const initialSchema = {
type: 'object',
properties: {
skill: {
type: 'string',
description: 'The skill name (no arguments). E.g., "pdf" or "xlsx"',
},
},
required: ['skill'],
additionalProperties: false,
$schema: 'http://json-schema.org/draft-07/schema#',
};
super(
SkillTool.Name,
ToolDisplayNames.SKILL,
'Execute a skill within the main conversation. Loading available skills...', // Initial description
Kind.Read,
initialSchema,
true, // isOutputMarkdown
false, // canUpdateOutput
);
this.skillManager = config.getSkillManager();
this.skillManager.addChangeListener(() => {
void this.refreshSkills();
});
// Initialize the tool asynchronously
this.refreshSkills();
}
/**
* Asynchronously initializes the tool by loading available skills
* and updating the description and schema.
*/
async refreshSkills(): Promise<void> {
try {
this.availableSkills = await this.skillManager.listSkills();
this.updateDescriptionAndSchema();
} catch (error) {
console.warn('Failed to load skills for Skills tool:', error);
this.availableSkills = [];
this.updateDescriptionAndSchema();
} finally {
// Update the client with the new tools
const geminiClient = this.config.getGeminiClient();
if (geminiClient && geminiClient.isInitialized()) {
await geminiClient.setTools();
}
}
}
/**
* Updates the tool's description and schema based on available skills.
*/
private updateDescriptionAndSchema(): void {
let skillDescriptions = '';
if (this.availableSkills.length === 0) {
skillDescriptions =
'No skills are currently configured. Skills can be created by adding directories with SKILL.md files to .qwen/skills/ or ~/.qwen/skills/.';
} else {
skillDescriptions = this.availableSkills
.map(
(skill) => `<skill>
<name>
${skill.name}
</name>
<description>
${skill.description} (${skill.level})
</description>
<location>
${skill.level}
</location>
</skill>`,
)
.join('\n');
}
const baseDescription = `Execute a skill within the main conversation
<skills_instructions>
When users ask you to perform tasks, check if any of the available skills below can help complete the task more effectively. Skills provide specialized capabilities and domain knowledge.
How to invoke:
- Use this tool with the skill name only (no arguments)
- Examples:
- \`skill: "pdf"\` - invoke the pdf skill
- \`skill: "xlsx"\` - invoke the xlsx skill
- \`skill: "ms-office-suite:pdf"\` - invoke using fully qualified name
Important:
- When a skill is relevant, you must invoke this tool IMMEDIATELY as your first action
- NEVER just announce or mention a skill in your text response without actually calling this tool
- This is a BLOCKING REQUIREMENT: invoke the relevant Skill tool BEFORE generating any other response about the task
- Only use skills listed in <available_skills> below
- Do not invoke a skill that is already running
- Do not use this tool for built-in CLI commands (like /help, /clear, etc.)
</skills_instructions>
<available_skills>
${skillDescriptions}
</available_skills>
`;
// Update description using object property assignment
(this as { description: string }).description = baseDescription;
}
override validateToolParams(params: SkillParams): string | null {
// Validate required fields
if (
!params.skill ||
typeof params.skill !== 'string' ||
params.skill.trim() === ''
) {
return 'Parameter "skill" must be a non-empty string.';
}
// Validate that the skill exists
const skillExists = this.availableSkills.some(
(skill) => skill.name === params.skill,
);
if (!skillExists) {
const availableNames = this.availableSkills.map((s) => s.name);
if (availableNames.length === 0) {
return `Skill "${params.skill}" not found. No skills are currently available.`;
}
return `Skill "${params.skill}" not found. Available skills: ${availableNames.join(', ')}`;
}
return null;
}
protected createInvocation(params: SkillParams) {
return new SkillToolInvocation(this.config, this.skillManager, params);
}
getAvailableSkillNames(): string[] {
return this.availableSkills.map((skill) => skill.name);
}
}
class SkillToolInvocation extends BaseToolInvocation<SkillParams, ToolResult> {
constructor(
private readonly config: Config,
private readonly skillManager: SkillManager,
params: SkillParams,
) {
super(params);
}
getDescription(): string {
return `Launching skill: "${this.params.skill}"`;
}
override async shouldConfirmExecute(): Promise<false> {
// Skill loading is a read-only operation, no confirmation needed
return false;
}
async execute(
_signal?: AbortSignal,
_updateOutput?: (output: ToolResultDisplay) => void,
): Promise<ToolResult> {
try {
// Load the skill with runtime config (includes additional files)
const skill = await this.skillManager.loadSkillForRuntime(
this.params.skill,
);
if (!skill) {
// Log failed skill launch
logSkillLaunch(
this.config,
new SkillLaunchEvent(this.params.skill, false),
);
// Get parse errors if any
const parseErrors = this.skillManager.getParseErrors();
const errorMessages: string[] = [];
for (const [filePath, error] of parseErrors) {
if (filePath.includes(this.params.skill)) {
errorMessages.push(`Parse error at ${filePath}: ${error.message}`);
}
}
const errorDetail =
errorMessages.length > 0
? `\nErrors:\n${errorMessages.join('\n')}`
: '';
return {
llmContent: `Skill "${this.params.skill}" not found.${errorDetail}`,
returnDisplay: `Skill "${this.params.skill}" not found.${errorDetail}`,
};
}
// Log successful skill launch
logSkillLaunch(
this.config,
new SkillLaunchEvent(this.params.skill, true),
);
const baseDir = path.dirname(skill.filePath);
// Build markdown content for LLM (show base dir, then body)
const llmContent = `Base directory for this skill: ${baseDir}\n\n${skill.body}\n`;
return {
llmContent: [{ text: llmContent }],
returnDisplay: `Launching skill: ${skill.name}`,
};
} catch (error) {
const errorMessage =
error instanceof Error ? error.message : String(error);
console.error(`[SkillsTool] Error launching skill: ${errorMessage}`);
// Log failed skill launch
logSkillLaunch(
this.config,
new SkillLaunchEvent(this.params.skill, false),
);
return {
llmContent: `Failed to load skill "${this.params.skill}": ${errorMessage}`,
returnDisplay: `Failed to load skill "${this.params.skill}": ${errorMessage}`,
};
}
}
}

Some files were not shown because too many files have changed in this diff Show More