refactor: rename ambiguous exported types

This commit is contained in:
mingholy.lmh
2025-11-27 14:50:40 +08:00
parent 638b7bb466
commit 56957a687b
14 changed files with 1188 additions and 614 deletions

View File

@@ -22,6 +22,7 @@ export default tseslint.config(
'bundle/**', 'bundle/**',
'package/bundle/**', 'package/bundle/**',
'.integration-tests/**', '.integration-tests/**',
'packages/**/.integration-test/**',
'dist/**', 'dist/**',
], ],
}, },

View File

@@ -12,20 +12,20 @@ export type {
ThinkingBlock, ThinkingBlock,
ToolUseBlock, ToolUseBlock,
ToolResultBlock, ToolResultBlock,
CLIUserMessage, SDKUserMessage,
CLIAssistantMessage, SDKAssistantMessage,
CLISystemMessage, SDKSystemMessage,
CLIResultMessage, SDKResultMessage,
CLIPartialAssistantMessage, SDKPartialAssistantMessage,
CLIMessage, SDKMessage,
} from './types/protocol.js'; } from './types/protocol.js';
export { export {
isCLIUserMessage, isSDKUserMessage,
isCLIAssistantMessage, isSDKAssistantMessage,
isCLISystemMessage, isSDKSystemMessage,
isCLIResultMessage, isSDKResultMessage,
isCLIPartialAssistantMessage, isSDKPartialAssistantMessage,
} from './types/protocol.js'; } from './types/protocol.js';
export type { export type {

View File

@@ -13,19 +13,19 @@ const STREAM_CLOSE_TIMEOUT = 10000;
import { randomUUID } from 'node:crypto'; import { randomUUID } from 'node:crypto';
import { SdkLogger } from '../utils/logger.js'; import { SdkLogger } from '../utils/logger.js';
import type { import type {
CLIMessage, SDKMessage,
CLIUserMessage, SDKUserMessage,
CLIControlRequest, CLIControlRequest,
CLIControlResponse, CLIControlResponse,
ControlCancelRequest, ControlCancelRequest,
PermissionSuggestion, PermissionSuggestion,
} from '../types/protocol.js'; } from '../types/protocol.js';
import { import {
isCLIUserMessage, isSDKUserMessage,
isCLIAssistantMessage, isSDKAssistantMessage,
isCLISystemMessage, isSDKSystemMessage,
isCLIResultMessage, isSDKResultMessage,
isCLIPartialAssistantMessage, isSDKPartialAssistantMessage,
isControlRequest, isControlRequest,
isControlResponse, isControlResponse,
isControlCancel, isControlCancel,
@@ -52,12 +52,12 @@ interface TransportWithEndInput extends Transport {
const logger = SdkLogger.createLogger('Query'); const logger = SdkLogger.createLogger('Query');
export class Query implements AsyncIterable<CLIMessage> { export class Query implements AsyncIterable<SDKMessage> {
private transport: Transport; private transport: Transport;
private options: QueryOptions; private options: QueryOptions;
private sessionId: string; private sessionId: string;
private inputStream: Stream<CLIMessage>; private inputStream: Stream<SDKMessage>;
private sdkMessages: AsyncGenerator<CLIMessage>; private sdkMessages: AsyncGenerator<SDKMessage>;
private abortController: AbortController; private abortController: AbortController;
private pendingControlRequests: Map<string, PendingControlRequest> = private pendingControlRequests: Map<string, PendingControlRequest> =
new Map(); new Map();
@@ -79,7 +79,7 @@ export class Query implements AsyncIterable<CLIMessage> {
this.transport = transport; this.transport = transport;
this.options = options; this.options = options;
this.sessionId = randomUUID(); this.sessionId = randomUUID();
this.inputStream = new Stream<CLIMessage>(); this.inputStream = new Stream<SDKMessage>();
this.abortController = options.abortController ?? new AbortController(); this.abortController = options.abortController ?? new AbortController();
this.isSingleTurn = singleTurn; this.isSingleTurn = singleTurn;
@@ -187,7 +187,7 @@ export class Query implements AsyncIterable<CLIMessage> {
return; return;
} }
if (isCLISystemMessage(message)) { if (isSDKSystemMessage(message)) {
/** /**
* SystemMessage contains session info (cwd, tools, model, etc.) * SystemMessage contains session info (cwd, tools, model, etc.)
* that should be passed to user. * that should be passed to user.
@@ -196,7 +196,7 @@ export class Query implements AsyncIterable<CLIMessage> {
return; return;
} }
if (isCLIResultMessage(message)) { if (isSDKResultMessage(message)) {
if (this.firstResultReceivedResolve) { if (this.firstResultReceivedResolve) {
this.firstResultReceivedResolve(); this.firstResultReceivedResolve();
} }
@@ -212,16 +212,16 @@ export class Query implements AsyncIterable<CLIMessage> {
} }
if ( if (
isCLIAssistantMessage(message) || isSDKAssistantMessage(message) ||
isCLIUserMessage(message) || isSDKUserMessage(message) ||
isCLIPartialAssistantMessage(message) isSDKPartialAssistantMessage(message)
) { ) {
this.inputStream.enqueue(message); this.inputStream.enqueue(message);
return; return;
} }
logger.warn('Unknown message type:', message); logger.warn('Unknown message type:', message);
this.inputStream.enqueue(message as CLIMessage); this.inputStream.enqueue(message as SDKMessage);
} }
private async handleControlRequest( private async handleControlRequest(
@@ -560,29 +560,29 @@ export class Query implements AsyncIterable<CLIMessage> {
logger.info('Query closed'); logger.info('Query closed');
} }
private async *readSdkMessages(): AsyncGenerator<CLIMessage> { private async *readSdkMessages(): AsyncGenerator<SDKMessage> {
for await (const message of this.inputStream) { for await (const message of this.inputStream) {
yield message; yield message;
} }
} }
async next(...args: [] | [unknown]): Promise<IteratorResult<CLIMessage>> { async next(...args: [] | [unknown]): Promise<IteratorResult<SDKMessage>> {
return this.sdkMessages.next(...args); return this.sdkMessages.next(...args);
} }
async return(value?: unknown): Promise<IteratorResult<CLIMessage>> { async return(value?: unknown): Promise<IteratorResult<SDKMessage>> {
return this.sdkMessages.return(value); return this.sdkMessages.return(value);
} }
async throw(e?: unknown): Promise<IteratorResult<CLIMessage>> { async throw(e?: unknown): Promise<IteratorResult<SDKMessage>> {
return this.sdkMessages.throw(e); return this.sdkMessages.throw(e);
} }
[Symbol.asyncIterator](): AsyncIterator<CLIMessage> { [Symbol.asyncIterator](): AsyncIterator<SDKMessage> {
return this.sdkMessages; return this.sdkMessages;
} }
async streamInput(messages: AsyncIterable<CLIUserMessage>): Promise<void> { async streamInput(messages: AsyncIterable<SDKUserMessage>): Promise<void> {
if (this.closed) { if (this.closed) {
throw new Error('Query is closed'); throw new Error('Query is closed');
} }

View File

@@ -2,7 +2,7 @@
* Factory function for creating Query instances. * Factory function for creating Query instances.
*/ */
import type { CLIUserMessage } from '../types/protocol.js'; import type { SDKUserMessage } from '../types/protocol.js';
import { serializeJsonLine } from '../utils/jsonLines.js'; import { serializeJsonLine } from '../utils/jsonLines.js';
import { ProcessTransport } from '../transport/ProcessTransport.js'; import { ProcessTransport } from '../transport/ProcessTransport.js';
import { parseExecutableSpec } from '../utils/cliPath.js'; import { parseExecutableSpec } from '../utils/cliPath.js';
@@ -22,11 +22,11 @@ export function query({
/** /**
* The prompt to send to the Qwen Code CLI process. * The prompt to send to the Qwen Code CLI process.
* - `string` for single-turn query, * - `string` for single-turn query,
* - `AsyncIterable<CLIUserMessage>` for multi-turn query. * - `AsyncIterable<SDKUserMessage>` for multi-turn query.
* *
* The transport will remain open until the prompt is done. * The transport will remain open until the prompt is done.
*/ */
prompt: string | AsyncIterable<CLIUserMessage>; prompt: string | AsyncIterable<SDKUserMessage>;
/** /**
* Configuration options for the query. * Configuration options for the query.
*/ */
@@ -67,7 +67,7 @@ export function query({
if (isSingleTurn) { if (isSingleTurn) {
const stringPrompt = prompt as string; const stringPrompt = prompt as string;
const message: CLIUserMessage = { const message: SDKUserMessage = {
type: 'user', type: 'user',
session_id: queryInstance.getSessionId(), session_id: queryInstance.getSessionId(),
message: { message: {
@@ -87,7 +87,7 @@ export function query({
})(); })();
} else { } else {
queryInstance queryInstance
.streamInput(prompt as AsyncIterable<CLIUserMessage>) .streamInput(prompt as AsyncIterable<SDKUserMessage>)
.catch((err) => { .catch((err) => {
logger.error('Error streaming input:', err); logger.error('Error streaming input:', err);
}); });

View File

@@ -89,7 +89,7 @@ export interface APIAssistantMessage {
usage: Usage; usage: Usage;
} }
export interface CLIUserMessage { export interface SDKUserMessage {
type: 'user'; type: 'user';
uuid?: string; uuid?: string;
session_id: string; session_id: string;
@@ -98,7 +98,7 @@ export interface CLIUserMessage {
options?: Record<string, unknown>; options?: Record<string, unknown>;
} }
export interface CLIAssistantMessage { export interface SDKAssistantMessage {
type: 'assistant'; type: 'assistant';
uuid: string; uuid: string;
session_id: string; session_id: string;
@@ -106,7 +106,7 @@ export interface CLIAssistantMessage {
parent_tool_use_id: string | null; parent_tool_use_id: string | null;
} }
export interface CLISystemMessage { export interface SDKSystemMessage {
type: 'system'; type: 'system';
subtype: string; subtype: string;
uuid: string; uuid: string;
@@ -133,7 +133,7 @@ export interface CLISystemMessage {
}; };
} }
export interface CLIResultMessageSuccess { export interface SDKResultMessageSuccess {
type: 'result'; type: 'result';
subtype: 'success'; subtype: 'success';
uuid: string; uuid: string;
@@ -149,7 +149,7 @@ export interface CLIResultMessageSuccess {
[key: string]: unknown; [key: string]: unknown;
} }
export interface CLIResultMessageError { export interface SDKResultMessageError {
type: 'result'; type: 'result';
subtype: 'error_max_turns' | 'error_during_execution'; subtype: 'error_max_turns' | 'error_during_execution';
uuid: string; uuid: string;
@@ -169,7 +169,7 @@ export interface CLIResultMessageError {
[key: string]: unknown; [key: string]: unknown;
} }
export type CLIResultMessage = CLIResultMessageSuccess | CLIResultMessageError; export type SDKResultMessage = SDKResultMessageSuccess | SDKResultMessageError;
export interface MessageStartStreamEvent { export interface MessageStartStreamEvent {
type: 'message_start'; type: 'message_start';
@@ -222,7 +222,7 @@ export type StreamEvent =
| ContentBlockStopEvent | ContentBlockStopEvent
| MessageStopStreamEvent; | MessageStopStreamEvent;
export interface CLIPartialAssistantMessage { export interface SDKPartialAssistantMessage {
type: 'stream_event'; type: 'stream_event';
uuid: string; uuid: string;
session_id: string; session_id: string;
@@ -389,22 +389,22 @@ export type ControlMessage =
| ControlCancelRequest; | ControlCancelRequest;
/** /**
* Union of all CLI message types * Union of all SDK message types
*/ */
export type CLIMessage = export type SDKMessage =
| CLIUserMessage | SDKUserMessage
| CLIAssistantMessage | SDKAssistantMessage
| CLISystemMessage | SDKSystemMessage
| CLIResultMessage | SDKResultMessage
| CLIPartialAssistantMessage; | SDKPartialAssistantMessage;
export function isCLIUserMessage(msg: any): msg is CLIUserMessage { export function isSDKUserMessage(msg: any): msg is SDKUserMessage {
return ( return (
msg && typeof msg === 'object' && msg.type === 'user' && 'message' in msg msg && typeof msg === 'object' && msg.type === 'user' && 'message' in msg
); );
} }
export function isCLIAssistantMessage(msg: any): msg is CLIAssistantMessage { export function isSDKAssistantMessage(msg: any): msg is SDKAssistantMessage {
return ( return (
msg && msg &&
typeof msg === 'object' && typeof msg === 'object' &&
@@ -416,7 +416,7 @@ export function isCLIAssistantMessage(msg: any): msg is CLIAssistantMessage {
); );
} }
export function isCLISystemMessage(msg: any): msg is CLISystemMessage { export function isSDKSystemMessage(msg: any): msg is SDKSystemMessage {
return ( return (
msg && msg &&
typeof msg === 'object' && typeof msg === 'object' &&
@@ -427,7 +427,7 @@ export function isCLISystemMessage(msg: any): msg is CLISystemMessage {
); );
} }
export function isCLIResultMessage(msg: any): msg is CLIResultMessage { export function isSDKResultMessage(msg: any): msg is SDKResultMessage {
return ( return (
msg && msg &&
typeof msg === 'object' && typeof msg === 'object' &&
@@ -440,9 +440,9 @@ export function isCLIResultMessage(msg: any): msg is CLIResultMessage {
); );
} }
export function isCLIPartialAssistantMessage( export function isSDKPartialAssistantMessage(
msg: any, msg: any,
): msg is CLIPartialAssistantMessage { ): msg is SDKPartialAssistantMessage {
return ( return (
msg && msg &&
typeof msg === 'object' && typeof msg === 'object' &&

View File

@@ -14,14 +14,15 @@ const e2eTestsDir = join(rootDir, '.integration-tests');
let runDir = ''; let runDir = '';
export async function setup() { export async function setup() {
runDir = join(e2eTestsDir, `${Date.now()}`); runDir = join(e2eTestsDir, `sdk-e2e-${Date.now()}`);
await mkdir(runDir, { recursive: true }); await mkdir(runDir, { recursive: true });
// Clean up old test runs, but keep the latest few for debugging // Clean up old test runs, but keep the latest few for debugging
try { try {
const testRuns = await readdir(e2eTestsDir); const testRuns = await readdir(e2eTestsDir);
if (testRuns.length > 5) { const sdkTestRuns = testRuns.filter((run) => run.startsWith('sdk-e2e-'));
const oldRuns = testRuns.sort().slice(0, testRuns.length - 5); if (sdkTestRuns.length > 5) {
const oldRuns = sdkTestRuns.sort().slice(0, sdkTestRuns.length - 5);
await Promise.all( await Promise.all(
oldRuns.map((oldRun) => oldRuns.map((oldRun) =>
rm(join(e2eTestsDir, oldRun), { rm(join(e2eTestsDir, oldRun), {
@@ -44,7 +45,7 @@ export async function setup() {
} }
process.env['VERBOSE'] = process.env['VERBOSE'] ?? 'false'; process.env['VERBOSE'] = process.env['VERBOSE'] ?? 'false';
console.log(`\nE2E test output directory: ${runDir}`); console.log(`\nSDK E2E test output directory: ${runDir}`);
console.log(`CLI path: ${process.env['TEST_CLI_PATH']}`); console.log(`CLI path: ${process.env['TEST_CLI_PATH']}`);
} }

View File

@@ -9,234 +9,48 @@
* Tests that the SDK can properly interact with MCP servers configured in qwen-code * Tests that the SDK can properly interact with MCP servers configured in qwen-code
*/ */
import { describe, it, expect, beforeAll } from 'vitest'; import { describe, it, expect, beforeAll, afterAll } from 'vitest';
import { query } from '../../src/index.js'; import { query } from '../../src/index.js';
import { import {
isCLIAssistantMessage, isSDKAssistantMessage,
isCLIResultMessage, isSDKResultMessage,
isCLISystemMessage, isSDKSystemMessage,
isCLIUserMessage, isSDKUserMessage,
type TextBlock, type SDKMessage,
type ContentBlock,
type CLIMessage,
type ToolUseBlock, type ToolUseBlock,
type CLISystemMessage, type SDKSystemMessage,
} from '../../src/types/protocol.js'; } from '../../src/types/protocol.js';
import { writeFileSync, mkdirSync, chmodSync } from 'node:fs'; import {
import { join } from 'node:path'; SDKTestHelper,
createMCPServer,
const TEST_CLI_PATH = process.env['TEST_CLI_PATH']!; extractText,
const E2E_TEST_FILE_DIR = process.env['E2E_TEST_FILE_DIR']!; findToolUseBlocks,
createSharedTestOptions,
} from './test-helper.js';
const SHARED_TEST_OPTIONS = { const SHARED_TEST_OPTIONS = {
pathToQwenExecutable: TEST_CLI_PATH, ...createSharedTestOptions(),
permissionMode: 'yolo' as const, permissionMode: 'yolo' as const,
}; };
/**
* Helper to extract text from ContentBlock array
*/
function extractText(content: ContentBlock[]): string {
return content
.filter((block): block is TextBlock => block.type === 'text')
.map((block) => block.text)
.join('');
}
/**
* Minimal MCP server implementation that doesn't require external dependencies
* This implements the MCP protocol directly using Node.js built-ins
*/
const MCP_SERVER_SCRIPT = `#!/usr/bin/env node
/**
* @license
* Copyright 2025 Qwen Team
* SPDX-License-Identifier: Apache-2.0
*/
const readline = require('readline');
const fs = require('fs');
// Debug logging to stderr (only when MCP_DEBUG or VERBOSE is set)
const debugEnabled = process.env['MCP_DEBUG'] === 'true' || process.env['VERBOSE'] === 'true';
function debug(msg) {
if (debugEnabled) {
fs.writeSync(2, \`[MCP-DEBUG] \${msg}\\n\`);
}
}
debug('MCP server starting...');
// Simple JSON-RPC implementation for MCP
class SimpleJSONRPC {
constructor() {
this.handlers = new Map();
this.rl = readline.createInterface({
input: process.stdin,
output: process.stdout,
terminal: false
});
this.rl.on('line', (line) => {
debug(\`Received line: \${line}\`);
try {
const message = JSON.parse(line);
debug(\`Parsed message: \${JSON.stringify(message)}\`);
this.handleMessage(message);
} catch (e) {
debug(\`Parse error: \${e.message}\`);
}
});
}
send(message) {
const msgStr = JSON.stringify(message);
debug(\`Sending message: \${msgStr}\`);
process.stdout.write(msgStr + '\\n');
}
async handleMessage(message) {
if (message.method && this.handlers.has(message.method)) {
try {
const result = await this.handlers.get(message.method)(message.params || {});
if (message.id !== undefined) {
this.send({
jsonrpc: '2.0',
id: message.id,
result
});
}
} catch (error) {
if (message.id !== undefined) {
this.send({
jsonrpc: '2.0',
id: message.id,
error: {
code: -32603,
message: error.message
}
});
}
}
} else if (message.id !== undefined) {
this.send({
jsonrpc: '2.0',
id: message.id,
error: {
code: -32601,
message: 'Method not found'
}
});
}
}
on(method, handler) {
this.handlers.set(method, handler);
}
}
// Create MCP server
const rpc = new SimpleJSONRPC();
// Handle initialize
rpc.on('initialize', async (params) => {
debug('Handling initialize request');
return {
protocolVersion: '2024-11-05',
capabilities: {
tools: {}
},
serverInfo: {
name: 'test-math-server',
version: '1.0.0'
}
};
});
// Handle tools/list
rpc.on('tools/list', async () => {
debug('Handling tools/list request');
return {
tools: [
{
name: 'add',
description: 'Add two numbers together',
inputSchema: {
type: 'object',
properties: {
a: { type: 'number', description: 'First number' },
b: { type: 'number', description: 'Second number' }
},
required: ['a', 'b']
}
},
{
name: 'multiply',
description: 'Multiply two numbers together',
inputSchema: {
type: 'object',
properties: {
a: { type: 'number', description: 'First number' },
b: { type: 'number', description: 'Second number' }
},
required: ['a', 'b']
}
}
]
};
});
// Handle tools/call
rpc.on('tools/call', async (params) => {
debug(\`Handling tools/call request for tool: \${params.name}\`);
if (params.name === 'add') {
const { a, b } = params.arguments;
return {
content: [{
type: 'text',
text: String(a + b)
}]
};
}
if (params.name === 'multiply') {
const { a, b } = params.arguments;
return {
content: [{
type: 'text',
text: String(a * b)
}]
};
}
throw new Error('Unknown tool: ' + params.name);
});
// Send initialization notification
rpc.send({
jsonrpc: '2.0',
method: 'initialized'
});
`;
describe('MCP Server Integration (E2E)', () => { describe('MCP Server Integration (E2E)', () => {
let testDir: string; let helper: SDKTestHelper;
let serverScriptPath: string; let serverScriptPath: string;
let testDir: string;
beforeAll(() => { beforeAll(async () => {
// Use the centralized E2E test directory from globalSetup // Create isolated test environment using SDKTestHelper
testDir = join(E2E_TEST_FILE_DIR, 'mcp-server-test'); helper = new SDKTestHelper();
mkdirSync(testDir, { recursive: true }); testDir = await helper.setup('mcp-server-integration');
// Write MCP server script // Create MCP server using the helper utility
serverScriptPath = join(testDir, 'mcp-server.cjs'); const mcpServer = await createMCPServer(helper, 'math', 'test-math-server');
writeFileSync(serverScriptPath, MCP_SERVER_SCRIPT); serverScriptPath = mcpServer.scriptPath;
});
// Make script executable on Unix-like systems afterAll(async () => {
if (process.platform !== 'win32') { // Cleanup test directory
chmodSync(serverScriptPath, 0o755); await helper.cleanup();
}
}); });
describe('Basic MCP Tool Usage', () => { describe('Basic MCP Tool Usage', () => {
@@ -257,7 +71,7 @@ describe('MCP Server Integration (E2E)', () => {
}, },
}); });
const messages: CLIMessage[] = []; const messages: SDKMessage[] = [];
let assistantText = ''; let assistantText = '';
let foundToolUse = false; let foundToolUse = false;
@@ -265,12 +79,9 @@ describe('MCP Server Integration (E2E)', () => {
for await (const message of q) { for await (const message of q) {
messages.push(message); messages.push(message);
if (isCLIAssistantMessage(message)) { if (isSDKAssistantMessage(message)) {
const toolUseBlock = message.message.content.find( const toolUseBlocks = findToolUseBlocks(message, 'add');
(block: ContentBlock): block is ToolUseBlock => if (toolUseBlocks.length > 0) {
block.type === 'tool_use',
);
if (toolUseBlock && toolUseBlock.name === 'add') {
foundToolUse = true; foundToolUse = true;
} }
assistantText += extractText(message.message.content); assistantText += extractText(message.message.content);
@@ -285,8 +96,8 @@ describe('MCP Server Integration (E2E)', () => {
// Validate successful completion // Validate successful completion
const lastMessage = messages[messages.length - 1]; const lastMessage = messages[messages.length - 1];
expect(isCLIResultMessage(lastMessage)).toBe(true); expect(isSDKResultMessage(lastMessage)).toBe(true);
if (isCLIResultMessage(lastMessage)) { if (isSDKResultMessage(lastMessage)) {
expect(lastMessage.subtype).toBe('success'); expect(lastMessage.subtype).toBe('success');
} }
} finally { } finally {
@@ -311,7 +122,7 @@ describe('MCP Server Integration (E2E)', () => {
}, },
}); });
const messages: CLIMessage[] = []; const messages: SDKMessage[] = [];
let assistantText = ''; let assistantText = '';
let foundToolUse = false; let foundToolUse = false;
@@ -319,12 +130,9 @@ describe('MCP Server Integration (E2E)', () => {
for await (const message of q) { for await (const message of q) {
messages.push(message); messages.push(message);
if (isCLIAssistantMessage(message)) { if (isSDKAssistantMessage(message)) {
const toolUseBlock = message.message.content.find( const toolUseBlocks = findToolUseBlocks(message, 'multiply');
(block: ContentBlock): block is ToolUseBlock => if (toolUseBlocks.length > 0) {
block.type === 'tool_use',
);
if (toolUseBlock && toolUseBlock.name === 'multiply') {
foundToolUse = true; foundToolUse = true;
} }
assistantText += extractText(message.message.content); assistantText += extractText(message.message.content);
@@ -339,7 +147,7 @@ describe('MCP Server Integration (E2E)', () => {
// Validate successful completion // Validate successful completion
const lastMessage = messages[messages.length - 1]; const lastMessage = messages[messages.length - 1];
expect(isCLIResultMessage(lastMessage)).toBe(true); expect(isSDKResultMessage(lastMessage)).toBe(true);
} finally { } finally {
await q.close(); await q.close();
} }
@@ -363,11 +171,11 @@ describe('MCP Server Integration (E2E)', () => {
}, },
}); });
let systemMessage: CLISystemMessage | null = null; let systemMessage: SDKSystemMessage | null = null;
try { try {
for await (const message of q) { for await (const message of q) {
if (isCLISystemMessage(message) && message.subtype === 'init') { if (isSDKSystemMessage(message) && message.subtype === 'init') {
systemMessage = message; systemMessage = message;
break; break;
} }
@@ -410,7 +218,7 @@ describe('MCP Server Integration (E2E)', () => {
}, },
}); });
const messages: CLIMessage[] = []; const messages: SDKMessage[] = [];
let assistantText = ''; let assistantText = '';
const toolCalls: string[] = []; const toolCalls: string[] = [];
@@ -418,11 +226,8 @@ describe('MCP Server Integration (E2E)', () => {
for await (const message of q) { for await (const message of q) {
messages.push(message); messages.push(message);
if (isCLIAssistantMessage(message)) { if (isSDKAssistantMessage(message)) {
const toolUseBlocks = message.message.content.filter( const toolUseBlocks = findToolUseBlocks(message);
(block: ContentBlock): block is ToolUseBlock =>
block.type === 'tool_use',
);
toolUseBlocks.forEach((block) => { toolUseBlocks.forEach((block) => {
toolCalls.push(block.name); toolCalls.push(block.name);
}); });
@@ -439,7 +244,7 @@ describe('MCP Server Integration (E2E)', () => {
// Validate successful completion // Validate successful completion
const lastMessage = messages[messages.length - 1]; const lastMessage = messages[messages.length - 1];
expect(isCLIResultMessage(lastMessage)).toBe(true); expect(isSDKResultMessage(lastMessage)).toBe(true);
} finally { } finally {
await q.close(); await q.close();
} }
@@ -462,7 +267,7 @@ describe('MCP Server Integration (E2E)', () => {
}, },
}); });
const messages: CLIMessage[] = []; const messages: SDKMessage[] = [];
let assistantText = ''; let assistantText = '';
const addToolCalls: ToolUseBlock[] = []; const addToolCalls: ToolUseBlock[] = [];
@@ -470,16 +275,9 @@ describe('MCP Server Integration (E2E)', () => {
for await (const message of q) { for await (const message of q) {
messages.push(message); messages.push(message);
if (isCLIAssistantMessage(message)) { if (isSDKAssistantMessage(message)) {
const toolUseBlocks = message.message.content.filter( const toolUseBlocks = findToolUseBlocks(message, 'add');
(block: ContentBlock): block is ToolUseBlock => addToolCalls.push(...toolUseBlocks);
block.type === 'tool_use',
);
toolUseBlocks.forEach((block) => {
if (block.name === 'add') {
addToolCalls.push(block);
}
});
assistantText += extractText(message.message.content); assistantText += extractText(message.message.content);
} }
} }
@@ -493,7 +291,7 @@ describe('MCP Server Integration (E2E)', () => {
// Validate successful completion // Validate successful completion
const lastMessage = messages[messages.length - 1]; const lastMessage = messages[messages.length - 1];
expect(isCLIResultMessage(lastMessage)).toBe(true); expect(isSDKResultMessage(lastMessage)).toBe(true);
} finally { } finally {
await q.close(); await q.close();
} }
@@ -525,19 +323,16 @@ describe('MCP Server Integration (E2E)', () => {
for await (const message of q) { for await (const message of q) {
messageTypes.push(message.type); messageTypes.push(message.type);
if (isCLIAssistantMessage(message)) { if (isSDKAssistantMessage(message)) {
const toolUseBlock = message.message.content.find( const toolUseBlocks = findToolUseBlocks(message);
(block: ContentBlock): block is ToolUseBlock => if (toolUseBlocks.length > 0) {
block.type === 'tool_use',
);
if (toolUseBlock) {
foundToolUse = true; foundToolUse = true;
expect(toolUseBlock.name).toBe('add'); expect(toolUseBlocks[0].name).toBe('add');
expect(toolUseBlock.input).toBeDefined(); expect(toolUseBlocks[0].input).toBeDefined();
} }
} }
if (isCLIUserMessage(message)) { if (isSDKUserMessage(message)) {
const content = message.message.content; const content = message.message.content;
const contentArray = Array.isArray(content) const contentArray = Array.isArray(content)
? content ? content
@@ -584,21 +379,21 @@ describe('MCP Server Integration (E2E)', () => {
}, },
}); });
const messages: CLIMessage[] = []; const messages: SDKMessage[] = [];
let assistantText = ''; let assistantText = '';
try { try {
for await (const message of q) { for await (const message of q) {
messages.push(message); messages.push(message);
if (isCLIAssistantMessage(message)) { if (isSDKAssistantMessage(message)) {
assistantText += extractText(message.message.content); assistantText += extractText(message.message.content);
} }
} }
// Should complete without crashing // Should complete without crashing
const lastMessage = messages[messages.length - 1]; const lastMessage = messages[messages.length - 1];
expect(isCLIResultMessage(lastMessage)).toBe(true); expect(isSDKResultMessage(lastMessage)).toBe(true);
// Assistant should indicate tool is not available or provide alternative // Assistant should indicate tool is not available or provide alternative
expect(assistantText.length).toBeGreaterThan(0); expect(assistantText.length).toBeGreaterThan(0);

View File

@@ -6,19 +6,19 @@
import { describe, it, expect } from 'vitest'; import { describe, it, expect } from 'vitest';
import { query } from '../../src/index.js'; import { query } from '../../src/index.js';
import { import {
isCLIUserMessage, isSDKUserMessage,
isCLIAssistantMessage, isSDKAssistantMessage,
isCLISystemMessage, isSDKSystemMessage,
isCLIResultMessage, isSDKResultMessage,
isCLIPartialAssistantMessage, isSDKPartialAssistantMessage,
isControlRequest, isControlRequest,
isControlResponse, isControlResponse,
isControlCancel, isControlCancel,
type CLIUserMessage, type SDKUserMessage,
type CLIAssistantMessage, type SDKAssistantMessage,
type TextBlock, type TextBlock,
type ContentBlock, type ContentBlock,
type CLIMessage, type SDKMessage,
type ControlMessage, type ControlMessage,
type ToolUseBlock, type ToolUseBlock,
} from '../../src/types/protocol.js'; } from '../../src/types/protocol.js';
@@ -31,16 +31,16 @@ const SHARED_TEST_OPTIONS = {
/** /**
* Determine the message type using protocol type guards * Determine the message type using protocol type guards
*/ */
function getMessageType(message: CLIMessage | ControlMessage): string { function getMessageType(message: SDKMessage | ControlMessage): string {
if (isCLIUserMessage(message)) { if (isSDKUserMessage(message)) {
return '🧑 USER'; return '🧑 USER';
} else if (isCLIAssistantMessage(message)) { } else if (isSDKAssistantMessage(message)) {
return '🤖 ASSISTANT'; return '🤖 ASSISTANT';
} else if (isCLISystemMessage(message)) { } else if (isSDKSystemMessage(message)) {
return `🖥️ SYSTEM(${message.subtype})`; return `🖥️ SYSTEM(${message.subtype})`;
} else if (isCLIResultMessage(message)) { } else if (isSDKResultMessage(message)) {
return `✅ RESULT(${message.subtype})`; return `✅ RESULT(${message.subtype})`;
} else if (isCLIPartialAssistantMessage(message)) { } else if (isSDKPartialAssistantMessage(message)) {
return '⏳ STREAM_EVENT'; return '⏳ STREAM_EVENT';
} else if (isControlRequest(message)) { } else if (isControlRequest(message)) {
return `🎮 CONTROL_REQUEST(${message.request.subtype})`; return `🎮 CONTROL_REQUEST(${message.request.subtype})`;
@@ -67,7 +67,7 @@ describe('Multi-Turn Conversations (E2E)', () => {
describe('AsyncIterable Prompt Support', () => { describe('AsyncIterable Prompt Support', () => {
it('should handle multi-turn conversation using AsyncIterable prompt', async () => { it('should handle multi-turn conversation using AsyncIterable prompt', async () => {
// Create multi-turn conversation generator // Create multi-turn conversation generator
async function* createMultiTurnConversation(): AsyncIterable<CLIUserMessage> { async function* createMultiTurnConversation(): AsyncIterable<SDKUserMessage> {
const sessionId = crypto.randomUUID(); const sessionId = crypto.randomUUID();
yield { yield {
@@ -78,7 +78,7 @@ describe('Multi-Turn Conversations (E2E)', () => {
content: 'What is 1 + 1?', content: 'What is 1 + 1?',
}, },
parent_tool_use_id: null, parent_tool_use_id: null,
} as CLIUserMessage; } as SDKUserMessage;
await new Promise((resolve) => setTimeout(resolve, 100)); await new Promise((resolve) => setTimeout(resolve, 100));
@@ -90,7 +90,7 @@ describe('Multi-Turn Conversations (E2E)', () => {
content: 'What is 2 + 2?', content: 'What is 2 + 2?',
}, },
parent_tool_use_id: null, parent_tool_use_id: null,
} as CLIUserMessage; } as SDKUserMessage;
await new Promise((resolve) => setTimeout(resolve, 100)); await new Promise((resolve) => setTimeout(resolve, 100));
@@ -102,7 +102,7 @@ describe('Multi-Turn Conversations (E2E)', () => {
content: 'What is 3 + 3?', content: 'What is 3 + 3?',
}, },
parent_tool_use_id: null, parent_tool_use_id: null,
} as CLIUserMessage; } as SDKUserMessage;
} }
// Create multi-turn query using AsyncIterable prompt // Create multi-turn query using AsyncIterable prompt
@@ -114,15 +114,15 @@ describe('Multi-Turn Conversations (E2E)', () => {
}, },
}); });
const messages: CLIMessage[] = []; const messages: SDKMessage[] = [];
const assistantMessages: CLIAssistantMessage[] = []; const assistantMessages: SDKAssistantMessage[] = [];
const assistantTexts: string[] = []; const assistantTexts: string[] = [];
try { try {
for await (const message of q) { for await (const message of q) {
messages.push(message); messages.push(message);
if (isCLIAssistantMessage(message)) { if (isSDKAssistantMessage(message)) {
assistantMessages.push(message); assistantMessages.push(message);
const text = extractText(message.message.content); const text = extractText(message.message.content);
assistantTexts.push(text); assistantTexts.push(text);
@@ -142,7 +142,7 @@ describe('Multi-Turn Conversations (E2E)', () => {
}); });
it('should maintain session context across turns', async () => { it('should maintain session context across turns', async () => {
async function* createContextualConversation(): AsyncIterable<CLIUserMessage> { async function* createContextualConversation(): AsyncIterable<SDKUserMessage> {
const sessionId = crypto.randomUUID(); const sessionId = crypto.randomUUID();
yield { yield {
@@ -154,7 +154,7 @@ describe('Multi-Turn Conversations (E2E)', () => {
'Suppose we have 3 rabbits and 4 carrots. How many animals are there?', 'Suppose we have 3 rabbits and 4 carrots. How many animals are there?',
}, },
parent_tool_use_id: null, parent_tool_use_id: null,
} as CLIUserMessage; } as SDKUserMessage;
await new Promise((resolve) => setTimeout(resolve, 200)); await new Promise((resolve) => setTimeout(resolve, 200));
@@ -166,7 +166,7 @@ describe('Multi-Turn Conversations (E2E)', () => {
content: 'How many animals are there? Only output the number', content: 'How many animals are there? Only output the number',
}, },
parent_tool_use_id: null, parent_tool_use_id: null,
} as CLIUserMessage; } as SDKUserMessage;
} }
const q = query({ const q = query({
@@ -177,11 +177,11 @@ describe('Multi-Turn Conversations (E2E)', () => {
}, },
}); });
const assistantMessages: CLIAssistantMessage[] = []; const assistantMessages: SDKAssistantMessage[] = [];
try { try {
for await (const message of q) { for await (const message of q) {
if (isCLIAssistantMessage(message)) { if (isSDKAssistantMessage(message)) {
assistantMessages.push(message); assistantMessages.push(message);
} }
} }
@@ -201,7 +201,7 @@ describe('Multi-Turn Conversations (E2E)', () => {
describe('Tool Usage in Multi-Turn', () => { describe('Tool Usage in Multi-Turn', () => {
it('should handle tool usage across multiple turns', async () => { it('should handle tool usage across multiple turns', async () => {
async function* createToolConversation(): AsyncIterable<CLIUserMessage> { async function* createToolConversation(): AsyncIterable<SDKUserMessage> {
const sessionId = crypto.randomUUID(); const sessionId = crypto.randomUUID();
yield { yield {
@@ -212,7 +212,7 @@ describe('Multi-Turn Conversations (E2E)', () => {
content: 'Create a file named test.txt with content "hello"', content: 'Create a file named test.txt with content "hello"',
}, },
parent_tool_use_id: null, parent_tool_use_id: null,
} as CLIUserMessage; } as SDKUserMessage;
await new Promise((resolve) => setTimeout(resolve, 200)); await new Promise((resolve) => setTimeout(resolve, 200));
@@ -224,7 +224,7 @@ describe('Multi-Turn Conversations (E2E)', () => {
content: 'Now read the test.txt file', content: 'Now read the test.txt file',
}, },
parent_tool_use_id: null, parent_tool_use_id: null,
} as CLIUserMessage; } as SDKUserMessage;
} }
const q = query({ const q = query({
@@ -237,15 +237,15 @@ describe('Multi-Turn Conversations (E2E)', () => {
}, },
}); });
const messages: CLIMessage[] = []; const messages: SDKMessage[] = [];
let toolUseCount = 0; let toolUseCount = 0;
const assistantMessages: CLIAssistantMessage[] = []; const assistantMessages: SDKAssistantMessage[] = [];
try { try {
for await (const message of q) { for await (const message of q) {
messages.push(message); messages.push(message);
if (isCLIAssistantMessage(message)) { if (isSDKAssistantMessage(message)) {
assistantMessages.push(message); assistantMessages.push(message);
const hasToolUseBlock = message.message.content.some( const hasToolUseBlock = message.message.content.some(
(block: ContentBlock): block is ToolUseBlock => (block: ContentBlock): block is ToolUseBlock =>
@@ -274,7 +274,7 @@ describe('Multi-Turn Conversations (E2E)', () => {
describe('Message Flow and Sequencing', () => { describe('Message Flow and Sequencing', () => {
it('should process messages in correct sequence', async () => { it('should process messages in correct sequence', async () => {
async function* createSequentialConversation(): AsyncIterable<CLIUserMessage> { async function* createSequentialConversation(): AsyncIterable<SDKUserMessage> {
const sessionId = crypto.randomUUID(); const sessionId = crypto.randomUUID();
yield { yield {
@@ -285,7 +285,7 @@ describe('Multi-Turn Conversations (E2E)', () => {
content: 'First question: What is 1 + 1?', content: 'First question: What is 1 + 1?',
}, },
parent_tool_use_id: null, parent_tool_use_id: null,
} as CLIUserMessage; } as SDKUserMessage;
await new Promise((resolve) => setTimeout(resolve, 100)); await new Promise((resolve) => setTimeout(resolve, 100));
@@ -297,7 +297,7 @@ describe('Multi-Turn Conversations (E2E)', () => {
content: 'Second question: What is 2 + 2?', content: 'Second question: What is 2 + 2?',
}, },
parent_tool_use_id: null, parent_tool_use_id: null,
} as CLIUserMessage; } as SDKUserMessage;
} }
const q = query({ const q = query({
@@ -316,7 +316,7 @@ describe('Multi-Turn Conversations (E2E)', () => {
const messageType = getMessageType(message); const messageType = getMessageType(message);
messageSequence.push(messageType); messageSequence.push(messageType);
if (isCLIAssistantMessage(message)) { if (isSDKAssistantMessage(message)) {
const text = extractText(message.message.content); const text = extractText(message.message.content);
assistantResponses.push(text); assistantResponses.push(text);
} }
@@ -338,7 +338,7 @@ describe('Multi-Turn Conversations (E2E)', () => {
}); });
it('should handle conversation completion correctly', async () => { it('should handle conversation completion correctly', async () => {
async function* createSimpleConversation(): AsyncIterable<CLIUserMessage> { async function* createSimpleConversation(): AsyncIterable<SDKUserMessage> {
const sessionId = crypto.randomUUID(); const sessionId = crypto.randomUUID();
yield { yield {
@@ -349,7 +349,7 @@ describe('Multi-Turn Conversations (E2E)', () => {
content: 'Hello', content: 'Hello',
}, },
parent_tool_use_id: null, parent_tool_use_id: null,
} as CLIUserMessage; } as SDKUserMessage;
await new Promise((resolve) => setTimeout(resolve, 100)); await new Promise((resolve) => setTimeout(resolve, 100));
@@ -361,7 +361,7 @@ describe('Multi-Turn Conversations (E2E)', () => {
content: 'Goodbye', content: 'Goodbye',
}, },
parent_tool_use_id: null, parent_tool_use_id: null,
} as CLIUserMessage; } as SDKUserMessage;
} }
const q = query({ const q = query({
@@ -379,7 +379,7 @@ describe('Multi-Turn Conversations (E2E)', () => {
for await (const message of q) { for await (const message of q) {
messageCount++; messageCount++;
if (isCLIResultMessage(message)) { if (isSDKResultMessage(message)) {
completedNaturally = true; completedNaturally = true;
expect(message.subtype).toBe('success'); expect(message.subtype).toBe('success');
} }
@@ -395,11 +395,11 @@ describe('Multi-Turn Conversations (E2E)', () => {
describe('Error Handling in Multi-Turn', () => { describe('Error Handling in Multi-Turn', () => {
it('should handle empty conversation gracefully', async () => { it('should handle empty conversation gracefully', async () => {
async function* createEmptyConversation(): AsyncIterable<CLIUserMessage> { async function* createEmptyConversation(): AsyncIterable<SDKUserMessage> {
// Generator that yields nothing // Generator that yields nothing
/* eslint-disable no-constant-condition */ /* eslint-disable no-constant-condition */
if (false) { if (false) {
yield {} as CLIUserMessage; // Unreachable, but satisfies TypeScript yield {} as SDKUserMessage; // Unreachable, but satisfies TypeScript
} }
} }
@@ -411,7 +411,7 @@ describe('Multi-Turn Conversations (E2E)', () => {
}, },
}); });
const messages: CLIMessage[] = []; const messages: SDKMessage[] = [];
try { try {
for await (const message of q) { for await (const message of q) {
@@ -426,7 +426,7 @@ describe('Multi-Turn Conversations (E2E)', () => {
}); });
it('should handle conversation with delays', async () => { it('should handle conversation with delays', async () => {
async function* createDelayedConversation(): AsyncIterable<CLIUserMessage> { async function* createDelayedConversation(): AsyncIterable<SDKUserMessage> {
const sessionId = crypto.randomUUID(); const sessionId = crypto.randomUUID();
yield { yield {
@@ -437,7 +437,7 @@ describe('Multi-Turn Conversations (E2E)', () => {
content: 'First message', content: 'First message',
}, },
parent_tool_use_id: null, parent_tool_use_id: null,
} as CLIUserMessage; } as SDKUserMessage;
// Longer delay to test patience // Longer delay to test patience
await new Promise((resolve) => setTimeout(resolve, 500)); await new Promise((resolve) => setTimeout(resolve, 500));
@@ -450,7 +450,7 @@ describe('Multi-Turn Conversations (E2E)', () => {
content: 'Second message after delay', content: 'Second message after delay',
}, },
parent_tool_use_id: null, parent_tool_use_id: null,
} as CLIUserMessage; } as SDKUserMessage;
} }
const q = query({ const q = query({
@@ -461,11 +461,11 @@ describe('Multi-Turn Conversations (E2E)', () => {
}, },
}); });
const assistantMessages: CLIAssistantMessage[] = []; const assistantMessages: SDKAssistantMessage[] = [];
try { try {
for await (const message of q) { for await (const message of q) {
if (isCLIAssistantMessage(message)) { if (isSDKAssistantMessage(message)) {
assistantMessages.push(message); assistantMessages.push(message);
} }
} }
@@ -479,7 +479,7 @@ describe('Multi-Turn Conversations (E2E)', () => {
describe('Partial Messages in Multi-Turn', () => { describe('Partial Messages in Multi-Turn', () => {
it('should receive partial messages when includePartialMessages is enabled', async () => { it('should receive partial messages when includePartialMessages is enabled', async () => {
async function* createMultiTurnConversation(): AsyncIterable<CLIUserMessage> { async function* createMultiTurnConversation(): AsyncIterable<SDKUserMessage> {
const sessionId = crypto.randomUUID(); const sessionId = crypto.randomUUID();
yield { yield {
@@ -490,7 +490,7 @@ describe('Multi-Turn Conversations (E2E)', () => {
content: 'What is 1 + 1?', content: 'What is 1 + 1?',
}, },
parent_tool_use_id: null, parent_tool_use_id: null,
} as CLIUserMessage; } as SDKUserMessage;
await new Promise((resolve) => setTimeout(resolve, 100)); await new Promise((resolve) => setTimeout(resolve, 100));
@@ -502,7 +502,7 @@ describe('Multi-Turn Conversations (E2E)', () => {
content: 'What is 2 + 2?', content: 'What is 2 + 2?',
}, },
parent_tool_use_id: null, parent_tool_use_id: null,
} as CLIUserMessage; } as SDKUserMessage;
} }
const q = query({ const q = query({
@@ -514,7 +514,7 @@ describe('Multi-Turn Conversations (E2E)', () => {
}, },
}); });
const messages: CLIMessage[] = []; const messages: SDKMessage[] = [];
let partialMessageCount = 0; let partialMessageCount = 0;
let assistantMessageCount = 0; let assistantMessageCount = 0;
@@ -522,11 +522,11 @@ describe('Multi-Turn Conversations (E2E)', () => {
for await (const message of q) { for await (const message of q) {
messages.push(message); messages.push(message);
if (isCLIPartialAssistantMessage(message)) { if (isSDKPartialAssistantMessage(message)) {
partialMessageCount++; partialMessageCount++;
} }
if (isCLIAssistantMessage(message)) { if (isSDKAssistantMessage(message)) {
assistantMessageCount++; assistantMessageCount++;
} }
} }

View File

@@ -7,10 +7,10 @@
import { describe, it, expect, beforeAll, afterAll } from 'vitest'; import { describe, it, expect, beforeAll, afterAll } from 'vitest';
import { query } from '../../src/index.js'; import { query } from '../../src/index.js';
import { import {
isCLIAssistantMessage, isSDKAssistantMessage,
isCLIResultMessage, isSDKResultMessage,
isCLIUserMessage, isSDKUserMessage,
type CLIUserMessage, type SDKUserMessage,
type ToolUseBlock, type ToolUseBlock,
type ContentBlock, type ContentBlock,
} from '../../src/types/protocol.js'; } from '../../src/types/protocol.js';
@@ -32,7 +32,7 @@ function createStreamingInputWithControlPoint(
firstMessage: string, firstMessage: string,
secondMessage: string, secondMessage: string,
): { ): {
generator: AsyncIterable<CLIUserMessage>; generator: AsyncIterable<SDKUserMessage>;
resume: () => void; resume: () => void;
} { } {
let resumeResolve: (() => void) | null = null; let resumeResolve: (() => void) | null = null;
@@ -51,7 +51,7 @@ function createStreamingInputWithControlPoint(
content: firstMessage, content: firstMessage,
}, },
parent_tool_use_id: null, parent_tool_use_id: null,
} as CLIUserMessage; } as SDKUserMessage;
await new Promise((resolve) => setTimeout(resolve, 200)); await new Promise((resolve) => setTimeout(resolve, 200));
@@ -67,7 +67,7 @@ function createStreamingInputWithControlPoint(
content: secondMessage, content: secondMessage,
}, },
parent_tool_use_id: null, parent_tool_use_id: null,
} as CLIUserMessage; } as SDKUserMessage;
})(); })();
const resume = () => { const resume = () => {
@@ -120,7 +120,7 @@ describe('Permission Control (E2E)', () => {
try { try {
let hasToolUse = false; let hasToolUse = false;
for await (const message of q) { for await (const message of q) {
if (isCLIAssistantMessage(message)) { if (isSDKAssistantMessage(message)) {
const toolUseBlock = message.message.content.find( const toolUseBlock = message.message.content.find(
(block: ContentBlock): block is ToolUseBlock => (block: ContentBlock): block is ToolUseBlock =>
block.type === 'tool_use', block.type === 'tool_use',
@@ -162,7 +162,7 @@ describe('Permission Control (E2E)', () => {
try { try {
let hasToolResult = false; let hasToolResult = false;
for await (const message of q) { for await (const message of q) {
if (isCLIUserMessage(message)) { if (isSDKUserMessage(message)) {
if ( if (
Array.isArray(message.message.content) && Array.isArray(message.message.content) &&
message.message.content.some( message.message.content.some(
@@ -372,7 +372,7 @@ describe('Permission Control (E2E)', () => {
(async () => { (async () => {
for await (const message of q) { for await (const message of q) {
if (isCLIAssistantMessage(message) || isCLIResultMessage(message)) { if (isSDKAssistantMessage(message) || isSDKResultMessage(message)) {
if (!firstResponseReceived) { if (!firstResponseReceived) {
firstResponseReceived = true; firstResponseReceived = true;
resolvers.first?.(); resolvers.first?.();
@@ -447,7 +447,7 @@ describe('Permission Control (E2E)', () => {
(async () => { (async () => {
for await (const message of q) { for await (const message of q) {
if (isCLIAssistantMessage(message) || isCLIResultMessage(message)) { if (isSDKAssistantMessage(message) || isSDKResultMessage(message)) {
if (!firstResponseReceived) { if (!firstResponseReceived) {
firstResponseReceived = true; firstResponseReceived = true;
resolvers.first?.(); resolvers.first?.();
@@ -522,7 +522,7 @@ describe('Permission Control (E2E)', () => {
(async () => { (async () => {
for await (const message of q) { for await (const message of q) {
if (isCLIAssistantMessage(message) || isCLIResultMessage(message)) { if (isSDKAssistantMessage(message) || isSDKResultMessage(message)) {
if (!firstResponseReceived) { if (!firstResponseReceived) {
firstResponseReceived = true; firstResponseReceived = true;
resolvers.first?.(); resolvers.first?.();
@@ -628,7 +628,7 @@ describe('Permission Control (E2E)', () => {
(async () => { (async () => {
for await (const message of q) { for await (const message of q) {
if (isCLIResultMessage(message)) { if (isSDKResultMessage(message)) {
if (!firstResponseReceived) { if (!firstResponseReceived) {
firstResponseReceived = true; firstResponseReceived = true;
resolvers.first?.(); resolvers.first?.();
@@ -695,7 +695,7 @@ describe('Permission Control (E2E)', () => {
let hasErrorInResult = false; let hasErrorInResult = false;
for await (const message of q) { for await (const message of q) {
if (isCLIUserMessage(message)) { if (isSDKUserMessage(message)) {
if (Array.isArray(message.message.content)) { if (Array.isArray(message.message.content)) {
const toolResult = message.message.content.find( const toolResult = message.message.content.find(
(block) => block.type === 'tool_result', (block) => block.type === 'tool_result',
@@ -752,7 +752,7 @@ describe('Permission Control (E2E)', () => {
let hasSuccessfulToolResult = false; let hasSuccessfulToolResult = false;
for await (const message of q) { for await (const message of q) {
if (isCLIUserMessage(message)) { if (isSDKUserMessage(message)) {
if (Array.isArray(message.message.content)) { if (Array.isArray(message.message.content)) {
const toolResult = message.message.content.find( const toolResult = message.message.content.find(
(block) => block.type === 'tool_result', (block) => block.type === 'tool_result',
@@ -798,7 +798,7 @@ describe('Permission Control (E2E)', () => {
let hasToolResult = false; let hasToolResult = false;
for await (const message of q) { for await (const message of q) {
if (isCLIUserMessage(message)) { if (isSDKUserMessage(message)) {
if (Array.isArray(message.message.content)) { if (Array.isArray(message.message.content)) {
const toolResult = message.message.content.find( const toolResult = message.message.content.find(
(block) => block.type === 'tool_result', (block) => block.type === 'tool_result',
@@ -838,7 +838,7 @@ describe('Permission Control (E2E)', () => {
let hasSuccessfulToolResult = false; let hasSuccessfulToolResult = false;
for await (const message of q) { for await (const message of q) {
if (isCLIUserMessage(message)) { if (isSDKUserMessage(message)) {
if (Array.isArray(message.message.content)) { if (Array.isArray(message.message.content)) {
const toolResult = message.message.content.find( const toolResult = message.message.content.find(
(block) => block.type === 'tool_result', (block) => block.type === 'tool_result',
@@ -891,7 +891,7 @@ describe('Permission Control (E2E)', () => {
let hasToolResult = false; let hasToolResult = false;
for await (const message of q) { for await (const message of q) {
if (isCLIUserMessage(message)) { if (isSDKUserMessage(message)) {
if (Array.isArray(message.message.content)) { if (Array.isArray(message.message.content)) {
const toolResult = message.message.content.find( const toolResult = message.message.content.find(
(block) => block.type === 'tool_result', (block) => block.type === 'tool_result',
@@ -929,7 +929,7 @@ describe('Permission Control (E2E)', () => {
let hasCommandResult = false; let hasCommandResult = false;
for await (const message of q) { for await (const message of q) {
if (isCLIUserMessage(message)) { if (isSDKUserMessage(message)) {
if (Array.isArray(message.message.content)) { if (Array.isArray(message.message.content)) {
const toolResult = message.message.content.find( const toolResult = message.message.content.find(
(block) => block.type === 'tool_result', (block) => block.type === 'tool_result',
@@ -968,7 +968,7 @@ describe('Permission Control (E2E)', () => {
let hasPlanModeMessage = false; let hasPlanModeMessage = false;
for await (const message of q) { for await (const message of q) {
if (isCLIUserMessage(message)) { if (isSDKUserMessage(message)) {
if (Array.isArray(message.message.content)) { if (Array.isArray(message.message.content)) {
const toolResult = message.message.content.find( const toolResult = message.message.content.find(
(block) => block.type === 'tool_result', (block) => block.type === 'tool_result',
@@ -1014,7 +1014,7 @@ describe('Permission Control (E2E)', () => {
let hasSuccessfulToolResult = false; let hasSuccessfulToolResult = false;
for await (const message of q) { for await (const message of q) {
if (isCLIUserMessage(message)) { if (isSDKUserMessage(message)) {
if (Array.isArray(message.message.content)) { if (Array.isArray(message.message.content)) {
const toolResult = message.message.content.find( const toolResult = message.message.content.find(
(block) => block.type === 'tool_result', (block) => block.type === 'tool_result',
@@ -1066,7 +1066,7 @@ describe('Permission Control (E2E)', () => {
let hasPlanModeBlock = false; let hasPlanModeBlock = false;
for await (const message of q) { for await (const message of q) {
if (isCLIUserMessage(message)) { if (isSDKUserMessage(message)) {
if (Array.isArray(message.message.content)) { if (Array.isArray(message.message.content)) {
const toolResult = message.message.content.find( const toolResult = message.message.content.find(
(block) => block.type === 'tool_result', (block) => block.type === 'tool_result',
@@ -1114,7 +1114,7 @@ describe('Permission Control (E2E)', () => {
let hasDeniedTool = false; let hasDeniedTool = false;
for await (const message of q) { for await (const message of q) {
if (isCLIUserMessage(message)) { if (isSDKUserMessage(message)) {
if (Array.isArray(message.message.content)) { if (Array.isArray(message.message.content)) {
const toolResult = message.message.content.find( const toolResult = message.message.content.find(
(block) => block.type === 'tool_result', (block) => block.type === 'tool_result',
@@ -1169,7 +1169,7 @@ describe('Permission Control (E2E)', () => {
let hasSuccessfulToolResult = false; let hasSuccessfulToolResult = false;
for await (const message of q) { for await (const message of q) {
if (isCLIUserMessage(message)) { if (isSDKUserMessage(message)) {
if (Array.isArray(message.message.content)) { if (Array.isArray(message.message.content)) {
const toolResult = message.message.content.find( const toolResult = message.message.content.find(
(block) => block.type === 'tool_result', (block) => block.type === 'tool_result',
@@ -1214,7 +1214,7 @@ describe('Permission Control (E2E)', () => {
let hasToolResult = false; let hasToolResult = false;
for await (const message of q) { for await (const message of q) {
if (isCLIUserMessage(message)) { if (isSDKUserMessage(message)) {
if (Array.isArray(message.message.content)) { if (Array.isArray(message.message.content)) {
const toolResult = message.message.content.find( const toolResult = message.message.content.find(
(block) => block.type === 'tool_result', (block) => block.type === 'tool_result',
@@ -1270,7 +1270,7 @@ describe('Permission Control (E2E)', () => {
let toolExecuted = false; let toolExecuted = false;
for await (const message of q) { for await (const message of q) {
if (isCLIUserMessage(message)) { if (isSDKUserMessage(message)) {
if (Array.isArray(message.message.content)) { if (Array.isArray(message.message.content)) {
const toolResult = message.message.content.find( const toolResult = message.message.content.find(
(block) => block.type === 'tool_result', (block) => block.type === 'tool_result',

View File

@@ -6,31 +6,22 @@
import { describe, it, expect } from 'vitest'; import { describe, it, expect } from 'vitest';
import { query } from '../../src/index.js'; import { query } from '../../src/index.js';
import { import {
isCLIAssistantMessage, isSDKAssistantMessage,
isCLISystemMessage, isSDKSystemMessage,
isCLIResultMessage, isSDKResultMessage,
isCLIPartialAssistantMessage, isSDKPartialAssistantMessage,
type TextBlock, type SDKMessage,
type ContentBlock, type SDKSystemMessage,
type CLIMessage, type SDKAssistantMessage,
type CLISystemMessage,
type CLIAssistantMessage,
} from '../../src/types/protocol.js'; } from '../../src/types/protocol.js';
const TEST_CLI_PATH = process.env['TEST_CLI_PATH']!; import {
extractText,
createSharedTestOptions,
assertSuccessfulCompletion,
collectMessagesByType,
} from './test-helper.js';
const SHARED_TEST_OPTIONS = { const SHARED_TEST_OPTIONS = createSharedTestOptions();
pathToQwenExecutable: TEST_CLI_PATH,
};
/**
* Helper to extract text from ContentBlock array
*/
function extractText(content: ContentBlock[]): string {
return content
.filter((block): block is TextBlock => block.type === 'text')
.map((block) => block.text)
.join('');
}
describe('Single-Turn Query (E2E)', () => { describe('Single-Turn Query (E2E)', () => {
describe('Simple Text Queries', () => { describe('Simple Text Queries', () => {
@@ -44,14 +35,14 @@ describe('Single-Turn Query (E2E)', () => {
}, },
}); });
const messages: CLIMessage[] = []; const messages: SDKMessage[] = [];
let assistantText = ''; let assistantText = '';
try { try {
for await (const message of q) { for await (const message of q) {
messages.push(message); messages.push(message);
if (isCLIAssistantMessage(message)) { if (isSDKAssistantMessage(message)) {
assistantText += extractText(message.message.content); assistantText += extractText(message.message.content);
} }
} }
@@ -64,11 +55,7 @@ describe('Single-Turn Query (E2E)', () => {
expect(assistantText).toMatch(/4/); expect(assistantText).toMatch(/4/);
// Validate message flow ends with success // Validate message flow ends with success
const lastMessage = messages[messages.length - 1]; assertSuccessfulCompletion(messages);
expect(isCLIResultMessage(lastMessage)).toBe(true);
if (isCLIResultMessage(lastMessage)) {
expect(lastMessage.subtype).toBe('success');
}
} finally { } finally {
await q.close(); await q.close();
} }
@@ -83,14 +70,14 @@ describe('Single-Turn Query (E2E)', () => {
}, },
}); });
const messages: CLIMessage[] = []; const messages: SDKMessage[] = [];
let assistantText = ''; let assistantText = '';
try { try {
for await (const message of q) { for await (const message of q) {
messages.push(message); messages.push(message);
if (isCLIAssistantMessage(message)) { if (isSDKAssistantMessage(message)) {
assistantText += extractText(message.message.content); assistantText += extractText(message.message.content);
} }
} }
@@ -100,8 +87,7 @@ describe('Single-Turn Query (E2E)', () => {
expect(assistantText.toLowerCase()).toContain('paris'); expect(assistantText.toLowerCase()).toContain('paris');
// Validate completion // Validate completion
const lastMessage = messages[messages.length - 1]; assertSuccessfulCompletion(messages);
expect(isCLIResultMessage(lastMessage)).toBe(true);
} finally { } finally {
await q.close(); await q.close();
} }
@@ -116,14 +102,14 @@ describe('Single-Turn Query (E2E)', () => {
}, },
}); });
const messages: CLIMessage[] = []; const messages: SDKMessage[] = [];
let assistantText = ''; let assistantText = '';
try { try {
for await (const message of q) { for await (const message of q) {
messages.push(message); messages.push(message);
if (isCLIAssistantMessage(message)) { if (isSDKAssistantMessage(message)) {
assistantText += extractText(message.message.content); assistantText += extractText(message.message.content);
} }
} }
@@ -133,7 +119,10 @@ describe('Single-Turn Query (E2E)', () => {
expect(assistantText.toLowerCase()).toMatch(/hello|hi|greetings/); expect(assistantText.toLowerCase()).toMatch(/hello|hi|greetings/);
// Validate message types // Validate message types
const assistantMessages = messages.filter(isCLIAssistantMessage); const assistantMessages = collectMessagesByType(
messages,
isSDKAssistantMessage,
);
expect(assistantMessages.length).toBeGreaterThan(0); expect(assistantMessages.length).toBeGreaterThan(0);
} finally { } finally {
await q.close(); await q.close();
@@ -151,14 +140,14 @@ describe('Single-Turn Query (E2E)', () => {
}, },
}); });
const messages: CLIMessage[] = []; const messages: SDKMessage[] = [];
let systemMessage: CLISystemMessage | null = null; let systemMessage: SDKSystemMessage | null = null;
try { try {
for await (const message of q) { for await (const message of q) {
messages.push(message); messages.push(message);
if (isCLISystemMessage(message) && message.subtype === 'init') { if (isSDKSystemMessage(message) && message.subtype === 'init') {
systemMessage = message; systemMessage = message;
} }
} }
@@ -180,7 +169,7 @@ describe('Single-Turn Query (E2E)', () => {
// Validate system message appears early in sequence // Validate system message appears early in sequence
const systemMessageIndex = messages.findIndex( const systemMessageIndex = messages.findIndex(
(msg) => isCLISystemMessage(msg) && msg.subtype === 'init', (msg) => isSDKSystemMessage(msg) && msg.subtype === 'init',
); );
expect(systemMessageIndex).toBeGreaterThanOrEqual(0); expect(systemMessageIndex).toBeGreaterThanOrEqual(0);
expect(systemMessageIndex).toBeLessThan(3); expect(systemMessageIndex).toBeLessThan(3);
@@ -198,12 +187,12 @@ describe('Single-Turn Query (E2E)', () => {
}, },
}); });
let systemMessage: CLISystemMessage | null = null; let systemMessage: SDKSystemMessage | null = null;
const sessionId = q.getSessionId(); const sessionId = q.getSessionId();
try { try {
for await (const message of q) { for await (const message of q) {
if (isCLISystemMessage(message) && message.subtype === 'init') { if (isSDKSystemMessage(message) && message.subtype === 'init') {
systemMessage = message; systemMessage = message;
} }
} }
@@ -262,7 +251,7 @@ describe('Single-Turn Query (E2E)', () => {
for await (const message of q) { for await (const message of q) {
messageCount++; messageCount++;
if (isCLIResultMessage(message)) { if (isSDKResultMessage(message)) {
completedNaturally = true; completedNaturally = true;
expect(message.subtype).toBe('success'); expect(message.subtype).toBe('success');
} }
@@ -319,7 +308,7 @@ describe('Single-Turn Query (E2E)', () => {
try { try {
for await (const message of q) { for await (const message of q) {
if (isCLIAssistantMessage(message)) { if (isSDKAssistantMessage(message)) {
hasResponse = true; hasResponse = true;
} }
} }
@@ -340,7 +329,7 @@ describe('Single-Turn Query (E2E)', () => {
}, },
}); });
const messages: CLIMessage[] = []; const messages: SDKMessage[] = [];
let partialMessageCount = 0; let partialMessageCount = 0;
let assistantMessageCount = 0; let assistantMessageCount = 0;
@@ -348,11 +337,11 @@ describe('Single-Turn Query (E2E)', () => {
for await (const message of q) { for await (const message of q) {
messages.push(message); messages.push(message);
if (isCLIPartialAssistantMessage(message)) { if (isSDKPartialAssistantMessage(message)) {
partialMessageCount++; partialMessageCount++;
} }
if (isCLIAssistantMessage(message)) { if (isSDKAssistantMessage(message)) {
assistantMessageCount++; assistantMessageCount++;
} }
} }
@@ -376,7 +365,7 @@ describe('Single-Turn Query (E2E)', () => {
}, },
}); });
const messages: CLIMessage[] = []; const messages: SDKMessage[] = [];
try { try {
for await (const message of q) { for await (const message of q) {
@@ -384,9 +373,18 @@ describe('Single-Turn Query (E2E)', () => {
} }
// Validate type guards work correctly // Validate type guards work correctly
const assistantMessages = messages.filter(isCLIAssistantMessage); const assistantMessages = collectMessagesByType(
const resultMessages = messages.filter(isCLIResultMessage); messages,
const systemMessages = messages.filter(isCLISystemMessage); isSDKAssistantMessage,
);
const resultMessages = collectMessagesByType(
messages,
isSDKResultMessage,
);
const systemMessages = collectMessagesByType(
messages,
isSDKSystemMessage,
);
expect(assistantMessages.length).toBeGreaterThan(0); expect(assistantMessages.length).toBeGreaterThan(0);
expect(resultMessages.length).toBeGreaterThan(0); expect(resultMessages.length).toBeGreaterThan(0);
@@ -414,11 +412,11 @@ describe('Single-Turn Query (E2E)', () => {
}, },
}); });
let assistantMessage: CLIAssistantMessage | null = null; let assistantMessage: SDKAssistantMessage | null = null;
try { try {
for await (const message of q) { for await (const message of q) {
if (isCLIAssistantMessage(message)) { if (isSDKAssistantMessage(message)) {
assistantMessage = message; assistantMessage = message;
} }
} }
@@ -426,17 +424,9 @@ describe('Single-Turn Query (E2E)', () => {
expect(assistantMessage).not.toBeNull(); expect(assistantMessage).not.toBeNull();
expect(assistantMessage!.message.content).toBeDefined(); expect(assistantMessage!.message.content).toBeDefined();
// Extract text blocks
const textBlocks = assistantMessage!.message.content.filter(
(block: ContentBlock): block is TextBlock => block.type === 'text',
);
expect(textBlocks.length).toBeGreaterThan(0);
expect(textBlocks[0].text).toBeDefined();
expect(textBlocks[0].text.length).toBeGreaterThan(0);
// Validate content contains expected numbers // Validate content contains expected numbers
const text = extractText(assistantMessage!.message.content); const text = extractText(assistantMessage!.message.content);
expect(text.length).toBeGreaterThan(0);
expect(text).toMatch(/1/); expect(text).toMatch(/1/);
expect(text).toMatch(/2/); expect(text).toMatch(/2/);
expect(text).toMatch(/3/); expect(text).toMatch(/3/);

View File

@@ -9,50 +9,42 @@
* Tests subagent delegation and task completion * Tests subagent delegation and task completion
*/ */
import { describe, it, expect, beforeAll } from 'vitest'; import { describe, it, expect, beforeAll, afterAll } from 'vitest';
import { query } from '../../src/index.js'; import { query } from '../../src/index.js';
import { import {
isCLIAssistantMessage, isSDKAssistantMessage,
isCLISystemMessage, type SDKMessage,
isCLIResultMessage,
type TextBlock,
type ContentBlock,
type CLIMessage,
type CLISystemMessage,
type SubagentConfig, type SubagentConfig,
type ContentBlock,
type ToolUseBlock, type ToolUseBlock,
} from '../../src/types/protocol.js'; } from '../../src/types/protocol.js';
import { writeFile, mkdir } from 'node:fs/promises'; import {
import { join } from 'node:path'; SDKTestHelper,
extractText,
createSharedTestOptions,
findToolUseBlocks,
assertSuccessfulCompletion,
findSystemMessage,
} from './test-helper.js';
const TEST_CLI_PATH = process.env['TEST_CLI_PATH']!; const SHARED_TEST_OPTIONS = createSharedTestOptions();
const E2E_TEST_FILE_DIR = process.env['E2E_TEST_FILE_DIR']!;
const SHARED_TEST_OPTIONS = {
pathToQwenExecutable: TEST_CLI_PATH,
};
/**
* Helper to extract text from ContentBlock array
*/
function extractText(content: ContentBlock[]): string {
return content
.filter((block): block is TextBlock => block.type === 'text')
.map((block) => block.text)
.join('');
}
describe('Subagents (E2E)', () => { describe('Subagents (E2E)', () => {
let helper: SDKTestHelper;
let testWorkDir: string; let testWorkDir: string;
beforeAll(async () => { beforeAll(async () => {
// Create a test working directory // Create isolated test environment using SDKTestHelper
testWorkDir = join(E2E_TEST_FILE_DIR, 'subagent-tests'); helper = new SDKTestHelper();
await mkdir(testWorkDir, { recursive: true }); testWorkDir = await helper.setup('subagent-tests');
// Create a simple test file for subagent to work with // Create a simple test file for subagent to work with
const testFilePath = join(testWorkDir, 'test.txt'); await helper.createFile('test.txt', 'Hello from test file\n');
await writeFile(testFilePath, 'Hello from test file\n', 'utf-8'); });
afterAll(async () => {
// Cleanup test directory
await helper.cleanup();
}); });
describe('Subagent Configuration', () => { describe('Subagent Configuration', () => {
@@ -75,29 +67,21 @@ describe('Subagents (E2E)', () => {
}, },
}); });
let systemMessage: CLISystemMessage | null = null; const messages: SDKMessage[] = [];
const messages: CLIMessage[] = [];
try { try {
for await (const message of q) { for await (const message of q) {
messages.push(message); messages.push(message);
if (isCLISystemMessage(message) && message.subtype === 'init') {
systemMessage = message;
}
} }
// Validate system message includes the subagent // Validate system message includes the subagent
const systemMessage = findSystemMessage(messages, 'init');
expect(systemMessage).not.toBeNull(); expect(systemMessage).not.toBeNull();
expect(systemMessage!.agents).toBeDefined(); expect(systemMessage!.agents).toBeDefined();
expect(systemMessage!.agents).toContain('simple-greeter'); expect(systemMessage!.agents).toContain('simple-greeter');
// Validate successful completion // Validate successful completion
const lastMessage = messages[messages.length - 1]; assertSuccessfulCompletion(messages);
expect(isCLIResultMessage(lastMessage)).toBe(true);
if (isCLIResultMessage(lastMessage)) {
expect(lastMessage.subtype).toBe('success');
}
} finally { } finally {
await q.close(); await q.close();
} }
@@ -128,16 +112,15 @@ describe('Subagents (E2E)', () => {
}, },
}); });
let systemMessage: CLISystemMessage | null = null; const messages: SDKMessage[] = [];
try { try {
for await (const message of q) { for await (const message of q) {
if (isCLISystemMessage(message) && message.subtype === 'init') { messages.push(message);
systemMessage = message;
}
} }
// Validate both subagents are registered // Validate both subagents are registered
const systemMessage = findSystemMessage(messages, 'init');
expect(systemMessage).not.toBeNull(); expect(systemMessage).not.toBeNull();
expect(systemMessage!.agents).toBeDefined(); expect(systemMessage!.agents).toBeDefined();
expect(systemMessage!.agents).toContain('greeter'); expect(systemMessage!.agents).toContain('greeter');
@@ -170,16 +153,15 @@ describe('Subagents (E2E)', () => {
}, },
}); });
let systemMessage: CLISystemMessage | null = null; const messages: SDKMessage[] = [];
try { try {
for await (const message of q) { for await (const message of q) {
if (isCLISystemMessage(message) && message.subtype === 'init') { messages.push(message);
systemMessage = message;
}
} }
// Validate subagent is registered // Validate subagent is registered
const systemMessage = findSystemMessage(messages, 'init');
expect(systemMessage).not.toBeNull(); expect(systemMessage).not.toBeNull();
expect(systemMessage!.agents).toBeDefined(); expect(systemMessage!.agents).toBeDefined();
expect(systemMessage!.agents).toContain('custom-model-agent'); expect(systemMessage!.agents).toContain('custom-model-agent');
@@ -210,16 +192,15 @@ describe('Subagents (E2E)', () => {
}, },
}); });
let systemMessage: CLISystemMessage | null = null; const messages: SDKMessage[] = [];
try { try {
for await (const message of q) { for await (const message of q) {
if (isCLISystemMessage(message) && message.subtype === 'init') { messages.push(message);
systemMessage = message;
}
} }
// Validate subagent is registered // Validate subagent is registered
const systemMessage = findSystemMessage(messages, 'init');
expect(systemMessage).not.toBeNull(); expect(systemMessage).not.toBeNull();
expect(systemMessage!.agents).toBeDefined(); expect(systemMessage!.agents).toBeDefined();
expect(systemMessage!.agents).toContain('limited-agent'); expect(systemMessage!.agents).toContain('limited-agent');
@@ -248,16 +229,15 @@ describe('Subagents (E2E)', () => {
}, },
}); });
let systemMessage: CLISystemMessage | null = null; const messages: SDKMessage[] = [];
try { try {
for await (const message of q) { for await (const message of q) {
if (isCLISystemMessage(message) && message.subtype === 'init') { messages.push(message);
systemMessage = message;
}
} }
// Validate subagent is registered // Validate subagent is registered
const systemMessage = findSystemMessage(messages, 'init');
expect(systemMessage).not.toBeNull(); expect(systemMessage).not.toBeNull();
expect(systemMessage!.agents).toBeDefined(); expect(systemMessage!.agents).toBeDefined();
expect(systemMessage!.agents).toContain('read-only-agent'); expect(systemMessage!.agents).toContain('read-only-agent');
@@ -277,7 +257,7 @@ describe('Subagents (E2E)', () => {
tools: ['read_file', 'list_directory'], tools: ['read_file', 'list_directory'],
}; };
const testFile = join(testWorkDir, 'test.txt'); const testFile = helper.getPath('test.txt');
const q = query({ const q = query({
prompt: `Use the file-reader subagent to read the file at ${testFile} and tell me what it contains.`, prompt: `Use the file-reader subagent to read the file at ${testFile} and tell me what it contains.`,
options: { options: {
@@ -289,7 +269,7 @@ describe('Subagents (E2E)', () => {
}, },
}); });
const messages: CLIMessage[] = []; const messages: SDKMessage[] = [];
let foundTaskTool = false; let foundTaskTool = false;
let taskToolUseId: string | null = null; let taskToolUseId: string | null = null;
let foundSubagentToolCall = false; let foundSubagentToolCall = false;
@@ -299,25 +279,19 @@ describe('Subagents (E2E)', () => {
for await (const message of q) { for await (const message of q) {
messages.push(message); messages.push(message);
if (isCLIAssistantMessage(message)) { if (isSDKAssistantMessage(message)) {
// Check for task tool use in content blocks (main agent calling subagent) // Check for task tool use in content blocks (main agent calling subagent)
const toolUseBlock = message.message.content.find( const taskToolBlocks = findToolUseBlocks(message, 'task');
(block: ContentBlock): block is ToolUseBlock => if (taskToolBlocks.length > 0) {
block.type === 'tool_use' && block.name === 'task',
);
if (toolUseBlock) {
foundTaskTool = true; foundTaskTool = true;
taskToolUseId = toolUseBlock.id; taskToolUseId = taskToolBlocks[0].id;
} }
// Check if this message is from a subagent (has parent_tool_use_id) // Check if this message is from a subagent (has parent_tool_use_id)
if (message.parent_tool_use_id !== null) { if (message.parent_tool_use_id !== null) {
// This is a subagent message // This is a subagent message
const subagentToolUse = message.message.content.find( const subagentToolBlocks = findToolUseBlocks(message);
(block: ContentBlock): block is ToolUseBlock => if (subagentToolBlocks.length > 0) {
block.type === 'tool_use',
);
if (subagentToolUse) {
foundSubagentToolCall = true; foundSubagentToolCall = true;
// Verify parent_tool_use_id matches the task tool use id // Verify parent_tool_use_id matches the task tool use id
expect(message.parent_tool_use_id).toBe(taskToolUseId); expect(message.parent_tool_use_id).toBe(taskToolUseId);
@@ -339,11 +313,7 @@ describe('Subagents (E2E)', () => {
expect(assistantText.length).toBeGreaterThan(0); expect(assistantText.length).toBeGreaterThan(0);
// Validate successful completion // Validate successful completion
const lastMessage = messages[messages.length - 1]; assertSuccessfulCompletion(messages);
expect(isCLIResultMessage(lastMessage)).toBe(true);
if (isCLIResultMessage(lastMessage)) {
expect(lastMessage.subtype).toBe('success');
}
} finally { } finally {
await q.close(); await q.close();
} }
@@ -369,7 +339,7 @@ describe('Subagents (E2E)', () => {
}, },
}); });
const messages: CLIMessage[] = []; const messages: SDKMessage[] = [];
let foundTaskTool = false; let foundTaskTool = false;
let assistantText = ''; let assistantText = '';
@@ -377,7 +347,7 @@ describe('Subagents (E2E)', () => {
for await (const message of q) { for await (const message of q) {
messages.push(message); messages.push(message);
if (isCLIAssistantMessage(message)) { if (isSDKAssistantMessage(message)) {
// Check for task tool use (main agent delegating to subagent) // Check for task tool use (main agent delegating to subagent)
const toolUseBlock = message.message.content.find( const toolUseBlock = message.message.content.find(
(block: ContentBlock): block is ToolUseBlock => (block: ContentBlock): block is ToolUseBlock =>
@@ -398,11 +368,7 @@ describe('Subagents (E2E)', () => {
expect(assistantText.length).toBeGreaterThan(0); expect(assistantText.length).toBeGreaterThan(0);
// Validate successful completion // Validate successful completion
const lastMessage = messages[messages.length - 1]; assertSuccessfulCompletion(messages);
expect(isCLIResultMessage(lastMessage)).toBe(true);
if (isCLIResultMessage(lastMessage)) {
expect(lastMessage.subtype).toBe('success');
}
} finally { } finally {
await q.close(); await q.close();
} }
@@ -429,7 +395,7 @@ describe('Subagents (E2E)', () => {
}, },
}); });
const messages: CLIMessage[] = []; const messages: SDKMessage[] = [];
let taskToolUseId: string | null = null; let taskToolUseId: string | null = null;
const subagentToolCalls: ToolUseBlock[] = []; const subagentToolCalls: ToolUseBlock[] = [];
const mainAgentToolCalls: ToolUseBlock[] = []; const mainAgentToolCalls: ToolUseBlock[] = [];
@@ -438,7 +404,7 @@ describe('Subagents (E2E)', () => {
for await (const message of q) { for await (const message of q) {
messages.push(message); messages.push(message);
if (isCLIAssistantMessage(message)) { if (isSDKAssistantMessage(message)) {
// Collect all tool use blocks // Collect all tool use blocks
const toolUseBlocks = message.message.content.filter( const toolUseBlocks = message.message.content.filter(
(block: ContentBlock): block is ToolUseBlock => (block: ContentBlock): block is ToolUseBlock =>
@@ -471,8 +437,8 @@ describe('Subagents (E2E)', () => {
// Verify all subagent messages have the correct parent_tool_use_id // Verify all subagent messages have the correct parent_tool_use_id
const subagentMessages = messages.filter( const subagentMessages = messages.filter(
(msg): msg is CLIMessage & { parent_tool_use_id: string } => (msg): msg is SDKMessage & { parent_tool_use_id: string } =>
isCLIAssistantMessage(msg) && msg.parent_tool_use_id !== null, isSDKAssistantMessage(msg) && msg.parent_tool_use_id !== null,
); );
expect(subagentMessages.length).toBeGreaterThan(0); expect(subagentMessages.length).toBeGreaterThan(0);
@@ -482,23 +448,19 @@ describe('Subagents (E2E)', () => {
// Verify no main agent tool calls (except task) have parent_tool_use_id // Verify no main agent tool calls (except task) have parent_tool_use_id
const mainAgentMessages = messages.filter( const mainAgentMessages = messages.filter(
(msg): msg is CLIMessage => (msg): msg is SDKMessage =>
isCLIAssistantMessage(msg) && msg.parent_tool_use_id === null, isSDKAssistantMessage(msg) && msg.parent_tool_use_id === null,
); );
for (const mainMsg of mainAgentMessages) { for (const mainMsg of mainAgentMessages) {
if (isCLIAssistantMessage(mainMsg)) { if (isSDKAssistantMessage(mainMsg)) {
// Main agent messages should not have parent_tool_use_id // Main agent messages should not have parent_tool_use_id
expect(mainMsg.parent_tool_use_id).toBeNull(); expect(mainMsg.parent_tool_use_id).toBeNull();
} }
} }
// Validate successful completion // Validate successful completion
const lastMessage = messages[messages.length - 1]; assertSuccessfulCompletion(messages);
expect(isCLIResultMessage(lastMessage)).toBe(true);
if (isCLIResultMessage(lastMessage)) {
expect(lastMessage.subtype).toBe('success');
}
} finally { } finally {
await q.close(); await q.close();
} }
@@ -517,16 +479,15 @@ describe('Subagents (E2E)', () => {
}, },
}); });
let systemMessage: CLISystemMessage | null = null; const messages: SDKMessage[] = [];
try { try {
for await (const message of q) { for await (const message of q) {
if (isCLISystemMessage(message) && message.subtype === 'init') { messages.push(message);
systemMessage = message;
}
} }
// Should still work with empty agents array // Should still work with empty agents array
const systemMessage = findSystemMessage(messages, 'init');
expect(systemMessage).not.toBeNull(); expect(systemMessage).not.toBeNull();
expect(systemMessage!.agents).toBeDefined(); expect(systemMessage!.agents).toBeDefined();
} finally { } finally {
@@ -552,16 +513,15 @@ describe('Subagents (E2E)', () => {
}, },
}); });
let systemMessage: CLISystemMessage | null = null; const messages: SDKMessage[] = [];
try { try {
for await (const message of q) { for await (const message of q) {
if (isCLISystemMessage(message) && message.subtype === 'init') { messages.push(message);
systemMessage = message;
}
} }
// Validate minimal agent is registered // Validate minimal agent is registered
const systemMessage = findSystemMessage(messages, 'init');
expect(systemMessage).not.toBeNull(); expect(systemMessage).not.toBeNull();
expect(systemMessage!.agents).toBeDefined(); expect(systemMessage!.agents).toBeDefined();
expect(systemMessage!.agents).toContain('minimal-agent'); expect(systemMessage!.agents).toContain('minimal-agent');
@@ -596,16 +556,15 @@ describe('Subagents (E2E)', () => {
}, },
}); });
let systemMessage: CLISystemMessage | null = null; const messages: SDKMessage[] = [];
try { try {
for await (const message of q) { for await (const message of q) {
if (isCLISystemMessage(message) && message.subtype === 'init') { messages.push(message);
systemMessage = message;
}
} }
// Validate subagent works with debug mode // Validate subagent works with debug mode
const systemMessage = findSystemMessage(messages, 'init');
expect(systemMessage).not.toBeNull(); expect(systemMessage).not.toBeNull();
expect(systemMessage!.agents).toBeDefined(); expect(systemMessage!.agents).toBeDefined();
expect(systemMessage!.agents).toContain('test-agent'); expect(systemMessage!.agents).toContain('test-agent');
@@ -633,16 +592,15 @@ describe('Subagents (E2E)', () => {
}, },
}); });
let systemMessage: CLISystemMessage | null = null; const messages: SDKMessage[] = [];
try { try {
for await (const message of q) { for await (const message of q) {
if (isCLISystemMessage(message) && message.subtype === 'init') { messages.push(message);
systemMessage = message;
}
} }
// Validate session consistency // Validate session consistency
const systemMessage = findSystemMessage(messages, 'init');
expect(systemMessage).not.toBeNull(); expect(systemMessage).not.toBeNull();
expect(systemMessage!.session_id).toBeDefined(); expect(systemMessage!.session_id).toBeDefined();
expect(systemMessage!.uuid).toBeDefined(); expect(systemMessage!.uuid).toBeDefined();

View File

@@ -6,9 +6,9 @@
import { describe, it, expect } from 'vitest'; import { describe, it, expect } from 'vitest';
import { query } from '../../src/index.js'; import { query } from '../../src/index.js';
import { import {
isCLIAssistantMessage, isSDKAssistantMessage,
isCLISystemMessage, isSDKSystemMessage,
type CLIUserMessage, type SDKUserMessage,
} from '../../src/types/protocol.js'; } from '../../src/types/protocol.js';
const TEST_CLI_PATH = process.env['TEST_CLI_PATH']!; const TEST_CLI_PATH = process.env['TEST_CLI_PATH']!;
@@ -30,7 +30,7 @@ function createStreamingInputWithControlPoint(
firstMessage: string, firstMessage: string,
secondMessage: string, secondMessage: string,
): { ): {
generator: AsyncIterable<CLIUserMessage>; generator: AsyncIterable<SDKUserMessage>;
resume: () => void; resume: () => void;
} { } {
let resumeResolve: (() => void) | null = null; let resumeResolve: (() => void) | null = null;
@@ -49,7 +49,7 @@ function createStreamingInputWithControlPoint(
content: firstMessage, content: firstMessage,
}, },
parent_tool_use_id: null, parent_tool_use_id: null,
} as CLIUserMessage; } as SDKUserMessage;
await new Promise((resolve) => setTimeout(resolve, 200)); await new Promise((resolve) => setTimeout(resolve, 200));
@@ -65,7 +65,7 @@ function createStreamingInputWithControlPoint(
content: secondMessage, content: secondMessage,
}, },
parent_tool_use_id: null, parent_tool_use_id: null,
} as CLIUserMessage; } as SDKUserMessage;
})(); })();
const resume = () => { const resume = () => {
@@ -113,10 +113,10 @@ describe('System Control (E2E)', () => {
// Consume messages in a single loop // Consume messages in a single loop
(async () => { (async () => {
for await (const message of q) { for await (const message of q) {
if (isCLISystemMessage(message)) { if (isSDKSystemMessage(message)) {
systemMessages.push({ model: message.model }); systemMessages.push({ model: message.model });
} }
if (isCLIAssistantMessage(message)) { if (isSDKAssistantMessage(message)) {
if (!firstResponseReceived) { if (!firstResponseReceived) {
firstResponseReceived = true; firstResponseReceived = true;
resolvers.first?.(); resolvers.first?.();
@@ -186,7 +186,7 @@ describe('System Control (E2E)', () => {
session_id: sessionId, session_id: sessionId,
message: { role: 'user', content: 'First message' }, message: { role: 'user', content: 'First message' },
parent_tool_use_id: null, parent_tool_use_id: null,
} as CLIUserMessage; } as SDKUserMessage;
await new Promise((resolve) => setTimeout(resolve, 200)); await new Promise((resolve) => setTimeout(resolve, 200));
await resumePromise1; await resumePromise1;
@@ -197,7 +197,7 @@ describe('System Control (E2E)', () => {
session_id: sessionId, session_id: sessionId,
message: { role: 'user', content: 'Second message' }, message: { role: 'user', content: 'Second message' },
parent_tool_use_id: null, parent_tool_use_id: null,
} as CLIUserMessage; } as SDKUserMessage;
await new Promise((resolve) => setTimeout(resolve, 200)); await new Promise((resolve) => setTimeout(resolve, 200));
await resumePromise2; await resumePromise2;
@@ -208,7 +208,7 @@ describe('System Control (E2E)', () => {
session_id: sessionId, session_id: sessionId,
message: { role: 'user', content: 'Third message' }, message: { role: 'user', content: 'Third message' },
parent_tool_use_id: null, parent_tool_use_id: null,
} as CLIUserMessage; } as SDKUserMessage;
})(); })();
const q = query({ const q = query({
@@ -232,10 +232,10 @@ describe('System Control (E2E)', () => {
(async () => { (async () => {
for await (const message of q) { for await (const message of q) {
if (isCLISystemMessage(message)) { if (isSDKSystemMessage(message)) {
systemMessages.push({ model: message.model }); systemMessages.push({ model: message.model });
} }
if (isCLIAssistantMessage(message)) { if (isSDKAssistantMessage(message)) {
if (responseCount < resolvers.length) { if (responseCount < resolvers.length) {
resolvers[responseCount]?.(); resolvers[responseCount]?.();
responseCount++; responseCount++;

View File

@@ -0,0 +1,829 @@
/**
* @license
* Copyright 2025 Qwen Team
* SPDX-License-Identifier: Apache-2.0
*/
/**
* SDK E2E Test Helper
* Provides utilities for SDK e2e tests including test isolation,
* file management, MCP server setup, and common test utilities.
*/
import { mkdir, writeFile, readFile, rm, chmod } from 'node:fs/promises';
import { join } from 'node:path';
import { existsSync } from 'node:fs';
import type {
SDKMessage,
SDKAssistantMessage,
SDKSystemMessage,
SDKUserMessage,
ContentBlock,
TextBlock,
ToolUseBlock,
} from '../../src/types/protocol.js';
import {
isSDKAssistantMessage,
isSDKSystemMessage,
isSDKResultMessage,
} from '../../src/types/protocol.js';
// ============================================================================
// Core Test Helper Class
// ============================================================================
export interface SDKTestHelperOptions {
/**
* Optional settings for .qwen/settings.json
*/
settings?: Record<string, unknown>;
/**
* Whether to create .qwen/settings.json
*/
createQwenConfig?: boolean;
}
/**
* Helper class for SDK E2E tests
* Provides isolated test environments for each test case
*/
export class SDKTestHelper {
testDir: string | null = null;
testName?: string;
private baseDir: string;
constructor() {
this.baseDir = process.env['E2E_TEST_FILE_DIR']!;
if (!this.baseDir) {
throw new Error('E2E_TEST_FILE_DIR environment variable not set');
}
}
/**
* Setup an isolated test directory for a specific test
*/
async setup(
testName: string,
options: SDKTestHelperOptions = {},
): Promise<string> {
this.testName = testName;
const sanitizedName = this.sanitizeTestName(testName);
this.testDir = join(this.baseDir, sanitizedName);
await mkdir(this.testDir, { recursive: true });
// Optionally create .qwen/settings.json for CLI configuration
if (options.createQwenConfig) {
const qwenDir = join(this.testDir, '.qwen');
await mkdir(qwenDir, { recursive: true });
const settings = {
telemetry: {
enabled: false, // SDK tests don't need telemetry
},
...options.settings,
};
await writeFile(
join(qwenDir, 'settings.json'),
JSON.stringify(settings, null, 2),
'utf-8',
);
}
return this.testDir;
}
/**
* Create a file in the test directory
*/
async createFile(fileName: string, content: string): Promise<string> {
if (!this.testDir) {
throw new Error('Test directory not initialized. Call setup() first.');
}
const filePath = join(this.testDir, fileName);
await writeFile(filePath, content, 'utf-8');
return filePath;
}
/**
* Read a file from the test directory
*/
async readFile(fileName: string): Promise<string> {
if (!this.testDir) {
throw new Error('Test directory not initialized. Call setup() first.');
}
const filePath = join(this.testDir, fileName);
return await readFile(filePath, 'utf-8');
}
/**
* Create a subdirectory in the test directory
*/
async mkdir(dirName: string): Promise<string> {
if (!this.testDir) {
throw new Error('Test directory not initialized. Call setup() first.');
}
const dirPath = join(this.testDir, dirName);
await mkdir(dirPath, { recursive: true });
return dirPath;
}
/**
* Check if a file exists in the test directory
*/
fileExists(fileName: string): boolean {
if (!this.testDir) {
throw new Error('Test directory not initialized. Call setup() first.');
}
const filePath = join(this.testDir, fileName);
return existsSync(filePath);
}
/**
* Get the full path to a file in the test directory
*/
getPath(fileName: string): string {
if (!this.testDir) {
throw new Error('Test directory not initialized. Call setup() first.');
}
return join(this.testDir, fileName);
}
/**
* Cleanup test directory
*/
async cleanup(): Promise<void> {
if (this.testDir && process.env['KEEP_OUTPUT'] !== 'true') {
try {
await rm(this.testDir, { recursive: true, force: true });
} catch (error) {
if (process.env['VERBOSE'] === 'true') {
console.warn('Cleanup warning:', (error as Error).message);
}
}
}
}
/**
* Sanitize test name to create valid directory name
*/
private sanitizeTestName(name: string): string {
return name
.toLowerCase()
.replace(/[^a-z0-9]/g, '-')
.replace(/-+/g, '-')
.substring(0, 100); // Limit length
}
}
// ============================================================================
// MCP Server Utilities
// ============================================================================
export interface MCPServerConfig {
command: string;
args: string[];
}
export interface MCPServerResult {
scriptPath: string;
config: MCPServerConfig;
}
/**
* Built-in MCP server template: Math server with add and multiply tools
*/
const MCP_MATH_SERVER_SCRIPT = `#!/usr/bin/env node
/**
* @license
* Copyright 2025 Qwen Team
* SPDX-License-Identifier: Apache-2.0
*/
const readline = require('readline');
const fs = require('fs');
// Debug logging to stderr (only when MCP_DEBUG or VERBOSE is set)
const debugEnabled = process.env['MCP_DEBUG'] === 'true' || process.env['VERBOSE'] === 'true';
function debug(msg) {
if (debugEnabled) {
fs.writeSync(2, \`[MCP-DEBUG] \${msg}\\n\`);
}
}
debug('MCP server starting...');
// Simple JSON-RPC implementation for MCP
class SimpleJSONRPC {
constructor() {
this.handlers = new Map();
this.rl = readline.createInterface({
input: process.stdin,
output: process.stdout,
terminal: false
});
this.rl.on('line', (line) => {
debug(\`Received line: \${line}\`);
try {
const message = JSON.parse(line);
debug(\`Parsed message: \${JSON.stringify(message)}\`);
this.handleMessage(message);
} catch (e) {
debug(\`Parse error: \${e.message}\`);
}
});
}
send(message) {
const msgStr = JSON.stringify(message);
debug(\`Sending message: \${msgStr}\`);
process.stdout.write(msgStr + '\\n');
}
async handleMessage(message) {
if (message.method && this.handlers.has(message.method)) {
try {
const result = await this.handlers.get(message.method)(message.params || {});
if (message.id !== undefined) {
this.send({
jsonrpc: '2.0',
id: message.id,
result
});
}
} catch (error) {
if (message.id !== undefined) {
this.send({
jsonrpc: '2.0',
id: message.id,
error: {
code: -32603,
message: error.message
}
});
}
}
} else if (message.id !== undefined) {
this.send({
jsonrpc: '2.0',
id: message.id,
error: {
code: -32601,
message: 'Method not found'
}
});
}
}
on(method, handler) {
this.handlers.set(method, handler);
}
}
// Create MCP server
const rpc = new SimpleJSONRPC();
// Handle initialize
rpc.on('initialize', async (params) => {
debug('Handling initialize request');
return {
protocolVersion: '2024-11-05',
capabilities: {
tools: {}
},
serverInfo: {
name: 'test-math-server',
version: '1.0.0'
}
};
});
// Handle tools/list
rpc.on('tools/list', async () => {
debug('Handling tools/list request');
return {
tools: [
{
name: 'add',
description: 'Add two numbers together',
inputSchema: {
type: 'object',
properties: {
a: { type: 'number', description: 'First number' },
b: { type: 'number', description: 'Second number' }
},
required: ['a', 'b']
}
},
{
name: 'multiply',
description: 'Multiply two numbers together',
inputSchema: {
type: 'object',
properties: {
a: { type: 'number', description: 'First number' },
b: { type: 'number', description: 'Second number' }
},
required: ['a', 'b']
}
}
]
};
});
// Handle tools/call
rpc.on('tools/call', async (params) => {
debug(\`Handling tools/call request for tool: \${params.name}\`);
if (params.name === 'add') {
const { a, b } = params.arguments;
return {
content: [{
type: 'text',
text: String(a + b)
}]
};
}
if (params.name === 'multiply') {
const { a, b } = params.arguments;
return {
content: [{
type: 'text',
text: String(a * b)
}]
};
}
throw new Error('Unknown tool: ' + params.name);
});
// Send initialization notification
rpc.send({
jsonrpc: '2.0',
method: 'initialized'
});
`;
/**
* Create an MCP server script in the test directory
* @param helper - SDKTestHelper instance
* @param type - Type of MCP server ('math' or provide custom script)
* @param serverName - Name of the MCP server (default: 'test-math-server')
* @param customScript - Custom MCP server script (if type is not 'math')
* @returns Object with scriptPath and config
*/
export async function createMCPServer(
helper: SDKTestHelper,
type: 'math' | 'custom' = 'math',
serverName: string = 'test-math-server',
customScript?: string,
): Promise<MCPServerResult> {
if (!helper.testDir) {
throw new Error('Test directory not initialized. Call setup() first.');
}
const script = type === 'math' ? MCP_MATH_SERVER_SCRIPT : customScript;
if (!script) {
throw new Error('Custom script required when type is "custom"');
}
const scriptPath = join(helper.testDir, `${serverName}.cjs`);
await writeFile(scriptPath, script, 'utf-8');
// Make script executable on Unix-like systems
if (process.platform !== 'win32') {
await chmod(scriptPath, 0o755);
}
return {
scriptPath,
config: {
command: 'node',
args: [scriptPath],
},
};
}
// ============================================================================
// Message & Content Utilities
// ============================================================================
/**
* Extract text from ContentBlock array
*/
export function extractText(content: ContentBlock[]): string {
return content
.filter((block): block is TextBlock => block.type === 'text')
.map((block) => block.text)
.join('');
}
/**
* Collect messages by type
*/
export function collectMessagesByType<T extends SDKMessage>(
messages: SDKMessage[],
predicate: (msg: SDKMessage) => msg is T,
): T[] {
return messages.filter(predicate);
}
/**
* Find tool use blocks in a message
*/
export function findToolUseBlocks(
message: SDKAssistantMessage,
toolName?: string,
): ToolUseBlock[] {
const toolUseBlocks = message.message.content.filter(
(block): block is ToolUseBlock => block.type === 'tool_use',
);
if (toolName) {
return toolUseBlocks.filter((block) => block.name === toolName);
}
return toolUseBlocks;
}
/**
* Extract all assistant text from messages
*/
export function getAssistantText(messages: SDKMessage[]): string {
return messages
.filter(isSDKAssistantMessage)
.map((msg) => extractText(msg.message.content))
.join('');
}
/**
* Find system message with optional subtype filter
*/
export function findSystemMessage(
messages: SDKMessage[],
subtype?: string,
): SDKSystemMessage | null {
const systemMessages = messages.filter(isSDKSystemMessage);
if (subtype) {
return systemMessages.find((msg) => msg.subtype === subtype) || null;
}
return systemMessages[0] || null;
}
/**
* Find all tool calls in messages
*/
export function findToolCalls(
messages: SDKMessage[],
toolName?: string,
): Array<{ message: SDKAssistantMessage; toolUse: ToolUseBlock }> {
const results: Array<{
message: SDKAssistantMessage;
toolUse: ToolUseBlock;
}> = [];
for (const message of messages) {
if (isSDKAssistantMessage(message)) {
const toolUseBlocks = findToolUseBlocks(message, toolName);
for (const toolUse of toolUseBlocks) {
results.push({ message, toolUse });
}
}
}
return results;
}
// ============================================================================
// Streaming Input Utilities
// ============================================================================
/**
* Create a simple streaming input from an array of message contents
*/
export async function* createStreamingInput(
messageContents: string[],
sessionId?: string,
): AsyncIterable<SDKUserMessage> {
const sid = sessionId || crypto.randomUUID();
for (const content of messageContents) {
yield {
type: 'user',
session_id: sid,
message: {
role: 'user',
content: content,
},
parent_tool_use_id: null,
} as SDKUserMessage;
// Small delay between messages
await new Promise((resolve) => setTimeout(resolve, 100));
}
}
/**
* Create a controlled streaming input with pause/resume capability
*/
export function createControlledStreamingInput(
messageContents: string[],
sessionId?: string,
): {
generator: AsyncIterable<SDKUserMessage>;
resume: () => void;
resumeAll: () => void;
} {
const sid = sessionId || crypto.randomUUID();
const resumeResolvers: Array<() => void> = [];
const resumePromises: Array<Promise<void>> = [];
// Create a resume promise for each message after the first
for (let i = 1; i < messageContents.length; i++) {
const promise = new Promise<void>((resolve) => {
resumeResolvers.push(resolve);
});
resumePromises.push(promise);
}
const generator = (async function* () {
// Yield first message immediately
yield {
type: 'user',
session_id: sid,
message: {
role: 'user',
content: messageContents[0],
},
parent_tool_use_id: null,
} as SDKUserMessage;
// For subsequent messages, wait for resume
for (let i = 1; i < messageContents.length; i++) {
await new Promise((resolve) => setTimeout(resolve, 200));
await resumePromises[i - 1];
await new Promise((resolve) => setTimeout(resolve, 200));
yield {
type: 'user',
session_id: sid,
message: {
role: 'user',
content: messageContents[i],
},
parent_tool_use_id: null,
} as SDKUserMessage;
}
})();
let currentResumeIndex = 0;
return {
generator,
resume: () => {
if (currentResumeIndex < resumeResolvers.length) {
resumeResolvers[currentResumeIndex]();
currentResumeIndex++;
}
},
resumeAll: () => {
resumeResolvers.forEach((resolve) => resolve());
currentResumeIndex = resumeResolvers.length;
},
};
}
// ============================================================================
// Assertion Utilities
// ============================================================================
/**
* Assert that messages follow expected type sequence
*/
export function assertMessageSequence(
messages: SDKMessage[],
expectedTypes: string[],
): void {
const actualTypes = messages.map((msg) => msg.type);
if (actualTypes.length < expectedTypes.length) {
throw new Error(
`Expected at least ${expectedTypes.length} messages, got ${actualTypes.length}`,
);
}
for (let i = 0; i < expectedTypes.length; i++) {
if (actualTypes[i] !== expectedTypes[i]) {
throw new Error(
`Expected message ${i} to be type '${expectedTypes[i]}', got '${actualTypes[i]}'`,
);
}
}
}
/**
* Assert that a specific tool was called
*/
export function assertToolCalled(
messages: SDKMessage[],
toolName: string,
): void {
const toolCalls = findToolCalls(messages, toolName);
if (toolCalls.length === 0) {
const allToolCalls = findToolCalls(messages);
const allToolNames = allToolCalls.map((tc) => tc.toolUse.name);
throw new Error(
`Expected tool '${toolName}' to be called. Found tools: ${allToolNames.length > 0 ? allToolNames.join(', ') : 'none'}`,
);
}
}
/**
* Assert that the conversation completed successfully
*/
export function assertSuccessfulCompletion(messages: SDKMessage[]): void {
const lastMessage = messages[messages.length - 1];
if (!isSDKResultMessage(lastMessage)) {
throw new Error(
`Expected last message to be a result message, got '${lastMessage.type}'`,
);
}
if (lastMessage.subtype !== 'success') {
throw new Error(
`Expected successful completion, got result subtype '${lastMessage.subtype}'`,
);
}
}
/**
* Wait for a condition to be true with timeout
*/
export async function waitFor(
predicate: () => boolean | Promise<boolean>,
options: {
timeout?: number;
interval?: number;
errorMessage?: string;
} = {},
): Promise<void> {
const {
timeout = 5000,
interval = 100,
errorMessage = 'Condition not met within timeout',
} = options;
const startTime = Date.now();
while (Date.now() - startTime < timeout) {
const result = await predicate();
if (result) {
return;
}
await new Promise((resolve) => setTimeout(resolve, interval));
}
throw new Error(errorMessage);
}
// ============================================================================
// Debug and Validation Utilities
// ============================================================================
/**
* Validate model output and warn about unexpected content
* Inspired by integration-tests test-helper
*/
export function validateModelOutput(
result: string,
expectedContent: string | (string | RegExp)[] | null = null,
testName = '',
): boolean {
// First, check if there's any output at all
if (!result || result.trim().length === 0) {
throw new Error('Expected model to return some output');
}
// If expectedContent is provided, check for it and warn if missing
if (expectedContent) {
const contents = Array.isArray(expectedContent)
? expectedContent
: [expectedContent];
const missingContent = contents.filter((content) => {
if (typeof content === 'string') {
return !result.toLowerCase().includes(content.toLowerCase());
} else if (content instanceof RegExp) {
return !content.test(result);
}
return false;
});
if (missingContent.length > 0) {
console.warn(
`Warning: Model did not include expected content in response: ${missingContent.join(', ')}.`,
'This is not ideal but not a test failure.',
);
console.warn(
'The tool was called successfully, which is the main requirement.',
);
return false;
} else if (process.env['VERBOSE'] === 'true') {
console.log(`${testName}: Model output validated successfully.`);
}
return true;
}
return true;
}
/**
* Print debug information when tests fail
*/
export function printDebugInfo(
messages: SDKMessage[],
context: Record<string, unknown> = {},
): void {
console.error('Test failed - Debug info:');
console.error('Message count:', messages.length);
// Print message types
const messageTypes = messages.map((m) => m.type);
console.error('Message types:', messageTypes.join(', '));
// Print assistant text
const assistantText = getAssistantText(messages);
console.error(
'Assistant text (first 500 chars):',
assistantText.substring(0, 500),
);
if (assistantText.length > 500) {
console.error(
'Assistant text (last 500 chars):',
assistantText.substring(assistantText.length - 500),
);
}
// Print tool calls
const toolCalls = findToolCalls(messages);
console.error(
'Tool calls found:',
toolCalls.map((tc) => tc.toolUse.name),
);
// Print any additional context provided
Object.entries(context).forEach(([key, value]) => {
console.error(`${key}:`, value);
});
}
/**
* Create detailed error message for tool call expectations
*/
export function createToolCallErrorMessage(
expectedTools: string | string[],
foundTools: string[],
messages: SDKMessage[],
): string {
const expectedStr = Array.isArray(expectedTools)
? expectedTools.join(' or ')
: expectedTools;
const assistantText = getAssistantText(messages);
const preview = assistantText
? assistantText.substring(0, 200) + '...'
: 'no output';
return (
`Expected to find ${expectedStr} tool call(s). ` +
`Found: ${foundTools.length > 0 ? foundTools.join(', ') : 'none'}. ` +
`Output preview: ${preview}`
);
}
// ============================================================================
// Shared Test Options Helper
// ============================================================================
/**
* Create shared test options with CLI path
*/
export function createSharedTestOptions(
overrides: Record<string, unknown> = {},
) {
const TEST_CLI_PATH = process.env['TEST_CLI_PATH'];
if (!TEST_CLI_PATH) {
throw new Error('TEST_CLI_PATH environment variable not set');
}
return {
pathToQwenExecutable: TEST_CLI_PATH,
...overrides,
};
}

View File

@@ -7,12 +7,12 @@ import { describe, expect, it, vi, beforeEach, afterEach } from 'vitest';
import { Query } from '../../src/query/Query.js'; import { Query } from '../../src/query/Query.js';
import type { Transport } from '../../src/transport/Transport.js'; import type { Transport } from '../../src/transport/Transport.js';
import type { import type {
CLIMessage, SDKMessage,
CLIUserMessage, SDKUserMessage,
CLIAssistantMessage, SDKAssistantMessage,
CLISystemMessage, SDKSystemMessage,
CLIResultMessage, SDKResultMessage,
CLIPartialAssistantMessage, SDKPartialAssistantMessage,
CLIControlRequest, CLIControlRequest,
CLIControlResponse, CLIControlResponse,
ControlCancelRequest, ControlCancelRequest,
@@ -118,7 +118,7 @@ function findControlRequest(
function createUserMessage( function createUserMessage(
content: string, content: string,
sessionId = 'test-session', sessionId = 'test-session',
): CLIUserMessage { ): SDKUserMessage {
return { return {
type: 'user', type: 'user',
session_id: sessionId, session_id: sessionId,
@@ -133,7 +133,7 @@ function createUserMessage(
function createAssistantMessage( function createAssistantMessage(
content: string, content: string,
sessionId = 'test-session', sessionId = 'test-session',
): CLIAssistantMessage { ): SDKAssistantMessage {
return { return {
type: 'assistant', type: 'assistant',
uuid: 'msg-123', uuid: 'msg-123',
@@ -153,7 +153,7 @@ function createAssistantMessage(
function createSystemMessage( function createSystemMessage(
subtype: string, subtype: string,
sessionId = 'test-session', sessionId = 'test-session',
): CLISystemMessage { ): SDKSystemMessage {
return { return {
type: 'system', type: 'system',
subtype, subtype,
@@ -168,7 +168,7 @@ function createSystemMessage(
function createResultMessage( function createResultMessage(
success: boolean, success: boolean,
sessionId = 'test-session', sessionId = 'test-session',
): CLIResultMessage { ): SDKResultMessage {
if (success) { if (success) {
return { return {
type: 'result', type: 'result',
@@ -202,7 +202,7 @@ function createResultMessage(
function createPartialMessage( function createPartialMessage(
sessionId = 'test-session', sessionId = 'test-session',
): CLIPartialAssistantMessage { ): SDKPartialAssistantMessage {
return { return {
type: 'stream_event', type: 'stream_event',
uuid: 'stream-123', uuid: 'stream-123',
@@ -816,7 +816,7 @@ describe('Query', () => {
msg !== null && msg !== null &&
'type' in msg && 'type' in msg &&
msg.type === 'user', msg.type === 'user',
) as CLIUserMessage[]; ) as SDKUserMessage[];
userMessages.forEach((msg) => { userMessages.forEach((msg) => {
expect(msg.session_id).toBe(sessionId); expect(msg.session_id).toBe(sessionId);
@@ -889,7 +889,7 @@ describe('Query', () => {
const query = new Query(transport, { cwd: '/test' }); const query = new Query(transport, { cwd: '/test' });
const iterationPromise = (async () => { const iterationPromise = (async () => {
const messages: CLIMessage[] = []; const messages: SDKMessage[] = [];
for await (const msg of query) { for await (const msg of query) {
messages.push(msg); messages.push(msg);
} }
@@ -946,7 +946,7 @@ describe('Query', () => {
it('should support for await loop', async () => { it('should support for await loop', async () => {
const query = new Query(transport, { cwd: '/test' }); const query = new Query(transport, { cwd: '/test' });
const messages: CLIMessage[] = []; const messages: SDKMessage[] = [];
const iterationPromise = (async () => { const iterationPromise = (async () => {
for await (const msg of query) { for await (const msg of query) {
messages.push(msg); messages.push(msg);
@@ -960,7 +960,7 @@ describe('Query', () => {
await iterationPromise; await iterationPromise;
expect(messages).toHaveLength(2); expect(messages).toHaveLength(2);
expect((messages[0] as CLIUserMessage).message.content).toBe('First'); expect((messages[0] as SDKUserMessage).message.content).toBe('First');
await query.close(); await query.close();
}); });
@@ -968,7 +968,7 @@ describe('Query', () => {
it('should complete iteration when query closes', async () => { it('should complete iteration when query closes', async () => {
const query = new Query(transport, { cwd: '/test' }); const query = new Query(transport, { cwd: '/test' });
const messages: CLIMessage[] = []; const messages: SDKMessage[] = [];
const iterationPromise = (async () => { const iterationPromise = (async () => {
for await (const msg of query) { for await (const msg of query) {
messages.push(msg); messages.push(msg);
@@ -1321,7 +1321,7 @@ describe('Query', () => {
const result = await query.next(); const result = await query.next();
expect(result.done).toBe(false); expect(result.done).toBe(false);
expect((result.value as CLIResultMessage).is_error).toBe(true); expect((result.value as SDKResultMessage).is_error).toBe(true);
await query.close(); await query.close();
}); });
@@ -1430,7 +1430,7 @@ describe('Query', () => {
transport.simulateMessage(createUserMessage(`Message ${i}`)); transport.simulateMessage(createUserMessage(`Message ${i}`));
} }
const messages: CLIMessage[] = []; const messages: SDKMessage[] = [];
for (let i = 0; i < 100; i++) { for (let i = 0; i < 100; i++) {
const result = await query.next(); const result = await query.next();
if (!result.done) { if (!result.done) {
@@ -1447,7 +1447,7 @@ describe('Query', () => {
const query = new Query(transport, { cwd: '/test' }); const query = new Query(transport, { cwd: '/test' });
const iterationPromise = (async () => { const iterationPromise = (async () => {
const messages: CLIMessage[] = []; const messages: SDKMessage[] = [];
for await (const msg of query) { for await (const msg of query) {
messages.push(msg); messages.push(msg);
if (messages.length === 2) { if (messages.length === 2) {