mirror of
https://github.com/QwenLM/qwen-code.git
synced 2025-12-19 09:33:53 +00:00
refactor: session and canUseTool support
This commit is contained in:
1
.vscode/launch.json
vendored
1
.vscode/launch.json
vendored
@@ -79,7 +79,6 @@
|
|||||||
"--",
|
"--",
|
||||||
"-p",
|
"-p",
|
||||||
"${input:prompt}",
|
"${input:prompt}",
|
||||||
"-y",
|
|
||||||
"--output-format",
|
"--output-format",
|
||||||
"stream-json"
|
"stream-json"
|
||||||
],
|
],
|
||||||
|
|||||||
@@ -758,8 +758,14 @@ export async function loadCliConfig(
|
|||||||
interactive = false;
|
interactive = false;
|
||||||
}
|
}
|
||||||
// In non-interactive mode, exclude tools that require a prompt.
|
// In non-interactive mode, exclude tools that require a prompt.
|
||||||
|
// However, if stream-json input is used, control can be requested via JSON messages,
|
||||||
|
// so tools should not be excluded in that case.
|
||||||
const extraExcludes: string[] = [];
|
const extraExcludes: string[] = [];
|
||||||
if (!interactive && !argv.experimentalAcp) {
|
if (
|
||||||
|
!interactive &&
|
||||||
|
!argv.experimentalAcp &&
|
||||||
|
inputFormat !== InputFormat.STREAM_JSON
|
||||||
|
) {
|
||||||
switch (approvalMode) {
|
switch (approvalMode) {
|
||||||
case ApprovalMode.PLAN:
|
case ApprovalMode.PLAN:
|
||||||
case ApprovalMode.DEFAULT:
|
case ApprovalMode.DEFAULT:
|
||||||
|
|||||||
@@ -479,6 +479,10 @@ describe('gemini.tsx main function kitty protocol', () => {
|
|||||||
inputFormat: undefined,
|
inputFormat: undefined,
|
||||||
outputFormat: undefined,
|
outputFormat: undefined,
|
||||||
includePartialMessages: undefined,
|
includePartialMessages: undefined,
|
||||||
|
coreTools: undefined,
|
||||||
|
excludeTools: undefined,
|
||||||
|
authType: undefined,
|
||||||
|
maxSessionTurns: undefined,
|
||||||
});
|
});
|
||||||
|
|
||||||
await main();
|
await main();
|
||||||
|
|||||||
@@ -26,7 +26,7 @@
|
|||||||
import type { IControlContext } from './ControlContext.js';
|
import type { IControlContext } from './ControlContext.js';
|
||||||
import type { IPendingRequestRegistry } from './controllers/baseController.js';
|
import type { IPendingRequestRegistry } from './controllers/baseController.js';
|
||||||
import { SystemController } from './controllers/systemController.js';
|
import { SystemController } from './controllers/systemController.js';
|
||||||
// import { PermissionController } from './controllers/permissionController.js';
|
import { PermissionController } from './controllers/permissionController.js';
|
||||||
// import { MCPController } from './controllers/mcpController.js';
|
// import { MCPController } from './controllers/mcpController.js';
|
||||||
// import { HookController } from './controllers/hookController.js';
|
// import { HookController } from './controllers/hookController.js';
|
||||||
import type {
|
import type {
|
||||||
@@ -64,7 +64,7 @@ export class ControlDispatcher implements IPendingRequestRegistry {
|
|||||||
|
|
||||||
// Make controllers publicly accessible
|
// Make controllers publicly accessible
|
||||||
readonly systemController: SystemController;
|
readonly systemController: SystemController;
|
||||||
// readonly permissionController: PermissionController;
|
readonly permissionController: PermissionController;
|
||||||
// readonly mcpController: MCPController;
|
// readonly mcpController: MCPController;
|
||||||
// readonly hookController: HookController;
|
// readonly hookController: HookController;
|
||||||
|
|
||||||
@@ -83,11 +83,11 @@ export class ControlDispatcher implements IPendingRequestRegistry {
|
|||||||
this,
|
this,
|
||||||
'SystemController',
|
'SystemController',
|
||||||
);
|
);
|
||||||
// this.permissionController = new PermissionController(
|
this.permissionController = new PermissionController(
|
||||||
// context,
|
context,
|
||||||
// this,
|
this,
|
||||||
// 'PermissionController',
|
'PermissionController',
|
||||||
// );
|
);
|
||||||
// this.mcpController = new MCPController(context, this, 'MCPController');
|
// this.mcpController = new MCPController(context, this, 'MCPController');
|
||||||
// this.hookController = new HookController(context, this, 'HookController');
|
// this.hookController = new HookController(context, this, 'HookController');
|
||||||
|
|
||||||
@@ -230,7 +230,7 @@ export class ControlDispatcher implements IPendingRequestRegistry {
|
|||||||
|
|
||||||
// Cleanup controllers (MCP controller will close all clients)
|
// Cleanup controllers (MCP controller will close all clients)
|
||||||
this.systemController.cleanup();
|
this.systemController.cleanup();
|
||||||
// this.permissionController.cleanup();
|
this.permissionController.cleanup();
|
||||||
// this.mcpController.cleanup();
|
// this.mcpController.cleanup();
|
||||||
// this.hookController.cleanup();
|
// this.hookController.cleanup();
|
||||||
}
|
}
|
||||||
@@ -302,9 +302,9 @@ export class ControlDispatcher implements IPendingRequestRegistry {
|
|||||||
case 'supported_commands':
|
case 'supported_commands':
|
||||||
return this.systemController;
|
return this.systemController;
|
||||||
|
|
||||||
// case 'can_use_tool':
|
case 'can_use_tool':
|
||||||
// case 'set_permission_mode':
|
case 'set_permission_mode':
|
||||||
// return this.permissionController;
|
return this.permissionController;
|
||||||
|
|
||||||
// case 'mcp_message':
|
// case 'mcp_message':
|
||||||
// case 'mcp_server_status':
|
// case 'mcp_server_status':
|
||||||
|
|||||||
@@ -29,7 +29,7 @@
|
|||||||
import type { IControlContext } from './ControlContext.js';
|
import type { IControlContext } from './ControlContext.js';
|
||||||
import type { ControlDispatcher } from './ControlDispatcher.js';
|
import type { ControlDispatcher } from './ControlDispatcher.js';
|
||||||
import type {
|
import type {
|
||||||
// PermissionServiceAPI,
|
PermissionServiceAPI,
|
||||||
SystemServiceAPI,
|
SystemServiceAPI,
|
||||||
// McpServiceAPI,
|
// McpServiceAPI,
|
||||||
// HookServiceAPI,
|
// HookServiceAPI,
|
||||||
@@ -61,43 +61,43 @@ export class ControlService {
|
|||||||
* Handles tool execution permissions, approval checks, and callbacks.
|
* Handles tool execution permissions, approval checks, and callbacks.
|
||||||
* Delegates to the shared PermissionController instance.
|
* Delegates to the shared PermissionController instance.
|
||||||
*/
|
*/
|
||||||
// get permission(): PermissionServiceAPI {
|
get permission(): PermissionServiceAPI {
|
||||||
// const controller = this.dispatcher.permissionController;
|
const controller = this.dispatcher.permissionController;
|
||||||
// return {
|
return {
|
||||||
// /**
|
/**
|
||||||
// * Check if a tool should be allowed based on current permission settings
|
* Check if a tool should be allowed based on current permission settings
|
||||||
// *
|
*
|
||||||
// * Evaluates permission mode and tool registry to determine if execution
|
* Evaluates permission mode and tool registry to determine if execution
|
||||||
// * should proceed. Can optionally modify tool arguments based on confirmation details.
|
* should proceed. Can optionally modify tool arguments based on confirmation details.
|
||||||
// *
|
*
|
||||||
// * @param toolRequest - Tool call request information
|
* @param toolRequest - Tool call request information
|
||||||
// * @param confirmationDetails - Optional confirmation details for UI
|
* @param confirmationDetails - Optional confirmation details for UI
|
||||||
// * @returns Permission decision with optional updated arguments
|
* @returns Permission decision with optional updated arguments
|
||||||
// */
|
*/
|
||||||
// shouldAllowTool: controller.shouldAllowTool.bind(controller),
|
shouldAllowTool: controller.shouldAllowTool.bind(controller),
|
||||||
//
|
|
||||||
// /**
|
/**
|
||||||
// * Build UI suggestions for tool confirmation dialogs
|
* Build UI suggestions for tool confirmation dialogs
|
||||||
// *
|
*
|
||||||
// * Creates actionable permission suggestions based on tool confirmation details.
|
* Creates actionable permission suggestions based on tool confirmation details.
|
||||||
// *
|
*
|
||||||
// * @param confirmationDetails - Tool confirmation details
|
* @param confirmationDetails - Tool confirmation details
|
||||||
// * @returns Array of permission suggestions or null
|
* @returns Array of permission suggestions or null
|
||||||
// */
|
*/
|
||||||
// buildPermissionSuggestions:
|
buildPermissionSuggestions:
|
||||||
// controller.buildPermissionSuggestions.bind(controller),
|
controller.buildPermissionSuggestions.bind(controller),
|
||||||
//
|
|
||||||
// /**
|
/**
|
||||||
// * Get callback for monitoring tool call status updates
|
* Get callback for monitoring tool call status updates
|
||||||
// *
|
*
|
||||||
// * Returns callback function for integration with CoreToolScheduler.
|
* Returns callback function for integration with CoreToolScheduler.
|
||||||
// *
|
*
|
||||||
// * @returns Callback function for tool call updates
|
* @returns Callback function for tool call updates
|
||||||
// */
|
*/
|
||||||
// getToolCallUpdateCallback:
|
getToolCallUpdateCallback:
|
||||||
// controller.getToolCallUpdateCallback.bind(controller),
|
controller.getToolCallUpdateCallback.bind(controller),
|
||||||
// };
|
};
|
||||||
// }
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* System Domain API
|
* System Domain API
|
||||||
|
|||||||
@@ -17,6 +17,8 @@
|
|||||||
import type {
|
import type {
|
||||||
ToolCallRequestInfo,
|
ToolCallRequestInfo,
|
||||||
WaitingToolCall,
|
WaitingToolCall,
|
||||||
|
ToolExecuteConfirmationDetails,
|
||||||
|
ToolMcpConfirmationDetails,
|
||||||
} from '@qwen-code/qwen-code-core';
|
} from '@qwen-code/qwen-code-core';
|
||||||
import {
|
import {
|
||||||
InputFormat,
|
InputFormat,
|
||||||
@@ -430,17 +432,14 @@ export class PermissionController extends BaseController {
|
|||||||
toolCall.confirmationDetails,
|
toolCall.confirmationDetails,
|
||||||
);
|
);
|
||||||
|
|
||||||
const response = await this.sendControlRequest(
|
const response = await this.sendControlRequest({
|
||||||
{
|
subtype: 'can_use_tool',
|
||||||
subtype: 'can_use_tool',
|
tool_name: toolCall.request.name,
|
||||||
tool_name: toolCall.request.name,
|
tool_use_id: toolCall.request.callId,
|
||||||
tool_use_id: toolCall.request.callId,
|
input: toolCall.request.args,
|
||||||
input: toolCall.request.args,
|
permission_suggestions: permissionSuggestions,
|
||||||
permission_suggestions: permissionSuggestions,
|
blocked_path: null,
|
||||||
blocked_path: null,
|
} as CLIControlPermissionRequest);
|
||||||
} as CLIControlPermissionRequest,
|
|
||||||
30000,
|
|
||||||
);
|
|
||||||
|
|
||||||
if (response.subtype !== 'success') {
|
if (response.subtype !== 'success') {
|
||||||
await toolCall.confirmationDetails.onConfirm(
|
await toolCall.confirmationDetails.onConfirm(
|
||||||
@@ -462,8 +461,15 @@ export class PermissionController extends BaseController {
|
|||||||
ToolConfirmationOutcome.ProceedOnce,
|
ToolConfirmationOutcome.ProceedOnce,
|
||||||
);
|
);
|
||||||
} else {
|
} else {
|
||||||
|
// Extract cancel message from response if available
|
||||||
|
const cancelMessage =
|
||||||
|
typeof payload['message'] === 'string'
|
||||||
|
? payload['message']
|
||||||
|
: undefined;
|
||||||
|
|
||||||
await toolCall.confirmationDetails.onConfirm(
|
await toolCall.confirmationDetails.onConfirm(
|
||||||
ToolConfirmationOutcome.Cancel,
|
ToolConfirmationOutcome.Cancel,
|
||||||
|
cancelMessage ? { cancelMessage } : undefined,
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
@@ -473,9 +479,23 @@ export class PermissionController extends BaseController {
|
|||||||
error,
|
error,
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
await toolCall.confirmationDetails.onConfirm(
|
// On error, use default cancel message
|
||||||
ToolConfirmationOutcome.Cancel,
|
// Only pass payload for exec and mcp types that support it
|
||||||
);
|
const confirmationType = toolCall.confirmationDetails.type;
|
||||||
|
if (confirmationType === 'exec' || confirmationType === 'mcp') {
|
||||||
|
const execOrMcpDetails = toolCall.confirmationDetails as
|
||||||
|
| ToolExecuteConfirmationDetails
|
||||||
|
| ToolMcpConfirmationDetails;
|
||||||
|
await execOrMcpDetails.onConfirm(
|
||||||
|
ToolConfirmationOutcome.Cancel,
|
||||||
|
undefined,
|
||||||
|
);
|
||||||
|
} else {
|
||||||
|
// For other types, don't pass payload (backward compatible)
|
||||||
|
await toolCall.confirmationDetails.onConfirm(
|
||||||
|
ToolConfirmationOutcome.Cancel,
|
||||||
|
);
|
||||||
|
}
|
||||||
} finally {
|
} finally {
|
||||||
this.pendingOutgoingRequests.delete(toolCall.request.callId);
|
this.pendingOutgoingRequests.delete(toolCall.request.callId);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -939,9 +939,25 @@ export abstract class BaseJsonOutputAdapter {
|
|||||||
this.emitMessageImpl(message);
|
this.emitMessageImpl(message);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Checks if responseParts contain any functionResponse with an error.
|
||||||
|
* This handles cancelled responses and other error cases where the error
|
||||||
|
* is embedded in responseParts rather than the top-level error field.
|
||||||
|
* @param responseParts - Array of Part objects
|
||||||
|
* @returns Error message if found, undefined otherwise
|
||||||
|
*/
|
||||||
|
private checkResponsePartsForError(
|
||||||
|
responseParts: Part[] | undefined,
|
||||||
|
): string | undefined {
|
||||||
|
// Use the shared helper function defined at file level
|
||||||
|
return checkResponsePartsForError(responseParts);
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Emits a tool result message.
|
* Emits a tool result message.
|
||||||
* Collects execution denied tool calls for inclusion in result messages.
|
* Collects execution denied tool calls for inclusion in result messages.
|
||||||
|
* Handles both explicit errors (response.error) and errors embedded in
|
||||||
|
* responseParts (e.g., cancelled responses).
|
||||||
* @param request - Tool call request info
|
* @param request - Tool call request info
|
||||||
* @param response - Tool call response info
|
* @param response - Tool call response info
|
||||||
* @param parentToolUseId - Parent tool use ID (null for main agent)
|
* @param parentToolUseId - Parent tool use ID (null for main agent)
|
||||||
@@ -951,6 +967,14 @@ export abstract class BaseJsonOutputAdapter {
|
|||||||
response: ToolCallResponseInfo,
|
response: ToolCallResponseInfo,
|
||||||
parentToolUseId: string | null = null,
|
parentToolUseId: string | null = null,
|
||||||
): void {
|
): void {
|
||||||
|
// Check for errors in responseParts (e.g., cancelled responses)
|
||||||
|
const responsePartsError = this.checkResponsePartsForError(
|
||||||
|
response.responseParts,
|
||||||
|
);
|
||||||
|
|
||||||
|
// Determine if this is an error response
|
||||||
|
const hasError = Boolean(response.error) || Boolean(responsePartsError);
|
||||||
|
|
||||||
// Track permission denials (execution denied errors)
|
// Track permission denials (execution denied errors)
|
||||||
if (
|
if (
|
||||||
response.error &&
|
response.error &&
|
||||||
@@ -967,7 +991,7 @@ export abstract class BaseJsonOutputAdapter {
|
|||||||
const block: ToolResultBlock = {
|
const block: ToolResultBlock = {
|
||||||
type: 'tool_result',
|
type: 'tool_result',
|
||||||
tool_use_id: request.callId,
|
tool_use_id: request.callId,
|
||||||
is_error: Boolean(response.error),
|
is_error: hasError,
|
||||||
};
|
};
|
||||||
const content = toolResultContent(response);
|
const content = toolResultContent(response);
|
||||||
if (content !== undefined) {
|
if (content !== undefined) {
|
||||||
@@ -1173,11 +1197,41 @@ export function partsToString(parts: Part[]): string {
|
|||||||
.join('');
|
.join('');
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Checks if responseParts contain any functionResponse with an error.
|
||||||
|
* Helper function for extracting error messages from responseParts.
|
||||||
|
* @param responseParts - Array of Part objects
|
||||||
|
* @returns Error message if found, undefined otherwise
|
||||||
|
*/
|
||||||
|
function checkResponsePartsForError(
|
||||||
|
responseParts: Part[] | undefined,
|
||||||
|
): string | undefined {
|
||||||
|
if (!responseParts || responseParts.length === 0) {
|
||||||
|
return undefined;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (const part of responseParts) {
|
||||||
|
if (
|
||||||
|
'functionResponse' in part &&
|
||||||
|
part.functionResponse?.response &&
|
||||||
|
typeof part.functionResponse.response === 'object' &&
|
||||||
|
'error' in part.functionResponse.response &&
|
||||||
|
part.functionResponse.response['error']
|
||||||
|
) {
|
||||||
|
const error = part.functionResponse.response['error'];
|
||||||
|
return typeof error === 'string' ? error : String(error);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return undefined;
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Extracts content from tool response.
|
* Extracts content from tool response.
|
||||||
* Uses functionResponsePartsToString to properly handle functionResponse parts,
|
* Uses functionResponsePartsToString to properly handle functionResponse parts,
|
||||||
* which correctly extracts output content from functionResponse objects rather
|
* which correctly extracts output content from functionResponse objects rather
|
||||||
* than simply concatenating text or JSON.stringify.
|
* than simply concatenating text or JSON.stringify.
|
||||||
|
* Also handles errors embedded in responseParts (e.g., cancelled responses).
|
||||||
*
|
*
|
||||||
* @param response - Tool call response
|
* @param response - Tool call response
|
||||||
* @returns String content or undefined
|
* @returns String content or undefined
|
||||||
@@ -1188,6 +1242,11 @@ export function toolResultContent(
|
|||||||
if (response.error) {
|
if (response.error) {
|
||||||
return response.error.message;
|
return response.error.message;
|
||||||
}
|
}
|
||||||
|
// Check for errors in responseParts (e.g., cancelled responses)
|
||||||
|
const responsePartsError = checkResponsePartsForError(response.responseParts);
|
||||||
|
if (responsePartsError) {
|
||||||
|
return responsePartsError;
|
||||||
|
}
|
||||||
if (
|
if (
|
||||||
typeof response.resultDisplay === 'string' &&
|
typeof response.resultDisplay === 'string' &&
|
||||||
response.resultDisplay.trim().length > 0
|
response.resultDisplay.trim().length > 0
|
||||||
|
|||||||
@@ -134,7 +134,7 @@ function createControlCancel(requestId: string): ControlCancelRequest {
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
describe('runNonInteractiveStreamJson', () => {
|
describe('runNonInteractiveStreamJson (refactored)', () => {
|
||||||
let config: Config;
|
let config: Config;
|
||||||
let mockInputReader: {
|
let mockInputReader: {
|
||||||
read: () => AsyncGenerator<
|
read: () => AsyncGenerator<
|
||||||
|
|||||||
@@ -4,17 +4,6 @@
|
|||||||
* SPDX-License-Identifier: Apache-2.0
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
*/
|
*/
|
||||||
|
|
||||||
/**
|
|
||||||
* Stream JSON Runner with Session State Machine
|
|
||||||
*
|
|
||||||
* Handles stream-json input/output format with:
|
|
||||||
* - Initialize handshake
|
|
||||||
* - Message routing (control vs user messages)
|
|
||||||
* - FIFO user message queue
|
|
||||||
* - Sequential message processing
|
|
||||||
* - Graceful shutdown
|
|
||||||
*/
|
|
||||||
|
|
||||||
import type { Config } from '@qwen-code/qwen-code-core';
|
import type { Config } from '@qwen-code/qwen-code-core';
|
||||||
import { StreamJsonInputReader } from './io/StreamJsonInputReader.js';
|
import { StreamJsonInputReader } from './io/StreamJsonInputReader.js';
|
||||||
import { StreamJsonOutputAdapter } from './io/StreamJsonOutputAdapter.js';
|
import { StreamJsonOutputAdapter } from './io/StreamJsonOutputAdapter.js';
|
||||||
@@ -42,48 +31,7 @@ import { createMinimalSettings } from '../config/settings.js';
|
|||||||
import { runNonInteractive } from '../nonInteractiveCli.js';
|
import { runNonInteractive } from '../nonInteractiveCli.js';
|
||||||
import { ConsolePatcher } from '../ui/utils/ConsolePatcher.js';
|
import { ConsolePatcher } from '../ui/utils/ConsolePatcher.js';
|
||||||
|
|
||||||
const SESSION_STATE = {
|
class Session {
|
||||||
INITIALIZING: 'initializing',
|
|
||||||
IDLE: 'idle',
|
|
||||||
PROCESSING_QUERY: 'processing_query',
|
|
||||||
SHUTTING_DOWN: 'shutting_down',
|
|
||||||
} as const;
|
|
||||||
|
|
||||||
type SessionState = (typeof SESSION_STATE)[keyof typeof SESSION_STATE];
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Message type classification for routing
|
|
||||||
*/
|
|
||||||
type MessageType =
|
|
||||||
| 'control_request'
|
|
||||||
| 'control_response'
|
|
||||||
| 'control_cancel'
|
|
||||||
| 'user'
|
|
||||||
| 'assistant'
|
|
||||||
| 'system'
|
|
||||||
| 'result'
|
|
||||||
| 'stream_event'
|
|
||||||
| 'unknown';
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Routed message with classification
|
|
||||||
*/
|
|
||||||
interface RoutedMessage {
|
|
||||||
type: MessageType;
|
|
||||||
message:
|
|
||||||
| CLIMessage
|
|
||||||
| CLIControlRequest
|
|
||||||
| CLIControlResponse
|
|
||||||
| ControlCancelRequest;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Session Manager
|
|
||||||
*
|
|
||||||
* Manages the session lifecycle and message processing state machine.
|
|
||||||
*/
|
|
||||||
class SessionManager {
|
|
||||||
private state: SessionState = SESSION_STATE.INITIALIZING;
|
|
||||||
private userMessageQueue: CLIUserMessage[] = [];
|
private userMessageQueue: CLIUserMessage[] = [];
|
||||||
private abortController: AbortController;
|
private abortController: AbortController;
|
||||||
private config: Config;
|
private config: Config;
|
||||||
@@ -98,6 +46,8 @@ class SessionManager {
|
|||||||
private debugMode: boolean;
|
private debugMode: boolean;
|
||||||
private shutdownHandler: (() => void) | null = null;
|
private shutdownHandler: (() => void) | null = null;
|
||||||
private initialPrompt: CLIUserMessage | null = null;
|
private initialPrompt: CLIUserMessage | null = null;
|
||||||
|
private processingPromise: Promise<void> | null = null;
|
||||||
|
private isShuttingDown: boolean = false;
|
||||||
|
|
||||||
constructor(config: Config, initialPrompt?: CLIUserMessage) {
|
constructor(config: Config, initialPrompt?: CLIUserMessage) {
|
||||||
this.config = config;
|
this.config = config;
|
||||||
@@ -112,161 +62,18 @@ class SessionManager {
|
|||||||
config.getIncludePartialMessages(),
|
config.getIncludePartialMessages(),
|
||||||
);
|
);
|
||||||
|
|
||||||
// Setup signal handlers for graceful shutdown
|
|
||||||
this.setupSignalHandlers();
|
this.setupSignalHandlers();
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Get next prompt ID
|
|
||||||
*/
|
|
||||||
private getNextPromptId(): string {
|
private getNextPromptId(): string {
|
||||||
this.promptIdCounter++;
|
this.promptIdCounter++;
|
||||||
return `${this.sessionId}########${this.promptIdCounter}`;
|
return `${this.sessionId}########${this.promptIdCounter}`;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Route a message to the appropriate handler based on its type
|
|
||||||
*
|
|
||||||
* Classifies incoming messages and routes them to appropriate handlers.
|
|
||||||
*/
|
|
||||||
private route(
|
|
||||||
message:
|
|
||||||
| CLIMessage
|
|
||||||
| CLIControlRequest
|
|
||||||
| CLIControlResponse
|
|
||||||
| ControlCancelRequest,
|
|
||||||
): RoutedMessage {
|
|
||||||
// Check control messages first
|
|
||||||
if (isControlRequest(message)) {
|
|
||||||
return { type: 'control_request', message };
|
|
||||||
}
|
|
||||||
if (isControlResponse(message)) {
|
|
||||||
return { type: 'control_response', message };
|
|
||||||
}
|
|
||||||
if (isControlCancel(message)) {
|
|
||||||
return { type: 'control_cancel', message };
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check data messages
|
|
||||||
if (isCLIUserMessage(message)) {
|
|
||||||
return { type: 'user', message };
|
|
||||||
}
|
|
||||||
if (isCLIAssistantMessage(message)) {
|
|
||||||
return { type: 'assistant', message };
|
|
||||||
}
|
|
||||||
if (isCLISystemMessage(message)) {
|
|
||||||
return { type: 'system', message };
|
|
||||||
}
|
|
||||||
if (isCLIResultMessage(message)) {
|
|
||||||
return { type: 'result', message };
|
|
||||||
}
|
|
||||||
if (isCLIPartialAssistantMessage(message)) {
|
|
||||||
return { type: 'stream_event', message };
|
|
||||||
}
|
|
||||||
|
|
||||||
// Unknown message type
|
|
||||||
if (this.debugMode) {
|
|
||||||
console.error(
|
|
||||||
'[SessionManager] Unknown message type:',
|
|
||||||
JSON.stringify(message, null, 2),
|
|
||||||
);
|
|
||||||
}
|
|
||||||
return { type: 'unknown', message };
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Process a single message with unified logic for both initial prompt and stream messages.
|
|
||||||
*
|
|
||||||
* Handles:
|
|
||||||
* - Abort check
|
|
||||||
* - First message detection and handling
|
|
||||||
* - Normal message processing
|
|
||||||
* - Shutdown state checks
|
|
||||||
*
|
|
||||||
* @param message - Message to process
|
|
||||||
* @returns true if the calling code should exit (break/return), false to continue
|
|
||||||
*/
|
|
||||||
private async processSingleMessage(
|
|
||||||
message:
|
|
||||||
| CLIMessage
|
|
||||||
| CLIControlRequest
|
|
||||||
| CLIControlResponse
|
|
||||||
| ControlCancelRequest,
|
|
||||||
): Promise<boolean> {
|
|
||||||
// Check for abort
|
|
||||||
if (this.abortController.signal.aborted) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Handle first message if control system not yet initialized
|
|
||||||
if (this.controlSystemEnabled === null) {
|
|
||||||
const handled = await this.handleFirstMessage(message);
|
|
||||||
if (handled) {
|
|
||||||
// If handled, check if we should shutdown
|
|
||||||
return this.state === SESSION_STATE.SHUTTING_DOWN;
|
|
||||||
}
|
|
||||||
// If not handled, fall through to normal processing
|
|
||||||
}
|
|
||||||
|
|
||||||
// Process message normally
|
|
||||||
await this.processMessage(message);
|
|
||||||
|
|
||||||
// Check for shutdown after processing
|
|
||||||
return this.state === SESSION_STATE.SHUTTING_DOWN;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Main entry point - run the session
|
|
||||||
*/
|
|
||||||
async run(): Promise<void> {
|
|
||||||
try {
|
|
||||||
if (this.debugMode) {
|
|
||||||
console.error('[SessionManager] Starting session', this.sessionId);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Process initial prompt if provided
|
|
||||||
if (this.initialPrompt !== null) {
|
|
||||||
const shouldExit = await this.processSingleMessage(this.initialPrompt);
|
|
||||||
if (shouldExit) {
|
|
||||||
await this.shutdown();
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Process messages from stream
|
|
||||||
for await (const message of this.inputReader.read()) {
|
|
||||||
const shouldExit = await this.processSingleMessage(message);
|
|
||||||
if (shouldExit) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Stream closed, shutdown
|
|
||||||
await this.shutdown();
|
|
||||||
} catch (error) {
|
|
||||||
if (this.debugMode) {
|
|
||||||
console.error('[SessionManager] Error:', error);
|
|
||||||
}
|
|
||||||
await this.shutdown();
|
|
||||||
throw error;
|
|
||||||
} finally {
|
|
||||||
// Ensure signal handlers are always cleaned up even if shutdown wasn't called
|
|
||||||
this.cleanupSignalHandlers();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private ensureControlSystem(): void {
|
private ensureControlSystem(): void {
|
||||||
if (this.controlContext && this.dispatcher && this.controlService) {
|
if (this.controlContext && this.dispatcher && this.controlService) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
// The control system follows a strict three-layer architecture:
|
|
||||||
// 1. ControlContext (shared session state)
|
|
||||||
// 2. ControlDispatcher (protocol routing SDK ↔ CLI)
|
|
||||||
// 3. ControlService (programmatic API for CLI runtime)
|
|
||||||
//
|
|
||||||
// Application code MUST interact with the control plane exclusively through
|
|
||||||
// ControlService. ControlDispatcher is reserved for protocol-level message
|
|
||||||
// routing and should never be used directly outside of this file.
|
|
||||||
this.controlContext = new ControlContext({
|
this.controlContext = new ControlContext({
|
||||||
config: this.config,
|
config: this.config,
|
||||||
streamJson: this.outputAdapter,
|
streamJson: this.outputAdapter,
|
||||||
@@ -299,25 +106,25 @@ class SessionManager {
|
|||||||
| CLIControlResponse
|
| CLIControlResponse
|
||||||
| ControlCancelRequest,
|
| ControlCancelRequest,
|
||||||
): Promise<boolean> {
|
): Promise<boolean> {
|
||||||
const routed = this.route(message);
|
if (isControlRequest(message)) {
|
||||||
|
const request = message as CLIControlRequest;
|
||||||
if (routed.type === 'control_request') {
|
|
||||||
const request = routed.message as CLIControlRequest;
|
|
||||||
this.controlSystemEnabled = true;
|
this.controlSystemEnabled = true;
|
||||||
this.ensureControlSystem();
|
this.ensureControlSystem();
|
||||||
if (request.request.subtype === 'initialize') {
|
if (request.request.subtype === 'initialize') {
|
||||||
await this.dispatcher?.dispatch(request);
|
await this.dispatcher?.dispatch(request);
|
||||||
this.state = SESSION_STATE.IDLE;
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
return false;
|
if (this.debugMode) {
|
||||||
|
console.error(
|
||||||
|
'[Session] Ignoring non-initialize control request during initialization',
|
||||||
|
);
|
||||||
|
}
|
||||||
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (routed.type === 'user') {
|
if (isCLIUserMessage(message)) {
|
||||||
this.controlSystemEnabled = false;
|
this.controlSystemEnabled = false;
|
||||||
this.state = SESSION_STATE.PROCESSING_QUERY;
|
this.enqueueUserMessage(message as CLIUserMessage);
|
||||||
this.userMessageQueue.push(routed.message as CLIUserMessage);
|
|
||||||
await this.processUserMessageQueue();
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -325,237 +132,43 @@ class SessionManager {
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
private async handleControlRequest(
|
||||||
* Process a single message from the stream
|
request: CLIControlRequest,
|
||||||
*/
|
|
||||||
private async processMessage(
|
|
||||||
message:
|
|
||||||
| CLIMessage
|
|
||||||
| CLIControlRequest
|
|
||||||
| CLIControlResponse
|
|
||||||
| ControlCancelRequest,
|
|
||||||
): Promise<void> {
|
): Promise<void> {
|
||||||
const routed = this.route(message);
|
|
||||||
|
|
||||||
if (this.debugMode) {
|
|
||||||
console.error(
|
|
||||||
`[SessionManager] State: ${this.state}, Message type: ${routed.type}`,
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
switch (this.state) {
|
|
||||||
case SESSION_STATE.INITIALIZING:
|
|
||||||
await this.handleInitializingState(routed);
|
|
||||||
break;
|
|
||||||
|
|
||||||
case SESSION_STATE.IDLE:
|
|
||||||
await this.handleIdleState(routed);
|
|
||||||
break;
|
|
||||||
|
|
||||||
case SESSION_STATE.PROCESSING_QUERY:
|
|
||||||
await this.handleProcessingState(routed);
|
|
||||||
break;
|
|
||||||
|
|
||||||
case SESSION_STATE.SHUTTING_DOWN:
|
|
||||||
// Ignore all messages during shutdown
|
|
||||||
break;
|
|
||||||
|
|
||||||
default: {
|
|
||||||
// Exhaustive check
|
|
||||||
const _exhaustiveCheck: never = this.state;
|
|
||||||
if (this.debugMode) {
|
|
||||||
console.error('[SessionManager] Unknown state:', _exhaustiveCheck);
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Handle messages in initializing state
|
|
||||||
*/
|
|
||||||
private async handleInitializingState(routed: RoutedMessage): Promise<void> {
|
|
||||||
if (routed.type === 'control_request') {
|
|
||||||
const request = routed.message as CLIControlRequest;
|
|
||||||
const dispatcher = this.getDispatcher();
|
|
||||||
if (!dispatcher) {
|
|
||||||
if (this.debugMode) {
|
|
||||||
console.error(
|
|
||||||
'[SessionManager] Control request received before control system initialization',
|
|
||||||
);
|
|
||||||
}
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
if (request.request.subtype === 'initialize') {
|
|
||||||
await dispatcher.dispatch(request);
|
|
||||||
this.state = SESSION_STATE.IDLE;
|
|
||||||
if (this.debugMode) {
|
|
||||||
console.error('[SessionManager] Initialized, transitioning to idle');
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if (this.debugMode) {
|
|
||||||
console.error(
|
|
||||||
'[SessionManager] Ignoring non-initialize control request during initialization',
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if (this.debugMode) {
|
|
||||||
console.error(
|
|
||||||
'[SessionManager] Ignoring non-control message during initialization',
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Handle messages in idle state
|
|
||||||
*/
|
|
||||||
private async handleIdleState(routed: RoutedMessage): Promise<void> {
|
|
||||||
const dispatcher = this.getDispatcher();
|
const dispatcher = this.getDispatcher();
|
||||||
if (routed.type === 'control_request') {
|
if (!dispatcher) {
|
||||||
if (!dispatcher) {
|
|
||||||
if (this.debugMode) {
|
|
||||||
console.error('[SessionManager] Ignoring control request (disabled)');
|
|
||||||
}
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
const request = routed.message as CLIControlRequest;
|
|
||||||
await dispatcher.dispatch(request);
|
|
||||||
// Stay in idle state
|
|
||||||
} else if (routed.type === 'control_response') {
|
|
||||||
if (!dispatcher) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
const response = routed.message as CLIControlResponse;
|
|
||||||
dispatcher.handleControlResponse(response);
|
|
||||||
// Stay in idle state
|
|
||||||
} else if (routed.type === 'control_cancel') {
|
|
||||||
if (!dispatcher) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
const cancelRequest = routed.message as ControlCancelRequest;
|
|
||||||
dispatcher.handleCancel(cancelRequest.request_id);
|
|
||||||
} else if (routed.type === 'user') {
|
|
||||||
const userMessage = routed.message as CLIUserMessage;
|
|
||||||
this.userMessageQueue.push(userMessage);
|
|
||||||
// Start processing queue
|
|
||||||
await this.processUserMessageQueue();
|
|
||||||
} else {
|
|
||||||
if (this.debugMode) {
|
if (this.debugMode) {
|
||||||
console.error(
|
console.error('[Session] Control system not enabled');
|
||||||
'[SessionManager] Ignoring message type in idle state:',
|
|
||||||
routed.type,
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Handle messages in processing state
|
|
||||||
*/
|
|
||||||
private async handleProcessingState(routed: RoutedMessage): Promise<void> {
|
|
||||||
const dispatcher = this.getDispatcher();
|
|
||||||
if (routed.type === 'control_request') {
|
|
||||||
if (!dispatcher) {
|
|
||||||
if (this.debugMode) {
|
|
||||||
console.error(
|
|
||||||
'[SessionManager] Control request ignored during processing (disabled)',
|
|
||||||
);
|
|
||||||
}
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
const request = routed.message as CLIControlRequest;
|
|
||||||
await dispatcher.dispatch(request);
|
|
||||||
// Continue processing
|
|
||||||
} else if (routed.type === 'control_response') {
|
|
||||||
if (!dispatcher) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
const response = routed.message as CLIControlResponse;
|
|
||||||
dispatcher.handleControlResponse(response);
|
|
||||||
// Continue processing
|
|
||||||
} else if (routed.type === 'user') {
|
|
||||||
// Enqueue for later
|
|
||||||
const userMessage = routed.message as CLIUserMessage;
|
|
||||||
this.userMessageQueue.push(userMessage);
|
|
||||||
if (this.debugMode) {
|
|
||||||
console.error(
|
|
||||||
'[SessionManager] Enqueued user message during processing',
|
|
||||||
);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if (this.debugMode) {
|
|
||||||
console.error(
|
|
||||||
'[SessionManager] Ignoring message type during processing:',
|
|
||||||
routed.type,
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Process user message queue (FIFO)
|
|
||||||
*/
|
|
||||||
private async processUserMessageQueue(): Promise<void> {
|
|
||||||
while (
|
|
||||||
this.userMessageQueue.length > 0 &&
|
|
||||||
!this.abortController.signal.aborted
|
|
||||||
) {
|
|
||||||
this.state = SESSION_STATE.PROCESSING_QUERY;
|
|
||||||
const userMessage = this.userMessageQueue.shift()!;
|
|
||||||
|
|
||||||
try {
|
|
||||||
await this.processUserMessage(userMessage);
|
|
||||||
} catch (error) {
|
|
||||||
if (this.debugMode) {
|
|
||||||
console.error(
|
|
||||||
'[SessionManager] Error processing user message:',
|
|
||||||
error,
|
|
||||||
);
|
|
||||||
}
|
|
||||||
// Send error result
|
|
||||||
this.emitErrorResult(error);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// If control system is disabled (single-query mode) and queue is empty,
|
|
||||||
// automatically shutdown instead of returning to idle
|
|
||||||
if (
|
|
||||||
!this.abortController.signal.aborted &&
|
|
||||||
this.state === SESSION_STATE.PROCESSING_QUERY &&
|
|
||||||
this.controlSystemEnabled === false &&
|
|
||||||
this.userMessageQueue.length === 0
|
|
||||||
) {
|
|
||||||
if (this.debugMode) {
|
|
||||||
console.error(
|
|
||||||
'[SessionManager] Single-query mode: queue processed, shutting down',
|
|
||||||
);
|
|
||||||
}
|
|
||||||
this.state = SESSION_STATE.SHUTTING_DOWN;
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Return to idle after processing queue (for multi-query mode with control system)
|
await dispatcher.dispatch(request);
|
||||||
if (
|
}
|
||||||
!this.abortController.signal.aborted &&
|
|
||||||
this.state === SESSION_STATE.PROCESSING_QUERY
|
private handleControlResponse(response: CLIControlResponse): void {
|
||||||
) {
|
const dispatcher = this.getDispatcher();
|
||||||
this.state = SESSION_STATE.IDLE;
|
if (!dispatcher) {
|
||||||
if (this.debugMode) {
|
return;
|
||||||
console.error('[SessionManager] Queue processed, returning to idle');
|
}
|
||||||
}
|
|
||||||
}
|
dispatcher.handleControlResponse(response);
|
||||||
|
}
|
||||||
|
|
||||||
|
private handleControlCancel(cancelRequest: ControlCancelRequest): void {
|
||||||
|
const dispatcher = this.getDispatcher();
|
||||||
|
if (!dispatcher) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
dispatcher.handleCancel(cancelRequest.request_id);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Process a single user message
|
|
||||||
*/
|
|
||||||
private async processUserMessage(userMessage: CLIUserMessage): Promise<void> {
|
private async processUserMessage(userMessage: CLIUserMessage): Promise<void> {
|
||||||
const input = extractUserMessageText(userMessage);
|
const input = extractUserMessageText(userMessage);
|
||||||
if (!input) {
|
if (!input) {
|
||||||
if (this.debugMode) {
|
if (this.debugMode) {
|
||||||
console.error('[SessionManager] No text content in user message');
|
console.error('[Session] No text content in user message');
|
||||||
}
|
}
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@@ -575,16 +188,56 @@ class SessionManager {
|
|||||||
},
|
},
|
||||||
);
|
);
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
// Error already handled by runNonInteractive via adapter.emitResult
|
|
||||||
if (this.debugMode) {
|
if (this.debugMode) {
|
||||||
console.error('[SessionManager] Query execution error:', error);
|
console.error('[Session] Query execution error:', error);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
private async processUserMessageQueue(): Promise<void> {
|
||||||
* Send tool results as user message
|
if (this.isShuttingDown || this.abortController.signal.aborted) {
|
||||||
*/
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
while (
|
||||||
|
this.userMessageQueue.length > 0 &&
|
||||||
|
!this.isShuttingDown &&
|
||||||
|
!this.abortController.signal.aborted
|
||||||
|
) {
|
||||||
|
const userMessage = this.userMessageQueue.shift()!;
|
||||||
|
try {
|
||||||
|
await this.processUserMessage(userMessage);
|
||||||
|
} catch (error) {
|
||||||
|
if (this.debugMode) {
|
||||||
|
console.error('[Session] Error processing user message:', error);
|
||||||
|
}
|
||||||
|
this.emitErrorResult(error);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private enqueueUserMessage(userMessage: CLIUserMessage): void {
|
||||||
|
this.userMessageQueue.push(userMessage);
|
||||||
|
this.ensureProcessingStarted();
|
||||||
|
}
|
||||||
|
|
||||||
|
private ensureProcessingStarted(): void {
|
||||||
|
if (this.processingPromise) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
this.processingPromise = this.processUserMessageQueue().finally(() => {
|
||||||
|
this.processingPromise = null;
|
||||||
|
if (
|
||||||
|
this.userMessageQueue.length > 0 &&
|
||||||
|
!this.isShuttingDown &&
|
||||||
|
!this.abortController.signal.aborted
|
||||||
|
) {
|
||||||
|
this.ensureProcessingStarted();
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
private emitErrorResult(
|
private emitErrorResult(
|
||||||
error: unknown,
|
error: unknown,
|
||||||
numTurns: number = 0,
|
numTurns: number = 0,
|
||||||
@@ -602,52 +255,51 @@ class SessionManager {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Handle interrupt control request
|
|
||||||
*/
|
|
||||||
private handleInterrupt(): void {
|
private handleInterrupt(): void {
|
||||||
if (this.debugMode) {
|
if (this.debugMode) {
|
||||||
console.error('[SessionManager] Interrupt requested');
|
console.error('[Session] Interrupt requested');
|
||||||
}
|
|
||||||
// Abort current query if processing
|
|
||||||
if (this.state === SESSION_STATE.PROCESSING_QUERY) {
|
|
||||||
this.abortController.abort();
|
|
||||||
this.abortController = new AbortController(); // Create new controller for next query
|
|
||||||
}
|
}
|
||||||
|
this.abortController.abort();
|
||||||
|
this.abortController = new AbortController();
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Setup signal handlers for graceful shutdown
|
|
||||||
*/
|
|
||||||
private setupSignalHandlers(): void {
|
private setupSignalHandlers(): void {
|
||||||
this.shutdownHandler = () => {
|
this.shutdownHandler = () => {
|
||||||
if (this.debugMode) {
|
if (this.debugMode) {
|
||||||
console.error('[SessionManager] Shutdown signal received');
|
console.error('[Session] Shutdown signal received');
|
||||||
}
|
}
|
||||||
|
this.isShuttingDown = true;
|
||||||
this.abortController.abort();
|
this.abortController.abort();
|
||||||
this.state = SESSION_STATE.SHUTTING_DOWN;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
process.on('SIGINT', this.shutdownHandler);
|
process.on('SIGINT', this.shutdownHandler);
|
||||||
process.on('SIGTERM', this.shutdownHandler);
|
process.on('SIGTERM', this.shutdownHandler);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Shutdown session and cleanup resources
|
|
||||||
*/
|
|
||||||
private async shutdown(): Promise<void> {
|
private async shutdown(): Promise<void> {
|
||||||
if (this.debugMode) {
|
if (this.debugMode) {
|
||||||
console.error('[SessionManager] Shutting down');
|
console.error('[Session] Shutting down');
|
||||||
|
}
|
||||||
|
|
||||||
|
this.isShuttingDown = true;
|
||||||
|
|
||||||
|
if (this.processingPromise) {
|
||||||
|
try {
|
||||||
|
await this.processingPromise;
|
||||||
|
} catch (error) {
|
||||||
|
if (this.debugMode) {
|
||||||
|
console.error(
|
||||||
|
'[Session] Error waiting for processing to complete:',
|
||||||
|
error,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
this.state = SESSION_STATE.SHUTTING_DOWN;
|
|
||||||
this.dispatcher?.shutdown();
|
this.dispatcher?.shutdown();
|
||||||
this.cleanupSignalHandlers();
|
this.cleanupSignalHandlers();
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Remove signal handlers to prevent memory leaks
|
|
||||||
*/
|
|
||||||
private cleanupSignalHandlers(): void {
|
private cleanupSignalHandlers(): void {
|
||||||
if (this.shutdownHandler) {
|
if (this.shutdownHandler) {
|
||||||
process.removeListener('SIGINT', this.shutdownHandler);
|
process.removeListener('SIGINT', this.shutdownHandler);
|
||||||
@@ -655,6 +307,94 @@ class SessionManager {
|
|||||||
this.shutdownHandler = null;
|
this.shutdownHandler = null;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async run(): Promise<void> {
|
||||||
|
try {
|
||||||
|
if (this.debugMode) {
|
||||||
|
console.error('[Session] Starting session', this.sessionId);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (this.initialPrompt !== null) {
|
||||||
|
const handled = await this.handleFirstMessage(this.initialPrompt);
|
||||||
|
if (handled && this.isShuttingDown) {
|
||||||
|
await this.shutdown();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
for await (const message of this.inputReader.read()) {
|
||||||
|
if (this.abortController.signal.aborted) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (this.controlSystemEnabled === null) {
|
||||||
|
const handled = await this.handleFirstMessage(message);
|
||||||
|
if (handled) {
|
||||||
|
if (this.isShuttingDown) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (isControlRequest(message)) {
|
||||||
|
await this.handleControlRequest(message as CLIControlRequest);
|
||||||
|
} else if (isControlResponse(message)) {
|
||||||
|
this.handleControlResponse(message as CLIControlResponse);
|
||||||
|
} else if (isControlCancel(message)) {
|
||||||
|
this.handleControlCancel(message as ControlCancelRequest);
|
||||||
|
} else if (isCLIUserMessage(message)) {
|
||||||
|
this.enqueueUserMessage(message as CLIUserMessage);
|
||||||
|
} else if (this.debugMode) {
|
||||||
|
if (
|
||||||
|
!isCLIAssistantMessage(message) &&
|
||||||
|
!isCLISystemMessage(message) &&
|
||||||
|
!isCLIResultMessage(message) &&
|
||||||
|
!isCLIPartialAssistantMessage(message)
|
||||||
|
) {
|
||||||
|
console.error(
|
||||||
|
'[Session] Unknown message type:',
|
||||||
|
JSON.stringify(message, null, 2),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (this.isShuttingDown) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} catch (streamError) {
|
||||||
|
if (this.debugMode) {
|
||||||
|
console.error('[Session] Stream reading error:', streamError);
|
||||||
|
}
|
||||||
|
throw streamError;
|
||||||
|
}
|
||||||
|
|
||||||
|
while (this.processingPromise) {
|
||||||
|
if (this.debugMode) {
|
||||||
|
console.error('[Session] Waiting for final processing to complete');
|
||||||
|
}
|
||||||
|
try {
|
||||||
|
await this.processingPromise;
|
||||||
|
} catch (error) {
|
||||||
|
if (this.debugMode) {
|
||||||
|
console.error('[Session] Error in final processing:', error);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
await this.shutdown();
|
||||||
|
} catch (error) {
|
||||||
|
if (this.debugMode) {
|
||||||
|
console.error('[Session] Error:', error);
|
||||||
|
}
|
||||||
|
await this.shutdown();
|
||||||
|
throw error;
|
||||||
|
} finally {
|
||||||
|
this.cleanupSignalHandlers();
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
function extractUserMessageText(message: CLIUserMessage): string | null {
|
function extractUserMessageText(message: CLIUserMessage): string | null {
|
||||||
@@ -682,12 +422,6 @@ function extractUserMessageText(message: CLIUserMessage): string | null {
|
|||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Entry point for stream-json mode
|
|
||||||
*
|
|
||||||
* @param config - Configuration object
|
|
||||||
* @param input - Optional initial prompt input to process before reading from stream
|
|
||||||
*/
|
|
||||||
export async function runNonInteractiveStreamJson(
|
export async function runNonInteractiveStreamJson(
|
||||||
config: Config,
|
config: Config,
|
||||||
input: string,
|
input: string,
|
||||||
@@ -698,7 +432,6 @@ export async function runNonInteractiveStreamJson(
|
|||||||
consolePatcher.patch();
|
consolePatcher.patch();
|
||||||
|
|
||||||
try {
|
try {
|
||||||
// Create initial user message from prompt input if provided
|
|
||||||
let initialPrompt: CLIUserMessage | undefined = undefined;
|
let initialPrompt: CLIUserMessage | undefined = undefined;
|
||||||
if (input && input.trim().length > 0) {
|
if (input && input.trim().length > 0) {
|
||||||
const sessionId = config.getSessionId();
|
const sessionId = config.getSessionId();
|
||||||
@@ -713,7 +446,7 @@ export async function runNonInteractiveStreamJson(
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
const manager = new SessionManager(config, initialPrompt);
|
const manager = new Session(config, initialPrompt);
|
||||||
await manager.run();
|
await manager.run();
|
||||||
} finally {
|
} finally {
|
||||||
consolePatcher.cleanup();
|
consolePatcher.cleanup();
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ import {
|
|||||||
FatalInputError,
|
FatalInputError,
|
||||||
promptIdContext,
|
promptIdContext,
|
||||||
OutputFormat,
|
OutputFormat,
|
||||||
|
InputFormat,
|
||||||
uiTelemetryService,
|
uiTelemetryService,
|
||||||
} from '@qwen-code/qwen-code-core';
|
} from '@qwen-code/qwen-code-core';
|
||||||
import type { Content, Part, PartListUnion } from '@google/genai';
|
import type { Content, Part, PartListUnion } from '@google/genai';
|
||||||
@@ -254,12 +255,18 @@ export async function runNonInteractive(
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const toolCallUpdateCallback = options.controlService
|
|
||||||
? options.controlService.permission.getToolCallUpdateCallback()
|
|
||||||
: undefined;
|
|
||||||
*/
|
*/
|
||||||
|
|
||||||
|
// Get toolCallUpdateCallback for SDK mode (stream-json)
|
||||||
|
const inputFormat =
|
||||||
|
typeof config.getInputFormat === 'function'
|
||||||
|
? config.getInputFormat()
|
||||||
|
: InputFormat.TEXT;
|
||||||
|
const toolCallUpdateCallback =
|
||||||
|
inputFormat === InputFormat.STREAM_JSON && options.controlService
|
||||||
|
? options.controlService.permission.getToolCallUpdateCallback()
|
||||||
|
: undefined;
|
||||||
|
|
||||||
// Only pass outputUpdateHandler for Task tool
|
// Only pass outputUpdateHandler for Task tool
|
||||||
const isTaskTool = finalRequestInfo.name === 'task';
|
const isTaskTool = finalRequestInfo.name === 'task';
|
||||||
const taskToolProgress = isTaskTool
|
const taskToolProgress = isTaskTool
|
||||||
@@ -277,13 +284,13 @@ export async function runNonInteractive(
|
|||||||
isTaskTool && taskToolProgressHandler
|
isTaskTool && taskToolProgressHandler
|
||||||
? {
|
? {
|
||||||
outputUpdateHandler: taskToolProgressHandler,
|
outputUpdateHandler: taskToolProgressHandler,
|
||||||
/*
|
onToolCallsUpdate: toolCallUpdateCallback,
|
||||||
toolCallUpdateCallback
|
|
||||||
? { onToolCallsUpdate: toolCallUpdateCallback }
|
|
||||||
: undefined,
|
|
||||||
*/
|
|
||||||
}
|
}
|
||||||
: undefined,
|
: toolCallUpdateCallback
|
||||||
|
? {
|
||||||
|
onToolCallsUpdate: toolCallUpdateCallback,
|
||||||
|
}
|
||||||
|
: undefined,
|
||||||
);
|
);
|
||||||
|
|
||||||
// Note: In JSON mode, subagent messages are automatically added to the main
|
// Note: In JSON mode, subagent messages are automatically added to the main
|
||||||
@@ -303,9 +310,6 @@ export async function runNonInteractive(
|
|||||||
? toolResponse.resultDisplay
|
? toolResponse.resultDisplay
|
||||||
: undefined,
|
: undefined,
|
||||||
);
|
);
|
||||||
// Note: We no longer emit a separate system message for tool errors
|
|
||||||
// in JSON/STREAM_JSON mode, as the error is already captured in the
|
|
||||||
// tool_result block with is_error=true.
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (adapter) {
|
if (adapter) {
|
||||||
|
|||||||
@@ -909,7 +909,10 @@ export class CoreToolScheduler {
|
|||||||
|
|
||||||
async handleConfirmationResponse(
|
async handleConfirmationResponse(
|
||||||
callId: string,
|
callId: string,
|
||||||
originalOnConfirm: (outcome: ToolConfirmationOutcome) => Promise<void>,
|
originalOnConfirm: (
|
||||||
|
outcome: ToolConfirmationOutcome,
|
||||||
|
payload?: ToolConfirmationPayload,
|
||||||
|
) => Promise<void>,
|
||||||
outcome: ToolConfirmationOutcome,
|
outcome: ToolConfirmationOutcome,
|
||||||
signal: AbortSignal,
|
signal: AbortSignal,
|
||||||
payload?: ToolConfirmationPayload,
|
payload?: ToolConfirmationPayload,
|
||||||
@@ -918,9 +921,7 @@ export class CoreToolScheduler {
|
|||||||
(c) => c.request.callId === callId && c.status === 'awaiting_approval',
|
(c) => c.request.callId === callId && c.status === 'awaiting_approval',
|
||||||
);
|
);
|
||||||
|
|
||||||
if (toolCall && toolCall.status === 'awaiting_approval') {
|
await originalOnConfirm(outcome, payload);
|
||||||
await originalOnConfirm(outcome);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (outcome === ToolConfirmationOutcome.ProceedAlways) {
|
if (outcome === ToolConfirmationOutcome.ProceedAlways) {
|
||||||
await this.autoApproveCompatiblePendingTools(signal, callId);
|
await this.autoApproveCompatiblePendingTools(signal, callId);
|
||||||
@@ -929,11 +930,10 @@ export class CoreToolScheduler {
|
|||||||
this.setToolCallOutcome(callId, outcome);
|
this.setToolCallOutcome(callId, outcome);
|
||||||
|
|
||||||
if (outcome === ToolConfirmationOutcome.Cancel || signal.aborted) {
|
if (outcome === ToolConfirmationOutcome.Cancel || signal.aborted) {
|
||||||
this.setStatusInternal(
|
// Use custom cancel message from payload if provided, otherwise use default
|
||||||
callId,
|
const cancelMessage =
|
||||||
'cancelled',
|
payload?.cancelMessage || 'User did not allow tool call';
|
||||||
'User did not allow tool call',
|
this.setStatusInternal(callId, 'cancelled', cancelMessage);
|
||||||
);
|
|
||||||
} else if (outcome === ToolConfirmationOutcome.ModifyWithEditor) {
|
} else if (outcome === ToolConfirmationOutcome.ModifyWithEditor) {
|
||||||
const waitingToolCall = toolCall as WaitingToolCall;
|
const waitingToolCall = toolCall as WaitingToolCall;
|
||||||
if (isModifiableDeclarativeTool(waitingToolCall.tool)) {
|
if (isModifiableDeclarativeTool(waitingToolCall.tool)) {
|
||||||
@@ -991,7 +991,8 @@ export class CoreToolScheduler {
|
|||||||
): Promise<void> {
|
): Promise<void> {
|
||||||
if (
|
if (
|
||||||
toolCall.confirmationDetails.type !== 'edit' ||
|
toolCall.confirmationDetails.type !== 'edit' ||
|
||||||
!isModifiableDeclarativeTool(toolCall.tool)
|
!isModifiableDeclarativeTool(toolCall.tool) ||
|
||||||
|
!payload.newContent
|
||||||
) {
|
) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ import type {
|
|||||||
ToolInvocation,
|
ToolInvocation,
|
||||||
ToolMcpConfirmationDetails,
|
ToolMcpConfirmationDetails,
|
||||||
ToolResult,
|
ToolResult,
|
||||||
|
ToolConfirmationPayload,
|
||||||
} from './tools.js';
|
} from './tools.js';
|
||||||
import {
|
import {
|
||||||
BaseDeclarativeTool,
|
BaseDeclarativeTool,
|
||||||
@@ -98,7 +99,10 @@ class DiscoveredMCPToolInvocation extends BaseToolInvocation<
|
|||||||
serverName: this.serverName,
|
serverName: this.serverName,
|
||||||
toolName: this.serverToolName, // Display original tool name in confirmation
|
toolName: this.serverToolName, // Display original tool name in confirmation
|
||||||
toolDisplayName: this.displayName, // Display global registry name exposed to model and user
|
toolDisplayName: this.displayName, // Display global registry name exposed to model and user
|
||||||
onConfirm: async (outcome: ToolConfirmationOutcome) => {
|
onConfirm: async (
|
||||||
|
outcome: ToolConfirmationOutcome,
|
||||||
|
_payload?: ToolConfirmationPayload,
|
||||||
|
) => {
|
||||||
if (outcome === ToolConfirmationOutcome.ProceedAlwaysServer) {
|
if (outcome === ToolConfirmationOutcome.ProceedAlwaysServer) {
|
||||||
DiscoveredMCPToolInvocation.allowlist.add(serverAllowListKey);
|
DiscoveredMCPToolInvocation.allowlist.add(serverAllowListKey);
|
||||||
} else if (outcome === ToolConfirmationOutcome.ProceedAlwaysTool) {
|
} else if (outcome === ToolConfirmationOutcome.ProceedAlwaysTool) {
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ import type {
|
|||||||
ToolResultDisplay,
|
ToolResultDisplay,
|
||||||
ToolCallConfirmationDetails,
|
ToolCallConfirmationDetails,
|
||||||
ToolExecuteConfirmationDetails,
|
ToolExecuteConfirmationDetails,
|
||||||
|
ToolConfirmationPayload,
|
||||||
} from './tools.js';
|
} from './tools.js';
|
||||||
import {
|
import {
|
||||||
BaseDeclarativeTool,
|
BaseDeclarativeTool,
|
||||||
@@ -102,7 +103,10 @@ export class ShellToolInvocation extends BaseToolInvocation<
|
|||||||
title: 'Confirm Shell Command',
|
title: 'Confirm Shell Command',
|
||||||
command: this.params.command,
|
command: this.params.command,
|
||||||
rootCommand: commandsToConfirm.join(', '),
|
rootCommand: commandsToConfirm.join(', '),
|
||||||
onConfirm: async (outcome: ToolConfirmationOutcome) => {
|
onConfirm: async (
|
||||||
|
outcome: ToolConfirmationOutcome,
|
||||||
|
_payload?: ToolConfirmationPayload,
|
||||||
|
) => {
|
||||||
if (outcome === ToolConfirmationOutcome.ProceedAlways) {
|
if (outcome === ToolConfirmationOutcome.ProceedAlways) {
|
||||||
commandsToConfirm.forEach((command) => this.allowlist.add(command));
|
commandsToConfirm.forEach((command) => this.allowlist.add(command));
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -531,13 +531,18 @@ export interface ToolEditConfirmationDetails {
|
|||||||
export interface ToolConfirmationPayload {
|
export interface ToolConfirmationPayload {
|
||||||
// used to override `modifiedProposedContent` for modifiable tools in the
|
// used to override `modifiedProposedContent` for modifiable tools in the
|
||||||
// inline modify flow
|
// inline modify flow
|
||||||
newContent: string;
|
newContent?: string;
|
||||||
|
// used to provide custom cancellation message when outcome is Cancel
|
||||||
|
cancelMessage?: string;
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface ToolExecuteConfirmationDetails {
|
export interface ToolExecuteConfirmationDetails {
|
||||||
type: 'exec';
|
type: 'exec';
|
||||||
title: string;
|
title: string;
|
||||||
onConfirm: (outcome: ToolConfirmationOutcome) => Promise<void>;
|
onConfirm: (
|
||||||
|
outcome: ToolConfirmationOutcome,
|
||||||
|
payload?: ToolConfirmationPayload,
|
||||||
|
) => Promise<void>;
|
||||||
command: string;
|
command: string;
|
||||||
rootCommand: string;
|
rootCommand: string;
|
||||||
}
|
}
|
||||||
@@ -548,7 +553,10 @@ export interface ToolMcpConfirmationDetails {
|
|||||||
serverName: string;
|
serverName: string;
|
||||||
toolName: string;
|
toolName: string;
|
||||||
toolDisplayName: string;
|
toolDisplayName: string;
|
||||||
onConfirm: (outcome: ToolConfirmationOutcome) => Promise<void>;
|
onConfirm: (
|
||||||
|
outcome: ToolConfirmationOutcome,
|
||||||
|
payload?: ToolConfirmationPayload,
|
||||||
|
) => Promise<void>;
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface ToolInfoConfirmationDetails {
|
export interface ToolInfoConfirmationDetails {
|
||||||
@@ -573,6 +581,11 @@ export interface ToolPlanConfirmationDetails {
|
|||||||
onConfirm: (outcome: ToolConfirmationOutcome) => Promise<void>;
|
onConfirm: (outcome: ToolConfirmationOutcome) => Promise<void>;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* TODO:
|
||||||
|
* 1. support explicit denied outcome
|
||||||
|
* 2. support proceed with modified input
|
||||||
|
*/
|
||||||
export enum ToolConfirmationOutcome {
|
export enum ToolConfirmationOutcome {
|
||||||
ProceedOnce = 'proceed_once',
|
ProceedOnce = 'proceed_once',
|
||||||
ProceedAlways = 'proceed_always',
|
ProceedAlways = 'proceed_always',
|
||||||
|
|||||||
@@ -12,7 +12,6 @@ import type {
|
|||||||
CLIControlRequest,
|
CLIControlRequest,
|
||||||
CLIControlResponse,
|
CLIControlResponse,
|
||||||
ControlCancelRequest,
|
ControlCancelRequest,
|
||||||
PermissionApproval,
|
|
||||||
PermissionSuggestion,
|
PermissionSuggestion,
|
||||||
} from '../types/protocol.js';
|
} from '../types/protocol.js';
|
||||||
import {
|
import {
|
||||||
@@ -299,7 +298,7 @@ export class Query implements AsyncIterable<CLIMessage> {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (process.env['DEBUG_SDK']) {
|
if (process.env['DEBUG']) {
|
||||||
console.warn('[Query] Unknown message type:', message);
|
console.warn('[Query] Unknown message type:', message);
|
||||||
}
|
}
|
||||||
this.inputStream.enqueue(message as CLIMessage);
|
this.inputStream.enqueue(message as CLIMessage);
|
||||||
@@ -320,12 +319,12 @@ export class Query implements AsyncIterable<CLIMessage> {
|
|||||||
|
|
||||||
switch (payload.subtype) {
|
switch (payload.subtype) {
|
||||||
case 'can_use_tool':
|
case 'can_use_tool':
|
||||||
response = (await this.handlePermissionRequest(
|
response = await this.handlePermissionRequest(
|
||||||
payload.tool_name,
|
payload.tool_name,
|
||||||
payload.input as Record<string, unknown>,
|
payload.input as Record<string, unknown>,
|
||||||
payload.permission_suggestions,
|
payload.permission_suggestions,
|
||||||
requestAbortController.signal,
|
requestAbortController.signal,
|
||||||
)) as unknown as Record<string, unknown>;
|
);
|
||||||
break;
|
break;
|
||||||
|
|
||||||
case 'mcp_message':
|
case 'mcp_message':
|
||||||
@@ -360,15 +359,17 @@ export class Query implements AsyncIterable<CLIMessage> {
|
|||||||
|
|
||||||
/**
|
/**
|
||||||
* Handle permission request (can_use_tool)
|
* Handle permission request (can_use_tool)
|
||||||
|
* Converts PermissionResult to CLI-expected format: { behavior: 'allow', updatedInput: ... } or { behavior: 'deny', message: ... }
|
||||||
*/
|
*/
|
||||||
private async handlePermissionRequest(
|
private async handlePermissionRequest(
|
||||||
toolName: string,
|
toolName: string,
|
||||||
toolInput: Record<string, unknown>,
|
toolInput: Record<string, unknown>,
|
||||||
permissionSuggestions: PermissionSuggestion[] | null,
|
permissionSuggestions: PermissionSuggestion[] | null,
|
||||||
signal: AbortSignal,
|
signal: AbortSignal,
|
||||||
): Promise<PermissionApproval> {
|
): Promise<Record<string, unknown>> {
|
||||||
|
/* Default deny all wildcard tool requests */
|
||||||
if (!this.options.canUseTool) {
|
if (!this.options.canUseTool) {
|
||||||
return { allowed: true };
|
return { behavior: 'deny', message: 'Denied' };
|
||||||
}
|
}
|
||||||
|
|
||||||
try {
|
try {
|
||||||
@@ -390,21 +391,51 @@ export class Query implements AsyncIterable<CLIMessage> {
|
|||||||
timeoutPromise,
|
timeoutPromise,
|
||||||
]);
|
]);
|
||||||
|
|
||||||
|
// Handle boolean return (backward compatibility)
|
||||||
if (typeof result === 'boolean') {
|
if (typeof result === 'boolean') {
|
||||||
return { allowed: result };
|
return result
|
||||||
|
? { behavior: 'allow', updatedInput: toolInput }
|
||||||
|
: { behavior: 'deny', message: 'Denied' };
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle PermissionResult format
|
||||||
|
const permissionResult = result as {
|
||||||
|
behavior: 'allow' | 'deny';
|
||||||
|
updatedInput?: Record<string, unknown>;
|
||||||
|
message?: string;
|
||||||
|
interrupt?: boolean;
|
||||||
|
};
|
||||||
|
|
||||||
|
if (permissionResult.behavior === 'allow') {
|
||||||
|
return {
|
||||||
|
behavior: 'allow',
|
||||||
|
updatedInput: permissionResult.updatedInput ?? toolInput,
|
||||||
|
};
|
||||||
|
} else {
|
||||||
|
return {
|
||||||
|
behavior: 'deny',
|
||||||
|
message: permissionResult.message ?? 'Denied',
|
||||||
|
...(permissionResult.interrupt !== undefined
|
||||||
|
? { interrupt: permissionResult.interrupt }
|
||||||
|
: {}),
|
||||||
|
};
|
||||||
}
|
}
|
||||||
return result as PermissionApproval;
|
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
/**
|
/**
|
||||||
* Timeout or error → deny (fail-safe).
|
* Timeout or error → deny (fail-safe).
|
||||||
* This ensures that any issues with the permission callback
|
* This ensures that any issues with the permission callback
|
||||||
* result in a safe default of denying access.
|
* result in a safe default of denying access.
|
||||||
*/
|
*/
|
||||||
|
const errorMessage =
|
||||||
|
error instanceof Error ? error.message : String(error);
|
||||||
console.warn(
|
console.warn(
|
||||||
'[Query] Permission callback error (denying by default):',
|
'[Query] Permission callback error (denying by default):',
|
||||||
error instanceof Error ? error.message : String(error),
|
errorMessage,
|
||||||
);
|
);
|
||||||
return { allowed: false };
|
return {
|
||||||
|
behavior: 'deny',
|
||||||
|
message: `Permission check failed: ${errorMessage}`,
|
||||||
|
};
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -283,6 +283,10 @@ export class ProcessTransport implements Transport {
|
|||||||
throw new Error('Cannot write to closed transport');
|
throw new Error('Cannot write to closed transport');
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (this.childStdin.writableEnded) {
|
||||||
|
throw new Error('Cannot write to ended stream');
|
||||||
|
}
|
||||||
|
|
||||||
if (this.childProcess?.killed || this.childProcess?.exitCode !== null) {
|
if (this.childProcess?.killed || this.childProcess?.exitCode !== null) {
|
||||||
throw new Error('Cannot write to terminated process');
|
throw new Error('Cannot write to terminated process');
|
||||||
}
|
}
|
||||||
@@ -293,17 +297,21 @@ export class ProcessTransport implements Transport {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (process.env['DEBUG_SDK']) {
|
if (process.env['DEBUG']) {
|
||||||
this.logForDebugging(
|
this.logForDebugging(
|
||||||
`[ProcessTransport] Writing to stdin: ${message.substring(0, 100)}`,
|
`[ProcessTransport] Writing to stdin (${message.length} bytes): ${message.substring(0, 100)}`,
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
try {
|
try {
|
||||||
const written = this.childStdin.write(message);
|
const written = this.childStdin.write(message);
|
||||||
if (!written && process.env['DEBUG_SDK']) {
|
if (!written) {
|
||||||
this.logForDebugging(
|
this.logForDebugging(
|
||||||
'[ProcessTransport] Write buffer full, data queued',
|
`[ProcessTransport] Write buffer full (${message.length} bytes), data queued. Waiting for drain event...`,
|
||||||
|
);
|
||||||
|
} else if (process.env['DEBUG']) {
|
||||||
|
this.logForDebugging(
|
||||||
|
`[ProcessTransport] Write successful (${message.length} bytes)`,
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
@@ -322,6 +330,7 @@ export class ProcessTransport implements Transport {
|
|||||||
const rl = readline.createInterface({
|
const rl = readline.createInterface({
|
||||||
input: this.childStdout,
|
input: this.childStdout,
|
||||||
crlfDelay: Infinity,
|
crlfDelay: Infinity,
|
||||||
|
terminal: false,
|
||||||
});
|
});
|
||||||
|
|
||||||
try {
|
try {
|
||||||
|
|||||||
@@ -3,7 +3,7 @@
|
|||||||
*/
|
*/
|
||||||
|
|
||||||
import type { ToolDefinition as ToolDef } from './mcp.js';
|
import type { ToolDefinition as ToolDef } from './mcp.js';
|
||||||
import type { PermissionMode } from './protocol.js';
|
import type { PermissionMode, PermissionSuggestion } from './protocol.js';
|
||||||
import type { ExternalMcpServerConfig } from './queryOptionsSchema.js';
|
import type { ExternalMcpServerConfig } from './queryOptionsSchema.js';
|
||||||
|
|
||||||
export type { ToolDef as ToolDefinition };
|
export type { ToolDef as ToolDefinition };
|
||||||
@@ -161,14 +161,15 @@ type ToolInput = Record<string, unknown>;
|
|||||||
*
|
*
|
||||||
* @param toolName - Name of the tool being executed
|
* @param toolName - Name of the tool being executed
|
||||||
* @param input - Input parameters for the tool
|
* @param input - Input parameters for the tool
|
||||||
* @param options - Options including abort signal
|
* @param options - Options including abort signal and suggestions
|
||||||
* @returns Promise with permission result
|
* @returns Promise with permission result
|
||||||
*/
|
*/
|
||||||
type CanUseTool = (
|
export type CanUseTool = (
|
||||||
toolName: string,
|
toolName: string,
|
||||||
input: ToolInput,
|
input: ToolInput,
|
||||||
options: {
|
options: {
|
||||||
signal: AbortSignal;
|
signal: AbortSignal;
|
||||||
|
suggestions?: PermissionSuggestion[] | null;
|
||||||
},
|
},
|
||||||
) => Promise<PermissionResult>;
|
) => Promise<PermissionResult>;
|
||||||
|
|
||||||
|
|||||||
@@ -3,7 +3,7 @@
|
|||||||
*/
|
*/
|
||||||
|
|
||||||
import { z } from 'zod';
|
import { z } from 'zod';
|
||||||
import type { PermissionCallback } from './config.js';
|
import type { CanUseTool } from './config.js';
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Schema for external MCP server configuration
|
* Schema for external MCP server configuration
|
||||||
@@ -35,7 +35,7 @@ export const QueryOptionsSchema = z
|
|||||||
env: z.record(z.string(), z.string()).optional(),
|
env: z.record(z.string(), z.string()).optional(),
|
||||||
permissionMode: z.enum(['default', 'plan', 'auto-edit', 'yolo']).optional(),
|
permissionMode: z.enum(['default', 'plan', 'auto-edit', 'yolo']).optional(),
|
||||||
canUseTool: z
|
canUseTool: z
|
||||||
.custom<PermissionCallback>((val) => typeof val === 'function', {
|
.custom<CanUseTool>((val) => typeof val === 'function', {
|
||||||
message: 'canUseTool must be a function',
|
message: 'canUseTool must be a function',
|
||||||
})
|
})
|
||||||
.optional(),
|
.optional(),
|
||||||
|
|||||||
@@ -81,7 +81,7 @@ describe('AbortController and Process Lifecycle (E2E)', () => {
|
|||||||
const controller = new AbortController();
|
const controller = new AbortController();
|
||||||
|
|
||||||
const q = query({
|
const q = query({
|
||||||
prompt: 'Write a detailed explanation about TypeScript',
|
prompt: 'Hello',
|
||||||
options: {
|
options: {
|
||||||
...SHARED_TEST_OPTIONS,
|
...SHARED_TEST_OPTIONS,
|
||||||
abortController: controller,
|
abortController: controller,
|
||||||
|
|||||||
@@ -160,8 +160,7 @@ describe('Multi-Turn Conversations (E2E)', () => {
|
|||||||
session_id: sessionId,
|
session_id: sessionId,
|
||||||
message: {
|
message: {
|
||||||
role: 'user',
|
role: 'user',
|
||||||
content:
|
content: 'My name is Alice. Hello!',
|
||||||
'My name is Alice. Remember this during our current conversation.',
|
|
||||||
},
|
},
|
||||||
parent_tool_use_id: null,
|
parent_tool_use_id: null,
|
||||||
} as CLIUserMessage;
|
} as CLIUserMessage;
|
||||||
@@ -212,80 +211,72 @@ describe('Multi-Turn Conversations (E2E)', () => {
|
|||||||
});
|
});
|
||||||
|
|
||||||
describe('Tool Usage in Multi-Turn', () => {
|
describe('Tool Usage in Multi-Turn', () => {
|
||||||
it(
|
it('should handle tool usage across multiple turns', async () => {
|
||||||
'should handle tool usage across multiple turns',
|
async function* createToolConversation(): AsyncIterable<CLIUserMessage> {
|
||||||
async () => {
|
const sessionId = crypto.randomUUID();
|
||||||
async function* createToolConversation(): AsyncIterable<CLIUserMessage> {
|
|
||||||
const sessionId = crypto.randomUUID();
|
|
||||||
|
|
||||||
yield {
|
yield {
|
||||||
type: 'user',
|
type: 'user',
|
||||||
session_id: sessionId,
|
session_id: sessionId,
|
||||||
message: {
|
message: {
|
||||||
role: 'user',
|
role: 'user',
|
||||||
content: 'List the files in the current directory',
|
content: 'List the files in the current directory',
|
||||||
},
|
|
||||||
parent_tool_use_id: null,
|
|
||||||
} as CLIUserMessage;
|
|
||||||
|
|
||||||
await new Promise((resolve) => setTimeout(resolve, 200));
|
|
||||||
|
|
||||||
yield {
|
|
||||||
type: 'user',
|
|
||||||
session_id: sessionId,
|
|
||||||
message: {
|
|
||||||
role: 'user',
|
|
||||||
content: 'Now tell me about the package.json file specifically',
|
|
||||||
},
|
|
||||||
parent_tool_use_id: null,
|
|
||||||
} as CLIUserMessage;
|
|
||||||
}
|
|
||||||
|
|
||||||
const q = query({
|
|
||||||
prompt: createToolConversation(),
|
|
||||||
options: {
|
|
||||||
...SHARED_TEST_OPTIONS,
|
|
||||||
cwd: process.cwd(),
|
|
||||||
debug: false,
|
|
||||||
},
|
},
|
||||||
});
|
parent_tool_use_id: null,
|
||||||
|
} as CLIUserMessage;
|
||||||
|
|
||||||
const messages: CLIMessage[] = [];
|
await new Promise((resolve) => setTimeout(resolve, 200));
|
||||||
let toolUseCount = 0;
|
|
||||||
let assistantCount = 0;
|
|
||||||
|
|
||||||
try {
|
yield {
|
||||||
for await (const message of q) {
|
type: 'user',
|
||||||
messages.push(message);
|
session_id: sessionId,
|
||||||
|
message: {
|
||||||
|
role: 'user',
|
||||||
|
content: 'Now tell me about the package.json file specifically',
|
||||||
|
},
|
||||||
|
parent_tool_use_id: null,
|
||||||
|
} as CLIUserMessage;
|
||||||
|
}
|
||||||
|
|
||||||
if (isCLIAssistantMessage(message)) {
|
const q = query({
|
||||||
const hasToolUseBlock = message.message.content.some(
|
prompt: createToolConversation(),
|
||||||
(block: ContentBlock): block is ToolUseBlock =>
|
options: {
|
||||||
block.type === 'tool_use',
|
...SHARED_TEST_OPTIONS,
|
||||||
);
|
cwd: process.cwd(),
|
||||||
if (hasToolUseBlock) {
|
debug: false,
|
||||||
toolUseCount++;
|
},
|
||||||
}
|
});
|
||||||
}
|
|
||||||
|
|
||||||
if (isCLIAssistantMessage(message)) {
|
const messages: CLIMessage[] = [];
|
||||||
assistantCount++;
|
let toolUseCount = 0;
|
||||||
}
|
let assistantCount = 0;
|
||||||
|
|
||||||
if (isCLIResultMessage(message)) {
|
try {
|
||||||
break;
|
for await (const message of q) {
|
||||||
|
messages.push(message);
|
||||||
|
|
||||||
|
if (isCLIAssistantMessage(message)) {
|
||||||
|
const hasToolUseBlock = message.message.content.some(
|
||||||
|
(block: ContentBlock): block is ToolUseBlock =>
|
||||||
|
block.type === 'tool_use',
|
||||||
|
);
|
||||||
|
if (hasToolUseBlock) {
|
||||||
|
toolUseCount++;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
expect(messages.length).toBeGreaterThan(0);
|
if (isCLIAssistantMessage(message)) {
|
||||||
expect(toolUseCount).toBeGreaterThan(0); // Should use tools
|
assistantCount++;
|
||||||
expect(assistantCount).toBeGreaterThanOrEqual(2); // Should have responses to both questions
|
}
|
||||||
} finally {
|
|
||||||
await q.close();
|
|
||||||
}
|
}
|
||||||
},
|
|
||||||
TEST_TIMEOUT,
|
expect(messages.length).toBeGreaterThan(0);
|
||||||
);
|
expect(toolUseCount).toBeGreaterThan(0); // Should use tools
|
||||||
|
expect(assistantCount).toBeGreaterThanOrEqual(2); // Should have responses to both questions
|
||||||
|
} finally {
|
||||||
|
await q.close();
|
||||||
|
}
|
||||||
|
}, 60000); //TEST_TIMEOUT,
|
||||||
});
|
});
|
||||||
|
|
||||||
describe('Message Flow and Sequencing', () => {
|
describe('Message Flow and Sequencing', () => {
|
||||||
|
|||||||
748
packages/sdk/typescript/test/e2e/permission-control.test.ts
Normal file
748
packages/sdk/typescript/test/e2e/permission-control.test.ts
Normal file
@@ -0,0 +1,748 @@
|
|||||||
|
/**
|
||||||
|
* E2E tests for permission control features:
|
||||||
|
* - canUseTool callback parameter
|
||||||
|
* - setPermissionMode API
|
||||||
|
*/
|
||||||
|
|
||||||
|
import { describe, it, expect, beforeAll, afterAll } from 'vitest';
|
||||||
|
import { query } from '../../src/index.js';
|
||||||
|
import {
|
||||||
|
isCLIAssistantMessage,
|
||||||
|
isCLIResultMessage,
|
||||||
|
isCLIUserMessage,
|
||||||
|
type CLIUserMessage,
|
||||||
|
type ToolUseBlock,
|
||||||
|
type ContentBlock,
|
||||||
|
} from '../../src/types/protocol.js';
|
||||||
|
|
||||||
|
const TEST_CLI_PATH =
|
||||||
|
'/Users/mingholy/Work/Projects/qwen-code/packages/cli/index.ts';
|
||||||
|
const TEST_TIMEOUT = 1600000;
|
||||||
|
|
||||||
|
const SHARED_TEST_OPTIONS = {
|
||||||
|
pathToQwenExecutable: TEST_CLI_PATH,
|
||||||
|
debug: false,
|
||||||
|
// env here sets environment variables for the CLI child process
|
||||||
|
env: {
|
||||||
|
// DEBUG: '1',
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Factory function that creates a streaming input with a control point.
|
||||||
|
* After the first message is yielded, the generator waits for a resume signal,
|
||||||
|
* allowing the test code to call query instance methods like setPermissionMode.
|
||||||
|
*/
|
||||||
|
function createStreamingInputWithControlPoint(
|
||||||
|
firstMessage: string,
|
||||||
|
secondMessage: string,
|
||||||
|
): {
|
||||||
|
generator: AsyncIterable<CLIUserMessage>;
|
||||||
|
resume: () => void;
|
||||||
|
} {
|
||||||
|
let resumeResolve: (() => void) | null = null;
|
||||||
|
const resumePromise = new Promise<void>((resolve) => {
|
||||||
|
resumeResolve = resolve;
|
||||||
|
});
|
||||||
|
|
||||||
|
const generator = (async function* () {
|
||||||
|
const sessionId = crypto.randomUUID();
|
||||||
|
|
||||||
|
yield {
|
||||||
|
type: 'user',
|
||||||
|
session_id: sessionId,
|
||||||
|
message: {
|
||||||
|
role: 'user',
|
||||||
|
content: firstMessage,
|
||||||
|
},
|
||||||
|
parent_tool_use_id: null,
|
||||||
|
} as CLIUserMessage;
|
||||||
|
|
||||||
|
await new Promise((resolve) => setTimeout(resolve, 200));
|
||||||
|
|
||||||
|
await resumePromise;
|
||||||
|
|
||||||
|
await new Promise((resolve) => setTimeout(resolve, 200));
|
||||||
|
|
||||||
|
yield {
|
||||||
|
type: 'user',
|
||||||
|
session_id: sessionId,
|
||||||
|
message: {
|
||||||
|
role: 'user',
|
||||||
|
content: secondMessage,
|
||||||
|
},
|
||||||
|
parent_tool_use_id: null,
|
||||||
|
} as CLIUserMessage;
|
||||||
|
})();
|
||||||
|
|
||||||
|
const resume = () => {
|
||||||
|
if (resumeResolve) {
|
||||||
|
resumeResolve();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
return { generator, resume };
|
||||||
|
}
|
||||||
|
|
||||||
|
describe('Permission Control (E2E)', () => {
|
||||||
|
beforeAll(() => {
|
||||||
|
//process.env['DEBUG'] = '1';
|
||||||
|
});
|
||||||
|
|
||||||
|
afterAll(() => {
|
||||||
|
delete process.env['DEBUG'];
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('canUseTool callback parameter', () => {
|
||||||
|
it(
|
||||||
|
'should invoke canUseTool callback when tool is requested',
|
||||||
|
async () => {
|
||||||
|
const toolCalls: Array<{
|
||||||
|
toolName: string;
|
||||||
|
input: Record<string, unknown>;
|
||||||
|
}> = [];
|
||||||
|
|
||||||
|
const q = query({
|
||||||
|
prompt: 'Write a js hello world to file.',
|
||||||
|
options: {
|
||||||
|
...SHARED_TEST_OPTIONS,
|
||||||
|
permissionMode: 'default',
|
||||||
|
|
||||||
|
canUseTool: async (toolName, input) => {
|
||||||
|
toolCalls.push({ toolName, input });
|
||||||
|
console.log(toolName, input);
|
||||||
|
/*
|
||||||
|
{
|
||||||
|
behavior: 'allow',
|
||||||
|
updatedInput: input,
|
||||||
|
};
|
||||||
|
*/
|
||||||
|
return {
|
||||||
|
behavior: 'deny',
|
||||||
|
message: 'Tool execution denied by user.',
|
||||||
|
};
|
||||||
|
},
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
try {
|
||||||
|
let hasToolUse = false;
|
||||||
|
for await (const message of q) {
|
||||||
|
if (isCLIAssistantMessage(message)) {
|
||||||
|
const toolUseBlock = message.message.content.find(
|
||||||
|
(block: ContentBlock): block is ToolUseBlock =>
|
||||||
|
block.type === 'tool_use',
|
||||||
|
);
|
||||||
|
if (toolUseBlock) {
|
||||||
|
hasToolUse = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
expect(hasToolUse).toBe(true);
|
||||||
|
expect(toolCalls.length).toBeGreaterThan(0);
|
||||||
|
expect(toolCalls[0].toolName).toBeDefined();
|
||||||
|
expect(toolCalls[0].input).toBeDefined();
|
||||||
|
} finally {
|
||||||
|
await q.close();
|
||||||
|
}
|
||||||
|
},
|
||||||
|
TEST_TIMEOUT,
|
||||||
|
);
|
||||||
|
|
||||||
|
it(
|
||||||
|
'should allow tool execution when canUseTool returns allow',
|
||||||
|
async () => {
|
||||||
|
let callbackInvoked = false;
|
||||||
|
|
||||||
|
const q = query({
|
||||||
|
prompt: 'List files in the current directory',
|
||||||
|
options: {
|
||||||
|
...SHARED_TEST_OPTIONS,
|
||||||
|
permissionMode: 'default',
|
||||||
|
canUseTool: async (toolName, input) => {
|
||||||
|
callbackInvoked = true;
|
||||||
|
return {
|
||||||
|
behavior: 'allow',
|
||||||
|
updatedInput: input,
|
||||||
|
};
|
||||||
|
},
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
try {
|
||||||
|
let hasToolResult = false;
|
||||||
|
for await (const message of q) {
|
||||||
|
if (isCLIUserMessage(message)) {
|
||||||
|
if (
|
||||||
|
Array.isArray(message.message.content) &&
|
||||||
|
message.message.content.some(
|
||||||
|
(block) => block.type === 'tool_result',
|
||||||
|
)
|
||||||
|
) {
|
||||||
|
hasToolResult = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (isCLIResultMessage(message)) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
expect(callbackInvoked).toBe(true);
|
||||||
|
expect(hasToolResult).toBe(true);
|
||||||
|
} finally {
|
||||||
|
await q.close();
|
||||||
|
}
|
||||||
|
},
|
||||||
|
TEST_TIMEOUT,
|
||||||
|
);
|
||||||
|
|
||||||
|
it(
|
||||||
|
'should deny tool execution when canUseTool returns deny',
|
||||||
|
async () => {
|
||||||
|
let callbackInvoked = false;
|
||||||
|
|
||||||
|
const q = query({
|
||||||
|
prompt: 'List files in the current directory',
|
||||||
|
options: {
|
||||||
|
...SHARED_TEST_OPTIONS,
|
||||||
|
permissionMode: 'default',
|
||||||
|
canUseTool: async () => {
|
||||||
|
callbackInvoked = true;
|
||||||
|
return {
|
||||||
|
behavior: 'deny',
|
||||||
|
message: 'Tool execution denied by test',
|
||||||
|
};
|
||||||
|
},
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
try {
|
||||||
|
for await (const message of q) {
|
||||||
|
if (isCLIResultMessage(message)) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
expect(callbackInvoked).toBe(true);
|
||||||
|
// Tool use might still appear, but execution should be denied
|
||||||
|
// The exact behavior depends on CLI implementation
|
||||||
|
} finally {
|
||||||
|
await q.close();
|
||||||
|
}
|
||||||
|
},
|
||||||
|
TEST_TIMEOUT,
|
||||||
|
);
|
||||||
|
|
||||||
|
it(
|
||||||
|
'should pass suggestions to canUseTool callback',
|
||||||
|
async () => {
|
||||||
|
let receivedSuggestions: unknown = null;
|
||||||
|
|
||||||
|
const q = query({
|
||||||
|
prompt: 'List files in the current directory',
|
||||||
|
options: {
|
||||||
|
...SHARED_TEST_OPTIONS,
|
||||||
|
permissionMode: 'default',
|
||||||
|
canUseTool: async (toolName, input, options) => {
|
||||||
|
receivedSuggestions = options.suggestions;
|
||||||
|
return {
|
||||||
|
behavior: 'allow',
|
||||||
|
updatedInput: input,
|
||||||
|
};
|
||||||
|
},
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
try {
|
||||||
|
for await (const message of q) {
|
||||||
|
if (isCLIResultMessage(message)) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Suggestions may be null or an array, depending on CLI implementation
|
||||||
|
expect(receivedSuggestions !== undefined).toBe(true);
|
||||||
|
} finally {
|
||||||
|
await q.close();
|
||||||
|
}
|
||||||
|
},
|
||||||
|
TEST_TIMEOUT,
|
||||||
|
);
|
||||||
|
|
||||||
|
it(
|
||||||
|
'should pass abort signal to canUseTool callback',
|
||||||
|
async () => {
|
||||||
|
let receivedSignal: AbortSignal | undefined = undefined;
|
||||||
|
|
||||||
|
const q = query({
|
||||||
|
prompt: 'List files in the current directory',
|
||||||
|
options: {
|
||||||
|
...SHARED_TEST_OPTIONS,
|
||||||
|
permissionMode: 'default',
|
||||||
|
canUseTool: async (toolName, input, options) => {
|
||||||
|
receivedSignal = options.signal;
|
||||||
|
return {
|
||||||
|
behavior: 'allow',
|
||||||
|
updatedInput: input,
|
||||||
|
};
|
||||||
|
},
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
try {
|
||||||
|
for await (const message of q) {
|
||||||
|
if (isCLIResultMessage(message)) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
expect(receivedSignal).toBeDefined();
|
||||||
|
expect(receivedSignal).toBeInstanceOf(AbortSignal);
|
||||||
|
} finally {
|
||||||
|
await q.close();
|
||||||
|
}
|
||||||
|
},
|
||||||
|
TEST_TIMEOUT,
|
||||||
|
);
|
||||||
|
|
||||||
|
it(
|
||||||
|
'should allow updatedInput modification in canUseTool callback',
|
||||||
|
async () => {
|
||||||
|
const originalInputs: Record<string, unknown>[] = [];
|
||||||
|
const updatedInputs: Record<string, unknown>[] = [];
|
||||||
|
|
||||||
|
const q = query({
|
||||||
|
prompt: 'List files in the current directory',
|
||||||
|
options: {
|
||||||
|
...SHARED_TEST_OPTIONS,
|
||||||
|
permissionMode: 'default',
|
||||||
|
canUseTool: async (toolName, input) => {
|
||||||
|
originalInputs.push({ ...input });
|
||||||
|
const updatedInput = {
|
||||||
|
...input,
|
||||||
|
modified: true,
|
||||||
|
testKey: 'testValue',
|
||||||
|
};
|
||||||
|
updatedInputs.push(updatedInput);
|
||||||
|
return {
|
||||||
|
behavior: 'allow',
|
||||||
|
updatedInput,
|
||||||
|
};
|
||||||
|
},
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
try {
|
||||||
|
for await (const message of q) {
|
||||||
|
if (isCLIResultMessage(message)) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
expect(originalInputs.length).toBeGreaterThan(0);
|
||||||
|
expect(updatedInputs.length).toBeGreaterThan(0);
|
||||||
|
expect(updatedInputs[0]?.['modified']).toBe(true);
|
||||||
|
expect(updatedInputs[0]?.['testKey']).toBe('testValue');
|
||||||
|
} finally {
|
||||||
|
await q.close();
|
||||||
|
}
|
||||||
|
},
|
||||||
|
TEST_TIMEOUT,
|
||||||
|
);
|
||||||
|
|
||||||
|
it(
|
||||||
|
'should default to deny when canUseTool is not provided',
|
||||||
|
async () => {
|
||||||
|
const q = query({
|
||||||
|
prompt: 'List files in the current directory',
|
||||||
|
options: {
|
||||||
|
...SHARED_TEST_OPTIONS,
|
||||||
|
permissionMode: 'default',
|
||||||
|
// canUseTool not provided
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
try {
|
||||||
|
// When canUseTool is not provided, tools should be denied by default
|
||||||
|
// The exact behavior depends on CLI implementation
|
||||||
|
for await (const message of q) {
|
||||||
|
if (isCLIResultMessage(message)) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Test passes if no errors occur
|
||||||
|
expect(true).toBe(true);
|
||||||
|
} finally {
|
||||||
|
await q.close();
|
||||||
|
}
|
||||||
|
},
|
||||||
|
TEST_TIMEOUT,
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('setPermissionMode API', () => {
|
||||||
|
it(
|
||||||
|
'should change permission mode from default to yolo',
|
||||||
|
async () => {
|
||||||
|
const { generator, resume } = createStreamingInputWithControlPoint(
|
||||||
|
'List files in the current directory',
|
||||||
|
'Now read the package.json file',
|
||||||
|
);
|
||||||
|
|
||||||
|
const q = query({
|
||||||
|
prompt: generator,
|
||||||
|
options: {
|
||||||
|
...SHARED_TEST_OPTIONS,
|
||||||
|
permissionMode: 'default',
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
try {
|
||||||
|
const resolvers: {
|
||||||
|
first?: () => void;
|
||||||
|
second?: () => void;
|
||||||
|
} = {};
|
||||||
|
const firstResponsePromise = new Promise<void>((resolve) => {
|
||||||
|
resolvers.first = resolve;
|
||||||
|
});
|
||||||
|
const secondResponsePromise = new Promise<void>((resolve) => {
|
||||||
|
resolvers.second = resolve;
|
||||||
|
});
|
||||||
|
|
||||||
|
let firstResponseReceived = false;
|
||||||
|
let secondResponseReceived = false;
|
||||||
|
|
||||||
|
(async () => {
|
||||||
|
for await (const message of q) {
|
||||||
|
if (
|
||||||
|
isCLIAssistantMessage(message) ||
|
||||||
|
isCLIResultMessage(message)
|
||||||
|
) {
|
||||||
|
if (!firstResponseReceived) {
|
||||||
|
firstResponseReceived = true;
|
||||||
|
resolvers.first?.();
|
||||||
|
} else if (!secondResponseReceived) {
|
||||||
|
secondResponseReceived = true;
|
||||||
|
resolvers.second?.();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})();
|
||||||
|
|
||||||
|
await Promise.race([
|
||||||
|
firstResponsePromise,
|
||||||
|
new Promise((_, reject) =>
|
||||||
|
setTimeout(
|
||||||
|
() => reject(new Error('Timeout waiting for first response')),
|
||||||
|
TEST_TIMEOUT,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
]);
|
||||||
|
|
||||||
|
expect(firstResponseReceived).toBe(true);
|
||||||
|
|
||||||
|
await q.setPermissionMode('yolo');
|
||||||
|
|
||||||
|
resume();
|
||||||
|
|
||||||
|
await Promise.race([
|
||||||
|
secondResponsePromise,
|
||||||
|
new Promise((_, reject) =>
|
||||||
|
setTimeout(
|
||||||
|
() => reject(new Error('Timeout waiting for second response')),
|
||||||
|
TEST_TIMEOUT,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
]);
|
||||||
|
|
||||||
|
expect(secondResponseReceived).toBe(true);
|
||||||
|
} finally {
|
||||||
|
await q.close();
|
||||||
|
}
|
||||||
|
},
|
||||||
|
TEST_TIMEOUT,
|
||||||
|
);
|
||||||
|
|
||||||
|
it(
|
||||||
|
'should change permission mode from yolo to plan',
|
||||||
|
async () => {
|
||||||
|
const { generator, resume } = createStreamingInputWithControlPoint(
|
||||||
|
'List files in the current directory',
|
||||||
|
'Now read the package.json file',
|
||||||
|
);
|
||||||
|
|
||||||
|
const q = query({
|
||||||
|
prompt: generator,
|
||||||
|
options: {
|
||||||
|
...SHARED_TEST_OPTIONS,
|
||||||
|
permissionMode: 'yolo',
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
try {
|
||||||
|
const resolvers: {
|
||||||
|
first?: () => void;
|
||||||
|
second?: () => void;
|
||||||
|
} = {};
|
||||||
|
const firstResponsePromise = new Promise<void>((resolve) => {
|
||||||
|
resolvers.first = resolve;
|
||||||
|
});
|
||||||
|
const secondResponsePromise = new Promise<void>((resolve) => {
|
||||||
|
resolvers.second = resolve;
|
||||||
|
});
|
||||||
|
|
||||||
|
let firstResponseReceived = false;
|
||||||
|
let secondResponseReceived = false;
|
||||||
|
|
||||||
|
(async () => {
|
||||||
|
for await (const message of q) {
|
||||||
|
if (
|
||||||
|
isCLIAssistantMessage(message) ||
|
||||||
|
isCLIResultMessage(message)
|
||||||
|
) {
|
||||||
|
if (!firstResponseReceived) {
|
||||||
|
firstResponseReceived = true;
|
||||||
|
resolvers.first?.();
|
||||||
|
} else if (!secondResponseReceived) {
|
||||||
|
secondResponseReceived = true;
|
||||||
|
resolvers.second?.();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})();
|
||||||
|
|
||||||
|
await Promise.race([
|
||||||
|
firstResponsePromise,
|
||||||
|
new Promise((_, reject) =>
|
||||||
|
setTimeout(
|
||||||
|
() => reject(new Error('Timeout waiting for first response')),
|
||||||
|
TEST_TIMEOUT,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
]);
|
||||||
|
|
||||||
|
expect(firstResponseReceived).toBe(true);
|
||||||
|
|
||||||
|
await q.setPermissionMode('plan');
|
||||||
|
|
||||||
|
resume();
|
||||||
|
|
||||||
|
await Promise.race([
|
||||||
|
secondResponsePromise,
|
||||||
|
new Promise((_, reject) =>
|
||||||
|
setTimeout(
|
||||||
|
() => reject(new Error('Timeout waiting for second response')),
|
||||||
|
TEST_TIMEOUT,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
]);
|
||||||
|
|
||||||
|
expect(secondResponseReceived).toBe(true);
|
||||||
|
} finally {
|
||||||
|
await q.close();
|
||||||
|
}
|
||||||
|
},
|
||||||
|
TEST_TIMEOUT,
|
||||||
|
);
|
||||||
|
|
||||||
|
it(
|
||||||
|
'should change permission mode to auto-edit',
|
||||||
|
async () => {
|
||||||
|
const { generator, resume } = createStreamingInputWithControlPoint(
|
||||||
|
'List files in the current directory',
|
||||||
|
'Now read the package.json file',
|
||||||
|
);
|
||||||
|
|
||||||
|
const q = query({
|
||||||
|
prompt: generator,
|
||||||
|
options: {
|
||||||
|
...SHARED_TEST_OPTIONS,
|
||||||
|
permissionMode: 'default',
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
try {
|
||||||
|
const resolvers: {
|
||||||
|
first?: () => void;
|
||||||
|
second?: () => void;
|
||||||
|
} = {};
|
||||||
|
const firstResponsePromise = new Promise<void>((resolve) => {
|
||||||
|
resolvers.first = resolve;
|
||||||
|
});
|
||||||
|
const secondResponsePromise = new Promise<void>((resolve) => {
|
||||||
|
resolvers.second = resolve;
|
||||||
|
});
|
||||||
|
|
||||||
|
let firstResponseReceived = false;
|
||||||
|
let secondResponseReceived = false;
|
||||||
|
|
||||||
|
(async () => {
|
||||||
|
for await (const message of q) {
|
||||||
|
if (
|
||||||
|
isCLIAssistantMessage(message) ||
|
||||||
|
isCLIResultMessage(message)
|
||||||
|
) {
|
||||||
|
if (!firstResponseReceived) {
|
||||||
|
firstResponseReceived = true;
|
||||||
|
resolvers.first?.();
|
||||||
|
} else if (!secondResponseReceived) {
|
||||||
|
secondResponseReceived = true;
|
||||||
|
resolvers.second?.();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})();
|
||||||
|
|
||||||
|
await Promise.race([
|
||||||
|
firstResponsePromise,
|
||||||
|
new Promise((_, reject) =>
|
||||||
|
setTimeout(
|
||||||
|
() => reject(new Error('Timeout waiting for first response')),
|
||||||
|
TEST_TIMEOUT,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
]);
|
||||||
|
|
||||||
|
expect(firstResponseReceived).toBe(true);
|
||||||
|
|
||||||
|
await q.setPermissionMode('auto-edit');
|
||||||
|
|
||||||
|
resume();
|
||||||
|
|
||||||
|
await Promise.race([
|
||||||
|
secondResponsePromise,
|
||||||
|
new Promise((_, reject) =>
|
||||||
|
setTimeout(
|
||||||
|
() => reject(new Error('Timeout waiting for second response')),
|
||||||
|
TEST_TIMEOUT,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
]);
|
||||||
|
|
||||||
|
expect(secondResponseReceived).toBe(true);
|
||||||
|
} finally {
|
||||||
|
await q.close();
|
||||||
|
}
|
||||||
|
},
|
||||||
|
TEST_TIMEOUT,
|
||||||
|
);
|
||||||
|
|
||||||
|
it(
|
||||||
|
'should throw error when setPermissionMode is called on closed query',
|
||||||
|
async () => {
|
||||||
|
const q = query({
|
||||||
|
prompt: 'Hello',
|
||||||
|
options: {
|
||||||
|
...SHARED_TEST_OPTIONS,
|
||||||
|
permissionMode: 'default',
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
await q.close();
|
||||||
|
|
||||||
|
await expect(q.setPermissionMode('yolo')).rejects.toThrow(
|
||||||
|
'Query is closed',
|
||||||
|
);
|
||||||
|
},
|
||||||
|
TEST_TIMEOUT,
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('canUseTool and setPermissionMode integration', () => {
|
||||||
|
it(
|
||||||
|
'should work together - canUseTool callback with dynamic permission mode change',
|
||||||
|
async () => {
|
||||||
|
const toolCalls: Array<{
|
||||||
|
toolName: string;
|
||||||
|
input: Record<string, unknown>;
|
||||||
|
}> = [];
|
||||||
|
|
||||||
|
const { generator, resume } = createStreamingInputWithControlPoint(
|
||||||
|
'List files in the current directory',
|
||||||
|
'Now read the package.json file',
|
||||||
|
);
|
||||||
|
|
||||||
|
const q = query({
|
||||||
|
prompt: generator,
|
||||||
|
options: {
|
||||||
|
...SHARED_TEST_OPTIONS,
|
||||||
|
permissionMode: 'default',
|
||||||
|
canUseTool: async (toolName, input) => {
|
||||||
|
toolCalls.push({ toolName, input });
|
||||||
|
return {
|
||||||
|
behavior: 'allow',
|
||||||
|
updatedInput: input,
|
||||||
|
};
|
||||||
|
},
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
try {
|
||||||
|
const resolvers: {
|
||||||
|
first?: () => void;
|
||||||
|
second?: () => void;
|
||||||
|
} = {};
|
||||||
|
const firstResponsePromise = new Promise<void>((resolve) => {
|
||||||
|
resolvers.first = resolve;
|
||||||
|
});
|
||||||
|
const secondResponsePromise = new Promise<void>((resolve) => {
|
||||||
|
resolvers.second = resolve;
|
||||||
|
});
|
||||||
|
|
||||||
|
let firstResponseReceived = false;
|
||||||
|
let secondResponseReceived = false;
|
||||||
|
|
||||||
|
(async () => {
|
||||||
|
for await (const message of q) {
|
||||||
|
if (
|
||||||
|
isCLIAssistantMessage(message) ||
|
||||||
|
isCLIResultMessage(message)
|
||||||
|
) {
|
||||||
|
if (!firstResponseReceived) {
|
||||||
|
firstResponseReceived = true;
|
||||||
|
resolvers.first?.();
|
||||||
|
} else if (!secondResponseReceived) {
|
||||||
|
secondResponseReceived = true;
|
||||||
|
resolvers.second?.();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})();
|
||||||
|
|
||||||
|
await Promise.race([
|
||||||
|
firstResponsePromise,
|
||||||
|
new Promise((_, reject) =>
|
||||||
|
setTimeout(
|
||||||
|
() => reject(new Error('Timeout waiting for first response')),
|
||||||
|
TEST_TIMEOUT,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
]);
|
||||||
|
|
||||||
|
expect(firstResponseReceived).toBe(true);
|
||||||
|
expect(toolCalls.length).toBeGreaterThan(0);
|
||||||
|
|
||||||
|
await q.setPermissionMode('yolo');
|
||||||
|
|
||||||
|
resume();
|
||||||
|
|
||||||
|
await Promise.race([
|
||||||
|
secondResponsePromise,
|
||||||
|
new Promise((_, reject) =>
|
||||||
|
setTimeout(
|
||||||
|
() => reject(new Error('Timeout waiting for second response')),
|
||||||
|
TEST_TIMEOUT,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
]);
|
||||||
|
|
||||||
|
expect(secondResponseReceived).toBe(true);
|
||||||
|
} finally {
|
||||||
|
await q.close();
|
||||||
|
}
|
||||||
|
},
|
||||||
|
TEST_TIMEOUT,
|
||||||
|
);
|
||||||
|
});
|
||||||
|
});
|
||||||
Reference in New Issue
Block a user