mirror of
https://github.com/QwenLM/qwen-code.git
synced 2025-12-19 09:33:53 +00:00
This commit is contained in:
@@ -91,6 +91,11 @@ If you are experiencing performance issues with file searching (e.g., with `@` c
|
||||
- **Default:** All tools available for use by the Gemini model.
|
||||
- **Example:** `"coreTools": ["ReadFileTool", "GlobTool", "ShellTool(ls)"]`.
|
||||
|
||||
- **`allowedTools`** (array of strings):
|
||||
- **Default:** `undefined`
|
||||
- **Description:** A list of tool names that will bypass the confirmation dialog. This is useful for tools that you trust and use frequently. The match semantics are the same as `coreTools`.
|
||||
- **Example:** `"allowedTools": ["ShellTool(git status)"]`.
|
||||
|
||||
- **`excludeTools`** (array of strings):
|
||||
- **Description:** Allows you to specify a list of core tool names that should be excluded from the model. A tool listed in both `excludeTools` and `coreTools` is excluded. You can also specify command-specific restrictions for tools that support it, like the `ShellTool`. For example, `"excludeTools": ["ShellTool(rm -rf)"]` will block the `rm -rf` command.
|
||||
- **Default**: No tools excluded.
|
||||
@@ -479,6 +484,9 @@ Arguments passed directly when running the CLI can override other configurations
|
||||
- `yolo`: Automatically approve all tool calls (equivalent to `--yolo`)
|
||||
- Cannot be used together with `--yolo`. Use `--approval-mode=yolo` instead of `--yolo` for the new unified approach.
|
||||
- Example: `gemini --approval-mode auto_edit`
|
||||
- **`--allowed-tools <tool1,tool2,...>`**:
|
||||
- A comma-separated list of tool names that will bypass the confirmation dialog.
|
||||
- Example: `gemini --allowed-tools "ShellTool(git status)"`
|
||||
- **`--telemetry`**:
|
||||
- Enables [telemetry](../telemetry.md).
|
||||
- **`--telemetry-target`**:
|
||||
|
||||
@@ -70,6 +70,7 @@ export interface CliArgs {
|
||||
telemetryLogPrompts: boolean | undefined;
|
||||
telemetryOutfile: string | undefined;
|
||||
allowedMcpServerNames: string[] | undefined;
|
||||
allowedTools: string[] | undefined;
|
||||
experimentalAcp: boolean | undefined;
|
||||
extensions: string[] | undefined;
|
||||
listExtensions: boolean | undefined;
|
||||
@@ -189,6 +190,11 @@ export async function parseArguments(settings: Settings): Promise<CliArgs> {
|
||||
string: true,
|
||||
description: 'Allowed MCP server names',
|
||||
})
|
||||
.option('allowed-tools', {
|
||||
type: 'array',
|
||||
string: true,
|
||||
description: 'Tools that are allowed to run without confirmation',
|
||||
})
|
||||
.option('extensions', {
|
||||
alias: 'e',
|
||||
type: 'array',
|
||||
@@ -489,6 +495,7 @@ export async function loadCliConfig(
|
||||
question,
|
||||
fullContext: argv.allFiles || false,
|
||||
coreTools: settings.coreTools || undefined,
|
||||
allowedTools: argv.allowedTools || settings.allowedTools || undefined,
|
||||
excludeTools,
|
||||
toolDiscoveryCommand: settings.toolDiscoveryCommand,
|
||||
toolCallCommand: settings.toolCallCommand,
|
||||
|
||||
@@ -344,6 +344,16 @@ export const SETTINGS_SCHEMA = {
|
||||
description: 'Paths to core tool definitions.',
|
||||
showInDialog: false,
|
||||
},
|
||||
allowedTools: {
|
||||
type: 'array',
|
||||
label: 'Allowed Tools',
|
||||
category: 'Advanced',
|
||||
requiresRestart: true,
|
||||
default: undefined as string[] | undefined,
|
||||
description:
|
||||
'A list of tool names that will bypass the confirmation dialog.',
|
||||
showInDialog: false,
|
||||
},
|
||||
excludeTools: {
|
||||
type: 'array',
|
||||
label: 'Exclude Tools',
|
||||
|
||||
@@ -53,14 +53,15 @@ const mockToolRegistry = {
|
||||
const mockConfig = {
|
||||
getToolRegistry: vi.fn(() => mockToolRegistry as unknown as ToolRegistry),
|
||||
getApprovalMode: vi.fn(() => ApprovalMode.DEFAULT),
|
||||
getSessionId: () => 'test-session-id',
|
||||
getUsageStatisticsEnabled: () => true,
|
||||
getDebugMode: () => false,
|
||||
getSessionId: () => 'test-session-id',
|
||||
getAllowedTools: vi.fn(() => []),
|
||||
getContentGeneratorConfig: () => ({
|
||||
model: 'test-model',
|
||||
authType: 'oauth-personal',
|
||||
}),
|
||||
};
|
||||
} as unknown as Config;
|
||||
|
||||
class MockToolInvocation extends BaseToolInvocation<object, ToolResult> {
|
||||
constructor(
|
||||
@@ -218,11 +219,6 @@ describe('useReactToolScheduler in YOLO Mode', () => {
|
||||
await vi.runAllTimersAsync(); // Process execution
|
||||
});
|
||||
|
||||
// Check that shouldConfirmExecute was NOT called
|
||||
expect(
|
||||
mockToolRequiresConfirmation.shouldConfirmExecute,
|
||||
).not.toHaveBeenCalled();
|
||||
|
||||
// Check that execute WAS called
|
||||
expect(mockToolRequiresConfirmation.execute).toHaveBeenCalledWith(
|
||||
request.args,
|
||||
|
||||
@@ -23,6 +23,9 @@ import { GeminiClient } from '../core/client.js';
|
||||
import { GitService } from '../services/gitService.js';
|
||||
import { ClearcutLogger } from '../telemetry/clearcut-logger/clearcut-logger.js';
|
||||
|
||||
import { ShellTool } from '../tools/shell.js';
|
||||
import { ReadFileTool } from '../tools/read-file.js';
|
||||
|
||||
vi.mock('fs', async (importOriginal) => {
|
||||
const actual = await importOriginal<typeof import('fs')>();
|
||||
return {
|
||||
@@ -629,6 +632,36 @@ describe('Server Config (config.ts)', () => {
|
||||
expect(config.getUseRipgrep()).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe('createToolRegistry', () => {
|
||||
it('should register a tool if coreTools contains an argument-specific pattern', async () => {
|
||||
const params: ConfigParameters = {
|
||||
...baseParams,
|
||||
coreTools: ['ShellTool(git status)'],
|
||||
};
|
||||
const config = new Config(params);
|
||||
await config.initialize();
|
||||
|
||||
// The ToolRegistry class is mocked, so we can inspect its prototype's methods.
|
||||
const registerToolMock = (
|
||||
(await vi.importMock('../tools/tool-registry')) as {
|
||||
ToolRegistry: { prototype: { registerTool: Mock } };
|
||||
}
|
||||
).ToolRegistry.prototype.registerTool;
|
||||
|
||||
// Check that registerTool was called for ShellTool
|
||||
const wasShellToolRegistered = (registerToolMock as Mock).mock.calls.some(
|
||||
(call) => call[0] instanceof vi.mocked(ShellTool),
|
||||
);
|
||||
expect(wasShellToolRegistered).toBe(true);
|
||||
|
||||
// Check that registerTool was NOT called for ReadFileTool
|
||||
const wasReadFileToolRegistered = (
|
||||
registerToolMock as Mock
|
||||
).mock.calls.some((call) => call[0] instanceof vi.mocked(ReadFileTool));
|
||||
expect(wasReadFileToolRegistered).toBe(false);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('setApprovalMode with folder trust', () => {
|
||||
|
||||
@@ -49,7 +49,8 @@ import { logCliConfiguration, logIdeConnection } from '../telemetry/loggers.js';
|
||||
import { IdeConnectionEvent, IdeConnectionType } from '../telemetry/types.js';
|
||||
|
||||
// Re-export OAuth config type
|
||||
export type { MCPOAuthConfig };
|
||||
export type { MCPOAuthConfig, AnyToolInvocation };
|
||||
import type { AnyToolInvocation } from '../tools/tools.js';
|
||||
import { WorkspaceContext } from '../utils/workspaceContext.js';
|
||||
import { Storage } from './storage.js';
|
||||
import { FileExclusions } from '../utils/ignorePatterns.js';
|
||||
@@ -159,6 +160,7 @@ export interface ConfigParameters {
|
||||
question?: string;
|
||||
fullContext?: boolean;
|
||||
coreTools?: string[];
|
||||
allowedTools?: string[];
|
||||
excludeTools?: string[];
|
||||
toolDiscoveryCommand?: string;
|
||||
toolCallCommand?: string;
|
||||
@@ -221,6 +223,7 @@ export class Config {
|
||||
private readonly question: string | undefined;
|
||||
private readonly fullContext: boolean;
|
||||
private readonly coreTools: string[] | undefined;
|
||||
private readonly allowedTools: string[] | undefined;
|
||||
private readonly excludeTools: string[] | undefined;
|
||||
private readonly toolDiscoveryCommand: string | undefined;
|
||||
private readonly toolCallCommand: string | undefined;
|
||||
@@ -295,6 +298,7 @@ export class Config {
|
||||
this.question = params.question;
|
||||
this.fullContext = params.fullContext ?? false;
|
||||
this.coreTools = params.coreTools;
|
||||
this.allowedTools = params.allowedTools;
|
||||
this.excludeTools = params.excludeTools;
|
||||
this.toolDiscoveryCommand = params.toolDiscoveryCommand;
|
||||
this.toolCallCommand = params.toolCallCommand;
|
||||
@@ -523,6 +527,10 @@ export class Config {
|
||||
return this.coreTools;
|
||||
}
|
||||
|
||||
getAllowedTools(): string[] | undefined {
|
||||
return this.allowedTools;
|
||||
}
|
||||
|
||||
getExcludeTools(): string[] | undefined {
|
||||
return this.excludeTools;
|
||||
}
|
||||
@@ -807,12 +815,10 @@ export class Config {
|
||||
const className = ToolClass.name;
|
||||
const toolName = ToolClass.Name || className;
|
||||
const coreTools = this.getCoreTools();
|
||||
const excludeTools = this.getExcludeTools();
|
||||
const excludeTools = this.getExcludeTools() || [];
|
||||
|
||||
let isEnabled = false;
|
||||
if (coreTools === undefined) {
|
||||
isEnabled = true;
|
||||
} else {
|
||||
let isEnabled = true; // Enabled by default if coreTools is not set.
|
||||
if (coreTools) {
|
||||
isEnabled = coreTools.some(
|
||||
(tool) =>
|
||||
tool === className ||
|
||||
@@ -822,10 +828,11 @@ export class Config {
|
||||
);
|
||||
}
|
||||
|
||||
if (
|
||||
excludeTools?.includes(className) ||
|
||||
excludeTools?.includes(toolName)
|
||||
) {
|
||||
const isExcluded = excludeTools.some(
|
||||
(tool) => tool === className || tool === toolName,
|
||||
);
|
||||
|
||||
if (isExcluded) {
|
||||
isEnabled = false;
|
||||
}
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
*/
|
||||
|
||||
import { describe, it, expect, vi } from 'vitest';
|
||||
import type { Mock } from 'vitest';
|
||||
import type { ToolCall, WaitingToolCall } from './coreToolScheduler.js';
|
||||
import {
|
||||
CoreToolScheduler,
|
||||
@@ -99,6 +100,41 @@ class TestApprovalInvocation extends BaseToolInvocation<
|
||||
}
|
||||
}
|
||||
|
||||
async function waitForStatus(
|
||||
onToolCallsUpdate: Mock,
|
||||
status: 'awaiting_approval' | 'executing' | 'success' | 'error' | 'cancelled',
|
||||
timeout = 5000,
|
||||
): Promise<ToolCall> {
|
||||
return new Promise((resolve, reject) => {
|
||||
const startTime = Date.now();
|
||||
const check = () => {
|
||||
if (Date.now() - startTime > timeout) {
|
||||
const seenStatuses = onToolCallsUpdate.mock.calls
|
||||
.flatMap((call) => call[0])
|
||||
.map((toolCall: ToolCall) => toolCall.status);
|
||||
reject(
|
||||
new Error(
|
||||
`Timed out waiting for status "${status}". Seen statuses: ${seenStatuses.join(
|
||||
', ',
|
||||
)}`,
|
||||
),
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
const foundCall = onToolCallsUpdate.mock.calls
|
||||
.flatMap((call) => call[0])
|
||||
.find((toolCall: ToolCall) => toolCall.status === status);
|
||||
if (foundCall) {
|
||||
resolve(foundCall);
|
||||
} else {
|
||||
setTimeout(check, 10); // Check again in 10ms
|
||||
}
|
||||
};
|
||||
check();
|
||||
});
|
||||
}
|
||||
|
||||
describe('CoreToolScheduler', () => {
|
||||
it('should cancel a tool call if the signal is aborted before confirmation', async () => {
|
||||
const mockTool = new MockTool();
|
||||
@@ -126,6 +162,7 @@ describe('CoreToolScheduler', () => {
|
||||
getUsageStatisticsEnabled: () => true,
|
||||
getDebugMode: () => false,
|
||||
getApprovalMode: () => ApprovalMode.DEFAULT,
|
||||
getAllowedTools: () => [],
|
||||
getContentGeneratorConfig: () => ({
|
||||
model: 'test-model',
|
||||
authType: 'oauth-personal',
|
||||
@@ -186,6 +223,7 @@ describe('CoreToolScheduler with payload', () => {
|
||||
getUsageStatisticsEnabled: () => true,
|
||||
getDebugMode: () => false,
|
||||
getApprovalMode: () => ApprovalMode.DEFAULT,
|
||||
getAllowedTools: () => [],
|
||||
getContentGeneratorConfig: () => ({
|
||||
model: 'test-model',
|
||||
authType: 'oauth-personal',
|
||||
@@ -212,16 +250,10 @@ describe('CoreToolScheduler with payload', () => {
|
||||
|
||||
await scheduler.schedule([request], abortController.signal);
|
||||
|
||||
await vi.waitFor(() => {
|
||||
const awaitingCall = onToolCallsUpdate.mock.calls.find(
|
||||
(call) => call[0][0].status === 'awaiting_approval',
|
||||
)?.[0][0];
|
||||
expect(awaitingCall).toBeDefined();
|
||||
});
|
||||
|
||||
const awaitingCall = onToolCallsUpdate.mock.calls.find(
|
||||
(call) => call[0][0].status === 'awaiting_approval',
|
||||
)?.[0][0];
|
||||
const awaitingCall = (await waitForStatus(
|
||||
onToolCallsUpdate,
|
||||
'awaiting_approval',
|
||||
)) as WaitingToolCall;
|
||||
const confirmationDetails = awaitingCall.confirmationDetails;
|
||||
|
||||
if (confirmationDetails) {
|
||||
@@ -497,6 +529,7 @@ describe('CoreToolScheduler edit cancellation', () => {
|
||||
getUsageStatisticsEnabled: () => true,
|
||||
getDebugMode: () => false,
|
||||
getApprovalMode: () => ApprovalMode.DEFAULT,
|
||||
getAllowedTools: () => [],
|
||||
getContentGeneratorConfig: () => ({
|
||||
model: 'test-model',
|
||||
authType: 'oauth-personal',
|
||||
@@ -523,12 +556,10 @@ describe('CoreToolScheduler edit cancellation', () => {
|
||||
|
||||
await scheduler.schedule([request], abortController.signal);
|
||||
|
||||
// Wait for the tool to reach awaiting_approval state
|
||||
const awaitingCall = onToolCallsUpdate.mock.calls.find(
|
||||
(call) => call[0][0].status === 'awaiting_approval',
|
||||
)?.[0][0];
|
||||
|
||||
expect(awaitingCall).toBeDefined();
|
||||
const awaitingCall = (await waitForStatus(
|
||||
onToolCallsUpdate,
|
||||
'awaiting_approval',
|
||||
)) as WaitingToolCall;
|
||||
|
||||
// Cancel the edit
|
||||
const confirmationDetails = awaitingCall.confirmationDetails;
|
||||
@@ -589,6 +620,7 @@ describe('CoreToolScheduler YOLO mode', () => {
|
||||
getUsageStatisticsEnabled: () => true,
|
||||
getDebugMode: () => false,
|
||||
getApprovalMode: () => ApprovalMode.YOLO,
|
||||
getAllowedTools: () => [],
|
||||
getContentGeneratorConfig: () => ({
|
||||
model: 'test-model',
|
||||
authType: 'oauth-personal',
|
||||
@@ -678,6 +710,7 @@ describe('CoreToolScheduler request queueing', () => {
|
||||
getUsageStatisticsEnabled: () => true,
|
||||
getDebugMode: () => false,
|
||||
getApprovalMode: () => ApprovalMode.YOLO, // Use YOLO to avoid confirmation prompts
|
||||
getAllowedTools: () => [],
|
||||
getContentGeneratorConfig: () => ({
|
||||
model: 'test-model',
|
||||
authType: 'oauth-personal',
|
||||
@@ -713,10 +746,7 @@ describe('CoreToolScheduler request queueing', () => {
|
||||
scheduler.schedule([request1], abortController.signal);
|
||||
|
||||
// Wait for the first call to be in the 'executing' state.
|
||||
await vi.waitFor(() => {
|
||||
const calls = onToolCallsUpdate.mock.calls.at(-1)?.[0] as ToolCall[];
|
||||
expect(calls?.[0]?.status).toBe('executing');
|
||||
});
|
||||
await waitForStatus(onToolCallsUpdate, 'executing');
|
||||
|
||||
// Schedule the second call while the first is "running".
|
||||
const schedulePromise2 = scheduler.schedule(
|
||||
@@ -737,16 +767,6 @@ describe('CoreToolScheduler request queueing', () => {
|
||||
// Wait for the second schedule promise to resolve.
|
||||
await schedulePromise2;
|
||||
|
||||
// Wait for the second call to be in the 'executing' state.
|
||||
await vi.waitFor(() => {
|
||||
const calls = onToolCallsUpdate.mock.calls.at(-1)?.[0] as ToolCall[];
|
||||
expect(calls?.[0]?.status).toBe('executing');
|
||||
});
|
||||
|
||||
// Now the second tool call should have been executed.
|
||||
expect(mockTool.executeFn).toHaveBeenCalledTimes(2);
|
||||
expect(mockTool.executeFn).toHaveBeenCalledWith({ b: 2 });
|
||||
|
||||
// Let the second call finish.
|
||||
const secondCallResult = {
|
||||
llmContent: 'Second call complete',
|
||||
@@ -756,6 +776,12 @@ describe('CoreToolScheduler request queueing', () => {
|
||||
// In a real scenario, a new promise would be created for the second call.
|
||||
resolveFirstCall!(secondCallResult);
|
||||
|
||||
await vi.waitFor(() => {
|
||||
// Now the second tool call should have been executed.
|
||||
expect(mockTool.executeFn).toHaveBeenCalledTimes(2);
|
||||
});
|
||||
expect(mockTool.executeFn).toHaveBeenCalledWith({ b: 2 });
|
||||
|
||||
// Wait for the second completion.
|
||||
await vi.waitFor(() => {
|
||||
expect(onAllToolCallsComplete).toHaveBeenCalledTimes(2);
|
||||
@@ -766,6 +792,96 @@ describe('CoreToolScheduler request queueing', () => {
|
||||
expect(onAllToolCallsComplete.mock.calls[1][0][0].status).toBe('success');
|
||||
});
|
||||
|
||||
it('should auto-approve a tool call if it is on the allowedTools list', async () => {
|
||||
// Arrange
|
||||
const mockTool = new MockTool('mockTool');
|
||||
mockTool.executeFn.mockReturnValue({
|
||||
llmContent: 'Tool executed',
|
||||
returnDisplay: 'Tool executed',
|
||||
});
|
||||
// This tool would normally require confirmation.
|
||||
mockTool.shouldConfirm = true;
|
||||
const declarativeTool = mockTool;
|
||||
|
||||
const toolRegistry = {
|
||||
getTool: () => declarativeTool,
|
||||
getToolByName: () => declarativeTool,
|
||||
getFunctionDeclarations: () => [],
|
||||
tools: new Map(),
|
||||
discovery: {},
|
||||
registerTool: () => {},
|
||||
getToolByDisplayName: () => declarativeTool,
|
||||
getTools: () => [],
|
||||
discoverTools: async () => {},
|
||||
getAllTools: () => [],
|
||||
getToolsByServer: () => [],
|
||||
};
|
||||
|
||||
const onAllToolCallsComplete = vi.fn();
|
||||
const onToolCallsUpdate = vi.fn();
|
||||
|
||||
// Configure the scheduler to auto-approve the specific tool call.
|
||||
const mockConfig = {
|
||||
getSessionId: () => 'test-session-id',
|
||||
getUsageStatisticsEnabled: () => true,
|
||||
getDebugMode: () => false,
|
||||
getApprovalMode: () => ApprovalMode.DEFAULT, // Not YOLO mode
|
||||
getAllowedTools: () => ['mockTool'], // Auto-approve this tool
|
||||
getToolRegistry: () => toolRegistry,
|
||||
getContentGeneratorConfig: () => ({
|
||||
model: 'test-model',
|
||||
authType: 'oauth-personal',
|
||||
}),
|
||||
} as unknown as Config;
|
||||
|
||||
const scheduler = new CoreToolScheduler({
|
||||
config: mockConfig,
|
||||
onAllToolCallsComplete,
|
||||
onToolCallsUpdate,
|
||||
getPreferredEditor: () => 'vscode',
|
||||
onEditorClose: vi.fn(),
|
||||
});
|
||||
|
||||
const abortController = new AbortController();
|
||||
const request = {
|
||||
callId: '1',
|
||||
name: 'mockTool',
|
||||
args: { param: 'value' },
|
||||
isClientInitiated: false,
|
||||
prompt_id: 'prompt-auto-approved',
|
||||
};
|
||||
|
||||
// Act
|
||||
await scheduler.schedule([request], abortController.signal);
|
||||
|
||||
// Assert
|
||||
// 1. The tool's execute method was called directly.
|
||||
expect(mockTool.executeFn).toHaveBeenCalledWith({ param: 'value' });
|
||||
|
||||
// 2. The tool call status never entered 'awaiting_approval'.
|
||||
const statusUpdates = onToolCallsUpdate.mock.calls
|
||||
.map((call) => (call[0][0] as ToolCall)?.status)
|
||||
.filter(Boolean);
|
||||
expect(statusUpdates).not.toContain('awaiting_approval');
|
||||
expect(statusUpdates).toEqual([
|
||||
'validating',
|
||||
'scheduled',
|
||||
'executing',
|
||||
'success',
|
||||
]);
|
||||
|
||||
// 3. The final callback indicates the tool call was successful.
|
||||
expect(onAllToolCallsComplete).toHaveBeenCalled();
|
||||
const completedCalls = onAllToolCallsComplete.mock
|
||||
.calls[0][0] as ToolCall[];
|
||||
expect(completedCalls).toHaveLength(1);
|
||||
const completedCall = completedCalls[0];
|
||||
expect(completedCall.status).toBe('success');
|
||||
if (completedCall.status === 'success') {
|
||||
expect(completedCall.response.resultDisplay).toBe('Tool executed');
|
||||
}
|
||||
});
|
||||
|
||||
it('should handle two synchronous calls to schedule', async () => {
|
||||
const mockTool = new MockTool();
|
||||
const declarativeTool = mockTool;
|
||||
@@ -782,7 +898,6 @@ describe('CoreToolScheduler request queueing', () => {
|
||||
getAllTools: () => [],
|
||||
getToolsByServer: () => [],
|
||||
} as unknown as ToolRegistry;
|
||||
|
||||
const onAllToolCallsComplete = vi.fn();
|
||||
const onToolCallsUpdate = vi.fn();
|
||||
|
||||
@@ -791,6 +906,7 @@ describe('CoreToolScheduler request queueing', () => {
|
||||
getUsageStatisticsEnabled: () => true,
|
||||
getDebugMode: () => false,
|
||||
getApprovalMode: () => ApprovalMode.YOLO,
|
||||
getAllowedTools: () => [],
|
||||
getContentGeneratorConfig: () => ({
|
||||
model: 'test-model',
|
||||
authType: 'oauth-personal',
|
||||
@@ -851,6 +967,7 @@ describe('CoreToolScheduler request queueing', () => {
|
||||
getUsageStatisticsEnabled: () => true,
|
||||
getDebugMode: () => false,
|
||||
getApprovalMode: () => approvalMode,
|
||||
getAllowedTools: () => [],
|
||||
setApprovalMode: (mode: ApprovalMode) => {
|
||||
approvalMode = mode;
|
||||
},
|
||||
|
||||
@@ -32,6 +32,7 @@ import {
|
||||
modifyWithEditor,
|
||||
} from '../tools/modifiable-tool.js';
|
||||
import * as Diff from 'diff';
|
||||
import { doesToolInvocationMatch } from '../utils/tool-utils.js';
|
||||
|
||||
export type ValidatingToolCall = {
|
||||
status: 'validating';
|
||||
@@ -615,68 +616,74 @@ export class CoreToolScheduler {
|
||||
);
|
||||
continue;
|
||||
}
|
||||
if (this.config.getApprovalMode() === ApprovalMode.YOLO) {
|
||||
|
||||
const confirmationDetails =
|
||||
await invocation.shouldConfirmExecute(signal);
|
||||
|
||||
if (!confirmationDetails) {
|
||||
this.setToolCallOutcome(
|
||||
reqInfo.callId,
|
||||
ToolConfirmationOutcome.ProceedAlways,
|
||||
);
|
||||
this.setStatusInternal(reqInfo.callId, 'scheduled');
|
||||
continue;
|
||||
}
|
||||
|
||||
const allowedTools = this.config.getAllowedTools() || [];
|
||||
if (
|
||||
this.config.getApprovalMode() === ApprovalMode.YOLO ||
|
||||
doesToolInvocationMatch(toolCall.tool, invocation, allowedTools)
|
||||
) {
|
||||
this.setToolCallOutcome(
|
||||
reqInfo.callId,
|
||||
ToolConfirmationOutcome.ProceedAlways,
|
||||
);
|
||||
this.setStatusInternal(reqInfo.callId, 'scheduled');
|
||||
} else {
|
||||
const confirmationDetails =
|
||||
await invocation.shouldConfirmExecute(signal);
|
||||
|
||||
if (confirmationDetails) {
|
||||
// Allow IDE to resolve confirmation
|
||||
if (
|
||||
confirmationDetails.type === 'edit' &&
|
||||
confirmationDetails.ideConfirmation
|
||||
) {
|
||||
confirmationDetails.ideConfirmation.then((resolution) => {
|
||||
if (resolution.status === 'accepted') {
|
||||
this.handleConfirmationResponse(
|
||||
reqInfo.callId,
|
||||
confirmationDetails.onConfirm,
|
||||
ToolConfirmationOutcome.ProceedOnce,
|
||||
signal,
|
||||
);
|
||||
} else {
|
||||
this.handleConfirmationResponse(
|
||||
reqInfo.callId,
|
||||
confirmationDetails.onConfirm,
|
||||
ToolConfirmationOutcome.Cancel,
|
||||
signal,
|
||||
);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
const originalOnConfirm = confirmationDetails.onConfirm;
|
||||
const wrappedConfirmationDetails: ToolCallConfirmationDetails = {
|
||||
...confirmationDetails,
|
||||
onConfirm: (
|
||||
outcome: ToolConfirmationOutcome,
|
||||
payload?: ToolConfirmationPayload,
|
||||
) =>
|
||||
// Allow IDE to resolve confirmation
|
||||
if (
|
||||
confirmationDetails.type === 'edit' &&
|
||||
confirmationDetails.ideConfirmation
|
||||
) {
|
||||
confirmationDetails.ideConfirmation.then((resolution) => {
|
||||
if (resolution.status === 'accepted') {
|
||||
this.handleConfirmationResponse(
|
||||
reqInfo.callId,
|
||||
originalOnConfirm,
|
||||
outcome,
|
||||
confirmationDetails.onConfirm,
|
||||
ToolConfirmationOutcome.ProceedOnce,
|
||||
signal,
|
||||
payload,
|
||||
),
|
||||
};
|
||||
this.setStatusInternal(
|
||||
reqInfo.callId,
|
||||
'awaiting_approval',
|
||||
wrappedConfirmationDetails,
|
||||
);
|
||||
} else {
|
||||
this.setToolCallOutcome(
|
||||
reqInfo.callId,
|
||||
ToolConfirmationOutcome.ProceedAlways,
|
||||
);
|
||||
this.setStatusInternal(reqInfo.callId, 'scheduled');
|
||||
);
|
||||
} else {
|
||||
this.handleConfirmationResponse(
|
||||
reqInfo.callId,
|
||||
confirmationDetails.onConfirm,
|
||||
ToolConfirmationOutcome.Cancel,
|
||||
signal,
|
||||
);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
const originalOnConfirm = confirmationDetails.onConfirm;
|
||||
const wrappedConfirmationDetails: ToolCallConfirmationDetails = {
|
||||
...confirmationDetails,
|
||||
onConfirm: (
|
||||
outcome: ToolConfirmationOutcome,
|
||||
payload?: ToolConfirmationPayload,
|
||||
) =>
|
||||
this.handleConfirmationResponse(
|
||||
reqInfo.callId,
|
||||
originalOnConfirm,
|
||||
outcome,
|
||||
signal,
|
||||
payload,
|
||||
),
|
||||
};
|
||||
this.setStatusInternal(
|
||||
reqInfo.callId,
|
||||
'awaiting_approval',
|
||||
wrappedConfirmationDetails,
|
||||
);
|
||||
}
|
||||
} catch (error) {
|
||||
this.setStatusInternal(
|
||||
|
||||
@@ -32,6 +32,7 @@ describe('executeToolCall', () => {
|
||||
mockConfig = {
|
||||
getToolRegistry: () => mockToolRegistry,
|
||||
getApprovalMode: () => ApprovalMode.DEFAULT,
|
||||
getAllowedTools: () => [],
|
||||
getSessionId: () => 'test-session-id',
|
||||
getUsageStatisticsEnabled: () => true,
|
||||
getDebugMode: () => false,
|
||||
|
||||
@@ -307,6 +307,21 @@ export abstract class BaseDeclarativeTool<
|
||||
*/
|
||||
export type AnyDeclarativeTool = DeclarativeTool<object, ToolResult>;
|
||||
|
||||
/**
|
||||
* Type guard to check if an object is a Tool.
|
||||
* @param obj The object to check.
|
||||
* @returns True if the object is a Tool, false otherwise.
|
||||
*/
|
||||
export function isTool(obj: unknown): obj is AnyDeclarativeTool {
|
||||
return (
|
||||
typeof obj === 'object' &&
|
||||
obj !== null &&
|
||||
'name' in obj &&
|
||||
'build' in obj &&
|
||||
typeof (obj as AnyDeclarativeTool).build === 'function'
|
||||
);
|
||||
}
|
||||
|
||||
export interface ToolResult {
|
||||
/**
|
||||
* Content meant to be included in LLM history.
|
||||
|
||||
@@ -16,11 +16,14 @@ import {
|
||||
import type { Config } from '../config/config.js';
|
||||
|
||||
const mockPlatform = vi.hoisted(() => vi.fn());
|
||||
const mockHomedir = vi.hoisted(() => vi.fn());
|
||||
vi.mock('os', () => ({
|
||||
default: {
|
||||
platform: mockPlatform,
|
||||
homedir: mockHomedir,
|
||||
},
|
||||
platform: mockPlatform,
|
||||
homedir: mockHomedir,
|
||||
}));
|
||||
|
||||
const mockQuote = vi.hoisted(() => vi.fn());
|
||||
@@ -38,6 +41,7 @@ beforeEach(() => {
|
||||
config = {
|
||||
getCoreTools: () => [],
|
||||
getExcludeTools: () => [],
|
||||
getAllowedTools: () => [],
|
||||
} as unknown as Config;
|
||||
});
|
||||
|
||||
|
||||
@@ -4,9 +4,13 @@
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import type { AnyToolInvocation } from '../index.js';
|
||||
import type { Config } from '../config/config.js';
|
||||
import os from 'node:os';
|
||||
import { quote } from 'shell-quote';
|
||||
import { doesToolInvocationMatch } from './tool-utils.js';
|
||||
|
||||
const SHELL_TOOL_NAMES = ['run_shell_command', 'ShellTool'];
|
||||
|
||||
/**
|
||||
* An identifier for the shell type.
|
||||
@@ -319,32 +323,19 @@ export function checkCommandPermissions(
|
||||
};
|
||||
}
|
||||
|
||||
const SHELL_TOOL_NAMES = ['run_shell_command', 'ShellTool'];
|
||||
const normalize = (cmd: string): string => cmd.trim().replace(/\s+/g, ' ');
|
||||
|
||||
const isPrefixedBy = (cmd: string, prefix: string): boolean => {
|
||||
if (!cmd.startsWith(prefix)) {
|
||||
return false;
|
||||
}
|
||||
return cmd.length === prefix.length || cmd[prefix.length] === ' ';
|
||||
};
|
||||
|
||||
const extractCommands = (tools: string[]): string[] =>
|
||||
tools.flatMap((tool) => {
|
||||
for (const toolName of SHELL_TOOL_NAMES) {
|
||||
if (tool.startsWith(`${toolName}(`) && tool.endsWith(')')) {
|
||||
return [normalize(tool.slice(toolName.length + 1, -1))];
|
||||
}
|
||||
}
|
||||
return [];
|
||||
});
|
||||
|
||||
const coreTools = config.getCoreTools() || [];
|
||||
const excludeTools = config.getExcludeTools() || [];
|
||||
const commandsToValidate = splitCommands(command).map(normalize);
|
||||
const invocation: AnyToolInvocation & { params: { command: string } } = {
|
||||
params: { command: '' },
|
||||
} as AnyToolInvocation & { params: { command: string } };
|
||||
|
||||
// 1. Blocklist Check (Highest Priority)
|
||||
if (SHELL_TOOL_NAMES.some((name) => excludeTools.includes(name))) {
|
||||
const excludeTools = config.getExcludeTools() || [];
|
||||
const isWildcardBlocked = SHELL_TOOL_NAMES.some((name) =>
|
||||
excludeTools.includes(name),
|
||||
);
|
||||
|
||||
if (isWildcardBlocked) {
|
||||
return {
|
||||
allAllowed: false,
|
||||
disallowedCommands: commandsToValidate,
|
||||
@@ -352,9 +343,12 @@ export function checkCommandPermissions(
|
||||
isHardDenial: true,
|
||||
};
|
||||
}
|
||||
const blockedCommands = extractCommands(excludeTools);
|
||||
|
||||
for (const cmd of commandsToValidate) {
|
||||
if (blockedCommands.some((blocked) => isPrefixedBy(cmd, blocked))) {
|
||||
invocation.params['command'] = cmd;
|
||||
if (
|
||||
doesToolInvocationMatch('run_shell_command', invocation, excludeTools)
|
||||
) {
|
||||
return {
|
||||
allAllowed: false,
|
||||
disallowedCommands: [cmd],
|
||||
@@ -364,7 +358,7 @@ export function checkCommandPermissions(
|
||||
}
|
||||
}
|
||||
|
||||
const globallyAllowedCommands = extractCommands(coreTools);
|
||||
const coreTools = config.getCoreTools() || [];
|
||||
const isWildcardAllowed = SHELL_TOOL_NAMES.some((name) =>
|
||||
coreTools.includes(name),
|
||||
);
|
||||
@@ -375,18 +369,30 @@ export function checkCommandPermissions(
|
||||
return { allAllowed: true, disallowedCommands: [] };
|
||||
}
|
||||
|
||||
const disallowedCommands: string[] = [];
|
||||
|
||||
if (sessionAllowlist) {
|
||||
// "DEFAULT DENY" MODE: A session allowlist is provided.
|
||||
// All commands must be in either the session or global allowlist.
|
||||
const disallowedCommands: string[] = [];
|
||||
const normalizedSessionAllowlist = new Set(
|
||||
[...sessionAllowlist].flatMap((cmd) =>
|
||||
SHELL_TOOL_NAMES.map((name) => `${name}(${cmd})`),
|
||||
),
|
||||
);
|
||||
|
||||
for (const cmd of commandsToValidate) {
|
||||
const isSessionAllowed = [...sessionAllowlist].some((allowed) =>
|
||||
isPrefixedBy(cmd, normalize(allowed)),
|
||||
invocation.params['command'] = cmd;
|
||||
const isSessionAllowed = doesToolInvocationMatch(
|
||||
'run_shell_command',
|
||||
invocation,
|
||||
[...normalizedSessionAllowlist],
|
||||
);
|
||||
if (isSessionAllowed) continue;
|
||||
|
||||
const isGloballyAllowed = globallyAllowedCommands.some((allowed) =>
|
||||
isPrefixedBy(cmd, allowed),
|
||||
const isGloballyAllowed = doesToolInvocationMatch(
|
||||
'run_shell_command',
|
||||
invocation,
|
||||
coreTools,
|
||||
);
|
||||
if (isGloballyAllowed) continue;
|
||||
|
||||
@@ -405,12 +411,18 @@ export function checkCommandPermissions(
|
||||
}
|
||||
} else {
|
||||
// "DEFAULT ALLOW" MODE: No session allowlist.
|
||||
const hasSpecificAllowedCommands = globallyAllowedCommands.length > 0;
|
||||
const hasSpecificAllowedCommands =
|
||||
coreTools.filter((tool) =>
|
||||
SHELL_TOOL_NAMES.some((name) => tool.startsWith(`${name}(`)),
|
||||
).length > 0;
|
||||
|
||||
if (hasSpecificAllowedCommands) {
|
||||
const disallowedCommands: string[] = [];
|
||||
for (const cmd of commandsToValidate) {
|
||||
const isGloballyAllowed = globallyAllowedCommands.some((allowed) =>
|
||||
isPrefixedBy(cmd, allowed),
|
||||
invocation.params['command'] = cmd;
|
||||
const isGloballyAllowed = doesToolInvocationMatch(
|
||||
'run_shell_command',
|
||||
invocation,
|
||||
coreTools,
|
||||
);
|
||||
if (!isGloballyAllowed) {
|
||||
disallowedCommands.push(cmd);
|
||||
@@ -420,7 +432,9 @@ export function checkCommandPermissions(
|
||||
return {
|
||||
allAllowed: false,
|
||||
disallowedCommands,
|
||||
blockReason: `Command(s) not in the allowed commands list. Disallowed commands: ${disallowedCommands.map((c) => JSON.stringify(c)).join(', ')}`,
|
||||
blockReason: `Command(s) not in the allowed commands list. Disallowed commands: ${disallowedCommands
|
||||
.map((c) => JSON.stringify(c))
|
||||
.join(', ')}`,
|
||||
isHardDenial: false, // This is a soft denial.
|
||||
};
|
||||
}
|
||||
|
||||
94
packages/core/src/utils/tool-utils.test.ts
Normal file
94
packages/core/src/utils/tool-utils.test.ts
Normal file
@@ -0,0 +1,94 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { expect, describe, it } from 'vitest';
|
||||
import { doesToolInvocationMatch } from './tool-utils.js';
|
||||
import type { AnyToolInvocation, Config } from '../index.js';
|
||||
import { ReadFileTool } from '../tools/read-file.js';
|
||||
|
||||
describe('doesToolInvocationMatch', () => {
|
||||
it('should not match a partial command prefix', () => {
|
||||
const invocation = {
|
||||
params: { command: 'git commitsomething' },
|
||||
} as AnyToolInvocation;
|
||||
const patterns = ['ShellTool(git commit)'];
|
||||
const result = doesToolInvocationMatch(
|
||||
'run_shell_command',
|
||||
invocation,
|
||||
patterns,
|
||||
);
|
||||
expect(result).toBe(false);
|
||||
});
|
||||
|
||||
it('should match an exact command', () => {
|
||||
const invocation = {
|
||||
params: { command: 'git status' },
|
||||
} as AnyToolInvocation;
|
||||
const patterns = ['ShellTool(git status)'];
|
||||
const result = doesToolInvocationMatch(
|
||||
'run_shell_command',
|
||||
invocation,
|
||||
patterns,
|
||||
);
|
||||
expect(result).toBe(true);
|
||||
});
|
||||
|
||||
it('should match a command that is a prefix', () => {
|
||||
const invocation = {
|
||||
params: { command: 'git status -v' },
|
||||
} as AnyToolInvocation;
|
||||
const patterns = ['ShellTool(git status)'];
|
||||
const result = doesToolInvocationMatch(
|
||||
'run_shell_command',
|
||||
invocation,
|
||||
patterns,
|
||||
);
|
||||
expect(result).toBe(true);
|
||||
});
|
||||
|
||||
describe('for non-shell tools', () => {
|
||||
const readFileTool = new ReadFileTool({} as Config);
|
||||
const invocation = {
|
||||
params: { file: 'test.txt' },
|
||||
} as AnyToolInvocation;
|
||||
|
||||
it('should match by tool name', () => {
|
||||
const patterns = ['read_file'];
|
||||
const result = doesToolInvocationMatch(
|
||||
readFileTool,
|
||||
invocation,
|
||||
patterns,
|
||||
);
|
||||
expect(result).toBe(true);
|
||||
});
|
||||
|
||||
it('should match by tool class name', () => {
|
||||
const patterns = ['ReadFileTool'];
|
||||
const result = doesToolInvocationMatch(
|
||||
readFileTool,
|
||||
invocation,
|
||||
patterns,
|
||||
);
|
||||
expect(result).toBe(true);
|
||||
});
|
||||
|
||||
it('should not match if neither name is in the patterns', () => {
|
||||
const patterns = ['some_other_tool', 'AnotherToolClass'];
|
||||
const result = doesToolInvocationMatch(
|
||||
readFileTool,
|
||||
invocation,
|
||||
patterns,
|
||||
);
|
||||
expect(result).toBe(false);
|
||||
});
|
||||
|
||||
it('should match by tool name when passed as a string', () => {
|
||||
const patterns = ['read_file'];
|
||||
const result = doesToolInvocationMatch('read_file', invocation, patterns);
|
||||
expect(result).toBe(true);
|
||||
});
|
||||
});
|
||||
});
|
||||
76
packages/core/src/utils/tool-utils.ts
Normal file
76
packages/core/src/utils/tool-utils.ts
Normal file
@@ -0,0 +1,76 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import type { AnyDeclarativeTool, AnyToolInvocation } from '../index.js';
|
||||
import { isTool } from '../index.js';
|
||||
|
||||
const SHELL_TOOL_NAMES = ['run_shell_command', 'ShellTool'];
|
||||
|
||||
/**
|
||||
* Checks if a tool invocation matches any of a list of patterns.
|
||||
*
|
||||
* @param toolOrToolName The tool object or the name of the tool being invoked.
|
||||
* @param invocation The invocation object for the tool.
|
||||
* @param patterns A list of patterns to match against.
|
||||
* Patterns can be:
|
||||
* - A tool name (e.g., "ReadFileTool") to match any invocation of that tool.
|
||||
* - A tool name with a prefix (e.g., "ShellTool(git status)") to match
|
||||
* invocations where the arguments start with that prefix.
|
||||
* @returns True if the invocation matches any pattern, false otherwise.
|
||||
*/
|
||||
export function doesToolInvocationMatch(
|
||||
toolOrToolName: AnyDeclarativeTool | string,
|
||||
invocation: AnyToolInvocation,
|
||||
patterns: string[],
|
||||
): boolean {
|
||||
let toolNames: string[];
|
||||
if (isTool(toolOrToolName)) {
|
||||
toolNames = [toolOrToolName.name, toolOrToolName.constructor.name];
|
||||
} else {
|
||||
toolNames = [toolOrToolName as string];
|
||||
}
|
||||
|
||||
if (toolNames.some((name) => SHELL_TOOL_NAMES.includes(name))) {
|
||||
toolNames = [...new Set([...toolNames, ...SHELL_TOOL_NAMES])];
|
||||
}
|
||||
|
||||
for (const pattern of patterns) {
|
||||
const openParen = pattern.indexOf('(');
|
||||
|
||||
if (openParen === -1) {
|
||||
// No arguments, just a tool name
|
||||
if (toolNames.includes(pattern)) {
|
||||
return true;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
const patternToolName = pattern.substring(0, openParen);
|
||||
if (!toolNames.includes(patternToolName)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!pattern.endsWith(')')) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const argPattern = pattern.substring(openParen + 1, pattern.length - 1);
|
||||
|
||||
if (
|
||||
'command' in invocation.params &&
|
||||
toolNames.includes('run_shell_command')
|
||||
) {
|
||||
const argValue = String(
|
||||
(invocation.params as { command: string }).command,
|
||||
);
|
||||
if (argValue === argPattern || argValue.startsWith(argPattern + ' ')) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
Reference in New Issue
Block a user