mirror of
https://github.com/QwenLM/qwen-code.git
synced 2025-12-24 02:29:13 +00:00
Compare commits
25 Commits
v0.0.12-ni
...
release/v0
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
45c0330ac5 | ||
|
|
e38947a62d | ||
|
|
e66bb7e717 | ||
|
|
490c36caeb | ||
|
|
e4d16adf7b | ||
|
|
85a2b8d6e0 | ||
|
|
5ecb4a2430 | ||
|
|
9c1d7228cb | ||
|
|
deb99a3b21 | ||
|
|
014059e8a6 | ||
|
|
3579d6555a | ||
|
|
9a56560eb4 | ||
|
|
da0863b943 | ||
|
|
5f68a8b6b3 | ||
|
|
761833c915 | ||
|
|
56808ac210 | ||
|
|
724c24933c | ||
|
|
17cdce6298 | ||
|
|
de468f0525 | ||
|
|
50199288ec | ||
|
|
8803b2eb76 | ||
|
|
e552bc9609 | ||
|
|
5f90472a7d | ||
|
|
19950e5b7c | ||
|
|
8e2fc76c15 |
13
.vscode/launch.json
vendored
13
.vscode/launch.json
vendored
@@ -101,6 +101,13 @@
|
||||
"env": {
|
||||
"GEMINI_SANDBOX": "false"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Attach by Process ID",
|
||||
"processId": "${command:PickProcess}",
|
||||
"request": "attach",
|
||||
"skipFiles": ["<node_internals>/**"],
|
||||
"type": "node"
|
||||
}
|
||||
],
|
||||
"inputs": [
|
||||
@@ -115,6 +122,12 @@
|
||||
"type": "promptString",
|
||||
"description": "Enter your prompt for non-interactive mode",
|
||||
"default": "Explain this code"
|
||||
},
|
||||
{
|
||||
"id": "debugPort",
|
||||
"type": "promptString",
|
||||
"description": "Enter the debug port number (default: 9229)",
|
||||
"default": "9229"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
20
CHANGELOG.md
20
CHANGELOG.md
@@ -1,5 +1,25 @@
|
||||
# Changelog
|
||||
|
||||
## 0.0.12
|
||||
|
||||
- Added vision model support for Qwen-OAuth authentication.
|
||||
- Synced upstream `gemini-cli` to v0.3.4 with numerous improvements and bug fixes.
|
||||
- Enhanced subagent functionality with system reminders and improved user experience.
|
||||
- Added tool call type coercion for better compatibility.
|
||||
- Fixed arrow key navigation issues on Windows.
|
||||
- Fixed missing tool call chunks for OpenAI logging.
|
||||
- Fixed system prompt issues to avoid malformed tool calls.
|
||||
- Fixed terminal flicker when subagent is executing.
|
||||
- Fixed duplicate subagents configuration when running in home directory.
|
||||
- Fixed Esc key unable to cancel subagent dialog.
|
||||
- Added confirmation prompt for `/init` command when context file exists.
|
||||
- Added `skipLoopDetection` configuration option.
|
||||
- Fixed `is_background` parameter reset issues.
|
||||
- Enhanced Windows compatibility with multi-line paste handling.
|
||||
- Improved subagent documentation and branding consistency.
|
||||
- Fixed various linting errors and improved code quality.
|
||||
- Miscellaneous improvements and bug fixes.
|
||||
|
||||
## 0.0.11
|
||||
|
||||
- Added subagents feature with file-based configuration system for specialized AI assistants.
|
||||
|
||||
53
README.md
53
README.md
@@ -54,6 +54,7 @@ For detailed setup instructions, see [Authorization](#authorization).
|
||||
- **Code Understanding & Editing** - Query and edit large codebases beyond traditional context window limits
|
||||
- **Workflow Automation** - Automate operational tasks like handling pull requests and complex rebases
|
||||
- **Enhanced Parser** - Adapted parser specifically optimized for Qwen-Coder models
|
||||
- **Vision Model Support** - Automatically detect images in your input and seamlessly switch to vision-capable models for multimodal analysis
|
||||
|
||||
## Installation
|
||||
|
||||
@@ -121,6 +122,58 @@ Create or edit `.qwen/settings.json` in your home directory:
|
||||
|
||||
> 📝 **Note**: Session token limit applies to a single conversation, not cumulative API calls.
|
||||
|
||||
### Vision Model Configuration
|
||||
|
||||
Qwen Code includes intelligent vision model auto-switching that detects images in your input and can automatically switch to vision-capable models for multimodal analysis. **This feature is enabled by default** - when you include images in your queries, you'll see a dialog asking how you'd like to handle the vision model switch.
|
||||
|
||||
#### Skip the Switch Dialog (Optional)
|
||||
|
||||
If you don't want to see the interactive dialog each time, configure the default behavior in your `.qwen/settings.json`:
|
||||
|
||||
```json
|
||||
{
|
||||
"experimental": {
|
||||
"vlmSwitchMode": "once"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Available modes:**
|
||||
|
||||
- **`"once"`** - Switch to vision model for this query only, then revert
|
||||
- **`"session"`** - Switch to vision model for the entire session
|
||||
- **`"persist"`** - Continue with current model (no switching)
|
||||
- **Not set** - Show interactive dialog each time (default)
|
||||
|
||||
#### Command Line Override
|
||||
|
||||
You can also set the behavior via command line:
|
||||
|
||||
```bash
|
||||
# Switch once per query
|
||||
qwen --vlm-switch-mode once
|
||||
|
||||
# Switch for entire session
|
||||
qwen --vlm-switch-mode session
|
||||
|
||||
# Never switch automatically
|
||||
qwen --vlm-switch-mode persist
|
||||
```
|
||||
|
||||
#### Disable Vision Models (Optional)
|
||||
|
||||
To completely disable vision model support, add to your `.qwen/settings.json`:
|
||||
|
||||
```json
|
||||
{
|
||||
"experimental": {
|
||||
"visionModelPreview": false
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
> 💡 **Tip**: In YOLO mode (`--yolo`), vision switching happens automatically without prompts when images are detected.
|
||||
|
||||
### Authorization
|
||||
|
||||
Choose your preferred authentication method based on your needs:
|
||||
|
||||
@@ -133,6 +133,28 @@ Focus on creating clear, comprehensive documentation that helps both
|
||||
new contributors and end users understand the project.
|
||||
```
|
||||
|
||||
## Using Subagents Effectively
|
||||
|
||||
### Automatic Delegation
|
||||
|
||||
Qwen Code proactively delegates tasks based on:
|
||||
|
||||
- The task description in your request
|
||||
- The description field in subagent configurations
|
||||
- Current context and available tools
|
||||
|
||||
To encourage more proactive subagent use, include phrases like "use PROACTIVELY" or "MUST BE USED" in your description field.
|
||||
|
||||
### Explicit Invocation
|
||||
|
||||
Request a specific subagent by mentioning it in your command:
|
||||
|
||||
```
|
||||
> Let the testing-expert subagent create unit tests for the payment module
|
||||
> Have the documentation-writer subagent update the API reference
|
||||
> Get the react-specialist subagent to optimize this component's performance
|
||||
```
|
||||
|
||||
## Examples
|
||||
|
||||
### Development Workflow Agents
|
||||
|
||||
12
package-lock.json
generated
12
package-lock.json
generated
@@ -1,12 +1,12 @@
|
||||
{
|
||||
"name": "@qwen-code/qwen-code",
|
||||
"version": "0.0.11",
|
||||
"version": "0.0.13-nightly.2",
|
||||
"lockfileVersion": 3,
|
||||
"requires": true,
|
||||
"packages": {
|
||||
"": {
|
||||
"name": "@qwen-code/qwen-code",
|
||||
"version": "0.0.11",
|
||||
"version": "0.0.13-nightly.2",
|
||||
"workspaces": [
|
||||
"packages/*"
|
||||
],
|
||||
@@ -13454,7 +13454,7 @@
|
||||
},
|
||||
"packages/cli": {
|
||||
"name": "@qwen-code/qwen-code",
|
||||
"version": "0.0.11",
|
||||
"version": "0.0.13-nightly.2",
|
||||
"dependencies": {
|
||||
"@google/genai": "1.9.0",
|
||||
"@iarna/toml": "^2.2.5",
|
||||
@@ -13662,7 +13662,7 @@
|
||||
},
|
||||
"packages/core": {
|
||||
"name": "@qwen-code/qwen-code-core",
|
||||
"version": "0.0.11",
|
||||
"version": "0.0.13-nightly.2",
|
||||
"dependencies": {
|
||||
"@google/genai": "1.13.0",
|
||||
"@lvce-editor/ripgrep": "^1.6.0",
|
||||
@@ -13788,7 +13788,7 @@
|
||||
},
|
||||
"packages/test-utils": {
|
||||
"name": "@qwen-code/qwen-code-test-utils",
|
||||
"version": "0.0.11",
|
||||
"version": "0.0.13-nightly.2",
|
||||
"dev": true,
|
||||
"license": "Apache-2.0",
|
||||
"devDependencies": {
|
||||
@@ -13800,7 +13800,7 @@
|
||||
},
|
||||
"packages/vscode-ide-companion": {
|
||||
"name": "qwen-code-vscode-ide-companion",
|
||||
"version": "0.0.11",
|
||||
"version": "0.0.13-nightly.2",
|
||||
"license": "LICENSE",
|
||||
"dependencies": {
|
||||
"@modelcontextprotocol/sdk": "^1.15.1",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@qwen-code/qwen-code",
|
||||
"version": "0.0.11",
|
||||
"version": "0.0.13-nightly.2",
|
||||
"engines": {
|
||||
"node": ">=20.0.0"
|
||||
},
|
||||
@@ -13,7 +13,7 @@
|
||||
"url": "git+https://github.com/QwenLM/qwen-code.git"
|
||||
},
|
||||
"config": {
|
||||
"sandboxImageUri": "ghcr.io/qwenlm/qwen-code:0.0.11"
|
||||
"sandboxImageUri": "ghcr.io/qwenlm/qwen-code:0.0.13-nightly.2"
|
||||
},
|
||||
"scripts": {
|
||||
"start": "node scripts/start.js",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@qwen-code/qwen-code",
|
||||
"version": "0.0.11",
|
||||
"version": "0.0.13-nightly.2",
|
||||
"description": "Qwen Code",
|
||||
"repository": {
|
||||
"type": "git",
|
||||
@@ -25,7 +25,7 @@
|
||||
"dist"
|
||||
],
|
||||
"config": {
|
||||
"sandboxImageUri": "ghcr.io/qwenlm/qwen-code:0.0.11"
|
||||
"sandboxImageUri": "ghcr.io/qwenlm/qwen-code:0.0.13-nightly.2"
|
||||
},
|
||||
"dependencies": {
|
||||
"@google/genai": "1.9.0",
|
||||
|
||||
@@ -1514,7 +1514,7 @@ describe('loadCliConfig model selection', () => {
|
||||
argv,
|
||||
);
|
||||
|
||||
expect(config.getModel()).toBe('qwen3-coder-plus');
|
||||
expect(config.getModel()).toBe('coder-model');
|
||||
});
|
||||
|
||||
it('always prefers model from argvs', async () => {
|
||||
|
||||
@@ -82,6 +82,7 @@ export interface CliArgs {
|
||||
includeDirectories: string[] | undefined;
|
||||
tavilyApiKey: string | undefined;
|
||||
screenReader: boolean | undefined;
|
||||
vlmSwitchMode: string | undefined;
|
||||
}
|
||||
|
||||
export async function parseArguments(settings: Settings): Promise<CliArgs> {
|
||||
@@ -249,6 +250,13 @@ export async function parseArguments(settings: Settings): Promise<CliArgs> {
|
||||
description: 'Enable screen reader mode for accessibility.',
|
||||
default: false,
|
||||
})
|
||||
.option('vlm-switch-mode', {
|
||||
type: 'string',
|
||||
choices: ['once', 'session', 'persist'],
|
||||
description:
|
||||
'Default behavior when images are detected in input. Values: once (one-time switch), session (switch for entire session), persist (continue with current model). Overrides settings files.',
|
||||
default: process.env['VLM_SWITCH_MODE'],
|
||||
})
|
||||
.check((argv) => {
|
||||
if (argv.prompt && argv['promptInteractive']) {
|
||||
throw new Error(
|
||||
@@ -524,6 +532,9 @@ export async function loadCliConfig(
|
||||
argv.screenReader !== undefined
|
||||
? argv.screenReader
|
||||
: (settings.ui?.accessibility?.screenReader ?? false);
|
||||
|
||||
const vlmSwitchMode =
|
||||
argv.vlmSwitchMode || settings.experimental?.vlmSwitchMode;
|
||||
return new Config({
|
||||
sessionId,
|
||||
embeddingModel: DEFAULT_GEMINI_EMBEDDING_MODEL,
|
||||
@@ -630,6 +641,7 @@ export async function loadCliConfig(
|
||||
skipNextSpeakerCheck: settings.model?.skipNextSpeakerCheck,
|
||||
enablePromptCompletion: settings.general?.enablePromptCompletion ?? false,
|
||||
skipLoopDetection: settings.skipLoopDetection ?? false,
|
||||
vlmSwitchMode,
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
@@ -69,7 +69,11 @@ const MOCK_WORKSPACE_SETTINGS_PATH = pathActual.join(
|
||||
);
|
||||
|
||||
// A more flexible type for test data that allows arbitrary properties.
|
||||
type TestSettings = Settings & { [key: string]: unknown };
|
||||
type TestSettings = Settings & {
|
||||
[key: string]: unknown;
|
||||
nested?: { [key: string]: unknown };
|
||||
nestedObj?: { [key: string]: unknown };
|
||||
};
|
||||
|
||||
vi.mock('fs', async (importOriginal) => {
|
||||
// Get all the functions from the real 'fs' module
|
||||
@@ -137,6 +141,9 @@ describe('Settings Loading and Merging', () => {
|
||||
advanced: {
|
||||
excludedEnvVars: [],
|
||||
},
|
||||
experimental: {},
|
||||
contentGenerator: {},
|
||||
systemPromptMappings: {},
|
||||
extensions: {
|
||||
disabled: [],
|
||||
workspacesWithMigrationNudge: [],
|
||||
@@ -197,6 +204,9 @@ describe('Settings Loading and Merging', () => {
|
||||
advanced: {
|
||||
excludedEnvVars: [],
|
||||
},
|
||||
experimental: {},
|
||||
contentGenerator: {},
|
||||
systemPromptMappings: {},
|
||||
extensions: {
|
||||
disabled: [],
|
||||
workspacesWithMigrationNudge: [],
|
||||
@@ -260,6 +270,9 @@ describe('Settings Loading and Merging', () => {
|
||||
advanced: {
|
||||
excludedEnvVars: [],
|
||||
},
|
||||
experimental: {},
|
||||
contentGenerator: {},
|
||||
systemPromptMappings: {},
|
||||
extensions: {
|
||||
disabled: [],
|
||||
workspacesWithMigrationNudge: [],
|
||||
@@ -320,6 +333,9 @@ describe('Settings Loading and Merging', () => {
|
||||
advanced: {
|
||||
excludedEnvVars: [],
|
||||
},
|
||||
experimental: {},
|
||||
contentGenerator: {},
|
||||
systemPromptMappings: {},
|
||||
extensions: {
|
||||
disabled: [],
|
||||
workspacesWithMigrationNudge: [],
|
||||
@@ -385,6 +401,9 @@ describe('Settings Loading and Merging', () => {
|
||||
advanced: {
|
||||
excludedEnvVars: [],
|
||||
},
|
||||
experimental: {},
|
||||
contentGenerator: {},
|
||||
systemPromptMappings: {},
|
||||
extensions: {
|
||||
disabled: [],
|
||||
workspacesWithMigrationNudge: [],
|
||||
@@ -477,6 +496,9 @@ describe('Settings Loading and Merging', () => {
|
||||
advanced: {
|
||||
excludedEnvVars: [],
|
||||
},
|
||||
experimental: {},
|
||||
contentGenerator: {},
|
||||
systemPromptMappings: {},
|
||||
extensions: {
|
||||
disabled: [],
|
||||
workspacesWithMigrationNudge: [],
|
||||
@@ -562,6 +584,9 @@ describe('Settings Loading and Merging', () => {
|
||||
advanced: {
|
||||
excludedEnvVars: [],
|
||||
},
|
||||
experimental: {},
|
||||
contentGenerator: {},
|
||||
systemPromptMappings: {},
|
||||
extensions: {
|
||||
disabled: [],
|
||||
workspacesWithMigrationNudge: [],
|
||||
@@ -691,6 +716,9 @@ describe('Settings Loading and Merging', () => {
|
||||
'/system/dir',
|
||||
],
|
||||
},
|
||||
experimental: {},
|
||||
contentGenerator: {},
|
||||
systemPromptMappings: {},
|
||||
extensions: {
|
||||
disabled: [],
|
||||
workspacesWithMigrationNudge: [],
|
||||
@@ -1431,6 +1459,9 @@ describe('Settings Loading and Merging', () => {
|
||||
advanced: {
|
||||
excludedEnvVars: [],
|
||||
},
|
||||
experimental: {},
|
||||
contentGenerator: {},
|
||||
systemPromptMappings: {},
|
||||
extensions: {
|
||||
disabled: [],
|
||||
workspacesWithMigrationNudge: [],
|
||||
@@ -1516,7 +1547,11 @@ describe('Settings Loading and Merging', () => {
|
||||
'workspace_endpoint_from_env/api',
|
||||
);
|
||||
expect(
|
||||
(settings.workspace.settings as TestSettings)['nested']['value'],
|
||||
(
|
||||
(settings.workspace.settings as TestSettings).nested as {
|
||||
[key: string]: unknown;
|
||||
}
|
||||
)['value'],
|
||||
).toBe('workspace_endpoint_from_env');
|
||||
expect((settings.merged as TestSettings)['endpoint']).toBe(
|
||||
'workspace_endpoint_from_env/api',
|
||||
@@ -1766,19 +1801,39 @@ describe('Settings Loading and Merging', () => {
|
||||
).toBeUndefined();
|
||||
|
||||
expect(
|
||||
(settings.user.settings as TestSettings)['nestedObj']['nestedNull'],
|
||||
(
|
||||
(settings.user.settings as TestSettings).nestedObj as {
|
||||
[key: string]: unknown;
|
||||
}
|
||||
)['nestedNull'],
|
||||
).toBeNull();
|
||||
expect(
|
||||
(settings.user.settings as TestSettings)['nestedObj']['nestedBool'],
|
||||
(
|
||||
(settings.user.settings as TestSettings).nestedObj as {
|
||||
[key: string]: unknown;
|
||||
}
|
||||
)['nestedBool'],
|
||||
).toBe(true);
|
||||
expect(
|
||||
(settings.user.settings as TestSettings)['nestedObj']['nestedNum'],
|
||||
(
|
||||
(settings.user.settings as TestSettings).nestedObj as {
|
||||
[key: string]: unknown;
|
||||
}
|
||||
)['nestedNum'],
|
||||
).toBe(0);
|
||||
expect(
|
||||
(settings.user.settings as TestSettings)['nestedObj']['nestedString'],
|
||||
(
|
||||
(settings.user.settings as TestSettings).nestedObj as {
|
||||
[key: string]: unknown;
|
||||
}
|
||||
)['nestedString'],
|
||||
).toBe('literal');
|
||||
expect(
|
||||
(settings.user.settings as TestSettings)['nestedObj']['anotherEnv'],
|
||||
(
|
||||
(settings.user.settings as TestSettings).nestedObj as {
|
||||
[key: string]: unknown;
|
||||
}
|
||||
)['anotherEnv'],
|
||||
).toBe('env_string_nested_value');
|
||||
|
||||
delete process.env['MY_ENV_STRING'];
|
||||
@@ -1864,6 +1919,9 @@ describe('Settings Loading and Merging', () => {
|
||||
advanced: {
|
||||
excludedEnvVars: [],
|
||||
},
|
||||
experimental: {},
|
||||
contentGenerator: {},
|
||||
systemPromptMappings: {},
|
||||
extensions: {
|
||||
disabled: [],
|
||||
workspacesWithMigrationNudge: [],
|
||||
@@ -2336,14 +2394,14 @@ describe('Settings Loading and Merging', () => {
|
||||
vimMode: false,
|
||||
},
|
||||
model: {
|
||||
maxSessionTurns: 0,
|
||||
maxSessionTurns: -1,
|
||||
},
|
||||
context: {
|
||||
includeDirectories: [],
|
||||
},
|
||||
security: {
|
||||
folderTrust: {
|
||||
enabled: null,
|
||||
enabled: false,
|
||||
},
|
||||
},
|
||||
};
|
||||
@@ -2352,9 +2410,9 @@ describe('Settings Loading and Merging', () => {
|
||||
|
||||
expect(v1Settings).toEqual({
|
||||
vimMode: false,
|
||||
maxSessionTurns: 0,
|
||||
maxSessionTurns: -1,
|
||||
includeDirectories: [],
|
||||
folderTrust: null,
|
||||
folderTrust: false,
|
||||
});
|
||||
});
|
||||
|
||||
|
||||
@@ -396,6 +396,24 @@ function mergeSettings(
|
||||
]),
|
||||
],
|
||||
},
|
||||
experimental: {
|
||||
...(systemDefaults.experimental || {}),
|
||||
...(user.experimental || {}),
|
||||
...(safeWorkspaceWithoutFolderTrust.experimental || {}),
|
||||
...(system.experimental || {}),
|
||||
},
|
||||
contentGenerator: {
|
||||
...(systemDefaults.contentGenerator || {}),
|
||||
...(user.contentGenerator || {}),
|
||||
...(safeWorkspaceWithoutFolderTrust.contentGenerator || {}),
|
||||
...(system.contentGenerator || {}),
|
||||
},
|
||||
systemPromptMappings: {
|
||||
...(systemDefaults.systemPromptMappings || {}),
|
||||
...(user.systemPromptMappings || {}),
|
||||
...(safeWorkspaceWithoutFolderTrust.systemPromptMappings || {}),
|
||||
...(system.systemPromptMappings || {}),
|
||||
},
|
||||
extensions: {
|
||||
...(systemDefaults.extensions || {}),
|
||||
...(user.extensions || {}),
|
||||
|
||||
@@ -741,6 +741,26 @@ export const SETTINGS_SCHEMA = {
|
||||
description: 'Enable extension management features.',
|
||||
showInDialog: false,
|
||||
},
|
||||
visionModelPreview: {
|
||||
type: 'boolean',
|
||||
label: 'Vision Model Preview',
|
||||
category: 'Experimental',
|
||||
requiresRestart: false,
|
||||
default: true,
|
||||
description:
|
||||
'Enable vision model support and auto-switching functionality. When disabled, vision models like qwen-vl-max-latest will be hidden and auto-switching will not occur.',
|
||||
showInDialog: true,
|
||||
},
|
||||
vlmSwitchMode: {
|
||||
type: 'string',
|
||||
label: 'VLM Switch Mode',
|
||||
category: 'Experimental',
|
||||
requiresRestart: false,
|
||||
default: undefined as string | undefined,
|
||||
description:
|
||||
'Default behavior when images are detected in input. Values: once (one-time switch), session (switch for entire session), persist (continue with current model). If not set, user will be prompted each time. This is a temporary experimental feature.',
|
||||
showInDialog: false,
|
||||
},
|
||||
},
|
||||
},
|
||||
|
||||
|
||||
@@ -56,6 +56,13 @@ vi.mock('../ui/commands/mcpCommand.js', () => ({
|
||||
kind: 'BUILT_IN',
|
||||
},
|
||||
}));
|
||||
vi.mock('../ui/commands/modelCommand.js', () => ({
|
||||
modelCommand: {
|
||||
name: 'model',
|
||||
description: 'Model command',
|
||||
kind: 'BUILT_IN',
|
||||
},
|
||||
}));
|
||||
|
||||
describe('BuiltinCommandLoader', () => {
|
||||
let mockConfig: Config;
|
||||
@@ -126,5 +133,8 @@ describe('BuiltinCommandLoader', () => {
|
||||
|
||||
const mcpCmd = commands.find((c) => c.name === 'mcp');
|
||||
expect(mcpCmd).toBeDefined();
|
||||
|
||||
const modelCmd = commands.find((c) => c.name === 'model');
|
||||
expect(modelCmd).toBeDefined();
|
||||
});
|
||||
});
|
||||
|
||||
@@ -35,6 +35,7 @@ import { settingsCommand } from '../ui/commands/settingsCommand.js';
|
||||
import { vimCommand } from '../ui/commands/vimCommand.js';
|
||||
import { setupGithubCommand } from '../ui/commands/setupGithubCommand.js';
|
||||
import { terminalSetupCommand } from '../ui/commands/terminalSetupCommand.js';
|
||||
import { modelCommand } from '../ui/commands/modelCommand.js';
|
||||
import { agentsCommand } from '../ui/commands/agentsCommand.js';
|
||||
|
||||
/**
|
||||
@@ -71,6 +72,7 @@ export class BuiltinCommandLoader implements ICommandLoader {
|
||||
initCommand,
|
||||
mcpCommand,
|
||||
memoryCommand,
|
||||
modelCommand,
|
||||
privacyCommand,
|
||||
quitCommand,
|
||||
quitConfirmCommand,
|
||||
|
||||
@@ -53,6 +53,17 @@ import { FolderTrustDialog } from './components/FolderTrustDialog.js';
|
||||
import { ShellConfirmationDialog } from './components/ShellConfirmationDialog.js';
|
||||
import { QuitConfirmationDialog } from './components/QuitConfirmationDialog.js';
|
||||
import { RadioButtonSelect } from './components/shared/RadioButtonSelect.js';
|
||||
import { ModelSelectionDialog } from './components/ModelSelectionDialog.js';
|
||||
import {
|
||||
ModelSwitchDialog,
|
||||
type VisionSwitchOutcome,
|
||||
} from './components/ModelSwitchDialog.js';
|
||||
import {
|
||||
getOpenAIAvailableModelFromEnv,
|
||||
getFilteredQwenModels,
|
||||
type AvailableModel,
|
||||
} from './models/availableModels.js';
|
||||
import { processVisionSwitchOutcome } from './hooks/useVisionAutoSwitch.js';
|
||||
import {
|
||||
AgentCreationWizard,
|
||||
AgentsManagerDialog,
|
||||
@@ -248,6 +259,20 @@ const App = ({ config, settings, startupWarnings = [], version }: AppProps) => {
|
||||
onWorkspaceMigrationDialogClose,
|
||||
} = useWorkspaceMigration(settings);
|
||||
|
||||
// Model selection dialog states
|
||||
const [isModelSelectionDialogOpen, setIsModelSelectionDialogOpen] =
|
||||
useState(false);
|
||||
const [isVisionSwitchDialogOpen, setIsVisionSwitchDialogOpen] =
|
||||
useState(false);
|
||||
const [visionSwitchResolver, setVisionSwitchResolver] = useState<{
|
||||
resolve: (result: {
|
||||
modelOverride?: string;
|
||||
persistSessionModel?: string;
|
||||
showGuidance?: boolean;
|
||||
}) => void;
|
||||
reject: () => void;
|
||||
} | null>(null);
|
||||
|
||||
useEffect(() => {
|
||||
const unsubscribe = ideContext.subscribeToIdeContext(setIdeContextState);
|
||||
// Set the initial value
|
||||
@@ -590,6 +615,75 @@ const App = ({ config, settings, startupWarnings = [], version }: AppProps) => {
|
||||
openAuthDialog();
|
||||
}, [openAuthDialog, setAuthError]);
|
||||
|
||||
// Vision switch handler for auto-switch functionality
|
||||
const handleVisionSwitchRequired = useCallback(
|
||||
async (_query: unknown) =>
|
||||
new Promise<{
|
||||
modelOverride?: string;
|
||||
persistSessionModel?: string;
|
||||
showGuidance?: boolean;
|
||||
}>((resolve, reject) => {
|
||||
setVisionSwitchResolver({ resolve, reject });
|
||||
setIsVisionSwitchDialogOpen(true);
|
||||
}),
|
||||
[],
|
||||
);
|
||||
|
||||
const handleVisionSwitchSelect = useCallback(
|
||||
(outcome: VisionSwitchOutcome) => {
|
||||
setIsVisionSwitchDialogOpen(false);
|
||||
if (visionSwitchResolver) {
|
||||
const result = processVisionSwitchOutcome(outcome);
|
||||
visionSwitchResolver.resolve(result);
|
||||
setVisionSwitchResolver(null);
|
||||
}
|
||||
},
|
||||
[visionSwitchResolver],
|
||||
);
|
||||
|
||||
const handleModelSelectionOpen = useCallback(() => {
|
||||
setIsModelSelectionDialogOpen(true);
|
||||
}, []);
|
||||
|
||||
const handleModelSelectionClose = useCallback(() => {
|
||||
setIsModelSelectionDialogOpen(false);
|
||||
}, []);
|
||||
|
||||
const handleModelSelect = useCallback(
|
||||
(modelId: string) => {
|
||||
config.setModel(modelId);
|
||||
setCurrentModel(modelId);
|
||||
setIsModelSelectionDialogOpen(false);
|
||||
addItem(
|
||||
{
|
||||
type: MessageType.INFO,
|
||||
text: `Switched model to \`${modelId}\` for this session.`,
|
||||
},
|
||||
Date.now(),
|
||||
);
|
||||
},
|
||||
[config, setCurrentModel, addItem],
|
||||
);
|
||||
|
||||
const getAvailableModelsForCurrentAuth = useCallback((): AvailableModel[] => {
|
||||
const contentGeneratorConfig = config.getContentGeneratorConfig();
|
||||
if (!contentGeneratorConfig) return [];
|
||||
|
||||
const visionModelPreviewEnabled =
|
||||
settings.merged.experimental?.visionModelPreview ?? true;
|
||||
|
||||
switch (contentGeneratorConfig.authType) {
|
||||
case AuthType.QWEN_OAUTH:
|
||||
return getFilteredQwenModels(visionModelPreviewEnabled);
|
||||
case AuthType.USE_OPENAI: {
|
||||
const openAIModel = getOpenAIAvailableModelFromEnv();
|
||||
return openAIModel ? [openAIModel] : [];
|
||||
}
|
||||
default:
|
||||
return [];
|
||||
}
|
||||
}, [config, settings.merged.experimental?.visionModelPreview]);
|
||||
|
||||
// Core hooks and processors
|
||||
const {
|
||||
vimEnabled: vimModeEnabled,
|
||||
@@ -620,6 +714,7 @@ const App = ({ config, settings, startupWarnings = [], version }: AppProps) => {
|
||||
setQuittingMessages,
|
||||
openPrivacyNotice,
|
||||
openSettingsDialog,
|
||||
handleModelSelectionOpen,
|
||||
openSubagentCreateDialog,
|
||||
openAgentsManagerDialog,
|
||||
toggleVimEnabled,
|
||||
@@ -664,10 +759,18 @@ const App = ({ config, settings, startupWarnings = [], version }: AppProps) => {
|
||||
setModelSwitchedFromQuotaError,
|
||||
refreshStatic,
|
||||
() => cancelHandlerRef.current(),
|
||||
settings.merged.experimental?.visionModelPreview ?? true,
|
||||
handleVisionSwitchRequired,
|
||||
);
|
||||
|
||||
const pendingHistoryItems = useMemo(
|
||||
() => [...pendingSlashCommandHistoryItems, ...pendingGeminiHistoryItems],
|
||||
() =>
|
||||
[...pendingSlashCommandHistoryItems, ...pendingGeminiHistoryItems].map(
|
||||
(item, index) => ({
|
||||
...item,
|
||||
id: index,
|
||||
}),
|
||||
),
|
||||
[pendingSlashCommandHistoryItems, pendingGeminiHistoryItems],
|
||||
);
|
||||
|
||||
@@ -1028,6 +1131,8 @@ const App = ({ config, settings, startupWarnings = [], version }: AppProps) => {
|
||||
!isAuthDialogOpen &&
|
||||
!isThemeDialogOpen &&
|
||||
!isEditorDialogOpen &&
|
||||
!isModelSelectionDialogOpen &&
|
||||
!isVisionSwitchDialogOpen &&
|
||||
!isSubagentCreateDialogOpen &&
|
||||
!showPrivacyNotice &&
|
||||
!showWelcomeBackDialog &&
|
||||
@@ -1049,6 +1154,8 @@ const App = ({ config, settings, startupWarnings = [], version }: AppProps) => {
|
||||
showWelcomeBackDialog,
|
||||
welcomeBackChoice,
|
||||
geminiClient,
|
||||
isModelSelectionDialogOpen,
|
||||
isVisionSwitchDialogOpen,
|
||||
]);
|
||||
|
||||
if (quittingMessages) {
|
||||
@@ -1121,16 +1228,14 @@ const App = ({ config, settings, startupWarnings = [], version }: AppProps) => {
|
||||
</Static>
|
||||
<OverflowProvider>
|
||||
<Box ref={pendingHistoryItemRef} flexDirection="column">
|
||||
{pendingHistoryItems.map((item, i) => (
|
||||
{pendingHistoryItems.map((item) => (
|
||||
<HistoryItemDisplay
|
||||
key={i}
|
||||
key={item.id}
|
||||
availableTerminalHeight={
|
||||
constrainHeight ? availableTerminalHeight : undefined
|
||||
}
|
||||
terminalWidth={mainAreaWidth}
|
||||
// TODO(taehykim): It seems like references to ids aren't necessary in
|
||||
// HistoryItemDisplay. Refactor later. Use a fake id for now.
|
||||
item={{ ...item, id: 0 }}
|
||||
item={item}
|
||||
isPending={true}
|
||||
config={config}
|
||||
isFocused={!isEditorDialogOpen}
|
||||
@@ -1318,6 +1423,15 @@ const App = ({ config, settings, startupWarnings = [], version }: AppProps) => {
|
||||
onExit={exitEditorDialog}
|
||||
/>
|
||||
</Box>
|
||||
) : isModelSelectionDialogOpen ? (
|
||||
<ModelSelectionDialog
|
||||
availableModels={getAvailableModelsForCurrentAuth()}
|
||||
currentModel={currentModel}
|
||||
onSelect={handleModelSelect}
|
||||
onCancel={handleModelSelectionClose}
|
||||
/>
|
||||
) : isVisionSwitchDialogOpen ? (
|
||||
<ModelSwitchDialog onSelect={handleVisionSwitchSelect} />
|
||||
) : showPrivacyNotice ? (
|
||||
<PrivacyNotice
|
||||
onExit={() => setShowPrivacyNotice(false)}
|
||||
|
||||
179
packages/cli/src/ui/commands/modelCommand.test.ts
Normal file
179
packages/cli/src/ui/commands/modelCommand.test.ts
Normal file
@@ -0,0 +1,179 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Qwen
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { describe, it, expect, beforeEach, vi } from 'vitest';
|
||||
import { modelCommand } from './modelCommand.js';
|
||||
import { type CommandContext } from './types.js';
|
||||
import { createMockCommandContext } from '../../test-utils/mockCommandContext.js';
|
||||
import {
|
||||
AuthType,
|
||||
type ContentGeneratorConfig,
|
||||
type Config,
|
||||
} from '@qwen-code/qwen-code-core';
|
||||
import * as availableModelsModule from '../models/availableModels.js';
|
||||
|
||||
// Mock the availableModels module
|
||||
vi.mock('../models/availableModels.js', () => ({
|
||||
AVAILABLE_MODELS_QWEN: [
|
||||
{ id: 'qwen3-coder-plus', label: 'qwen3-coder-plus' },
|
||||
{ id: 'qwen-vl-max-latest', label: 'qwen-vl-max', isVision: true },
|
||||
],
|
||||
getOpenAIAvailableModelFromEnv: vi.fn(),
|
||||
}));
|
||||
|
||||
// Helper function to create a mock config
|
||||
function createMockConfig(
|
||||
contentGeneratorConfig: ContentGeneratorConfig | null,
|
||||
): Partial<Config> {
|
||||
return {
|
||||
getContentGeneratorConfig: vi.fn().mockReturnValue(contentGeneratorConfig),
|
||||
};
|
||||
}
|
||||
|
||||
describe('modelCommand', () => {
|
||||
let mockContext: CommandContext;
|
||||
const mockGetOpenAIAvailableModelFromEnv = vi.mocked(
|
||||
availableModelsModule.getOpenAIAvailableModelFromEnv,
|
||||
);
|
||||
|
||||
beforeEach(() => {
|
||||
mockContext = createMockCommandContext();
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
it('should have the correct name and description', () => {
|
||||
expect(modelCommand.name).toBe('model');
|
||||
expect(modelCommand.description).toBe('Switch the model for this session');
|
||||
});
|
||||
|
||||
it('should return error when config is not available', async () => {
|
||||
mockContext.services.config = null;
|
||||
|
||||
const result = await modelCommand.action!(mockContext, '');
|
||||
|
||||
expect(result).toEqual({
|
||||
type: 'message',
|
||||
messageType: 'error',
|
||||
content: 'Configuration not available.',
|
||||
});
|
||||
});
|
||||
|
||||
it('should return error when content generator config is not available', async () => {
|
||||
const mockConfig = createMockConfig(null);
|
||||
mockContext.services.config = mockConfig as Config;
|
||||
|
||||
const result = await modelCommand.action!(mockContext, '');
|
||||
|
||||
expect(result).toEqual({
|
||||
type: 'message',
|
||||
messageType: 'error',
|
||||
content: 'Content generator configuration not available.',
|
||||
});
|
||||
});
|
||||
|
||||
it('should return error when auth type is not available', async () => {
|
||||
const mockConfig = createMockConfig({
|
||||
model: 'test-model',
|
||||
authType: undefined,
|
||||
});
|
||||
mockContext.services.config = mockConfig as Config;
|
||||
|
||||
const result = await modelCommand.action!(mockContext, '');
|
||||
|
||||
expect(result).toEqual({
|
||||
type: 'message',
|
||||
messageType: 'error',
|
||||
content: 'Authentication type not available.',
|
||||
});
|
||||
});
|
||||
|
||||
it('should return dialog action for QWEN_OAUTH auth type', async () => {
|
||||
const mockConfig = createMockConfig({
|
||||
model: 'test-model',
|
||||
authType: AuthType.QWEN_OAUTH,
|
||||
});
|
||||
mockContext.services.config = mockConfig as Config;
|
||||
|
||||
const result = await modelCommand.action!(mockContext, '');
|
||||
|
||||
expect(result).toEqual({
|
||||
type: 'dialog',
|
||||
dialog: 'model',
|
||||
});
|
||||
});
|
||||
|
||||
it('should return dialog action for USE_OPENAI auth type when model is available', async () => {
|
||||
mockGetOpenAIAvailableModelFromEnv.mockReturnValue({
|
||||
id: 'gpt-4',
|
||||
label: 'gpt-4',
|
||||
});
|
||||
|
||||
const mockConfig = createMockConfig({
|
||||
model: 'test-model',
|
||||
authType: AuthType.USE_OPENAI,
|
||||
});
|
||||
mockContext.services.config = mockConfig as Config;
|
||||
|
||||
const result = await modelCommand.action!(mockContext, '');
|
||||
|
||||
expect(result).toEqual({
|
||||
type: 'dialog',
|
||||
dialog: 'model',
|
||||
});
|
||||
});
|
||||
|
||||
it('should return error for USE_OPENAI auth type when no model is available', async () => {
|
||||
mockGetOpenAIAvailableModelFromEnv.mockReturnValue(null);
|
||||
|
||||
const mockConfig = createMockConfig({
|
||||
model: 'test-model',
|
||||
authType: AuthType.USE_OPENAI,
|
||||
});
|
||||
mockContext.services.config = mockConfig as Config;
|
||||
|
||||
const result = await modelCommand.action!(mockContext, '');
|
||||
|
||||
expect(result).toEqual({
|
||||
type: 'message',
|
||||
messageType: 'error',
|
||||
content:
|
||||
'No models available for the current authentication type (openai).',
|
||||
});
|
||||
});
|
||||
|
||||
it('should return error for unsupported auth types', async () => {
|
||||
const mockConfig = createMockConfig({
|
||||
model: 'test-model',
|
||||
authType: 'UNSUPPORTED_AUTH_TYPE' as AuthType,
|
||||
});
|
||||
mockContext.services.config = mockConfig as Config;
|
||||
|
||||
const result = await modelCommand.action!(mockContext, '');
|
||||
|
||||
expect(result).toEqual({
|
||||
type: 'message',
|
||||
messageType: 'error',
|
||||
content:
|
||||
'No models available for the current authentication type (UNSUPPORTED_AUTH_TYPE).',
|
||||
});
|
||||
});
|
||||
|
||||
it('should handle undefined auth type', async () => {
|
||||
const mockConfig = createMockConfig({
|
||||
model: 'test-model',
|
||||
authType: undefined,
|
||||
});
|
||||
mockContext.services.config = mockConfig as Config;
|
||||
|
||||
const result = await modelCommand.action!(mockContext, '');
|
||||
|
||||
expect(result).toEqual({
|
||||
type: 'message',
|
||||
messageType: 'error',
|
||||
content: 'Authentication type not available.',
|
||||
});
|
||||
});
|
||||
});
|
||||
88
packages/cli/src/ui/commands/modelCommand.ts
Normal file
88
packages/cli/src/ui/commands/modelCommand.ts
Normal file
@@ -0,0 +1,88 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Qwen
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { AuthType } from '@qwen-code/qwen-code-core';
|
||||
import type {
|
||||
SlashCommand,
|
||||
CommandContext,
|
||||
OpenDialogActionReturn,
|
||||
MessageActionReturn,
|
||||
} from './types.js';
|
||||
import { CommandKind } from './types.js';
|
||||
import {
|
||||
AVAILABLE_MODELS_QWEN,
|
||||
getOpenAIAvailableModelFromEnv,
|
||||
type AvailableModel,
|
||||
} from '../models/availableModels.js';
|
||||
|
||||
function getAvailableModelsForAuthType(authType: AuthType): AvailableModel[] {
|
||||
switch (authType) {
|
||||
case AuthType.QWEN_OAUTH:
|
||||
return AVAILABLE_MODELS_QWEN;
|
||||
case AuthType.USE_OPENAI: {
|
||||
const openAIModel = getOpenAIAvailableModelFromEnv();
|
||||
return openAIModel ? [openAIModel] : [];
|
||||
}
|
||||
default:
|
||||
// For other auth types, return empty array for now
|
||||
// This can be expanded later according to the design doc
|
||||
return [];
|
||||
}
|
||||
}
|
||||
|
||||
export const modelCommand: SlashCommand = {
|
||||
name: 'model',
|
||||
description: 'Switch the model for this session',
|
||||
kind: CommandKind.BUILT_IN,
|
||||
action: async (
|
||||
context: CommandContext,
|
||||
): Promise<OpenDialogActionReturn | MessageActionReturn> => {
|
||||
const { services } = context;
|
||||
const { config } = services;
|
||||
|
||||
if (!config) {
|
||||
return {
|
||||
type: 'message',
|
||||
messageType: 'error',
|
||||
content: 'Configuration not available.',
|
||||
};
|
||||
}
|
||||
|
||||
const contentGeneratorConfig = config.getContentGeneratorConfig();
|
||||
if (!contentGeneratorConfig) {
|
||||
return {
|
||||
type: 'message',
|
||||
messageType: 'error',
|
||||
content: 'Content generator configuration not available.',
|
||||
};
|
||||
}
|
||||
|
||||
const authType = contentGeneratorConfig.authType;
|
||||
if (!authType) {
|
||||
return {
|
||||
type: 'message',
|
||||
messageType: 'error',
|
||||
content: 'Authentication type not available.',
|
||||
};
|
||||
}
|
||||
|
||||
const availableModels = getAvailableModelsForAuthType(authType);
|
||||
|
||||
if (availableModels.length === 0) {
|
||||
return {
|
||||
type: 'message',
|
||||
messageType: 'error',
|
||||
content: `No models available for the current authentication type (${authType}).`,
|
||||
};
|
||||
}
|
||||
|
||||
// Trigger model selection dialog
|
||||
return {
|
||||
type: 'dialog',
|
||||
dialog: 'model',
|
||||
};
|
||||
},
|
||||
};
|
||||
@@ -116,6 +116,7 @@ export interface OpenDialogActionReturn {
|
||||
| 'editor'
|
||||
| 'privacy'
|
||||
| 'settings'
|
||||
| 'model'
|
||||
| 'subagent_create'
|
||||
| 'subagent_list';
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
*/
|
||||
|
||||
import type React from 'react';
|
||||
import { memo } from 'react';
|
||||
import type { HistoryItem } from '../types.js';
|
||||
import { UserMessage } from './messages/UserMessage.js';
|
||||
import { UserShellMessage } from './messages/UserShellMessage.js';
|
||||
@@ -35,7 +36,7 @@ interface HistoryItemDisplayProps {
|
||||
commands?: readonly SlashCommand[];
|
||||
}
|
||||
|
||||
export const HistoryItemDisplay: React.FC<HistoryItemDisplayProps> = ({
|
||||
const HistoryItemDisplayComponent: React.FC<HistoryItemDisplayProps> = ({
|
||||
item,
|
||||
availableTerminalHeight,
|
||||
terminalWidth,
|
||||
@@ -101,3 +102,7 @@ export const HistoryItemDisplay: React.FC<HistoryItemDisplayProps> = ({
|
||||
{item.type === 'summary' && <SummaryMessage summary={item.summary} />}
|
||||
</Box>
|
||||
);
|
||||
|
||||
HistoryItemDisplayComponent.displayName = 'HistoryItemDisplay';
|
||||
|
||||
export const HistoryItemDisplay = memo(HistoryItemDisplayComponent);
|
||||
|
||||
246
packages/cli/src/ui/components/ModelSelectionDialog.test.tsx
Normal file
246
packages/cli/src/ui/components/ModelSelectionDialog.test.tsx
Normal file
@@ -0,0 +1,246 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Qwen
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import React from 'react';
|
||||
import { render } from 'ink-testing-library';
|
||||
import { describe, it, expect, vi, beforeEach } from 'vitest';
|
||||
import { ModelSelectionDialog } from './ModelSelectionDialog.js';
|
||||
import type { AvailableModel } from '../models/availableModels.js';
|
||||
import type { RadioSelectItem } from './shared/RadioButtonSelect.js';
|
||||
|
||||
// Mock the useKeypress hook
|
||||
const mockUseKeypress = vi.hoisted(() => vi.fn());
|
||||
vi.mock('../hooks/useKeypress.js', () => ({
|
||||
useKeypress: mockUseKeypress,
|
||||
}));
|
||||
|
||||
// Mock the RadioButtonSelect component
|
||||
const mockRadioButtonSelect = vi.hoisted(() => vi.fn());
|
||||
vi.mock('./shared/RadioButtonSelect.js', () => ({
|
||||
RadioButtonSelect: mockRadioButtonSelect,
|
||||
}));
|
||||
|
||||
describe('ModelSelectionDialog', () => {
|
||||
const mockAvailableModels: AvailableModel[] = [
|
||||
{ id: 'qwen3-coder-plus', label: 'qwen3-coder-plus' },
|
||||
{ id: 'qwen-vl-max-latest', label: 'qwen-vl-max', isVision: true },
|
||||
{ id: 'gpt-4', label: 'GPT-4' },
|
||||
];
|
||||
|
||||
const mockOnSelect = vi.fn();
|
||||
const mockOnCancel = vi.fn();
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
|
||||
// Mock RadioButtonSelect to return a simple div
|
||||
mockRadioButtonSelect.mockReturnValue(
|
||||
React.createElement('div', { 'data-testid': 'radio-select' }),
|
||||
);
|
||||
});
|
||||
|
||||
it('should setup escape key handler to call onCancel', () => {
|
||||
render(
|
||||
<ModelSelectionDialog
|
||||
availableModels={mockAvailableModels}
|
||||
currentModel="qwen3-coder-plus"
|
||||
onSelect={mockOnSelect}
|
||||
onCancel={mockOnCancel}
|
||||
/>,
|
||||
);
|
||||
|
||||
expect(mockUseKeypress).toHaveBeenCalledWith(expect.any(Function), {
|
||||
isActive: true,
|
||||
});
|
||||
|
||||
// Simulate escape key press
|
||||
const keypressHandler = mockUseKeypress.mock.calls[0][0];
|
||||
keypressHandler({ name: 'escape' });
|
||||
|
||||
expect(mockOnCancel).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should not call onCancel for non-escape keys', () => {
|
||||
render(
|
||||
<ModelSelectionDialog
|
||||
availableModels={mockAvailableModels}
|
||||
currentModel="qwen3-coder-plus"
|
||||
onSelect={mockOnSelect}
|
||||
onCancel={mockOnCancel}
|
||||
/>,
|
||||
);
|
||||
|
||||
const keypressHandler = mockUseKeypress.mock.calls[0][0];
|
||||
keypressHandler({ name: 'enter' });
|
||||
|
||||
expect(mockOnCancel).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should set correct initial index for current model', () => {
|
||||
render(
|
||||
<ModelSelectionDialog
|
||||
availableModels={mockAvailableModels}
|
||||
currentModel="qwen-vl-max-latest"
|
||||
onSelect={mockOnSelect}
|
||||
onCancel={mockOnCancel}
|
||||
/>,
|
||||
);
|
||||
|
||||
const callArgs = mockRadioButtonSelect.mock.calls[0][0];
|
||||
expect(callArgs.initialIndex).toBe(1); // qwen-vl-max-latest is at index 1
|
||||
});
|
||||
|
||||
it('should set initial index to 0 when current model is not found', () => {
|
||||
render(
|
||||
<ModelSelectionDialog
|
||||
availableModels={mockAvailableModels}
|
||||
currentModel="non-existent-model"
|
||||
onSelect={mockOnSelect}
|
||||
onCancel={mockOnCancel}
|
||||
/>,
|
||||
);
|
||||
|
||||
const callArgs = mockRadioButtonSelect.mock.calls[0][0];
|
||||
expect(callArgs.initialIndex).toBe(0);
|
||||
});
|
||||
|
||||
it('should call onSelect when a model is selected', () => {
|
||||
render(
|
||||
<ModelSelectionDialog
|
||||
availableModels={mockAvailableModels}
|
||||
currentModel="qwen3-coder-plus"
|
||||
onSelect={mockOnSelect}
|
||||
onCancel={mockOnCancel}
|
||||
/>,
|
||||
);
|
||||
|
||||
const callArgs = mockRadioButtonSelect.mock.calls[0][0];
|
||||
expect(typeof callArgs.onSelect).toBe('function');
|
||||
|
||||
// Simulate selection
|
||||
const onSelectCallback = mockRadioButtonSelect.mock.calls[0][0].onSelect;
|
||||
onSelectCallback('qwen-vl-max-latest');
|
||||
|
||||
expect(mockOnSelect).toHaveBeenCalledWith('qwen-vl-max-latest');
|
||||
});
|
||||
|
||||
it('should handle empty models array', () => {
|
||||
render(
|
||||
<ModelSelectionDialog
|
||||
availableModels={[]}
|
||||
currentModel=""
|
||||
onSelect={mockOnSelect}
|
||||
onCancel={mockOnCancel}
|
||||
/>,
|
||||
);
|
||||
|
||||
const callArgs = mockRadioButtonSelect.mock.calls[0][0];
|
||||
expect(callArgs.items).toEqual([]);
|
||||
expect(callArgs.initialIndex).toBe(0);
|
||||
});
|
||||
|
||||
it('should create correct option items with proper labels', () => {
|
||||
render(
|
||||
<ModelSelectionDialog
|
||||
availableModels={mockAvailableModels}
|
||||
currentModel="qwen3-coder-plus"
|
||||
onSelect={mockOnSelect}
|
||||
onCancel={mockOnCancel}
|
||||
/>,
|
||||
);
|
||||
|
||||
const expectedItems = [
|
||||
{
|
||||
label: 'qwen3-coder-plus (current)',
|
||||
value: 'qwen3-coder-plus',
|
||||
},
|
||||
{
|
||||
label: 'qwen-vl-max [Vision]',
|
||||
value: 'qwen-vl-max-latest',
|
||||
},
|
||||
{
|
||||
label: 'GPT-4',
|
||||
value: 'gpt-4',
|
||||
},
|
||||
];
|
||||
|
||||
const callArgs = mockRadioButtonSelect.mock.calls[0][0];
|
||||
expect(callArgs.items).toEqual(expectedItems);
|
||||
});
|
||||
|
||||
it('should show vision indicator for vision models', () => {
|
||||
render(
|
||||
<ModelSelectionDialog
|
||||
availableModels={mockAvailableModels}
|
||||
currentModel="gpt-4"
|
||||
onSelect={mockOnSelect}
|
||||
onCancel={mockOnCancel}
|
||||
/>,
|
||||
);
|
||||
|
||||
const callArgs = mockRadioButtonSelect.mock.calls[0][0];
|
||||
const visionModelItem = callArgs.items.find(
|
||||
(item: RadioSelectItem<string>) => item.value === 'qwen-vl-max-latest',
|
||||
);
|
||||
|
||||
expect(visionModelItem?.label).toContain('[Vision]');
|
||||
});
|
||||
|
||||
it('should show current indicator for the current model', () => {
|
||||
render(
|
||||
<ModelSelectionDialog
|
||||
availableModels={mockAvailableModels}
|
||||
currentModel="qwen-vl-max-latest"
|
||||
onSelect={mockOnSelect}
|
||||
onCancel={mockOnCancel}
|
||||
/>,
|
||||
);
|
||||
|
||||
const callArgs = mockRadioButtonSelect.mock.calls[0][0];
|
||||
const currentModelItem = callArgs.items.find(
|
||||
(item: RadioSelectItem<string>) => item.value === 'qwen-vl-max-latest',
|
||||
);
|
||||
|
||||
expect(currentModelItem?.label).toContain('(current)');
|
||||
});
|
||||
|
||||
it('should pass isFocused prop to RadioButtonSelect', () => {
|
||||
render(
|
||||
<ModelSelectionDialog
|
||||
availableModels={mockAvailableModels}
|
||||
currentModel="qwen3-coder-plus"
|
||||
onSelect={mockOnSelect}
|
||||
onCancel={mockOnCancel}
|
||||
/>,
|
||||
);
|
||||
|
||||
const callArgs = mockRadioButtonSelect.mock.calls[0][0];
|
||||
expect(callArgs.isFocused).toBe(true);
|
||||
});
|
||||
|
||||
it('should handle multiple onSelect calls correctly', () => {
|
||||
render(
|
||||
<ModelSelectionDialog
|
||||
availableModels={mockAvailableModels}
|
||||
currentModel="qwen3-coder-plus"
|
||||
onSelect={mockOnSelect}
|
||||
onCancel={mockOnCancel}
|
||||
/>,
|
||||
);
|
||||
|
||||
const onSelectCallback = mockRadioButtonSelect.mock.calls[0][0].onSelect;
|
||||
|
||||
// Call multiple times
|
||||
onSelectCallback('qwen3-coder-plus');
|
||||
onSelectCallback('qwen-vl-max-latest');
|
||||
onSelectCallback('gpt-4');
|
||||
|
||||
expect(mockOnSelect).toHaveBeenCalledTimes(3);
|
||||
expect(mockOnSelect).toHaveBeenNthCalledWith(1, 'qwen3-coder-plus');
|
||||
expect(mockOnSelect).toHaveBeenNthCalledWith(2, 'qwen-vl-max-latest');
|
||||
expect(mockOnSelect).toHaveBeenNthCalledWith(3, 'gpt-4');
|
||||
});
|
||||
});
|
||||
87
packages/cli/src/ui/components/ModelSelectionDialog.tsx
Normal file
87
packages/cli/src/ui/components/ModelSelectionDialog.tsx
Normal file
@@ -0,0 +1,87 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Qwen
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import type React from 'react';
|
||||
import { Box, Text } from 'ink';
|
||||
import { Colors } from '../colors.js';
|
||||
import {
|
||||
RadioButtonSelect,
|
||||
type RadioSelectItem,
|
||||
} from './shared/RadioButtonSelect.js';
|
||||
import { useKeypress } from '../hooks/useKeypress.js';
|
||||
import type { AvailableModel } from '../models/availableModels.js';
|
||||
|
||||
export interface ModelSelectionDialogProps {
|
||||
availableModels: AvailableModel[];
|
||||
currentModel: string;
|
||||
onSelect: (modelId: string) => void;
|
||||
onCancel: () => void;
|
||||
}
|
||||
|
||||
export const ModelSelectionDialog: React.FC<ModelSelectionDialogProps> = ({
|
||||
availableModels,
|
||||
currentModel,
|
||||
onSelect,
|
||||
onCancel,
|
||||
}) => {
|
||||
useKeypress(
|
||||
(key) => {
|
||||
if (key.name === 'escape') {
|
||||
onCancel();
|
||||
}
|
||||
},
|
||||
{ isActive: true },
|
||||
);
|
||||
|
||||
const options: Array<RadioSelectItem<string>> = availableModels.map(
|
||||
(model) => {
|
||||
const visionIndicator = model.isVision ? ' [Vision]' : '';
|
||||
const currentIndicator = model.id === currentModel ? ' (current)' : '';
|
||||
return {
|
||||
label: `${model.label}${visionIndicator}${currentIndicator}`,
|
||||
value: model.id,
|
||||
};
|
||||
},
|
||||
);
|
||||
|
||||
const initialIndex = Math.max(
|
||||
0,
|
||||
availableModels.findIndex((model) => model.id === currentModel),
|
||||
);
|
||||
|
||||
const handleSelect = (modelId: string) => {
|
||||
onSelect(modelId);
|
||||
};
|
||||
|
||||
return (
|
||||
<Box
|
||||
flexDirection="column"
|
||||
borderStyle="round"
|
||||
borderColor={Colors.AccentBlue}
|
||||
padding={1}
|
||||
width="100%"
|
||||
marginLeft={1}
|
||||
>
|
||||
<Box flexDirection="column" marginBottom={1}>
|
||||
<Text bold>Select Model</Text>
|
||||
<Text>Choose a model for this session:</Text>
|
||||
</Box>
|
||||
|
||||
<Box marginBottom={1}>
|
||||
<RadioButtonSelect
|
||||
items={options}
|
||||
initialIndex={initialIndex}
|
||||
onSelect={handleSelect}
|
||||
isFocused
|
||||
/>
|
||||
</Box>
|
||||
|
||||
<Box>
|
||||
<Text color={Colors.Gray}>Press Enter to select, Esc to cancel</Text>
|
||||
</Box>
|
||||
</Box>
|
||||
);
|
||||
};
|
||||
181
packages/cli/src/ui/components/ModelSwitchDialog.test.tsx
Normal file
181
packages/cli/src/ui/components/ModelSwitchDialog.test.tsx
Normal file
@@ -0,0 +1,181 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Qwen
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import React from 'react';
|
||||
import { render } from 'ink-testing-library';
|
||||
import { describe, it, expect, vi, beforeEach } from 'vitest';
|
||||
import { ModelSwitchDialog, VisionSwitchOutcome } from './ModelSwitchDialog.js';
|
||||
|
||||
// Mock the useKeypress hook
|
||||
const mockUseKeypress = vi.hoisted(() => vi.fn());
|
||||
vi.mock('../hooks/useKeypress.js', () => ({
|
||||
useKeypress: mockUseKeypress,
|
||||
}));
|
||||
|
||||
// Mock the RadioButtonSelect component
|
||||
const mockRadioButtonSelect = vi.hoisted(() => vi.fn());
|
||||
vi.mock('./shared/RadioButtonSelect.js', () => ({
|
||||
RadioButtonSelect: mockRadioButtonSelect,
|
||||
}));
|
||||
|
||||
describe('ModelSwitchDialog', () => {
|
||||
const mockOnSelect = vi.fn();
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
|
||||
// Mock RadioButtonSelect to return a simple div
|
||||
mockRadioButtonSelect.mockReturnValue(
|
||||
React.createElement('div', { 'data-testid': 'radio-select' }),
|
||||
);
|
||||
});
|
||||
|
||||
it('should setup RadioButtonSelect with correct options', () => {
|
||||
render(<ModelSwitchDialog onSelect={mockOnSelect} />);
|
||||
|
||||
const expectedItems = [
|
||||
{
|
||||
label: 'Switch for this request only',
|
||||
value: VisionSwitchOutcome.SwitchOnce,
|
||||
},
|
||||
{
|
||||
label: 'Switch session to vision model',
|
||||
value: VisionSwitchOutcome.SwitchSessionToVL,
|
||||
},
|
||||
{
|
||||
label: 'Continue with current model',
|
||||
value: VisionSwitchOutcome.ContinueWithCurrentModel,
|
||||
},
|
||||
];
|
||||
|
||||
const callArgs = mockRadioButtonSelect.mock.calls[0][0];
|
||||
expect(callArgs.items).toEqual(expectedItems);
|
||||
expect(callArgs.initialIndex).toBe(0);
|
||||
expect(callArgs.isFocused).toBe(true);
|
||||
});
|
||||
|
||||
it('should call onSelect when an option is selected', () => {
|
||||
render(<ModelSwitchDialog onSelect={mockOnSelect} />);
|
||||
|
||||
const callArgs = mockRadioButtonSelect.mock.calls[0][0];
|
||||
expect(typeof callArgs.onSelect).toBe('function');
|
||||
|
||||
// Simulate selection of "Switch for this request only"
|
||||
const onSelectCallback = mockRadioButtonSelect.mock.calls[0][0].onSelect;
|
||||
onSelectCallback(VisionSwitchOutcome.SwitchOnce);
|
||||
|
||||
expect(mockOnSelect).toHaveBeenCalledWith(VisionSwitchOutcome.SwitchOnce);
|
||||
});
|
||||
|
||||
it('should call onSelect with SwitchSessionToVL when second option is selected', () => {
|
||||
render(<ModelSwitchDialog onSelect={mockOnSelect} />);
|
||||
|
||||
const onSelectCallback = mockRadioButtonSelect.mock.calls[0][0].onSelect;
|
||||
onSelectCallback(VisionSwitchOutcome.SwitchSessionToVL);
|
||||
|
||||
expect(mockOnSelect).toHaveBeenCalledWith(
|
||||
VisionSwitchOutcome.SwitchSessionToVL,
|
||||
);
|
||||
});
|
||||
|
||||
it('should call onSelect with ContinueWithCurrentModel when third option is selected', () => {
|
||||
render(<ModelSwitchDialog onSelect={mockOnSelect} />);
|
||||
|
||||
const onSelectCallback = mockRadioButtonSelect.mock.calls[0][0].onSelect;
|
||||
onSelectCallback(VisionSwitchOutcome.ContinueWithCurrentModel);
|
||||
|
||||
expect(mockOnSelect).toHaveBeenCalledWith(
|
||||
VisionSwitchOutcome.ContinueWithCurrentModel,
|
||||
);
|
||||
});
|
||||
|
||||
it('should setup escape key handler to call onSelect with ContinueWithCurrentModel', () => {
|
||||
render(<ModelSwitchDialog onSelect={mockOnSelect} />);
|
||||
|
||||
expect(mockUseKeypress).toHaveBeenCalledWith(expect.any(Function), {
|
||||
isActive: true,
|
||||
});
|
||||
|
||||
// Simulate escape key press
|
||||
const keypressHandler = mockUseKeypress.mock.calls[0][0];
|
||||
keypressHandler({ name: 'escape' });
|
||||
|
||||
expect(mockOnSelect).toHaveBeenCalledWith(
|
||||
VisionSwitchOutcome.ContinueWithCurrentModel,
|
||||
);
|
||||
});
|
||||
|
||||
it('should not call onSelect for non-escape keys', () => {
|
||||
render(<ModelSwitchDialog onSelect={mockOnSelect} />);
|
||||
|
||||
const keypressHandler = mockUseKeypress.mock.calls[0][0];
|
||||
keypressHandler({ name: 'enter' });
|
||||
|
||||
expect(mockOnSelect).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should set initial index to 0 (first option)', () => {
|
||||
render(<ModelSwitchDialog onSelect={mockOnSelect} />);
|
||||
|
||||
const callArgs = mockRadioButtonSelect.mock.calls[0][0];
|
||||
expect(callArgs.initialIndex).toBe(0);
|
||||
});
|
||||
|
||||
describe('VisionSwitchOutcome enum', () => {
|
||||
it('should have correct enum values', () => {
|
||||
expect(VisionSwitchOutcome.SwitchOnce).toBe('once');
|
||||
expect(VisionSwitchOutcome.SwitchSessionToVL).toBe('session');
|
||||
expect(VisionSwitchOutcome.ContinueWithCurrentModel).toBe('persist');
|
||||
});
|
||||
});
|
||||
|
||||
it('should handle multiple onSelect calls correctly', () => {
|
||||
render(<ModelSwitchDialog onSelect={mockOnSelect} />);
|
||||
|
||||
const onSelectCallback = mockRadioButtonSelect.mock.calls[0][0].onSelect;
|
||||
|
||||
// Call multiple times
|
||||
onSelectCallback(VisionSwitchOutcome.SwitchOnce);
|
||||
onSelectCallback(VisionSwitchOutcome.SwitchSessionToVL);
|
||||
onSelectCallback(VisionSwitchOutcome.ContinueWithCurrentModel);
|
||||
|
||||
expect(mockOnSelect).toHaveBeenCalledTimes(3);
|
||||
expect(mockOnSelect).toHaveBeenNthCalledWith(
|
||||
1,
|
||||
VisionSwitchOutcome.SwitchOnce,
|
||||
);
|
||||
expect(mockOnSelect).toHaveBeenNthCalledWith(
|
||||
2,
|
||||
VisionSwitchOutcome.SwitchSessionToVL,
|
||||
);
|
||||
expect(mockOnSelect).toHaveBeenNthCalledWith(
|
||||
3,
|
||||
VisionSwitchOutcome.ContinueWithCurrentModel,
|
||||
);
|
||||
});
|
||||
|
||||
it('should pass isFocused prop to RadioButtonSelect', () => {
|
||||
render(<ModelSwitchDialog onSelect={mockOnSelect} />);
|
||||
|
||||
const callArgs = mockRadioButtonSelect.mock.calls[0][0];
|
||||
expect(callArgs.isFocused).toBe(true);
|
||||
});
|
||||
|
||||
it('should handle escape key multiple times', () => {
|
||||
render(<ModelSwitchDialog onSelect={mockOnSelect} />);
|
||||
|
||||
const keypressHandler = mockUseKeypress.mock.calls[0][0];
|
||||
|
||||
// Call escape multiple times
|
||||
keypressHandler({ name: 'escape' });
|
||||
keypressHandler({ name: 'escape' });
|
||||
|
||||
expect(mockOnSelect).toHaveBeenCalledTimes(2);
|
||||
expect(mockOnSelect).toHaveBeenCalledWith(
|
||||
VisionSwitchOutcome.ContinueWithCurrentModel,
|
||||
);
|
||||
});
|
||||
});
|
||||
89
packages/cli/src/ui/components/ModelSwitchDialog.tsx
Normal file
89
packages/cli/src/ui/components/ModelSwitchDialog.tsx
Normal file
@@ -0,0 +1,89 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Qwen
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import type React from 'react';
|
||||
import { Box, Text } from 'ink';
|
||||
import { Colors } from '../colors.js';
|
||||
import {
|
||||
RadioButtonSelect,
|
||||
type RadioSelectItem,
|
||||
} from './shared/RadioButtonSelect.js';
|
||||
import { useKeypress } from '../hooks/useKeypress.js';
|
||||
|
||||
export enum VisionSwitchOutcome {
|
||||
SwitchOnce = 'once',
|
||||
SwitchSessionToVL = 'session',
|
||||
ContinueWithCurrentModel = 'persist',
|
||||
}
|
||||
|
||||
export interface ModelSwitchDialogProps {
|
||||
onSelect: (outcome: VisionSwitchOutcome) => void;
|
||||
}
|
||||
|
||||
export const ModelSwitchDialog: React.FC<ModelSwitchDialogProps> = ({
|
||||
onSelect,
|
||||
}) => {
|
||||
useKeypress(
|
||||
(key) => {
|
||||
if (key.name === 'escape') {
|
||||
onSelect(VisionSwitchOutcome.ContinueWithCurrentModel);
|
||||
}
|
||||
},
|
||||
{ isActive: true },
|
||||
);
|
||||
|
||||
const options: Array<RadioSelectItem<VisionSwitchOutcome>> = [
|
||||
{
|
||||
label: 'Switch for this request only',
|
||||
value: VisionSwitchOutcome.SwitchOnce,
|
||||
},
|
||||
{
|
||||
label: 'Switch session to vision model',
|
||||
value: VisionSwitchOutcome.SwitchSessionToVL,
|
||||
},
|
||||
{
|
||||
label: 'Continue with current model',
|
||||
value: VisionSwitchOutcome.ContinueWithCurrentModel,
|
||||
},
|
||||
];
|
||||
|
||||
const handleSelect = (outcome: VisionSwitchOutcome) => {
|
||||
onSelect(outcome);
|
||||
};
|
||||
|
||||
return (
|
||||
<Box
|
||||
flexDirection="column"
|
||||
borderStyle="round"
|
||||
borderColor={Colors.AccentYellow}
|
||||
padding={1}
|
||||
width="100%"
|
||||
marginLeft={1}
|
||||
>
|
||||
<Box flexDirection="column" marginBottom={1}>
|
||||
<Text bold>Vision Model Switch Required</Text>
|
||||
<Text>
|
||||
Your message contains an image, but the current model doesn't
|
||||
support vision.
|
||||
</Text>
|
||||
<Text>How would you like to proceed?</Text>
|
||||
</Box>
|
||||
|
||||
<Box marginBottom={1}>
|
||||
<RadioButtonSelect
|
||||
items={options}
|
||||
initialIndex={0}
|
||||
onSelect={handleSelect}
|
||||
isFocused
|
||||
/>
|
||||
</Box>
|
||||
|
||||
<Box>
|
||||
<Text color={Colors.Gray}>Press Enter to select, Esc to cancel</Text>
|
||||
</Box>
|
||||
</Box>
|
||||
);
|
||||
};
|
||||
@@ -27,6 +27,7 @@ export interface ToolConfirmationMessageProps {
|
||||
isFocused?: boolean;
|
||||
availableTerminalHeight?: number;
|
||||
terminalWidth: number;
|
||||
compactMode?: boolean;
|
||||
}
|
||||
|
||||
export const ToolConfirmationMessage: React.FC<
|
||||
@@ -37,6 +38,7 @@ export const ToolConfirmationMessage: React.FC<
|
||||
isFocused = true,
|
||||
availableTerminalHeight,
|
||||
terminalWidth,
|
||||
compactMode = false,
|
||||
}) => {
|
||||
const { onConfirm } = confirmationDetails;
|
||||
const childWidth = terminalWidth - 2; // 2 for padding
|
||||
@@ -70,6 +72,40 @@ export const ToolConfirmationMessage: React.FC<
|
||||
|
||||
const handleSelect = (item: ToolConfirmationOutcome) => handleConfirm(item);
|
||||
|
||||
// Compact mode: return simple 3-option display
|
||||
if (compactMode) {
|
||||
const compactOptions: Array<RadioSelectItem<ToolConfirmationOutcome>> = [
|
||||
{
|
||||
label: 'Yes, allow once',
|
||||
value: ToolConfirmationOutcome.ProceedOnce,
|
||||
},
|
||||
{
|
||||
label: 'Allow always',
|
||||
value: ToolConfirmationOutcome.ProceedAlways,
|
||||
},
|
||||
{
|
||||
label: 'No',
|
||||
value: ToolConfirmationOutcome.Cancel,
|
||||
},
|
||||
];
|
||||
|
||||
return (
|
||||
<Box flexDirection="column">
|
||||
<Box>
|
||||
<Text wrap="truncate">Do you want to proceed?</Text>
|
||||
</Box>
|
||||
<Box>
|
||||
<RadioButtonSelect
|
||||
items={compactOptions}
|
||||
onSelect={handleSelect}
|
||||
isFocused={isFocused}
|
||||
/>
|
||||
</Box>
|
||||
</Box>
|
||||
);
|
||||
}
|
||||
|
||||
// Original logic continues unchanged below
|
||||
let bodyContent: React.ReactNode | null = null; // Removed contextDisplay here
|
||||
let question: string;
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
*/
|
||||
|
||||
import { useReducer, useCallback, useMemo } from 'react';
|
||||
import { Box, Text, useInput } from 'ink';
|
||||
import { Box, Text } from 'ink';
|
||||
import { wizardReducer, initialWizardState } from '../reducers.js';
|
||||
import { LocationSelector } from './LocationSelector.js';
|
||||
import { GenerationMethodSelector } from './GenerationMethodSelector.js';
|
||||
@@ -20,6 +20,7 @@ import type { Config } from '@qwen-code/qwen-code-core';
|
||||
import { Colors } from '../../../colors.js';
|
||||
import { theme } from '../../../semantic-colors.js';
|
||||
import { TextEntryStep } from './TextEntryStep.js';
|
||||
import { useKeypress } from '../../../hooks/useKeypress.js';
|
||||
|
||||
interface AgentCreationWizardProps {
|
||||
onClose: () => void;
|
||||
@@ -49,8 +50,12 @@ export function AgentCreationWizard({
|
||||
}, [onClose]);
|
||||
|
||||
// Centralized ESC key handling for the entire wizard
|
||||
useInput((input, key) => {
|
||||
if (key.escape) {
|
||||
useKeypress(
|
||||
(key) => {
|
||||
if (key.name !== 'escape') {
|
||||
return;
|
||||
}
|
||||
|
||||
// LLM DescriptionInput handles its own ESC logic when generating
|
||||
const kind = getStepKind(state.generationMethod, state.currentStep);
|
||||
if (kind === 'LLM_DESC' && state.isGenerating) {
|
||||
@@ -64,8 +69,9 @@ export function AgentCreationWizard({
|
||||
// On other steps, ESC goes back to previous step
|
||||
handlePrevious();
|
||||
}
|
||||
}
|
||||
});
|
||||
},
|
||||
{ isActive: true },
|
||||
);
|
||||
|
||||
const stepProps: WizardStepProps = useMemo(
|
||||
() => ({
|
||||
|
||||
@@ -227,7 +227,7 @@ export const AgentSelectionStep = ({
|
||||
const textColor = isSelected ? theme.text.accent : theme.text.primary;
|
||||
|
||||
return (
|
||||
<Box key={agent.name} alignItems="center">
|
||||
<Box key={`${agent.name}-${agent.level}`} alignItems="center">
|
||||
<Box minWidth={2} flexShrink={0}>
|
||||
<Text color={isSelected ? theme.text.accent : theme.text.primary}>
|
||||
{isSelected ? '●' : ' '}
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
*/
|
||||
|
||||
import { useState, useCallback, useMemo, useEffect } from 'react';
|
||||
import { Box, Text, useInput } from 'ink';
|
||||
import { Box, Text } from 'ink';
|
||||
import { AgentSelectionStep } from './AgentSelectionStep.js';
|
||||
import { ActionSelectionStep } from './ActionSelectionStep.js';
|
||||
import { AgentViewerStep } from './AgentViewerStep.js';
|
||||
@@ -17,7 +17,8 @@ import { MANAGEMENT_STEPS } from '../types.js';
|
||||
import { Colors } from '../../../colors.js';
|
||||
import { theme } from '../../../semantic-colors.js';
|
||||
import { getColorForDisplay, shouldShowColor } from '../utils.js';
|
||||
import type { Config, SubagentConfig } from '@qwen-code/qwen-code-core';
|
||||
import type { SubagentConfig, Config } from '@qwen-code/qwen-code-core';
|
||||
import { useKeypress } from '../../../hooks/useKeypress.js';
|
||||
|
||||
interface AgentsManagerDialogProps {
|
||||
onClose: () => void;
|
||||
@@ -52,18 +53,7 @@ export function AgentsManagerDialog({
|
||||
const manager = config.getSubagentManager();
|
||||
|
||||
// Load agents from all levels separately to show all agents including conflicts
|
||||
const [projectAgents, userAgents, builtinAgents] = await Promise.all([
|
||||
manager.listSubagents({ level: 'project' }),
|
||||
manager.listSubagents({ level: 'user' }),
|
||||
manager.listSubagents({ level: 'builtin' }),
|
||||
]);
|
||||
|
||||
// Combine all agents (project, user, and builtin level)
|
||||
const allAgents = [
|
||||
...(projectAgents || []),
|
||||
...(userAgents || []),
|
||||
...(builtinAgents || []),
|
||||
];
|
||||
const allAgents = await manager.listSubagents();
|
||||
|
||||
setAvailableAgents(allAgents);
|
||||
}, [config]);
|
||||
@@ -122,8 +112,12 @@ export function AgentsManagerDialog({
|
||||
);
|
||||
|
||||
// Centralized ESC key handling for the entire dialog
|
||||
useInput((input, key) => {
|
||||
if (key.escape) {
|
||||
useKeypress(
|
||||
(key) => {
|
||||
if (key.name !== 'escape') {
|
||||
return;
|
||||
}
|
||||
|
||||
const currentStep = getCurrentStep();
|
||||
if (currentStep === MANAGEMENT_STEPS.AGENT_SELECTION) {
|
||||
// On first step, ESC cancels the entire dialog
|
||||
@@ -132,8 +126,9 @@ export function AgentsManagerDialog({
|
||||
// On other steps, ESC goes back to previous step in navigation stack
|
||||
handleNavigateBack();
|
||||
}
|
||||
}
|
||||
});
|
||||
},
|
||||
{ isActive: true },
|
||||
);
|
||||
|
||||
// Props for child components - now using direct state and callbacks
|
||||
const commonProps = useMemo(
|
||||
|
||||
@@ -18,12 +18,12 @@ import { COLOR_OPTIONS } from '../constants.js';
|
||||
import { fmtDuration } from '../utils.js';
|
||||
import { ToolConfirmationMessage } from '../../messages/ToolConfirmationMessage.js';
|
||||
|
||||
export type DisplayMode = 'default' | 'verbose';
|
||||
export type DisplayMode = 'compact' | 'default' | 'verbose';
|
||||
|
||||
export interface AgentExecutionDisplayProps {
|
||||
data: TaskResultDisplay;
|
||||
availableHeight?: number;
|
||||
childWidth?: number;
|
||||
childWidth: number;
|
||||
config: Config;
|
||||
}
|
||||
|
||||
@@ -80,7 +80,7 @@ export const AgentExecutionDisplay: React.FC<AgentExecutionDisplayProps> = ({
|
||||
childWidth,
|
||||
config,
|
||||
}) => {
|
||||
const [displayMode, setDisplayMode] = React.useState<DisplayMode>('default');
|
||||
const [displayMode, setDisplayMode] = React.useState<DisplayMode>('compact');
|
||||
|
||||
const agentColor = useMemo(() => {
|
||||
const colorOption = COLOR_OPTIONS.find(
|
||||
@@ -93,8 +93,6 @@ export const AgentExecutionDisplay: React.FC<AgentExecutionDisplayProps> = ({
|
||||
// This component only listens to keyboard shortcut events when the subagent is running
|
||||
if (data.status !== 'running') return '';
|
||||
|
||||
if (displayMode === 'verbose') return 'Press ctrl+r to show less.';
|
||||
|
||||
if (displayMode === 'default') {
|
||||
const hasMoreLines =
|
||||
data.taskPrompt.split('\n').length > MAX_TASK_PROMPT_LINES;
|
||||
@@ -102,17 +100,28 @@ export const AgentExecutionDisplay: React.FC<AgentExecutionDisplayProps> = ({
|
||||
data.toolCalls && data.toolCalls.length > MAX_TOOL_CALLS;
|
||||
|
||||
if (hasMoreToolCalls || hasMoreLines) {
|
||||
return 'Press ctrl+r to show more.';
|
||||
return 'Press ctrl+r to show less, ctrl+e to show more.';
|
||||
}
|
||||
return '';
|
||||
return 'Press ctrl+r to show less.';
|
||||
}
|
||||
return '';
|
||||
}, [displayMode, data.toolCalls, data.taskPrompt, data.status]);
|
||||
|
||||
// Handle ctrl+r keypresses to control display mode
|
||||
if (displayMode === 'verbose') {
|
||||
return 'Press ctrl+e to show less.';
|
||||
}
|
||||
|
||||
return '';
|
||||
}, [displayMode, data]);
|
||||
|
||||
// Handle keyboard shortcuts to control display mode
|
||||
useKeypress(
|
||||
(key) => {
|
||||
if (key.ctrl && key.name === 'r') {
|
||||
// ctrl+r toggles between compact and default
|
||||
setDisplayMode((current) =>
|
||||
current === 'compact' ? 'default' : 'compact',
|
||||
);
|
||||
} else if (key.ctrl && key.name === 'e') {
|
||||
// ctrl+e toggles between default and verbose
|
||||
setDisplayMode((current) =>
|
||||
current === 'default' ? 'verbose' : 'default',
|
||||
);
|
||||
@@ -121,6 +130,82 @@ export const AgentExecutionDisplay: React.FC<AgentExecutionDisplayProps> = ({
|
||||
{ isActive: true },
|
||||
);
|
||||
|
||||
if (displayMode === 'compact') {
|
||||
return (
|
||||
<Box flexDirection="column">
|
||||
{/* Header: Agent name and status */}
|
||||
{!data.pendingConfirmation && (
|
||||
<Box flexDirection="row">
|
||||
<Text bold color={agentColor}>
|
||||
{data.subagentName}
|
||||
</Text>
|
||||
<StatusDot status={data.status} />
|
||||
<StatusIndicator status={data.status} />
|
||||
</Box>
|
||||
)}
|
||||
|
||||
{/* Running state: Show current tool call and progress */}
|
||||
{data.status === 'running' && (
|
||||
<>
|
||||
{/* Current tool call */}
|
||||
{data.toolCalls && data.toolCalls.length > 0 && (
|
||||
<Box flexDirection="column">
|
||||
<ToolCallItem
|
||||
toolCall={data.toolCalls[data.toolCalls.length - 1]}
|
||||
compact={true}
|
||||
/>
|
||||
{/* Show count of additional tool calls if there are more than 1 */}
|
||||
{data.toolCalls.length > 1 && !data.pendingConfirmation && (
|
||||
<Box flexDirection="row" paddingLeft={4}>
|
||||
<Text color={Colors.Gray}>
|
||||
+{data.toolCalls.length - 1} more tool calls (ctrl+r to
|
||||
expand)
|
||||
</Text>
|
||||
</Box>
|
||||
)}
|
||||
</Box>
|
||||
)}
|
||||
|
||||
{/* Inline approval prompt when awaiting confirmation */}
|
||||
{data.pendingConfirmation && (
|
||||
<Box flexDirection="column" marginTop={1} paddingLeft={1}>
|
||||
<ToolConfirmationMessage
|
||||
confirmationDetails={data.pendingConfirmation}
|
||||
isFocused={true}
|
||||
availableTerminalHeight={availableHeight}
|
||||
terminalWidth={childWidth}
|
||||
compactMode={true}
|
||||
config={config}
|
||||
/>
|
||||
</Box>
|
||||
)}
|
||||
</>
|
||||
)}
|
||||
|
||||
{/* Completed state: Show summary line */}
|
||||
{data.status === 'completed' && data.executionSummary && (
|
||||
<Box flexDirection="row" marginTop={1}>
|
||||
<Text color={theme.text.secondary}>
|
||||
Execution Summary: {data.executionSummary.totalToolCalls} tool
|
||||
uses · {data.executionSummary.totalTokens.toLocaleString()} tokens
|
||||
· {fmtDuration(data.executionSummary.totalDurationMs)}
|
||||
</Text>
|
||||
</Box>
|
||||
)}
|
||||
|
||||
{/* Failed/Cancelled state: Show error reason */}
|
||||
{data.status === 'failed' && (
|
||||
<Box flexDirection="row" marginTop={1}>
|
||||
<Text color={theme.status.error}>
|
||||
Failed: {data.terminateReason}
|
||||
</Text>
|
||||
</Box>
|
||||
)}
|
||||
</Box>
|
||||
);
|
||||
}
|
||||
|
||||
// Default and verbose modes use normal layout
|
||||
return (
|
||||
<Box flexDirection="column" paddingX={1} gap={1}>
|
||||
{/* Header with subagent name and status */}
|
||||
@@ -158,7 +243,8 @@ export const AgentExecutionDisplay: React.FC<AgentExecutionDisplayProps> = ({
|
||||
config={config}
|
||||
isFocused={true}
|
||||
availableTerminalHeight={availableHeight}
|
||||
terminalWidth={childWidth ?? 80}
|
||||
terminalWidth={childWidth}
|
||||
compactMode={true}
|
||||
/>
|
||||
</Box>
|
||||
)}
|
||||
@@ -280,7 +366,8 @@ const ToolCallItem: React.FC<{
|
||||
resultDisplay?: string;
|
||||
description?: string;
|
||||
};
|
||||
}> = ({ toolCall }) => {
|
||||
compact?: boolean;
|
||||
}> = ({ toolCall, compact = false }) => {
|
||||
const STATUS_INDICATOR_WIDTH = 3;
|
||||
|
||||
// Map subagent status to ToolCallStatus-like display
|
||||
@@ -335,8 +422,8 @@ const ToolCallItem: React.FC<{
|
||||
</Text>
|
||||
</Box>
|
||||
|
||||
{/* Second line: truncated returnDisplay output */}
|
||||
{truncatedOutput && (
|
||||
{/* Second line: truncated returnDisplay output - hidden in compact mode */}
|
||||
{!compact && truncatedOutput && (
|
||||
<Box flexDirection="row" paddingLeft={STATUS_INDICATOR_WIDTH}>
|
||||
<Text color={Colors.Gray}>{truncatedOutput}</Text>
|
||||
</Box>
|
||||
|
||||
@@ -526,7 +526,7 @@ describe('KeypressContext - Kitty Protocol', () => {
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
expect(keyHandler).toHaveBeenCalledTimes(2); // 1 paste event + 1 paste event for 'after'
|
||||
expect(keyHandler).toHaveBeenCalledTimes(6); // 1 paste event + 5 individual chars for 'after'
|
||||
});
|
||||
|
||||
// Should emit paste event first
|
||||
@@ -538,12 +538,40 @@ describe('KeypressContext - Kitty Protocol', () => {
|
||||
}),
|
||||
);
|
||||
|
||||
// Then process 'after' as a paste event (since it's > 2 chars)
|
||||
// Then process 'after' as individual characters (since it doesn't contain return)
|
||||
expect(keyHandler).toHaveBeenNthCalledWith(
|
||||
2,
|
||||
expect.objectContaining({
|
||||
paste: true,
|
||||
sequence: 'after',
|
||||
name: 'a',
|
||||
paste: false,
|
||||
}),
|
||||
);
|
||||
expect(keyHandler).toHaveBeenNthCalledWith(
|
||||
3,
|
||||
expect.objectContaining({
|
||||
name: 'f',
|
||||
paste: false,
|
||||
}),
|
||||
);
|
||||
expect(keyHandler).toHaveBeenNthCalledWith(
|
||||
4,
|
||||
expect.objectContaining({
|
||||
name: 't',
|
||||
paste: false,
|
||||
}),
|
||||
);
|
||||
expect(keyHandler).toHaveBeenNthCalledWith(
|
||||
5,
|
||||
expect.objectContaining({
|
||||
name: 'e',
|
||||
paste: false,
|
||||
}),
|
||||
);
|
||||
expect(keyHandler).toHaveBeenNthCalledWith(
|
||||
6,
|
||||
expect.objectContaining({
|
||||
name: 'r',
|
||||
paste: false,
|
||||
}),
|
||||
);
|
||||
});
|
||||
@@ -571,7 +599,7 @@ describe('KeypressContext - Kitty Protocol', () => {
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
expect(keyHandler).toHaveBeenCalledTimes(14); // Adjusted based on actual behavior
|
||||
expect(keyHandler).toHaveBeenCalledTimes(16); // 5 + 1 + 6 + 1 + 3 = 16 calls
|
||||
});
|
||||
|
||||
// Check the sequence: 'start' (5 chars) + paste1 + 'middle' (6 chars) + paste2 + 'end' (3 chars as paste)
|
||||
@@ -643,13 +671,18 @@ describe('KeypressContext - Kitty Protocol', () => {
|
||||
}),
|
||||
);
|
||||
|
||||
// 'end' as paste event (since it's > 2 chars)
|
||||
// 'end' as individual characters (since it doesn't contain return)
|
||||
expect(keyHandler).toHaveBeenNthCalledWith(
|
||||
callIndex++,
|
||||
expect.objectContaining({
|
||||
paste: true,
|
||||
sequence: 'end',
|
||||
}),
|
||||
expect.objectContaining({ name: 'e' }),
|
||||
);
|
||||
expect(keyHandler).toHaveBeenNthCalledWith(
|
||||
callIndex++,
|
||||
expect.objectContaining({ name: 'n' }),
|
||||
);
|
||||
expect(keyHandler).toHaveBeenNthCalledWith(
|
||||
callIndex++,
|
||||
expect.objectContaining({ name: 'd' }),
|
||||
);
|
||||
});
|
||||
|
||||
@@ -738,16 +771,18 @@ describe('KeypressContext - Kitty Protocol', () => {
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
// With the current implementation, fragmented data gets processed differently
|
||||
// The first fragment '\x1b[20' gets processed as individual characters
|
||||
// The second fragment '0~content\x1b[2' gets processed as paste + individual chars
|
||||
// The third fragment '01~' gets processed as individual characters
|
||||
expect(keyHandler).toHaveBeenCalled();
|
||||
// With the current implementation, fragmented paste markers get reconstructed
|
||||
// into a single paste event for 'content'
|
||||
expect(keyHandler).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
|
||||
// The current implementation processes fragmented paste markers as separate events
|
||||
// rather than reconstructing them into a single paste event
|
||||
expect(keyHandler.mock.calls.length).toBeGreaterThan(1);
|
||||
// Should reconstruct the fragmented paste markers into a single paste event
|
||||
expect(keyHandler).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
paste: true,
|
||||
sequence: 'content',
|
||||
}),
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -851,19 +886,38 @@ describe('KeypressContext - Kitty Protocol', () => {
|
||||
stdin.emit('data', Buffer.from('lo'));
|
||||
});
|
||||
|
||||
// With the current implementation, data is processed as it arrives
|
||||
// First chunk 'hel' is treated as paste (multi-character)
|
||||
// With the current implementation, data is processed as individual characters
|
||||
// since 'hel' doesn't contain return (0x0d)
|
||||
expect(keyHandler).toHaveBeenNthCalledWith(
|
||||
1,
|
||||
expect.objectContaining({
|
||||
paste: true,
|
||||
sequence: 'hel',
|
||||
name: 'h',
|
||||
sequence: 'h',
|
||||
paste: false,
|
||||
}),
|
||||
);
|
||||
|
||||
// Second chunk 'lo' is processed as individual characters
|
||||
expect(keyHandler).toHaveBeenNthCalledWith(
|
||||
2,
|
||||
expect.objectContaining({
|
||||
name: 'e',
|
||||
sequence: 'e',
|
||||
paste: false,
|
||||
}),
|
||||
);
|
||||
|
||||
expect(keyHandler).toHaveBeenNthCalledWith(
|
||||
3,
|
||||
expect.objectContaining({
|
||||
name: 'l',
|
||||
sequence: 'l',
|
||||
paste: false,
|
||||
}),
|
||||
);
|
||||
|
||||
// Second chunk 'lo' is also processed as individual characters
|
||||
expect(keyHandler).toHaveBeenNthCalledWith(
|
||||
4,
|
||||
expect.objectContaining({
|
||||
name: 'l',
|
||||
sequence: 'l',
|
||||
@@ -872,7 +926,7 @@ describe('KeypressContext - Kitty Protocol', () => {
|
||||
);
|
||||
|
||||
expect(keyHandler).toHaveBeenNthCalledWith(
|
||||
3,
|
||||
5,
|
||||
expect.objectContaining({
|
||||
name: 'o',
|
||||
sequence: 'o',
|
||||
@@ -880,7 +934,7 @@ describe('KeypressContext - Kitty Protocol', () => {
|
||||
}),
|
||||
);
|
||||
|
||||
expect(keyHandler).toHaveBeenCalledTimes(3);
|
||||
expect(keyHandler).toHaveBeenCalledTimes(5);
|
||||
} finally {
|
||||
vi.useRealTimers();
|
||||
}
|
||||
@@ -907,14 +961,20 @@ describe('KeypressContext - Kitty Protocol', () => {
|
||||
});
|
||||
|
||||
// Should flush immediately without waiting for timeout
|
||||
// Large data gets treated as paste event
|
||||
expect(keyHandler).toHaveBeenCalledTimes(1);
|
||||
expect(keyHandler).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
paste: true,
|
||||
sequence: largeData,
|
||||
}),
|
||||
);
|
||||
// Large data without return gets treated as individual characters
|
||||
expect(keyHandler).toHaveBeenCalledTimes(65);
|
||||
|
||||
// Each character should be processed individually
|
||||
for (let i = 0; i < 65; i++) {
|
||||
expect(keyHandler).toHaveBeenNthCalledWith(
|
||||
i + 1,
|
||||
expect.objectContaining({
|
||||
name: 'x',
|
||||
sequence: 'x',
|
||||
paste: false,
|
||||
}),
|
||||
);
|
||||
}
|
||||
|
||||
// Advancing timer should not cause additional calls
|
||||
const callCountBefore = keyHandler.mock.calls.length;
|
||||
|
||||
@@ -407,7 +407,11 @@ export function KeypressProvider({
|
||||
return;
|
||||
}
|
||||
|
||||
if (rawDataBuffer.length <= 2 || isPaste) {
|
||||
if (
|
||||
(rawDataBuffer.length <= 2 && rawDataBuffer.includes(0x0d)) ||
|
||||
!rawDataBuffer.includes(0x0d) ||
|
||||
isPaste
|
||||
) {
|
||||
keypressStream.write(rawDataBuffer);
|
||||
} else {
|
||||
// Flush raw data buffer as a paste event
|
||||
|
||||
@@ -106,6 +106,7 @@ describe('useSlashCommandProcessor', () => {
|
||||
const mockLoadHistory = vi.fn();
|
||||
const mockOpenThemeDialog = vi.fn();
|
||||
const mockOpenAuthDialog = vi.fn();
|
||||
const mockOpenModelSelectionDialog = vi.fn();
|
||||
const mockSetQuittingMessages = vi.fn();
|
||||
|
||||
const mockConfig = makeFakeConfig({});
|
||||
@@ -122,6 +123,7 @@ describe('useSlashCommandProcessor', () => {
|
||||
mockBuiltinLoadCommands.mockResolvedValue([]);
|
||||
mockFileLoadCommands.mockResolvedValue([]);
|
||||
mockMcpLoadCommands.mockResolvedValue([]);
|
||||
mockOpenModelSelectionDialog.mockClear();
|
||||
});
|
||||
|
||||
const setupProcessorHook = (
|
||||
@@ -150,11 +152,13 @@ describe('useSlashCommandProcessor', () => {
|
||||
mockSetQuittingMessages,
|
||||
vi.fn(), // openPrivacyNotice
|
||||
vi.fn(), // openSettingsDialog
|
||||
mockOpenModelSelectionDialog,
|
||||
vi.fn(), // openSubagentCreateDialog
|
||||
vi.fn(), // openAgentsManagerDialog
|
||||
vi.fn(), // toggleVimEnabled
|
||||
setIsProcessing,
|
||||
vi.fn(), // setGeminiMdFileCount
|
||||
vi.fn(), // _showQuitConfirmation
|
||||
),
|
||||
);
|
||||
|
||||
@@ -395,6 +399,21 @@ describe('useSlashCommandProcessor', () => {
|
||||
expect(mockOpenThemeDialog).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should handle "dialog: model" action', async () => {
|
||||
const command = createTestCommand({
|
||||
name: 'modelcmd',
|
||||
action: vi.fn().mockResolvedValue({ type: 'dialog', dialog: 'model' }),
|
||||
});
|
||||
const result = setupProcessorHook([command]);
|
||||
await waitFor(() => expect(result.current.slashCommands).toHaveLength(1));
|
||||
|
||||
await act(async () => {
|
||||
await result.current.handleSlashCommand('/modelcmd');
|
||||
});
|
||||
|
||||
expect(mockOpenModelSelectionDialog).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should handle "load_history" action', async () => {
|
||||
const command = createTestCommand({
|
||||
name: 'load',
|
||||
@@ -904,11 +923,13 @@ describe('useSlashCommandProcessor', () => {
|
||||
mockSetQuittingMessages,
|
||||
vi.fn(), // openPrivacyNotice
|
||||
vi.fn(), // openSettingsDialog
|
||||
vi.fn(), // openModelSelectionDialog
|
||||
vi.fn(), // openSubagentCreateDialog
|
||||
vi.fn(), // openAgentsManagerDialog
|
||||
vi.fn(), // toggleVimEnabled
|
||||
vi.fn(), // setIsProcessing
|
||||
vi.fn(), // setGeminiMdFileCount
|
||||
vi.fn(), // _showQuitConfirmation
|
||||
),
|
||||
);
|
||||
|
||||
|
||||
@@ -53,6 +53,7 @@ export const useSlashCommandProcessor = (
|
||||
setQuittingMessages: (message: HistoryItem[]) => void,
|
||||
openPrivacyNotice: () => void,
|
||||
openSettingsDialog: () => void,
|
||||
openModelSelectionDialog: () => void,
|
||||
openSubagentCreateDialog: () => void,
|
||||
openAgentsManagerDialog: () => void,
|
||||
toggleVimEnabled: () => Promise<boolean>,
|
||||
@@ -404,6 +405,9 @@ export const useSlashCommandProcessor = (
|
||||
case 'settings':
|
||||
openSettingsDialog();
|
||||
return { type: 'handled' };
|
||||
case 'model':
|
||||
openModelSelectionDialog();
|
||||
return { type: 'handled' };
|
||||
case 'subagent_create':
|
||||
openSubagentCreateDialog();
|
||||
return { type: 'handled' };
|
||||
@@ -663,6 +667,7 @@ export const useSlashCommandProcessor = (
|
||||
setSessionShellAllowlist,
|
||||
setIsProcessing,
|
||||
setConfirmationRequest,
|
||||
openModelSelectionDialog,
|
||||
session.stats,
|
||||
],
|
||||
);
|
||||
|
||||
@@ -56,6 +56,12 @@ const MockedUserPromptEvent = vi.hoisted(() =>
|
||||
);
|
||||
const mockParseAndFormatApiError = vi.hoisted(() => vi.fn());
|
||||
|
||||
// Vision auto-switch mocks (hoisted)
|
||||
const mockHandleVisionSwitch = vi.hoisted(() =>
|
||||
vi.fn().mockResolvedValue({ shouldProceed: true }),
|
||||
);
|
||||
const mockRestoreOriginalModel = vi.hoisted(() => vi.fn());
|
||||
|
||||
vi.mock('@qwen-code/qwen-code-core', async (importOriginal) => {
|
||||
const actualCoreModule = (await importOriginal()) as any;
|
||||
return {
|
||||
@@ -76,6 +82,13 @@ vi.mock('./useReactToolScheduler.js', async (importOriginal) => {
|
||||
};
|
||||
});
|
||||
|
||||
vi.mock('./useVisionAutoSwitch.js', () => ({
|
||||
useVisionAutoSwitch: vi.fn(() => ({
|
||||
handleVisionSwitch: mockHandleVisionSwitch,
|
||||
restoreOriginalModel: mockRestoreOriginalModel,
|
||||
})),
|
||||
}));
|
||||
|
||||
vi.mock('./useKeypress.js', () => ({
|
||||
useKeypress: vi.fn(),
|
||||
}));
|
||||
@@ -199,6 +212,7 @@ describe('useGeminiStream', () => {
|
||||
getContentGeneratorConfig: vi
|
||||
.fn()
|
||||
.mockReturnValue(contentGeneratorConfig),
|
||||
getMaxSessionTurns: vi.fn(() => 50),
|
||||
} as unknown as Config;
|
||||
mockOnDebugMessage = vi.fn();
|
||||
mockHandleSlashCommand = vi.fn().mockResolvedValue(false);
|
||||
@@ -1551,6 +1565,7 @@ describe('useGeminiStream', () => {
|
||||
expect.any(String), // Argument 3: The prompt_id string
|
||||
);
|
||||
});
|
||||
|
||||
describe('Thought Reset', () => {
|
||||
it('should reset thought to null when starting a new prompt', async () => {
|
||||
// First, simulate a response with a thought
|
||||
@@ -1900,4 +1915,166 @@ describe('useGeminiStream', () => {
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
// --- New tests focused on recent modifications ---
|
||||
describe('Vision Auto Switch Integration', () => {
|
||||
it('should call handleVisionSwitch and proceed to send when allowed', async () => {
|
||||
mockHandleVisionSwitch.mockResolvedValueOnce({ shouldProceed: true });
|
||||
mockSendMessageStream.mockReturnValue(
|
||||
(async function* () {
|
||||
yield { type: ServerGeminiEventType.Content, value: 'ok' };
|
||||
yield { type: ServerGeminiEventType.Finished, value: 'STOP' };
|
||||
})(),
|
||||
);
|
||||
|
||||
const { result } = renderHook(() =>
|
||||
useGeminiStream(
|
||||
new MockedGeminiClientClass(mockConfig),
|
||||
[],
|
||||
mockAddItem,
|
||||
mockConfig,
|
||||
mockOnDebugMessage,
|
||||
mockHandleSlashCommand,
|
||||
false,
|
||||
() => 'vscode' as EditorType,
|
||||
() => {},
|
||||
() => Promise.resolve(),
|
||||
false,
|
||||
() => {},
|
||||
() => {},
|
||||
() => {},
|
||||
),
|
||||
);
|
||||
|
||||
await act(async () => {
|
||||
await result.current.submitQuery('image prompt');
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockHandleVisionSwitch).toHaveBeenCalled();
|
||||
expect(mockSendMessageStream).toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
it('should gate submission when handleVisionSwitch returns shouldProceed=false', async () => {
|
||||
mockHandleVisionSwitch.mockResolvedValueOnce({ shouldProceed: false });
|
||||
|
||||
const { result } = renderHook(() =>
|
||||
useGeminiStream(
|
||||
new MockedGeminiClientClass(mockConfig),
|
||||
[],
|
||||
mockAddItem,
|
||||
mockConfig,
|
||||
mockOnDebugMessage,
|
||||
mockHandleSlashCommand,
|
||||
false,
|
||||
() => 'vscode' as EditorType,
|
||||
() => {},
|
||||
() => Promise.resolve(),
|
||||
false,
|
||||
() => {},
|
||||
() => {},
|
||||
() => {},
|
||||
),
|
||||
);
|
||||
|
||||
await act(async () => {
|
||||
await result.current.submitQuery('vision-gated');
|
||||
});
|
||||
|
||||
// No call to API, no restoreOriginalModel needed since no override occurred
|
||||
expect(mockSendMessageStream).not.toHaveBeenCalled();
|
||||
expect(mockRestoreOriginalModel).not.toHaveBeenCalled();
|
||||
|
||||
// Next call allowed (flag reset path)
|
||||
mockHandleVisionSwitch.mockResolvedValueOnce({ shouldProceed: true });
|
||||
mockSendMessageStream.mockReturnValue(
|
||||
(async function* () {
|
||||
yield { type: ServerGeminiEventType.Content, value: 'ok' };
|
||||
yield { type: ServerGeminiEventType.Finished, value: 'STOP' };
|
||||
})(),
|
||||
);
|
||||
await act(async () => {
|
||||
await result.current.submitQuery('after-gate');
|
||||
});
|
||||
await waitFor(() => {
|
||||
expect(mockSendMessageStream).toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('Model restore on completion and errors', () => {
|
||||
it('should restore model after successful stream completion', async () => {
|
||||
mockSendMessageStream.mockReturnValue(
|
||||
(async function* () {
|
||||
yield { type: ServerGeminiEventType.Content, value: 'content' };
|
||||
yield { type: ServerGeminiEventType.Finished, value: 'STOP' };
|
||||
})(),
|
||||
);
|
||||
|
||||
const { result } = renderHook(() =>
|
||||
useGeminiStream(
|
||||
new MockedGeminiClientClass(mockConfig),
|
||||
[],
|
||||
mockAddItem,
|
||||
mockConfig,
|
||||
mockOnDebugMessage,
|
||||
mockHandleSlashCommand,
|
||||
false,
|
||||
() => 'vscode' as EditorType,
|
||||
() => {},
|
||||
() => Promise.resolve(),
|
||||
false,
|
||||
() => {},
|
||||
() => {},
|
||||
() => {},
|
||||
),
|
||||
);
|
||||
|
||||
await act(async () => {
|
||||
await result.current.submitQuery('restore-success');
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockRestoreOriginalModel).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
});
|
||||
|
||||
it('should restore model when an error occurs during streaming', async () => {
|
||||
const testError = new Error('stream failure');
|
||||
mockSendMessageStream.mockReturnValue(
|
||||
(async function* () {
|
||||
yield { type: ServerGeminiEventType.Content, value: 'content' };
|
||||
throw testError;
|
||||
})(),
|
||||
);
|
||||
|
||||
const { result } = renderHook(() =>
|
||||
useGeminiStream(
|
||||
new MockedGeminiClientClass(mockConfig),
|
||||
[],
|
||||
mockAddItem,
|
||||
mockConfig,
|
||||
mockOnDebugMessage,
|
||||
mockHandleSlashCommand,
|
||||
false,
|
||||
() => 'vscode' as EditorType,
|
||||
() => {},
|
||||
() => Promise.resolve(),
|
||||
false,
|
||||
() => {},
|
||||
() => {},
|
||||
() => {},
|
||||
),
|
||||
);
|
||||
|
||||
await act(async () => {
|
||||
await result.current.submitQuery('restore-error');
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockRestoreOriginalModel).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -42,6 +42,7 @@ import type {
|
||||
import { StreamingState, MessageType, ToolCallStatus } from '../types.js';
|
||||
import { isAtCommand, isSlashCommand } from '../utils/commandUtils.js';
|
||||
import { useShellCommandProcessor } from './shellCommandProcessor.js';
|
||||
import { useVisionAutoSwitch } from './useVisionAutoSwitch.js';
|
||||
import { handleAtCommand } from './atCommandProcessor.js';
|
||||
import { findLastSafeSplitPoint } from '../utils/markdownUtilities.js';
|
||||
import { useStateAndRef } from './useStateAndRef.js';
|
||||
@@ -88,6 +89,12 @@ export const useGeminiStream = (
|
||||
setModelSwitchedFromQuotaError: React.Dispatch<React.SetStateAction<boolean>>,
|
||||
onEditorClose: () => void,
|
||||
onCancelSubmit: () => void,
|
||||
visionModelPreviewEnabled: boolean,
|
||||
onVisionSwitchRequired?: (query: PartListUnion) => Promise<{
|
||||
modelOverride?: string;
|
||||
persistSessionModel?: string;
|
||||
showGuidance?: boolean;
|
||||
}>,
|
||||
) => {
|
||||
const [initError, setInitError] = useState<string | null>(null);
|
||||
const abortControllerRef = useRef<AbortController | null>(null);
|
||||
@@ -155,6 +162,13 @@ export const useGeminiStream = (
|
||||
geminiClient,
|
||||
);
|
||||
|
||||
const { handleVisionSwitch, restoreOriginalModel } = useVisionAutoSwitch(
|
||||
config,
|
||||
addItem,
|
||||
visionModelPreviewEnabled,
|
||||
onVisionSwitchRequired,
|
||||
);
|
||||
|
||||
const streamingState = useMemo(() => {
|
||||
if (toolCalls.some((tc) => tc.status === 'awaiting_approval')) {
|
||||
return StreamingState.WaitingForConfirmation;
|
||||
@@ -715,6 +729,20 @@ export const useGeminiStream = (
|
||||
return;
|
||||
}
|
||||
|
||||
// Handle vision switch requirement
|
||||
const visionSwitchResult = await handleVisionSwitch(
|
||||
queryToSend,
|
||||
userMessageTimestamp,
|
||||
options?.isContinuation || false,
|
||||
);
|
||||
|
||||
if (!visionSwitchResult.shouldProceed) {
|
||||
isSubmittingQueryRef.current = false;
|
||||
return;
|
||||
}
|
||||
|
||||
const finalQueryToSend = queryToSend;
|
||||
|
||||
if (!options?.isContinuation) {
|
||||
startNewPrompt();
|
||||
setThought(null); // Reset thought when starting a new prompt
|
||||
@@ -725,7 +753,7 @@ export const useGeminiStream = (
|
||||
|
||||
try {
|
||||
const stream = geminiClient.sendMessageStream(
|
||||
queryToSend,
|
||||
finalQueryToSend,
|
||||
abortSignal,
|
||||
prompt_id!,
|
||||
);
|
||||
@@ -736,6 +764,8 @@ export const useGeminiStream = (
|
||||
);
|
||||
|
||||
if (processingStatus === StreamProcessingStatus.UserCancelled) {
|
||||
// Restore original model if it was temporarily overridden
|
||||
restoreOriginalModel();
|
||||
isSubmittingQueryRef.current = false;
|
||||
return;
|
||||
}
|
||||
@@ -748,7 +778,13 @@ export const useGeminiStream = (
|
||||
loopDetectedRef.current = false;
|
||||
handleLoopDetectedEvent();
|
||||
}
|
||||
|
||||
// Restore original model if it was temporarily overridden
|
||||
restoreOriginalModel();
|
||||
} catch (error: unknown) {
|
||||
// Restore original model if it was temporarily overridden
|
||||
restoreOriginalModel();
|
||||
|
||||
if (error instanceof UnauthorizedError) {
|
||||
onAuthError();
|
||||
} else if (!isNodeError(error) || error.name !== 'AbortError') {
|
||||
@@ -786,6 +822,8 @@ export const useGeminiStream = (
|
||||
startNewPrompt,
|
||||
getPromptCount,
|
||||
handleLoopDetectedEvent,
|
||||
handleVisionSwitch,
|
||||
restoreOriginalModel,
|
||||
],
|
||||
);
|
||||
|
||||
@@ -911,10 +949,13 @@ export const useGeminiStream = (
|
||||
],
|
||||
);
|
||||
|
||||
const pendingHistoryItems = [
|
||||
pendingHistoryItemRef.current,
|
||||
pendingToolCallGroupDisplay,
|
||||
].filter((i) => i !== undefined && i !== null);
|
||||
const pendingHistoryItems = useMemo(
|
||||
() =>
|
||||
[pendingHistoryItemRef.current, pendingToolCallGroupDisplay].filter(
|
||||
(i) => i !== undefined && i !== null,
|
||||
),
|
||||
[pendingHistoryItemRef, pendingToolCallGroupDisplay],
|
||||
);
|
||||
|
||||
useEffect(() => {
|
||||
const saveRestorableToolCalls = async () => {
|
||||
|
||||
853
packages/cli/src/ui/hooks/useVisionAutoSwitch.test.ts
Normal file
853
packages/cli/src/ui/hooks/useVisionAutoSwitch.test.ts
Normal file
@@ -0,0 +1,853 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Qwen
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
/* eslint-disable @typescript-eslint/no-explicit-any */
|
||||
import { describe, it, expect, vi, beforeEach } from 'vitest';
|
||||
import { renderHook, act } from '@testing-library/react';
|
||||
import type { Part, PartListUnion } from '@google/genai';
|
||||
import { AuthType, type Config, ApprovalMode } from '@qwen-code/qwen-code-core';
|
||||
import {
|
||||
shouldOfferVisionSwitch,
|
||||
processVisionSwitchOutcome,
|
||||
getVisionSwitchGuidanceMessage,
|
||||
useVisionAutoSwitch,
|
||||
} from './useVisionAutoSwitch.js';
|
||||
import { VisionSwitchOutcome } from '../components/ModelSwitchDialog.js';
|
||||
import { MessageType } from '../types.js';
|
||||
import { getDefaultVisionModel } from '../models/availableModels.js';
|
||||
|
||||
describe('useVisionAutoSwitch helpers', () => {
|
||||
describe('shouldOfferVisionSwitch', () => {
|
||||
it('returns false when authType is not QWEN_OAUTH', () => {
|
||||
const parts: PartListUnion = [
|
||||
{ inlineData: { mimeType: 'image/png', data: '...' } },
|
||||
];
|
||||
const result = shouldOfferVisionSwitch(
|
||||
parts,
|
||||
AuthType.USE_GEMINI,
|
||||
'qwen3-coder-plus',
|
||||
true,
|
||||
);
|
||||
expect(result).toBe(false);
|
||||
});
|
||||
|
||||
it('returns false when current model is already a vision model', () => {
|
||||
const parts: PartListUnion = [
|
||||
{ inlineData: { mimeType: 'image/png', data: '...' } },
|
||||
];
|
||||
const result = shouldOfferVisionSwitch(
|
||||
parts,
|
||||
AuthType.QWEN_OAUTH,
|
||||
'vision-model',
|
||||
true,
|
||||
);
|
||||
expect(result).toBe(false);
|
||||
});
|
||||
|
||||
it('returns true when image parts exist, QWEN_OAUTH, and model is not vision', () => {
|
||||
const parts: PartListUnion = [
|
||||
{ text: 'hello' },
|
||||
{ inlineData: { mimeType: 'image/jpeg', data: '...' } },
|
||||
];
|
||||
const result = shouldOfferVisionSwitch(
|
||||
parts,
|
||||
AuthType.QWEN_OAUTH,
|
||||
'qwen3-coder-plus',
|
||||
true,
|
||||
);
|
||||
expect(result).toBe(true);
|
||||
});
|
||||
|
||||
it('detects image when provided as a single Part object (non-array)', () => {
|
||||
const singleImagePart: PartListUnion = {
|
||||
fileData: { mimeType: 'image/gif', fileUri: 'file://image.gif' },
|
||||
} as Part;
|
||||
const result = shouldOfferVisionSwitch(
|
||||
singleImagePart,
|
||||
AuthType.QWEN_OAUTH,
|
||||
'qwen3-coder-plus',
|
||||
true,
|
||||
);
|
||||
expect(result).toBe(true);
|
||||
});
|
||||
|
||||
it('returns false when parts contain no images', () => {
|
||||
const parts: PartListUnion = [{ text: 'just text' }];
|
||||
const result = shouldOfferVisionSwitch(
|
||||
parts,
|
||||
AuthType.QWEN_OAUTH,
|
||||
'qwen3-coder-plus',
|
||||
true,
|
||||
);
|
||||
expect(result).toBe(false);
|
||||
});
|
||||
|
||||
it('returns false when parts is a plain string', () => {
|
||||
const parts: PartListUnion = 'plain text';
|
||||
const result = shouldOfferVisionSwitch(
|
||||
parts,
|
||||
AuthType.QWEN_OAUTH,
|
||||
'qwen3-coder-plus',
|
||||
true,
|
||||
);
|
||||
expect(result).toBe(false);
|
||||
});
|
||||
|
||||
it('returns false when visionModelPreviewEnabled is false', () => {
|
||||
const parts: PartListUnion = [
|
||||
{ inlineData: { mimeType: 'image/png', data: '...' } },
|
||||
];
|
||||
const result = shouldOfferVisionSwitch(
|
||||
parts,
|
||||
AuthType.QWEN_OAUTH,
|
||||
'qwen3-coder-plus',
|
||||
false,
|
||||
);
|
||||
expect(result).toBe(false);
|
||||
});
|
||||
|
||||
it('returns true when image parts exist in YOLO mode context', () => {
|
||||
const parts: PartListUnion = [
|
||||
{ inlineData: { mimeType: 'image/png', data: '...' } },
|
||||
];
|
||||
const result = shouldOfferVisionSwitch(
|
||||
parts,
|
||||
AuthType.QWEN_OAUTH,
|
||||
'qwen3-coder-plus',
|
||||
true,
|
||||
);
|
||||
expect(result).toBe(true);
|
||||
});
|
||||
|
||||
it('returns false when no image parts exist in YOLO mode context', () => {
|
||||
const parts: PartListUnion = [{ text: 'just text' }];
|
||||
const result = shouldOfferVisionSwitch(
|
||||
parts,
|
||||
AuthType.QWEN_OAUTH,
|
||||
'qwen3-coder-plus',
|
||||
true,
|
||||
);
|
||||
expect(result).toBe(false);
|
||||
});
|
||||
|
||||
it('returns false when already using vision model in YOLO mode context', () => {
|
||||
const parts: PartListUnion = [
|
||||
{ inlineData: { mimeType: 'image/png', data: '...' } },
|
||||
];
|
||||
const result = shouldOfferVisionSwitch(
|
||||
parts,
|
||||
AuthType.QWEN_OAUTH,
|
||||
'vision-model',
|
||||
true,
|
||||
);
|
||||
expect(result).toBe(false);
|
||||
});
|
||||
|
||||
it('returns false when authType is not QWEN_OAUTH in YOLO mode context', () => {
|
||||
const parts: PartListUnion = [
|
||||
{ inlineData: { mimeType: 'image/png', data: '...' } },
|
||||
];
|
||||
const result = shouldOfferVisionSwitch(
|
||||
parts,
|
||||
AuthType.USE_GEMINI,
|
||||
'qwen3-coder-plus',
|
||||
true,
|
||||
);
|
||||
expect(result).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe('processVisionSwitchOutcome', () => {
|
||||
it('maps SwitchOnce to a one-time model override', () => {
|
||||
const vl = getDefaultVisionModel();
|
||||
const result = processVisionSwitchOutcome(VisionSwitchOutcome.SwitchOnce);
|
||||
expect(result).toEqual({ modelOverride: vl });
|
||||
});
|
||||
|
||||
it('maps SwitchSessionToVL to a persistent session model', () => {
|
||||
const vl = getDefaultVisionModel();
|
||||
const result = processVisionSwitchOutcome(
|
||||
VisionSwitchOutcome.SwitchSessionToVL,
|
||||
);
|
||||
expect(result).toEqual({ persistSessionModel: vl });
|
||||
});
|
||||
|
||||
it('maps ContinueWithCurrentModel to empty result', () => {
|
||||
const result = processVisionSwitchOutcome(
|
||||
VisionSwitchOutcome.ContinueWithCurrentModel,
|
||||
);
|
||||
expect(result).toEqual({});
|
||||
});
|
||||
});
|
||||
|
||||
describe('getVisionSwitchGuidanceMessage', () => {
|
||||
it('returns the expected guidance message', () => {
|
||||
const vl = getDefaultVisionModel();
|
||||
const expected =
|
||||
'To use images with your query, you can:\n' +
|
||||
`• Use /model set ${vl} to switch to a vision-capable model\n` +
|
||||
'• Or remove the image and provide a text description instead';
|
||||
expect(getVisionSwitchGuidanceMessage()).toBe(expected);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('useVisionAutoSwitch hook', () => {
|
||||
type AddItemFn = (
|
||||
item: { type: MessageType; text: string },
|
||||
ts: number,
|
||||
) => any;
|
||||
|
||||
const createMockConfig = (
|
||||
authType: AuthType,
|
||||
initialModel: string,
|
||||
approvalMode: ApprovalMode = ApprovalMode.DEFAULT,
|
||||
vlmSwitchMode?: string,
|
||||
) => {
|
||||
let currentModel = initialModel;
|
||||
const mockConfig: Partial<Config> = {
|
||||
getModel: vi.fn(() => currentModel),
|
||||
setModel: vi.fn((m: string) => {
|
||||
currentModel = m;
|
||||
}),
|
||||
getApprovalMode: vi.fn(() => approvalMode),
|
||||
getVlmSwitchMode: vi.fn(() => vlmSwitchMode),
|
||||
getContentGeneratorConfig: vi.fn(() => ({
|
||||
authType,
|
||||
model: currentModel,
|
||||
apiKey: 'test-key',
|
||||
vertexai: false,
|
||||
})),
|
||||
};
|
||||
return mockConfig as Config;
|
||||
};
|
||||
|
||||
let addItem: AddItemFn;
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
addItem = vi.fn();
|
||||
});
|
||||
|
||||
it('returns shouldProceed=true immediately for continuations', async () => {
|
||||
const config = createMockConfig(AuthType.QWEN_OAUTH, 'qwen3-coder-plus');
|
||||
const { result } = renderHook(() =>
|
||||
useVisionAutoSwitch(config, addItem as any, true, vi.fn()),
|
||||
);
|
||||
|
||||
const parts: PartListUnion = [
|
||||
{ inlineData: { mimeType: 'image/png', data: '...' } },
|
||||
];
|
||||
let res: any;
|
||||
await act(async () => {
|
||||
res = await result.current.handleVisionSwitch(parts, Date.now(), true);
|
||||
});
|
||||
expect(res).toEqual({ shouldProceed: true });
|
||||
expect(addItem).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('does nothing when authType is not QWEN_OAUTH', async () => {
|
||||
const config = createMockConfig(AuthType.USE_GEMINI, 'qwen3-coder-plus');
|
||||
const onVisionSwitchRequired = vi.fn();
|
||||
const { result } = renderHook(() =>
|
||||
useVisionAutoSwitch(config, addItem as any, true, onVisionSwitchRequired),
|
||||
);
|
||||
|
||||
const parts: PartListUnion = [
|
||||
{ inlineData: { mimeType: 'image/png', data: '...' } },
|
||||
];
|
||||
let res: any;
|
||||
await act(async () => {
|
||||
res = await result.current.handleVisionSwitch(parts, 123, false);
|
||||
});
|
||||
expect(res).toEqual({ shouldProceed: true });
|
||||
expect(onVisionSwitchRequired).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('does nothing when there are no image parts', async () => {
|
||||
const config = createMockConfig(AuthType.QWEN_OAUTH, 'qwen3-coder-plus');
|
||||
const onVisionSwitchRequired = vi.fn();
|
||||
const { result } = renderHook(() =>
|
||||
useVisionAutoSwitch(config, addItem as any, true, onVisionSwitchRequired),
|
||||
);
|
||||
|
||||
const parts: PartListUnion = [{ text: 'no images here' }];
|
||||
let res: any;
|
||||
await act(async () => {
|
||||
res = await result.current.handleVisionSwitch(parts, 456, false);
|
||||
});
|
||||
expect(res).toEqual({ shouldProceed: true });
|
||||
expect(onVisionSwitchRequired).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('continues with current model when dialog returns empty result', async () => {
|
||||
const config = createMockConfig(AuthType.QWEN_OAUTH, 'qwen3-coder-plus');
|
||||
const onVisionSwitchRequired = vi.fn().mockResolvedValue({}); // Empty result for ContinueWithCurrentModel
|
||||
const { result } = renderHook(() =>
|
||||
useVisionAutoSwitch(config, addItem as any, true, onVisionSwitchRequired),
|
||||
);
|
||||
|
||||
const parts: PartListUnion = [
|
||||
{ inlineData: { mimeType: 'image/png', data: '...' } },
|
||||
];
|
||||
|
||||
const userTs = 1010;
|
||||
let res: any;
|
||||
await act(async () => {
|
||||
res = await result.current.handleVisionSwitch(parts, userTs, false);
|
||||
});
|
||||
|
||||
// Should not add any guidance message
|
||||
expect(addItem).not.toHaveBeenCalledWith(
|
||||
{ type: MessageType.INFO, text: getVisionSwitchGuidanceMessage() },
|
||||
userTs,
|
||||
);
|
||||
expect(res).toEqual({ shouldProceed: true });
|
||||
expect(config.setModel).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('applies a one-time override and returns originalModel, then restores', async () => {
|
||||
const initialModel = 'qwen3-coder-plus';
|
||||
const config = createMockConfig(AuthType.QWEN_OAUTH, initialModel);
|
||||
const onVisionSwitchRequired = vi
|
||||
.fn()
|
||||
.mockResolvedValue({ modelOverride: 'coder-model' });
|
||||
const { result } = renderHook(() =>
|
||||
useVisionAutoSwitch(config, addItem as any, true, onVisionSwitchRequired),
|
||||
);
|
||||
|
||||
const parts: PartListUnion = [
|
||||
{ inlineData: { mimeType: 'image/png', data: '...' } },
|
||||
];
|
||||
|
||||
let res: any;
|
||||
await act(async () => {
|
||||
res = await result.current.handleVisionSwitch(parts, 2020, false);
|
||||
});
|
||||
|
||||
expect(res).toEqual({ shouldProceed: true, originalModel: initialModel });
|
||||
expect(config.setModel).toHaveBeenCalledWith('coder-model', {
|
||||
reason: 'vision_auto_switch',
|
||||
context: 'User-prompted vision switch (one-time override)',
|
||||
});
|
||||
|
||||
// Now restore
|
||||
act(() => {
|
||||
result.current.restoreOriginalModel();
|
||||
});
|
||||
expect(config.setModel).toHaveBeenLastCalledWith(initialModel, {
|
||||
reason: 'vision_auto_switch',
|
||||
context: 'Restoring original model after vision switch',
|
||||
});
|
||||
});
|
||||
|
||||
it('persists session model when dialog requests persistence', async () => {
|
||||
const config = createMockConfig(AuthType.QWEN_OAUTH, 'qwen3-coder-plus');
|
||||
const onVisionSwitchRequired = vi
|
||||
.fn()
|
||||
.mockResolvedValue({ persistSessionModel: 'coder-model' });
|
||||
const { result } = renderHook(() =>
|
||||
useVisionAutoSwitch(config, addItem as any, true, onVisionSwitchRequired),
|
||||
);
|
||||
|
||||
const parts: PartListUnion = [
|
||||
{ inlineData: { mimeType: 'image/png', data: '...' } },
|
||||
];
|
||||
|
||||
let res: any;
|
||||
await act(async () => {
|
||||
res = await result.current.handleVisionSwitch(parts, 3030, false);
|
||||
});
|
||||
|
||||
expect(res).toEqual({ shouldProceed: true });
|
||||
expect(config.setModel).toHaveBeenCalledWith('coder-model', {
|
||||
reason: 'vision_auto_switch',
|
||||
context: 'User-prompted vision switch (session persistent)',
|
||||
});
|
||||
|
||||
// Restore should be a no-op since no one-time override was used
|
||||
act(() => {
|
||||
result.current.restoreOriginalModel();
|
||||
});
|
||||
// Last call should still be the persisted model set
|
||||
expect((config.setModel as any).mock.calls.pop()?.[0]).toBe('coder-model');
|
||||
});
|
||||
|
||||
it('returns shouldProceed=true when dialog returns no special flags', async () => {
|
||||
const config = createMockConfig(AuthType.QWEN_OAUTH, 'qwen3-coder-plus');
|
||||
const onVisionSwitchRequired = vi.fn().mockResolvedValue({});
|
||||
const { result } = renderHook(() =>
|
||||
useVisionAutoSwitch(config, addItem as any, true, onVisionSwitchRequired),
|
||||
);
|
||||
|
||||
const parts: PartListUnion = [
|
||||
{ inlineData: { mimeType: 'image/png', data: '...' } },
|
||||
];
|
||||
let res: any;
|
||||
await act(async () => {
|
||||
res = await result.current.handleVisionSwitch(parts, 4040, false);
|
||||
});
|
||||
expect(res).toEqual({ shouldProceed: true });
|
||||
expect(config.setModel).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('blocks when dialog throws or is cancelled', async () => {
|
||||
const config = createMockConfig(AuthType.QWEN_OAUTH, 'qwen3-coder-plus');
|
||||
const onVisionSwitchRequired = vi.fn().mockRejectedValue(new Error('x'));
|
||||
const { result } = renderHook(() =>
|
||||
useVisionAutoSwitch(config, addItem as any, true, onVisionSwitchRequired),
|
||||
);
|
||||
|
||||
const parts: PartListUnion = [
|
||||
{ inlineData: { mimeType: 'image/png', data: '...' } },
|
||||
];
|
||||
let res: any;
|
||||
await act(async () => {
|
||||
res = await result.current.handleVisionSwitch(parts, 5050, false);
|
||||
});
|
||||
expect(res).toEqual({ shouldProceed: false });
|
||||
expect(config.setModel).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('does nothing when visionModelPreviewEnabled is false', async () => {
|
||||
const config = createMockConfig(AuthType.QWEN_OAUTH, 'qwen3-coder-plus');
|
||||
const onVisionSwitchRequired = vi.fn();
|
||||
const { result } = renderHook(() =>
|
||||
useVisionAutoSwitch(
|
||||
config,
|
||||
addItem as any,
|
||||
false,
|
||||
onVisionSwitchRequired,
|
||||
),
|
||||
);
|
||||
|
||||
const parts: PartListUnion = [
|
||||
{ inlineData: { mimeType: 'image/png', data: '...' } },
|
||||
];
|
||||
let res: any;
|
||||
await act(async () => {
|
||||
res = await result.current.handleVisionSwitch(parts, 6060, false);
|
||||
});
|
||||
expect(res).toEqual({ shouldProceed: true });
|
||||
expect(onVisionSwitchRequired).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
describe('YOLO mode behavior', () => {
|
||||
it('automatically switches to vision model in YOLO mode without showing dialog', async () => {
|
||||
const initialModel = 'qwen3-coder-plus';
|
||||
const config = createMockConfig(
|
||||
AuthType.QWEN_OAUTH,
|
||||
initialModel,
|
||||
ApprovalMode.YOLO,
|
||||
);
|
||||
const onVisionSwitchRequired = vi.fn(); // Should not be called in YOLO mode
|
||||
const { result } = renderHook(() =>
|
||||
useVisionAutoSwitch(
|
||||
config,
|
||||
addItem as any,
|
||||
true,
|
||||
onVisionSwitchRequired,
|
||||
),
|
||||
);
|
||||
|
||||
const parts: PartListUnion = [
|
||||
{ inlineData: { mimeType: 'image/png', data: '...' } },
|
||||
];
|
||||
|
||||
let res: any;
|
||||
await act(async () => {
|
||||
res = await result.current.handleVisionSwitch(parts, 7070, false);
|
||||
});
|
||||
|
||||
// Should automatically switch without calling the dialog
|
||||
expect(onVisionSwitchRequired).not.toHaveBeenCalled();
|
||||
expect(res).toEqual({
|
||||
shouldProceed: true,
|
||||
originalModel: initialModel,
|
||||
});
|
||||
expect(config.setModel).toHaveBeenCalledWith(getDefaultVisionModel(), {
|
||||
reason: 'vision_auto_switch',
|
||||
context: 'YOLO mode auto-switch for image content',
|
||||
});
|
||||
});
|
||||
|
||||
it('does not switch in YOLO mode when no images are present', async () => {
|
||||
const config = createMockConfig(
|
||||
AuthType.QWEN_OAUTH,
|
||||
'qwen3-coder-plus',
|
||||
ApprovalMode.YOLO,
|
||||
);
|
||||
const onVisionSwitchRequired = vi.fn();
|
||||
const { result } = renderHook(() =>
|
||||
useVisionAutoSwitch(
|
||||
config,
|
||||
addItem as any,
|
||||
true,
|
||||
onVisionSwitchRequired,
|
||||
),
|
||||
);
|
||||
|
||||
const parts: PartListUnion = [{ text: 'no images here' }];
|
||||
|
||||
let res: any;
|
||||
await act(async () => {
|
||||
res = await result.current.handleVisionSwitch(parts, 8080, false);
|
||||
});
|
||||
|
||||
expect(res).toEqual({ shouldProceed: true });
|
||||
expect(onVisionSwitchRequired).not.toHaveBeenCalled();
|
||||
expect(config.setModel).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('does not switch in YOLO mode when already using vision model', async () => {
|
||||
const config = createMockConfig(
|
||||
AuthType.QWEN_OAUTH,
|
||||
'vision-model',
|
||||
ApprovalMode.YOLO,
|
||||
);
|
||||
const onVisionSwitchRequired = vi.fn();
|
||||
const { result } = renderHook(() =>
|
||||
useVisionAutoSwitch(
|
||||
config,
|
||||
addItem as any,
|
||||
true,
|
||||
onVisionSwitchRequired,
|
||||
),
|
||||
);
|
||||
|
||||
const parts: PartListUnion = [
|
||||
{ inlineData: { mimeType: 'image/png', data: '...' } },
|
||||
];
|
||||
|
||||
let res: any;
|
||||
await act(async () => {
|
||||
res = await result.current.handleVisionSwitch(parts, 9090, false);
|
||||
});
|
||||
|
||||
expect(res).toEqual({ shouldProceed: true });
|
||||
expect(onVisionSwitchRequired).not.toHaveBeenCalled();
|
||||
expect(config.setModel).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('restores original model after YOLO mode auto-switch', async () => {
|
||||
const initialModel = 'qwen3-coder-plus';
|
||||
const config = createMockConfig(
|
||||
AuthType.QWEN_OAUTH,
|
||||
initialModel,
|
||||
ApprovalMode.YOLO,
|
||||
);
|
||||
const onVisionSwitchRequired = vi.fn();
|
||||
const { result } = renderHook(() =>
|
||||
useVisionAutoSwitch(
|
||||
config,
|
||||
addItem as any,
|
||||
true,
|
||||
onVisionSwitchRequired,
|
||||
),
|
||||
);
|
||||
|
||||
const parts: PartListUnion = [
|
||||
{ inlineData: { mimeType: 'image/png', data: '...' } },
|
||||
];
|
||||
|
||||
// First, trigger the auto-switch
|
||||
await act(async () => {
|
||||
await result.current.handleVisionSwitch(parts, 10100, false);
|
||||
});
|
||||
|
||||
// Verify model was switched
|
||||
expect(config.setModel).toHaveBeenCalledWith(getDefaultVisionModel(), {
|
||||
reason: 'vision_auto_switch',
|
||||
context: 'YOLO mode auto-switch for image content',
|
||||
});
|
||||
|
||||
// Now restore the original model
|
||||
act(() => {
|
||||
result.current.restoreOriginalModel();
|
||||
});
|
||||
|
||||
// Verify model was restored
|
||||
expect(config.setModel).toHaveBeenLastCalledWith(initialModel, {
|
||||
reason: 'vision_auto_switch',
|
||||
context: 'Restoring original model after vision switch',
|
||||
});
|
||||
});
|
||||
|
||||
it('does not switch in YOLO mode when authType is not QWEN_OAUTH', async () => {
|
||||
const config = createMockConfig(
|
||||
AuthType.USE_GEMINI,
|
||||
'qwen3-coder-plus',
|
||||
ApprovalMode.YOLO,
|
||||
);
|
||||
const onVisionSwitchRequired = vi.fn();
|
||||
const { result } = renderHook(() =>
|
||||
useVisionAutoSwitch(
|
||||
config,
|
||||
addItem as any,
|
||||
true,
|
||||
onVisionSwitchRequired,
|
||||
),
|
||||
);
|
||||
|
||||
const parts: PartListUnion = [
|
||||
{ inlineData: { mimeType: 'image/png', data: '...' } },
|
||||
];
|
||||
|
||||
let res: any;
|
||||
await act(async () => {
|
||||
res = await result.current.handleVisionSwitch(parts, 11110, false);
|
||||
});
|
||||
|
||||
expect(res).toEqual({ shouldProceed: true });
|
||||
expect(onVisionSwitchRequired).not.toHaveBeenCalled();
|
||||
expect(config.setModel).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('does not switch in YOLO mode when visionModelPreviewEnabled is false', async () => {
|
||||
const config = createMockConfig(
|
||||
AuthType.QWEN_OAUTH,
|
||||
'qwen3-coder-plus',
|
||||
ApprovalMode.YOLO,
|
||||
);
|
||||
const onVisionSwitchRequired = vi.fn();
|
||||
const { result } = renderHook(() =>
|
||||
useVisionAutoSwitch(
|
||||
config,
|
||||
addItem as any,
|
||||
false,
|
||||
onVisionSwitchRequired,
|
||||
),
|
||||
);
|
||||
|
||||
const parts: PartListUnion = [
|
||||
{ inlineData: { mimeType: 'image/png', data: '...' } },
|
||||
];
|
||||
|
||||
let res: any;
|
||||
await act(async () => {
|
||||
res = await result.current.handleVisionSwitch(parts, 12120, false);
|
||||
});
|
||||
|
||||
expect(res).toEqual({ shouldProceed: true });
|
||||
expect(onVisionSwitchRequired).not.toHaveBeenCalled();
|
||||
expect(config.setModel).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('handles multiple image formats in YOLO mode', async () => {
|
||||
const initialModel = 'qwen3-coder-plus';
|
||||
const config = createMockConfig(
|
||||
AuthType.QWEN_OAUTH,
|
||||
initialModel,
|
||||
ApprovalMode.YOLO,
|
||||
);
|
||||
const onVisionSwitchRequired = vi.fn();
|
||||
const { result } = renderHook(() =>
|
||||
useVisionAutoSwitch(
|
||||
config,
|
||||
addItem as any,
|
||||
true,
|
||||
onVisionSwitchRequired,
|
||||
),
|
||||
);
|
||||
|
||||
const parts: PartListUnion = [
|
||||
{ text: 'Here are some images:' },
|
||||
{ inlineData: { mimeType: 'image/jpeg', data: '...' } },
|
||||
{ fileData: { mimeType: 'image/png', fileUri: 'file://image.png' } },
|
||||
{ text: 'Please analyze them.' },
|
||||
];
|
||||
|
||||
let res: any;
|
||||
await act(async () => {
|
||||
res = await result.current.handleVisionSwitch(parts, 13130, false);
|
||||
});
|
||||
|
||||
expect(res).toEqual({
|
||||
shouldProceed: true,
|
||||
originalModel: initialModel,
|
||||
});
|
||||
expect(config.setModel).toHaveBeenCalledWith(getDefaultVisionModel(), {
|
||||
reason: 'vision_auto_switch',
|
||||
context: 'YOLO mode auto-switch for image content',
|
||||
});
|
||||
expect(onVisionSwitchRequired).not.toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
describe('VLM switch mode default behavior', () => {
|
||||
it('should automatically switch once when vlmSwitchMode is "once"', async () => {
|
||||
const config = createMockConfig(
|
||||
AuthType.QWEN_OAUTH,
|
||||
'qwen3-coder-plus',
|
||||
ApprovalMode.DEFAULT,
|
||||
'once',
|
||||
);
|
||||
const onVisionSwitchRequired = vi.fn(); // Should not be called
|
||||
const { result } = renderHook(() =>
|
||||
useVisionAutoSwitch(
|
||||
config,
|
||||
addItem as any,
|
||||
true,
|
||||
onVisionSwitchRequired,
|
||||
),
|
||||
);
|
||||
|
||||
const parts: PartListUnion = [
|
||||
{ inlineData: { mimeType: 'image/jpeg', data: 'base64data' } },
|
||||
];
|
||||
|
||||
const switchResult = await result.current.handleVisionSwitch(
|
||||
parts,
|
||||
Date.now(),
|
||||
false,
|
||||
);
|
||||
|
||||
expect(switchResult.shouldProceed).toBe(true);
|
||||
expect(switchResult.originalModel).toBe('qwen3-coder-plus');
|
||||
expect(config.setModel).toHaveBeenCalledWith('vision-model', {
|
||||
reason: 'vision_auto_switch',
|
||||
context: 'Default VLM switch mode: once (one-time override)',
|
||||
});
|
||||
expect(onVisionSwitchRequired).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should switch session when vlmSwitchMode is "session"', async () => {
|
||||
const config = createMockConfig(
|
||||
AuthType.QWEN_OAUTH,
|
||||
'qwen3-coder-plus',
|
||||
ApprovalMode.DEFAULT,
|
||||
'session',
|
||||
);
|
||||
const onVisionSwitchRequired = vi.fn(); // Should not be called
|
||||
const { result } = renderHook(() =>
|
||||
useVisionAutoSwitch(
|
||||
config,
|
||||
addItem as any,
|
||||
true,
|
||||
onVisionSwitchRequired,
|
||||
),
|
||||
);
|
||||
|
||||
const parts: PartListUnion = [
|
||||
{ inlineData: { mimeType: 'image/jpeg', data: 'base64data' } },
|
||||
];
|
||||
|
||||
const switchResult = await result.current.handleVisionSwitch(
|
||||
parts,
|
||||
Date.now(),
|
||||
false,
|
||||
);
|
||||
|
||||
expect(switchResult.shouldProceed).toBe(true);
|
||||
expect(switchResult.originalModel).toBeUndefined(); // No original model for session switch
|
||||
expect(config.setModel).toHaveBeenCalledWith('vision-model', {
|
||||
reason: 'vision_auto_switch',
|
||||
context: 'Default VLM switch mode: session (session persistent)',
|
||||
});
|
||||
expect(onVisionSwitchRequired).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should continue with current model when vlmSwitchMode is "persist"', async () => {
|
||||
const config = createMockConfig(
|
||||
AuthType.QWEN_OAUTH,
|
||||
'qwen3-coder-plus',
|
||||
ApprovalMode.DEFAULT,
|
||||
'persist',
|
||||
);
|
||||
const onVisionSwitchRequired = vi.fn(); // Should not be called
|
||||
const { result } = renderHook(() =>
|
||||
useVisionAutoSwitch(
|
||||
config,
|
||||
addItem as any,
|
||||
true,
|
||||
onVisionSwitchRequired,
|
||||
),
|
||||
);
|
||||
|
||||
const parts: PartListUnion = [
|
||||
{ inlineData: { mimeType: 'image/jpeg', data: 'base64data' } },
|
||||
];
|
||||
|
||||
const switchResult = await result.current.handleVisionSwitch(
|
||||
parts,
|
||||
Date.now(),
|
||||
false,
|
||||
);
|
||||
|
||||
expect(switchResult.shouldProceed).toBe(true);
|
||||
expect(switchResult.originalModel).toBeUndefined();
|
||||
expect(config.setModel).not.toHaveBeenCalled();
|
||||
expect(onVisionSwitchRequired).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should fall back to user prompt when vlmSwitchMode is not set', async () => {
|
||||
const config = createMockConfig(
|
||||
AuthType.QWEN_OAUTH,
|
||||
'qwen3-coder-plus',
|
||||
ApprovalMode.DEFAULT,
|
||||
undefined, // No default mode
|
||||
);
|
||||
const onVisionSwitchRequired = vi
|
||||
.fn()
|
||||
.mockResolvedValue({ modelOverride: 'vision-model' });
|
||||
const { result } = renderHook(() =>
|
||||
useVisionAutoSwitch(
|
||||
config,
|
||||
addItem as any,
|
||||
true,
|
||||
onVisionSwitchRequired,
|
||||
),
|
||||
);
|
||||
|
||||
const parts: PartListUnion = [
|
||||
{ inlineData: { mimeType: 'image/jpeg', data: 'base64data' } },
|
||||
];
|
||||
|
||||
const switchResult = await result.current.handleVisionSwitch(
|
||||
parts,
|
||||
Date.now(),
|
||||
false,
|
||||
);
|
||||
|
||||
expect(switchResult.shouldProceed).toBe(true);
|
||||
expect(onVisionSwitchRequired).toHaveBeenCalledWith(parts);
|
||||
});
|
||||
|
||||
it('should fall back to persist behavior when vlmSwitchMode has invalid value', async () => {
|
||||
const config = createMockConfig(
|
||||
AuthType.QWEN_OAUTH,
|
||||
'qwen3-coder-plus',
|
||||
ApprovalMode.DEFAULT,
|
||||
'invalid-value',
|
||||
);
|
||||
const onVisionSwitchRequired = vi.fn(); // Should not be called
|
||||
const { result } = renderHook(() =>
|
||||
useVisionAutoSwitch(
|
||||
config,
|
||||
addItem as any,
|
||||
true,
|
||||
onVisionSwitchRequired,
|
||||
),
|
||||
);
|
||||
|
||||
const parts: PartListUnion = [
|
||||
{ inlineData: { mimeType: 'image/jpeg', data: 'base64data' } },
|
||||
];
|
||||
|
||||
const switchResult = await result.current.handleVisionSwitch(
|
||||
parts,
|
||||
Date.now(),
|
||||
false,
|
||||
);
|
||||
|
||||
expect(switchResult.shouldProceed).toBe(true);
|
||||
expect(switchResult.originalModel).toBeUndefined();
|
||||
// For invalid values, it should continue with current model (persist behavior)
|
||||
expect(config.setModel).not.toHaveBeenCalled();
|
||||
expect(onVisionSwitchRequired).not.toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
});
|
||||
363
packages/cli/src/ui/hooks/useVisionAutoSwitch.ts
Normal file
363
packages/cli/src/ui/hooks/useVisionAutoSwitch.ts
Normal file
@@ -0,0 +1,363 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Qwen
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { type PartListUnion, type Part } from '@google/genai';
|
||||
import { AuthType, type Config, ApprovalMode } from '@qwen-code/qwen-code-core';
|
||||
import { useCallback, useRef } from 'react';
|
||||
import { VisionSwitchOutcome } from '../components/ModelSwitchDialog.js';
|
||||
import {
|
||||
getDefaultVisionModel,
|
||||
isVisionModel,
|
||||
} from '../models/availableModels.js';
|
||||
import { MessageType } from '../types.js';
|
||||
import type { UseHistoryManagerReturn } from './useHistoryManager.js';
|
||||
import {
|
||||
isSupportedImageMimeType,
|
||||
getUnsupportedImageFormatWarning,
|
||||
} from '@qwen-code/qwen-code-core';
|
||||
|
||||
/**
|
||||
* Checks if a PartListUnion contains image parts
|
||||
*/
|
||||
function hasImageParts(parts: PartListUnion): boolean {
|
||||
if (typeof parts === 'string') {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (Array.isArray(parts)) {
|
||||
return parts.some((part) => {
|
||||
// Skip string parts
|
||||
if (typeof part === 'string') return false;
|
||||
return isImagePart(part);
|
||||
});
|
||||
}
|
||||
|
||||
// If it's a single Part (not a string), check if it's an image
|
||||
if (typeof parts === 'object') {
|
||||
return isImagePart(parts);
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
/**
|
||||
* Checks if a single Part is an image part
|
||||
*/
|
||||
function isImagePart(part: Part): boolean {
|
||||
// Check for inlineData with image mime type
|
||||
if ('inlineData' in part && part.inlineData?.mimeType?.startsWith('image/')) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// Check for fileData with image mime type
|
||||
if ('fileData' in part && part.fileData?.mimeType?.startsWith('image/')) {
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
/**
|
||||
* Checks if image parts have supported formats and returns unsupported ones
|
||||
*/
|
||||
function checkImageFormatsSupport(parts: PartListUnion): {
|
||||
hasImages: boolean;
|
||||
hasUnsupportedFormats: boolean;
|
||||
unsupportedMimeTypes: string[];
|
||||
} {
|
||||
const unsupportedMimeTypes: string[] = [];
|
||||
let hasImages = false;
|
||||
|
||||
if (typeof parts === 'string') {
|
||||
return {
|
||||
hasImages: false,
|
||||
hasUnsupportedFormats: false,
|
||||
unsupportedMimeTypes: [],
|
||||
};
|
||||
}
|
||||
|
||||
const partsArray = Array.isArray(parts) ? parts : [parts];
|
||||
|
||||
for (const part of partsArray) {
|
||||
if (typeof part === 'string') continue;
|
||||
|
||||
let mimeType: string | undefined;
|
||||
|
||||
// Check inlineData
|
||||
if (
|
||||
'inlineData' in part &&
|
||||
part.inlineData?.mimeType?.startsWith('image/')
|
||||
) {
|
||||
hasImages = true;
|
||||
mimeType = part.inlineData.mimeType;
|
||||
}
|
||||
|
||||
// Check fileData
|
||||
if ('fileData' in part && part.fileData?.mimeType?.startsWith('image/')) {
|
||||
hasImages = true;
|
||||
mimeType = part.fileData.mimeType;
|
||||
}
|
||||
|
||||
// Check if the mime type is supported
|
||||
if (mimeType && !isSupportedImageMimeType(mimeType)) {
|
||||
unsupportedMimeTypes.push(mimeType);
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
hasImages,
|
||||
hasUnsupportedFormats: unsupportedMimeTypes.length > 0,
|
||||
unsupportedMimeTypes,
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Determines if we should offer vision switch for the given parts, auth type, and current model
|
||||
*/
|
||||
export function shouldOfferVisionSwitch(
|
||||
parts: PartListUnion,
|
||||
authType: AuthType,
|
||||
currentModel: string,
|
||||
visionModelPreviewEnabled: boolean = true,
|
||||
): boolean {
|
||||
// Only trigger for qwen-oauth
|
||||
if (authType !== AuthType.QWEN_OAUTH) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// If vision model preview is disabled, never offer vision switch
|
||||
if (!visionModelPreviewEnabled) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// If current model is already a vision model, no need to switch
|
||||
if (isVisionModel(currentModel)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check if the current message contains image parts
|
||||
return hasImageParts(parts);
|
||||
}
|
||||
|
||||
/**
|
||||
* Interface for vision switch result
|
||||
*/
|
||||
export interface VisionSwitchResult {
|
||||
modelOverride?: string;
|
||||
persistSessionModel?: string;
|
||||
showGuidance?: boolean;
|
||||
}
|
||||
|
||||
/**
|
||||
* Processes the vision switch outcome and returns the appropriate result
|
||||
*/
|
||||
export function processVisionSwitchOutcome(
|
||||
outcome: VisionSwitchOutcome,
|
||||
): VisionSwitchResult {
|
||||
const vlModelId = getDefaultVisionModel();
|
||||
|
||||
switch (outcome) {
|
||||
case VisionSwitchOutcome.SwitchOnce:
|
||||
return { modelOverride: vlModelId };
|
||||
|
||||
case VisionSwitchOutcome.SwitchSessionToVL:
|
||||
return { persistSessionModel: vlModelId };
|
||||
|
||||
case VisionSwitchOutcome.ContinueWithCurrentModel:
|
||||
return {}; // Continue with current model, no changes needed
|
||||
|
||||
default:
|
||||
return {}; // Default to continuing with current model
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the guidance message for when vision switch is disallowed
|
||||
*/
|
||||
export function getVisionSwitchGuidanceMessage(): string {
|
||||
const vlModelId = getDefaultVisionModel();
|
||||
return `To use images with your query, you can:
|
||||
• Use /model set ${vlModelId} to switch to a vision-capable model
|
||||
• Or remove the image and provide a text description instead`;
|
||||
}
|
||||
|
||||
/**
|
||||
* Interface for vision switch handling result
|
||||
*/
|
||||
export interface VisionSwitchHandlingResult {
|
||||
shouldProceed: boolean;
|
||||
originalModel?: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* Custom hook for handling vision model auto-switching
|
||||
*/
|
||||
export function useVisionAutoSwitch(
|
||||
config: Config,
|
||||
addItem: UseHistoryManagerReturn['addItem'],
|
||||
visionModelPreviewEnabled: boolean = true,
|
||||
onVisionSwitchRequired?: (query: PartListUnion) => Promise<{
|
||||
modelOverride?: string;
|
||||
persistSessionModel?: string;
|
||||
showGuidance?: boolean;
|
||||
}>,
|
||||
) {
|
||||
const originalModelRef = useRef<string | null>(null);
|
||||
|
||||
const handleVisionSwitch = useCallback(
|
||||
async (
|
||||
query: PartListUnion,
|
||||
userMessageTimestamp: number,
|
||||
isContinuation: boolean,
|
||||
): Promise<VisionSwitchHandlingResult> => {
|
||||
// Skip vision switch handling for continuations or if no handler provided
|
||||
if (isContinuation || !onVisionSwitchRequired) {
|
||||
return { shouldProceed: true };
|
||||
}
|
||||
|
||||
const contentGeneratorConfig = config.getContentGeneratorConfig();
|
||||
|
||||
// Only handle qwen-oauth auth type
|
||||
if (contentGeneratorConfig?.authType !== AuthType.QWEN_OAUTH) {
|
||||
return { shouldProceed: true };
|
||||
}
|
||||
|
||||
// Check image format support first
|
||||
const formatCheck = checkImageFormatsSupport(query);
|
||||
|
||||
// If there are unsupported image formats, show warning
|
||||
if (formatCheck.hasUnsupportedFormats) {
|
||||
addItem(
|
||||
{
|
||||
type: MessageType.INFO,
|
||||
text: getUnsupportedImageFormatWarning(),
|
||||
},
|
||||
userMessageTimestamp,
|
||||
);
|
||||
// Continue processing but with warning shown
|
||||
}
|
||||
|
||||
// Check if vision switch is needed
|
||||
if (
|
||||
!shouldOfferVisionSwitch(
|
||||
query,
|
||||
contentGeneratorConfig.authType,
|
||||
config.getModel(),
|
||||
visionModelPreviewEnabled,
|
||||
)
|
||||
) {
|
||||
return { shouldProceed: true };
|
||||
}
|
||||
|
||||
// In YOLO mode, automatically switch to vision model without user interaction
|
||||
if (config.getApprovalMode() === ApprovalMode.YOLO) {
|
||||
const vlModelId = getDefaultVisionModel();
|
||||
originalModelRef.current = config.getModel();
|
||||
config.setModel(vlModelId, {
|
||||
reason: 'vision_auto_switch',
|
||||
context: 'YOLO mode auto-switch for image content',
|
||||
});
|
||||
return {
|
||||
shouldProceed: true,
|
||||
originalModel: originalModelRef.current,
|
||||
};
|
||||
}
|
||||
|
||||
// Check if there's a default VLM switch mode configured
|
||||
const defaultVlmSwitchMode = config.getVlmSwitchMode();
|
||||
if (defaultVlmSwitchMode) {
|
||||
// Convert string value to VisionSwitchOutcome enum
|
||||
let outcome: VisionSwitchOutcome;
|
||||
switch (defaultVlmSwitchMode) {
|
||||
case 'once':
|
||||
outcome = VisionSwitchOutcome.SwitchOnce;
|
||||
break;
|
||||
case 'session':
|
||||
outcome = VisionSwitchOutcome.SwitchSessionToVL;
|
||||
break;
|
||||
case 'persist':
|
||||
outcome = VisionSwitchOutcome.ContinueWithCurrentModel;
|
||||
break;
|
||||
default:
|
||||
// Invalid value, fall back to prompting user
|
||||
outcome = VisionSwitchOutcome.ContinueWithCurrentModel;
|
||||
}
|
||||
|
||||
// Process the default outcome
|
||||
const visionSwitchResult = processVisionSwitchOutcome(outcome);
|
||||
|
||||
if (visionSwitchResult.modelOverride) {
|
||||
// One-time model override
|
||||
originalModelRef.current = config.getModel();
|
||||
config.setModel(visionSwitchResult.modelOverride, {
|
||||
reason: 'vision_auto_switch',
|
||||
context: `Default VLM switch mode: ${defaultVlmSwitchMode} (one-time override)`,
|
||||
});
|
||||
return {
|
||||
shouldProceed: true,
|
||||
originalModel: originalModelRef.current,
|
||||
};
|
||||
} else if (visionSwitchResult.persistSessionModel) {
|
||||
// Persistent session model change
|
||||
config.setModel(visionSwitchResult.persistSessionModel, {
|
||||
reason: 'vision_auto_switch',
|
||||
context: `Default VLM switch mode: ${defaultVlmSwitchMode} (session persistent)`,
|
||||
});
|
||||
return { shouldProceed: true };
|
||||
}
|
||||
|
||||
// For ContinueWithCurrentModel or any other case, proceed with current model
|
||||
return { shouldProceed: true };
|
||||
}
|
||||
|
||||
try {
|
||||
const visionSwitchResult = await onVisionSwitchRequired(query);
|
||||
|
||||
if (visionSwitchResult.modelOverride) {
|
||||
// One-time model override
|
||||
originalModelRef.current = config.getModel();
|
||||
config.setModel(visionSwitchResult.modelOverride, {
|
||||
reason: 'vision_auto_switch',
|
||||
context: 'User-prompted vision switch (one-time override)',
|
||||
});
|
||||
return {
|
||||
shouldProceed: true,
|
||||
originalModel: originalModelRef.current,
|
||||
};
|
||||
} else if (visionSwitchResult.persistSessionModel) {
|
||||
// Persistent session model change
|
||||
config.setModel(visionSwitchResult.persistSessionModel, {
|
||||
reason: 'vision_auto_switch',
|
||||
context: 'User-prompted vision switch (session persistent)',
|
||||
});
|
||||
return { shouldProceed: true };
|
||||
}
|
||||
|
||||
// For ContinueWithCurrentModel or any other case, proceed with current model
|
||||
return { shouldProceed: true };
|
||||
} catch (_error) {
|
||||
// If vision switch dialog was cancelled or errored, don't proceed
|
||||
return { shouldProceed: false };
|
||||
}
|
||||
},
|
||||
[config, addItem, visionModelPreviewEnabled, onVisionSwitchRequired],
|
||||
);
|
||||
|
||||
const restoreOriginalModel = useCallback(() => {
|
||||
if (originalModelRef.current) {
|
||||
config.setModel(originalModelRef.current, {
|
||||
reason: 'vision_auto_switch',
|
||||
context: 'Restoring original model after vision switch',
|
||||
});
|
||||
originalModelRef.current = null;
|
||||
}
|
||||
}, [config]);
|
||||
|
||||
return {
|
||||
handleVisionSwitch,
|
||||
restoreOriginalModel,
|
||||
};
|
||||
}
|
||||
55
packages/cli/src/ui/models/availableModels.ts
Normal file
55
packages/cli/src/ui/models/availableModels.ts
Normal file
@@ -0,0 +1,55 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Qwen
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
export type AvailableModel = {
|
||||
id: string;
|
||||
label: string;
|
||||
isVision?: boolean;
|
||||
};
|
||||
|
||||
export const MAINLINE_VLM = 'vision-model';
|
||||
export const MAINLINE_CODER = 'coder-model';
|
||||
|
||||
export const AVAILABLE_MODELS_QWEN: AvailableModel[] = [
|
||||
{ id: MAINLINE_CODER, label: MAINLINE_CODER },
|
||||
{ id: MAINLINE_VLM, label: MAINLINE_VLM, isVision: true },
|
||||
];
|
||||
|
||||
/**
|
||||
* Get available Qwen models filtered by vision model preview setting
|
||||
*/
|
||||
export function getFilteredQwenModels(
|
||||
visionModelPreviewEnabled: boolean,
|
||||
): AvailableModel[] {
|
||||
if (visionModelPreviewEnabled) {
|
||||
return AVAILABLE_MODELS_QWEN;
|
||||
}
|
||||
return AVAILABLE_MODELS_QWEN.filter((model) => !model.isVision);
|
||||
}
|
||||
|
||||
/**
|
||||
* Currently we use the single model of `OPENAI_MODEL` in the env.
|
||||
* In the future, after settings.json is updated, we will allow users to configure this themselves.
|
||||
*/
|
||||
export function getOpenAIAvailableModelFromEnv(): AvailableModel | null {
|
||||
const id = process.env['OPENAI_MODEL']?.trim();
|
||||
return id ? { id, label: id } : null;
|
||||
}
|
||||
|
||||
/**
|
||||
/**
|
||||
* Hard code the default vision model as a string literal,
|
||||
* until our coding model supports multimodal.
|
||||
*/
|
||||
export function getDefaultVisionModel(): string {
|
||||
return MAINLINE_VLM;
|
||||
}
|
||||
|
||||
export function isVisionModel(modelId: string): boolean {
|
||||
return AVAILABLE_MODELS_QWEN.some(
|
||||
(model) => model.id === modelId && model.isVision,
|
||||
);
|
||||
}
|
||||
@@ -126,6 +126,18 @@ describe('validateNonInterActiveAuth', () => {
|
||||
expect(refreshAuthMock).toHaveBeenCalledWith(AuthType.USE_OPENAI);
|
||||
});
|
||||
|
||||
it('uses configured QWEN_OAUTH if provided', async () => {
|
||||
const nonInteractiveConfig: NonInteractiveConfig = {
|
||||
refreshAuth: refreshAuthMock,
|
||||
};
|
||||
await validateNonInteractiveAuth(
|
||||
AuthType.QWEN_OAUTH,
|
||||
undefined,
|
||||
nonInteractiveConfig,
|
||||
);
|
||||
expect(refreshAuthMock).toHaveBeenCalledWith(AuthType.QWEN_OAUTH);
|
||||
});
|
||||
|
||||
it('uses USE_VERTEX_AI if GOOGLE_GENAI_USE_VERTEXAI is true (with GOOGLE_CLOUD_PROJECT and GOOGLE_CLOUD_LOCATION)', async () => {
|
||||
process.env['GOOGLE_GENAI_USE_VERTEXAI'] = 'true';
|
||||
process.env['GOOGLE_CLOUD_PROJECT'] = 'test-project';
|
||||
|
||||
@@ -97,6 +97,18 @@ class GeminiAgent {
|
||||
name: 'Vertex AI',
|
||||
description: null,
|
||||
},
|
||||
{
|
||||
id: AuthType.USE_OPENAI,
|
||||
name: 'Use OpenAI API key',
|
||||
description:
|
||||
'Requires setting the `OPENAI_API_KEY` environment variable',
|
||||
},
|
||||
{
|
||||
id: AuthType.QWEN_OAUTH,
|
||||
name: 'Qwen OAuth',
|
||||
description:
|
||||
'OAuth authentication for Qwen models with 2000 daily requests',
|
||||
},
|
||||
];
|
||||
|
||||
return {
|
||||
|
||||
@@ -19,3 +19,4 @@ export {
|
||||
} from './src/telemetry/types.js';
|
||||
export { makeFakeConfig } from './src/test-utils/config.js';
|
||||
export * from './src/utils/pathReader.js';
|
||||
export * from './src/utils/request-tokenizer/supportedImageFormats.js';
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@qwen-code/qwen-code-core",
|
||||
"version": "0.0.11",
|
||||
"version": "0.0.13-nightly.2",
|
||||
"description": "Qwen Code Core",
|
||||
"repository": {
|
||||
"type": "git",
|
||||
|
||||
@@ -737,4 +737,85 @@ describe('setApprovalMode with folder trust', () => {
|
||||
expect(() => config.setApprovalMode(ApprovalMode.AUTO_EDIT)).not.toThrow();
|
||||
expect(() => config.setApprovalMode(ApprovalMode.DEFAULT)).not.toThrow();
|
||||
});
|
||||
|
||||
describe('Model Switch Logging', () => {
|
||||
it('should log model switch when setModel is called with different model', async () => {
|
||||
const config = new Config({
|
||||
sessionId: 'test-model-switch',
|
||||
targetDir: '.',
|
||||
debugMode: false,
|
||||
model: 'qwen3-coder-plus',
|
||||
cwd: '.',
|
||||
});
|
||||
|
||||
// Initialize the config to set up content generator
|
||||
await config.initialize();
|
||||
|
||||
// Mock the logger's logModelSwitch method
|
||||
const logModelSwitchSpy = vi.spyOn(config['logger']!, 'logModelSwitch');
|
||||
|
||||
// Change the model
|
||||
config.setModel('qwen-vl-max-latest', {
|
||||
reason: 'vision_auto_switch',
|
||||
context: 'Test model switch',
|
||||
});
|
||||
|
||||
// Verify that logModelSwitch was called with correct parameters
|
||||
expect(logModelSwitchSpy).toHaveBeenCalledWith({
|
||||
fromModel: 'qwen3-coder-plus',
|
||||
toModel: 'qwen-vl-max-latest',
|
||||
reason: 'vision_auto_switch',
|
||||
context: 'Test model switch',
|
||||
});
|
||||
});
|
||||
|
||||
it('should not log when setModel is called with same model', async () => {
|
||||
const config = new Config({
|
||||
sessionId: 'test-same-model',
|
||||
targetDir: '.',
|
||||
debugMode: false,
|
||||
model: 'qwen3-coder-plus',
|
||||
cwd: '.',
|
||||
});
|
||||
|
||||
// Initialize the config to set up content generator
|
||||
await config.initialize();
|
||||
|
||||
// Mock the logger's logModelSwitch method
|
||||
const logModelSwitchSpy = vi.spyOn(config['logger']!, 'logModelSwitch');
|
||||
|
||||
// Set the same model
|
||||
config.setModel('qwen3-coder-plus');
|
||||
|
||||
// Verify that logModelSwitch was not called
|
||||
expect(logModelSwitchSpy).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should use default reason when no options provided', async () => {
|
||||
const config = new Config({
|
||||
sessionId: 'test-default-reason',
|
||||
targetDir: '.',
|
||||
debugMode: false,
|
||||
model: 'qwen3-coder-plus',
|
||||
cwd: '.',
|
||||
});
|
||||
|
||||
// Initialize the config to set up content generator
|
||||
await config.initialize();
|
||||
|
||||
// Mock the logger's logModelSwitch method
|
||||
const logModelSwitchSpy = vi.spyOn(config['logger']!, 'logModelSwitch');
|
||||
|
||||
// Change the model without options
|
||||
config.setModel('qwen-vl-max-latest');
|
||||
|
||||
// Verify that logModelSwitch was called with default reason
|
||||
expect(logModelSwitchSpy).toHaveBeenCalledWith({
|
||||
fromModel: 'qwen3-coder-plus',
|
||||
toModel: 'qwen-vl-max-latest',
|
||||
reason: 'manual',
|
||||
context: undefined,
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -56,6 +56,7 @@ import {
|
||||
DEFAULT_GEMINI_FLASH_MODEL,
|
||||
} from './models.js';
|
||||
import { Storage } from './storage.js';
|
||||
import { Logger, type ModelSwitchEvent } from '../core/logger.js';
|
||||
|
||||
// Re-export OAuth config type
|
||||
export type { AnyToolInvocation, MCPOAuthConfig };
|
||||
@@ -239,6 +240,7 @@ export interface ConfigParameters {
|
||||
extensionManagement?: boolean;
|
||||
enablePromptCompletion?: boolean;
|
||||
skipLoopDetection?: boolean;
|
||||
vlmSwitchMode?: string;
|
||||
}
|
||||
|
||||
export class Config {
|
||||
@@ -330,9 +332,11 @@ export class Config {
|
||||
private readonly extensionManagement: boolean;
|
||||
private readonly enablePromptCompletion: boolean = false;
|
||||
private readonly skipLoopDetection: boolean;
|
||||
private readonly vlmSwitchMode: string | undefined;
|
||||
private initialized: boolean = false;
|
||||
readonly storage: Storage;
|
||||
private readonly fileExclusions: FileExclusions;
|
||||
private logger: Logger | null = null;
|
||||
|
||||
constructor(params: ConfigParameters) {
|
||||
this.sessionId = params.sessionId;
|
||||
@@ -424,8 +428,15 @@ export class Config {
|
||||
this.extensionManagement = params.extensionManagement ?? false;
|
||||
this.storage = new Storage(this.targetDir);
|
||||
this.enablePromptCompletion = params.enablePromptCompletion ?? false;
|
||||
this.vlmSwitchMode = params.vlmSwitchMode;
|
||||
this.fileExclusions = new FileExclusions(this);
|
||||
|
||||
// Initialize logger asynchronously
|
||||
this.logger = new Logger(this.sessionId, this.storage);
|
||||
this.logger.initialize().catch((error) => {
|
||||
console.debug('Failed to initialize logger:', error);
|
||||
});
|
||||
|
||||
if (params.contextFileName) {
|
||||
setGeminiMdFilename(params.contextFileName);
|
||||
}
|
||||
@@ -517,10 +528,45 @@ export class Config {
|
||||
return this.contentGeneratorConfig?.model || this.model;
|
||||
}
|
||||
|
||||
setModel(newModel: string): void {
|
||||
setModel(
|
||||
newModel: string,
|
||||
options?: {
|
||||
reason?: ModelSwitchEvent['reason'];
|
||||
context?: string;
|
||||
},
|
||||
): void {
|
||||
const oldModel = this.getModel();
|
||||
|
||||
if (this.contentGeneratorConfig) {
|
||||
this.contentGeneratorConfig.model = newModel;
|
||||
}
|
||||
|
||||
// Log the model switch if the model actually changed
|
||||
if (oldModel !== newModel && this.logger) {
|
||||
const switchEvent: ModelSwitchEvent = {
|
||||
fromModel: oldModel,
|
||||
toModel: newModel,
|
||||
reason: options?.reason || 'manual',
|
||||
context: options?.context,
|
||||
};
|
||||
|
||||
// Log asynchronously to avoid blocking
|
||||
this.logger.logModelSwitch(switchEvent).catch((error) => {
|
||||
console.debug('Failed to log model switch:', error);
|
||||
});
|
||||
}
|
||||
|
||||
// Reinitialize chat with updated configuration while preserving history
|
||||
const geminiClient = this.getGeminiClient();
|
||||
if (geminiClient && geminiClient.isInitialized()) {
|
||||
// Use async operation but don't await to avoid blocking
|
||||
geminiClient.reinitialize().catch((error) => {
|
||||
console.error(
|
||||
'Failed to reinitialize chat with updated config:',
|
||||
error,
|
||||
);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
isInFallbackMode(): boolean {
|
||||
@@ -926,6 +972,10 @@ export class Config {
|
||||
return this.skipLoopDetection;
|
||||
}
|
||||
|
||||
getVlmSwitchMode(): string | undefined {
|
||||
return this.vlmSwitchMode;
|
||||
}
|
||||
|
||||
async getGitService(): Promise<GitService> {
|
||||
if (!this.gitService) {
|
||||
this.gitService = new GitService(this.targetDir, this.storage);
|
||||
|
||||
@@ -4,11 +4,10 @@
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
export const DEFAULT_QWEN_MODEL = 'qwen3-coder-plus';
|
||||
// We do not have a fallback model for now, but note it here anyway.
|
||||
export const DEFAULT_QWEN_FLASH_MODEL = 'qwen3-coder-flash';
|
||||
export const DEFAULT_QWEN_MODEL = 'coder-model';
|
||||
export const DEFAULT_QWEN_FLASH_MODEL = 'coder-model';
|
||||
|
||||
export const DEFAULT_GEMINI_MODEL = 'qwen3-coder-plus';
|
||||
export const DEFAULT_GEMINI_MODEL = 'coder-model';
|
||||
export const DEFAULT_GEMINI_FLASH_MODEL = 'gemini-2.5-flash';
|
||||
export const DEFAULT_GEMINI_FLASH_LITE_MODEL = 'gemini-2.5-flash-lite';
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -5,9 +5,10 @@
|
||||
*/
|
||||
|
||||
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
|
||||
import { OpenAIContentGenerator } from '../openaiContentGenerator.js';
|
||||
import { OpenAIContentGenerator } from '../openaiContentGenerator/openaiContentGenerator.js';
|
||||
import type { Config } from '../../config/config.js';
|
||||
import { AuthType } from '../contentGenerator.js';
|
||||
import type { OpenAICompatibleProvider } from '../openaiContentGenerator/provider/index.js';
|
||||
import OpenAI from 'openai';
|
||||
|
||||
// Mock OpenAI
|
||||
@@ -30,6 +31,7 @@ describe('OpenAIContentGenerator Timeout Handling', () => {
|
||||
let mockConfig: Config;
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
let mockOpenAIClient: any;
|
||||
let mockProvider: OpenAICompatibleProvider;
|
||||
|
||||
beforeEach(() => {
|
||||
// Reset mocks
|
||||
@@ -42,6 +44,7 @@ describe('OpenAIContentGenerator Timeout Handling', () => {
|
||||
mockConfig = {
|
||||
getContentGeneratorConfig: vi.fn().mockReturnValue({
|
||||
authType: 'openai',
|
||||
enableOpenAILogging: false,
|
||||
}),
|
||||
getCliVersion: vi.fn().mockReturnValue('1.0.0'),
|
||||
} as unknown as Config;
|
||||
@@ -53,17 +56,34 @@ describe('OpenAIContentGenerator Timeout Handling', () => {
|
||||
create: vi.fn(),
|
||||
},
|
||||
},
|
||||
embeddings: {
|
||||
create: vi.fn(),
|
||||
},
|
||||
};
|
||||
|
||||
vi.mocked(OpenAI).mockImplementation(() => mockOpenAIClient);
|
||||
|
||||
// Create mock provider
|
||||
mockProvider = {
|
||||
buildHeaders: vi.fn().mockReturnValue({
|
||||
'User-Agent': 'QwenCode/1.0.0 (test; test)',
|
||||
}),
|
||||
buildClient: vi.fn().mockReturnValue(mockOpenAIClient),
|
||||
buildRequest: vi.fn().mockImplementation((req) => req),
|
||||
};
|
||||
|
||||
// Create generator instance
|
||||
const contentGeneratorConfig = {
|
||||
model: 'gpt-4',
|
||||
apiKey: 'test-key',
|
||||
authType: AuthType.USE_OPENAI,
|
||||
enableOpenAILogging: false,
|
||||
};
|
||||
generator = new OpenAIContentGenerator(contentGeneratorConfig, mockConfig);
|
||||
generator = new OpenAIContentGenerator(
|
||||
contentGeneratorConfig,
|
||||
mockConfig,
|
||||
mockProvider,
|
||||
);
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
@@ -209,7 +229,7 @@ describe('OpenAIContentGenerator Timeout Handling', () => {
|
||||
await expect(
|
||||
generator.generateContentStream(request, 'test-prompt-id'),
|
||||
).rejects.toThrow(
|
||||
/Streaming setup timeout after \d+s\. Try reducing input length or increasing timeout in config\./,
|
||||
/Streaming request timeout after \d+s\. Try reducing input length or increasing timeout in config\./,
|
||||
);
|
||||
});
|
||||
|
||||
@@ -227,12 +247,8 @@ describe('OpenAIContentGenerator Timeout Handling', () => {
|
||||
} catch (error: unknown) {
|
||||
const errorMessage =
|
||||
error instanceof Error ? error.message : String(error);
|
||||
expect(errorMessage).toContain(
|
||||
'Streaming setup timeout troubleshooting:',
|
||||
);
|
||||
expect(errorMessage).toContain(
|
||||
'Check network connectivity and firewall settings',
|
||||
);
|
||||
expect(errorMessage).toContain('Streaming timeout troubleshooting:');
|
||||
expect(errorMessage).toContain('Check network connectivity');
|
||||
expect(errorMessage).toContain('Consider using non-streaming mode');
|
||||
}
|
||||
});
|
||||
@@ -246,23 +262,21 @@ describe('OpenAIContentGenerator Timeout Handling', () => {
|
||||
authType: AuthType.USE_OPENAI,
|
||||
baseUrl: 'http://localhost:8080',
|
||||
};
|
||||
new OpenAIContentGenerator(contentGeneratorConfig, mockConfig);
|
||||
new OpenAIContentGenerator(
|
||||
contentGeneratorConfig,
|
||||
mockConfig,
|
||||
mockProvider,
|
||||
);
|
||||
|
||||
// Verify OpenAI client was created with timeout config
|
||||
expect(OpenAI).toHaveBeenCalledWith({
|
||||
apiKey: 'test-key',
|
||||
baseURL: 'http://localhost:8080',
|
||||
timeout: 120000,
|
||||
maxRetries: 3,
|
||||
defaultHeaders: {
|
||||
'User-Agent': expect.stringMatching(/^QwenCode/),
|
||||
},
|
||||
});
|
||||
// Verify provider buildClient was called
|
||||
expect(mockProvider.buildClient).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should use custom timeout from config', () => {
|
||||
const customConfig = {
|
||||
getContentGeneratorConfig: vi.fn().mockReturnValue({}),
|
||||
getContentGeneratorConfig: vi.fn().mockReturnValue({
|
||||
enableOpenAILogging: false,
|
||||
}),
|
||||
getCliVersion: vi.fn().mockReturnValue('1.0.0'),
|
||||
} as unknown as Config;
|
||||
|
||||
@@ -274,22 +288,31 @@ describe('OpenAIContentGenerator Timeout Handling', () => {
|
||||
timeout: 300000,
|
||||
maxRetries: 5,
|
||||
};
|
||||
new OpenAIContentGenerator(contentGeneratorConfig, customConfig);
|
||||
|
||||
expect(OpenAI).toHaveBeenCalledWith({
|
||||
apiKey: 'test-key',
|
||||
baseURL: 'http://localhost:8080',
|
||||
timeout: 300000,
|
||||
maxRetries: 5,
|
||||
defaultHeaders: {
|
||||
'User-Agent': expect.stringMatching(/^QwenCode/),
|
||||
},
|
||||
});
|
||||
// Create a custom mock provider for this test
|
||||
const customMockProvider: OpenAICompatibleProvider = {
|
||||
buildHeaders: vi.fn().mockReturnValue({
|
||||
'User-Agent': 'QwenCode/1.0.0 (test; test)',
|
||||
}),
|
||||
buildClient: vi.fn().mockReturnValue(mockOpenAIClient),
|
||||
buildRequest: vi.fn().mockImplementation((req) => req),
|
||||
};
|
||||
|
||||
new OpenAIContentGenerator(
|
||||
contentGeneratorConfig,
|
||||
customConfig,
|
||||
customMockProvider,
|
||||
);
|
||||
|
||||
// Verify provider buildClient was called
|
||||
expect(customMockProvider.buildClient).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should handle missing timeout config gracefully', () => {
|
||||
const noTimeoutConfig = {
|
||||
getContentGeneratorConfig: vi.fn().mockReturnValue({}),
|
||||
getContentGeneratorConfig: vi.fn().mockReturnValue({
|
||||
enableOpenAILogging: false,
|
||||
}),
|
||||
getCliVersion: vi.fn().mockReturnValue('1.0.0'),
|
||||
} as unknown as Config;
|
||||
|
||||
@@ -299,17 +322,24 @@ describe('OpenAIContentGenerator Timeout Handling', () => {
|
||||
authType: AuthType.USE_OPENAI,
|
||||
baseUrl: 'http://localhost:8080',
|
||||
};
|
||||
new OpenAIContentGenerator(contentGeneratorConfig, noTimeoutConfig);
|
||||
|
||||
expect(OpenAI).toHaveBeenCalledWith({
|
||||
apiKey: 'test-key',
|
||||
baseURL: 'http://localhost:8080',
|
||||
timeout: 120000, // default
|
||||
maxRetries: 3, // default
|
||||
defaultHeaders: {
|
||||
'User-Agent': expect.stringMatching(/^QwenCode/),
|
||||
},
|
||||
});
|
||||
// Create a custom mock provider for this test
|
||||
const noTimeoutMockProvider: OpenAICompatibleProvider = {
|
||||
buildHeaders: vi.fn().mockReturnValue({
|
||||
'User-Agent': 'QwenCode/1.0.0 (test; test)',
|
||||
}),
|
||||
buildClient: vi.fn().mockReturnValue(mockOpenAIClient),
|
||||
buildRequest: vi.fn().mockImplementation((req) => req),
|
||||
};
|
||||
|
||||
new OpenAIContentGenerator(
|
||||
contentGeneratorConfig,
|
||||
noTimeoutConfig,
|
||||
noTimeoutMockProvider,
|
||||
);
|
||||
|
||||
// Verify provider buildClient was called
|
||||
expect(noTimeoutMockProvider.buildClient).toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
|
||||
@@ -226,6 +226,9 @@ describe('Gemini Client (client.ts)', () => {
|
||||
vertexai: false,
|
||||
authType: AuthType.USE_GEMINI,
|
||||
};
|
||||
const mockSubagentManager = {
|
||||
listSubagents: vi.fn().mockResolvedValue([]),
|
||||
};
|
||||
const mockConfigObject = {
|
||||
getContentGeneratorConfig: vi
|
||||
.fn()
|
||||
@@ -260,6 +263,7 @@ describe('Gemini Client (client.ts)', () => {
|
||||
getCliVersion: vi.fn().mockReturnValue('1.0.0'),
|
||||
getChatCompression: vi.fn().mockReturnValue(undefined),
|
||||
getSkipNextSpeakerCheck: vi.fn().mockReturnValue(false),
|
||||
getSubagentManager: vi.fn().mockReturnValue(mockSubagentManager),
|
||||
getSkipLoopDetection: vi.fn().mockReturnValue(false),
|
||||
};
|
||||
const MockedConfig = vi.mocked(Config, true);
|
||||
@@ -437,7 +441,8 @@ describe('Gemini Client (client.ts)', () => {
|
||||
);
|
||||
});
|
||||
|
||||
it('should allow overriding model and config', async () => {
|
||||
/* We now use model in contentGeneratorConfig in most cases. */
|
||||
it.skip('should allow overriding model and config', async () => {
|
||||
const contents: Content[] = [
|
||||
{ role: 'user', parts: [{ text: 'hello' }] },
|
||||
];
|
||||
@@ -2545,4 +2550,82 @@ ${JSON.stringify(
|
||||
expect(mockChat.setHistory).toHaveBeenCalledWith(historyWithThoughts);
|
||||
});
|
||||
});
|
||||
|
||||
describe('initialize', () => {
|
||||
it('should accept extraHistory parameter and pass it to startChat', async () => {
|
||||
const mockStartChat = vi.fn().mockResolvedValue({});
|
||||
client['startChat'] = mockStartChat;
|
||||
|
||||
const extraHistory = [
|
||||
{ role: 'user', parts: [{ text: 'Previous message' }] },
|
||||
{ role: 'model', parts: [{ text: 'Previous response' }] },
|
||||
];
|
||||
|
||||
const contentGeneratorConfig = {
|
||||
model: 'test-model',
|
||||
apiKey: 'test-key',
|
||||
vertexai: false,
|
||||
authType: AuthType.USE_GEMINI,
|
||||
};
|
||||
|
||||
await client.initialize(contentGeneratorConfig, extraHistory);
|
||||
|
||||
expect(mockStartChat).toHaveBeenCalledWith(extraHistory, 'test-model');
|
||||
});
|
||||
|
||||
it('should use empty array when no extraHistory is provided', async () => {
|
||||
const mockStartChat = vi.fn().mockResolvedValue({});
|
||||
client['startChat'] = mockStartChat;
|
||||
|
||||
const contentGeneratorConfig = {
|
||||
model: 'test-model',
|
||||
apiKey: 'test-key',
|
||||
vertexai: false,
|
||||
authType: AuthType.USE_GEMINI,
|
||||
};
|
||||
|
||||
await client.initialize(contentGeneratorConfig);
|
||||
|
||||
expect(mockStartChat).toHaveBeenCalledWith([], 'test-model');
|
||||
});
|
||||
});
|
||||
|
||||
describe('reinitialize', () => {
|
||||
it('should reinitialize with preserved user history', async () => {
|
||||
// Mock the initialize method
|
||||
const mockInitialize = vi.fn().mockResolvedValue(undefined);
|
||||
client['initialize'] = mockInitialize;
|
||||
|
||||
// Set up initial history with environment context + user messages
|
||||
const mockHistory = [
|
||||
{ role: 'user', parts: [{ text: 'Environment context' }] },
|
||||
{ role: 'model', parts: [{ text: 'Got it. Thanks for the context!' }] },
|
||||
{ role: 'user', parts: [{ text: 'User message 1' }] },
|
||||
{ role: 'model', parts: [{ text: 'Model response 1' }] },
|
||||
];
|
||||
|
||||
const mockChat = {
|
||||
getHistory: vi.fn().mockReturnValue(mockHistory),
|
||||
};
|
||||
client['chat'] = mockChat as unknown as GeminiChat;
|
||||
client['getHistory'] = vi.fn().mockReturnValue(mockHistory);
|
||||
|
||||
await client.reinitialize();
|
||||
|
||||
// Should call initialize with preserved user history (excluding first 2 env messages)
|
||||
expect(mockInitialize).toHaveBeenCalledWith(
|
||||
expect.any(Object), // contentGeneratorConfig
|
||||
[
|
||||
{ role: 'user', parts: [{ text: 'User message 1' }] },
|
||||
{ role: 'model', parts: [{ text: 'Model response 1' }] },
|
||||
],
|
||||
);
|
||||
});
|
||||
|
||||
it('should not throw error when chat is not initialized', async () => {
|
||||
client['chat'] = undefined;
|
||||
|
||||
await expect(client.reinitialize()).resolves.not.toThrow();
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -29,6 +29,7 @@ import {
|
||||
makeChatCompressionEvent,
|
||||
NextSpeakerCheckEvent,
|
||||
} from '../telemetry/types.js';
|
||||
import { TaskTool } from '../tools/task.js';
|
||||
import {
|
||||
getDirectoryContextString,
|
||||
getEnvironmentContext,
|
||||
@@ -137,13 +138,24 @@ export class GeminiClient {
|
||||
this.lastPromptId = this.config.getSessionId();
|
||||
}
|
||||
|
||||
async initialize(contentGeneratorConfig: ContentGeneratorConfig) {
|
||||
async initialize(
|
||||
contentGeneratorConfig: ContentGeneratorConfig,
|
||||
extraHistory?: Content[],
|
||||
) {
|
||||
this.contentGenerator = await createContentGenerator(
|
||||
contentGeneratorConfig,
|
||||
this.config,
|
||||
this.config.getSessionId(),
|
||||
);
|
||||
this.chat = await this.startChat();
|
||||
/**
|
||||
* Always take the model from contentGeneratorConfig to initialize,
|
||||
* despite the `this.config.contentGeneratorConfig` is not updated yet because in
|
||||
* `Config` it will not be updated until the initialization is successful.
|
||||
*/
|
||||
this.chat = await this.startChat(
|
||||
extraHistory || [],
|
||||
contentGeneratorConfig.model,
|
||||
);
|
||||
}
|
||||
|
||||
getContentGenerator(): ContentGenerator {
|
||||
@@ -216,6 +228,28 @@ export class GeminiClient {
|
||||
this.chat = await this.startChat();
|
||||
}
|
||||
|
||||
/**
|
||||
* Reinitializes the chat with the current contentGeneratorConfig while preserving chat history.
|
||||
* This creates a new chat object using the existing history and updated configuration.
|
||||
* Should be called when configuration changes (model, auth, etc.) to ensure consistency.
|
||||
*/
|
||||
async reinitialize(): Promise<void> {
|
||||
if (!this.chat) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Preserve the current chat history (excluding environment context)
|
||||
const currentHistory = this.getHistory();
|
||||
// Remove the initial environment context (first 2 messages: user env + model acknowledgment)
|
||||
const userHistory = currentHistory.slice(2);
|
||||
|
||||
// Get current content generator config and reinitialize with preserved history
|
||||
const contentGeneratorConfig = this.config.getContentGeneratorConfig();
|
||||
if (contentGeneratorConfig) {
|
||||
await this.initialize(contentGeneratorConfig, userHistory);
|
||||
}
|
||||
}
|
||||
|
||||
async addDirectoryContext(): Promise<void> {
|
||||
if (!this.chat) {
|
||||
return;
|
||||
@@ -227,7 +261,10 @@ export class GeminiClient {
|
||||
});
|
||||
}
|
||||
|
||||
async startChat(extraHistory?: Content[]): Promise<GeminiChat> {
|
||||
async startChat(
|
||||
extraHistory?: Content[],
|
||||
model?: string,
|
||||
): Promise<GeminiChat> {
|
||||
this.forceFullIdeContext = true;
|
||||
this.hasFailedCompressionAttempt = false;
|
||||
const envParts = await getEnvironmentContext(this.config);
|
||||
@@ -247,9 +284,13 @@ export class GeminiClient {
|
||||
];
|
||||
try {
|
||||
const userMemory = this.config.getUserMemory();
|
||||
const systemInstruction = getCoreSystemPrompt(userMemory);
|
||||
const systemInstruction = getCoreSystemPrompt(
|
||||
userMemory,
|
||||
{},
|
||||
model || this.config.getModel(),
|
||||
);
|
||||
const generateContentConfigWithThinking = isThinkingSupported(
|
||||
this.config.getModel(),
|
||||
model || this.config.getModel(),
|
||||
)
|
||||
? {
|
||||
...this.generateContentConfig,
|
||||
@@ -455,7 +496,8 @@ export class GeminiClient {
|
||||
turns: number = MAX_TURNS,
|
||||
originalModel?: string,
|
||||
): AsyncGenerator<ServerGeminiStreamEvent, Turn> {
|
||||
if (this.lastPromptId !== prompt_id) {
|
||||
const isNewPrompt = this.lastPromptId !== prompt_id;
|
||||
if (isNewPrompt) {
|
||||
this.loopDetector.reset(prompt_id);
|
||||
this.lastPromptId = prompt_id;
|
||||
}
|
||||
@@ -488,7 +530,11 @@ export class GeminiClient {
|
||||
// Get all the content that would be sent in an API call
|
||||
const currentHistory = this.getChat().getHistory(true);
|
||||
const userMemory = this.config.getUserMemory();
|
||||
const systemPrompt = getCoreSystemPrompt(userMemory);
|
||||
const systemPrompt = getCoreSystemPrompt(
|
||||
userMemory,
|
||||
{},
|
||||
this.config.getModel(),
|
||||
);
|
||||
const environment = await getEnvironmentContext(this.config);
|
||||
|
||||
// Create a mock request content to count total tokens
|
||||
@@ -552,6 +598,24 @@ export class GeminiClient {
|
||||
this.forceFullIdeContext = false;
|
||||
}
|
||||
|
||||
if (isNewPrompt) {
|
||||
const taskTool = this.config.getToolRegistry().getTool(TaskTool.Name);
|
||||
const subagents = (
|
||||
await this.config.getSubagentManager().listSubagents()
|
||||
).filter((subagent) => subagent.level !== 'builtin');
|
||||
|
||||
if (taskTool && subagents.length > 0) {
|
||||
this.getChat().addHistory({
|
||||
role: 'user',
|
||||
parts: [
|
||||
{
|
||||
text: `<system-reminder>You have powerful specialized agents at your disposal, available agent types are: ${subagents.map((subagent) => subagent.name).join(', ')}. PROACTIVELY use the ${TaskTool.Name} tool to delegate user's task to appropriate agent when user's task matches agent capabilities. Ignore this message if user's task is not relevant to any agent. This message is for internal use only. Do not mention this to user in your response.</system-reminder>`,
|
||||
},
|
||||
],
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
const turn = new Turn(this.getChat(), prompt_id);
|
||||
|
||||
if (!this.config.getSkipLoopDetection()) {
|
||||
@@ -624,14 +688,18 @@ export class GeminiClient {
|
||||
model?: string,
|
||||
config: GenerateContentConfig = {},
|
||||
): Promise<Record<string, unknown>> {
|
||||
// Use current model from config instead of hardcoded Flash model
|
||||
const modelToUse =
|
||||
model || this.config.getModel() || DEFAULT_GEMINI_FLASH_MODEL;
|
||||
/**
|
||||
* TODO: ensure `model` consistency among GeminiClient, GeminiChat, and ContentGenerator
|
||||
* `model` passed to generateContent is not respected as we always use contentGenerator
|
||||
* We should ignore model for now because some calls use `DEFAULT_GEMINI_FLASH_MODEL`
|
||||
* which is not available as `qwen3-coder-flash`
|
||||
*/
|
||||
const modelToUse = this.config.getModel() || DEFAULT_GEMINI_FLASH_MODEL;
|
||||
try {
|
||||
const userMemory = this.config.getUserMemory();
|
||||
const finalSystemInstruction = config.systemInstruction
|
||||
? getCustomSystemPrompt(config.systemInstruction, userMemory)
|
||||
: getCoreSystemPrompt(userMemory);
|
||||
: getCoreSystemPrompt(userMemory, {}, modelToUse);
|
||||
|
||||
const requestConfig = {
|
||||
abortSignal,
|
||||
@@ -722,7 +790,7 @@ export class GeminiClient {
|
||||
const userMemory = this.config.getUserMemory();
|
||||
const finalSystemInstruction = generationConfig.systemInstruction
|
||||
? getCustomSystemPrompt(generationConfig.systemInstruction, userMemory)
|
||||
: getCoreSystemPrompt(userMemory);
|
||||
: getCoreSystemPrompt(userMemory, {}, this.config.getModel());
|
||||
|
||||
const requestConfig: GenerateContentConfig = {
|
||||
abortSignal,
|
||||
|
||||
@@ -500,7 +500,7 @@ export class GeminiChat {
|
||||
if (error instanceof Error && error.message) {
|
||||
if (isSchemaDepthError(error.message)) return false;
|
||||
if (error.message.includes('429')) return true;
|
||||
if (error.message.match(/5\d{2}/)) return true;
|
||||
if (error.message.match(/^5\d{2}/)) return true;
|
||||
}
|
||||
return false;
|
||||
},
|
||||
|
||||
@@ -755,4 +755,84 @@ describe('Logger', () => {
|
||||
expect(logger['messageId']).toBe(0);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Model Switch Logging', () => {
|
||||
it('should log model switch events correctly', async () => {
|
||||
const testSessionId = 'test-session-model-switch';
|
||||
const logger = new Logger(testSessionId, new Storage(process.cwd()));
|
||||
await logger.initialize();
|
||||
|
||||
const modelSwitchEvent = {
|
||||
fromModel: 'qwen3-coder-plus',
|
||||
toModel: 'qwen-vl-max-latest',
|
||||
reason: 'vision_auto_switch' as const,
|
||||
context: 'YOLO mode auto-switch for image content',
|
||||
};
|
||||
|
||||
await logger.logModelSwitch(modelSwitchEvent);
|
||||
|
||||
// Read the log file to verify the entry was written
|
||||
const logContent = await fs.readFile(TEST_LOG_FILE_PATH, 'utf-8');
|
||||
const logs: LogEntry[] = JSON.parse(logContent);
|
||||
|
||||
const modelSwitchLog = logs.find(
|
||||
(log) =>
|
||||
log.sessionId === testSessionId &&
|
||||
log.type === MessageSenderType.MODEL_SWITCH,
|
||||
);
|
||||
|
||||
expect(modelSwitchLog).toBeDefined();
|
||||
expect(modelSwitchLog!.type).toBe(MessageSenderType.MODEL_SWITCH);
|
||||
|
||||
const loggedEvent = JSON.parse(modelSwitchLog!.message);
|
||||
expect(loggedEvent.fromModel).toBe('qwen3-coder-plus');
|
||||
expect(loggedEvent.toModel).toBe('qwen-vl-max-latest');
|
||||
expect(loggedEvent.reason).toBe('vision_auto_switch');
|
||||
expect(loggedEvent.context).toBe(
|
||||
'YOLO mode auto-switch for image content',
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle multiple model switch events', async () => {
|
||||
const testSessionId = 'test-session-multiple-switches';
|
||||
const logger = new Logger(testSessionId, new Storage(process.cwd()));
|
||||
await logger.initialize();
|
||||
|
||||
// Log first switch
|
||||
await logger.logModelSwitch({
|
||||
fromModel: 'qwen3-coder-plus',
|
||||
toModel: 'qwen-vl-max-latest',
|
||||
reason: 'vision_auto_switch',
|
||||
context: 'Auto-switch for image',
|
||||
});
|
||||
|
||||
// Log second switch (restore)
|
||||
await logger.logModelSwitch({
|
||||
fromModel: 'qwen-vl-max-latest',
|
||||
toModel: 'qwen3-coder-plus',
|
||||
reason: 'vision_auto_switch',
|
||||
context: 'Restoring original model',
|
||||
});
|
||||
|
||||
// Read the log file to verify both entries were written
|
||||
const logContent = await fs.readFile(TEST_LOG_FILE_PATH, 'utf-8');
|
||||
const logs: LogEntry[] = JSON.parse(logContent);
|
||||
|
||||
const modelSwitchLogs = logs.filter(
|
||||
(log) =>
|
||||
log.sessionId === testSessionId &&
|
||||
log.type === MessageSenderType.MODEL_SWITCH,
|
||||
);
|
||||
|
||||
expect(modelSwitchLogs).toHaveLength(2);
|
||||
|
||||
const firstSwitch = JSON.parse(modelSwitchLogs[0].message);
|
||||
expect(firstSwitch.fromModel).toBe('qwen3-coder-plus');
|
||||
expect(firstSwitch.toModel).toBe('qwen-vl-max-latest');
|
||||
|
||||
const secondSwitch = JSON.parse(modelSwitchLogs[1].message);
|
||||
expect(secondSwitch.fromModel).toBe('qwen-vl-max-latest');
|
||||
expect(secondSwitch.toModel).toBe('qwen3-coder-plus');
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -13,6 +13,7 @@ const LOG_FILE_NAME = 'logs.json';
|
||||
|
||||
export enum MessageSenderType {
|
||||
USER = 'user',
|
||||
MODEL_SWITCH = 'model_switch',
|
||||
}
|
||||
|
||||
export interface LogEntry {
|
||||
@@ -23,6 +24,13 @@ export interface LogEntry {
|
||||
message: string;
|
||||
}
|
||||
|
||||
export interface ModelSwitchEvent {
|
||||
fromModel: string;
|
||||
toModel: string;
|
||||
reason: 'vision_auto_switch' | 'manual' | 'fallback' | 'other';
|
||||
context?: string;
|
||||
}
|
||||
|
||||
// This regex matches any character that is NOT a letter (a-z, A-Z),
|
||||
// a number (0-9), a hyphen (-), an underscore (_), or a dot (.).
|
||||
|
||||
@@ -270,6 +278,17 @@ export class Logger {
|
||||
}
|
||||
}
|
||||
|
||||
async logModelSwitch(event: ModelSwitchEvent): Promise<void> {
|
||||
const message = JSON.stringify({
|
||||
fromModel: event.fromModel,
|
||||
toModel: event.toModel,
|
||||
reason: event.reason,
|
||||
context: event.context,
|
||||
});
|
||||
|
||||
await this.logMessage(MessageSenderType.MODEL_SWITCH, message);
|
||||
}
|
||||
|
||||
private _checkpointPath(tag: string): string {
|
||||
if (!tag.length) {
|
||||
throw new Error('No checkpoint tag specified.');
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -5,6 +5,37 @@
|
||||
*/
|
||||
|
||||
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
|
||||
|
||||
// Mock the request tokenizer module BEFORE importing the class that uses it
|
||||
const mockTokenizer = {
|
||||
calculateTokens: vi.fn().mockResolvedValue({
|
||||
totalTokens: 50,
|
||||
breakdown: {
|
||||
textTokens: 50,
|
||||
imageTokens: 0,
|
||||
audioTokens: 0,
|
||||
otherTokens: 0,
|
||||
},
|
||||
processingTime: 1,
|
||||
}),
|
||||
dispose: vi.fn(),
|
||||
};
|
||||
|
||||
vi.mock('../../../utils/request-tokenizer/index.js', () => ({
|
||||
getDefaultTokenizer: vi.fn(() => mockTokenizer),
|
||||
DefaultRequestTokenizer: vi.fn(() => mockTokenizer),
|
||||
disposeDefaultTokenizer: vi.fn(),
|
||||
}));
|
||||
|
||||
// Mock tiktoken as well for completeness
|
||||
vi.mock('tiktoken', () => ({
|
||||
get_encoding: vi.fn(() => ({
|
||||
encode: vi.fn(() => new Array(50)), // Mock 50 tokens
|
||||
free: vi.fn(),
|
||||
})),
|
||||
}));
|
||||
|
||||
// Now import the modules that depend on the mocked modules
|
||||
import { OpenAIContentGenerator } from './openaiContentGenerator.js';
|
||||
import type { Config } from '../../config/config.js';
|
||||
import { AuthType } from '../contentGenerator.js';
|
||||
@@ -15,14 +46,6 @@ import type {
|
||||
import type { OpenAICompatibleProvider } from './provider/index.js';
|
||||
import type OpenAI from 'openai';
|
||||
|
||||
// Mock tiktoken
|
||||
vi.mock('tiktoken', () => ({
|
||||
get_encoding: vi.fn().mockReturnValue({
|
||||
encode: vi.fn().mockReturnValue(new Array(50)), // Mock 50 tokens
|
||||
free: vi.fn(),
|
||||
}),
|
||||
}));
|
||||
|
||||
describe('OpenAIContentGenerator (Refactored)', () => {
|
||||
let generator: OpenAIContentGenerator;
|
||||
let mockConfig: Config;
|
||||
|
||||
@@ -13,6 +13,7 @@ import type { PipelineConfig } from './pipeline.js';
|
||||
import { ContentGenerationPipeline } from './pipeline.js';
|
||||
import { DefaultTelemetryService } from './telemetryService.js';
|
||||
import { EnhancedErrorHandler } from './errorHandler.js';
|
||||
import { getDefaultTokenizer } from '../../utils/request-tokenizer/index.js';
|
||||
import type { ContentGeneratorConfig } from '../contentGenerator.js';
|
||||
|
||||
export class OpenAIContentGenerator implements ContentGenerator {
|
||||
@@ -71,27 +72,30 @@ export class OpenAIContentGenerator implements ContentGenerator {
|
||||
async countTokens(
|
||||
request: CountTokensParameters,
|
||||
): Promise<CountTokensResponse> {
|
||||
// Use tiktoken for accurate token counting
|
||||
const content = JSON.stringify(request.contents);
|
||||
let totalTokens = 0;
|
||||
|
||||
try {
|
||||
const { get_encoding } = await import('tiktoken');
|
||||
const encoding = get_encoding('cl100k_base'); // GPT-4 encoding, but estimate for qwen
|
||||
totalTokens = encoding.encode(content).length;
|
||||
encoding.free();
|
||||
// Use the new high-performance request tokenizer
|
||||
const tokenizer = getDefaultTokenizer();
|
||||
const result = await tokenizer.calculateTokens(request, {
|
||||
textEncoding: 'cl100k_base', // Use GPT-4 encoding for consistency
|
||||
});
|
||||
|
||||
return {
|
||||
totalTokens: result.totalTokens,
|
||||
};
|
||||
} catch (error) {
|
||||
console.warn(
|
||||
'Failed to load tiktoken, falling back to character approximation:',
|
||||
'Failed to calculate tokens with new tokenizer, falling back to simple method:',
|
||||
error,
|
||||
);
|
||||
// Fallback: rough approximation using character count
|
||||
totalTokens = Math.ceil(content.length / 4); // Rough estimate: 1 token ≈ 4 characters
|
||||
}
|
||||
|
||||
return {
|
||||
totalTokens,
|
||||
};
|
||||
// Fallback to original simple method
|
||||
const content = JSON.stringify(request.contents);
|
||||
const totalTokens = Math.ceil(content.length / 4); // Rough estimate: 1 token ≈ 4 characters
|
||||
|
||||
return {
|
||||
totalTokens,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
async embedContent(
|
||||
|
||||
@@ -1105,5 +1105,164 @@ describe('ContentGenerationPipeline', () => {
|
||||
expect.any(Array),
|
||||
);
|
||||
});
|
||||
|
||||
it('should collect all OpenAI chunks for logging even when Gemini responses are filtered', async () => {
|
||||
// Create chunks that would produce empty Gemini responses (partial tool calls)
|
||||
const partialToolCallChunk1: OpenAI.Chat.ChatCompletionChunk = {
|
||||
id: 'chunk-1',
|
||||
object: 'chat.completion.chunk',
|
||||
created: Date.now(),
|
||||
model: 'test-model',
|
||||
choices: [
|
||||
{
|
||||
index: 0,
|
||||
delta: {
|
||||
tool_calls: [
|
||||
{
|
||||
index: 0,
|
||||
id: 'call_123',
|
||||
type: 'function',
|
||||
function: { name: 'test_function', arguments: '{"par' },
|
||||
},
|
||||
],
|
||||
},
|
||||
finish_reason: null,
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
const partialToolCallChunk2: OpenAI.Chat.ChatCompletionChunk = {
|
||||
id: 'chunk-2',
|
||||
object: 'chat.completion.chunk',
|
||||
created: Date.now(),
|
||||
model: 'test-model',
|
||||
choices: [
|
||||
{
|
||||
index: 0,
|
||||
delta: {
|
||||
tool_calls: [
|
||||
{
|
||||
index: 0,
|
||||
function: { arguments: 'am": "value"}' },
|
||||
},
|
||||
],
|
||||
},
|
||||
finish_reason: null,
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
const finishChunk: OpenAI.Chat.ChatCompletionChunk = {
|
||||
id: 'chunk-3',
|
||||
object: 'chat.completion.chunk',
|
||||
created: Date.now(),
|
||||
model: 'test-model',
|
||||
choices: [
|
||||
{
|
||||
index: 0,
|
||||
delta: {},
|
||||
finish_reason: 'tool_calls',
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
// Mock empty Gemini responses for partial chunks (they get filtered)
|
||||
const emptyGeminiResponse1 = new GenerateContentResponse();
|
||||
emptyGeminiResponse1.candidates = [
|
||||
{
|
||||
content: { parts: [], role: 'model' },
|
||||
index: 0,
|
||||
safetyRatings: [],
|
||||
},
|
||||
];
|
||||
|
||||
const emptyGeminiResponse2 = new GenerateContentResponse();
|
||||
emptyGeminiResponse2.candidates = [
|
||||
{
|
||||
content: { parts: [], role: 'model' },
|
||||
index: 0,
|
||||
safetyRatings: [],
|
||||
},
|
||||
];
|
||||
|
||||
// Mock final Gemini response with tool call
|
||||
const finalGeminiResponse = new GenerateContentResponse();
|
||||
finalGeminiResponse.candidates = [
|
||||
{
|
||||
content: {
|
||||
parts: [
|
||||
{
|
||||
functionCall: {
|
||||
id: 'call_123',
|
||||
name: 'test_function',
|
||||
args: { param: 'value' },
|
||||
},
|
||||
},
|
||||
],
|
||||
role: 'model',
|
||||
},
|
||||
finishReason: FinishReason.STOP,
|
||||
index: 0,
|
||||
safetyRatings: [],
|
||||
},
|
||||
];
|
||||
|
||||
// Setup converter mocks
|
||||
(mockConverter.convertGeminiRequestToOpenAI as Mock).mockReturnValue([
|
||||
{ role: 'user', content: 'test' },
|
||||
]);
|
||||
(mockConverter.convertOpenAIChunkToGemini as Mock)
|
||||
.mockReturnValueOnce(emptyGeminiResponse1) // First partial chunk -> empty response
|
||||
.mockReturnValueOnce(emptyGeminiResponse2) // Second partial chunk -> empty response
|
||||
.mockReturnValueOnce(finalGeminiResponse); // Finish chunk -> complete response
|
||||
|
||||
// Mock stream
|
||||
const mockStream = {
|
||||
async *[Symbol.asyncIterator]() {
|
||||
yield partialToolCallChunk1;
|
||||
yield partialToolCallChunk2;
|
||||
yield finishChunk;
|
||||
},
|
||||
};
|
||||
|
||||
(mockClient.chat.completions.create as Mock).mockResolvedValue(
|
||||
mockStream,
|
||||
);
|
||||
|
||||
const request: GenerateContentParameters = {
|
||||
model: 'test-model',
|
||||
contents: [{ role: 'user', parts: [{ text: 'test' }] }],
|
||||
};
|
||||
|
||||
// Collect responses
|
||||
const responses: GenerateContentResponse[] = [];
|
||||
const resultGenerator = await pipeline.executeStream(
|
||||
request,
|
||||
'test-prompt-id',
|
||||
);
|
||||
for await (const response of resultGenerator) {
|
||||
responses.push(response);
|
||||
}
|
||||
|
||||
// Should only yield the final response (empty ones are filtered)
|
||||
expect(responses).toHaveLength(1);
|
||||
expect(responses[0]).toBe(finalGeminiResponse);
|
||||
|
||||
// Verify telemetry was called with ALL OpenAI chunks, including the filtered ones
|
||||
expect(mockTelemetryService.logStreamingSuccess).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
model: 'test-model',
|
||||
duration: expect.any(Number),
|
||||
userPromptId: 'test-prompt-id',
|
||||
authType: 'openai',
|
||||
}),
|
||||
[finalGeminiResponse], // Only the non-empty Gemini response
|
||||
expect.objectContaining({
|
||||
model: 'test-model',
|
||||
messages: [{ role: 'user', content: 'test' }],
|
||||
}),
|
||||
[partialToolCallChunk1, partialToolCallChunk2, finishChunk], // ALL OpenAI chunks
|
||||
);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -10,14 +10,11 @@ import {
|
||||
GenerateContentResponse,
|
||||
} from '@google/genai';
|
||||
import type { Config } from '../../config/config.js';
|
||||
import { type ContentGeneratorConfig } from '../contentGenerator.js';
|
||||
import { type OpenAICompatibleProvider } from './provider/index.js';
|
||||
import type { ContentGeneratorConfig } from '../contentGenerator.js';
|
||||
import type { OpenAICompatibleProvider } from './provider/index.js';
|
||||
import { OpenAIContentConverter } from './converter.js';
|
||||
import {
|
||||
type TelemetryService,
|
||||
type RequestContext,
|
||||
} from './telemetryService.js';
|
||||
import { type ErrorHandler } from './errorHandler.js';
|
||||
import type { TelemetryService, RequestContext } from './telemetryService.js';
|
||||
import type { ErrorHandler } from './errorHandler.js';
|
||||
|
||||
export interface PipelineConfig {
|
||||
cliConfig: Config;
|
||||
@@ -101,7 +98,7 @@ export class ContentGenerationPipeline {
|
||||
* 2. Filter empty responses
|
||||
* 3. Handle chunk merging for providers that send finishReason and usageMetadata separately
|
||||
* 4. Collect both formats for logging
|
||||
* 5. Handle success/error logging with original OpenAI format
|
||||
* 5. Handle success/error logging
|
||||
*/
|
||||
private async *processStreamWithLogging(
|
||||
stream: AsyncIterable<OpenAI.Chat.ChatCompletionChunk>,
|
||||
@@ -121,6 +118,9 @@ export class ContentGenerationPipeline {
|
||||
try {
|
||||
// Stage 2a: Convert and yield each chunk while preserving original
|
||||
for await (const chunk of stream) {
|
||||
// Always collect OpenAI chunks for logging, regardless of Gemini conversion result
|
||||
collectedOpenAIChunks.push(chunk);
|
||||
|
||||
const response = this.converter.convertOpenAIChunkToGemini(chunk);
|
||||
|
||||
// Stage 2b: Filter empty responses to avoid downstream issues
|
||||
@@ -135,9 +135,7 @@ export class ContentGenerationPipeline {
|
||||
// Stage 2c: Handle chunk merging for providers that send finishReason and usageMetadata separately
|
||||
const shouldYield = this.handleChunkMerging(
|
||||
response,
|
||||
chunk,
|
||||
collectedGeminiResponses,
|
||||
collectedOpenAIChunks,
|
||||
(mergedResponse) => {
|
||||
pendingFinishResponse = mergedResponse;
|
||||
},
|
||||
@@ -169,19 +167,11 @@ export class ContentGenerationPipeline {
|
||||
collectedOpenAIChunks,
|
||||
);
|
||||
} catch (error) {
|
||||
// Stage 2e: Stream failed - handle error and logging
|
||||
context.duration = Date.now() - context.startTime;
|
||||
|
||||
// Clear streaming tool calls on error to prevent data pollution
|
||||
this.converter.resetStreamingToolCalls();
|
||||
|
||||
await this.config.telemetryService.logError(
|
||||
context,
|
||||
error,
|
||||
openaiRequest,
|
||||
);
|
||||
|
||||
this.config.errorHandler.handle(error, context, request);
|
||||
// Use shared error handling logic
|
||||
await this.handleError(error, context, request);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -193,17 +183,13 @@ export class ContentGenerationPipeline {
|
||||
* finishReason and the most up-to-date usage information from any provider pattern.
|
||||
*
|
||||
* @param response Current Gemini response
|
||||
* @param chunk Current OpenAI chunk
|
||||
* @param collectedGeminiResponses Array to collect responses for logging
|
||||
* @param collectedOpenAIChunks Array to collect chunks for logging
|
||||
* @param setPendingFinish Callback to set pending finish response
|
||||
* @returns true if the response should be yielded, false if it should be held for merging
|
||||
*/
|
||||
private handleChunkMerging(
|
||||
response: GenerateContentResponse,
|
||||
chunk: OpenAI.Chat.ChatCompletionChunk,
|
||||
collectedGeminiResponses: GenerateContentResponse[],
|
||||
collectedOpenAIChunks: OpenAI.Chat.ChatCompletionChunk[],
|
||||
setPendingFinish: (response: GenerateContentResponse) => void,
|
||||
): boolean {
|
||||
const isFinishChunk = response.candidates?.[0]?.finishReason;
|
||||
@@ -217,7 +203,6 @@ export class ContentGenerationPipeline {
|
||||
if (isFinishChunk) {
|
||||
// This is a finish reason chunk
|
||||
collectedGeminiResponses.push(response);
|
||||
collectedOpenAIChunks.push(chunk);
|
||||
setPendingFinish(response);
|
||||
return false; // Don't yield yet, wait for potential subsequent chunks to merge
|
||||
} else if (hasPendingFinish) {
|
||||
@@ -239,7 +224,6 @@ export class ContentGenerationPipeline {
|
||||
// Update the collected responses with the merged response
|
||||
collectedGeminiResponses[collectedGeminiResponses.length - 1] =
|
||||
mergedResponse;
|
||||
collectedOpenAIChunks.push(chunk);
|
||||
|
||||
setPendingFinish(mergedResponse);
|
||||
return true; // Yield the merged response
|
||||
@@ -247,7 +231,6 @@ export class ContentGenerationPipeline {
|
||||
|
||||
// Normal chunk - collect and yield
|
||||
collectedGeminiResponses.push(response);
|
||||
collectedOpenAIChunks.push(chunk);
|
||||
return true;
|
||||
}
|
||||
|
||||
@@ -365,25 +348,59 @@ export class ContentGenerationPipeline {
|
||||
context.duration = Date.now() - context.startTime;
|
||||
return result;
|
||||
} catch (error) {
|
||||
context.duration = Date.now() - context.startTime;
|
||||
|
||||
// Log error
|
||||
const openaiRequest = await this.buildRequest(
|
||||
// Use shared error handling logic
|
||||
return await this.handleError(
|
||||
error,
|
||||
context,
|
||||
request,
|
||||
userPromptId,
|
||||
isStreaming,
|
||||
);
|
||||
await this.config.telemetryService.logError(
|
||||
context,
|
||||
error,
|
||||
openaiRequest,
|
||||
);
|
||||
|
||||
// Handle and throw enhanced error
|
||||
this.config.errorHandler.handle(error, context, request);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Shared error handling logic for both executeWithErrorHandling and processStreamWithLogging
|
||||
* This centralizes the common error processing steps to avoid duplication
|
||||
*/
|
||||
private async handleError(
|
||||
error: unknown,
|
||||
context: RequestContext,
|
||||
request: GenerateContentParameters,
|
||||
userPromptId?: string,
|
||||
isStreaming?: boolean,
|
||||
): Promise<never> {
|
||||
context.duration = Date.now() - context.startTime;
|
||||
|
||||
// Build request for logging (may fail, but we still want to log the error)
|
||||
let openaiRequest: OpenAI.Chat.ChatCompletionCreateParams;
|
||||
try {
|
||||
if (userPromptId !== undefined && isStreaming !== undefined) {
|
||||
openaiRequest = await this.buildRequest(
|
||||
request,
|
||||
userPromptId,
|
||||
isStreaming,
|
||||
);
|
||||
} else {
|
||||
// For processStreamWithLogging, we don't have userPromptId/isStreaming,
|
||||
// so create a minimal request
|
||||
openaiRequest = {
|
||||
model: this.contentGeneratorConfig.model,
|
||||
messages: [],
|
||||
};
|
||||
}
|
||||
} catch (_buildError) {
|
||||
// If we can't build the request, create a minimal one for logging
|
||||
openaiRequest = {
|
||||
model: this.contentGeneratorConfig.model,
|
||||
messages: [],
|
||||
};
|
||||
}
|
||||
|
||||
await this.config.telemetryService.logError(context, error, openaiRequest);
|
||||
this.config.errorHandler.handle(error, context, request);
|
||||
}
|
||||
|
||||
/**
|
||||
* Create request context with common properties
|
||||
*/
|
||||
|
||||
@@ -560,4 +560,146 @@ describe('DashScopeOpenAICompatibleProvider', () => {
|
||||
]);
|
||||
});
|
||||
});
|
||||
|
||||
describe('output token limits', () => {
|
||||
it('should limit max_tokens when it exceeds model limit for qwen3-coder-plus', () => {
|
||||
const request: OpenAI.Chat.ChatCompletionCreateParams = {
|
||||
model: 'qwen3-coder-plus',
|
||||
messages: [{ role: 'user', content: 'Hello' }],
|
||||
max_tokens: 100000, // Exceeds the 65536 limit
|
||||
};
|
||||
|
||||
const result = provider.buildRequest(request, 'test-prompt-id');
|
||||
|
||||
expect(result.max_tokens).toBe(65536); // Should be limited to model's output limit
|
||||
});
|
||||
|
||||
it('should limit max_tokens when it exceeds model limit for qwen-vl-max-latest', () => {
|
||||
const request: OpenAI.Chat.ChatCompletionCreateParams = {
|
||||
model: 'qwen-vl-max-latest',
|
||||
messages: [{ role: 'user', content: 'Hello' }],
|
||||
max_tokens: 20000, // Exceeds the 8192 limit
|
||||
};
|
||||
|
||||
const result = provider.buildRequest(request, 'test-prompt-id');
|
||||
|
||||
expect(result.max_tokens).toBe(8192); // Should be limited to model's output limit
|
||||
});
|
||||
|
||||
it('should not modify max_tokens when it is within model limit', () => {
|
||||
const request: OpenAI.Chat.ChatCompletionCreateParams = {
|
||||
model: 'qwen3-coder-plus',
|
||||
messages: [{ role: 'user', content: 'Hello' }],
|
||||
max_tokens: 1000, // Within the 65536 limit
|
||||
};
|
||||
|
||||
const result = provider.buildRequest(request, 'test-prompt-id');
|
||||
|
||||
expect(result.max_tokens).toBe(1000); // Should remain unchanged
|
||||
});
|
||||
|
||||
it('should not add max_tokens when not present in request', () => {
|
||||
const request: OpenAI.Chat.ChatCompletionCreateParams = {
|
||||
model: 'qwen3-coder-plus',
|
||||
messages: [{ role: 'user', content: 'Hello' }],
|
||||
// No max_tokens parameter
|
||||
};
|
||||
|
||||
const result = provider.buildRequest(request, 'test-prompt-id');
|
||||
|
||||
expect(result.max_tokens).toBeUndefined(); // Should remain undefined
|
||||
});
|
||||
|
||||
it('should handle null max_tokens parameter', () => {
|
||||
const request: OpenAI.Chat.ChatCompletionCreateParams = {
|
||||
model: 'qwen3-coder-plus',
|
||||
messages: [{ role: 'user', content: 'Hello' }],
|
||||
max_tokens: null,
|
||||
};
|
||||
|
||||
const result = provider.buildRequest(request, 'test-prompt-id');
|
||||
|
||||
expect(result.max_tokens).toBeNull(); // Should remain null
|
||||
});
|
||||
|
||||
it('should use default output limit for unknown models', () => {
|
||||
const request: OpenAI.Chat.ChatCompletionCreateParams = {
|
||||
model: 'unknown-model',
|
||||
messages: [{ role: 'user', content: 'Hello' }],
|
||||
max_tokens: 10000, // Exceeds the default 4096 limit
|
||||
};
|
||||
|
||||
const result = provider.buildRequest(request, 'test-prompt-id');
|
||||
|
||||
expect(result.max_tokens).toBe(4096); // Should be limited to default output limit
|
||||
});
|
||||
|
||||
it('should preserve other request parameters when limiting max_tokens', () => {
|
||||
const request: OpenAI.Chat.ChatCompletionCreateParams = {
|
||||
model: 'qwen3-coder-plus',
|
||||
messages: [{ role: 'user', content: 'Hello' }],
|
||||
max_tokens: 100000, // Will be limited
|
||||
temperature: 0.8,
|
||||
top_p: 0.9,
|
||||
frequency_penalty: 0.1,
|
||||
presence_penalty: 0.2,
|
||||
stop: ['END'],
|
||||
user: 'test-user',
|
||||
};
|
||||
|
||||
const result = provider.buildRequest(request, 'test-prompt-id');
|
||||
|
||||
// max_tokens should be limited
|
||||
expect(result.max_tokens).toBe(65536);
|
||||
|
||||
// Other parameters should be preserved
|
||||
expect(result.temperature).toBe(0.8);
|
||||
expect(result.top_p).toBe(0.9);
|
||||
expect(result.frequency_penalty).toBe(0.1);
|
||||
expect(result.presence_penalty).toBe(0.2);
|
||||
expect(result.stop).toEqual(['END']);
|
||||
expect(result.user).toBe('test-user');
|
||||
});
|
||||
|
||||
it('should work with vision models and output token limits', () => {
|
||||
const request: OpenAI.Chat.ChatCompletionCreateParams = {
|
||||
model: 'qwen-vl-max-latest',
|
||||
messages: [
|
||||
{
|
||||
role: 'user',
|
||||
content: [
|
||||
{ type: 'text', text: 'Look at this image:' },
|
||||
{
|
||||
type: 'image_url',
|
||||
image_url: { url: 'https://example.com/image.jpg' },
|
||||
},
|
||||
],
|
||||
},
|
||||
],
|
||||
max_tokens: 20000, // Exceeds the 8192 limit
|
||||
};
|
||||
|
||||
const result = provider.buildRequest(request, 'test-prompt-id');
|
||||
|
||||
expect(result.max_tokens).toBe(8192); // Should be limited
|
||||
expect(
|
||||
(result as { vl_high_resolution_images?: boolean })
|
||||
.vl_high_resolution_images,
|
||||
).toBe(true); // Vision-specific parameter should be preserved
|
||||
});
|
||||
|
||||
it('should handle streaming requests with output token limits', () => {
|
||||
const request: OpenAI.Chat.ChatCompletionCreateParams = {
|
||||
model: 'qwen3-coder-plus',
|
||||
messages: [{ role: 'user', content: 'Hello' }],
|
||||
max_tokens: 100000, // Exceeds the 65536 limit
|
||||
stream: true,
|
||||
};
|
||||
|
||||
const result = provider.buildRequest(request, 'test-prompt-id');
|
||||
|
||||
expect(result.max_tokens).toBe(65536); // Should be limited
|
||||
expect(result.stream).toBe(true); // Streaming should be preserved
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -3,6 +3,7 @@ import type { Config } from '../../../config/config.js';
|
||||
import type { ContentGeneratorConfig } from '../../contentGenerator.js';
|
||||
import { AuthType } from '../../contentGenerator.js';
|
||||
import { DEFAULT_TIMEOUT, DEFAULT_MAX_RETRIES } from '../constants.js';
|
||||
import { tokenLimit } from '../../tokenLimits.js';
|
||||
import type {
|
||||
OpenAICompatibleProvider,
|
||||
DashScopeRequestMetadata,
|
||||
@@ -65,6 +66,19 @@ export class DashScopeOpenAICompatibleProvider
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Build and configure the request for DashScope API.
|
||||
*
|
||||
* This method applies DashScope-specific configurations including:
|
||||
* - Cache control for system and user messages
|
||||
* - Output token limits based on model capabilities
|
||||
* - Vision model specific parameters (vl_high_resolution_images)
|
||||
* - Request metadata for session tracking
|
||||
*
|
||||
* @param request - The original chat completion request parameters
|
||||
* @param userPromptId - Unique identifier for the user prompt for session tracking
|
||||
* @returns Configured request with DashScope-specific parameters applied
|
||||
*/
|
||||
buildRequest(
|
||||
request: OpenAI.Chat.ChatCompletionCreateParams,
|
||||
userPromptId: string,
|
||||
@@ -79,11 +93,28 @@ export class DashScopeOpenAICompatibleProvider
|
||||
messages = this.addDashScopeCacheControl(messages, cacheTarget);
|
||||
}
|
||||
|
||||
// Apply output token limits based on model capabilities
|
||||
// This ensures max_tokens doesn't exceed the model's maximum output limit
|
||||
const requestWithTokenLimits = this.applyOutputTokenLimit(
|
||||
request,
|
||||
request.model,
|
||||
);
|
||||
|
||||
if (request.model.startsWith('qwen-vl')) {
|
||||
return {
|
||||
...requestWithTokenLimits,
|
||||
messages,
|
||||
...(this.buildMetadata(userPromptId) || {}),
|
||||
/* @ts-expect-error dashscope exclusive */
|
||||
vl_high_resolution_images: true,
|
||||
} as OpenAI.Chat.ChatCompletionCreateParams;
|
||||
}
|
||||
|
||||
return {
|
||||
...request, // Preserve all original parameters including sampling params
|
||||
...requestWithTokenLimits, // Preserve all original parameters including sampling params and adjusted max_tokens
|
||||
messages,
|
||||
...(this.buildMetadata(userPromptId) || {}),
|
||||
};
|
||||
} as OpenAI.Chat.ChatCompletionCreateParams;
|
||||
}
|
||||
|
||||
buildMetadata(userPromptId: string): DashScopeRequestMetadata {
|
||||
@@ -236,6 +267,41 @@ export class DashScopeOpenAICompatibleProvider
|
||||
return contentArray;
|
||||
}
|
||||
|
||||
/**
|
||||
* Apply output token limit to a request's max_tokens parameter.
|
||||
*
|
||||
* Ensures that existing max_tokens parameters don't exceed the model's maximum output
|
||||
* token limit. Only modifies max_tokens when already present in the request.
|
||||
*
|
||||
* @param request - The chat completion request parameters
|
||||
* @param model - The model name to get the output token limit for
|
||||
* @returns The request with max_tokens adjusted to respect the model's limits (if present)
|
||||
*/
|
||||
private applyOutputTokenLimit<T extends { max_tokens?: number | null }>(
|
||||
request: T,
|
||||
model: string,
|
||||
): T {
|
||||
const currentMaxTokens = request.max_tokens;
|
||||
|
||||
// Only process if max_tokens is already present in the request
|
||||
if (currentMaxTokens === undefined || currentMaxTokens === null) {
|
||||
return request; // No max_tokens parameter, return unchanged
|
||||
}
|
||||
|
||||
const modelLimit = tokenLimit(model, 'output');
|
||||
|
||||
// If max_tokens exceeds the model limit, cap it to the model's limit
|
||||
if (currentMaxTokens > modelLimit) {
|
||||
return {
|
||||
...request,
|
||||
max_tokens: modelLimit,
|
||||
};
|
||||
}
|
||||
|
||||
// If max_tokens is within the limit, return the request unchanged
|
||||
return request;
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if cache control should be disabled based on configuration.
|
||||
*
|
||||
|
||||
@@ -364,6 +364,120 @@ describe('URL matching with trailing slash compatibility', () => {
|
||||
});
|
||||
});
|
||||
|
||||
describe('Model-specific tool call formats', () => {
|
||||
beforeEach(() => {
|
||||
vi.resetAllMocks();
|
||||
vi.stubEnv('SANDBOX', undefined);
|
||||
});
|
||||
|
||||
it('should use XML format for qwen3-coder model', () => {
|
||||
vi.mocked(isGitRepository).mockReturnValue(false);
|
||||
const prompt = getCoreSystemPrompt(undefined, undefined, 'qwen3-coder-7b');
|
||||
|
||||
// Should contain XML-style tool calls
|
||||
expect(prompt).toContain('<tool_call>');
|
||||
expect(prompt).toContain('<function=run_shell_command>');
|
||||
expect(prompt).toContain('<parameter=command>');
|
||||
expect(prompt).toContain('</function>');
|
||||
expect(prompt).toContain('</tool_call>');
|
||||
|
||||
// Should NOT contain bracket-style tool calls
|
||||
expect(prompt).not.toContain('[tool_call: run_shell_command for');
|
||||
|
||||
// Should NOT contain JSON-style tool calls
|
||||
expect(prompt).not.toContain('{"name": "run_shell_command"');
|
||||
|
||||
expect(prompt).toMatchSnapshot();
|
||||
});
|
||||
|
||||
it('should use JSON format for qwen-vl model', () => {
|
||||
vi.mocked(isGitRepository).mockReturnValue(false);
|
||||
const prompt = getCoreSystemPrompt(undefined, undefined, 'qwen-vl-max');
|
||||
|
||||
// Should contain JSON-style tool calls
|
||||
expect(prompt).toContain('<tool_call>');
|
||||
expect(prompt).toContain('{"name": "run_shell_command"');
|
||||
expect(prompt).toContain('"arguments": {"command": "node server.js &"}');
|
||||
expect(prompt).toContain('</tool_call>');
|
||||
|
||||
// Should NOT contain bracket-style tool calls
|
||||
expect(prompt).not.toContain('[tool_call: run_shell_command for');
|
||||
|
||||
// Should NOT contain XML-style tool calls with parameters
|
||||
expect(prompt).not.toContain('<function=run_shell_command>');
|
||||
expect(prompt).not.toContain('<parameter=command>');
|
||||
|
||||
expect(prompt).toMatchSnapshot();
|
||||
});
|
||||
|
||||
it('should use bracket format for generic models', () => {
|
||||
vi.mocked(isGitRepository).mockReturnValue(false);
|
||||
const prompt = getCoreSystemPrompt(undefined, undefined, 'gpt-4');
|
||||
|
||||
// Should contain bracket-style tool calls
|
||||
expect(prompt).toContain('[tool_call: run_shell_command for');
|
||||
expect(prompt).toContain('because it must run in the background]');
|
||||
|
||||
// Should NOT contain XML-style tool calls
|
||||
expect(prompt).not.toContain('<function=run_shell_command>');
|
||||
expect(prompt).not.toContain('<parameter=command>');
|
||||
|
||||
// Should NOT contain JSON-style tool calls
|
||||
expect(prompt).not.toContain('{"name": "run_shell_command"');
|
||||
|
||||
expect(prompt).toMatchSnapshot();
|
||||
});
|
||||
|
||||
it('should use bracket format when no model is specified', () => {
|
||||
vi.mocked(isGitRepository).mockReturnValue(false);
|
||||
const prompt = getCoreSystemPrompt();
|
||||
|
||||
// Should contain bracket-style tool calls (default behavior)
|
||||
expect(prompt).toContain('[tool_call: run_shell_command for');
|
||||
expect(prompt).toContain('because it must run in the background]');
|
||||
|
||||
// Should NOT contain XML or JSON formats
|
||||
expect(prompt).not.toContain('<function=run_shell_command>');
|
||||
expect(prompt).not.toContain('{"name": "run_shell_command"');
|
||||
|
||||
expect(prompt).toMatchSnapshot();
|
||||
});
|
||||
|
||||
it('should preserve model-specific formats with user memory', () => {
|
||||
vi.mocked(isGitRepository).mockReturnValue(false);
|
||||
const userMemory = 'User prefers concise responses.';
|
||||
const prompt = getCoreSystemPrompt(
|
||||
userMemory,
|
||||
undefined,
|
||||
'qwen3-coder-14b',
|
||||
);
|
||||
|
||||
// Should contain XML-style tool calls
|
||||
expect(prompt).toContain('<tool_call>');
|
||||
expect(prompt).toContain('<function=run_shell_command>');
|
||||
|
||||
// Should contain user memory with separator
|
||||
expect(prompt).toContain('---');
|
||||
expect(prompt).toContain('User prefers concise responses.');
|
||||
|
||||
expect(prompt).toMatchSnapshot();
|
||||
});
|
||||
|
||||
it('should preserve model-specific formats with sandbox environment', () => {
|
||||
vi.stubEnv('SANDBOX', 'true');
|
||||
vi.mocked(isGitRepository).mockReturnValue(false);
|
||||
const prompt = getCoreSystemPrompt(undefined, undefined, 'qwen-vl-plus');
|
||||
|
||||
// Should contain JSON-style tool calls
|
||||
expect(prompt).toContain('{"name": "run_shell_command"');
|
||||
|
||||
// Should contain sandbox instructions
|
||||
expect(prompt).toContain('# Sandbox');
|
||||
|
||||
expect(prompt).toMatchSnapshot();
|
||||
});
|
||||
});
|
||||
|
||||
describe('getCustomSystemPrompt', () => {
|
||||
it('should handle string custom instruction without user memory', () => {
|
||||
const customInstruction =
|
||||
|
||||
@@ -7,18 +7,10 @@
|
||||
import path from 'node:path';
|
||||
import fs from 'node:fs';
|
||||
import os from 'node:os';
|
||||
import { EditTool } from '../tools/edit.js';
|
||||
import { GlobTool } from '../tools/glob.js';
|
||||
import { GrepTool } from '../tools/grep.js';
|
||||
import { ReadFileTool } from '../tools/read-file.js';
|
||||
import { ReadManyFilesTool } from '../tools/read-many-files.js';
|
||||
import { ShellTool } from '../tools/shell.js';
|
||||
import { WriteFileTool } from '../tools/write-file.js';
|
||||
import { ToolNames } from '../tools/tool-names.js';
|
||||
import process from 'node:process';
|
||||
import { isGitRepository } from '../utils/gitUtils.js';
|
||||
import { MemoryTool, GEMINI_CONFIG_DIR } from '../tools/memoryTool.js';
|
||||
import { TodoWriteTool } from '../tools/todoWrite.js';
|
||||
import { TaskTool } from '../tools/task.js';
|
||||
import { GEMINI_CONFIG_DIR } from '../tools/memoryTool.js';
|
||||
import type { GenerateContentConfig } from '@google/genai';
|
||||
|
||||
export interface ModelTemplateMapping {
|
||||
@@ -91,6 +83,7 @@ export function getCustomSystemPrompt(
|
||||
export function getCoreSystemPrompt(
|
||||
userMemory?: string,
|
||||
config?: SystemPromptConfig,
|
||||
model?: string,
|
||||
): string {
|
||||
// if GEMINI_SYSTEM_MD is set (and not 0|false), override system prompt from file
|
||||
// default path is .gemini/system.md but can be modified via custom path in GEMINI_SYSTEM_MD
|
||||
@@ -177,11 +170,11 @@ You are Qwen Code, an interactive CLI agent developed by Alibaba Group, speciali
|
||||
- **Proactiveness:** Fulfill the user's request thoroughly, including reasonable, directly implied follow-up actions.
|
||||
- **Confirm Ambiguity/Expansion:** Do not take significant actions beyond the clear scope of the request without confirming with the user. If asked *how* to do something, explain first, don't just do it.
|
||||
- **Explaining Changes:** After completing a code modification or file operation *do not* provide summaries unless asked.
|
||||
- **Path Construction:** Before using any file system tool (e.g., ${ReadFileTool.Name}' or '${WriteFileTool.Name}'), you must construct the full absolute path for the file_path argument. Always combine the absolute path of the project's root directory with the file's path relative to the root. For example, if the project root is /path/to/project/ and the file is foo/bar/baz.txt, the final path you must use is /path/to/project/foo/bar/baz.txt. If the user provides a relative path, you must resolve it against the root directory to create an absolute path.
|
||||
- **Path Construction:** Before using any file system tool (e.g., ${ToolNames.READ_FILE}' or '${ToolNames.WRITE_FILE}'), you must construct the full absolute path for the file_path argument. Always combine the absolute path of the project's root directory with the file's path relative to the root. For example, if the project root is /path/to/project/ and the file is foo/bar/baz.txt, the final path you must use is /path/to/project/foo/bar/baz.txt. If the user provides a relative path, you must resolve it against the root directory to create an absolute path.
|
||||
- **Do Not revert changes:** Do not revert changes to the codebase unless asked to do so by the user. Only revert changes made by you if they have resulted in an error or if the user has explicitly asked you to revert the changes.
|
||||
|
||||
# Task Management
|
||||
You have access to the ${TodoWriteTool.Name} tool to help you manage and plan tasks. Use these tools VERY frequently to ensure that you are tracking your tasks and giving the user visibility into your progress.
|
||||
You have access to the ${ToolNames.TODO_WRITE} tool to help you manage and plan tasks. Use these tools VERY frequently to ensure that you are tracking your tasks and giving the user visibility into your progress.
|
||||
These tools are also EXTREMELY helpful for planning tasks, and for breaking down larger complex tasks into smaller steps. If you do not use this tool when planning, you may forget to do important tasks - and that is unacceptable.
|
||||
|
||||
It is critical that you mark todos as completed as soon as you are done with a task. Do not batch up multiple tasks before marking them as completed.
|
||||
@@ -190,13 +183,13 @@ Examples:
|
||||
|
||||
<example>
|
||||
user: Run the build and fix any type errors
|
||||
assistant: I'm going to use the ${TodoWriteTool.Name} tool to write the following items to the todo list:
|
||||
assistant: I'm going to use the ${ToolNames.TODO_WRITE} tool to write the following items to the todo list:
|
||||
- Run the build
|
||||
- Fix any type errors
|
||||
|
||||
I'm now going to run the build using Bash.
|
||||
|
||||
Looks like I found 10 type errors. I'm going to use the ${TodoWriteTool.Name} tool to write 10 items to the todo list.
|
||||
Looks like I found 10 type errors. I'm going to use the ${ToolNames.TODO_WRITE} tool to write 10 items to the todo list.
|
||||
|
||||
marking the first todo as in_progress
|
||||
|
||||
@@ -211,7 +204,7 @@ In the above example, the assistant completes all the tasks, including the 10 er
|
||||
<example>
|
||||
user: Help me write a new feature that allows users to track their usage metrics and export them to various formats
|
||||
|
||||
A: I'll help you implement a usage metrics tracking and export feature. Let me first use the ${TodoWriteTool.Name} tool to plan this task.
|
||||
A: I'll help you implement a usage metrics tracking and export feature. Let me first use the ${ToolNames.TODO_WRITE} tool to plan this task.
|
||||
Adding the following todos to the todo list:
|
||||
1. Research existing metrics tracking in the codebase
|
||||
2. Design the metrics collection system
|
||||
@@ -232,8 +225,8 @@ I've found some existing telemetry code. Let me mark the first todo as in_progre
|
||||
|
||||
## Software Engineering Tasks
|
||||
When requested to perform tasks like fixing bugs, adding features, refactoring, or explaining code, follow this iterative approach:
|
||||
- **Plan:** After understanding the user's request, create an initial plan based on your existing knowledge and any immediately obvious context. Use the '${TodoWriteTool.Name}' tool to capture this rough plan for complex or multi-step work. Don't wait for complete understanding - start with what you know.
|
||||
- **Implement:** Begin implementing the plan while gathering additional context as needed. Use '${GrepTool.Name}', '${GlobTool.Name}', '${ReadFileTool.Name}', and '${ReadManyFilesTool.Name}' tools strategically when you encounter specific unknowns during implementation. Use the available tools (e.g., '${EditTool.Name}', '${WriteFileTool.Name}' '${ShellTool.Name}' ...) to act on the plan, strictly adhering to the project's established conventions (detailed under 'Core Mandates').
|
||||
- **Plan:** After understanding the user's request, create an initial plan based on your existing knowledge and any immediately obvious context. Use the '${ToolNames.TODO_WRITE}' tool to capture this rough plan for complex or multi-step work. Don't wait for complete understanding - start with what you know.
|
||||
- **Implement:** Begin implementing the plan while gathering additional context as needed. Use '${ToolNames.GREP}', '${ToolNames.GLOB}', '${ToolNames.READ_FILE}', and '${ToolNames.READ_MANY_FILES}' tools strategically when you encounter specific unknowns during implementation. Use the available tools (e.g., '${ToolNames.EDIT}', '${ToolNames.WRITE_FILE}' '${ToolNames.SHELL}' ...) to act on the plan, strictly adhering to the project's established conventions (detailed under 'Core Mandates').
|
||||
- **Adapt:** As you discover new information or encounter obstacles, update your plan and todos accordingly. Mark todos as in_progress when starting and completed when finishing each task. Add new todos if the scope expands. Refine your approach based on what you learn.
|
||||
- **Verify (Tests):** If applicable and feasible, verify the changes using the project's testing procedures. Identify the correct test commands and frameworks by examining 'README' files, build/package configuration (e.g., 'package.json'), or existing test execution patterns. NEVER assume standard test commands.
|
||||
- **Verify (Standards):** VERY IMPORTANT: After making code changes, execute the project-specific build, linting and type-checking commands (e.g., 'tsc', 'npm run lint', 'ruff check .') that you have identified for this project (or obtained from the user). This ensures code quality and adherence to standards. If unsure about these commands, you can ask the user if they'd like you to run them and if so how to.
|
||||
@@ -242,11 +235,11 @@ When requested to perform tasks like fixing bugs, adding features, refactoring,
|
||||
|
||||
- Tool results and user messages may include <system-reminder> tags. <system-reminder> tags contain useful information and reminders. They are NOT part of the user's provided input or the tool result.
|
||||
|
||||
IMPORTANT: Always use the ${TodoWriteTool.Name} tool to plan and track tasks throughout the conversation.
|
||||
IMPORTANT: Always use the ${ToolNames.TODO_WRITE} tool to plan and track tasks throughout the conversation.
|
||||
|
||||
## New Applications
|
||||
|
||||
**Goal:** Autonomously implement and deliver a visually appealing, substantially complete, and functional prototype. Utilize all tools at your disposal to implement the application. Some tools you may especially find useful are '${WriteFileTool.Name}', '${EditTool.Name}' and '${ShellTool.Name}'.
|
||||
**Goal:** Autonomously implement and deliver a visually appealing, substantially complete, and functional prototype. Utilize all tools at your disposal to implement the application. Some tools you may especially find useful are '${ToolNames.WRITE_FILE}', '${ToolNames.EDIT}' and '${ToolNames.SHELL}'.
|
||||
|
||||
1. **Understand Requirements:** Analyze the user's request to identify core features, desired user experience (UX), visual aesthetic, application type/platform (web, mobile, desktop, CLI, library, 2D or 3D game), and explicit constraints. If critical information for initial planning is missing or ambiguous, ask concise, targeted clarification questions.
|
||||
2. **Propose Plan:** Formulate an internal development plan. Present a clear, concise, high-level summary to the user. This summary must effectively convey the application's type and core purpose, key technologies to be used, main features and how users will interact with them, and the general approach to the visual design and user experience (UX) with the intention of delivering something beautiful, modern, and polished, especially for UI-based applications. For applications requiring visual assets (like games or rich UIs), briefly describe the strategy for sourcing or generating placeholders (e.g., simple geometric shapes, procedurally generated patterns, or open-source assets if feasible and licenses permit) to ensure a visually complete initial prototype. Ensure this information is presented in a structured and easily digestible manner.
|
||||
@@ -259,7 +252,7 @@ IMPORTANT: Always use the ${TodoWriteTool.Name} tool to plan and track tasks thr
|
||||
- **3d Games:** HTML/CSS/JavaScript with Three.js.
|
||||
- **2d Games:** HTML/CSS/JavaScript.
|
||||
3. **User Approval:** Obtain user approval for the proposed plan.
|
||||
4. **Implementation:** Use the '${TodoWriteTool.Name}' tool to convert the approved plan into a structured todo list with specific, actionable tasks, then autonomously implement each task utilizing all available tools. When starting ensure you scaffold the application using '${ShellTool.Name}' for commands like 'npm init', 'npx create-react-app'. Aim for full scope completion. Proactively create or source necessary placeholder assets (e.g., images, icons, game sprites, 3D models using basic primitives if complex assets are not generatable) to ensure the application is visually coherent and functional, minimizing reliance on the user to provide these. If the model can generate simple assets (e.g., a uniformly colored square sprite, a simple 3D cube), it should do so. Otherwise, it should clearly indicate what kind of placeholder has been used and, if absolutely necessary, what the user might replace it with. Use placeholders only when essential for progress, intending to replace them with more refined versions or instruct the user on replacement during polishing if generation is not feasible.
|
||||
4. **Implementation:** Use the '${ToolNames.TODO_WRITE}' tool to convert the approved plan into a structured todo list with specific, actionable tasks, then autonomously implement each task utilizing all available tools. When starting ensure you scaffold the application using '${ToolNames.SHELL}' for commands like 'npm init', 'npx create-react-app'. Aim for full scope completion. Proactively create or source necessary placeholder assets (e.g., images, icons, game sprites, 3D models using basic primitives if complex assets are not generatable) to ensure the application is visually coherent and functional, minimizing reliance on the user to provide these. If the model can generate simple assets (e.g., a uniformly colored square sprite, a simple 3D cube), it should do so. Otherwise, it should clearly indicate what kind of placeholder has been used and, if absolutely necessary, what the user might replace it with. Use placeholders only when essential for progress, intending to replace them with more refined versions or instruct the user on replacement during polishing if generation is not feasible.
|
||||
5. **Verify:** Review work against the original request, the approved plan. Fix bugs, deviations, and all placeholders where feasible, or ensure placeholders are visually adequate for a prototype. Ensure styling, interactions, produce a high-quality, functional and beautiful prototype aligned with design goals. Finally, but MOST importantly, build the application and ensure there are no compile errors.
|
||||
6. **Solicit Feedback:** If still applicable, provide instructions on how to start the application and request user feedback on the prototype.
|
||||
|
||||
@@ -275,18 +268,18 @@ IMPORTANT: Always use the ${TodoWriteTool.Name} tool to plan and track tasks thr
|
||||
- **Handling Inability:** If unable/unwilling to fulfill a request, state so briefly (1-2 sentences) without excessive justification. Offer alternatives if appropriate.
|
||||
|
||||
## Security and Safety Rules
|
||||
- **Explain Critical Commands:** Before executing commands with '${ShellTool.Name}' that modify the file system, codebase, or system state, you *must* provide a brief explanation of the command's purpose and potential impact. Prioritize user understanding and safety. You should not ask permission to use the tool; the user will be presented with a confirmation dialogue upon use (you do not need to tell them this).
|
||||
- **Explain Critical Commands:** Before executing commands with '${ToolNames.SHELL}' that modify the file system, codebase, or system state, you *must* provide a brief explanation of the command's purpose and potential impact. Prioritize user understanding and safety. You should not ask permission to use the tool; the user will be presented with a confirmation dialogue upon use (you do not need to tell them this).
|
||||
- **Security First:** Always apply security best practices. Never introduce code that exposes, logs, or commits secrets, API keys, or other sensitive information.
|
||||
|
||||
## Tool Usage
|
||||
- **File Paths:** Always use absolute paths when referring to files with tools like '${ReadFileTool.Name}' or '${WriteFileTool.Name}'. Relative paths are not supported. You must provide an absolute path.
|
||||
- **File Paths:** Always use absolute paths when referring to files with tools like '${ToolNames.READ_FILE}' or '${ToolNames.WRITE_FILE}'. Relative paths are not supported. You must provide an absolute path.
|
||||
- **Parallelism:** Execute multiple independent tool calls in parallel when feasible (i.e. searching the codebase).
|
||||
- **Command Execution:** Use the '${ShellTool.Name}' tool for running shell commands, remembering the safety rule to explain modifying commands first.
|
||||
- **Command Execution:** Use the '${ToolNames.SHELL}' tool for running shell commands, remembering the safety rule to explain modifying commands first.
|
||||
- **Background Processes:** Use background processes (via \`&\`) for commands that are unlikely to stop on their own, e.g. \`node server.js &\`. If unsure, ask the user.
|
||||
- **Interactive Commands:** Try to avoid shell commands that are likely to require user interaction (e.g. \`git rebase -i\`). Use non-interactive versions of commands (e.g. \`npm init -y\` instead of \`npm init\`) when available, and otherwise remind the user that interactive shell commands are not supported and may cause hangs until canceled by the user.
|
||||
- **Task Management:** Use the '${TodoWriteTool.Name}' tool proactively for complex, multi-step tasks to track progress and provide visibility to users. This tool helps organize work systematically and ensures no requirements are missed.
|
||||
- **Subagent Delegation:** When doing file search, prefer to use the '${TaskTool.Name}' tool in order to reduce context usage. You should proactively use the '${TaskTool.Name}' tool with specialized agents when the task at hand matches the agent's description.
|
||||
- **Remembering Facts:** Use the '${MemoryTool.Name}' tool to remember specific, *user-related* facts or preferences when the user explicitly asks, or when they state a clear, concise piece of information that would help personalize or streamline *your future interactions with them* (e.g., preferred coding style, common project paths they use, personal tool aliases). This tool is for user-specific information that should persist across sessions. Do *not* use it for general project context or information. If unsure whether to save something, you can ask the user, "Should I remember that for you?"
|
||||
- **Task Management:** Use the '${ToolNames.TODO_WRITE}' tool proactively for complex, multi-step tasks to track progress and provide visibility to users. This tool helps organize work systematically and ensures no requirements are missed.
|
||||
- **Subagent Delegation:** When doing file search, prefer to use the '${ToolNames.TASK}' tool in order to reduce context usage. You should proactively use the '${ToolNames.TASK}' tool with specialized agents when the task at hand matches the agent's description.
|
||||
- **Remembering Facts:** Use the '${ToolNames.MEMORY}' tool to remember specific, *user-related* facts or preferences when the user explicitly asks, or when they state a clear, concise piece of information that would help personalize or streamline *your future interactions with them* (e.g., preferred coding style, common project paths they use, personal tool aliases). This tool is for user-specific information that should persist across sessions. Do *not* use it for general project context or information. If unsure whether to save something, you can ask the user, "Should I remember that for you?"
|
||||
- **Respect User Confirmations:** Most tool calls (also denoted as 'function calls') will first require confirmation from the user, where they will either approve or cancel the function call. If a user cancels a function call, respect their choice and do _not_ try to make the function call again. It is okay to request the tool call again _only_ if the user requests that same tool call on a subsequent prompt. When a user cancels a function call, assume best intentions from the user and consider inquiring if they prefer any alternative paths forward.
|
||||
|
||||
## Interaction Details
|
||||
@@ -338,157 +331,10 @@ ${(function () {
|
||||
return '';
|
||||
})()}
|
||||
|
||||
# Examples (Illustrating Tone and Workflow)
|
||||
<example>
|
||||
user: 1 + 2
|
||||
model: 3
|
||||
</example>
|
||||
|
||||
<example>
|
||||
user: is 13 a prime number?
|
||||
model: true
|
||||
</example>
|
||||
|
||||
<example>
|
||||
user: start the server implemented in server.js
|
||||
model:
|
||||
<tool_call>
|
||||
<function=run_shell_command>
|
||||
<parameter=command>
|
||||
node server.js &
|
||||
</parameter>
|
||||
</function>
|
||||
</tool_call>
|
||||
</example>
|
||||
|
||||
<example>
|
||||
user: Refactor the auth logic in src/auth.py to use the requests library instead of urllib.
|
||||
model: Okay, I can refactor 'src/auth.py'.
|
||||
First, I'll analyze the code and check for a test safety net before planning any changes.
|
||||
<tool_call>
|
||||
<function=glob>
|
||||
<parameter=path>
|
||||
tests/test_auth.py
|
||||
</parameter>
|
||||
</function>
|
||||
</tool_call>
|
||||
<tool_call>
|
||||
<function=read_file>
|
||||
<parameter=path>
|
||||
/path/to/tests/test_auth.py
|
||||
</parameter>
|
||||
<parameter=offset>
|
||||
0
|
||||
</parameter>
|
||||
<parameter=limit>
|
||||
10
|
||||
</parameter>
|
||||
</function>
|
||||
</tool_call>
|
||||
(After analysis)
|
||||
Great, 'tests/test_auth.py' exists and covers the core authentication logic. With this safety net in place, I can safely plan the refactoring.
|
||||
I'll also confirm 'requests' is a dependency.
|
||||
<tool_call>
|
||||
<function=read_file>
|
||||
<parameter=path>
|
||||
/path/to/requirements.txt
|
||||
</parameter>
|
||||
</function>
|
||||
</tool_call>
|
||||
(After analysis)
|
||||
Looks good, 'requests' is available.
|
||||
|
||||
Here's the plan:
|
||||
1. Replace the 'urllib' calls with 'requests'.
|
||||
2. Add proper 'try...except' error handling for the new network calls.
|
||||
3. Remove the old 'urllib' import.
|
||||
4. Run the project's linter and tests to verify the changes.
|
||||
|
||||
<tool_call>
|
||||
<function=replace>
|
||||
<parameter=path>
|
||||
src/auth.py
|
||||
</parameter>
|
||||
<parameter=old_content>
|
||||
(old code content)
|
||||
</parameter>
|
||||
<parameter=new_content>
|
||||
(new code content)
|
||||
</parameter>
|
||||
</function>
|
||||
</tool_call>
|
||||
Refactoring complete. Running verification...
|
||||
<tool_call>
|
||||
<function=run_shell_command
|
||||
<parameter=command>
|
||||
ruff check src/auth.py && pytest
|
||||
</parameter>
|
||||
</function>
|
||||
</tool_call>
|
||||
(After verification passes)
|
||||
All checks passed. This is a stable checkpoint.
|
||||
Would you like me to write a commit message and commit these changes?
|
||||
</example>
|
||||
|
||||
<example>
|
||||
user: Delete the temp directory.
|
||||
model: I can run \`rm -rf /path/to/project/temp\`. This will permanently delete the directory and all its contents.
|
||||
</example>
|
||||
|
||||
<example>
|
||||
user: Write tests for someFile.ts
|
||||
model:
|
||||
Okay, I can write those tests. First, I'll read someFile.ts to understand its functionality.
|
||||
<tool_call>
|
||||
<function=read_file>
|
||||
<parameter=path>
|
||||
/path/to/someFile.ts
|
||||
</parameter>
|
||||
</function>
|
||||
</tool_call>
|
||||
Now I'll look for existing or related test files to understand current testing conventions and dependencies.
|
||||
<tool_call>
|
||||
<function>read_many_files for paths ['**/*.test.ts', 'src/**/*.spec.ts'] assuming someFile.ts is in the src directory]
|
||||
</tool_call>
|
||||
(After reviewing existing tests and the file content)
|
||||
<tool_call>
|
||||
<function=write_file>
|
||||
<parameter=path>
|
||||
/path/to/someFile.test.ts
|
||||
</parameter>
|
||||
</function>
|
||||
</tool_call>
|
||||
I've written the tests. Now I'll run the project's test command to verify them.
|
||||
<tool_call>
|
||||
<function=run_shell_command>
|
||||
<parameter=command>
|
||||
npm run test
|
||||
</parameter>
|
||||
</function>
|
||||
</tool_call>
|
||||
(After verification passes)
|
||||
All checks passed. This is a stable checkpoint.
|
||||
</example>
|
||||
|
||||
<example>
|
||||
user: Where are all the 'app.config' files in this project? I need to check their settings.
|
||||
model:
|
||||
<tool_call>
|
||||
<function=glob>
|
||||
<parameter=pattern>
|
||||
./**/app.config
|
||||
</parameter>
|
||||
</function>
|
||||
</tool_call>
|
||||
(Assuming GlobTool returns a list of paths like ['/path/to/moduleA/app.config', '/path/to/moduleB/app.config'])
|
||||
I found the following 'app.config' files:
|
||||
- /path/to/moduleA/app.config
|
||||
- /path/to/moduleB/app.config
|
||||
To help you check their settings, I can read their contents. Which one would you like to start with, or should I read all of them?
|
||||
</example>
|
||||
${getToolCallExamples(model || '')}
|
||||
|
||||
# Final Reminder
|
||||
Your core function is efficient and safe assistance. Balance extreme conciseness with the crucial need for clarity, especially regarding safety and potential system modifications. Always prioritize user control and project conventions. Never make assumptions about the contents of files; instead use '${ReadFileTool.Name}' or '${ReadManyFilesTool.Name}' to ensure you aren't making broad assumptions. Finally, you are an agent - please keep going until the user's query is completely resolved.
|
||||
Your core function is efficient and safe assistance. Balance extreme conciseness with the crucial need for clarity, especially regarding safety and potential system modifications. Always prioritize user control and project conventions. Never make assumptions about the contents of files; instead use '${ToolNames.READ_FILE}' or '${ToolNames.READ_MANY_FILES}' to ensure you aren't making broad assumptions. Finally, you are an agent - please keep going until the user's query is completely resolved.
|
||||
`.trim();
|
||||
|
||||
// if GEMINI_WRITE_SYSTEM_MD is set (and not 0|false), write base system prompt to file
|
||||
@@ -615,3 +461,374 @@ You are a specialized context summarizer that creates a comprehensive markdown s
|
||||
|
||||
`.trim();
|
||||
}
|
||||
|
||||
const generalToolCallExamples = `
|
||||
# Examples (Illustrating Tone and Workflow)
|
||||
<example>
|
||||
user: 1 + 2
|
||||
model: 3
|
||||
</example>
|
||||
|
||||
<example>
|
||||
user: is 13 a prime number?
|
||||
model: true
|
||||
</example>
|
||||
|
||||
<example>
|
||||
user: start the server implemented in server.js
|
||||
model: [tool_call: ${ToolNames.SHELL} for 'node server.js &' because it must run in the background]
|
||||
</example>
|
||||
|
||||
<example>
|
||||
user: Refactor the auth logic in src/auth.py to use the requests library instead of urllib.
|
||||
model: Okay, I can refactor 'src/auth.py'.
|
||||
First, I'll analyze the code and check for a test safety net before planning any changes.
|
||||
[tool_call: ${ToolNames.GLOB} for path 'tests/test_auth.py']
|
||||
[tool_call: ${ToolNames.READ_FILE} for path '/path/to/tests/test_auth.py' with offset 0 and limit 10]
|
||||
(After analysis)
|
||||
Great, 'tests/test_auth.py' exists and covers the core authentication logic. With this safety net in place, I can safely plan the refactoring.
|
||||
I'll also confirm 'requests' is a dependency.
|
||||
[tool_call: ${ToolNames.READ_FILE} for path '/path/to/requirements.txt']
|
||||
(After analysis)
|
||||
Looks good, 'requests' is available.
|
||||
|
||||
Here's the plan:
|
||||
1. Replace the 'urllib' calls with 'requests'.
|
||||
2. Add proper 'try...except' error handling for the new network calls.
|
||||
3. Remove the old 'urllib' import.
|
||||
4. Run the project's linter and tests to verify the changes.
|
||||
|
||||
[tool_call: ${ToolNames.EDIT} for path 'src/auth.py' replacing old content with new content]
|
||||
Refactoring complete. Running verification...
|
||||
[tool_call: ${ToolNames.SHELL} for 'ruff check src/auth.py && pytest']
|
||||
(After verification passes)
|
||||
All checks passed. This is a stable checkpoint.
|
||||
Would you like me to write a commit message and commit these changes?
|
||||
</example>
|
||||
|
||||
<example>
|
||||
user: Delete the temp directory.
|
||||
model: I can run \`rm -rf /path/to/project/temp\`. This will permanently delete the directory and all its contents.
|
||||
</example>
|
||||
|
||||
<example>
|
||||
user: Write tests for someFile.ts
|
||||
model:
|
||||
Okay, I can write those tests. First, I'll read someFile.ts to understand its functionality.
|
||||
[tool_call: ${ToolNames.READ_FILE} for path '/path/to/someFile.ts']
|
||||
Now I'll look for existing or related test files to understand current testing conventions and dependencies.
|
||||
[tool_call: ${ToolNames.READ_MANY_FILES} for paths ['**/*.test.ts', 'src/**/*.spec.ts']]
|
||||
(After reviewing existing tests and the file content)
|
||||
[tool_call: ${ToolNames.WRITE_FILE} for path '/path/to/someFile.test.ts']
|
||||
I've written the tests. Now I'll run the project's test command to verify them.
|
||||
[tool_call: ${ToolNames.SHELL} for 'npm run test']
|
||||
(After verification passes)
|
||||
All checks passed. This is a stable checkpoint.
|
||||
</example>
|
||||
|
||||
<example>
|
||||
user: Where are all the 'app.config' files in this project? I need to check their settings.
|
||||
model:
|
||||
[tool_call: ${ToolNames.GLOB} for pattern './**/app.config']
|
||||
(Assuming GlobTool returns a list of paths like ['/path/to/moduleA/app.config', '/path/to/moduleB/app.config'])
|
||||
I found the following 'app.config' files:
|
||||
- /path/to/moduleA/app.config
|
||||
- /path/to/moduleB/app.config
|
||||
To help you check their settings, I can read their contents. Which one would you like to start with, or should I read all of them?
|
||||
</example>
|
||||
`.trim();
|
||||
|
||||
const qwenCoderToolCallExamples = `
|
||||
# Examples (Illustrating Tone and Workflow)
|
||||
<example>
|
||||
user: 1 + 2
|
||||
model: 3
|
||||
</example>
|
||||
|
||||
<example>
|
||||
user: is 13 a prime number?
|
||||
model: true
|
||||
</example>
|
||||
|
||||
<example>
|
||||
user: start the server implemented in server.js
|
||||
model:
|
||||
<tool_call>
|
||||
<function=${ToolNames.SHELL}>
|
||||
<parameter=command>
|
||||
node server.js &
|
||||
</parameter>
|
||||
</function>
|
||||
</tool_call>
|
||||
</example>
|
||||
|
||||
<example>
|
||||
user: Refactor the auth logic in src/auth.py to use the requests library instead of urllib.
|
||||
model: Okay, I can refactor 'src/auth.py'.
|
||||
First, I'll analyze the code and check for a test safety net before planning any changes.
|
||||
<tool_call>
|
||||
<function=${ToolNames.GLOB}>
|
||||
<parameter=path>
|
||||
tests/test_auth.py
|
||||
</parameter>
|
||||
</function>
|
||||
</tool_call>
|
||||
<tool_call>
|
||||
<function=${ToolNames.READ_FILE}>
|
||||
<parameter=path>
|
||||
/path/to/tests/test_auth.py
|
||||
</parameter>
|
||||
<parameter=offset>
|
||||
0
|
||||
</parameter>
|
||||
<parameter=limit>
|
||||
10
|
||||
</parameter>
|
||||
</function>
|
||||
</tool_call>
|
||||
(After analysis)
|
||||
Great, 'tests/test_auth.py' exists and covers the core authentication logic. With this safety net in place, I can safely plan the refactoring.
|
||||
I'll also confirm 'requests' is a dependency.
|
||||
<tool_call>
|
||||
<function=${ToolNames.READ_FILE}>
|
||||
<parameter=path>
|
||||
/path/to/requirements.txt
|
||||
</parameter>
|
||||
</function>
|
||||
</tool_call>
|
||||
(After analysis)
|
||||
Looks good, 'requests' is available.
|
||||
|
||||
Here's the plan:
|
||||
1. Replace the 'urllib' calls with 'requests'.
|
||||
2. Add proper 'try...except' error handling for the new network calls.
|
||||
3. Remove the old 'urllib' import.
|
||||
4. Run the project's linter and tests to verify the changes.
|
||||
|
||||
<tool_call>
|
||||
<function=${ToolNames.EDIT}>
|
||||
<parameter=path>
|
||||
src/auth.py
|
||||
</parameter>
|
||||
<parameter=old_content>
|
||||
(old code content)
|
||||
</parameter>
|
||||
<parameter=new_content>
|
||||
(new code content)
|
||||
</parameter>
|
||||
</function>
|
||||
</tool_call>
|
||||
Refactoring complete. Running verification...
|
||||
<tool_call>
|
||||
<function=${ToolNames.SHELL}>
|
||||
<parameter=command>
|
||||
ruff check src/auth.py && pytest
|
||||
</parameter>
|
||||
</function>
|
||||
</tool_call>
|
||||
(After verification passes)
|
||||
All checks passed. This is a stable checkpoint.
|
||||
Would you like me to write a commit message and commit these changes?
|
||||
</example>
|
||||
|
||||
<example>
|
||||
user: Delete the temp directory.
|
||||
model: I can run \`rm -rf /path/to/project/temp\`. This will permanently delete the directory and all its contents.
|
||||
</example>
|
||||
|
||||
<example>
|
||||
user: Write tests for someFile.ts
|
||||
model:
|
||||
Okay, I can write those tests. First, I'll read someFile.ts to understand its functionality.
|
||||
<tool_call>
|
||||
<function=${ToolNames.READ_FILE}>
|
||||
<parameter=path>
|
||||
/path/to/someFile.ts
|
||||
</parameter>
|
||||
</function>
|
||||
</tool_call>
|
||||
Now I'll look for existing or related test files to understand current testing conventions and dependencies.
|
||||
<tool_call>
|
||||
<function=${ToolNames.READ_MANY_FILES}>
|
||||
<parameter=paths>
|
||||
['**/*.test.ts', 'src/**/*.spec.ts']
|
||||
</parameter>
|
||||
</function>
|
||||
</tool_call>
|
||||
(After reviewing existing tests and the file content)
|
||||
<tool_call>
|
||||
<function=${ToolNames.WRITE_FILE}>
|
||||
<parameter=path>
|
||||
/path/to/someFile.test.ts
|
||||
</parameter>
|
||||
</function>
|
||||
</tool_call>
|
||||
I've written the tests. Now I'll run the project's test command to verify them.
|
||||
<tool_call>
|
||||
<function=${ToolNames.SHELL}>
|
||||
<parameter=command>
|
||||
npm run test
|
||||
</parameter>
|
||||
</function>
|
||||
</tool_call>
|
||||
(After verification passes)
|
||||
All checks passed. This is a stable checkpoint.
|
||||
</example>
|
||||
|
||||
<example>
|
||||
user: Where are all the 'app.config' files in this project? I need to check their settings.
|
||||
model:
|
||||
<tool_call>
|
||||
<function=${ToolNames.GLOB}>
|
||||
<parameter=pattern>
|
||||
./**/app.config
|
||||
</parameter>
|
||||
</function>
|
||||
</tool_call>
|
||||
(Assuming GlobTool returns a list of paths like ['/path/to/moduleA/app.config', '/path/to/moduleB/app.config'])
|
||||
I found the following 'app.config' files:
|
||||
- /path/to/moduleA/app.config
|
||||
- /path/to/moduleB/app.config
|
||||
To help you check their settings, I can read their contents. Which one would you like to start with, or should I read all of them?
|
||||
</example>
|
||||
`.trim();
|
||||
const qwenVlToolCallExamples = `
|
||||
# Examples (Illustrating Tone and Workflow)
|
||||
<example>
|
||||
user: 1 + 2
|
||||
model: 3
|
||||
</example>
|
||||
|
||||
<example>
|
||||
user: is 13 a prime number?
|
||||
model: true
|
||||
</example>
|
||||
|
||||
<example>
|
||||
user: start the server implemented in server.js
|
||||
model:
|
||||
<tool_call>
|
||||
{"name": "${ToolNames.SHELL}", "arguments": {"command": "node server.js &"}}
|
||||
</tool_call>
|
||||
</example>
|
||||
|
||||
<example>
|
||||
user: Refactor the auth logic in src/auth.py to use the requests library instead of urllib.
|
||||
model: Okay, I can refactor 'src/auth.py'.
|
||||
First, I'll analyze the code and check for a test safety net before planning any changes.
|
||||
<tool_call>
|
||||
{"name": "${ToolNames.GLOB}", "arguments": {"path": "tests/test_auth.py"}}
|
||||
</tool_call>
|
||||
<tool_call>
|
||||
{"name": "${ToolNames.READ_FILE}", "arguments": {"path": "/path/to/tests/test_auth.py", "offset": 0, "limit": 10}}
|
||||
</tool_call>
|
||||
(After analysis)
|
||||
Great, 'tests/test_auth.py' exists and covers the core authentication logic. With this safety net in place, I can safely plan the refactoring.
|
||||
I'll also confirm 'requests' is a dependency.
|
||||
<tool_call>
|
||||
{"name": "${ToolNames.READ_FILE}", "arguments": {"path": "/path/to/requirements.txt"}}
|
||||
</tool_call>
|
||||
(After analysis)
|
||||
Looks good, 'requests' is available.
|
||||
|
||||
Here's the plan:
|
||||
1. Replace the 'urllib' calls with 'requests'.
|
||||
2. Add proper 'try...except' error handling for the new network calls.
|
||||
3. Remove the old 'urllib' import.
|
||||
4. Run the project's linter and tests to verify the changes.
|
||||
|
||||
<tool_call>
|
||||
{"name": "${ToolNames.EDIT}", "arguments": {"path": "src/auth.py", "old_content": "(old code content)", "new_content": "(new code content)"}}
|
||||
</tool_call>
|
||||
Refactoring complete. Running verification...
|
||||
<tool_call>
|
||||
{"name": "${ToolNames.SHELL}", "arguments": {"command": "ruff check src/auth.py && pytest"}}
|
||||
</tool_call>
|
||||
(After verification passes)
|
||||
All checks passed. This is a stable checkpoint.
|
||||
Would you like me to write a commit message and commit these changes?
|
||||
</example>
|
||||
|
||||
<example>
|
||||
user: Delete the temp directory.
|
||||
model: I can run \`rm -rf /path/to/project/temp\`. This will permanently delete the directory and all its contents.
|
||||
</example>
|
||||
|
||||
<example>
|
||||
user: Write tests for someFile.ts
|
||||
model:
|
||||
Okay, I can write those tests. First, I'll read someFile.ts to understand its functionality.
|
||||
<tool_call>
|
||||
{"name": "${ToolNames.READ_FILE}", "arguments": {"path": "/path/to/someFile.ts"}}
|
||||
</tool_call>
|
||||
Now I'll look for existing or related test files to understand current testing conventions and dependencies.
|
||||
<tool_call>
|
||||
{"name": "${ToolNames.READ_MANY_FILES}", "arguments": {"paths": ["**/*.test.ts", "src/**/*.spec.ts"]}}
|
||||
</tool_call>
|
||||
(After reviewing existing tests and the file content)
|
||||
<tool_call>
|
||||
{"name": "${ToolNames.WRITE_FILE}", "arguments": {"path": "/path/to/someFile.test.ts"}}
|
||||
</tool_call>
|
||||
I've written the tests. Now I'll run the project's test command to verify them.
|
||||
<tool_call>
|
||||
{"name": "${ToolNames.SHELL}", "arguments": {"command": "npm run test"}}
|
||||
</tool_call>
|
||||
(After verification passes)
|
||||
All checks passed. This is a stable checkpoint.
|
||||
</example>
|
||||
|
||||
<example>
|
||||
user: Where are all the 'app.config' files in this project? I need to check their settings.
|
||||
model:
|
||||
<tool_call>
|
||||
{"name": "${ToolNames.GLOB}", "arguments": {"pattern": "./**/app.config"}}
|
||||
</tool_call>
|
||||
(Assuming GlobTool returns a list of paths like ['/path/to/moduleA/app.config', '/path/to/moduleB/app.config'])
|
||||
I found the following 'app.config' files:
|
||||
- /path/to/moduleA/app.config
|
||||
- /path/to/moduleB/app.config
|
||||
To help you check their settings, I can read their contents. Which one would you like to start with, or should I read all of them?
|
||||
</example>
|
||||
`.trim();
|
||||
|
||||
function getToolCallExamples(model?: string): string {
|
||||
// Check for environment variable override first
|
||||
const toolCallStyle = process.env['QWEN_CODE_TOOL_CALL_STYLE'];
|
||||
if (toolCallStyle) {
|
||||
switch (toolCallStyle.toLowerCase()) {
|
||||
case 'qwen-coder':
|
||||
return qwenCoderToolCallExamples;
|
||||
case 'qwen-vl':
|
||||
return qwenVlToolCallExamples;
|
||||
case 'general':
|
||||
return generalToolCallExamples;
|
||||
default:
|
||||
console.warn(
|
||||
`Unknown QWEN_CODE_TOOL_CALL_STYLE value: ${toolCallStyle}. Using model-based detection.`,
|
||||
);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Enhanced regex-based model detection
|
||||
if (model && model.length < 100) {
|
||||
// Match qwen*-coder patterns (e.g., qwen3-coder, qwen2.5-coder, qwen-coder)
|
||||
if (/qwen[^-]*-coder/i.test(model)) {
|
||||
return qwenCoderToolCallExamples;
|
||||
}
|
||||
// Match qwen*-vl patterns (e.g., qwen-vl, qwen2-vl, qwen3-vl)
|
||||
if (/qwen[^-]*-vl/i.test(model)) {
|
||||
return qwenVlToolCallExamples;
|
||||
}
|
||||
// Match coder-model pattern (same as qwen3-coder)
|
||||
if (/coder-model/i.test(model)) {
|
||||
return qwenCoderToolCallExamples;
|
||||
}
|
||||
// Match vision-model pattern (same as qwen3-vl)
|
||||
if (/vision-model/i.test(model)) {
|
||||
return qwenVlToolCallExamples;
|
||||
}
|
||||
}
|
||||
|
||||
return generalToolCallExamples;
|
||||
}
|
||||
|
||||
@@ -1,5 +1,10 @@
|
||||
import { describe, it, expect } from 'vitest';
|
||||
import { normalize, tokenLimit, DEFAULT_TOKEN_LIMIT } from './tokenLimits.js';
|
||||
import {
|
||||
normalize,
|
||||
tokenLimit,
|
||||
DEFAULT_TOKEN_LIMIT,
|
||||
DEFAULT_OUTPUT_TOKEN_LIMIT,
|
||||
} from './tokenLimits.js';
|
||||
|
||||
describe('normalize', () => {
|
||||
it('should lowercase and trim the model string', () => {
|
||||
@@ -225,3 +230,96 @@ describe('tokenLimit', () => {
|
||||
expect(tokenLimit('CLAUDE-3.5-SONNET')).toBe(200000);
|
||||
});
|
||||
});
|
||||
|
||||
describe('tokenLimit with output type', () => {
|
||||
describe('Qwen models with output limits', () => {
|
||||
it('should return the correct output limit for qwen3-coder-plus', () => {
|
||||
expect(tokenLimit('qwen3-coder-plus', 'output')).toBe(65536);
|
||||
expect(tokenLimit('qwen3-coder-plus-20250601', 'output')).toBe(65536);
|
||||
});
|
||||
|
||||
it('should return the correct output limit for qwen-vl-max-latest', () => {
|
||||
expect(tokenLimit('qwen-vl-max-latest', 'output')).toBe(8192);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Default output limits', () => {
|
||||
it('should return the default output limit for unknown models', () => {
|
||||
expect(tokenLimit('unknown-model', 'output')).toBe(
|
||||
DEFAULT_OUTPUT_TOKEN_LIMIT,
|
||||
);
|
||||
expect(tokenLimit('gpt-4', 'output')).toBe(DEFAULT_OUTPUT_TOKEN_LIMIT);
|
||||
expect(tokenLimit('claude-3.5-sonnet', 'output')).toBe(
|
||||
DEFAULT_OUTPUT_TOKEN_LIMIT,
|
||||
);
|
||||
});
|
||||
|
||||
it('should return the default output limit for models without specific output patterns', () => {
|
||||
expect(tokenLimit('qwen3-coder-7b', 'output')).toBe(
|
||||
DEFAULT_OUTPUT_TOKEN_LIMIT,
|
||||
);
|
||||
expect(tokenLimit('qwen-plus', 'output')).toBe(
|
||||
DEFAULT_OUTPUT_TOKEN_LIMIT,
|
||||
);
|
||||
expect(tokenLimit('qwen-vl-max', 'output')).toBe(
|
||||
DEFAULT_OUTPUT_TOKEN_LIMIT,
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Input vs Output limits comparison', () => {
|
||||
it('should return different limits for input vs output for qwen3-coder-plus', () => {
|
||||
expect(tokenLimit('qwen3-coder-plus', 'input')).toBe(1048576); // 1M input
|
||||
expect(tokenLimit('qwen3-coder-plus', 'output')).toBe(65536); // 64K output
|
||||
});
|
||||
|
||||
it('should return different limits for input vs output for qwen-vl-max-latest', () => {
|
||||
expect(tokenLimit('qwen-vl-max-latest', 'input')).toBe(131072); // 128K input
|
||||
expect(tokenLimit('qwen-vl-max-latest', 'output')).toBe(8192); // 8K output
|
||||
});
|
||||
|
||||
it('should return same default limits for unknown models', () => {
|
||||
expect(tokenLimit('unknown-model', 'input')).toBe(DEFAULT_TOKEN_LIMIT); // 128K input
|
||||
expect(tokenLimit('unknown-model', 'output')).toBe(
|
||||
DEFAULT_OUTPUT_TOKEN_LIMIT,
|
||||
); // 4K output
|
||||
});
|
||||
});
|
||||
|
||||
describe('Backward compatibility', () => {
|
||||
it('should default to input type when no type is specified', () => {
|
||||
expect(tokenLimit('qwen3-coder-plus')).toBe(1048576); // Should be input limit
|
||||
expect(tokenLimit('qwen-vl-max-latest')).toBe(131072); // Should be input limit
|
||||
expect(tokenLimit('unknown-model')).toBe(DEFAULT_TOKEN_LIMIT); // Should be input default
|
||||
});
|
||||
|
||||
it('should work with explicit input type', () => {
|
||||
expect(tokenLimit('qwen3-coder-plus', 'input')).toBe(1048576);
|
||||
expect(tokenLimit('qwen-vl-max-latest', 'input')).toBe(131072);
|
||||
expect(tokenLimit('unknown-model', 'input')).toBe(DEFAULT_TOKEN_LIMIT);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Model normalization with output limits', () => {
|
||||
it('should handle normalized model names for output limits', () => {
|
||||
expect(tokenLimit('QWEN3-CODER-PLUS', 'output')).toBe(65536);
|
||||
expect(tokenLimit('qwen3-coder-plus-20250601', 'output')).toBe(65536);
|
||||
expect(tokenLimit('QWEN-VL-MAX-LATEST', 'output')).toBe(8192);
|
||||
});
|
||||
|
||||
it('should handle complex model strings for output limits', () => {
|
||||
expect(
|
||||
tokenLimit(
|
||||
' a/b/c|QWEN3-CODER-PLUS:qwen3-coder-plus-2024-05-13 ',
|
||||
'output',
|
||||
),
|
||||
).toBe(65536);
|
||||
expect(
|
||||
tokenLimit(
|
||||
'provider/qwen-vl-max-latest:qwen-vl-max-latest-v1',
|
||||
'output',
|
||||
),
|
||||
).toBe(8192);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -1,7 +1,15 @@
|
||||
type Model = string;
|
||||
type TokenCount = number;
|
||||
|
||||
/**
|
||||
* Token limit types for different use cases.
|
||||
* - 'input': Maximum input context window size
|
||||
* - 'output': Maximum output tokens that can be generated in a single response
|
||||
*/
|
||||
export type TokenLimitType = 'input' | 'output';
|
||||
|
||||
export const DEFAULT_TOKEN_LIMIT: TokenCount = 131_072; // 128K (power-of-two)
|
||||
export const DEFAULT_OUTPUT_TOKEN_LIMIT: TokenCount = 4_096; // 4K tokens
|
||||
|
||||
/**
|
||||
* Accurate numeric limits:
|
||||
@@ -18,6 +26,10 @@ const LIMITS = {
|
||||
'1m': 1_048_576,
|
||||
'2m': 2_097_152,
|
||||
'10m': 10_485_760, // 10 million tokens
|
||||
// Output token limits (typically much smaller than input limits)
|
||||
'4k': 4_096,
|
||||
'8k': 8_192,
|
||||
'16k': 16_384,
|
||||
} as const;
|
||||
|
||||
/** Robust normalizer: strips provider prefixes, pipes/colons, date/version suffixes, etc. */
|
||||
@@ -36,7 +48,7 @@ export function normalize(model: string): string {
|
||||
// - dates (e.g., -20250219), -v1, version numbers, 'latest', 'preview' etc.
|
||||
s = s.replace(/-preview/g, '');
|
||||
// Special handling for Qwen model names that include "-latest" as part of the model name
|
||||
if (!s.match(/^qwen-(?:plus|flash)-latest$/)) {
|
||||
if (!s.match(/^qwen-(?:plus|flash|vl-max)-latest$/)) {
|
||||
// \d{6,} - Match 6 or more digits (dates) like -20250219 (6+ digit dates)
|
||||
// \d+x\d+b - Match patterns like 4x8b, -7b, -70b
|
||||
// v\d+(?:\.\d+)* - Match version patterns starting with 'v' like -v1, -v1.2, -v2.1.3
|
||||
@@ -99,6 +111,12 @@ const PATTERNS: Array<[RegExp, TokenCount]> = [
|
||||
// Commercial Qwen3-Coder-Flash: 1M token context
|
||||
[/^qwen3-coder-flash(-.*)?$/, LIMITS['1m']], // catches "qwen3-coder-flash" and date variants
|
||||
|
||||
// Generic coder-model: same as qwen3-coder-plus (1M token context)
|
||||
[/^coder-model$/, LIMITS['1m']],
|
||||
|
||||
// Commercial Qwen3-Max-Preview: 256K token context
|
||||
[/^qwen3-max-preview(-.*)?$/, LIMITS['256k']], // catches "qwen3-max-preview" and date variants
|
||||
|
||||
// Open-source Qwen3-Coder variants: 256K native
|
||||
[/^qwen3-coder-.*$/, LIMITS['256k']],
|
||||
// Open-source Qwen3 2507 variants: 256K native
|
||||
@@ -116,6 +134,12 @@ const PATTERNS: Array<[RegExp, TokenCount]> = [
|
||||
[/^qwen-flash-latest$/, LIMITS['1m']],
|
||||
[/^qwen-turbo.*$/, LIMITS['128k']],
|
||||
|
||||
// Qwen Vision Models
|
||||
[/^qwen-vl-max.*$/, LIMITS['128k']],
|
||||
|
||||
// Generic vision-model: same as qwen-vl-max (128K token context)
|
||||
[/^vision-model$/, LIMITS['128k']],
|
||||
|
||||
// -------------------
|
||||
// ByteDance Seed-OSS (512K)
|
||||
// -------------------
|
||||
@@ -139,16 +163,60 @@ const PATTERNS: Array<[RegExp, TokenCount]> = [
|
||||
[/^mistral-large-2.*$/, LIMITS['128k']],
|
||||
];
|
||||
|
||||
/** Return the token limit for a model string (uses normalize + ordered regex list). */
|
||||
export function tokenLimit(model: Model): TokenCount {
|
||||
/**
|
||||
* Output token limit patterns for specific model families.
|
||||
* These patterns define the maximum number of tokens that can be generated
|
||||
* in a single response for specific models.
|
||||
*/
|
||||
const OUTPUT_PATTERNS: Array<[RegExp, TokenCount]> = [
|
||||
// -------------------
|
||||
// Alibaba / Qwen - DashScope Models
|
||||
// -------------------
|
||||
// Qwen3-Coder-Plus: 65,536 max output tokens
|
||||
[/^qwen3-coder-plus(-.*)?$/, LIMITS['64k']],
|
||||
|
||||
// Generic coder-model: same as qwen3-coder-plus (64K max output tokens)
|
||||
[/^coder-model$/, LIMITS['64k']],
|
||||
|
||||
// Qwen3-Max-Preview: 65,536 max output tokens
|
||||
[/^qwen3-max-preview(-.*)?$/, LIMITS['64k']],
|
||||
|
||||
// Qwen-VL-Max-Latest: 8,192 max output tokens
|
||||
[/^qwen-vl-max-latest$/, LIMITS['8k']],
|
||||
|
||||
// Generic vision-model: same as qwen-vl-max-latest (8K max output tokens)
|
||||
[/^vision-model$/, LIMITS['8k']],
|
||||
|
||||
// Qwen3-VL-Plus: 8,192 max output tokens
|
||||
[/^qwen3-vl-plus$/, LIMITS['8k']],
|
||||
];
|
||||
|
||||
/**
|
||||
* Return the token limit for a model string based on the specified type.
|
||||
*
|
||||
* This function determines the maximum number of tokens for either input context
|
||||
* or output generation based on the model and token type. It uses the same
|
||||
* normalization logic for consistency across both input and output limits.
|
||||
*
|
||||
* @param model - The model name to get the token limit for
|
||||
* @param type - The type of token limit ('input' for context window, 'output' for generation)
|
||||
* @returns The maximum number of tokens allowed for this model and type
|
||||
*/
|
||||
export function tokenLimit(
|
||||
model: Model,
|
||||
type: TokenLimitType = 'input',
|
||||
): TokenCount {
|
||||
const norm = normalize(model);
|
||||
|
||||
for (const [regex, limit] of PATTERNS) {
|
||||
// Choose the appropriate patterns based on token type
|
||||
const patterns = type === 'output' ? OUTPUT_PATTERNS : PATTERNS;
|
||||
|
||||
for (const [regex, limit] of patterns) {
|
||||
if (regex.test(norm)) {
|
||||
return limit;
|
||||
}
|
||||
}
|
||||
|
||||
// final fallback: DEFAULT_TOKEN_LIMIT (power-of-two 128K)
|
||||
return DEFAULT_TOKEN_LIMIT;
|
||||
// Return appropriate default based on token type
|
||||
return type === 'output' ? DEFAULT_OUTPUT_TOKEN_LIMIT : DEFAULT_TOKEN_LIMIT;
|
||||
}
|
||||
|
||||
@@ -242,7 +242,7 @@ describe('Turn', () => {
|
||||
expect(turn.getDebugResponses().length).toBe(0);
|
||||
expect(reportError).toHaveBeenCalledWith(
|
||||
error,
|
||||
'Error when talking to Gemini API',
|
||||
'Error when talking to API',
|
||||
[...historyContent, reqParts],
|
||||
'Turn.run-sendMessageStream',
|
||||
);
|
||||
|
||||
@@ -310,7 +310,7 @@ export class Turn {
|
||||
const contextForReport = [...this.chat.getHistory(/*curated*/ true), req];
|
||||
await reportError(
|
||||
error,
|
||||
'Error when talking to Gemini API',
|
||||
'Error when talking to API',
|
||||
contextForReport,
|
||||
'Turn.run-sendMessageStream',
|
||||
);
|
||||
|
||||
@@ -401,11 +401,9 @@ describe('QwenContentGenerator', () => {
|
||||
expect(mockQwenClient.getAccessToken).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should count tokens with valid token', async () => {
|
||||
vi.mocked(mockQwenClient.getAccessToken).mockResolvedValue({
|
||||
token: 'valid-token',
|
||||
});
|
||||
vi.mocked(mockQwenClient.getCredentials).mockReturnValue(mockCredentials);
|
||||
it('should count tokens without requiring authentication', async () => {
|
||||
// Clear any previous mock calls
|
||||
vi.clearAllMocks();
|
||||
|
||||
const request: CountTokensParameters = {
|
||||
model: 'qwen-turbo',
|
||||
@@ -415,7 +413,8 @@ describe('QwenContentGenerator', () => {
|
||||
const result = await qwenContentGenerator.countTokens(request);
|
||||
|
||||
expect(result.totalTokens).toBe(15);
|
||||
expect(mockQwenClient.getAccessToken).toHaveBeenCalled();
|
||||
// countTokens is a local operation and should not require OAuth credentials
|
||||
expect(mockQwenClient.getAccessToken).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should embed content with valid token', async () => {
|
||||
@@ -1652,7 +1651,7 @@ describe('QwenContentGenerator', () => {
|
||||
SharedTokenManager.getInstance = originalGetInstance;
|
||||
});
|
||||
|
||||
it('should handle all method types with token failure', async () => {
|
||||
it('should handle method types with token failure (except countTokens)', async () => {
|
||||
const mockTokenManager = {
|
||||
getValidCredentials: vi
|
||||
.fn()
|
||||
@@ -1685,7 +1684,7 @@ describe('QwenContentGenerator', () => {
|
||||
contents: [{ parts: [{ text: 'Embed' }] }],
|
||||
};
|
||||
|
||||
// All methods should fail with the same error
|
||||
// Methods requiring authentication should fail
|
||||
await expect(
|
||||
newGenerator.generateContent(generateRequest, 'test-id'),
|
||||
).rejects.toThrow('Failed to obtain valid Qwen access token');
|
||||
@@ -1694,14 +1693,14 @@ describe('QwenContentGenerator', () => {
|
||||
newGenerator.generateContentStream(generateRequest, 'test-id'),
|
||||
).rejects.toThrow('Failed to obtain valid Qwen access token');
|
||||
|
||||
await expect(newGenerator.countTokens(countRequest)).rejects.toThrow(
|
||||
'Failed to obtain valid Qwen access token',
|
||||
);
|
||||
|
||||
await expect(newGenerator.embedContent(embedRequest)).rejects.toThrow(
|
||||
'Failed to obtain valid Qwen access token',
|
||||
);
|
||||
|
||||
// countTokens should succeed as it's a local operation
|
||||
const countResult = await newGenerator.countTokens(countRequest);
|
||||
expect(countResult.totalTokens).toBe(15);
|
||||
|
||||
SharedTokenManager.getInstance = originalGetInstance;
|
||||
});
|
||||
});
|
||||
|
||||
@@ -180,9 +180,7 @@ export class QwenContentGenerator extends OpenAIContentGenerator {
|
||||
override async countTokens(
|
||||
request: CountTokensParameters,
|
||||
): Promise<CountTokensResponse> {
|
||||
return this.executeWithCredentialManagement(() =>
|
||||
super.countTokens(request),
|
||||
);
|
||||
return super.countTokens(request);
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@@ -712,8 +712,6 @@ async function authWithQwenDeviceFlow(
|
||||
`Polling... (attempt ${attempt + 1}/${maxAttempts})`,
|
||||
);
|
||||
|
||||
process.stdout.write('.');
|
||||
|
||||
// Wait with cancellation check every 100ms
|
||||
await new Promise<void>((resolve) => {
|
||||
const checkInterval = 100; // Check every 100ms
|
||||
|
||||
@@ -901,5 +901,37 @@ describe('SharedTokenManager', () => {
|
||||
);
|
||||
}
|
||||
});
|
||||
|
||||
it('should properly clean up timeout when file operation completes before timeout', async () => {
|
||||
const tokenManager = SharedTokenManager.getInstance();
|
||||
tokenManager.clearCache();
|
||||
|
||||
const mockClient = {
|
||||
getCredentials: vi.fn().mockReturnValue(null),
|
||||
setCredentials: vi.fn(),
|
||||
getAccessToken: vi.fn(),
|
||||
requestDeviceAuthorization: vi.fn(),
|
||||
pollDeviceToken: vi.fn(),
|
||||
refreshAccessToken: vi.fn(),
|
||||
};
|
||||
|
||||
// Mock clearTimeout to verify it's called
|
||||
const clearTimeoutSpy = vi.spyOn(global, 'clearTimeout');
|
||||
|
||||
// Mock file stat to resolve quickly (before timeout)
|
||||
mockFs.stat.mockResolvedValue({ mtimeMs: 12345 } as Stats);
|
||||
|
||||
// Call checkAndReloadIfNeeded which uses withTimeout internally
|
||||
const checkMethod = getPrivateProperty(
|
||||
tokenManager,
|
||||
'checkAndReloadIfNeeded',
|
||||
) as (client?: IQwenOAuth2Client) => Promise<void>;
|
||||
await checkMethod.call(tokenManager, mockClient);
|
||||
|
||||
// Verify that clearTimeout was called to clean up the timer
|
||||
expect(clearTimeoutSpy).toHaveBeenCalled();
|
||||
|
||||
clearTimeoutSpy.mockRestore();
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -290,6 +290,36 @@ export class SharedTokenManager {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Utility method to add timeout to any promise operation
|
||||
* Properly cleans up the timeout when the promise completes
|
||||
*/
|
||||
private withTimeout<T>(
|
||||
promise: Promise<T>,
|
||||
timeoutMs: number,
|
||||
operationType = 'Operation',
|
||||
): Promise<T> {
|
||||
let timeoutId: NodeJS.Timeout;
|
||||
|
||||
return Promise.race([
|
||||
promise.finally(() => {
|
||||
// Clear timeout when main promise completes (success or failure)
|
||||
if (timeoutId) {
|
||||
clearTimeout(timeoutId);
|
||||
}
|
||||
}),
|
||||
new Promise<never>((_, reject) => {
|
||||
timeoutId = setTimeout(
|
||||
() =>
|
||||
reject(
|
||||
new Error(`${operationType} timed out after ${timeoutMs}ms`),
|
||||
),
|
||||
timeoutMs,
|
||||
);
|
||||
}),
|
||||
]);
|
||||
}
|
||||
|
||||
/**
|
||||
* Perform the actual file check and reload operation
|
||||
* This is separated to enable proper promise-based synchronization
|
||||
@@ -303,25 +333,12 @@ export class SharedTokenManager {
|
||||
|
||||
try {
|
||||
const filePath = this.getCredentialFilePath();
|
||||
// Add timeout to file stat operation
|
||||
const withTimeout = async <T>(
|
||||
promise: Promise<T>,
|
||||
timeoutMs: number,
|
||||
): Promise<T> =>
|
||||
Promise.race([
|
||||
promise,
|
||||
new Promise<never>((_, reject) =>
|
||||
setTimeout(
|
||||
() =>
|
||||
reject(
|
||||
new Error(`File operation timed out after ${timeoutMs}ms`),
|
||||
),
|
||||
timeoutMs,
|
||||
),
|
||||
),
|
||||
]);
|
||||
|
||||
const stats = await withTimeout(fs.stat(filePath), 3000);
|
||||
const stats = await this.withTimeout(
|
||||
fs.stat(filePath),
|
||||
3000,
|
||||
'File operation',
|
||||
);
|
||||
const fileModTime = stats.mtimeMs;
|
||||
|
||||
// Reload credentials if file has been modified since last cache
|
||||
@@ -451,7 +468,7 @@ export class SharedTokenManager {
|
||||
// Check if we have a refresh token before attempting refresh
|
||||
const currentCredentials = qwenClient.getCredentials();
|
||||
if (!currentCredentials.refresh_token) {
|
||||
console.debug('create a NO_REFRESH_TOKEN error');
|
||||
// console.debug('create a NO_REFRESH_TOKEN error');
|
||||
throw new TokenManagerError(
|
||||
TokenError.NO_REFRESH_TOKEN,
|
||||
'No refresh token available for token refresh',
|
||||
@@ -589,26 +606,12 @@ export class SharedTokenManager {
|
||||
const dirPath = path.dirname(filePath);
|
||||
const tempPath = `${filePath}.tmp.${randomUUID()}`;
|
||||
|
||||
// Add timeout wrapper for file operations
|
||||
const withTimeout = async <T>(
|
||||
promise: Promise<T>,
|
||||
timeoutMs: number,
|
||||
): Promise<T> =>
|
||||
Promise.race([
|
||||
promise,
|
||||
new Promise<never>((_, reject) =>
|
||||
setTimeout(
|
||||
() => reject(new Error(`Operation timed out after ${timeoutMs}ms`)),
|
||||
timeoutMs,
|
||||
),
|
||||
),
|
||||
]);
|
||||
|
||||
// Create directory with restricted permissions
|
||||
try {
|
||||
await withTimeout(
|
||||
await this.withTimeout(
|
||||
fs.mkdir(dirPath, { recursive: true, mode: 0o700 }),
|
||||
5000,
|
||||
'File operation',
|
||||
);
|
||||
} catch (error) {
|
||||
throw new TokenManagerError(
|
||||
@@ -622,21 +625,30 @@ export class SharedTokenManager {
|
||||
|
||||
try {
|
||||
// Write to temporary file first with restricted permissions
|
||||
await withTimeout(
|
||||
await this.withTimeout(
|
||||
fs.writeFile(tempPath, credString, { mode: 0o600 }),
|
||||
5000,
|
||||
'File operation',
|
||||
);
|
||||
|
||||
// Atomic move to final location
|
||||
await withTimeout(fs.rename(tempPath, filePath), 5000);
|
||||
await this.withTimeout(
|
||||
fs.rename(tempPath, filePath),
|
||||
5000,
|
||||
'File operation',
|
||||
);
|
||||
|
||||
// Update cached file modification time atomically after successful write
|
||||
const stats = await withTimeout(fs.stat(filePath), 5000);
|
||||
const stats = await this.withTimeout(
|
||||
fs.stat(filePath),
|
||||
5000,
|
||||
'File operation',
|
||||
);
|
||||
this.memoryCache.fileModTime = stats.mtimeMs;
|
||||
} catch (error) {
|
||||
// Clean up temp file if it exists
|
||||
try {
|
||||
await withTimeout(fs.unlink(tempPath), 1000);
|
||||
await this.withTimeout(fs.unlink(tempPath), 1000, 'File operation');
|
||||
} catch (_cleanupError) {
|
||||
// Ignore cleanup errors - temp file might not exist
|
||||
}
|
||||
|
||||
@@ -185,6 +185,7 @@ You are a helpful assistant.
|
||||
const config = manager.parseSubagentContent(
|
||||
validMarkdown,
|
||||
validConfig.filePath,
|
||||
'project',
|
||||
);
|
||||
|
||||
expect(config.name).toBe('test-agent');
|
||||
@@ -209,6 +210,7 @@ You are a helpful assistant.
|
||||
const config = manager.parseSubagentContent(
|
||||
markdownWithTools,
|
||||
validConfig.filePath,
|
||||
'project',
|
||||
);
|
||||
|
||||
expect(config.tools).toEqual(['read_file', 'write_file']);
|
||||
@@ -229,6 +231,7 @@ You are a helpful assistant.
|
||||
const config = manager.parseSubagentContent(
|
||||
markdownWithModel,
|
||||
validConfig.filePath,
|
||||
'project',
|
||||
);
|
||||
|
||||
expect(config.modelConfig).toEqual({ model: 'custom-model', temp: 0.5 });
|
||||
@@ -249,6 +252,7 @@ You are a helpful assistant.
|
||||
const config = manager.parseSubagentContent(
|
||||
markdownWithRun,
|
||||
validConfig.filePath,
|
||||
'project',
|
||||
);
|
||||
|
||||
expect(config.runConfig).toEqual({ max_time_minutes: 5, max_turns: 10 });
|
||||
@@ -266,6 +270,7 @@ You are a helpful assistant.
|
||||
const config = manager.parseSubagentContent(
|
||||
markdownWithNumeric,
|
||||
validConfig.filePath,
|
||||
'project',
|
||||
);
|
||||
|
||||
expect(config.name).toBe('11');
|
||||
@@ -286,6 +291,7 @@ You are a helpful assistant.
|
||||
const config = manager.parseSubagentContent(
|
||||
markdownWithBoolean,
|
||||
validConfig.filePath,
|
||||
'project',
|
||||
);
|
||||
|
||||
expect(config.name).toBe('true');
|
||||
@@ -301,8 +307,13 @@ You are a helpful assistant.
|
||||
const projectConfig = manager.parseSubagentContent(
|
||||
validMarkdown,
|
||||
projectPath,
|
||||
'project',
|
||||
);
|
||||
const userConfig = manager.parseSubagentContent(
|
||||
validMarkdown,
|
||||
userPath,
|
||||
'user',
|
||||
);
|
||||
const userConfig = manager.parseSubagentContent(validMarkdown, userPath);
|
||||
|
||||
expect(projectConfig.level).toBe('project');
|
||||
expect(userConfig.level).toBe('user');
|
||||
@@ -313,7 +324,11 @@ You are a helpful assistant.
|
||||
Just content`;
|
||||
|
||||
expect(() =>
|
||||
manager.parseSubagentContent(invalidMarkdown, validConfig.filePath),
|
||||
manager.parseSubagentContent(
|
||||
invalidMarkdown,
|
||||
validConfig.filePath,
|
||||
'project',
|
||||
),
|
||||
).toThrow(SubagentError);
|
||||
});
|
||||
|
||||
@@ -326,7 +341,11 @@ You are a helpful assistant.
|
||||
`;
|
||||
|
||||
expect(() =>
|
||||
manager.parseSubagentContent(markdownWithoutName, validConfig.filePath),
|
||||
manager.parseSubagentContent(
|
||||
markdownWithoutName,
|
||||
validConfig.filePath,
|
||||
'project',
|
||||
),
|
||||
).toThrow(SubagentError);
|
||||
});
|
||||
|
||||
@@ -342,39 +361,20 @@ You are a helpful assistant.
|
||||
manager.parseSubagentContent(
|
||||
markdownWithoutDescription,
|
||||
validConfig.filePath,
|
||||
'project',
|
||||
),
|
||||
).toThrow(SubagentError);
|
||||
});
|
||||
|
||||
it('should warn when filename does not match subagent name', () => {
|
||||
const consoleSpy = vi.spyOn(console, 'warn').mockImplementation(() => {});
|
||||
const mismatchedPath = '/test/project/.qwen/agents/wrong-filename.md';
|
||||
|
||||
const config = manager.parseSubagentContent(
|
||||
validMarkdown,
|
||||
mismatchedPath,
|
||||
);
|
||||
|
||||
expect(config.name).toBe('test-agent');
|
||||
expect(consoleSpy).toHaveBeenCalledWith(
|
||||
expect.stringContaining(
|
||||
'Warning: Subagent file "wrong-filename.md" contains name "test-agent"',
|
||||
),
|
||||
);
|
||||
expect(consoleSpy).toHaveBeenCalledWith(
|
||||
expect.stringContaining(
|
||||
'Consider renaming the file to "test-agent.md"',
|
||||
),
|
||||
);
|
||||
|
||||
consoleSpy.mockRestore();
|
||||
});
|
||||
|
||||
it('should not warn when filename matches subagent name', () => {
|
||||
const consoleSpy = vi.spyOn(console, 'warn').mockImplementation(() => {});
|
||||
const matchingPath = '/test/project/.qwen/agents/test-agent.md';
|
||||
|
||||
const config = manager.parseSubagentContent(validMarkdown, matchingPath);
|
||||
const config = manager.parseSubagentContent(
|
||||
validMarkdown,
|
||||
matchingPath,
|
||||
'project',
|
||||
);
|
||||
|
||||
expect(config.name).toBe('test-agent');
|
||||
expect(consoleSpy).not.toHaveBeenCalled();
|
||||
|
||||
@@ -39,6 +39,7 @@ const AGENT_CONFIG_DIR = 'agents';
|
||||
*/
|
||||
export class SubagentManager {
|
||||
private readonly validator: SubagentValidator;
|
||||
private subagentsCache: Map<SubagentLevel, SubagentConfig[]> | null = null;
|
||||
|
||||
constructor(private readonly config: Config) {
|
||||
this.validator = new SubagentValidator();
|
||||
@@ -92,6 +93,8 @@ export class SubagentManager {
|
||||
|
||||
try {
|
||||
await fs.writeFile(filePath, content, 'utf8');
|
||||
// Clear cache after successful creation
|
||||
this.clearCache();
|
||||
} catch (error) {
|
||||
throw new SubagentError(
|
||||
`Failed to write subagent file: ${error instanceof Error ? error.message : 'Unknown error'}`,
|
||||
@@ -180,6 +183,8 @@ export class SubagentManager {
|
||||
|
||||
try {
|
||||
await fs.writeFile(existing.filePath, content, 'utf8');
|
||||
// Clear cache after successful update
|
||||
this.clearCache();
|
||||
} catch (error) {
|
||||
throw new SubagentError(
|
||||
`Failed to update subagent file: ${error instanceof Error ? error.message : 'Unknown error'}`,
|
||||
@@ -236,6 +241,9 @@ export class SubagentManager {
|
||||
name,
|
||||
);
|
||||
}
|
||||
|
||||
// Clear cache after successful deletion
|
||||
this.clearCache();
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -254,9 +262,17 @@ export class SubagentManager {
|
||||
? [options.level]
|
||||
: ['project', 'user', 'builtin'];
|
||||
|
||||
// Check if we should use cache or force refresh
|
||||
const shouldUseCache = !options.force && this.subagentsCache !== null;
|
||||
|
||||
// Initialize cache if it doesn't exist or we're forcing a refresh
|
||||
if (!shouldUseCache) {
|
||||
await this.refreshCache();
|
||||
}
|
||||
|
||||
// Collect subagents from each level (project takes precedence over user, user takes precedence over builtin)
|
||||
for (const level of levelsToCheck) {
|
||||
const levelSubagents = await this.listSubagentsAtLevel(level);
|
||||
const levelSubagents = this.subagentsCache?.get(level) || [];
|
||||
|
||||
for (const subagent of levelSubagents) {
|
||||
// Skip if we've already seen this name (precedence: project > user > builtin)
|
||||
@@ -304,6 +320,30 @@ export class SubagentManager {
|
||||
return subagents;
|
||||
}
|
||||
|
||||
/**
|
||||
* Refreshes the subagents cache by loading all subagents from disk.
|
||||
* This method is called automatically when cache is null or when force=true.
|
||||
*
|
||||
* @private
|
||||
*/
|
||||
private async refreshCache(): Promise<void> {
|
||||
this.subagentsCache = new Map();
|
||||
|
||||
const levels: SubagentLevel[] = ['project', 'user', 'builtin'];
|
||||
|
||||
for (const level of levels) {
|
||||
const levelSubagents = await this.listSubagentsAtLevel(level);
|
||||
this.subagentsCache.set(level, levelSubagents);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Clears the subagents cache, forcing the next listSubagents call to reload from disk.
|
||||
*/
|
||||
clearCache(): void {
|
||||
this.subagentsCache = null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Finds a subagent by name and returns its metadata.
|
||||
*
|
||||
@@ -329,7 +369,10 @@ export class SubagentManager {
|
||||
* @returns SubagentConfig
|
||||
* @throws SubagentError if parsing fails
|
||||
*/
|
||||
async parseSubagentFile(filePath: string): Promise<SubagentConfig> {
|
||||
async parseSubagentFile(
|
||||
filePath: string,
|
||||
level: SubagentLevel,
|
||||
): Promise<SubagentConfig> {
|
||||
let content: string;
|
||||
|
||||
try {
|
||||
@@ -341,7 +384,7 @@ export class SubagentManager {
|
||||
);
|
||||
}
|
||||
|
||||
return this.parseSubagentContent(content, filePath);
|
||||
return this.parseSubagentContent(content, filePath, level);
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -352,7 +395,11 @@ export class SubagentManager {
|
||||
* @returns SubagentConfig
|
||||
* @throws SubagentError if parsing fails
|
||||
*/
|
||||
parseSubagentContent(content: string, filePath: string): SubagentConfig {
|
||||
parseSubagentContent(
|
||||
content: string,
|
||||
filePath: string,
|
||||
level: SubagentLevel,
|
||||
): SubagentConfig {
|
||||
try {
|
||||
// Split frontmatter and content
|
||||
const frontmatterRegex = /^---\n([\s\S]*?)\n---\n([\s\S]*)$/;
|
||||
@@ -393,31 +440,16 @@ export class SubagentManager {
|
||||
| undefined;
|
||||
const color = frontmatter['color'] as string | undefined;
|
||||
|
||||
// Determine level from file path using robust, cross-platform check
|
||||
// A project-level agent lives under <projectRoot>/.qwen/agents
|
||||
const projectAgentsDir = path.join(
|
||||
this.config.getProjectRoot(),
|
||||
QWEN_CONFIG_DIR,
|
||||
AGENT_CONFIG_DIR,
|
||||
);
|
||||
const rel = path.relative(
|
||||
path.normalize(projectAgentsDir),
|
||||
path.normalize(filePath),
|
||||
);
|
||||
const isProjectLevel =
|
||||
rel !== '' && !rel.startsWith('..') && !path.isAbsolute(rel);
|
||||
const level: SubagentLevel = isProjectLevel ? 'project' : 'user';
|
||||
|
||||
const config: SubagentConfig = {
|
||||
name,
|
||||
description,
|
||||
tools,
|
||||
systemPrompt: systemPrompt.trim(),
|
||||
level,
|
||||
filePath,
|
||||
modelConfig: modelConfig as Partial<ModelConfig>,
|
||||
runConfig: runConfig as Partial<RunConfig>,
|
||||
color,
|
||||
level,
|
||||
};
|
||||
|
||||
// Validate the parsed configuration
|
||||
@@ -426,16 +458,6 @@ export class SubagentManager {
|
||||
throw new Error(`Validation failed: ${validation.errors.join(', ')}`);
|
||||
}
|
||||
|
||||
// Warn if filename doesn't match subagent name (potential issue)
|
||||
const expectedFilename = `${config.name}.md`;
|
||||
const actualFilename = path.basename(filePath);
|
||||
if (actualFilename !== expectedFilename) {
|
||||
console.warn(
|
||||
`Warning: Subagent file "${actualFilename}" contains name "${config.name}" but filename suggests "${path.basename(actualFilename, '.md')}". ` +
|
||||
`Consider renaming the file to "${expectedFilename}" for consistency.`,
|
||||
);
|
||||
}
|
||||
|
||||
return config;
|
||||
} catch (error) {
|
||||
throw new SubagentError(
|
||||
@@ -678,14 +700,18 @@ export class SubagentManager {
|
||||
return BuiltinAgentRegistry.getBuiltinAgents();
|
||||
}
|
||||
|
||||
const baseDir =
|
||||
level === 'project'
|
||||
? path.join(
|
||||
this.config.getProjectRoot(),
|
||||
QWEN_CONFIG_DIR,
|
||||
AGENT_CONFIG_DIR,
|
||||
)
|
||||
: path.join(os.homedir(), QWEN_CONFIG_DIR, AGENT_CONFIG_DIR);
|
||||
const projectRoot = this.config.getProjectRoot();
|
||||
const homeDir = os.homedir();
|
||||
const isHomeDirectory = path.resolve(projectRoot) === path.resolve(homeDir);
|
||||
|
||||
// If project level is requested but project root is same as home directory,
|
||||
// return empty array to avoid conflicts between project and global agents
|
||||
if (level === 'project' && isHomeDirectory) {
|
||||
return [];
|
||||
}
|
||||
|
||||
let baseDir = level === 'project' ? projectRoot : homeDir;
|
||||
baseDir = path.join(baseDir, QWEN_CONFIG_DIR, AGENT_CONFIG_DIR);
|
||||
|
||||
try {
|
||||
const files = await fs.readdir(baseDir);
|
||||
@@ -697,7 +723,7 @@ export class SubagentManager {
|
||||
const filePath = path.join(baseDir, file);
|
||||
|
||||
try {
|
||||
const config = await this.parseSubagentFile(filePath);
|
||||
const config = await this.parseSubagentFile(filePath, level);
|
||||
subagents.push(config);
|
||||
} catch (_error) {
|
||||
// Ignore invalid files
|
||||
|
||||
@@ -23,7 +23,11 @@ import {
|
||||
} from 'vitest';
|
||||
import { Config, type ConfigParameters } from '../config/config.js';
|
||||
import { DEFAULT_GEMINI_MODEL } from '../config/models.js';
|
||||
import { createContentGenerator } from '../core/contentGenerator.js';
|
||||
import {
|
||||
createContentGenerator,
|
||||
createContentGeneratorConfig,
|
||||
AuthType,
|
||||
} from '../core/contentGenerator.js';
|
||||
import { GeminiChat } from '../core/geminiChat.js';
|
||||
import { executeToolCall } from '../core/nonInteractiveToolExecutor.js';
|
||||
import type { ToolRegistry } from '../tools/tool-registry.js';
|
||||
@@ -56,8 +60,7 @@ async function createMockConfig(
|
||||
};
|
||||
const config = new Config(configParams);
|
||||
await config.initialize();
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
await config.refreshAuth('test-auth' as any);
|
||||
await config.refreshAuth(AuthType.USE_GEMINI);
|
||||
|
||||
// Mock ToolRegistry
|
||||
const mockToolRegistry = {
|
||||
@@ -164,6 +167,10 @@ describe('subagent.ts', () => {
|
||||
getGenerativeModel: vi.fn(),
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
} as any);
|
||||
vi.mocked(createContentGeneratorConfig).mockReturnValue({
|
||||
model: DEFAULT_GEMINI_MODEL,
|
||||
authType: undefined,
|
||||
});
|
||||
|
||||
mockSendMessageStream = vi.fn();
|
||||
// We mock the implementation of the constructor.
|
||||
|
||||
@@ -116,6 +116,9 @@ export interface ListSubagentsOptions {
|
||||
|
||||
/** Sort direction */
|
||||
sortOrder?: 'asc' | 'desc';
|
||||
|
||||
/** Force refresh from disk, bypassing cache. Defaults to false. */
|
||||
force?: boolean;
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@@ -24,6 +24,7 @@ import { ApprovalMode } from '../config/config.js';
|
||||
import { ensureCorrectEdit } from '../utils/editCorrector.js';
|
||||
import { DEFAULT_DIFF_OPTIONS, getDiffStat } from './diffOptions.js';
|
||||
import { ReadFileTool } from './read-file.js';
|
||||
import { ToolNames } from './tool-names.js';
|
||||
import type {
|
||||
ModifiableDeclarativeTool,
|
||||
ModifyContext,
|
||||
@@ -461,7 +462,7 @@ export class EditTool
|
||||
extends BaseDeclarativeTool<EditToolParams, ToolResult>
|
||||
implements ModifiableDeclarativeTool<EditToolParams>
|
||||
{
|
||||
static readonly Name = 'edit';
|
||||
static readonly Name = ToolNames.EDIT;
|
||||
constructor(private readonly config: Config) {
|
||||
super(
|
||||
EditTool.Name,
|
||||
|
||||
@@ -62,6 +62,9 @@ describe('GlobTool', () => {
|
||||
// Ensure a noticeable difference in modification time
|
||||
await new Promise((resolve) => setTimeout(resolve, 50));
|
||||
await fs.writeFile(path.join(tempRootDir, 'newer.sortme'), 'newer_content');
|
||||
|
||||
// For type coercion testing
|
||||
await fs.mkdir(path.join(tempRootDir, '123'));
|
||||
});
|
||||
|
||||
afterEach(async () => {
|
||||
@@ -279,26 +282,20 @@ describe('GlobTool', () => {
|
||||
);
|
||||
});
|
||||
|
||||
it('should return error if path is provided but is not a string (schema validation)', () => {
|
||||
it('should pass if path is provided but is not a string (type coercion)', () => {
|
||||
const params = {
|
||||
pattern: '*.ts',
|
||||
path: 123,
|
||||
};
|
||||
// @ts-expect-error - We're intentionally creating invalid params for testing
|
||||
expect(globTool.validateToolParams(params)).toBe(
|
||||
'params/path must be string',
|
||||
);
|
||||
} as unknown as GlobToolParams; // Force incorrect type
|
||||
expect(globTool.validateToolParams(params)).toBeNull();
|
||||
});
|
||||
|
||||
it('should return error if case_sensitive is provided but is not a boolean (schema validation)', () => {
|
||||
it('should pass if case_sensitive is provided but is not a boolean (type coercion)', () => {
|
||||
const params = {
|
||||
pattern: '*.ts',
|
||||
case_sensitive: 'true',
|
||||
};
|
||||
// @ts-expect-error - We're intentionally creating invalid params for testing
|
||||
expect(globTool.validateToolParams(params)).toBe(
|
||||
'params/case_sensitive must be boolean',
|
||||
);
|
||||
} as unknown as GlobToolParams; // Force incorrect type
|
||||
expect(globTool.validateToolParams(params)).toBeNull();
|
||||
});
|
||||
|
||||
it("should return error if search path resolves outside the tool's root directory", () => {
|
||||
|
||||
@@ -9,6 +9,7 @@ import path from 'node:path';
|
||||
import { glob, escape } from 'glob';
|
||||
import type { ToolInvocation, ToolResult } from './tools.js';
|
||||
import { BaseDeclarativeTool, BaseToolInvocation, Kind } from './tools.js';
|
||||
import { ToolNames } from './tool-names.js';
|
||||
import { shortenPath, makeRelative } from '../utils/paths.js';
|
||||
import type { Config } from '../config/config.js';
|
||||
import { ToolErrorType } from './tool-error.js';
|
||||
@@ -252,7 +253,7 @@ class GlobToolInvocation extends BaseToolInvocation<
|
||||
* Implementation of the Glob tool logic
|
||||
*/
|
||||
export class GlobTool extends BaseDeclarativeTool<GlobToolParams, ToolResult> {
|
||||
static readonly Name = 'glob';
|
||||
static readonly Name = ToolNames.GLOB;
|
||||
|
||||
constructor(private config: Config) {
|
||||
super(
|
||||
|
||||
@@ -12,6 +12,7 @@ import { spawn } from 'node:child_process';
|
||||
import { globStream } from 'glob';
|
||||
import type { ToolInvocation, ToolResult } from './tools.js';
|
||||
import { BaseDeclarativeTool, BaseToolInvocation, Kind } from './tools.js';
|
||||
import { ToolNames } from './tool-names.js';
|
||||
import { makeRelative, shortenPath } from '../utils/paths.js';
|
||||
import { getErrorMessage, isNodeError } from '../utils/errors.js';
|
||||
import { isGitRepository } from '../utils/gitUtils.js';
|
||||
@@ -597,7 +598,7 @@ class GrepToolInvocation extends BaseToolInvocation<
|
||||
* Implementation of the Grep tool logic (moved from CLI)
|
||||
*/
|
||||
export class GrepTool extends BaseDeclarativeTool<GrepToolParams, ToolResult> {
|
||||
static readonly Name = 'search_file_content'; // Keep static name
|
||||
static readonly Name = ToolNames.GREP;
|
||||
|
||||
constructor(private readonly config: Config) {
|
||||
super(
|
||||
|
||||
@@ -8,6 +8,7 @@ import path from 'node:path';
|
||||
import { makeRelative, shortenPath } from '../utils/paths.js';
|
||||
import type { ToolInvocation, ToolLocation, ToolResult } from './tools.js';
|
||||
import { BaseDeclarativeTool, BaseToolInvocation, Kind } from './tools.js';
|
||||
import { ToolNames } from './tool-names.js';
|
||||
|
||||
import type { PartUnion } from '@google/genai';
|
||||
import {
|
||||
@@ -136,7 +137,7 @@ export class ReadFileTool extends BaseDeclarativeTool<
|
||||
ReadFileToolParams,
|
||||
ToolResult
|
||||
> {
|
||||
static readonly Name: string = 'read_file';
|
||||
static readonly Name: string = ToolNames.READ_FILE;
|
||||
|
||||
constructor(private config: Config) {
|
||||
super(
|
||||
|
||||
@@ -191,14 +191,12 @@ describe('ReadManyFilesTool', () => {
|
||||
);
|
||||
});
|
||||
|
||||
it('should throw error if include array contains non-string elements', () => {
|
||||
it('should coerce non-string elements in include array', () => {
|
||||
const params = {
|
||||
paths: ['file1.txt'],
|
||||
include: ['*.ts', 123] as string[],
|
||||
};
|
||||
expect(() => tool.build(params)).toThrow(
|
||||
'params/include/1 must be string',
|
||||
);
|
||||
expect(() => tool.build(params)).toBeDefined();
|
||||
});
|
||||
|
||||
it('should throw error if exclude array contains non-string elements', () => {
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
|
||||
import type { ToolInvocation, ToolResult } from './tools.js';
|
||||
import { BaseDeclarativeTool, BaseToolInvocation, Kind } from './tools.js';
|
||||
import { ToolNames } from './tool-names.js';
|
||||
import { getErrorMessage } from '../utils/errors.js';
|
||||
import * as fs from 'node:fs';
|
||||
import * as path from 'node:path';
|
||||
@@ -526,7 +527,7 @@ export class ReadManyFilesTool extends BaseDeclarativeTool<
|
||||
ReadManyFilesParams,
|
||||
ToolResult
|
||||
> {
|
||||
static readonly Name: string = 'read_many_files';
|
||||
static readonly Name: string = ToolNames.READ_MANY_FILES;
|
||||
|
||||
constructor(private config: Config) {
|
||||
const parameterSchema = {
|
||||
|
||||
@@ -8,7 +8,6 @@ import fs from 'node:fs';
|
||||
import path from 'node:path';
|
||||
import { EOL } from 'node:os';
|
||||
import { spawn } from 'node:child_process';
|
||||
import { rgPath } from '@lvce-editor/ripgrep';
|
||||
import type { ToolInvocation, ToolResult } from './tools.js';
|
||||
import { BaseDeclarativeTool, BaseToolInvocation, Kind } from './tools.js';
|
||||
import { SchemaValidator } from '../utils/schemaValidator.js';
|
||||
@@ -18,6 +17,14 @@ import type { Config } from '../config/config.js';
|
||||
|
||||
const DEFAULT_TOTAL_MAX_MATCHES = 20000;
|
||||
|
||||
/**
|
||||
* Lazy loads the ripgrep binary path to avoid loading the library until needed
|
||||
*/
|
||||
async function getRipgrepPath(): Promise<string> {
|
||||
const { rgPath } = await import('@lvce-editor/ripgrep');
|
||||
return rgPath;
|
||||
}
|
||||
|
||||
/**
|
||||
* Parameters for the GrepTool
|
||||
*/
|
||||
@@ -292,8 +299,9 @@ class GrepToolInvocation extends BaseToolInvocation<
|
||||
rgArgs.push(absolutePath);
|
||||
|
||||
try {
|
||||
const ripgrepPath = await getRipgrepPath();
|
||||
const output = await new Promise<string>((resolve, reject) => {
|
||||
const child = spawn(rgPath, rgArgs, {
|
||||
const child = spawn(ripgrepPath, rgArgs, {
|
||||
windowsHide: true,
|
||||
});
|
||||
|
||||
|
||||
@@ -9,6 +9,7 @@ import path from 'node:path';
|
||||
import os, { EOL } from 'node:os';
|
||||
import crypto from 'node:crypto';
|
||||
import type { Config } from '../config/config.js';
|
||||
import { ToolNames } from './tool-names.js';
|
||||
import { ToolErrorType } from './tool-error.js';
|
||||
import type {
|
||||
ToolInvocation,
|
||||
@@ -403,7 +404,7 @@ export class ShellTool extends BaseDeclarativeTool<
|
||||
ShellToolParams,
|
||||
ToolResult
|
||||
> {
|
||||
static Name: string = 'run_shell_command';
|
||||
static Name: string = ToolNames.SHELL;
|
||||
private allowlist: Set<string> = new Set();
|
||||
|
||||
constructor(private readonly config: Config) {
|
||||
@@ -419,6 +420,11 @@ export class ShellTool extends BaseDeclarativeTool<
|
||||
type: 'string',
|
||||
description: getCommandDescription(),
|
||||
},
|
||||
is_background: {
|
||||
type: 'boolean',
|
||||
description:
|
||||
'Whether to run the command in background. Default is false. Set to true for long-running processes like development servers, watchers, or daemons that should continue running without blocking further commands.',
|
||||
},
|
||||
description: {
|
||||
type: 'string',
|
||||
description:
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
*/
|
||||
|
||||
import { BaseDeclarativeTool, BaseToolInvocation, Kind } from './tools.js';
|
||||
import { ToolNames } from './tool-names.js';
|
||||
import type {
|
||||
ToolResult,
|
||||
ToolResultDisplay,
|
||||
@@ -46,7 +47,7 @@ export interface TaskParams {
|
||||
* for the model to choose from.
|
||||
*/
|
||||
export class TaskTool extends BaseDeclarativeTool<TaskParams, ToolResult> {
|
||||
static readonly Name: string = 'task';
|
||||
static readonly Name: string = ToolNames.TASK;
|
||||
|
||||
private subagentManager: SubagentManager;
|
||||
private availableSubagents: SubagentConfig[] = [];
|
||||
|
||||
23
packages/core/src/tools/tool-names.ts
Normal file
23
packages/core/src/tools/tool-names.ts
Normal file
@@ -0,0 +1,23 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
/**
|
||||
* Tool name constants to avoid circular dependencies.
|
||||
* These constants are used across multiple files and should be kept in sync
|
||||
* with the actual tool class names.
|
||||
*/
|
||||
export const ToolNames = {
|
||||
EDIT: 'edit',
|
||||
WRITE_FILE: 'write_file',
|
||||
READ_FILE: 'read_file',
|
||||
READ_MANY_FILES: 'read_many_files',
|
||||
GREP: 'search_file_content',
|
||||
GLOB: 'glob',
|
||||
SHELL: 'run_shell_command',
|
||||
TODO_WRITE: 'todo_write',
|
||||
MEMORY: 'save_memory',
|
||||
TASK: 'task',
|
||||
} as const;
|
||||
@@ -220,14 +220,12 @@ describe('WriteFileTool', () => {
|
||||
);
|
||||
});
|
||||
|
||||
it('should throw an error if the content is null', () => {
|
||||
const dirAsFilePath = path.join(rootDir, 'a_directory');
|
||||
fs.mkdirSync(dirAsFilePath);
|
||||
it('should coerce null content into an empty string', () => {
|
||||
const params = {
|
||||
file_path: dirAsFilePath,
|
||||
file_path: path.join(rootDir, 'test.txt'),
|
||||
content: null,
|
||||
} as unknown as WriteFileToolParams; // Intentionally non-conforming
|
||||
expect(() => tool.build(params)).toThrow('params/content must be string');
|
||||
expect(() => tool.build(params)).toBeDefined();
|
||||
});
|
||||
|
||||
it('should throw error if the file_path is empty', () => {
|
||||
|
||||
@@ -31,6 +31,7 @@ import {
|
||||
ensureCorrectFileContent,
|
||||
} from '../utils/editCorrector.js';
|
||||
import { DEFAULT_DIFF_OPTIONS, getDiffStat } from './diffOptions.js';
|
||||
import { ToolNames } from './tool-names.js';
|
||||
import type {
|
||||
ModifiableDeclarativeTool,
|
||||
ModifyContext,
|
||||
@@ -403,7 +404,7 @@ export class WriteFileTool
|
||||
extends BaseDeclarativeTool<WriteFileToolParams, ToolResult>
|
||||
implements ModifiableDeclarativeTool<WriteFileToolParams>
|
||||
{
|
||||
static readonly Name: string = 'write_file';
|
||||
static readonly Name: string = ToolNames.WRITE_FILE;
|
||||
|
||||
constructor(private readonly config: Config) {
|
||||
super(
|
||||
|
||||
@@ -7,11 +7,7 @@
|
||||
import type { Content, GenerateContentConfig } from '@google/genai';
|
||||
import type { GeminiClient } from '../core/client.js';
|
||||
import type { EditToolParams } from '../tools/edit.js';
|
||||
import { EditTool } from '../tools/edit.js';
|
||||
import { WriteFileTool } from '../tools/write-file.js';
|
||||
import { ReadFileTool } from '../tools/read-file.js';
|
||||
import { ReadManyFilesTool } from '../tools/read-many-files.js';
|
||||
import { GrepTool } from '../tools/grep.js';
|
||||
import { ToolNames } from '../tools/tool-names.js';
|
||||
import { LruCache } from './LruCache.js';
|
||||
import { DEFAULT_QWEN_FLASH_MODEL } from '../config/models.js';
|
||||
import {
|
||||
@@ -85,14 +81,14 @@ async function findLastEditTimestamp(
|
||||
const history = (await client.getHistory()) ?? [];
|
||||
|
||||
// Tools that may reference the file path in their FunctionResponse `output`.
|
||||
const toolsInResp = new Set([
|
||||
WriteFileTool.Name,
|
||||
EditTool.Name,
|
||||
ReadManyFilesTool.Name,
|
||||
GrepTool.Name,
|
||||
const toolsInResp = new Set<string>([
|
||||
ToolNames.WRITE_FILE,
|
||||
ToolNames.EDIT,
|
||||
ToolNames.READ_MANY_FILES,
|
||||
ToolNames.GREP,
|
||||
]);
|
||||
// Tools that may reference the file path in their FunctionCall `args`.
|
||||
const toolsInCall = new Set([...toolsInResp, ReadFileTool.Name]);
|
||||
const toolsInCall = new Set<string>([...toolsInResp, ToolNames.READ_FILE]);
|
||||
|
||||
// Iterate backwards to find the most recent relevant action.
|
||||
for (const entry of history.slice().reverse()) {
|
||||
|
||||
157
packages/core/src/utils/request-tokenizer/imageTokenizer.test.ts
Normal file
157
packages/core/src/utils/request-tokenizer/imageTokenizer.test.ts
Normal file
@@ -0,0 +1,157 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Qwen
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { describe, it, expect } from 'vitest';
|
||||
import { ImageTokenizer } from './imageTokenizer.js';
|
||||
|
||||
describe('ImageTokenizer', () => {
|
||||
const tokenizer = new ImageTokenizer();
|
||||
|
||||
describe('token calculation', () => {
|
||||
it('should calculate tokens based on image dimensions with reference logic', () => {
|
||||
const metadata = {
|
||||
width: 28,
|
||||
height: 28,
|
||||
mimeType: 'image/png',
|
||||
dataSize: 1000,
|
||||
};
|
||||
|
||||
const tokens = tokenizer.calculateTokens(metadata);
|
||||
|
||||
// 28x28 = 784 pixels = 1 image token + 2 special tokens = 3 total
|
||||
// But minimum scaling may apply for small images
|
||||
expect(tokens).toBeGreaterThanOrEqual(6); // Minimum after scaling + special tokens
|
||||
});
|
||||
|
||||
it('should calculate tokens for larger images', () => {
|
||||
const metadata = {
|
||||
width: 512,
|
||||
height: 512,
|
||||
mimeType: 'image/png',
|
||||
dataSize: 10000,
|
||||
};
|
||||
|
||||
const tokens = tokenizer.calculateTokens(metadata);
|
||||
|
||||
// 512x512 with reference logic: rounded dimensions + scaling + special tokens
|
||||
expect(tokens).toBeGreaterThan(300);
|
||||
expect(tokens).toBeLessThan(400); // Should be reasonable for 512x512
|
||||
});
|
||||
|
||||
it('should enforce minimum tokens per image with scaling', () => {
|
||||
const metadata = {
|
||||
width: 1,
|
||||
height: 1,
|
||||
mimeType: 'image/png',
|
||||
dataSize: 100,
|
||||
};
|
||||
|
||||
const tokens = tokenizer.calculateTokens(metadata);
|
||||
|
||||
// Tiny images get scaled up to minimum pixels + special tokens
|
||||
expect(tokens).toBeGreaterThanOrEqual(6); // 4 image tokens + 2 special tokens
|
||||
});
|
||||
|
||||
it('should handle very large images with scaling', () => {
|
||||
const metadata = {
|
||||
width: 8192,
|
||||
height: 8192,
|
||||
mimeType: 'image/png',
|
||||
dataSize: 100000,
|
||||
};
|
||||
|
||||
const tokens = tokenizer.calculateTokens(metadata);
|
||||
|
||||
// Very large images should be scaled down to max limit + special tokens
|
||||
expect(tokens).toBeLessThanOrEqual(16386); // 16384 max + 2 special tokens
|
||||
expect(tokens).toBeGreaterThan(16000); // Should be close to the limit
|
||||
});
|
||||
});
|
||||
|
||||
describe('PNG dimension extraction', () => {
|
||||
it('should extract dimensions from valid PNG', async () => {
|
||||
// 1x1 PNG image in base64
|
||||
const pngBase64 =
|
||||
'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNkYPhfDwAChAI9jU77yQAAAABJRU5ErkJggg==';
|
||||
|
||||
const metadata = await tokenizer.extractImageMetadata(
|
||||
pngBase64,
|
||||
'image/png',
|
||||
);
|
||||
|
||||
expect(metadata.width).toBe(1);
|
||||
expect(metadata.height).toBe(1);
|
||||
expect(metadata.mimeType).toBe('image/png');
|
||||
});
|
||||
|
||||
it('should handle invalid PNG gracefully', async () => {
|
||||
const invalidBase64 = 'invalid-png-data';
|
||||
|
||||
const metadata = await tokenizer.extractImageMetadata(
|
||||
invalidBase64,
|
||||
'image/png',
|
||||
);
|
||||
|
||||
// Should return default dimensions
|
||||
expect(metadata.width).toBe(512);
|
||||
expect(metadata.height).toBe(512);
|
||||
expect(metadata.mimeType).toBe('image/png');
|
||||
});
|
||||
});
|
||||
|
||||
describe('batch processing', () => {
|
||||
it('should process multiple images serially', async () => {
|
||||
const pngBase64 =
|
||||
'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNkYPhfDwAChAI9jU77yQAAAABJRU5ErkJggg==';
|
||||
|
||||
const images = [
|
||||
{ data: pngBase64, mimeType: 'image/png' },
|
||||
{ data: pngBase64, mimeType: 'image/png' },
|
||||
{ data: pngBase64, mimeType: 'image/png' },
|
||||
];
|
||||
|
||||
const tokens = await tokenizer.calculateTokensBatch(images);
|
||||
|
||||
expect(tokens).toHaveLength(3);
|
||||
expect(tokens.every((t) => t >= 4)).toBe(true); // All should have at least 4 tokens
|
||||
});
|
||||
|
||||
it('should handle mixed valid and invalid images', async () => {
|
||||
const validPng =
|
||||
'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNkYPhfDwAChAI9jU77yQAAAABJRU5ErkJggg==';
|
||||
const invalidPng = 'invalid-data';
|
||||
|
||||
const images = [
|
||||
{ data: validPng, mimeType: 'image/png' },
|
||||
{ data: invalidPng, mimeType: 'image/png' },
|
||||
];
|
||||
|
||||
const tokens = await tokenizer.calculateTokensBatch(images);
|
||||
|
||||
expect(tokens).toHaveLength(2);
|
||||
expect(tokens.every((t) => t >= 4)).toBe(true); // All should have at least minimum tokens
|
||||
});
|
||||
});
|
||||
|
||||
describe('different image formats', () => {
|
||||
it('should handle different MIME types', async () => {
|
||||
const pngBase64 =
|
||||
'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNkYPhfDwAChAI9jU77yQAAAABJRU5ErkJggg==';
|
||||
|
||||
const formats = ['image/png', 'image/jpeg', 'image/webp', 'image/gif'];
|
||||
|
||||
for (const mimeType of formats) {
|
||||
const metadata = await tokenizer.extractImageMetadata(
|
||||
pngBase64,
|
||||
mimeType,
|
||||
);
|
||||
expect(metadata.mimeType).toBe(mimeType);
|
||||
expect(metadata.width).toBeGreaterThan(0);
|
||||
expect(metadata.height).toBeGreaterThan(0);
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
505
packages/core/src/utils/request-tokenizer/imageTokenizer.ts
Normal file
505
packages/core/src/utils/request-tokenizer/imageTokenizer.ts
Normal file
@@ -0,0 +1,505 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Qwen
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import type { ImageMetadata } from './types.js';
|
||||
import { isSupportedImageMimeType } from './supportedImageFormats.js';
|
||||
|
||||
/**
|
||||
* Image tokenizer for calculating image tokens based on dimensions
|
||||
*
|
||||
* Key rules:
|
||||
* - 28x28 pixels = 1 token
|
||||
* - Minimum: 4 tokens per image
|
||||
* - Maximum: 16384 tokens per image
|
||||
* - Additional: 2 special tokens (vision_bos + vision_eos)
|
||||
* - Supports: PNG, JPEG, WebP, GIF, BMP, TIFF, HEIC formats
|
||||
*/
|
||||
export class ImageTokenizer {
|
||||
/** 28x28 pixels = 1 token */
|
||||
private static readonly PIXELS_PER_TOKEN = 28 * 28;
|
||||
|
||||
/** Minimum tokens per image */
|
||||
private static readonly MIN_TOKENS_PER_IMAGE = 4;
|
||||
|
||||
/** Maximum tokens per image */
|
||||
private static readonly MAX_TOKENS_PER_IMAGE = 16384;
|
||||
|
||||
/** Special tokens for vision markers */
|
||||
private static readonly VISION_SPECIAL_TOKENS = 2;
|
||||
|
||||
/**
|
||||
* Extract image metadata from base64 data
|
||||
*
|
||||
* @param base64Data Base64-encoded image data (with or without data URL prefix)
|
||||
* @param mimeType MIME type of the image
|
||||
* @returns Promise resolving to ImageMetadata with dimensions and format info
|
||||
*/
|
||||
async extractImageMetadata(
|
||||
base64Data: string,
|
||||
mimeType: string,
|
||||
): Promise<ImageMetadata> {
|
||||
try {
|
||||
// Check if the MIME type is supported
|
||||
if (!isSupportedImageMimeType(mimeType)) {
|
||||
console.warn(`Unsupported image format: ${mimeType}`);
|
||||
// Return default metadata for unsupported formats
|
||||
return {
|
||||
width: 512,
|
||||
height: 512,
|
||||
mimeType,
|
||||
dataSize: Math.floor(base64Data.length * 0.75),
|
||||
};
|
||||
}
|
||||
|
||||
const cleanBase64 = base64Data.replace(/^data:[^;]+;base64,/, '');
|
||||
const buffer = Buffer.from(cleanBase64, 'base64');
|
||||
const dimensions = await this.extractDimensions(buffer, mimeType);
|
||||
|
||||
return {
|
||||
width: dimensions.width,
|
||||
height: dimensions.height,
|
||||
mimeType,
|
||||
dataSize: buffer.length,
|
||||
};
|
||||
} catch (error) {
|
||||
console.warn('Failed to extract image metadata:', error);
|
||||
// Return default metadata for fallback
|
||||
return {
|
||||
width: 512,
|
||||
height: 512,
|
||||
mimeType,
|
||||
dataSize: Math.floor(base64Data.length * 0.75),
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Extract image dimensions from buffer based on format
|
||||
*
|
||||
* @param buffer Binary image data buffer
|
||||
* @param mimeType MIME type to determine parsing strategy
|
||||
* @returns Promise resolving to width and height dimensions
|
||||
*/
|
||||
private async extractDimensions(
|
||||
buffer: Buffer,
|
||||
mimeType: string,
|
||||
): Promise<{ width: number; height: number }> {
|
||||
if (mimeType.includes('png')) {
|
||||
return this.extractPngDimensions(buffer);
|
||||
}
|
||||
|
||||
if (mimeType.includes('jpeg') || mimeType.includes('jpg')) {
|
||||
return this.extractJpegDimensions(buffer);
|
||||
}
|
||||
|
||||
if (mimeType.includes('webp')) {
|
||||
return this.extractWebpDimensions(buffer);
|
||||
}
|
||||
|
||||
if (mimeType.includes('gif')) {
|
||||
return this.extractGifDimensions(buffer);
|
||||
}
|
||||
|
||||
if (mimeType.includes('bmp')) {
|
||||
return this.extractBmpDimensions(buffer);
|
||||
}
|
||||
|
||||
if (mimeType.includes('tiff')) {
|
||||
return this.extractTiffDimensions(buffer);
|
||||
}
|
||||
|
||||
if (mimeType.includes('heic')) {
|
||||
return this.extractHeicDimensions(buffer);
|
||||
}
|
||||
|
||||
return { width: 512, height: 512 };
|
||||
}
|
||||
|
||||
/**
|
||||
* Extract PNG dimensions from IHDR chunk
|
||||
* PNG signature: 89 50 4E 47 0D 0A 1A 0A
|
||||
* Width/height at bytes 16-19 and 20-23 (big-endian)
|
||||
*/
|
||||
private extractPngDimensions(buffer: Buffer): {
|
||||
width: number;
|
||||
height: number;
|
||||
} {
|
||||
if (buffer.length < 24) {
|
||||
throw new Error('Invalid PNG: buffer too short');
|
||||
}
|
||||
|
||||
// Verify PNG signature
|
||||
const signature = buffer.subarray(0, 8);
|
||||
const expectedSignature = Buffer.from([
|
||||
0x89, 0x50, 0x4e, 0x47, 0x0d, 0x0a, 0x1a, 0x0a,
|
||||
]);
|
||||
if (!signature.equals(expectedSignature)) {
|
||||
throw new Error('Invalid PNG signature');
|
||||
}
|
||||
|
||||
const width = buffer.readUInt32BE(16);
|
||||
const height = buffer.readUInt32BE(20);
|
||||
|
||||
return { width, height };
|
||||
}
|
||||
|
||||
/**
|
||||
* Extract JPEG dimensions from SOF (Start of Frame) markers
|
||||
* JPEG starts with FF D8, SOF markers: 0xC0-0xC3, 0xC5-0xC7, 0xC9-0xCB, 0xCD-0xCF
|
||||
* Dimensions at offset +5 (height) and +7 (width) from SOF marker
|
||||
*/
|
||||
private extractJpegDimensions(buffer: Buffer): {
|
||||
width: number;
|
||||
height: number;
|
||||
} {
|
||||
if (buffer.length < 4 || buffer[0] !== 0xff || buffer[1] !== 0xd8) {
|
||||
throw new Error('Invalid JPEG signature');
|
||||
}
|
||||
|
||||
let offset = 2;
|
||||
|
||||
while (offset < buffer.length - 8) {
|
||||
if (buffer[offset] !== 0xff) {
|
||||
offset++;
|
||||
continue;
|
||||
}
|
||||
|
||||
const marker = buffer[offset + 1];
|
||||
|
||||
// SOF markers
|
||||
if (
|
||||
(marker >= 0xc0 && marker <= 0xc3) ||
|
||||
(marker >= 0xc5 && marker <= 0xc7) ||
|
||||
(marker >= 0xc9 && marker <= 0xcb) ||
|
||||
(marker >= 0xcd && marker <= 0xcf)
|
||||
) {
|
||||
const height = buffer.readUInt16BE(offset + 5);
|
||||
const width = buffer.readUInt16BE(offset + 7);
|
||||
return { width, height };
|
||||
}
|
||||
|
||||
const segmentLength = buffer.readUInt16BE(offset + 2);
|
||||
offset += 2 + segmentLength;
|
||||
}
|
||||
|
||||
throw new Error('Could not find JPEG dimensions');
|
||||
}
|
||||
|
||||
/**
|
||||
* Extract WebP dimensions from RIFF container
|
||||
* Supports VP8, VP8L, and VP8X formats
|
||||
*/
|
||||
private extractWebpDimensions(buffer: Buffer): {
|
||||
width: number;
|
||||
height: number;
|
||||
} {
|
||||
if (buffer.length < 30) {
|
||||
throw new Error('Invalid WebP: too short');
|
||||
}
|
||||
|
||||
const riffSignature = buffer.subarray(0, 4).toString('ascii');
|
||||
const webpSignature = buffer.subarray(8, 12).toString('ascii');
|
||||
|
||||
if (riffSignature !== 'RIFF' || webpSignature !== 'WEBP') {
|
||||
throw new Error('Invalid WebP signature');
|
||||
}
|
||||
|
||||
const format = buffer.subarray(12, 16).toString('ascii');
|
||||
|
||||
if (format === 'VP8 ') {
|
||||
const width = buffer.readUInt16LE(26) & 0x3fff;
|
||||
const height = buffer.readUInt16LE(28) & 0x3fff;
|
||||
return { width, height };
|
||||
} else if (format === 'VP8L') {
|
||||
const bits = buffer.readUInt32LE(21);
|
||||
const width = (bits & 0x3fff) + 1;
|
||||
const height = ((bits >> 14) & 0x3fff) + 1;
|
||||
return { width, height };
|
||||
} else if (format === 'VP8X') {
|
||||
const width = (buffer.readUInt32LE(24) & 0xffffff) + 1;
|
||||
const height = (buffer.readUInt32LE(26) & 0xffffff) + 1;
|
||||
return { width, height };
|
||||
}
|
||||
|
||||
throw new Error('Unsupported WebP format');
|
||||
}
|
||||
|
||||
/**
|
||||
* Extract GIF dimensions from header
|
||||
* Supports GIF87a and GIF89a formats
|
||||
*/
|
||||
private extractGifDimensions(buffer: Buffer): {
|
||||
width: number;
|
||||
height: number;
|
||||
} {
|
||||
if (buffer.length < 10) {
|
||||
throw new Error('Invalid GIF: too short');
|
||||
}
|
||||
|
||||
const signature = buffer.subarray(0, 6).toString('ascii');
|
||||
if (signature !== 'GIF87a' && signature !== 'GIF89a') {
|
||||
throw new Error('Invalid GIF signature');
|
||||
}
|
||||
|
||||
const width = buffer.readUInt16LE(6);
|
||||
const height = buffer.readUInt16LE(8);
|
||||
|
||||
return { width, height };
|
||||
}
|
||||
|
||||
/**
|
||||
* Calculate tokens for an image based on its metadata
|
||||
*
|
||||
* @param metadata Image metadata containing width, height, and format info
|
||||
* @returns Total token count including base image tokens and special tokens
|
||||
*/
|
||||
calculateTokens(metadata: ImageMetadata): number {
|
||||
return this.calculateTokensWithScaling(metadata.width, metadata.height);
|
||||
}
|
||||
|
||||
/**
|
||||
* Calculate tokens with scaling logic
|
||||
*
|
||||
* Steps:
|
||||
* 1. Normalize to 28-pixel multiples
|
||||
* 2. Scale large images down, small images up
|
||||
* 3. Calculate tokens: pixels / 784 + 2 special tokens
|
||||
*
|
||||
* @param width Original image width in pixels
|
||||
* @param height Original image height in pixels
|
||||
* @returns Total token count for the image
|
||||
*/
|
||||
private calculateTokensWithScaling(width: number, height: number): number {
|
||||
// Normalize to 28-pixel multiples
|
||||
let hBar = Math.round(height / 28) * 28;
|
||||
let wBar = Math.round(width / 28) * 28;
|
||||
|
||||
// Define pixel boundaries
|
||||
const minPixels =
|
||||
ImageTokenizer.MIN_TOKENS_PER_IMAGE * ImageTokenizer.PIXELS_PER_TOKEN;
|
||||
const maxPixels =
|
||||
ImageTokenizer.MAX_TOKENS_PER_IMAGE * ImageTokenizer.PIXELS_PER_TOKEN;
|
||||
|
||||
// Apply scaling
|
||||
if (hBar * wBar > maxPixels) {
|
||||
// Scale down large images
|
||||
const beta = Math.sqrt((height * width) / maxPixels);
|
||||
hBar = Math.floor(height / beta / 28) * 28;
|
||||
wBar = Math.floor(width / beta / 28) * 28;
|
||||
} else if (hBar * wBar < minPixels) {
|
||||
// Scale up small images
|
||||
const beta = Math.sqrt(minPixels / (height * width));
|
||||
hBar = Math.ceil((height * beta) / 28) * 28;
|
||||
wBar = Math.ceil((width * beta) / 28) * 28;
|
||||
}
|
||||
|
||||
// Calculate tokens
|
||||
const imageTokens = Math.floor(
|
||||
(hBar * wBar) / ImageTokenizer.PIXELS_PER_TOKEN,
|
||||
);
|
||||
|
||||
return imageTokens + ImageTokenizer.VISION_SPECIAL_TOKENS;
|
||||
}
|
||||
|
||||
/**
|
||||
* Calculate tokens for multiple images serially
|
||||
*
|
||||
* @param base64DataArray Array of image data with MIME type information
|
||||
* @returns Promise resolving to array of token counts in same order as input
|
||||
*/
|
||||
async calculateTokensBatch(
|
||||
base64DataArray: Array<{ data: string; mimeType: string }>,
|
||||
): Promise<number[]> {
|
||||
const results: number[] = [];
|
||||
|
||||
for (const { data, mimeType } of base64DataArray) {
|
||||
try {
|
||||
const metadata = await this.extractImageMetadata(data, mimeType);
|
||||
results.push(this.calculateTokens(metadata));
|
||||
} catch (error) {
|
||||
console.warn('Error calculating tokens for image:', error);
|
||||
// Return minimum tokens as fallback
|
||||
results.push(
|
||||
ImageTokenizer.MIN_TOKENS_PER_IMAGE +
|
||||
ImageTokenizer.VISION_SPECIAL_TOKENS,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
return results;
|
||||
}
|
||||
|
||||
/**
|
||||
* Extract BMP dimensions from header
|
||||
* BMP signature: 42 4D (BM)
|
||||
* Width/height at bytes 18-21 and 22-25 (little-endian)
|
||||
*/
|
||||
private extractBmpDimensions(buffer: Buffer): {
|
||||
width: number;
|
||||
height: number;
|
||||
} {
|
||||
if (buffer.length < 26) {
|
||||
throw new Error('Invalid BMP: buffer too short');
|
||||
}
|
||||
|
||||
// Verify BMP signature
|
||||
if (buffer[0] !== 0x42 || buffer[1] !== 0x4d) {
|
||||
throw new Error('Invalid BMP signature');
|
||||
}
|
||||
|
||||
const width = buffer.readUInt32LE(18);
|
||||
const height = buffer.readUInt32LE(22);
|
||||
|
||||
return { width, height: Math.abs(height) }; // Height can be negative for top-down BMPs
|
||||
}
|
||||
|
||||
/**
|
||||
* Extract TIFF dimensions from IFD (Image File Directory)
|
||||
* TIFF can be little-endian (II) or big-endian (MM)
|
||||
* Width/height are stored in IFD entries with tags 0x0100 and 0x0101
|
||||
*/
|
||||
private extractTiffDimensions(buffer: Buffer): {
|
||||
width: number;
|
||||
height: number;
|
||||
} {
|
||||
if (buffer.length < 8) {
|
||||
throw new Error('Invalid TIFF: buffer too short');
|
||||
}
|
||||
|
||||
// Check byte order
|
||||
const byteOrder = buffer.subarray(0, 2).toString('ascii');
|
||||
const isLittleEndian = byteOrder === 'II';
|
||||
const isBigEndian = byteOrder === 'MM';
|
||||
|
||||
if (!isLittleEndian && !isBigEndian) {
|
||||
throw new Error('Invalid TIFF byte order');
|
||||
}
|
||||
|
||||
// Read magic number (should be 42)
|
||||
const magic = isLittleEndian
|
||||
? buffer.readUInt16LE(2)
|
||||
: buffer.readUInt16BE(2);
|
||||
if (magic !== 42) {
|
||||
throw new Error('Invalid TIFF magic number');
|
||||
}
|
||||
|
||||
// Read IFD offset
|
||||
const ifdOffset = isLittleEndian
|
||||
? buffer.readUInt32LE(4)
|
||||
: buffer.readUInt32BE(4);
|
||||
|
||||
if (ifdOffset >= buffer.length) {
|
||||
throw new Error('Invalid TIFF IFD offset');
|
||||
}
|
||||
|
||||
// Read number of directory entries
|
||||
const numEntries = isLittleEndian
|
||||
? buffer.readUInt16LE(ifdOffset)
|
||||
: buffer.readUInt16BE(ifdOffset);
|
||||
|
||||
let width = 0;
|
||||
let height = 0;
|
||||
|
||||
// Parse IFD entries
|
||||
for (let i = 0; i < numEntries; i++) {
|
||||
const entryOffset = ifdOffset + 2 + i * 12;
|
||||
|
||||
if (entryOffset + 12 > buffer.length) break;
|
||||
|
||||
const tag = isLittleEndian
|
||||
? buffer.readUInt16LE(entryOffset)
|
||||
: buffer.readUInt16BE(entryOffset);
|
||||
|
||||
const type = isLittleEndian
|
||||
? buffer.readUInt16LE(entryOffset + 2)
|
||||
: buffer.readUInt16BE(entryOffset + 2);
|
||||
|
||||
const value = isLittleEndian
|
||||
? buffer.readUInt32LE(entryOffset + 8)
|
||||
: buffer.readUInt32BE(entryOffset + 8);
|
||||
|
||||
if (tag === 0x0100) {
|
||||
// ImageWidth
|
||||
width = type === 3 ? value : value; // SHORT or LONG
|
||||
} else if (tag === 0x0101) {
|
||||
// ImageLength (height)
|
||||
height = type === 3 ? value : value; // SHORT or LONG
|
||||
}
|
||||
|
||||
if (width > 0 && height > 0) break;
|
||||
}
|
||||
|
||||
if (width === 0 || height === 0) {
|
||||
throw new Error('Could not find TIFF dimensions');
|
||||
}
|
||||
|
||||
return { width, height };
|
||||
}
|
||||
|
||||
/**
|
||||
* Extract HEIC dimensions from meta box
|
||||
* HEIC is based on ISO Base Media File Format
|
||||
* This is a simplified implementation that looks for 'ispe' (Image Spatial Extents) box
|
||||
*/
|
||||
private extractHeicDimensions(buffer: Buffer): {
|
||||
width: number;
|
||||
height: number;
|
||||
} {
|
||||
if (buffer.length < 12) {
|
||||
throw new Error('Invalid HEIC: buffer too short');
|
||||
}
|
||||
|
||||
// Check for ftyp box with HEIC brand
|
||||
const ftypBox = buffer.subarray(4, 8).toString('ascii');
|
||||
if (ftypBox !== 'ftyp') {
|
||||
throw new Error('Invalid HEIC: missing ftyp box');
|
||||
}
|
||||
|
||||
const brand = buffer.subarray(8, 12).toString('ascii');
|
||||
if (!['heic', 'heix', 'hevc', 'hevx'].includes(brand)) {
|
||||
throw new Error('Invalid HEIC brand');
|
||||
}
|
||||
|
||||
// Look for meta box and then ispe box
|
||||
let offset = 0;
|
||||
while (offset < buffer.length - 8) {
|
||||
const boxSize = buffer.readUInt32BE(offset);
|
||||
const boxType = buffer.subarray(offset + 4, offset + 8).toString('ascii');
|
||||
|
||||
if (boxType === 'meta') {
|
||||
// Look for ispe box inside meta box
|
||||
const metaOffset = offset + 8;
|
||||
let innerOffset = metaOffset + 4; // Skip version and flags
|
||||
|
||||
while (innerOffset < offset + boxSize - 8) {
|
||||
const innerBoxSize = buffer.readUInt32BE(innerOffset);
|
||||
const innerBoxType = buffer
|
||||
.subarray(innerOffset + 4, innerOffset + 8)
|
||||
.toString('ascii');
|
||||
|
||||
if (innerBoxType === 'ispe') {
|
||||
// Found Image Spatial Extents box
|
||||
if (innerOffset + 20 <= buffer.length) {
|
||||
const width = buffer.readUInt32BE(innerOffset + 12);
|
||||
const height = buffer.readUInt32BE(innerOffset + 16);
|
||||
return { width, height };
|
||||
}
|
||||
}
|
||||
|
||||
if (innerBoxSize === 0) break;
|
||||
innerOffset += innerBoxSize;
|
||||
}
|
||||
}
|
||||
|
||||
if (boxSize === 0) break;
|
||||
offset += boxSize;
|
||||
}
|
||||
|
||||
// Fallback: return default dimensions if we can't parse the structure
|
||||
console.warn('Could not extract HEIC dimensions, using default');
|
||||
return { width: 512, height: 512 };
|
||||
}
|
||||
}
|
||||
40
packages/core/src/utils/request-tokenizer/index.ts
Normal file
40
packages/core/src/utils/request-tokenizer/index.ts
Normal file
@@ -0,0 +1,40 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Qwen
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
export { DefaultRequestTokenizer } from './requestTokenizer.js';
|
||||
import { DefaultRequestTokenizer } from './requestTokenizer.js';
|
||||
export { TextTokenizer } from './textTokenizer.js';
|
||||
export { ImageTokenizer } from './imageTokenizer.js';
|
||||
|
||||
export type {
|
||||
RequestTokenizer,
|
||||
TokenizerConfig,
|
||||
TokenCalculationResult,
|
||||
ImageMetadata,
|
||||
} from './types.js';
|
||||
|
||||
// Singleton instance for convenient usage
|
||||
let defaultTokenizer: DefaultRequestTokenizer | null = null;
|
||||
|
||||
/**
|
||||
* Get the default request tokenizer instance
|
||||
*/
|
||||
export function getDefaultTokenizer(): DefaultRequestTokenizer {
|
||||
if (!defaultTokenizer) {
|
||||
defaultTokenizer = new DefaultRequestTokenizer();
|
||||
}
|
||||
return defaultTokenizer;
|
||||
}
|
||||
|
||||
/**
|
||||
* Dispose of the default tokenizer instance
|
||||
*/
|
||||
export async function disposeDefaultTokenizer(): Promise<void> {
|
||||
if (defaultTokenizer) {
|
||||
await defaultTokenizer.dispose();
|
||||
defaultTokenizer = null;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,293 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Qwen
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { describe, it, expect, beforeEach, afterEach } from 'vitest';
|
||||
import { DefaultRequestTokenizer } from './requestTokenizer.js';
|
||||
import type { CountTokensParameters } from '@google/genai';
|
||||
|
||||
describe('DefaultRequestTokenizer', () => {
|
||||
let tokenizer: DefaultRequestTokenizer;
|
||||
|
||||
beforeEach(() => {
|
||||
tokenizer = new DefaultRequestTokenizer();
|
||||
});
|
||||
|
||||
afterEach(async () => {
|
||||
await tokenizer.dispose();
|
||||
});
|
||||
|
||||
describe('text token calculation', () => {
|
||||
it('should calculate tokens for simple text content', async () => {
|
||||
const request: CountTokensParameters = {
|
||||
model: 'test-model',
|
||||
contents: [
|
||||
{
|
||||
role: 'user',
|
||||
parts: [{ text: 'Hello, world!' }],
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
const result = await tokenizer.calculateTokens(request);
|
||||
|
||||
expect(result.totalTokens).toBeGreaterThan(0);
|
||||
expect(result.breakdown.textTokens).toBeGreaterThan(0);
|
||||
expect(result.breakdown.imageTokens).toBe(0);
|
||||
expect(result.processingTime).toBeGreaterThan(0);
|
||||
});
|
||||
|
||||
it('should handle multiple text parts', async () => {
|
||||
const request: CountTokensParameters = {
|
||||
model: 'test-model',
|
||||
contents: [
|
||||
{
|
||||
role: 'user',
|
||||
parts: [
|
||||
{ text: 'First part' },
|
||||
{ text: 'Second part' },
|
||||
{ text: 'Third part' },
|
||||
],
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
const result = await tokenizer.calculateTokens(request);
|
||||
|
||||
expect(result.totalTokens).toBeGreaterThan(0);
|
||||
expect(result.breakdown.textTokens).toBeGreaterThan(0);
|
||||
});
|
||||
|
||||
it('should handle string content', async () => {
|
||||
const request: CountTokensParameters = {
|
||||
model: 'test-model',
|
||||
contents: ['Simple string content'],
|
||||
};
|
||||
|
||||
const result = await tokenizer.calculateTokens(request);
|
||||
|
||||
expect(result.totalTokens).toBeGreaterThan(0);
|
||||
expect(result.breakdown.textTokens).toBeGreaterThan(0);
|
||||
});
|
||||
});
|
||||
|
||||
describe('image token calculation', () => {
|
||||
it('should calculate tokens for image content', async () => {
|
||||
// Create a simple 1x1 PNG image in base64
|
||||
const pngBase64 =
|
||||
'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNkYPhfDwAChAI9jU77yQAAAABJRU5ErkJggg==';
|
||||
|
||||
const request: CountTokensParameters = {
|
||||
model: 'test-model',
|
||||
contents: [
|
||||
{
|
||||
role: 'user',
|
||||
parts: [
|
||||
{
|
||||
inlineData: {
|
||||
mimeType: 'image/png',
|
||||
data: pngBase64,
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
const result = await tokenizer.calculateTokens(request);
|
||||
|
||||
expect(result.totalTokens).toBeGreaterThanOrEqual(4); // Minimum 4 tokens per image
|
||||
expect(result.breakdown.imageTokens).toBeGreaterThanOrEqual(4);
|
||||
expect(result.breakdown.textTokens).toBe(0);
|
||||
});
|
||||
|
||||
it('should handle multiple images', async () => {
|
||||
const pngBase64 =
|
||||
'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNkYPhfDwAChAI9jU77yQAAAABJRU5ErkJggg==';
|
||||
|
||||
const request: CountTokensParameters = {
|
||||
model: 'test-model',
|
||||
contents: [
|
||||
{
|
||||
role: 'user',
|
||||
parts: [
|
||||
{
|
||||
inlineData: {
|
||||
mimeType: 'image/png',
|
||||
data: pngBase64,
|
||||
},
|
||||
},
|
||||
{
|
||||
inlineData: {
|
||||
mimeType: 'image/png',
|
||||
data: pngBase64,
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
const result = await tokenizer.calculateTokens(request);
|
||||
|
||||
expect(result.totalTokens).toBeGreaterThanOrEqual(8); // At least 4 tokens per image
|
||||
expect(result.breakdown.imageTokens).toBeGreaterThanOrEqual(8);
|
||||
});
|
||||
});
|
||||
|
||||
describe('mixed content', () => {
|
||||
it('should handle text and image content together', async () => {
|
||||
const pngBase64 =
|
||||
'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNkYPhfDwAChAI9jU77yQAAAABJRU5ErkJggg==';
|
||||
|
||||
const request: CountTokensParameters = {
|
||||
model: 'test-model',
|
||||
contents: [
|
||||
{
|
||||
role: 'user',
|
||||
parts: [
|
||||
{ text: 'Here is an image:' },
|
||||
{
|
||||
inlineData: {
|
||||
mimeType: 'image/png',
|
||||
data: pngBase64,
|
||||
},
|
||||
},
|
||||
{ text: 'What do you see?' },
|
||||
],
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
const result = await tokenizer.calculateTokens(request);
|
||||
|
||||
expect(result.totalTokens).toBeGreaterThan(4);
|
||||
expect(result.breakdown.textTokens).toBeGreaterThan(0);
|
||||
expect(result.breakdown.imageTokens).toBeGreaterThanOrEqual(4);
|
||||
});
|
||||
});
|
||||
|
||||
describe('function content', () => {
|
||||
it('should handle function calls', async () => {
|
||||
const request: CountTokensParameters = {
|
||||
model: 'test-model',
|
||||
contents: [
|
||||
{
|
||||
role: 'user',
|
||||
parts: [
|
||||
{
|
||||
functionCall: {
|
||||
name: 'test_function',
|
||||
args: { param1: 'value1', param2: 42 },
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
const result = await tokenizer.calculateTokens(request);
|
||||
|
||||
expect(result.totalTokens).toBeGreaterThan(0);
|
||||
expect(result.breakdown.otherTokens).toBeGreaterThan(0);
|
||||
});
|
||||
});
|
||||
|
||||
describe('empty content', () => {
|
||||
it('should handle empty request', async () => {
|
||||
const request: CountTokensParameters = {
|
||||
model: 'test-model',
|
||||
contents: [],
|
||||
};
|
||||
|
||||
const result = await tokenizer.calculateTokens(request);
|
||||
|
||||
expect(result.totalTokens).toBe(0);
|
||||
expect(result.breakdown.textTokens).toBe(0);
|
||||
expect(result.breakdown.imageTokens).toBe(0);
|
||||
});
|
||||
|
||||
it('should handle undefined contents', async () => {
|
||||
const request: CountTokensParameters = {
|
||||
model: 'test-model',
|
||||
contents: [],
|
||||
};
|
||||
|
||||
const result = await tokenizer.calculateTokens(request);
|
||||
|
||||
expect(result.totalTokens).toBe(0);
|
||||
});
|
||||
});
|
||||
|
||||
describe('configuration', () => {
|
||||
it('should use custom text encoding', async () => {
|
||||
const request: CountTokensParameters = {
|
||||
model: 'test-model',
|
||||
contents: [
|
||||
{
|
||||
role: 'user',
|
||||
parts: [{ text: 'Test text for encoding' }],
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
const result = await tokenizer.calculateTokens(request, {
|
||||
textEncoding: 'cl100k_base',
|
||||
});
|
||||
|
||||
expect(result.totalTokens).toBeGreaterThan(0);
|
||||
});
|
||||
|
||||
it('should process multiple images serially', async () => {
|
||||
const pngBase64 =
|
||||
'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNkYPhfDwAChAI9jU77yQAAAABJRU5ErkJggg==';
|
||||
|
||||
const request: CountTokensParameters = {
|
||||
model: 'test-model',
|
||||
contents: [
|
||||
{
|
||||
role: 'user',
|
||||
parts: Array(10).fill({
|
||||
inlineData: {
|
||||
mimeType: 'image/png',
|
||||
data: pngBase64,
|
||||
},
|
||||
}),
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
const result = await tokenizer.calculateTokens(request);
|
||||
|
||||
expect(result.totalTokens).toBeGreaterThanOrEqual(60); // At least 6 tokens per image * 10 images
|
||||
});
|
||||
});
|
||||
|
||||
describe('error handling', () => {
|
||||
it('should handle malformed image data gracefully', async () => {
|
||||
const request: CountTokensParameters = {
|
||||
model: 'test-model',
|
||||
contents: [
|
||||
{
|
||||
role: 'user',
|
||||
parts: [
|
||||
{
|
||||
inlineData: {
|
||||
mimeType: 'image/png',
|
||||
data: 'invalid-base64-data',
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
const result = await tokenizer.calculateTokens(request);
|
||||
|
||||
// Should still return some tokens (fallback to minimum)
|
||||
expect(result.totalTokens).toBeGreaterThanOrEqual(4);
|
||||
});
|
||||
});
|
||||
});
|
||||
341
packages/core/src/utils/request-tokenizer/requestTokenizer.ts
Normal file
341
packages/core/src/utils/request-tokenizer/requestTokenizer.ts
Normal file
@@ -0,0 +1,341 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Qwen
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import type {
|
||||
CountTokensParameters,
|
||||
Content,
|
||||
Part,
|
||||
PartUnion,
|
||||
} from '@google/genai';
|
||||
import type {
|
||||
RequestTokenizer,
|
||||
TokenizerConfig,
|
||||
TokenCalculationResult,
|
||||
} from './types.js';
|
||||
import { TextTokenizer } from './textTokenizer.js';
|
||||
import { ImageTokenizer } from './imageTokenizer.js';
|
||||
|
||||
/**
|
||||
* Simple request tokenizer that handles text and image content serially
|
||||
*/
|
||||
export class DefaultRequestTokenizer implements RequestTokenizer {
|
||||
private textTokenizer: TextTokenizer;
|
||||
private imageTokenizer: ImageTokenizer;
|
||||
|
||||
constructor() {
|
||||
this.textTokenizer = new TextTokenizer();
|
||||
this.imageTokenizer = new ImageTokenizer();
|
||||
}
|
||||
|
||||
/**
|
||||
* Calculate tokens for a request using serial processing
|
||||
*/
|
||||
async calculateTokens(
|
||||
request: CountTokensParameters,
|
||||
config: TokenizerConfig = {},
|
||||
): Promise<TokenCalculationResult> {
|
||||
const startTime = performance.now();
|
||||
|
||||
// Apply configuration
|
||||
if (config.textEncoding) {
|
||||
this.textTokenizer = new TextTokenizer(config.textEncoding);
|
||||
}
|
||||
|
||||
try {
|
||||
// Process request content and group by type
|
||||
const { textContents, imageContents, audioContents, otherContents } =
|
||||
this.processAndGroupContents(request);
|
||||
|
||||
if (
|
||||
textContents.length === 0 &&
|
||||
imageContents.length === 0 &&
|
||||
audioContents.length === 0 &&
|
||||
otherContents.length === 0
|
||||
) {
|
||||
return {
|
||||
totalTokens: 0,
|
||||
breakdown: {
|
||||
textTokens: 0,
|
||||
imageTokens: 0,
|
||||
audioTokens: 0,
|
||||
otherTokens: 0,
|
||||
},
|
||||
processingTime: performance.now() - startTime,
|
||||
};
|
||||
}
|
||||
|
||||
// Calculate tokens for each content type serially
|
||||
const textTokens = await this.calculateTextTokens(textContents);
|
||||
const imageTokens = await this.calculateImageTokens(imageContents);
|
||||
const audioTokens = await this.calculateAudioTokens(audioContents);
|
||||
const otherTokens = await this.calculateOtherTokens(otherContents);
|
||||
|
||||
const totalTokens = textTokens + imageTokens + audioTokens + otherTokens;
|
||||
const processingTime = performance.now() - startTime;
|
||||
|
||||
return {
|
||||
totalTokens,
|
||||
breakdown: {
|
||||
textTokens,
|
||||
imageTokens,
|
||||
audioTokens,
|
||||
otherTokens,
|
||||
},
|
||||
processingTime,
|
||||
};
|
||||
} catch (error) {
|
||||
console.error('Error calculating tokens:', error);
|
||||
|
||||
// Fallback calculation
|
||||
const fallbackTokens = this.calculateFallbackTokens(request);
|
||||
|
||||
return {
|
||||
totalTokens: fallbackTokens,
|
||||
breakdown: {
|
||||
textTokens: fallbackTokens,
|
||||
imageTokens: 0,
|
||||
audioTokens: 0,
|
||||
otherTokens: 0,
|
||||
},
|
||||
processingTime: performance.now() - startTime,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Calculate tokens for text contents
|
||||
*/
|
||||
private async calculateTextTokens(textContents: string[]): Promise<number> {
|
||||
if (textContents.length === 0) return 0;
|
||||
|
||||
try {
|
||||
const tokenCounts =
|
||||
await this.textTokenizer.calculateTokensBatch(textContents);
|
||||
return tokenCounts.reduce((sum, count) => sum + count, 0);
|
||||
} catch (error) {
|
||||
console.warn('Error calculating text tokens:', error);
|
||||
// Fallback: character-based estimation
|
||||
const totalChars = textContents.join('').length;
|
||||
return Math.ceil(totalChars / 4);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Calculate tokens for image contents using serial processing
|
||||
*/
|
||||
private async calculateImageTokens(
|
||||
imageContents: Array<{ data: string; mimeType: string }>,
|
||||
): Promise<number> {
|
||||
if (imageContents.length === 0) return 0;
|
||||
|
||||
try {
|
||||
const tokenCounts =
|
||||
await this.imageTokenizer.calculateTokensBatch(imageContents);
|
||||
return tokenCounts.reduce((sum, count) => sum + count, 0);
|
||||
} catch (error) {
|
||||
console.warn('Error calculating image tokens:', error);
|
||||
// Fallback: minimum tokens per image
|
||||
return imageContents.length * 6; // 4 image tokens + 2 special tokens as minimum
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Calculate tokens for audio contents
|
||||
* TODO: Implement proper audio token calculation
|
||||
*/
|
||||
private async calculateAudioTokens(
|
||||
audioContents: Array<{ data: string; mimeType: string }>,
|
||||
): Promise<number> {
|
||||
if (audioContents.length === 0) return 0;
|
||||
|
||||
// Placeholder implementation - audio token calculation would depend on
|
||||
// the specific model's audio processing capabilities
|
||||
// For now, estimate based on data size
|
||||
let totalTokens = 0;
|
||||
|
||||
for (const audioContent of audioContents) {
|
||||
try {
|
||||
const dataSize = Math.floor(audioContent.data.length * 0.75); // Approximate binary size
|
||||
// Rough estimate: 1 token per 100 bytes of audio data
|
||||
totalTokens += Math.max(Math.ceil(dataSize / 100), 10); // Minimum 10 tokens per audio
|
||||
} catch (error) {
|
||||
console.warn('Error calculating audio tokens:', error);
|
||||
totalTokens += 10; // Fallback minimum
|
||||
}
|
||||
}
|
||||
|
||||
return totalTokens;
|
||||
}
|
||||
|
||||
/**
|
||||
* Calculate tokens for other content types (functions, files, etc.)
|
||||
*/
|
||||
private async calculateOtherTokens(otherContents: string[]): Promise<number> {
|
||||
if (otherContents.length === 0) return 0;
|
||||
|
||||
try {
|
||||
// Treat other content as text for token calculation
|
||||
const tokenCounts =
|
||||
await this.textTokenizer.calculateTokensBatch(otherContents);
|
||||
return tokenCounts.reduce((sum, count) => sum + count, 0);
|
||||
} catch (error) {
|
||||
console.warn('Error calculating other content tokens:', error);
|
||||
// Fallback: character-based estimation
|
||||
const totalChars = otherContents.join('').length;
|
||||
return Math.ceil(totalChars / 4);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Fallback token calculation using simple string serialization
|
||||
*/
|
||||
private calculateFallbackTokens(request: CountTokensParameters): number {
|
||||
try {
|
||||
const content = JSON.stringify(request.contents);
|
||||
return Math.ceil(content.length / 4); // Rough estimate: 1 token ≈ 4 characters
|
||||
} catch (error) {
|
||||
console.warn('Error in fallback token calculation:', error);
|
||||
return 100; // Conservative fallback
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Process request contents and group by type
|
||||
*/
|
||||
private processAndGroupContents(request: CountTokensParameters): {
|
||||
textContents: string[];
|
||||
imageContents: Array<{ data: string; mimeType: string }>;
|
||||
audioContents: Array<{ data: string; mimeType: string }>;
|
||||
otherContents: string[];
|
||||
} {
|
||||
const textContents: string[] = [];
|
||||
const imageContents: Array<{ data: string; mimeType: string }> = [];
|
||||
const audioContents: Array<{ data: string; mimeType: string }> = [];
|
||||
const otherContents: string[] = [];
|
||||
|
||||
if (!request.contents) {
|
||||
return { textContents, imageContents, audioContents, otherContents };
|
||||
}
|
||||
|
||||
const contents = Array.isArray(request.contents)
|
||||
? request.contents
|
||||
: [request.contents];
|
||||
|
||||
for (const content of contents) {
|
||||
this.processContent(
|
||||
content,
|
||||
textContents,
|
||||
imageContents,
|
||||
audioContents,
|
||||
otherContents,
|
||||
);
|
||||
}
|
||||
|
||||
return { textContents, imageContents, audioContents, otherContents };
|
||||
}
|
||||
|
||||
/**
|
||||
* Process a single content item and add to appropriate arrays
|
||||
*/
|
||||
private processContent(
|
||||
content: Content | string | PartUnion,
|
||||
textContents: string[],
|
||||
imageContents: Array<{ data: string; mimeType: string }>,
|
||||
audioContents: Array<{ data: string; mimeType: string }>,
|
||||
otherContents: string[],
|
||||
): void {
|
||||
if (typeof content === 'string') {
|
||||
if (content.trim()) {
|
||||
textContents.push(content);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if ('parts' in content && content.parts) {
|
||||
for (const part of content.parts) {
|
||||
this.processPart(
|
||||
part,
|
||||
textContents,
|
||||
imageContents,
|
||||
audioContents,
|
||||
otherContents,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Process a single part and add to appropriate arrays
|
||||
*/
|
||||
private processPart(
|
||||
part: Part | string,
|
||||
textContents: string[],
|
||||
imageContents: Array<{ data: string; mimeType: string }>,
|
||||
audioContents: Array<{ data: string; mimeType: string }>,
|
||||
otherContents: string[],
|
||||
): void {
|
||||
if (typeof part === 'string') {
|
||||
if (part.trim()) {
|
||||
textContents.push(part);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if ('text' in part && part.text) {
|
||||
textContents.push(part.text);
|
||||
return;
|
||||
}
|
||||
|
||||
if ('inlineData' in part && part.inlineData) {
|
||||
const { data, mimeType } = part.inlineData;
|
||||
if (mimeType && mimeType.startsWith('image/')) {
|
||||
imageContents.push({ data: data || '', mimeType });
|
||||
return;
|
||||
}
|
||||
if (mimeType && mimeType.startsWith('audio/')) {
|
||||
audioContents.push({ data: data || '', mimeType });
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
if ('fileData' in part && part.fileData) {
|
||||
otherContents.push(JSON.stringify(part.fileData));
|
||||
return;
|
||||
}
|
||||
|
||||
if ('functionCall' in part && part.functionCall) {
|
||||
otherContents.push(JSON.stringify(part.functionCall));
|
||||
return;
|
||||
}
|
||||
|
||||
if ('functionResponse' in part && part.functionResponse) {
|
||||
otherContents.push(JSON.stringify(part.functionResponse));
|
||||
return;
|
||||
}
|
||||
|
||||
// Unknown part type - try to serialize
|
||||
try {
|
||||
const serialized = JSON.stringify(part);
|
||||
if (serialized && serialized !== '{}') {
|
||||
otherContents.push(serialized);
|
||||
}
|
||||
} catch (error) {
|
||||
console.warn('Failed to serialize unknown part type:', error);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Dispose of resources
|
||||
*/
|
||||
async dispose(): Promise<void> {
|
||||
try {
|
||||
// Dispose of tokenizers
|
||||
this.textTokenizer.dispose();
|
||||
} catch (error) {
|
||||
console.warn('Error disposing request tokenizer:', error);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,56 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Qwen
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
/**
|
||||
* Supported image MIME types for vision models
|
||||
* These formats are supported by the vision model and can be processed by the image tokenizer
|
||||
*/
|
||||
export const SUPPORTED_IMAGE_MIME_TYPES = [
|
||||
'image/bmp',
|
||||
'image/jpeg',
|
||||
'image/jpg', // Alternative MIME type for JPEG
|
||||
'image/png',
|
||||
'image/tiff',
|
||||
'image/webp',
|
||||
'image/heic',
|
||||
] as const;
|
||||
|
||||
/**
|
||||
* Type for supported image MIME types
|
||||
*/
|
||||
export type SupportedImageMimeType =
|
||||
(typeof SUPPORTED_IMAGE_MIME_TYPES)[number];
|
||||
|
||||
/**
|
||||
* Check if a MIME type is supported for vision processing
|
||||
* @param mimeType The MIME type to check
|
||||
* @returns True if the MIME type is supported
|
||||
*/
|
||||
export function isSupportedImageMimeType(
|
||||
mimeType: string,
|
||||
): mimeType is SupportedImageMimeType {
|
||||
return SUPPORTED_IMAGE_MIME_TYPES.includes(
|
||||
mimeType as SupportedImageMimeType,
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get a human-readable list of supported image formats
|
||||
* @returns Comma-separated string of supported formats
|
||||
*/
|
||||
export function getSupportedImageFormatsString(): string {
|
||||
return SUPPORTED_IMAGE_MIME_TYPES.map((type) =>
|
||||
type.replace('image/', '').toUpperCase(),
|
||||
).join(', ');
|
||||
}
|
||||
|
||||
/**
|
||||
* Get warning message for unsupported image formats
|
||||
* @returns Warning message string
|
||||
*/
|
||||
export function getUnsupportedImageFormatWarning(): string {
|
||||
return `Only the following image formats are supported: ${getSupportedImageFormatsString()}. Other formats may not work as expected.`;
|
||||
}
|
||||
347
packages/core/src/utils/request-tokenizer/textTokenizer.test.ts
Normal file
347
packages/core/src/utils/request-tokenizer/textTokenizer.test.ts
Normal file
@@ -0,0 +1,347 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Qwen
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
|
||||
import { TextTokenizer } from './textTokenizer.js';
|
||||
|
||||
// Mock tiktoken at the top level with hoisted functions
|
||||
const mockEncode = vi.hoisted(() => vi.fn());
|
||||
const mockFree = vi.hoisted(() => vi.fn());
|
||||
const mockGetEncoding = vi.hoisted(() => vi.fn());
|
||||
|
||||
vi.mock('tiktoken', () => ({
|
||||
get_encoding: mockGetEncoding,
|
||||
}));
|
||||
|
||||
describe('TextTokenizer', () => {
|
||||
let tokenizer: TextTokenizer;
|
||||
let consoleWarnSpy: ReturnType<typeof vi.spyOn>;
|
||||
|
||||
beforeEach(() => {
|
||||
vi.resetAllMocks();
|
||||
consoleWarnSpy = vi.spyOn(console, 'warn').mockImplementation(() => {});
|
||||
|
||||
// Default mock implementation
|
||||
mockGetEncoding.mockReturnValue({
|
||||
encode: mockEncode,
|
||||
free: mockFree,
|
||||
});
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.restoreAllMocks();
|
||||
tokenizer?.dispose();
|
||||
});
|
||||
|
||||
describe('constructor', () => {
|
||||
it('should create tokenizer with default encoding', () => {
|
||||
tokenizer = new TextTokenizer();
|
||||
expect(tokenizer).toBeInstanceOf(TextTokenizer);
|
||||
});
|
||||
|
||||
it('should create tokenizer with custom encoding', () => {
|
||||
tokenizer = new TextTokenizer('gpt2');
|
||||
expect(tokenizer).toBeInstanceOf(TextTokenizer);
|
||||
});
|
||||
});
|
||||
|
||||
describe('calculateTokens', () => {
|
||||
beforeEach(() => {
|
||||
tokenizer = new TextTokenizer();
|
||||
});
|
||||
|
||||
it('should return 0 for empty text', async () => {
|
||||
const result = await tokenizer.calculateTokens('');
|
||||
expect(result).toBe(0);
|
||||
});
|
||||
|
||||
it('should return 0 for null/undefined text', async () => {
|
||||
const result1 = await tokenizer.calculateTokens(
|
||||
null as unknown as string,
|
||||
);
|
||||
const result2 = await tokenizer.calculateTokens(
|
||||
undefined as unknown as string,
|
||||
);
|
||||
expect(result1).toBe(0);
|
||||
expect(result2).toBe(0);
|
||||
});
|
||||
|
||||
it('should calculate tokens using tiktoken when available', async () => {
|
||||
const testText = 'Hello, world!';
|
||||
const mockTokens = [1, 2, 3, 4, 5]; // 5 tokens
|
||||
mockEncode.mockReturnValue(mockTokens);
|
||||
|
||||
const result = await tokenizer.calculateTokens(testText);
|
||||
|
||||
expect(mockGetEncoding).toHaveBeenCalledWith('cl100k_base');
|
||||
expect(mockEncode).toHaveBeenCalledWith(testText);
|
||||
expect(result).toBe(5);
|
||||
});
|
||||
|
||||
it('should use fallback calculation when tiktoken fails to load', async () => {
|
||||
mockGetEncoding.mockImplementation(() => {
|
||||
throw new Error('Failed to load tiktoken');
|
||||
});
|
||||
|
||||
const testText = 'Hello, world!'; // 13 characters
|
||||
const result = await tokenizer.calculateTokens(testText);
|
||||
|
||||
expect(consoleWarnSpy).toHaveBeenCalledWith(
|
||||
'Failed to load tiktoken with encoding cl100k_base:',
|
||||
expect.any(Error),
|
||||
);
|
||||
// Fallback: Math.ceil(13 / 4) = 4
|
||||
expect(result).toBe(4);
|
||||
});
|
||||
|
||||
it('should use fallback calculation when encoding fails', async () => {
|
||||
mockEncode.mockImplementation(() => {
|
||||
throw new Error('Encoding failed');
|
||||
});
|
||||
|
||||
const testText = 'Hello, world!'; // 13 characters
|
||||
const result = await tokenizer.calculateTokens(testText);
|
||||
|
||||
expect(consoleWarnSpy).toHaveBeenCalledWith(
|
||||
'Error encoding text with tiktoken:',
|
||||
expect.any(Error),
|
||||
);
|
||||
// Fallback: Math.ceil(13 / 4) = 4
|
||||
expect(result).toBe(4);
|
||||
});
|
||||
|
||||
it('should handle very long text', async () => {
|
||||
const longText = 'a'.repeat(10000);
|
||||
const mockTokens = new Array(2500); // 2500 tokens
|
||||
mockEncode.mockReturnValue(mockTokens);
|
||||
|
||||
const result = await tokenizer.calculateTokens(longText);
|
||||
|
||||
expect(result).toBe(2500);
|
||||
});
|
||||
|
||||
it('should handle unicode characters', async () => {
|
||||
const unicodeText = '你好世界 🌍';
|
||||
const mockTokens = [1, 2, 3, 4, 5, 6];
|
||||
mockEncode.mockReturnValue(mockTokens);
|
||||
|
||||
const result = await tokenizer.calculateTokens(unicodeText);
|
||||
|
||||
expect(result).toBe(6);
|
||||
});
|
||||
|
||||
it('should use custom encoding when specified', async () => {
|
||||
tokenizer = new TextTokenizer('gpt2');
|
||||
const testText = 'Hello, world!';
|
||||
const mockTokens = [1, 2, 3];
|
||||
mockEncode.mockReturnValue(mockTokens);
|
||||
|
||||
const result = await tokenizer.calculateTokens(testText);
|
||||
|
||||
expect(mockGetEncoding).toHaveBeenCalledWith('gpt2');
|
||||
expect(result).toBe(3);
|
||||
});
|
||||
});
|
||||
|
||||
describe('calculateTokensBatch', () => {
|
||||
beforeEach(() => {
|
||||
tokenizer = new TextTokenizer();
|
||||
});
|
||||
|
||||
it('should process multiple texts and return token counts', async () => {
|
||||
const texts = ['Hello', 'world', 'test'];
|
||||
mockEncode
|
||||
.mockReturnValueOnce([1, 2]) // 2 tokens for 'Hello'
|
||||
.mockReturnValueOnce([3, 4, 5]) // 3 tokens for 'world'
|
||||
.mockReturnValueOnce([6]); // 1 token for 'test'
|
||||
|
||||
const result = await tokenizer.calculateTokensBatch(texts);
|
||||
|
||||
expect(result).toEqual([2, 3, 1]);
|
||||
expect(mockEncode).toHaveBeenCalledTimes(3);
|
||||
});
|
||||
|
||||
it('should handle empty array', async () => {
|
||||
const result = await tokenizer.calculateTokensBatch([]);
|
||||
expect(result).toEqual([]);
|
||||
});
|
||||
|
||||
it('should handle array with empty strings', async () => {
|
||||
const texts = ['', 'hello', ''];
|
||||
mockEncode.mockReturnValue([1, 2, 3]); // Only called for 'hello'
|
||||
|
||||
const result = await tokenizer.calculateTokensBatch(texts);
|
||||
|
||||
expect(result).toEqual([0, 3, 0]);
|
||||
expect(mockEncode).toHaveBeenCalledTimes(1);
|
||||
expect(mockEncode).toHaveBeenCalledWith('hello');
|
||||
});
|
||||
|
||||
it('should use fallback calculation when tiktoken fails to load', async () => {
|
||||
mockGetEncoding.mockImplementation(() => {
|
||||
throw new Error('Failed to load tiktoken');
|
||||
});
|
||||
|
||||
const texts = ['Hello', 'world']; // 5 and 5 characters
|
||||
const result = await tokenizer.calculateTokensBatch(texts);
|
||||
|
||||
expect(consoleWarnSpy).toHaveBeenCalledWith(
|
||||
'Failed to load tiktoken with encoding cl100k_base:',
|
||||
expect.any(Error),
|
||||
);
|
||||
// Fallback: Math.ceil(5/4) = 2 for both
|
||||
expect(result).toEqual([2, 2]);
|
||||
});
|
||||
|
||||
it('should use fallback calculation when encoding fails during batch processing', async () => {
|
||||
mockEncode.mockImplementation(() => {
|
||||
throw new Error('Encoding failed');
|
||||
});
|
||||
|
||||
const texts = ['Hello', 'world']; // 5 and 5 characters
|
||||
const result = await tokenizer.calculateTokensBatch(texts);
|
||||
|
||||
expect(consoleWarnSpy).toHaveBeenCalledWith(
|
||||
'Error encoding texts with tiktoken:',
|
||||
expect.any(Error),
|
||||
);
|
||||
// Fallback: Math.ceil(5/4) = 2 for both
|
||||
expect(result).toEqual([2, 2]);
|
||||
});
|
||||
|
||||
it('should handle null and undefined values in batch', async () => {
|
||||
const texts = [null, 'hello', undefined, 'world'] as unknown as string[];
|
||||
mockEncode
|
||||
.mockReturnValueOnce([1, 2, 3]) // 3 tokens for 'hello'
|
||||
.mockReturnValueOnce([4, 5]); // 2 tokens for 'world'
|
||||
|
||||
const result = await tokenizer.calculateTokensBatch(texts);
|
||||
|
||||
expect(result).toEqual([0, 3, 0, 2]);
|
||||
});
|
||||
});
|
||||
|
||||
describe('dispose', () => {
|
||||
beforeEach(() => {
|
||||
tokenizer = new TextTokenizer();
|
||||
});
|
||||
|
||||
it('should free tiktoken encoding when disposing', async () => {
|
||||
// Initialize the encoding by calling calculateTokens
|
||||
await tokenizer.calculateTokens('test');
|
||||
|
||||
tokenizer.dispose();
|
||||
|
||||
expect(mockFree).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should handle disposal when encoding is not initialized', () => {
|
||||
expect(() => tokenizer.dispose()).not.toThrow();
|
||||
expect(mockFree).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should handle disposal when encoding is null', async () => {
|
||||
// Force encoding to be null by making tiktoken fail
|
||||
mockGetEncoding.mockImplementation(() => {
|
||||
throw new Error('Failed to load');
|
||||
});
|
||||
|
||||
await tokenizer.calculateTokens('test');
|
||||
|
||||
expect(() => tokenizer.dispose()).not.toThrow();
|
||||
expect(mockFree).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should handle errors during disposal gracefully', async () => {
|
||||
await tokenizer.calculateTokens('test');
|
||||
|
||||
mockFree.mockImplementation(() => {
|
||||
throw new Error('Free failed');
|
||||
});
|
||||
|
||||
tokenizer.dispose();
|
||||
|
||||
expect(consoleWarnSpy).toHaveBeenCalledWith(
|
||||
'Error freeing tiktoken encoding:',
|
||||
expect.any(Error),
|
||||
);
|
||||
});
|
||||
|
||||
it('should allow multiple calls to dispose', async () => {
|
||||
await tokenizer.calculateTokens('test');
|
||||
|
||||
tokenizer.dispose();
|
||||
tokenizer.dispose(); // Second call should not throw
|
||||
|
||||
expect(mockFree).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
});
|
||||
|
||||
describe('lazy initialization', () => {
|
||||
beforeEach(() => {
|
||||
tokenizer = new TextTokenizer();
|
||||
});
|
||||
|
||||
it('should not initialize tiktoken until first use', () => {
|
||||
expect(mockGetEncoding).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should initialize tiktoken on first calculateTokens call', async () => {
|
||||
await tokenizer.calculateTokens('test');
|
||||
expect(mockGetEncoding).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
|
||||
it('should not reinitialize tiktoken on subsequent calls', async () => {
|
||||
await tokenizer.calculateTokens('test1');
|
||||
await tokenizer.calculateTokens('test2');
|
||||
|
||||
expect(mockGetEncoding).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
|
||||
it('should initialize tiktoken on first calculateTokensBatch call', async () => {
|
||||
await tokenizer.calculateTokensBatch(['test']);
|
||||
expect(mockGetEncoding).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
});
|
||||
|
||||
describe('edge cases', () => {
|
||||
beforeEach(() => {
|
||||
tokenizer = new TextTokenizer();
|
||||
});
|
||||
|
||||
it('should handle very short text', async () => {
|
||||
const result = await tokenizer.calculateTokens('a');
|
||||
|
||||
if (mockGetEncoding.mock.calls.length > 0) {
|
||||
// If tiktoken was called, use its result
|
||||
expect(mockEncode).toHaveBeenCalledWith('a');
|
||||
} else {
|
||||
// If tiktoken failed, should use fallback: Math.ceil(1/4) = 1
|
||||
expect(result).toBe(1);
|
||||
}
|
||||
});
|
||||
|
||||
it('should handle text with only whitespace', async () => {
|
||||
const whitespaceText = ' \n\t ';
|
||||
const mockTokens = [1];
|
||||
mockEncode.mockReturnValue(mockTokens);
|
||||
|
||||
const result = await tokenizer.calculateTokens(whitespaceText);
|
||||
|
||||
expect(result).toBe(1);
|
||||
});
|
||||
|
||||
it('should handle special characters and symbols', async () => {
|
||||
const specialText = '!@#$%^&*()_+-=[]{}|;:,.<>?';
|
||||
const mockTokens = new Array(10);
|
||||
mockEncode.mockReturnValue(mockTokens);
|
||||
|
||||
const result = await tokenizer.calculateTokens(specialText);
|
||||
|
||||
expect(result).toBe(10);
|
||||
});
|
||||
});
|
||||
});
|
||||
97
packages/core/src/utils/request-tokenizer/textTokenizer.ts
Normal file
97
packages/core/src/utils/request-tokenizer/textTokenizer.ts
Normal file
@@ -0,0 +1,97 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Qwen
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import type { TiktokenEncoding, Tiktoken } from 'tiktoken';
|
||||
import { get_encoding } from 'tiktoken';
|
||||
|
||||
/**
|
||||
* Text tokenizer for calculating text tokens using tiktoken
|
||||
*/
|
||||
export class TextTokenizer {
|
||||
private encoding: Tiktoken | null = null;
|
||||
private encodingName: string;
|
||||
|
||||
constructor(encodingName: string = 'cl100k_base') {
|
||||
this.encodingName = encodingName;
|
||||
}
|
||||
|
||||
/**
|
||||
* Initialize the tokenizer (lazy loading)
|
||||
*/
|
||||
private async ensureEncoding(): Promise<void> {
|
||||
if (this.encoding) return;
|
||||
|
||||
try {
|
||||
// Use type assertion since we know the encoding name is valid
|
||||
this.encoding = get_encoding(this.encodingName as TiktokenEncoding);
|
||||
} catch (error) {
|
||||
console.warn(
|
||||
`Failed to load tiktoken with encoding ${this.encodingName}:`,
|
||||
error,
|
||||
);
|
||||
this.encoding = null;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Calculate tokens for text content
|
||||
*/
|
||||
async calculateTokens(text: string): Promise<number> {
|
||||
if (!text) return 0;
|
||||
|
||||
await this.ensureEncoding();
|
||||
|
||||
if (this.encoding) {
|
||||
try {
|
||||
return this.encoding.encode(text).length;
|
||||
} catch (error) {
|
||||
console.warn('Error encoding text with tiktoken:', error);
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback: rough approximation using character count
|
||||
// This is a conservative estimate: 1 token ≈ 4 characters for most languages
|
||||
return Math.ceil(text.length / 4);
|
||||
}
|
||||
|
||||
/**
|
||||
* Calculate tokens for multiple text strings in parallel
|
||||
*/
|
||||
async calculateTokensBatch(texts: string[]): Promise<number[]> {
|
||||
await this.ensureEncoding();
|
||||
|
||||
if (this.encoding) {
|
||||
try {
|
||||
return texts.map((text) => {
|
||||
if (!text) return 0;
|
||||
// this.encoding may be null, add a null check to satisfy lint
|
||||
return this.encoding ? this.encoding.encode(text).length : 0;
|
||||
});
|
||||
} catch (error) {
|
||||
console.warn('Error encoding texts with tiktoken:', error);
|
||||
// In case of error, return fallback estimation for all texts
|
||||
return texts.map((text) => Math.ceil((text || '').length / 4));
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback for batch processing
|
||||
return texts.map((text) => Math.ceil((text || '').length / 4));
|
||||
}
|
||||
|
||||
/**
|
||||
* Dispose of resources
|
||||
*/
|
||||
dispose(): void {
|
||||
if (this.encoding) {
|
||||
try {
|
||||
this.encoding.free();
|
||||
} catch (error) {
|
||||
console.warn('Error freeing tiktoken encoding:', error);
|
||||
}
|
||||
this.encoding = null;
|
||||
}
|
||||
}
|
||||
}
|
||||
64
packages/core/src/utils/request-tokenizer/types.ts
Normal file
64
packages/core/src/utils/request-tokenizer/types.ts
Normal file
@@ -0,0 +1,64 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Qwen
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import type { CountTokensParameters } from '@google/genai';
|
||||
|
||||
/**
|
||||
* Token calculation result for different content types
|
||||
*/
|
||||
export interface TokenCalculationResult {
|
||||
/** Total tokens calculated */
|
||||
totalTokens: number;
|
||||
/** Breakdown by content type */
|
||||
breakdown: {
|
||||
textTokens: number;
|
||||
imageTokens: number;
|
||||
audioTokens: number;
|
||||
otherTokens: number;
|
||||
};
|
||||
/** Processing time in milliseconds */
|
||||
processingTime: number;
|
||||
}
|
||||
|
||||
/**
|
||||
* Configuration for token calculation
|
||||
*/
|
||||
export interface TokenizerConfig {
|
||||
/** Custom text tokenizer encoding (defaults to cl100k_base) */
|
||||
textEncoding?: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* Image metadata extracted from base64 data
|
||||
*/
|
||||
export interface ImageMetadata {
|
||||
/** Image width in pixels */
|
||||
width: number;
|
||||
/** Image height in pixels */
|
||||
height: number;
|
||||
/** MIME type of the image */
|
||||
mimeType: string;
|
||||
/** Size of the base64 data in bytes */
|
||||
dataSize: number;
|
||||
}
|
||||
|
||||
/**
|
||||
* Request tokenizer interface
|
||||
*/
|
||||
export interface RequestTokenizer {
|
||||
/**
|
||||
* Calculate tokens for a request
|
||||
*/
|
||||
calculateTokens(
|
||||
request: CountTokensParameters,
|
||||
config?: TokenizerConfig,
|
||||
): Promise<TokenCalculationResult>;
|
||||
|
||||
/**
|
||||
* Dispose of resources (worker threads, etc.)
|
||||
*/
|
||||
dispose(): Promise<void>;
|
||||
}
|
||||
@@ -9,7 +9,7 @@ import * as addFormats from 'ajv-formats';
|
||||
// Ajv's ESM/CJS interop: use 'any' for compatibility as recommended by Ajv docs
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
const AjvClass = (AjvPkg as any).default || AjvPkg;
|
||||
const ajValidator = new AjvClass();
|
||||
const ajValidator = new AjvClass({ coerceTypes: true });
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
const addFormatsFunc = (addFormats as any).default || addFormats;
|
||||
addFormatsFunc(ajValidator);
|
||||
@@ -32,8 +32,27 @@ export class SchemaValidator {
|
||||
const validate = ajValidator.compile(schema);
|
||||
const valid = validate(data);
|
||||
if (!valid && validate.errors) {
|
||||
return ajValidator.errorsText(validate.errors, { dataVar: 'params' });
|
||||
// Find any True or False values and lowercase them
|
||||
fixBooleanCasing(data as Record<string, unknown>);
|
||||
|
||||
const validate = ajValidator.compile(schema);
|
||||
const valid = validate(data);
|
||||
|
||||
if (!valid && validate.errors) {
|
||||
return ajValidator.errorsText(validate.errors, { dataVar: 'params' });
|
||||
}
|
||||
}
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
function fixBooleanCasing(data: Record<string, unknown>) {
|
||||
for (const key of Object.keys(data)) {
|
||||
if (!(key in data)) continue;
|
||||
|
||||
if (typeof data[key] === 'object') {
|
||||
fixBooleanCasing(data[key] as Record<string, unknown>);
|
||||
} else if (data[key] === 'True') data[key] = 'true';
|
||||
else if (data[key] === 'False') data[key] = 'false';
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@qwen-code/qwen-code-test-utils",
|
||||
"version": "0.0.11",
|
||||
"version": "0.0.13-nightly.2",
|
||||
"private": true,
|
||||
"main": "src/index.ts",
|
||||
"license": "Apache-2.0",
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
"name": "qwen-code-vscode-ide-companion",
|
||||
"displayName": "Qwen Code Companion",
|
||||
"description": "Enable Qwen Code with direct access to your VS Code workspace.",
|
||||
"version": "0.0.11",
|
||||
"version": "0.0.13-nightly.2",
|
||||
"publisher": "qwenlm",
|
||||
"icon": "assets/icon.png",
|
||||
"repository": {
|
||||
|
||||
Reference in New Issue
Block a user