refactor: session and canUseTool support

This commit is contained in:
mingholy.lmh
2025-11-21 16:22:18 +08:00
parent 2fe6ba7c56
commit f635cd3070
21 changed files with 1266 additions and 639 deletions

1
.vscode/launch.json vendored
View File

@@ -79,7 +79,6 @@
"--",
"-p",
"${input:prompt}",
"-y",
"--output-format",
"stream-json"
],

View File

@@ -758,8 +758,14 @@ export async function loadCliConfig(
interactive = false;
}
// In non-interactive mode, exclude tools that require a prompt.
// However, if stream-json input is used, control can be requested via JSON messages,
// so tools should not be excluded in that case.
const extraExcludes: string[] = [];
if (!interactive && !argv.experimentalAcp) {
if (
!interactive &&
!argv.experimentalAcp &&
inputFormat !== InputFormat.STREAM_JSON
) {
switch (approvalMode) {
case ApprovalMode.PLAN:
case ApprovalMode.DEFAULT:

View File

@@ -479,6 +479,10 @@ describe('gemini.tsx main function kitty protocol', () => {
inputFormat: undefined,
outputFormat: undefined,
includePartialMessages: undefined,
coreTools: undefined,
excludeTools: undefined,
authType: undefined,
maxSessionTurns: undefined,
});
await main();

View File

@@ -26,7 +26,7 @@
import type { IControlContext } from './ControlContext.js';
import type { IPendingRequestRegistry } from './controllers/baseController.js';
import { SystemController } from './controllers/systemController.js';
// import { PermissionController } from './controllers/permissionController.js';
import { PermissionController } from './controllers/permissionController.js';
// import { MCPController } from './controllers/mcpController.js';
// import { HookController } from './controllers/hookController.js';
import type {
@@ -64,7 +64,7 @@ export class ControlDispatcher implements IPendingRequestRegistry {
// Make controllers publicly accessible
readonly systemController: SystemController;
// readonly permissionController: PermissionController;
readonly permissionController: PermissionController;
// readonly mcpController: MCPController;
// readonly hookController: HookController;
@@ -83,11 +83,11 @@ export class ControlDispatcher implements IPendingRequestRegistry {
this,
'SystemController',
);
// this.permissionController = new PermissionController(
// context,
// this,
// 'PermissionController',
// );
this.permissionController = new PermissionController(
context,
this,
'PermissionController',
);
// this.mcpController = new MCPController(context, this, 'MCPController');
// this.hookController = new HookController(context, this, 'HookController');
@@ -230,7 +230,7 @@ export class ControlDispatcher implements IPendingRequestRegistry {
// Cleanup controllers (MCP controller will close all clients)
this.systemController.cleanup();
// this.permissionController.cleanup();
this.permissionController.cleanup();
// this.mcpController.cleanup();
// this.hookController.cleanup();
}
@@ -302,9 +302,9 @@ export class ControlDispatcher implements IPendingRequestRegistry {
case 'supported_commands':
return this.systemController;
// case 'can_use_tool':
// case 'set_permission_mode':
// return this.permissionController;
case 'can_use_tool':
case 'set_permission_mode':
return this.permissionController;
// case 'mcp_message':
// case 'mcp_server_status':

View File

@@ -29,7 +29,7 @@
import type { IControlContext } from './ControlContext.js';
import type { ControlDispatcher } from './ControlDispatcher.js';
import type {
// PermissionServiceAPI,
PermissionServiceAPI,
SystemServiceAPI,
// McpServiceAPI,
// HookServiceAPI,
@@ -61,43 +61,43 @@ export class ControlService {
* Handles tool execution permissions, approval checks, and callbacks.
* Delegates to the shared PermissionController instance.
*/
// get permission(): PermissionServiceAPI {
// const controller = this.dispatcher.permissionController;
// return {
// /**
// * Check if a tool should be allowed based on current permission settings
// *
// * Evaluates permission mode and tool registry to determine if execution
// * should proceed. Can optionally modify tool arguments based on confirmation details.
// *
// * @param toolRequest - Tool call request information
// * @param confirmationDetails - Optional confirmation details for UI
// * @returns Permission decision with optional updated arguments
// */
// shouldAllowTool: controller.shouldAllowTool.bind(controller),
//
// /**
// * Build UI suggestions for tool confirmation dialogs
// *
// * Creates actionable permission suggestions based on tool confirmation details.
// *
// * @param confirmationDetails - Tool confirmation details
// * @returns Array of permission suggestions or null
// */
// buildPermissionSuggestions:
// controller.buildPermissionSuggestions.bind(controller),
//
// /**
// * Get callback for monitoring tool call status updates
// *
// * Returns callback function for integration with CoreToolScheduler.
// *
// * @returns Callback function for tool call updates
// */
// getToolCallUpdateCallback:
// controller.getToolCallUpdateCallback.bind(controller),
// };
// }
get permission(): PermissionServiceAPI {
const controller = this.dispatcher.permissionController;
return {
/**
* Check if a tool should be allowed based on current permission settings
*
* Evaluates permission mode and tool registry to determine if execution
* should proceed. Can optionally modify tool arguments based on confirmation details.
*
* @param toolRequest - Tool call request information
* @param confirmationDetails - Optional confirmation details for UI
* @returns Permission decision with optional updated arguments
*/
shouldAllowTool: controller.shouldAllowTool.bind(controller),
/**
* Build UI suggestions for tool confirmation dialogs
*
* Creates actionable permission suggestions based on tool confirmation details.
*
* @param confirmationDetails - Tool confirmation details
* @returns Array of permission suggestions or null
*/
buildPermissionSuggestions:
controller.buildPermissionSuggestions.bind(controller),
/**
* Get callback for monitoring tool call status updates
*
* Returns callback function for integration with CoreToolScheduler.
*
* @returns Callback function for tool call updates
*/
getToolCallUpdateCallback:
controller.getToolCallUpdateCallback.bind(controller),
};
}
/**
* System Domain API

View File

@@ -17,6 +17,8 @@
import type {
ToolCallRequestInfo,
WaitingToolCall,
ToolExecuteConfirmationDetails,
ToolMcpConfirmationDetails,
} from '@qwen-code/qwen-code-core';
import {
InputFormat,
@@ -430,17 +432,14 @@ export class PermissionController extends BaseController {
toolCall.confirmationDetails,
);
const response = await this.sendControlRequest(
{
const response = await this.sendControlRequest({
subtype: 'can_use_tool',
tool_name: toolCall.request.name,
tool_use_id: toolCall.request.callId,
input: toolCall.request.args,
permission_suggestions: permissionSuggestions,
blocked_path: null,
} as CLIControlPermissionRequest,
30000,
);
} as CLIControlPermissionRequest);
if (response.subtype !== 'success') {
await toolCall.confirmationDetails.onConfirm(
@@ -462,8 +461,15 @@ export class PermissionController extends BaseController {
ToolConfirmationOutcome.ProceedOnce,
);
} else {
// Extract cancel message from response if available
const cancelMessage =
typeof payload['message'] === 'string'
? payload['message']
: undefined;
await toolCall.confirmationDetails.onConfirm(
ToolConfirmationOutcome.Cancel,
cancelMessage ? { cancelMessage } : undefined,
);
}
} catch (error) {
@@ -473,9 +479,23 @@ export class PermissionController extends BaseController {
error,
);
}
// On error, use default cancel message
// Only pass payload for exec and mcp types that support it
const confirmationType = toolCall.confirmationDetails.type;
if (confirmationType === 'exec' || confirmationType === 'mcp') {
const execOrMcpDetails = toolCall.confirmationDetails as
| ToolExecuteConfirmationDetails
| ToolMcpConfirmationDetails;
await execOrMcpDetails.onConfirm(
ToolConfirmationOutcome.Cancel,
undefined,
);
} else {
// For other types, don't pass payload (backward compatible)
await toolCall.confirmationDetails.onConfirm(
ToolConfirmationOutcome.Cancel,
);
}
} finally {
this.pendingOutgoingRequests.delete(toolCall.request.callId);
}

View File

@@ -939,9 +939,25 @@ export abstract class BaseJsonOutputAdapter {
this.emitMessageImpl(message);
}
/**
* Checks if responseParts contain any functionResponse with an error.
* This handles cancelled responses and other error cases where the error
* is embedded in responseParts rather than the top-level error field.
* @param responseParts - Array of Part objects
* @returns Error message if found, undefined otherwise
*/
private checkResponsePartsForError(
responseParts: Part[] | undefined,
): string | undefined {
// Use the shared helper function defined at file level
return checkResponsePartsForError(responseParts);
}
/**
* Emits a tool result message.
* Collects execution denied tool calls for inclusion in result messages.
* Handles both explicit errors (response.error) and errors embedded in
* responseParts (e.g., cancelled responses).
* @param request - Tool call request info
* @param response - Tool call response info
* @param parentToolUseId - Parent tool use ID (null for main agent)
@@ -951,6 +967,14 @@ export abstract class BaseJsonOutputAdapter {
response: ToolCallResponseInfo,
parentToolUseId: string | null = null,
): void {
// Check for errors in responseParts (e.g., cancelled responses)
const responsePartsError = this.checkResponsePartsForError(
response.responseParts,
);
// Determine if this is an error response
const hasError = Boolean(response.error) || Boolean(responsePartsError);
// Track permission denials (execution denied errors)
if (
response.error &&
@@ -967,7 +991,7 @@ export abstract class BaseJsonOutputAdapter {
const block: ToolResultBlock = {
type: 'tool_result',
tool_use_id: request.callId,
is_error: Boolean(response.error),
is_error: hasError,
};
const content = toolResultContent(response);
if (content !== undefined) {
@@ -1173,11 +1197,41 @@ export function partsToString(parts: Part[]): string {
.join('');
}
/**
* Checks if responseParts contain any functionResponse with an error.
* Helper function for extracting error messages from responseParts.
* @param responseParts - Array of Part objects
* @returns Error message if found, undefined otherwise
*/
function checkResponsePartsForError(
responseParts: Part[] | undefined,
): string | undefined {
if (!responseParts || responseParts.length === 0) {
return undefined;
}
for (const part of responseParts) {
if (
'functionResponse' in part &&
part.functionResponse?.response &&
typeof part.functionResponse.response === 'object' &&
'error' in part.functionResponse.response &&
part.functionResponse.response['error']
) {
const error = part.functionResponse.response['error'];
return typeof error === 'string' ? error : String(error);
}
}
return undefined;
}
/**
* Extracts content from tool response.
* Uses functionResponsePartsToString to properly handle functionResponse parts,
* which correctly extracts output content from functionResponse objects rather
* than simply concatenating text or JSON.stringify.
* Also handles errors embedded in responseParts (e.g., cancelled responses).
*
* @param response - Tool call response
* @returns String content or undefined
@@ -1188,6 +1242,11 @@ export function toolResultContent(
if (response.error) {
return response.error.message;
}
// Check for errors in responseParts (e.g., cancelled responses)
const responsePartsError = checkResponsePartsForError(response.responseParts);
if (responsePartsError) {
return responsePartsError;
}
if (
typeof response.resultDisplay === 'string' &&
response.resultDisplay.trim().length > 0

View File

@@ -134,7 +134,7 @@ function createControlCancel(requestId: string): ControlCancelRequest {
};
}
describe('runNonInteractiveStreamJson', () => {
describe('runNonInteractiveStreamJson (refactored)', () => {
let config: Config;
let mockInputReader: {
read: () => AsyncGenerator<

View File

@@ -4,17 +4,6 @@
* SPDX-License-Identifier: Apache-2.0
*/
/**
* Stream JSON Runner with Session State Machine
*
* Handles stream-json input/output format with:
* - Initialize handshake
* - Message routing (control vs user messages)
* - FIFO user message queue
* - Sequential message processing
* - Graceful shutdown
*/
import type { Config } from '@qwen-code/qwen-code-core';
import { StreamJsonInputReader } from './io/StreamJsonInputReader.js';
import { StreamJsonOutputAdapter } from './io/StreamJsonOutputAdapter.js';
@@ -42,48 +31,7 @@ import { createMinimalSettings } from '../config/settings.js';
import { runNonInteractive } from '../nonInteractiveCli.js';
import { ConsolePatcher } from '../ui/utils/ConsolePatcher.js';
const SESSION_STATE = {
INITIALIZING: 'initializing',
IDLE: 'idle',
PROCESSING_QUERY: 'processing_query',
SHUTTING_DOWN: 'shutting_down',
} as const;
type SessionState = (typeof SESSION_STATE)[keyof typeof SESSION_STATE];
/**
* Message type classification for routing
*/
type MessageType =
| 'control_request'
| 'control_response'
| 'control_cancel'
| 'user'
| 'assistant'
| 'system'
| 'result'
| 'stream_event'
| 'unknown';
/**
* Routed message with classification
*/
interface RoutedMessage {
type: MessageType;
message:
| CLIMessage
| CLIControlRequest
| CLIControlResponse
| ControlCancelRequest;
}
/**
* Session Manager
*
* Manages the session lifecycle and message processing state machine.
*/
class SessionManager {
private state: SessionState = SESSION_STATE.INITIALIZING;
class Session {
private userMessageQueue: CLIUserMessage[] = [];
private abortController: AbortController;
private config: Config;
@@ -98,6 +46,8 @@ class SessionManager {
private debugMode: boolean;
private shutdownHandler: (() => void) | null = null;
private initialPrompt: CLIUserMessage | null = null;
private processingPromise: Promise<void> | null = null;
private isShuttingDown: boolean = false;
constructor(config: Config, initialPrompt?: CLIUserMessage) {
this.config = config;
@@ -112,161 +62,18 @@ class SessionManager {
config.getIncludePartialMessages(),
);
// Setup signal handlers for graceful shutdown
this.setupSignalHandlers();
}
/**
* Get next prompt ID
*/
private getNextPromptId(): string {
this.promptIdCounter++;
return `${this.sessionId}########${this.promptIdCounter}`;
}
/**
* Route a message to the appropriate handler based on its type
*
* Classifies incoming messages and routes them to appropriate handlers.
*/
private route(
message:
| CLIMessage
| CLIControlRequest
| CLIControlResponse
| ControlCancelRequest,
): RoutedMessage {
// Check control messages first
if (isControlRequest(message)) {
return { type: 'control_request', message };
}
if (isControlResponse(message)) {
return { type: 'control_response', message };
}
if (isControlCancel(message)) {
return { type: 'control_cancel', message };
}
// Check data messages
if (isCLIUserMessage(message)) {
return { type: 'user', message };
}
if (isCLIAssistantMessage(message)) {
return { type: 'assistant', message };
}
if (isCLISystemMessage(message)) {
return { type: 'system', message };
}
if (isCLIResultMessage(message)) {
return { type: 'result', message };
}
if (isCLIPartialAssistantMessage(message)) {
return { type: 'stream_event', message };
}
// Unknown message type
if (this.debugMode) {
console.error(
'[SessionManager] Unknown message type:',
JSON.stringify(message, null, 2),
);
}
return { type: 'unknown', message };
}
/**
* Process a single message with unified logic for both initial prompt and stream messages.
*
* Handles:
* - Abort check
* - First message detection and handling
* - Normal message processing
* - Shutdown state checks
*
* @param message - Message to process
* @returns true if the calling code should exit (break/return), false to continue
*/
private async processSingleMessage(
message:
| CLIMessage
| CLIControlRequest
| CLIControlResponse
| ControlCancelRequest,
): Promise<boolean> {
// Check for abort
if (this.abortController.signal.aborted) {
return true;
}
// Handle first message if control system not yet initialized
if (this.controlSystemEnabled === null) {
const handled = await this.handleFirstMessage(message);
if (handled) {
// If handled, check if we should shutdown
return this.state === SESSION_STATE.SHUTTING_DOWN;
}
// If not handled, fall through to normal processing
}
// Process message normally
await this.processMessage(message);
// Check for shutdown after processing
return this.state === SESSION_STATE.SHUTTING_DOWN;
}
/**
* Main entry point - run the session
*/
async run(): Promise<void> {
try {
if (this.debugMode) {
console.error('[SessionManager] Starting session', this.sessionId);
}
// Process initial prompt if provided
if (this.initialPrompt !== null) {
const shouldExit = await this.processSingleMessage(this.initialPrompt);
if (shouldExit) {
await this.shutdown();
return;
}
}
// Process messages from stream
for await (const message of this.inputReader.read()) {
const shouldExit = await this.processSingleMessage(message);
if (shouldExit) {
break;
}
}
// Stream closed, shutdown
await this.shutdown();
} catch (error) {
if (this.debugMode) {
console.error('[SessionManager] Error:', error);
}
await this.shutdown();
throw error;
} finally {
// Ensure signal handlers are always cleaned up even if shutdown wasn't called
this.cleanupSignalHandlers();
}
}
private ensureControlSystem(): void {
if (this.controlContext && this.dispatcher && this.controlService) {
return;
}
// The control system follows a strict three-layer architecture:
// 1. ControlContext (shared session state)
// 2. ControlDispatcher (protocol routing SDK ↔ CLI)
// 3. ControlService (programmatic API for CLI runtime)
//
// Application code MUST interact with the control plane exclusively through
// ControlService. ControlDispatcher is reserved for protocol-level message
// routing and should never be used directly outside of this file.
this.controlContext = new ControlContext({
config: this.config,
streamJson: this.outputAdapter,
@@ -299,25 +106,25 @@ class SessionManager {
| CLIControlResponse
| ControlCancelRequest,
): Promise<boolean> {
const routed = this.route(message);
if (routed.type === 'control_request') {
const request = routed.message as CLIControlRequest;
if (isControlRequest(message)) {
const request = message as CLIControlRequest;
this.controlSystemEnabled = true;
this.ensureControlSystem();
if (request.request.subtype === 'initialize') {
await this.dispatcher?.dispatch(request);
this.state = SESSION_STATE.IDLE;
return true;
}
return false;
if (this.debugMode) {
console.error(
'[Session] Ignoring non-initialize control request during initialization',
);
}
return true;
}
if (routed.type === 'user') {
if (isCLIUserMessage(message)) {
this.controlSystemEnabled = false;
this.state = SESSION_STATE.PROCESSING_QUERY;
this.userMessageQueue.push(routed.message as CLIUserMessage);
await this.processUserMessageQueue();
this.enqueueUserMessage(message as CLIUserMessage);
return true;
}
@@ -325,237 +132,43 @@ class SessionManager {
return false;
}
/**
* Process a single message from the stream
*/
private async processMessage(
message:
| CLIMessage
| CLIControlRequest
| CLIControlResponse
| ControlCancelRequest,
private async handleControlRequest(
request: CLIControlRequest,
): Promise<void> {
const routed = this.route(message);
if (this.debugMode) {
console.error(
`[SessionManager] State: ${this.state}, Message type: ${routed.type}`,
);
}
switch (this.state) {
case SESSION_STATE.INITIALIZING:
await this.handleInitializingState(routed);
break;
case SESSION_STATE.IDLE:
await this.handleIdleState(routed);
break;
case SESSION_STATE.PROCESSING_QUERY:
await this.handleProcessingState(routed);
break;
case SESSION_STATE.SHUTTING_DOWN:
// Ignore all messages during shutdown
break;
default: {
// Exhaustive check
const _exhaustiveCheck: never = this.state;
if (this.debugMode) {
console.error('[SessionManager] Unknown state:', _exhaustiveCheck);
}
break;
}
}
}
/**
* Handle messages in initializing state
*/
private async handleInitializingState(routed: RoutedMessage): Promise<void> {
if (routed.type === 'control_request') {
const request = routed.message as CLIControlRequest;
const dispatcher = this.getDispatcher();
if (!dispatcher) {
if (this.debugMode) {
console.error(
'[SessionManager] Control request received before control system initialization',
);
console.error('[Session] Control system not enabled');
}
return;
}
if (request.request.subtype === 'initialize') {
await dispatcher.dispatch(request);
this.state = SESSION_STATE.IDLE;
if (this.debugMode) {
console.error('[SessionManager] Initialized, transitioning to idle');
}
} else {
if (this.debugMode) {
console.error(
'[SessionManager] Ignoring non-initialize control request during initialization',
);
}
}
} else {
if (this.debugMode) {
console.error(
'[SessionManager] Ignoring non-control message during initialization',
);
}
}
}
/**
* Handle messages in idle state
*/
private async handleIdleState(routed: RoutedMessage): Promise<void> {
const dispatcher = this.getDispatcher();
if (routed.type === 'control_request') {
if (!dispatcher) {
if (this.debugMode) {
console.error('[SessionManager] Ignoring control request (disabled)');
}
return;
}
const request = routed.message as CLIControlRequest;
await dispatcher.dispatch(request);
// Stay in idle state
} else if (routed.type === 'control_response') {
}
private handleControlResponse(response: CLIControlResponse): void {
const dispatcher = this.getDispatcher();
if (!dispatcher) {
return;
}
const response = routed.message as CLIControlResponse;
dispatcher.handleControlResponse(response);
// Stay in idle state
} else if (routed.type === 'control_cancel') {
}
private handleControlCancel(cancelRequest: ControlCancelRequest): void {
const dispatcher = this.getDispatcher();
if (!dispatcher) {
return;
}
const cancelRequest = routed.message as ControlCancelRequest;
dispatcher.handleCancel(cancelRequest.request_id);
} else if (routed.type === 'user') {
const userMessage = routed.message as CLIUserMessage;
this.userMessageQueue.push(userMessage);
// Start processing queue
await this.processUserMessageQueue();
} else {
if (this.debugMode) {
console.error(
'[SessionManager] Ignoring message type in idle state:',
routed.type,
);
}
}
}
/**
* Handle messages in processing state
*/
private async handleProcessingState(routed: RoutedMessage): Promise<void> {
const dispatcher = this.getDispatcher();
if (routed.type === 'control_request') {
if (!dispatcher) {
if (this.debugMode) {
console.error(
'[SessionManager] Control request ignored during processing (disabled)',
);
}
return;
}
const request = routed.message as CLIControlRequest;
await dispatcher.dispatch(request);
// Continue processing
} else if (routed.type === 'control_response') {
if (!dispatcher) {
return;
}
const response = routed.message as CLIControlResponse;
dispatcher.handleControlResponse(response);
// Continue processing
} else if (routed.type === 'user') {
// Enqueue for later
const userMessage = routed.message as CLIUserMessage;
this.userMessageQueue.push(userMessage);
if (this.debugMode) {
console.error(
'[SessionManager] Enqueued user message during processing',
);
}
} else {
if (this.debugMode) {
console.error(
'[SessionManager] Ignoring message type during processing:',
routed.type,
);
}
}
}
/**
* Process user message queue (FIFO)
*/
private async processUserMessageQueue(): Promise<void> {
while (
this.userMessageQueue.length > 0 &&
!this.abortController.signal.aborted
) {
this.state = SESSION_STATE.PROCESSING_QUERY;
const userMessage = this.userMessageQueue.shift()!;
try {
await this.processUserMessage(userMessage);
} catch (error) {
if (this.debugMode) {
console.error(
'[SessionManager] Error processing user message:',
error,
);
}
// Send error result
this.emitErrorResult(error);
}
}
// If control system is disabled (single-query mode) and queue is empty,
// automatically shutdown instead of returning to idle
if (
!this.abortController.signal.aborted &&
this.state === SESSION_STATE.PROCESSING_QUERY &&
this.controlSystemEnabled === false &&
this.userMessageQueue.length === 0
) {
if (this.debugMode) {
console.error(
'[SessionManager] Single-query mode: queue processed, shutting down',
);
}
this.state = SESSION_STATE.SHUTTING_DOWN;
return;
}
// Return to idle after processing queue (for multi-query mode with control system)
if (
!this.abortController.signal.aborted &&
this.state === SESSION_STATE.PROCESSING_QUERY
) {
this.state = SESSION_STATE.IDLE;
if (this.debugMode) {
console.error('[SessionManager] Queue processed, returning to idle');
}
}
}
/**
* Process a single user message
*/
private async processUserMessage(userMessage: CLIUserMessage): Promise<void> {
const input = extractUserMessageText(userMessage);
if (!input) {
if (this.debugMode) {
console.error('[SessionManager] No text content in user message');
console.error('[Session] No text content in user message');
}
return;
}
@@ -575,16 +188,56 @@ class SessionManager {
},
);
} catch (error) {
// Error already handled by runNonInteractive via adapter.emitResult
if (this.debugMode) {
console.error('[SessionManager] Query execution error:', error);
console.error('[Session] Query execution error:', error);
}
}
}
/**
* Send tool results as user message
*/
private async processUserMessageQueue(): Promise<void> {
if (this.isShuttingDown || this.abortController.signal.aborted) {
return;
}
while (
this.userMessageQueue.length > 0 &&
!this.isShuttingDown &&
!this.abortController.signal.aborted
) {
const userMessage = this.userMessageQueue.shift()!;
try {
await this.processUserMessage(userMessage);
} catch (error) {
if (this.debugMode) {
console.error('[Session] Error processing user message:', error);
}
this.emitErrorResult(error);
}
}
}
private enqueueUserMessage(userMessage: CLIUserMessage): void {
this.userMessageQueue.push(userMessage);
this.ensureProcessingStarted();
}
private ensureProcessingStarted(): void {
if (this.processingPromise) {
return;
}
this.processingPromise = this.processUserMessageQueue().finally(() => {
this.processingPromise = null;
if (
this.userMessageQueue.length > 0 &&
!this.isShuttingDown &&
!this.abortController.signal.aborted
) {
this.ensureProcessingStarted();
}
});
}
private emitErrorResult(
error: unknown,
numTurns: number = 0,
@@ -602,52 +255,51 @@ class SessionManager {
});
}
/**
* Handle interrupt control request
*/
private handleInterrupt(): void {
if (this.debugMode) {
console.error('[SessionManager] Interrupt requested');
console.error('[Session] Interrupt requested');
}
// Abort current query if processing
if (this.state === SESSION_STATE.PROCESSING_QUERY) {
this.abortController.abort();
this.abortController = new AbortController(); // Create new controller for next query
}
this.abortController = new AbortController();
}
/**
* Setup signal handlers for graceful shutdown
*/
private setupSignalHandlers(): void {
this.shutdownHandler = () => {
if (this.debugMode) {
console.error('[SessionManager] Shutdown signal received');
console.error('[Session] Shutdown signal received');
}
this.isShuttingDown = true;
this.abortController.abort();
this.state = SESSION_STATE.SHUTTING_DOWN;
};
process.on('SIGINT', this.shutdownHandler);
process.on('SIGTERM', this.shutdownHandler);
}
/**
* Shutdown session and cleanup resources
*/
private async shutdown(): Promise<void> {
if (this.debugMode) {
console.error('[SessionManager] Shutting down');
console.error('[Session] Shutting down');
}
this.isShuttingDown = true;
if (this.processingPromise) {
try {
await this.processingPromise;
} catch (error) {
if (this.debugMode) {
console.error(
'[Session] Error waiting for processing to complete:',
error,
);
}
}
}
this.state = SESSION_STATE.SHUTTING_DOWN;
this.dispatcher?.shutdown();
this.cleanupSignalHandlers();
}
/**
* Remove signal handlers to prevent memory leaks
*/
private cleanupSignalHandlers(): void {
if (this.shutdownHandler) {
process.removeListener('SIGINT', this.shutdownHandler);
@@ -655,6 +307,94 @@ class SessionManager {
this.shutdownHandler = null;
}
}
async run(): Promise<void> {
try {
if (this.debugMode) {
console.error('[Session] Starting session', this.sessionId);
}
if (this.initialPrompt !== null) {
const handled = await this.handleFirstMessage(this.initialPrompt);
if (handled && this.isShuttingDown) {
await this.shutdown();
return;
}
}
try {
for await (const message of this.inputReader.read()) {
if (this.abortController.signal.aborted) {
break;
}
if (this.controlSystemEnabled === null) {
const handled = await this.handleFirstMessage(message);
if (handled) {
if (this.isShuttingDown) {
break;
}
continue;
}
}
if (isControlRequest(message)) {
await this.handleControlRequest(message as CLIControlRequest);
} else if (isControlResponse(message)) {
this.handleControlResponse(message as CLIControlResponse);
} else if (isControlCancel(message)) {
this.handleControlCancel(message as ControlCancelRequest);
} else if (isCLIUserMessage(message)) {
this.enqueueUserMessage(message as CLIUserMessage);
} else if (this.debugMode) {
if (
!isCLIAssistantMessage(message) &&
!isCLISystemMessage(message) &&
!isCLIResultMessage(message) &&
!isCLIPartialAssistantMessage(message)
) {
console.error(
'[Session] Unknown message type:',
JSON.stringify(message, null, 2),
);
}
}
if (this.isShuttingDown) {
break;
}
}
} catch (streamError) {
if (this.debugMode) {
console.error('[Session] Stream reading error:', streamError);
}
throw streamError;
}
while (this.processingPromise) {
if (this.debugMode) {
console.error('[Session] Waiting for final processing to complete');
}
try {
await this.processingPromise;
} catch (error) {
if (this.debugMode) {
console.error('[Session] Error in final processing:', error);
}
}
}
await this.shutdown();
} catch (error) {
if (this.debugMode) {
console.error('[Session] Error:', error);
}
await this.shutdown();
throw error;
} finally {
this.cleanupSignalHandlers();
}
}
}
function extractUserMessageText(message: CLIUserMessage): string | null {
@@ -682,12 +422,6 @@ function extractUserMessageText(message: CLIUserMessage): string | null {
return null;
}
/**
* Entry point for stream-json mode
*
* @param config - Configuration object
* @param input - Optional initial prompt input to process before reading from stream
*/
export async function runNonInteractiveStreamJson(
config: Config,
input: string,
@@ -698,7 +432,6 @@ export async function runNonInteractiveStreamJson(
consolePatcher.patch();
try {
// Create initial user message from prompt input if provided
let initialPrompt: CLIUserMessage | undefined = undefined;
if (input && input.trim().length > 0) {
const sessionId = config.getSessionId();
@@ -713,7 +446,7 @@ export async function runNonInteractiveStreamJson(
};
}
const manager = new SessionManager(config, initialPrompt);
const manager = new Session(config, initialPrompt);
await manager.run();
} finally {
consolePatcher.cleanup();

View File

@@ -15,6 +15,7 @@ import {
FatalInputError,
promptIdContext,
OutputFormat,
InputFormat,
uiTelemetryService,
} from '@qwen-code/qwen-code-core';
import type { Content, Part, PartListUnion } from '@google/genai';
@@ -254,11 +255,17 @@ export async function runNonInteractive(
};
}
}
*/
const toolCallUpdateCallback = options.controlService
// Get toolCallUpdateCallback for SDK mode (stream-json)
const inputFormat =
typeof config.getInputFormat === 'function'
? config.getInputFormat()
: InputFormat.TEXT;
const toolCallUpdateCallback =
inputFormat === InputFormat.STREAM_JSON && options.controlService
? options.controlService.permission.getToolCallUpdateCallback()
: undefined;
*/
// Only pass outputUpdateHandler for Task tool
const isTaskTool = finalRequestInfo.name === 'task';
@@ -277,11 +284,11 @@ export async function runNonInteractive(
isTaskTool && taskToolProgressHandler
? {
outputUpdateHandler: taskToolProgressHandler,
/*
toolCallUpdateCallback
? { onToolCallsUpdate: toolCallUpdateCallback }
: undefined,
*/
onToolCallsUpdate: toolCallUpdateCallback,
}
: toolCallUpdateCallback
? {
onToolCallsUpdate: toolCallUpdateCallback,
}
: undefined,
);
@@ -303,9 +310,6 @@ export async function runNonInteractive(
? toolResponse.resultDisplay
: undefined,
);
// Note: We no longer emit a separate system message for tool errors
// in JSON/STREAM_JSON mode, as the error is already captured in the
// tool_result block with is_error=true.
}
if (adapter) {

View File

@@ -909,7 +909,10 @@ export class CoreToolScheduler {
async handleConfirmationResponse(
callId: string,
originalOnConfirm: (outcome: ToolConfirmationOutcome) => Promise<void>,
originalOnConfirm: (
outcome: ToolConfirmationOutcome,
payload?: ToolConfirmationPayload,
) => Promise<void>,
outcome: ToolConfirmationOutcome,
signal: AbortSignal,
payload?: ToolConfirmationPayload,
@@ -918,9 +921,7 @@ export class CoreToolScheduler {
(c) => c.request.callId === callId && c.status === 'awaiting_approval',
);
if (toolCall && toolCall.status === 'awaiting_approval') {
await originalOnConfirm(outcome);
}
await originalOnConfirm(outcome, payload);
if (outcome === ToolConfirmationOutcome.ProceedAlways) {
await this.autoApproveCompatiblePendingTools(signal, callId);
@@ -929,11 +930,10 @@ export class CoreToolScheduler {
this.setToolCallOutcome(callId, outcome);
if (outcome === ToolConfirmationOutcome.Cancel || signal.aborted) {
this.setStatusInternal(
callId,
'cancelled',
'User did not allow tool call',
);
// Use custom cancel message from payload if provided, otherwise use default
const cancelMessage =
payload?.cancelMessage || 'User did not allow tool call';
this.setStatusInternal(callId, 'cancelled', cancelMessage);
} else if (outcome === ToolConfirmationOutcome.ModifyWithEditor) {
const waitingToolCall = toolCall as WaitingToolCall;
if (isModifiableDeclarativeTool(waitingToolCall.tool)) {
@@ -991,7 +991,8 @@ export class CoreToolScheduler {
): Promise<void> {
if (
toolCall.confirmationDetails.type !== 'edit' ||
!isModifiableDeclarativeTool(toolCall.tool)
!isModifiableDeclarativeTool(toolCall.tool) ||
!payload.newContent
) {
return;
}

View File

@@ -10,6 +10,7 @@ import type {
ToolInvocation,
ToolMcpConfirmationDetails,
ToolResult,
ToolConfirmationPayload,
} from './tools.js';
import {
BaseDeclarativeTool,
@@ -98,7 +99,10 @@ class DiscoveredMCPToolInvocation extends BaseToolInvocation<
serverName: this.serverName,
toolName: this.serverToolName, // Display original tool name in confirmation
toolDisplayName: this.displayName, // Display global registry name exposed to model and user
onConfirm: async (outcome: ToolConfirmationOutcome) => {
onConfirm: async (
outcome: ToolConfirmationOutcome,
_payload?: ToolConfirmationPayload,
) => {
if (outcome === ToolConfirmationOutcome.ProceedAlwaysServer) {
DiscoveredMCPToolInvocation.allowlist.add(serverAllowListKey);
} else if (outcome === ToolConfirmationOutcome.ProceedAlwaysTool) {

View File

@@ -17,6 +17,7 @@ import type {
ToolResultDisplay,
ToolCallConfirmationDetails,
ToolExecuteConfirmationDetails,
ToolConfirmationPayload,
} from './tools.js';
import {
BaseDeclarativeTool,
@@ -102,7 +103,10 @@ export class ShellToolInvocation extends BaseToolInvocation<
title: 'Confirm Shell Command',
command: this.params.command,
rootCommand: commandsToConfirm.join(', '),
onConfirm: async (outcome: ToolConfirmationOutcome) => {
onConfirm: async (
outcome: ToolConfirmationOutcome,
_payload?: ToolConfirmationPayload,
) => {
if (outcome === ToolConfirmationOutcome.ProceedAlways) {
commandsToConfirm.forEach((command) => this.allowlist.add(command));
}

View File

@@ -531,13 +531,18 @@ export interface ToolEditConfirmationDetails {
export interface ToolConfirmationPayload {
// used to override `modifiedProposedContent` for modifiable tools in the
// inline modify flow
newContent: string;
newContent?: string;
// used to provide custom cancellation message when outcome is Cancel
cancelMessage?: string;
}
export interface ToolExecuteConfirmationDetails {
type: 'exec';
title: string;
onConfirm: (outcome: ToolConfirmationOutcome) => Promise<void>;
onConfirm: (
outcome: ToolConfirmationOutcome,
payload?: ToolConfirmationPayload,
) => Promise<void>;
command: string;
rootCommand: string;
}
@@ -548,7 +553,10 @@ export interface ToolMcpConfirmationDetails {
serverName: string;
toolName: string;
toolDisplayName: string;
onConfirm: (outcome: ToolConfirmationOutcome) => Promise<void>;
onConfirm: (
outcome: ToolConfirmationOutcome,
payload?: ToolConfirmationPayload,
) => Promise<void>;
}
export interface ToolInfoConfirmationDetails {
@@ -573,6 +581,11 @@ export interface ToolPlanConfirmationDetails {
onConfirm: (outcome: ToolConfirmationOutcome) => Promise<void>;
}
/**
* TODO:
* 1. support explicit denied outcome
* 2. support proceed with modified input
*/
export enum ToolConfirmationOutcome {
ProceedOnce = 'proceed_once',
ProceedAlways = 'proceed_always',

View File

@@ -12,7 +12,6 @@ import type {
CLIControlRequest,
CLIControlResponse,
ControlCancelRequest,
PermissionApproval,
PermissionSuggestion,
} from '../types/protocol.js';
import {
@@ -299,7 +298,7 @@ export class Query implements AsyncIterable<CLIMessage> {
return;
}
if (process.env['DEBUG_SDK']) {
if (process.env['DEBUG']) {
console.warn('[Query] Unknown message type:', message);
}
this.inputStream.enqueue(message as CLIMessage);
@@ -320,12 +319,12 @@ export class Query implements AsyncIterable<CLIMessage> {
switch (payload.subtype) {
case 'can_use_tool':
response = (await this.handlePermissionRequest(
response = await this.handlePermissionRequest(
payload.tool_name,
payload.input as Record<string, unknown>,
payload.permission_suggestions,
requestAbortController.signal,
)) as unknown as Record<string, unknown>;
);
break;
case 'mcp_message':
@@ -360,15 +359,17 @@ export class Query implements AsyncIterable<CLIMessage> {
/**
* Handle permission request (can_use_tool)
* Converts PermissionResult to CLI-expected format: { behavior: 'allow', updatedInput: ... } or { behavior: 'deny', message: ... }
*/
private async handlePermissionRequest(
toolName: string,
toolInput: Record<string, unknown>,
permissionSuggestions: PermissionSuggestion[] | null,
signal: AbortSignal,
): Promise<PermissionApproval> {
): Promise<Record<string, unknown>> {
/* Default deny all wildcard tool requests */
if (!this.options.canUseTool) {
return { allowed: true };
return { behavior: 'deny', message: 'Denied' };
}
try {
@@ -390,21 +391,51 @@ export class Query implements AsyncIterable<CLIMessage> {
timeoutPromise,
]);
// Handle boolean return (backward compatibility)
if (typeof result === 'boolean') {
return { allowed: result };
return result
? { behavior: 'allow', updatedInput: toolInput }
: { behavior: 'deny', message: 'Denied' };
}
// Handle PermissionResult format
const permissionResult = result as {
behavior: 'allow' | 'deny';
updatedInput?: Record<string, unknown>;
message?: string;
interrupt?: boolean;
};
if (permissionResult.behavior === 'allow') {
return {
behavior: 'allow',
updatedInput: permissionResult.updatedInput ?? toolInput,
};
} else {
return {
behavior: 'deny',
message: permissionResult.message ?? 'Denied',
...(permissionResult.interrupt !== undefined
? { interrupt: permissionResult.interrupt }
: {}),
};
}
return result as PermissionApproval;
} catch (error) {
/**
* Timeout or error → deny (fail-safe).
* This ensures that any issues with the permission callback
* result in a safe default of denying access.
*/
const errorMessage =
error instanceof Error ? error.message : String(error);
console.warn(
'[Query] Permission callback error (denying by default):',
error instanceof Error ? error.message : String(error),
errorMessage,
);
return { allowed: false };
return {
behavior: 'deny',
message: `Permission check failed: ${errorMessage}`,
};
}
}

View File

@@ -283,6 +283,10 @@ export class ProcessTransport implements Transport {
throw new Error('Cannot write to closed transport');
}
if (this.childStdin.writableEnded) {
throw new Error('Cannot write to ended stream');
}
if (this.childProcess?.killed || this.childProcess?.exitCode !== null) {
throw new Error('Cannot write to terminated process');
}
@@ -293,17 +297,21 @@ export class ProcessTransport implements Transport {
);
}
if (process.env['DEBUG_SDK']) {
if (process.env['DEBUG']) {
this.logForDebugging(
`[ProcessTransport] Writing to stdin: ${message.substring(0, 100)}`,
`[ProcessTransport] Writing to stdin (${message.length} bytes): ${message.substring(0, 100)}`,
);
}
try {
const written = this.childStdin.write(message);
if (!written && process.env['DEBUG_SDK']) {
if (!written) {
this.logForDebugging(
'[ProcessTransport] Write buffer full, data queued',
`[ProcessTransport] Write buffer full (${message.length} bytes), data queued. Waiting for drain event...`,
);
} else if (process.env['DEBUG']) {
this.logForDebugging(
`[ProcessTransport] Write successful (${message.length} bytes)`,
);
}
} catch (error) {
@@ -322,6 +330,7 @@ export class ProcessTransport implements Transport {
const rl = readline.createInterface({
input: this.childStdout,
crlfDelay: Infinity,
terminal: false,
});
try {

View File

@@ -3,7 +3,7 @@
*/
import type { ToolDefinition as ToolDef } from './mcp.js';
import type { PermissionMode } from './protocol.js';
import type { PermissionMode, PermissionSuggestion } from './protocol.js';
import type { ExternalMcpServerConfig } from './queryOptionsSchema.js';
export type { ToolDef as ToolDefinition };
@@ -161,14 +161,15 @@ type ToolInput = Record<string, unknown>;
*
* @param toolName - Name of the tool being executed
* @param input - Input parameters for the tool
* @param options - Options including abort signal
* @param options - Options including abort signal and suggestions
* @returns Promise with permission result
*/
type CanUseTool = (
export type CanUseTool = (
toolName: string,
input: ToolInput,
options: {
signal: AbortSignal;
suggestions?: PermissionSuggestion[] | null;
},
) => Promise<PermissionResult>;

View File

@@ -3,7 +3,7 @@
*/
import { z } from 'zod';
import type { PermissionCallback } from './config.js';
import type { CanUseTool } from './config.js';
/**
* Schema for external MCP server configuration
@@ -35,7 +35,7 @@ export const QueryOptionsSchema = z
env: z.record(z.string(), z.string()).optional(),
permissionMode: z.enum(['default', 'plan', 'auto-edit', 'yolo']).optional(),
canUseTool: z
.custom<PermissionCallback>((val) => typeof val === 'function', {
.custom<CanUseTool>((val) => typeof val === 'function', {
message: 'canUseTool must be a function',
})
.optional(),

View File

@@ -81,7 +81,7 @@ describe('AbortController and Process Lifecycle (E2E)', () => {
const controller = new AbortController();
const q = query({
prompt: 'Write a detailed explanation about TypeScript',
prompt: 'Hello',
options: {
...SHARED_TEST_OPTIONS,
abortController: controller,

View File

@@ -160,8 +160,7 @@ describe('Multi-Turn Conversations (E2E)', () => {
session_id: sessionId,
message: {
role: 'user',
content:
'My name is Alice. Remember this during our current conversation.',
content: 'My name is Alice. Hello!',
},
parent_tool_use_id: null,
} as CLIUserMessage;
@@ -212,9 +211,7 @@ describe('Multi-Turn Conversations (E2E)', () => {
});
describe('Tool Usage in Multi-Turn', () => {
it(
'should handle tool usage across multiple turns',
async () => {
it('should handle tool usage across multiple turns', async () => {
async function* createToolConversation(): AsyncIterable<CLIUserMessage> {
const sessionId = crypto.randomUUID();
@@ -271,10 +268,6 @@ describe('Multi-Turn Conversations (E2E)', () => {
if (isCLIAssistantMessage(message)) {
assistantCount++;
}
if (isCLIResultMessage(message)) {
break;
}
}
expect(messages.length).toBeGreaterThan(0);
@@ -283,9 +276,7 @@ describe('Multi-Turn Conversations (E2E)', () => {
} finally {
await q.close();
}
},
TEST_TIMEOUT,
);
}, 60000); //TEST_TIMEOUT,
});
describe('Message Flow and Sequencing', () => {

View File

@@ -0,0 +1,748 @@
/**
* E2E tests for permission control features:
* - canUseTool callback parameter
* - setPermissionMode API
*/
import { describe, it, expect, beforeAll, afterAll } from 'vitest';
import { query } from '../../src/index.js';
import {
isCLIAssistantMessage,
isCLIResultMessage,
isCLIUserMessage,
type CLIUserMessage,
type ToolUseBlock,
type ContentBlock,
} from '../../src/types/protocol.js';
const TEST_CLI_PATH =
'/Users/mingholy/Work/Projects/qwen-code/packages/cli/index.ts';
const TEST_TIMEOUT = 1600000;
const SHARED_TEST_OPTIONS = {
pathToQwenExecutable: TEST_CLI_PATH,
debug: false,
// env here sets environment variables for the CLI child process
env: {
// DEBUG: '1',
},
};
/**
* Factory function that creates a streaming input with a control point.
* After the first message is yielded, the generator waits for a resume signal,
* allowing the test code to call query instance methods like setPermissionMode.
*/
function createStreamingInputWithControlPoint(
firstMessage: string,
secondMessage: string,
): {
generator: AsyncIterable<CLIUserMessage>;
resume: () => void;
} {
let resumeResolve: (() => void) | null = null;
const resumePromise = new Promise<void>((resolve) => {
resumeResolve = resolve;
});
const generator = (async function* () {
const sessionId = crypto.randomUUID();
yield {
type: 'user',
session_id: sessionId,
message: {
role: 'user',
content: firstMessage,
},
parent_tool_use_id: null,
} as CLIUserMessage;
await new Promise((resolve) => setTimeout(resolve, 200));
await resumePromise;
await new Promise((resolve) => setTimeout(resolve, 200));
yield {
type: 'user',
session_id: sessionId,
message: {
role: 'user',
content: secondMessage,
},
parent_tool_use_id: null,
} as CLIUserMessage;
})();
const resume = () => {
if (resumeResolve) {
resumeResolve();
}
};
return { generator, resume };
}
describe('Permission Control (E2E)', () => {
beforeAll(() => {
//process.env['DEBUG'] = '1';
});
afterAll(() => {
delete process.env['DEBUG'];
});
describe('canUseTool callback parameter', () => {
it(
'should invoke canUseTool callback when tool is requested',
async () => {
const toolCalls: Array<{
toolName: string;
input: Record<string, unknown>;
}> = [];
const q = query({
prompt: 'Write a js hello world to file.',
options: {
...SHARED_TEST_OPTIONS,
permissionMode: 'default',
canUseTool: async (toolName, input) => {
toolCalls.push({ toolName, input });
console.log(toolName, input);
/*
{
behavior: 'allow',
updatedInput: input,
};
*/
return {
behavior: 'deny',
message: 'Tool execution denied by user.',
};
},
},
});
try {
let hasToolUse = false;
for await (const message of q) {
if (isCLIAssistantMessage(message)) {
const toolUseBlock = message.message.content.find(
(block: ContentBlock): block is ToolUseBlock =>
block.type === 'tool_use',
);
if (toolUseBlock) {
hasToolUse = true;
}
}
}
expect(hasToolUse).toBe(true);
expect(toolCalls.length).toBeGreaterThan(0);
expect(toolCalls[0].toolName).toBeDefined();
expect(toolCalls[0].input).toBeDefined();
} finally {
await q.close();
}
},
TEST_TIMEOUT,
);
it(
'should allow tool execution when canUseTool returns allow',
async () => {
let callbackInvoked = false;
const q = query({
prompt: 'List files in the current directory',
options: {
...SHARED_TEST_OPTIONS,
permissionMode: 'default',
canUseTool: async (toolName, input) => {
callbackInvoked = true;
return {
behavior: 'allow',
updatedInput: input,
};
},
},
});
try {
let hasToolResult = false;
for await (const message of q) {
if (isCLIUserMessage(message)) {
if (
Array.isArray(message.message.content) &&
message.message.content.some(
(block) => block.type === 'tool_result',
)
) {
hasToolResult = true;
}
}
if (isCLIResultMessage(message)) {
break;
}
}
expect(callbackInvoked).toBe(true);
expect(hasToolResult).toBe(true);
} finally {
await q.close();
}
},
TEST_TIMEOUT,
);
it(
'should deny tool execution when canUseTool returns deny',
async () => {
let callbackInvoked = false;
const q = query({
prompt: 'List files in the current directory',
options: {
...SHARED_TEST_OPTIONS,
permissionMode: 'default',
canUseTool: async () => {
callbackInvoked = true;
return {
behavior: 'deny',
message: 'Tool execution denied by test',
};
},
},
});
try {
for await (const message of q) {
if (isCLIResultMessage(message)) {
break;
}
}
expect(callbackInvoked).toBe(true);
// Tool use might still appear, but execution should be denied
// The exact behavior depends on CLI implementation
} finally {
await q.close();
}
},
TEST_TIMEOUT,
);
it(
'should pass suggestions to canUseTool callback',
async () => {
let receivedSuggestions: unknown = null;
const q = query({
prompt: 'List files in the current directory',
options: {
...SHARED_TEST_OPTIONS,
permissionMode: 'default',
canUseTool: async (toolName, input, options) => {
receivedSuggestions = options.suggestions;
return {
behavior: 'allow',
updatedInput: input,
};
},
},
});
try {
for await (const message of q) {
if (isCLIResultMessage(message)) {
break;
}
}
// Suggestions may be null or an array, depending on CLI implementation
expect(receivedSuggestions !== undefined).toBe(true);
} finally {
await q.close();
}
},
TEST_TIMEOUT,
);
it(
'should pass abort signal to canUseTool callback',
async () => {
let receivedSignal: AbortSignal | undefined = undefined;
const q = query({
prompt: 'List files in the current directory',
options: {
...SHARED_TEST_OPTIONS,
permissionMode: 'default',
canUseTool: async (toolName, input, options) => {
receivedSignal = options.signal;
return {
behavior: 'allow',
updatedInput: input,
};
},
},
});
try {
for await (const message of q) {
if (isCLIResultMessage(message)) {
break;
}
}
expect(receivedSignal).toBeDefined();
expect(receivedSignal).toBeInstanceOf(AbortSignal);
} finally {
await q.close();
}
},
TEST_TIMEOUT,
);
it(
'should allow updatedInput modification in canUseTool callback',
async () => {
const originalInputs: Record<string, unknown>[] = [];
const updatedInputs: Record<string, unknown>[] = [];
const q = query({
prompt: 'List files in the current directory',
options: {
...SHARED_TEST_OPTIONS,
permissionMode: 'default',
canUseTool: async (toolName, input) => {
originalInputs.push({ ...input });
const updatedInput = {
...input,
modified: true,
testKey: 'testValue',
};
updatedInputs.push(updatedInput);
return {
behavior: 'allow',
updatedInput,
};
},
},
});
try {
for await (const message of q) {
if (isCLIResultMessage(message)) {
break;
}
}
expect(originalInputs.length).toBeGreaterThan(0);
expect(updatedInputs.length).toBeGreaterThan(0);
expect(updatedInputs[0]?.['modified']).toBe(true);
expect(updatedInputs[0]?.['testKey']).toBe('testValue');
} finally {
await q.close();
}
},
TEST_TIMEOUT,
);
it(
'should default to deny when canUseTool is not provided',
async () => {
const q = query({
prompt: 'List files in the current directory',
options: {
...SHARED_TEST_OPTIONS,
permissionMode: 'default',
// canUseTool not provided
},
});
try {
// When canUseTool is not provided, tools should be denied by default
// The exact behavior depends on CLI implementation
for await (const message of q) {
if (isCLIResultMessage(message)) {
break;
}
}
// Test passes if no errors occur
expect(true).toBe(true);
} finally {
await q.close();
}
},
TEST_TIMEOUT,
);
});
describe('setPermissionMode API', () => {
it(
'should change permission mode from default to yolo',
async () => {
const { generator, resume } = createStreamingInputWithControlPoint(
'List files in the current directory',
'Now read the package.json file',
);
const q = query({
prompt: generator,
options: {
...SHARED_TEST_OPTIONS,
permissionMode: 'default',
},
});
try {
const resolvers: {
first?: () => void;
second?: () => void;
} = {};
const firstResponsePromise = new Promise<void>((resolve) => {
resolvers.first = resolve;
});
const secondResponsePromise = new Promise<void>((resolve) => {
resolvers.second = resolve;
});
let firstResponseReceived = false;
let secondResponseReceived = false;
(async () => {
for await (const message of q) {
if (
isCLIAssistantMessage(message) ||
isCLIResultMessage(message)
) {
if (!firstResponseReceived) {
firstResponseReceived = true;
resolvers.first?.();
} else if (!secondResponseReceived) {
secondResponseReceived = true;
resolvers.second?.();
}
}
}
})();
await Promise.race([
firstResponsePromise,
new Promise((_, reject) =>
setTimeout(
() => reject(new Error('Timeout waiting for first response')),
TEST_TIMEOUT,
),
),
]);
expect(firstResponseReceived).toBe(true);
await q.setPermissionMode('yolo');
resume();
await Promise.race([
secondResponsePromise,
new Promise((_, reject) =>
setTimeout(
() => reject(new Error('Timeout waiting for second response')),
TEST_TIMEOUT,
),
),
]);
expect(secondResponseReceived).toBe(true);
} finally {
await q.close();
}
},
TEST_TIMEOUT,
);
it(
'should change permission mode from yolo to plan',
async () => {
const { generator, resume } = createStreamingInputWithControlPoint(
'List files in the current directory',
'Now read the package.json file',
);
const q = query({
prompt: generator,
options: {
...SHARED_TEST_OPTIONS,
permissionMode: 'yolo',
},
});
try {
const resolvers: {
first?: () => void;
second?: () => void;
} = {};
const firstResponsePromise = new Promise<void>((resolve) => {
resolvers.first = resolve;
});
const secondResponsePromise = new Promise<void>((resolve) => {
resolvers.second = resolve;
});
let firstResponseReceived = false;
let secondResponseReceived = false;
(async () => {
for await (const message of q) {
if (
isCLIAssistantMessage(message) ||
isCLIResultMessage(message)
) {
if (!firstResponseReceived) {
firstResponseReceived = true;
resolvers.first?.();
} else if (!secondResponseReceived) {
secondResponseReceived = true;
resolvers.second?.();
}
}
}
})();
await Promise.race([
firstResponsePromise,
new Promise((_, reject) =>
setTimeout(
() => reject(new Error('Timeout waiting for first response')),
TEST_TIMEOUT,
),
),
]);
expect(firstResponseReceived).toBe(true);
await q.setPermissionMode('plan');
resume();
await Promise.race([
secondResponsePromise,
new Promise((_, reject) =>
setTimeout(
() => reject(new Error('Timeout waiting for second response')),
TEST_TIMEOUT,
),
),
]);
expect(secondResponseReceived).toBe(true);
} finally {
await q.close();
}
},
TEST_TIMEOUT,
);
it(
'should change permission mode to auto-edit',
async () => {
const { generator, resume } = createStreamingInputWithControlPoint(
'List files in the current directory',
'Now read the package.json file',
);
const q = query({
prompt: generator,
options: {
...SHARED_TEST_OPTIONS,
permissionMode: 'default',
},
});
try {
const resolvers: {
first?: () => void;
second?: () => void;
} = {};
const firstResponsePromise = new Promise<void>((resolve) => {
resolvers.first = resolve;
});
const secondResponsePromise = new Promise<void>((resolve) => {
resolvers.second = resolve;
});
let firstResponseReceived = false;
let secondResponseReceived = false;
(async () => {
for await (const message of q) {
if (
isCLIAssistantMessage(message) ||
isCLIResultMessage(message)
) {
if (!firstResponseReceived) {
firstResponseReceived = true;
resolvers.first?.();
} else if (!secondResponseReceived) {
secondResponseReceived = true;
resolvers.second?.();
}
}
}
})();
await Promise.race([
firstResponsePromise,
new Promise((_, reject) =>
setTimeout(
() => reject(new Error('Timeout waiting for first response')),
TEST_TIMEOUT,
),
),
]);
expect(firstResponseReceived).toBe(true);
await q.setPermissionMode('auto-edit');
resume();
await Promise.race([
secondResponsePromise,
new Promise((_, reject) =>
setTimeout(
() => reject(new Error('Timeout waiting for second response')),
TEST_TIMEOUT,
),
),
]);
expect(secondResponseReceived).toBe(true);
} finally {
await q.close();
}
},
TEST_TIMEOUT,
);
it(
'should throw error when setPermissionMode is called on closed query',
async () => {
const q = query({
prompt: 'Hello',
options: {
...SHARED_TEST_OPTIONS,
permissionMode: 'default',
},
});
await q.close();
await expect(q.setPermissionMode('yolo')).rejects.toThrow(
'Query is closed',
);
},
TEST_TIMEOUT,
);
});
describe('canUseTool and setPermissionMode integration', () => {
it(
'should work together - canUseTool callback with dynamic permission mode change',
async () => {
const toolCalls: Array<{
toolName: string;
input: Record<string, unknown>;
}> = [];
const { generator, resume } = createStreamingInputWithControlPoint(
'List files in the current directory',
'Now read the package.json file',
);
const q = query({
prompt: generator,
options: {
...SHARED_TEST_OPTIONS,
permissionMode: 'default',
canUseTool: async (toolName, input) => {
toolCalls.push({ toolName, input });
return {
behavior: 'allow',
updatedInput: input,
};
},
},
});
try {
const resolvers: {
first?: () => void;
second?: () => void;
} = {};
const firstResponsePromise = new Promise<void>((resolve) => {
resolvers.first = resolve;
});
const secondResponsePromise = new Promise<void>((resolve) => {
resolvers.second = resolve;
});
let firstResponseReceived = false;
let secondResponseReceived = false;
(async () => {
for await (const message of q) {
if (
isCLIAssistantMessage(message) ||
isCLIResultMessage(message)
) {
if (!firstResponseReceived) {
firstResponseReceived = true;
resolvers.first?.();
} else if (!secondResponseReceived) {
secondResponseReceived = true;
resolvers.second?.();
}
}
}
})();
await Promise.race([
firstResponsePromise,
new Promise((_, reject) =>
setTimeout(
() => reject(new Error('Timeout waiting for first response')),
TEST_TIMEOUT,
),
),
]);
expect(firstResponseReceived).toBe(true);
expect(toolCalls.length).toBeGreaterThan(0);
await q.setPermissionMode('yolo');
resume();
await Promise.race([
secondResponsePromise,
new Promise((_, reject) =>
setTimeout(
() => reject(new Error('Timeout waiting for second response')),
TEST_TIMEOUT,
),
),
]);
expect(secondResponseReceived).toBe(true);
} finally {
await q.close();
}
},
TEST_TIMEOUT,
);
});
});