diff --git a/docs/cli/configuration.md b/docs/cli/configuration.md index e4ba1526..72d73b4f 100644 --- a/docs/cli/configuration.md +++ b/docs/cli/configuration.md @@ -91,6 +91,11 @@ If you are experiencing performance issues with file searching (e.g., with `@` c - **Default:** All tools available for use by the Gemini model. - **Example:** `"coreTools": ["ReadFileTool", "GlobTool", "ShellTool(ls)"]`. +- **`allowedTools`** (array of strings): + - **Default:** `undefined` + - **Description:** A list of tool names that will bypass the confirmation dialog. This is useful for tools that you trust and use frequently. The match semantics are the same as `coreTools`. + - **Example:** `"allowedTools": ["ShellTool(git status)"]`. + - **`excludeTools`** (array of strings): - **Description:** Allows you to specify a list of core tool names that should be excluded from the model. A tool listed in both `excludeTools` and `coreTools` is excluded. You can also specify command-specific restrictions for tools that support it, like the `ShellTool`. For example, `"excludeTools": ["ShellTool(rm -rf)"]` will block the `rm -rf` command. - **Default**: No tools excluded. @@ -479,6 +484,9 @@ Arguments passed directly when running the CLI can override other configurations - `yolo`: Automatically approve all tool calls (equivalent to `--yolo`) - Cannot be used together with `--yolo`. Use `--approval-mode=yolo` instead of `--yolo` for the new unified approach. - Example: `gemini --approval-mode auto_edit` +- **`--allowed-tools `**: + - A comma-separated list of tool names that will bypass the confirmation dialog. + - Example: `gemini --allowed-tools "ShellTool(git status)"` - **`--telemetry`**: - Enables [telemetry](../telemetry.md). - **`--telemetry-target`**: diff --git a/packages/cli/src/config/config.ts b/packages/cli/src/config/config.ts index c7296a5e..969b45fb 100755 --- a/packages/cli/src/config/config.ts +++ b/packages/cli/src/config/config.ts @@ -70,6 +70,7 @@ export interface CliArgs { telemetryLogPrompts: boolean | undefined; telemetryOutfile: string | undefined; allowedMcpServerNames: string[] | undefined; + allowedTools: string[] | undefined; experimentalAcp: boolean | undefined; extensions: string[] | undefined; listExtensions: boolean | undefined; @@ -189,6 +190,11 @@ export async function parseArguments(settings: Settings): Promise { string: true, description: 'Allowed MCP server names', }) + .option('allowed-tools', { + type: 'array', + string: true, + description: 'Tools that are allowed to run without confirmation', + }) .option('extensions', { alias: 'e', type: 'array', @@ -489,6 +495,7 @@ export async function loadCliConfig( question, fullContext: argv.allFiles || false, coreTools: settings.coreTools || undefined, + allowedTools: argv.allowedTools || settings.allowedTools || undefined, excludeTools, toolDiscoveryCommand: settings.toolDiscoveryCommand, toolCallCommand: settings.toolCallCommand, diff --git a/packages/cli/src/config/settingsSchema.ts b/packages/cli/src/config/settingsSchema.ts index 89461b52..c9e845b9 100644 --- a/packages/cli/src/config/settingsSchema.ts +++ b/packages/cli/src/config/settingsSchema.ts @@ -344,6 +344,16 @@ export const SETTINGS_SCHEMA = { description: 'Paths to core tool definitions.', showInDialog: false, }, + allowedTools: { + type: 'array', + label: 'Allowed Tools', + category: 'Advanced', + requiresRestart: true, + default: undefined as string[] | undefined, + description: + 'A list of tool names that will bypass the confirmation dialog.', + showInDialog: false, + }, excludeTools: { type: 'array', label: 'Exclude Tools', diff --git a/packages/cli/src/ui/hooks/useToolScheduler.test.ts b/packages/cli/src/ui/hooks/useToolScheduler.test.ts index f3817764..3a06dd27 100644 --- a/packages/cli/src/ui/hooks/useToolScheduler.test.ts +++ b/packages/cli/src/ui/hooks/useToolScheduler.test.ts @@ -53,14 +53,15 @@ const mockToolRegistry = { const mockConfig = { getToolRegistry: vi.fn(() => mockToolRegistry as unknown as ToolRegistry), getApprovalMode: vi.fn(() => ApprovalMode.DEFAULT), + getSessionId: () => 'test-session-id', getUsageStatisticsEnabled: () => true, getDebugMode: () => false, - getSessionId: () => 'test-session-id', + getAllowedTools: vi.fn(() => []), getContentGeneratorConfig: () => ({ model: 'test-model', authType: 'oauth-personal', }), -}; +} as unknown as Config; class MockToolInvocation extends BaseToolInvocation { constructor( @@ -218,11 +219,6 @@ describe('useReactToolScheduler in YOLO Mode', () => { await vi.runAllTimersAsync(); // Process execution }); - // Check that shouldConfirmExecute was NOT called - expect( - mockToolRequiresConfirmation.shouldConfirmExecute, - ).not.toHaveBeenCalled(); - // Check that execute WAS called expect(mockToolRequiresConfirmation.execute).toHaveBeenCalledWith( request.args, diff --git a/packages/core/src/config/config.test.ts b/packages/core/src/config/config.test.ts index 3326ebf9..d5647bee 100644 --- a/packages/core/src/config/config.test.ts +++ b/packages/core/src/config/config.test.ts @@ -23,6 +23,9 @@ import { GeminiClient } from '../core/client.js'; import { GitService } from '../services/gitService.js'; import { ClearcutLogger } from '../telemetry/clearcut-logger/clearcut-logger.js'; +import { ShellTool } from '../tools/shell.js'; +import { ReadFileTool } from '../tools/read-file.js'; + vi.mock('fs', async (importOriginal) => { const actual = await importOriginal(); return { @@ -629,6 +632,36 @@ describe('Server Config (config.ts)', () => { expect(config.getUseRipgrep()).toBe(false); }); }); + + describe('createToolRegistry', () => { + it('should register a tool if coreTools contains an argument-specific pattern', async () => { + const params: ConfigParameters = { + ...baseParams, + coreTools: ['ShellTool(git status)'], + }; + const config = new Config(params); + await config.initialize(); + + // The ToolRegistry class is mocked, so we can inspect its prototype's methods. + const registerToolMock = ( + (await vi.importMock('../tools/tool-registry')) as { + ToolRegistry: { prototype: { registerTool: Mock } }; + } + ).ToolRegistry.prototype.registerTool; + + // Check that registerTool was called for ShellTool + const wasShellToolRegistered = (registerToolMock as Mock).mock.calls.some( + (call) => call[0] instanceof vi.mocked(ShellTool), + ); + expect(wasShellToolRegistered).toBe(true); + + // Check that registerTool was NOT called for ReadFileTool + const wasReadFileToolRegistered = ( + registerToolMock as Mock + ).mock.calls.some((call) => call[0] instanceof vi.mocked(ReadFileTool)); + expect(wasReadFileToolRegistered).toBe(false); + }); + }); }); describe('setApprovalMode with folder trust', () => { diff --git a/packages/core/src/config/config.ts b/packages/core/src/config/config.ts index 86d3ddd5..ed2bac45 100644 --- a/packages/core/src/config/config.ts +++ b/packages/core/src/config/config.ts @@ -49,7 +49,8 @@ import { logCliConfiguration, logIdeConnection } from '../telemetry/loggers.js'; import { IdeConnectionEvent, IdeConnectionType } from '../telemetry/types.js'; // Re-export OAuth config type -export type { MCPOAuthConfig }; +export type { MCPOAuthConfig, AnyToolInvocation }; +import type { AnyToolInvocation } from '../tools/tools.js'; import { WorkspaceContext } from '../utils/workspaceContext.js'; import { Storage } from './storage.js'; import { FileExclusions } from '../utils/ignorePatterns.js'; @@ -159,6 +160,7 @@ export interface ConfigParameters { question?: string; fullContext?: boolean; coreTools?: string[]; + allowedTools?: string[]; excludeTools?: string[]; toolDiscoveryCommand?: string; toolCallCommand?: string; @@ -221,6 +223,7 @@ export class Config { private readonly question: string | undefined; private readonly fullContext: boolean; private readonly coreTools: string[] | undefined; + private readonly allowedTools: string[] | undefined; private readonly excludeTools: string[] | undefined; private readonly toolDiscoveryCommand: string | undefined; private readonly toolCallCommand: string | undefined; @@ -295,6 +298,7 @@ export class Config { this.question = params.question; this.fullContext = params.fullContext ?? false; this.coreTools = params.coreTools; + this.allowedTools = params.allowedTools; this.excludeTools = params.excludeTools; this.toolDiscoveryCommand = params.toolDiscoveryCommand; this.toolCallCommand = params.toolCallCommand; @@ -523,6 +527,10 @@ export class Config { return this.coreTools; } + getAllowedTools(): string[] | undefined { + return this.allowedTools; + } + getExcludeTools(): string[] | undefined { return this.excludeTools; } @@ -807,12 +815,10 @@ export class Config { const className = ToolClass.name; const toolName = ToolClass.Name || className; const coreTools = this.getCoreTools(); - const excludeTools = this.getExcludeTools(); + const excludeTools = this.getExcludeTools() || []; - let isEnabled = false; - if (coreTools === undefined) { - isEnabled = true; - } else { + let isEnabled = true; // Enabled by default if coreTools is not set. + if (coreTools) { isEnabled = coreTools.some( (tool) => tool === className || @@ -822,10 +828,11 @@ export class Config { ); } - if ( - excludeTools?.includes(className) || - excludeTools?.includes(toolName) - ) { + const isExcluded = excludeTools.some( + (tool) => tool === className || tool === toolName, + ); + + if (isExcluded) { isEnabled = false; } diff --git a/packages/core/src/core/coreToolScheduler.test.ts b/packages/core/src/core/coreToolScheduler.test.ts index ee3692f6..e4b4ac5b 100644 --- a/packages/core/src/core/coreToolScheduler.test.ts +++ b/packages/core/src/core/coreToolScheduler.test.ts @@ -5,6 +5,7 @@ */ import { describe, it, expect, vi } from 'vitest'; +import type { Mock } from 'vitest'; import type { ToolCall, WaitingToolCall } from './coreToolScheduler.js'; import { CoreToolScheduler, @@ -99,6 +100,41 @@ class TestApprovalInvocation extends BaseToolInvocation< } } +async function waitForStatus( + onToolCallsUpdate: Mock, + status: 'awaiting_approval' | 'executing' | 'success' | 'error' | 'cancelled', + timeout = 5000, +): Promise { + return new Promise((resolve, reject) => { + const startTime = Date.now(); + const check = () => { + if (Date.now() - startTime > timeout) { + const seenStatuses = onToolCallsUpdate.mock.calls + .flatMap((call) => call[0]) + .map((toolCall: ToolCall) => toolCall.status); + reject( + new Error( + `Timed out waiting for status "${status}". Seen statuses: ${seenStatuses.join( + ', ', + )}`, + ), + ); + return; + } + + const foundCall = onToolCallsUpdate.mock.calls + .flatMap((call) => call[0]) + .find((toolCall: ToolCall) => toolCall.status === status); + if (foundCall) { + resolve(foundCall); + } else { + setTimeout(check, 10); // Check again in 10ms + } + }; + check(); + }); +} + describe('CoreToolScheduler', () => { it('should cancel a tool call if the signal is aborted before confirmation', async () => { const mockTool = new MockTool(); @@ -126,6 +162,7 @@ describe('CoreToolScheduler', () => { getUsageStatisticsEnabled: () => true, getDebugMode: () => false, getApprovalMode: () => ApprovalMode.DEFAULT, + getAllowedTools: () => [], getContentGeneratorConfig: () => ({ model: 'test-model', authType: 'oauth-personal', @@ -186,6 +223,7 @@ describe('CoreToolScheduler with payload', () => { getUsageStatisticsEnabled: () => true, getDebugMode: () => false, getApprovalMode: () => ApprovalMode.DEFAULT, + getAllowedTools: () => [], getContentGeneratorConfig: () => ({ model: 'test-model', authType: 'oauth-personal', @@ -212,16 +250,10 @@ describe('CoreToolScheduler with payload', () => { await scheduler.schedule([request], abortController.signal); - await vi.waitFor(() => { - const awaitingCall = onToolCallsUpdate.mock.calls.find( - (call) => call[0][0].status === 'awaiting_approval', - )?.[0][0]; - expect(awaitingCall).toBeDefined(); - }); - - const awaitingCall = onToolCallsUpdate.mock.calls.find( - (call) => call[0][0].status === 'awaiting_approval', - )?.[0][0]; + const awaitingCall = (await waitForStatus( + onToolCallsUpdate, + 'awaiting_approval', + )) as WaitingToolCall; const confirmationDetails = awaitingCall.confirmationDetails; if (confirmationDetails) { @@ -497,6 +529,7 @@ describe('CoreToolScheduler edit cancellation', () => { getUsageStatisticsEnabled: () => true, getDebugMode: () => false, getApprovalMode: () => ApprovalMode.DEFAULT, + getAllowedTools: () => [], getContentGeneratorConfig: () => ({ model: 'test-model', authType: 'oauth-personal', @@ -523,12 +556,10 @@ describe('CoreToolScheduler edit cancellation', () => { await scheduler.schedule([request], abortController.signal); - // Wait for the tool to reach awaiting_approval state - const awaitingCall = onToolCallsUpdate.mock.calls.find( - (call) => call[0][0].status === 'awaiting_approval', - )?.[0][0]; - - expect(awaitingCall).toBeDefined(); + const awaitingCall = (await waitForStatus( + onToolCallsUpdate, + 'awaiting_approval', + )) as WaitingToolCall; // Cancel the edit const confirmationDetails = awaitingCall.confirmationDetails; @@ -589,6 +620,7 @@ describe('CoreToolScheduler YOLO mode', () => { getUsageStatisticsEnabled: () => true, getDebugMode: () => false, getApprovalMode: () => ApprovalMode.YOLO, + getAllowedTools: () => [], getContentGeneratorConfig: () => ({ model: 'test-model', authType: 'oauth-personal', @@ -678,6 +710,7 @@ describe('CoreToolScheduler request queueing', () => { getUsageStatisticsEnabled: () => true, getDebugMode: () => false, getApprovalMode: () => ApprovalMode.YOLO, // Use YOLO to avoid confirmation prompts + getAllowedTools: () => [], getContentGeneratorConfig: () => ({ model: 'test-model', authType: 'oauth-personal', @@ -713,10 +746,7 @@ describe('CoreToolScheduler request queueing', () => { scheduler.schedule([request1], abortController.signal); // Wait for the first call to be in the 'executing' state. - await vi.waitFor(() => { - const calls = onToolCallsUpdate.mock.calls.at(-1)?.[0] as ToolCall[]; - expect(calls?.[0]?.status).toBe('executing'); - }); + await waitForStatus(onToolCallsUpdate, 'executing'); // Schedule the second call while the first is "running". const schedulePromise2 = scheduler.schedule( @@ -737,16 +767,6 @@ describe('CoreToolScheduler request queueing', () => { // Wait for the second schedule promise to resolve. await schedulePromise2; - // Wait for the second call to be in the 'executing' state. - await vi.waitFor(() => { - const calls = onToolCallsUpdate.mock.calls.at(-1)?.[0] as ToolCall[]; - expect(calls?.[0]?.status).toBe('executing'); - }); - - // Now the second tool call should have been executed. - expect(mockTool.executeFn).toHaveBeenCalledTimes(2); - expect(mockTool.executeFn).toHaveBeenCalledWith({ b: 2 }); - // Let the second call finish. const secondCallResult = { llmContent: 'Second call complete', @@ -756,6 +776,12 @@ describe('CoreToolScheduler request queueing', () => { // In a real scenario, a new promise would be created for the second call. resolveFirstCall!(secondCallResult); + await vi.waitFor(() => { + // Now the second tool call should have been executed. + expect(mockTool.executeFn).toHaveBeenCalledTimes(2); + }); + expect(mockTool.executeFn).toHaveBeenCalledWith({ b: 2 }); + // Wait for the second completion. await vi.waitFor(() => { expect(onAllToolCallsComplete).toHaveBeenCalledTimes(2); @@ -766,6 +792,96 @@ describe('CoreToolScheduler request queueing', () => { expect(onAllToolCallsComplete.mock.calls[1][0][0].status).toBe('success'); }); + it('should auto-approve a tool call if it is on the allowedTools list', async () => { + // Arrange + const mockTool = new MockTool('mockTool'); + mockTool.executeFn.mockReturnValue({ + llmContent: 'Tool executed', + returnDisplay: 'Tool executed', + }); + // This tool would normally require confirmation. + mockTool.shouldConfirm = true; + const declarativeTool = mockTool; + + const toolRegistry = { + getTool: () => declarativeTool, + getToolByName: () => declarativeTool, + getFunctionDeclarations: () => [], + tools: new Map(), + discovery: {}, + registerTool: () => {}, + getToolByDisplayName: () => declarativeTool, + getTools: () => [], + discoverTools: async () => {}, + getAllTools: () => [], + getToolsByServer: () => [], + }; + + const onAllToolCallsComplete = vi.fn(); + const onToolCallsUpdate = vi.fn(); + + // Configure the scheduler to auto-approve the specific tool call. + const mockConfig = { + getSessionId: () => 'test-session-id', + getUsageStatisticsEnabled: () => true, + getDebugMode: () => false, + getApprovalMode: () => ApprovalMode.DEFAULT, // Not YOLO mode + getAllowedTools: () => ['mockTool'], // Auto-approve this tool + getToolRegistry: () => toolRegistry, + getContentGeneratorConfig: () => ({ + model: 'test-model', + authType: 'oauth-personal', + }), + } as unknown as Config; + + const scheduler = new CoreToolScheduler({ + config: mockConfig, + onAllToolCallsComplete, + onToolCallsUpdate, + getPreferredEditor: () => 'vscode', + onEditorClose: vi.fn(), + }); + + const abortController = new AbortController(); + const request = { + callId: '1', + name: 'mockTool', + args: { param: 'value' }, + isClientInitiated: false, + prompt_id: 'prompt-auto-approved', + }; + + // Act + await scheduler.schedule([request], abortController.signal); + + // Assert + // 1. The tool's execute method was called directly. + expect(mockTool.executeFn).toHaveBeenCalledWith({ param: 'value' }); + + // 2. The tool call status never entered 'awaiting_approval'. + const statusUpdates = onToolCallsUpdate.mock.calls + .map((call) => (call[0][0] as ToolCall)?.status) + .filter(Boolean); + expect(statusUpdates).not.toContain('awaiting_approval'); + expect(statusUpdates).toEqual([ + 'validating', + 'scheduled', + 'executing', + 'success', + ]); + + // 3. The final callback indicates the tool call was successful. + expect(onAllToolCallsComplete).toHaveBeenCalled(); + const completedCalls = onAllToolCallsComplete.mock + .calls[0][0] as ToolCall[]; + expect(completedCalls).toHaveLength(1); + const completedCall = completedCalls[0]; + expect(completedCall.status).toBe('success'); + if (completedCall.status === 'success') { + expect(completedCall.response.resultDisplay).toBe('Tool executed'); + } + }); + it('should handle two synchronous calls to schedule', async () => { const mockTool = new MockTool(); const declarativeTool = mockTool; @@ -782,7 +898,6 @@ describe('CoreToolScheduler request queueing', () => { getAllTools: () => [], getToolsByServer: () => [], } as unknown as ToolRegistry; - const onAllToolCallsComplete = vi.fn(); const onToolCallsUpdate = vi.fn(); @@ -791,6 +906,7 @@ describe('CoreToolScheduler request queueing', () => { getUsageStatisticsEnabled: () => true, getDebugMode: () => false, getApprovalMode: () => ApprovalMode.YOLO, + getAllowedTools: () => [], getContentGeneratorConfig: () => ({ model: 'test-model', authType: 'oauth-personal', @@ -851,6 +967,7 @@ describe('CoreToolScheduler request queueing', () => { getUsageStatisticsEnabled: () => true, getDebugMode: () => false, getApprovalMode: () => approvalMode, + getAllowedTools: () => [], setApprovalMode: (mode: ApprovalMode) => { approvalMode = mode; }, diff --git a/packages/core/src/core/coreToolScheduler.ts b/packages/core/src/core/coreToolScheduler.ts index 8c175094..8a4a7a1c 100644 --- a/packages/core/src/core/coreToolScheduler.ts +++ b/packages/core/src/core/coreToolScheduler.ts @@ -32,6 +32,7 @@ import { modifyWithEditor, } from '../tools/modifiable-tool.js'; import * as Diff from 'diff'; +import { doesToolInvocationMatch } from '../utils/tool-utils.js'; export type ValidatingToolCall = { status: 'validating'; @@ -615,68 +616,74 @@ export class CoreToolScheduler { ); continue; } - if (this.config.getApprovalMode() === ApprovalMode.YOLO) { + + const confirmationDetails = + await invocation.shouldConfirmExecute(signal); + + if (!confirmationDetails) { + this.setToolCallOutcome( + reqInfo.callId, + ToolConfirmationOutcome.ProceedAlways, + ); + this.setStatusInternal(reqInfo.callId, 'scheduled'); + continue; + } + + const allowedTools = this.config.getAllowedTools() || []; + if ( + this.config.getApprovalMode() === ApprovalMode.YOLO || + doesToolInvocationMatch(toolCall.tool, invocation, allowedTools) + ) { this.setToolCallOutcome( reqInfo.callId, ToolConfirmationOutcome.ProceedAlways, ); this.setStatusInternal(reqInfo.callId, 'scheduled'); } else { - const confirmationDetails = - await invocation.shouldConfirmExecute(signal); - - if (confirmationDetails) { - // Allow IDE to resolve confirmation - if ( - confirmationDetails.type === 'edit' && - confirmationDetails.ideConfirmation - ) { - confirmationDetails.ideConfirmation.then((resolution) => { - if (resolution.status === 'accepted') { - this.handleConfirmationResponse( - reqInfo.callId, - confirmationDetails.onConfirm, - ToolConfirmationOutcome.ProceedOnce, - signal, - ); - } else { - this.handleConfirmationResponse( - reqInfo.callId, - confirmationDetails.onConfirm, - ToolConfirmationOutcome.Cancel, - signal, - ); - } - }); - } - - const originalOnConfirm = confirmationDetails.onConfirm; - const wrappedConfirmationDetails: ToolCallConfirmationDetails = { - ...confirmationDetails, - onConfirm: ( - outcome: ToolConfirmationOutcome, - payload?: ToolConfirmationPayload, - ) => + // Allow IDE to resolve confirmation + if ( + confirmationDetails.type === 'edit' && + confirmationDetails.ideConfirmation + ) { + confirmationDetails.ideConfirmation.then((resolution) => { + if (resolution.status === 'accepted') { this.handleConfirmationResponse( reqInfo.callId, - originalOnConfirm, - outcome, + confirmationDetails.onConfirm, + ToolConfirmationOutcome.ProceedOnce, signal, - payload, - ), - }; - this.setStatusInternal( - reqInfo.callId, - 'awaiting_approval', - wrappedConfirmationDetails, - ); - } else { - this.setToolCallOutcome( - reqInfo.callId, - ToolConfirmationOutcome.ProceedAlways, - ); - this.setStatusInternal(reqInfo.callId, 'scheduled'); + ); + } else { + this.handleConfirmationResponse( + reqInfo.callId, + confirmationDetails.onConfirm, + ToolConfirmationOutcome.Cancel, + signal, + ); + } + }); } + + const originalOnConfirm = confirmationDetails.onConfirm; + const wrappedConfirmationDetails: ToolCallConfirmationDetails = { + ...confirmationDetails, + onConfirm: ( + outcome: ToolConfirmationOutcome, + payload?: ToolConfirmationPayload, + ) => + this.handleConfirmationResponse( + reqInfo.callId, + originalOnConfirm, + outcome, + signal, + payload, + ), + }; + this.setStatusInternal( + reqInfo.callId, + 'awaiting_approval', + wrappedConfirmationDetails, + ); } } catch (error) { this.setStatusInternal( diff --git a/packages/core/src/core/nonInteractiveToolExecutor.test.ts b/packages/core/src/core/nonInteractiveToolExecutor.test.ts index c5e71239..c46328bf 100644 --- a/packages/core/src/core/nonInteractiveToolExecutor.test.ts +++ b/packages/core/src/core/nonInteractiveToolExecutor.test.ts @@ -32,6 +32,7 @@ describe('executeToolCall', () => { mockConfig = { getToolRegistry: () => mockToolRegistry, getApprovalMode: () => ApprovalMode.DEFAULT, + getAllowedTools: () => [], getSessionId: () => 'test-session-id', getUsageStatisticsEnabled: () => true, getDebugMode: () => false, diff --git a/packages/core/src/tools/tools.ts b/packages/core/src/tools/tools.ts index 432fdc18..5a61fecd 100644 --- a/packages/core/src/tools/tools.ts +++ b/packages/core/src/tools/tools.ts @@ -307,6 +307,21 @@ export abstract class BaseDeclarativeTool< */ export type AnyDeclarativeTool = DeclarativeTool; +/** + * Type guard to check if an object is a Tool. + * @param obj The object to check. + * @returns True if the object is a Tool, false otherwise. + */ +export function isTool(obj: unknown): obj is AnyDeclarativeTool { + return ( + typeof obj === 'object' && + obj !== null && + 'name' in obj && + 'build' in obj && + typeof (obj as AnyDeclarativeTool).build === 'function' + ); +} + export interface ToolResult { /** * Content meant to be included in LLM history. diff --git a/packages/core/src/utils/shell-utils.test.ts b/packages/core/src/utils/shell-utils.test.ts index 71f33dbf..3c18fff4 100644 --- a/packages/core/src/utils/shell-utils.test.ts +++ b/packages/core/src/utils/shell-utils.test.ts @@ -16,11 +16,14 @@ import { import type { Config } from '../config/config.js'; const mockPlatform = vi.hoisted(() => vi.fn()); +const mockHomedir = vi.hoisted(() => vi.fn()); vi.mock('os', () => ({ default: { platform: mockPlatform, + homedir: mockHomedir, }, platform: mockPlatform, + homedir: mockHomedir, })); const mockQuote = vi.hoisted(() => vi.fn()); @@ -38,6 +41,7 @@ beforeEach(() => { config = { getCoreTools: () => [], getExcludeTools: () => [], + getAllowedTools: () => [], } as unknown as Config; }); diff --git a/packages/core/src/utils/shell-utils.ts b/packages/core/src/utils/shell-utils.ts index 3af4959b..cedebe6c 100644 --- a/packages/core/src/utils/shell-utils.ts +++ b/packages/core/src/utils/shell-utils.ts @@ -4,9 +4,13 @@ * SPDX-License-Identifier: Apache-2.0 */ +import type { AnyToolInvocation } from '../index.js'; import type { Config } from '../config/config.js'; import os from 'node:os'; import { quote } from 'shell-quote'; +import { doesToolInvocationMatch } from './tool-utils.js'; + +const SHELL_TOOL_NAMES = ['run_shell_command', 'ShellTool']; /** * An identifier for the shell type. @@ -319,32 +323,19 @@ export function checkCommandPermissions( }; } - const SHELL_TOOL_NAMES = ['run_shell_command', 'ShellTool']; const normalize = (cmd: string): string => cmd.trim().replace(/\s+/g, ' '); - - const isPrefixedBy = (cmd: string, prefix: string): boolean => { - if (!cmd.startsWith(prefix)) { - return false; - } - return cmd.length === prefix.length || cmd[prefix.length] === ' '; - }; - - const extractCommands = (tools: string[]): string[] => - tools.flatMap((tool) => { - for (const toolName of SHELL_TOOL_NAMES) { - if (tool.startsWith(`${toolName}(`) && tool.endsWith(')')) { - return [normalize(tool.slice(toolName.length + 1, -1))]; - } - } - return []; - }); - - const coreTools = config.getCoreTools() || []; - const excludeTools = config.getExcludeTools() || []; const commandsToValidate = splitCommands(command).map(normalize); + const invocation: AnyToolInvocation & { params: { command: string } } = { + params: { command: '' }, + } as AnyToolInvocation & { params: { command: string } }; // 1. Blocklist Check (Highest Priority) - if (SHELL_TOOL_NAMES.some((name) => excludeTools.includes(name))) { + const excludeTools = config.getExcludeTools() || []; + const isWildcardBlocked = SHELL_TOOL_NAMES.some((name) => + excludeTools.includes(name), + ); + + if (isWildcardBlocked) { return { allAllowed: false, disallowedCommands: commandsToValidate, @@ -352,9 +343,12 @@ export function checkCommandPermissions( isHardDenial: true, }; } - const blockedCommands = extractCommands(excludeTools); + for (const cmd of commandsToValidate) { - if (blockedCommands.some((blocked) => isPrefixedBy(cmd, blocked))) { + invocation.params['command'] = cmd; + if ( + doesToolInvocationMatch('run_shell_command', invocation, excludeTools) + ) { return { allAllowed: false, disallowedCommands: [cmd], @@ -364,7 +358,7 @@ export function checkCommandPermissions( } } - const globallyAllowedCommands = extractCommands(coreTools); + const coreTools = config.getCoreTools() || []; const isWildcardAllowed = SHELL_TOOL_NAMES.some((name) => coreTools.includes(name), ); @@ -375,18 +369,30 @@ export function checkCommandPermissions( return { allAllowed: true, disallowedCommands: [] }; } + const disallowedCommands: string[] = []; + if (sessionAllowlist) { // "DEFAULT DENY" MODE: A session allowlist is provided. // All commands must be in either the session or global allowlist. - const disallowedCommands: string[] = []; + const normalizedSessionAllowlist = new Set( + [...sessionAllowlist].flatMap((cmd) => + SHELL_TOOL_NAMES.map((name) => `${name}(${cmd})`), + ), + ); + for (const cmd of commandsToValidate) { - const isSessionAllowed = [...sessionAllowlist].some((allowed) => - isPrefixedBy(cmd, normalize(allowed)), + invocation.params['command'] = cmd; + const isSessionAllowed = doesToolInvocationMatch( + 'run_shell_command', + invocation, + [...normalizedSessionAllowlist], ); if (isSessionAllowed) continue; - const isGloballyAllowed = globallyAllowedCommands.some((allowed) => - isPrefixedBy(cmd, allowed), + const isGloballyAllowed = doesToolInvocationMatch( + 'run_shell_command', + invocation, + coreTools, ); if (isGloballyAllowed) continue; @@ -405,12 +411,18 @@ export function checkCommandPermissions( } } else { // "DEFAULT ALLOW" MODE: No session allowlist. - const hasSpecificAllowedCommands = globallyAllowedCommands.length > 0; + const hasSpecificAllowedCommands = + coreTools.filter((tool) => + SHELL_TOOL_NAMES.some((name) => tool.startsWith(`${name}(`)), + ).length > 0; + if (hasSpecificAllowedCommands) { - const disallowedCommands: string[] = []; for (const cmd of commandsToValidate) { - const isGloballyAllowed = globallyAllowedCommands.some((allowed) => - isPrefixedBy(cmd, allowed), + invocation.params['command'] = cmd; + const isGloballyAllowed = doesToolInvocationMatch( + 'run_shell_command', + invocation, + coreTools, ); if (!isGloballyAllowed) { disallowedCommands.push(cmd); @@ -420,7 +432,9 @@ export function checkCommandPermissions( return { allAllowed: false, disallowedCommands, - blockReason: `Command(s) not in the allowed commands list. Disallowed commands: ${disallowedCommands.map((c) => JSON.stringify(c)).join(', ')}`, + blockReason: `Command(s) not in the allowed commands list. Disallowed commands: ${disallowedCommands + .map((c) => JSON.stringify(c)) + .join(', ')}`, isHardDenial: false, // This is a soft denial. }; } diff --git a/packages/core/src/utils/tool-utils.test.ts b/packages/core/src/utils/tool-utils.test.ts new file mode 100644 index 00000000..5527e186 --- /dev/null +++ b/packages/core/src/utils/tool-utils.test.ts @@ -0,0 +1,94 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { expect, describe, it } from 'vitest'; +import { doesToolInvocationMatch } from './tool-utils.js'; +import type { AnyToolInvocation, Config } from '../index.js'; +import { ReadFileTool } from '../tools/read-file.js'; + +describe('doesToolInvocationMatch', () => { + it('should not match a partial command prefix', () => { + const invocation = { + params: { command: 'git commitsomething' }, + } as AnyToolInvocation; + const patterns = ['ShellTool(git commit)']; + const result = doesToolInvocationMatch( + 'run_shell_command', + invocation, + patterns, + ); + expect(result).toBe(false); + }); + + it('should match an exact command', () => { + const invocation = { + params: { command: 'git status' }, + } as AnyToolInvocation; + const patterns = ['ShellTool(git status)']; + const result = doesToolInvocationMatch( + 'run_shell_command', + invocation, + patterns, + ); + expect(result).toBe(true); + }); + + it('should match a command that is a prefix', () => { + const invocation = { + params: { command: 'git status -v' }, + } as AnyToolInvocation; + const patterns = ['ShellTool(git status)']; + const result = doesToolInvocationMatch( + 'run_shell_command', + invocation, + patterns, + ); + expect(result).toBe(true); + }); + + describe('for non-shell tools', () => { + const readFileTool = new ReadFileTool({} as Config); + const invocation = { + params: { file: 'test.txt' }, + } as AnyToolInvocation; + + it('should match by tool name', () => { + const patterns = ['read_file']; + const result = doesToolInvocationMatch( + readFileTool, + invocation, + patterns, + ); + expect(result).toBe(true); + }); + + it('should match by tool class name', () => { + const patterns = ['ReadFileTool']; + const result = doesToolInvocationMatch( + readFileTool, + invocation, + patterns, + ); + expect(result).toBe(true); + }); + + it('should not match if neither name is in the patterns', () => { + const patterns = ['some_other_tool', 'AnotherToolClass']; + const result = doesToolInvocationMatch( + readFileTool, + invocation, + patterns, + ); + expect(result).toBe(false); + }); + + it('should match by tool name when passed as a string', () => { + const patterns = ['read_file']; + const result = doesToolInvocationMatch('read_file', invocation, patterns); + expect(result).toBe(true); + }); + }); +}); diff --git a/packages/core/src/utils/tool-utils.ts b/packages/core/src/utils/tool-utils.ts new file mode 100644 index 00000000..cd3053ff --- /dev/null +++ b/packages/core/src/utils/tool-utils.ts @@ -0,0 +1,76 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import type { AnyDeclarativeTool, AnyToolInvocation } from '../index.js'; +import { isTool } from '../index.js'; + +const SHELL_TOOL_NAMES = ['run_shell_command', 'ShellTool']; + +/** + * Checks if a tool invocation matches any of a list of patterns. + * + * @param toolOrToolName The tool object or the name of the tool being invoked. + * @param invocation The invocation object for the tool. + * @param patterns A list of patterns to match against. + * Patterns can be: + * - A tool name (e.g., "ReadFileTool") to match any invocation of that tool. + * - A tool name with a prefix (e.g., "ShellTool(git status)") to match + * invocations where the arguments start with that prefix. + * @returns True if the invocation matches any pattern, false otherwise. + */ +export function doesToolInvocationMatch( + toolOrToolName: AnyDeclarativeTool | string, + invocation: AnyToolInvocation, + patterns: string[], +): boolean { + let toolNames: string[]; + if (isTool(toolOrToolName)) { + toolNames = [toolOrToolName.name, toolOrToolName.constructor.name]; + } else { + toolNames = [toolOrToolName as string]; + } + + if (toolNames.some((name) => SHELL_TOOL_NAMES.includes(name))) { + toolNames = [...new Set([...toolNames, ...SHELL_TOOL_NAMES])]; + } + + for (const pattern of patterns) { + const openParen = pattern.indexOf('('); + + if (openParen === -1) { + // No arguments, just a tool name + if (toolNames.includes(pattern)) { + return true; + } + continue; + } + + const patternToolName = pattern.substring(0, openParen); + if (!toolNames.includes(patternToolName)) { + continue; + } + + if (!pattern.endsWith(')')) { + continue; + } + + const argPattern = pattern.substring(openParen + 1, pattern.length - 1); + + if ( + 'command' in invocation.params && + toolNames.includes('run_shell_command') + ) { + const argValue = String( + (invocation.params as { command: string }).command, + ); + if (argValue === argPattern || argValue.startsWith(argPattern + ' ')) { + return true; + } + } + } + + return false; +}