feat(cli): Add --allowed-tools flag to bypass tool confirmation (#2417) (#6453)

This commit is contained in:
Andrew Garrett
2025-08-27 02:17:43 +10:00
committed by GitHub
parent c33a0da1df
commit 52dae2c583
14 changed files with 524 additions and 135 deletions

View File

@@ -91,6 +91,11 @@ If you are experiencing performance issues with file searching (e.g., with `@` c
- **Default:** All tools available for use by the Gemini model. - **Default:** All tools available for use by the Gemini model.
- **Example:** `"coreTools": ["ReadFileTool", "GlobTool", "ShellTool(ls)"]`. - **Example:** `"coreTools": ["ReadFileTool", "GlobTool", "ShellTool(ls)"]`.
- **`allowedTools`** (array of strings):
- **Default:** `undefined`
- **Description:** A list of tool names that will bypass the confirmation dialog. This is useful for tools that you trust and use frequently. The match semantics are the same as `coreTools`.
- **Example:** `"allowedTools": ["ShellTool(git status)"]`.
- **`excludeTools`** (array of strings): - **`excludeTools`** (array of strings):
- **Description:** Allows you to specify a list of core tool names that should be excluded from the model. A tool listed in both `excludeTools` and `coreTools` is excluded. You can also specify command-specific restrictions for tools that support it, like the `ShellTool`. For example, `"excludeTools": ["ShellTool(rm -rf)"]` will block the `rm -rf` command. - **Description:** Allows you to specify a list of core tool names that should be excluded from the model. A tool listed in both `excludeTools` and `coreTools` is excluded. You can also specify command-specific restrictions for tools that support it, like the `ShellTool`. For example, `"excludeTools": ["ShellTool(rm -rf)"]` will block the `rm -rf` command.
- **Default**: No tools excluded. - **Default**: No tools excluded.
@@ -479,6 +484,9 @@ Arguments passed directly when running the CLI can override other configurations
- `yolo`: Automatically approve all tool calls (equivalent to `--yolo`) - `yolo`: Automatically approve all tool calls (equivalent to `--yolo`)
- Cannot be used together with `--yolo`. Use `--approval-mode=yolo` instead of `--yolo` for the new unified approach. - Cannot be used together with `--yolo`. Use `--approval-mode=yolo` instead of `--yolo` for the new unified approach.
- Example: `gemini --approval-mode auto_edit` - Example: `gemini --approval-mode auto_edit`
- **`--allowed-tools <tool1,tool2,...>`**:
- A comma-separated list of tool names that will bypass the confirmation dialog.
- Example: `gemini --allowed-tools "ShellTool(git status)"`
- **`--telemetry`**: - **`--telemetry`**:
- Enables [telemetry](../telemetry.md). - Enables [telemetry](../telemetry.md).
- **`--telemetry-target`**: - **`--telemetry-target`**:

View File

@@ -70,6 +70,7 @@ export interface CliArgs {
telemetryLogPrompts: boolean | undefined; telemetryLogPrompts: boolean | undefined;
telemetryOutfile: string | undefined; telemetryOutfile: string | undefined;
allowedMcpServerNames: string[] | undefined; allowedMcpServerNames: string[] | undefined;
allowedTools: string[] | undefined;
experimentalAcp: boolean | undefined; experimentalAcp: boolean | undefined;
extensions: string[] | undefined; extensions: string[] | undefined;
listExtensions: boolean | undefined; listExtensions: boolean | undefined;
@@ -189,6 +190,11 @@ export async function parseArguments(settings: Settings): Promise<CliArgs> {
string: true, string: true,
description: 'Allowed MCP server names', description: 'Allowed MCP server names',
}) })
.option('allowed-tools', {
type: 'array',
string: true,
description: 'Tools that are allowed to run without confirmation',
})
.option('extensions', { .option('extensions', {
alias: 'e', alias: 'e',
type: 'array', type: 'array',
@@ -489,6 +495,7 @@ export async function loadCliConfig(
question, question,
fullContext: argv.allFiles || false, fullContext: argv.allFiles || false,
coreTools: settings.coreTools || undefined, coreTools: settings.coreTools || undefined,
allowedTools: argv.allowedTools || settings.allowedTools || undefined,
excludeTools, excludeTools,
toolDiscoveryCommand: settings.toolDiscoveryCommand, toolDiscoveryCommand: settings.toolDiscoveryCommand,
toolCallCommand: settings.toolCallCommand, toolCallCommand: settings.toolCallCommand,

View File

@@ -344,6 +344,16 @@ export const SETTINGS_SCHEMA = {
description: 'Paths to core tool definitions.', description: 'Paths to core tool definitions.',
showInDialog: false, showInDialog: false,
}, },
allowedTools: {
type: 'array',
label: 'Allowed Tools',
category: 'Advanced',
requiresRestart: true,
default: undefined as string[] | undefined,
description:
'A list of tool names that will bypass the confirmation dialog.',
showInDialog: false,
},
excludeTools: { excludeTools: {
type: 'array', type: 'array',
label: 'Exclude Tools', label: 'Exclude Tools',

View File

@@ -53,14 +53,15 @@ const mockToolRegistry = {
const mockConfig = { const mockConfig = {
getToolRegistry: vi.fn(() => mockToolRegistry as unknown as ToolRegistry), getToolRegistry: vi.fn(() => mockToolRegistry as unknown as ToolRegistry),
getApprovalMode: vi.fn(() => ApprovalMode.DEFAULT), getApprovalMode: vi.fn(() => ApprovalMode.DEFAULT),
getSessionId: () => 'test-session-id',
getUsageStatisticsEnabled: () => true, getUsageStatisticsEnabled: () => true,
getDebugMode: () => false, getDebugMode: () => false,
getSessionId: () => 'test-session-id', getAllowedTools: vi.fn(() => []),
getContentGeneratorConfig: () => ({ getContentGeneratorConfig: () => ({
model: 'test-model', model: 'test-model',
authType: 'oauth-personal', authType: 'oauth-personal',
}), }),
}; } as unknown as Config;
class MockToolInvocation extends BaseToolInvocation<object, ToolResult> { class MockToolInvocation extends BaseToolInvocation<object, ToolResult> {
constructor( constructor(
@@ -218,11 +219,6 @@ describe('useReactToolScheduler in YOLO Mode', () => {
await vi.runAllTimersAsync(); // Process execution await vi.runAllTimersAsync(); // Process execution
}); });
// Check that shouldConfirmExecute was NOT called
expect(
mockToolRequiresConfirmation.shouldConfirmExecute,
).not.toHaveBeenCalled();
// Check that execute WAS called // Check that execute WAS called
expect(mockToolRequiresConfirmation.execute).toHaveBeenCalledWith( expect(mockToolRequiresConfirmation.execute).toHaveBeenCalledWith(
request.args, request.args,

View File

@@ -23,6 +23,9 @@ import { GeminiClient } from '../core/client.js';
import { GitService } from '../services/gitService.js'; import { GitService } from '../services/gitService.js';
import { ClearcutLogger } from '../telemetry/clearcut-logger/clearcut-logger.js'; import { ClearcutLogger } from '../telemetry/clearcut-logger/clearcut-logger.js';
import { ShellTool } from '../tools/shell.js';
import { ReadFileTool } from '../tools/read-file.js';
vi.mock('fs', async (importOriginal) => { vi.mock('fs', async (importOriginal) => {
const actual = await importOriginal<typeof import('fs')>(); const actual = await importOriginal<typeof import('fs')>();
return { return {
@@ -629,6 +632,36 @@ describe('Server Config (config.ts)', () => {
expect(config.getUseRipgrep()).toBe(false); expect(config.getUseRipgrep()).toBe(false);
}); });
}); });
describe('createToolRegistry', () => {
it('should register a tool if coreTools contains an argument-specific pattern', async () => {
const params: ConfigParameters = {
...baseParams,
coreTools: ['ShellTool(git status)'],
};
const config = new Config(params);
await config.initialize();
// The ToolRegistry class is mocked, so we can inspect its prototype's methods.
const registerToolMock = (
(await vi.importMock('../tools/tool-registry')) as {
ToolRegistry: { prototype: { registerTool: Mock } };
}
).ToolRegistry.prototype.registerTool;
// Check that registerTool was called for ShellTool
const wasShellToolRegistered = (registerToolMock as Mock).mock.calls.some(
(call) => call[0] instanceof vi.mocked(ShellTool),
);
expect(wasShellToolRegistered).toBe(true);
// Check that registerTool was NOT called for ReadFileTool
const wasReadFileToolRegistered = (
registerToolMock as Mock
).mock.calls.some((call) => call[0] instanceof vi.mocked(ReadFileTool));
expect(wasReadFileToolRegistered).toBe(false);
});
});
}); });
describe('setApprovalMode with folder trust', () => { describe('setApprovalMode with folder trust', () => {

View File

@@ -49,7 +49,8 @@ import { logCliConfiguration, logIdeConnection } from '../telemetry/loggers.js';
import { IdeConnectionEvent, IdeConnectionType } from '../telemetry/types.js'; import { IdeConnectionEvent, IdeConnectionType } from '../telemetry/types.js';
// Re-export OAuth config type // Re-export OAuth config type
export type { MCPOAuthConfig }; export type { MCPOAuthConfig, AnyToolInvocation };
import type { AnyToolInvocation } from '../tools/tools.js';
import { WorkspaceContext } from '../utils/workspaceContext.js'; import { WorkspaceContext } from '../utils/workspaceContext.js';
import { Storage } from './storage.js'; import { Storage } from './storage.js';
import { FileExclusions } from '../utils/ignorePatterns.js'; import { FileExclusions } from '../utils/ignorePatterns.js';
@@ -159,6 +160,7 @@ export interface ConfigParameters {
question?: string; question?: string;
fullContext?: boolean; fullContext?: boolean;
coreTools?: string[]; coreTools?: string[];
allowedTools?: string[];
excludeTools?: string[]; excludeTools?: string[];
toolDiscoveryCommand?: string; toolDiscoveryCommand?: string;
toolCallCommand?: string; toolCallCommand?: string;
@@ -221,6 +223,7 @@ export class Config {
private readonly question: string | undefined; private readonly question: string | undefined;
private readonly fullContext: boolean; private readonly fullContext: boolean;
private readonly coreTools: string[] | undefined; private readonly coreTools: string[] | undefined;
private readonly allowedTools: string[] | undefined;
private readonly excludeTools: string[] | undefined; private readonly excludeTools: string[] | undefined;
private readonly toolDiscoveryCommand: string | undefined; private readonly toolDiscoveryCommand: string | undefined;
private readonly toolCallCommand: string | undefined; private readonly toolCallCommand: string | undefined;
@@ -295,6 +298,7 @@ export class Config {
this.question = params.question; this.question = params.question;
this.fullContext = params.fullContext ?? false; this.fullContext = params.fullContext ?? false;
this.coreTools = params.coreTools; this.coreTools = params.coreTools;
this.allowedTools = params.allowedTools;
this.excludeTools = params.excludeTools; this.excludeTools = params.excludeTools;
this.toolDiscoveryCommand = params.toolDiscoveryCommand; this.toolDiscoveryCommand = params.toolDiscoveryCommand;
this.toolCallCommand = params.toolCallCommand; this.toolCallCommand = params.toolCallCommand;
@@ -523,6 +527,10 @@ export class Config {
return this.coreTools; return this.coreTools;
} }
getAllowedTools(): string[] | undefined {
return this.allowedTools;
}
getExcludeTools(): string[] | undefined { getExcludeTools(): string[] | undefined {
return this.excludeTools; return this.excludeTools;
} }
@@ -807,12 +815,10 @@ export class Config {
const className = ToolClass.name; const className = ToolClass.name;
const toolName = ToolClass.Name || className; const toolName = ToolClass.Name || className;
const coreTools = this.getCoreTools(); const coreTools = this.getCoreTools();
const excludeTools = this.getExcludeTools(); const excludeTools = this.getExcludeTools() || [];
let isEnabled = false; let isEnabled = true; // Enabled by default if coreTools is not set.
if (coreTools === undefined) { if (coreTools) {
isEnabled = true;
} else {
isEnabled = coreTools.some( isEnabled = coreTools.some(
(tool) => (tool) =>
tool === className || tool === className ||
@@ -822,10 +828,11 @@ export class Config {
); );
} }
if ( const isExcluded = excludeTools.some(
excludeTools?.includes(className) || (tool) => tool === className || tool === toolName,
excludeTools?.includes(toolName) );
) {
if (isExcluded) {
isEnabled = false; isEnabled = false;
} }

View File

@@ -5,6 +5,7 @@
*/ */
import { describe, it, expect, vi } from 'vitest'; import { describe, it, expect, vi } from 'vitest';
import type { Mock } from 'vitest';
import type { ToolCall, WaitingToolCall } from './coreToolScheduler.js'; import type { ToolCall, WaitingToolCall } from './coreToolScheduler.js';
import { import {
CoreToolScheduler, CoreToolScheduler,
@@ -99,6 +100,41 @@ class TestApprovalInvocation extends BaseToolInvocation<
} }
} }
async function waitForStatus(
onToolCallsUpdate: Mock,
status: 'awaiting_approval' | 'executing' | 'success' | 'error' | 'cancelled',
timeout = 5000,
): Promise<ToolCall> {
return new Promise((resolve, reject) => {
const startTime = Date.now();
const check = () => {
if (Date.now() - startTime > timeout) {
const seenStatuses = onToolCallsUpdate.mock.calls
.flatMap((call) => call[0])
.map((toolCall: ToolCall) => toolCall.status);
reject(
new Error(
`Timed out waiting for status "${status}". Seen statuses: ${seenStatuses.join(
', ',
)}`,
),
);
return;
}
const foundCall = onToolCallsUpdate.mock.calls
.flatMap((call) => call[0])
.find((toolCall: ToolCall) => toolCall.status === status);
if (foundCall) {
resolve(foundCall);
} else {
setTimeout(check, 10); // Check again in 10ms
}
};
check();
});
}
describe('CoreToolScheduler', () => { describe('CoreToolScheduler', () => {
it('should cancel a tool call if the signal is aborted before confirmation', async () => { it('should cancel a tool call if the signal is aborted before confirmation', async () => {
const mockTool = new MockTool(); const mockTool = new MockTool();
@@ -126,6 +162,7 @@ describe('CoreToolScheduler', () => {
getUsageStatisticsEnabled: () => true, getUsageStatisticsEnabled: () => true,
getDebugMode: () => false, getDebugMode: () => false,
getApprovalMode: () => ApprovalMode.DEFAULT, getApprovalMode: () => ApprovalMode.DEFAULT,
getAllowedTools: () => [],
getContentGeneratorConfig: () => ({ getContentGeneratorConfig: () => ({
model: 'test-model', model: 'test-model',
authType: 'oauth-personal', authType: 'oauth-personal',
@@ -186,6 +223,7 @@ describe('CoreToolScheduler with payload', () => {
getUsageStatisticsEnabled: () => true, getUsageStatisticsEnabled: () => true,
getDebugMode: () => false, getDebugMode: () => false,
getApprovalMode: () => ApprovalMode.DEFAULT, getApprovalMode: () => ApprovalMode.DEFAULT,
getAllowedTools: () => [],
getContentGeneratorConfig: () => ({ getContentGeneratorConfig: () => ({
model: 'test-model', model: 'test-model',
authType: 'oauth-personal', authType: 'oauth-personal',
@@ -212,16 +250,10 @@ describe('CoreToolScheduler with payload', () => {
await scheduler.schedule([request], abortController.signal); await scheduler.schedule([request], abortController.signal);
await vi.waitFor(() => { const awaitingCall = (await waitForStatus(
const awaitingCall = onToolCallsUpdate.mock.calls.find( onToolCallsUpdate,
(call) => call[0][0].status === 'awaiting_approval', 'awaiting_approval',
)?.[0][0]; )) as WaitingToolCall;
expect(awaitingCall).toBeDefined();
});
const awaitingCall = onToolCallsUpdate.mock.calls.find(
(call) => call[0][0].status === 'awaiting_approval',
)?.[0][0];
const confirmationDetails = awaitingCall.confirmationDetails; const confirmationDetails = awaitingCall.confirmationDetails;
if (confirmationDetails) { if (confirmationDetails) {
@@ -497,6 +529,7 @@ describe('CoreToolScheduler edit cancellation', () => {
getUsageStatisticsEnabled: () => true, getUsageStatisticsEnabled: () => true,
getDebugMode: () => false, getDebugMode: () => false,
getApprovalMode: () => ApprovalMode.DEFAULT, getApprovalMode: () => ApprovalMode.DEFAULT,
getAllowedTools: () => [],
getContentGeneratorConfig: () => ({ getContentGeneratorConfig: () => ({
model: 'test-model', model: 'test-model',
authType: 'oauth-personal', authType: 'oauth-personal',
@@ -523,12 +556,10 @@ describe('CoreToolScheduler edit cancellation', () => {
await scheduler.schedule([request], abortController.signal); await scheduler.schedule([request], abortController.signal);
// Wait for the tool to reach awaiting_approval state const awaitingCall = (await waitForStatus(
const awaitingCall = onToolCallsUpdate.mock.calls.find( onToolCallsUpdate,
(call) => call[0][0].status === 'awaiting_approval', 'awaiting_approval',
)?.[0][0]; )) as WaitingToolCall;
expect(awaitingCall).toBeDefined();
// Cancel the edit // Cancel the edit
const confirmationDetails = awaitingCall.confirmationDetails; const confirmationDetails = awaitingCall.confirmationDetails;
@@ -589,6 +620,7 @@ describe('CoreToolScheduler YOLO mode', () => {
getUsageStatisticsEnabled: () => true, getUsageStatisticsEnabled: () => true,
getDebugMode: () => false, getDebugMode: () => false,
getApprovalMode: () => ApprovalMode.YOLO, getApprovalMode: () => ApprovalMode.YOLO,
getAllowedTools: () => [],
getContentGeneratorConfig: () => ({ getContentGeneratorConfig: () => ({
model: 'test-model', model: 'test-model',
authType: 'oauth-personal', authType: 'oauth-personal',
@@ -678,6 +710,7 @@ describe('CoreToolScheduler request queueing', () => {
getUsageStatisticsEnabled: () => true, getUsageStatisticsEnabled: () => true,
getDebugMode: () => false, getDebugMode: () => false,
getApprovalMode: () => ApprovalMode.YOLO, // Use YOLO to avoid confirmation prompts getApprovalMode: () => ApprovalMode.YOLO, // Use YOLO to avoid confirmation prompts
getAllowedTools: () => [],
getContentGeneratorConfig: () => ({ getContentGeneratorConfig: () => ({
model: 'test-model', model: 'test-model',
authType: 'oauth-personal', authType: 'oauth-personal',
@@ -713,10 +746,7 @@ describe('CoreToolScheduler request queueing', () => {
scheduler.schedule([request1], abortController.signal); scheduler.schedule([request1], abortController.signal);
// Wait for the first call to be in the 'executing' state. // Wait for the first call to be in the 'executing' state.
await vi.waitFor(() => { await waitForStatus(onToolCallsUpdate, 'executing');
const calls = onToolCallsUpdate.mock.calls.at(-1)?.[0] as ToolCall[];
expect(calls?.[0]?.status).toBe('executing');
});
// Schedule the second call while the first is "running". // Schedule the second call while the first is "running".
const schedulePromise2 = scheduler.schedule( const schedulePromise2 = scheduler.schedule(
@@ -737,16 +767,6 @@ describe('CoreToolScheduler request queueing', () => {
// Wait for the second schedule promise to resolve. // Wait for the second schedule promise to resolve.
await schedulePromise2; await schedulePromise2;
// Wait for the second call to be in the 'executing' state.
await vi.waitFor(() => {
const calls = onToolCallsUpdate.mock.calls.at(-1)?.[0] as ToolCall[];
expect(calls?.[0]?.status).toBe('executing');
});
// Now the second tool call should have been executed.
expect(mockTool.executeFn).toHaveBeenCalledTimes(2);
expect(mockTool.executeFn).toHaveBeenCalledWith({ b: 2 });
// Let the second call finish. // Let the second call finish.
const secondCallResult = { const secondCallResult = {
llmContent: 'Second call complete', llmContent: 'Second call complete',
@@ -756,6 +776,12 @@ describe('CoreToolScheduler request queueing', () => {
// In a real scenario, a new promise would be created for the second call. // In a real scenario, a new promise would be created for the second call.
resolveFirstCall!(secondCallResult); resolveFirstCall!(secondCallResult);
await vi.waitFor(() => {
// Now the second tool call should have been executed.
expect(mockTool.executeFn).toHaveBeenCalledTimes(2);
});
expect(mockTool.executeFn).toHaveBeenCalledWith({ b: 2 });
// Wait for the second completion. // Wait for the second completion.
await vi.waitFor(() => { await vi.waitFor(() => {
expect(onAllToolCallsComplete).toHaveBeenCalledTimes(2); expect(onAllToolCallsComplete).toHaveBeenCalledTimes(2);
@@ -766,6 +792,96 @@ describe('CoreToolScheduler request queueing', () => {
expect(onAllToolCallsComplete.mock.calls[1][0][0].status).toBe('success'); expect(onAllToolCallsComplete.mock.calls[1][0][0].status).toBe('success');
}); });
it('should auto-approve a tool call if it is on the allowedTools list', async () => {
// Arrange
const mockTool = new MockTool('mockTool');
mockTool.executeFn.mockReturnValue({
llmContent: 'Tool executed',
returnDisplay: 'Tool executed',
});
// This tool would normally require confirmation.
mockTool.shouldConfirm = true;
const declarativeTool = mockTool;
const toolRegistry = {
getTool: () => declarativeTool,
getToolByName: () => declarativeTool,
getFunctionDeclarations: () => [],
tools: new Map(),
discovery: {},
registerTool: () => {},
getToolByDisplayName: () => declarativeTool,
getTools: () => [],
discoverTools: async () => {},
getAllTools: () => [],
getToolsByServer: () => [],
};
const onAllToolCallsComplete = vi.fn();
const onToolCallsUpdate = vi.fn();
// Configure the scheduler to auto-approve the specific tool call.
const mockConfig = {
getSessionId: () => 'test-session-id',
getUsageStatisticsEnabled: () => true,
getDebugMode: () => false,
getApprovalMode: () => ApprovalMode.DEFAULT, // Not YOLO mode
getAllowedTools: () => ['mockTool'], // Auto-approve this tool
getToolRegistry: () => toolRegistry,
getContentGeneratorConfig: () => ({
model: 'test-model',
authType: 'oauth-personal',
}),
} as unknown as Config;
const scheduler = new CoreToolScheduler({
config: mockConfig,
onAllToolCallsComplete,
onToolCallsUpdate,
getPreferredEditor: () => 'vscode',
onEditorClose: vi.fn(),
});
const abortController = new AbortController();
const request = {
callId: '1',
name: 'mockTool',
args: { param: 'value' },
isClientInitiated: false,
prompt_id: 'prompt-auto-approved',
};
// Act
await scheduler.schedule([request], abortController.signal);
// Assert
// 1. The tool's execute method was called directly.
expect(mockTool.executeFn).toHaveBeenCalledWith({ param: 'value' });
// 2. The tool call status never entered 'awaiting_approval'.
const statusUpdates = onToolCallsUpdate.mock.calls
.map((call) => (call[0][0] as ToolCall)?.status)
.filter(Boolean);
expect(statusUpdates).not.toContain('awaiting_approval');
expect(statusUpdates).toEqual([
'validating',
'scheduled',
'executing',
'success',
]);
// 3. The final callback indicates the tool call was successful.
expect(onAllToolCallsComplete).toHaveBeenCalled();
const completedCalls = onAllToolCallsComplete.mock
.calls[0][0] as ToolCall[];
expect(completedCalls).toHaveLength(1);
const completedCall = completedCalls[0];
expect(completedCall.status).toBe('success');
if (completedCall.status === 'success') {
expect(completedCall.response.resultDisplay).toBe('Tool executed');
}
});
it('should handle two synchronous calls to schedule', async () => { it('should handle two synchronous calls to schedule', async () => {
const mockTool = new MockTool(); const mockTool = new MockTool();
const declarativeTool = mockTool; const declarativeTool = mockTool;
@@ -782,7 +898,6 @@ describe('CoreToolScheduler request queueing', () => {
getAllTools: () => [], getAllTools: () => [],
getToolsByServer: () => [], getToolsByServer: () => [],
} as unknown as ToolRegistry; } as unknown as ToolRegistry;
const onAllToolCallsComplete = vi.fn(); const onAllToolCallsComplete = vi.fn();
const onToolCallsUpdate = vi.fn(); const onToolCallsUpdate = vi.fn();
@@ -791,6 +906,7 @@ describe('CoreToolScheduler request queueing', () => {
getUsageStatisticsEnabled: () => true, getUsageStatisticsEnabled: () => true,
getDebugMode: () => false, getDebugMode: () => false,
getApprovalMode: () => ApprovalMode.YOLO, getApprovalMode: () => ApprovalMode.YOLO,
getAllowedTools: () => [],
getContentGeneratorConfig: () => ({ getContentGeneratorConfig: () => ({
model: 'test-model', model: 'test-model',
authType: 'oauth-personal', authType: 'oauth-personal',
@@ -851,6 +967,7 @@ describe('CoreToolScheduler request queueing', () => {
getUsageStatisticsEnabled: () => true, getUsageStatisticsEnabled: () => true,
getDebugMode: () => false, getDebugMode: () => false,
getApprovalMode: () => approvalMode, getApprovalMode: () => approvalMode,
getAllowedTools: () => [],
setApprovalMode: (mode: ApprovalMode) => { setApprovalMode: (mode: ApprovalMode) => {
approvalMode = mode; approvalMode = mode;
}, },

View File

@@ -32,6 +32,7 @@ import {
modifyWithEditor, modifyWithEditor,
} from '../tools/modifiable-tool.js'; } from '../tools/modifiable-tool.js';
import * as Diff from 'diff'; import * as Diff from 'diff';
import { doesToolInvocationMatch } from '../utils/tool-utils.js';
export type ValidatingToolCall = { export type ValidatingToolCall = {
status: 'validating'; status: 'validating';
@@ -615,68 +616,74 @@ export class CoreToolScheduler {
); );
continue; continue;
} }
if (this.config.getApprovalMode() === ApprovalMode.YOLO) {
const confirmationDetails =
await invocation.shouldConfirmExecute(signal);
if (!confirmationDetails) {
this.setToolCallOutcome(
reqInfo.callId,
ToolConfirmationOutcome.ProceedAlways,
);
this.setStatusInternal(reqInfo.callId, 'scheduled');
continue;
}
const allowedTools = this.config.getAllowedTools() || [];
if (
this.config.getApprovalMode() === ApprovalMode.YOLO ||
doesToolInvocationMatch(toolCall.tool, invocation, allowedTools)
) {
this.setToolCallOutcome( this.setToolCallOutcome(
reqInfo.callId, reqInfo.callId,
ToolConfirmationOutcome.ProceedAlways, ToolConfirmationOutcome.ProceedAlways,
); );
this.setStatusInternal(reqInfo.callId, 'scheduled'); this.setStatusInternal(reqInfo.callId, 'scheduled');
} else { } else {
const confirmationDetails = // Allow IDE to resolve confirmation
await invocation.shouldConfirmExecute(signal); if (
confirmationDetails.type === 'edit' &&
if (confirmationDetails) { confirmationDetails.ideConfirmation
// Allow IDE to resolve confirmation ) {
if ( confirmationDetails.ideConfirmation.then((resolution) => {
confirmationDetails.type === 'edit' && if (resolution.status === 'accepted') {
confirmationDetails.ideConfirmation
) {
confirmationDetails.ideConfirmation.then((resolution) => {
if (resolution.status === 'accepted') {
this.handleConfirmationResponse(
reqInfo.callId,
confirmationDetails.onConfirm,
ToolConfirmationOutcome.ProceedOnce,
signal,
);
} else {
this.handleConfirmationResponse(
reqInfo.callId,
confirmationDetails.onConfirm,
ToolConfirmationOutcome.Cancel,
signal,
);
}
});
}
const originalOnConfirm = confirmationDetails.onConfirm;
const wrappedConfirmationDetails: ToolCallConfirmationDetails = {
...confirmationDetails,
onConfirm: (
outcome: ToolConfirmationOutcome,
payload?: ToolConfirmationPayload,
) =>
this.handleConfirmationResponse( this.handleConfirmationResponse(
reqInfo.callId, reqInfo.callId,
originalOnConfirm, confirmationDetails.onConfirm,
outcome, ToolConfirmationOutcome.ProceedOnce,
signal, signal,
payload, );
), } else {
}; this.handleConfirmationResponse(
this.setStatusInternal( reqInfo.callId,
reqInfo.callId, confirmationDetails.onConfirm,
'awaiting_approval', ToolConfirmationOutcome.Cancel,
wrappedConfirmationDetails, signal,
); );
} else { }
this.setToolCallOutcome( });
reqInfo.callId,
ToolConfirmationOutcome.ProceedAlways,
);
this.setStatusInternal(reqInfo.callId, 'scheduled');
} }
const originalOnConfirm = confirmationDetails.onConfirm;
const wrappedConfirmationDetails: ToolCallConfirmationDetails = {
...confirmationDetails,
onConfirm: (
outcome: ToolConfirmationOutcome,
payload?: ToolConfirmationPayload,
) =>
this.handleConfirmationResponse(
reqInfo.callId,
originalOnConfirm,
outcome,
signal,
payload,
),
};
this.setStatusInternal(
reqInfo.callId,
'awaiting_approval',
wrappedConfirmationDetails,
);
} }
} catch (error) { } catch (error) {
this.setStatusInternal( this.setStatusInternal(

View File

@@ -32,6 +32,7 @@ describe('executeToolCall', () => {
mockConfig = { mockConfig = {
getToolRegistry: () => mockToolRegistry, getToolRegistry: () => mockToolRegistry,
getApprovalMode: () => ApprovalMode.DEFAULT, getApprovalMode: () => ApprovalMode.DEFAULT,
getAllowedTools: () => [],
getSessionId: () => 'test-session-id', getSessionId: () => 'test-session-id',
getUsageStatisticsEnabled: () => true, getUsageStatisticsEnabled: () => true,
getDebugMode: () => false, getDebugMode: () => false,

View File

@@ -307,6 +307,21 @@ export abstract class BaseDeclarativeTool<
*/ */
export type AnyDeclarativeTool = DeclarativeTool<object, ToolResult>; export type AnyDeclarativeTool = DeclarativeTool<object, ToolResult>;
/**
* Type guard to check if an object is a Tool.
* @param obj The object to check.
* @returns True if the object is a Tool, false otherwise.
*/
export function isTool(obj: unknown): obj is AnyDeclarativeTool {
return (
typeof obj === 'object' &&
obj !== null &&
'name' in obj &&
'build' in obj &&
typeof (obj as AnyDeclarativeTool).build === 'function'
);
}
export interface ToolResult { export interface ToolResult {
/** /**
* Content meant to be included in LLM history. * Content meant to be included in LLM history.

View File

@@ -16,11 +16,14 @@ import {
import type { Config } from '../config/config.js'; import type { Config } from '../config/config.js';
const mockPlatform = vi.hoisted(() => vi.fn()); const mockPlatform = vi.hoisted(() => vi.fn());
const mockHomedir = vi.hoisted(() => vi.fn());
vi.mock('os', () => ({ vi.mock('os', () => ({
default: { default: {
platform: mockPlatform, platform: mockPlatform,
homedir: mockHomedir,
}, },
platform: mockPlatform, platform: mockPlatform,
homedir: mockHomedir,
})); }));
const mockQuote = vi.hoisted(() => vi.fn()); const mockQuote = vi.hoisted(() => vi.fn());
@@ -38,6 +41,7 @@ beforeEach(() => {
config = { config = {
getCoreTools: () => [], getCoreTools: () => [],
getExcludeTools: () => [], getExcludeTools: () => [],
getAllowedTools: () => [],
} as unknown as Config; } as unknown as Config;
}); });

View File

@@ -4,9 +4,13 @@
* SPDX-License-Identifier: Apache-2.0 * SPDX-License-Identifier: Apache-2.0
*/ */
import type { AnyToolInvocation } from '../index.js';
import type { Config } from '../config/config.js'; import type { Config } from '../config/config.js';
import os from 'node:os'; import os from 'node:os';
import { quote } from 'shell-quote'; import { quote } from 'shell-quote';
import { doesToolInvocationMatch } from './tool-utils.js';
const SHELL_TOOL_NAMES = ['run_shell_command', 'ShellTool'];
/** /**
* An identifier for the shell type. * An identifier for the shell type.
@@ -319,32 +323,19 @@ export function checkCommandPermissions(
}; };
} }
const SHELL_TOOL_NAMES = ['run_shell_command', 'ShellTool'];
const normalize = (cmd: string): string => cmd.trim().replace(/\s+/g, ' '); const normalize = (cmd: string): string => cmd.trim().replace(/\s+/g, ' ');
const isPrefixedBy = (cmd: string, prefix: string): boolean => {
if (!cmd.startsWith(prefix)) {
return false;
}
return cmd.length === prefix.length || cmd[prefix.length] === ' ';
};
const extractCommands = (tools: string[]): string[] =>
tools.flatMap((tool) => {
for (const toolName of SHELL_TOOL_NAMES) {
if (tool.startsWith(`${toolName}(`) && tool.endsWith(')')) {
return [normalize(tool.slice(toolName.length + 1, -1))];
}
}
return [];
});
const coreTools = config.getCoreTools() || [];
const excludeTools = config.getExcludeTools() || [];
const commandsToValidate = splitCommands(command).map(normalize); const commandsToValidate = splitCommands(command).map(normalize);
const invocation: AnyToolInvocation & { params: { command: string } } = {
params: { command: '' },
} as AnyToolInvocation & { params: { command: string } };
// 1. Blocklist Check (Highest Priority) // 1. Blocklist Check (Highest Priority)
if (SHELL_TOOL_NAMES.some((name) => excludeTools.includes(name))) { const excludeTools = config.getExcludeTools() || [];
const isWildcardBlocked = SHELL_TOOL_NAMES.some((name) =>
excludeTools.includes(name),
);
if (isWildcardBlocked) {
return { return {
allAllowed: false, allAllowed: false,
disallowedCommands: commandsToValidate, disallowedCommands: commandsToValidate,
@@ -352,9 +343,12 @@ export function checkCommandPermissions(
isHardDenial: true, isHardDenial: true,
}; };
} }
const blockedCommands = extractCommands(excludeTools);
for (const cmd of commandsToValidate) { for (const cmd of commandsToValidate) {
if (blockedCommands.some((blocked) => isPrefixedBy(cmd, blocked))) { invocation.params['command'] = cmd;
if (
doesToolInvocationMatch('run_shell_command', invocation, excludeTools)
) {
return { return {
allAllowed: false, allAllowed: false,
disallowedCommands: [cmd], disallowedCommands: [cmd],
@@ -364,7 +358,7 @@ export function checkCommandPermissions(
} }
} }
const globallyAllowedCommands = extractCommands(coreTools); const coreTools = config.getCoreTools() || [];
const isWildcardAllowed = SHELL_TOOL_NAMES.some((name) => const isWildcardAllowed = SHELL_TOOL_NAMES.some((name) =>
coreTools.includes(name), coreTools.includes(name),
); );
@@ -375,18 +369,30 @@ export function checkCommandPermissions(
return { allAllowed: true, disallowedCommands: [] }; return { allAllowed: true, disallowedCommands: [] };
} }
const disallowedCommands: string[] = [];
if (sessionAllowlist) { if (sessionAllowlist) {
// "DEFAULT DENY" MODE: A session allowlist is provided. // "DEFAULT DENY" MODE: A session allowlist is provided.
// All commands must be in either the session or global allowlist. // All commands must be in either the session or global allowlist.
const disallowedCommands: string[] = []; const normalizedSessionAllowlist = new Set(
[...sessionAllowlist].flatMap((cmd) =>
SHELL_TOOL_NAMES.map((name) => `${name}(${cmd})`),
),
);
for (const cmd of commandsToValidate) { for (const cmd of commandsToValidate) {
const isSessionAllowed = [...sessionAllowlist].some((allowed) => invocation.params['command'] = cmd;
isPrefixedBy(cmd, normalize(allowed)), const isSessionAllowed = doesToolInvocationMatch(
'run_shell_command',
invocation,
[...normalizedSessionAllowlist],
); );
if (isSessionAllowed) continue; if (isSessionAllowed) continue;
const isGloballyAllowed = globallyAllowedCommands.some((allowed) => const isGloballyAllowed = doesToolInvocationMatch(
isPrefixedBy(cmd, allowed), 'run_shell_command',
invocation,
coreTools,
); );
if (isGloballyAllowed) continue; if (isGloballyAllowed) continue;
@@ -405,12 +411,18 @@ export function checkCommandPermissions(
} }
} else { } else {
// "DEFAULT ALLOW" MODE: No session allowlist. // "DEFAULT ALLOW" MODE: No session allowlist.
const hasSpecificAllowedCommands = globallyAllowedCommands.length > 0; const hasSpecificAllowedCommands =
coreTools.filter((tool) =>
SHELL_TOOL_NAMES.some((name) => tool.startsWith(`${name}(`)),
).length > 0;
if (hasSpecificAllowedCommands) { if (hasSpecificAllowedCommands) {
const disallowedCommands: string[] = [];
for (const cmd of commandsToValidate) { for (const cmd of commandsToValidate) {
const isGloballyAllowed = globallyAllowedCommands.some((allowed) => invocation.params['command'] = cmd;
isPrefixedBy(cmd, allowed), const isGloballyAllowed = doesToolInvocationMatch(
'run_shell_command',
invocation,
coreTools,
); );
if (!isGloballyAllowed) { if (!isGloballyAllowed) {
disallowedCommands.push(cmd); disallowedCommands.push(cmd);
@@ -420,7 +432,9 @@ export function checkCommandPermissions(
return { return {
allAllowed: false, allAllowed: false,
disallowedCommands, disallowedCommands,
blockReason: `Command(s) not in the allowed commands list. Disallowed commands: ${disallowedCommands.map((c) => JSON.stringify(c)).join(', ')}`, blockReason: `Command(s) not in the allowed commands list. Disallowed commands: ${disallowedCommands
.map((c) => JSON.stringify(c))
.join(', ')}`,
isHardDenial: false, // This is a soft denial. isHardDenial: false, // This is a soft denial.
}; };
} }

View File

@@ -0,0 +1,94 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { expect, describe, it } from 'vitest';
import { doesToolInvocationMatch } from './tool-utils.js';
import type { AnyToolInvocation, Config } from '../index.js';
import { ReadFileTool } from '../tools/read-file.js';
describe('doesToolInvocationMatch', () => {
it('should not match a partial command prefix', () => {
const invocation = {
params: { command: 'git commitsomething' },
} as AnyToolInvocation;
const patterns = ['ShellTool(git commit)'];
const result = doesToolInvocationMatch(
'run_shell_command',
invocation,
patterns,
);
expect(result).toBe(false);
});
it('should match an exact command', () => {
const invocation = {
params: { command: 'git status' },
} as AnyToolInvocation;
const patterns = ['ShellTool(git status)'];
const result = doesToolInvocationMatch(
'run_shell_command',
invocation,
patterns,
);
expect(result).toBe(true);
});
it('should match a command that is a prefix', () => {
const invocation = {
params: { command: 'git status -v' },
} as AnyToolInvocation;
const patterns = ['ShellTool(git status)'];
const result = doesToolInvocationMatch(
'run_shell_command',
invocation,
patterns,
);
expect(result).toBe(true);
});
describe('for non-shell tools', () => {
const readFileTool = new ReadFileTool({} as Config);
const invocation = {
params: { file: 'test.txt' },
} as AnyToolInvocation;
it('should match by tool name', () => {
const patterns = ['read_file'];
const result = doesToolInvocationMatch(
readFileTool,
invocation,
patterns,
);
expect(result).toBe(true);
});
it('should match by tool class name', () => {
const patterns = ['ReadFileTool'];
const result = doesToolInvocationMatch(
readFileTool,
invocation,
patterns,
);
expect(result).toBe(true);
});
it('should not match if neither name is in the patterns', () => {
const patterns = ['some_other_tool', 'AnotherToolClass'];
const result = doesToolInvocationMatch(
readFileTool,
invocation,
patterns,
);
expect(result).toBe(false);
});
it('should match by tool name when passed as a string', () => {
const patterns = ['read_file'];
const result = doesToolInvocationMatch('read_file', invocation, patterns);
expect(result).toBe(true);
});
});
});

View File

@@ -0,0 +1,76 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import type { AnyDeclarativeTool, AnyToolInvocation } from '../index.js';
import { isTool } from '../index.js';
const SHELL_TOOL_NAMES = ['run_shell_command', 'ShellTool'];
/**
* Checks if a tool invocation matches any of a list of patterns.
*
* @param toolOrToolName The tool object or the name of the tool being invoked.
* @param invocation The invocation object for the tool.
* @param patterns A list of patterns to match against.
* Patterns can be:
* - A tool name (e.g., "ReadFileTool") to match any invocation of that tool.
* - A tool name with a prefix (e.g., "ShellTool(git status)") to match
* invocations where the arguments start with that prefix.
* @returns True if the invocation matches any pattern, false otherwise.
*/
export function doesToolInvocationMatch(
toolOrToolName: AnyDeclarativeTool | string,
invocation: AnyToolInvocation,
patterns: string[],
): boolean {
let toolNames: string[];
if (isTool(toolOrToolName)) {
toolNames = [toolOrToolName.name, toolOrToolName.constructor.name];
} else {
toolNames = [toolOrToolName as string];
}
if (toolNames.some((name) => SHELL_TOOL_NAMES.includes(name))) {
toolNames = [...new Set([...toolNames, ...SHELL_TOOL_NAMES])];
}
for (const pattern of patterns) {
const openParen = pattern.indexOf('(');
if (openParen === -1) {
// No arguments, just a tool name
if (toolNames.includes(pattern)) {
return true;
}
continue;
}
const patternToolName = pattern.substring(0, openParen);
if (!toolNames.includes(patternToolName)) {
continue;
}
if (!pattern.endsWith(')')) {
continue;
}
const argPattern = pattern.substring(openParen + 1, pattern.length - 1);
if (
'command' in invocation.params &&
toolNames.includes('run_shell_command')
) {
const argValue = String(
(invocation.params as { command: string }).command,
);
if (argValue === argPattern || argValue.startsWith(argPattern + ' ')) {
return true;
}
}
}
return false;
}