refactor: session and canUseTool support

This commit is contained in:
mingholy.lmh
2025-11-21 16:22:18 +08:00
parent 2fe6ba7c56
commit f635cd3070
21 changed files with 1266 additions and 639 deletions

View File

@@ -12,7 +12,6 @@ import type {
CLIControlRequest,
CLIControlResponse,
ControlCancelRequest,
PermissionApproval,
PermissionSuggestion,
} from '../types/protocol.js';
import {
@@ -299,7 +298,7 @@ export class Query implements AsyncIterable<CLIMessage> {
return;
}
if (process.env['DEBUG_SDK']) {
if (process.env['DEBUG']) {
console.warn('[Query] Unknown message type:', message);
}
this.inputStream.enqueue(message as CLIMessage);
@@ -320,12 +319,12 @@ export class Query implements AsyncIterable<CLIMessage> {
switch (payload.subtype) {
case 'can_use_tool':
response = (await this.handlePermissionRequest(
response = await this.handlePermissionRequest(
payload.tool_name,
payload.input as Record<string, unknown>,
payload.permission_suggestions,
requestAbortController.signal,
)) as unknown as Record<string, unknown>;
);
break;
case 'mcp_message':
@@ -360,15 +359,17 @@ export class Query implements AsyncIterable<CLIMessage> {
/**
* Handle permission request (can_use_tool)
* Converts PermissionResult to CLI-expected format: { behavior: 'allow', updatedInput: ... } or { behavior: 'deny', message: ... }
*/
private async handlePermissionRequest(
toolName: string,
toolInput: Record<string, unknown>,
permissionSuggestions: PermissionSuggestion[] | null,
signal: AbortSignal,
): Promise<PermissionApproval> {
): Promise<Record<string, unknown>> {
/* Default deny all wildcard tool requests */
if (!this.options.canUseTool) {
return { allowed: true };
return { behavior: 'deny', message: 'Denied' };
}
try {
@@ -390,21 +391,51 @@ export class Query implements AsyncIterable<CLIMessage> {
timeoutPromise,
]);
// Handle boolean return (backward compatibility)
if (typeof result === 'boolean') {
return { allowed: result };
return result
? { behavior: 'allow', updatedInput: toolInput }
: { behavior: 'deny', message: 'Denied' };
}
// Handle PermissionResult format
const permissionResult = result as {
behavior: 'allow' | 'deny';
updatedInput?: Record<string, unknown>;
message?: string;
interrupt?: boolean;
};
if (permissionResult.behavior === 'allow') {
return {
behavior: 'allow',
updatedInput: permissionResult.updatedInput ?? toolInput,
};
} else {
return {
behavior: 'deny',
message: permissionResult.message ?? 'Denied',
...(permissionResult.interrupt !== undefined
? { interrupt: permissionResult.interrupt }
: {}),
};
}
return result as PermissionApproval;
} catch (error) {
/**
* Timeout or error → deny (fail-safe).
* This ensures that any issues with the permission callback
* result in a safe default of denying access.
*/
const errorMessage =
error instanceof Error ? error.message : String(error);
console.warn(
'[Query] Permission callback error (denying by default):',
error instanceof Error ? error.message : String(error),
errorMessage,
);
return { allowed: false };
return {
behavior: 'deny',
message: `Permission check failed: ${errorMessage}`,
};
}
}

View File

@@ -283,6 +283,10 @@ export class ProcessTransport implements Transport {
throw new Error('Cannot write to closed transport');
}
if (this.childStdin.writableEnded) {
throw new Error('Cannot write to ended stream');
}
if (this.childProcess?.killed || this.childProcess?.exitCode !== null) {
throw new Error('Cannot write to terminated process');
}
@@ -293,17 +297,21 @@ export class ProcessTransport implements Transport {
);
}
if (process.env['DEBUG_SDK']) {
if (process.env['DEBUG']) {
this.logForDebugging(
`[ProcessTransport] Writing to stdin: ${message.substring(0, 100)}`,
`[ProcessTransport] Writing to stdin (${message.length} bytes): ${message.substring(0, 100)}`,
);
}
try {
const written = this.childStdin.write(message);
if (!written && process.env['DEBUG_SDK']) {
if (!written) {
this.logForDebugging(
'[ProcessTransport] Write buffer full, data queued',
`[ProcessTransport] Write buffer full (${message.length} bytes), data queued. Waiting for drain event...`,
);
} else if (process.env['DEBUG']) {
this.logForDebugging(
`[ProcessTransport] Write successful (${message.length} bytes)`,
);
}
} catch (error) {
@@ -322,6 +330,7 @@ export class ProcessTransport implements Transport {
const rl = readline.createInterface({
input: this.childStdout,
crlfDelay: Infinity,
terminal: false,
});
try {

View File

@@ -3,7 +3,7 @@
*/
import type { ToolDefinition as ToolDef } from './mcp.js';
import type { PermissionMode } from './protocol.js';
import type { PermissionMode, PermissionSuggestion } from './protocol.js';
import type { ExternalMcpServerConfig } from './queryOptionsSchema.js';
export type { ToolDef as ToolDefinition };
@@ -161,14 +161,15 @@ type ToolInput = Record<string, unknown>;
*
* @param toolName - Name of the tool being executed
* @param input - Input parameters for the tool
* @param options - Options including abort signal
* @param options - Options including abort signal and suggestions
* @returns Promise with permission result
*/
type CanUseTool = (
export type CanUseTool = (
toolName: string,
input: ToolInput,
options: {
signal: AbortSignal;
suggestions?: PermissionSuggestion[] | null;
},
) => Promise<PermissionResult>;

View File

@@ -3,7 +3,7 @@
*/
import { z } from 'zod';
import type { PermissionCallback } from './config.js';
import type { CanUseTool } from './config.js';
/**
* Schema for external MCP server configuration
@@ -35,7 +35,7 @@ export const QueryOptionsSchema = z
env: z.record(z.string(), z.string()).optional(),
permissionMode: z.enum(['default', 'plan', 'auto-edit', 'yolo']).optional(),
canUseTool: z
.custom<PermissionCallback>((val) => typeof val === 'function', {
.custom<CanUseTool>((val) => typeof val === 'function', {
message: 'canUseTool must be a function',
})
.optional(),

View File

@@ -81,7 +81,7 @@ describe('AbortController and Process Lifecycle (E2E)', () => {
const controller = new AbortController();
const q = query({
prompt: 'Write a detailed explanation about TypeScript',
prompt: 'Hello',
options: {
...SHARED_TEST_OPTIONS,
abortController: controller,

View File

@@ -160,8 +160,7 @@ describe('Multi-Turn Conversations (E2E)', () => {
session_id: sessionId,
message: {
role: 'user',
content:
'My name is Alice. Remember this during our current conversation.',
content: 'My name is Alice. Hello!',
},
parent_tool_use_id: null,
} as CLIUserMessage;
@@ -212,80 +211,72 @@ describe('Multi-Turn Conversations (E2E)', () => {
});
describe('Tool Usage in Multi-Turn', () => {
it(
'should handle tool usage across multiple turns',
async () => {
async function* createToolConversation(): AsyncIterable<CLIUserMessage> {
const sessionId = crypto.randomUUID();
it('should handle tool usage across multiple turns', async () => {
async function* createToolConversation(): AsyncIterable<CLIUserMessage> {
const sessionId = crypto.randomUUID();
yield {
type: 'user',
session_id: sessionId,
message: {
role: 'user',
content: 'List the files in the current directory',
},
parent_tool_use_id: null,
} as CLIUserMessage;
await new Promise((resolve) => setTimeout(resolve, 200));
yield {
type: 'user',
session_id: sessionId,
message: {
role: 'user',
content: 'Now tell me about the package.json file specifically',
},
parent_tool_use_id: null,
} as CLIUserMessage;
}
const q = query({
prompt: createToolConversation(),
options: {
...SHARED_TEST_OPTIONS,
cwd: process.cwd(),
debug: false,
yield {
type: 'user',
session_id: sessionId,
message: {
role: 'user',
content: 'List the files in the current directory',
},
});
parent_tool_use_id: null,
} as CLIUserMessage;
const messages: CLIMessage[] = [];
let toolUseCount = 0;
let assistantCount = 0;
await new Promise((resolve) => setTimeout(resolve, 200));
try {
for await (const message of q) {
messages.push(message);
yield {
type: 'user',
session_id: sessionId,
message: {
role: 'user',
content: 'Now tell me about the package.json file specifically',
},
parent_tool_use_id: null,
} as CLIUserMessage;
}
if (isCLIAssistantMessage(message)) {
const hasToolUseBlock = message.message.content.some(
(block: ContentBlock): block is ToolUseBlock =>
block.type === 'tool_use',
);
if (hasToolUseBlock) {
toolUseCount++;
}
}
const q = query({
prompt: createToolConversation(),
options: {
...SHARED_TEST_OPTIONS,
cwd: process.cwd(),
debug: false,
},
});
if (isCLIAssistantMessage(message)) {
assistantCount++;
}
const messages: CLIMessage[] = [];
let toolUseCount = 0;
let assistantCount = 0;
if (isCLIResultMessage(message)) {
break;
try {
for await (const message of q) {
messages.push(message);
if (isCLIAssistantMessage(message)) {
const hasToolUseBlock = message.message.content.some(
(block: ContentBlock): block is ToolUseBlock =>
block.type === 'tool_use',
);
if (hasToolUseBlock) {
toolUseCount++;
}
}
expect(messages.length).toBeGreaterThan(0);
expect(toolUseCount).toBeGreaterThan(0); // Should use tools
expect(assistantCount).toBeGreaterThanOrEqual(2); // Should have responses to both questions
} finally {
await q.close();
if (isCLIAssistantMessage(message)) {
assistantCount++;
}
}
},
TEST_TIMEOUT,
);
expect(messages.length).toBeGreaterThan(0);
expect(toolUseCount).toBeGreaterThan(0); // Should use tools
expect(assistantCount).toBeGreaterThanOrEqual(2); // Should have responses to both questions
} finally {
await q.close();
}
}, 60000); //TEST_TIMEOUT,
});
describe('Message Flow and Sequencing', () => {

View File

@@ -0,0 +1,748 @@
/**
* E2E tests for permission control features:
* - canUseTool callback parameter
* - setPermissionMode API
*/
import { describe, it, expect, beforeAll, afterAll } from 'vitest';
import { query } from '../../src/index.js';
import {
isCLIAssistantMessage,
isCLIResultMessage,
isCLIUserMessage,
type CLIUserMessage,
type ToolUseBlock,
type ContentBlock,
} from '../../src/types/protocol.js';
const TEST_CLI_PATH =
'/Users/mingholy/Work/Projects/qwen-code/packages/cli/index.ts';
const TEST_TIMEOUT = 1600000;
const SHARED_TEST_OPTIONS = {
pathToQwenExecutable: TEST_CLI_PATH,
debug: false,
// env here sets environment variables for the CLI child process
env: {
// DEBUG: '1',
},
};
/**
* Factory function that creates a streaming input with a control point.
* After the first message is yielded, the generator waits for a resume signal,
* allowing the test code to call query instance methods like setPermissionMode.
*/
function createStreamingInputWithControlPoint(
firstMessage: string,
secondMessage: string,
): {
generator: AsyncIterable<CLIUserMessage>;
resume: () => void;
} {
let resumeResolve: (() => void) | null = null;
const resumePromise = new Promise<void>((resolve) => {
resumeResolve = resolve;
});
const generator = (async function* () {
const sessionId = crypto.randomUUID();
yield {
type: 'user',
session_id: sessionId,
message: {
role: 'user',
content: firstMessage,
},
parent_tool_use_id: null,
} as CLIUserMessage;
await new Promise((resolve) => setTimeout(resolve, 200));
await resumePromise;
await new Promise((resolve) => setTimeout(resolve, 200));
yield {
type: 'user',
session_id: sessionId,
message: {
role: 'user',
content: secondMessage,
},
parent_tool_use_id: null,
} as CLIUserMessage;
})();
const resume = () => {
if (resumeResolve) {
resumeResolve();
}
};
return { generator, resume };
}
describe('Permission Control (E2E)', () => {
beforeAll(() => {
//process.env['DEBUG'] = '1';
});
afterAll(() => {
delete process.env['DEBUG'];
});
describe('canUseTool callback parameter', () => {
it(
'should invoke canUseTool callback when tool is requested',
async () => {
const toolCalls: Array<{
toolName: string;
input: Record<string, unknown>;
}> = [];
const q = query({
prompt: 'Write a js hello world to file.',
options: {
...SHARED_TEST_OPTIONS,
permissionMode: 'default',
canUseTool: async (toolName, input) => {
toolCalls.push({ toolName, input });
console.log(toolName, input);
/*
{
behavior: 'allow',
updatedInput: input,
};
*/
return {
behavior: 'deny',
message: 'Tool execution denied by user.',
};
},
},
});
try {
let hasToolUse = false;
for await (const message of q) {
if (isCLIAssistantMessage(message)) {
const toolUseBlock = message.message.content.find(
(block: ContentBlock): block is ToolUseBlock =>
block.type === 'tool_use',
);
if (toolUseBlock) {
hasToolUse = true;
}
}
}
expect(hasToolUse).toBe(true);
expect(toolCalls.length).toBeGreaterThan(0);
expect(toolCalls[0].toolName).toBeDefined();
expect(toolCalls[0].input).toBeDefined();
} finally {
await q.close();
}
},
TEST_TIMEOUT,
);
it(
'should allow tool execution when canUseTool returns allow',
async () => {
let callbackInvoked = false;
const q = query({
prompt: 'List files in the current directory',
options: {
...SHARED_TEST_OPTIONS,
permissionMode: 'default',
canUseTool: async (toolName, input) => {
callbackInvoked = true;
return {
behavior: 'allow',
updatedInput: input,
};
},
},
});
try {
let hasToolResult = false;
for await (const message of q) {
if (isCLIUserMessage(message)) {
if (
Array.isArray(message.message.content) &&
message.message.content.some(
(block) => block.type === 'tool_result',
)
) {
hasToolResult = true;
}
}
if (isCLIResultMessage(message)) {
break;
}
}
expect(callbackInvoked).toBe(true);
expect(hasToolResult).toBe(true);
} finally {
await q.close();
}
},
TEST_TIMEOUT,
);
it(
'should deny tool execution when canUseTool returns deny',
async () => {
let callbackInvoked = false;
const q = query({
prompt: 'List files in the current directory',
options: {
...SHARED_TEST_OPTIONS,
permissionMode: 'default',
canUseTool: async () => {
callbackInvoked = true;
return {
behavior: 'deny',
message: 'Tool execution denied by test',
};
},
},
});
try {
for await (const message of q) {
if (isCLIResultMessage(message)) {
break;
}
}
expect(callbackInvoked).toBe(true);
// Tool use might still appear, but execution should be denied
// The exact behavior depends on CLI implementation
} finally {
await q.close();
}
},
TEST_TIMEOUT,
);
it(
'should pass suggestions to canUseTool callback',
async () => {
let receivedSuggestions: unknown = null;
const q = query({
prompt: 'List files in the current directory',
options: {
...SHARED_TEST_OPTIONS,
permissionMode: 'default',
canUseTool: async (toolName, input, options) => {
receivedSuggestions = options.suggestions;
return {
behavior: 'allow',
updatedInput: input,
};
},
},
});
try {
for await (const message of q) {
if (isCLIResultMessage(message)) {
break;
}
}
// Suggestions may be null or an array, depending on CLI implementation
expect(receivedSuggestions !== undefined).toBe(true);
} finally {
await q.close();
}
},
TEST_TIMEOUT,
);
it(
'should pass abort signal to canUseTool callback',
async () => {
let receivedSignal: AbortSignal | undefined = undefined;
const q = query({
prompt: 'List files in the current directory',
options: {
...SHARED_TEST_OPTIONS,
permissionMode: 'default',
canUseTool: async (toolName, input, options) => {
receivedSignal = options.signal;
return {
behavior: 'allow',
updatedInput: input,
};
},
},
});
try {
for await (const message of q) {
if (isCLIResultMessage(message)) {
break;
}
}
expect(receivedSignal).toBeDefined();
expect(receivedSignal).toBeInstanceOf(AbortSignal);
} finally {
await q.close();
}
},
TEST_TIMEOUT,
);
it(
'should allow updatedInput modification in canUseTool callback',
async () => {
const originalInputs: Record<string, unknown>[] = [];
const updatedInputs: Record<string, unknown>[] = [];
const q = query({
prompt: 'List files in the current directory',
options: {
...SHARED_TEST_OPTIONS,
permissionMode: 'default',
canUseTool: async (toolName, input) => {
originalInputs.push({ ...input });
const updatedInput = {
...input,
modified: true,
testKey: 'testValue',
};
updatedInputs.push(updatedInput);
return {
behavior: 'allow',
updatedInput,
};
},
},
});
try {
for await (const message of q) {
if (isCLIResultMessage(message)) {
break;
}
}
expect(originalInputs.length).toBeGreaterThan(0);
expect(updatedInputs.length).toBeGreaterThan(0);
expect(updatedInputs[0]?.['modified']).toBe(true);
expect(updatedInputs[0]?.['testKey']).toBe('testValue');
} finally {
await q.close();
}
},
TEST_TIMEOUT,
);
it(
'should default to deny when canUseTool is not provided',
async () => {
const q = query({
prompt: 'List files in the current directory',
options: {
...SHARED_TEST_OPTIONS,
permissionMode: 'default',
// canUseTool not provided
},
});
try {
// When canUseTool is not provided, tools should be denied by default
// The exact behavior depends on CLI implementation
for await (const message of q) {
if (isCLIResultMessage(message)) {
break;
}
}
// Test passes if no errors occur
expect(true).toBe(true);
} finally {
await q.close();
}
},
TEST_TIMEOUT,
);
});
describe('setPermissionMode API', () => {
it(
'should change permission mode from default to yolo',
async () => {
const { generator, resume } = createStreamingInputWithControlPoint(
'List files in the current directory',
'Now read the package.json file',
);
const q = query({
prompt: generator,
options: {
...SHARED_TEST_OPTIONS,
permissionMode: 'default',
},
});
try {
const resolvers: {
first?: () => void;
second?: () => void;
} = {};
const firstResponsePromise = new Promise<void>((resolve) => {
resolvers.first = resolve;
});
const secondResponsePromise = new Promise<void>((resolve) => {
resolvers.second = resolve;
});
let firstResponseReceived = false;
let secondResponseReceived = false;
(async () => {
for await (const message of q) {
if (
isCLIAssistantMessage(message) ||
isCLIResultMessage(message)
) {
if (!firstResponseReceived) {
firstResponseReceived = true;
resolvers.first?.();
} else if (!secondResponseReceived) {
secondResponseReceived = true;
resolvers.second?.();
}
}
}
})();
await Promise.race([
firstResponsePromise,
new Promise((_, reject) =>
setTimeout(
() => reject(new Error('Timeout waiting for first response')),
TEST_TIMEOUT,
),
),
]);
expect(firstResponseReceived).toBe(true);
await q.setPermissionMode('yolo');
resume();
await Promise.race([
secondResponsePromise,
new Promise((_, reject) =>
setTimeout(
() => reject(new Error('Timeout waiting for second response')),
TEST_TIMEOUT,
),
),
]);
expect(secondResponseReceived).toBe(true);
} finally {
await q.close();
}
},
TEST_TIMEOUT,
);
it(
'should change permission mode from yolo to plan',
async () => {
const { generator, resume } = createStreamingInputWithControlPoint(
'List files in the current directory',
'Now read the package.json file',
);
const q = query({
prompt: generator,
options: {
...SHARED_TEST_OPTIONS,
permissionMode: 'yolo',
},
});
try {
const resolvers: {
first?: () => void;
second?: () => void;
} = {};
const firstResponsePromise = new Promise<void>((resolve) => {
resolvers.first = resolve;
});
const secondResponsePromise = new Promise<void>((resolve) => {
resolvers.second = resolve;
});
let firstResponseReceived = false;
let secondResponseReceived = false;
(async () => {
for await (const message of q) {
if (
isCLIAssistantMessage(message) ||
isCLIResultMessage(message)
) {
if (!firstResponseReceived) {
firstResponseReceived = true;
resolvers.first?.();
} else if (!secondResponseReceived) {
secondResponseReceived = true;
resolvers.second?.();
}
}
}
})();
await Promise.race([
firstResponsePromise,
new Promise((_, reject) =>
setTimeout(
() => reject(new Error('Timeout waiting for first response')),
TEST_TIMEOUT,
),
),
]);
expect(firstResponseReceived).toBe(true);
await q.setPermissionMode('plan');
resume();
await Promise.race([
secondResponsePromise,
new Promise((_, reject) =>
setTimeout(
() => reject(new Error('Timeout waiting for second response')),
TEST_TIMEOUT,
),
),
]);
expect(secondResponseReceived).toBe(true);
} finally {
await q.close();
}
},
TEST_TIMEOUT,
);
it(
'should change permission mode to auto-edit',
async () => {
const { generator, resume } = createStreamingInputWithControlPoint(
'List files in the current directory',
'Now read the package.json file',
);
const q = query({
prompt: generator,
options: {
...SHARED_TEST_OPTIONS,
permissionMode: 'default',
},
});
try {
const resolvers: {
first?: () => void;
second?: () => void;
} = {};
const firstResponsePromise = new Promise<void>((resolve) => {
resolvers.first = resolve;
});
const secondResponsePromise = new Promise<void>((resolve) => {
resolvers.second = resolve;
});
let firstResponseReceived = false;
let secondResponseReceived = false;
(async () => {
for await (const message of q) {
if (
isCLIAssistantMessage(message) ||
isCLIResultMessage(message)
) {
if (!firstResponseReceived) {
firstResponseReceived = true;
resolvers.first?.();
} else if (!secondResponseReceived) {
secondResponseReceived = true;
resolvers.second?.();
}
}
}
})();
await Promise.race([
firstResponsePromise,
new Promise((_, reject) =>
setTimeout(
() => reject(new Error('Timeout waiting for first response')),
TEST_TIMEOUT,
),
),
]);
expect(firstResponseReceived).toBe(true);
await q.setPermissionMode('auto-edit');
resume();
await Promise.race([
secondResponsePromise,
new Promise((_, reject) =>
setTimeout(
() => reject(new Error('Timeout waiting for second response')),
TEST_TIMEOUT,
),
),
]);
expect(secondResponseReceived).toBe(true);
} finally {
await q.close();
}
},
TEST_TIMEOUT,
);
it(
'should throw error when setPermissionMode is called on closed query',
async () => {
const q = query({
prompt: 'Hello',
options: {
...SHARED_TEST_OPTIONS,
permissionMode: 'default',
},
});
await q.close();
await expect(q.setPermissionMode('yolo')).rejects.toThrow(
'Query is closed',
);
},
TEST_TIMEOUT,
);
});
describe('canUseTool and setPermissionMode integration', () => {
it(
'should work together - canUseTool callback with dynamic permission mode change',
async () => {
const toolCalls: Array<{
toolName: string;
input: Record<string, unknown>;
}> = [];
const { generator, resume } = createStreamingInputWithControlPoint(
'List files in the current directory',
'Now read the package.json file',
);
const q = query({
prompt: generator,
options: {
...SHARED_TEST_OPTIONS,
permissionMode: 'default',
canUseTool: async (toolName, input) => {
toolCalls.push({ toolName, input });
return {
behavior: 'allow',
updatedInput: input,
};
},
},
});
try {
const resolvers: {
first?: () => void;
second?: () => void;
} = {};
const firstResponsePromise = new Promise<void>((resolve) => {
resolvers.first = resolve;
});
const secondResponsePromise = new Promise<void>((resolve) => {
resolvers.second = resolve;
});
let firstResponseReceived = false;
let secondResponseReceived = false;
(async () => {
for await (const message of q) {
if (
isCLIAssistantMessage(message) ||
isCLIResultMessage(message)
) {
if (!firstResponseReceived) {
firstResponseReceived = true;
resolvers.first?.();
} else if (!secondResponseReceived) {
secondResponseReceived = true;
resolvers.second?.();
}
}
}
})();
await Promise.race([
firstResponsePromise,
new Promise((_, reject) =>
setTimeout(
() => reject(new Error('Timeout waiting for first response')),
TEST_TIMEOUT,
),
),
]);
expect(firstResponseReceived).toBe(true);
expect(toolCalls.length).toBeGreaterThan(0);
await q.setPermissionMode('yolo');
resume();
await Promise.race([
secondResponsePromise,
new Promise((_, reject) =>
setTimeout(
() => reject(new Error('Timeout waiting for second response')),
TEST_TIMEOUT,
),
),
]);
expect(secondResponseReceived).toBe(true);
} finally {
await q.close();
}
},
TEST_TIMEOUT,
);
});
});