feat(cli): Add --allowed-tools flag to bypass tool confirmation (#2417) (#6453)

This commit is contained in:
Andrew Garrett
2025-08-27 02:17:43 +10:00
committed by GitHub
parent c33a0da1df
commit 52dae2c583
14 changed files with 524 additions and 135 deletions

View File

@@ -91,6 +91,11 @@ If you are experiencing performance issues with file searching (e.g., with `@` c
- **Default:** All tools available for use by the Gemini model.
- **Example:** `"coreTools": ["ReadFileTool", "GlobTool", "ShellTool(ls)"]`.
- **`allowedTools`** (array of strings):
- **Default:** `undefined`
- **Description:** A list of tool names that will bypass the confirmation dialog. This is useful for tools that you trust and use frequently. The match semantics are the same as `coreTools`.
- **Example:** `"allowedTools": ["ShellTool(git status)"]`.
- **`excludeTools`** (array of strings):
- **Description:** Allows you to specify a list of core tool names that should be excluded from the model. A tool listed in both `excludeTools` and `coreTools` is excluded. You can also specify command-specific restrictions for tools that support it, like the `ShellTool`. For example, `"excludeTools": ["ShellTool(rm -rf)"]` will block the `rm -rf` command.
- **Default**: No tools excluded.
@@ -479,6 +484,9 @@ Arguments passed directly when running the CLI can override other configurations
- `yolo`: Automatically approve all tool calls (equivalent to `--yolo`)
- Cannot be used together with `--yolo`. Use `--approval-mode=yolo` instead of `--yolo` for the new unified approach.
- Example: `gemini --approval-mode auto_edit`
- **`--allowed-tools <tool1,tool2,...>`**:
- A comma-separated list of tool names that will bypass the confirmation dialog.
- Example: `gemini --allowed-tools "ShellTool(git status)"`
- **`--telemetry`**:
- Enables [telemetry](../telemetry.md).
- **`--telemetry-target`**:

View File

@@ -70,6 +70,7 @@ export interface CliArgs {
telemetryLogPrompts: boolean | undefined;
telemetryOutfile: string | undefined;
allowedMcpServerNames: string[] | undefined;
allowedTools: string[] | undefined;
experimentalAcp: boolean | undefined;
extensions: string[] | undefined;
listExtensions: boolean | undefined;
@@ -189,6 +190,11 @@ export async function parseArguments(settings: Settings): Promise<CliArgs> {
string: true,
description: 'Allowed MCP server names',
})
.option('allowed-tools', {
type: 'array',
string: true,
description: 'Tools that are allowed to run without confirmation',
})
.option('extensions', {
alias: 'e',
type: 'array',
@@ -489,6 +495,7 @@ export async function loadCliConfig(
question,
fullContext: argv.allFiles || false,
coreTools: settings.coreTools || undefined,
allowedTools: argv.allowedTools || settings.allowedTools || undefined,
excludeTools,
toolDiscoveryCommand: settings.toolDiscoveryCommand,
toolCallCommand: settings.toolCallCommand,

View File

@@ -344,6 +344,16 @@ export const SETTINGS_SCHEMA = {
description: 'Paths to core tool definitions.',
showInDialog: false,
},
allowedTools: {
type: 'array',
label: 'Allowed Tools',
category: 'Advanced',
requiresRestart: true,
default: undefined as string[] | undefined,
description:
'A list of tool names that will bypass the confirmation dialog.',
showInDialog: false,
},
excludeTools: {
type: 'array',
label: 'Exclude Tools',

View File

@@ -53,14 +53,15 @@ const mockToolRegistry = {
const mockConfig = {
getToolRegistry: vi.fn(() => mockToolRegistry as unknown as ToolRegistry),
getApprovalMode: vi.fn(() => ApprovalMode.DEFAULT),
getSessionId: () => 'test-session-id',
getUsageStatisticsEnabled: () => true,
getDebugMode: () => false,
getSessionId: () => 'test-session-id',
getAllowedTools: vi.fn(() => []),
getContentGeneratorConfig: () => ({
model: 'test-model',
authType: 'oauth-personal',
}),
};
} as unknown as Config;
class MockToolInvocation extends BaseToolInvocation<object, ToolResult> {
constructor(
@@ -218,11 +219,6 @@ describe('useReactToolScheduler in YOLO Mode', () => {
await vi.runAllTimersAsync(); // Process execution
});
// Check that shouldConfirmExecute was NOT called
expect(
mockToolRequiresConfirmation.shouldConfirmExecute,
).not.toHaveBeenCalled();
// Check that execute WAS called
expect(mockToolRequiresConfirmation.execute).toHaveBeenCalledWith(
request.args,

View File

@@ -23,6 +23,9 @@ import { GeminiClient } from '../core/client.js';
import { GitService } from '../services/gitService.js';
import { ClearcutLogger } from '../telemetry/clearcut-logger/clearcut-logger.js';
import { ShellTool } from '../tools/shell.js';
import { ReadFileTool } from '../tools/read-file.js';
vi.mock('fs', async (importOriginal) => {
const actual = await importOriginal<typeof import('fs')>();
return {
@@ -629,6 +632,36 @@ describe('Server Config (config.ts)', () => {
expect(config.getUseRipgrep()).toBe(false);
});
});
describe('createToolRegistry', () => {
it('should register a tool if coreTools contains an argument-specific pattern', async () => {
const params: ConfigParameters = {
...baseParams,
coreTools: ['ShellTool(git status)'],
};
const config = new Config(params);
await config.initialize();
// The ToolRegistry class is mocked, so we can inspect its prototype's methods.
const registerToolMock = (
(await vi.importMock('../tools/tool-registry')) as {
ToolRegistry: { prototype: { registerTool: Mock } };
}
).ToolRegistry.prototype.registerTool;
// Check that registerTool was called for ShellTool
const wasShellToolRegistered = (registerToolMock as Mock).mock.calls.some(
(call) => call[0] instanceof vi.mocked(ShellTool),
);
expect(wasShellToolRegistered).toBe(true);
// Check that registerTool was NOT called for ReadFileTool
const wasReadFileToolRegistered = (
registerToolMock as Mock
).mock.calls.some((call) => call[0] instanceof vi.mocked(ReadFileTool));
expect(wasReadFileToolRegistered).toBe(false);
});
});
});
describe('setApprovalMode with folder trust', () => {

View File

@@ -49,7 +49,8 @@ import { logCliConfiguration, logIdeConnection } from '../telemetry/loggers.js';
import { IdeConnectionEvent, IdeConnectionType } from '../telemetry/types.js';
// Re-export OAuth config type
export type { MCPOAuthConfig };
export type { MCPOAuthConfig, AnyToolInvocation };
import type { AnyToolInvocation } from '../tools/tools.js';
import { WorkspaceContext } from '../utils/workspaceContext.js';
import { Storage } from './storage.js';
import { FileExclusions } from '../utils/ignorePatterns.js';
@@ -159,6 +160,7 @@ export interface ConfigParameters {
question?: string;
fullContext?: boolean;
coreTools?: string[];
allowedTools?: string[];
excludeTools?: string[];
toolDiscoveryCommand?: string;
toolCallCommand?: string;
@@ -221,6 +223,7 @@ export class Config {
private readonly question: string | undefined;
private readonly fullContext: boolean;
private readonly coreTools: string[] | undefined;
private readonly allowedTools: string[] | undefined;
private readonly excludeTools: string[] | undefined;
private readonly toolDiscoveryCommand: string | undefined;
private readonly toolCallCommand: string | undefined;
@@ -295,6 +298,7 @@ export class Config {
this.question = params.question;
this.fullContext = params.fullContext ?? false;
this.coreTools = params.coreTools;
this.allowedTools = params.allowedTools;
this.excludeTools = params.excludeTools;
this.toolDiscoveryCommand = params.toolDiscoveryCommand;
this.toolCallCommand = params.toolCallCommand;
@@ -523,6 +527,10 @@ export class Config {
return this.coreTools;
}
getAllowedTools(): string[] | undefined {
return this.allowedTools;
}
getExcludeTools(): string[] | undefined {
return this.excludeTools;
}
@@ -807,12 +815,10 @@ export class Config {
const className = ToolClass.name;
const toolName = ToolClass.Name || className;
const coreTools = this.getCoreTools();
const excludeTools = this.getExcludeTools();
const excludeTools = this.getExcludeTools() || [];
let isEnabled = false;
if (coreTools === undefined) {
isEnabled = true;
} else {
let isEnabled = true; // Enabled by default if coreTools is not set.
if (coreTools) {
isEnabled = coreTools.some(
(tool) =>
tool === className ||
@@ -822,10 +828,11 @@ export class Config {
);
}
if (
excludeTools?.includes(className) ||
excludeTools?.includes(toolName)
) {
const isExcluded = excludeTools.some(
(tool) => tool === className || tool === toolName,
);
if (isExcluded) {
isEnabled = false;
}

View File

@@ -5,6 +5,7 @@
*/
import { describe, it, expect, vi } from 'vitest';
import type { Mock } from 'vitest';
import type { ToolCall, WaitingToolCall } from './coreToolScheduler.js';
import {
CoreToolScheduler,
@@ -99,6 +100,41 @@ class TestApprovalInvocation extends BaseToolInvocation<
}
}
async function waitForStatus(
onToolCallsUpdate: Mock,
status: 'awaiting_approval' | 'executing' | 'success' | 'error' | 'cancelled',
timeout = 5000,
): Promise<ToolCall> {
return new Promise((resolve, reject) => {
const startTime = Date.now();
const check = () => {
if (Date.now() - startTime > timeout) {
const seenStatuses = onToolCallsUpdate.mock.calls
.flatMap((call) => call[0])
.map((toolCall: ToolCall) => toolCall.status);
reject(
new Error(
`Timed out waiting for status "${status}". Seen statuses: ${seenStatuses.join(
', ',
)}`,
),
);
return;
}
const foundCall = onToolCallsUpdate.mock.calls
.flatMap((call) => call[0])
.find((toolCall: ToolCall) => toolCall.status === status);
if (foundCall) {
resolve(foundCall);
} else {
setTimeout(check, 10); // Check again in 10ms
}
};
check();
});
}
describe('CoreToolScheduler', () => {
it('should cancel a tool call if the signal is aborted before confirmation', async () => {
const mockTool = new MockTool();
@@ -126,6 +162,7 @@ describe('CoreToolScheduler', () => {
getUsageStatisticsEnabled: () => true,
getDebugMode: () => false,
getApprovalMode: () => ApprovalMode.DEFAULT,
getAllowedTools: () => [],
getContentGeneratorConfig: () => ({
model: 'test-model',
authType: 'oauth-personal',
@@ -186,6 +223,7 @@ describe('CoreToolScheduler with payload', () => {
getUsageStatisticsEnabled: () => true,
getDebugMode: () => false,
getApprovalMode: () => ApprovalMode.DEFAULT,
getAllowedTools: () => [],
getContentGeneratorConfig: () => ({
model: 'test-model',
authType: 'oauth-personal',
@@ -212,16 +250,10 @@ describe('CoreToolScheduler with payload', () => {
await scheduler.schedule([request], abortController.signal);
await vi.waitFor(() => {
const awaitingCall = onToolCallsUpdate.mock.calls.find(
(call) => call[0][0].status === 'awaiting_approval',
)?.[0][0];
expect(awaitingCall).toBeDefined();
});
const awaitingCall = onToolCallsUpdate.mock.calls.find(
(call) => call[0][0].status === 'awaiting_approval',
)?.[0][0];
const awaitingCall = (await waitForStatus(
onToolCallsUpdate,
'awaiting_approval',
)) as WaitingToolCall;
const confirmationDetails = awaitingCall.confirmationDetails;
if (confirmationDetails) {
@@ -497,6 +529,7 @@ describe('CoreToolScheduler edit cancellation', () => {
getUsageStatisticsEnabled: () => true,
getDebugMode: () => false,
getApprovalMode: () => ApprovalMode.DEFAULT,
getAllowedTools: () => [],
getContentGeneratorConfig: () => ({
model: 'test-model',
authType: 'oauth-personal',
@@ -523,12 +556,10 @@ describe('CoreToolScheduler edit cancellation', () => {
await scheduler.schedule([request], abortController.signal);
// Wait for the tool to reach awaiting_approval state
const awaitingCall = onToolCallsUpdate.mock.calls.find(
(call) => call[0][0].status === 'awaiting_approval',
)?.[0][0];
expect(awaitingCall).toBeDefined();
const awaitingCall = (await waitForStatus(
onToolCallsUpdate,
'awaiting_approval',
)) as WaitingToolCall;
// Cancel the edit
const confirmationDetails = awaitingCall.confirmationDetails;
@@ -589,6 +620,7 @@ describe('CoreToolScheduler YOLO mode', () => {
getUsageStatisticsEnabled: () => true,
getDebugMode: () => false,
getApprovalMode: () => ApprovalMode.YOLO,
getAllowedTools: () => [],
getContentGeneratorConfig: () => ({
model: 'test-model',
authType: 'oauth-personal',
@@ -678,6 +710,7 @@ describe('CoreToolScheduler request queueing', () => {
getUsageStatisticsEnabled: () => true,
getDebugMode: () => false,
getApprovalMode: () => ApprovalMode.YOLO, // Use YOLO to avoid confirmation prompts
getAllowedTools: () => [],
getContentGeneratorConfig: () => ({
model: 'test-model',
authType: 'oauth-personal',
@@ -713,10 +746,7 @@ describe('CoreToolScheduler request queueing', () => {
scheduler.schedule([request1], abortController.signal);
// Wait for the first call to be in the 'executing' state.
await vi.waitFor(() => {
const calls = onToolCallsUpdate.mock.calls.at(-1)?.[0] as ToolCall[];
expect(calls?.[0]?.status).toBe('executing');
});
await waitForStatus(onToolCallsUpdate, 'executing');
// Schedule the second call while the first is "running".
const schedulePromise2 = scheduler.schedule(
@@ -737,16 +767,6 @@ describe('CoreToolScheduler request queueing', () => {
// Wait for the second schedule promise to resolve.
await schedulePromise2;
// Wait for the second call to be in the 'executing' state.
await vi.waitFor(() => {
const calls = onToolCallsUpdate.mock.calls.at(-1)?.[0] as ToolCall[];
expect(calls?.[0]?.status).toBe('executing');
});
// Now the second tool call should have been executed.
expect(mockTool.executeFn).toHaveBeenCalledTimes(2);
expect(mockTool.executeFn).toHaveBeenCalledWith({ b: 2 });
// Let the second call finish.
const secondCallResult = {
llmContent: 'Second call complete',
@@ -756,6 +776,12 @@ describe('CoreToolScheduler request queueing', () => {
// In a real scenario, a new promise would be created for the second call.
resolveFirstCall!(secondCallResult);
await vi.waitFor(() => {
// Now the second tool call should have been executed.
expect(mockTool.executeFn).toHaveBeenCalledTimes(2);
});
expect(mockTool.executeFn).toHaveBeenCalledWith({ b: 2 });
// Wait for the second completion.
await vi.waitFor(() => {
expect(onAllToolCallsComplete).toHaveBeenCalledTimes(2);
@@ -766,6 +792,96 @@ describe('CoreToolScheduler request queueing', () => {
expect(onAllToolCallsComplete.mock.calls[1][0][0].status).toBe('success');
});
it('should auto-approve a tool call if it is on the allowedTools list', async () => {
// Arrange
const mockTool = new MockTool('mockTool');
mockTool.executeFn.mockReturnValue({
llmContent: 'Tool executed',
returnDisplay: 'Tool executed',
});
// This tool would normally require confirmation.
mockTool.shouldConfirm = true;
const declarativeTool = mockTool;
const toolRegistry = {
getTool: () => declarativeTool,
getToolByName: () => declarativeTool,
getFunctionDeclarations: () => [],
tools: new Map(),
discovery: {},
registerTool: () => {},
getToolByDisplayName: () => declarativeTool,
getTools: () => [],
discoverTools: async () => {},
getAllTools: () => [],
getToolsByServer: () => [],
};
const onAllToolCallsComplete = vi.fn();
const onToolCallsUpdate = vi.fn();
// Configure the scheduler to auto-approve the specific tool call.
const mockConfig = {
getSessionId: () => 'test-session-id',
getUsageStatisticsEnabled: () => true,
getDebugMode: () => false,
getApprovalMode: () => ApprovalMode.DEFAULT, // Not YOLO mode
getAllowedTools: () => ['mockTool'], // Auto-approve this tool
getToolRegistry: () => toolRegistry,
getContentGeneratorConfig: () => ({
model: 'test-model',
authType: 'oauth-personal',
}),
} as unknown as Config;
const scheduler = new CoreToolScheduler({
config: mockConfig,
onAllToolCallsComplete,
onToolCallsUpdate,
getPreferredEditor: () => 'vscode',
onEditorClose: vi.fn(),
});
const abortController = new AbortController();
const request = {
callId: '1',
name: 'mockTool',
args: { param: 'value' },
isClientInitiated: false,
prompt_id: 'prompt-auto-approved',
};
// Act
await scheduler.schedule([request], abortController.signal);
// Assert
// 1. The tool's execute method was called directly.
expect(mockTool.executeFn).toHaveBeenCalledWith({ param: 'value' });
// 2. The tool call status never entered 'awaiting_approval'.
const statusUpdates = onToolCallsUpdate.mock.calls
.map((call) => (call[0][0] as ToolCall)?.status)
.filter(Boolean);
expect(statusUpdates).not.toContain('awaiting_approval');
expect(statusUpdates).toEqual([
'validating',
'scheduled',
'executing',
'success',
]);
// 3. The final callback indicates the tool call was successful.
expect(onAllToolCallsComplete).toHaveBeenCalled();
const completedCalls = onAllToolCallsComplete.mock
.calls[0][0] as ToolCall[];
expect(completedCalls).toHaveLength(1);
const completedCall = completedCalls[0];
expect(completedCall.status).toBe('success');
if (completedCall.status === 'success') {
expect(completedCall.response.resultDisplay).toBe('Tool executed');
}
});
it('should handle two synchronous calls to schedule', async () => {
const mockTool = new MockTool();
const declarativeTool = mockTool;
@@ -782,7 +898,6 @@ describe('CoreToolScheduler request queueing', () => {
getAllTools: () => [],
getToolsByServer: () => [],
} as unknown as ToolRegistry;
const onAllToolCallsComplete = vi.fn();
const onToolCallsUpdate = vi.fn();
@@ -791,6 +906,7 @@ describe('CoreToolScheduler request queueing', () => {
getUsageStatisticsEnabled: () => true,
getDebugMode: () => false,
getApprovalMode: () => ApprovalMode.YOLO,
getAllowedTools: () => [],
getContentGeneratorConfig: () => ({
model: 'test-model',
authType: 'oauth-personal',
@@ -851,6 +967,7 @@ describe('CoreToolScheduler request queueing', () => {
getUsageStatisticsEnabled: () => true,
getDebugMode: () => false,
getApprovalMode: () => approvalMode,
getAllowedTools: () => [],
setApprovalMode: (mode: ApprovalMode) => {
approvalMode = mode;
},

View File

@@ -32,6 +32,7 @@ import {
modifyWithEditor,
} from '../tools/modifiable-tool.js';
import * as Diff from 'diff';
import { doesToolInvocationMatch } from '../utils/tool-utils.js';
export type ValidatingToolCall = {
status: 'validating';
@@ -615,68 +616,74 @@ export class CoreToolScheduler {
);
continue;
}
if (this.config.getApprovalMode() === ApprovalMode.YOLO) {
const confirmationDetails =
await invocation.shouldConfirmExecute(signal);
if (!confirmationDetails) {
this.setToolCallOutcome(
reqInfo.callId,
ToolConfirmationOutcome.ProceedAlways,
);
this.setStatusInternal(reqInfo.callId, 'scheduled');
continue;
}
const allowedTools = this.config.getAllowedTools() || [];
if (
this.config.getApprovalMode() === ApprovalMode.YOLO ||
doesToolInvocationMatch(toolCall.tool, invocation, allowedTools)
) {
this.setToolCallOutcome(
reqInfo.callId,
ToolConfirmationOutcome.ProceedAlways,
);
this.setStatusInternal(reqInfo.callId, 'scheduled');
} else {
const confirmationDetails =
await invocation.shouldConfirmExecute(signal);
if (confirmationDetails) {
// Allow IDE to resolve confirmation
if (
confirmationDetails.type === 'edit' &&
confirmationDetails.ideConfirmation
) {
confirmationDetails.ideConfirmation.then((resolution) => {
if (resolution.status === 'accepted') {
this.handleConfirmationResponse(
reqInfo.callId,
confirmationDetails.onConfirm,
ToolConfirmationOutcome.ProceedOnce,
signal,
);
} else {
this.handleConfirmationResponse(
reqInfo.callId,
confirmationDetails.onConfirm,
ToolConfirmationOutcome.Cancel,
signal,
);
}
});
}
const originalOnConfirm = confirmationDetails.onConfirm;
const wrappedConfirmationDetails: ToolCallConfirmationDetails = {
...confirmationDetails,
onConfirm: (
outcome: ToolConfirmationOutcome,
payload?: ToolConfirmationPayload,
) =>
// Allow IDE to resolve confirmation
if (
confirmationDetails.type === 'edit' &&
confirmationDetails.ideConfirmation
) {
confirmationDetails.ideConfirmation.then((resolution) => {
if (resolution.status === 'accepted') {
this.handleConfirmationResponse(
reqInfo.callId,
originalOnConfirm,
outcome,
confirmationDetails.onConfirm,
ToolConfirmationOutcome.ProceedOnce,
signal,
payload,
),
};
this.setStatusInternal(
reqInfo.callId,
'awaiting_approval',
wrappedConfirmationDetails,
);
} else {
this.setToolCallOutcome(
reqInfo.callId,
ToolConfirmationOutcome.ProceedAlways,
);
this.setStatusInternal(reqInfo.callId, 'scheduled');
);
} else {
this.handleConfirmationResponse(
reqInfo.callId,
confirmationDetails.onConfirm,
ToolConfirmationOutcome.Cancel,
signal,
);
}
});
}
const originalOnConfirm = confirmationDetails.onConfirm;
const wrappedConfirmationDetails: ToolCallConfirmationDetails = {
...confirmationDetails,
onConfirm: (
outcome: ToolConfirmationOutcome,
payload?: ToolConfirmationPayload,
) =>
this.handleConfirmationResponse(
reqInfo.callId,
originalOnConfirm,
outcome,
signal,
payload,
),
};
this.setStatusInternal(
reqInfo.callId,
'awaiting_approval',
wrappedConfirmationDetails,
);
}
} catch (error) {
this.setStatusInternal(

View File

@@ -32,6 +32,7 @@ describe('executeToolCall', () => {
mockConfig = {
getToolRegistry: () => mockToolRegistry,
getApprovalMode: () => ApprovalMode.DEFAULT,
getAllowedTools: () => [],
getSessionId: () => 'test-session-id',
getUsageStatisticsEnabled: () => true,
getDebugMode: () => false,

View File

@@ -307,6 +307,21 @@ export abstract class BaseDeclarativeTool<
*/
export type AnyDeclarativeTool = DeclarativeTool<object, ToolResult>;
/**
* Type guard to check if an object is a Tool.
* @param obj The object to check.
* @returns True if the object is a Tool, false otherwise.
*/
export function isTool(obj: unknown): obj is AnyDeclarativeTool {
return (
typeof obj === 'object' &&
obj !== null &&
'name' in obj &&
'build' in obj &&
typeof (obj as AnyDeclarativeTool).build === 'function'
);
}
export interface ToolResult {
/**
* Content meant to be included in LLM history.

View File

@@ -16,11 +16,14 @@ import {
import type { Config } from '../config/config.js';
const mockPlatform = vi.hoisted(() => vi.fn());
const mockHomedir = vi.hoisted(() => vi.fn());
vi.mock('os', () => ({
default: {
platform: mockPlatform,
homedir: mockHomedir,
},
platform: mockPlatform,
homedir: mockHomedir,
}));
const mockQuote = vi.hoisted(() => vi.fn());
@@ -38,6 +41,7 @@ beforeEach(() => {
config = {
getCoreTools: () => [],
getExcludeTools: () => [],
getAllowedTools: () => [],
} as unknown as Config;
});

View File

@@ -4,9 +4,13 @@
* SPDX-License-Identifier: Apache-2.0
*/
import type { AnyToolInvocation } from '../index.js';
import type { Config } from '../config/config.js';
import os from 'node:os';
import { quote } from 'shell-quote';
import { doesToolInvocationMatch } from './tool-utils.js';
const SHELL_TOOL_NAMES = ['run_shell_command', 'ShellTool'];
/**
* An identifier for the shell type.
@@ -319,32 +323,19 @@ export function checkCommandPermissions(
};
}
const SHELL_TOOL_NAMES = ['run_shell_command', 'ShellTool'];
const normalize = (cmd: string): string => cmd.trim().replace(/\s+/g, ' ');
const isPrefixedBy = (cmd: string, prefix: string): boolean => {
if (!cmd.startsWith(prefix)) {
return false;
}
return cmd.length === prefix.length || cmd[prefix.length] === ' ';
};
const extractCommands = (tools: string[]): string[] =>
tools.flatMap((tool) => {
for (const toolName of SHELL_TOOL_NAMES) {
if (tool.startsWith(`${toolName}(`) && tool.endsWith(')')) {
return [normalize(tool.slice(toolName.length + 1, -1))];
}
}
return [];
});
const coreTools = config.getCoreTools() || [];
const excludeTools = config.getExcludeTools() || [];
const commandsToValidate = splitCommands(command).map(normalize);
const invocation: AnyToolInvocation & { params: { command: string } } = {
params: { command: '' },
} as AnyToolInvocation & { params: { command: string } };
// 1. Blocklist Check (Highest Priority)
if (SHELL_TOOL_NAMES.some((name) => excludeTools.includes(name))) {
const excludeTools = config.getExcludeTools() || [];
const isWildcardBlocked = SHELL_TOOL_NAMES.some((name) =>
excludeTools.includes(name),
);
if (isWildcardBlocked) {
return {
allAllowed: false,
disallowedCommands: commandsToValidate,
@@ -352,9 +343,12 @@ export function checkCommandPermissions(
isHardDenial: true,
};
}
const blockedCommands = extractCommands(excludeTools);
for (const cmd of commandsToValidate) {
if (blockedCommands.some((blocked) => isPrefixedBy(cmd, blocked))) {
invocation.params['command'] = cmd;
if (
doesToolInvocationMatch('run_shell_command', invocation, excludeTools)
) {
return {
allAllowed: false,
disallowedCommands: [cmd],
@@ -364,7 +358,7 @@ export function checkCommandPermissions(
}
}
const globallyAllowedCommands = extractCommands(coreTools);
const coreTools = config.getCoreTools() || [];
const isWildcardAllowed = SHELL_TOOL_NAMES.some((name) =>
coreTools.includes(name),
);
@@ -375,18 +369,30 @@ export function checkCommandPermissions(
return { allAllowed: true, disallowedCommands: [] };
}
const disallowedCommands: string[] = [];
if (sessionAllowlist) {
// "DEFAULT DENY" MODE: A session allowlist is provided.
// All commands must be in either the session or global allowlist.
const disallowedCommands: string[] = [];
const normalizedSessionAllowlist = new Set(
[...sessionAllowlist].flatMap((cmd) =>
SHELL_TOOL_NAMES.map((name) => `${name}(${cmd})`),
),
);
for (const cmd of commandsToValidate) {
const isSessionAllowed = [...sessionAllowlist].some((allowed) =>
isPrefixedBy(cmd, normalize(allowed)),
invocation.params['command'] = cmd;
const isSessionAllowed = doesToolInvocationMatch(
'run_shell_command',
invocation,
[...normalizedSessionAllowlist],
);
if (isSessionAllowed) continue;
const isGloballyAllowed = globallyAllowedCommands.some((allowed) =>
isPrefixedBy(cmd, allowed),
const isGloballyAllowed = doesToolInvocationMatch(
'run_shell_command',
invocation,
coreTools,
);
if (isGloballyAllowed) continue;
@@ -405,12 +411,18 @@ export function checkCommandPermissions(
}
} else {
// "DEFAULT ALLOW" MODE: No session allowlist.
const hasSpecificAllowedCommands = globallyAllowedCommands.length > 0;
const hasSpecificAllowedCommands =
coreTools.filter((tool) =>
SHELL_TOOL_NAMES.some((name) => tool.startsWith(`${name}(`)),
).length > 0;
if (hasSpecificAllowedCommands) {
const disallowedCommands: string[] = [];
for (const cmd of commandsToValidate) {
const isGloballyAllowed = globallyAllowedCommands.some((allowed) =>
isPrefixedBy(cmd, allowed),
invocation.params['command'] = cmd;
const isGloballyAllowed = doesToolInvocationMatch(
'run_shell_command',
invocation,
coreTools,
);
if (!isGloballyAllowed) {
disallowedCommands.push(cmd);
@@ -420,7 +432,9 @@ export function checkCommandPermissions(
return {
allAllowed: false,
disallowedCommands,
blockReason: `Command(s) not in the allowed commands list. Disallowed commands: ${disallowedCommands.map((c) => JSON.stringify(c)).join(', ')}`,
blockReason: `Command(s) not in the allowed commands list. Disallowed commands: ${disallowedCommands
.map((c) => JSON.stringify(c))
.join(', ')}`,
isHardDenial: false, // This is a soft denial.
};
}

View File

@@ -0,0 +1,94 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { expect, describe, it } from 'vitest';
import { doesToolInvocationMatch } from './tool-utils.js';
import type { AnyToolInvocation, Config } from '../index.js';
import { ReadFileTool } from '../tools/read-file.js';
describe('doesToolInvocationMatch', () => {
it('should not match a partial command prefix', () => {
const invocation = {
params: { command: 'git commitsomething' },
} as AnyToolInvocation;
const patterns = ['ShellTool(git commit)'];
const result = doesToolInvocationMatch(
'run_shell_command',
invocation,
patterns,
);
expect(result).toBe(false);
});
it('should match an exact command', () => {
const invocation = {
params: { command: 'git status' },
} as AnyToolInvocation;
const patterns = ['ShellTool(git status)'];
const result = doesToolInvocationMatch(
'run_shell_command',
invocation,
patterns,
);
expect(result).toBe(true);
});
it('should match a command that is a prefix', () => {
const invocation = {
params: { command: 'git status -v' },
} as AnyToolInvocation;
const patterns = ['ShellTool(git status)'];
const result = doesToolInvocationMatch(
'run_shell_command',
invocation,
patterns,
);
expect(result).toBe(true);
});
describe('for non-shell tools', () => {
const readFileTool = new ReadFileTool({} as Config);
const invocation = {
params: { file: 'test.txt' },
} as AnyToolInvocation;
it('should match by tool name', () => {
const patterns = ['read_file'];
const result = doesToolInvocationMatch(
readFileTool,
invocation,
patterns,
);
expect(result).toBe(true);
});
it('should match by tool class name', () => {
const patterns = ['ReadFileTool'];
const result = doesToolInvocationMatch(
readFileTool,
invocation,
patterns,
);
expect(result).toBe(true);
});
it('should not match if neither name is in the patterns', () => {
const patterns = ['some_other_tool', 'AnotherToolClass'];
const result = doesToolInvocationMatch(
readFileTool,
invocation,
patterns,
);
expect(result).toBe(false);
});
it('should match by tool name when passed as a string', () => {
const patterns = ['read_file'];
const result = doesToolInvocationMatch('read_file', invocation, patterns);
expect(result).toBe(true);
});
});
});

View File

@@ -0,0 +1,76 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import type { AnyDeclarativeTool, AnyToolInvocation } from '../index.js';
import { isTool } from '../index.js';
const SHELL_TOOL_NAMES = ['run_shell_command', 'ShellTool'];
/**
* Checks if a tool invocation matches any of a list of patterns.
*
* @param toolOrToolName The tool object or the name of the tool being invoked.
* @param invocation The invocation object for the tool.
* @param patterns A list of patterns to match against.
* Patterns can be:
* - A tool name (e.g., "ReadFileTool") to match any invocation of that tool.
* - A tool name with a prefix (e.g., "ShellTool(git status)") to match
* invocations where the arguments start with that prefix.
* @returns True if the invocation matches any pattern, false otherwise.
*/
export function doesToolInvocationMatch(
toolOrToolName: AnyDeclarativeTool | string,
invocation: AnyToolInvocation,
patterns: string[],
): boolean {
let toolNames: string[];
if (isTool(toolOrToolName)) {
toolNames = [toolOrToolName.name, toolOrToolName.constructor.name];
} else {
toolNames = [toolOrToolName as string];
}
if (toolNames.some((name) => SHELL_TOOL_NAMES.includes(name))) {
toolNames = [...new Set([...toolNames, ...SHELL_TOOL_NAMES])];
}
for (const pattern of patterns) {
const openParen = pattern.indexOf('(');
if (openParen === -1) {
// No arguments, just a tool name
if (toolNames.includes(pattern)) {
return true;
}
continue;
}
const patternToolName = pattern.substring(0, openParen);
if (!toolNames.includes(patternToolName)) {
continue;
}
if (!pattern.endsWith(')')) {
continue;
}
const argPattern = pattern.substring(openParen + 1, pattern.length - 1);
if (
'command' in invocation.params &&
toolNames.includes('run_shell_command')
) {
const argValue = String(
(invocation.params as { command: string }).command,
);
if (argValue === argPattern || argValue.startsWith(argPattern + ' ')) {
return true;
}
}
}
return false;
}