mirror of
https://github.com/QwenLM/qwen-code.git
synced 2025-12-23 18:19:15 +00:00
Compare commits
14 Commits
fix/openai
...
v0.0.14-ni
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b0b5689da2 | ||
|
|
c405434c41 | ||
|
|
673854b446 | ||
|
|
4e7a7e2656 | ||
|
|
8379bc4d81 | ||
|
|
e148e4be28 | ||
|
|
48d8587bf9 | ||
|
|
5ecb4a2430 | ||
|
|
9c1d7228cb | ||
|
|
deb99a3b21 | ||
|
|
014059e8a6 | ||
|
|
3579d6555a | ||
|
|
9a56560eb4 | ||
|
|
da0863b943 |
13
.vscode/launch.json
vendored
13
.vscode/launch.json
vendored
@@ -101,6 +101,13 @@
|
||||
"env": {
|
||||
"GEMINI_SANDBOX": "false"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Attach by Process ID",
|
||||
"processId": "${command:PickProcess}",
|
||||
"request": "attach",
|
||||
"skipFiles": ["<node_internals>/**"],
|
||||
"type": "node"
|
||||
}
|
||||
],
|
||||
"inputs": [
|
||||
@@ -115,6 +122,12 @@
|
||||
"type": "promptString",
|
||||
"description": "Enter your prompt for non-interactive mode",
|
||||
"default": "Explain this code"
|
||||
},
|
||||
{
|
||||
"id": "debugPort",
|
||||
"type": "promptString",
|
||||
"description": "Enter the debug port number (default: 9229)",
|
||||
"default": "9229"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
31
CHANGELOG.md
31
CHANGELOG.md
@@ -1,5 +1,36 @@
|
||||
# Changelog
|
||||
|
||||
## 0.0.13
|
||||
|
||||
- Added YOLO mode support for automatic vision model switching with CLI arguments and environment variables.
|
||||
- Fixed ripgrep lazy loading to resolve VS Code IDE companion startup issues.
|
||||
- Fixed authentication hang when selecting Qwen OAuth.
|
||||
- Added OpenAI and Qwen OAuth authentication support to Zed ACP integration.
|
||||
- Fixed output token limit for Qwen models.
|
||||
- Fixed Markdown list display issues on Windows.
|
||||
- Enhanced vision model instructions and documentation.
|
||||
- Improved authentication method compatibility across different IDE integrations.
|
||||
|
||||
## 0.0.12
|
||||
|
||||
- Added vision model support for Qwen-OAuth authentication.
|
||||
- Synced upstream `gemini-cli` to v0.3.4 with numerous improvements and bug fixes.
|
||||
- Enhanced subagent functionality with system reminders and improved user experience.
|
||||
- Added tool call type coercion for better compatibility.
|
||||
- Fixed arrow key navigation issues on Windows.
|
||||
- Fixed missing tool call chunks for OpenAI logging.
|
||||
- Fixed system prompt issues to avoid malformed tool calls.
|
||||
- Fixed terminal flicker when subagent is executing.
|
||||
- Fixed duplicate subagents configuration when running in home directory.
|
||||
- Fixed Esc key unable to cancel subagent dialog.
|
||||
- Added confirmation prompt for `/init` command when context file exists.
|
||||
- Added `skipLoopDetection` configuration option.
|
||||
- Fixed `is_background` parameter reset issues.
|
||||
- Enhanced Windows compatibility with multi-line paste handling.
|
||||
- Improved subagent documentation and branding consistency.
|
||||
- Fixed various linting errors and improved code quality.
|
||||
- Miscellaneous improvements and bug fixes.
|
||||
|
||||
## 0.0.11
|
||||
|
||||
- Added subagents feature with file-based configuration system for specialized AI assistants.
|
||||
|
||||
53
README.md
53
README.md
@@ -54,6 +54,7 @@ For detailed setup instructions, see [Authorization](#authorization).
|
||||
- **Code Understanding & Editing** - Query and edit large codebases beyond traditional context window limits
|
||||
- **Workflow Automation** - Automate operational tasks like handling pull requests and complex rebases
|
||||
- **Enhanced Parser** - Adapted parser specifically optimized for Qwen-Coder models
|
||||
- **Vision Model Support** - Automatically detect images in your input and seamlessly switch to vision-capable models for multimodal analysis
|
||||
|
||||
## Installation
|
||||
|
||||
@@ -121,6 +122,58 @@ Create or edit `.qwen/settings.json` in your home directory:
|
||||
|
||||
> 📝 **Note**: Session token limit applies to a single conversation, not cumulative API calls.
|
||||
|
||||
### Vision Model Configuration
|
||||
|
||||
Qwen Code includes intelligent vision model auto-switching that detects images in your input and can automatically switch to vision-capable models for multimodal analysis. **This feature is enabled by default** - when you include images in your queries, you'll see a dialog asking how you'd like to handle the vision model switch.
|
||||
|
||||
#### Skip the Switch Dialog (Optional)
|
||||
|
||||
If you don't want to see the interactive dialog each time, configure the default behavior in your `.qwen/settings.json`:
|
||||
|
||||
```json
|
||||
{
|
||||
"experimental": {
|
||||
"vlmSwitchMode": "once"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Available modes:**
|
||||
|
||||
- **`"once"`** - Switch to vision model for this query only, then revert
|
||||
- **`"session"`** - Switch to vision model for the entire session
|
||||
- **`"persist"`** - Continue with current model (no switching)
|
||||
- **Not set** - Show interactive dialog each time (default)
|
||||
|
||||
#### Command Line Override
|
||||
|
||||
You can also set the behavior via command line:
|
||||
|
||||
```bash
|
||||
# Switch once per query
|
||||
qwen --vlm-switch-mode once
|
||||
|
||||
# Switch for entire session
|
||||
qwen --vlm-switch-mode session
|
||||
|
||||
# Never switch automatically
|
||||
qwen --vlm-switch-mode persist
|
||||
```
|
||||
|
||||
#### Disable Vision Models (Optional)
|
||||
|
||||
To completely disable vision model support, add to your `.qwen/settings.json`:
|
||||
|
||||
```json
|
||||
{
|
||||
"experimental": {
|
||||
"visionModelPreview": false
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
> 💡 **Tip**: In YOLO mode (`--yolo`), vision switching happens automatically without prompts when images are detected.
|
||||
|
||||
### Authorization
|
||||
|
||||
Choose your preferred authentication method based on your needs:
|
||||
|
||||
@@ -124,6 +124,18 @@ Slash commands provide meta-level control over the CLI itself.
|
||||
- **`/auth`**
|
||||
- **Description:** Open a dialog that lets you change the authentication method.
|
||||
|
||||
- **`/approval-mode`**
|
||||
- **Description:** Change the approval mode for tool usage.
|
||||
- **Usage:** `/approval-mode [mode] [--session|--project|--user]`
|
||||
- **Available Modes:**
|
||||
- **`plan`**: Analyze only; do not modify files or execute commands
|
||||
- **`default`**: Require approval for file edits or shell commands
|
||||
- **`auto-edit`**: Automatically approve file edits
|
||||
- **`yolo`**: Automatically approve all tools
|
||||
- **Examples:**
|
||||
- `/approval-mode plan --project` (persist plan mode for this project)
|
||||
- `/approval-mode yolo --user` (persist YOLO mode for this user across projects)
|
||||
|
||||
- **`/about`**
|
||||
- **Description:** Show version info. Please share this information when filing issues.
|
||||
|
||||
|
||||
@@ -362,6 +362,18 @@ If you are experiencing performance issues with file searching (e.g., with `@` c
|
||||
"skipLoopDetection": true
|
||||
```
|
||||
|
||||
- **`approvalMode`** (string):
|
||||
- **Description:** Sets the default approval mode for tool usage. Accepted values are:
|
||||
- `plan`: Analyze only, do not modify files or execute commands.
|
||||
- `default`: Require approval before file edits or shell commands run.
|
||||
- `auto-edit`: Automatically approve file edits.
|
||||
- `yolo`: Automatically approve all tool calls.
|
||||
- **Default:** `"default"`
|
||||
- **Example:**
|
||||
```json
|
||||
"approvalMode": "plan"
|
||||
```
|
||||
|
||||
### Example `settings.json`:
|
||||
|
||||
```json
|
||||
@@ -486,12 +498,13 @@ Arguments passed directly when running the CLI can override other configurations
|
||||
- **`--yolo`**:
|
||||
- Enables YOLO mode, which automatically approves all tool calls.
|
||||
- **`--approval-mode <mode>`**:
|
||||
- Sets the approval mode for tool calls. Available modes:
|
||||
- `default`: Prompt for approval on each tool call (default behavior)
|
||||
- `auto_edit`: Automatically approve edit tools (edit, write_file) while prompting for others
|
||||
- `yolo`: Automatically approve all tool calls (equivalent to `--yolo`)
|
||||
- Sets the approval mode for tool calls. Supported modes:
|
||||
- `plan`: Analyze only—do not modify files or execute commands.
|
||||
- `default`: Require approval for file edits or shell commands (default behavior).
|
||||
- `auto-edit`: Automatically approve edit tools (edit, write_file) while prompting for others.
|
||||
- `yolo`: Automatically approve all tool calls (equivalent to `--yolo`).
|
||||
- Cannot be used together with `--yolo`. Use `--approval-mode=yolo` instead of `--yolo` for the new unified approach.
|
||||
- Example: `qwen --approval-mode auto_edit`
|
||||
- Example: `qwen --approval-mode auto-edit`
|
||||
- **`--allowed-tools <tool1,tool2,...>`**:
|
||||
- A comma-separated list of tool names that will bypass the confirmation dialog.
|
||||
- Example: `qwen --allowed-tools "ShellTool(git status)"`
|
||||
|
||||
@@ -4,16 +4,16 @@ This document lists the available keyboard shortcuts in Qwen Code.
|
||||
|
||||
## General
|
||||
|
||||
| Shortcut | Description |
|
||||
| -------- | --------------------------------------------------------------------------------------------------------------------- |
|
||||
| `Esc` | Close dialogs and suggestions. |
|
||||
| `Ctrl+C` | Cancel the ongoing request and clear the input. Press twice to exit the application. |
|
||||
| `Ctrl+D` | Exit the application if the input is empty. Press twice to confirm. |
|
||||
| `Ctrl+L` | Clear the screen. |
|
||||
| `Ctrl+O` | Toggle the display of the debug console. |
|
||||
| `Ctrl+S` | Allows long responses to print fully, disabling truncation. Use your terminal's scrollback to view the entire output. |
|
||||
| `Ctrl+T` | Toggle the display of tool descriptions. |
|
||||
| `Ctrl+Y` | Toggle auto-approval (YOLO mode) for all tool calls. |
|
||||
| Shortcut | Description |
|
||||
| ----------- | --------------------------------------------------------------------------------------------------------------------- |
|
||||
| `Esc` | Close dialogs and suggestions. |
|
||||
| `Ctrl+C` | Cancel the ongoing request and clear the input. Press twice to exit the application. |
|
||||
| `Ctrl+D` | Exit the application if the input is empty. Press twice to confirm. |
|
||||
| `Ctrl+L` | Clear the screen. |
|
||||
| `Ctrl+O` | Toggle the display of the debug console. |
|
||||
| `Ctrl+S` | Allows long responses to print fully, disabling truncation. Use your terminal's scrollback to view the entire output. |
|
||||
| `Ctrl+T` | Toggle the display of tool descriptions. |
|
||||
| `Shift+Tab` | Cycle approval modes (`plan` → `default` → `auto-edit` → `yolo`). |
|
||||
|
||||
## Input Prompt
|
||||
|
||||
|
||||
12
package-lock.json
generated
12
package-lock.json
generated
@@ -1,12 +1,12 @@
|
||||
{
|
||||
"name": "@qwen-code/qwen-code",
|
||||
"version": "0.0.11",
|
||||
"version": "0.0.14-nightly.1",
|
||||
"lockfileVersion": 3,
|
||||
"requires": true,
|
||||
"packages": {
|
||||
"": {
|
||||
"name": "@qwen-code/qwen-code",
|
||||
"version": "0.0.11",
|
||||
"version": "0.0.14-nightly.1",
|
||||
"workspaces": [
|
||||
"packages/*"
|
||||
],
|
||||
@@ -13454,7 +13454,7 @@
|
||||
},
|
||||
"packages/cli": {
|
||||
"name": "@qwen-code/qwen-code",
|
||||
"version": "0.0.11",
|
||||
"version": "0.0.14-nightly.1",
|
||||
"dependencies": {
|
||||
"@google/genai": "1.9.0",
|
||||
"@iarna/toml": "^2.2.5",
|
||||
@@ -13662,7 +13662,7 @@
|
||||
},
|
||||
"packages/core": {
|
||||
"name": "@qwen-code/qwen-code-core",
|
||||
"version": "0.0.11",
|
||||
"version": "0.0.14-nightly.1",
|
||||
"dependencies": {
|
||||
"@google/genai": "1.13.0",
|
||||
"@lvce-editor/ripgrep": "^1.6.0",
|
||||
@@ -13788,7 +13788,7 @@
|
||||
},
|
||||
"packages/test-utils": {
|
||||
"name": "@qwen-code/qwen-code-test-utils",
|
||||
"version": "0.0.11",
|
||||
"version": "0.0.14-nightly.1",
|
||||
"dev": true,
|
||||
"license": "Apache-2.0",
|
||||
"devDependencies": {
|
||||
@@ -13800,7 +13800,7 @@
|
||||
},
|
||||
"packages/vscode-ide-companion": {
|
||||
"name": "qwen-code-vscode-ide-companion",
|
||||
"version": "0.0.11",
|
||||
"version": "0.0.14-nightly.1",
|
||||
"license": "LICENSE",
|
||||
"dependencies": {
|
||||
"@modelcontextprotocol/sdk": "^1.15.1",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@qwen-code/qwen-code",
|
||||
"version": "0.0.11",
|
||||
"version": "0.0.14-nightly.1",
|
||||
"engines": {
|
||||
"node": ">=20.0.0"
|
||||
},
|
||||
@@ -13,7 +13,7 @@
|
||||
"url": "git+https://github.com/QwenLM/qwen-code.git"
|
||||
},
|
||||
"config": {
|
||||
"sandboxImageUri": "ghcr.io/qwenlm/qwen-code:0.0.11"
|
||||
"sandboxImageUri": "ghcr.io/qwenlm/qwen-code:0.0.14-nightly.1"
|
||||
},
|
||||
"scripts": {
|
||||
"start": "node scripts/start.js",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@qwen-code/qwen-code",
|
||||
"version": "0.0.11",
|
||||
"version": "0.0.14-nightly.1",
|
||||
"description": "Qwen Code",
|
||||
"repository": {
|
||||
"type": "git",
|
||||
@@ -25,7 +25,7 @@
|
||||
"dist"
|
||||
],
|
||||
"config": {
|
||||
"sandboxImageUri": "ghcr.io/qwenlm/qwen-code:0.0.11"
|
||||
"sandboxImageUri": "ghcr.io/qwenlm/qwen-code:0.0.14-nightly.1"
|
||||
},
|
||||
"dependencies": {
|
||||
"@google/genai": "1.9.0",
|
||||
|
||||
@@ -269,7 +269,7 @@ describe('Configuration Integration Tests', () => {
|
||||
parseArguments = parseArgs;
|
||||
});
|
||||
|
||||
it('should parse --approval-mode=auto_edit correctly through the full argument parsing flow', async () => {
|
||||
it('should parse --approval-mode=auto-edit correctly through the full argument parsing flow', async () => {
|
||||
const originalArgv = process.argv;
|
||||
|
||||
try {
|
||||
@@ -277,7 +277,7 @@ describe('Configuration Integration Tests', () => {
|
||||
'node',
|
||||
'script.js',
|
||||
'--approval-mode',
|
||||
'auto_edit',
|
||||
'auto-edit',
|
||||
'-p',
|
||||
'test',
|
||||
];
|
||||
@@ -285,7 +285,30 @@ describe('Configuration Integration Tests', () => {
|
||||
const argv = await parseArguments({} as Settings);
|
||||
|
||||
// Verify that the argument was parsed correctly
|
||||
expect(argv.approvalMode).toBe('auto_edit');
|
||||
expect(argv.approvalMode).toBe('auto-edit');
|
||||
expect(argv.prompt).toBe('test');
|
||||
expect(argv.yolo).toBe(false);
|
||||
} finally {
|
||||
process.argv = originalArgv;
|
||||
}
|
||||
});
|
||||
|
||||
it('should parse --approval-mode=plan correctly through the full argument parsing flow', async () => {
|
||||
const originalArgv = process.argv;
|
||||
|
||||
try {
|
||||
process.argv = [
|
||||
'node',
|
||||
'script.js',
|
||||
'--approval-mode',
|
||||
'plan',
|
||||
'-p',
|
||||
'test',
|
||||
];
|
||||
|
||||
const argv = await parseArguments({} as Settings);
|
||||
|
||||
expect(argv.approvalMode).toBe('plan');
|
||||
expect(argv.prompt).toBe('test');
|
||||
expect(argv.yolo).toBe(false);
|
||||
} finally {
|
||||
|
||||
@@ -262,9 +262,9 @@ describe('parseArguments', () => {
|
||||
});
|
||||
|
||||
it('should allow --approval-mode without --yolo', async () => {
|
||||
process.argv = ['node', 'script.js', '--approval-mode', 'auto_edit'];
|
||||
process.argv = ['node', 'script.js', '--approval-mode', 'auto-edit'];
|
||||
const argv = await parseArguments({} as Settings);
|
||||
expect(argv.approvalMode).toBe('auto_edit');
|
||||
expect(argv.approvalMode).toBe('auto-edit');
|
||||
expect(argv.yolo).toBe(false);
|
||||
});
|
||||
|
||||
@@ -1087,6 +1087,32 @@ describe('Approval mode tool exclusion logic', () => {
|
||||
expect(excludedTools).toContain(WriteFileTool.Name);
|
||||
});
|
||||
|
||||
it('should exclude all interactive tools in non-interactive mode with plan approval mode', async () => {
|
||||
process.argv = [
|
||||
'node',
|
||||
'script.js',
|
||||
'--approval-mode',
|
||||
'plan',
|
||||
'-p',
|
||||
'test',
|
||||
];
|
||||
const argv = await parseArguments({} as Settings);
|
||||
const settings: Settings = {};
|
||||
const extensions: Extension[] = [];
|
||||
|
||||
const config = await loadCliConfig(
|
||||
settings,
|
||||
extensions,
|
||||
'test-session',
|
||||
argv,
|
||||
);
|
||||
|
||||
const excludedTools = config.getExcludeTools();
|
||||
expect(excludedTools).toContain(ShellTool.Name);
|
||||
expect(excludedTools).toContain(EditTool.Name);
|
||||
expect(excludedTools).toContain(WriteFileTool.Name);
|
||||
});
|
||||
|
||||
it('should exclude all interactive tools in non-interactive mode with explicit default approval mode', async () => {
|
||||
process.argv = [
|
||||
'node',
|
||||
@@ -1113,12 +1139,12 @@ describe('Approval mode tool exclusion logic', () => {
|
||||
expect(excludedTools).toContain(WriteFileTool.Name);
|
||||
});
|
||||
|
||||
it('should exclude only shell tools in non-interactive mode with auto_edit approval mode', async () => {
|
||||
it('should exclude only shell tools in non-interactive mode with auto-edit approval mode', async () => {
|
||||
process.argv = [
|
||||
'node',
|
||||
'script.js',
|
||||
'--approval-mode',
|
||||
'auto_edit',
|
||||
'auto-edit',
|
||||
'-p',
|
||||
'test',
|
||||
];
|
||||
@@ -1189,8 +1215,9 @@ describe('Approval mode tool exclusion logic', () => {
|
||||
|
||||
const testCases = [
|
||||
{ args: ['node', 'script.js'] }, // default
|
||||
{ args: ['node', 'script.js', '--approval-mode', 'plan'] },
|
||||
{ args: ['node', 'script.js', '--approval-mode', 'default'] },
|
||||
{ args: ['node', 'script.js', '--approval-mode', 'auto_edit'] },
|
||||
{ args: ['node', 'script.js', '--approval-mode', 'auto-edit'] },
|
||||
{ args: ['node', 'script.js', '--approval-mode', 'yolo'] },
|
||||
{ args: ['node', 'script.js', '--yolo'] },
|
||||
];
|
||||
@@ -1215,12 +1242,12 @@ describe('Approval mode tool exclusion logic', () => {
|
||||
}
|
||||
});
|
||||
|
||||
it('should merge approval mode exclusions with settings exclusions in auto_edit mode', async () => {
|
||||
it('should merge approval mode exclusions with settings exclusions in auto-edit mode', async () => {
|
||||
process.argv = [
|
||||
'node',
|
||||
'script.js',
|
||||
'--approval-mode',
|
||||
'auto_edit',
|
||||
'auto-edit',
|
||||
'-p',
|
||||
'test',
|
||||
];
|
||||
@@ -1238,8 +1265,8 @@ describe('Approval mode tool exclusion logic', () => {
|
||||
const excludedTools = config.getExcludeTools();
|
||||
expect(excludedTools).toContain('custom_tool'); // From settings
|
||||
expect(excludedTools).toContain(ShellTool.Name); // From approval mode
|
||||
expect(excludedTools).not.toContain(EditTool.Name); // Should be allowed in auto_edit
|
||||
expect(excludedTools).not.toContain(WriteFileTool.Name); // Should be allowed in auto_edit
|
||||
expect(excludedTools).not.toContain(EditTool.Name); // Should be allowed in auto-edit
|
||||
expect(excludedTools).not.toContain(WriteFileTool.Name); // Should be allowed in auto-edit
|
||||
});
|
||||
|
||||
it('should throw an error for invalid approval mode values in loadCliConfig', async () => {
|
||||
@@ -1262,7 +1289,7 @@ describe('Approval mode tool exclusion logic', () => {
|
||||
invalidArgv as CliArgs,
|
||||
),
|
||||
).rejects.toThrow(
|
||||
'Invalid approval mode: invalid_mode. Valid values are: yolo, auto_edit, default',
|
||||
'Invalid approval mode: invalid_mode. Valid values are: plan, default, auto-edit, yolo',
|
||||
);
|
||||
});
|
||||
});
|
||||
@@ -1514,7 +1541,7 @@ describe('loadCliConfig model selection', () => {
|
||||
argv,
|
||||
);
|
||||
|
||||
expect(config.getModel()).toBe('qwen3-coder-plus');
|
||||
expect(config.getModel()).toBe('coder-model');
|
||||
});
|
||||
|
||||
it('always prefers model from argvs', async () => {
|
||||
@@ -1929,6 +1956,13 @@ describe('loadCliConfig approval mode', () => {
|
||||
expect(config.getApprovalMode()).toBe(ServerConfig.ApprovalMode.DEFAULT);
|
||||
});
|
||||
|
||||
it('should set PLAN approval mode when --approval-mode=plan', async () => {
|
||||
process.argv = ['node', 'script.js', '--approval-mode', 'plan'];
|
||||
const argv = await parseArguments({} as Settings);
|
||||
const config = await loadCliConfig({}, [], 'test-session', argv);
|
||||
expect(config.getApprovalMode()).toBe(ServerConfig.ApprovalMode.PLAN);
|
||||
});
|
||||
|
||||
it('should set YOLO approval mode when --yolo flag is used', async () => {
|
||||
process.argv = ['node', 'script.js', '--yolo'];
|
||||
const argv = await parseArguments({} as Settings);
|
||||
@@ -1950,8 +1984,8 @@ describe('loadCliConfig approval mode', () => {
|
||||
expect(config.getApprovalMode()).toBe(ServerConfig.ApprovalMode.DEFAULT);
|
||||
});
|
||||
|
||||
it('should set AUTO_EDIT approval mode when --approval-mode=auto_edit', async () => {
|
||||
process.argv = ['node', 'script.js', '--approval-mode', 'auto_edit'];
|
||||
it('should set AUTO_EDIT approval mode when --approval-mode=auto-edit', async () => {
|
||||
process.argv = ['node', 'script.js', '--approval-mode', 'auto-edit'];
|
||||
const argv = await parseArguments({} as Settings);
|
||||
const config = await loadCliConfig({}, [], 'test-session', argv);
|
||||
expect(config.getApprovalMode()).toBe(ServerConfig.ApprovalMode.AUTO_EDIT);
|
||||
@@ -1964,6 +1998,33 @@ describe('loadCliConfig approval mode', () => {
|
||||
expect(config.getApprovalMode()).toBe(ServerConfig.ApprovalMode.YOLO);
|
||||
});
|
||||
|
||||
it('should use approval mode from settings when CLI flags are not provided', async () => {
|
||||
process.argv = ['node', 'script.js'];
|
||||
const argv = await parseArguments({} as Settings);
|
||||
const settings: Settings = { approvalMode: 'plan' };
|
||||
const config = await loadCliConfig(settings, [], 'test-session', argv);
|
||||
expect(config.getApprovalMode()).toBe(ServerConfig.ApprovalMode.PLAN);
|
||||
});
|
||||
|
||||
it('should normalize approval mode values from settings', async () => {
|
||||
process.argv = ['node', 'script.js'];
|
||||
const argv = await parseArguments({} as Settings);
|
||||
const settings: Settings = { approvalMode: 'AutoEdit' };
|
||||
const config = await loadCliConfig(settings, [], 'test-session', argv);
|
||||
expect(config.getApprovalMode()).toBe(ServerConfig.ApprovalMode.AUTO_EDIT);
|
||||
});
|
||||
|
||||
it('should throw when approval mode in settings is invalid', async () => {
|
||||
process.argv = ['node', 'script.js'];
|
||||
const argv = await parseArguments({} as Settings);
|
||||
const settings: Settings = { approvalMode: 'invalid_mode' };
|
||||
await expect(
|
||||
loadCliConfig(settings, [], 'test-session', argv),
|
||||
).rejects.toThrow(
|
||||
'Invalid approval mode: invalid_mode. Valid values are: plan, default, auto-edit, yolo',
|
||||
);
|
||||
});
|
||||
|
||||
it('should prioritize --approval-mode over --yolo when both would be valid (but validation prevents this)', async () => {
|
||||
// Note: This test documents the intended behavior, but in practice the validation
|
||||
// prevents both flags from being used together
|
||||
@@ -1995,8 +2056,8 @@ describe('loadCliConfig approval mode', () => {
|
||||
expect(config.getApprovalMode()).toBe(ServerConfig.ApprovalMode.DEFAULT);
|
||||
});
|
||||
|
||||
it('should override --approval-mode=auto_edit to DEFAULT', async () => {
|
||||
process.argv = ['node', 'script.js', '--approval-mode', 'auto_edit'];
|
||||
it('should override --approval-mode=auto-edit to DEFAULT', async () => {
|
||||
process.argv = ['node', 'script.js', '--approval-mode', 'auto-edit'];
|
||||
const argv = await parseArguments({} as Settings);
|
||||
const config = await loadCliConfig({}, [], 'test-session', argv);
|
||||
expect(config.getApprovalMode()).toBe(ServerConfig.ApprovalMode.DEFAULT);
|
||||
@@ -2015,6 +2076,13 @@ describe('loadCliConfig approval mode', () => {
|
||||
const config = await loadCliConfig({}, [], 'test-session', argv);
|
||||
expect(config.getApprovalMode()).toBe(ServerConfig.ApprovalMode.DEFAULT);
|
||||
});
|
||||
|
||||
it('should allow PLAN approval mode in untrusted folders', async () => {
|
||||
process.argv = ['node', 'script.js', '--approval-mode', 'plan'];
|
||||
const argv = await parseArguments({} as Settings);
|
||||
const config = await loadCliConfig({}, [], 'test-session', argv);
|
||||
expect(config.getApprovalMode()).toBe(ServerConfig.ApprovalMode.PLAN);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
|
||||
@@ -52,6 +52,39 @@ const logger = {
|
||||
error: (...args: any[]) => console.error('[ERROR]', ...args),
|
||||
};
|
||||
|
||||
const VALID_APPROVAL_MODE_VALUES = [
|
||||
'plan',
|
||||
'default',
|
||||
'auto-edit',
|
||||
'yolo',
|
||||
] as const;
|
||||
|
||||
function formatApprovalModeError(value: string): Error {
|
||||
return new Error(
|
||||
`Invalid approval mode: ${value}. Valid values are: ${VALID_APPROVAL_MODE_VALUES.join(
|
||||
', ',
|
||||
)}`,
|
||||
);
|
||||
}
|
||||
|
||||
function parseApprovalModeValue(value: string): ApprovalMode {
|
||||
const normalized = value.trim().toLowerCase();
|
||||
switch (normalized) {
|
||||
case 'plan':
|
||||
return ApprovalMode.PLAN;
|
||||
case 'default':
|
||||
return ApprovalMode.DEFAULT;
|
||||
case 'yolo':
|
||||
return ApprovalMode.YOLO;
|
||||
case 'auto_edit':
|
||||
case 'autoedit':
|
||||
case 'auto-edit':
|
||||
return ApprovalMode.AUTO_EDIT;
|
||||
default:
|
||||
throw formatApprovalModeError(value);
|
||||
}
|
||||
}
|
||||
|
||||
export interface CliArgs {
|
||||
model: string | undefined;
|
||||
sandbox: boolean | string | undefined;
|
||||
@@ -82,6 +115,7 @@ export interface CliArgs {
|
||||
includeDirectories: string[] | undefined;
|
||||
tavilyApiKey: string | undefined;
|
||||
screenReader: boolean | undefined;
|
||||
vlmSwitchMode: string | undefined;
|
||||
}
|
||||
|
||||
export async function parseArguments(settings: Settings): Promise<CliArgs> {
|
||||
@@ -146,9 +180,9 @@ export async function parseArguments(settings: Settings): Promise<CliArgs> {
|
||||
})
|
||||
.option('approval-mode', {
|
||||
type: 'string',
|
||||
choices: ['default', 'auto_edit', 'yolo'],
|
||||
choices: ['plan', 'default', 'auto-edit', 'yolo'],
|
||||
description:
|
||||
'Set the approval mode: default (prompt for approval), auto_edit (auto-approve edit tools), yolo (auto-approve all tools)',
|
||||
'Set the approval mode: plan (plan only), default (prompt for approval), auto-edit (auto-approve edit tools), yolo (auto-approve all tools)',
|
||||
})
|
||||
.option('telemetry', {
|
||||
type: 'boolean',
|
||||
@@ -249,6 +283,13 @@ export async function parseArguments(settings: Settings): Promise<CliArgs> {
|
||||
description: 'Enable screen reader mode for accessibility.',
|
||||
default: false,
|
||||
})
|
||||
.option('vlm-switch-mode', {
|
||||
type: 'string',
|
||||
choices: ['once', 'session', 'persist'],
|
||||
description:
|
||||
'Default behavior when images are detected in input. Values: once (one-time switch), session (switch for entire session), persist (continue with current model). Overrides settings files.',
|
||||
default: process.env['VLM_SWITCH_MODE'],
|
||||
})
|
||||
.check((argv) => {
|
||||
if (argv.prompt && argv['promptInteractive']) {
|
||||
throw new Error(
|
||||
@@ -430,30 +471,21 @@ export async function loadCliConfig(
|
||||
// Determine approval mode with backward compatibility
|
||||
let approvalMode: ApprovalMode;
|
||||
if (argv.approvalMode) {
|
||||
// New --approval-mode flag takes precedence
|
||||
switch (argv.approvalMode) {
|
||||
case 'yolo':
|
||||
approvalMode = ApprovalMode.YOLO;
|
||||
break;
|
||||
case 'auto_edit':
|
||||
approvalMode = ApprovalMode.AUTO_EDIT;
|
||||
break;
|
||||
case 'default':
|
||||
approvalMode = ApprovalMode.DEFAULT;
|
||||
break;
|
||||
default:
|
||||
throw new Error(
|
||||
`Invalid approval mode: ${argv.approvalMode}. Valid values are: yolo, auto_edit, default`,
|
||||
);
|
||||
}
|
||||
approvalMode = parseApprovalModeValue(argv.approvalMode);
|
||||
} else if (argv.yolo) {
|
||||
approvalMode = ApprovalMode.YOLO;
|
||||
} else if (settings.approvalMode) {
|
||||
approvalMode = parseApprovalModeValue(settings.approvalMode);
|
||||
} else {
|
||||
// Fallback to legacy --yolo flag behavior
|
||||
approvalMode =
|
||||
argv.yolo || false ? ApprovalMode.YOLO : ApprovalMode.DEFAULT;
|
||||
approvalMode = ApprovalMode.DEFAULT;
|
||||
}
|
||||
|
||||
// Force approval mode to default if the folder is not trusted.
|
||||
if (!trustedFolder && approvalMode !== ApprovalMode.DEFAULT) {
|
||||
if (
|
||||
!trustedFolder &&
|
||||
approvalMode !== ApprovalMode.DEFAULT &&
|
||||
approvalMode !== ApprovalMode.PLAN
|
||||
) {
|
||||
logger.warn(
|
||||
`Approval mode overridden to "default" because the current folder is not trusted.`,
|
||||
);
|
||||
@@ -466,6 +498,7 @@ export async function loadCliConfig(
|
||||
const extraExcludes: string[] = [];
|
||||
if (!interactive && !argv.experimentalAcp) {
|
||||
switch (approvalMode) {
|
||||
case ApprovalMode.PLAN:
|
||||
case ApprovalMode.DEFAULT:
|
||||
// In default non-interactive mode, all tools that require approval are excluded.
|
||||
extraExcludes.push(ShellTool.Name, EditTool.Name, WriteFileTool.Name);
|
||||
@@ -524,6 +557,9 @@ export async function loadCliConfig(
|
||||
argv.screenReader !== undefined
|
||||
? argv.screenReader
|
||||
: (settings.ui?.accessibility?.screenReader ?? false);
|
||||
|
||||
const vlmSwitchMode =
|
||||
argv.vlmSwitchMode || settings.experimental?.vlmSwitchMode;
|
||||
return new Config({
|
||||
sessionId,
|
||||
embeddingModel: DEFAULT_GEMINI_EMBEDDING_MODEL,
|
||||
@@ -630,6 +666,7 @@ export async function loadCliConfig(
|
||||
skipNextSpeakerCheck: settings.model?.skipNextSpeakerCheck,
|
||||
enablePromptCompletion: settings.general?.enablePromptCompletion ?? false,
|
||||
skipLoopDetection: settings.skipLoopDetection ?? false,
|
||||
vlmSwitchMode,
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
@@ -69,7 +69,11 @@ const MOCK_WORKSPACE_SETTINGS_PATH = pathActual.join(
|
||||
);
|
||||
|
||||
// A more flexible type for test data that allows arbitrary properties.
|
||||
type TestSettings = Settings & { [key: string]: unknown };
|
||||
type TestSettings = Settings & {
|
||||
[key: string]: unknown;
|
||||
nested?: { [key: string]: unknown };
|
||||
nestedObj?: { [key: string]: unknown };
|
||||
};
|
||||
|
||||
vi.mock('fs', async (importOriginal) => {
|
||||
// Get all the functions from the real 'fs' module
|
||||
@@ -137,6 +141,9 @@ describe('Settings Loading and Merging', () => {
|
||||
advanced: {
|
||||
excludedEnvVars: [],
|
||||
},
|
||||
experimental: {},
|
||||
contentGenerator: {},
|
||||
systemPromptMappings: {},
|
||||
extensions: {
|
||||
disabled: [],
|
||||
workspacesWithMigrationNudge: [],
|
||||
@@ -197,6 +204,9 @@ describe('Settings Loading and Merging', () => {
|
||||
advanced: {
|
||||
excludedEnvVars: [],
|
||||
},
|
||||
experimental: {},
|
||||
contentGenerator: {},
|
||||
systemPromptMappings: {},
|
||||
extensions: {
|
||||
disabled: [],
|
||||
workspacesWithMigrationNudge: [],
|
||||
@@ -260,6 +270,9 @@ describe('Settings Loading and Merging', () => {
|
||||
advanced: {
|
||||
excludedEnvVars: [],
|
||||
},
|
||||
experimental: {},
|
||||
contentGenerator: {},
|
||||
systemPromptMappings: {},
|
||||
extensions: {
|
||||
disabled: [],
|
||||
workspacesWithMigrationNudge: [],
|
||||
@@ -320,6 +333,9 @@ describe('Settings Loading and Merging', () => {
|
||||
advanced: {
|
||||
excludedEnvVars: [],
|
||||
},
|
||||
experimental: {},
|
||||
contentGenerator: {},
|
||||
systemPromptMappings: {},
|
||||
extensions: {
|
||||
disabled: [],
|
||||
workspacesWithMigrationNudge: [],
|
||||
@@ -385,6 +401,9 @@ describe('Settings Loading and Merging', () => {
|
||||
advanced: {
|
||||
excludedEnvVars: [],
|
||||
},
|
||||
experimental: {},
|
||||
contentGenerator: {},
|
||||
systemPromptMappings: {},
|
||||
extensions: {
|
||||
disabled: [],
|
||||
workspacesWithMigrationNudge: [],
|
||||
@@ -477,6 +496,9 @@ describe('Settings Loading and Merging', () => {
|
||||
advanced: {
|
||||
excludedEnvVars: [],
|
||||
},
|
||||
experimental: {},
|
||||
contentGenerator: {},
|
||||
systemPromptMappings: {},
|
||||
extensions: {
|
||||
disabled: [],
|
||||
workspacesWithMigrationNudge: [],
|
||||
@@ -562,6 +584,9 @@ describe('Settings Loading and Merging', () => {
|
||||
advanced: {
|
||||
excludedEnvVars: [],
|
||||
},
|
||||
experimental: {},
|
||||
contentGenerator: {},
|
||||
systemPromptMappings: {},
|
||||
extensions: {
|
||||
disabled: [],
|
||||
workspacesWithMigrationNudge: [],
|
||||
@@ -691,6 +716,9 @@ describe('Settings Loading and Merging', () => {
|
||||
'/system/dir',
|
||||
],
|
||||
},
|
||||
experimental: {},
|
||||
contentGenerator: {},
|
||||
systemPromptMappings: {},
|
||||
extensions: {
|
||||
disabled: [],
|
||||
workspacesWithMigrationNudge: [],
|
||||
@@ -1431,6 +1459,9 @@ describe('Settings Loading and Merging', () => {
|
||||
advanced: {
|
||||
excludedEnvVars: [],
|
||||
},
|
||||
experimental: {},
|
||||
contentGenerator: {},
|
||||
systemPromptMappings: {},
|
||||
extensions: {
|
||||
disabled: [],
|
||||
workspacesWithMigrationNudge: [],
|
||||
@@ -1516,7 +1547,11 @@ describe('Settings Loading and Merging', () => {
|
||||
'workspace_endpoint_from_env/api',
|
||||
);
|
||||
expect(
|
||||
(settings.workspace.settings as TestSettings)['nested']['value'],
|
||||
(
|
||||
(settings.workspace.settings as TestSettings).nested as {
|
||||
[key: string]: unknown;
|
||||
}
|
||||
)['value'],
|
||||
).toBe('workspace_endpoint_from_env');
|
||||
expect((settings.merged as TestSettings)['endpoint']).toBe(
|
||||
'workspace_endpoint_from_env/api',
|
||||
@@ -1766,19 +1801,39 @@ describe('Settings Loading and Merging', () => {
|
||||
).toBeUndefined();
|
||||
|
||||
expect(
|
||||
(settings.user.settings as TestSettings)['nestedObj']['nestedNull'],
|
||||
(
|
||||
(settings.user.settings as TestSettings).nestedObj as {
|
||||
[key: string]: unknown;
|
||||
}
|
||||
)['nestedNull'],
|
||||
).toBeNull();
|
||||
expect(
|
||||
(settings.user.settings as TestSettings)['nestedObj']['nestedBool'],
|
||||
(
|
||||
(settings.user.settings as TestSettings).nestedObj as {
|
||||
[key: string]: unknown;
|
||||
}
|
||||
)['nestedBool'],
|
||||
).toBe(true);
|
||||
expect(
|
||||
(settings.user.settings as TestSettings)['nestedObj']['nestedNum'],
|
||||
(
|
||||
(settings.user.settings as TestSettings).nestedObj as {
|
||||
[key: string]: unknown;
|
||||
}
|
||||
)['nestedNum'],
|
||||
).toBe(0);
|
||||
expect(
|
||||
(settings.user.settings as TestSettings)['nestedObj']['nestedString'],
|
||||
(
|
||||
(settings.user.settings as TestSettings).nestedObj as {
|
||||
[key: string]: unknown;
|
||||
}
|
||||
)['nestedString'],
|
||||
).toBe('literal');
|
||||
expect(
|
||||
(settings.user.settings as TestSettings)['nestedObj']['anotherEnv'],
|
||||
(
|
||||
(settings.user.settings as TestSettings).nestedObj as {
|
||||
[key: string]: unknown;
|
||||
}
|
||||
)['anotherEnv'],
|
||||
).toBe('env_string_nested_value');
|
||||
|
||||
delete process.env['MY_ENV_STRING'];
|
||||
@@ -1864,6 +1919,9 @@ describe('Settings Loading and Merging', () => {
|
||||
advanced: {
|
||||
excludedEnvVars: [],
|
||||
},
|
||||
experimental: {},
|
||||
contentGenerator: {},
|
||||
systemPromptMappings: {},
|
||||
extensions: {
|
||||
disabled: [],
|
||||
workspacesWithMigrationNudge: [],
|
||||
@@ -2336,14 +2394,14 @@ describe('Settings Loading and Merging', () => {
|
||||
vimMode: false,
|
||||
},
|
||||
model: {
|
||||
maxSessionTurns: 0,
|
||||
maxSessionTurns: -1,
|
||||
},
|
||||
context: {
|
||||
includeDirectories: [],
|
||||
},
|
||||
security: {
|
||||
folderTrust: {
|
||||
enabled: null,
|
||||
enabled: false,
|
||||
},
|
||||
},
|
||||
};
|
||||
@@ -2352,9 +2410,9 @@ describe('Settings Loading and Merging', () => {
|
||||
|
||||
expect(v1Settings).toEqual({
|
||||
vimMode: false,
|
||||
maxSessionTurns: 0,
|
||||
maxSessionTurns: -1,
|
||||
includeDirectories: [],
|
||||
folderTrust: null,
|
||||
folderTrust: false,
|
||||
});
|
||||
});
|
||||
|
||||
|
||||
@@ -396,6 +396,24 @@ function mergeSettings(
|
||||
]),
|
||||
],
|
||||
},
|
||||
experimental: {
|
||||
...(systemDefaults.experimental || {}),
|
||||
...(user.experimental || {}),
|
||||
...(safeWorkspaceWithoutFolderTrust.experimental || {}),
|
||||
...(system.experimental || {}),
|
||||
},
|
||||
contentGenerator: {
|
||||
...(systemDefaults.contentGenerator || {}),
|
||||
...(user.contentGenerator || {}),
|
||||
...(safeWorkspaceWithoutFolderTrust.contentGenerator || {}),
|
||||
...(system.contentGenerator || {}),
|
||||
},
|
||||
systemPromptMappings: {
|
||||
...(systemDefaults.systemPromptMappings || {}),
|
||||
...(user.systemPromptMappings || {}),
|
||||
...(safeWorkspaceWithoutFolderTrust.systemPromptMappings || {}),
|
||||
...(system.systemPromptMappings || {}),
|
||||
},
|
||||
extensions: {
|
||||
...(systemDefaults.extensions || {}),
|
||||
...(user.extensions || {}),
|
||||
|
||||
@@ -746,11 +746,21 @@ export const SETTINGS_SCHEMA = {
|
||||
label: 'Vision Model Preview',
|
||||
category: 'Experimental',
|
||||
requiresRestart: false,
|
||||
default: false,
|
||||
default: true,
|
||||
description:
|
||||
'Enable vision model support and auto-switching functionality. When disabled, vision models like qwen-vl-max-latest will be hidden and auto-switching will not occur.',
|
||||
showInDialog: true,
|
||||
},
|
||||
vlmSwitchMode: {
|
||||
type: 'string',
|
||||
label: 'VLM Switch Mode',
|
||||
category: 'Experimental',
|
||||
requiresRestart: false,
|
||||
default: undefined as string | undefined,
|
||||
description:
|
||||
'Default behavior when images are detected in input. Values: once (one-time switch), session (switch for entire session), persist (continue with current model). If not set, user will be prompted each time. This is a temporary experimental feature.',
|
||||
showInDialog: false,
|
||||
},
|
||||
},
|
||||
},
|
||||
|
||||
@@ -882,6 +892,16 @@ export const SETTINGS_SCHEMA = {
|
||||
description: 'Disable all loop detection checks (streaming and LLM).',
|
||||
showInDialog: true,
|
||||
},
|
||||
approvalMode: {
|
||||
type: 'string',
|
||||
label: 'Default Approval Mode',
|
||||
category: 'General',
|
||||
requiresRestart: false,
|
||||
default: 'default',
|
||||
description:
|
||||
'Default approval mode for tool usage. Valid values: plan, default, auto-edit, yolo.',
|
||||
showInDialog: true,
|
||||
},
|
||||
enableWelcomeBack: {
|
||||
type: 'boolean',
|
||||
label: 'Enable Welcome Back',
|
||||
|
||||
@@ -15,6 +15,14 @@ vi.mock('../ui/commands/aboutCommand.js', async () => {
|
||||
};
|
||||
});
|
||||
|
||||
vi.mock('../ui/commands/approvalModeCommand.js', () => ({
|
||||
approvalModeCommand: {
|
||||
name: 'approval-mode',
|
||||
description: 'Approval mode command',
|
||||
kind: 'built-in',
|
||||
},
|
||||
}));
|
||||
|
||||
vi.mock('../ui/commands/ideCommand.js', () => ({ ideCommand: vi.fn() }));
|
||||
vi.mock('../ui/commands/restoreCommand.js', () => ({
|
||||
restoreCommand: vi.fn(),
|
||||
@@ -128,6 +136,10 @@ describe('BuiltinCommandLoader', () => {
|
||||
expect(aboutCmd).toBeDefined();
|
||||
expect(aboutCmd?.kind).toBe(CommandKind.BUILT_IN);
|
||||
|
||||
const approvalModeCmd = commands.find((c) => c.name === 'approval-mode');
|
||||
expect(approvalModeCmd).toBeDefined();
|
||||
expect(approvalModeCmd?.kind).toBe(CommandKind.BUILT_IN);
|
||||
|
||||
const ideCmd = commands.find((c) => c.name === 'ide');
|
||||
expect(ideCmd).toBeDefined();
|
||||
|
||||
|
||||
@@ -8,6 +8,8 @@ import type { ICommandLoader } from './types.js';
|
||||
import type { SlashCommand } from '../ui/commands/types.js';
|
||||
import type { Config } from '@qwen-code/qwen-code-core';
|
||||
import { aboutCommand } from '../ui/commands/aboutCommand.js';
|
||||
import { agentsCommand } from '../ui/commands/agentsCommand.js';
|
||||
import { approvalModeCommand } from '../ui/commands/approvalModeCommand.js';
|
||||
import { authCommand } from '../ui/commands/authCommand.js';
|
||||
import { bugCommand } from '../ui/commands/bugCommand.js';
|
||||
import { chatCommand } from '../ui/commands/chatCommand.js';
|
||||
@@ -24,19 +26,18 @@ import { ideCommand } from '../ui/commands/ideCommand.js';
|
||||
import { initCommand } from '../ui/commands/initCommand.js';
|
||||
import { mcpCommand } from '../ui/commands/mcpCommand.js';
|
||||
import { memoryCommand } from '../ui/commands/memoryCommand.js';
|
||||
import { modelCommand } from '../ui/commands/modelCommand.js';
|
||||
import { privacyCommand } from '../ui/commands/privacyCommand.js';
|
||||
import { quitCommand, quitConfirmCommand } from '../ui/commands/quitCommand.js';
|
||||
import { restoreCommand } from '../ui/commands/restoreCommand.js';
|
||||
import { settingsCommand } from '../ui/commands/settingsCommand.js';
|
||||
import { statsCommand } from '../ui/commands/statsCommand.js';
|
||||
import { summaryCommand } from '../ui/commands/summaryCommand.js';
|
||||
import { terminalSetupCommand } from '../ui/commands/terminalSetupCommand.js';
|
||||
import { themeCommand } from '../ui/commands/themeCommand.js';
|
||||
import { toolsCommand } from '../ui/commands/toolsCommand.js';
|
||||
import { settingsCommand } from '../ui/commands/settingsCommand.js';
|
||||
import { vimCommand } from '../ui/commands/vimCommand.js';
|
||||
import { setupGithubCommand } from '../ui/commands/setupGithubCommand.js';
|
||||
import { terminalSetupCommand } from '../ui/commands/terminalSetupCommand.js';
|
||||
import { modelCommand } from '../ui/commands/modelCommand.js';
|
||||
import { agentsCommand } from '../ui/commands/agentsCommand.js';
|
||||
|
||||
/**
|
||||
* Loads the core, hard-coded slash commands that are an integral part
|
||||
@@ -56,6 +57,7 @@ export class BuiltinCommandLoader implements ICommandLoader {
|
||||
const allDefinitions: Array<SlashCommand | null> = [
|
||||
aboutCommand,
|
||||
agentsCommand,
|
||||
approvalModeCommand,
|
||||
authCommand,
|
||||
bugCommand,
|
||||
chatCommand,
|
||||
|
||||
@@ -35,7 +35,10 @@ export const createMockCommandContext = (
|
||||
},
|
||||
services: {
|
||||
config: null,
|
||||
settings: { merged: {} } as LoadedSettings,
|
||||
settings: {
|
||||
merged: {},
|
||||
setValue: vi.fn(),
|
||||
} as unknown as LoadedSettings,
|
||||
git: undefined as GitService | undefined,
|
||||
logger: {
|
||||
log: vi.fn(),
|
||||
|
||||
@@ -566,7 +566,9 @@ const App = ({ config, settings, startupWarnings = [], version }: AppProps) => {
|
||||
}
|
||||
|
||||
// Switch model for future use but return false to stop current retry
|
||||
config.setModel(fallbackModel);
|
||||
config.setModel(fallbackModel).catch((error) => {
|
||||
console.error('Failed to switch to fallback model:', error);
|
||||
});
|
||||
config.setFallbackMode(true);
|
||||
logFlashFallback(
|
||||
config,
|
||||
@@ -650,17 +652,28 @@ const App = ({ config, settings, startupWarnings = [], version }: AppProps) => {
|
||||
}, []);
|
||||
|
||||
const handleModelSelect = useCallback(
|
||||
(modelId: string) => {
|
||||
config.setModel(modelId);
|
||||
setCurrentModel(modelId);
|
||||
setIsModelSelectionDialogOpen(false);
|
||||
addItem(
|
||||
{
|
||||
type: MessageType.INFO,
|
||||
text: `Switched model to \`${modelId}\` for this session.`,
|
||||
},
|
||||
Date.now(),
|
||||
);
|
||||
async (modelId: string) => {
|
||||
try {
|
||||
await config.setModel(modelId);
|
||||
setCurrentModel(modelId);
|
||||
setIsModelSelectionDialogOpen(false);
|
||||
addItem(
|
||||
{
|
||||
type: MessageType.INFO,
|
||||
text: `Switched model to \`${modelId}\` for this session.`,
|
||||
},
|
||||
Date.now(),
|
||||
);
|
||||
} catch (error) {
|
||||
console.error('Failed to switch model:', error);
|
||||
addItem(
|
||||
{
|
||||
type: MessageType.ERROR,
|
||||
text: `Failed to switch to model \`${modelId}\`. Please try again.`,
|
||||
},
|
||||
Date.now(),
|
||||
);
|
||||
}
|
||||
},
|
||||
[config, setCurrentModel, addItem],
|
||||
);
|
||||
@@ -670,7 +683,7 @@ const App = ({ config, settings, startupWarnings = [], version }: AppProps) => {
|
||||
if (!contentGeneratorConfig) return [];
|
||||
|
||||
const visionModelPreviewEnabled =
|
||||
settings.merged.experimental?.visionModelPreview ?? false;
|
||||
settings.merged.experimental?.visionModelPreview ?? true;
|
||||
|
||||
switch (contentGeneratorConfig.authType) {
|
||||
case AuthType.QWEN_OAUTH:
|
||||
@@ -759,7 +772,7 @@ const App = ({ config, settings, startupWarnings = [], version }: AppProps) => {
|
||||
setModelSwitchedFromQuotaError,
|
||||
refreshStatic,
|
||||
() => cancelHandlerRef.current(),
|
||||
settings.merged.experimental?.visionModelPreview ?? false,
|
||||
settings.merged.experimental?.visionModelPreview ?? true,
|
||||
handleVisionSwitchRequired,
|
||||
);
|
||||
|
||||
|
||||
495
packages/cli/src/ui/commands/approvalModeCommand.test.ts
Normal file
495
packages/cli/src/ui/commands/approvalModeCommand.test.ts
Normal file
@@ -0,0 +1,495 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Qwen
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
|
||||
import { approvalModeCommand } from './approvalModeCommand.js';
|
||||
import {
|
||||
type CommandContext,
|
||||
CommandKind,
|
||||
type MessageActionReturn,
|
||||
} from './types.js';
|
||||
import { createMockCommandContext } from '../../test-utils/mockCommandContext.js';
|
||||
import { ApprovalMode } from '@qwen-code/qwen-code-core';
|
||||
import { SettingScope, type LoadedSettings } from '../../config/settings.js';
|
||||
|
||||
describe('approvalModeCommand', () => {
|
||||
let mockContext: CommandContext;
|
||||
let setApprovalModeMock: ReturnType<typeof vi.fn>;
|
||||
let setSettingsValueMock: ReturnType<typeof vi.fn>;
|
||||
const originalEnv = { ...process.env };
|
||||
const userSettingsPath = '/mock/user/settings.json';
|
||||
const projectSettingsPath = '/mock/project/settings.json';
|
||||
const userSettingsFile = { path: userSettingsPath, settings: {} };
|
||||
const projectSettingsFile = { path: projectSettingsPath, settings: {} };
|
||||
|
||||
const getModeSubCommand = (mode: ApprovalMode) =>
|
||||
approvalModeCommand.subCommands?.find((cmd) => cmd.name === mode);
|
||||
|
||||
const getScopeSubCommand = (
|
||||
mode: ApprovalMode,
|
||||
scope: '--session' | '--user' | '--project',
|
||||
) => getModeSubCommand(mode)?.subCommands?.find((cmd) => cmd.name === scope);
|
||||
|
||||
beforeEach(() => {
|
||||
setApprovalModeMock = vi.fn();
|
||||
setSettingsValueMock = vi.fn();
|
||||
|
||||
mockContext = createMockCommandContext({
|
||||
services: {
|
||||
config: {
|
||||
getApprovalMode: vi.fn().mockReturnValue(ApprovalMode.DEFAULT),
|
||||
setApprovalMode: setApprovalModeMock,
|
||||
},
|
||||
settings: {
|
||||
merged: {},
|
||||
setValue: setSettingsValueMock,
|
||||
forScope: vi
|
||||
.fn()
|
||||
.mockImplementation((scope: SettingScope) =>
|
||||
scope === SettingScope.User
|
||||
? userSettingsFile
|
||||
: scope === SettingScope.Workspace
|
||||
? projectSettingsFile
|
||||
: { path: '', settings: {} },
|
||||
),
|
||||
} as unknown as LoadedSettings,
|
||||
},
|
||||
} as unknown as CommandContext);
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
process.env = { ...originalEnv };
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
it('should have the correct command properties', () => {
|
||||
expect(approvalModeCommand.name).toBe('approval-mode');
|
||||
expect(approvalModeCommand.kind).toBe(CommandKind.BUILT_IN);
|
||||
expect(approvalModeCommand.description).toBe(
|
||||
'View or change the approval mode for tool usage',
|
||||
);
|
||||
});
|
||||
|
||||
it('should show current mode, options, and usage when no arguments provided', async () => {
|
||||
if (!approvalModeCommand.action) {
|
||||
throw new Error('approvalModeCommand must have an action.');
|
||||
}
|
||||
|
||||
const result = (await approvalModeCommand.action(
|
||||
mockContext,
|
||||
'',
|
||||
)) as MessageActionReturn;
|
||||
|
||||
expect(result.type).toBe('message');
|
||||
expect(result.messageType).toBe('info');
|
||||
const expectedMessage = [
|
||||
'Current approval mode: default',
|
||||
'',
|
||||
'Available approval modes:',
|
||||
' - plan: Plan mode - Analyze only, do not modify files or execute commands',
|
||||
' - default: Default mode - Require approval for file edits or shell commands',
|
||||
' - auto-edit: Auto-edit mode - Automatically approve file edits',
|
||||
' - yolo: YOLO mode - Automatically approve all tools',
|
||||
'',
|
||||
'Usage: /approval-mode <mode> [--session|--user|--project]',
|
||||
].join('\n');
|
||||
expect(result.content).toBe(expectedMessage);
|
||||
});
|
||||
|
||||
it('should display error when config is not available', async () => {
|
||||
if (!approvalModeCommand.action) {
|
||||
throw new Error('approvalModeCommand must have an action.');
|
||||
}
|
||||
|
||||
const nullConfigContext = createMockCommandContext({
|
||||
services: {
|
||||
config: null,
|
||||
},
|
||||
} as unknown as CommandContext);
|
||||
|
||||
const result = (await approvalModeCommand.action(
|
||||
nullConfigContext,
|
||||
'',
|
||||
)) as MessageActionReturn;
|
||||
|
||||
expect(result.type).toBe('message');
|
||||
expect(result.messageType).toBe('error');
|
||||
expect(result.content).toBe('Configuration not available.');
|
||||
});
|
||||
|
||||
it('should change approval mode when valid mode is provided', async () => {
|
||||
if (!approvalModeCommand.action) {
|
||||
throw new Error('approvalModeCommand must have an action.');
|
||||
}
|
||||
|
||||
const result = (await approvalModeCommand.action(
|
||||
mockContext,
|
||||
'plan',
|
||||
)) as MessageActionReturn;
|
||||
|
||||
expect(setApprovalModeMock).toHaveBeenCalledWith(ApprovalMode.PLAN);
|
||||
expect(setSettingsValueMock).not.toHaveBeenCalled();
|
||||
expect(result.type).toBe('message');
|
||||
expect(result.messageType).toBe('info');
|
||||
expect(result.content).toBe('Approval mode changed to: plan');
|
||||
});
|
||||
|
||||
it('should accept canonical auto-edit mode value', async () => {
|
||||
if (!approvalModeCommand.action) {
|
||||
throw new Error('approvalModeCommand must have an action.');
|
||||
}
|
||||
|
||||
const result = (await approvalModeCommand.action(
|
||||
mockContext,
|
||||
'auto-edit',
|
||||
)) as MessageActionReturn;
|
||||
|
||||
expect(setApprovalModeMock).toHaveBeenCalledWith(ApprovalMode.AUTO_EDIT);
|
||||
expect(setSettingsValueMock).not.toHaveBeenCalled();
|
||||
expect(result.type).toBe('message');
|
||||
expect(result.messageType).toBe('info');
|
||||
expect(result.content).toBe('Approval mode changed to: auto-edit');
|
||||
});
|
||||
|
||||
it('should accept auto-edit alias for compatibility', async () => {
|
||||
if (!approvalModeCommand.action) {
|
||||
throw new Error('approvalModeCommand must have an action.');
|
||||
}
|
||||
|
||||
const result = (await approvalModeCommand.action(
|
||||
mockContext,
|
||||
'auto-edit',
|
||||
)) as MessageActionReturn;
|
||||
|
||||
expect(setApprovalModeMock).toHaveBeenCalledWith(ApprovalMode.AUTO_EDIT);
|
||||
expect(setSettingsValueMock).not.toHaveBeenCalled();
|
||||
expect(result.content).toBe('Approval mode changed to: auto-edit');
|
||||
});
|
||||
|
||||
it('should display error when invalid mode is provided', async () => {
|
||||
if (!approvalModeCommand.action) {
|
||||
throw new Error('approvalModeCommand must have an action.');
|
||||
}
|
||||
|
||||
const result = (await approvalModeCommand.action(
|
||||
mockContext,
|
||||
'invalid',
|
||||
)) as MessageActionReturn;
|
||||
|
||||
expect(result.type).toBe('message');
|
||||
expect(result.messageType).toBe('error');
|
||||
expect(result.content).toContain('Invalid approval mode: invalid');
|
||||
expect(result.content).toContain('Available approval modes:');
|
||||
expect(result.content).toContain(
|
||||
'Usage: /approval-mode <mode> [--session|--user|--project]',
|
||||
);
|
||||
});
|
||||
|
||||
it('should display error when setApprovalMode throws an error', async () => {
|
||||
if (!approvalModeCommand.action) {
|
||||
throw new Error('approvalModeCommand must have an action.');
|
||||
}
|
||||
|
||||
const errorMessage = 'Failed to set approval mode';
|
||||
mockContext.services.config!.setApprovalMode = vi
|
||||
.fn()
|
||||
.mockImplementation(() => {
|
||||
throw new Error(errorMessage);
|
||||
});
|
||||
|
||||
const result = (await approvalModeCommand.action(
|
||||
mockContext,
|
||||
'plan',
|
||||
)) as MessageActionReturn;
|
||||
|
||||
expect(result.type).toBe('message');
|
||||
expect(result.messageType).toBe('error');
|
||||
expect(result.content).toBe(
|
||||
`Failed to change approval mode: ${errorMessage}`,
|
||||
);
|
||||
});
|
||||
|
||||
it('should allow selecting auto-edit with user scope via nested subcommands', async () => {
|
||||
if (!approvalModeCommand.subCommands) {
|
||||
throw new Error('approvalModeCommand must have subCommands.');
|
||||
}
|
||||
|
||||
const userSubCommand = getScopeSubCommand(ApprovalMode.AUTO_EDIT, '--user');
|
||||
if (!userSubCommand?.action) {
|
||||
throw new Error('--user scope subcommand must have an action.');
|
||||
}
|
||||
|
||||
const result = (await userSubCommand.action(
|
||||
mockContext,
|
||||
'',
|
||||
)) as MessageActionReturn;
|
||||
|
||||
expect(setApprovalModeMock).toHaveBeenCalledWith(ApprovalMode.AUTO_EDIT);
|
||||
expect(setSettingsValueMock).toHaveBeenCalledWith(
|
||||
SettingScope.User,
|
||||
'approvalMode',
|
||||
'auto-edit',
|
||||
);
|
||||
expect(result.content).toBe(
|
||||
`Approval mode changed to: auto-edit (saved to user settings at ${userSettingsPath})`,
|
||||
);
|
||||
});
|
||||
|
||||
it('should allow selecting plan with project scope via nested subcommands', async () => {
|
||||
if (!approvalModeCommand.subCommands) {
|
||||
throw new Error('approvalModeCommand must have subCommands.');
|
||||
}
|
||||
|
||||
const projectSubCommand = getScopeSubCommand(
|
||||
ApprovalMode.PLAN,
|
||||
'--project',
|
||||
);
|
||||
if (!projectSubCommand?.action) {
|
||||
throw new Error('--project scope subcommand must have an action.');
|
||||
}
|
||||
|
||||
const result = (await projectSubCommand.action(
|
||||
mockContext,
|
||||
'',
|
||||
)) as MessageActionReturn;
|
||||
|
||||
expect(setApprovalModeMock).toHaveBeenCalledWith(ApprovalMode.PLAN);
|
||||
expect(setSettingsValueMock).toHaveBeenCalledWith(
|
||||
SettingScope.Workspace,
|
||||
'approvalMode',
|
||||
'plan',
|
||||
);
|
||||
expect(result.content).toBe(
|
||||
`Approval mode changed to: plan (saved to project settings at ${projectSettingsPath})`,
|
||||
);
|
||||
});
|
||||
|
||||
it('should allow selecting plan with session scope via nested subcommands', async () => {
|
||||
if (!approvalModeCommand.subCommands) {
|
||||
throw new Error('approvalModeCommand must have subCommands.');
|
||||
}
|
||||
|
||||
const sessionSubCommand = getScopeSubCommand(
|
||||
ApprovalMode.PLAN,
|
||||
'--session',
|
||||
);
|
||||
if (!sessionSubCommand?.action) {
|
||||
throw new Error('--session scope subcommand must have an action.');
|
||||
}
|
||||
|
||||
const result = (await sessionSubCommand.action(
|
||||
mockContext,
|
||||
'',
|
||||
)) as MessageActionReturn;
|
||||
|
||||
expect(setApprovalModeMock).toHaveBeenCalledWith(ApprovalMode.PLAN);
|
||||
expect(setSettingsValueMock).not.toHaveBeenCalled();
|
||||
expect(result.content).toBe('Approval mode changed to: plan');
|
||||
});
|
||||
|
||||
it('should allow providing a scope argument after selecting a mode subcommand', async () => {
|
||||
if (!approvalModeCommand.subCommands) {
|
||||
throw new Error('approvalModeCommand must have subCommands.');
|
||||
}
|
||||
|
||||
const planSubCommand = getModeSubCommand(ApprovalMode.PLAN);
|
||||
if (!planSubCommand?.action) {
|
||||
throw new Error('plan subcommand must have an action.');
|
||||
}
|
||||
|
||||
const result = (await planSubCommand.action(
|
||||
mockContext,
|
||||
'--user',
|
||||
)) as MessageActionReturn;
|
||||
|
||||
expect(setApprovalModeMock).toHaveBeenCalledWith(ApprovalMode.PLAN);
|
||||
expect(setSettingsValueMock).toHaveBeenCalledWith(
|
||||
SettingScope.User,
|
||||
'approvalMode',
|
||||
'plan',
|
||||
);
|
||||
expect(result.content).toBe(
|
||||
`Approval mode changed to: plan (saved to user settings at ${userSettingsPath})`,
|
||||
);
|
||||
});
|
||||
|
||||
it('should support --user plan pattern (scope first)', async () => {
|
||||
if (!approvalModeCommand.action) {
|
||||
throw new Error('approvalModeCommand must have an action.');
|
||||
}
|
||||
|
||||
const result = (await approvalModeCommand.action(
|
||||
mockContext,
|
||||
'--user plan',
|
||||
)) as MessageActionReturn;
|
||||
|
||||
expect(setApprovalModeMock).toHaveBeenCalledWith(ApprovalMode.PLAN);
|
||||
expect(setSettingsValueMock).toHaveBeenCalledWith(
|
||||
SettingScope.User,
|
||||
'approvalMode',
|
||||
'plan',
|
||||
);
|
||||
expect(result.content).toBe(
|
||||
`Approval mode changed to: plan (saved to user settings at ${userSettingsPath})`,
|
||||
);
|
||||
});
|
||||
|
||||
it('should support plan --user pattern (mode first)', async () => {
|
||||
if (!approvalModeCommand.action) {
|
||||
throw new Error('approvalModeCommand must have an action.');
|
||||
}
|
||||
|
||||
const result = (await approvalModeCommand.action(
|
||||
mockContext,
|
||||
'plan --user',
|
||||
)) as MessageActionReturn;
|
||||
|
||||
expect(setApprovalModeMock).toHaveBeenCalledWith(ApprovalMode.PLAN);
|
||||
expect(setSettingsValueMock).toHaveBeenCalledWith(
|
||||
SettingScope.User,
|
||||
'approvalMode',
|
||||
'plan',
|
||||
);
|
||||
expect(result.content).toBe(
|
||||
`Approval mode changed to: plan (saved to user settings at ${userSettingsPath})`,
|
||||
);
|
||||
});
|
||||
|
||||
it('should support --project auto-edit pattern', async () => {
|
||||
if (!approvalModeCommand.action) {
|
||||
throw new Error('approvalModeCommand must have an action.');
|
||||
}
|
||||
|
||||
const result = (await approvalModeCommand.action(
|
||||
mockContext,
|
||||
'--project auto-edit',
|
||||
)) as MessageActionReturn;
|
||||
|
||||
expect(setApprovalModeMock).toHaveBeenCalledWith(ApprovalMode.AUTO_EDIT);
|
||||
expect(setSettingsValueMock).toHaveBeenCalledWith(
|
||||
SettingScope.Workspace,
|
||||
'approvalMode',
|
||||
'auto-edit',
|
||||
);
|
||||
expect(result.content).toBe(
|
||||
`Approval mode changed to: auto-edit (saved to project settings at ${projectSettingsPath})`,
|
||||
);
|
||||
});
|
||||
|
||||
it('should display error when only scope flag is provided', async () => {
|
||||
if (!approvalModeCommand.action) {
|
||||
throw new Error('approvalModeCommand must have an action.');
|
||||
}
|
||||
|
||||
const result = (await approvalModeCommand.action(
|
||||
mockContext,
|
||||
'--user',
|
||||
)) as MessageActionReturn;
|
||||
|
||||
expect(result.type).toBe('message');
|
||||
expect(result.messageType).toBe('error');
|
||||
expect(result.content).toContain('Missing approval mode');
|
||||
expect(setApprovalModeMock).not.toHaveBeenCalled();
|
||||
expect(setSettingsValueMock).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should display error when multiple scope flags are provided', async () => {
|
||||
if (!approvalModeCommand.action) {
|
||||
throw new Error('approvalModeCommand must have an action.');
|
||||
}
|
||||
|
||||
const result = (await approvalModeCommand.action(
|
||||
mockContext,
|
||||
'--user --project plan',
|
||||
)) as MessageActionReturn;
|
||||
|
||||
expect(result.type).toBe('message');
|
||||
expect(result.messageType).toBe('error');
|
||||
expect(result.content).toContain('Multiple scope flags provided');
|
||||
expect(setApprovalModeMock).not.toHaveBeenCalled();
|
||||
expect(setSettingsValueMock).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should surface a helpful error when scope subcommands receive extra arguments', async () => {
|
||||
if (!approvalModeCommand.subCommands) {
|
||||
throw new Error('approvalModeCommand must have subCommands.');
|
||||
}
|
||||
|
||||
const userSubCommand = getScopeSubCommand(ApprovalMode.DEFAULT, '--user');
|
||||
if (!userSubCommand?.action) {
|
||||
throw new Error('--user scope subcommand must have an action.');
|
||||
}
|
||||
|
||||
const result = (await userSubCommand.action(
|
||||
mockContext,
|
||||
'extra',
|
||||
)) as MessageActionReturn;
|
||||
|
||||
expect(result.type).toBe('message');
|
||||
expect(result.messageType).toBe('error');
|
||||
expect(result.content).toBe(
|
||||
'Scope subcommands do not accept additional arguments.',
|
||||
);
|
||||
expect(setApprovalModeMock).not.toHaveBeenCalled();
|
||||
expect(setSettingsValueMock).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should provide completion for approval modes', async () => {
|
||||
if (!approvalModeCommand.completion) {
|
||||
throw new Error('approvalModeCommand must have a completion function.');
|
||||
}
|
||||
|
||||
// Test partial mode completion
|
||||
const result = await approvalModeCommand.completion(mockContext, 'p');
|
||||
expect(result).toEqual(['plan']);
|
||||
|
||||
const result2 = await approvalModeCommand.completion(mockContext, 'a');
|
||||
expect(result2).toEqual(['auto-edit']);
|
||||
|
||||
// Test empty completion - should suggest available modes first
|
||||
const result3 = await approvalModeCommand.completion(mockContext, '');
|
||||
expect(result3).toEqual(['plan', 'default', 'auto-edit', 'yolo']);
|
||||
|
||||
const result4 = await approvalModeCommand.completion(mockContext, 'AUTO');
|
||||
expect(result4).toEqual(['auto-edit']);
|
||||
|
||||
// Test mode first pattern: 'plan ' should suggest scope flags
|
||||
const result5 = await approvalModeCommand.completion(mockContext, 'plan ');
|
||||
expect(result5).toEqual(['--session', '--project', '--user']);
|
||||
|
||||
const result6 = await approvalModeCommand.completion(
|
||||
mockContext,
|
||||
'plan --u',
|
||||
);
|
||||
expect(result6).toEqual(['--user']);
|
||||
|
||||
// Test scope first pattern: '--user ' should suggest modes
|
||||
const result7 = await approvalModeCommand.completion(
|
||||
mockContext,
|
||||
'--user ',
|
||||
);
|
||||
expect(result7).toEqual(['plan', 'default', 'auto-edit', 'yolo']);
|
||||
|
||||
const result8 = await approvalModeCommand.completion(
|
||||
mockContext,
|
||||
'--user p',
|
||||
);
|
||||
expect(result8).toEqual(['plan']);
|
||||
|
||||
// Test completed patterns should return empty
|
||||
const result9 = await approvalModeCommand.completion(
|
||||
mockContext,
|
||||
'plan --user ',
|
||||
);
|
||||
expect(result9).toEqual([]);
|
||||
|
||||
const result10 = await approvalModeCommand.completion(
|
||||
mockContext,
|
||||
'--user plan ',
|
||||
);
|
||||
expect(result10).toEqual([]);
|
||||
});
|
||||
});
|
||||
434
packages/cli/src/ui/commands/approvalModeCommand.ts
Normal file
434
packages/cli/src/ui/commands/approvalModeCommand.ts
Normal file
@@ -0,0 +1,434 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Qwen
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import type {
|
||||
SlashCommand,
|
||||
CommandContext,
|
||||
MessageActionReturn,
|
||||
} from './types.js';
|
||||
import { CommandKind } from './types.js';
|
||||
import { ApprovalMode, APPROVAL_MODES } from '@qwen-code/qwen-code-core';
|
||||
import { SettingScope } from '../../config/settings.js';
|
||||
|
||||
const USAGE_MESSAGE =
|
||||
'Usage: /approval-mode <mode> [--session|--user|--project]';
|
||||
|
||||
const normalizeInputMode = (value: string): string =>
|
||||
value.trim().toLowerCase();
|
||||
|
||||
const tokenizeArgs = (args: string): string[] => {
|
||||
const matches = args.match(/(?:"[^"]*"|'[^']*'|[^\s"']+)/g);
|
||||
if (!matches) {
|
||||
return [];
|
||||
}
|
||||
|
||||
return matches.map((token) => {
|
||||
if (
|
||||
(token.startsWith('"') && token.endsWith('"')) ||
|
||||
(token.startsWith("'") && token.endsWith("'"))
|
||||
) {
|
||||
return token.slice(1, -1);
|
||||
}
|
||||
return token;
|
||||
});
|
||||
};
|
||||
|
||||
const parseApprovalMode = (value: string | null): ApprovalMode | null => {
|
||||
if (!value) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const normalized = normalizeInputMode(value).replace(/_/g, '-');
|
||||
const matchIndex = APPROVAL_MODES.findIndex(
|
||||
(candidate) => candidate === normalized,
|
||||
);
|
||||
|
||||
return matchIndex === -1 ? null : APPROVAL_MODES[matchIndex];
|
||||
};
|
||||
|
||||
const formatModeDescription = (mode: ApprovalMode): string => {
|
||||
switch (mode) {
|
||||
case ApprovalMode.PLAN:
|
||||
return 'Plan mode - Analyze only, do not modify files or execute commands';
|
||||
case ApprovalMode.DEFAULT:
|
||||
return 'Default mode - Require approval for file edits or shell commands';
|
||||
case ApprovalMode.AUTO_EDIT:
|
||||
return 'Auto-edit mode - Automatically approve file edits';
|
||||
case ApprovalMode.YOLO:
|
||||
return 'YOLO mode - Automatically approve all tools';
|
||||
default:
|
||||
return `${mode} mode`;
|
||||
}
|
||||
};
|
||||
|
||||
const parseApprovalArgs = (
|
||||
args: string,
|
||||
): {
|
||||
mode: string | null;
|
||||
scope: 'session' | 'user' | 'project';
|
||||
error?: string;
|
||||
} => {
|
||||
const trimmedArgs = args.trim();
|
||||
if (!trimmedArgs) {
|
||||
return { mode: null, scope: 'session' };
|
||||
}
|
||||
|
||||
const tokens = tokenizeArgs(trimmedArgs);
|
||||
let mode: string | null = null;
|
||||
let scope: 'session' | 'user' | 'project' = 'session';
|
||||
let scopeFlag: string | null = null;
|
||||
|
||||
// Find scope flag and mode
|
||||
for (const token of tokens) {
|
||||
if (token === '--session' || token === '--user' || token === '--project') {
|
||||
if (scopeFlag) {
|
||||
return {
|
||||
mode: null,
|
||||
scope: 'session',
|
||||
error: 'Multiple scope flags provided',
|
||||
};
|
||||
}
|
||||
scopeFlag = token;
|
||||
scope = token.substring(2) as 'session' | 'user' | 'project';
|
||||
} else if (!mode) {
|
||||
mode = token;
|
||||
} else {
|
||||
return {
|
||||
mode: null,
|
||||
scope: 'session',
|
||||
error: 'Invalid arguments provided',
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
if (!mode) {
|
||||
return { mode: null, scope: 'session', error: 'Missing approval mode' };
|
||||
}
|
||||
|
||||
return { mode, scope };
|
||||
};
|
||||
|
||||
const setApprovalModeWithScope = async (
|
||||
context: CommandContext,
|
||||
mode: ApprovalMode,
|
||||
scope: 'session' | 'user' | 'project',
|
||||
): Promise<MessageActionReturn> => {
|
||||
const { services } = context;
|
||||
const { config } = services;
|
||||
|
||||
if (!config) {
|
||||
return {
|
||||
type: 'message',
|
||||
messageType: 'error',
|
||||
content: 'Configuration not available.',
|
||||
};
|
||||
}
|
||||
|
||||
try {
|
||||
// Always set the mode in the current session
|
||||
config.setApprovalMode(mode);
|
||||
|
||||
// If scope is not session, also persist to settings
|
||||
if (scope !== 'session') {
|
||||
const { settings } = context.services;
|
||||
if (!settings || typeof settings.setValue !== 'function') {
|
||||
return {
|
||||
type: 'message',
|
||||
messageType: 'error',
|
||||
content:
|
||||
'Settings service is not available; unable to persist the approval mode.',
|
||||
};
|
||||
}
|
||||
|
||||
const settingScope =
|
||||
scope === 'user' ? SettingScope.User : SettingScope.Workspace;
|
||||
const scopeLabel = scope === 'user' ? 'user' : 'project';
|
||||
let settingsPath: string | undefined;
|
||||
|
||||
try {
|
||||
if (typeof settings.forScope === 'function') {
|
||||
settingsPath = settings.forScope(settingScope)?.path;
|
||||
}
|
||||
} catch (_error) {
|
||||
settingsPath = undefined;
|
||||
}
|
||||
|
||||
try {
|
||||
settings.setValue(settingScope, 'approvalMode', mode);
|
||||
} catch (error) {
|
||||
return {
|
||||
type: 'message',
|
||||
messageType: 'error',
|
||||
content: `Failed to save approval mode: ${(error as Error).message}`,
|
||||
};
|
||||
}
|
||||
|
||||
const locationSuffix = settingsPath ? ` at ${settingsPath}` : '';
|
||||
|
||||
const scopeSuffix = ` (saved to ${scopeLabel} settings${locationSuffix})`;
|
||||
|
||||
return {
|
||||
type: 'message',
|
||||
messageType: 'info',
|
||||
content: `Approval mode changed to: ${mode}${scopeSuffix}`,
|
||||
};
|
||||
}
|
||||
|
||||
return {
|
||||
type: 'message',
|
||||
messageType: 'info',
|
||||
content: `Approval mode changed to: ${mode}`,
|
||||
};
|
||||
} catch (error) {
|
||||
return {
|
||||
type: 'message',
|
||||
messageType: 'error',
|
||||
content: `Failed to change approval mode: ${(error as Error).message}`,
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
export const approvalModeCommand: SlashCommand = {
|
||||
name: 'approval-mode',
|
||||
description: 'View or change the approval mode for tool usage',
|
||||
kind: CommandKind.BUILT_IN,
|
||||
action: async (
|
||||
context: CommandContext,
|
||||
args: string,
|
||||
): Promise<MessageActionReturn> => {
|
||||
const { config } = context.services;
|
||||
if (!config) {
|
||||
return {
|
||||
type: 'message',
|
||||
messageType: 'error',
|
||||
content: 'Configuration not available.',
|
||||
};
|
||||
}
|
||||
|
||||
// If no arguments provided, show current mode and available options
|
||||
if (!args || args.trim() === '') {
|
||||
const currentMode =
|
||||
typeof config.getApprovalMode === 'function'
|
||||
? config.getApprovalMode()
|
||||
: null;
|
||||
|
||||
const messageLines: string[] = [];
|
||||
|
||||
if (currentMode) {
|
||||
messageLines.push(`Current approval mode: ${currentMode}`);
|
||||
messageLines.push('');
|
||||
}
|
||||
|
||||
messageLines.push('Available approval modes:');
|
||||
for (const mode of APPROVAL_MODES) {
|
||||
messageLines.push(` - ${mode}: ${formatModeDescription(mode)}`);
|
||||
}
|
||||
messageLines.push('');
|
||||
messageLines.push(USAGE_MESSAGE);
|
||||
|
||||
return {
|
||||
type: 'message',
|
||||
messageType: 'info',
|
||||
content: messageLines.join('\n'),
|
||||
};
|
||||
}
|
||||
|
||||
// Parse arguments flexibly
|
||||
const parsed = parseApprovalArgs(args);
|
||||
|
||||
if (parsed.error) {
|
||||
return {
|
||||
type: 'message',
|
||||
messageType: 'error',
|
||||
content: `${parsed.error}. ${USAGE_MESSAGE}`,
|
||||
};
|
||||
}
|
||||
|
||||
if (!parsed.mode) {
|
||||
return {
|
||||
type: 'message',
|
||||
messageType: 'info',
|
||||
content: USAGE_MESSAGE,
|
||||
};
|
||||
}
|
||||
|
||||
const requestedMode = parseApprovalMode(parsed.mode);
|
||||
|
||||
if (!requestedMode) {
|
||||
let message = `Invalid approval mode: ${parsed.mode}\n\n`;
|
||||
message += 'Available approval modes:\n';
|
||||
for (const mode of APPROVAL_MODES) {
|
||||
message += ` - ${mode}: ${formatModeDescription(mode)}\n`;
|
||||
}
|
||||
message += `\n${USAGE_MESSAGE}`;
|
||||
return {
|
||||
type: 'message',
|
||||
messageType: 'error',
|
||||
content: message,
|
||||
};
|
||||
}
|
||||
|
||||
return setApprovalModeWithScope(context, requestedMode, parsed.scope);
|
||||
},
|
||||
subCommands: APPROVAL_MODES.map((mode) => ({
|
||||
name: mode,
|
||||
description: formatModeDescription(mode),
|
||||
kind: CommandKind.BUILT_IN,
|
||||
subCommands: [
|
||||
{
|
||||
name: '--session',
|
||||
description: 'Apply to current session only (temporary)',
|
||||
kind: CommandKind.BUILT_IN,
|
||||
action: async (
|
||||
context: CommandContext,
|
||||
args: string,
|
||||
): Promise<MessageActionReturn> => {
|
||||
if (args.trim().length > 0) {
|
||||
return {
|
||||
type: 'message',
|
||||
messageType: 'error',
|
||||
content: 'Scope subcommands do not accept additional arguments.',
|
||||
};
|
||||
}
|
||||
return setApprovalModeWithScope(context, mode, 'session');
|
||||
},
|
||||
},
|
||||
{
|
||||
name: '--project',
|
||||
description: 'Persist for this project/workspace',
|
||||
kind: CommandKind.BUILT_IN,
|
||||
action: async (
|
||||
context: CommandContext,
|
||||
args: string,
|
||||
): Promise<MessageActionReturn> => {
|
||||
if (args.trim().length > 0) {
|
||||
return {
|
||||
type: 'message',
|
||||
messageType: 'error',
|
||||
content: 'Scope subcommands do not accept additional arguments.',
|
||||
};
|
||||
}
|
||||
return setApprovalModeWithScope(context, mode, 'project');
|
||||
},
|
||||
},
|
||||
{
|
||||
name: '--user',
|
||||
description: 'Persist for this user on this machine',
|
||||
kind: CommandKind.BUILT_IN,
|
||||
action: async (
|
||||
context: CommandContext,
|
||||
args: string,
|
||||
): Promise<MessageActionReturn> => {
|
||||
if (args.trim().length > 0) {
|
||||
return {
|
||||
type: 'message',
|
||||
messageType: 'error',
|
||||
content: 'Scope subcommands do not accept additional arguments.',
|
||||
};
|
||||
}
|
||||
return setApprovalModeWithScope(context, mode, 'user');
|
||||
},
|
||||
},
|
||||
],
|
||||
action: async (
|
||||
context: CommandContext,
|
||||
args: string,
|
||||
): Promise<MessageActionReturn> => {
|
||||
if (args.trim().length > 0) {
|
||||
// Allow users who type `/approval-mode plan --user` via the subcommand path
|
||||
const parsed = parseApprovalArgs(`${mode} ${args}`);
|
||||
if (parsed.error) {
|
||||
return {
|
||||
type: 'message',
|
||||
messageType: 'error',
|
||||
content: `${parsed.error}. ${USAGE_MESSAGE}`,
|
||||
};
|
||||
}
|
||||
|
||||
const normalizedMode = parseApprovalMode(parsed.mode);
|
||||
if (!normalizedMode) {
|
||||
return {
|
||||
type: 'message',
|
||||
messageType: 'error',
|
||||
content: `Invalid approval mode: ${parsed.mode}. ${USAGE_MESSAGE}`,
|
||||
};
|
||||
}
|
||||
|
||||
return setApprovalModeWithScope(context, normalizedMode, parsed.scope);
|
||||
}
|
||||
|
||||
return setApprovalModeWithScope(context, mode, 'session');
|
||||
},
|
||||
})),
|
||||
completion: async (_context: CommandContext, partialArg: string) => {
|
||||
const tokens = tokenizeArgs(partialArg);
|
||||
const hasTrailingSpace = /\s$/.test(partialArg);
|
||||
const currentSegment = hasTrailingSpace
|
||||
? ''
|
||||
: tokens.length > 0
|
||||
? tokens[tokens.length - 1]
|
||||
: '';
|
||||
|
||||
const normalizedCurrent = normalizeInputMode(currentSegment).replace(
|
||||
/_/g,
|
||||
'-',
|
||||
);
|
||||
|
||||
const scopeValues = ['--session', '--project', '--user'];
|
||||
|
||||
const normalizeToken = (token: string) =>
|
||||
normalizeInputMode(token).replace(/_/g, '-');
|
||||
|
||||
const normalizedTokens = tokens.map(normalizeToken);
|
||||
|
||||
if (tokens.length === 0) {
|
||||
if (currentSegment.startsWith('-')) {
|
||||
return scopeValues.filter((scope) => scope.startsWith(currentSegment));
|
||||
}
|
||||
return APPROVAL_MODES;
|
||||
}
|
||||
|
||||
if (tokens.length === 1 && !hasTrailingSpace) {
|
||||
const originalToken = tokens[0];
|
||||
if (originalToken.startsWith('-')) {
|
||||
return scopeValues.filter((scope) =>
|
||||
scope.startsWith(normalizedCurrent),
|
||||
);
|
||||
}
|
||||
return APPROVAL_MODES.filter((mode) =>
|
||||
mode.startsWith(normalizedCurrent),
|
||||
);
|
||||
}
|
||||
|
||||
if (tokens.length === 1 && hasTrailingSpace) {
|
||||
const normalizedFirst = normalizedTokens[0];
|
||||
if (scopeValues.includes(tokens[0])) {
|
||||
return APPROVAL_MODES;
|
||||
}
|
||||
if (APPROVAL_MODES.includes(normalizedFirst as ApprovalMode)) {
|
||||
return scopeValues;
|
||||
}
|
||||
return APPROVAL_MODES;
|
||||
}
|
||||
|
||||
if (tokens.length === 2 && !hasTrailingSpace) {
|
||||
const normalizedFirst = normalizedTokens[0];
|
||||
if (scopeValues.includes(tokens[0])) {
|
||||
return APPROVAL_MODES.filter((mode) =>
|
||||
mode.startsWith(normalizedCurrent),
|
||||
);
|
||||
}
|
||||
if (APPROVAL_MODES.includes(normalizedFirst as ApprovalMode)) {
|
||||
return scopeValues.filter((scope) =>
|
||||
scope.startsWith(normalizedCurrent),
|
||||
);
|
||||
}
|
||||
return [];
|
||||
}
|
||||
|
||||
return [];
|
||||
},
|
||||
};
|
||||
@@ -21,15 +21,20 @@ export const AutoAcceptIndicator: React.FC<AutoAcceptIndicatorProps> = ({
|
||||
let subText = '';
|
||||
|
||||
switch (approvalMode) {
|
||||
case ApprovalMode.PLAN:
|
||||
textColor = Colors.AccentBlue;
|
||||
textContent = 'plan mode';
|
||||
subText = ' (shift + tab to cycle)';
|
||||
break;
|
||||
case ApprovalMode.AUTO_EDIT:
|
||||
textColor = Colors.AccentGreen;
|
||||
textContent = 'accepting edits';
|
||||
subText = ' (shift + tab to toggle)';
|
||||
textContent = 'auto-accept edits';
|
||||
subText = ' (shift + tab to cycle)';
|
||||
break;
|
||||
case ApprovalMode.YOLO:
|
||||
textColor = Colors.AccentRed;
|
||||
textContent = 'YOLO mode';
|
||||
subText = ' (ctrl + y to toggle)';
|
||||
subText = ' (shift + tab to cycle)';
|
||||
break;
|
||||
case ApprovalMode.DEFAULT:
|
||||
default:
|
||||
|
||||
@@ -133,12 +133,6 @@ export const Help: React.FC<Help> = ({ commands }) => (
|
||||
</Text>{' '}
|
||||
- Open input in external editor
|
||||
</Text>
|
||||
<Text color={Colors.Foreground}>
|
||||
<Text bold color={Colors.AccentPurple}>
|
||||
Ctrl+Y
|
||||
</Text>{' '}
|
||||
- Toggle YOLO mode
|
||||
</Text>
|
||||
<Text color={Colors.Foreground}>
|
||||
<Text bold color={Colors.AccentPurple}>
|
||||
Enter
|
||||
@@ -155,7 +149,7 @@ export const Help: React.FC<Help> = ({ commands }) => (
|
||||
<Text bold color={Colors.AccentPurple}>
|
||||
Shift+Tab
|
||||
</Text>{' '}
|
||||
- Toggle auto-accepting edits
|
||||
- Cycle approval modes
|
||||
</Text>
|
||||
<Text color={Colors.Foreground}>
|
||||
<Text bold color={Colors.AccentPurple}>
|
||||
|
||||
@@ -46,8 +46,8 @@ describe('ModelSwitchDialog', () => {
|
||||
value: VisionSwitchOutcome.SwitchSessionToVL,
|
||||
},
|
||||
{
|
||||
label: 'Do not switch, show guidance',
|
||||
value: VisionSwitchOutcome.DisallowWithGuidance,
|
||||
label: 'Continue with current model',
|
||||
value: VisionSwitchOutcome.ContinueWithCurrentModel,
|
||||
},
|
||||
];
|
||||
|
||||
@@ -81,18 +81,18 @@ describe('ModelSwitchDialog', () => {
|
||||
);
|
||||
});
|
||||
|
||||
it('should call onSelect with DisallowWithGuidance when third option is selected', () => {
|
||||
it('should call onSelect with ContinueWithCurrentModel when third option is selected', () => {
|
||||
render(<ModelSwitchDialog onSelect={mockOnSelect} />);
|
||||
|
||||
const onSelectCallback = mockRadioButtonSelect.mock.calls[0][0].onSelect;
|
||||
onSelectCallback(VisionSwitchOutcome.DisallowWithGuidance);
|
||||
onSelectCallback(VisionSwitchOutcome.ContinueWithCurrentModel);
|
||||
|
||||
expect(mockOnSelect).toHaveBeenCalledWith(
|
||||
VisionSwitchOutcome.DisallowWithGuidance,
|
||||
VisionSwitchOutcome.ContinueWithCurrentModel,
|
||||
);
|
||||
});
|
||||
|
||||
it('should setup escape key handler to call onSelect with DisallowWithGuidance', () => {
|
||||
it('should setup escape key handler to call onSelect with ContinueWithCurrentModel', () => {
|
||||
render(<ModelSwitchDialog onSelect={mockOnSelect} />);
|
||||
|
||||
expect(mockUseKeypress).toHaveBeenCalledWith(expect.any(Function), {
|
||||
@@ -104,7 +104,7 @@ describe('ModelSwitchDialog', () => {
|
||||
keypressHandler({ name: 'escape' });
|
||||
|
||||
expect(mockOnSelect).toHaveBeenCalledWith(
|
||||
VisionSwitchOutcome.DisallowWithGuidance,
|
||||
VisionSwitchOutcome.ContinueWithCurrentModel,
|
||||
);
|
||||
});
|
||||
|
||||
@@ -126,13 +126,9 @@ describe('ModelSwitchDialog', () => {
|
||||
|
||||
describe('VisionSwitchOutcome enum', () => {
|
||||
it('should have correct enum values', () => {
|
||||
expect(VisionSwitchOutcome.SwitchOnce).toBe('switch_once');
|
||||
expect(VisionSwitchOutcome.SwitchSessionToVL).toBe(
|
||||
'switch_session_to_vl',
|
||||
);
|
||||
expect(VisionSwitchOutcome.DisallowWithGuidance).toBe(
|
||||
'disallow_with_guidance',
|
||||
);
|
||||
expect(VisionSwitchOutcome.SwitchOnce).toBe('once');
|
||||
expect(VisionSwitchOutcome.SwitchSessionToVL).toBe('session');
|
||||
expect(VisionSwitchOutcome.ContinueWithCurrentModel).toBe('persist');
|
||||
});
|
||||
});
|
||||
|
||||
@@ -144,7 +140,7 @@ describe('ModelSwitchDialog', () => {
|
||||
// Call multiple times
|
||||
onSelectCallback(VisionSwitchOutcome.SwitchOnce);
|
||||
onSelectCallback(VisionSwitchOutcome.SwitchSessionToVL);
|
||||
onSelectCallback(VisionSwitchOutcome.DisallowWithGuidance);
|
||||
onSelectCallback(VisionSwitchOutcome.ContinueWithCurrentModel);
|
||||
|
||||
expect(mockOnSelect).toHaveBeenCalledTimes(3);
|
||||
expect(mockOnSelect).toHaveBeenNthCalledWith(
|
||||
@@ -157,7 +153,7 @@ describe('ModelSwitchDialog', () => {
|
||||
);
|
||||
expect(mockOnSelect).toHaveBeenNthCalledWith(
|
||||
3,
|
||||
VisionSwitchOutcome.DisallowWithGuidance,
|
||||
VisionSwitchOutcome.ContinueWithCurrentModel,
|
||||
);
|
||||
});
|
||||
|
||||
@@ -179,7 +175,7 @@ describe('ModelSwitchDialog', () => {
|
||||
|
||||
expect(mockOnSelect).toHaveBeenCalledTimes(2);
|
||||
expect(mockOnSelect).toHaveBeenCalledWith(
|
||||
VisionSwitchOutcome.DisallowWithGuidance,
|
||||
VisionSwitchOutcome.ContinueWithCurrentModel,
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -14,9 +14,9 @@ import {
|
||||
import { useKeypress } from '../hooks/useKeypress.js';
|
||||
|
||||
export enum VisionSwitchOutcome {
|
||||
SwitchOnce = 'switch_once',
|
||||
SwitchSessionToVL = 'switch_session_to_vl',
|
||||
DisallowWithGuidance = 'disallow_with_guidance',
|
||||
SwitchOnce = 'once',
|
||||
SwitchSessionToVL = 'session',
|
||||
ContinueWithCurrentModel = 'persist',
|
||||
}
|
||||
|
||||
export interface ModelSwitchDialogProps {
|
||||
@@ -29,7 +29,7 @@ export const ModelSwitchDialog: React.FC<ModelSwitchDialogProps> = ({
|
||||
useKeypress(
|
||||
(key) => {
|
||||
if (key.name === 'escape') {
|
||||
onSelect(VisionSwitchOutcome.DisallowWithGuidance);
|
||||
onSelect(VisionSwitchOutcome.ContinueWithCurrentModel);
|
||||
}
|
||||
},
|
||||
{ isActive: true },
|
||||
@@ -45,8 +45,8 @@ export const ModelSwitchDialog: React.FC<ModelSwitchDialogProps> = ({
|
||||
value: VisionSwitchOutcome.SwitchSessionToVL,
|
||||
},
|
||||
{
|
||||
label: 'Do not switch, show guidance',
|
||||
value: VisionSwitchOutcome.DisallowWithGuidance,
|
||||
label: 'Continue with current model',
|
||||
value: VisionSwitchOutcome.ContinueWithCurrentModel,
|
||||
},
|
||||
];
|
||||
|
||||
|
||||
41
packages/cli/src/ui/components/PlanSummaryDisplay.tsx
Normal file
41
packages/cli/src/ui/components/PlanSummaryDisplay.tsx
Normal file
@@ -0,0 +1,41 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Qwen
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import type React from 'react';
|
||||
import { Box, Text } from 'ink';
|
||||
import { MarkdownDisplay } from '../utils/MarkdownDisplay.js';
|
||||
import { Colors } from '../colors.js';
|
||||
import type { PlanResultDisplay } from '@qwen-code/qwen-code-core';
|
||||
|
||||
interface PlanSummaryDisplayProps {
|
||||
data: PlanResultDisplay;
|
||||
availableHeight?: number;
|
||||
childWidth: number;
|
||||
}
|
||||
|
||||
export const PlanSummaryDisplay: React.FC<PlanSummaryDisplayProps> = ({
|
||||
data,
|
||||
availableHeight,
|
||||
childWidth,
|
||||
}) => {
|
||||
const { message, plan } = data;
|
||||
|
||||
return (
|
||||
<Box flexDirection="column">
|
||||
<Box marginBottom={1}>
|
||||
<Text color={Colors.AccentGreen} wrap="wrap">
|
||||
{message}
|
||||
</Text>
|
||||
</Box>
|
||||
<MarkdownDisplay
|
||||
text={plan}
|
||||
isPending={false}
|
||||
availableTerminalHeight={availableHeight}
|
||||
terminalWidth={childWidth}
|
||||
/>
|
||||
</Box>
|
||||
);
|
||||
};
|
||||
@@ -5,6 +5,7 @@
|
||||
*/
|
||||
|
||||
import { describe, it, expect, vi } from 'vitest';
|
||||
import { EOL } from 'node:os';
|
||||
import { ToolConfirmationMessage } from './ToolConfirmationMessage.js';
|
||||
import type {
|
||||
ToolCallConfirmationDetails,
|
||||
@@ -66,6 +67,30 @@ describe('ToolConfirmationMessage', () => {
|
||||
);
|
||||
});
|
||||
|
||||
it('should render plan confirmation with markdown plan content', () => {
|
||||
const confirmationDetails: ToolCallConfirmationDetails = {
|
||||
type: 'plan',
|
||||
title: 'Would you like to proceed?',
|
||||
plan: '# Implementation Plan\n- Step one\n- Step two'.replace(/\n/g, EOL),
|
||||
onConfirm: vi.fn(),
|
||||
};
|
||||
|
||||
const { lastFrame } = renderWithProviders(
|
||||
<ToolConfirmationMessage
|
||||
confirmationDetails={confirmationDetails}
|
||||
config={mockConfig}
|
||||
availableTerminalHeight={30}
|
||||
terminalWidth={80}
|
||||
/>,
|
||||
);
|
||||
|
||||
expect(lastFrame()).toContain('Yes, and auto-accept edits');
|
||||
expect(lastFrame()).toContain('Yes, and manually approve edits');
|
||||
expect(lastFrame()).toContain('No, keep planning');
|
||||
expect(lastFrame()).toContain('Implementation Plan');
|
||||
expect(lastFrame()).toContain('Step one');
|
||||
});
|
||||
|
||||
describe('with folder trust', () => {
|
||||
const editConfirmationDetails: ToolCallConfirmationDetails = {
|
||||
type: 'edit',
|
||||
|
||||
@@ -9,6 +9,7 @@ import { Box, Text } from 'ink';
|
||||
import { DiffRenderer } from './DiffRenderer.js';
|
||||
import { Colors } from '../../colors.js';
|
||||
import { RenderInline } from '../../utils/InlineMarkdownRenderer.js';
|
||||
import { MarkdownDisplay } from '../../utils/MarkdownDisplay.js';
|
||||
import type {
|
||||
ToolCallConfirmationDetails,
|
||||
ToolExecuteConfirmationDetails,
|
||||
@@ -235,6 +236,33 @@ export const ToolConfirmationMessage: React.FC<
|
||||
</Box>
|
||||
</Box>
|
||||
);
|
||||
} else if (confirmationDetails.type === 'plan') {
|
||||
const planProps = confirmationDetails;
|
||||
|
||||
question = planProps.title;
|
||||
options.push({
|
||||
label: 'Yes, and auto-accept edits',
|
||||
value: ToolConfirmationOutcome.ProceedAlways,
|
||||
});
|
||||
options.push({
|
||||
label: 'Yes, and manually approve edits',
|
||||
value: ToolConfirmationOutcome.ProceedOnce,
|
||||
});
|
||||
options.push({
|
||||
label: 'No, keep planning (esc)',
|
||||
value: ToolConfirmationOutcome.Cancel,
|
||||
});
|
||||
|
||||
bodyContent = (
|
||||
<Box flexDirection="column" paddingX={1} marginLeft={1}>
|
||||
<MarkdownDisplay
|
||||
text={planProps.plan}
|
||||
isPending={false}
|
||||
availableTerminalHeight={availableBodyContentHeight()}
|
||||
terminalWidth={childWidth}
|
||||
/>
|
||||
</Box>
|
||||
);
|
||||
} else if (confirmationDetails.type === 'info') {
|
||||
const infoProps = confirmationDetails;
|
||||
const displayUrls =
|
||||
|
||||
@@ -18,9 +18,11 @@ import { TOOL_STATUS } from '../../constants.js';
|
||||
import type {
|
||||
TodoResultDisplay,
|
||||
TaskResultDisplay,
|
||||
PlanResultDisplay,
|
||||
Config,
|
||||
} from '@qwen-code/qwen-code-core';
|
||||
import { AgentExecutionDisplay } from '../subagents/index.js';
|
||||
import { PlanSummaryDisplay } from '../PlanSummaryDisplay.js';
|
||||
|
||||
const STATIC_HEIGHT = 1;
|
||||
const RESERVED_LINE_COUNT = 5; // for tool name, status, padding etc.
|
||||
@@ -35,6 +37,7 @@ export type TextEmphasis = 'high' | 'medium' | 'low';
|
||||
type DisplayRendererResult =
|
||||
| { type: 'none' }
|
||||
| { type: 'todo'; data: TodoResultDisplay }
|
||||
| { type: 'plan'; data: PlanResultDisplay }
|
||||
| { type: 'string'; data: string }
|
||||
| { type: 'diff'; data: { fileDiff: string; fileName: string } }
|
||||
| { type: 'task'; data: TaskResultDisplay };
|
||||
@@ -63,6 +66,18 @@ const useResultDisplayRenderer = (
|
||||
};
|
||||
}
|
||||
|
||||
if (
|
||||
typeof resultDisplay === 'object' &&
|
||||
resultDisplay !== null &&
|
||||
'type' in resultDisplay &&
|
||||
resultDisplay.type === 'plan_summary'
|
||||
) {
|
||||
return {
|
||||
type: 'plan',
|
||||
data: resultDisplay as PlanResultDisplay,
|
||||
};
|
||||
}
|
||||
|
||||
// Check for SubagentExecutionResultDisplay (for non-task tools)
|
||||
if (
|
||||
typeof resultDisplay === 'object' &&
|
||||
@@ -102,6 +117,18 @@ const TodoResultRenderer: React.FC<{ data: TodoResultDisplay }> = ({
|
||||
data,
|
||||
}) => <TodoDisplay todos={data.todos} />;
|
||||
|
||||
const PlanResultRenderer: React.FC<{
|
||||
data: PlanResultDisplay;
|
||||
availableHeight?: number;
|
||||
childWidth: number;
|
||||
}> = ({ data, availableHeight, childWidth }) => (
|
||||
<PlanSummaryDisplay
|
||||
data={data}
|
||||
availableHeight={availableHeight}
|
||||
childWidth={childWidth}
|
||||
/>
|
||||
);
|
||||
|
||||
/**
|
||||
* Component to render subagent execution results
|
||||
*/
|
||||
@@ -229,6 +256,13 @@ export const ToolMessage: React.FC<ToolMessageProps> = ({
|
||||
{displayRenderer.type === 'todo' && (
|
||||
<TodoResultRenderer data={displayRenderer.data} />
|
||||
)}
|
||||
{displayRenderer.type === 'plan' && (
|
||||
<PlanResultRenderer
|
||||
data={displayRenderer.data}
|
||||
availableHeight={availableHeight}
|
||||
childWidth={childWidth}
|
||||
/>
|
||||
)}
|
||||
{displayRenderer.type === 'task' && (
|
||||
<SubagentExecutionRenderer
|
||||
data={displayRenderer.data}
|
||||
|
||||
@@ -53,7 +53,7 @@ export function AgentsManagerDialog({
|
||||
const manager = config.getSubagentManager();
|
||||
|
||||
// Load agents from all levels separately to show all agents including conflicts
|
||||
const allAgents = await manager.listSubagents();
|
||||
const allAgents = await manager.listSubagents({ force: true });
|
||||
|
||||
setAvailableAgents(allAgents);
|
||||
}, [config]);
|
||||
|
||||
@@ -526,7 +526,7 @@ describe('KeypressContext - Kitty Protocol', () => {
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
expect(keyHandler).toHaveBeenCalledTimes(2); // 1 paste event + 1 paste event for 'after'
|
||||
expect(keyHandler).toHaveBeenCalledTimes(6); // 1 paste event + 5 individual chars for 'after'
|
||||
});
|
||||
|
||||
// Should emit paste event first
|
||||
@@ -538,12 +538,40 @@ describe('KeypressContext - Kitty Protocol', () => {
|
||||
}),
|
||||
);
|
||||
|
||||
// Then process 'after' as a paste event (since it's > 2 chars)
|
||||
// Then process 'after' as individual characters (since it doesn't contain return)
|
||||
expect(keyHandler).toHaveBeenNthCalledWith(
|
||||
2,
|
||||
expect.objectContaining({
|
||||
paste: true,
|
||||
sequence: 'after',
|
||||
name: 'a',
|
||||
paste: false,
|
||||
}),
|
||||
);
|
||||
expect(keyHandler).toHaveBeenNthCalledWith(
|
||||
3,
|
||||
expect.objectContaining({
|
||||
name: 'f',
|
||||
paste: false,
|
||||
}),
|
||||
);
|
||||
expect(keyHandler).toHaveBeenNthCalledWith(
|
||||
4,
|
||||
expect.objectContaining({
|
||||
name: 't',
|
||||
paste: false,
|
||||
}),
|
||||
);
|
||||
expect(keyHandler).toHaveBeenNthCalledWith(
|
||||
5,
|
||||
expect.objectContaining({
|
||||
name: 'e',
|
||||
paste: false,
|
||||
}),
|
||||
);
|
||||
expect(keyHandler).toHaveBeenNthCalledWith(
|
||||
6,
|
||||
expect.objectContaining({
|
||||
name: 'r',
|
||||
paste: false,
|
||||
}),
|
||||
);
|
||||
});
|
||||
@@ -571,7 +599,7 @@ describe('KeypressContext - Kitty Protocol', () => {
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
expect(keyHandler).toHaveBeenCalledTimes(14); // Adjusted based on actual behavior
|
||||
expect(keyHandler).toHaveBeenCalledTimes(16); // 5 + 1 + 6 + 1 + 3 = 16 calls
|
||||
});
|
||||
|
||||
// Check the sequence: 'start' (5 chars) + paste1 + 'middle' (6 chars) + paste2 + 'end' (3 chars as paste)
|
||||
@@ -643,13 +671,18 @@ describe('KeypressContext - Kitty Protocol', () => {
|
||||
}),
|
||||
);
|
||||
|
||||
// 'end' as paste event (since it's > 2 chars)
|
||||
// 'end' as individual characters (since it doesn't contain return)
|
||||
expect(keyHandler).toHaveBeenNthCalledWith(
|
||||
callIndex++,
|
||||
expect.objectContaining({
|
||||
paste: true,
|
||||
sequence: 'end',
|
||||
}),
|
||||
expect.objectContaining({ name: 'e' }),
|
||||
);
|
||||
expect(keyHandler).toHaveBeenNthCalledWith(
|
||||
callIndex++,
|
||||
expect.objectContaining({ name: 'n' }),
|
||||
);
|
||||
expect(keyHandler).toHaveBeenNthCalledWith(
|
||||
callIndex++,
|
||||
expect.objectContaining({ name: 'd' }),
|
||||
);
|
||||
});
|
||||
|
||||
@@ -738,16 +771,18 @@ describe('KeypressContext - Kitty Protocol', () => {
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
// With the current implementation, fragmented data gets processed differently
|
||||
// The first fragment '\x1b[20' gets processed as individual characters
|
||||
// The second fragment '0~content\x1b[2' gets processed as paste + individual chars
|
||||
// The third fragment '01~' gets processed as individual characters
|
||||
expect(keyHandler).toHaveBeenCalled();
|
||||
// With the current implementation, fragmented paste markers get reconstructed
|
||||
// into a single paste event for 'content'
|
||||
expect(keyHandler).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
|
||||
// The current implementation processes fragmented paste markers as separate events
|
||||
// rather than reconstructing them into a single paste event
|
||||
expect(keyHandler.mock.calls.length).toBeGreaterThan(1);
|
||||
// Should reconstruct the fragmented paste markers into a single paste event
|
||||
expect(keyHandler).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
paste: true,
|
||||
sequence: 'content',
|
||||
}),
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -851,19 +886,38 @@ describe('KeypressContext - Kitty Protocol', () => {
|
||||
stdin.emit('data', Buffer.from('lo'));
|
||||
});
|
||||
|
||||
// With the current implementation, data is processed as it arrives
|
||||
// First chunk 'hel' is treated as paste (multi-character)
|
||||
// With the current implementation, data is processed as individual characters
|
||||
// since 'hel' doesn't contain return (0x0d)
|
||||
expect(keyHandler).toHaveBeenNthCalledWith(
|
||||
1,
|
||||
expect.objectContaining({
|
||||
paste: true,
|
||||
sequence: 'hel',
|
||||
name: 'h',
|
||||
sequence: 'h',
|
||||
paste: false,
|
||||
}),
|
||||
);
|
||||
|
||||
// Second chunk 'lo' is processed as individual characters
|
||||
expect(keyHandler).toHaveBeenNthCalledWith(
|
||||
2,
|
||||
expect.objectContaining({
|
||||
name: 'e',
|
||||
sequence: 'e',
|
||||
paste: false,
|
||||
}),
|
||||
);
|
||||
|
||||
expect(keyHandler).toHaveBeenNthCalledWith(
|
||||
3,
|
||||
expect.objectContaining({
|
||||
name: 'l',
|
||||
sequence: 'l',
|
||||
paste: false,
|
||||
}),
|
||||
);
|
||||
|
||||
// Second chunk 'lo' is also processed as individual characters
|
||||
expect(keyHandler).toHaveBeenNthCalledWith(
|
||||
4,
|
||||
expect.objectContaining({
|
||||
name: 'l',
|
||||
sequence: 'l',
|
||||
@@ -872,7 +926,7 @@ describe('KeypressContext - Kitty Protocol', () => {
|
||||
);
|
||||
|
||||
expect(keyHandler).toHaveBeenNthCalledWith(
|
||||
3,
|
||||
5,
|
||||
expect.objectContaining({
|
||||
name: 'o',
|
||||
sequence: 'o',
|
||||
@@ -880,7 +934,7 @@ describe('KeypressContext - Kitty Protocol', () => {
|
||||
}),
|
||||
);
|
||||
|
||||
expect(keyHandler).toHaveBeenCalledTimes(3);
|
||||
expect(keyHandler).toHaveBeenCalledTimes(5);
|
||||
} finally {
|
||||
vi.useRealTimers();
|
||||
}
|
||||
@@ -907,14 +961,20 @@ describe('KeypressContext - Kitty Protocol', () => {
|
||||
});
|
||||
|
||||
// Should flush immediately without waiting for timeout
|
||||
// Large data gets treated as paste event
|
||||
expect(keyHandler).toHaveBeenCalledTimes(1);
|
||||
expect(keyHandler).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
paste: true,
|
||||
sequence: largeData,
|
||||
}),
|
||||
);
|
||||
// Large data without return gets treated as individual characters
|
||||
expect(keyHandler).toHaveBeenCalledTimes(65);
|
||||
|
||||
// Each character should be processed individually
|
||||
for (let i = 0; i < 65; i++) {
|
||||
expect(keyHandler).toHaveBeenNthCalledWith(
|
||||
i + 1,
|
||||
expect.objectContaining({
|
||||
name: 'x',
|
||||
sequence: 'x',
|
||||
paste: false,
|
||||
}),
|
||||
);
|
||||
}
|
||||
|
||||
// Advancing timer should not cause additional calls
|
||||
const callCountBefore = keyHandler.mock.calls.length;
|
||||
|
||||
@@ -407,7 +407,11 @@ export function KeypressProvider({
|
||||
return;
|
||||
}
|
||||
|
||||
if (rawDataBuffer.length <= 2 || isPaste) {
|
||||
if (
|
||||
(rawDataBuffer.length <= 2 && rawDataBuffer.includes(0x0d)) ||
|
||||
!rawDataBuffer.includes(0x0d) ||
|
||||
isPaste
|
||||
) {
|
||||
keypressStream.write(rawDataBuffer);
|
||||
} else {
|
||||
// Flush raw data buffer as a paste event
|
||||
|
||||
@@ -158,7 +158,19 @@ describe('useAutoAcceptIndicator', () => {
|
||||
expect(mockConfigInstance.getApprovalMode).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
|
||||
it('should toggle the indicator and update config when Shift+Tab or Ctrl+Y is pressed', () => {
|
||||
it('should initialize with ApprovalMode.PLAN if config.getApprovalMode returns ApprovalMode.PLAN', () => {
|
||||
mockConfigInstance.getApprovalMode.mockReturnValue(ApprovalMode.PLAN);
|
||||
const { result } = renderHook(() =>
|
||||
useAutoAcceptIndicator({
|
||||
config: mockConfigInstance as unknown as ActualConfigType,
|
||||
addItem: vi.fn(),
|
||||
}),
|
||||
);
|
||||
expect(result.current).toBe(ApprovalMode.PLAN);
|
||||
expect(mockConfigInstance.getApprovalMode).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
|
||||
it('should cycle approval modes when Shift+Tab is pressed', () => {
|
||||
mockConfigInstance.getApprovalMode.mockReturnValue(ApprovalMode.DEFAULT);
|
||||
const { result } = renderHook(() =>
|
||||
useAutoAcceptIndicator({
|
||||
@@ -180,23 +192,10 @@ describe('useAutoAcceptIndicator', () => {
|
||||
expect(result.current).toBe(ApprovalMode.AUTO_EDIT);
|
||||
|
||||
act(() => {
|
||||
capturedUseKeypressHandler({ name: 'y', ctrl: true } as Key);
|
||||
});
|
||||
expect(mockConfigInstance.setApprovalMode).toHaveBeenCalledWith(
|
||||
ApprovalMode.YOLO,
|
||||
);
|
||||
expect(result.current).toBe(ApprovalMode.YOLO);
|
||||
|
||||
act(() => {
|
||||
capturedUseKeypressHandler({ name: 'y', ctrl: true } as Key);
|
||||
});
|
||||
expect(mockConfigInstance.setApprovalMode).toHaveBeenCalledWith(
|
||||
ApprovalMode.DEFAULT,
|
||||
);
|
||||
expect(result.current).toBe(ApprovalMode.DEFAULT);
|
||||
|
||||
act(() => {
|
||||
capturedUseKeypressHandler({ name: 'y', ctrl: true } as Key);
|
||||
capturedUseKeypressHandler({
|
||||
name: 'tab',
|
||||
shift: true,
|
||||
} as Key);
|
||||
});
|
||||
expect(mockConfigInstance.setApprovalMode).toHaveBeenCalledWith(
|
||||
ApprovalMode.YOLO,
|
||||
@@ -210,9 +209,9 @@ describe('useAutoAcceptIndicator', () => {
|
||||
} as Key);
|
||||
});
|
||||
expect(mockConfigInstance.setApprovalMode).toHaveBeenCalledWith(
|
||||
ApprovalMode.AUTO_EDIT,
|
||||
ApprovalMode.PLAN,
|
||||
);
|
||||
expect(result.current).toBe(ApprovalMode.AUTO_EDIT);
|
||||
expect(result.current).toBe(ApprovalMode.PLAN);
|
||||
|
||||
act(() => {
|
||||
capturedUseKeypressHandler({
|
||||
@@ -314,118 +313,10 @@ describe('useAutoAcceptIndicator', () => {
|
||||
mockConfigInstance.isTrustedFolder.mockReturnValue(false);
|
||||
});
|
||||
|
||||
it('should not enable YOLO mode when Ctrl+Y is pressed', () => {
|
||||
mockConfigInstance.getApprovalMode.mockReturnValue(ApprovalMode.DEFAULT);
|
||||
mockConfigInstance.setApprovalMode.mockImplementation(() => {
|
||||
throw new Error(
|
||||
'Cannot enable privileged approval modes in an untrusted folder.',
|
||||
);
|
||||
});
|
||||
const mockAddItem = vi.fn();
|
||||
const { result } = renderHook(() =>
|
||||
useAutoAcceptIndicator({
|
||||
config: mockConfigInstance as unknown as ActualConfigType,
|
||||
addItem: mockAddItem,
|
||||
}),
|
||||
);
|
||||
|
||||
expect(result.current).toBe(ApprovalMode.DEFAULT);
|
||||
|
||||
act(() => {
|
||||
capturedUseKeypressHandler({ name: 'y', ctrl: true } as Key);
|
||||
});
|
||||
|
||||
// We expect setApprovalMode to be called, and the error to be caught.
|
||||
expect(mockConfigInstance.setApprovalMode).toHaveBeenCalledWith(
|
||||
ApprovalMode.YOLO,
|
||||
);
|
||||
expect(mockAddItem).toHaveBeenCalled();
|
||||
// Verify the underlying config value was not changed
|
||||
expect(mockConfigInstance.getApprovalMode()).toBe(ApprovalMode.DEFAULT);
|
||||
});
|
||||
|
||||
it('should not enable AUTO_EDIT mode when Shift+Tab is pressed', () => {
|
||||
mockConfigInstance.getApprovalMode.mockReturnValue(ApprovalMode.DEFAULT);
|
||||
mockConfigInstance.setApprovalMode.mockImplementation(() => {
|
||||
throw new Error(
|
||||
'Cannot enable privileged approval modes in an untrusted folder.',
|
||||
);
|
||||
});
|
||||
const mockAddItem = vi.fn();
|
||||
const { result } = renderHook(() =>
|
||||
useAutoAcceptIndicator({
|
||||
config: mockConfigInstance as unknown as ActualConfigType,
|
||||
addItem: mockAddItem,
|
||||
}),
|
||||
);
|
||||
|
||||
expect(result.current).toBe(ApprovalMode.DEFAULT);
|
||||
|
||||
act(() => {
|
||||
capturedUseKeypressHandler({
|
||||
name: 'tab',
|
||||
shift: true,
|
||||
} as Key);
|
||||
});
|
||||
|
||||
// We expect setApprovalMode to be called, and the error to be caught.
|
||||
expect(mockConfigInstance.setApprovalMode).toHaveBeenCalledWith(
|
||||
ApprovalMode.AUTO_EDIT,
|
||||
);
|
||||
expect(mockAddItem).toHaveBeenCalled();
|
||||
// Verify the underlying config value was not changed
|
||||
expect(mockConfigInstance.getApprovalMode()).toBe(ApprovalMode.DEFAULT);
|
||||
});
|
||||
|
||||
it('should disable YOLO mode when Ctrl+Y is pressed', () => {
|
||||
mockConfigInstance.getApprovalMode.mockReturnValue(ApprovalMode.YOLO);
|
||||
const mockAddItem = vi.fn();
|
||||
renderHook(() =>
|
||||
useAutoAcceptIndicator({
|
||||
config: mockConfigInstance as unknown as ActualConfigType,
|
||||
addItem: mockAddItem,
|
||||
}),
|
||||
);
|
||||
|
||||
act(() => {
|
||||
capturedUseKeypressHandler({ name: 'y', ctrl: true } as Key);
|
||||
});
|
||||
|
||||
expect(mockConfigInstance.setApprovalMode).toHaveBeenCalledWith(
|
||||
ApprovalMode.DEFAULT,
|
||||
);
|
||||
expect(mockConfigInstance.getApprovalMode()).toBe(ApprovalMode.DEFAULT);
|
||||
});
|
||||
|
||||
it('should disable AUTO_EDIT mode when Shift+Tab is pressed', () => {
|
||||
mockConfigInstance.getApprovalMode.mockReturnValue(
|
||||
ApprovalMode.AUTO_EDIT,
|
||||
);
|
||||
const mockAddItem = vi.fn();
|
||||
renderHook(() =>
|
||||
useAutoAcceptIndicator({
|
||||
config: mockConfigInstance as unknown as ActualConfigType,
|
||||
addItem: mockAddItem,
|
||||
}),
|
||||
);
|
||||
|
||||
act(() => {
|
||||
capturedUseKeypressHandler({
|
||||
name: 'tab',
|
||||
shift: true,
|
||||
} as Key);
|
||||
});
|
||||
|
||||
expect(mockConfigInstance.setApprovalMode).toHaveBeenCalledWith(
|
||||
ApprovalMode.DEFAULT,
|
||||
);
|
||||
expect(mockConfigInstance.getApprovalMode()).toBe(ApprovalMode.DEFAULT);
|
||||
});
|
||||
|
||||
it('should show a warning when trying to enable privileged modes', () => {
|
||||
// Mock the error thrown by setApprovalMode
|
||||
it('should show a warning when cycling from DEFAULT to AUTO_EDIT', () => {
|
||||
const errorMessage =
|
||||
'Cannot enable privileged approval modes in an untrusted folder.';
|
||||
mockConfigInstance.getApprovalMode.mockReturnValue(ApprovalMode.DEFAULT);
|
||||
mockConfigInstance.setApprovalMode.mockImplementation(() => {
|
||||
throw new Error(errorMessage);
|
||||
});
|
||||
@@ -438,11 +329,13 @@ describe('useAutoAcceptIndicator', () => {
|
||||
}),
|
||||
);
|
||||
|
||||
// Try to enable YOLO mode
|
||||
act(() => {
|
||||
capturedUseKeypressHandler({ name: 'y', ctrl: true } as Key);
|
||||
capturedUseKeypressHandler({ name: 'tab', shift: true } as Key);
|
||||
});
|
||||
|
||||
expect(mockConfigInstance.setApprovalMode).toHaveBeenCalledWith(
|
||||
ApprovalMode.AUTO_EDIT,
|
||||
);
|
||||
expect(mockAddItem).toHaveBeenCalledWith(
|
||||
{
|
||||
type: MessageType.INFO,
|
||||
@@ -450,15 +343,33 @@ describe('useAutoAcceptIndicator', () => {
|
||||
},
|
||||
expect.any(Number),
|
||||
);
|
||||
});
|
||||
|
||||
// Try to enable AUTO_EDIT mode
|
||||
act(() => {
|
||||
capturedUseKeypressHandler({
|
||||
name: 'tab',
|
||||
shift: true,
|
||||
} as Key);
|
||||
it('should show a warning when cycling from AUTO_EDIT to YOLO', () => {
|
||||
const errorMessage =
|
||||
'Cannot enable privileged approval modes in an untrusted folder.';
|
||||
mockConfigInstance.getApprovalMode.mockReturnValue(
|
||||
ApprovalMode.AUTO_EDIT,
|
||||
);
|
||||
mockConfigInstance.setApprovalMode.mockImplementation(() => {
|
||||
throw new Error(errorMessage);
|
||||
});
|
||||
|
||||
const mockAddItem = vi.fn();
|
||||
renderHook(() =>
|
||||
useAutoAcceptIndicator({
|
||||
config: mockConfigInstance as unknown as ActualConfigType,
|
||||
addItem: mockAddItem,
|
||||
}),
|
||||
);
|
||||
|
||||
act(() => {
|
||||
capturedUseKeypressHandler({ name: 'tab', shift: true } as Key);
|
||||
});
|
||||
|
||||
expect(mockConfigInstance.setApprovalMode).toHaveBeenCalledWith(
|
||||
ApprovalMode.YOLO,
|
||||
);
|
||||
expect(mockAddItem).toHaveBeenCalledWith(
|
||||
{
|
||||
type: MessageType.INFO,
|
||||
@@ -466,8 +377,27 @@ describe('useAutoAcceptIndicator', () => {
|
||||
},
|
||||
expect.any(Number),
|
||||
);
|
||||
});
|
||||
|
||||
expect(mockAddItem).toHaveBeenCalledTimes(2);
|
||||
it('should cycle from YOLO to PLAN when Shift+Tab is pressed', () => {
|
||||
mockConfigInstance.getApprovalMode.mockReturnValue(ApprovalMode.YOLO);
|
||||
const mockAddItem = vi.fn();
|
||||
renderHook(() =>
|
||||
useAutoAcceptIndicator({
|
||||
config: mockConfigInstance as unknown as ActualConfigType,
|
||||
addItem: mockAddItem,
|
||||
}),
|
||||
);
|
||||
|
||||
act(() => {
|
||||
capturedUseKeypressHandler({ name: 'tab', shift: true } as Key);
|
||||
});
|
||||
|
||||
expect(mockConfigInstance.setApprovalMode).toHaveBeenCalledWith(
|
||||
ApprovalMode.PLAN,
|
||||
);
|
||||
expect(mockConfigInstance.getApprovalMode()).toBe(ApprovalMode.PLAN);
|
||||
expect(mockAddItem).not.toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -4,7 +4,11 @@
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { ApprovalMode, type Config } from '@qwen-code/qwen-code-core';
|
||||
import {
|
||||
type ApprovalMode,
|
||||
APPROVAL_MODES,
|
||||
type Config,
|
||||
} from '@qwen-code/qwen-code-core';
|
||||
import { useEffect, useState } from 'react';
|
||||
import { useKeypress } from './useKeypress.js';
|
||||
import type { HistoryItemWithoutId } from '../types.js';
|
||||
@@ -29,34 +33,28 @@ export function useAutoAcceptIndicator({
|
||||
|
||||
useKeypress(
|
||||
(key) => {
|
||||
let nextApprovalMode: ApprovalMode | undefined;
|
||||
|
||||
if (key.ctrl && key.name === 'y') {
|
||||
nextApprovalMode =
|
||||
config.getApprovalMode() === ApprovalMode.YOLO
|
||||
? ApprovalMode.DEFAULT
|
||||
: ApprovalMode.YOLO;
|
||||
} else if (key.shift && key.name === 'tab') {
|
||||
nextApprovalMode =
|
||||
config.getApprovalMode() === ApprovalMode.AUTO_EDIT
|
||||
? ApprovalMode.DEFAULT
|
||||
: ApprovalMode.AUTO_EDIT;
|
||||
if (!(key.shift && key.name === 'tab')) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (nextApprovalMode) {
|
||||
try {
|
||||
config.setApprovalMode(nextApprovalMode);
|
||||
// Update local state immediately for responsiveness
|
||||
setShowAutoAcceptIndicator(nextApprovalMode);
|
||||
} catch (e) {
|
||||
addItem(
|
||||
{
|
||||
type: MessageType.INFO,
|
||||
text: (e as Error).message,
|
||||
},
|
||||
Date.now(),
|
||||
);
|
||||
}
|
||||
const currentMode = config.getApprovalMode();
|
||||
const currentIndex = APPROVAL_MODES.indexOf(currentMode);
|
||||
const nextIndex =
|
||||
currentIndex === -1 ? 0 : (currentIndex + 1) % APPROVAL_MODES.length;
|
||||
const nextApprovalMode = APPROVAL_MODES[nextIndex];
|
||||
|
||||
try {
|
||||
config.setApprovalMode(nextApprovalMode);
|
||||
// Update local state immediately for responsiveness
|
||||
setShowAutoAcceptIndicator(nextApprovalMode);
|
||||
} catch (e) {
|
||||
addItem(
|
||||
{
|
||||
type: MessageType.INFO,
|
||||
text: (e as Error).message,
|
||||
},
|
||||
Date.now(),
|
||||
);
|
||||
}
|
||||
},
|
||||
{ isActive: true },
|
||||
|
||||
@@ -60,7 +60,9 @@ const mockParseAndFormatApiError = vi.hoisted(() => vi.fn());
|
||||
const mockHandleVisionSwitch = vi.hoisted(() =>
|
||||
vi.fn().mockResolvedValue({ shouldProceed: true }),
|
||||
);
|
||||
const mockRestoreOriginalModel = vi.hoisted(() => vi.fn());
|
||||
const mockRestoreOriginalModel = vi.hoisted(() =>
|
||||
vi.fn().mockResolvedValue(undefined),
|
||||
);
|
||||
|
||||
vi.mock('@qwen-code/qwen-code-core', async (importOriginal) => {
|
||||
const actualCoreModule = (await importOriginal()) as any;
|
||||
@@ -301,6 +303,8 @@ describe('useGeminiStream', () => {
|
||||
() => {},
|
||||
() => {},
|
||||
() => {},
|
||||
false, // visionModelPreviewEnabled
|
||||
undefined, // onVisionSwitchRequired (optional)
|
||||
);
|
||||
},
|
||||
{
|
||||
@@ -462,6 +466,8 @@ describe('useGeminiStream', () => {
|
||||
() => {},
|
||||
() => {},
|
||||
() => {},
|
||||
false, // visionModelPreviewEnabled
|
||||
undefined, // onVisionSwitchRequired (optional)
|
||||
),
|
||||
);
|
||||
|
||||
@@ -541,6 +547,8 @@ describe('useGeminiStream', () => {
|
||||
() => {},
|
||||
() => {},
|
||||
() => {},
|
||||
false, // visionModelPreviewEnabled
|
||||
undefined, // onVisionSwitchRequired (optional)
|
||||
),
|
||||
);
|
||||
|
||||
@@ -649,6 +657,8 @@ describe('useGeminiStream', () => {
|
||||
() => {},
|
||||
() => {},
|
||||
() => {},
|
||||
false, // visionModelPreviewEnabled
|
||||
undefined, // onVisionSwitchRequired (optional)
|
||||
),
|
||||
);
|
||||
|
||||
@@ -758,6 +768,8 @@ describe('useGeminiStream', () => {
|
||||
() => {},
|
||||
() => {},
|
||||
() => {},
|
||||
false, // visionModelPreviewEnabled
|
||||
undefined, // onVisionSwitchRequired (optional)
|
||||
),
|
||||
);
|
||||
|
||||
@@ -887,6 +899,8 @@ describe('useGeminiStream', () => {
|
||||
() => {},
|
||||
() => {},
|
||||
cancelSubmitSpy,
|
||||
false, // visionModelPreviewEnabled
|
||||
undefined, // onVisionSwitchRequired (optional)
|
||||
),
|
||||
);
|
||||
|
||||
@@ -1198,6 +1212,8 @@ describe('useGeminiStream', () => {
|
||||
() => {},
|
||||
() => {},
|
||||
() => {},
|
||||
false, // visionModelPreviewEnabled
|
||||
undefined, // onVisionSwitchRequired (optional)
|
||||
),
|
||||
);
|
||||
|
||||
@@ -1251,6 +1267,8 @@ describe('useGeminiStream', () => {
|
||||
() => {},
|
||||
() => {},
|
||||
() => {},
|
||||
false, // visionModelPreviewEnabled
|
||||
undefined, // onVisionSwitchRequired (optional)
|
||||
),
|
||||
);
|
||||
|
||||
@@ -1301,6 +1319,8 @@ describe('useGeminiStream', () => {
|
||||
() => {},
|
||||
() => {},
|
||||
() => {},
|
||||
false, // visionModelPreviewEnabled
|
||||
undefined, // onVisionSwitchRequired (optional)
|
||||
),
|
||||
);
|
||||
|
||||
@@ -1349,6 +1369,8 @@ describe('useGeminiStream', () => {
|
||||
() => {},
|
||||
() => {},
|
||||
() => {},
|
||||
false, // visionModelPreviewEnabled
|
||||
undefined, // onVisionSwitchRequired (optional)
|
||||
),
|
||||
);
|
||||
|
||||
@@ -1398,6 +1420,8 @@ describe('useGeminiStream', () => {
|
||||
() => {},
|
||||
() => {},
|
||||
() => {},
|
||||
false, // visionModelPreviewEnabled
|
||||
undefined, // onVisionSwitchRequired (optional)
|
||||
),
|
||||
);
|
||||
|
||||
@@ -1487,6 +1511,8 @@ describe('useGeminiStream', () => {
|
||||
() => {},
|
||||
() => {},
|
||||
() => {},
|
||||
false, // visionModelPreviewEnabled
|
||||
undefined, // onVisionSwitchRequired (optional)
|
||||
),
|
||||
);
|
||||
|
||||
@@ -1537,6 +1563,8 @@ describe('useGeminiStream', () => {
|
||||
vi.fn(), // setModelSwitched
|
||||
vi.fn(), // onEditorClose
|
||||
vi.fn(), // onCancelSubmit
|
||||
false, // visionModelPreviewEnabled
|
||||
undefined, // onVisionSwitchRequired (optional)
|
||||
),
|
||||
);
|
||||
|
||||
@@ -1602,6 +1630,8 @@ describe('useGeminiStream', () => {
|
||||
() => {},
|
||||
() => {},
|
||||
() => {},
|
||||
false, // visionModelPreviewEnabled
|
||||
undefined, // onVisionSwitchRequired (optional)
|
||||
),
|
||||
);
|
||||
|
||||
@@ -1680,6 +1710,8 @@ describe('useGeminiStream', () => {
|
||||
() => {},
|
||||
() => {},
|
||||
() => {},
|
||||
false, // visionModelPreviewEnabled
|
||||
undefined, // onVisionSwitchRequired (optional)
|
||||
),
|
||||
);
|
||||
|
||||
@@ -1734,6 +1766,8 @@ describe('useGeminiStream', () => {
|
||||
() => {},
|
||||
() => {},
|
||||
() => {},
|
||||
false, // visionModelPreviewEnabled
|
||||
undefined, // onVisionSwitchRequired (optional)
|
||||
),
|
||||
);
|
||||
|
||||
@@ -1943,6 +1977,8 @@ describe('useGeminiStream', () => {
|
||||
() => {},
|
||||
() => {},
|
||||
() => {},
|
||||
false, // visionModelPreviewEnabled
|
||||
undefined, // onVisionSwitchRequired (optional)
|
||||
),
|
||||
);
|
||||
|
||||
@@ -1975,6 +2011,8 @@ describe('useGeminiStream', () => {
|
||||
() => {},
|
||||
() => {},
|
||||
() => {},
|
||||
false, // visionModelPreviewEnabled
|
||||
undefined, // onVisionSwitchRequired (optional)
|
||||
),
|
||||
);
|
||||
|
||||
@@ -2028,6 +2066,8 @@ describe('useGeminiStream', () => {
|
||||
() => {},
|
||||
() => {},
|
||||
() => {},
|
||||
false, // visionModelPreviewEnabled
|
||||
undefined, // onVisionSwitchRequired (optional)
|
||||
),
|
||||
);
|
||||
|
||||
@@ -2065,6 +2105,8 @@ describe('useGeminiStream', () => {
|
||||
() => {},
|
||||
() => {},
|
||||
() => {},
|
||||
false, // visionModelPreviewEnabled
|
||||
undefined, // onVisionSwitchRequired (optional)
|
||||
),
|
||||
);
|
||||
|
||||
|
||||
@@ -89,7 +89,7 @@ export const useGeminiStream = (
|
||||
setModelSwitchedFromQuotaError: React.Dispatch<React.SetStateAction<boolean>>,
|
||||
onEditorClose: () => void,
|
||||
onCancelSubmit: () => void,
|
||||
visionModelPreviewEnabled: boolean = false,
|
||||
visionModelPreviewEnabled: boolean,
|
||||
onVisionSwitchRequired?: (query: PartListUnion) => Promise<{
|
||||
modelOverride?: string;
|
||||
persistSessionModel?: string;
|
||||
@@ -765,7 +765,9 @@ export const useGeminiStream = (
|
||||
|
||||
if (processingStatus === StreamProcessingStatus.UserCancelled) {
|
||||
// Restore original model if it was temporarily overridden
|
||||
restoreOriginalModel();
|
||||
restoreOriginalModel().catch((error) => {
|
||||
console.error('Failed to restore original model:', error);
|
||||
});
|
||||
isSubmittingQueryRef.current = false;
|
||||
return;
|
||||
}
|
||||
@@ -780,10 +782,14 @@ export const useGeminiStream = (
|
||||
}
|
||||
|
||||
// Restore original model if it was temporarily overridden
|
||||
restoreOriginalModel();
|
||||
restoreOriginalModel().catch((error) => {
|
||||
console.error('Failed to restore original model:', error);
|
||||
});
|
||||
} catch (error: unknown) {
|
||||
// Restore original model if it was temporarily overridden
|
||||
restoreOriginalModel();
|
||||
restoreOriginalModel().catch((error) => {
|
||||
console.error('Failed to restore original model:', error);
|
||||
});
|
||||
|
||||
if (error instanceof UnauthorizedError) {
|
||||
onAuthError();
|
||||
|
||||
@@ -8,7 +8,28 @@
|
||||
import { describe, it, expect, vi, beforeEach } from 'vitest';
|
||||
import { renderHook, act } from '@testing-library/react';
|
||||
import type { Part, PartListUnion } from '@google/genai';
|
||||
import { AuthType, type Config } from '@qwen-code/qwen-code-core';
|
||||
import { AuthType, type Config, ApprovalMode } from '@qwen-code/qwen-code-core';
|
||||
|
||||
// Mock the image format functions from core package
|
||||
vi.mock('@qwen-code/qwen-code-core', async (importOriginal) => {
|
||||
const actual = (await importOriginal()) as Record<string, unknown>;
|
||||
return {
|
||||
...actual,
|
||||
isSupportedImageMimeType: vi.fn((mimeType: string) =>
|
||||
[
|
||||
'image/png',
|
||||
'image/jpeg',
|
||||
'image/jpg',
|
||||
'image/gif',
|
||||
'image/webp',
|
||||
].includes(mimeType),
|
||||
),
|
||||
getUnsupportedImageFormatWarning: vi.fn(
|
||||
() =>
|
||||
'Only the following image formats are supported: BMP, JPEG, JPG, PNG, TIFF, WEBP, HEIC. Other formats may not work as expected.',
|
||||
),
|
||||
};
|
||||
});
|
||||
import {
|
||||
shouldOfferVisionSwitch,
|
||||
processVisionSwitchOutcome,
|
||||
@@ -41,7 +62,7 @@ describe('useVisionAutoSwitch helpers', () => {
|
||||
const result = shouldOfferVisionSwitch(
|
||||
parts,
|
||||
AuthType.QWEN_OAUTH,
|
||||
'qwen-vl-max-latest',
|
||||
'vision-model',
|
||||
true,
|
||||
);
|
||||
expect(result).toBe(false);
|
||||
@@ -108,6 +129,56 @@ describe('useVisionAutoSwitch helpers', () => {
|
||||
);
|
||||
expect(result).toBe(false);
|
||||
});
|
||||
|
||||
it('returns true when image parts exist in YOLO mode context', () => {
|
||||
const parts: PartListUnion = [
|
||||
{ inlineData: { mimeType: 'image/png', data: '...' } },
|
||||
];
|
||||
const result = shouldOfferVisionSwitch(
|
||||
parts,
|
||||
AuthType.QWEN_OAUTH,
|
||||
'qwen3-coder-plus',
|
||||
true,
|
||||
);
|
||||
expect(result).toBe(true);
|
||||
});
|
||||
|
||||
it('returns false when no image parts exist in YOLO mode context', () => {
|
||||
const parts: PartListUnion = [{ text: 'just text' }];
|
||||
const result = shouldOfferVisionSwitch(
|
||||
parts,
|
||||
AuthType.QWEN_OAUTH,
|
||||
'qwen3-coder-plus',
|
||||
true,
|
||||
);
|
||||
expect(result).toBe(false);
|
||||
});
|
||||
|
||||
it('returns false when already using vision model in YOLO mode context', () => {
|
||||
const parts: PartListUnion = [
|
||||
{ inlineData: { mimeType: 'image/png', data: '...' } },
|
||||
];
|
||||
const result = shouldOfferVisionSwitch(
|
||||
parts,
|
||||
AuthType.QWEN_OAUTH,
|
||||
'vision-model',
|
||||
true,
|
||||
);
|
||||
expect(result).toBe(false);
|
||||
});
|
||||
|
||||
it('returns false when authType is not QWEN_OAUTH in YOLO mode context', () => {
|
||||
const parts: PartListUnion = [
|
||||
{ inlineData: { mimeType: 'image/png', data: '...' } },
|
||||
];
|
||||
const result = shouldOfferVisionSwitch(
|
||||
parts,
|
||||
AuthType.USE_GEMINI,
|
||||
'qwen3-coder-plus',
|
||||
true,
|
||||
);
|
||||
expect(result).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe('processVisionSwitchOutcome', () => {
|
||||
@@ -125,11 +196,11 @@ describe('useVisionAutoSwitch helpers', () => {
|
||||
expect(result).toEqual({ persistSessionModel: vl });
|
||||
});
|
||||
|
||||
it('maps DisallowWithGuidance to showGuidance', () => {
|
||||
it('maps ContinueWithCurrentModel to empty result', () => {
|
||||
const result = processVisionSwitchOutcome(
|
||||
VisionSwitchOutcome.DisallowWithGuidance,
|
||||
VisionSwitchOutcome.ContinueWithCurrentModel,
|
||||
);
|
||||
expect(result).toEqual({ showGuidance: true });
|
||||
expect(result).toEqual({});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -151,13 +222,20 @@ describe('useVisionAutoSwitch hook', () => {
|
||||
ts: number,
|
||||
) => any;
|
||||
|
||||
const createMockConfig = (authType: AuthType, initialModel: string) => {
|
||||
const createMockConfig = (
|
||||
authType: AuthType,
|
||||
initialModel: string,
|
||||
approvalMode: ApprovalMode = ApprovalMode.DEFAULT,
|
||||
vlmSwitchMode?: string,
|
||||
) => {
|
||||
let currentModel = initialModel;
|
||||
const mockConfig: Partial<Config> = {
|
||||
getModel: vi.fn(() => currentModel),
|
||||
setModel: vi.fn((m: string) => {
|
||||
setModel: vi.fn(async (m: string) => {
|
||||
currentModel = m;
|
||||
}),
|
||||
getApprovalMode: vi.fn(() => approvalMode),
|
||||
getVlmSwitchMode: vi.fn(() => vlmSwitchMode),
|
||||
getContentGeneratorConfig: vi.fn(() => ({
|
||||
authType,
|
||||
model: currentModel,
|
||||
@@ -226,11 +304,9 @@ describe('useVisionAutoSwitch hook', () => {
|
||||
expect(onVisionSwitchRequired).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('shows guidance and blocks when dialog returns showGuidance', async () => {
|
||||
it('continues with current model when dialog returns empty result', async () => {
|
||||
const config = createMockConfig(AuthType.QWEN_OAUTH, 'qwen3-coder-plus');
|
||||
const onVisionSwitchRequired = vi
|
||||
.fn()
|
||||
.mockResolvedValue({ showGuidance: true });
|
||||
const onVisionSwitchRequired = vi.fn().mockResolvedValue({}); // Empty result for ContinueWithCurrentModel
|
||||
const { result } = renderHook(() =>
|
||||
useVisionAutoSwitch(config, addItem as any, true, onVisionSwitchRequired),
|
||||
);
|
||||
@@ -245,11 +321,12 @@ describe('useVisionAutoSwitch hook', () => {
|
||||
res = await result.current.handleVisionSwitch(parts, userTs, false);
|
||||
});
|
||||
|
||||
expect(addItem).toHaveBeenCalledWith(
|
||||
// Should not add any guidance message
|
||||
expect(addItem).not.toHaveBeenCalledWith(
|
||||
{ type: MessageType.INFO, text: getVisionSwitchGuidanceMessage() },
|
||||
userTs,
|
||||
);
|
||||
expect(res).toEqual({ shouldProceed: false });
|
||||
expect(res).toEqual({ shouldProceed: true });
|
||||
expect(config.setModel).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
@@ -258,7 +335,7 @@ describe('useVisionAutoSwitch hook', () => {
|
||||
const config = createMockConfig(AuthType.QWEN_OAUTH, initialModel);
|
||||
const onVisionSwitchRequired = vi
|
||||
.fn()
|
||||
.mockResolvedValue({ modelOverride: 'qwen-vl-max-latest' });
|
||||
.mockResolvedValue({ modelOverride: 'coder-model' });
|
||||
const { result } = renderHook(() =>
|
||||
useVisionAutoSwitch(config, addItem as any, true, onVisionSwitchRequired),
|
||||
);
|
||||
@@ -273,20 +350,26 @@ describe('useVisionAutoSwitch hook', () => {
|
||||
});
|
||||
|
||||
expect(res).toEqual({ shouldProceed: true, originalModel: initialModel });
|
||||
expect(config.setModel).toHaveBeenCalledWith('qwen-vl-max-latest');
|
||||
expect(config.setModel).toHaveBeenCalledWith('coder-model', {
|
||||
reason: 'vision_auto_switch',
|
||||
context: 'User-prompted vision switch (one-time override)',
|
||||
});
|
||||
|
||||
// Now restore
|
||||
act(() => {
|
||||
result.current.restoreOriginalModel();
|
||||
await act(async () => {
|
||||
await result.current.restoreOriginalModel();
|
||||
});
|
||||
expect(config.setModel).toHaveBeenLastCalledWith(initialModel, {
|
||||
reason: 'vision_auto_switch',
|
||||
context: 'Restoring original model after vision switch',
|
||||
});
|
||||
expect(config.setModel).toHaveBeenLastCalledWith(initialModel);
|
||||
});
|
||||
|
||||
it('persists session model when dialog requests persistence', async () => {
|
||||
const config = createMockConfig(AuthType.QWEN_OAUTH, 'qwen3-coder-plus');
|
||||
const onVisionSwitchRequired = vi
|
||||
.fn()
|
||||
.mockResolvedValue({ persistSessionModel: 'qwen-vl-max-latest' });
|
||||
.mockResolvedValue({ persistSessionModel: 'coder-model' });
|
||||
const { result } = renderHook(() =>
|
||||
useVisionAutoSwitch(config, addItem as any, true, onVisionSwitchRequired),
|
||||
);
|
||||
@@ -301,16 +384,17 @@ describe('useVisionAutoSwitch hook', () => {
|
||||
});
|
||||
|
||||
expect(res).toEqual({ shouldProceed: true });
|
||||
expect(config.setModel).toHaveBeenCalledWith('qwen-vl-max-latest');
|
||||
expect(config.setModel).toHaveBeenCalledWith('coder-model', {
|
||||
reason: 'vision_auto_switch',
|
||||
context: 'User-prompted vision switch (session persistent)',
|
||||
});
|
||||
|
||||
// Restore should be a no-op since no one-time override was used
|
||||
act(() => {
|
||||
result.current.restoreOriginalModel();
|
||||
await act(async () => {
|
||||
await result.current.restoreOriginalModel();
|
||||
});
|
||||
// Last call should still be the persisted model set
|
||||
expect((config.setModel as any).mock.calls.pop()?.[0]).toBe(
|
||||
'qwen-vl-max-latest',
|
||||
);
|
||||
expect((config.setModel as any).mock.calls.pop()?.[0]).toBe('coder-model');
|
||||
});
|
||||
|
||||
it('returns shouldProceed=true when dialog returns no special flags', async () => {
|
||||
@@ -371,4 +455,420 @@ describe('useVisionAutoSwitch hook', () => {
|
||||
expect(res).toEqual({ shouldProceed: true });
|
||||
expect(onVisionSwitchRequired).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
describe('YOLO mode behavior', () => {
|
||||
it('automatically switches to vision model in YOLO mode without showing dialog', async () => {
|
||||
const initialModel = 'qwen3-coder-plus';
|
||||
const config = createMockConfig(
|
||||
AuthType.QWEN_OAUTH,
|
||||
initialModel,
|
||||
ApprovalMode.YOLO,
|
||||
);
|
||||
const onVisionSwitchRequired = vi.fn(); // Should not be called in YOLO mode
|
||||
const { result } = renderHook(() =>
|
||||
useVisionAutoSwitch(
|
||||
config,
|
||||
addItem as any,
|
||||
true,
|
||||
onVisionSwitchRequired,
|
||||
),
|
||||
);
|
||||
|
||||
const parts: PartListUnion = [
|
||||
{ inlineData: { mimeType: 'image/png', data: '...' } },
|
||||
];
|
||||
|
||||
let res: any;
|
||||
await act(async () => {
|
||||
res = await result.current.handleVisionSwitch(parts, 7070, false);
|
||||
});
|
||||
|
||||
// Should automatically switch without calling the dialog
|
||||
expect(onVisionSwitchRequired).not.toHaveBeenCalled();
|
||||
expect(res).toEqual({
|
||||
shouldProceed: true,
|
||||
originalModel: initialModel,
|
||||
});
|
||||
expect(config.setModel).toHaveBeenCalledWith(getDefaultVisionModel(), {
|
||||
reason: 'vision_auto_switch',
|
||||
context: 'YOLO mode auto-switch for image content',
|
||||
});
|
||||
});
|
||||
|
||||
it('does not switch in YOLO mode when no images are present', async () => {
|
||||
const config = createMockConfig(
|
||||
AuthType.QWEN_OAUTH,
|
||||
'qwen3-coder-plus',
|
||||
ApprovalMode.YOLO,
|
||||
);
|
||||
const onVisionSwitchRequired = vi.fn();
|
||||
const { result } = renderHook(() =>
|
||||
useVisionAutoSwitch(
|
||||
config,
|
||||
addItem as any,
|
||||
true,
|
||||
onVisionSwitchRequired,
|
||||
),
|
||||
);
|
||||
|
||||
const parts: PartListUnion = [{ text: 'no images here' }];
|
||||
|
||||
let res: any;
|
||||
await act(async () => {
|
||||
res = await result.current.handleVisionSwitch(parts, 8080, false);
|
||||
});
|
||||
|
||||
expect(res).toEqual({ shouldProceed: true });
|
||||
expect(onVisionSwitchRequired).not.toHaveBeenCalled();
|
||||
expect(config.setModel).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('does not switch in YOLO mode when already using vision model', async () => {
|
||||
const config = createMockConfig(
|
||||
AuthType.QWEN_OAUTH,
|
||||
'vision-model',
|
||||
ApprovalMode.YOLO,
|
||||
);
|
||||
const onVisionSwitchRequired = vi.fn();
|
||||
const { result } = renderHook(() =>
|
||||
useVisionAutoSwitch(
|
||||
config,
|
||||
addItem as any,
|
||||
true,
|
||||
onVisionSwitchRequired,
|
||||
),
|
||||
);
|
||||
|
||||
const parts: PartListUnion = [
|
||||
{ inlineData: { mimeType: 'image/png', data: '...' } },
|
||||
];
|
||||
|
||||
let res: any;
|
||||
await act(async () => {
|
||||
res = await result.current.handleVisionSwitch(parts, 9090, false);
|
||||
});
|
||||
|
||||
expect(res).toEqual({ shouldProceed: true });
|
||||
expect(onVisionSwitchRequired).not.toHaveBeenCalled();
|
||||
expect(config.setModel).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('restores original model after YOLO mode auto-switch', async () => {
|
||||
const initialModel = 'qwen3-coder-plus';
|
||||
const config = createMockConfig(
|
||||
AuthType.QWEN_OAUTH,
|
||||
initialModel,
|
||||
ApprovalMode.YOLO,
|
||||
);
|
||||
const onVisionSwitchRequired = vi.fn();
|
||||
const { result } = renderHook(() =>
|
||||
useVisionAutoSwitch(
|
||||
config,
|
||||
addItem as any,
|
||||
true,
|
||||
onVisionSwitchRequired,
|
||||
),
|
||||
);
|
||||
|
||||
const parts: PartListUnion = [
|
||||
{ inlineData: { mimeType: 'image/png', data: '...' } },
|
||||
];
|
||||
|
||||
// First, trigger the auto-switch
|
||||
await act(async () => {
|
||||
await result.current.handleVisionSwitch(parts, 10100, false);
|
||||
});
|
||||
|
||||
// Verify model was switched
|
||||
expect(config.setModel).toHaveBeenCalledWith(getDefaultVisionModel(), {
|
||||
reason: 'vision_auto_switch',
|
||||
context: 'YOLO mode auto-switch for image content',
|
||||
});
|
||||
|
||||
// Now restore the original model
|
||||
await act(async () => {
|
||||
await result.current.restoreOriginalModel();
|
||||
});
|
||||
|
||||
// Verify model was restored
|
||||
expect(config.setModel).toHaveBeenLastCalledWith(initialModel, {
|
||||
reason: 'vision_auto_switch',
|
||||
context: 'Restoring original model after vision switch',
|
||||
});
|
||||
});
|
||||
|
||||
it('does not switch in YOLO mode when authType is not QWEN_OAUTH', async () => {
|
||||
const config = createMockConfig(
|
||||
AuthType.USE_GEMINI,
|
||||
'qwen3-coder-plus',
|
||||
ApprovalMode.YOLO,
|
||||
);
|
||||
const onVisionSwitchRequired = vi.fn();
|
||||
const { result } = renderHook(() =>
|
||||
useVisionAutoSwitch(
|
||||
config,
|
||||
addItem as any,
|
||||
true,
|
||||
onVisionSwitchRequired,
|
||||
),
|
||||
);
|
||||
|
||||
const parts: PartListUnion = [
|
||||
{ inlineData: { mimeType: 'image/png', data: '...' } },
|
||||
];
|
||||
|
||||
let res: any;
|
||||
await act(async () => {
|
||||
res = await result.current.handleVisionSwitch(parts, 11110, false);
|
||||
});
|
||||
|
||||
expect(res).toEqual({ shouldProceed: true });
|
||||
expect(onVisionSwitchRequired).not.toHaveBeenCalled();
|
||||
expect(config.setModel).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('does not switch in YOLO mode when visionModelPreviewEnabled is false', async () => {
|
||||
const config = createMockConfig(
|
||||
AuthType.QWEN_OAUTH,
|
||||
'qwen3-coder-plus',
|
||||
ApprovalMode.YOLO,
|
||||
);
|
||||
const onVisionSwitchRequired = vi.fn();
|
||||
const { result } = renderHook(() =>
|
||||
useVisionAutoSwitch(
|
||||
config,
|
||||
addItem as any,
|
||||
false,
|
||||
onVisionSwitchRequired,
|
||||
),
|
||||
);
|
||||
|
||||
const parts: PartListUnion = [
|
||||
{ inlineData: { mimeType: 'image/png', data: '...' } },
|
||||
];
|
||||
|
||||
let res: any;
|
||||
await act(async () => {
|
||||
res = await result.current.handleVisionSwitch(parts, 12120, false);
|
||||
});
|
||||
|
||||
expect(res).toEqual({ shouldProceed: true });
|
||||
expect(onVisionSwitchRequired).not.toHaveBeenCalled();
|
||||
expect(config.setModel).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('handles multiple image formats in YOLO mode', async () => {
|
||||
const initialModel = 'qwen3-coder-plus';
|
||||
const config = createMockConfig(
|
||||
AuthType.QWEN_OAUTH,
|
||||
initialModel,
|
||||
ApprovalMode.YOLO,
|
||||
);
|
||||
const onVisionSwitchRequired = vi.fn();
|
||||
const { result } = renderHook(() =>
|
||||
useVisionAutoSwitch(
|
||||
config,
|
||||
addItem as any,
|
||||
true,
|
||||
onVisionSwitchRequired,
|
||||
),
|
||||
);
|
||||
|
||||
const parts: PartListUnion = [
|
||||
{ text: 'Here are some images:' },
|
||||
{ inlineData: { mimeType: 'image/jpeg', data: '...' } },
|
||||
{ fileData: { mimeType: 'image/png', fileUri: 'file://image.png' } },
|
||||
{ text: 'Please analyze them.' },
|
||||
];
|
||||
|
||||
let res: any;
|
||||
await act(async () => {
|
||||
res = await result.current.handleVisionSwitch(parts, 13130, false);
|
||||
});
|
||||
|
||||
expect(res).toEqual({
|
||||
shouldProceed: true,
|
||||
originalModel: initialModel,
|
||||
});
|
||||
expect(config.setModel).toHaveBeenCalledWith(getDefaultVisionModel(), {
|
||||
reason: 'vision_auto_switch',
|
||||
context: 'YOLO mode auto-switch for image content',
|
||||
});
|
||||
expect(onVisionSwitchRequired).not.toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
describe('VLM switch mode default behavior', () => {
|
||||
it('should automatically switch once when vlmSwitchMode is "once"', async () => {
|
||||
const config = createMockConfig(
|
||||
AuthType.QWEN_OAUTH,
|
||||
'qwen3-coder-plus',
|
||||
ApprovalMode.DEFAULT,
|
||||
'once',
|
||||
);
|
||||
const onVisionSwitchRequired = vi.fn(); // Should not be called
|
||||
const { result } = renderHook(() =>
|
||||
useVisionAutoSwitch(
|
||||
config,
|
||||
addItem as any,
|
||||
true,
|
||||
onVisionSwitchRequired,
|
||||
),
|
||||
);
|
||||
|
||||
const parts: PartListUnion = [
|
||||
{ inlineData: { mimeType: 'image/jpeg', data: 'base64data' } },
|
||||
];
|
||||
|
||||
const switchResult = await result.current.handleVisionSwitch(
|
||||
parts,
|
||||
Date.now(),
|
||||
false,
|
||||
);
|
||||
|
||||
expect(switchResult.shouldProceed).toBe(true);
|
||||
expect(switchResult.originalModel).toBe('qwen3-coder-plus');
|
||||
expect(config.setModel).toHaveBeenCalledWith('vision-model', {
|
||||
reason: 'vision_auto_switch',
|
||||
context: 'Default VLM switch mode: once (one-time override)',
|
||||
});
|
||||
expect(onVisionSwitchRequired).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should switch session when vlmSwitchMode is "session"', async () => {
|
||||
const config = createMockConfig(
|
||||
AuthType.QWEN_OAUTH,
|
||||
'qwen3-coder-plus',
|
||||
ApprovalMode.DEFAULT,
|
||||
'session',
|
||||
);
|
||||
const onVisionSwitchRequired = vi.fn(); // Should not be called
|
||||
const { result } = renderHook(() =>
|
||||
useVisionAutoSwitch(
|
||||
config,
|
||||
addItem as any,
|
||||
true,
|
||||
onVisionSwitchRequired,
|
||||
),
|
||||
);
|
||||
|
||||
const parts: PartListUnion = [
|
||||
{ inlineData: { mimeType: 'image/jpeg', data: 'base64data' } },
|
||||
];
|
||||
|
||||
const switchResult = await result.current.handleVisionSwitch(
|
||||
parts,
|
||||
Date.now(),
|
||||
false,
|
||||
);
|
||||
|
||||
expect(switchResult.shouldProceed).toBe(true);
|
||||
expect(switchResult.originalModel).toBeUndefined(); // No original model for session switch
|
||||
expect(config.setModel).toHaveBeenCalledWith('vision-model', {
|
||||
reason: 'vision_auto_switch',
|
||||
context: 'Default VLM switch mode: session (session persistent)',
|
||||
});
|
||||
expect(onVisionSwitchRequired).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should continue with current model when vlmSwitchMode is "persist"', async () => {
|
||||
const config = createMockConfig(
|
||||
AuthType.QWEN_OAUTH,
|
||||
'qwen3-coder-plus',
|
||||
ApprovalMode.DEFAULT,
|
||||
'persist',
|
||||
);
|
||||
const onVisionSwitchRequired = vi.fn(); // Should not be called
|
||||
const { result } = renderHook(() =>
|
||||
useVisionAutoSwitch(
|
||||
config,
|
||||
addItem as any,
|
||||
true,
|
||||
onVisionSwitchRequired,
|
||||
),
|
||||
);
|
||||
|
||||
const parts: PartListUnion = [
|
||||
{ inlineData: { mimeType: 'image/jpeg', data: 'base64data' } },
|
||||
];
|
||||
|
||||
const switchResult = await result.current.handleVisionSwitch(
|
||||
parts,
|
||||
Date.now(),
|
||||
false,
|
||||
);
|
||||
|
||||
expect(switchResult.shouldProceed).toBe(true);
|
||||
expect(switchResult.originalModel).toBeUndefined();
|
||||
expect(config.setModel).not.toHaveBeenCalled();
|
||||
expect(onVisionSwitchRequired).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should fall back to user prompt when vlmSwitchMode is not set', async () => {
|
||||
const config = createMockConfig(
|
||||
AuthType.QWEN_OAUTH,
|
||||
'qwen3-coder-plus',
|
||||
ApprovalMode.DEFAULT,
|
||||
undefined, // No default mode
|
||||
);
|
||||
const onVisionSwitchRequired = vi
|
||||
.fn()
|
||||
.mockResolvedValue({ modelOverride: 'vision-model' });
|
||||
const { result } = renderHook(() =>
|
||||
useVisionAutoSwitch(
|
||||
config,
|
||||
addItem as any,
|
||||
true,
|
||||
onVisionSwitchRequired,
|
||||
),
|
||||
);
|
||||
|
||||
const parts: PartListUnion = [
|
||||
{ inlineData: { mimeType: 'image/jpeg', data: 'base64data' } },
|
||||
];
|
||||
|
||||
const switchResult = await result.current.handleVisionSwitch(
|
||||
parts,
|
||||
Date.now(),
|
||||
false,
|
||||
);
|
||||
|
||||
expect(switchResult.shouldProceed).toBe(true);
|
||||
expect(onVisionSwitchRequired).toHaveBeenCalledWith(parts);
|
||||
});
|
||||
|
||||
it('should fall back to persist behavior when vlmSwitchMode has invalid value', async () => {
|
||||
const config = createMockConfig(
|
||||
AuthType.QWEN_OAUTH,
|
||||
'qwen3-coder-plus',
|
||||
ApprovalMode.DEFAULT,
|
||||
'invalid-value',
|
||||
);
|
||||
const onVisionSwitchRequired = vi.fn(); // Should not be called
|
||||
const { result } = renderHook(() =>
|
||||
useVisionAutoSwitch(
|
||||
config,
|
||||
addItem as any,
|
||||
true,
|
||||
onVisionSwitchRequired,
|
||||
),
|
||||
);
|
||||
|
||||
const parts: PartListUnion = [
|
||||
{ inlineData: { mimeType: 'image/jpeg', data: 'base64data' } },
|
||||
];
|
||||
|
||||
const switchResult = await result.current.handleVisionSwitch(
|
||||
parts,
|
||||
Date.now(),
|
||||
false,
|
||||
);
|
||||
|
||||
expect(switchResult.shouldProceed).toBe(true);
|
||||
expect(switchResult.originalModel).toBeUndefined();
|
||||
// For invalid values, it should continue with current model (persist behavior)
|
||||
expect(config.setModel).not.toHaveBeenCalled();
|
||||
expect(onVisionSwitchRequired).not.toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
*/
|
||||
|
||||
import { type PartListUnion, type Part } from '@google/genai';
|
||||
import { AuthType, type Config } from '@qwen-code/qwen-code-core';
|
||||
import { AuthType, type Config, ApprovalMode } from '@qwen-code/qwen-code-core';
|
||||
import { useCallback, useRef } from 'react';
|
||||
import { VisionSwitchOutcome } from '../components/ModelSwitchDialog.js';
|
||||
import {
|
||||
@@ -121,7 +121,7 @@ export function shouldOfferVisionSwitch(
|
||||
parts: PartListUnion,
|
||||
authType: AuthType,
|
||||
currentModel: string,
|
||||
visionModelPreviewEnabled: boolean = false,
|
||||
visionModelPreviewEnabled: boolean = true,
|
||||
): boolean {
|
||||
// Only trigger for qwen-oauth
|
||||
if (authType !== AuthType.QWEN_OAUTH) {
|
||||
@@ -166,11 +166,11 @@ export function processVisionSwitchOutcome(
|
||||
case VisionSwitchOutcome.SwitchSessionToVL:
|
||||
return { persistSessionModel: vlModelId };
|
||||
|
||||
case VisionSwitchOutcome.DisallowWithGuidance:
|
||||
return { showGuidance: true };
|
||||
case VisionSwitchOutcome.ContinueWithCurrentModel:
|
||||
return {}; // Continue with current model, no changes needed
|
||||
|
||||
default:
|
||||
return { showGuidance: true };
|
||||
return {}; // Default to continuing with current model
|
||||
}
|
||||
}
|
||||
|
||||
@@ -198,7 +198,7 @@ export interface VisionSwitchHandlingResult {
|
||||
export function useVisionAutoSwitch(
|
||||
config: Config,
|
||||
addItem: UseHistoryManagerReturn['addItem'],
|
||||
visionModelPreviewEnabled: boolean = false,
|
||||
visionModelPreviewEnabled: boolean = true,
|
||||
onVisionSwitchRequired?: (query: PartListUnion) => Promise<{
|
||||
modelOverride?: string;
|
||||
persistSessionModel?: string;
|
||||
@@ -252,35 +252,91 @@ export function useVisionAutoSwitch(
|
||||
return { shouldProceed: true };
|
||||
}
|
||||
|
||||
try {
|
||||
const visionSwitchResult = await onVisionSwitchRequired(query);
|
||||
// In YOLO mode, automatically switch to vision model without user interaction
|
||||
if (config.getApprovalMode() === ApprovalMode.YOLO) {
|
||||
const vlModelId = getDefaultVisionModel();
|
||||
originalModelRef.current = config.getModel();
|
||||
await config.setModel(vlModelId, {
|
||||
reason: 'vision_auto_switch',
|
||||
context: 'YOLO mode auto-switch for image content',
|
||||
});
|
||||
return {
|
||||
shouldProceed: true,
|
||||
originalModel: originalModelRef.current,
|
||||
};
|
||||
}
|
||||
|
||||
if (visionSwitchResult.showGuidance) {
|
||||
// Show guidance and don't proceed with the request
|
||||
addItem(
|
||||
{
|
||||
type: MessageType.INFO,
|
||||
text: getVisionSwitchGuidanceMessage(),
|
||||
},
|
||||
userMessageTimestamp,
|
||||
);
|
||||
return { shouldProceed: false };
|
||||
// Check if there's a default VLM switch mode configured
|
||||
const defaultVlmSwitchMode = config.getVlmSwitchMode();
|
||||
if (defaultVlmSwitchMode) {
|
||||
// Convert string value to VisionSwitchOutcome enum
|
||||
let outcome: VisionSwitchOutcome;
|
||||
switch (defaultVlmSwitchMode) {
|
||||
case 'once':
|
||||
outcome = VisionSwitchOutcome.SwitchOnce;
|
||||
break;
|
||||
case 'session':
|
||||
outcome = VisionSwitchOutcome.SwitchSessionToVL;
|
||||
break;
|
||||
case 'persist':
|
||||
outcome = VisionSwitchOutcome.ContinueWithCurrentModel;
|
||||
break;
|
||||
default:
|
||||
// Invalid value, fall back to prompting user
|
||||
outcome = VisionSwitchOutcome.ContinueWithCurrentModel;
|
||||
}
|
||||
|
||||
// Process the default outcome
|
||||
const visionSwitchResult = processVisionSwitchOutcome(outcome);
|
||||
|
||||
if (visionSwitchResult.modelOverride) {
|
||||
// One-time model override
|
||||
originalModelRef.current = config.getModel();
|
||||
config.setModel(visionSwitchResult.modelOverride);
|
||||
await config.setModel(visionSwitchResult.modelOverride, {
|
||||
reason: 'vision_auto_switch',
|
||||
context: `Default VLM switch mode: ${defaultVlmSwitchMode} (one-time override)`,
|
||||
});
|
||||
return {
|
||||
shouldProceed: true,
|
||||
originalModel: originalModelRef.current,
|
||||
};
|
||||
} else if (visionSwitchResult.persistSessionModel) {
|
||||
// Persistent session model change
|
||||
config.setModel(visionSwitchResult.persistSessionModel);
|
||||
await config.setModel(visionSwitchResult.persistSessionModel, {
|
||||
reason: 'vision_auto_switch',
|
||||
context: `Default VLM switch mode: ${defaultVlmSwitchMode} (session persistent)`,
|
||||
});
|
||||
return { shouldProceed: true };
|
||||
}
|
||||
|
||||
// For ContinueWithCurrentModel or any other case, proceed with current model
|
||||
return { shouldProceed: true };
|
||||
}
|
||||
|
||||
try {
|
||||
const visionSwitchResult = await onVisionSwitchRequired(query);
|
||||
|
||||
if (visionSwitchResult.modelOverride) {
|
||||
// One-time model override
|
||||
originalModelRef.current = config.getModel();
|
||||
await config.setModel(visionSwitchResult.modelOverride, {
|
||||
reason: 'vision_auto_switch',
|
||||
context: 'User-prompted vision switch (one-time override)',
|
||||
});
|
||||
return {
|
||||
shouldProceed: true,
|
||||
originalModel: originalModelRef.current,
|
||||
};
|
||||
} else if (visionSwitchResult.persistSessionModel) {
|
||||
// Persistent session model change
|
||||
await config.setModel(visionSwitchResult.persistSessionModel, {
|
||||
reason: 'vision_auto_switch',
|
||||
context: 'User-prompted vision switch (session persistent)',
|
||||
});
|
||||
return { shouldProceed: true };
|
||||
}
|
||||
|
||||
// For ContinueWithCurrentModel or any other case, proceed with current model
|
||||
return { shouldProceed: true };
|
||||
} catch (_error) {
|
||||
// If vision switch dialog was cancelled or errored, don't proceed
|
||||
@@ -290,9 +346,12 @@ export function useVisionAutoSwitch(
|
||||
[config, addItem, visionModelPreviewEnabled, onVisionSwitchRequired],
|
||||
);
|
||||
|
||||
const restoreOriginalModel = useCallback(() => {
|
||||
const restoreOriginalModel = useCallback(async () => {
|
||||
if (originalModelRef.current) {
|
||||
config.setModel(originalModelRef.current);
|
||||
await config.setModel(originalModelRef.current, {
|
||||
reason: 'vision_auto_switch',
|
||||
context: 'Restoring original model after vision switch',
|
||||
});
|
||||
originalModelRef.current = null;
|
||||
}
|
||||
}, [config]);
|
||||
|
||||
@@ -10,9 +10,12 @@ export type AvailableModel = {
|
||||
isVision?: boolean;
|
||||
};
|
||||
|
||||
export const MAINLINE_VLM = 'vision-model';
|
||||
export const MAINLINE_CODER = 'coder-model';
|
||||
|
||||
export const AVAILABLE_MODELS_QWEN: AvailableModel[] = [
|
||||
{ id: 'qwen3-coder-plus', label: 'qwen3-coder-plus' },
|
||||
{ id: 'qwen-vl-max-latest', label: 'qwen-vl-max', isVision: true },
|
||||
{ id: MAINLINE_CODER, label: MAINLINE_CODER },
|
||||
{ id: MAINLINE_VLM, label: MAINLINE_VLM, isVision: true },
|
||||
];
|
||||
|
||||
/**
|
||||
@@ -42,7 +45,7 @@ export function getOpenAIAvailableModelFromEnv(): AvailableModel | null {
|
||||
* until our coding model supports multimodal.
|
||||
*/
|
||||
export function getDefaultVisionModel(): string {
|
||||
return 'qwen-vl-max-latest';
|
||||
return MAINLINE_VLM;
|
||||
}
|
||||
|
||||
export function isVisionModel(modelId: string): boolean {
|
||||
|
||||
@@ -9,7 +9,6 @@ import { describe, it, expect, vi, beforeEach } from 'vitest';
|
||||
import { MarkdownDisplay } from './MarkdownDisplay.js';
|
||||
import { LoadedSettings } from '../../config/settings.js';
|
||||
import { SettingsContext } from '../contexts/SettingsContext.js';
|
||||
import { EOL } from 'node:os';
|
||||
|
||||
describe('<MarkdownDisplay />', () => {
|
||||
const baseProps = {
|
||||
@@ -57,7 +56,7 @@ describe('<MarkdownDisplay />', () => {
|
||||
## Header 2
|
||||
### Header 3
|
||||
#### Header 4
|
||||
`.replace(/\n/g, EOL);
|
||||
`;
|
||||
const { lastFrame } = render(
|
||||
<SettingsContext.Provider value={mockSettings}>
|
||||
<MarkdownDisplay {...baseProps} text={text} />
|
||||
@@ -67,10 +66,7 @@ describe('<MarkdownDisplay />', () => {
|
||||
});
|
||||
|
||||
it('renders a fenced code block with a language', () => {
|
||||
const text = '```javascript\nconst x = 1;\nconsole.log(x);\n```'.replace(
|
||||
/\n/g,
|
||||
EOL,
|
||||
);
|
||||
const text = '```javascript\nconst x = 1;\nconsole.log(x);\n```';
|
||||
const { lastFrame } = render(
|
||||
<SettingsContext.Provider value={mockSettings}>
|
||||
<MarkdownDisplay {...baseProps} text={text} />
|
||||
@@ -80,7 +76,7 @@ describe('<MarkdownDisplay />', () => {
|
||||
});
|
||||
|
||||
it('renders a fenced code block without a language', () => {
|
||||
const text = '```\nplain text\n```'.replace(/\n/g, EOL);
|
||||
const text = '```\nplain text\n```';
|
||||
const { lastFrame } = render(
|
||||
<SettingsContext.Provider value={mockSettings}>
|
||||
<MarkdownDisplay {...baseProps} text={text} />
|
||||
@@ -90,7 +86,7 @@ describe('<MarkdownDisplay />', () => {
|
||||
});
|
||||
|
||||
it('handles unclosed (pending) code blocks', () => {
|
||||
const text = '```typescript\nlet y = 2;'.replace(/\n/g, EOL);
|
||||
const text = '```typescript\nlet y = 2;';
|
||||
const { lastFrame } = render(
|
||||
<SettingsContext.Provider value={mockSettings}>
|
||||
<MarkdownDisplay {...baseProps} text={text} isPending={true} />
|
||||
@@ -104,7 +100,7 @@ describe('<MarkdownDisplay />', () => {
|
||||
- item A
|
||||
* item B
|
||||
+ item C
|
||||
`.replace(/\n/g, EOL);
|
||||
`;
|
||||
const { lastFrame } = render(
|
||||
<SettingsContext.Provider value={mockSettings}>
|
||||
<MarkdownDisplay {...baseProps} text={text} />
|
||||
@@ -118,7 +114,7 @@ describe('<MarkdownDisplay />', () => {
|
||||
* Level 1
|
||||
* Level 2
|
||||
* Level 3
|
||||
`.replace(/\n/g, EOL);
|
||||
`;
|
||||
const { lastFrame } = render(
|
||||
<SettingsContext.Provider value={mockSettings}>
|
||||
<MarkdownDisplay {...baseProps} text={text} />
|
||||
@@ -131,7 +127,7 @@ describe('<MarkdownDisplay />', () => {
|
||||
const text = `
|
||||
1. First item
|
||||
2. Second item
|
||||
`.replace(/\n/g, EOL);
|
||||
`;
|
||||
const { lastFrame } = render(
|
||||
<SettingsContext.Provider value={mockSettings}>
|
||||
<MarkdownDisplay {...baseProps} text={text} />
|
||||
@@ -147,7 +143,7 @@ Hello
|
||||
World
|
||||
***
|
||||
Test
|
||||
`.replace(/\n/g, EOL);
|
||||
`;
|
||||
const { lastFrame } = render(
|
||||
<SettingsContext.Provider value={mockSettings}>
|
||||
<MarkdownDisplay {...baseProps} text={text} />
|
||||
@@ -162,7 +158,7 @@ Test
|
||||
|----------|:--------:|
|
||||
| Cell 1 | Cell 2 |
|
||||
| Cell 3 | Cell 4 |
|
||||
`.replace(/\n/g, EOL);
|
||||
`;
|
||||
const { lastFrame } = render(
|
||||
<SettingsContext.Provider value={mockSettings}>
|
||||
<MarkdownDisplay {...baseProps} text={text} />
|
||||
@@ -176,7 +172,7 @@ Test
|
||||
Some text before.
|
||||
| A | B |
|
||||
|---|
|
||||
| 1 | 2 |`.replace(/\n/g, EOL);
|
||||
| 1 | 2 |`;
|
||||
const { lastFrame } = render(
|
||||
<SettingsContext.Provider value={mockSettings}>
|
||||
<MarkdownDisplay {...baseProps} text={text} />
|
||||
@@ -188,7 +184,7 @@ Some text before.
|
||||
it('inserts a single space between paragraphs', () => {
|
||||
const text = `Paragraph 1.
|
||||
|
||||
Paragraph 2.`.replace(/\n/g, EOL);
|
||||
Paragraph 2.`;
|
||||
const { lastFrame } = render(
|
||||
<SettingsContext.Provider value={mockSettings}>
|
||||
<MarkdownDisplay {...baseProps} text={text} />
|
||||
@@ -211,7 +207,7 @@ some code
|
||||
\`\`\`
|
||||
|
||||
Another paragraph.
|
||||
`.replace(/\n/g, EOL);
|
||||
`;
|
||||
const { lastFrame } = render(
|
||||
<SettingsContext.Provider value={mockSettings}>
|
||||
<MarkdownDisplay {...baseProps} text={text} />
|
||||
@@ -221,7 +217,7 @@ Another paragraph.
|
||||
});
|
||||
|
||||
it('hides line numbers in code blocks when showLineNumbers is false', () => {
|
||||
const text = '```javascript\nconst x = 1;\n```'.replace(/\n/g, EOL);
|
||||
const text = '```javascript\nconst x = 1;\n```';
|
||||
const settings = new LoadedSettings(
|
||||
{ path: '', settings: {} },
|
||||
{ path: '', settings: {} },
|
||||
@@ -242,7 +238,7 @@ Another paragraph.
|
||||
});
|
||||
|
||||
it('shows line numbers in code blocks by default', () => {
|
||||
const text = '```javascript\nconst x = 1;\n```'.replace(/\n/g, EOL);
|
||||
const text = '```javascript\nconst x = 1;\n```';
|
||||
const { lastFrame } = render(
|
||||
<SettingsContext.Provider value={mockSettings}>
|
||||
<MarkdownDisplay {...baseProps} text={text} />
|
||||
@@ -251,4 +247,21 @@ Another paragraph.
|
||||
expect(lastFrame()).toMatchSnapshot();
|
||||
expect(lastFrame()).toContain(' 1 ');
|
||||
});
|
||||
|
||||
it('correctly splits lines using \\n regardless of platform EOL', () => {
|
||||
// Test that the component uses \n for splitting, not EOL
|
||||
const textWithUnixLineEndings = 'Line 1\nLine 2\nLine 3';
|
||||
|
||||
const { lastFrame } = render(
|
||||
<SettingsContext.Provider value={mockSettings}>
|
||||
<MarkdownDisplay {...baseProps} text={textWithUnixLineEndings} />
|
||||
</SettingsContext.Provider>,
|
||||
);
|
||||
|
||||
const output = lastFrame();
|
||||
expect(output).toContain('Line 1');
|
||||
expect(output).toContain('Line 2');
|
||||
expect(output).toContain('Line 3');
|
||||
expect(output).toMatchSnapshot();
|
||||
});
|
||||
});
|
||||
|
||||
@@ -6,7 +6,6 @@
|
||||
|
||||
import React from 'react';
|
||||
import { Text, Box } from 'ink';
|
||||
import { EOL } from 'node:os';
|
||||
import { Colors } from '../colors.js';
|
||||
import { colorizeCode } from './CodeColorizer.js';
|
||||
import { TableRenderer } from './TableRenderer.js';
|
||||
@@ -35,7 +34,7 @@ const MarkdownDisplayInternal: React.FC<MarkdownDisplayProps> = ({
|
||||
}) => {
|
||||
if (!text) return <></>;
|
||||
|
||||
const lines = text.split(EOL);
|
||||
const lines = text.split(`\n`);
|
||||
const headerRegex = /^ *(#{1,4}) +(.*)/;
|
||||
const codeFenceRegex = /^ *(`{3,}|~{3,}) *(\w*?) *$/;
|
||||
const ulItemRegex = /^([ \t]*)([-*+]) +(.*)/;
|
||||
|
||||
@@ -14,6 +14,12 @@ Another paragraph.
|
||||
"
|
||||
`;
|
||||
|
||||
exports[`<MarkdownDisplay /> > correctly splits lines using \\n regardless of platform EOL 1`] = `
|
||||
"Line 1
|
||||
Line 2
|
||||
Line 3"
|
||||
`;
|
||||
|
||||
exports[`<MarkdownDisplay /> > handles a table at the end of the input 1`] = `
|
||||
"Some text before.
|
||||
| A | B |
|
||||
|
||||
@@ -126,6 +126,18 @@ describe('validateNonInterActiveAuth', () => {
|
||||
expect(refreshAuthMock).toHaveBeenCalledWith(AuthType.USE_OPENAI);
|
||||
});
|
||||
|
||||
it('uses configured QWEN_OAUTH if provided', async () => {
|
||||
const nonInteractiveConfig: NonInteractiveConfig = {
|
||||
refreshAuth: refreshAuthMock,
|
||||
};
|
||||
await validateNonInteractiveAuth(
|
||||
AuthType.QWEN_OAUTH,
|
||||
undefined,
|
||||
nonInteractiveConfig,
|
||||
);
|
||||
expect(refreshAuthMock).toHaveBeenCalledWith(AuthType.QWEN_OAUTH);
|
||||
});
|
||||
|
||||
it('uses USE_VERTEX_AI if GOOGLE_GENAI_USE_VERTEXAI is true (with GOOGLE_CLOUD_PROJECT and GOOGLE_CLOUD_LOCATION)', async () => {
|
||||
process.env['GOOGLE_GENAI_USE_VERTEXAI'] = 'true';
|
||||
process.env['GOOGLE_CLOUD_PROJECT'] = 'test-project';
|
||||
|
||||
@@ -97,6 +97,18 @@ class GeminiAgent {
|
||||
name: 'Vertex AI',
|
||||
description: null,
|
||||
},
|
||||
{
|
||||
id: AuthType.USE_OPENAI,
|
||||
name: 'Use OpenAI API key',
|
||||
description:
|
||||
'Requires setting the `OPENAI_API_KEY` environment variable',
|
||||
},
|
||||
{
|
||||
id: AuthType.QWEN_OAUTH,
|
||||
name: 'Qwen OAuth',
|
||||
description:
|
||||
'OAuth authentication for Qwen models with 2000 daily requests',
|
||||
},
|
||||
];
|
||||
|
||||
return {
|
||||
@@ -871,6 +883,16 @@ function toToolCallContent(toolResult: ToolResult): acp.ToolCallContent | null {
|
||||
type: 'content',
|
||||
content: { type: 'text', text: todoText },
|
||||
};
|
||||
} else if (
|
||||
'type' in toolResult.returnDisplay &&
|
||||
toolResult.returnDisplay.type === 'plan_summary'
|
||||
) {
|
||||
const planDisplay = toolResult.returnDisplay;
|
||||
const planText = `${planDisplay.message}\n\n${planDisplay.plan}`;
|
||||
return {
|
||||
type: 'content',
|
||||
content: { type: 'text', text: planText },
|
||||
};
|
||||
} else if ('fileDiff' in toolResult.returnDisplay) {
|
||||
// Handle FileDiff
|
||||
return {
|
||||
@@ -942,6 +964,15 @@ function toPermissionOptions(
|
||||
},
|
||||
...basicPermissionOptions,
|
||||
];
|
||||
case 'plan':
|
||||
return [
|
||||
{
|
||||
optionId: ToolConfirmationOutcome.ProceedAlways,
|
||||
name: `Always Allow Plans`,
|
||||
kind: 'allow_always',
|
||||
},
|
||||
...basicPermissionOptions,
|
||||
];
|
||||
default: {
|
||||
const unreachable: never = confirmation;
|
||||
throw new Error(`Unexpected: ${unreachable}`);
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@qwen-code/qwen-code-core",
|
||||
"version": "0.0.11",
|
||||
"version": "0.0.14-nightly.1",
|
||||
"description": "Qwen Code Core",
|
||||
"repository": {
|
||||
"type": "git",
|
||||
|
||||
@@ -710,6 +710,18 @@ describe('setApprovalMode with folder trust', () => {
|
||||
expect(() => config.setApprovalMode(ApprovalMode.DEFAULT)).not.toThrow();
|
||||
});
|
||||
|
||||
it('should NOT throw an error when setting PLAN mode in an untrusted folder', () => {
|
||||
const config = new Config({
|
||||
sessionId: 'test',
|
||||
targetDir: '.',
|
||||
debugMode: false,
|
||||
model: 'test-model',
|
||||
cwd: '.',
|
||||
trustedFolder: false, // Untrusted
|
||||
});
|
||||
expect(() => config.setApprovalMode(ApprovalMode.PLAN)).not.toThrow();
|
||||
});
|
||||
|
||||
it('should NOT throw an error when setting any mode in a trusted folder', () => {
|
||||
const config = new Config({
|
||||
sessionId: 'test',
|
||||
@@ -722,6 +734,7 @@ describe('setApprovalMode with folder trust', () => {
|
||||
expect(() => config.setApprovalMode(ApprovalMode.YOLO)).not.toThrow();
|
||||
expect(() => config.setApprovalMode(ApprovalMode.AUTO_EDIT)).not.toThrow();
|
||||
expect(() => config.setApprovalMode(ApprovalMode.DEFAULT)).not.toThrow();
|
||||
expect(() => config.setApprovalMode(ApprovalMode.PLAN)).not.toThrow();
|
||||
});
|
||||
|
||||
it('should NOT throw an error when setting any mode if trustedFolder is undefined', () => {
|
||||
@@ -736,5 +749,87 @@ describe('setApprovalMode with folder trust', () => {
|
||||
expect(() => config.setApprovalMode(ApprovalMode.YOLO)).not.toThrow();
|
||||
expect(() => config.setApprovalMode(ApprovalMode.AUTO_EDIT)).not.toThrow();
|
||||
expect(() => config.setApprovalMode(ApprovalMode.DEFAULT)).not.toThrow();
|
||||
expect(() => config.setApprovalMode(ApprovalMode.PLAN)).not.toThrow();
|
||||
});
|
||||
|
||||
describe('Model Switch Logging', () => {
|
||||
it('should log model switch when setModel is called with different model', async () => {
|
||||
const config = new Config({
|
||||
sessionId: 'test-model-switch',
|
||||
targetDir: '.',
|
||||
debugMode: false,
|
||||
model: 'qwen3-coder-plus',
|
||||
cwd: '.',
|
||||
});
|
||||
|
||||
// Initialize the config to set up content generator
|
||||
await config.initialize();
|
||||
|
||||
// Mock the logger's logModelSwitch method
|
||||
const logModelSwitchSpy = vi.spyOn(config['logger']!, 'logModelSwitch');
|
||||
|
||||
// Change the model
|
||||
await config.setModel('qwen-vl-max-latest', {
|
||||
reason: 'vision_auto_switch',
|
||||
context: 'Test model switch',
|
||||
});
|
||||
|
||||
// Verify that logModelSwitch was called with correct parameters
|
||||
expect(logModelSwitchSpy).toHaveBeenCalledWith({
|
||||
fromModel: 'qwen3-coder-plus',
|
||||
toModel: 'qwen-vl-max-latest',
|
||||
reason: 'vision_auto_switch',
|
||||
context: 'Test model switch',
|
||||
});
|
||||
});
|
||||
|
||||
it('should not log when setModel is called with same model', async () => {
|
||||
const config = new Config({
|
||||
sessionId: 'test-same-model',
|
||||
targetDir: '.',
|
||||
debugMode: false,
|
||||
model: 'qwen3-coder-plus',
|
||||
cwd: '.',
|
||||
});
|
||||
|
||||
// Initialize the config to set up content generator
|
||||
await config.initialize();
|
||||
|
||||
// Mock the logger's logModelSwitch method
|
||||
const logModelSwitchSpy = vi.spyOn(config['logger']!, 'logModelSwitch');
|
||||
|
||||
// Set the same model
|
||||
await config.setModel('qwen3-coder-plus');
|
||||
|
||||
// Verify that logModelSwitch was not called
|
||||
expect(logModelSwitchSpy).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should use default reason when no options provided', async () => {
|
||||
const config = new Config({
|
||||
sessionId: 'test-default-reason',
|
||||
targetDir: '.',
|
||||
debugMode: false,
|
||||
model: 'qwen3-coder-plus',
|
||||
cwd: '.',
|
||||
});
|
||||
|
||||
// Initialize the config to set up content generator
|
||||
await config.initialize();
|
||||
|
||||
// Mock the logger's logModelSwitch method
|
||||
const logModelSwitchSpy = vi.spyOn(config['logger']!, 'logModelSwitch');
|
||||
|
||||
// Change the model without options
|
||||
await config.setModel('qwen-vl-max-latest');
|
||||
|
||||
// Verify that logModelSwitch was called with default reason
|
||||
expect(logModelSwitchSpy).toHaveBeenCalledWith({
|
||||
fromModel: 'qwen3-coder-plus',
|
||||
toModel: 'qwen-vl-max-latest',
|
||||
reason: 'manual',
|
||||
context: undefined,
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -33,6 +33,7 @@ import {
|
||||
import { logCliConfiguration, logIdeConnection } from '../telemetry/loggers.js';
|
||||
import { IdeConnectionEvent, IdeConnectionType } from '../telemetry/types.js';
|
||||
import { EditTool } from '../tools/edit.js';
|
||||
import { ExitPlanModeTool } from '../tools/exitPlanMode.js';
|
||||
import { GlobTool } from '../tools/glob.js';
|
||||
import { GrepTool } from '../tools/grep.js';
|
||||
import { LSTool } from '../tools/ls.js';
|
||||
@@ -56,16 +57,20 @@ import {
|
||||
DEFAULT_GEMINI_FLASH_MODEL,
|
||||
} from './models.js';
|
||||
import { Storage } from './storage.js';
|
||||
import { Logger, type ModelSwitchEvent } from '../core/logger.js';
|
||||
|
||||
// Re-export OAuth config type
|
||||
export type { AnyToolInvocation, MCPOAuthConfig };
|
||||
|
||||
export enum ApprovalMode {
|
||||
PLAN = 'plan',
|
||||
DEFAULT = 'default',
|
||||
AUTO_EDIT = 'autoEdit',
|
||||
AUTO_EDIT = 'auto-edit',
|
||||
YOLO = 'yolo',
|
||||
}
|
||||
|
||||
export const APPROVAL_MODES = Object.values(ApprovalMode);
|
||||
|
||||
export interface AccessibilitySettings {
|
||||
disableLoadingPhrases?: boolean;
|
||||
screenReader?: boolean;
|
||||
@@ -239,6 +244,7 @@ export interface ConfigParameters {
|
||||
extensionManagement?: boolean;
|
||||
enablePromptCompletion?: boolean;
|
||||
skipLoopDetection?: boolean;
|
||||
vlmSwitchMode?: string;
|
||||
}
|
||||
|
||||
export class Config {
|
||||
@@ -330,9 +336,11 @@ export class Config {
|
||||
private readonly extensionManagement: boolean;
|
||||
private readonly enablePromptCompletion: boolean = false;
|
||||
private readonly skipLoopDetection: boolean;
|
||||
private readonly vlmSwitchMode: string | undefined;
|
||||
private initialized: boolean = false;
|
||||
readonly storage: Storage;
|
||||
private readonly fileExclusions: FileExclusions;
|
||||
private logger: Logger | null = null;
|
||||
|
||||
constructor(params: ConfigParameters) {
|
||||
this.sessionId = params.sessionId;
|
||||
@@ -424,8 +432,15 @@ export class Config {
|
||||
this.extensionManagement = params.extensionManagement ?? false;
|
||||
this.storage = new Storage(this.targetDir);
|
||||
this.enablePromptCompletion = params.enablePromptCompletion ?? false;
|
||||
this.vlmSwitchMode = params.vlmSwitchMode;
|
||||
this.fileExclusions = new FileExclusions(this);
|
||||
|
||||
// Initialize logger asynchronously
|
||||
this.logger = new Logger(this.sessionId, this.storage);
|
||||
this.logger.initialize().catch((error) => {
|
||||
console.debug('Failed to initialize logger:', error);
|
||||
});
|
||||
|
||||
if (params.contextFileName) {
|
||||
setGeminiMdFilename(params.contextFileName);
|
||||
}
|
||||
@@ -517,21 +532,47 @@ export class Config {
|
||||
return this.contentGeneratorConfig?.model || this.model;
|
||||
}
|
||||
|
||||
setModel(newModel: string): void {
|
||||
async setModel(
|
||||
newModel: string,
|
||||
options?: {
|
||||
reason?: ModelSwitchEvent['reason'];
|
||||
context?: string;
|
||||
},
|
||||
): Promise<void> {
|
||||
const oldModel = this.getModel();
|
||||
|
||||
if (this.contentGeneratorConfig) {
|
||||
this.contentGeneratorConfig.model = newModel;
|
||||
}
|
||||
|
||||
// Log the model switch if the model actually changed
|
||||
if (oldModel !== newModel && this.logger) {
|
||||
const switchEvent: ModelSwitchEvent = {
|
||||
fromModel: oldModel,
|
||||
toModel: newModel,
|
||||
reason: options?.reason || 'manual',
|
||||
context: options?.context,
|
||||
};
|
||||
|
||||
// Log asynchronously to avoid blocking
|
||||
this.logger.logModelSwitch(switchEvent).catch((error) => {
|
||||
console.debug('Failed to log model switch:', error);
|
||||
});
|
||||
}
|
||||
|
||||
// Reinitialize chat with updated configuration while preserving history
|
||||
const geminiClient = this.getGeminiClient();
|
||||
if (geminiClient && geminiClient.isInitialized()) {
|
||||
// Use async operation but don't await to avoid blocking
|
||||
geminiClient.reinitialize().catch((error) => {
|
||||
// Now await the reinitialize operation to ensure completion
|
||||
try {
|
||||
await geminiClient.reinitialize();
|
||||
} catch (error) {
|
||||
console.error(
|
||||
'Failed to reinitialize chat with updated config:',
|
||||
error,
|
||||
);
|
||||
});
|
||||
throw error; // Re-throw to let callers handle the error
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -663,7 +704,11 @@ export class Config {
|
||||
}
|
||||
|
||||
setApprovalMode(mode: ApprovalMode): void {
|
||||
if (this.isTrustedFolder() === false && mode !== ApprovalMode.DEFAULT) {
|
||||
if (
|
||||
this.isTrustedFolder() === false &&
|
||||
mode !== ApprovalMode.DEFAULT &&
|
||||
mode !== ApprovalMode.PLAN
|
||||
) {
|
||||
throw new Error(
|
||||
'Cannot enable privileged approval modes in an untrusted folder.',
|
||||
);
|
||||
@@ -938,6 +983,10 @@ export class Config {
|
||||
return this.skipLoopDetection;
|
||||
}
|
||||
|
||||
getVlmSwitchMode(): string | undefined {
|
||||
return this.vlmSwitchMode;
|
||||
}
|
||||
|
||||
async getGitService(): Promise<GitService> {
|
||||
if (!this.gitService) {
|
||||
this.gitService = new GitService(this.targetDir, this.storage);
|
||||
@@ -1002,11 +1051,12 @@ export class Config {
|
||||
registerCoreTool(GlobTool, this);
|
||||
registerCoreTool(EditTool, this);
|
||||
registerCoreTool(WriteFileTool, this);
|
||||
registerCoreTool(WebFetchTool, this);
|
||||
registerCoreTool(ReadManyFilesTool, this);
|
||||
registerCoreTool(ShellTool, this);
|
||||
registerCoreTool(MemoryTool);
|
||||
registerCoreTool(TodoWriteTool, this);
|
||||
registerCoreTool(ExitPlanModeTool, this);
|
||||
registerCoreTool(WebFetchTool, this);
|
||||
// Conditionally register web search tool only if Tavily API key is set
|
||||
if (this.getTavilyApiKey()) {
|
||||
registerCoreTool(WebSearchTool, this);
|
||||
|
||||
@@ -41,7 +41,7 @@ describe('Flash Model Fallback Configuration', () => {
|
||||
// with the fallback mechanism. This will be necessary we introduce more
|
||||
// intelligent model routing.
|
||||
describe('setModel', () => {
|
||||
it('should only mark as switched if contentGeneratorConfig exists', () => {
|
||||
it('should only mark as switched if contentGeneratorConfig exists', async () => {
|
||||
// Create config without initializing contentGeneratorConfig
|
||||
const newConfig = new Config({
|
||||
sessionId: 'test-session-2',
|
||||
@@ -52,15 +52,15 @@ describe('Flash Model Fallback Configuration', () => {
|
||||
});
|
||||
|
||||
// Should not crash when contentGeneratorConfig is undefined
|
||||
newConfig.setModel(DEFAULT_GEMINI_FLASH_MODEL);
|
||||
await newConfig.setModel(DEFAULT_GEMINI_FLASH_MODEL);
|
||||
expect(newConfig.isInFallbackMode()).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe('getModel', () => {
|
||||
it('should return contentGeneratorConfig model if available', () => {
|
||||
it('should return contentGeneratorConfig model if available', async () => {
|
||||
// Simulate initialized content generator config
|
||||
config.setModel(DEFAULT_GEMINI_FLASH_MODEL);
|
||||
await config.setModel(DEFAULT_GEMINI_FLASH_MODEL);
|
||||
expect(config.getModel()).toBe(DEFAULT_GEMINI_FLASH_MODEL);
|
||||
});
|
||||
|
||||
@@ -88,8 +88,8 @@ describe('Flash Model Fallback Configuration', () => {
|
||||
expect(config.isInFallbackMode()).toBe(false);
|
||||
});
|
||||
|
||||
it('should persist switched state throughout session', () => {
|
||||
config.setModel(DEFAULT_GEMINI_FLASH_MODEL);
|
||||
it('should persist switched state throughout session', async () => {
|
||||
await config.setModel(DEFAULT_GEMINI_FLASH_MODEL);
|
||||
// Setting state for fallback mode as is expected of clients
|
||||
config.setFallbackMode(true);
|
||||
expect(config.isInFallbackMode()).toBe(true);
|
||||
|
||||
@@ -4,11 +4,10 @@
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
export const DEFAULT_QWEN_MODEL = 'qwen3-coder-plus';
|
||||
// We do not have a fallback model for now, but note it here anyway.
|
||||
export const DEFAULT_QWEN_FLASH_MODEL = 'qwen3-coder-flash';
|
||||
export const DEFAULT_QWEN_MODEL = 'coder-model';
|
||||
export const DEFAULT_QWEN_FLASH_MODEL = 'coder-model';
|
||||
|
||||
export const DEFAULT_GEMINI_MODEL = 'qwen3-coder-plus';
|
||||
export const DEFAULT_GEMINI_MODEL = 'coder-model';
|
||||
export const DEFAULT_GEMINI_FLASH_MODEL = 'gemini-2.5-flash';
|
||||
export const DEFAULT_GEMINI_FLASH_LITE_MODEL = 'gemini-2.5-flash-lite';
|
||||
|
||||
|
||||
@@ -23,6 +23,7 @@ import type {
|
||||
} from '@google/genai';
|
||||
import { GoogleGenAI } from '@google/genai';
|
||||
import { findIndexAfterFraction, GeminiClient } from './client.js';
|
||||
import { getPlanModeSystemReminder } from './prompts.js';
|
||||
import {
|
||||
AuthType,
|
||||
type ContentGenerator,
|
||||
@@ -50,6 +51,10 @@ const mockGenerateContentFn = vi.fn();
|
||||
const mockEmbedContentFn = vi.fn();
|
||||
const mockTurnRunFn = vi.fn();
|
||||
|
||||
let ApprovalModeEnum: typeof import('../config/config.js').ApprovalMode;
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
let mockConfigObject: any;
|
||||
|
||||
vi.mock('@google/genai');
|
||||
vi.mock('./turn', async (importOriginal) => {
|
||||
const actual = await importOriginal<typeof import('./turn.js')>();
|
||||
@@ -178,6 +183,12 @@ describe('Gemini Client (client.ts)', () => {
|
||||
beforeEach(async () => {
|
||||
vi.resetAllMocks();
|
||||
|
||||
ApprovalModeEnum = (
|
||||
await vi.importActual<typeof import('../config/config.js')>(
|
||||
'../config/config.js',
|
||||
)
|
||||
).ApprovalMode;
|
||||
|
||||
// Disable 429 simulation for tests
|
||||
setSimulate429(false);
|
||||
|
||||
@@ -228,8 +239,9 @@ describe('Gemini Client (client.ts)', () => {
|
||||
};
|
||||
const mockSubagentManager = {
|
||||
listSubagents: vi.fn().mockResolvedValue([]),
|
||||
addChangeListener: vi.fn().mockReturnValue(() => {}),
|
||||
};
|
||||
const mockConfigObject = {
|
||||
mockConfigObject = {
|
||||
getContentGeneratorConfig: vi
|
||||
.fn()
|
||||
.mockReturnValue(contentGeneratorConfig),
|
||||
@@ -252,6 +264,7 @@ describe('Gemini Client (client.ts)', () => {
|
||||
getNoBrowser: vi.fn().mockReturnValue(false),
|
||||
getSystemPromptMappings: vi.fn().mockReturnValue(undefined),
|
||||
getUsageStatisticsEnabled: vi.fn().mockReturnValue(true),
|
||||
getApprovalMode: vi.fn().mockReturnValue(ApprovalModeEnum.DEFAULT),
|
||||
getIdeModeFeature: vi.fn().mockReturnValue(false),
|
||||
getIdeMode: vi.fn().mockReturnValue(true),
|
||||
getDebugMode: vi.fn().mockReturnValue(false),
|
||||
@@ -948,6 +961,42 @@ describe('Gemini Client (client.ts)', () => {
|
||||
});
|
||||
|
||||
describe('sendMessageStream', () => {
|
||||
it('injects a plan mode reminder before user queries when approval mode is PLAN', async () => {
|
||||
const mockStream = (async function* () {})();
|
||||
mockTurnRunFn.mockReturnValue(mockStream);
|
||||
|
||||
mockConfigObject.getApprovalMode.mockReturnValue(ApprovalModeEnum.PLAN);
|
||||
|
||||
const mockChat: Partial<GeminiChat> = {
|
||||
addHistory: vi.fn(),
|
||||
getHistory: vi.fn().mockReturnValue([]),
|
||||
};
|
||||
client['chat'] = mockChat as GeminiChat;
|
||||
|
||||
const mockGenerator: Partial<ContentGenerator> = {
|
||||
countTokens: vi.fn().mockResolvedValue({ totalTokens: 0 }),
|
||||
generateContent: mockGenerateContentFn,
|
||||
};
|
||||
client['contentGenerator'] = mockGenerator as ContentGenerator;
|
||||
|
||||
const stream = client.sendMessageStream(
|
||||
'Plan mode test',
|
||||
new AbortController().signal,
|
||||
'prompt-plan-1',
|
||||
);
|
||||
|
||||
await fromAsync(stream);
|
||||
|
||||
expect(mockTurnRunFn).toHaveBeenCalledWith(
|
||||
[getPlanModeSystemReminder(), 'Plan mode test'],
|
||||
expect.any(Object),
|
||||
);
|
||||
|
||||
mockConfigObject.getApprovalMode.mockReturnValue(
|
||||
ApprovalModeEnum.DEFAULT,
|
||||
);
|
||||
});
|
||||
|
||||
it('emits a compression event when the context was automatically compressed', async () => {
|
||||
// Arrange
|
||||
const mockStream = (async function* () {
|
||||
@@ -1176,10 +1225,7 @@ ${JSON.stringify(
|
||||
|
||||
// Assert
|
||||
expect(ideContext.getIdeContext).toHaveBeenCalled();
|
||||
expect(mockTurnRunFn).toHaveBeenCalledWith(
|
||||
initialRequest,
|
||||
expect.any(Object),
|
||||
);
|
||||
expect(mockTurnRunFn).toHaveBeenCalledWith(['Hi'], expect.any(Object));
|
||||
});
|
||||
|
||||
it('should add context if ideMode is enabled and there is one active file', async () => {
|
||||
|
||||
@@ -17,6 +17,7 @@ import type {
|
||||
import { ProxyAgent, setGlobalDispatcher } from 'undici';
|
||||
import type { UserTierId } from '../code_assist/types.js';
|
||||
import type { Config } from '../config/config.js';
|
||||
import { ApprovalMode } from '../config/config.js';
|
||||
import { DEFAULT_GEMINI_FLASH_MODEL } from '../config/models.js';
|
||||
import type { File, IdeContext } from '../ide/ideContext.js';
|
||||
import { ideContext } from '../ide/ideContext.js';
|
||||
@@ -40,6 +41,7 @@ import { getFunctionCalls } from '../utils/generateContentResponseUtilities.js';
|
||||
import { isFunctionResponse } from '../utils/messageInspectors.js';
|
||||
import { checkNextSpeaker } from '../utils/nextSpeakerChecker.js';
|
||||
import { retryWithBackoff } from '../utils/retry.js';
|
||||
import { flatMapTextParts } from '../utils/partUtils.js';
|
||||
import type {
|
||||
ContentGenerator,
|
||||
ContentGeneratorConfig,
|
||||
@@ -50,6 +52,8 @@ import {
|
||||
getCompressionPrompt,
|
||||
getCoreSystemPrompt,
|
||||
getCustomSystemPrompt,
|
||||
getPlanModeSystemReminder,
|
||||
getSubagentSystemReminder,
|
||||
} from './prompts.js';
|
||||
import { tokenLimit } from './tokenLimits.js';
|
||||
import type { ChatCompressionInfo, ServerGeminiStreamEvent } from './turn.js';
|
||||
@@ -598,24 +602,6 @@ export class GeminiClient {
|
||||
this.forceFullIdeContext = false;
|
||||
}
|
||||
|
||||
if (isNewPrompt) {
|
||||
const taskTool = this.config.getToolRegistry().getTool(TaskTool.Name);
|
||||
const subagents = (
|
||||
await this.config.getSubagentManager().listSubagents()
|
||||
).filter((subagent) => subagent.level !== 'builtin');
|
||||
|
||||
if (taskTool && subagents.length > 0) {
|
||||
this.getChat().addHistory({
|
||||
role: 'user',
|
||||
parts: [
|
||||
{
|
||||
text: `<system-reminder>You have powerful specialized agents at your disposal, available agent types are: ${subagents.map((subagent) => subagent.name).join(', ')}. PROACTIVELY use the ${TaskTool.Name} tool to delegate user's task to appropriate agent when user's task matches agent capabilities. Ignore this message if user's task is not relevant to any agent. This message is for internal use only. Do not mention this to user in your response.</system-reminder>`,
|
||||
},
|
||||
],
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
const turn = new Turn(this.getChat(), prompt_id);
|
||||
|
||||
if (!this.config.getSkipLoopDetection()) {
|
||||
@@ -626,7 +612,30 @@ export class GeminiClient {
|
||||
}
|
||||
}
|
||||
|
||||
const resultStream = turn.run(request, signal);
|
||||
// append system reminders to the request
|
||||
let requestToSent = await flatMapTextParts(request, async (text) => [text]);
|
||||
if (isNewPrompt) {
|
||||
const systemReminders = [];
|
||||
|
||||
// add subagent system reminder if there are subagents
|
||||
const hasTaskTool = this.config.getToolRegistry().getTool(TaskTool.Name);
|
||||
const subagents = (await this.config.getSubagentManager().listSubagents())
|
||||
.filter((subagent) => subagent.level !== 'builtin')
|
||||
.map((subagent) => subagent.name);
|
||||
|
||||
if (hasTaskTool && subagents.length > 0) {
|
||||
systemReminders.push(getSubagentSystemReminder(subagents));
|
||||
}
|
||||
|
||||
// add plan mode system reminder if approval mode is plan
|
||||
if (this.config.getApprovalMode() === ApprovalMode.PLAN) {
|
||||
systemReminders.push(getPlanModeSystemReminder());
|
||||
}
|
||||
|
||||
requestToSent = [...systemReminders, ...requestToSent];
|
||||
}
|
||||
|
||||
const resultStream = turn.run(requestToSent, signal);
|
||||
for await (const event of resultStream) {
|
||||
if (!this.config.getSkipLoopDetection()) {
|
||||
if (this.loopDetector.addAndCheck(event)) {
|
||||
@@ -1053,7 +1062,7 @@ export class GeminiClient {
|
||||
error,
|
||||
);
|
||||
if (accepted !== false && accepted !== null) {
|
||||
this.config.setModel(fallbackModel);
|
||||
await this.config.setModel(fallbackModel);
|
||||
this.config.setFallbackMode(true);
|
||||
return fallbackModel;
|
||||
}
|
||||
|
||||
@@ -10,11 +10,13 @@ import { describe, expect, it, vi } from 'vitest';
|
||||
import type {
|
||||
Config,
|
||||
ToolCallConfirmationDetails,
|
||||
ToolCallRequestInfo,
|
||||
ToolConfirmationPayload,
|
||||
ToolInvocation,
|
||||
ToolResult,
|
||||
ToolResultDisplay,
|
||||
ToolRegistry,
|
||||
SuccessfulToolCall,
|
||||
} from '../index.js';
|
||||
import {
|
||||
ApprovalMode,
|
||||
@@ -24,11 +26,16 @@ import {
|
||||
ToolConfirmationOutcome,
|
||||
} from '../index.js';
|
||||
import { MockModifiableTool, MockTool } from '../test-utils/tools.js';
|
||||
import type { ToolCall, WaitingToolCall } from './coreToolScheduler.js';
|
||||
import type {
|
||||
ToolCall,
|
||||
WaitingToolCall,
|
||||
ErroredToolCall,
|
||||
} from './coreToolScheduler.js';
|
||||
import {
|
||||
CoreToolScheduler,
|
||||
convertToFunctionResponse,
|
||||
} from './coreToolScheduler.js';
|
||||
import { getPlanModeSystemReminder } from './prompts.js';
|
||||
|
||||
class TestApprovalTool extends BaseDeclarativeTool<{ id: string }, ToolResult> {
|
||||
static readonly Name = 'testApprovalTool';
|
||||
@@ -101,6 +108,49 @@ class TestApprovalInvocation extends BaseToolInvocation<
|
||||
}
|
||||
}
|
||||
|
||||
class SimpleToolInvocation extends BaseToolInvocation<
|
||||
Record<string, unknown>,
|
||||
ToolResult
|
||||
> {
|
||||
constructor(
|
||||
params: Record<string, unknown>,
|
||||
private readonly executeImpl: () => Promise<ToolResult> | ToolResult,
|
||||
) {
|
||||
super(params);
|
||||
}
|
||||
|
||||
getDescription(): string {
|
||||
return 'simple tool invocation';
|
||||
}
|
||||
|
||||
async execute(): Promise<ToolResult> {
|
||||
return await Promise.resolve(this.executeImpl());
|
||||
}
|
||||
}
|
||||
|
||||
class SimpleTool extends BaseDeclarativeTool<
|
||||
Record<string, unknown>,
|
||||
ToolResult
|
||||
> {
|
||||
constructor(
|
||||
name: string,
|
||||
kind: Kind,
|
||||
private readonly executeImpl: () => Promise<ToolResult> | ToolResult,
|
||||
) {
|
||||
super(name, name, 'Simple test tool', kind, {
|
||||
type: 'object',
|
||||
properties: {},
|
||||
additionalProperties: true,
|
||||
});
|
||||
}
|
||||
|
||||
protected createInvocation(
|
||||
params: Record<string, unknown>,
|
||||
): ToolInvocation<Record<string, unknown>, ToolResult> {
|
||||
return new SimpleToolInvocation(params, this.executeImpl);
|
||||
}
|
||||
}
|
||||
|
||||
async function waitForStatus(
|
||||
onToolCallsUpdate: Mock,
|
||||
status: 'awaiting_approval' | 'executing' | 'success' | 'error' | 'cancelled',
|
||||
@@ -197,6 +247,249 @@ describe('CoreToolScheduler', () => {
|
||||
expect(completedCalls[0].status).toBe('cancelled');
|
||||
});
|
||||
|
||||
describe('plan mode enforcement', () => {
|
||||
it('returns plan reminder and skips execution for edit tools', async () => {
|
||||
const executeSpy = vi.fn().mockResolvedValue({
|
||||
llmContent: 'should not execute',
|
||||
returnDisplay: 'should not execute',
|
||||
});
|
||||
// Use MockTool with shouldConfirm=true to simulate a tool that requires confirmation
|
||||
const tool = new MockTool('write_file');
|
||||
tool.shouldConfirm = true;
|
||||
tool.executeFn = executeSpy;
|
||||
|
||||
const mockToolRegistry = {
|
||||
getTool: vi.fn().mockReturnValue(tool),
|
||||
getAllToolNames: vi.fn().mockReturnValue([tool.name]),
|
||||
} as unknown as ToolRegistry;
|
||||
|
||||
const onAllToolCallsComplete = vi.fn();
|
||||
const onToolCallsUpdate = vi.fn();
|
||||
|
||||
const mockConfig = {
|
||||
getSessionId: () => 'plan-session',
|
||||
getUsageStatisticsEnabled: () => true,
|
||||
getDebugMode: () => false,
|
||||
getApprovalMode: vi.fn().mockReturnValue(ApprovalMode.PLAN),
|
||||
getAllowedTools: () => [],
|
||||
getContentGeneratorConfig: () => ({
|
||||
model: 'test-model',
|
||||
authType: 'oauth-personal',
|
||||
}),
|
||||
getToolRegistry: () => mockToolRegistry,
|
||||
} as unknown as Config;
|
||||
|
||||
const scheduler = new CoreToolScheduler({
|
||||
config: mockConfig,
|
||||
onAllToolCallsComplete,
|
||||
onToolCallsUpdate,
|
||||
getPreferredEditor: () => 'vscode',
|
||||
onEditorClose: vi.fn(),
|
||||
});
|
||||
|
||||
const request: ToolCallRequestInfo = {
|
||||
callId: 'plan-1',
|
||||
name: 'write_file',
|
||||
args: {},
|
||||
isClientInitiated: false,
|
||||
prompt_id: 'prompt-plan',
|
||||
};
|
||||
|
||||
await scheduler.schedule([request], new AbortController().signal);
|
||||
|
||||
const errorCall = (await waitForStatus(
|
||||
onToolCallsUpdate,
|
||||
'error',
|
||||
)) as ErroredToolCall;
|
||||
|
||||
expect(executeSpy).not.toHaveBeenCalled();
|
||||
expect(
|
||||
errorCall.response.responseParts[0]?.functionResponse?.response?.[
|
||||
'output'
|
||||
],
|
||||
).toBe(getPlanModeSystemReminder());
|
||||
expect(errorCall.response.resultDisplay).toContain('Plan mode');
|
||||
});
|
||||
|
||||
it('allows read tools to execute in plan mode', async () => {
|
||||
const executeSpy = vi.fn().mockResolvedValue({
|
||||
llmContent: 'read ok',
|
||||
returnDisplay: 'read ok',
|
||||
});
|
||||
const tool = new SimpleTool('read_file', Kind.Read, executeSpy);
|
||||
|
||||
const mockToolRegistry = {
|
||||
getTool: vi.fn().mockReturnValue(tool),
|
||||
getAllToolNames: vi.fn().mockReturnValue([tool.name]),
|
||||
} as unknown as ToolRegistry;
|
||||
|
||||
const onAllToolCallsComplete = vi.fn();
|
||||
const onToolCallsUpdate = vi.fn();
|
||||
|
||||
const mockConfig = {
|
||||
getSessionId: () => 'plan-session',
|
||||
getUsageStatisticsEnabled: () => true,
|
||||
getDebugMode: () => false,
|
||||
getApprovalMode: vi.fn().mockReturnValue(ApprovalMode.PLAN),
|
||||
getAllowedTools: () => [],
|
||||
getContentGeneratorConfig: () => ({
|
||||
model: 'test-model',
|
||||
authType: 'oauth-personal',
|
||||
}),
|
||||
getToolRegistry: () => mockToolRegistry,
|
||||
} as unknown as Config;
|
||||
|
||||
const scheduler = new CoreToolScheduler({
|
||||
config: mockConfig,
|
||||
onAllToolCallsComplete,
|
||||
onToolCallsUpdate,
|
||||
getPreferredEditor: () => 'vscode',
|
||||
onEditorClose: vi.fn(),
|
||||
});
|
||||
|
||||
const request: ToolCallRequestInfo = {
|
||||
callId: 'plan-2',
|
||||
name: tool.name,
|
||||
args: {},
|
||||
isClientInitiated: false,
|
||||
prompt_id: 'prompt-plan',
|
||||
};
|
||||
|
||||
await scheduler.schedule([request], new AbortController().signal);
|
||||
|
||||
const successCall = (await waitForStatus(
|
||||
onToolCallsUpdate,
|
||||
'success',
|
||||
)) as SuccessfulToolCall;
|
||||
|
||||
expect(executeSpy).toHaveBeenCalledTimes(1);
|
||||
expect(
|
||||
successCall.response.responseParts[0]?.functionResponse?.response?.[
|
||||
'output'
|
||||
],
|
||||
).toBe('read ok');
|
||||
});
|
||||
|
||||
it('enforces shell command restrictions in plan mode', async () => {
|
||||
const executeSpyAllowed = vi.fn().mockResolvedValue({
|
||||
llmContent: 'shell ok',
|
||||
returnDisplay: 'shell ok',
|
||||
});
|
||||
const allowedTool = new SimpleTool(
|
||||
'run_shell_command',
|
||||
Kind.Execute,
|
||||
executeSpyAllowed,
|
||||
);
|
||||
|
||||
const allowedToolRegistry = {
|
||||
getTool: vi.fn().mockReturnValue(allowedTool),
|
||||
getAllToolNames: vi.fn().mockReturnValue([allowedTool.name]),
|
||||
} as unknown as ToolRegistry;
|
||||
|
||||
const allowedConfig = {
|
||||
getSessionId: () => 'plan-session',
|
||||
getUsageStatisticsEnabled: () => true,
|
||||
getDebugMode: () => false,
|
||||
getApprovalMode: vi.fn().mockReturnValue(ApprovalMode.PLAN),
|
||||
getAllowedTools: () => [],
|
||||
getContentGeneratorConfig: () => ({
|
||||
model: 'test-model',
|
||||
authType: 'oauth-personal',
|
||||
}),
|
||||
getToolRegistry: () => allowedToolRegistry,
|
||||
} as unknown as Config;
|
||||
|
||||
const allowedUpdates = vi.fn();
|
||||
const allowedScheduler = new CoreToolScheduler({
|
||||
config: allowedConfig,
|
||||
onAllToolCallsComplete: vi.fn(),
|
||||
onToolCallsUpdate: allowedUpdates,
|
||||
getPreferredEditor: () => 'vscode',
|
||||
onEditorClose: vi.fn(),
|
||||
});
|
||||
|
||||
const allowedRequest: ToolCallRequestInfo = {
|
||||
callId: 'plan-shell-allowed',
|
||||
name: allowedTool.name,
|
||||
args: { command: 'ls -la' },
|
||||
isClientInitiated: false,
|
||||
prompt_id: 'prompt-plan',
|
||||
};
|
||||
|
||||
await allowedScheduler.schedule(
|
||||
[allowedRequest],
|
||||
new AbortController().signal,
|
||||
);
|
||||
|
||||
await waitForStatus(allowedUpdates, 'success');
|
||||
expect(executeSpyAllowed).toHaveBeenCalledTimes(1);
|
||||
|
||||
const executeSpyBlocked = vi.fn().mockResolvedValue({
|
||||
llmContent: 'blocked',
|
||||
returnDisplay: 'blocked',
|
||||
});
|
||||
// Use MockTool with shouldConfirm=true to simulate a shell tool that requires confirmation
|
||||
const blockedTool = new MockTool('run_shell_command');
|
||||
blockedTool.shouldConfirm = true;
|
||||
blockedTool.executeFn = executeSpyBlocked;
|
||||
|
||||
const blockedToolRegistry = {
|
||||
getTool: vi.fn().mockReturnValue(blockedTool),
|
||||
getAllToolNames: vi.fn().mockReturnValue([blockedTool.name]),
|
||||
} as unknown as ToolRegistry;
|
||||
|
||||
const blockedConfig = {
|
||||
getSessionId: () => 'plan-session',
|
||||
getUsageStatisticsEnabled: () => true,
|
||||
getDebugMode: () => false,
|
||||
getApprovalMode: vi.fn().mockReturnValue(ApprovalMode.PLAN),
|
||||
getAllowedTools: () => [],
|
||||
getContentGeneratorConfig: () => ({
|
||||
model: 'test-model',
|
||||
authType: 'oauth-personal',
|
||||
}),
|
||||
getToolRegistry: () => blockedToolRegistry,
|
||||
} as unknown as Config;
|
||||
|
||||
const blockedUpdates = vi.fn();
|
||||
const blockedScheduler = new CoreToolScheduler({
|
||||
config: blockedConfig,
|
||||
onAllToolCallsComplete: vi.fn(),
|
||||
onToolCallsUpdate: blockedUpdates,
|
||||
getPreferredEditor: () => 'vscode',
|
||||
onEditorClose: vi.fn(),
|
||||
});
|
||||
|
||||
const blockedRequest: ToolCallRequestInfo = {
|
||||
callId: 'plan-shell-blocked',
|
||||
name: 'run_shell_command',
|
||||
args: { command: 'rm -rf tmp' },
|
||||
isClientInitiated: false,
|
||||
prompt_id: 'prompt-plan',
|
||||
};
|
||||
|
||||
await blockedScheduler.schedule(
|
||||
[blockedRequest],
|
||||
new AbortController().signal,
|
||||
);
|
||||
|
||||
const blockedCall = (await waitForStatus(
|
||||
blockedUpdates,
|
||||
'error',
|
||||
)) as ErroredToolCall;
|
||||
expect(executeSpyBlocked).not.toHaveBeenCalled();
|
||||
expect(
|
||||
blockedCall.response.responseParts[0]?.functionResponse?.response?.[
|
||||
'output'
|
||||
],
|
||||
).toBe(getPlanModeSystemReminder());
|
||||
const observedStatuses = blockedUpdates.mock.calls
|
||||
.flatMap((call) => call[0] as ToolCall[])
|
||||
.map((tc) => tc.status);
|
||||
expect(observedStatuses).not.toContain('awaiting_approval');
|
||||
});
|
||||
});
|
||||
|
||||
describe('getToolSuggestion', () => {
|
||||
it('should suggest the top N closest tool names for a typo', () => {
|
||||
// Create mocked tool registry
|
||||
|
||||
@@ -34,6 +34,7 @@ import {
|
||||
import * as Diff from 'diff';
|
||||
import { doesToolInvocationMatch } from '../utils/tool-utils.js';
|
||||
import levenshtein from 'fast-levenshtein';
|
||||
import { getPlanModeSystemReminder } from './prompts.js';
|
||||
|
||||
export type ValidatingToolCall = {
|
||||
status: 'validating';
|
||||
@@ -674,7 +675,27 @@ export class CoreToolScheduler {
|
||||
}
|
||||
|
||||
const allowedTools = this.config.getAllowedTools() || [];
|
||||
if (
|
||||
const isPlanMode =
|
||||
this.config.getApprovalMode() === ApprovalMode.PLAN;
|
||||
const isExitPlanModeTool = reqInfo.name === 'exit_plan_mode';
|
||||
|
||||
if (isPlanMode && !isExitPlanModeTool) {
|
||||
if (confirmationDetails) {
|
||||
this.setStatusInternal(reqInfo.callId, 'error', {
|
||||
callId: reqInfo.callId,
|
||||
responseParts: convertToFunctionResponse(
|
||||
reqInfo.name,
|
||||
reqInfo.callId,
|
||||
getPlanModeSystemReminder(),
|
||||
),
|
||||
resultDisplay: 'Plan mode blocked a non-read-only tool call.',
|
||||
error: undefined,
|
||||
errorType: undefined,
|
||||
});
|
||||
} else {
|
||||
this.setStatusInternal(reqInfo.callId, 'scheduled');
|
||||
}
|
||||
} else if (
|
||||
this.config.getApprovalMode() === ApprovalMode.YOLO ||
|
||||
doesToolInvocationMatch(toolCall.tool, invocation, allowedTools)
|
||||
) {
|
||||
|
||||
@@ -224,7 +224,7 @@ export class GeminiChat {
|
||||
error,
|
||||
);
|
||||
if (accepted !== false && accepted !== null) {
|
||||
this.config.setModel(fallbackModel);
|
||||
await this.config.setModel(fallbackModel);
|
||||
this.config.setFallbackMode(true);
|
||||
return fallbackModel;
|
||||
}
|
||||
|
||||
@@ -755,4 +755,84 @@ describe('Logger', () => {
|
||||
expect(logger['messageId']).toBe(0);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Model Switch Logging', () => {
|
||||
it('should log model switch events correctly', async () => {
|
||||
const testSessionId = 'test-session-model-switch';
|
||||
const logger = new Logger(testSessionId, new Storage(process.cwd()));
|
||||
await logger.initialize();
|
||||
|
||||
const modelSwitchEvent = {
|
||||
fromModel: 'qwen3-coder-plus',
|
||||
toModel: 'qwen-vl-max-latest',
|
||||
reason: 'vision_auto_switch' as const,
|
||||
context: 'YOLO mode auto-switch for image content',
|
||||
};
|
||||
|
||||
await logger.logModelSwitch(modelSwitchEvent);
|
||||
|
||||
// Read the log file to verify the entry was written
|
||||
const logContent = await fs.readFile(TEST_LOG_FILE_PATH, 'utf-8');
|
||||
const logs: LogEntry[] = JSON.parse(logContent);
|
||||
|
||||
const modelSwitchLog = logs.find(
|
||||
(log) =>
|
||||
log.sessionId === testSessionId &&
|
||||
log.type === MessageSenderType.MODEL_SWITCH,
|
||||
);
|
||||
|
||||
expect(modelSwitchLog).toBeDefined();
|
||||
expect(modelSwitchLog!.type).toBe(MessageSenderType.MODEL_SWITCH);
|
||||
|
||||
const loggedEvent = JSON.parse(modelSwitchLog!.message);
|
||||
expect(loggedEvent.fromModel).toBe('qwen3-coder-plus');
|
||||
expect(loggedEvent.toModel).toBe('qwen-vl-max-latest');
|
||||
expect(loggedEvent.reason).toBe('vision_auto_switch');
|
||||
expect(loggedEvent.context).toBe(
|
||||
'YOLO mode auto-switch for image content',
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle multiple model switch events', async () => {
|
||||
const testSessionId = 'test-session-multiple-switches';
|
||||
const logger = new Logger(testSessionId, new Storage(process.cwd()));
|
||||
await logger.initialize();
|
||||
|
||||
// Log first switch
|
||||
await logger.logModelSwitch({
|
||||
fromModel: 'qwen3-coder-plus',
|
||||
toModel: 'qwen-vl-max-latest',
|
||||
reason: 'vision_auto_switch',
|
||||
context: 'Auto-switch for image',
|
||||
});
|
||||
|
||||
// Log second switch (restore)
|
||||
await logger.logModelSwitch({
|
||||
fromModel: 'qwen-vl-max-latest',
|
||||
toModel: 'qwen3-coder-plus',
|
||||
reason: 'vision_auto_switch',
|
||||
context: 'Restoring original model',
|
||||
});
|
||||
|
||||
// Read the log file to verify both entries were written
|
||||
const logContent = await fs.readFile(TEST_LOG_FILE_PATH, 'utf-8');
|
||||
const logs: LogEntry[] = JSON.parse(logContent);
|
||||
|
||||
const modelSwitchLogs = logs.filter(
|
||||
(log) =>
|
||||
log.sessionId === testSessionId &&
|
||||
log.type === MessageSenderType.MODEL_SWITCH,
|
||||
);
|
||||
|
||||
expect(modelSwitchLogs).toHaveLength(2);
|
||||
|
||||
const firstSwitch = JSON.parse(modelSwitchLogs[0].message);
|
||||
expect(firstSwitch.fromModel).toBe('qwen3-coder-plus');
|
||||
expect(firstSwitch.toModel).toBe('qwen-vl-max-latest');
|
||||
|
||||
const secondSwitch = JSON.parse(modelSwitchLogs[1].message);
|
||||
expect(secondSwitch.fromModel).toBe('qwen-vl-max-latest');
|
||||
expect(secondSwitch.toModel).toBe('qwen3-coder-plus');
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -13,6 +13,7 @@ const LOG_FILE_NAME = 'logs.json';
|
||||
|
||||
export enum MessageSenderType {
|
||||
USER = 'user',
|
||||
MODEL_SWITCH = 'model_switch',
|
||||
}
|
||||
|
||||
export interface LogEntry {
|
||||
@@ -23,6 +24,13 @@ export interface LogEntry {
|
||||
message: string;
|
||||
}
|
||||
|
||||
export interface ModelSwitchEvent {
|
||||
fromModel: string;
|
||||
toModel: string;
|
||||
reason: 'vision_auto_switch' | 'manual' | 'fallback' | 'other';
|
||||
context?: string;
|
||||
}
|
||||
|
||||
// This regex matches any character that is NOT a letter (a-z, A-Z),
|
||||
// a number (0-9), a hyphen (-), an underscore (_), or a dot (.).
|
||||
|
||||
@@ -270,6 +278,17 @@ export class Logger {
|
||||
}
|
||||
}
|
||||
|
||||
async logModelSwitch(event: ModelSwitchEvent): Promise<void> {
|
||||
const message = JSON.stringify({
|
||||
fromModel: event.fromModel,
|
||||
toModel: event.toModel,
|
||||
reason: event.reason,
|
||||
context: event.context,
|
||||
});
|
||||
|
||||
await this.logMessage(MessageSenderType.MODEL_SWITCH, message);
|
||||
}
|
||||
|
||||
private _checkpointPath(tag: string): string {
|
||||
if (!tag.length) {
|
||||
throw new Error('No checkpoint tag specified.');
|
||||
|
||||
@@ -376,28 +376,22 @@ export class OpenAIContentConverter {
|
||||
parsedParts: Pick<ParsedParts, 'textParts' | 'mediaParts'>,
|
||||
): OpenAI.Chat.ChatCompletionMessageParam | null {
|
||||
const { textParts, mediaParts } = parsedParts;
|
||||
const combinedText = textParts.join('');
|
||||
const content = textParts.map((text) => ({ type: 'text' as const, text }));
|
||||
|
||||
// If no media parts, return simple text message
|
||||
if (mediaParts.length === 0) {
|
||||
return combinedText ? { role, content: combinedText } : null;
|
||||
return content.length > 0 ? { role, content } : null;
|
||||
}
|
||||
|
||||
// For assistant messages with media, convert to text only
|
||||
// since OpenAI assistant messages don't support media content arrays
|
||||
if (role === 'assistant') {
|
||||
return combinedText
|
||||
? { role: 'assistant' as const, content: combinedText }
|
||||
return content.length > 0
|
||||
? { role: 'assistant' as const, content }
|
||||
: null;
|
||||
}
|
||||
|
||||
// Create multimodal content array for user messages
|
||||
const contentArray: OpenAI.Chat.ChatCompletionContentPart[] = [];
|
||||
|
||||
// Add text content
|
||||
if (combinedText) {
|
||||
contentArray.push({ type: 'text', text: combinedText });
|
||||
}
|
||||
const contentArray: OpenAI.Chat.ChatCompletionContentPart[] = [...content];
|
||||
|
||||
// Add media content
|
||||
for (const mediaPart of mediaParts) {
|
||||
@@ -405,14 +399,14 @@ export class OpenAIContentConverter {
|
||||
if (mediaPart.fileUri) {
|
||||
// For file URIs, use the URI directly
|
||||
contentArray.push({
|
||||
type: 'image_url',
|
||||
type: 'image_url' as const,
|
||||
image_url: { url: mediaPart.fileUri },
|
||||
});
|
||||
} else if (mediaPart.data) {
|
||||
// For inline data, create data URL
|
||||
const dataUrl = `data:${mediaPart.mimeType};base64,${mediaPart.data}`;
|
||||
contentArray.push({
|
||||
type: 'image_url',
|
||||
type: 'image_url' as const,
|
||||
image_url: { url: dataUrl },
|
||||
});
|
||||
}
|
||||
@@ -421,7 +415,7 @@ export class OpenAIContentConverter {
|
||||
const format = this.getAudioFormat(mediaPart.mimeType);
|
||||
if (format) {
|
||||
contentArray.push({
|
||||
type: 'input_audio',
|
||||
type: 'input_audio' as const,
|
||||
input_audio: {
|
||||
data: mediaPart.data,
|
||||
format: format as 'wav' | 'mp3',
|
||||
|
||||
@@ -1105,5 +1105,164 @@ describe('ContentGenerationPipeline', () => {
|
||||
expect.any(Array),
|
||||
);
|
||||
});
|
||||
|
||||
it('should collect all OpenAI chunks for logging even when Gemini responses are filtered', async () => {
|
||||
// Create chunks that would produce empty Gemini responses (partial tool calls)
|
||||
const partialToolCallChunk1: OpenAI.Chat.ChatCompletionChunk = {
|
||||
id: 'chunk-1',
|
||||
object: 'chat.completion.chunk',
|
||||
created: Date.now(),
|
||||
model: 'test-model',
|
||||
choices: [
|
||||
{
|
||||
index: 0,
|
||||
delta: {
|
||||
tool_calls: [
|
||||
{
|
||||
index: 0,
|
||||
id: 'call_123',
|
||||
type: 'function',
|
||||
function: { name: 'test_function', arguments: '{"par' },
|
||||
},
|
||||
],
|
||||
},
|
||||
finish_reason: null,
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
const partialToolCallChunk2: OpenAI.Chat.ChatCompletionChunk = {
|
||||
id: 'chunk-2',
|
||||
object: 'chat.completion.chunk',
|
||||
created: Date.now(),
|
||||
model: 'test-model',
|
||||
choices: [
|
||||
{
|
||||
index: 0,
|
||||
delta: {
|
||||
tool_calls: [
|
||||
{
|
||||
index: 0,
|
||||
function: { arguments: 'am": "value"}' },
|
||||
},
|
||||
],
|
||||
},
|
||||
finish_reason: null,
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
const finishChunk: OpenAI.Chat.ChatCompletionChunk = {
|
||||
id: 'chunk-3',
|
||||
object: 'chat.completion.chunk',
|
||||
created: Date.now(),
|
||||
model: 'test-model',
|
||||
choices: [
|
||||
{
|
||||
index: 0,
|
||||
delta: {},
|
||||
finish_reason: 'tool_calls',
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
// Mock empty Gemini responses for partial chunks (they get filtered)
|
||||
const emptyGeminiResponse1 = new GenerateContentResponse();
|
||||
emptyGeminiResponse1.candidates = [
|
||||
{
|
||||
content: { parts: [], role: 'model' },
|
||||
index: 0,
|
||||
safetyRatings: [],
|
||||
},
|
||||
];
|
||||
|
||||
const emptyGeminiResponse2 = new GenerateContentResponse();
|
||||
emptyGeminiResponse2.candidates = [
|
||||
{
|
||||
content: { parts: [], role: 'model' },
|
||||
index: 0,
|
||||
safetyRatings: [],
|
||||
},
|
||||
];
|
||||
|
||||
// Mock final Gemini response with tool call
|
||||
const finalGeminiResponse = new GenerateContentResponse();
|
||||
finalGeminiResponse.candidates = [
|
||||
{
|
||||
content: {
|
||||
parts: [
|
||||
{
|
||||
functionCall: {
|
||||
id: 'call_123',
|
||||
name: 'test_function',
|
||||
args: { param: 'value' },
|
||||
},
|
||||
},
|
||||
],
|
||||
role: 'model',
|
||||
},
|
||||
finishReason: FinishReason.STOP,
|
||||
index: 0,
|
||||
safetyRatings: [],
|
||||
},
|
||||
];
|
||||
|
||||
// Setup converter mocks
|
||||
(mockConverter.convertGeminiRequestToOpenAI as Mock).mockReturnValue([
|
||||
{ role: 'user', content: 'test' },
|
||||
]);
|
||||
(mockConverter.convertOpenAIChunkToGemini as Mock)
|
||||
.mockReturnValueOnce(emptyGeminiResponse1) // First partial chunk -> empty response
|
||||
.mockReturnValueOnce(emptyGeminiResponse2) // Second partial chunk -> empty response
|
||||
.mockReturnValueOnce(finalGeminiResponse); // Finish chunk -> complete response
|
||||
|
||||
// Mock stream
|
||||
const mockStream = {
|
||||
async *[Symbol.asyncIterator]() {
|
||||
yield partialToolCallChunk1;
|
||||
yield partialToolCallChunk2;
|
||||
yield finishChunk;
|
||||
},
|
||||
};
|
||||
|
||||
(mockClient.chat.completions.create as Mock).mockResolvedValue(
|
||||
mockStream,
|
||||
);
|
||||
|
||||
const request: GenerateContentParameters = {
|
||||
model: 'test-model',
|
||||
contents: [{ role: 'user', parts: [{ text: 'test' }] }],
|
||||
};
|
||||
|
||||
// Collect responses
|
||||
const responses: GenerateContentResponse[] = [];
|
||||
const resultGenerator = await pipeline.executeStream(
|
||||
request,
|
||||
'test-prompt-id',
|
||||
);
|
||||
for await (const response of resultGenerator) {
|
||||
responses.push(response);
|
||||
}
|
||||
|
||||
// Should only yield the final response (empty ones are filtered)
|
||||
expect(responses).toHaveLength(1);
|
||||
expect(responses[0]).toBe(finalGeminiResponse);
|
||||
|
||||
// Verify telemetry was called with ALL OpenAI chunks, including the filtered ones
|
||||
expect(mockTelemetryService.logStreamingSuccess).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
model: 'test-model',
|
||||
duration: expect.any(Number),
|
||||
userPromptId: 'test-prompt-id',
|
||||
authType: 'openai',
|
||||
}),
|
||||
[finalGeminiResponse], // Only the non-empty Gemini response
|
||||
expect.objectContaining({
|
||||
model: 'test-model',
|
||||
messages: [{ role: 'user', content: 'test' }],
|
||||
}),
|
||||
[partialToolCallChunk1, partialToolCallChunk2, finishChunk], // ALL OpenAI chunks
|
||||
);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -118,6 +118,9 @@ export class ContentGenerationPipeline {
|
||||
try {
|
||||
// Stage 2a: Convert and yield each chunk while preserving original
|
||||
for await (const chunk of stream) {
|
||||
// Always collect OpenAI chunks for logging, regardless of Gemini conversion result
|
||||
collectedOpenAIChunks.push(chunk);
|
||||
|
||||
const response = this.converter.convertOpenAIChunkToGemini(chunk);
|
||||
|
||||
// Stage 2b: Filter empty responses to avoid downstream issues
|
||||
@@ -132,9 +135,7 @@ export class ContentGenerationPipeline {
|
||||
// Stage 2c: Handle chunk merging for providers that send finishReason and usageMetadata separately
|
||||
const shouldYield = this.handleChunkMerging(
|
||||
response,
|
||||
chunk,
|
||||
collectedGeminiResponses,
|
||||
collectedOpenAIChunks,
|
||||
(mergedResponse) => {
|
||||
pendingFinishResponse = mergedResponse;
|
||||
},
|
||||
@@ -182,17 +183,13 @@ export class ContentGenerationPipeline {
|
||||
* finishReason and the most up-to-date usage information from any provider pattern.
|
||||
*
|
||||
* @param response Current Gemini response
|
||||
* @param chunk Current OpenAI chunk
|
||||
* @param collectedGeminiResponses Array to collect responses for logging
|
||||
* @param collectedOpenAIChunks Array to collect chunks for logging
|
||||
* @param setPendingFinish Callback to set pending finish response
|
||||
* @returns true if the response should be yielded, false if it should be held for merging
|
||||
*/
|
||||
private handleChunkMerging(
|
||||
response: GenerateContentResponse,
|
||||
chunk: OpenAI.Chat.ChatCompletionChunk,
|
||||
collectedGeminiResponses: GenerateContentResponse[],
|
||||
collectedOpenAIChunks: OpenAI.Chat.ChatCompletionChunk[],
|
||||
setPendingFinish: (response: GenerateContentResponse) => void,
|
||||
): boolean {
|
||||
const isFinishChunk = response.candidates?.[0]?.finishReason;
|
||||
@@ -206,7 +203,6 @@ export class ContentGenerationPipeline {
|
||||
if (isFinishChunk) {
|
||||
// This is a finish reason chunk
|
||||
collectedGeminiResponses.push(response);
|
||||
collectedOpenAIChunks.push(chunk);
|
||||
setPendingFinish(response);
|
||||
return false; // Don't yield yet, wait for potential subsequent chunks to merge
|
||||
} else if (hasPendingFinish) {
|
||||
@@ -228,7 +224,6 @@ export class ContentGenerationPipeline {
|
||||
// Update the collected responses with the merged response
|
||||
collectedGeminiResponses[collectedGeminiResponses.length - 1] =
|
||||
mergedResponse;
|
||||
collectedOpenAIChunks.push(chunk);
|
||||
|
||||
setPendingFinish(mergedResponse);
|
||||
return true; // Yield the merged response
|
||||
@@ -236,7 +231,6 @@ export class ContentGenerationPipeline {
|
||||
|
||||
// Normal chunk - collect and yield
|
||||
collectedGeminiResponses.push(response);
|
||||
collectedOpenAIChunks.push(chunk);
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
@@ -560,4 +560,146 @@ describe('DashScopeOpenAICompatibleProvider', () => {
|
||||
]);
|
||||
});
|
||||
});
|
||||
|
||||
describe('output token limits', () => {
|
||||
it('should limit max_tokens when it exceeds model limit for qwen3-coder-plus', () => {
|
||||
const request: OpenAI.Chat.ChatCompletionCreateParams = {
|
||||
model: 'qwen3-coder-plus',
|
||||
messages: [{ role: 'user', content: 'Hello' }],
|
||||
max_tokens: 100000, // Exceeds the 65536 limit
|
||||
};
|
||||
|
||||
const result = provider.buildRequest(request, 'test-prompt-id');
|
||||
|
||||
expect(result.max_tokens).toBe(65536); // Should be limited to model's output limit
|
||||
});
|
||||
|
||||
it('should limit max_tokens when it exceeds model limit for qwen-vl-max-latest', () => {
|
||||
const request: OpenAI.Chat.ChatCompletionCreateParams = {
|
||||
model: 'qwen-vl-max-latest',
|
||||
messages: [{ role: 'user', content: 'Hello' }],
|
||||
max_tokens: 20000, // Exceeds the 8192 limit
|
||||
};
|
||||
|
||||
const result = provider.buildRequest(request, 'test-prompt-id');
|
||||
|
||||
expect(result.max_tokens).toBe(8192); // Should be limited to model's output limit
|
||||
});
|
||||
|
||||
it('should not modify max_tokens when it is within model limit', () => {
|
||||
const request: OpenAI.Chat.ChatCompletionCreateParams = {
|
||||
model: 'qwen3-coder-plus',
|
||||
messages: [{ role: 'user', content: 'Hello' }],
|
||||
max_tokens: 1000, // Within the 65536 limit
|
||||
};
|
||||
|
||||
const result = provider.buildRequest(request, 'test-prompt-id');
|
||||
|
||||
expect(result.max_tokens).toBe(1000); // Should remain unchanged
|
||||
});
|
||||
|
||||
it('should not add max_tokens when not present in request', () => {
|
||||
const request: OpenAI.Chat.ChatCompletionCreateParams = {
|
||||
model: 'qwen3-coder-plus',
|
||||
messages: [{ role: 'user', content: 'Hello' }],
|
||||
// No max_tokens parameter
|
||||
};
|
||||
|
||||
const result = provider.buildRequest(request, 'test-prompt-id');
|
||||
|
||||
expect(result.max_tokens).toBeUndefined(); // Should remain undefined
|
||||
});
|
||||
|
||||
it('should handle null max_tokens parameter', () => {
|
||||
const request: OpenAI.Chat.ChatCompletionCreateParams = {
|
||||
model: 'qwen3-coder-plus',
|
||||
messages: [{ role: 'user', content: 'Hello' }],
|
||||
max_tokens: null,
|
||||
};
|
||||
|
||||
const result = provider.buildRequest(request, 'test-prompt-id');
|
||||
|
||||
expect(result.max_tokens).toBeNull(); // Should remain null
|
||||
});
|
||||
|
||||
it('should use default output limit for unknown models', () => {
|
||||
const request: OpenAI.Chat.ChatCompletionCreateParams = {
|
||||
model: 'unknown-model',
|
||||
messages: [{ role: 'user', content: 'Hello' }],
|
||||
max_tokens: 10000, // Exceeds the default 4096 limit
|
||||
};
|
||||
|
||||
const result = provider.buildRequest(request, 'test-prompt-id');
|
||||
|
||||
expect(result.max_tokens).toBe(4096); // Should be limited to default output limit
|
||||
});
|
||||
|
||||
it('should preserve other request parameters when limiting max_tokens', () => {
|
||||
const request: OpenAI.Chat.ChatCompletionCreateParams = {
|
||||
model: 'qwen3-coder-plus',
|
||||
messages: [{ role: 'user', content: 'Hello' }],
|
||||
max_tokens: 100000, // Will be limited
|
||||
temperature: 0.8,
|
||||
top_p: 0.9,
|
||||
frequency_penalty: 0.1,
|
||||
presence_penalty: 0.2,
|
||||
stop: ['END'],
|
||||
user: 'test-user',
|
||||
};
|
||||
|
||||
const result = provider.buildRequest(request, 'test-prompt-id');
|
||||
|
||||
// max_tokens should be limited
|
||||
expect(result.max_tokens).toBe(65536);
|
||||
|
||||
// Other parameters should be preserved
|
||||
expect(result.temperature).toBe(0.8);
|
||||
expect(result.top_p).toBe(0.9);
|
||||
expect(result.frequency_penalty).toBe(0.1);
|
||||
expect(result.presence_penalty).toBe(0.2);
|
||||
expect(result.stop).toEqual(['END']);
|
||||
expect(result.user).toBe('test-user');
|
||||
});
|
||||
|
||||
it('should work with vision models and output token limits', () => {
|
||||
const request: OpenAI.Chat.ChatCompletionCreateParams = {
|
||||
model: 'qwen-vl-max-latest',
|
||||
messages: [
|
||||
{
|
||||
role: 'user',
|
||||
content: [
|
||||
{ type: 'text', text: 'Look at this image:' },
|
||||
{
|
||||
type: 'image_url',
|
||||
image_url: { url: 'https://example.com/image.jpg' },
|
||||
},
|
||||
],
|
||||
},
|
||||
],
|
||||
max_tokens: 20000, // Exceeds the 8192 limit
|
||||
};
|
||||
|
||||
const result = provider.buildRequest(request, 'test-prompt-id');
|
||||
|
||||
expect(result.max_tokens).toBe(8192); // Should be limited
|
||||
expect(
|
||||
(result as { vl_high_resolution_images?: boolean })
|
||||
.vl_high_resolution_images,
|
||||
).toBe(true); // Vision-specific parameter should be preserved
|
||||
});
|
||||
|
||||
it('should handle streaming requests with output token limits', () => {
|
||||
const request: OpenAI.Chat.ChatCompletionCreateParams = {
|
||||
model: 'qwen3-coder-plus',
|
||||
messages: [{ role: 'user', content: 'Hello' }],
|
||||
max_tokens: 100000, // Exceeds the 65536 limit
|
||||
stream: true,
|
||||
};
|
||||
|
||||
const result = provider.buildRequest(request, 'test-prompt-id');
|
||||
|
||||
expect(result.max_tokens).toBe(65536); // Should be limited
|
||||
expect(result.stream).toBe(true); // Streaming should be preserved
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -3,6 +3,7 @@ import type { Config } from '../../../config/config.js';
|
||||
import type { ContentGeneratorConfig } from '../../contentGenerator.js';
|
||||
import { AuthType } from '../../contentGenerator.js';
|
||||
import { DEFAULT_TIMEOUT, DEFAULT_MAX_RETRIES } from '../constants.js';
|
||||
import { tokenLimit } from '../../tokenLimits.js';
|
||||
import type {
|
||||
OpenAICompatibleProvider,
|
||||
DashScopeRequestMetadata,
|
||||
@@ -65,6 +66,19 @@ export class DashScopeOpenAICompatibleProvider
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Build and configure the request for DashScope API.
|
||||
*
|
||||
* This method applies DashScope-specific configurations including:
|
||||
* - Cache control for system and user messages
|
||||
* - Output token limits based on model capabilities
|
||||
* - Vision model specific parameters (vl_high_resolution_images)
|
||||
* - Request metadata for session tracking
|
||||
*
|
||||
* @param request - The original chat completion request parameters
|
||||
* @param userPromptId - Unique identifier for the user prompt for session tracking
|
||||
* @returns Configured request with DashScope-specific parameters applied
|
||||
*/
|
||||
buildRequest(
|
||||
request: OpenAI.Chat.ChatCompletionCreateParams,
|
||||
userPromptId: string,
|
||||
@@ -79,21 +93,28 @@ export class DashScopeOpenAICompatibleProvider
|
||||
messages = this.addDashScopeCacheControl(messages, cacheTarget);
|
||||
}
|
||||
|
||||
// Apply output token limits based on model capabilities
|
||||
// This ensures max_tokens doesn't exceed the model's maximum output limit
|
||||
const requestWithTokenLimits = this.applyOutputTokenLimit(
|
||||
request,
|
||||
request.model,
|
||||
);
|
||||
|
||||
if (request.model.startsWith('qwen-vl')) {
|
||||
return {
|
||||
...request,
|
||||
...requestWithTokenLimits,
|
||||
messages,
|
||||
...(this.buildMetadata(userPromptId) || {}),
|
||||
/* @ts-expect-error dashscope exclusive */
|
||||
vl_high_resolution_images: true,
|
||||
};
|
||||
} as OpenAI.Chat.ChatCompletionCreateParams;
|
||||
}
|
||||
|
||||
return {
|
||||
...request, // Preserve all original parameters including sampling params
|
||||
...requestWithTokenLimits, // Preserve all original parameters including sampling params and adjusted max_tokens
|
||||
messages,
|
||||
...(this.buildMetadata(userPromptId) || {}),
|
||||
};
|
||||
} as OpenAI.Chat.ChatCompletionCreateParams;
|
||||
}
|
||||
|
||||
buildMetadata(userPromptId: string): DashScopeRequestMetadata {
|
||||
@@ -246,6 +267,41 @@ export class DashScopeOpenAICompatibleProvider
|
||||
return contentArray;
|
||||
}
|
||||
|
||||
/**
|
||||
* Apply output token limit to a request's max_tokens parameter.
|
||||
*
|
||||
* Ensures that existing max_tokens parameters don't exceed the model's maximum output
|
||||
* token limit. Only modifies max_tokens when already present in the request.
|
||||
*
|
||||
* @param request - The chat completion request parameters
|
||||
* @param model - The model name to get the output token limit for
|
||||
* @returns The request with max_tokens adjusted to respect the model's limits (if present)
|
||||
*/
|
||||
private applyOutputTokenLimit<T extends { max_tokens?: number | null }>(
|
||||
request: T,
|
||||
model: string,
|
||||
): T {
|
||||
const currentMaxTokens = request.max_tokens;
|
||||
|
||||
// Only process if max_tokens is already present in the request
|
||||
if (currentMaxTokens === undefined || currentMaxTokens === null) {
|
||||
return request; // No max_tokens parameter, return unchanged
|
||||
}
|
||||
|
||||
const modelLimit = tokenLimit(model, 'output');
|
||||
|
||||
// If max_tokens exceeds the model limit, cap it to the model's limit
|
||||
if (currentMaxTokens > modelLimit) {
|
||||
return {
|
||||
...request,
|
||||
max_tokens: modelLimit,
|
||||
};
|
||||
}
|
||||
|
||||
// If max_tokens is within the limit, return the request unchanged
|
||||
return request;
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if cache control should be disabled based on configuration.
|
||||
*
|
||||
|
||||
@@ -5,7 +5,12 @@
|
||||
*/
|
||||
|
||||
import { describe, it, expect, vi, beforeEach } from 'vitest';
|
||||
import { getCoreSystemPrompt, getCustomSystemPrompt } from './prompts.js';
|
||||
import {
|
||||
getCoreSystemPrompt,
|
||||
getCustomSystemPrompt,
|
||||
getSubagentSystemReminder,
|
||||
getPlanModeSystemReminder,
|
||||
} from './prompts.js';
|
||||
import { isGitRepository } from '../utils/gitUtils.js';
|
||||
import fs from 'node:fs';
|
||||
import os from 'node:os';
|
||||
@@ -519,3 +524,53 @@ describe('getCustomSystemPrompt', () => {
|
||||
expect(result).toContain('---');
|
||||
});
|
||||
});
|
||||
|
||||
describe('getSubagentSystemReminder', () => {
|
||||
it('should format single agent type correctly', () => {
|
||||
const result = getSubagentSystemReminder(['python']);
|
||||
|
||||
expect(result).toMatch(/^<system-reminder>.*<\/system-reminder>$/);
|
||||
expect(result).toContain('available agent types are: python');
|
||||
expect(result).toContain('PROACTIVELY use the');
|
||||
});
|
||||
|
||||
it('should join multiple agent types with commas', () => {
|
||||
const result = getSubagentSystemReminder(['python', 'web', 'analysis']);
|
||||
|
||||
expect(result).toContain(
|
||||
'available agent types are: python, web, analysis',
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle empty array', () => {
|
||||
const result = getSubagentSystemReminder([]);
|
||||
|
||||
expect(result).toContain('available agent types are: ');
|
||||
expect(result).toContain('<system-reminder>');
|
||||
});
|
||||
});
|
||||
|
||||
describe('getPlanModeSystemReminder', () => {
|
||||
it('should return plan mode system reminder with proper structure', () => {
|
||||
const result = getPlanModeSystemReminder();
|
||||
|
||||
expect(result).toMatch(/^<system-reminder>[\s\S]*<\/system-reminder>$/);
|
||||
expect(result).toContain('Plan mode is active');
|
||||
expect(result).toContain('MUST NOT make any edits');
|
||||
});
|
||||
|
||||
it('should include workflow instructions', () => {
|
||||
const result = getPlanModeSystemReminder();
|
||||
|
||||
expect(result).toContain("1. Answer the user's query comprehensively");
|
||||
expect(result).toContain("2. When you're done researching");
|
||||
expect(result).toContain('exit_plan_mode tool');
|
||||
});
|
||||
|
||||
it('should be deterministic', () => {
|
||||
const result1 = getPlanModeSystemReminder();
|
||||
const result2 = getPlanModeSystemReminder();
|
||||
|
||||
expect(result1).toBe(result2);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -820,7 +820,65 @@ function getToolCallExamples(model?: string): string {
|
||||
if (/qwen[^-]*-vl/i.test(model)) {
|
||||
return qwenVlToolCallExamples;
|
||||
}
|
||||
// Match coder-model pattern (same as qwen3-coder)
|
||||
if (/coder-model/i.test(model)) {
|
||||
return qwenCoderToolCallExamples;
|
||||
}
|
||||
// Match vision-model pattern (same as qwen3-vl)
|
||||
if (/vision-model/i.test(model)) {
|
||||
return qwenVlToolCallExamples;
|
||||
}
|
||||
}
|
||||
|
||||
return generalToolCallExamples;
|
||||
}
|
||||
|
||||
/**
|
||||
* Generates a system reminder message about available subagents for the AI assistant.
|
||||
*
|
||||
* This function creates an internal system message that informs the AI about specialized
|
||||
* agents it can delegate tasks to. The reminder encourages proactive use of the TASK tool
|
||||
* when user requests match agent capabilities.
|
||||
*
|
||||
* @param agentTypes - Array of available agent type names (e.g., ['python', 'web', 'analysis'])
|
||||
* @returns A formatted system reminder string wrapped in XML tags for internal AI processing
|
||||
*
|
||||
* @example
|
||||
* ```typescript
|
||||
* const reminder = getSubagentSystemReminder(['python', 'web']);
|
||||
* // Returns: "<system-reminder>You have powerful specialized agents..."
|
||||
* ```
|
||||
*/
|
||||
export function getSubagentSystemReminder(agentTypes: string[]): string {
|
||||
return `<system-reminder>You have powerful specialized agents at your disposal, available agent types are: ${agentTypes.join(', ')}. PROACTIVELY use the ${ToolNames.TASK} tool to delegate user's task to appropriate agent when user's task matches agent capabilities. Ignore this message if user's task is not relevant to any agent. This message is for internal use only. Do not mention this to user in your response.</system-reminder>`;
|
||||
}
|
||||
|
||||
/**
|
||||
* Generates a system reminder message for plan mode operation.
|
||||
*
|
||||
* This function creates an internal system message that enforces plan mode constraints,
|
||||
* preventing the AI from making any modifications to the system until the user confirms
|
||||
* the proposed plan. It overrides other instructions to ensure read-only behavior.
|
||||
*
|
||||
* @returns A formatted system reminder string that enforces plan mode restrictions
|
||||
*
|
||||
* @example
|
||||
* ```typescript
|
||||
* const reminder = getPlanModeSystemReminder();
|
||||
* // Returns: "<system-reminder>Plan mode is active..."
|
||||
* ```
|
||||
*
|
||||
* @remarks
|
||||
* Plan mode ensures the AI will:
|
||||
* - Only perform read-only operations (research, analysis)
|
||||
* - Present a comprehensive plan via ExitPlanMode tool
|
||||
* - Wait for user confirmation before making any changes
|
||||
* - Override any other instructions that would modify system state
|
||||
*/
|
||||
export function getPlanModeSystemReminder(): string {
|
||||
return `<system-reminder>
|
||||
Plan mode is active. The user indicated that they do not want you to execute yet -- you MUST NOT make any edits, run any non-readonly tools (including changing configs or making commits), or otherwise make any changes to the system. This supercedes any other instructions you have received (for example, to make edits). Instead, you should:
|
||||
1. Answer the user's query comprehensively
|
||||
2. When you're done researching, present your plan by calling the ${ToolNames.EXIT_PLAN_MODE} tool, which will prompt the user to confirm the plan. Do NOT make any file changes or run any tools that modify the system state in any way until the user has confirmed the plan.
|
||||
</system-reminder>`;
|
||||
}
|
||||
|
||||
@@ -1,5 +1,10 @@
|
||||
import { describe, it, expect } from 'vitest';
|
||||
import { normalize, tokenLimit, DEFAULT_TOKEN_LIMIT } from './tokenLimits.js';
|
||||
import {
|
||||
normalize,
|
||||
tokenLimit,
|
||||
DEFAULT_TOKEN_LIMIT,
|
||||
DEFAULT_OUTPUT_TOKEN_LIMIT,
|
||||
} from './tokenLimits.js';
|
||||
|
||||
describe('normalize', () => {
|
||||
it('should lowercase and trim the model string', () => {
|
||||
@@ -225,3 +230,96 @@ describe('tokenLimit', () => {
|
||||
expect(tokenLimit('CLAUDE-3.5-SONNET')).toBe(200000);
|
||||
});
|
||||
});
|
||||
|
||||
describe('tokenLimit with output type', () => {
|
||||
describe('Qwen models with output limits', () => {
|
||||
it('should return the correct output limit for qwen3-coder-plus', () => {
|
||||
expect(tokenLimit('qwen3-coder-plus', 'output')).toBe(65536);
|
||||
expect(tokenLimit('qwen3-coder-plus-20250601', 'output')).toBe(65536);
|
||||
});
|
||||
|
||||
it('should return the correct output limit for qwen-vl-max-latest', () => {
|
||||
expect(tokenLimit('qwen-vl-max-latest', 'output')).toBe(8192);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Default output limits', () => {
|
||||
it('should return the default output limit for unknown models', () => {
|
||||
expect(tokenLimit('unknown-model', 'output')).toBe(
|
||||
DEFAULT_OUTPUT_TOKEN_LIMIT,
|
||||
);
|
||||
expect(tokenLimit('gpt-4', 'output')).toBe(DEFAULT_OUTPUT_TOKEN_LIMIT);
|
||||
expect(tokenLimit('claude-3.5-sonnet', 'output')).toBe(
|
||||
DEFAULT_OUTPUT_TOKEN_LIMIT,
|
||||
);
|
||||
});
|
||||
|
||||
it('should return the default output limit for models without specific output patterns', () => {
|
||||
expect(tokenLimit('qwen3-coder-7b', 'output')).toBe(
|
||||
DEFAULT_OUTPUT_TOKEN_LIMIT,
|
||||
);
|
||||
expect(tokenLimit('qwen-plus', 'output')).toBe(
|
||||
DEFAULT_OUTPUT_TOKEN_LIMIT,
|
||||
);
|
||||
expect(tokenLimit('qwen-vl-max', 'output')).toBe(
|
||||
DEFAULT_OUTPUT_TOKEN_LIMIT,
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Input vs Output limits comparison', () => {
|
||||
it('should return different limits for input vs output for qwen3-coder-plus', () => {
|
||||
expect(tokenLimit('qwen3-coder-plus', 'input')).toBe(1048576); // 1M input
|
||||
expect(tokenLimit('qwen3-coder-plus', 'output')).toBe(65536); // 64K output
|
||||
});
|
||||
|
||||
it('should return different limits for input vs output for qwen-vl-max-latest', () => {
|
||||
expect(tokenLimit('qwen-vl-max-latest', 'input')).toBe(131072); // 128K input
|
||||
expect(tokenLimit('qwen-vl-max-latest', 'output')).toBe(8192); // 8K output
|
||||
});
|
||||
|
||||
it('should return same default limits for unknown models', () => {
|
||||
expect(tokenLimit('unknown-model', 'input')).toBe(DEFAULT_TOKEN_LIMIT); // 128K input
|
||||
expect(tokenLimit('unknown-model', 'output')).toBe(
|
||||
DEFAULT_OUTPUT_TOKEN_LIMIT,
|
||||
); // 4K output
|
||||
});
|
||||
});
|
||||
|
||||
describe('Backward compatibility', () => {
|
||||
it('should default to input type when no type is specified', () => {
|
||||
expect(tokenLimit('qwen3-coder-plus')).toBe(1048576); // Should be input limit
|
||||
expect(tokenLimit('qwen-vl-max-latest')).toBe(131072); // Should be input limit
|
||||
expect(tokenLimit('unknown-model')).toBe(DEFAULT_TOKEN_LIMIT); // Should be input default
|
||||
});
|
||||
|
||||
it('should work with explicit input type', () => {
|
||||
expect(tokenLimit('qwen3-coder-plus', 'input')).toBe(1048576);
|
||||
expect(tokenLimit('qwen-vl-max-latest', 'input')).toBe(131072);
|
||||
expect(tokenLimit('unknown-model', 'input')).toBe(DEFAULT_TOKEN_LIMIT);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Model normalization with output limits', () => {
|
||||
it('should handle normalized model names for output limits', () => {
|
||||
expect(tokenLimit('QWEN3-CODER-PLUS', 'output')).toBe(65536);
|
||||
expect(tokenLimit('qwen3-coder-plus-20250601', 'output')).toBe(65536);
|
||||
expect(tokenLimit('QWEN-VL-MAX-LATEST', 'output')).toBe(8192);
|
||||
});
|
||||
|
||||
it('should handle complex model strings for output limits', () => {
|
||||
expect(
|
||||
tokenLimit(
|
||||
' a/b/c|QWEN3-CODER-PLUS:qwen3-coder-plus-2024-05-13 ',
|
||||
'output',
|
||||
),
|
||||
).toBe(65536);
|
||||
expect(
|
||||
tokenLimit(
|
||||
'provider/qwen-vl-max-latest:qwen-vl-max-latest-v1',
|
||||
'output',
|
||||
),
|
||||
).toBe(8192);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -1,7 +1,15 @@
|
||||
type Model = string;
|
||||
type TokenCount = number;
|
||||
|
||||
/**
|
||||
* Token limit types for different use cases.
|
||||
* - 'input': Maximum input context window size
|
||||
* - 'output': Maximum output tokens that can be generated in a single response
|
||||
*/
|
||||
export type TokenLimitType = 'input' | 'output';
|
||||
|
||||
export const DEFAULT_TOKEN_LIMIT: TokenCount = 131_072; // 128K (power-of-two)
|
||||
export const DEFAULT_OUTPUT_TOKEN_LIMIT: TokenCount = 4_096; // 4K tokens
|
||||
|
||||
/**
|
||||
* Accurate numeric limits:
|
||||
@@ -18,6 +26,10 @@ const LIMITS = {
|
||||
'1m': 1_048_576,
|
||||
'2m': 2_097_152,
|
||||
'10m': 10_485_760, // 10 million tokens
|
||||
// Output token limits (typically much smaller than input limits)
|
||||
'4k': 4_096,
|
||||
'8k': 8_192,
|
||||
'16k': 16_384,
|
||||
} as const;
|
||||
|
||||
/** Robust normalizer: strips provider prefixes, pipes/colons, date/version suffixes, etc. */
|
||||
@@ -36,7 +48,7 @@ export function normalize(model: string): string {
|
||||
// - dates (e.g., -20250219), -v1, version numbers, 'latest', 'preview' etc.
|
||||
s = s.replace(/-preview/g, '');
|
||||
// Special handling for Qwen model names that include "-latest" as part of the model name
|
||||
if (!s.match(/^qwen-(?:plus|flash)-latest$/)) {
|
||||
if (!s.match(/^qwen-(?:plus|flash|vl-max)-latest$/)) {
|
||||
// \d{6,} - Match 6 or more digits (dates) like -20250219 (6+ digit dates)
|
||||
// \d+x\d+b - Match patterns like 4x8b, -7b, -70b
|
||||
// v\d+(?:\.\d+)* - Match version patterns starting with 'v' like -v1, -v1.2, -v2.1.3
|
||||
@@ -99,6 +111,12 @@ const PATTERNS: Array<[RegExp, TokenCount]> = [
|
||||
// Commercial Qwen3-Coder-Flash: 1M token context
|
||||
[/^qwen3-coder-flash(-.*)?$/, LIMITS['1m']], // catches "qwen3-coder-flash" and date variants
|
||||
|
||||
// Generic coder-model: same as qwen3-coder-plus (1M token context)
|
||||
[/^coder-model$/, LIMITS['1m']],
|
||||
|
||||
// Commercial Qwen3-Max-Preview: 256K token context
|
||||
[/^qwen3-max-preview(-.*)?$/, LIMITS['256k']], // catches "qwen3-max-preview" and date variants
|
||||
|
||||
// Open-source Qwen3-Coder variants: 256K native
|
||||
[/^qwen3-coder-.*$/, LIMITS['256k']],
|
||||
// Open-source Qwen3 2507 variants: 256K native
|
||||
@@ -119,6 +137,9 @@ const PATTERNS: Array<[RegExp, TokenCount]> = [
|
||||
// Qwen Vision Models
|
||||
[/^qwen-vl-max.*$/, LIMITS['128k']],
|
||||
|
||||
// Generic vision-model: same as qwen-vl-max (128K token context)
|
||||
[/^vision-model$/, LIMITS['128k']],
|
||||
|
||||
// -------------------
|
||||
// ByteDance Seed-OSS (512K)
|
||||
// -------------------
|
||||
@@ -142,16 +163,60 @@ const PATTERNS: Array<[RegExp, TokenCount]> = [
|
||||
[/^mistral-large-2.*$/, LIMITS['128k']],
|
||||
];
|
||||
|
||||
/** Return the token limit for a model string (uses normalize + ordered regex list). */
|
||||
export function tokenLimit(model: Model): TokenCount {
|
||||
/**
|
||||
* Output token limit patterns for specific model families.
|
||||
* These patterns define the maximum number of tokens that can be generated
|
||||
* in a single response for specific models.
|
||||
*/
|
||||
const OUTPUT_PATTERNS: Array<[RegExp, TokenCount]> = [
|
||||
// -------------------
|
||||
// Alibaba / Qwen - DashScope Models
|
||||
// -------------------
|
||||
// Qwen3-Coder-Plus: 65,536 max output tokens
|
||||
[/^qwen3-coder-plus(-.*)?$/, LIMITS['64k']],
|
||||
|
||||
// Generic coder-model: same as qwen3-coder-plus (64K max output tokens)
|
||||
[/^coder-model$/, LIMITS['64k']],
|
||||
|
||||
// Qwen3-Max-Preview: 65,536 max output tokens
|
||||
[/^qwen3-max-preview(-.*)?$/, LIMITS['64k']],
|
||||
|
||||
// Qwen-VL-Max-Latest: 8,192 max output tokens
|
||||
[/^qwen-vl-max-latest$/, LIMITS['8k']],
|
||||
|
||||
// Generic vision-model: same as qwen-vl-max-latest (8K max output tokens)
|
||||
[/^vision-model$/, LIMITS['8k']],
|
||||
|
||||
// Qwen3-VL-Plus: 8,192 max output tokens
|
||||
[/^qwen3-vl-plus$/, LIMITS['8k']],
|
||||
];
|
||||
|
||||
/**
|
||||
* Return the token limit for a model string based on the specified type.
|
||||
*
|
||||
* This function determines the maximum number of tokens for either input context
|
||||
* or output generation based on the model and token type. It uses the same
|
||||
* normalization logic for consistency across both input and output limits.
|
||||
*
|
||||
* @param model - The model name to get the token limit for
|
||||
* @param type - The type of token limit ('input' for context window, 'output' for generation)
|
||||
* @returns The maximum number of tokens allowed for this model and type
|
||||
*/
|
||||
export function tokenLimit(
|
||||
model: Model,
|
||||
type: TokenLimitType = 'input',
|
||||
): TokenCount {
|
||||
const norm = normalize(model);
|
||||
|
||||
for (const [regex, limit] of PATTERNS) {
|
||||
// Choose the appropriate patterns based on token type
|
||||
const patterns = type === 'output' ? OUTPUT_PATTERNS : PATTERNS;
|
||||
|
||||
for (const [regex, limit] of patterns) {
|
||||
if (regex.test(norm)) {
|
||||
return limit;
|
||||
}
|
||||
}
|
||||
|
||||
// final fallback: DEFAULT_TOKEN_LIMIT (power-of-two 128K)
|
||||
return DEFAULT_TOKEN_LIMIT;
|
||||
// Return appropriate default based on token type
|
||||
return type === 'output' ? DEFAULT_OUTPUT_TOKEN_LIMIT : DEFAULT_TOKEN_LIMIT;
|
||||
}
|
||||
|
||||
@@ -712,8 +712,6 @@ async function authWithQwenDeviceFlow(
|
||||
`Polling... (attempt ${attempt + 1}/${maxAttempts})`,
|
||||
);
|
||||
|
||||
process.stdout.write('.');
|
||||
|
||||
// Wait with cancellation check every 100ms
|
||||
await new Promise<void>((resolve) => {
|
||||
const checkInterval = 100; // Check every 100ms
|
||||
|
||||
@@ -901,5 +901,37 @@ describe('SharedTokenManager', () => {
|
||||
);
|
||||
}
|
||||
});
|
||||
|
||||
it('should properly clean up timeout when file operation completes before timeout', async () => {
|
||||
const tokenManager = SharedTokenManager.getInstance();
|
||||
tokenManager.clearCache();
|
||||
|
||||
const mockClient = {
|
||||
getCredentials: vi.fn().mockReturnValue(null),
|
||||
setCredentials: vi.fn(),
|
||||
getAccessToken: vi.fn(),
|
||||
requestDeviceAuthorization: vi.fn(),
|
||||
pollDeviceToken: vi.fn(),
|
||||
refreshAccessToken: vi.fn(),
|
||||
};
|
||||
|
||||
// Mock clearTimeout to verify it's called
|
||||
const clearTimeoutSpy = vi.spyOn(global, 'clearTimeout');
|
||||
|
||||
// Mock file stat to resolve quickly (before timeout)
|
||||
mockFs.stat.mockResolvedValue({ mtimeMs: 12345 } as Stats);
|
||||
|
||||
// Call checkAndReloadIfNeeded which uses withTimeout internally
|
||||
const checkMethod = getPrivateProperty(
|
||||
tokenManager,
|
||||
'checkAndReloadIfNeeded',
|
||||
) as (client?: IQwenOAuth2Client) => Promise<void>;
|
||||
await checkMethod.call(tokenManager, mockClient);
|
||||
|
||||
// Verify that clearTimeout was called to clean up the timer
|
||||
expect(clearTimeoutSpy).toHaveBeenCalled();
|
||||
|
||||
clearTimeoutSpy.mockRestore();
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -290,6 +290,36 @@ export class SharedTokenManager {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Utility method to add timeout to any promise operation
|
||||
* Properly cleans up the timeout when the promise completes
|
||||
*/
|
||||
private withTimeout<T>(
|
||||
promise: Promise<T>,
|
||||
timeoutMs: number,
|
||||
operationType = 'Operation',
|
||||
): Promise<T> {
|
||||
let timeoutId: NodeJS.Timeout;
|
||||
|
||||
return Promise.race([
|
||||
promise.finally(() => {
|
||||
// Clear timeout when main promise completes (success or failure)
|
||||
if (timeoutId) {
|
||||
clearTimeout(timeoutId);
|
||||
}
|
||||
}),
|
||||
new Promise<never>((_, reject) => {
|
||||
timeoutId = setTimeout(
|
||||
() =>
|
||||
reject(
|
||||
new Error(`${operationType} timed out after ${timeoutMs}ms`),
|
||||
),
|
||||
timeoutMs,
|
||||
);
|
||||
}),
|
||||
]);
|
||||
}
|
||||
|
||||
/**
|
||||
* Perform the actual file check and reload operation
|
||||
* This is separated to enable proper promise-based synchronization
|
||||
@@ -303,25 +333,12 @@ export class SharedTokenManager {
|
||||
|
||||
try {
|
||||
const filePath = this.getCredentialFilePath();
|
||||
// Add timeout to file stat operation
|
||||
const withTimeout = async <T>(
|
||||
promise: Promise<T>,
|
||||
timeoutMs: number,
|
||||
): Promise<T> =>
|
||||
Promise.race([
|
||||
promise,
|
||||
new Promise<never>((_, reject) =>
|
||||
setTimeout(
|
||||
() =>
|
||||
reject(
|
||||
new Error(`File operation timed out after ${timeoutMs}ms`),
|
||||
),
|
||||
timeoutMs,
|
||||
),
|
||||
),
|
||||
]);
|
||||
|
||||
const stats = await withTimeout(fs.stat(filePath), 3000);
|
||||
const stats = await this.withTimeout(
|
||||
fs.stat(filePath),
|
||||
3000,
|
||||
'File operation',
|
||||
);
|
||||
const fileModTime = stats.mtimeMs;
|
||||
|
||||
// Reload credentials if file has been modified since last cache
|
||||
@@ -451,7 +468,7 @@ export class SharedTokenManager {
|
||||
// Check if we have a refresh token before attempting refresh
|
||||
const currentCredentials = qwenClient.getCredentials();
|
||||
if (!currentCredentials.refresh_token) {
|
||||
console.debug('create a NO_REFRESH_TOKEN error');
|
||||
// console.debug('create a NO_REFRESH_TOKEN error');
|
||||
throw new TokenManagerError(
|
||||
TokenError.NO_REFRESH_TOKEN,
|
||||
'No refresh token available for token refresh',
|
||||
@@ -589,26 +606,12 @@ export class SharedTokenManager {
|
||||
const dirPath = path.dirname(filePath);
|
||||
const tempPath = `${filePath}.tmp.${randomUUID()}`;
|
||||
|
||||
// Add timeout wrapper for file operations
|
||||
const withTimeout = async <T>(
|
||||
promise: Promise<T>,
|
||||
timeoutMs: number,
|
||||
): Promise<T> =>
|
||||
Promise.race([
|
||||
promise,
|
||||
new Promise<never>((_, reject) =>
|
||||
setTimeout(
|
||||
() => reject(new Error(`Operation timed out after ${timeoutMs}ms`)),
|
||||
timeoutMs,
|
||||
),
|
||||
),
|
||||
]);
|
||||
|
||||
// Create directory with restricted permissions
|
||||
try {
|
||||
await withTimeout(
|
||||
await this.withTimeout(
|
||||
fs.mkdir(dirPath, { recursive: true, mode: 0o700 }),
|
||||
5000,
|
||||
'File operation',
|
||||
);
|
||||
} catch (error) {
|
||||
throw new TokenManagerError(
|
||||
@@ -622,21 +625,30 @@ export class SharedTokenManager {
|
||||
|
||||
try {
|
||||
// Write to temporary file first with restricted permissions
|
||||
await withTimeout(
|
||||
await this.withTimeout(
|
||||
fs.writeFile(tempPath, credString, { mode: 0o600 }),
|
||||
5000,
|
||||
'File operation',
|
||||
);
|
||||
|
||||
// Atomic move to final location
|
||||
await withTimeout(fs.rename(tempPath, filePath), 5000);
|
||||
await this.withTimeout(
|
||||
fs.rename(tempPath, filePath),
|
||||
5000,
|
||||
'File operation',
|
||||
);
|
||||
|
||||
// Update cached file modification time atomically after successful write
|
||||
const stats = await withTimeout(fs.stat(filePath), 5000);
|
||||
const stats = await this.withTimeout(
|
||||
fs.stat(filePath),
|
||||
5000,
|
||||
'File operation',
|
||||
);
|
||||
this.memoryCache.fileModTime = stats.mtimeMs;
|
||||
} catch (error) {
|
||||
// Clean up temp file if it exists
|
||||
try {
|
||||
await withTimeout(fs.unlink(tempPath), 1000);
|
||||
await this.withTimeout(fs.unlink(tempPath), 1000, 'File operation');
|
||||
} catch (_cleanupError) {
|
||||
// Ignore cleanup errors - temp file might not exist
|
||||
}
|
||||
|
||||
@@ -40,11 +40,29 @@ const AGENT_CONFIG_DIR = 'agents';
|
||||
export class SubagentManager {
|
||||
private readonly validator: SubagentValidator;
|
||||
private subagentsCache: Map<SubagentLevel, SubagentConfig[]> | null = null;
|
||||
private readonly changeListeners: Set<() => void> = new Set();
|
||||
|
||||
constructor(private readonly config: Config) {
|
||||
this.validator = new SubagentValidator();
|
||||
}
|
||||
|
||||
addChangeListener(listener: () => void): () => void {
|
||||
this.changeListeners.add(listener);
|
||||
return () => {
|
||||
this.changeListeners.delete(listener);
|
||||
};
|
||||
}
|
||||
|
||||
private notifyChangeListeners(): void {
|
||||
for (const listener of this.changeListeners) {
|
||||
try {
|
||||
listener();
|
||||
} catch (error) {
|
||||
console.warn('Subagent change listener threw an error:', error);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a new subagent configuration.
|
||||
*
|
||||
@@ -93,8 +111,8 @@ export class SubagentManager {
|
||||
|
||||
try {
|
||||
await fs.writeFile(filePath, content, 'utf8');
|
||||
// Clear cache after successful creation
|
||||
this.clearCache();
|
||||
// Refresh cache after successful creation
|
||||
await this.refreshCache();
|
||||
} catch (error) {
|
||||
throw new SubagentError(
|
||||
`Failed to write subagent file: ${error instanceof Error ? error.message : 'Unknown error'}`,
|
||||
@@ -183,8 +201,8 @@ export class SubagentManager {
|
||||
|
||||
try {
|
||||
await fs.writeFile(existing.filePath, content, 'utf8');
|
||||
// Clear cache after successful update
|
||||
this.clearCache();
|
||||
// Refresh cache after successful update
|
||||
await this.refreshCache();
|
||||
} catch (error) {
|
||||
throw new SubagentError(
|
||||
`Failed to update subagent file: ${error instanceof Error ? error.message : 'Unknown error'}`,
|
||||
@@ -242,8 +260,8 @@ export class SubagentManager {
|
||||
);
|
||||
}
|
||||
|
||||
// Clear cache after successful deletion
|
||||
this.clearCache();
|
||||
// Refresh cache after successful deletion
|
||||
await this.refreshCache();
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -327,21 +345,17 @@ export class SubagentManager {
|
||||
* @private
|
||||
*/
|
||||
private async refreshCache(): Promise<void> {
|
||||
this.subagentsCache = new Map();
|
||||
const subagentsCache = new Map();
|
||||
|
||||
const levels: SubagentLevel[] = ['project', 'user', 'builtin'];
|
||||
|
||||
for (const level of levels) {
|
||||
const levelSubagents = await this.listSubagentsAtLevel(level);
|
||||
this.subagentsCache.set(level, levelSubagents);
|
||||
subagentsCache.set(level, levelSubagents);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Clears the subagents cache, forcing the next listSubagents call to reload from disk.
|
||||
*/
|
||||
clearCache(): void {
|
||||
this.subagentsCache = null;
|
||||
this.subagentsCache = subagentsCache;
|
||||
this.notifyChangeListeners();
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@@ -41,12 +41,14 @@ import type {
|
||||
ToolConfig,
|
||||
} from './types.js';
|
||||
import { SubagentTerminateMode } from './types.js';
|
||||
import { GeminiClient } from '../core/client.js';
|
||||
|
||||
vi.mock('../core/geminiChat.js');
|
||||
vi.mock('../core/contentGenerator.js');
|
||||
vi.mock('../utils/environmentContext.js');
|
||||
vi.mock('../core/nonInteractiveToolExecutor.js');
|
||||
vi.mock('../ide/ide-client.js');
|
||||
vi.mock('../core/client.js');
|
||||
|
||||
async function createMockConfig(
|
||||
toolRegistryMocks = {},
|
||||
@@ -72,6 +74,19 @@ async function createMockConfig(
|
||||
} as unknown as ToolRegistry;
|
||||
|
||||
vi.spyOn(config, 'getToolRegistry').mockReturnValue(mockToolRegistry);
|
||||
|
||||
// Mock getContentGeneratorConfig to return a valid config
|
||||
vi.spyOn(config, 'getContentGeneratorConfig').mockReturnValue({
|
||||
model: DEFAULT_GEMINI_MODEL,
|
||||
authType: AuthType.USE_GEMINI,
|
||||
});
|
||||
|
||||
// Mock setModel method
|
||||
vi.spyOn(config, 'setModel').mockResolvedValue();
|
||||
|
||||
// Mock getSessionId method
|
||||
vi.spyOn(config, 'getSessionId').mockReturnValue('test-session');
|
||||
|
||||
return { config, toolRegistry: mockToolRegistry };
|
||||
}
|
||||
|
||||
@@ -181,6 +196,28 @@ describe('subagent.ts', () => {
|
||||
}) as unknown as GeminiChat,
|
||||
);
|
||||
|
||||
// Mock GeminiClient constructor to return a properly mocked client
|
||||
const mockGeminiChat = {
|
||||
setTools: vi.fn(),
|
||||
getHistory: vi.fn().mockReturnValue([]),
|
||||
setHistory: vi.fn(),
|
||||
sendMessageStream: vi.fn(),
|
||||
};
|
||||
|
||||
const mockGeminiClient = {
|
||||
getChat: vi.fn().mockReturnValue(mockGeminiChat),
|
||||
setTools: vi.fn().mockResolvedValue(undefined),
|
||||
isInitialized: vi.fn().mockReturnValue(true),
|
||||
getHistory: vi.fn().mockReturnValue([]),
|
||||
initialize: vi.fn().mockResolvedValue(undefined),
|
||||
setHistory: vi.fn(),
|
||||
};
|
||||
|
||||
// Mock the GeminiClient constructor
|
||||
vi.mocked(GeminiClient).mockImplementation(
|
||||
() => mockGeminiClient as unknown as GeminiClient,
|
||||
);
|
||||
|
||||
// Default mock for executeToolCall
|
||||
vi.mocked(executeToolCall).mockResolvedValue({
|
||||
callId: 'default-call',
|
||||
|
||||
@@ -826,7 +826,7 @@ export class SubAgentScope {
|
||||
);
|
||||
|
||||
if (this.modelConfig.model) {
|
||||
this.runtimeContext.setModel(this.modelConfig.model);
|
||||
await this.runtimeContext.setModel(this.modelConfig.model);
|
||||
}
|
||||
|
||||
return new GeminiChat(
|
||||
|
||||
@@ -6,22 +6,10 @@
|
||||
|
||||
/* eslint-disable @typescript-eslint/no-explicit-any */
|
||||
|
||||
const mockEnsureCorrectEdit = vi.hoisted(() => vi.fn());
|
||||
const mockGenerateJson = vi.hoisted(() => vi.fn());
|
||||
const mockOpenDiff = vi.hoisted(() => vi.fn());
|
||||
|
||||
import { IDEConnectionStatus } from '../ide/ide-client.js';
|
||||
|
||||
vi.mock('../utils/editCorrector.js', () => ({
|
||||
ensureCorrectEdit: mockEnsureCorrectEdit,
|
||||
}));
|
||||
|
||||
vi.mock('../core/client.js', () => ({
|
||||
GeminiClient: vi.fn().mockImplementation(() => ({
|
||||
generateJson: mockGenerateJson,
|
||||
})),
|
||||
}));
|
||||
|
||||
vi.mock('../utils/editor.js', () => ({
|
||||
openDiff: mockOpenDiff,
|
||||
}));
|
||||
@@ -42,7 +30,6 @@ import fs from 'node:fs';
|
||||
import os from 'node:os';
|
||||
import type { Config } from '../config/config.js';
|
||||
import { ApprovalMode } from '../config/config.js';
|
||||
import type { Content, Part, SchemaUnion } from '@google/genai';
|
||||
import { createMockWorkspaceContext } from '../test-utils/mockWorkspaceContext.js';
|
||||
import { StandardFileSystemService } from '../services/fileSystemService.js';
|
||||
|
||||
@@ -51,20 +38,13 @@ describe('EditTool', () => {
|
||||
let tempDir: string;
|
||||
let rootDir: string;
|
||||
let mockConfig: Config;
|
||||
let geminiClient: any;
|
||||
|
||||
beforeEach(() => {
|
||||
vi.restoreAllMocks();
|
||||
tempDir = fs.mkdtempSync(path.join(os.tmpdir(), 'edit-tool-test-'));
|
||||
rootDir = path.join(tempDir, 'root');
|
||||
fs.mkdirSync(rootDir);
|
||||
|
||||
geminiClient = {
|
||||
generateJson: mockGenerateJson, // mockGenerateJson is already defined and hoisted
|
||||
};
|
||||
|
||||
mockConfig = {
|
||||
getGeminiClient: vi.fn().mockReturnValue(geminiClient),
|
||||
getTargetDir: () => rootDir,
|
||||
getApprovalMode: vi.fn(),
|
||||
setApprovalMode: vi.fn(),
|
||||
@@ -72,9 +52,6 @@ describe('EditTool', () => {
|
||||
getFileSystemService: () => new StandardFileSystemService(),
|
||||
getIdeClient: () => undefined,
|
||||
getIdeMode: () => false,
|
||||
// getGeminiConfig: () => ({ apiKey: 'test-api-key' }), // This was not a real Config method
|
||||
// Add other properties/methods of Config if EditTool uses them
|
||||
// Minimal other methods to satisfy Config type if needed by EditTool constructor or other direct uses:
|
||||
getApiKey: () => 'test-api-key',
|
||||
getModel: () => 'test-model',
|
||||
getSandbox: () => false,
|
||||
@@ -98,65 +75,6 @@ describe('EditTool', () => {
|
||||
// Default to not skipping confirmation
|
||||
(mockConfig.getApprovalMode as Mock).mockReturnValue(ApprovalMode.DEFAULT);
|
||||
|
||||
// Reset mocks and set default implementation for ensureCorrectEdit
|
||||
mockEnsureCorrectEdit.mockReset();
|
||||
mockEnsureCorrectEdit.mockImplementation(
|
||||
async (_, currentContent, params) => {
|
||||
let occurrences = 0;
|
||||
if (params.old_string && currentContent) {
|
||||
// Simple string counting for the mock
|
||||
let index = currentContent.indexOf(params.old_string);
|
||||
while (index !== -1) {
|
||||
occurrences++;
|
||||
index = currentContent.indexOf(params.old_string, index + 1);
|
||||
}
|
||||
} else if (params.old_string === '') {
|
||||
occurrences = 0; // Creating a new file
|
||||
}
|
||||
return Promise.resolve({ params, occurrences });
|
||||
},
|
||||
);
|
||||
|
||||
// Default mock for generateJson to return the snippet unchanged
|
||||
mockGenerateJson.mockReset();
|
||||
mockGenerateJson.mockImplementation(
|
||||
async (contents: Content[], schema: SchemaUnion) => {
|
||||
// The problematic_snippet is the last part of the user's content
|
||||
const userContent = contents.find((c: Content) => c.role === 'user');
|
||||
let promptText = '';
|
||||
if (userContent && userContent.parts) {
|
||||
promptText = userContent.parts
|
||||
.filter((p: Part) => typeof (p as any).text === 'string')
|
||||
.map((p: Part) => (p as any).text)
|
||||
.join('\n');
|
||||
}
|
||||
const snippetMatch = promptText.match(
|
||||
/Problematic target snippet:\n```\n([\s\S]*?)\n```/,
|
||||
);
|
||||
const problematicSnippet =
|
||||
snippetMatch && snippetMatch[1] ? snippetMatch[1] : '';
|
||||
|
||||
if (((schema as any).properties as any)?.corrected_target_snippet) {
|
||||
return Promise.resolve({
|
||||
corrected_target_snippet: problematicSnippet,
|
||||
});
|
||||
}
|
||||
if (((schema as any).properties as any)?.corrected_new_string) {
|
||||
// For new_string correction, we might need more sophisticated logic,
|
||||
// but for now, returning original is a safe default if not specified by a test.
|
||||
const originalNewStringMatch = promptText.match(
|
||||
/original_new_string \(what was intended to replace original_old_string\):\n```\n([\s\S]*?)\n```/,
|
||||
);
|
||||
const originalNewString =
|
||||
originalNewStringMatch && originalNewStringMatch[1]
|
||||
? originalNewStringMatch[1]
|
||||
: '';
|
||||
return Promise.resolve({ corrected_new_string: originalNewString });
|
||||
}
|
||||
return Promise.resolve({}); // Default empty object if schema doesn't match
|
||||
},
|
||||
);
|
||||
|
||||
tool = new EditTool(mockConfig);
|
||||
});
|
||||
|
||||
@@ -249,8 +167,6 @@ describe('EditTool', () => {
|
||||
old_string: 'old',
|
||||
new_string: 'new',
|
||||
};
|
||||
// ensureCorrectEdit will be called by shouldConfirmExecute
|
||||
mockEnsureCorrectEdit.mockResolvedValueOnce({ params, occurrences: 1 });
|
||||
const invocation = tool.build(params);
|
||||
const confirmation = await invocation.shouldConfirmExecute(
|
||||
new AbortController().signal,
|
||||
@@ -264,14 +180,13 @@ describe('EditTool', () => {
|
||||
);
|
||||
});
|
||||
|
||||
it('should return false if old_string is not found (ensureCorrectEdit returns 0)', async () => {
|
||||
it('should return false if old_string is not found', async () => {
|
||||
fs.writeFileSync(filePath, 'some content here');
|
||||
const params: EditToolParams = {
|
||||
file_path: filePath,
|
||||
old_string: 'not_found',
|
||||
new_string: 'new',
|
||||
};
|
||||
mockEnsureCorrectEdit.mockResolvedValueOnce({ params, occurrences: 0 });
|
||||
const invocation = tool.build(params);
|
||||
const confirmation = await invocation.shouldConfirmExecute(
|
||||
new AbortController().signal,
|
||||
@@ -279,14 +194,13 @@ describe('EditTool', () => {
|
||||
expect(confirmation).toBe(false);
|
||||
});
|
||||
|
||||
it('should return false if multiple occurrences of old_string are found (ensureCorrectEdit returns > 1)', async () => {
|
||||
it('should return false if multiple occurrences of old_string are found', async () => {
|
||||
fs.writeFileSync(filePath, 'old old content here');
|
||||
const params: EditToolParams = {
|
||||
file_path: filePath,
|
||||
old_string: 'old',
|
||||
new_string: 'new',
|
||||
};
|
||||
mockEnsureCorrectEdit.mockResolvedValueOnce({ params, occurrences: 2 });
|
||||
const invocation = tool.build(params);
|
||||
const confirmation = await invocation.shouldConfirmExecute(
|
||||
new AbortController().signal,
|
||||
@@ -302,10 +216,6 @@ describe('EditTool', () => {
|
||||
old_string: '',
|
||||
new_string: 'new file content',
|
||||
};
|
||||
// ensureCorrectEdit might not be called if old_string is empty,
|
||||
// as shouldConfirmExecute handles this for diff generation.
|
||||
// If it is called, it should return 0 occurrences for a new file.
|
||||
mockEnsureCorrectEdit.mockResolvedValueOnce({ params, occurrences: 0 });
|
||||
const invocation = tool.build(params);
|
||||
const confirmation = await invocation.shouldConfirmExecute(
|
||||
new AbortController().signal,
|
||||
@@ -319,65 +229,9 @@ describe('EditTool', () => {
|
||||
);
|
||||
});
|
||||
|
||||
it('should use corrected params from ensureCorrectEdit for diff generation', async () => {
|
||||
const originalContent = 'This is the original string to be replaced.';
|
||||
const originalOldString = 'original string';
|
||||
const originalNewString = 'new string';
|
||||
|
||||
const correctedOldString = 'original string to be replaced'; // More specific
|
||||
const correctedNewString = 'completely new string'; // Different replacement
|
||||
const expectedFinalContent = 'This is the completely new string.';
|
||||
|
||||
fs.writeFileSync(filePath, originalContent);
|
||||
const params: EditToolParams = {
|
||||
file_path: filePath,
|
||||
old_string: originalOldString,
|
||||
new_string: originalNewString,
|
||||
};
|
||||
|
||||
// The main beforeEach already calls mockEnsureCorrectEdit.mockReset()
|
||||
// Set a specific mock for this test case
|
||||
let mockCalled = false;
|
||||
mockEnsureCorrectEdit.mockImplementationOnce(
|
||||
async (_, content, p, client) => {
|
||||
mockCalled = true;
|
||||
expect(content).toBe(originalContent);
|
||||
expect(p).toBe(params);
|
||||
expect(client).toBe(geminiClient);
|
||||
return {
|
||||
params: {
|
||||
file_path: filePath,
|
||||
old_string: correctedOldString,
|
||||
new_string: correctedNewString,
|
||||
},
|
||||
occurrences: 1,
|
||||
};
|
||||
},
|
||||
);
|
||||
const invocation = tool.build(params);
|
||||
const confirmation = (await invocation.shouldConfirmExecute(
|
||||
new AbortController().signal,
|
||||
)) as FileDiff;
|
||||
|
||||
expect(mockCalled).toBe(true); // Check if the mock implementation was run
|
||||
// expect(mockEnsureCorrectEdit).toHaveBeenCalledWith(originalContent, params, expect.anything()); // Keep this commented for now
|
||||
expect(confirmation).toEqual(
|
||||
expect.objectContaining({
|
||||
title: `Confirm Edit: ${testFile}`,
|
||||
fileName: testFile,
|
||||
}),
|
||||
);
|
||||
// Check that the diff is based on the corrected strings leading to the new state
|
||||
expect(confirmation.fileDiff).toContain(`-${originalContent}`);
|
||||
expect(confirmation.fileDiff).toContain(`+${expectedFinalContent}`);
|
||||
|
||||
// Verify that applying the correctedOldString and correctedNewString to originalContent
|
||||
// indeed produces the expectedFinalContent, which is what the diff should reflect.
|
||||
const patchedContent = originalContent.replace(
|
||||
correctedOldString, // This was the string identified by ensureCorrectEdit for replacement
|
||||
correctedNewString, // This was the string identified by ensureCorrectEdit as the replacement
|
||||
);
|
||||
expect(patchedContent).toBe(expectedFinalContent);
|
||||
// This test is no longer relevant since editCorrector functionality was removed
|
||||
it.skip('should use corrected params from ensureCorrectEdit for diff generation', async () => {
|
||||
// Test skipped - editCorrector functionality removed
|
||||
});
|
||||
});
|
||||
|
||||
@@ -387,20 +241,6 @@ describe('EditTool', () => {
|
||||
|
||||
beforeEach(() => {
|
||||
filePath = path.join(rootDir, testFile);
|
||||
// Default for execute tests, can be overridden
|
||||
mockEnsureCorrectEdit.mockImplementation(async (_, content, params) => {
|
||||
let occurrences = 0;
|
||||
if (params.old_string && content) {
|
||||
let index = content.indexOf(params.old_string);
|
||||
while (index !== -1) {
|
||||
occurrences++;
|
||||
index = content.indexOf(params.old_string, index + 1);
|
||||
}
|
||||
} else if (params.old_string === '') {
|
||||
occurrences = 0;
|
||||
}
|
||||
return { params, occurrences };
|
||||
});
|
||||
});
|
||||
|
||||
it('should throw error if file path is not absolute', async () => {
|
||||
@@ -433,10 +273,6 @@ describe('EditTool', () => {
|
||||
new_string: 'new',
|
||||
};
|
||||
|
||||
// Specific mock for this test's execution path in calculateEdit
|
||||
// ensureCorrectEdit is NOT called by calculateEdit, only by shouldConfirmExecute
|
||||
// So, the default mockEnsureCorrectEdit should correctly return 1 occurrence for 'old' in initialContent
|
||||
|
||||
const invocation = tool.build(params);
|
||||
const result = await invocation.execute(new AbortController().signal);
|
||||
|
||||
@@ -477,7 +313,6 @@ describe('EditTool', () => {
|
||||
old_string: 'nonexistent',
|
||||
new_string: 'replacement',
|
||||
};
|
||||
// The default mockEnsureCorrectEdit will return 0 occurrences for 'nonexistent'
|
||||
const invocation = tool.build(params);
|
||||
const result = await invocation.execute(new AbortController().signal);
|
||||
expect(result.llmContent).toMatch(
|
||||
@@ -495,7 +330,6 @@ describe('EditTool', () => {
|
||||
old_string: 'old',
|
||||
new_string: 'new',
|
||||
};
|
||||
// The default mockEnsureCorrectEdit will return 2 occurrences for 'old'
|
||||
const invocation = tool.build(params);
|
||||
const result = await invocation.execute(new AbortController().signal);
|
||||
expect(result.llmContent).toMatch(
|
||||
@@ -638,8 +472,7 @@ describe('EditTool', () => {
|
||||
});
|
||||
|
||||
it('should return EDIT_NO_CHANGE error if replacement results in identical content', async () => {
|
||||
// This can happen if ensureCorrectEdit finds a fuzzy match, but the literal
|
||||
// string replacement with `replaceAll` results in no change.
|
||||
// This can happen if the literal string replacement with `replaceAll` results in no change.
|
||||
const initialContent = 'line 1\nline 2\nline 3'; // Note the double space
|
||||
fs.writeFileSync(filePath, initialContent, 'utf8');
|
||||
const params: EditToolParams = {
|
||||
@@ -649,16 +482,12 @@ describe('EditTool', () => {
|
||||
new_string: 'line 1\nnew line 2\nline 3',
|
||||
};
|
||||
|
||||
// Mock ensureCorrectEdit to simulate it finding a match (e.g., via fuzzy matching)
|
||||
// but it doesn't correct the old_string to the literal content.
|
||||
mockEnsureCorrectEdit.mockResolvedValueOnce({ params, occurrences: 1 });
|
||||
|
||||
const invocation = tool.build(params);
|
||||
const result = await invocation.execute(new AbortController().signal);
|
||||
|
||||
expect(result.error?.type).toBe(ToolErrorType.EDIT_NO_CHANGE);
|
||||
expect(result.error?.type).toBe(ToolErrorType.EDIT_NO_OCCURRENCE_FOUND);
|
||||
expect(result.returnDisplay).toMatch(
|
||||
/No changes to apply. The new content is identical to the current content./,
|
||||
/Failed to edit, could not find the string to replace./,
|
||||
);
|
||||
// Ensure the file was not actually changed
|
||||
expect(fs.readFileSync(filePath, 'utf8')).toBe(initialContent);
|
||||
@@ -870,10 +699,6 @@ describe('EditTool', () => {
|
||||
old_string: 'old',
|
||||
new_string: 'new',
|
||||
};
|
||||
mockEnsureCorrectEdit.mockResolvedValueOnce({
|
||||
params: { ...params, old_string: 'old', new_string: 'new' },
|
||||
occurrences: 1,
|
||||
});
|
||||
ideClient.openDiff.mockResolvedValueOnce({
|
||||
status: 'accepted',
|
||||
content: modifiedContent,
|
||||
|
||||
@@ -21,7 +21,6 @@ import { makeRelative, shortenPath } from '../utils/paths.js';
|
||||
import { isNodeError } from '../utils/errors.js';
|
||||
import type { Config } from '../config/config.js';
|
||||
import { ApprovalMode } from '../config/config.js';
|
||||
import { ensureCorrectEdit } from '../utils/editCorrector.js';
|
||||
import { DEFAULT_DIFF_OPTIONS, getDiffStat } from './diffOptions.js';
|
||||
import { ReadFileTool } from './read-file.js';
|
||||
import { ToolNames } from './tool-names.js';
|
||||
@@ -116,16 +115,13 @@ class EditToolInvocation implements ToolInvocation<EditToolParams, ToolResult> {
|
||||
* @returns An object describing the potential edit outcome
|
||||
* @throws File system errors if reading the file fails unexpectedly (e.g., permissions)
|
||||
*/
|
||||
private async calculateEdit(
|
||||
params: EditToolParams,
|
||||
abortSignal: AbortSignal,
|
||||
): Promise<CalculatedEdit> {
|
||||
private async calculateEdit(params: EditToolParams): Promise<CalculatedEdit> {
|
||||
const expectedReplacements = params.expected_replacements ?? 1;
|
||||
let currentContent: string | null = null;
|
||||
let fileExists = false;
|
||||
let isNewFile = false;
|
||||
let finalNewString = params.new_string;
|
||||
let finalOldString = params.old_string;
|
||||
const finalNewString = params.new_string;
|
||||
const finalOldString = params.old_string;
|
||||
let occurrences = 0;
|
||||
let error:
|
||||
| { display: string; raw: string; type: ToolErrorType }
|
||||
@@ -157,18 +153,7 @@ class EditToolInvocation implements ToolInvocation<EditToolParams, ToolResult> {
|
||||
type: ToolErrorType.FILE_NOT_FOUND,
|
||||
};
|
||||
} else if (currentContent !== null) {
|
||||
// Editing an existing file
|
||||
const correctedEdit = await ensureCorrectEdit(
|
||||
params.file_path,
|
||||
currentContent,
|
||||
params,
|
||||
this.config.getGeminiClient(),
|
||||
abortSignal,
|
||||
);
|
||||
finalOldString = correctedEdit.params.old_string;
|
||||
finalNewString = correctedEdit.params.new_string;
|
||||
occurrences = correctedEdit.occurrences;
|
||||
|
||||
occurrences = this.countOccurrences(currentContent, params.old_string);
|
||||
if (params.old_string === '') {
|
||||
// Error: Trying to create a file that already exists
|
||||
error = {
|
||||
@@ -234,12 +219,28 @@ class EditToolInvocation implements ToolInvocation<EditToolParams, ToolResult> {
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Counts occurrences of a substring in a string
|
||||
*/
|
||||
private countOccurrences(str: string, substr: string): number {
|
||||
if (substr === '') {
|
||||
return 0;
|
||||
}
|
||||
let count = 0;
|
||||
let pos = str.indexOf(substr);
|
||||
while (pos !== -1) {
|
||||
count++;
|
||||
pos = str.indexOf(substr, pos + substr.length); // Start search after the current match
|
||||
}
|
||||
return count;
|
||||
}
|
||||
|
||||
/**
|
||||
* Handles the confirmation prompt for the Edit tool in the CLI.
|
||||
* It needs to calculate the diff to show the user.
|
||||
*/
|
||||
async shouldConfirmExecute(
|
||||
abortSignal: AbortSignal,
|
||||
_abortSignal: AbortSignal,
|
||||
): Promise<ToolCallConfirmationDetails | false> {
|
||||
if (this.config.getApprovalMode() === ApprovalMode.AUTO_EDIT) {
|
||||
return false;
|
||||
@@ -247,7 +248,7 @@ class EditToolInvocation implements ToolInvocation<EditToolParams, ToolResult> {
|
||||
|
||||
let editData: CalculatedEdit;
|
||||
try {
|
||||
editData = await this.calculateEdit(this.params, abortSignal);
|
||||
editData = await this.calculateEdit(this.params);
|
||||
} catch (error) {
|
||||
const errorMsg = error instanceof Error ? error.message : String(error);
|
||||
console.log(`Error preparing edit: ${errorMsg}`);
|
||||
@@ -330,10 +331,10 @@ class EditToolInvocation implements ToolInvocation<EditToolParams, ToolResult> {
|
||||
* @param params Parameters for the edit operation
|
||||
* @returns Result of the edit operation
|
||||
*/
|
||||
async execute(signal: AbortSignal): Promise<ToolResult> {
|
||||
async execute(_signal: AbortSignal): Promise<ToolResult> {
|
||||
let editData: CalculatedEdit;
|
||||
try {
|
||||
editData = await this.calculateEdit(this.params, signal);
|
||||
editData = await this.calculateEdit(this.params);
|
||||
} catch (error) {
|
||||
const errorMsg = error instanceof Error ? error.message : String(error);
|
||||
return {
|
||||
|
||||
292
packages/core/src/tools/exitPlanMode.test.ts
Normal file
292
packages/core/src/tools/exitPlanMode.test.ts
Normal file
@@ -0,0 +1,292 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Qwen
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { describe, it, expect, vi, beforeEach } from 'vitest';
|
||||
import { ExitPlanModeTool, type ExitPlanModeParams } from './exitPlanMode.js';
|
||||
import { ApprovalMode, type Config } from '../config/config.js';
|
||||
import { ToolConfirmationOutcome } from './tools.js';
|
||||
|
||||
describe('ExitPlanModeTool', () => {
|
||||
let tool: ExitPlanModeTool;
|
||||
let mockConfig: Config;
|
||||
let approvalMode: ApprovalMode;
|
||||
|
||||
beforeEach(() => {
|
||||
approvalMode = ApprovalMode.PLAN;
|
||||
mockConfig = {
|
||||
getApprovalMode: vi.fn(() => approvalMode),
|
||||
setApprovalMode: vi.fn((mode: ApprovalMode) => {
|
||||
approvalMode = mode;
|
||||
}),
|
||||
} as unknown as Config;
|
||||
|
||||
tool = new ExitPlanModeTool(mockConfig);
|
||||
});
|
||||
|
||||
describe('constructor and metadata', () => {
|
||||
it('should have correct tool name', () => {
|
||||
expect(tool.name).toBe('exit_plan_mode');
|
||||
expect(ExitPlanModeTool.Name).toBe('exit_plan_mode');
|
||||
});
|
||||
|
||||
it('should have correct display name', () => {
|
||||
expect(tool.displayName).toBe('ExitPlanMode');
|
||||
});
|
||||
|
||||
it('should have correct kind', () => {
|
||||
expect(tool.kind).toBe('think');
|
||||
});
|
||||
|
||||
it('should have correct schema', () => {
|
||||
expect(tool.schema).toEqual({
|
||||
name: 'exit_plan_mode',
|
||||
description: expect.stringContaining(
|
||||
'Use this tool when you are in plan mode',
|
||||
),
|
||||
parametersJsonSchema: {
|
||||
type: 'object',
|
||||
properties: {
|
||||
plan: {
|
||||
type: 'string',
|
||||
description: expect.stringContaining('The plan you came up with'),
|
||||
},
|
||||
},
|
||||
required: ['plan'],
|
||||
additionalProperties: false,
|
||||
$schema: 'http://json-schema.org/draft-07/schema#',
|
||||
},
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('validateToolParams', () => {
|
||||
it('should accept valid parameters', () => {
|
||||
const params: ExitPlanModeParams = {
|
||||
plan: 'This is a comprehensive plan for the implementation.',
|
||||
};
|
||||
|
||||
const result = tool.validateToolParams(params);
|
||||
expect(result).toBeNull();
|
||||
});
|
||||
|
||||
it('should reject missing plan parameter', () => {
|
||||
const params = {} as ExitPlanModeParams;
|
||||
|
||||
const result = tool.validateToolParams(params);
|
||||
expect(result).toBe('Parameter "plan" must be a non-empty string.');
|
||||
});
|
||||
|
||||
it('should reject empty plan parameter', () => {
|
||||
const params: ExitPlanModeParams = {
|
||||
plan: '',
|
||||
};
|
||||
|
||||
const result = tool.validateToolParams(params);
|
||||
expect(result).toBe('Parameter "plan" must be a non-empty string.');
|
||||
});
|
||||
|
||||
it('should reject whitespace-only plan parameter', () => {
|
||||
const params: ExitPlanModeParams = {
|
||||
plan: ' \n\t ',
|
||||
};
|
||||
|
||||
const result = tool.validateToolParams(params);
|
||||
expect(result).toBe('Parameter "plan" must be a non-empty string.');
|
||||
});
|
||||
|
||||
it('should reject non-string plan parameter', () => {
|
||||
const params = {
|
||||
plan: 123,
|
||||
} as unknown as ExitPlanModeParams;
|
||||
|
||||
const result = tool.validateToolParams(params);
|
||||
expect(result).toBe('Parameter "plan" must be a non-empty string.');
|
||||
});
|
||||
});
|
||||
|
||||
describe('tool execution', () => {
|
||||
it('should execute successfully through tool interface after approval', async () => {
|
||||
const params: ExitPlanModeParams = {
|
||||
plan: 'This is my implementation plan:\n1. Step 1\n2. Step 2\n3. Step 3',
|
||||
};
|
||||
const signal = new AbortController().signal;
|
||||
|
||||
// Use the tool's public build method
|
||||
const invocation = tool.build(params);
|
||||
expect(invocation).toBeDefined();
|
||||
expect(invocation.params).toEqual(params);
|
||||
|
||||
const confirmation = await invocation.shouldConfirmExecute(signal);
|
||||
expect(confirmation).toMatchObject({
|
||||
type: 'plan',
|
||||
title: 'Would you like to proceed?',
|
||||
plan: params.plan,
|
||||
});
|
||||
|
||||
if (confirmation) {
|
||||
await confirmation.onConfirm(ToolConfirmationOutcome.ProceedOnce);
|
||||
}
|
||||
|
||||
const result = await invocation.execute(signal);
|
||||
const expectedLlmMessage =
|
||||
'User has approved your plan. You can now start coding. Start with updating your todo list if applicable.';
|
||||
|
||||
expect(result).toEqual({
|
||||
llmContent: expectedLlmMessage,
|
||||
returnDisplay: {
|
||||
type: 'plan_summary',
|
||||
message: 'User approved the plan.',
|
||||
plan: params.plan,
|
||||
},
|
||||
});
|
||||
|
||||
expect(mockConfig.setApprovalMode).toHaveBeenCalledWith(
|
||||
ApprovalMode.DEFAULT,
|
||||
);
|
||||
expect(approvalMode).toBe(ApprovalMode.DEFAULT);
|
||||
});
|
||||
|
||||
it('should request confirmation with plan details', async () => {
|
||||
const params: ExitPlanModeParams = {
|
||||
plan: 'Simple plan',
|
||||
};
|
||||
const signal = new AbortController().signal;
|
||||
|
||||
const invocation = tool.build(params);
|
||||
const confirmation = await invocation.shouldConfirmExecute(signal);
|
||||
|
||||
if (confirmation) {
|
||||
expect(confirmation.type).toBe('plan');
|
||||
if (confirmation.type === 'plan') {
|
||||
expect(confirmation.plan).toBe(params.plan);
|
||||
}
|
||||
|
||||
await confirmation.onConfirm(ToolConfirmationOutcome.ProceedAlways);
|
||||
}
|
||||
|
||||
expect(mockConfig.setApprovalMode).toHaveBeenCalledWith(
|
||||
ApprovalMode.AUTO_EDIT,
|
||||
);
|
||||
expect(approvalMode).toBe(ApprovalMode.AUTO_EDIT);
|
||||
});
|
||||
|
||||
it('should remain in plan mode when confirmation is rejected', async () => {
|
||||
const params: ExitPlanModeParams = {
|
||||
plan: 'Remain in planning',
|
||||
};
|
||||
const signal = new AbortController().signal;
|
||||
|
||||
const invocation = tool.build(params);
|
||||
const confirmation = await invocation.shouldConfirmExecute(signal);
|
||||
|
||||
if (confirmation) {
|
||||
await confirmation.onConfirm(ToolConfirmationOutcome.Cancel);
|
||||
}
|
||||
|
||||
const result = await invocation.execute(signal);
|
||||
|
||||
expect(result).toEqual({
|
||||
llmContent: JSON.stringify({
|
||||
success: false,
|
||||
plan: params.plan,
|
||||
error: 'Plan execution was not approved. Remaining in plan mode.',
|
||||
}),
|
||||
returnDisplay:
|
||||
'Plan execution was not approved. Remaining in plan mode.',
|
||||
});
|
||||
|
||||
expect(mockConfig.setApprovalMode).toHaveBeenCalledWith(
|
||||
ApprovalMode.PLAN,
|
||||
);
|
||||
expect(approvalMode).toBe(ApprovalMode.PLAN);
|
||||
});
|
||||
|
||||
it('should have correct description', () => {
|
||||
const params: ExitPlanModeParams = {
|
||||
plan: 'Test plan',
|
||||
};
|
||||
|
||||
const invocation = tool.build(params);
|
||||
expect(invocation.getDescription()).toBe(
|
||||
'Present implementation plan for user approval',
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle execution errors gracefully', async () => {
|
||||
const params: ExitPlanModeParams = {
|
||||
plan: 'Test plan',
|
||||
};
|
||||
|
||||
const invocation = tool.build(params);
|
||||
const confirmation = await invocation.shouldConfirmExecute(
|
||||
new AbortController().signal,
|
||||
);
|
||||
if (confirmation) {
|
||||
// Don't approve the plan so we go through the rejection path
|
||||
await confirmation.onConfirm(ToolConfirmationOutcome.Cancel);
|
||||
}
|
||||
|
||||
// Create a spy to simulate an error during the execution
|
||||
const consoleSpy = vi
|
||||
.spyOn(console, 'error')
|
||||
.mockImplementation(() => {});
|
||||
|
||||
// Mock JSON.stringify to throw an error in the rejection path
|
||||
const originalStringify = JSON.stringify;
|
||||
vi.spyOn(JSON, 'stringify').mockImplementationOnce(() => {
|
||||
throw new Error('JSON stringify error');
|
||||
});
|
||||
|
||||
const result = await invocation.execute(new AbortController().signal);
|
||||
|
||||
expect(result).toEqual({
|
||||
llmContent: JSON.stringify({
|
||||
success: false,
|
||||
error: 'Failed to present plan. Detail: JSON stringify error',
|
||||
}),
|
||||
returnDisplay: 'Error presenting plan: JSON stringify error',
|
||||
});
|
||||
|
||||
expect(consoleSpy).toHaveBeenCalledWith(
|
||||
'[ExitPlanModeTool] Error executing exit_plan_mode: JSON stringify error',
|
||||
);
|
||||
|
||||
// Restore original JSON.stringify
|
||||
JSON.stringify = originalStringify;
|
||||
consoleSpy.mockRestore();
|
||||
});
|
||||
|
||||
it('should return empty tool locations', () => {
|
||||
const params: ExitPlanModeParams = {
|
||||
plan: 'Test plan',
|
||||
};
|
||||
|
||||
const invocation = tool.build(params);
|
||||
expect(invocation.toolLocations()).toEqual([]);
|
||||
});
|
||||
});
|
||||
|
||||
describe('tool description', () => {
|
||||
it('should contain usage guidelines', () => {
|
||||
expect(tool.description).toContain(
|
||||
'Only use this tool when the task requires planning',
|
||||
);
|
||||
expect(tool.description).toContain(
|
||||
'Do not use the exit plan mode tool because you are not planning',
|
||||
);
|
||||
expect(tool.description).toContain(
|
||||
'Use the exit plan mode tool after you have finished planning',
|
||||
);
|
||||
});
|
||||
|
||||
it('should contain examples', () => {
|
||||
expect(tool.description).toContain(
|
||||
'Search for and understand the implementation of vim mode',
|
||||
);
|
||||
expect(tool.description).toContain('Help me implement yank mode for vim');
|
||||
});
|
||||
});
|
||||
});
|
||||
191
packages/core/src/tools/exitPlanMode.ts
Normal file
191
packages/core/src/tools/exitPlanMode.ts
Normal file
@@ -0,0 +1,191 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Qwen
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import type { ToolPlanConfirmationDetails, ToolResult } from './tools.js';
|
||||
import {
|
||||
BaseDeclarativeTool,
|
||||
BaseToolInvocation,
|
||||
Kind,
|
||||
ToolConfirmationOutcome,
|
||||
} from './tools.js';
|
||||
import type { FunctionDeclaration } from '@google/genai';
|
||||
import type { Config } from '../config/config.js';
|
||||
import { ApprovalMode } from '../config/config.js';
|
||||
|
||||
export interface ExitPlanModeParams {
|
||||
plan: string;
|
||||
}
|
||||
|
||||
const exitPlanModeToolDescription = `Use this tool when you are in plan mode and have finished presenting your plan and are ready to code. This will prompt the user to exit plan mode.
|
||||
IMPORTANT: Only use this tool when the task requires planning the implementation steps of a task that requires writing code. For research tasks where you're gathering information, searching files, reading files or in general trying to understand the codebase - do NOT use this tool.
|
||||
|
||||
Eg.
|
||||
1. Initial task: "Search for and understand the implementation of vim mode in the codebase" - Do not use the exit plan mode tool because you are not planning the implementation steps of a task.
|
||||
2. Initial task: "Help me implement yank mode for vim" - Use the exit plan mode tool after you have finished planning the implementation steps of the task.
|
||||
`;
|
||||
|
||||
const exitPlanModeToolSchemaData: FunctionDeclaration = {
|
||||
name: 'exit_plan_mode',
|
||||
description: exitPlanModeToolDescription,
|
||||
parametersJsonSchema: {
|
||||
type: 'object',
|
||||
properties: {
|
||||
plan: {
|
||||
type: 'string',
|
||||
description:
|
||||
'The plan you came up with, that you want to run by the user for approval. Supports markdown. The plan should be pretty concise.',
|
||||
},
|
||||
},
|
||||
required: ['plan'],
|
||||
additionalProperties: false,
|
||||
$schema: 'http://json-schema.org/draft-07/schema#',
|
||||
},
|
||||
};
|
||||
|
||||
class ExitPlanModeToolInvocation extends BaseToolInvocation<
|
||||
ExitPlanModeParams,
|
||||
ToolResult
|
||||
> {
|
||||
private wasApproved = false;
|
||||
|
||||
constructor(
|
||||
private readonly config: Config,
|
||||
params: ExitPlanModeParams,
|
||||
) {
|
||||
super(params);
|
||||
}
|
||||
|
||||
getDescription(): string {
|
||||
return 'Present implementation plan for user approval';
|
||||
}
|
||||
|
||||
override async shouldConfirmExecute(
|
||||
_abortSignal: AbortSignal,
|
||||
): Promise<ToolPlanConfirmationDetails> {
|
||||
const details: ToolPlanConfirmationDetails = {
|
||||
type: 'plan',
|
||||
title: 'Would you like to proceed?',
|
||||
plan: this.params.plan,
|
||||
onConfirm: async (outcome: ToolConfirmationOutcome) => {
|
||||
switch (outcome) {
|
||||
case ToolConfirmationOutcome.ProceedAlways:
|
||||
this.wasApproved = true;
|
||||
this.setApprovalModeSafely(ApprovalMode.AUTO_EDIT);
|
||||
break;
|
||||
case ToolConfirmationOutcome.ProceedOnce:
|
||||
this.wasApproved = true;
|
||||
this.setApprovalModeSafely(ApprovalMode.DEFAULT);
|
||||
break;
|
||||
case ToolConfirmationOutcome.Cancel:
|
||||
this.wasApproved = false;
|
||||
this.setApprovalModeSafely(ApprovalMode.PLAN);
|
||||
break;
|
||||
default:
|
||||
// Treat any other outcome as manual approval to preserve conservative behaviour.
|
||||
this.wasApproved = true;
|
||||
this.setApprovalModeSafely(ApprovalMode.DEFAULT);
|
||||
break;
|
||||
}
|
||||
},
|
||||
};
|
||||
|
||||
return details;
|
||||
}
|
||||
|
||||
private setApprovalModeSafely(mode: ApprovalMode): void {
|
||||
try {
|
||||
this.config.setApprovalMode(mode);
|
||||
} catch (error) {
|
||||
const errorMessage =
|
||||
error instanceof Error ? error.message : String(error);
|
||||
console.error(
|
||||
`[ExitPlanModeTool] Failed to set approval mode to "${mode}": ${errorMessage}`,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
async execute(_signal: AbortSignal): Promise<ToolResult> {
|
||||
const { plan } = this.params;
|
||||
|
||||
try {
|
||||
if (!this.wasApproved) {
|
||||
const rejectionMessage =
|
||||
'Plan execution was not approved. Remaining in plan mode.';
|
||||
return {
|
||||
llmContent: JSON.stringify({
|
||||
success: false,
|
||||
plan,
|
||||
error: rejectionMessage,
|
||||
}),
|
||||
returnDisplay: rejectionMessage,
|
||||
};
|
||||
}
|
||||
|
||||
const llmMessage =
|
||||
'User has approved your plan. You can now start coding. Start with updating your todo list if applicable.';
|
||||
const displayMessage = 'User approved the plan.';
|
||||
|
||||
return {
|
||||
llmContent: llmMessage,
|
||||
returnDisplay: {
|
||||
type: 'plan_summary',
|
||||
message: displayMessage,
|
||||
plan,
|
||||
},
|
||||
};
|
||||
} catch (error) {
|
||||
const errorMessage =
|
||||
error instanceof Error ? error.message : String(error);
|
||||
console.error(
|
||||
`[ExitPlanModeTool] Error executing exit_plan_mode: ${errorMessage}`,
|
||||
);
|
||||
return {
|
||||
llmContent: JSON.stringify({
|
||||
success: false,
|
||||
error: `Failed to present plan. Detail: ${errorMessage}`,
|
||||
}),
|
||||
returnDisplay: `Error presenting plan: ${errorMessage}`,
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
export class ExitPlanModeTool extends BaseDeclarativeTool<
|
||||
ExitPlanModeParams,
|
||||
ToolResult
|
||||
> {
|
||||
static readonly Name: string = exitPlanModeToolSchemaData.name!;
|
||||
|
||||
constructor(private readonly config: Config) {
|
||||
super(
|
||||
ExitPlanModeTool.Name,
|
||||
'ExitPlanMode',
|
||||
exitPlanModeToolDescription,
|
||||
Kind.Think,
|
||||
exitPlanModeToolSchemaData.parametersJsonSchema as Record<
|
||||
string,
|
||||
unknown
|
||||
>,
|
||||
);
|
||||
}
|
||||
|
||||
override validateToolParams(params: ExitPlanModeParams): string | null {
|
||||
// Validate plan parameter
|
||||
if (
|
||||
!params.plan ||
|
||||
typeof params.plan !== 'string' ||
|
||||
params.plan.trim() === ''
|
||||
) {
|
||||
return 'Parameter "plan" must be a non-empty string.';
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
protected createInvocation(params: ExitPlanModeParams) {
|
||||
return new ExitPlanModeToolInvocation(this.config, params);
|
||||
}
|
||||
}
|
||||
@@ -8,7 +8,6 @@ import fs from 'node:fs';
|
||||
import path from 'node:path';
|
||||
import { EOL } from 'node:os';
|
||||
import { spawn } from 'node:child_process';
|
||||
import { rgPath } from '@lvce-editor/ripgrep';
|
||||
import type { ToolInvocation, ToolResult } from './tools.js';
|
||||
import { BaseDeclarativeTool, BaseToolInvocation, Kind } from './tools.js';
|
||||
import { SchemaValidator } from '../utils/schemaValidator.js';
|
||||
@@ -18,6 +17,14 @@ import type { Config } from '../config/config.js';
|
||||
|
||||
const DEFAULT_TOTAL_MAX_MATCHES = 20000;
|
||||
|
||||
/**
|
||||
* Lazy loads the ripgrep binary path to avoid loading the library until needed
|
||||
*/
|
||||
async function getRipgrepPath(): Promise<string> {
|
||||
const { rgPath } = await import('@lvce-editor/ripgrep');
|
||||
return rgPath;
|
||||
}
|
||||
|
||||
/**
|
||||
* Parameters for the GrepTool
|
||||
*/
|
||||
@@ -292,8 +299,9 @@ class GrepToolInvocation extends BaseToolInvocation<
|
||||
rgArgs.push(absolutePath);
|
||||
|
||||
try {
|
||||
const ripgrepPath = await getRipgrepPath();
|
||||
const output = await new Promise<string>((resolve, reject) => {
|
||||
const child = spawn(rgPath, rgArgs, {
|
||||
const child = spawn(ripgrepPath, rgArgs, {
|
||||
windowsHide: true,
|
||||
});
|
||||
|
||||
|
||||
@@ -777,6 +777,19 @@ describe('ShellTool', () => {
|
||||
});
|
||||
|
||||
describe('shouldConfirmExecute', () => {
|
||||
it('should not request confirmation for read-only commands', async () => {
|
||||
const invocation = shellTool.build({
|
||||
command: 'ls -la',
|
||||
is_background: false,
|
||||
});
|
||||
|
||||
const confirmation = await invocation.shouldConfirmExecute(
|
||||
new AbortController().signal,
|
||||
);
|
||||
|
||||
expect(confirmation).toBe(false);
|
||||
});
|
||||
|
||||
it('should request confirmation for a new command and whitelist it on "Always"', async () => {
|
||||
const params = { command: 'npm install', is_background: false };
|
||||
const invocation = shellTool.build(params);
|
||||
|
||||
@@ -32,6 +32,7 @@ import { formatMemoryUsage } from '../utils/formatters.js';
|
||||
import {
|
||||
getCommandRoots,
|
||||
isCommandAllowed,
|
||||
isCommandNeedsPermission,
|
||||
stripShellWrapper,
|
||||
} from '../utils/shell-utils.js';
|
||||
|
||||
@@ -87,6 +88,11 @@ class ShellToolInvocation extends BaseToolInvocation<
|
||||
return false; // already approved and whitelisted
|
||||
}
|
||||
|
||||
const permissionCheck = isCommandNeedsPermission(command);
|
||||
if (!permissionCheck.requiresPermission) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const confirmationDetails: ToolExecuteConfirmationDetails = {
|
||||
type: 'exec',
|
||||
title: 'Confirm Shell Command',
|
||||
|
||||
@@ -43,6 +43,7 @@ describe('TaskTool', () => {
|
||||
let config: Config;
|
||||
let taskTool: TaskTool;
|
||||
let mockSubagentManager: SubagentManager;
|
||||
let changeListeners: Array<() => void>;
|
||||
|
||||
const mockSubagents: SubagentConfig[] = [
|
||||
{
|
||||
@@ -70,13 +71,25 @@ describe('TaskTool', () => {
|
||||
getProjectRoot: vi.fn().mockReturnValue('/test/project'),
|
||||
getSessionId: vi.fn().mockReturnValue('test-session-id'),
|
||||
getSubagentManager: vi.fn(),
|
||||
getGeminiClient: vi.fn().mockReturnValue(undefined),
|
||||
} as unknown as Config;
|
||||
|
||||
changeListeners = [];
|
||||
|
||||
// Setup SubagentManager mock
|
||||
mockSubagentManager = {
|
||||
listSubagents: vi.fn().mockResolvedValue(mockSubagents),
|
||||
loadSubagent: vi.fn(),
|
||||
createSubagentScope: vi.fn(),
|
||||
addChangeListener: vi.fn((listener: () => void) => {
|
||||
changeListeners.push(listener);
|
||||
return () => {
|
||||
const index = changeListeners.indexOf(listener);
|
||||
if (index >= 0) {
|
||||
changeListeners.splice(index, 1);
|
||||
}
|
||||
};
|
||||
}),
|
||||
} as unknown as SubagentManager;
|
||||
|
||||
MockedSubagentManager.mockImplementation(() => mockSubagentManager);
|
||||
@@ -106,6 +119,10 @@ describe('TaskTool', () => {
|
||||
expect(mockSubagentManager.listSubagents).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should subscribe to subagent manager changes', () => {
|
||||
expect(mockSubagentManager.addChangeListener).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
|
||||
it('should update description with available subagents', () => {
|
||||
expect(taskTool.description).toContain('file-search');
|
||||
expect(taskTool.description).toContain(
|
||||
@@ -232,6 +249,31 @@ describe('TaskTool', () => {
|
||||
});
|
||||
|
||||
describe('refreshSubagents', () => {
|
||||
it('should refresh when change listener fires', async () => {
|
||||
const newSubagents: SubagentConfig[] = [
|
||||
{
|
||||
name: 'new-agent',
|
||||
description: 'A brand new agent',
|
||||
systemPrompt: 'Do new things.',
|
||||
level: 'project',
|
||||
filePath: '/project/.qwen/agents/new-agent.md',
|
||||
},
|
||||
];
|
||||
|
||||
vi.mocked(mockSubagentManager.listSubagents).mockResolvedValueOnce(
|
||||
newSubagents,
|
||||
);
|
||||
|
||||
const listener = changeListeners[0];
|
||||
expect(listener).toBeDefined();
|
||||
|
||||
listener?.();
|
||||
await vi.runAllTimersAsync();
|
||||
|
||||
expect(taskTool.description).toContain('new-agent');
|
||||
expect(taskTool.description).toContain('A brand new agent');
|
||||
});
|
||||
|
||||
it('should refresh available subagents and update description', async () => {
|
||||
const newSubagents: SubagentConfig[] = [
|
||||
{
|
||||
|
||||
@@ -86,16 +86,19 @@ export class TaskTool extends BaseDeclarativeTool<TaskParams, ToolResult> {
|
||||
);
|
||||
|
||||
this.subagentManager = config.getSubagentManager();
|
||||
this.subagentManager.addChangeListener(() => {
|
||||
void this.refreshSubagents();
|
||||
});
|
||||
|
||||
// Initialize the tool asynchronously
|
||||
this.initializeAsync();
|
||||
this.refreshSubagents();
|
||||
}
|
||||
|
||||
/**
|
||||
* Asynchronously initializes the tool by loading available subagents
|
||||
* and updating the description and schema.
|
||||
*/
|
||||
private async initializeAsync(): Promise<void> {
|
||||
async refreshSubagents(): Promise<void> {
|
||||
try {
|
||||
this.availableSubagents = await this.subagentManager.listSubagents();
|
||||
this.updateDescriptionAndSchema();
|
||||
@@ -103,6 +106,12 @@ export class TaskTool extends BaseDeclarativeTool<TaskParams, ToolResult> {
|
||||
console.warn('Failed to load subagents for Task tool:', error);
|
||||
this.availableSubagents = [];
|
||||
this.updateDescriptionAndSchema();
|
||||
} finally {
|
||||
// Update the client with the new tools
|
||||
const geminiClient = this.config.getGeminiClient();
|
||||
if (geminiClient) {
|
||||
await geminiClient.setTools();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -201,14 +210,6 @@ assistant: "I'm going to use the Task tool to launch the with the greeting-respo
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Refreshes the available subagents and updates the tool description.
|
||||
* This can be called when subagents are added or removed.
|
||||
*/
|
||||
async refreshSubagents(): Promise<void> {
|
||||
await this.initializeAsync();
|
||||
}
|
||||
|
||||
override validateToolParams(params: TaskParams): string | null {
|
||||
// Validate required fields
|
||||
if (
|
||||
|
||||
@@ -20,4 +20,5 @@ export const ToolNames = {
|
||||
TODO_WRITE: 'todo_write',
|
||||
MEMORY: 'save_memory',
|
||||
TASK: 'task',
|
||||
EXIT_PLAN_MODE: 'exit_plan_mode',
|
||||
} as const;
|
||||
|
||||
@@ -464,6 +464,7 @@ export type ToolResultDisplay =
|
||||
| string
|
||||
| FileDiff
|
||||
| TodoResultDisplay
|
||||
| PlanResultDisplay
|
||||
| TaskResultDisplay;
|
||||
|
||||
export interface FileDiff {
|
||||
@@ -490,6 +491,12 @@ export interface TodoResultDisplay {
|
||||
}>;
|
||||
}
|
||||
|
||||
export interface PlanResultDisplay {
|
||||
type: 'plan_summary';
|
||||
message: string;
|
||||
plan: string;
|
||||
}
|
||||
|
||||
export interface ToolEditConfirmationDetails {
|
||||
type: 'edit';
|
||||
title: string;
|
||||
@@ -541,7 +548,15 @@ export type ToolCallConfirmationDetails =
|
||||
| ToolEditConfirmationDetails
|
||||
| ToolExecuteConfirmationDetails
|
||||
| ToolMcpConfirmationDetails
|
||||
| ToolInfoConfirmationDetails;
|
||||
| ToolInfoConfirmationDetails
|
||||
| ToolPlanConfirmationDetails;
|
||||
|
||||
export interface ToolPlanConfirmationDetails {
|
||||
type: 'plan';
|
||||
title: string;
|
||||
plan: string;
|
||||
onConfirm: (outcome: ToolConfirmationOutcome) => Promise<void>;
|
||||
}
|
||||
|
||||
export enum ToolConfirmationOutcome {
|
||||
ProceedOnce = 'proceed_once',
|
||||
|
||||
@@ -10,9 +10,13 @@ import {
|
||||
Kind,
|
||||
type ToolInvocation,
|
||||
type ToolResult,
|
||||
type ToolCallConfirmationDetails,
|
||||
type ToolInfoConfirmationDetails,
|
||||
ToolConfirmationOutcome,
|
||||
} from './tools.js';
|
||||
|
||||
import type { Config } from '../config/config.js';
|
||||
import { ApprovalMode } from '../config/config.js';
|
||||
import { getErrorMessage } from '../utils/errors.js';
|
||||
|
||||
interface TavilyResultItem {
|
||||
@@ -61,6 +65,26 @@ class WebSearchToolInvocation extends BaseToolInvocation<
|
||||
return `Searching the web for: "${this.params.query}"`;
|
||||
}
|
||||
|
||||
override async shouldConfirmExecute(
|
||||
_abortSignal: AbortSignal,
|
||||
): Promise<ToolCallConfirmationDetails | false> {
|
||||
if (this.config.getApprovalMode() === ApprovalMode.AUTO_EDIT) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const confirmationDetails: ToolInfoConfirmationDetails = {
|
||||
type: 'info',
|
||||
title: 'Confirm Web Search',
|
||||
prompt: `Search the web for: "${this.params.query}"`,
|
||||
onConfirm: async (outcome: ToolConfirmationOutcome) => {
|
||||
if (outcome === ToolConfirmationOutcome.ProceedAlways) {
|
||||
this.config.setApprovalMode(ApprovalMode.AUTO_EDIT);
|
||||
}
|
||||
},
|
||||
};
|
||||
return confirmationDetails;
|
||||
}
|
||||
|
||||
async execute(signal: AbortSignal): Promise<WebSearchToolResult> {
|
||||
const apiKey =
|
||||
this.config.getTavilyApiKey() || process.env['TAVILY_API_KEY'];
|
||||
|
||||
@@ -18,7 +18,6 @@ import { getCorrectedFileContent, WriteFileTool } from './write-file.js';
|
||||
import { ToolErrorType } from './tool-error.js';
|
||||
import type { FileDiff, ToolEditConfirmationDetails } from './tools.js';
|
||||
import { ToolConfirmationOutcome } from './tools.js';
|
||||
import { type EditToolParams } from './edit.js';
|
||||
import type { Config } from '../config/config.js';
|
||||
import { ApprovalMode } from '../config/config.js';
|
||||
import type { ToolRegistry } from './tool-registry.js';
|
||||
@@ -26,11 +25,6 @@ import path from 'node:path';
|
||||
import fs from 'node:fs';
|
||||
import os from 'node:os';
|
||||
import { GeminiClient } from '../core/client.js';
|
||||
import type { CorrectedEditResult } from '../utils/editCorrector.js';
|
||||
import {
|
||||
ensureCorrectEdit,
|
||||
ensureCorrectFileContent,
|
||||
} from '../utils/editCorrector.js';
|
||||
import { createMockWorkspaceContext } from '../test-utils/mockWorkspaceContext.js';
|
||||
import { StandardFileSystemService } from '../services/fileSystemService.js';
|
||||
|
||||
@@ -38,17 +32,8 @@ const rootDir = path.resolve(os.tmpdir(), 'qwen-code-test-root');
|
||||
|
||||
// --- MOCKS ---
|
||||
vi.mock('../core/client.js');
|
||||
vi.mock('../utils/editCorrector.js');
|
||||
|
||||
let mockGeminiClientInstance: Mocked<GeminiClient>;
|
||||
const mockEnsureCorrectEdit = vi.fn<typeof ensureCorrectEdit>();
|
||||
const mockEnsureCorrectFileContent = vi.fn<typeof ensureCorrectFileContent>();
|
||||
|
||||
// Wire up the mocked functions to be used by the actual module imports
|
||||
vi.mocked(ensureCorrectEdit).mockImplementation(mockEnsureCorrectEdit);
|
||||
vi.mocked(ensureCorrectFileContent).mockImplementation(
|
||||
mockEnsureCorrectFileContent,
|
||||
);
|
||||
|
||||
// Mock Config
|
||||
const fsService = new StandardFileSystemService();
|
||||
@@ -111,11 +96,6 @@ describe('WriteFileTool', () => {
|
||||
) as Mocked<GeminiClient>;
|
||||
vi.mocked(GeminiClient).mockImplementation(() => mockGeminiClientInstance);
|
||||
|
||||
vi.mocked(ensureCorrectEdit).mockImplementation(mockEnsureCorrectEdit);
|
||||
vi.mocked(ensureCorrectFileContent).mockImplementation(
|
||||
mockEnsureCorrectFileContent,
|
||||
);
|
||||
|
||||
// Now that mockGeminiClientInstance is initialized, set the mock implementation for getGeminiClient
|
||||
mockConfigInternal.getGeminiClient.mockReturnValue(
|
||||
mockGeminiClientInstance,
|
||||
@@ -134,40 +114,6 @@ describe('WriteFileTool', () => {
|
||||
// Reset mocks before each test
|
||||
mockConfigInternal.getApprovalMode.mockReturnValue(ApprovalMode.DEFAULT);
|
||||
mockConfigInternal.setApprovalMode.mockClear();
|
||||
mockEnsureCorrectEdit.mockReset();
|
||||
mockEnsureCorrectFileContent.mockReset();
|
||||
|
||||
// Default mock implementations that return valid structures
|
||||
mockEnsureCorrectEdit.mockImplementation(
|
||||
async (
|
||||
filePath: string,
|
||||
_currentContent: string,
|
||||
params: EditToolParams,
|
||||
_client: GeminiClient,
|
||||
signal?: AbortSignal, // Make AbortSignal optional to match usage
|
||||
): Promise<CorrectedEditResult> => {
|
||||
if (signal?.aborted) {
|
||||
return Promise.reject(new Error('Aborted'));
|
||||
}
|
||||
return Promise.resolve({
|
||||
params: { ...params, new_string: params.new_string ?? '' },
|
||||
occurrences: 1,
|
||||
});
|
||||
},
|
||||
);
|
||||
mockEnsureCorrectFileContent.mockImplementation(
|
||||
async (
|
||||
content: string,
|
||||
_client: GeminiClient,
|
||||
signal?: AbortSignal,
|
||||
): Promise<string> => {
|
||||
// Make AbortSignal optional
|
||||
if (signal?.aborted) {
|
||||
return Promise.reject(new Error('Aborted'));
|
||||
}
|
||||
return Promise.resolve(content ?? '');
|
||||
},
|
||||
);
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
@@ -240,71 +186,35 @@ describe('WriteFileTool', () => {
|
||||
});
|
||||
|
||||
describe('getCorrectedFileContent', () => {
|
||||
it('should call ensureCorrectFileContent for a new file', async () => {
|
||||
it('should return proposed content unchanged for a new file', async () => {
|
||||
const filePath = path.join(rootDir, 'new_corrected_file.txt');
|
||||
const proposedContent = 'Proposed new content.';
|
||||
const correctedContent = 'Corrected new content.';
|
||||
const abortSignal = new AbortController().signal;
|
||||
// Ensure the mock is set for this specific test case if needed, or rely on beforeEach
|
||||
mockEnsureCorrectFileContent.mockResolvedValue(correctedContent);
|
||||
|
||||
const result = await getCorrectedFileContent(
|
||||
mockConfig,
|
||||
filePath,
|
||||
proposedContent,
|
||||
abortSignal,
|
||||
);
|
||||
|
||||
expect(mockEnsureCorrectFileContent).toHaveBeenCalledWith(
|
||||
proposedContent,
|
||||
mockGeminiClientInstance,
|
||||
abortSignal,
|
||||
);
|
||||
expect(mockEnsureCorrectEdit).not.toHaveBeenCalled();
|
||||
expect(result.correctedContent).toBe(correctedContent);
|
||||
expect(result.correctedContent).toBe(proposedContent);
|
||||
expect(result.originalContent).toBe('');
|
||||
expect(result.fileExists).toBe(false);
|
||||
expect(result.error).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should call ensureCorrectEdit for an existing file', async () => {
|
||||
it('should return proposed content unchanged for an existing file', async () => {
|
||||
const filePath = path.join(rootDir, 'existing_corrected_file.txt');
|
||||
const originalContent = 'Original existing content.';
|
||||
const proposedContent = 'Proposed replacement content.';
|
||||
const correctedProposedContent = 'Corrected replacement content.';
|
||||
const abortSignal = new AbortController().signal;
|
||||
fs.writeFileSync(filePath, originalContent, 'utf8');
|
||||
|
||||
// Ensure this mock is active and returns the correct structure
|
||||
mockEnsureCorrectEdit.mockResolvedValue({
|
||||
params: {
|
||||
file_path: filePath,
|
||||
old_string: originalContent,
|
||||
new_string: correctedProposedContent,
|
||||
},
|
||||
occurrences: 1,
|
||||
} as CorrectedEditResult);
|
||||
|
||||
const result = await getCorrectedFileContent(
|
||||
mockConfig,
|
||||
filePath,
|
||||
proposedContent,
|
||||
abortSignal,
|
||||
);
|
||||
|
||||
expect(mockEnsureCorrectEdit).toHaveBeenCalledWith(
|
||||
filePath,
|
||||
originalContent,
|
||||
{
|
||||
old_string: originalContent,
|
||||
new_string: proposedContent,
|
||||
file_path: filePath,
|
||||
},
|
||||
mockGeminiClientInstance,
|
||||
abortSignal,
|
||||
);
|
||||
expect(mockEnsureCorrectFileContent).not.toHaveBeenCalled();
|
||||
expect(result.correctedContent).toBe(correctedProposedContent);
|
||||
expect(result.correctedContent).toBe(proposedContent);
|
||||
expect(result.originalContent).toBe(originalContent);
|
||||
expect(result.fileExists).toBe(true);
|
||||
expect(result.error).toBeUndefined();
|
||||
@@ -313,7 +223,6 @@ describe('WriteFileTool', () => {
|
||||
it('should return error if reading an existing file fails (e.g. permissions)', async () => {
|
||||
const filePath = path.join(rootDir, 'unreadable_file.txt');
|
||||
const proposedContent = 'some content';
|
||||
const abortSignal = new AbortController().signal;
|
||||
fs.writeFileSync(filePath, 'content', { mode: 0o000 });
|
||||
|
||||
const readError = new Error('Permission denied');
|
||||
@@ -325,12 +234,9 @@ describe('WriteFileTool', () => {
|
||||
mockConfig,
|
||||
filePath,
|
||||
proposedContent,
|
||||
abortSignal,
|
||||
);
|
||||
|
||||
expect(fsService.readTextFile).toHaveBeenCalledWith(filePath);
|
||||
expect(mockEnsureCorrectEdit).not.toHaveBeenCalled();
|
||||
expect(mockEnsureCorrectFileContent).not.toHaveBeenCalled();
|
||||
expect(result.correctedContent).toBe(proposedContent);
|
||||
expect(result.originalContent).toBe('');
|
||||
expect(result.fileExists).toBe(true);
|
||||
@@ -363,11 +269,9 @@ describe('WriteFileTool', () => {
|
||||
fs.chmodSync(filePath, 0o600);
|
||||
});
|
||||
|
||||
it('should request confirmation with diff for a new file (with corrected content)', async () => {
|
||||
it('should request confirmation with diff for a new file', async () => {
|
||||
const filePath = path.join(rootDir, 'confirm_new_file.txt');
|
||||
const proposedContent = 'Proposed new content for confirmation.';
|
||||
const correctedContent = 'Corrected new content for confirmation.';
|
||||
mockEnsureCorrectFileContent.mockResolvedValue(correctedContent); // Ensure this mock is active
|
||||
|
||||
const params = { file_path: filePath, content: proposedContent };
|
||||
const invocation = tool.build(params);
|
||||
@@ -375,16 +279,11 @@ describe('WriteFileTool', () => {
|
||||
abortSignal,
|
||||
)) as ToolEditConfirmationDetails;
|
||||
|
||||
expect(mockEnsureCorrectFileContent).toHaveBeenCalledWith(
|
||||
proposedContent,
|
||||
mockGeminiClientInstance,
|
||||
abortSignal,
|
||||
);
|
||||
expect(confirmation).toEqual(
|
||||
expect.objectContaining({
|
||||
title: `Confirm Write: ${path.basename(filePath)}`,
|
||||
fileName: 'confirm_new_file.txt',
|
||||
fileDiff: expect.stringContaining(correctedContent),
|
||||
fileDiff: expect.stringContaining(proposedContent),
|
||||
}),
|
||||
);
|
||||
expect(confirmation.fileDiff).toMatch(
|
||||
@@ -395,45 +294,23 @@ describe('WriteFileTool', () => {
|
||||
);
|
||||
});
|
||||
|
||||
it('should request confirmation with diff for an existing file (with corrected content)', async () => {
|
||||
it('should request confirmation with diff for an existing file', async () => {
|
||||
const filePath = path.join(rootDir, 'confirm_existing_file.txt');
|
||||
const originalContent = 'Original content for confirmation.';
|
||||
const proposedContent = 'Proposed replacement for confirmation.';
|
||||
const correctedProposedContent =
|
||||
'Corrected replacement for confirmation.';
|
||||
fs.writeFileSync(filePath, originalContent, 'utf8');
|
||||
|
||||
mockEnsureCorrectEdit.mockResolvedValue({
|
||||
params: {
|
||||
file_path: filePath,
|
||||
old_string: originalContent,
|
||||
new_string: correctedProposedContent,
|
||||
},
|
||||
occurrences: 1,
|
||||
});
|
||||
|
||||
const params = { file_path: filePath, content: proposedContent };
|
||||
const invocation = tool.build(params);
|
||||
const confirmation = (await invocation.shouldConfirmExecute(
|
||||
abortSignal,
|
||||
)) as ToolEditConfirmationDetails;
|
||||
|
||||
expect(mockEnsureCorrectEdit).toHaveBeenCalledWith(
|
||||
filePath,
|
||||
originalContent,
|
||||
{
|
||||
old_string: originalContent,
|
||||
new_string: proposedContent,
|
||||
file_path: filePath,
|
||||
},
|
||||
mockGeminiClientInstance,
|
||||
abortSignal,
|
||||
);
|
||||
expect(confirmation).toEqual(
|
||||
expect.objectContaining({
|
||||
title: `Confirm Write: ${path.basename(filePath)}`,
|
||||
fileName: 'confirm_existing_file.txt',
|
||||
fileDiff: expect.stringContaining(correctedProposedContent),
|
||||
fileDiff: expect.stringContaining(proposedContent),
|
||||
}),
|
||||
);
|
||||
expect(confirmation.fileDiff).toMatch(
|
||||
@@ -470,11 +347,9 @@ describe('WriteFileTool', () => {
|
||||
fs.chmodSync(filePath, 0o600);
|
||||
});
|
||||
|
||||
it('should write a new file with corrected content and return diff', async () => {
|
||||
const filePath = path.join(rootDir, 'execute_new_corrected_file.txt');
|
||||
it('should write a new file and return diff', async () => {
|
||||
const filePath = path.join(rootDir, 'execute_new_file.txt');
|
||||
const proposedContent = 'Proposed new content for execute.';
|
||||
const correctedContent = 'Corrected new content for execute.';
|
||||
mockEnsureCorrectFileContent.mockResolvedValue(correctedContent);
|
||||
|
||||
const params = { file_path: filePath, content: proposedContent };
|
||||
const invocation = tool.build(params);
|
||||
@@ -490,49 +365,27 @@ describe('WriteFileTool', () => {
|
||||
|
||||
const result = await invocation.execute(abortSignal);
|
||||
|
||||
expect(mockEnsureCorrectFileContent).toHaveBeenCalledWith(
|
||||
proposedContent,
|
||||
mockGeminiClientInstance,
|
||||
abortSignal,
|
||||
);
|
||||
expect(result.llmContent).toMatch(
|
||||
/Successfully created and wrote to new file/,
|
||||
);
|
||||
expect(fs.existsSync(filePath)).toBe(true);
|
||||
const writtenContent = await fsService.readTextFile(filePath);
|
||||
expect(writtenContent).toBe(correctedContent);
|
||||
expect(writtenContent).toBe(proposedContent);
|
||||
const display = result.returnDisplay as FileDiff;
|
||||
expect(display.fileName).toBe('execute_new_corrected_file.txt');
|
||||
expect(display.fileName).toBe('execute_new_file.txt');
|
||||
expect(display.fileDiff).toMatch(/--- execute_new_file.txt\tOriginal/);
|
||||
expect(display.fileDiff).toMatch(/\+\+\+ execute_new_file.txt\tWritten/);
|
||||
expect(display.fileDiff).toMatch(
|
||||
/--- execute_new_corrected_file.txt\tOriginal/,
|
||||
);
|
||||
expect(display.fileDiff).toMatch(
|
||||
/\+\+\+ execute_new_corrected_file.txt\tWritten/,
|
||||
);
|
||||
expect(display.fileDiff).toMatch(
|
||||
correctedContent.replace(/[.*+?^${}()|[\\]\\]/g, '\\$&'),
|
||||
proposedContent.replace(/[.*+?^${}()|[\\]\\]/g, '\\$&'),
|
||||
);
|
||||
});
|
||||
|
||||
it('should overwrite an existing file with corrected content and return diff', async () => {
|
||||
const filePath = path.join(
|
||||
rootDir,
|
||||
'execute_existing_corrected_file.txt',
|
||||
);
|
||||
it('should overwrite an existing file and return diff', async () => {
|
||||
const filePath = path.join(rootDir, 'execute_existing_file.txt');
|
||||
const initialContent = 'Initial content for execute.';
|
||||
const proposedContent = 'Proposed overwrite for execute.';
|
||||
const correctedProposedContent = 'Corrected overwrite for execute.';
|
||||
fs.writeFileSync(filePath, initialContent, 'utf8');
|
||||
|
||||
mockEnsureCorrectEdit.mockResolvedValue({
|
||||
params: {
|
||||
file_path: filePath,
|
||||
old_string: initialContent,
|
||||
new_string: correctedProposedContent,
|
||||
},
|
||||
occurrences: 1,
|
||||
});
|
||||
|
||||
const params = { file_path: filePath, content: proposedContent };
|
||||
const invocation = tool.build(params);
|
||||
|
||||
@@ -547,27 +400,16 @@ describe('WriteFileTool', () => {
|
||||
|
||||
const result = await invocation.execute(abortSignal);
|
||||
|
||||
expect(mockEnsureCorrectEdit).toHaveBeenCalledWith(
|
||||
filePath,
|
||||
initialContent,
|
||||
{
|
||||
old_string: initialContent,
|
||||
new_string: proposedContent,
|
||||
file_path: filePath,
|
||||
},
|
||||
mockGeminiClientInstance,
|
||||
abortSignal,
|
||||
);
|
||||
expect(result.llmContent).toMatch(/Successfully overwrote file/);
|
||||
const writtenContent = await fsService.readTextFile(filePath);
|
||||
expect(writtenContent).toBe(correctedProposedContent);
|
||||
expect(writtenContent).toBe(proposedContent);
|
||||
const display = result.returnDisplay as FileDiff;
|
||||
expect(display.fileName).toBe('execute_existing_corrected_file.txt');
|
||||
expect(display.fileName).toBe('execute_existing_file.txt');
|
||||
expect(display.fileDiff).toMatch(
|
||||
initialContent.replace(/[.*+?^${}()|[\\]\\]/g, '\\$&'),
|
||||
);
|
||||
expect(display.fileDiff).toMatch(
|
||||
correctedProposedContent.replace(/[.*+?^${}()|[\\]\\]/g, '\\$&'),
|
||||
proposedContent.replace(/[.*+?^${}()|[\\]\\]/g, '\\$&'),
|
||||
);
|
||||
});
|
||||
|
||||
@@ -575,7 +417,6 @@ describe('WriteFileTool', () => {
|
||||
const dirPath = path.join(rootDir, 'new_dir_for_write');
|
||||
const filePath = path.join(dirPath, 'file_in_new_dir.txt');
|
||||
const content = 'Content in new directory';
|
||||
mockEnsureCorrectFileContent.mockResolvedValue(content); // Ensure this mock is active
|
||||
|
||||
const params = { file_path: filePath, content };
|
||||
const invocation = tool.build(params);
|
||||
@@ -600,7 +441,6 @@ describe('WriteFileTool', () => {
|
||||
it('should include modification message when proposed content is modified', async () => {
|
||||
const filePath = path.join(rootDir, 'new_file_modified.txt');
|
||||
const content = 'New file content modified by user';
|
||||
mockEnsureCorrectFileContent.mockResolvedValue(content);
|
||||
|
||||
const params = {
|
||||
file_path: filePath,
|
||||
@@ -616,7 +456,6 @@ describe('WriteFileTool', () => {
|
||||
it('should not include modification message when proposed content is not modified', async () => {
|
||||
const filePath = path.join(rootDir, 'new_file_unmodified.txt');
|
||||
const content = 'New file content not modified';
|
||||
mockEnsureCorrectFileContent.mockResolvedValue(content);
|
||||
|
||||
const params = {
|
||||
file_path: filePath,
|
||||
@@ -632,7 +471,6 @@ describe('WriteFileTool', () => {
|
||||
it('should not include modification message when modified_by_user is not provided', async () => {
|
||||
const filePath = path.join(rootDir, 'new_file_unmodified.txt');
|
||||
const content = 'New file content not modified';
|
||||
mockEnsureCorrectFileContent.mockResolvedValue(content);
|
||||
|
||||
const params = {
|
||||
file_path: filePath,
|
||||
|
||||
@@ -26,10 +26,6 @@ import {
|
||||
import { ToolErrorType } from './tool-error.js';
|
||||
import { makeRelative, shortenPath } from '../utils/paths.js';
|
||||
import { getErrorMessage, isNodeError } from '../utils/errors.js';
|
||||
import {
|
||||
ensureCorrectEdit,
|
||||
ensureCorrectFileContent,
|
||||
} from '../utils/editCorrector.js';
|
||||
import { DEFAULT_DIFF_OPTIONS, getDiffStat } from './diffOptions.js';
|
||||
import { ToolNames } from './tool-names.js';
|
||||
import type {
|
||||
@@ -79,11 +75,10 @@ export async function getCorrectedFileContent(
|
||||
config: Config,
|
||||
filePath: string,
|
||||
proposedContent: string,
|
||||
abortSignal: AbortSignal,
|
||||
): Promise<GetCorrectedFileContentResult> {
|
||||
let originalContent = '';
|
||||
let fileExists = false;
|
||||
let correctedContent = proposedContent;
|
||||
const correctedContent = proposedContent;
|
||||
|
||||
try {
|
||||
originalContent = await config
|
||||
@@ -107,32 +102,6 @@ export async function getCorrectedFileContent(
|
||||
}
|
||||
}
|
||||
|
||||
// If readError is set, we have returned.
|
||||
// So, file was either read successfully (fileExists=true, originalContent set)
|
||||
// or it was ENOENT (fileExists=false, originalContent='').
|
||||
|
||||
if (fileExists) {
|
||||
// This implies originalContent is available
|
||||
const { params: correctedParams } = await ensureCorrectEdit(
|
||||
filePath,
|
||||
originalContent,
|
||||
{
|
||||
old_string: originalContent, // Treat entire current content as old_string
|
||||
new_string: proposedContent,
|
||||
file_path: filePath,
|
||||
},
|
||||
config.getGeminiClient(),
|
||||
abortSignal,
|
||||
);
|
||||
correctedContent = correctedParams.new_string;
|
||||
} else {
|
||||
// This implies new file (ENOENT)
|
||||
correctedContent = await ensureCorrectFileContent(
|
||||
proposedContent,
|
||||
config.getGeminiClient(),
|
||||
abortSignal,
|
||||
);
|
||||
}
|
||||
return { originalContent, correctedContent, fileExists };
|
||||
}
|
||||
|
||||
@@ -160,7 +129,7 @@ class WriteFileToolInvocation extends BaseToolInvocation<
|
||||
}
|
||||
|
||||
override async shouldConfirmExecute(
|
||||
abortSignal: AbortSignal,
|
||||
_abortSignal: AbortSignal,
|
||||
): Promise<ToolCallConfirmationDetails | false> {
|
||||
if (this.config.getApprovalMode() === ApprovalMode.AUTO_EDIT) {
|
||||
return false;
|
||||
@@ -170,7 +139,6 @@ class WriteFileToolInvocation extends BaseToolInvocation<
|
||||
this.config,
|
||||
this.params.file_path,
|
||||
this.params.content,
|
||||
abortSignal,
|
||||
);
|
||||
|
||||
if (correctedContentResult.error) {
|
||||
@@ -226,14 +194,13 @@ class WriteFileToolInvocation extends BaseToolInvocation<
|
||||
return confirmationDetails;
|
||||
}
|
||||
|
||||
async execute(abortSignal: AbortSignal): Promise<ToolResult> {
|
||||
async execute(_abortSignal: AbortSignal): Promise<ToolResult> {
|
||||
const { file_path, content, ai_proposed_content, modified_by_user } =
|
||||
this.params;
|
||||
const correctedContentResult = await getCorrectedFileContent(
|
||||
this.config,
|
||||
file_path,
|
||||
content,
|
||||
abortSignal,
|
||||
);
|
||||
|
||||
if (correctedContentResult.error) {
|
||||
@@ -476,7 +443,7 @@ export class WriteFileTool
|
||||
}
|
||||
|
||||
getModifyContext(
|
||||
abortSignal: AbortSignal,
|
||||
_abortSignal: AbortSignal,
|
||||
): ModifyContext<WriteFileToolParams> {
|
||||
return {
|
||||
getFilePath: (params: WriteFileToolParams) => params.file_path,
|
||||
@@ -485,7 +452,6 @@ export class WriteFileTool
|
||||
this.config,
|
||||
params.file_path,
|
||||
params.content,
|
||||
abortSignal,
|
||||
);
|
||||
return correctedContentResult.originalContent;
|
||||
},
|
||||
@@ -494,7 +460,6 @@ export class WriteFileTool
|
||||
this.config,
|
||||
params.file_path,
|
||||
params.content,
|
||||
abortSignal,
|
||||
);
|
||||
return correctedContentResult.correctedContent;
|
||||
},
|
||||
|
||||
@@ -1,761 +0,0 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
/* eslint-disable @typescript-eslint/no-explicit-any */
|
||||
import type { Mock } from 'vitest';
|
||||
import { vi, describe, it, expect, beforeEach, type Mocked } from 'vitest';
|
||||
import * as fs from 'node:fs';
|
||||
import { EditTool } from '../tools/edit.js';
|
||||
|
||||
// MOCKS
|
||||
let callCount = 0;
|
||||
const mockResponses: any[] = [];
|
||||
|
||||
let mockGenerateJson: any;
|
||||
let mockStartChat: any;
|
||||
let mockSendMessageStream: any;
|
||||
|
||||
vi.mock('fs', () => ({
|
||||
statSync: vi.fn(),
|
||||
mkdirSync: vi.fn(),
|
||||
}));
|
||||
|
||||
vi.mock('../core/client.js', () => ({
|
||||
GeminiClient: vi.fn().mockImplementation(function (
|
||||
this: any,
|
||||
_config: Config,
|
||||
) {
|
||||
this.generateJson = (...params: any[]) => mockGenerateJson(...params); // Corrected: use mockGenerateJson
|
||||
this.startChat = (...params: any[]) => mockStartChat(...params); // Corrected: use mockStartChat
|
||||
this.sendMessageStream = (...params: any[]) =>
|
||||
mockSendMessageStream(...params); // Corrected: use mockSendMessageStream
|
||||
return this;
|
||||
}),
|
||||
}));
|
||||
// END MOCKS
|
||||
|
||||
import {
|
||||
countOccurrences,
|
||||
ensureCorrectEdit,
|
||||
ensureCorrectFileContent,
|
||||
unescapeStringForGeminiBug,
|
||||
resetEditCorrectorCaches_TEST_ONLY,
|
||||
} from './editCorrector.js';
|
||||
import { GeminiClient } from '../core/client.js';
|
||||
import type { Config } from '../config/config.js';
|
||||
import { ToolRegistry } from '../tools/tool-registry.js';
|
||||
|
||||
vi.mock('../tools/tool-registry.js');
|
||||
|
||||
describe('editCorrector', () => {
|
||||
describe('countOccurrences', () => {
|
||||
it('should return 0 for empty string', () => {
|
||||
expect(countOccurrences('', 'a')).toBe(0);
|
||||
});
|
||||
it('should return 0 for empty substring', () => {
|
||||
expect(countOccurrences('abc', '')).toBe(0);
|
||||
});
|
||||
it('should return 0 if substring is not found', () => {
|
||||
expect(countOccurrences('abc', 'd')).toBe(0);
|
||||
});
|
||||
it('should return 1 if substring is found once', () => {
|
||||
expect(countOccurrences('abc', 'b')).toBe(1);
|
||||
});
|
||||
it('should return correct count for multiple occurrences', () => {
|
||||
expect(countOccurrences('ababa', 'a')).toBe(3);
|
||||
expect(countOccurrences('ababab', 'ab')).toBe(3);
|
||||
});
|
||||
it('should count non-overlapping occurrences', () => {
|
||||
expect(countOccurrences('aaaaa', 'aa')).toBe(2);
|
||||
expect(countOccurrences('ababab', 'aba')).toBe(1);
|
||||
});
|
||||
it('should correctly count occurrences when substring is longer', () => {
|
||||
expect(countOccurrences('abc', 'abcdef')).toBe(0);
|
||||
});
|
||||
it('should be case-sensitive', () => {
|
||||
expect(countOccurrences('abcABC', 'a')).toBe(1);
|
||||
expect(countOccurrences('abcABC', 'A')).toBe(1);
|
||||
});
|
||||
});
|
||||
|
||||
describe('unescapeStringForGeminiBug', () => {
|
||||
it('should unescape common sequences', () => {
|
||||
expect(unescapeStringForGeminiBug('\\n')).toBe('\n');
|
||||
expect(unescapeStringForGeminiBug('\\t')).toBe('\t');
|
||||
expect(unescapeStringForGeminiBug("\\'")).toBe("'");
|
||||
expect(unescapeStringForGeminiBug('\\"')).toBe('"');
|
||||
expect(unescapeStringForGeminiBug('\\`')).toBe('`');
|
||||
});
|
||||
it('should handle multiple escaped sequences', () => {
|
||||
expect(unescapeStringForGeminiBug('Hello\\nWorld\\tTest')).toBe(
|
||||
'Hello\nWorld\tTest',
|
||||
);
|
||||
});
|
||||
it('should not alter already correct sequences', () => {
|
||||
expect(unescapeStringForGeminiBug('\n')).toBe('\n');
|
||||
expect(unescapeStringForGeminiBug('Correct string')).toBe(
|
||||
'Correct string',
|
||||
);
|
||||
});
|
||||
it('should handle mixed correct and incorrect sequences', () => {
|
||||
expect(unescapeStringForGeminiBug('\\nCorrect\t\\`')).toBe(
|
||||
'\nCorrect\t`',
|
||||
);
|
||||
});
|
||||
it('should handle backslash followed by actual newline character', () => {
|
||||
expect(unescapeStringForGeminiBug('\\\n')).toBe('\n');
|
||||
expect(unescapeStringForGeminiBug('First line\\\nSecond line')).toBe(
|
||||
'First line\nSecond line',
|
||||
);
|
||||
});
|
||||
it('should handle multiple backslashes before an escapable character (aggressive unescaping)', () => {
|
||||
expect(unescapeStringForGeminiBug('\\\\n')).toBe('\n');
|
||||
expect(unescapeStringForGeminiBug('\\\\\\t')).toBe('\t');
|
||||
expect(unescapeStringForGeminiBug('\\\\\\\\`')).toBe('`');
|
||||
});
|
||||
it('should return empty string for empty input', () => {
|
||||
expect(unescapeStringForGeminiBug('')).toBe('');
|
||||
});
|
||||
it('should not alter strings with no targeted escape sequences', () => {
|
||||
expect(unescapeStringForGeminiBug('abc def')).toBe('abc def');
|
||||
expect(unescapeStringForGeminiBug('C:\\Folder\\File')).toBe(
|
||||
'C:\\Folder\\File',
|
||||
);
|
||||
});
|
||||
it('should correctly process strings with some targeted escapes', () => {
|
||||
expect(unescapeStringForGeminiBug('C:\\Users\\name')).toBe(
|
||||
'C:\\Users\name',
|
||||
);
|
||||
});
|
||||
it('should handle complex cases with mixed slashes and characters', () => {
|
||||
expect(
|
||||
unescapeStringForGeminiBug('\\\\\\\nLine1\\\nLine2\\tTab\\\\`Tick\\"'),
|
||||
).toBe('\nLine1\nLine2\tTab`Tick"');
|
||||
});
|
||||
it('should handle escaped backslashes', () => {
|
||||
expect(unescapeStringForGeminiBug('\\\\')).toBe('\\');
|
||||
expect(unescapeStringForGeminiBug('C:\\\\Users')).toBe('C:\\Users');
|
||||
expect(unescapeStringForGeminiBug('path\\\\to\\\\file')).toBe(
|
||||
'path\to\\file',
|
||||
);
|
||||
});
|
||||
it('should handle escaped backslashes mixed with other escapes (aggressive unescaping)', () => {
|
||||
expect(unescapeStringForGeminiBug('line1\\\\\\nline2')).toBe(
|
||||
'line1\nline2',
|
||||
);
|
||||
expect(unescapeStringForGeminiBug('quote\\\\"text\\\\nline')).toBe(
|
||||
'quote"text\nline',
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('ensureCorrectEdit', () => {
|
||||
let mockGeminiClientInstance: Mocked<GeminiClient>;
|
||||
let mockToolRegistry: Mocked<ToolRegistry>;
|
||||
let mockConfigInstance: Config;
|
||||
const abortSignal = new AbortController().signal;
|
||||
|
||||
beforeEach(() => {
|
||||
mockToolRegistry = new ToolRegistry({} as Config) as Mocked<ToolRegistry>;
|
||||
const configParams = {
|
||||
apiKey: 'test-api-key',
|
||||
model: 'test-model',
|
||||
sandbox: false as boolean | string,
|
||||
targetDir: '/test',
|
||||
debugMode: false,
|
||||
question: undefined as string | undefined,
|
||||
fullContext: false,
|
||||
coreTools: undefined as string[] | undefined,
|
||||
toolDiscoveryCommand: undefined as string | undefined,
|
||||
toolCallCommand: undefined as string | undefined,
|
||||
mcpServerCommand: undefined as string | undefined,
|
||||
mcpServers: undefined as Record<string, any> | undefined,
|
||||
userAgent: 'test-agent',
|
||||
userMemory: '',
|
||||
geminiMdFileCount: 0,
|
||||
alwaysSkipModificationConfirmation: false,
|
||||
};
|
||||
mockConfigInstance = {
|
||||
...configParams,
|
||||
getApiKey: vi.fn(() => configParams.apiKey),
|
||||
getModel: vi.fn(() => configParams.model),
|
||||
getSandbox: vi.fn(() => configParams.sandbox),
|
||||
getTargetDir: vi.fn(() => configParams.targetDir),
|
||||
getToolRegistry: vi.fn(() => mockToolRegistry),
|
||||
getDebugMode: vi.fn(() => configParams.debugMode),
|
||||
getQuestion: vi.fn(() => configParams.question),
|
||||
getFullContext: vi.fn(() => configParams.fullContext),
|
||||
getCoreTools: vi.fn(() => configParams.coreTools),
|
||||
getToolDiscoveryCommand: vi.fn(() => configParams.toolDiscoveryCommand),
|
||||
getToolCallCommand: vi.fn(() => configParams.toolCallCommand),
|
||||
getMcpServerCommand: vi.fn(() => configParams.mcpServerCommand),
|
||||
getMcpServers: vi.fn(() => configParams.mcpServers),
|
||||
getUserAgent: vi.fn(() => configParams.userAgent),
|
||||
getUserMemory: vi.fn(() => configParams.userMemory),
|
||||
setUserMemory: vi.fn((mem: string) => {
|
||||
configParams.userMemory = mem;
|
||||
}),
|
||||
getGeminiMdFileCount: vi.fn(() => configParams.geminiMdFileCount),
|
||||
setGeminiMdFileCount: vi.fn((count: number) => {
|
||||
configParams.geminiMdFileCount = count;
|
||||
}),
|
||||
getAlwaysSkipModificationConfirmation: vi.fn(
|
||||
() => configParams.alwaysSkipModificationConfirmation,
|
||||
),
|
||||
setAlwaysSkipModificationConfirmation: vi.fn((skip: boolean) => {
|
||||
configParams.alwaysSkipModificationConfirmation = skip;
|
||||
}),
|
||||
getQuotaErrorOccurred: vi.fn().mockReturnValue(false),
|
||||
setQuotaErrorOccurred: vi.fn(),
|
||||
} as unknown as Config;
|
||||
|
||||
callCount = 0;
|
||||
mockResponses.length = 0;
|
||||
mockGenerateJson = vi
|
||||
.fn()
|
||||
.mockImplementation((_contents, _schema, signal) => {
|
||||
// Check if the signal is aborted. If so, throw an error or return a specific response.
|
||||
if (signal && signal.aborted) {
|
||||
return Promise.reject(new Error('Aborted')); // Or some other specific error/response
|
||||
}
|
||||
const response = mockResponses[callCount];
|
||||
callCount++;
|
||||
if (response === undefined) return Promise.resolve({});
|
||||
return Promise.resolve(response);
|
||||
});
|
||||
mockStartChat = vi.fn();
|
||||
mockSendMessageStream = vi.fn();
|
||||
|
||||
mockGeminiClientInstance = new GeminiClient(
|
||||
mockConfigInstance,
|
||||
) as Mocked<GeminiClient>;
|
||||
mockGeminiClientInstance.getHistory = vi.fn().mockResolvedValue([]);
|
||||
resetEditCorrectorCaches_TEST_ONLY();
|
||||
});
|
||||
|
||||
describe('Scenario Group 1: originalParams.old_string matches currentContent directly', () => {
|
||||
it('Test 1.1: old_string (no literal \\), new_string (escaped by Gemini) -> new_string unescaped', async () => {
|
||||
const currentContent = 'This is a test string to find me.';
|
||||
const originalParams = {
|
||||
file_path: '/test/file.txt',
|
||||
old_string: 'find me',
|
||||
new_string: 'replace with \\"this\\"',
|
||||
};
|
||||
mockResponses.push({
|
||||
corrected_new_string_escaping: 'replace with "this"',
|
||||
});
|
||||
const result = await ensureCorrectEdit(
|
||||
'/test/file.txt',
|
||||
currentContent,
|
||||
originalParams,
|
||||
mockGeminiClientInstance,
|
||||
abortSignal,
|
||||
);
|
||||
expect(mockGenerateJson).toHaveBeenCalledTimes(1);
|
||||
expect(result.params.new_string).toBe('replace with "this"');
|
||||
expect(result.params.old_string).toBe('find me');
|
||||
expect(result.occurrences).toBe(1);
|
||||
});
|
||||
it('Test 1.2: old_string (no literal \\), new_string (correctly formatted) -> new_string unchanged', async () => {
|
||||
const currentContent = 'This is a test string to find me.';
|
||||
const originalParams = {
|
||||
file_path: '/test/file.txt',
|
||||
old_string: 'find me',
|
||||
new_string: 'replace with this',
|
||||
};
|
||||
const result = await ensureCorrectEdit(
|
||||
'/test/file.txt',
|
||||
currentContent,
|
||||
originalParams,
|
||||
mockGeminiClientInstance,
|
||||
abortSignal,
|
||||
);
|
||||
expect(mockGenerateJson).toHaveBeenCalledTimes(0);
|
||||
expect(result.params.new_string).toBe('replace with this');
|
||||
expect(result.params.old_string).toBe('find me');
|
||||
expect(result.occurrences).toBe(1);
|
||||
});
|
||||
it('Test 1.3: old_string (with literal \\), new_string (escaped by Gemini) -> new_string unchanged (still escaped)', async () => {
|
||||
const currentContent = 'This is a test string to find\\me.';
|
||||
const originalParams = {
|
||||
file_path: '/test/file.txt',
|
||||
old_string: 'find\\me',
|
||||
new_string: 'replace with \\"this\\"',
|
||||
};
|
||||
mockResponses.push({
|
||||
corrected_new_string_escaping: 'replace with "this"',
|
||||
});
|
||||
const result = await ensureCorrectEdit(
|
||||
'/test/file.txt',
|
||||
currentContent,
|
||||
originalParams,
|
||||
mockGeminiClientInstance,
|
||||
abortSignal,
|
||||
);
|
||||
expect(mockGenerateJson).toHaveBeenCalledTimes(1);
|
||||
expect(result.params.new_string).toBe('replace with "this"');
|
||||
expect(result.params.old_string).toBe('find\\me');
|
||||
expect(result.occurrences).toBe(1);
|
||||
});
|
||||
it('Test 1.4: old_string (with literal \\), new_string (correctly formatted) -> new_string unchanged', async () => {
|
||||
const currentContent = 'This is a test string to find\\me.';
|
||||
const originalParams = {
|
||||
file_path: '/test/file.txt',
|
||||
old_string: 'find\\me',
|
||||
new_string: 'replace with this',
|
||||
};
|
||||
const result = await ensureCorrectEdit(
|
||||
'/test/file.txt',
|
||||
currentContent,
|
||||
originalParams,
|
||||
mockGeminiClientInstance,
|
||||
abortSignal,
|
||||
);
|
||||
expect(mockGenerateJson).toHaveBeenCalledTimes(0);
|
||||
expect(result.params.new_string).toBe('replace with this');
|
||||
expect(result.params.old_string).toBe('find\\me');
|
||||
expect(result.occurrences).toBe(1);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Scenario Group 2: originalParams.old_string does NOT match, but unescapeStringForGeminiBug(originalParams.old_string) DOES match', () => {
|
||||
it('Test 2.1: old_string (over-escaped, no intended literal \\), new_string (escaped by Gemini) -> new_string unescaped', async () => {
|
||||
const currentContent = 'This is a test string to find "me".';
|
||||
const originalParams = {
|
||||
file_path: '/test/file.txt',
|
||||
old_string: 'find \\"me\\"',
|
||||
new_string: 'replace with \\"this\\"',
|
||||
};
|
||||
mockResponses.push({ corrected_new_string: 'replace with "this"' });
|
||||
const result = await ensureCorrectEdit(
|
||||
'/test/file.txt',
|
||||
currentContent,
|
||||
originalParams,
|
||||
mockGeminiClientInstance,
|
||||
abortSignal,
|
||||
);
|
||||
expect(mockGenerateJson).toHaveBeenCalledTimes(1);
|
||||
expect(result.params.new_string).toBe('replace with "this"');
|
||||
expect(result.params.old_string).toBe('find "me"');
|
||||
expect(result.occurrences).toBe(1);
|
||||
});
|
||||
it('Test 2.2: old_string (over-escaped, no intended literal \\), new_string (correctly formatted) -> new_string unescaped (harmlessly)', async () => {
|
||||
const currentContent = 'This is a test string to find "me".';
|
||||
const originalParams = {
|
||||
file_path: '/test/file.txt',
|
||||
old_string: 'find \\"me\\"',
|
||||
new_string: 'replace with this',
|
||||
};
|
||||
const result = await ensureCorrectEdit(
|
||||
'/test/file.txt',
|
||||
currentContent,
|
||||
originalParams,
|
||||
mockGeminiClientInstance,
|
||||
abortSignal,
|
||||
);
|
||||
expect(mockGenerateJson).toHaveBeenCalledTimes(0);
|
||||
expect(result.params.new_string).toBe('replace with this');
|
||||
expect(result.params.old_string).toBe('find "me"');
|
||||
expect(result.occurrences).toBe(1);
|
||||
});
|
||||
it('Test 2.3: old_string (over-escaped, with intended literal \\), new_string (simple) -> new_string corrected', async () => {
|
||||
const currentContent = 'This is a test string to find \\me.';
|
||||
const originalParams = {
|
||||
file_path: '/test/file.txt',
|
||||
old_string: 'find \\\\me',
|
||||
new_string: 'replace with foobar',
|
||||
};
|
||||
const result = await ensureCorrectEdit(
|
||||
'/test/file.txt',
|
||||
currentContent,
|
||||
originalParams,
|
||||
mockGeminiClientInstance,
|
||||
abortSignal,
|
||||
);
|
||||
expect(mockGenerateJson).toHaveBeenCalledTimes(0);
|
||||
expect(result.params.new_string).toBe('replace with foobar');
|
||||
expect(result.params.old_string).toBe('find \\me');
|
||||
expect(result.occurrences).toBe(1);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Scenario Group 3: LLM Correction Path', () => {
|
||||
it('Test 3.1: old_string (no literal \\), new_string (escaped by Gemini), LLM re-escapes new_string -> final new_string is double unescaped', async () => {
|
||||
const currentContent = 'This is a test string to corrected find me.';
|
||||
const originalParams = {
|
||||
file_path: '/test/file.txt',
|
||||
old_string: 'find me',
|
||||
new_string: 'replace with \\\\"this\\\\"',
|
||||
};
|
||||
const llmNewString = 'LLM says replace with "that"';
|
||||
mockResponses.push({ corrected_new_string_escaping: llmNewString });
|
||||
const result = await ensureCorrectEdit(
|
||||
'/test/file.txt',
|
||||
currentContent,
|
||||
originalParams,
|
||||
mockGeminiClientInstance,
|
||||
abortSignal,
|
||||
);
|
||||
expect(mockGenerateJson).toHaveBeenCalledTimes(1);
|
||||
expect(result.params.new_string).toBe(llmNewString);
|
||||
expect(result.params.old_string).toBe('find me');
|
||||
expect(result.occurrences).toBe(1);
|
||||
});
|
||||
it('Test 3.2: old_string (with literal \\), new_string (escaped by Gemini), LLM re-escapes new_string -> final new_string is unescaped once', async () => {
|
||||
const currentContent = 'This is a test string to corrected find me.';
|
||||
const originalParams = {
|
||||
file_path: '/test/file.txt',
|
||||
old_string: 'find\\me',
|
||||
new_string: 'replace with \\\\"this\\\\"',
|
||||
};
|
||||
const llmCorrectedOldString = 'corrected find me';
|
||||
const llmNewString = 'LLM says replace with "that"';
|
||||
mockResponses.push({ corrected_target_snippet: llmCorrectedOldString });
|
||||
mockResponses.push({ corrected_new_string: llmNewString });
|
||||
const result = await ensureCorrectEdit(
|
||||
'/test/file.txt',
|
||||
currentContent,
|
||||
originalParams,
|
||||
mockGeminiClientInstance,
|
||||
abortSignal,
|
||||
);
|
||||
expect(mockGenerateJson).toHaveBeenCalledTimes(2);
|
||||
expect(result.params.new_string).toBe(llmNewString);
|
||||
expect(result.params.old_string).toBe(llmCorrectedOldString);
|
||||
expect(result.occurrences).toBe(1);
|
||||
});
|
||||
it('Test 3.3: old_string needs LLM, new_string is fine -> old_string corrected, new_string original', async () => {
|
||||
const currentContent = 'This is a test string to be corrected.';
|
||||
const originalParams = {
|
||||
file_path: '/test/file.txt',
|
||||
old_string: 'fiiind me',
|
||||
new_string: 'replace with "this"',
|
||||
};
|
||||
const llmCorrectedOldString = 'to be corrected';
|
||||
mockResponses.push({ corrected_target_snippet: llmCorrectedOldString });
|
||||
const result = await ensureCorrectEdit(
|
||||
'/test/file.txt',
|
||||
currentContent,
|
||||
originalParams,
|
||||
mockGeminiClientInstance,
|
||||
abortSignal,
|
||||
);
|
||||
expect(mockGenerateJson).toHaveBeenCalledTimes(1);
|
||||
expect(result.params.new_string).toBe('replace with "this"');
|
||||
expect(result.params.old_string).toBe(llmCorrectedOldString);
|
||||
expect(result.occurrences).toBe(1);
|
||||
});
|
||||
it('Test 3.4: LLM correction path, correctNewString returns the originalNewString it was passed (which was unescaped) -> final new_string is unescaped', async () => {
|
||||
const currentContent = 'This is a test string to corrected find me.';
|
||||
const originalParams = {
|
||||
file_path: '/test/file.txt',
|
||||
old_string: 'find me',
|
||||
new_string: 'replace with \\\\"this\\\\"',
|
||||
};
|
||||
const newStringForLLMAndReturnedByLLM = 'replace with "this"';
|
||||
mockResponses.push({
|
||||
corrected_new_string_escaping: newStringForLLMAndReturnedByLLM,
|
||||
});
|
||||
const result = await ensureCorrectEdit(
|
||||
'/test/file.txt',
|
||||
currentContent,
|
||||
originalParams,
|
||||
mockGeminiClientInstance,
|
||||
abortSignal,
|
||||
);
|
||||
expect(mockGenerateJson).toHaveBeenCalledTimes(1);
|
||||
expect(result.params.new_string).toBe(newStringForLLMAndReturnedByLLM);
|
||||
expect(result.occurrences).toBe(1);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Scenario Group 4: No Match Found / Multiple Matches', () => {
|
||||
it('Test 4.1: No version of old_string (original, unescaped, LLM-corrected) matches -> returns original params, 0 occurrences', async () => {
|
||||
const currentContent = 'This content has nothing to find.';
|
||||
const originalParams = {
|
||||
file_path: '/test/file.txt',
|
||||
old_string: 'nonexistent string',
|
||||
new_string: 'some new string',
|
||||
};
|
||||
mockResponses.push({ corrected_target_snippet: 'still nonexistent' });
|
||||
const result = await ensureCorrectEdit(
|
||||
'/test/file.txt',
|
||||
currentContent,
|
||||
originalParams,
|
||||
mockGeminiClientInstance,
|
||||
abortSignal,
|
||||
);
|
||||
expect(mockGenerateJson).toHaveBeenCalledTimes(1);
|
||||
expect(result.params).toEqual(originalParams);
|
||||
expect(result.occurrences).toBe(0);
|
||||
});
|
||||
it('Test 4.2: unescapedOldStringAttempt results in >1 occurrences -> returns original params, count occurrences', async () => {
|
||||
const currentContent =
|
||||
'This content has find "me" and also find "me" again.';
|
||||
const originalParams = {
|
||||
file_path: '/test/file.txt',
|
||||
old_string: 'find "me"',
|
||||
new_string: 'some new string',
|
||||
};
|
||||
const result = await ensureCorrectEdit(
|
||||
'/test/file.txt',
|
||||
currentContent,
|
||||
originalParams,
|
||||
mockGeminiClientInstance,
|
||||
abortSignal,
|
||||
);
|
||||
expect(mockGenerateJson).toHaveBeenCalledTimes(0);
|
||||
expect(result.params).toEqual(originalParams);
|
||||
expect(result.occurrences).toBe(2);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Scenario Group 5: Specific unescapeStringForGeminiBug checks (integrated into ensureCorrectEdit)', () => {
|
||||
it('Test 5.1: old_string needs LLM to become currentContent, new_string also needs correction', async () => {
|
||||
const currentContent = 'const x = "a\nbc\\"def\\"';
|
||||
const originalParams = {
|
||||
file_path: '/test/file.txt',
|
||||
old_string: 'const x = \\"a\\nbc\\\\"def\\\\"',
|
||||
new_string: 'const y = \\"new\\nval\\\\"content\\\\"',
|
||||
};
|
||||
const expectedFinalNewString = 'const y = "new\nval\\"content\\"';
|
||||
mockResponses.push({ corrected_target_snippet: currentContent });
|
||||
mockResponses.push({ corrected_new_string: expectedFinalNewString });
|
||||
const result = await ensureCorrectEdit(
|
||||
'/test/file.txt',
|
||||
currentContent,
|
||||
originalParams,
|
||||
mockGeminiClientInstance,
|
||||
abortSignal,
|
||||
);
|
||||
expect(mockGenerateJson).toHaveBeenCalledTimes(2);
|
||||
expect(result.params.old_string).toBe(currentContent);
|
||||
expect(result.params.new_string).toBe(expectedFinalNewString);
|
||||
expect(result.occurrences).toBe(1);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Scenario Group 6: Concurrent Edits', () => {
|
||||
it('Test 6.1: should return early if file was modified by another process', async () => {
|
||||
const filePath = '/test/file.txt';
|
||||
const currentContent =
|
||||
'This content has been modified by someone else.';
|
||||
const originalParams = {
|
||||
file_path: filePath,
|
||||
old_string: 'nonexistent string',
|
||||
new_string: 'some new string',
|
||||
};
|
||||
|
||||
const now = Date.now();
|
||||
const lastEditTime = now - 5000; // 5 seconds ago
|
||||
|
||||
// Mock the file's modification time to be recent
|
||||
vi.spyOn(fs, 'statSync').mockReturnValue({
|
||||
mtimeMs: now,
|
||||
} as fs.Stats);
|
||||
|
||||
// Mock the last edit timestamp from our history to be in the past
|
||||
const history = [
|
||||
{
|
||||
role: 'model',
|
||||
parts: [
|
||||
{
|
||||
functionResponse: {
|
||||
name: EditTool.Name,
|
||||
id: `${EditTool.Name}-${lastEditTime}-123`,
|
||||
response: {
|
||||
output: {
|
||||
llmContent: `Successfully modified file: ${filePath}`,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
];
|
||||
(mockGeminiClientInstance.getHistory as Mock).mockResolvedValue(
|
||||
history,
|
||||
);
|
||||
|
||||
const result = await ensureCorrectEdit(
|
||||
filePath,
|
||||
currentContent,
|
||||
originalParams,
|
||||
mockGeminiClientInstance,
|
||||
abortSignal,
|
||||
);
|
||||
|
||||
expect(result.occurrences).toBe(0);
|
||||
expect(result.params).toEqual(originalParams);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('ensureCorrectFileContent', () => {
|
||||
let mockGeminiClientInstance: Mocked<GeminiClient>;
|
||||
let mockToolRegistry: Mocked<ToolRegistry>;
|
||||
let mockConfigInstance: Config;
|
||||
const abortSignal = new AbortController().signal;
|
||||
|
||||
beforeEach(() => {
|
||||
mockToolRegistry = new ToolRegistry({} as Config) as Mocked<ToolRegistry>;
|
||||
const configParams = {
|
||||
apiKey: 'test-api-key',
|
||||
model: 'test-model',
|
||||
sandbox: false as boolean | string,
|
||||
targetDir: '/test',
|
||||
debugMode: false,
|
||||
question: undefined as string | undefined,
|
||||
fullContext: false,
|
||||
coreTools: undefined as string[] | undefined,
|
||||
toolDiscoveryCommand: undefined as string | undefined,
|
||||
toolCallCommand: undefined as string | undefined,
|
||||
mcpServerCommand: undefined as string | undefined,
|
||||
mcpServers: undefined as Record<string, any> | undefined,
|
||||
userAgent: 'test-agent',
|
||||
userMemory: '',
|
||||
geminiMdFileCount: 0,
|
||||
alwaysSkipModificationConfirmation: false,
|
||||
};
|
||||
mockConfigInstance = {
|
||||
...configParams,
|
||||
getApiKey: vi.fn(() => configParams.apiKey),
|
||||
getModel: vi.fn(() => configParams.model),
|
||||
getSandbox: vi.fn(() => configParams.sandbox),
|
||||
getTargetDir: vi.fn(() => configParams.targetDir),
|
||||
getToolRegistry: vi.fn(() => mockToolRegistry),
|
||||
getDebugMode: vi.fn(() => configParams.debugMode),
|
||||
getQuestion: vi.fn(() => configParams.question),
|
||||
getFullContext: vi.fn(() => configParams.fullContext),
|
||||
getCoreTools: vi.fn(() => configParams.coreTools),
|
||||
getToolDiscoveryCommand: vi.fn(() => configParams.toolDiscoveryCommand),
|
||||
getToolCallCommand: vi.fn(() => configParams.toolCallCommand),
|
||||
getMcpServerCommand: vi.fn(() => configParams.mcpServerCommand),
|
||||
getMcpServers: vi.fn(() => configParams.mcpServers),
|
||||
getUserAgent: vi.fn(() => configParams.userAgent),
|
||||
getUserMemory: vi.fn(() => configParams.userMemory),
|
||||
setUserMemory: vi.fn((mem: string) => {
|
||||
configParams.userMemory = mem;
|
||||
}),
|
||||
getGeminiMdFileCount: vi.fn(() => configParams.geminiMdFileCount),
|
||||
setGeminiMdFileCount: vi.fn((count: number) => {
|
||||
configParams.geminiMdFileCount = count;
|
||||
}),
|
||||
getAlwaysSkipModificationConfirmation: vi.fn(
|
||||
() => configParams.alwaysSkipModificationConfirmation,
|
||||
),
|
||||
setAlwaysSkipModificationConfirmation: vi.fn((skip: boolean) => {
|
||||
configParams.alwaysSkipModificationConfirmation = skip;
|
||||
}),
|
||||
getQuotaErrorOccurred: vi.fn().mockReturnValue(false),
|
||||
setQuotaErrorOccurred: vi.fn(),
|
||||
} as unknown as Config;
|
||||
|
||||
callCount = 0;
|
||||
mockResponses.length = 0;
|
||||
mockGenerateJson = vi
|
||||
.fn()
|
||||
.mockImplementation((_contents, _schema, signal) => {
|
||||
if (signal && signal.aborted) {
|
||||
return Promise.reject(new Error('Aborted'));
|
||||
}
|
||||
const response = mockResponses[callCount];
|
||||
callCount++;
|
||||
if (response === undefined) return Promise.resolve({});
|
||||
return Promise.resolve(response);
|
||||
});
|
||||
mockStartChat = vi.fn();
|
||||
mockSendMessageStream = vi.fn();
|
||||
|
||||
mockGeminiClientInstance = new GeminiClient(
|
||||
mockConfigInstance,
|
||||
) as Mocked<GeminiClient>;
|
||||
resetEditCorrectorCaches_TEST_ONLY();
|
||||
});
|
||||
|
||||
it('should return content unchanged if no escaping issues detected', async () => {
|
||||
const content = 'This is normal content without escaping issues';
|
||||
const result = await ensureCorrectFileContent(
|
||||
content,
|
||||
mockGeminiClientInstance,
|
||||
abortSignal,
|
||||
);
|
||||
expect(result).toBe(content);
|
||||
expect(mockGenerateJson).toHaveBeenCalledTimes(0);
|
||||
});
|
||||
|
||||
it('should call correctStringEscaping for potentially escaped content', async () => {
|
||||
const content = 'console.log(\\"Hello World\\");';
|
||||
const correctedContent = 'console.log("Hello World");';
|
||||
mockResponses.push({
|
||||
corrected_string_escaping: correctedContent,
|
||||
});
|
||||
|
||||
const result = await ensureCorrectFileContent(
|
||||
content,
|
||||
mockGeminiClientInstance,
|
||||
abortSignal,
|
||||
);
|
||||
|
||||
expect(result).toBe(correctedContent);
|
||||
expect(mockGenerateJson).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
|
||||
it('should handle correctStringEscaping returning corrected content via correct property name', async () => {
|
||||
// This test specifically verifies the property name fix
|
||||
const content = 'const message = \\"Hello\\nWorld\\";';
|
||||
const correctedContent = 'const message = "Hello\nWorld";';
|
||||
|
||||
// Mock the response with the correct property name
|
||||
mockResponses.push({
|
||||
corrected_string_escaping: correctedContent,
|
||||
});
|
||||
|
||||
const result = await ensureCorrectFileContent(
|
||||
content,
|
||||
mockGeminiClientInstance,
|
||||
abortSignal,
|
||||
);
|
||||
|
||||
expect(result).toBe(correctedContent);
|
||||
expect(mockGenerateJson).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
|
||||
it('should return original content if LLM correction fails', async () => {
|
||||
const content = 'console.log(\\"Hello World\\");';
|
||||
// Mock empty response to simulate LLM failure
|
||||
mockResponses.push({});
|
||||
|
||||
const result = await ensureCorrectFileContent(
|
||||
content,
|
||||
mockGeminiClientInstance,
|
||||
abortSignal,
|
||||
);
|
||||
|
||||
expect(result).toBe(content);
|
||||
expect(mockGenerateJson).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
|
||||
it('should handle various escape sequences that need correction', async () => {
|
||||
const content =
|
||||
'const obj = { name: \\"John\\", age: 30, bio: \\"Developer\\nEngineer\\" };';
|
||||
const correctedContent =
|
||||
'const obj = { name: "John", age: 30, bio: "Developer\nEngineer" };';
|
||||
|
||||
mockResponses.push({
|
||||
corrected_string_escaping: correctedContent,
|
||||
});
|
||||
|
||||
const result = await ensureCorrectFileContent(
|
||||
content,
|
||||
mockGeminiClientInstance,
|
||||
abortSignal,
|
||||
);
|
||||
|
||||
expect(result).toBe(correctedContent);
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -1,747 +0,0 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import type { Content, GenerateContentConfig } from '@google/genai';
|
||||
import type { GeminiClient } from '../core/client.js';
|
||||
import type { EditToolParams } from '../tools/edit.js';
|
||||
import { ToolNames } from '../tools/tool-names.js';
|
||||
import { LruCache } from './LruCache.js';
|
||||
import { DEFAULT_QWEN_FLASH_MODEL } from '../config/models.js';
|
||||
import {
|
||||
isFunctionResponse,
|
||||
isFunctionCall,
|
||||
} from '../utils/messageInspectors.js';
|
||||
import * as fs from 'node:fs';
|
||||
|
||||
const EditModel = DEFAULT_QWEN_FLASH_MODEL;
|
||||
const EditConfig: GenerateContentConfig = {
|
||||
thinkingConfig: {
|
||||
thinkingBudget: 0,
|
||||
},
|
||||
};
|
||||
|
||||
const MAX_CACHE_SIZE = 50;
|
||||
|
||||
// Cache for ensureCorrectEdit results
|
||||
const editCorrectionCache = new LruCache<string, CorrectedEditResult>(
|
||||
MAX_CACHE_SIZE,
|
||||
);
|
||||
|
||||
// Cache for ensureCorrectFileContent results
|
||||
const fileContentCorrectionCache = new LruCache<string, string>(MAX_CACHE_SIZE);
|
||||
|
||||
/**
|
||||
* Defines the structure of the parameters within CorrectedEditResult
|
||||
*/
|
||||
interface CorrectedEditParams {
|
||||
file_path: string;
|
||||
old_string: string;
|
||||
new_string: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* Defines the result structure for ensureCorrectEdit.
|
||||
*/
|
||||
export interface CorrectedEditResult {
|
||||
params: CorrectedEditParams;
|
||||
occurrences: number;
|
||||
}
|
||||
|
||||
/**
|
||||
* Extracts the timestamp from the .id value, which is in format
|
||||
* <tool.name>-<timestamp>-<uuid>
|
||||
* @param fcnId the ID value of a functionCall or functionResponse object
|
||||
* @returns -1 if the timestamp could not be extracted, else the timestamp (as a number)
|
||||
*/
|
||||
function getTimestampFromFunctionId(fcnId: string): number {
|
||||
const idParts = fcnId.split('-');
|
||||
if (idParts.length > 2) {
|
||||
const timestamp = parseInt(idParts[1], 10);
|
||||
if (!isNaN(timestamp)) {
|
||||
return timestamp;
|
||||
}
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
|
||||
/**
|
||||
* Will look through the gemini client history and determine when the most recent
|
||||
* edit to a target file occurred. If no edit happened, it will return -1
|
||||
* @param filePath the path to the file
|
||||
* @param client the geminiClient, so that we can get the history
|
||||
* @returns a DateTime (as a number) of when the last edit occurred, or -1 if no edit was found.
|
||||
*/
|
||||
async function findLastEditTimestamp(
|
||||
filePath: string,
|
||||
client: GeminiClient,
|
||||
): Promise<number> {
|
||||
const history = (await client.getHistory()) ?? [];
|
||||
|
||||
// Tools that may reference the file path in their FunctionResponse `output`.
|
||||
const toolsInResp = new Set<string>([
|
||||
ToolNames.WRITE_FILE,
|
||||
ToolNames.EDIT,
|
||||
ToolNames.READ_MANY_FILES,
|
||||
ToolNames.GREP,
|
||||
]);
|
||||
// Tools that may reference the file path in their FunctionCall `args`.
|
||||
const toolsInCall = new Set<string>([...toolsInResp, ToolNames.READ_FILE]);
|
||||
|
||||
// Iterate backwards to find the most recent relevant action.
|
||||
for (const entry of history.slice().reverse()) {
|
||||
if (!entry.parts) continue;
|
||||
|
||||
for (const part of entry.parts) {
|
||||
let id: string | undefined;
|
||||
let content: unknown;
|
||||
|
||||
// Check for a relevant FunctionCall with the file path in its arguments.
|
||||
if (
|
||||
isFunctionCall(entry) &&
|
||||
part.functionCall?.name &&
|
||||
toolsInCall.has(part.functionCall.name)
|
||||
) {
|
||||
id = part.functionCall.id;
|
||||
content = part.functionCall.args;
|
||||
}
|
||||
// Check for a relevant FunctionResponse with the file path in its output.
|
||||
else if (
|
||||
isFunctionResponse(entry) &&
|
||||
part.functionResponse?.name &&
|
||||
toolsInResp.has(part.functionResponse.name)
|
||||
) {
|
||||
const { response } = part.functionResponse;
|
||||
if (response && !('error' in response) && 'output' in response) {
|
||||
id = part.functionResponse.id;
|
||||
content = response['output'];
|
||||
}
|
||||
}
|
||||
|
||||
if (!id || content === undefined) continue;
|
||||
|
||||
// Use the "blunt hammer" approach to find the file path in the content.
|
||||
// Note that the tool response data is inconsistent in their formatting
|
||||
// with successes and errors - so, we just check for the existence
|
||||
// as the best guess to if error/failed occurred with the response.
|
||||
const stringified = JSON.stringify(content);
|
||||
if (
|
||||
!stringified.includes('Error') && // only applicable for functionResponse
|
||||
!stringified.includes('Failed') && // only applicable for functionResponse
|
||||
stringified.includes(filePath)
|
||||
) {
|
||||
return getTimestampFromFunctionId(id);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return -1;
|
||||
}
|
||||
|
||||
/**
|
||||
* Attempts to correct edit parameters if the original old_string is not found.
|
||||
* It tries unescaping, and then LLM-based correction.
|
||||
* Results are cached to avoid redundant processing.
|
||||
*
|
||||
* @param currentContent The current content of the file.
|
||||
* @param originalParams The original EditToolParams
|
||||
* @param client The GeminiClient for LLM calls.
|
||||
* @returns A promise resolving to an object containing the (potentially corrected)
|
||||
* EditToolParams (as CorrectedEditParams) and the final occurrences count.
|
||||
*/
|
||||
export async function ensureCorrectEdit(
|
||||
filePath: string,
|
||||
currentContent: string,
|
||||
originalParams: EditToolParams, // This is the EditToolParams from edit.ts, without \'corrected\'
|
||||
client: GeminiClient,
|
||||
abortSignal: AbortSignal,
|
||||
): Promise<CorrectedEditResult> {
|
||||
const cacheKey = `${currentContent}---${originalParams.old_string}---${originalParams.new_string}`;
|
||||
const cachedResult = editCorrectionCache.get(cacheKey);
|
||||
if (cachedResult) {
|
||||
return cachedResult;
|
||||
}
|
||||
|
||||
let finalNewString = originalParams.new_string;
|
||||
const newStringPotentiallyEscaped =
|
||||
unescapeStringForGeminiBug(originalParams.new_string) !==
|
||||
originalParams.new_string;
|
||||
|
||||
const expectedReplacements = originalParams.expected_replacements ?? 1;
|
||||
|
||||
let finalOldString = originalParams.old_string;
|
||||
let occurrences = countOccurrences(currentContent, finalOldString);
|
||||
|
||||
if (occurrences === expectedReplacements) {
|
||||
if (newStringPotentiallyEscaped) {
|
||||
finalNewString = await correctNewStringEscaping(
|
||||
client,
|
||||
finalOldString,
|
||||
originalParams.new_string,
|
||||
abortSignal,
|
||||
);
|
||||
}
|
||||
} else if (occurrences > expectedReplacements) {
|
||||
const expectedReplacements = originalParams.expected_replacements ?? 1;
|
||||
|
||||
// If user expects multiple replacements, return as-is
|
||||
if (occurrences === expectedReplacements) {
|
||||
const result: CorrectedEditResult = {
|
||||
params: { ...originalParams },
|
||||
occurrences,
|
||||
};
|
||||
editCorrectionCache.set(cacheKey, result);
|
||||
return result;
|
||||
}
|
||||
|
||||
// If user expects 1 but found multiple, try to correct (existing behavior)
|
||||
if (expectedReplacements === 1) {
|
||||
const result: CorrectedEditResult = {
|
||||
params: { ...originalParams },
|
||||
occurrences,
|
||||
};
|
||||
editCorrectionCache.set(cacheKey, result);
|
||||
return result;
|
||||
}
|
||||
|
||||
// If occurrences don't match expected, return as-is (will fail validation later)
|
||||
const result: CorrectedEditResult = {
|
||||
params: { ...originalParams },
|
||||
occurrences,
|
||||
};
|
||||
editCorrectionCache.set(cacheKey, result);
|
||||
return result;
|
||||
} else {
|
||||
// occurrences is 0 or some other unexpected state initially
|
||||
const unescapedOldStringAttempt = unescapeStringForGeminiBug(
|
||||
originalParams.old_string,
|
||||
);
|
||||
occurrences = countOccurrences(currentContent, unescapedOldStringAttempt);
|
||||
|
||||
if (occurrences === expectedReplacements) {
|
||||
finalOldString = unescapedOldStringAttempt;
|
||||
if (newStringPotentiallyEscaped) {
|
||||
finalNewString = await correctNewString(
|
||||
client,
|
||||
originalParams.old_string, // original old
|
||||
unescapedOldStringAttempt, // corrected old
|
||||
originalParams.new_string, // original new (which is potentially escaped)
|
||||
abortSignal,
|
||||
);
|
||||
}
|
||||
} else if (occurrences === 0) {
|
||||
if (filePath) {
|
||||
// In order to keep from clobbering edits made outside our system,
|
||||
// let's check if there was a more recent edit to the file than what
|
||||
// our system has done
|
||||
const lastEditedByUsTime = await findLastEditTimestamp(
|
||||
filePath,
|
||||
client,
|
||||
);
|
||||
|
||||
// Add a 1-second buffer to account for timing inaccuracies. If the file
|
||||
// was modified more than a second after the last edit tool was run, we
|
||||
// can assume it was modified by something else.
|
||||
if (lastEditedByUsTime > 0) {
|
||||
const stats = fs.statSync(filePath);
|
||||
const diff = stats.mtimeMs - lastEditedByUsTime;
|
||||
if (diff > 2000) {
|
||||
// Hard coded for 2 seconds
|
||||
// This file was edited sooner
|
||||
const result: CorrectedEditResult = {
|
||||
params: { ...originalParams },
|
||||
occurrences: 0, // Explicitly 0 as LLM failed
|
||||
};
|
||||
editCorrectionCache.set(cacheKey, result);
|
||||
return result;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const llmCorrectedOldString = await correctOldStringMismatch(
|
||||
client,
|
||||
currentContent,
|
||||
unescapedOldStringAttempt,
|
||||
abortSignal,
|
||||
);
|
||||
const llmOldOccurrences = countOccurrences(
|
||||
currentContent,
|
||||
llmCorrectedOldString,
|
||||
);
|
||||
|
||||
if (llmOldOccurrences === expectedReplacements) {
|
||||
finalOldString = llmCorrectedOldString;
|
||||
occurrences = llmOldOccurrences;
|
||||
|
||||
if (newStringPotentiallyEscaped) {
|
||||
const baseNewStringForLLMCorrection = unescapeStringForGeminiBug(
|
||||
originalParams.new_string,
|
||||
);
|
||||
finalNewString = await correctNewString(
|
||||
client,
|
||||
originalParams.old_string, // original old
|
||||
llmCorrectedOldString, // corrected old
|
||||
baseNewStringForLLMCorrection, // base new for correction
|
||||
abortSignal,
|
||||
);
|
||||
}
|
||||
} else {
|
||||
// LLM correction also failed for old_string
|
||||
const result: CorrectedEditResult = {
|
||||
params: { ...originalParams },
|
||||
occurrences: 0, // Explicitly 0 as LLM failed
|
||||
};
|
||||
editCorrectionCache.set(cacheKey, result);
|
||||
return result;
|
||||
}
|
||||
} else {
|
||||
// Unescaping old_string resulted in > 1 occurrence
|
||||
const result: CorrectedEditResult = {
|
||||
params: { ...originalParams },
|
||||
occurrences, // This will be > 1
|
||||
};
|
||||
editCorrectionCache.set(cacheKey, result);
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
||||
const { targetString, pair } = trimPairIfPossible(
|
||||
finalOldString,
|
||||
finalNewString,
|
||||
currentContent,
|
||||
expectedReplacements,
|
||||
);
|
||||
finalOldString = targetString;
|
||||
finalNewString = pair;
|
||||
|
||||
// Final result construction
|
||||
const result: CorrectedEditResult = {
|
||||
params: {
|
||||
file_path: originalParams.file_path,
|
||||
old_string: finalOldString,
|
||||
new_string: finalNewString,
|
||||
},
|
||||
occurrences: countOccurrences(currentContent, finalOldString), // Recalculate occurrences with the final old_string
|
||||
};
|
||||
editCorrectionCache.set(cacheKey, result);
|
||||
return result;
|
||||
}
|
||||
|
||||
export async function ensureCorrectFileContent(
|
||||
content: string,
|
||||
client: GeminiClient,
|
||||
abortSignal: AbortSignal,
|
||||
): Promise<string> {
|
||||
const cachedResult = fileContentCorrectionCache.get(content);
|
||||
if (cachedResult) {
|
||||
return cachedResult;
|
||||
}
|
||||
|
||||
const contentPotentiallyEscaped =
|
||||
unescapeStringForGeminiBug(content) !== content;
|
||||
if (!contentPotentiallyEscaped) {
|
||||
fileContentCorrectionCache.set(content, content);
|
||||
return content;
|
||||
}
|
||||
|
||||
const correctedContent = await correctStringEscaping(
|
||||
content,
|
||||
client,
|
||||
abortSignal,
|
||||
);
|
||||
fileContentCorrectionCache.set(content, correctedContent);
|
||||
return correctedContent;
|
||||
}
|
||||
|
||||
// Define the expected JSON schema for the LLM response for old_string correction
|
||||
const OLD_STRING_CORRECTION_SCHEMA: Record<string, unknown> = {
|
||||
type: 'object',
|
||||
properties: {
|
||||
corrected_target_snippet: {
|
||||
type: 'string',
|
||||
description:
|
||||
'The corrected version of the target snippet that exactly and uniquely matches a segment within the provided file content.',
|
||||
},
|
||||
},
|
||||
required: ['corrected_target_snippet'],
|
||||
};
|
||||
|
||||
export async function correctOldStringMismatch(
|
||||
geminiClient: GeminiClient,
|
||||
fileContent: string,
|
||||
problematicSnippet: string,
|
||||
abortSignal: AbortSignal,
|
||||
): Promise<string> {
|
||||
const prompt = `
|
||||
Context: A process needs to find an exact literal, unique match for a specific text snippet within a file's content. The provided snippet failed to match exactly. This is most likely because it has been overly escaped.
|
||||
|
||||
Task: Analyze the provided file content and the problematic target snippet. Identify the segment in the file content that the snippet was *most likely* intended to match. Output the *exact*, literal text of that segment from the file content. Focus *only* on removing extra escape characters and correcting formatting, whitespace, or minor differences to achieve a PERFECT literal match. The output must be the exact literal text as it appears in the file.
|
||||
|
||||
Problematic target snippet:
|
||||
\`\`\`
|
||||
${problematicSnippet}
|
||||
\`\`\`
|
||||
|
||||
File Content:
|
||||
\`\`\`
|
||||
${fileContent}
|
||||
\`\`\`
|
||||
|
||||
For example, if the problematic target snippet was "\\\\\\nconst greeting = \`Hello \\\\\`\${name}\\\\\`\`;" and the file content had content that looked like "\nconst greeting = \`Hello ${'\\`'}\${name}${'\\`'}\`;", then corrected_target_snippet should likely be "\nconst greeting = \`Hello ${'\\`'}\${name}${'\\`'}\`;" to fix the incorrect escaping to match the original file content.
|
||||
If the differences are only in whitespace or formatting, apply similar whitespace/formatting changes to the corrected_target_snippet.
|
||||
|
||||
Return ONLY the corrected target snippet in the specified JSON format with the key 'corrected_target_snippet'. If no clear, unique match can be found, return an empty string for 'corrected_target_snippet'.
|
||||
`.trim();
|
||||
|
||||
const contents: Content[] = [{ role: 'user', parts: [{ text: prompt }] }];
|
||||
|
||||
try {
|
||||
const result = await geminiClient.generateJson(
|
||||
contents,
|
||||
OLD_STRING_CORRECTION_SCHEMA,
|
||||
abortSignal,
|
||||
EditModel,
|
||||
EditConfig,
|
||||
);
|
||||
|
||||
if (
|
||||
result &&
|
||||
typeof result['corrected_target_snippet'] === 'string' &&
|
||||
result['corrected_target_snippet'].length > 0
|
||||
) {
|
||||
return result['corrected_target_snippet'];
|
||||
} else {
|
||||
return problematicSnippet;
|
||||
}
|
||||
} catch (error) {
|
||||
if (abortSignal.aborted) {
|
||||
throw error;
|
||||
}
|
||||
|
||||
console.error(
|
||||
'Error during LLM call for old string snippet correction:',
|
||||
error,
|
||||
);
|
||||
|
||||
return problematicSnippet;
|
||||
}
|
||||
}
|
||||
|
||||
// Define the expected JSON schema for the new_string correction LLM response
|
||||
const NEW_STRING_CORRECTION_SCHEMA: Record<string, unknown> = {
|
||||
type: 'object',
|
||||
properties: {
|
||||
corrected_new_string: {
|
||||
type: 'string',
|
||||
description:
|
||||
'The original_new_string adjusted to be a suitable replacement for the corrected_old_string, while maintaining the original intent of the change.',
|
||||
},
|
||||
},
|
||||
required: ['corrected_new_string'],
|
||||
};
|
||||
|
||||
/**
|
||||
* Adjusts the new_string to align with a corrected old_string, maintaining the original intent.
|
||||
*/
|
||||
export async function correctNewString(
|
||||
geminiClient: GeminiClient,
|
||||
originalOldString: string,
|
||||
correctedOldString: string,
|
||||
originalNewString: string,
|
||||
abortSignal: AbortSignal,
|
||||
): Promise<string> {
|
||||
if (originalOldString === correctedOldString) {
|
||||
return originalNewString;
|
||||
}
|
||||
|
||||
const prompt = `
|
||||
Context: A text replacement operation was planned. The original text to be replaced (original_old_string) was slightly different from the actual text in the file (corrected_old_string). The original_old_string has now been corrected to match the file content.
|
||||
We now need to adjust the replacement text (original_new_string) so that it makes sense as a replacement for the corrected_old_string, while preserving the original intent of the change.
|
||||
|
||||
original_old_string (what was initially intended to be found):
|
||||
\`\`\`
|
||||
${originalOldString}
|
||||
\`\`\`
|
||||
|
||||
corrected_old_string (what was actually found in the file and will be replaced):
|
||||
\`\`\`
|
||||
${correctedOldString}
|
||||
\`\`\`
|
||||
|
||||
original_new_string (what was intended to replace original_old_string):
|
||||
\`\`\`
|
||||
${originalNewString}
|
||||
\`\`\`
|
||||
|
||||
Task: Based on the differences between original_old_string and corrected_old_string, and the content of original_new_string, generate a corrected_new_string. This corrected_new_string should be what original_new_string would have been if it was designed to replace corrected_old_string directly, while maintaining the spirit of the original transformation.
|
||||
|
||||
For example, if original_old_string was "\\\\\\nconst greeting = \`Hello \\\\\`\${name}\\\\\`\`;" and corrected_old_string is "\nconst greeting = \`Hello ${'\\`'}\${name}${'\\`'}\`;", and original_new_string was "\\\\\\nconst greeting = \`Hello \\\\\`\${name} \${lastName}\\\\\`\`;", then corrected_new_string should likely be "\nconst greeting = \`Hello ${'\\`'}\${name} \${lastName}${'\\`'}\`;" to fix the incorrect escaping.
|
||||
If the differences are only in whitespace or formatting, apply similar whitespace/formatting changes to the corrected_new_string.
|
||||
|
||||
Return ONLY the corrected string in the specified JSON format with the key 'corrected_new_string'. If no adjustment is deemed necessary or possible, return the original_new_string.
|
||||
`.trim();
|
||||
|
||||
const contents: Content[] = [{ role: 'user', parts: [{ text: prompt }] }];
|
||||
|
||||
try {
|
||||
const result = await geminiClient.generateJson(
|
||||
contents,
|
||||
NEW_STRING_CORRECTION_SCHEMA,
|
||||
abortSignal,
|
||||
EditModel,
|
||||
EditConfig,
|
||||
);
|
||||
|
||||
if (
|
||||
result &&
|
||||
typeof result['corrected_new_string'] === 'string' &&
|
||||
result['corrected_new_string'].length > 0
|
||||
) {
|
||||
return result['corrected_new_string'];
|
||||
} else {
|
||||
return originalNewString;
|
||||
}
|
||||
} catch (error) {
|
||||
if (abortSignal.aborted) {
|
||||
throw error;
|
||||
}
|
||||
|
||||
console.error('Error during LLM call for new_string correction:', error);
|
||||
return originalNewString;
|
||||
}
|
||||
}
|
||||
|
||||
const CORRECT_NEW_STRING_ESCAPING_SCHEMA: Record<string, unknown> = {
|
||||
type: 'object',
|
||||
properties: {
|
||||
corrected_new_string_escaping: {
|
||||
type: 'string',
|
||||
description:
|
||||
'The new_string with corrected escaping, ensuring it is a proper replacement for the old_string, especially considering potential over-escaping issues from previous LLM generations.',
|
||||
},
|
||||
},
|
||||
required: ['corrected_new_string_escaping'],
|
||||
};
|
||||
|
||||
export async function correctNewStringEscaping(
|
||||
geminiClient: GeminiClient,
|
||||
oldString: string,
|
||||
potentiallyProblematicNewString: string,
|
||||
abortSignal: AbortSignal,
|
||||
): Promise<string> {
|
||||
const prompt = `
|
||||
Context: A text replacement operation is planned. The text to be replaced (old_string) has been correctly identified in the file. However, the replacement text (new_string) might have been improperly escaped by a previous LLM generation (e.g. too many backslashes for newlines like \\n instead of \n, or unnecessarily quotes like \\"Hello\\" instead of "Hello").
|
||||
|
||||
old_string (this is the exact text that will be replaced):
|
||||
\`\`\`
|
||||
${oldString}
|
||||
\`\`\`
|
||||
|
||||
potentially_problematic_new_string (this is the text that should replace old_string, but MIGHT have bad escaping, or might be entirely correct):
|
||||
\`\`\`
|
||||
${potentiallyProblematicNewString}
|
||||
\`\`\`
|
||||
|
||||
Task: Analyze the potentially_problematic_new_string. If it's syntactically invalid due to incorrect escaping (e.g., "\n", "\t", "\\", "\\'", "\\""), correct the invalid syntax. The goal is to ensure the new_string, when inserted into the code, will be a valid and correctly interpreted.
|
||||
|
||||
For example, if old_string is "foo" and potentially_problematic_new_string is "bar\\nbaz", the corrected_new_string_escaping should be "bar\nbaz".
|
||||
If potentially_problematic_new_string is console.log(\\"Hello World\\"), it should be console.log("Hello World").
|
||||
|
||||
Return ONLY the corrected string in the specified JSON format with the key 'corrected_new_string_escaping'. If no escaping correction is needed, return the original potentially_problematic_new_string.
|
||||
`.trim();
|
||||
|
||||
const contents: Content[] = [{ role: 'user', parts: [{ text: prompt }] }];
|
||||
|
||||
try {
|
||||
const result = await geminiClient.generateJson(
|
||||
contents,
|
||||
CORRECT_NEW_STRING_ESCAPING_SCHEMA,
|
||||
abortSignal,
|
||||
EditModel,
|
||||
EditConfig,
|
||||
);
|
||||
|
||||
if (
|
||||
result &&
|
||||
typeof result['corrected_new_string_escaping'] === 'string' &&
|
||||
result['corrected_new_string_escaping'].length > 0
|
||||
) {
|
||||
return result['corrected_new_string_escaping'];
|
||||
} else {
|
||||
return potentiallyProblematicNewString;
|
||||
}
|
||||
} catch (error) {
|
||||
if (abortSignal.aborted) {
|
||||
throw error;
|
||||
}
|
||||
|
||||
console.error(
|
||||
'Error during LLM call for new_string escaping correction:',
|
||||
error,
|
||||
);
|
||||
return potentiallyProblematicNewString;
|
||||
}
|
||||
}
|
||||
|
||||
const CORRECT_STRING_ESCAPING_SCHEMA: Record<string, unknown> = {
|
||||
type: 'object',
|
||||
properties: {
|
||||
corrected_string_escaping: {
|
||||
type: 'string',
|
||||
description:
|
||||
'The string with corrected escaping, ensuring it is valid, specially considering potential over-escaping issues from previous LLM generations.',
|
||||
},
|
||||
},
|
||||
required: ['corrected_string_escaping'],
|
||||
};
|
||||
|
||||
export async function correctStringEscaping(
|
||||
potentiallyProblematicString: string,
|
||||
client: GeminiClient,
|
||||
abortSignal: AbortSignal,
|
||||
): Promise<string> {
|
||||
const prompt = `
|
||||
Context: An LLM has just generated potentially_problematic_string and the text might have been improperly escaped (e.g. too many backslashes for newlines like \\n instead of \n, or unnecessarily quotes like \\"Hello\\" instead of "Hello").
|
||||
|
||||
potentially_problematic_string (this text MIGHT have bad escaping, or might be entirely correct):
|
||||
\`\`\`
|
||||
${potentiallyProblematicString}
|
||||
\`\`\`
|
||||
|
||||
Task: Analyze the potentially_problematic_string. If it's syntactically invalid due to incorrect escaping (e.g., "\n", "\t", "\\", "\\'", "\\""), correct the invalid syntax. The goal is to ensure the text will be a valid and correctly interpreted.
|
||||
|
||||
For example, if potentially_problematic_string is "bar\\nbaz", the corrected_new_string_escaping should be "bar\nbaz".
|
||||
If potentially_problematic_string is console.log(\\"Hello World\\"), it should be console.log("Hello World").
|
||||
|
||||
Return ONLY the corrected string in the specified JSON format with the key 'corrected_string_escaping'. If no escaping correction is needed, return the original potentially_problematic_string.
|
||||
`.trim();
|
||||
|
||||
const contents: Content[] = [{ role: 'user', parts: [{ text: prompt }] }];
|
||||
|
||||
try {
|
||||
const result = await client.generateJson(
|
||||
contents,
|
||||
CORRECT_STRING_ESCAPING_SCHEMA,
|
||||
abortSignal,
|
||||
EditModel,
|
||||
EditConfig,
|
||||
);
|
||||
|
||||
if (
|
||||
result &&
|
||||
typeof result['corrected_string_escaping'] === 'string' &&
|
||||
result['corrected_string_escaping'].length > 0
|
||||
) {
|
||||
return result['corrected_string_escaping'];
|
||||
} else {
|
||||
return potentiallyProblematicString;
|
||||
}
|
||||
} catch (error) {
|
||||
if (abortSignal.aborted) {
|
||||
throw error;
|
||||
}
|
||||
|
||||
console.error(
|
||||
'Error during LLM call for string escaping correction:',
|
||||
error,
|
||||
);
|
||||
return potentiallyProblematicString;
|
||||
}
|
||||
}
|
||||
|
||||
function trimPairIfPossible(
|
||||
target: string,
|
||||
trimIfTargetTrims: string,
|
||||
currentContent: string,
|
||||
expectedReplacements: number,
|
||||
) {
|
||||
const trimmedTargetString = target.trim();
|
||||
if (target.length !== trimmedTargetString.length) {
|
||||
const trimmedTargetOccurrences = countOccurrences(
|
||||
currentContent,
|
||||
trimmedTargetString,
|
||||
);
|
||||
|
||||
if (trimmedTargetOccurrences === expectedReplacements) {
|
||||
const trimmedReactiveString = trimIfTargetTrims.trim();
|
||||
return {
|
||||
targetString: trimmedTargetString,
|
||||
pair: trimmedReactiveString,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
targetString: target,
|
||||
pair: trimIfTargetTrims,
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Unescapes a string that might have been overly escaped by an LLM.
|
||||
*/
|
||||
export function unescapeStringForGeminiBug(inputString: string): string {
|
||||
// Regex explanation:
|
||||
// \\ : Matches exactly one literal backslash character.
|
||||
// (n|t|r|'|"|`|\\|\n) : This is a capturing group. It matches one of the following:
|
||||
// n, t, r, ', ", ` : These match the literal characters 'n', 't', 'r', single quote, double quote, or backtick.
|
||||
// This handles cases like "\\n", "\\`", etc.
|
||||
// \\ : This matches a literal backslash. This handles cases like "\\\\" (escaped backslash).
|
||||
// \n : This matches an actual newline character. This handles cases where the input
|
||||
// string might have something like "\\\n" (a literal backslash followed by a newline).
|
||||
// g : Global flag, to replace all occurrences.
|
||||
|
||||
return inputString.replace(
|
||||
/\\+(n|t|r|'|"|`|\\|\n)/g,
|
||||
(match, capturedChar) => {
|
||||
// 'match' is the entire erroneous sequence, e.g., if the input (in memory) was "\\\\`", match is "\\\\`".
|
||||
// 'capturedChar' is the character that determines the true meaning, e.g., '`'.
|
||||
|
||||
switch (capturedChar) {
|
||||
case 'n':
|
||||
return '\n'; // Correctly escaped: \n (newline character)
|
||||
case 't':
|
||||
return '\t'; // Correctly escaped: \t (tab character)
|
||||
case 'r':
|
||||
return '\r'; // Correctly escaped: \r (carriage return character)
|
||||
case "'":
|
||||
return "'"; // Correctly escaped: ' (apostrophe character)
|
||||
case '"':
|
||||
return '"'; // Correctly escaped: " (quotation mark character)
|
||||
case '`':
|
||||
return '`'; // Correctly escaped: ` (backtick character)
|
||||
case '\\': // This handles when 'capturedChar' is a literal backslash
|
||||
return '\\'; // Replace escaped backslash (e.g., "\\\\") with single backslash
|
||||
case '\n': // This handles when 'capturedChar' is an actual newline
|
||||
return '\n'; // Replace the whole erroneous sequence (e.g., "\\\n" in memory) with a clean newline
|
||||
default:
|
||||
// This fallback should ideally not be reached if the regex captures correctly.
|
||||
// It would return the original matched sequence if an unexpected character was captured.
|
||||
return match;
|
||||
}
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Counts occurrences of a substring in a string
|
||||
*/
|
||||
export function countOccurrences(str: string, substr: string): number {
|
||||
if (substr === '') {
|
||||
return 0;
|
||||
}
|
||||
let count = 0;
|
||||
let pos = str.indexOf(substr);
|
||||
while (pos !== -1) {
|
||||
count++;
|
||||
pos = str.indexOf(substr, pos + substr.length); // Start search after the current match
|
||||
}
|
||||
return count;
|
||||
}
|
||||
|
||||
export function resetEditCorrectorCaches_TEST_ONLY() {
|
||||
editCorrectionCache.clear();
|
||||
fileContentCorrectionCache.clear();
|
||||
}
|
||||
@@ -11,6 +11,7 @@ import {
|
||||
getCommandRoots,
|
||||
getShellConfiguration,
|
||||
isCommandAllowed,
|
||||
isCommandNeedsPermission,
|
||||
stripShellWrapper,
|
||||
} from './shell-utils.js';
|
||||
import type { Config } from '../config/config.js';
|
||||
@@ -27,8 +28,10 @@ vi.mock('os', () => ({
|
||||
}));
|
||||
|
||||
const mockQuote = vi.hoisted(() => vi.fn());
|
||||
const mockParse = vi.hoisted(() => vi.fn());
|
||||
vi.mock('shell-quote', () => ({
|
||||
quote: mockQuote,
|
||||
parse: mockParse,
|
||||
}));
|
||||
|
||||
let config: Config;
|
||||
@@ -38,6 +41,7 @@ beforeEach(() => {
|
||||
mockQuote.mockImplementation((args: string[]) =>
|
||||
args.map((arg) => `'${arg}'`).join(' '),
|
||||
);
|
||||
mockParse.mockImplementation((cmd: string) => cmd.split(' '));
|
||||
config = {
|
||||
getCoreTools: () => [],
|
||||
getExcludeTools: () => [],
|
||||
@@ -436,3 +440,16 @@ describe('getShellConfiguration', () => {
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('isCommandNeedPermission', () => {
|
||||
it('returns false for read-only commands', () => {
|
||||
const result = isCommandNeedsPermission('ls');
|
||||
expect(result.requiresPermission).toBe(false);
|
||||
});
|
||||
|
||||
it('returns true for mutating commands with reason', () => {
|
||||
const result = isCommandNeedsPermission('rm -rf temp');
|
||||
expect(result.requiresPermission).toBe(true);
|
||||
expect(result.reason).toContain('requires permission to execute');
|
||||
});
|
||||
});
|
||||
|
||||
@@ -9,6 +9,7 @@ import type { Config } from '../config/config.js';
|
||||
import os from 'node:os';
|
||||
import { quote } from 'shell-quote';
|
||||
import { doesToolInvocationMatch } from './tool-utils.js';
|
||||
import { isShellCommandReadOnly } from './shellReadOnlyChecker.js';
|
||||
|
||||
const SHELL_TOOL_NAMES = ['run_shell_command', 'ShellTool'];
|
||||
|
||||
@@ -469,3 +470,19 @@ export function isCommandAllowed(
|
||||
}
|
||||
return { allowed: false, reason: blockReason };
|
||||
}
|
||||
|
||||
export function isCommandNeedsPermission(command: string): {
|
||||
requiresPermission: boolean;
|
||||
reason?: string;
|
||||
} {
|
||||
const isAllowed = isShellCommandReadOnly(command);
|
||||
|
||||
if (isAllowed) {
|
||||
return { requiresPermission: false };
|
||||
}
|
||||
|
||||
return {
|
||||
requiresPermission: true,
|
||||
reason: 'Command requires permission to execute.',
|
||||
};
|
||||
}
|
||||
|
||||
56
packages/core/src/utils/shellReadOnlyChecker.test.ts
Normal file
56
packages/core/src/utils/shellReadOnlyChecker.test.ts
Normal file
@@ -0,0 +1,56 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Qwen
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { describe, expect, it } from 'vitest';
|
||||
import { isShellCommandReadOnly } from './shellReadOnlyChecker.js';
|
||||
|
||||
describe('evaluateShellCommandReadOnly', () => {
|
||||
it('allows simple read-only command', () => {
|
||||
const result = isShellCommandReadOnly('ls -la');
|
||||
expect(result).toBe(true);
|
||||
});
|
||||
|
||||
it('rejects mutating commands like rm', () => {
|
||||
const result = isShellCommandReadOnly('rm -rf temp');
|
||||
expect(result).toBe(false);
|
||||
});
|
||||
|
||||
it('rejects redirection output', () => {
|
||||
const result = isShellCommandReadOnly('ls > out.txt');
|
||||
expect(result).toBe(false);
|
||||
});
|
||||
|
||||
it('rejects command substitution', () => {
|
||||
const result = isShellCommandReadOnly('echo $(touch file)');
|
||||
expect(result).toBe(false);
|
||||
});
|
||||
|
||||
it('allows git status but rejects git commit', () => {
|
||||
expect(isShellCommandReadOnly('git status')).toBe(true);
|
||||
const commitResult = isShellCommandReadOnly('git commit -am "msg"');
|
||||
expect(commitResult).toBe(false);
|
||||
});
|
||||
|
||||
it('rejects find with exec', () => {
|
||||
const result = isShellCommandReadOnly('find . -exec rm {} \\;');
|
||||
expect(result).toBe(false);
|
||||
});
|
||||
|
||||
it('rejects sed in-place', () => {
|
||||
const result = isShellCommandReadOnly("sed -i 's/foo/bar/' file");
|
||||
expect(result).toBe(false);
|
||||
});
|
||||
|
||||
it('rejects empty command', () => {
|
||||
const result = isShellCommandReadOnly(' ');
|
||||
expect(result).toBe(false);
|
||||
});
|
||||
|
||||
it('respects environment prefix followed by allowed command', () => {
|
||||
const result = isShellCommandReadOnly('FOO=bar ls');
|
||||
expect(result).toBe(true);
|
||||
});
|
||||
});
|
||||
300
packages/core/src/utils/shellReadOnlyChecker.ts
Normal file
300
packages/core/src/utils/shellReadOnlyChecker.ts
Normal file
@@ -0,0 +1,300 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Qwen
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { parse } from 'shell-quote';
|
||||
import {
|
||||
detectCommandSubstitution,
|
||||
splitCommands,
|
||||
stripShellWrapper,
|
||||
} from './shell-utils.js';
|
||||
|
||||
const READ_ONLY_ROOT_COMMANDS = new Set([
|
||||
'awk',
|
||||
'basename',
|
||||
'cat',
|
||||
'cd',
|
||||
'column',
|
||||
'cut',
|
||||
'df',
|
||||
'dirname',
|
||||
'du',
|
||||
'echo',
|
||||
'env',
|
||||
'find',
|
||||
'git',
|
||||
'grep',
|
||||
'head',
|
||||
'less',
|
||||
'ls',
|
||||
'more',
|
||||
'printenv',
|
||||
'printf',
|
||||
'ps',
|
||||
'pwd',
|
||||
'rg',
|
||||
'ripgrep',
|
||||
'sed',
|
||||
'sort',
|
||||
'stat',
|
||||
'tail',
|
||||
'tree',
|
||||
'uniq',
|
||||
'wc',
|
||||
'which',
|
||||
'where',
|
||||
'whoami',
|
||||
]);
|
||||
|
||||
const BLOCKED_FIND_FLAGS = new Set([
|
||||
'-delete',
|
||||
'-exec',
|
||||
'-execdir',
|
||||
'-ok',
|
||||
'-okdir',
|
||||
]);
|
||||
|
||||
const BLOCKED_FIND_PREFIXES = ['-fprint', '-fprintf'];
|
||||
|
||||
const READ_ONLY_GIT_SUBCOMMANDS = new Set([
|
||||
'blame',
|
||||
'branch',
|
||||
'cat-file',
|
||||
'diff',
|
||||
'grep',
|
||||
'log',
|
||||
'ls-files',
|
||||
'remote',
|
||||
'rev-parse',
|
||||
'show',
|
||||
'status',
|
||||
'describe',
|
||||
]);
|
||||
|
||||
const BLOCKED_GIT_REMOTE_ACTIONS = new Set([
|
||||
'add',
|
||||
'remove',
|
||||
'rename',
|
||||
'set-url',
|
||||
'prune',
|
||||
'update',
|
||||
]);
|
||||
|
||||
const BLOCKED_GIT_BRANCH_FLAGS = new Set([
|
||||
'-d',
|
||||
'-D',
|
||||
'--delete',
|
||||
'--move',
|
||||
'-m',
|
||||
]);
|
||||
|
||||
const BLOCKED_SED_PREFIXES = ['-i'];
|
||||
|
||||
const ENV_ASSIGNMENT_REGEX = /^[A-Za-z_][A-Za-z0-9_]*=/;
|
||||
|
||||
function containsWriteRedirection(command: string): boolean {
|
||||
let inSingleQuotes = false;
|
||||
let inDoubleQuotes = false;
|
||||
let escapeNext = false;
|
||||
|
||||
for (const char of command) {
|
||||
if (escapeNext) {
|
||||
escapeNext = false;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (char === '\\' && !inSingleQuotes) {
|
||||
escapeNext = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (char === "'" && !inDoubleQuotes) {
|
||||
inSingleQuotes = !inSingleQuotes;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (char === '"' && !inSingleQuotes) {
|
||||
inDoubleQuotes = !inDoubleQuotes;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!inSingleQuotes && !inDoubleQuotes && char === '>') {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
function normalizeTokens(segment: string): string[] {
|
||||
const parsed = parse(segment);
|
||||
const tokens: string[] = [];
|
||||
for (const token of parsed) {
|
||||
if (typeof token === 'string') {
|
||||
tokens.push(token);
|
||||
}
|
||||
}
|
||||
return tokens;
|
||||
}
|
||||
|
||||
function skipEnvironmentAssignments(tokens: string[]): {
|
||||
root?: string;
|
||||
args: string[];
|
||||
} {
|
||||
let index = 0;
|
||||
while (index < tokens.length && ENV_ASSIGNMENT_REGEX.test(tokens[index]!)) {
|
||||
index++;
|
||||
}
|
||||
|
||||
if (index >= tokens.length) {
|
||||
return { args: [] };
|
||||
}
|
||||
|
||||
return {
|
||||
root: tokens[index],
|
||||
args: tokens.slice(index + 1),
|
||||
};
|
||||
}
|
||||
|
||||
function evaluateFindCommand(tokens: string[]): boolean {
|
||||
const [, ...rest] = tokens;
|
||||
for (const token of rest) {
|
||||
const lower = token.toLowerCase();
|
||||
if (BLOCKED_FIND_FLAGS.has(lower)) {
|
||||
return false;
|
||||
}
|
||||
if (BLOCKED_FIND_PREFIXES.some((prefix) => lower.startsWith(prefix))) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
function evaluateSedCommand(tokens: string[]): boolean {
|
||||
const [, ...rest] = tokens;
|
||||
for (const token of rest) {
|
||||
if (
|
||||
BLOCKED_SED_PREFIXES.some((prefix) => token.startsWith(prefix)) ||
|
||||
token === '--in-place'
|
||||
) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
function evaluateGitRemoteArgs(args: string[]): boolean {
|
||||
for (const arg of args) {
|
||||
if (BLOCKED_GIT_REMOTE_ACTIONS.has(arg.toLowerCase())) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
function evaluateGitBranchArgs(args: string[]): boolean {
|
||||
for (const arg of args) {
|
||||
if (BLOCKED_GIT_BRANCH_FLAGS.has(arg)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
function evaluateGitCommand(tokens: string[]): boolean {
|
||||
let index = 1;
|
||||
while (index < tokens.length && tokens[index]!.startsWith('-')) {
|
||||
const flag = tokens[index]!.toLowerCase();
|
||||
if (flag === '--version' || flag === '--help') {
|
||||
return true;
|
||||
}
|
||||
index++;
|
||||
}
|
||||
|
||||
if (index >= tokens.length) {
|
||||
return true;
|
||||
}
|
||||
|
||||
const subcommand = tokens[index]!.toLowerCase();
|
||||
if (!READ_ONLY_GIT_SUBCOMMANDS.has(subcommand)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const args = tokens.slice(index + 1);
|
||||
|
||||
if (subcommand === 'remote') {
|
||||
return evaluateGitRemoteArgs(args);
|
||||
}
|
||||
|
||||
if (subcommand === 'branch') {
|
||||
return evaluateGitBranchArgs(args);
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
function evaluateShellSegment(segment: string): boolean {
|
||||
if (!segment.trim()) {
|
||||
return true;
|
||||
}
|
||||
|
||||
const stripped = stripShellWrapper(segment);
|
||||
if (!stripped) {
|
||||
return true;
|
||||
}
|
||||
|
||||
if (detectCommandSubstitution(stripped)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (containsWriteRedirection(stripped)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const tokens = normalizeTokens(stripped);
|
||||
if (tokens.length === 0) {
|
||||
return true;
|
||||
}
|
||||
|
||||
const { root, args } = skipEnvironmentAssignments(tokens);
|
||||
if (!root) {
|
||||
return true;
|
||||
}
|
||||
|
||||
const normalizedRoot = root.toLowerCase();
|
||||
if (!READ_ONLY_ROOT_COMMANDS.has(normalizedRoot)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (normalizedRoot === 'find') {
|
||||
return evaluateFindCommand([normalizedRoot, ...args]);
|
||||
}
|
||||
|
||||
if (normalizedRoot === 'sed') {
|
||||
return evaluateSedCommand([normalizedRoot, ...args]);
|
||||
}
|
||||
|
||||
if (normalizedRoot === 'git') {
|
||||
return evaluateGitCommand([normalizedRoot, ...args]);
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
export function isShellCommandReadOnly(command: string): boolean {
|
||||
if (typeof command !== 'string' || !command.trim()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const segments = splitCommands(command);
|
||||
for (const segment of segments) {
|
||||
const isAllowed = evaluateShellSegment(segment);
|
||||
if (!isAllowed) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@qwen-code/qwen-code-test-utils",
|
||||
"version": "0.0.11",
|
||||
"version": "0.0.14-nightly.1",
|
||||
"private": true,
|
||||
"main": "src/index.ts",
|
||||
"license": "Apache-2.0",
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
"name": "qwen-code-vscode-ide-companion",
|
||||
"displayName": "Qwen Code Companion",
|
||||
"description": "Enable Qwen Code with direct access to your VS Code workspace.",
|
||||
"version": "0.0.11",
|
||||
"version": "0.0.14-nightly.1",
|
||||
"publisher": "qwenlm",
|
||||
"icon": "assets/icon.png",
|
||||
"repository": {
|
||||
|
||||
Reference in New Issue
Block a user