mirror of
https://github.com/QwenLM/qwen-code.git
synced 2026-01-23 17:26:23 +00:00
Compare commits
4 Commits
refactor/a
...
mingholy/f
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f578ff07a2 | ||
|
|
6eb16c0bcf | ||
|
|
7fa1dcb0e6 | ||
|
|
03f12bfa3f |
@@ -11,8 +11,10 @@ import {
|
||||
AbortError,
|
||||
isAbortError,
|
||||
isSDKAssistantMessage,
|
||||
isSDKResultMessage,
|
||||
type TextBlock,
|
||||
type ContentBlock,
|
||||
type SDKUserMessage,
|
||||
} from '@qwen-code/sdk';
|
||||
import { SDKTestHelper, createSharedTestOptions } from './test-helper.js';
|
||||
|
||||
@@ -250,6 +252,161 @@ describe('AbortController and Process Lifecycle (E2E)', () => {
|
||||
});
|
||||
});
|
||||
|
||||
describe('Closed stdin behavior (asyncGenerator prompt)', () => {
|
||||
it('should reject control requests after stdin closes', async () => {
|
||||
async function* createPrompt(): AsyncIterable<SDKUserMessage> {
|
||||
yield {
|
||||
type: 'user',
|
||||
session_id: crypto.randomUUID(),
|
||||
message: {
|
||||
role: 'user',
|
||||
content: 'Say "OK".',
|
||||
},
|
||||
parent_tool_use_id: null,
|
||||
};
|
||||
}
|
||||
|
||||
const q = query({
|
||||
prompt: createPrompt(),
|
||||
options: {
|
||||
...SHARED_TEST_OPTIONS,
|
||||
cwd: testDir,
|
||||
debug: false,
|
||||
},
|
||||
});
|
||||
|
||||
let firstResultReceived = false;
|
||||
|
||||
try {
|
||||
for await (const message of q) {
|
||||
if (isSDKResultMessage(message)) {
|
||||
firstResultReceived = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
expect(firstResultReceived).toBe(true);
|
||||
|
||||
await new Promise((resolve) => setTimeout(resolve, 50));
|
||||
|
||||
await expect(q.setPermissionMode('default')).rejects.toThrow(
|
||||
'Input stream closed',
|
||||
);
|
||||
} finally {
|
||||
await q.close();
|
||||
}
|
||||
});
|
||||
|
||||
it('should handle control responses when stdin closes before replies', async () => {
|
||||
await helper.createFile('test.txt', 'original content');
|
||||
|
||||
let canUseToolCalledResolve: () => void = () => {};
|
||||
const canUseToolCalledPromise = new Promise<void>((resolve, reject) => {
|
||||
canUseToolCalledResolve = resolve;
|
||||
setTimeout(() => {
|
||||
reject(new Error('canUseTool callback not called'));
|
||||
}, 15000);
|
||||
});
|
||||
|
||||
let inputStreamDoneResolve: () => void = () => {};
|
||||
const inputStreamDonePromise = new Promise<void>((resolve, reject) => {
|
||||
inputStreamDoneResolve = resolve;
|
||||
setTimeout(() => {
|
||||
reject(new Error('inputStreamDonePromise timeout'));
|
||||
}, 15000);
|
||||
});
|
||||
|
||||
let firstResultResolve: () => void = () => {};
|
||||
const firstResultPromise = new Promise<void>((resolve) => {
|
||||
firstResultResolve = resolve;
|
||||
});
|
||||
|
||||
let secondResultResolve: () => void = () => {};
|
||||
const secondResultPromise = new Promise<void>((resolve, reject) => {
|
||||
secondResultResolve = resolve;
|
||||
});
|
||||
|
||||
async function* createPrompt(): AsyncIterable<SDKUserMessage> {
|
||||
const sessionId = crypto.randomUUID();
|
||||
|
||||
yield {
|
||||
type: 'user',
|
||||
session_id: sessionId,
|
||||
message: {
|
||||
role: 'user',
|
||||
content: 'Say "OK".',
|
||||
},
|
||||
parent_tool_use_id: null,
|
||||
};
|
||||
|
||||
await firstResultPromise;
|
||||
|
||||
yield {
|
||||
type: 'user',
|
||||
session_id: sessionId,
|
||||
message: {
|
||||
role: 'user',
|
||||
content: 'Write "updated" to test.txt.',
|
||||
},
|
||||
parent_tool_use_id: null,
|
||||
};
|
||||
await inputStreamDonePromise;
|
||||
}
|
||||
|
||||
const q = query({
|
||||
prompt: createPrompt(),
|
||||
options: {
|
||||
...SHARED_TEST_OPTIONS,
|
||||
cwd: testDir,
|
||||
permissionMode: 'default',
|
||||
coreTools: ['read_file', 'write_file'],
|
||||
canUseTool: async (toolName, input) => {
|
||||
inputStreamDoneResolve();
|
||||
await new Promise((resolve) => setTimeout(resolve, 1000));
|
||||
canUseToolCalledResolve();
|
||||
|
||||
return {
|
||||
behavior: 'allow',
|
||||
updatedInput: input,
|
||||
};
|
||||
},
|
||||
debug: false,
|
||||
},
|
||||
});
|
||||
|
||||
try {
|
||||
const loop = async () => {
|
||||
let resultCount = 0;
|
||||
for await (const _message of q) {
|
||||
console.log(JSON.stringify(_message, null, 2));
|
||||
// Consume messages until completion.
|
||||
if (isSDKResultMessage(_message)) {
|
||||
resultCount += 1;
|
||||
if (resultCount === 1) {
|
||||
firstResultResolve();
|
||||
}
|
||||
if (resultCount === 2) {
|
||||
secondResultResolve();
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
loop();
|
||||
|
||||
await firstResultPromise;
|
||||
await canUseToolCalledPromise;
|
||||
await secondResultPromise;
|
||||
|
||||
const content = await helper.readFile('test.txt');
|
||||
expect(content).toBe('original content');
|
||||
} finally {
|
||||
await q.close();
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
describe('Error Handling and Recovery', () => {
|
||||
it('should handle invalid executable path', async () => {
|
||||
try {
|
||||
|
||||
@@ -12,7 +12,12 @@
|
||||
*/
|
||||
|
||||
import { describe, it, expect, beforeEach, afterEach } from 'vitest';
|
||||
import { query, isSDKAssistantMessage, type SDKMessage } from '@qwen-code/sdk';
|
||||
import {
|
||||
query,
|
||||
isSDKAssistantMessage,
|
||||
type SDKMessage,
|
||||
type SDKUserMessage,
|
||||
} from '@qwen-code/sdk';
|
||||
import {
|
||||
SDKTestHelper,
|
||||
extractText,
|
||||
@@ -739,4 +744,229 @@ describe('Tool Control Parameters (E2E)', () => {
|
||||
TEST_TIMEOUT,
|
||||
);
|
||||
});
|
||||
|
||||
describe('canUseTool with asyncGenerator prompt', () => {
|
||||
it(
|
||||
'should invoke canUseTool callback when using asyncGenerator as prompt',
|
||||
async () => {
|
||||
await helper.createFile('test.txt', 'original content');
|
||||
|
||||
const canUseToolCalls: Array<{
|
||||
toolName: string;
|
||||
input: Record<string, unknown>;
|
||||
}> = [];
|
||||
|
||||
// Create an async generator that yields a single message
|
||||
async function* createPrompt(): AsyncIterable<SDKUserMessage> {
|
||||
yield {
|
||||
type: 'user',
|
||||
session_id: crypto.randomUUID(),
|
||||
message: {
|
||||
role: 'user',
|
||||
content: 'Read test.txt and then write "updated" to it.',
|
||||
},
|
||||
parent_tool_use_id: null,
|
||||
};
|
||||
|
||||
await new Promise((resolve) => setTimeout(resolve, 3000));
|
||||
}
|
||||
|
||||
const q = query({
|
||||
prompt: createPrompt(),
|
||||
options: {
|
||||
...SHARED_TEST_OPTIONS,
|
||||
cwd: testDir,
|
||||
permissionMode: 'default',
|
||||
coreTools: ['read_file', 'write_file'],
|
||||
allowedTools: [],
|
||||
canUseTool: async (toolName, input) => {
|
||||
canUseToolCalls.push({ toolName, input });
|
||||
return {
|
||||
behavior: 'allow',
|
||||
updatedInput: input,
|
||||
};
|
||||
},
|
||||
debug: false,
|
||||
},
|
||||
});
|
||||
|
||||
const messages: SDKMessage[] = [];
|
||||
|
||||
try {
|
||||
for await (const message of q) {
|
||||
messages.push(message);
|
||||
}
|
||||
|
||||
const toolCalls = findToolCalls(messages);
|
||||
const toolNames = toolCalls.map((tc) => tc.toolUse.name);
|
||||
|
||||
// Both tools should have been executed
|
||||
expect(toolNames).toContain('read_file');
|
||||
expect(toolNames).toContain('write_file');
|
||||
|
||||
const toolsCalledInCallback = canUseToolCalls.map(
|
||||
(call) => call.toolName,
|
||||
);
|
||||
expect(toolsCalledInCallback).toContain('write_file');
|
||||
|
||||
const writeFileResults = findToolResults(messages, 'write_file');
|
||||
expect(writeFileResults.length).toBeGreaterThan(0);
|
||||
|
||||
// Verify file was modified
|
||||
const content = await helper.readFile('test.txt');
|
||||
expect(content).toBe('updated');
|
||||
} finally {
|
||||
await q.close();
|
||||
}
|
||||
},
|
||||
TEST_TIMEOUT,
|
||||
);
|
||||
|
||||
it(
|
||||
'should deny tool when canUseTool returns deny with asyncGenerator prompt',
|
||||
async () => {
|
||||
await helper.createFile('test.txt', 'original content');
|
||||
|
||||
// Create an async generator that yields a single message
|
||||
async function* createPrompt(): AsyncIterable<SDKUserMessage> {
|
||||
yield {
|
||||
type: 'user',
|
||||
session_id: crypto.randomUUID(),
|
||||
message: {
|
||||
role: 'user',
|
||||
content: 'Write "modified" to test.txt.',
|
||||
},
|
||||
parent_tool_use_id: null,
|
||||
};
|
||||
await new Promise((resolve) => setTimeout(resolve, 3000));
|
||||
}
|
||||
|
||||
const q = query({
|
||||
prompt: createPrompt(),
|
||||
options: {
|
||||
...SHARED_TEST_OPTIONS,
|
||||
cwd: testDir,
|
||||
permissionMode: 'default',
|
||||
coreTools: ['read_file', 'write_file'],
|
||||
canUseTool: async (toolName) => {
|
||||
if (toolName === 'write_file') {
|
||||
return {
|
||||
behavior: 'deny',
|
||||
message: 'Write operations are not allowed',
|
||||
};
|
||||
}
|
||||
return { behavior: 'allow', updatedInput: {} };
|
||||
},
|
||||
debug: false,
|
||||
},
|
||||
});
|
||||
|
||||
const messages: SDKMessage[] = [];
|
||||
|
||||
try {
|
||||
for await (const message of q) {
|
||||
messages.push(message);
|
||||
}
|
||||
|
||||
// write_file should have been attempted but stream was closed
|
||||
const writeFileResults = findToolResults(messages, 'write_file');
|
||||
expect(writeFileResults.length).toBeGreaterThan(0);
|
||||
for (const result of writeFileResults) {
|
||||
expect(result.content).toContain(
|
||||
'[Operation Cancelled] Reason: Write operations are not allowed',
|
||||
);
|
||||
}
|
||||
|
||||
// File content should remain unchanged (because write was denied)
|
||||
const content = await helper.readFile('test.txt');
|
||||
expect(content).toBe('original content');
|
||||
} finally {
|
||||
await q.close();
|
||||
}
|
||||
},
|
||||
TEST_TIMEOUT,
|
||||
);
|
||||
|
||||
it(
|
||||
'should support multi-turn conversation with canUseTool using asyncGenerator',
|
||||
async () => {
|
||||
await helper.createFile('data.txt', 'initial data');
|
||||
|
||||
const canUseToolCalls: string[] = [];
|
||||
|
||||
// Create an async generator that yields multiple messages
|
||||
async function* createMultiTurnPrompt(): AsyncIterable<SDKUserMessage> {
|
||||
const sessionId = crypto.randomUUID();
|
||||
|
||||
yield {
|
||||
type: 'user',
|
||||
session_id: sessionId,
|
||||
message: {
|
||||
role: 'user',
|
||||
content: 'Read data.txt and tell me what it contains.',
|
||||
},
|
||||
parent_tool_use_id: null,
|
||||
};
|
||||
|
||||
// Small delay to simulate multi-turn conversation
|
||||
await new Promise((resolve) => setTimeout(resolve, 100));
|
||||
|
||||
yield {
|
||||
type: 'user',
|
||||
session_id: sessionId,
|
||||
message: {
|
||||
role: 'user',
|
||||
content: 'Now append " - updated" to the file content.',
|
||||
},
|
||||
parent_tool_use_id: null,
|
||||
};
|
||||
}
|
||||
|
||||
const q = query({
|
||||
prompt: createMultiTurnPrompt(),
|
||||
options: {
|
||||
...SHARED_TEST_OPTIONS,
|
||||
cwd: testDir,
|
||||
permissionMode: 'default',
|
||||
coreTools: ['read_file', 'write_file'],
|
||||
canUseTool: async (toolName) => {
|
||||
canUseToolCalls.push(toolName);
|
||||
return { behavior: 'allow', updatedInput: {} };
|
||||
},
|
||||
debug: false,
|
||||
},
|
||||
});
|
||||
|
||||
const messages: SDKMessage[] = [];
|
||||
|
||||
try {
|
||||
for await (const message of q) {
|
||||
messages.push(message);
|
||||
}
|
||||
|
||||
const toolCalls = findToolCalls(messages);
|
||||
const toolNames = toolCalls.map((tc) => tc.toolUse.name);
|
||||
|
||||
// Should have read_file and write_file calls
|
||||
expect(toolNames).toContain('read_file');
|
||||
expect(toolNames).toContain('write_file');
|
||||
|
||||
// canUseTool should not be called once stream is closed
|
||||
expect(canUseToolCalls).toHaveLength(0);
|
||||
|
||||
const writeFileResults = findToolResults(messages, 'write_file');
|
||||
expect(writeFileResults.length).toBeGreaterThan(0);
|
||||
for (const result of writeFileResults) {
|
||||
expect(result.content).toContain('Error: Input closed');
|
||||
}
|
||||
|
||||
const content = await helper.readFile('data.txt');
|
||||
expect(content).toBe('initial data');
|
||||
} finally {
|
||||
await q.close();
|
||||
}
|
||||
},
|
||||
TEST_TIMEOUT,
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -35,6 +35,7 @@ export interface IControlContext {
|
||||
permissionMode: PermissionMode;
|
||||
sdkMcpServers: Set<string>;
|
||||
mcpClients: Map<string, { client: Client; config: MCPServerConfig }>;
|
||||
inputClosed: boolean;
|
||||
|
||||
onInterrupt?: () => void;
|
||||
}
|
||||
@@ -52,6 +53,7 @@ export class ControlContext implements IControlContext {
|
||||
permissionMode: PermissionMode;
|
||||
sdkMcpServers: Set<string>;
|
||||
mcpClients: Map<string, { client: Client; config: MCPServerConfig }>;
|
||||
inputClosed: boolean;
|
||||
|
||||
onInterrupt?: () => void;
|
||||
|
||||
@@ -71,6 +73,7 @@ export class ControlContext implements IControlContext {
|
||||
this.permissionMode = options.permissionMode || 'default';
|
||||
this.sdkMcpServers = new Set();
|
||||
this.mcpClients = new Map();
|
||||
this.inputClosed = false;
|
||||
this.onInterrupt = options.onInterrupt;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -42,6 +42,7 @@ function createMockContext(debugMode: boolean = false): IControlContext {
|
||||
permissionMode: 'default',
|
||||
sdkMcpServers: new Set<string>(),
|
||||
mcpClients: new Map(),
|
||||
inputClosed: false,
|
||||
};
|
||||
}
|
||||
|
||||
@@ -637,6 +638,130 @@ describe('ControlDispatcher', () => {
|
||||
});
|
||||
});
|
||||
|
||||
describe('markInputClosed', () => {
|
||||
it('should reject all pending outgoing requests when input closes', () => {
|
||||
const requestId1 = 'reject-req-1';
|
||||
const requestId2 = 'reject-req-2';
|
||||
const resolve1 = vi.fn();
|
||||
const resolve2 = vi.fn();
|
||||
const reject1 = vi.fn();
|
||||
const reject2 = vi.fn();
|
||||
const timeoutId1 = setTimeout(() => {}, 1000);
|
||||
const timeoutId2 = setTimeout(() => {}, 1000);
|
||||
const clearTimeoutSpy = vi.spyOn(global, 'clearTimeout');
|
||||
|
||||
const register = (
|
||||
dispatcher as unknown as {
|
||||
registerOutgoingRequest: (
|
||||
id: string,
|
||||
controller: string,
|
||||
resolve: (response: ControlResponse) => void,
|
||||
reject: (error: Error) => void,
|
||||
timeoutId: NodeJS.Timeout,
|
||||
) => void;
|
||||
}
|
||||
).registerOutgoingRequest.bind(dispatcher);
|
||||
|
||||
register(requestId1, 'SystemController', resolve1, reject1, timeoutId1);
|
||||
register(requestId2, 'SystemController', resolve2, reject2, timeoutId2);
|
||||
|
||||
dispatcher.markInputClosed();
|
||||
|
||||
expect(reject1).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ message: 'Input closed' }),
|
||||
);
|
||||
expect(reject2).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ message: 'Input closed' }),
|
||||
);
|
||||
expect(clearTimeoutSpy).toHaveBeenCalledWith(timeoutId1);
|
||||
expect(clearTimeoutSpy).toHaveBeenCalledWith(timeoutId2);
|
||||
});
|
||||
|
||||
it('should mark input as closed on context', () => {
|
||||
dispatcher.markInputClosed();
|
||||
expect(mockContext.inputClosed).toBe(true);
|
||||
});
|
||||
|
||||
it('should handle empty pending requests gracefully', () => {
|
||||
expect(() => dispatcher.markInputClosed()).not.toThrow();
|
||||
});
|
||||
|
||||
it('should be idempotent when called multiple times', () => {
|
||||
const requestId = 'idempotent-req';
|
||||
const resolve = vi.fn();
|
||||
const reject = vi.fn();
|
||||
const timeoutId = setTimeout(() => {}, 1000);
|
||||
|
||||
(
|
||||
dispatcher as unknown as {
|
||||
registerOutgoingRequest: (
|
||||
id: string,
|
||||
controller: string,
|
||||
resolve: (response: ControlResponse) => void,
|
||||
reject: (error: Error) => void,
|
||||
timeoutId: NodeJS.Timeout,
|
||||
) => void;
|
||||
}
|
||||
).registerOutgoingRequest(
|
||||
requestId,
|
||||
'SystemController',
|
||||
resolve,
|
||||
reject,
|
||||
timeoutId,
|
||||
);
|
||||
|
||||
dispatcher.markInputClosed();
|
||||
const firstRejectCount = vi.mocked(reject).mock.calls.length;
|
||||
|
||||
// Call again - should not reject again
|
||||
dispatcher.markInputClosed();
|
||||
const secondRejectCount = vi.mocked(reject).mock.calls.length;
|
||||
|
||||
expect(secondRejectCount).toBe(firstRejectCount);
|
||||
});
|
||||
|
||||
it('should log input closure in debug mode', () => {
|
||||
const context = createMockContext(true);
|
||||
const consoleSpy = vi
|
||||
.spyOn(console, 'error')
|
||||
.mockImplementation(() => {});
|
||||
|
||||
const dispatcherWithDebug = new ControlDispatcher(context);
|
||||
const requestId = 'reject-req-debug';
|
||||
const resolve = vi.fn();
|
||||
const reject = vi.fn();
|
||||
const timeoutId = setTimeout(() => {}, 1000);
|
||||
|
||||
(
|
||||
dispatcherWithDebug as unknown as {
|
||||
registerOutgoingRequest: (
|
||||
id: string,
|
||||
controller: string,
|
||||
resolve: (response: ControlResponse) => void,
|
||||
reject: (error: Error) => void,
|
||||
timeoutId: NodeJS.Timeout,
|
||||
) => void;
|
||||
}
|
||||
).registerOutgoingRequest(
|
||||
requestId,
|
||||
'SystemController',
|
||||
resolve,
|
||||
reject,
|
||||
timeoutId,
|
||||
);
|
||||
|
||||
dispatcherWithDebug.markInputClosed();
|
||||
|
||||
expect(consoleSpy).toHaveBeenCalledWith(
|
||||
expect.stringContaining(
|
||||
'[ControlDispatcher] Input closed, rejecting 1 pending outgoing requests',
|
||||
),
|
||||
);
|
||||
|
||||
consoleSpy.mockRestore();
|
||||
});
|
||||
});
|
||||
|
||||
describe('shutdown', () => {
|
||||
it('should cancel all pending incoming requests', () => {
|
||||
const requestId1 = 'shutdown-req-1';
|
||||
|
||||
@@ -207,6 +207,36 @@ export class ControlDispatcher implements IPendingRequestRegistry {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Marks stdin as closed and rejects all pending outgoing requests.
|
||||
* After this is called, new outgoing requests will be rejected immediately.
|
||||
* This should be called when stdin closes to avoid waiting for responses.
|
||||
*/
|
||||
markInputClosed(): void {
|
||||
if (this.context.inputClosed) {
|
||||
return; // Already marked as closed
|
||||
}
|
||||
|
||||
this.context.inputClosed = true;
|
||||
|
||||
const requestIds = Array.from(this.pendingOutgoingRequests.keys());
|
||||
|
||||
if (this.context.debugMode) {
|
||||
console.error(
|
||||
`[ControlDispatcher] Input closed, rejecting ${requestIds.length} pending outgoing requests`,
|
||||
);
|
||||
}
|
||||
|
||||
// Reject all currently pending outgoing requests
|
||||
for (const id of requestIds) {
|
||||
const pending = this.pendingOutgoingRequests.get(id);
|
||||
if (pending) {
|
||||
this.deregisterOutgoingRequest(id);
|
||||
pending.reject(new Error('Input closed'));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Stops all pending requests and cleans up all controllers
|
||||
*/
|
||||
@@ -243,7 +273,7 @@ export class ControlDispatcher implements IPendingRequestRegistry {
|
||||
}
|
||||
|
||||
/**
|
||||
* Registers an incoming request in the pending registry
|
||||
* Registers an incoming request in the pending registry.
|
||||
*/
|
||||
registerIncomingRequest(
|
||||
requestId: string,
|
||||
|
||||
@@ -124,6 +124,11 @@ export abstract class BaseController {
|
||||
timeoutMs: number = DEFAULT_REQUEST_TIMEOUT_MS,
|
||||
signal?: AbortSignal,
|
||||
): Promise<ControlResponse> {
|
||||
// Check if stream is closed
|
||||
if (this.context.inputClosed) {
|
||||
throw new Error('Input closed');
|
||||
}
|
||||
|
||||
// Check if already aborted
|
||||
if (signal?.aborted) {
|
||||
throw new Error('Request aborted');
|
||||
|
||||
@@ -469,21 +469,27 @@ export class PermissionController extends BaseController {
|
||||
error,
|
||||
);
|
||||
}
|
||||
// On error, use default cancel message
|
||||
|
||||
// Extract error message
|
||||
const errorMessage =
|
||||
error instanceof Error ? error.message : String(error);
|
||||
|
||||
// On error, pass error message as cancel message
|
||||
// Only pass payload for exec and mcp types that support it
|
||||
const confirmationType = toolCall.confirmationDetails.type;
|
||||
if (['edit', 'exec', 'mcp'].includes(confirmationType)) {
|
||||
const execOrMcpDetails = toolCall.confirmationDetails as
|
||||
| ToolExecuteConfirmationDetails
|
||||
| ToolMcpConfirmationDetails;
|
||||
await execOrMcpDetails.onConfirm(
|
||||
ToolConfirmationOutcome.Cancel,
|
||||
undefined,
|
||||
);
|
||||
await execOrMcpDetails.onConfirm(ToolConfirmationOutcome.Cancel, {
|
||||
cancelMessage: `Error: ${errorMessage}`,
|
||||
});
|
||||
} else {
|
||||
// For other types, don't pass payload (backward compatible)
|
||||
await toolCall.confirmationDetails.onConfirm(
|
||||
ToolConfirmationOutcome.Cancel,
|
||||
{
|
||||
cancelMessage: `Error: ${errorMessage}`,
|
||||
},
|
||||
);
|
||||
}
|
||||
} finally {
|
||||
|
||||
@@ -153,6 +153,7 @@ describe('runNonInteractiveStreamJson', () => {
|
||||
handleControlResponse: ReturnType<typeof vi.fn>;
|
||||
handleCancel: ReturnType<typeof vi.fn>;
|
||||
shutdown: ReturnType<typeof vi.fn>;
|
||||
markInputClosed: ReturnType<typeof vi.fn>;
|
||||
getPendingIncomingRequestCount: ReturnType<typeof vi.fn>;
|
||||
waitForPendingIncomingRequests: ReturnType<typeof vi.fn>;
|
||||
sdkMcpController: {
|
||||
@@ -192,6 +193,7 @@ describe('runNonInteractiveStreamJson', () => {
|
||||
handleControlResponse: vi.fn(),
|
||||
handleCancel: vi.fn(),
|
||||
shutdown: vi.fn(),
|
||||
markInputClosed: vi.fn(),
|
||||
getPendingIncomingRequestCount: vi.fn().mockReturnValue(0),
|
||||
waitForPendingIncomingRequests: vi.fn().mockResolvedValue(undefined),
|
||||
sdkMcpController: {
|
||||
|
||||
@@ -596,7 +596,14 @@ class Session {
|
||||
throw streamError;
|
||||
}
|
||||
|
||||
// Stream ended - wait for all pending work before shutdown
|
||||
// Stdin closed - mark input as closed in dispatcher
|
||||
// This will reject all current pending outgoing requests AND any future requests
|
||||
// that might be registered by async message handlers still running
|
||||
if (this.dispatcher) {
|
||||
this.dispatcher.markInputClosed();
|
||||
}
|
||||
|
||||
// Wait for all pending work before shutdown
|
||||
await this.waitForAllPendingWork();
|
||||
await this.shutdown();
|
||||
} catch (error) {
|
||||
|
||||
@@ -102,16 +102,14 @@ export const QWEN_OAUTH_ALLOWED_MODELS = [
|
||||
export const QWEN_OAUTH_MODELS: ModelConfig[] = [
|
||||
{
|
||||
id: 'coder-model',
|
||||
name: 'Qwen Coder',
|
||||
description:
|
||||
'The latest Qwen Coder model from Alibaba Cloud ModelStudio (version: qwen3-coder-plus-2025-09-23)',
|
||||
name: 'coder-model',
|
||||
description: 'The latest Qwen Coder model from Alibaba Cloud ModelStudio',
|
||||
capabilities: { vision: false },
|
||||
},
|
||||
{
|
||||
id: 'vision-model',
|
||||
name: 'Qwen Vision',
|
||||
description:
|
||||
'The latest Qwen Vision model from Alibaba Cloud ModelStudio (version: qwen3-vl-plus-2025-09-23)',
|
||||
name: 'vision-model',
|
||||
description: 'The latest Qwen Vision model from Alibaba Cloud ModelStudio',
|
||||
capabilities: { vision: true },
|
||||
},
|
||||
];
|
||||
|
||||
@@ -663,7 +663,21 @@ export class Query implements AsyncIterable<SDKMessage> {
|
||||
},
|
||||
);
|
||||
|
||||
this.transport.write(serializeJsonLine(request));
|
||||
try {
|
||||
this.transport.write(serializeJsonLine(request));
|
||||
} catch (error) {
|
||||
const pending = this.pendingControlRequests.get(requestId);
|
||||
if (pending) {
|
||||
clearTimeout(pending.timeout);
|
||||
this.pendingControlRequests.delete(requestId);
|
||||
}
|
||||
const errorMsg = error instanceof Error ? error.message : String(error);
|
||||
logger.error(`Failed to send control request: ${errorMsg}`);
|
||||
return Promise.reject(
|
||||
new Error(`Failed to send control request: ${errorMsg}`),
|
||||
);
|
||||
}
|
||||
|
||||
return responsePromise;
|
||||
}
|
||||
|
||||
@@ -687,7 +701,15 @@ export class Query implements AsyncIterable<SDKMessage> {
|
||||
},
|
||||
};
|
||||
|
||||
this.transport.write(serializeJsonLine(response));
|
||||
try {
|
||||
this.transport.write(serializeJsonLine(response));
|
||||
} catch (error) {
|
||||
// Write failed - log and ignore since response cannot be delivered
|
||||
const errorMsg = error instanceof Error ? error.message : String(error);
|
||||
logger.warn(
|
||||
`Failed to send control response for request ${requestId}: ${errorMsg}`,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
async close(): Promise<void> {
|
||||
@@ -790,11 +812,7 @@ export class Query implements AsyncIterable<SDKMessage> {
|
||||
* The timeout ensures we don't hang indefinitely - either the turn proceeds
|
||||
* normally, or it fails with a timeout, but Promise.race will always resolve.
|
||||
*/
|
||||
if (
|
||||
!this.isSingleTurn &&
|
||||
this.sdkMcpTransports.size > 0 &&
|
||||
this.firstResultReceivedPromise
|
||||
) {
|
||||
if (this.firstResultReceivedPromise) {
|
||||
const streamCloseTimeout =
|
||||
this.options.timeout?.streamClose ?? DEFAULT_STREAM_CLOSE_TIMEOUT;
|
||||
let timeoutId: NodeJS.Timeout | undefined;
|
||||
|
||||
@@ -18,6 +18,7 @@ export class ProcessTransport implements Transport {
|
||||
private ready = false;
|
||||
private _exitError: Error | null = null;
|
||||
private closed = false;
|
||||
private inputClosed = false;
|
||||
private abortController: AbortController;
|
||||
private processExitHandler: (() => void) | null = null;
|
||||
private abortHandler: (() => void) | null = null;
|
||||
@@ -210,6 +211,7 @@ export class ProcessTransport implements Transport {
|
||||
|
||||
this.ready = false;
|
||||
this.closed = true;
|
||||
this.inputClosed = true;
|
||||
}
|
||||
|
||||
async waitForExit(): Promise<void> {
|
||||
@@ -273,8 +275,16 @@ export class ProcessTransport implements Transport {
|
||||
throw new Error('Cannot write to closed transport');
|
||||
}
|
||||
|
||||
if (this.childStdin.writableEnded) {
|
||||
throw new Error('Cannot write to ended stream');
|
||||
if (this.inputClosed) {
|
||||
throw new Error('Input stream closed');
|
||||
}
|
||||
|
||||
if (this.childStdin.writableEnded || this.childStdin.destroyed) {
|
||||
this.inputClosed = true;
|
||||
logger.warn(
|
||||
`Cannot write to ${this.childStdin.writableEnded ? 'ended' : 'destroyed'} stdin stream, ignoring write`,
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
if (this.childProcess?.killed || this.childProcess?.exitCode !== null) {
|
||||
@@ -301,10 +311,25 @@ export class ProcessTransport implements Transport {
|
||||
logger.debug(`Write successful (${message.length} bytes)`);
|
||||
}
|
||||
} catch (error) {
|
||||
// Check if this is a stream-closed error (EPIPE, ERR_STREAM_WRITE_AFTER_END, etc.)
|
||||
const errorMsg = error instanceof Error ? error.message : String(error);
|
||||
const isStreamClosedError =
|
||||
errorMsg.includes('EPIPE') ||
|
||||
errorMsg.includes('ERR_STREAM_WRITE_AFTER_END') ||
|
||||
errorMsg.includes('write after end');
|
||||
|
||||
if (isStreamClosedError) {
|
||||
// Soft-fail: log and return without throwing or changing ready state
|
||||
this.inputClosed = true;
|
||||
logger.warn(`Stream closed, cannot write: ${errorMsg}`);
|
||||
return;
|
||||
}
|
||||
|
||||
// For other errors, maintain original behavior
|
||||
this.ready = false;
|
||||
const errorMsg = `Failed to write to stdin: ${error instanceof Error ? error.message : String(error)}`;
|
||||
logger.error(errorMsg);
|
||||
throw new Error(errorMsg);
|
||||
const fullErrorMsg = `Failed to write to stdin: ${errorMsg}`;
|
||||
logger.error(fullErrorMsg);
|
||||
throw new Error(fullErrorMsg);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -344,6 +369,7 @@ export class ProcessTransport implements Transport {
|
||||
endInput(): void {
|
||||
if (this.childStdin) {
|
||||
this.childStdin.end();
|
||||
this.inputClosed = true;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -647,7 +647,7 @@ describe('ProcessTransport', () => {
|
||||
);
|
||||
});
|
||||
|
||||
it('should throw if writing to ended stream', () => {
|
||||
it('should not throw when writing to ended stream (soft-fail)', () => {
|
||||
mockPrepareSpawnInfo.mockReturnValue({
|
||||
command: 'qwen',
|
||||
args: [],
|
||||
@@ -664,9 +664,8 @@ describe('ProcessTransport', () => {
|
||||
|
||||
mockStdin.end();
|
||||
|
||||
expect(() => transport.write('test')).toThrow(
|
||||
'Cannot write to ended stream',
|
||||
);
|
||||
// Should not throw - soft-fail behavior
|
||||
expect(() => transport.write('test')).not.toThrow();
|
||||
});
|
||||
|
||||
it('should throw if writing to terminated process', () => {
|
||||
|
||||
@@ -261,6 +261,20 @@ function createControlCancel(requestId: string): ControlCancelRequest {
|
||||
};
|
||||
}
|
||||
|
||||
async function respondToInitialize(
|
||||
transport: MockTransport,
|
||||
query: Query,
|
||||
): Promise<void> {
|
||||
await vi.waitFor(() => {
|
||||
expect(transport.writtenMessages.length).toBeGreaterThan(0);
|
||||
});
|
||||
const initRequest = transport.getLastWrittenMessage() as CLIControlRequest;
|
||||
transport.simulateMessage(
|
||||
createControlResponse(initRequest.request_id, true, {}),
|
||||
);
|
||||
await query.initialized;
|
||||
}
|
||||
|
||||
describe('Query', () => {
|
||||
let transport: MockTransport;
|
||||
|
||||
@@ -295,6 +309,7 @@ describe('Query', () => {
|
||||
expect(initRequest.type).toBe('control_request');
|
||||
expect(initRequest.request.subtype).toBe('initialize');
|
||||
|
||||
await respondToInitialize(transport, query);
|
||||
await query.close();
|
||||
});
|
||||
|
||||
@@ -307,6 +322,8 @@ describe('Query', () => {
|
||||
|
||||
expect(query1.getSessionId()).not.toBe(query2.getSessionId());
|
||||
|
||||
await respondToInitialize(transport, query1);
|
||||
await respondToInitialize(transport2, query2);
|
||||
await query1.close();
|
||||
await query2.close();
|
||||
await transport2.close();
|
||||
@@ -338,6 +355,8 @@ describe('Query', () => {
|
||||
it('should route user messages to output stream', async () => {
|
||||
const query = new Query(transport, { cwd: '/test' });
|
||||
|
||||
await respondToInitialize(transport, query);
|
||||
|
||||
const userMsg = createUserMessage('Hello');
|
||||
transport.simulateMessage(userMsg);
|
||||
|
||||
@@ -351,6 +370,8 @@ describe('Query', () => {
|
||||
it('should route assistant messages to output stream', async () => {
|
||||
const query = new Query(transport, { cwd: '/test' });
|
||||
|
||||
await respondToInitialize(transport, query);
|
||||
|
||||
const assistantMsg = createAssistantMessage('Response');
|
||||
transport.simulateMessage(assistantMsg);
|
||||
|
||||
@@ -364,6 +385,8 @@ describe('Query', () => {
|
||||
it('should route system messages to output stream', async () => {
|
||||
const query = new Query(transport, { cwd: '/test' });
|
||||
|
||||
await respondToInitialize(transport, query);
|
||||
|
||||
const systemMsg = createSystemMessage('session_start');
|
||||
transport.simulateMessage(systemMsg);
|
||||
|
||||
@@ -377,6 +400,8 @@ describe('Query', () => {
|
||||
it('should route result messages to output stream', async () => {
|
||||
const query = new Query(transport, { cwd: '/test' });
|
||||
|
||||
await respondToInitialize(transport, query);
|
||||
|
||||
const resultMsg = createResultMessage(true);
|
||||
transport.simulateMessage(resultMsg);
|
||||
|
||||
@@ -390,6 +415,8 @@ describe('Query', () => {
|
||||
it('should route partial assistant messages to output stream', async () => {
|
||||
const query = new Query(transport, { cwd: '/test' });
|
||||
|
||||
await respondToInitialize(transport, query);
|
||||
|
||||
const partialMsg = createPartialMessage();
|
||||
transport.simulateMessage(partialMsg);
|
||||
|
||||
@@ -403,6 +430,8 @@ describe('Query', () => {
|
||||
it('should handle unknown message types', async () => {
|
||||
const query = new Query(transport, { cwd: '/test' });
|
||||
|
||||
await respondToInitialize(transport, query);
|
||||
|
||||
const unknownMsg = { type: 'unknown', data: 'test' };
|
||||
transport.simulateMessage(unknownMsg);
|
||||
|
||||
@@ -416,6 +445,8 @@ describe('Query', () => {
|
||||
it('should yield messages in order', async () => {
|
||||
const query = new Query(transport, { cwd: '/test' });
|
||||
|
||||
await respondToInitialize(transport, query);
|
||||
|
||||
const msg1 = createUserMessage('First');
|
||||
const msg2 = createAssistantMessage('Second');
|
||||
const msg3 = createResultMessage(true);
|
||||
@@ -445,6 +476,8 @@ describe('Query', () => {
|
||||
canUseTool,
|
||||
});
|
||||
|
||||
await respondToInitialize(transport, query);
|
||||
|
||||
const controlReq = createControlRequest('can_use_tool');
|
||||
transport.simulateMessage(controlReq);
|
||||
|
||||
@@ -469,6 +502,8 @@ describe('Query', () => {
|
||||
canUseTool,
|
||||
});
|
||||
|
||||
await respondToInitialize(transport, query);
|
||||
|
||||
const controlReq = createControlRequest('can_use_tool', 'perm-req-1');
|
||||
transport.simulateMessage(controlReq);
|
||||
|
||||
@@ -495,6 +530,8 @@ describe('Query', () => {
|
||||
canUseTool,
|
||||
});
|
||||
|
||||
await respondToInitialize(transport, query);
|
||||
|
||||
const controlReq = createControlRequest('can_use_tool', 'perm-req-2');
|
||||
transport.simulateMessage(controlReq);
|
||||
|
||||
@@ -519,6 +556,8 @@ describe('Query', () => {
|
||||
cwd: '/test',
|
||||
});
|
||||
|
||||
await respondToInitialize(transport, query);
|
||||
|
||||
const controlReq = createControlRequest('can_use_tool', 'perm-req-3');
|
||||
transport.simulateMessage(controlReq);
|
||||
|
||||
@@ -554,6 +593,8 @@ describe('Query', () => {
|
||||
},
|
||||
});
|
||||
|
||||
await respondToInitialize(transport, query);
|
||||
|
||||
const controlReq = createControlRequest('can_use_tool', 'perm-req-4');
|
||||
transport.simulateMessage(controlReq);
|
||||
|
||||
@@ -583,6 +624,8 @@ describe('Query', () => {
|
||||
canUseTool,
|
||||
});
|
||||
|
||||
await respondToInitialize(transport, query);
|
||||
|
||||
const controlReq = createControlRequest('can_use_tool', 'perm-req-5');
|
||||
transport.simulateMessage(controlReq);
|
||||
|
||||
@@ -613,6 +656,8 @@ describe('Query', () => {
|
||||
canUseTool,
|
||||
});
|
||||
|
||||
await respondToInitialize(transport, query);
|
||||
|
||||
const controlReq = createControlRequest('can_use_tool', 'perm-req-6');
|
||||
transport.simulateMessage(controlReq);
|
||||
|
||||
@@ -644,6 +689,8 @@ describe('Query', () => {
|
||||
canUseTool,
|
||||
});
|
||||
|
||||
await respondToInitialize(transport, query);
|
||||
|
||||
const controlReq = createControlRequest('can_use_tool', 'perm-req-7');
|
||||
transport.simulateMessage(controlReq);
|
||||
|
||||
@@ -684,6 +731,8 @@ describe('Query', () => {
|
||||
canUseTool,
|
||||
});
|
||||
|
||||
await respondToInitialize(transport, query);
|
||||
|
||||
const controlReq = createControlRequest('can_use_tool', 'cancel-req-1');
|
||||
transport.simulateMessage(controlReq);
|
||||
|
||||
@@ -703,6 +752,8 @@ describe('Query', () => {
|
||||
cwd: '/test',
|
||||
});
|
||||
|
||||
await respondToInitialize(transport, query);
|
||||
|
||||
// Send cancel for non-existent request
|
||||
transport.simulateMessage(createControlCancel('unknown-req'));
|
||||
|
||||
@@ -717,24 +768,16 @@ describe('Query', () => {
|
||||
it('should support streamInput() for follow-up messages', async () => {
|
||||
const query = new Query(transport, { cwd: '/test' });
|
||||
|
||||
// Respond to initialize
|
||||
await vi.waitFor(() => {
|
||||
expect(transport.writtenMessages.length).toBeGreaterThan(0);
|
||||
});
|
||||
const initRequest =
|
||||
transport.getLastWrittenMessage() as CLIControlRequest;
|
||||
transport.simulateMessage(
|
||||
createControlResponse(initRequest.request_id, true, {}),
|
||||
);
|
||||
|
||||
await query.initialized;
|
||||
await respondToInitialize(transport, query);
|
||||
|
||||
async function* messageGenerator() {
|
||||
yield createUserMessage('Follow-up 1');
|
||||
yield createUserMessage('Follow-up 2');
|
||||
}
|
||||
|
||||
await query.streamInput(messageGenerator());
|
||||
const streamPromise = query.streamInput(messageGenerator());
|
||||
transport.simulateMessage(createResultMessage(true));
|
||||
await streamPromise;
|
||||
|
||||
const messages = transport.getAllWrittenMessages();
|
||||
const userMessages = messages.filter(
|
||||
@@ -753,24 +796,16 @@ describe('Query', () => {
|
||||
const query = new Query(transport, { cwd: '/test' });
|
||||
const sessionId = query.getSessionId();
|
||||
|
||||
// Respond to initialize
|
||||
await vi.waitFor(() => {
|
||||
expect(transport.writtenMessages.length).toBeGreaterThan(0);
|
||||
});
|
||||
const initRequest =
|
||||
transport.getLastWrittenMessage() as CLIControlRequest;
|
||||
transport.simulateMessage(
|
||||
createControlResponse(initRequest.request_id, true, {}),
|
||||
);
|
||||
|
||||
await query.initialized;
|
||||
await respondToInitialize(transport, query);
|
||||
|
||||
async function* messageGenerator() {
|
||||
yield createUserMessage('Turn 1', sessionId);
|
||||
yield createUserMessage('Turn 2', sessionId);
|
||||
}
|
||||
|
||||
await query.streamInput(messageGenerator());
|
||||
const streamPromise = query.streamInput(messageGenerator());
|
||||
transport.simulateMessage(createResultMessage(true));
|
||||
await streamPromise;
|
||||
|
||||
const messages = transport.getAllWrittenMessages();
|
||||
const userMessages = messages.filter(
|
||||
@@ -790,6 +825,7 @@ describe('Query', () => {
|
||||
|
||||
it('should throw if streamInput() called on closed query', async () => {
|
||||
const query = new Query(transport, { cwd: '/test' });
|
||||
await respondToInitialize(transport, query);
|
||||
await query.close();
|
||||
|
||||
async function* messageGenerator() {
|
||||
@@ -808,17 +844,7 @@ describe('Query', () => {
|
||||
abortController,
|
||||
});
|
||||
|
||||
// Respond to initialize
|
||||
await vi.waitFor(() => {
|
||||
expect(transport.writtenMessages.length).toBeGreaterThan(0);
|
||||
});
|
||||
const initRequest =
|
||||
transport.getLastWrittenMessage() as CLIControlRequest;
|
||||
transport.simulateMessage(
|
||||
createControlResponse(initRequest.request_id, true, {}),
|
||||
);
|
||||
|
||||
await query.initialized;
|
||||
await respondToInitialize(transport, query);
|
||||
|
||||
async function* messageGenerator() {
|
||||
yield createUserMessage('Message 1');
|
||||
@@ -826,7 +852,9 @@ describe('Query', () => {
|
||||
yield createUserMessage('Message 2'); // Should not be sent
|
||||
}
|
||||
|
||||
await query.streamInput(messageGenerator());
|
||||
const streamPromise = query.streamInput(messageGenerator());
|
||||
transport.simulateMessage(createResultMessage(true));
|
||||
await streamPromise;
|
||||
|
||||
await query.close();
|
||||
});
|
||||
@@ -835,6 +863,8 @@ describe('Query', () => {
|
||||
describe('Lifecycle Management', () => {
|
||||
it('should close transport on close()', async () => {
|
||||
const query = new Query(transport, { cwd: '/test' });
|
||||
|
||||
await respondToInitialize(transport, query);
|
||||
await query.close();
|
||||
|
||||
expect(transport.closed).toBe(true);
|
||||
@@ -842,6 +872,7 @@ describe('Query', () => {
|
||||
|
||||
it('should mark query as closed', async () => {
|
||||
const query = new Query(transport, { cwd: '/test' });
|
||||
await respondToInitialize(transport, query);
|
||||
expect(query.isClosed()).toBe(false);
|
||||
|
||||
await query.close();
|
||||
@@ -851,6 +882,8 @@ describe('Query', () => {
|
||||
it('should complete output stream on close()', async () => {
|
||||
const query = new Query(transport, { cwd: '/test' });
|
||||
|
||||
await respondToInitialize(transport, query);
|
||||
|
||||
const iterationPromise = (async () => {
|
||||
const messages: SDKMessage[] = [];
|
||||
for await (const msg of query) {
|
||||
@@ -869,6 +902,8 @@ describe('Query', () => {
|
||||
it('should be idempotent when closing multiple times', async () => {
|
||||
const query = new Query(transport, { cwd: '/test' });
|
||||
|
||||
await respondToInitialize(transport, query);
|
||||
|
||||
await query.close();
|
||||
await query.close();
|
||||
await query.close();
|
||||
@@ -883,6 +918,8 @@ describe('Query', () => {
|
||||
abortController,
|
||||
});
|
||||
|
||||
await respondToInitialize(transport, query);
|
||||
|
||||
abortController.abort();
|
||||
|
||||
await vi.waitFor(() => {
|
||||
@@ -909,6 +946,8 @@ describe('Query', () => {
|
||||
it('should support for await loop', async () => {
|
||||
const query = new Query(transport, { cwd: '/test' });
|
||||
|
||||
await respondToInitialize(transport, query);
|
||||
|
||||
const messages: SDKMessage[] = [];
|
||||
const iterationPromise = (async () => {
|
||||
for await (const msg of query) {
|
||||
@@ -931,6 +970,8 @@ describe('Query', () => {
|
||||
it('should complete iteration when query closes', async () => {
|
||||
const query = new Query(transport, { cwd: '/test' });
|
||||
|
||||
await respondToInitialize(transport, query);
|
||||
|
||||
const messages: SDKMessage[] = [];
|
||||
const iterationPromise = (async () => {
|
||||
for await (const msg of query) {
|
||||
@@ -953,6 +994,8 @@ describe('Query', () => {
|
||||
it('should propagate transport errors', async () => {
|
||||
const query = new Query(transport, { cwd: '/test' });
|
||||
|
||||
await respondToInitialize(transport, query);
|
||||
|
||||
const iterationPromise = (async () => {
|
||||
for await (const msg of query) {
|
||||
void msg;
|
||||
@@ -971,17 +1014,7 @@ describe('Query', () => {
|
||||
it('should provide interrupt() method', async () => {
|
||||
const query = new Query(transport, { cwd: '/test' });
|
||||
|
||||
// Respond to initialize
|
||||
await vi.waitFor(() => {
|
||||
expect(transport.writtenMessages.length).toBeGreaterThan(0);
|
||||
});
|
||||
const initRequest =
|
||||
transport.getLastWrittenMessage() as CLIControlRequest;
|
||||
transport.simulateMessage(
|
||||
createControlResponse(initRequest.request_id, true, {}),
|
||||
);
|
||||
|
||||
await query.initialized;
|
||||
await respondToInitialize(transport, query);
|
||||
|
||||
const interruptPromise = query.interrupt();
|
||||
|
||||
@@ -1011,17 +1044,7 @@ describe('Query', () => {
|
||||
it('should provide setPermissionMode() method', async () => {
|
||||
const query = new Query(transport, { cwd: '/test' });
|
||||
|
||||
// Respond to initialize
|
||||
await vi.waitFor(() => {
|
||||
expect(transport.writtenMessages.length).toBeGreaterThan(0);
|
||||
});
|
||||
const initRequest =
|
||||
transport.getLastWrittenMessage() as CLIControlRequest;
|
||||
transport.simulateMessage(
|
||||
createControlResponse(initRequest.request_id, true, {}),
|
||||
);
|
||||
|
||||
await query.initialized;
|
||||
await respondToInitialize(transport, query);
|
||||
|
||||
const setModePromise = query.setPermissionMode('yolo');
|
||||
|
||||
@@ -1051,17 +1074,7 @@ describe('Query', () => {
|
||||
it('should provide setModel() method', async () => {
|
||||
const query = new Query(transport, { cwd: '/test' });
|
||||
|
||||
// Respond to initialize
|
||||
await vi.waitFor(() => {
|
||||
expect(transport.writtenMessages.length).toBeGreaterThan(0);
|
||||
});
|
||||
const initRequest =
|
||||
transport.getLastWrittenMessage() as CLIControlRequest;
|
||||
transport.simulateMessage(
|
||||
createControlResponse(initRequest.request_id, true, {}),
|
||||
);
|
||||
|
||||
await query.initialized;
|
||||
await respondToInitialize(transport, query);
|
||||
|
||||
const setModelPromise = query.setModel('new-model');
|
||||
|
||||
@@ -1091,17 +1104,7 @@ describe('Query', () => {
|
||||
it('should provide supportedCommands() method', async () => {
|
||||
const query = new Query(transport, { cwd: '/test' });
|
||||
|
||||
// Respond to initialize
|
||||
await vi.waitFor(() => {
|
||||
expect(transport.writtenMessages.length).toBeGreaterThan(0);
|
||||
});
|
||||
const initRequest =
|
||||
transport.getLastWrittenMessage() as CLIControlRequest;
|
||||
transport.simulateMessage(
|
||||
createControlResponse(initRequest.request_id, true, {}),
|
||||
);
|
||||
|
||||
await query.initialized;
|
||||
await respondToInitialize(transport, query);
|
||||
|
||||
const commandsPromise = query.supportedCommands();
|
||||
|
||||
@@ -1135,17 +1138,7 @@ describe('Query', () => {
|
||||
it('should provide mcpServerStatus() method', async () => {
|
||||
const query = new Query(transport, { cwd: '/test' });
|
||||
|
||||
// Respond to initialize
|
||||
await vi.waitFor(() => {
|
||||
expect(transport.writtenMessages.length).toBeGreaterThan(0);
|
||||
});
|
||||
const initRequest =
|
||||
transport.getLastWrittenMessage() as CLIControlRequest;
|
||||
transport.simulateMessage(
|
||||
createControlResponse(initRequest.request_id, true, {}),
|
||||
);
|
||||
|
||||
await query.initialized;
|
||||
await respondToInitialize(transport, query);
|
||||
|
||||
const statusPromise = query.mcpServerStatus();
|
||||
|
||||
@@ -1180,6 +1173,7 @@ describe('Query', () => {
|
||||
|
||||
it('should throw if methods called on closed query', async () => {
|
||||
const query = new Query(transport, { cwd: '/test' });
|
||||
await respondToInitialize(transport, query);
|
||||
await query.close();
|
||||
|
||||
await expect(query.interrupt()).rejects.toThrow('Query is closed');
|
||||
@@ -1198,6 +1192,8 @@ describe('Query', () => {
|
||||
it('should propagate transport errors to stream', async () => {
|
||||
const query = new Query(transport, { cwd: '/test' });
|
||||
|
||||
await respondToInitialize(transport, query);
|
||||
|
||||
const error = new Error('Transport failure');
|
||||
transport.simulateError(error);
|
||||
|
||||
@@ -1214,17 +1210,7 @@ describe('Query', () => {
|
||||
},
|
||||
});
|
||||
|
||||
// Respond to initialize
|
||||
await vi.waitFor(() => {
|
||||
expect(transport.writtenMessages.length).toBeGreaterThan(0);
|
||||
});
|
||||
const initRequest =
|
||||
transport.getLastWrittenMessage() as CLIControlRequest;
|
||||
transport.simulateMessage(
|
||||
createControlResponse(initRequest.request_id, true, {}),
|
||||
);
|
||||
|
||||
await query.initialized;
|
||||
await respondToInitialize(transport, query);
|
||||
|
||||
// Call interrupt but don't respond - should timeout
|
||||
const interruptPromise = query.interrupt();
|
||||
@@ -1237,17 +1223,7 @@ describe('Query', () => {
|
||||
it('should handle malformed control responses', async () => {
|
||||
const query = new Query(transport, { cwd: '/test' });
|
||||
|
||||
// Respond to initialize
|
||||
await vi.waitFor(() => {
|
||||
expect(transport.writtenMessages.length).toBeGreaterThan(0);
|
||||
});
|
||||
const initRequest =
|
||||
transport.getLastWrittenMessage() as CLIControlRequest;
|
||||
transport.simulateMessage(
|
||||
createControlResponse(initRequest.request_id, true, {}),
|
||||
);
|
||||
|
||||
await query.initialized;
|
||||
await respondToInitialize(transport, query);
|
||||
|
||||
const interruptPromise = query.interrupt();
|
||||
|
||||
@@ -1284,6 +1260,8 @@ describe('Query', () => {
|
||||
it('should handle CLI sending error result message', async () => {
|
||||
const query = new Query(transport, { cwd: '/test' });
|
||||
|
||||
await respondToInitialize(transport, query);
|
||||
|
||||
const errorResult = createResultMessage(false);
|
||||
transport.simulateMessage(errorResult);
|
||||
|
||||
@@ -1303,6 +1281,8 @@ describe('Query', () => {
|
||||
true, // singleTurn = true
|
||||
);
|
||||
|
||||
await respondToInitialize(transport, query);
|
||||
|
||||
const resultMsg = createResultMessage(true);
|
||||
transport.simulateMessage(resultMsg);
|
||||
|
||||
@@ -1320,6 +1300,8 @@ describe('Query', () => {
|
||||
false, // singleTurn = false
|
||||
);
|
||||
|
||||
await respondToInitialize(transport, query);
|
||||
|
||||
const resultMsg = createResultMessage(true);
|
||||
transport.simulateMessage(resultMsg);
|
||||
|
||||
@@ -1332,19 +1314,23 @@ describe('Query', () => {
|
||||
});
|
||||
|
||||
describe('State Management', () => {
|
||||
it('should track session ID', () => {
|
||||
it('should track session ID', async () => {
|
||||
const query = new Query(transport, { cwd: '/test' });
|
||||
const sessionId = query.getSessionId();
|
||||
|
||||
expect(sessionId).toBeTruthy();
|
||||
expect(typeof sessionId).toBe('string');
|
||||
expect(sessionId.length).toBeGreaterThan(0);
|
||||
|
||||
await respondToInitialize(transport, query);
|
||||
await query.close();
|
||||
});
|
||||
|
||||
it('should track closed state', async () => {
|
||||
const query = new Query(transport, { cwd: '/test' });
|
||||
|
||||
expect(query.isClosed()).toBe(false);
|
||||
await respondToInitialize(transport, query);
|
||||
await query.close();
|
||||
expect(query.isClosed()).toBe(true);
|
||||
});
|
||||
@@ -1352,17 +1338,7 @@ describe('Query', () => {
|
||||
it('should provide endInput() method', async () => {
|
||||
const query = new Query(transport, { cwd: '/test' });
|
||||
|
||||
// Respond to initialize
|
||||
await vi.waitFor(() => {
|
||||
expect(transport.writtenMessages.length).toBeGreaterThan(0);
|
||||
});
|
||||
const initRequest =
|
||||
transport.getLastWrittenMessage() as CLIControlRequest;
|
||||
transport.simulateMessage(
|
||||
createControlResponse(initRequest.request_id, true, {}),
|
||||
);
|
||||
|
||||
await query.initialized;
|
||||
await respondToInitialize(transport, query);
|
||||
|
||||
query.endInput();
|
||||
expect(transport.endInputCalled).toBe(true);
|
||||
@@ -1372,6 +1348,7 @@ describe('Query', () => {
|
||||
|
||||
it('should throw if endInput() called on closed query', async () => {
|
||||
const query = new Query(transport, { cwd: '/test' });
|
||||
await respondToInitialize(transport, query);
|
||||
await query.close();
|
||||
|
||||
expect(() => query.endInput()).toThrow('Query is closed');
|
||||
@@ -1382,6 +1359,8 @@ describe('Query', () => {
|
||||
it('should handle empty message stream', async () => {
|
||||
const query = new Query(transport, { cwd: '/test' });
|
||||
|
||||
await respondToInitialize(transport, query);
|
||||
|
||||
transport.simulateClose();
|
||||
|
||||
const result = await query.next();
|
||||
@@ -1393,6 +1372,8 @@ describe('Query', () => {
|
||||
it('should handle rapid message flow', async () => {
|
||||
const query = new Query(transport, { cwd: '/test' });
|
||||
|
||||
await respondToInitialize(transport, query);
|
||||
|
||||
// Simulate rapid messages
|
||||
for (let i = 0; i < 100; i++) {
|
||||
transport.simulateMessage(createUserMessage(`Message ${i}`));
|
||||
@@ -1414,6 +1395,8 @@ describe('Query', () => {
|
||||
it('should handle close during message iteration', async () => {
|
||||
const query = new Query(transport, { cwd: '/test' });
|
||||
|
||||
await respondToInitialize(transport, query);
|
||||
|
||||
const iterationPromise = (async () => {
|
||||
const messages: SDKMessage[] = [];
|
||||
for await (const msg of query) {
|
||||
|
||||
Reference in New Issue
Block a user