feat(cli): Add --allowed-tools flag to bypass tool confirmation (#2417) (#6453)

This commit is contained in:
Andrew Garrett
2025-08-27 02:17:43 +10:00
committed by GitHub
parent c33a0da1df
commit 52dae2c583
14 changed files with 524 additions and 135 deletions

View File

@@ -16,11 +16,14 @@ import {
import type { Config } from '../config/config.js';
const mockPlatform = vi.hoisted(() => vi.fn());
const mockHomedir = vi.hoisted(() => vi.fn());
vi.mock('os', () => ({
default: {
platform: mockPlatform,
homedir: mockHomedir,
},
platform: mockPlatform,
homedir: mockHomedir,
}));
const mockQuote = vi.hoisted(() => vi.fn());
@@ -38,6 +41,7 @@ beforeEach(() => {
config = {
getCoreTools: () => [],
getExcludeTools: () => [],
getAllowedTools: () => [],
} as unknown as Config;
});

View File

@@ -4,9 +4,13 @@
* SPDX-License-Identifier: Apache-2.0
*/
import type { AnyToolInvocation } from '../index.js';
import type { Config } from '../config/config.js';
import os from 'node:os';
import { quote } from 'shell-quote';
import { doesToolInvocationMatch } from './tool-utils.js';
const SHELL_TOOL_NAMES = ['run_shell_command', 'ShellTool'];
/**
* An identifier for the shell type.
@@ -319,32 +323,19 @@ export function checkCommandPermissions(
};
}
const SHELL_TOOL_NAMES = ['run_shell_command', 'ShellTool'];
const normalize = (cmd: string): string => cmd.trim().replace(/\s+/g, ' ');
const isPrefixedBy = (cmd: string, prefix: string): boolean => {
if (!cmd.startsWith(prefix)) {
return false;
}
return cmd.length === prefix.length || cmd[prefix.length] === ' ';
};
const extractCommands = (tools: string[]): string[] =>
tools.flatMap((tool) => {
for (const toolName of SHELL_TOOL_NAMES) {
if (tool.startsWith(`${toolName}(`) && tool.endsWith(')')) {
return [normalize(tool.slice(toolName.length + 1, -1))];
}
}
return [];
});
const coreTools = config.getCoreTools() || [];
const excludeTools = config.getExcludeTools() || [];
const commandsToValidate = splitCommands(command).map(normalize);
const invocation: AnyToolInvocation & { params: { command: string } } = {
params: { command: '' },
} as AnyToolInvocation & { params: { command: string } };
// 1. Blocklist Check (Highest Priority)
if (SHELL_TOOL_NAMES.some((name) => excludeTools.includes(name))) {
const excludeTools = config.getExcludeTools() || [];
const isWildcardBlocked = SHELL_TOOL_NAMES.some((name) =>
excludeTools.includes(name),
);
if (isWildcardBlocked) {
return {
allAllowed: false,
disallowedCommands: commandsToValidate,
@@ -352,9 +343,12 @@ export function checkCommandPermissions(
isHardDenial: true,
};
}
const blockedCommands = extractCommands(excludeTools);
for (const cmd of commandsToValidate) {
if (blockedCommands.some((blocked) => isPrefixedBy(cmd, blocked))) {
invocation.params['command'] = cmd;
if (
doesToolInvocationMatch('run_shell_command', invocation, excludeTools)
) {
return {
allAllowed: false,
disallowedCommands: [cmd],
@@ -364,7 +358,7 @@ export function checkCommandPermissions(
}
}
const globallyAllowedCommands = extractCommands(coreTools);
const coreTools = config.getCoreTools() || [];
const isWildcardAllowed = SHELL_TOOL_NAMES.some((name) =>
coreTools.includes(name),
);
@@ -375,18 +369,30 @@ export function checkCommandPermissions(
return { allAllowed: true, disallowedCommands: [] };
}
const disallowedCommands: string[] = [];
if (sessionAllowlist) {
// "DEFAULT DENY" MODE: A session allowlist is provided.
// All commands must be in either the session or global allowlist.
const disallowedCommands: string[] = [];
const normalizedSessionAllowlist = new Set(
[...sessionAllowlist].flatMap((cmd) =>
SHELL_TOOL_NAMES.map((name) => `${name}(${cmd})`),
),
);
for (const cmd of commandsToValidate) {
const isSessionAllowed = [...sessionAllowlist].some((allowed) =>
isPrefixedBy(cmd, normalize(allowed)),
invocation.params['command'] = cmd;
const isSessionAllowed = doesToolInvocationMatch(
'run_shell_command',
invocation,
[...normalizedSessionAllowlist],
);
if (isSessionAllowed) continue;
const isGloballyAllowed = globallyAllowedCommands.some((allowed) =>
isPrefixedBy(cmd, allowed),
const isGloballyAllowed = doesToolInvocationMatch(
'run_shell_command',
invocation,
coreTools,
);
if (isGloballyAllowed) continue;
@@ -405,12 +411,18 @@ export function checkCommandPermissions(
}
} else {
// "DEFAULT ALLOW" MODE: No session allowlist.
const hasSpecificAllowedCommands = globallyAllowedCommands.length > 0;
const hasSpecificAllowedCommands =
coreTools.filter((tool) =>
SHELL_TOOL_NAMES.some((name) => tool.startsWith(`${name}(`)),
).length > 0;
if (hasSpecificAllowedCommands) {
const disallowedCommands: string[] = [];
for (const cmd of commandsToValidate) {
const isGloballyAllowed = globallyAllowedCommands.some((allowed) =>
isPrefixedBy(cmd, allowed),
invocation.params['command'] = cmd;
const isGloballyAllowed = doesToolInvocationMatch(
'run_shell_command',
invocation,
coreTools,
);
if (!isGloballyAllowed) {
disallowedCommands.push(cmd);
@@ -420,7 +432,9 @@ export function checkCommandPermissions(
return {
allAllowed: false,
disallowedCommands,
blockReason: `Command(s) not in the allowed commands list. Disallowed commands: ${disallowedCommands.map((c) => JSON.stringify(c)).join(', ')}`,
blockReason: `Command(s) not in the allowed commands list. Disallowed commands: ${disallowedCommands
.map((c) => JSON.stringify(c))
.join(', ')}`,
isHardDenial: false, // This is a soft denial.
};
}

View File

@@ -0,0 +1,94 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { expect, describe, it } from 'vitest';
import { doesToolInvocationMatch } from './tool-utils.js';
import type { AnyToolInvocation, Config } from '../index.js';
import { ReadFileTool } from '../tools/read-file.js';
describe('doesToolInvocationMatch', () => {
it('should not match a partial command prefix', () => {
const invocation = {
params: { command: 'git commitsomething' },
} as AnyToolInvocation;
const patterns = ['ShellTool(git commit)'];
const result = doesToolInvocationMatch(
'run_shell_command',
invocation,
patterns,
);
expect(result).toBe(false);
});
it('should match an exact command', () => {
const invocation = {
params: { command: 'git status' },
} as AnyToolInvocation;
const patterns = ['ShellTool(git status)'];
const result = doesToolInvocationMatch(
'run_shell_command',
invocation,
patterns,
);
expect(result).toBe(true);
});
it('should match a command that is a prefix', () => {
const invocation = {
params: { command: 'git status -v' },
} as AnyToolInvocation;
const patterns = ['ShellTool(git status)'];
const result = doesToolInvocationMatch(
'run_shell_command',
invocation,
patterns,
);
expect(result).toBe(true);
});
describe('for non-shell tools', () => {
const readFileTool = new ReadFileTool({} as Config);
const invocation = {
params: { file: 'test.txt' },
} as AnyToolInvocation;
it('should match by tool name', () => {
const patterns = ['read_file'];
const result = doesToolInvocationMatch(
readFileTool,
invocation,
patterns,
);
expect(result).toBe(true);
});
it('should match by tool class name', () => {
const patterns = ['ReadFileTool'];
const result = doesToolInvocationMatch(
readFileTool,
invocation,
patterns,
);
expect(result).toBe(true);
});
it('should not match if neither name is in the patterns', () => {
const patterns = ['some_other_tool', 'AnotherToolClass'];
const result = doesToolInvocationMatch(
readFileTool,
invocation,
patterns,
);
expect(result).toBe(false);
});
it('should match by tool name when passed as a string', () => {
const patterns = ['read_file'];
const result = doesToolInvocationMatch('read_file', invocation, patterns);
expect(result).toBe(true);
});
});
});

View File

@@ -0,0 +1,76 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import type { AnyDeclarativeTool, AnyToolInvocation } from '../index.js';
import { isTool } from '../index.js';
const SHELL_TOOL_NAMES = ['run_shell_command', 'ShellTool'];
/**
* Checks if a tool invocation matches any of a list of patterns.
*
* @param toolOrToolName The tool object or the name of the tool being invoked.
* @param invocation The invocation object for the tool.
* @param patterns A list of patterns to match against.
* Patterns can be:
* - A tool name (e.g., "ReadFileTool") to match any invocation of that tool.
* - A tool name with a prefix (e.g., "ShellTool(git status)") to match
* invocations where the arguments start with that prefix.
* @returns True if the invocation matches any pattern, false otherwise.
*/
export function doesToolInvocationMatch(
toolOrToolName: AnyDeclarativeTool | string,
invocation: AnyToolInvocation,
patterns: string[],
): boolean {
let toolNames: string[];
if (isTool(toolOrToolName)) {
toolNames = [toolOrToolName.name, toolOrToolName.constructor.name];
} else {
toolNames = [toolOrToolName as string];
}
if (toolNames.some((name) => SHELL_TOOL_NAMES.includes(name))) {
toolNames = [...new Set([...toolNames, ...SHELL_TOOL_NAMES])];
}
for (const pattern of patterns) {
const openParen = pattern.indexOf('(');
if (openParen === -1) {
// No arguments, just a tool name
if (toolNames.includes(pattern)) {
return true;
}
continue;
}
const patternToolName = pattern.substring(0, openParen);
if (!toolNames.includes(patternToolName)) {
continue;
}
if (!pattern.endsWith(')')) {
continue;
}
const argPattern = pattern.substring(openParen + 1, pattern.length - 1);
if (
'command' in invocation.params &&
toolNames.includes('run_shell_command')
) {
const argValue = String(
(invocation.params as { command: string }).command,
);
if (argValue === argPattern || argValue.startsWith(argPattern + ' ')) {
return true;
}
}
}
return false;
}