refactor: rename ambiguous exported types

This commit is contained in:
mingholy.lmh
2025-11-27 14:50:40 +08:00
parent 638b7bb466
commit 56957a687b
14 changed files with 1188 additions and 614 deletions

View File

@@ -12,20 +12,20 @@ export type {
ThinkingBlock,
ToolUseBlock,
ToolResultBlock,
CLIUserMessage,
CLIAssistantMessage,
CLISystemMessage,
CLIResultMessage,
CLIPartialAssistantMessage,
CLIMessage,
SDKUserMessage,
SDKAssistantMessage,
SDKSystemMessage,
SDKResultMessage,
SDKPartialAssistantMessage,
SDKMessage,
} from './types/protocol.js';
export {
isCLIUserMessage,
isCLIAssistantMessage,
isCLISystemMessage,
isCLIResultMessage,
isCLIPartialAssistantMessage,
isSDKUserMessage,
isSDKAssistantMessage,
isSDKSystemMessage,
isSDKResultMessage,
isSDKPartialAssistantMessage,
} from './types/protocol.js';
export type {

View File

@@ -13,19 +13,19 @@ const STREAM_CLOSE_TIMEOUT = 10000;
import { randomUUID } from 'node:crypto';
import { SdkLogger } from '../utils/logger.js';
import type {
CLIMessage,
CLIUserMessage,
SDKMessage,
SDKUserMessage,
CLIControlRequest,
CLIControlResponse,
ControlCancelRequest,
PermissionSuggestion,
} from '../types/protocol.js';
import {
isCLIUserMessage,
isCLIAssistantMessage,
isCLISystemMessage,
isCLIResultMessage,
isCLIPartialAssistantMessage,
isSDKUserMessage,
isSDKAssistantMessage,
isSDKSystemMessage,
isSDKResultMessage,
isSDKPartialAssistantMessage,
isControlRequest,
isControlResponse,
isControlCancel,
@@ -52,12 +52,12 @@ interface TransportWithEndInput extends Transport {
const logger = SdkLogger.createLogger('Query');
export class Query implements AsyncIterable<CLIMessage> {
export class Query implements AsyncIterable<SDKMessage> {
private transport: Transport;
private options: QueryOptions;
private sessionId: string;
private inputStream: Stream<CLIMessage>;
private sdkMessages: AsyncGenerator<CLIMessage>;
private inputStream: Stream<SDKMessage>;
private sdkMessages: AsyncGenerator<SDKMessage>;
private abortController: AbortController;
private pendingControlRequests: Map<string, PendingControlRequest> =
new Map();
@@ -79,7 +79,7 @@ export class Query implements AsyncIterable<CLIMessage> {
this.transport = transport;
this.options = options;
this.sessionId = randomUUID();
this.inputStream = new Stream<CLIMessage>();
this.inputStream = new Stream<SDKMessage>();
this.abortController = options.abortController ?? new AbortController();
this.isSingleTurn = singleTurn;
@@ -187,7 +187,7 @@ export class Query implements AsyncIterable<CLIMessage> {
return;
}
if (isCLISystemMessage(message)) {
if (isSDKSystemMessage(message)) {
/**
* SystemMessage contains session info (cwd, tools, model, etc.)
* that should be passed to user.
@@ -196,7 +196,7 @@ export class Query implements AsyncIterable<CLIMessage> {
return;
}
if (isCLIResultMessage(message)) {
if (isSDKResultMessage(message)) {
if (this.firstResultReceivedResolve) {
this.firstResultReceivedResolve();
}
@@ -212,16 +212,16 @@ export class Query implements AsyncIterable<CLIMessage> {
}
if (
isCLIAssistantMessage(message) ||
isCLIUserMessage(message) ||
isCLIPartialAssistantMessage(message)
isSDKAssistantMessage(message) ||
isSDKUserMessage(message) ||
isSDKPartialAssistantMessage(message)
) {
this.inputStream.enqueue(message);
return;
}
logger.warn('Unknown message type:', message);
this.inputStream.enqueue(message as CLIMessage);
this.inputStream.enqueue(message as SDKMessage);
}
private async handleControlRequest(
@@ -560,29 +560,29 @@ export class Query implements AsyncIterable<CLIMessage> {
logger.info('Query closed');
}
private async *readSdkMessages(): AsyncGenerator<CLIMessage> {
private async *readSdkMessages(): AsyncGenerator<SDKMessage> {
for await (const message of this.inputStream) {
yield message;
}
}
async next(...args: [] | [unknown]): Promise<IteratorResult<CLIMessage>> {
async next(...args: [] | [unknown]): Promise<IteratorResult<SDKMessage>> {
return this.sdkMessages.next(...args);
}
async return(value?: unknown): Promise<IteratorResult<CLIMessage>> {
async return(value?: unknown): Promise<IteratorResult<SDKMessage>> {
return this.sdkMessages.return(value);
}
async throw(e?: unknown): Promise<IteratorResult<CLIMessage>> {
async throw(e?: unknown): Promise<IteratorResult<SDKMessage>> {
return this.sdkMessages.throw(e);
}
[Symbol.asyncIterator](): AsyncIterator<CLIMessage> {
[Symbol.asyncIterator](): AsyncIterator<SDKMessage> {
return this.sdkMessages;
}
async streamInput(messages: AsyncIterable<CLIUserMessage>): Promise<void> {
async streamInput(messages: AsyncIterable<SDKUserMessage>): Promise<void> {
if (this.closed) {
throw new Error('Query is closed');
}

View File

@@ -2,7 +2,7 @@
* Factory function for creating Query instances.
*/
import type { CLIUserMessage } from '../types/protocol.js';
import type { SDKUserMessage } from '../types/protocol.js';
import { serializeJsonLine } from '../utils/jsonLines.js';
import { ProcessTransport } from '../transport/ProcessTransport.js';
import { parseExecutableSpec } from '../utils/cliPath.js';
@@ -22,11 +22,11 @@ export function query({
/**
* The prompt to send to the Qwen Code CLI process.
* - `string` for single-turn query,
* - `AsyncIterable<CLIUserMessage>` for multi-turn query.
* - `AsyncIterable<SDKUserMessage>` for multi-turn query.
*
* The transport will remain open until the prompt is done.
*/
prompt: string | AsyncIterable<CLIUserMessage>;
prompt: string | AsyncIterable<SDKUserMessage>;
/**
* Configuration options for the query.
*/
@@ -67,7 +67,7 @@ export function query({
if (isSingleTurn) {
const stringPrompt = prompt as string;
const message: CLIUserMessage = {
const message: SDKUserMessage = {
type: 'user',
session_id: queryInstance.getSessionId(),
message: {
@@ -87,7 +87,7 @@ export function query({
})();
} else {
queryInstance
.streamInput(prompt as AsyncIterable<CLIUserMessage>)
.streamInput(prompt as AsyncIterable<SDKUserMessage>)
.catch((err) => {
logger.error('Error streaming input:', err);
});

View File

@@ -89,7 +89,7 @@ export interface APIAssistantMessage {
usage: Usage;
}
export interface CLIUserMessage {
export interface SDKUserMessage {
type: 'user';
uuid?: string;
session_id: string;
@@ -98,7 +98,7 @@ export interface CLIUserMessage {
options?: Record<string, unknown>;
}
export interface CLIAssistantMessage {
export interface SDKAssistantMessage {
type: 'assistant';
uuid: string;
session_id: string;
@@ -106,7 +106,7 @@ export interface CLIAssistantMessage {
parent_tool_use_id: string | null;
}
export interface CLISystemMessage {
export interface SDKSystemMessage {
type: 'system';
subtype: string;
uuid: string;
@@ -133,7 +133,7 @@ export interface CLISystemMessage {
};
}
export interface CLIResultMessageSuccess {
export interface SDKResultMessageSuccess {
type: 'result';
subtype: 'success';
uuid: string;
@@ -149,7 +149,7 @@ export interface CLIResultMessageSuccess {
[key: string]: unknown;
}
export interface CLIResultMessageError {
export interface SDKResultMessageError {
type: 'result';
subtype: 'error_max_turns' | 'error_during_execution';
uuid: string;
@@ -169,7 +169,7 @@ export interface CLIResultMessageError {
[key: string]: unknown;
}
export type CLIResultMessage = CLIResultMessageSuccess | CLIResultMessageError;
export type SDKResultMessage = SDKResultMessageSuccess | SDKResultMessageError;
export interface MessageStartStreamEvent {
type: 'message_start';
@@ -222,7 +222,7 @@ export type StreamEvent =
| ContentBlockStopEvent
| MessageStopStreamEvent;
export interface CLIPartialAssistantMessage {
export interface SDKPartialAssistantMessage {
type: 'stream_event';
uuid: string;
session_id: string;
@@ -389,22 +389,22 @@ export type ControlMessage =
| ControlCancelRequest;
/**
* Union of all CLI message types
* Union of all SDK message types
*/
export type CLIMessage =
| CLIUserMessage
| CLIAssistantMessage
| CLISystemMessage
| CLIResultMessage
| CLIPartialAssistantMessage;
export type SDKMessage =
| SDKUserMessage
| SDKAssistantMessage
| SDKSystemMessage
| SDKResultMessage
| SDKPartialAssistantMessage;
export function isCLIUserMessage(msg: any): msg is CLIUserMessage {
export function isSDKUserMessage(msg: any): msg is SDKUserMessage {
return (
msg && typeof msg === 'object' && msg.type === 'user' && 'message' in msg
);
}
export function isCLIAssistantMessage(msg: any): msg is CLIAssistantMessage {
export function isSDKAssistantMessage(msg: any): msg is SDKAssistantMessage {
return (
msg &&
typeof msg === 'object' &&
@@ -416,7 +416,7 @@ export function isCLIAssistantMessage(msg: any): msg is CLIAssistantMessage {
);
}
export function isCLISystemMessage(msg: any): msg is CLISystemMessage {
export function isSDKSystemMessage(msg: any): msg is SDKSystemMessage {
return (
msg &&
typeof msg === 'object' &&
@@ -427,7 +427,7 @@ export function isCLISystemMessage(msg: any): msg is CLISystemMessage {
);
}
export function isCLIResultMessage(msg: any): msg is CLIResultMessage {
export function isSDKResultMessage(msg: any): msg is SDKResultMessage {
return (
msg &&
typeof msg === 'object' &&
@@ -440,9 +440,9 @@ export function isCLIResultMessage(msg: any): msg is CLIResultMessage {
);
}
export function isCLIPartialAssistantMessage(
export function isSDKPartialAssistantMessage(
msg: any,
): msg is CLIPartialAssistantMessage {
): msg is SDKPartialAssistantMessage {
return (
msg &&
typeof msg === 'object' &&

View File

@@ -14,14 +14,15 @@ const e2eTestsDir = join(rootDir, '.integration-tests');
let runDir = '';
export async function setup() {
runDir = join(e2eTestsDir, `${Date.now()}`);
runDir = join(e2eTestsDir, `sdk-e2e-${Date.now()}`);
await mkdir(runDir, { recursive: true });
// Clean up old test runs, but keep the latest few for debugging
try {
const testRuns = await readdir(e2eTestsDir);
if (testRuns.length > 5) {
const oldRuns = testRuns.sort().slice(0, testRuns.length - 5);
const sdkTestRuns = testRuns.filter((run) => run.startsWith('sdk-e2e-'));
if (sdkTestRuns.length > 5) {
const oldRuns = sdkTestRuns.sort().slice(0, sdkTestRuns.length - 5);
await Promise.all(
oldRuns.map((oldRun) =>
rm(join(e2eTestsDir, oldRun), {
@@ -44,7 +45,7 @@ export async function setup() {
}
process.env['VERBOSE'] = process.env['VERBOSE'] ?? 'false';
console.log(`\nE2E test output directory: ${runDir}`);
console.log(`\nSDK E2E test output directory: ${runDir}`);
console.log(`CLI path: ${process.env['TEST_CLI_PATH']}`);
}

View File

@@ -9,234 +9,48 @@
* Tests that the SDK can properly interact with MCP servers configured in qwen-code
*/
import { describe, it, expect, beforeAll } from 'vitest';
import { describe, it, expect, beforeAll, afterAll } from 'vitest';
import { query } from '../../src/index.js';
import {
isCLIAssistantMessage,
isCLIResultMessage,
isCLISystemMessage,
isCLIUserMessage,
type TextBlock,
type ContentBlock,
type CLIMessage,
isSDKAssistantMessage,
isSDKResultMessage,
isSDKSystemMessage,
isSDKUserMessage,
type SDKMessage,
type ToolUseBlock,
type CLISystemMessage,
type SDKSystemMessage,
} from '../../src/types/protocol.js';
import { writeFileSync, mkdirSync, chmodSync } from 'node:fs';
import { join } from 'node:path';
const TEST_CLI_PATH = process.env['TEST_CLI_PATH']!;
const E2E_TEST_FILE_DIR = process.env['E2E_TEST_FILE_DIR']!;
import {
SDKTestHelper,
createMCPServer,
extractText,
findToolUseBlocks,
createSharedTestOptions,
} from './test-helper.js';
const SHARED_TEST_OPTIONS = {
pathToQwenExecutable: TEST_CLI_PATH,
...createSharedTestOptions(),
permissionMode: 'yolo' as const,
};
/**
* Helper to extract text from ContentBlock array
*/
function extractText(content: ContentBlock[]): string {
return content
.filter((block): block is TextBlock => block.type === 'text')
.map((block) => block.text)
.join('');
}
/**
* Minimal MCP server implementation that doesn't require external dependencies
* This implements the MCP protocol directly using Node.js built-ins
*/
const MCP_SERVER_SCRIPT = `#!/usr/bin/env node
/**
* @license
* Copyright 2025 Qwen Team
* SPDX-License-Identifier: Apache-2.0
*/
const readline = require('readline');
const fs = require('fs');
// Debug logging to stderr (only when MCP_DEBUG or VERBOSE is set)
const debugEnabled = process.env['MCP_DEBUG'] === 'true' || process.env['VERBOSE'] === 'true';
function debug(msg) {
if (debugEnabled) {
fs.writeSync(2, \`[MCP-DEBUG] \${msg}\\n\`);
}
}
debug('MCP server starting...');
// Simple JSON-RPC implementation for MCP
class SimpleJSONRPC {
constructor() {
this.handlers = new Map();
this.rl = readline.createInterface({
input: process.stdin,
output: process.stdout,
terminal: false
});
this.rl.on('line', (line) => {
debug(\`Received line: \${line}\`);
try {
const message = JSON.parse(line);
debug(\`Parsed message: \${JSON.stringify(message)}\`);
this.handleMessage(message);
} catch (e) {
debug(\`Parse error: \${e.message}\`);
}
});
}
send(message) {
const msgStr = JSON.stringify(message);
debug(\`Sending message: \${msgStr}\`);
process.stdout.write(msgStr + '\\n');
}
async handleMessage(message) {
if (message.method && this.handlers.has(message.method)) {
try {
const result = await this.handlers.get(message.method)(message.params || {});
if (message.id !== undefined) {
this.send({
jsonrpc: '2.0',
id: message.id,
result
});
}
} catch (error) {
if (message.id !== undefined) {
this.send({
jsonrpc: '2.0',
id: message.id,
error: {
code: -32603,
message: error.message
}
});
}
}
} else if (message.id !== undefined) {
this.send({
jsonrpc: '2.0',
id: message.id,
error: {
code: -32601,
message: 'Method not found'
}
});
}
}
on(method, handler) {
this.handlers.set(method, handler);
}
}
// Create MCP server
const rpc = new SimpleJSONRPC();
// Handle initialize
rpc.on('initialize', async (params) => {
debug('Handling initialize request');
return {
protocolVersion: '2024-11-05',
capabilities: {
tools: {}
},
serverInfo: {
name: 'test-math-server',
version: '1.0.0'
}
};
});
// Handle tools/list
rpc.on('tools/list', async () => {
debug('Handling tools/list request');
return {
tools: [
{
name: 'add',
description: 'Add two numbers together',
inputSchema: {
type: 'object',
properties: {
a: { type: 'number', description: 'First number' },
b: { type: 'number', description: 'Second number' }
},
required: ['a', 'b']
}
},
{
name: 'multiply',
description: 'Multiply two numbers together',
inputSchema: {
type: 'object',
properties: {
a: { type: 'number', description: 'First number' },
b: { type: 'number', description: 'Second number' }
},
required: ['a', 'b']
}
}
]
};
});
// Handle tools/call
rpc.on('tools/call', async (params) => {
debug(\`Handling tools/call request for tool: \${params.name}\`);
if (params.name === 'add') {
const { a, b } = params.arguments;
return {
content: [{
type: 'text',
text: String(a + b)
}]
};
}
if (params.name === 'multiply') {
const { a, b } = params.arguments;
return {
content: [{
type: 'text',
text: String(a * b)
}]
};
}
throw new Error('Unknown tool: ' + params.name);
});
// Send initialization notification
rpc.send({
jsonrpc: '2.0',
method: 'initialized'
});
`;
describe('MCP Server Integration (E2E)', () => {
let testDir: string;
let helper: SDKTestHelper;
let serverScriptPath: string;
let testDir: string;
beforeAll(() => {
// Use the centralized E2E test directory from globalSetup
testDir = join(E2E_TEST_FILE_DIR, 'mcp-server-test');
mkdirSync(testDir, { recursive: true });
beforeAll(async () => {
// Create isolated test environment using SDKTestHelper
helper = new SDKTestHelper();
testDir = await helper.setup('mcp-server-integration');
// Write MCP server script
serverScriptPath = join(testDir, 'mcp-server.cjs');
writeFileSync(serverScriptPath, MCP_SERVER_SCRIPT);
// Create MCP server using the helper utility
const mcpServer = await createMCPServer(helper, 'math', 'test-math-server');
serverScriptPath = mcpServer.scriptPath;
});
// Make script executable on Unix-like systems
if (process.platform !== 'win32') {
chmodSync(serverScriptPath, 0o755);
}
afterAll(async () => {
// Cleanup test directory
await helper.cleanup();
});
describe('Basic MCP Tool Usage', () => {
@@ -257,7 +71,7 @@ describe('MCP Server Integration (E2E)', () => {
},
});
const messages: CLIMessage[] = [];
const messages: SDKMessage[] = [];
let assistantText = '';
let foundToolUse = false;
@@ -265,12 +79,9 @@ describe('MCP Server Integration (E2E)', () => {
for await (const message of q) {
messages.push(message);
if (isCLIAssistantMessage(message)) {
const toolUseBlock = message.message.content.find(
(block: ContentBlock): block is ToolUseBlock =>
block.type === 'tool_use',
);
if (toolUseBlock && toolUseBlock.name === 'add') {
if (isSDKAssistantMessage(message)) {
const toolUseBlocks = findToolUseBlocks(message, 'add');
if (toolUseBlocks.length > 0) {
foundToolUse = true;
}
assistantText += extractText(message.message.content);
@@ -285,8 +96,8 @@ describe('MCP Server Integration (E2E)', () => {
// Validate successful completion
const lastMessage = messages[messages.length - 1];
expect(isCLIResultMessage(lastMessage)).toBe(true);
if (isCLIResultMessage(lastMessage)) {
expect(isSDKResultMessage(lastMessage)).toBe(true);
if (isSDKResultMessage(lastMessage)) {
expect(lastMessage.subtype).toBe('success');
}
} finally {
@@ -311,7 +122,7 @@ describe('MCP Server Integration (E2E)', () => {
},
});
const messages: CLIMessage[] = [];
const messages: SDKMessage[] = [];
let assistantText = '';
let foundToolUse = false;
@@ -319,12 +130,9 @@ describe('MCP Server Integration (E2E)', () => {
for await (const message of q) {
messages.push(message);
if (isCLIAssistantMessage(message)) {
const toolUseBlock = message.message.content.find(
(block: ContentBlock): block is ToolUseBlock =>
block.type === 'tool_use',
);
if (toolUseBlock && toolUseBlock.name === 'multiply') {
if (isSDKAssistantMessage(message)) {
const toolUseBlocks = findToolUseBlocks(message, 'multiply');
if (toolUseBlocks.length > 0) {
foundToolUse = true;
}
assistantText += extractText(message.message.content);
@@ -339,7 +147,7 @@ describe('MCP Server Integration (E2E)', () => {
// Validate successful completion
const lastMessage = messages[messages.length - 1];
expect(isCLIResultMessage(lastMessage)).toBe(true);
expect(isSDKResultMessage(lastMessage)).toBe(true);
} finally {
await q.close();
}
@@ -363,11 +171,11 @@ describe('MCP Server Integration (E2E)', () => {
},
});
let systemMessage: CLISystemMessage | null = null;
let systemMessage: SDKSystemMessage | null = null;
try {
for await (const message of q) {
if (isCLISystemMessage(message) && message.subtype === 'init') {
if (isSDKSystemMessage(message) && message.subtype === 'init') {
systemMessage = message;
break;
}
@@ -410,7 +218,7 @@ describe('MCP Server Integration (E2E)', () => {
},
});
const messages: CLIMessage[] = [];
const messages: SDKMessage[] = [];
let assistantText = '';
const toolCalls: string[] = [];
@@ -418,11 +226,8 @@ describe('MCP Server Integration (E2E)', () => {
for await (const message of q) {
messages.push(message);
if (isCLIAssistantMessage(message)) {
const toolUseBlocks = message.message.content.filter(
(block: ContentBlock): block is ToolUseBlock =>
block.type === 'tool_use',
);
if (isSDKAssistantMessage(message)) {
const toolUseBlocks = findToolUseBlocks(message);
toolUseBlocks.forEach((block) => {
toolCalls.push(block.name);
});
@@ -439,7 +244,7 @@ describe('MCP Server Integration (E2E)', () => {
// Validate successful completion
const lastMessage = messages[messages.length - 1];
expect(isCLIResultMessage(lastMessage)).toBe(true);
expect(isSDKResultMessage(lastMessage)).toBe(true);
} finally {
await q.close();
}
@@ -462,7 +267,7 @@ describe('MCP Server Integration (E2E)', () => {
},
});
const messages: CLIMessage[] = [];
const messages: SDKMessage[] = [];
let assistantText = '';
const addToolCalls: ToolUseBlock[] = [];
@@ -470,16 +275,9 @@ describe('MCP Server Integration (E2E)', () => {
for await (const message of q) {
messages.push(message);
if (isCLIAssistantMessage(message)) {
const toolUseBlocks = message.message.content.filter(
(block: ContentBlock): block is ToolUseBlock =>
block.type === 'tool_use',
);
toolUseBlocks.forEach((block) => {
if (block.name === 'add') {
addToolCalls.push(block);
}
});
if (isSDKAssistantMessage(message)) {
const toolUseBlocks = findToolUseBlocks(message, 'add');
addToolCalls.push(...toolUseBlocks);
assistantText += extractText(message.message.content);
}
}
@@ -493,7 +291,7 @@ describe('MCP Server Integration (E2E)', () => {
// Validate successful completion
const lastMessage = messages[messages.length - 1];
expect(isCLIResultMessage(lastMessage)).toBe(true);
expect(isSDKResultMessage(lastMessage)).toBe(true);
} finally {
await q.close();
}
@@ -525,19 +323,16 @@ describe('MCP Server Integration (E2E)', () => {
for await (const message of q) {
messageTypes.push(message.type);
if (isCLIAssistantMessage(message)) {
const toolUseBlock = message.message.content.find(
(block: ContentBlock): block is ToolUseBlock =>
block.type === 'tool_use',
);
if (toolUseBlock) {
if (isSDKAssistantMessage(message)) {
const toolUseBlocks = findToolUseBlocks(message);
if (toolUseBlocks.length > 0) {
foundToolUse = true;
expect(toolUseBlock.name).toBe('add');
expect(toolUseBlock.input).toBeDefined();
expect(toolUseBlocks[0].name).toBe('add');
expect(toolUseBlocks[0].input).toBeDefined();
}
}
if (isCLIUserMessage(message)) {
if (isSDKUserMessage(message)) {
const content = message.message.content;
const contentArray = Array.isArray(content)
? content
@@ -584,21 +379,21 @@ describe('MCP Server Integration (E2E)', () => {
},
});
const messages: CLIMessage[] = [];
const messages: SDKMessage[] = [];
let assistantText = '';
try {
for await (const message of q) {
messages.push(message);
if (isCLIAssistantMessage(message)) {
if (isSDKAssistantMessage(message)) {
assistantText += extractText(message.message.content);
}
}
// Should complete without crashing
const lastMessage = messages[messages.length - 1];
expect(isCLIResultMessage(lastMessage)).toBe(true);
expect(isSDKResultMessage(lastMessage)).toBe(true);
// Assistant should indicate tool is not available or provide alternative
expect(assistantText.length).toBeGreaterThan(0);

View File

@@ -6,19 +6,19 @@
import { describe, it, expect } from 'vitest';
import { query } from '../../src/index.js';
import {
isCLIUserMessage,
isCLIAssistantMessage,
isCLISystemMessage,
isCLIResultMessage,
isCLIPartialAssistantMessage,
isSDKUserMessage,
isSDKAssistantMessage,
isSDKSystemMessage,
isSDKResultMessage,
isSDKPartialAssistantMessage,
isControlRequest,
isControlResponse,
isControlCancel,
type CLIUserMessage,
type CLIAssistantMessage,
type SDKUserMessage,
type SDKAssistantMessage,
type TextBlock,
type ContentBlock,
type CLIMessage,
type SDKMessage,
type ControlMessage,
type ToolUseBlock,
} from '../../src/types/protocol.js';
@@ -31,16 +31,16 @@ const SHARED_TEST_OPTIONS = {
/**
* Determine the message type using protocol type guards
*/
function getMessageType(message: CLIMessage | ControlMessage): string {
if (isCLIUserMessage(message)) {
function getMessageType(message: SDKMessage | ControlMessage): string {
if (isSDKUserMessage(message)) {
return '🧑 USER';
} else if (isCLIAssistantMessage(message)) {
} else if (isSDKAssistantMessage(message)) {
return '🤖 ASSISTANT';
} else if (isCLISystemMessage(message)) {
} else if (isSDKSystemMessage(message)) {
return `🖥️ SYSTEM(${message.subtype})`;
} else if (isCLIResultMessage(message)) {
} else if (isSDKResultMessage(message)) {
return `✅ RESULT(${message.subtype})`;
} else if (isCLIPartialAssistantMessage(message)) {
} else if (isSDKPartialAssistantMessage(message)) {
return '⏳ STREAM_EVENT';
} else if (isControlRequest(message)) {
return `🎮 CONTROL_REQUEST(${message.request.subtype})`;
@@ -67,7 +67,7 @@ describe('Multi-Turn Conversations (E2E)', () => {
describe('AsyncIterable Prompt Support', () => {
it('should handle multi-turn conversation using AsyncIterable prompt', async () => {
// Create multi-turn conversation generator
async function* createMultiTurnConversation(): AsyncIterable<CLIUserMessage> {
async function* createMultiTurnConversation(): AsyncIterable<SDKUserMessage> {
const sessionId = crypto.randomUUID();
yield {
@@ -78,7 +78,7 @@ describe('Multi-Turn Conversations (E2E)', () => {
content: 'What is 1 + 1?',
},
parent_tool_use_id: null,
} as CLIUserMessage;
} as SDKUserMessage;
await new Promise((resolve) => setTimeout(resolve, 100));
@@ -90,7 +90,7 @@ describe('Multi-Turn Conversations (E2E)', () => {
content: 'What is 2 + 2?',
},
parent_tool_use_id: null,
} as CLIUserMessage;
} as SDKUserMessage;
await new Promise((resolve) => setTimeout(resolve, 100));
@@ -102,7 +102,7 @@ describe('Multi-Turn Conversations (E2E)', () => {
content: 'What is 3 + 3?',
},
parent_tool_use_id: null,
} as CLIUserMessage;
} as SDKUserMessage;
}
// Create multi-turn query using AsyncIterable prompt
@@ -114,15 +114,15 @@ describe('Multi-Turn Conversations (E2E)', () => {
},
});
const messages: CLIMessage[] = [];
const assistantMessages: CLIAssistantMessage[] = [];
const messages: SDKMessage[] = [];
const assistantMessages: SDKAssistantMessage[] = [];
const assistantTexts: string[] = [];
try {
for await (const message of q) {
messages.push(message);
if (isCLIAssistantMessage(message)) {
if (isSDKAssistantMessage(message)) {
assistantMessages.push(message);
const text = extractText(message.message.content);
assistantTexts.push(text);
@@ -142,7 +142,7 @@ describe('Multi-Turn Conversations (E2E)', () => {
});
it('should maintain session context across turns', async () => {
async function* createContextualConversation(): AsyncIterable<CLIUserMessage> {
async function* createContextualConversation(): AsyncIterable<SDKUserMessage> {
const sessionId = crypto.randomUUID();
yield {
@@ -154,7 +154,7 @@ describe('Multi-Turn Conversations (E2E)', () => {
'Suppose we have 3 rabbits and 4 carrots. How many animals are there?',
},
parent_tool_use_id: null,
} as CLIUserMessage;
} as SDKUserMessage;
await new Promise((resolve) => setTimeout(resolve, 200));
@@ -166,7 +166,7 @@ describe('Multi-Turn Conversations (E2E)', () => {
content: 'How many animals are there? Only output the number',
},
parent_tool_use_id: null,
} as CLIUserMessage;
} as SDKUserMessage;
}
const q = query({
@@ -177,11 +177,11 @@ describe('Multi-Turn Conversations (E2E)', () => {
},
});
const assistantMessages: CLIAssistantMessage[] = [];
const assistantMessages: SDKAssistantMessage[] = [];
try {
for await (const message of q) {
if (isCLIAssistantMessage(message)) {
if (isSDKAssistantMessage(message)) {
assistantMessages.push(message);
}
}
@@ -201,7 +201,7 @@ describe('Multi-Turn Conversations (E2E)', () => {
describe('Tool Usage in Multi-Turn', () => {
it('should handle tool usage across multiple turns', async () => {
async function* createToolConversation(): AsyncIterable<CLIUserMessage> {
async function* createToolConversation(): AsyncIterable<SDKUserMessage> {
const sessionId = crypto.randomUUID();
yield {
@@ -212,7 +212,7 @@ describe('Multi-Turn Conversations (E2E)', () => {
content: 'Create a file named test.txt with content "hello"',
},
parent_tool_use_id: null,
} as CLIUserMessage;
} as SDKUserMessage;
await new Promise((resolve) => setTimeout(resolve, 200));
@@ -224,7 +224,7 @@ describe('Multi-Turn Conversations (E2E)', () => {
content: 'Now read the test.txt file',
},
parent_tool_use_id: null,
} as CLIUserMessage;
} as SDKUserMessage;
}
const q = query({
@@ -237,15 +237,15 @@ describe('Multi-Turn Conversations (E2E)', () => {
},
});
const messages: CLIMessage[] = [];
const messages: SDKMessage[] = [];
let toolUseCount = 0;
const assistantMessages: CLIAssistantMessage[] = [];
const assistantMessages: SDKAssistantMessage[] = [];
try {
for await (const message of q) {
messages.push(message);
if (isCLIAssistantMessage(message)) {
if (isSDKAssistantMessage(message)) {
assistantMessages.push(message);
const hasToolUseBlock = message.message.content.some(
(block: ContentBlock): block is ToolUseBlock =>
@@ -274,7 +274,7 @@ describe('Multi-Turn Conversations (E2E)', () => {
describe('Message Flow and Sequencing', () => {
it('should process messages in correct sequence', async () => {
async function* createSequentialConversation(): AsyncIterable<CLIUserMessage> {
async function* createSequentialConversation(): AsyncIterable<SDKUserMessage> {
const sessionId = crypto.randomUUID();
yield {
@@ -285,7 +285,7 @@ describe('Multi-Turn Conversations (E2E)', () => {
content: 'First question: What is 1 + 1?',
},
parent_tool_use_id: null,
} as CLIUserMessage;
} as SDKUserMessage;
await new Promise((resolve) => setTimeout(resolve, 100));
@@ -297,7 +297,7 @@ describe('Multi-Turn Conversations (E2E)', () => {
content: 'Second question: What is 2 + 2?',
},
parent_tool_use_id: null,
} as CLIUserMessage;
} as SDKUserMessage;
}
const q = query({
@@ -316,7 +316,7 @@ describe('Multi-Turn Conversations (E2E)', () => {
const messageType = getMessageType(message);
messageSequence.push(messageType);
if (isCLIAssistantMessage(message)) {
if (isSDKAssistantMessage(message)) {
const text = extractText(message.message.content);
assistantResponses.push(text);
}
@@ -338,7 +338,7 @@ describe('Multi-Turn Conversations (E2E)', () => {
});
it('should handle conversation completion correctly', async () => {
async function* createSimpleConversation(): AsyncIterable<CLIUserMessage> {
async function* createSimpleConversation(): AsyncIterable<SDKUserMessage> {
const sessionId = crypto.randomUUID();
yield {
@@ -349,7 +349,7 @@ describe('Multi-Turn Conversations (E2E)', () => {
content: 'Hello',
},
parent_tool_use_id: null,
} as CLIUserMessage;
} as SDKUserMessage;
await new Promise((resolve) => setTimeout(resolve, 100));
@@ -361,7 +361,7 @@ describe('Multi-Turn Conversations (E2E)', () => {
content: 'Goodbye',
},
parent_tool_use_id: null,
} as CLIUserMessage;
} as SDKUserMessage;
}
const q = query({
@@ -379,7 +379,7 @@ describe('Multi-Turn Conversations (E2E)', () => {
for await (const message of q) {
messageCount++;
if (isCLIResultMessage(message)) {
if (isSDKResultMessage(message)) {
completedNaturally = true;
expect(message.subtype).toBe('success');
}
@@ -395,11 +395,11 @@ describe('Multi-Turn Conversations (E2E)', () => {
describe('Error Handling in Multi-Turn', () => {
it('should handle empty conversation gracefully', async () => {
async function* createEmptyConversation(): AsyncIterable<CLIUserMessage> {
async function* createEmptyConversation(): AsyncIterable<SDKUserMessage> {
// Generator that yields nothing
/* eslint-disable no-constant-condition */
if (false) {
yield {} as CLIUserMessage; // Unreachable, but satisfies TypeScript
yield {} as SDKUserMessage; // Unreachable, but satisfies TypeScript
}
}
@@ -411,7 +411,7 @@ describe('Multi-Turn Conversations (E2E)', () => {
},
});
const messages: CLIMessage[] = [];
const messages: SDKMessage[] = [];
try {
for await (const message of q) {
@@ -426,7 +426,7 @@ describe('Multi-Turn Conversations (E2E)', () => {
});
it('should handle conversation with delays', async () => {
async function* createDelayedConversation(): AsyncIterable<CLIUserMessage> {
async function* createDelayedConversation(): AsyncIterable<SDKUserMessage> {
const sessionId = crypto.randomUUID();
yield {
@@ -437,7 +437,7 @@ describe('Multi-Turn Conversations (E2E)', () => {
content: 'First message',
},
parent_tool_use_id: null,
} as CLIUserMessage;
} as SDKUserMessage;
// Longer delay to test patience
await new Promise((resolve) => setTimeout(resolve, 500));
@@ -450,7 +450,7 @@ describe('Multi-Turn Conversations (E2E)', () => {
content: 'Second message after delay',
},
parent_tool_use_id: null,
} as CLIUserMessage;
} as SDKUserMessage;
}
const q = query({
@@ -461,11 +461,11 @@ describe('Multi-Turn Conversations (E2E)', () => {
},
});
const assistantMessages: CLIAssistantMessage[] = [];
const assistantMessages: SDKAssistantMessage[] = [];
try {
for await (const message of q) {
if (isCLIAssistantMessage(message)) {
if (isSDKAssistantMessage(message)) {
assistantMessages.push(message);
}
}
@@ -479,7 +479,7 @@ describe('Multi-Turn Conversations (E2E)', () => {
describe('Partial Messages in Multi-Turn', () => {
it('should receive partial messages when includePartialMessages is enabled', async () => {
async function* createMultiTurnConversation(): AsyncIterable<CLIUserMessage> {
async function* createMultiTurnConversation(): AsyncIterable<SDKUserMessage> {
const sessionId = crypto.randomUUID();
yield {
@@ -490,7 +490,7 @@ describe('Multi-Turn Conversations (E2E)', () => {
content: 'What is 1 + 1?',
},
parent_tool_use_id: null,
} as CLIUserMessage;
} as SDKUserMessage;
await new Promise((resolve) => setTimeout(resolve, 100));
@@ -502,7 +502,7 @@ describe('Multi-Turn Conversations (E2E)', () => {
content: 'What is 2 + 2?',
},
parent_tool_use_id: null,
} as CLIUserMessage;
} as SDKUserMessage;
}
const q = query({
@@ -514,7 +514,7 @@ describe('Multi-Turn Conversations (E2E)', () => {
},
});
const messages: CLIMessage[] = [];
const messages: SDKMessage[] = [];
let partialMessageCount = 0;
let assistantMessageCount = 0;
@@ -522,11 +522,11 @@ describe('Multi-Turn Conversations (E2E)', () => {
for await (const message of q) {
messages.push(message);
if (isCLIPartialAssistantMessage(message)) {
if (isSDKPartialAssistantMessage(message)) {
partialMessageCount++;
}
if (isCLIAssistantMessage(message)) {
if (isSDKAssistantMessage(message)) {
assistantMessageCount++;
}
}

View File

@@ -7,10 +7,10 @@
import { describe, it, expect, beforeAll, afterAll } from 'vitest';
import { query } from '../../src/index.js';
import {
isCLIAssistantMessage,
isCLIResultMessage,
isCLIUserMessage,
type CLIUserMessage,
isSDKAssistantMessage,
isSDKResultMessage,
isSDKUserMessage,
type SDKUserMessage,
type ToolUseBlock,
type ContentBlock,
} from '../../src/types/protocol.js';
@@ -32,7 +32,7 @@ function createStreamingInputWithControlPoint(
firstMessage: string,
secondMessage: string,
): {
generator: AsyncIterable<CLIUserMessage>;
generator: AsyncIterable<SDKUserMessage>;
resume: () => void;
} {
let resumeResolve: (() => void) | null = null;
@@ -51,7 +51,7 @@ function createStreamingInputWithControlPoint(
content: firstMessage,
},
parent_tool_use_id: null,
} as CLIUserMessage;
} as SDKUserMessage;
await new Promise((resolve) => setTimeout(resolve, 200));
@@ -67,7 +67,7 @@ function createStreamingInputWithControlPoint(
content: secondMessage,
},
parent_tool_use_id: null,
} as CLIUserMessage;
} as SDKUserMessage;
})();
const resume = () => {
@@ -120,7 +120,7 @@ describe('Permission Control (E2E)', () => {
try {
let hasToolUse = false;
for await (const message of q) {
if (isCLIAssistantMessage(message)) {
if (isSDKAssistantMessage(message)) {
const toolUseBlock = message.message.content.find(
(block: ContentBlock): block is ToolUseBlock =>
block.type === 'tool_use',
@@ -162,7 +162,7 @@ describe('Permission Control (E2E)', () => {
try {
let hasToolResult = false;
for await (const message of q) {
if (isCLIUserMessage(message)) {
if (isSDKUserMessage(message)) {
if (
Array.isArray(message.message.content) &&
message.message.content.some(
@@ -372,7 +372,7 @@ describe('Permission Control (E2E)', () => {
(async () => {
for await (const message of q) {
if (isCLIAssistantMessage(message) || isCLIResultMessage(message)) {
if (isSDKAssistantMessage(message) || isSDKResultMessage(message)) {
if (!firstResponseReceived) {
firstResponseReceived = true;
resolvers.first?.();
@@ -447,7 +447,7 @@ describe('Permission Control (E2E)', () => {
(async () => {
for await (const message of q) {
if (isCLIAssistantMessage(message) || isCLIResultMessage(message)) {
if (isSDKAssistantMessage(message) || isSDKResultMessage(message)) {
if (!firstResponseReceived) {
firstResponseReceived = true;
resolvers.first?.();
@@ -522,7 +522,7 @@ describe('Permission Control (E2E)', () => {
(async () => {
for await (const message of q) {
if (isCLIAssistantMessage(message) || isCLIResultMessage(message)) {
if (isSDKAssistantMessage(message) || isSDKResultMessage(message)) {
if (!firstResponseReceived) {
firstResponseReceived = true;
resolvers.first?.();
@@ -628,7 +628,7 @@ describe('Permission Control (E2E)', () => {
(async () => {
for await (const message of q) {
if (isCLIResultMessage(message)) {
if (isSDKResultMessage(message)) {
if (!firstResponseReceived) {
firstResponseReceived = true;
resolvers.first?.();
@@ -695,7 +695,7 @@ describe('Permission Control (E2E)', () => {
let hasErrorInResult = false;
for await (const message of q) {
if (isCLIUserMessage(message)) {
if (isSDKUserMessage(message)) {
if (Array.isArray(message.message.content)) {
const toolResult = message.message.content.find(
(block) => block.type === 'tool_result',
@@ -752,7 +752,7 @@ describe('Permission Control (E2E)', () => {
let hasSuccessfulToolResult = false;
for await (const message of q) {
if (isCLIUserMessage(message)) {
if (isSDKUserMessage(message)) {
if (Array.isArray(message.message.content)) {
const toolResult = message.message.content.find(
(block) => block.type === 'tool_result',
@@ -798,7 +798,7 @@ describe('Permission Control (E2E)', () => {
let hasToolResult = false;
for await (const message of q) {
if (isCLIUserMessage(message)) {
if (isSDKUserMessage(message)) {
if (Array.isArray(message.message.content)) {
const toolResult = message.message.content.find(
(block) => block.type === 'tool_result',
@@ -838,7 +838,7 @@ describe('Permission Control (E2E)', () => {
let hasSuccessfulToolResult = false;
for await (const message of q) {
if (isCLIUserMessage(message)) {
if (isSDKUserMessage(message)) {
if (Array.isArray(message.message.content)) {
const toolResult = message.message.content.find(
(block) => block.type === 'tool_result',
@@ -891,7 +891,7 @@ describe('Permission Control (E2E)', () => {
let hasToolResult = false;
for await (const message of q) {
if (isCLIUserMessage(message)) {
if (isSDKUserMessage(message)) {
if (Array.isArray(message.message.content)) {
const toolResult = message.message.content.find(
(block) => block.type === 'tool_result',
@@ -929,7 +929,7 @@ describe('Permission Control (E2E)', () => {
let hasCommandResult = false;
for await (const message of q) {
if (isCLIUserMessage(message)) {
if (isSDKUserMessage(message)) {
if (Array.isArray(message.message.content)) {
const toolResult = message.message.content.find(
(block) => block.type === 'tool_result',
@@ -968,7 +968,7 @@ describe('Permission Control (E2E)', () => {
let hasPlanModeMessage = false;
for await (const message of q) {
if (isCLIUserMessage(message)) {
if (isSDKUserMessage(message)) {
if (Array.isArray(message.message.content)) {
const toolResult = message.message.content.find(
(block) => block.type === 'tool_result',
@@ -1014,7 +1014,7 @@ describe('Permission Control (E2E)', () => {
let hasSuccessfulToolResult = false;
for await (const message of q) {
if (isCLIUserMessage(message)) {
if (isSDKUserMessage(message)) {
if (Array.isArray(message.message.content)) {
const toolResult = message.message.content.find(
(block) => block.type === 'tool_result',
@@ -1066,7 +1066,7 @@ describe('Permission Control (E2E)', () => {
let hasPlanModeBlock = false;
for await (const message of q) {
if (isCLIUserMessage(message)) {
if (isSDKUserMessage(message)) {
if (Array.isArray(message.message.content)) {
const toolResult = message.message.content.find(
(block) => block.type === 'tool_result',
@@ -1114,7 +1114,7 @@ describe('Permission Control (E2E)', () => {
let hasDeniedTool = false;
for await (const message of q) {
if (isCLIUserMessage(message)) {
if (isSDKUserMessage(message)) {
if (Array.isArray(message.message.content)) {
const toolResult = message.message.content.find(
(block) => block.type === 'tool_result',
@@ -1169,7 +1169,7 @@ describe('Permission Control (E2E)', () => {
let hasSuccessfulToolResult = false;
for await (const message of q) {
if (isCLIUserMessage(message)) {
if (isSDKUserMessage(message)) {
if (Array.isArray(message.message.content)) {
const toolResult = message.message.content.find(
(block) => block.type === 'tool_result',
@@ -1214,7 +1214,7 @@ describe('Permission Control (E2E)', () => {
let hasToolResult = false;
for await (const message of q) {
if (isCLIUserMessage(message)) {
if (isSDKUserMessage(message)) {
if (Array.isArray(message.message.content)) {
const toolResult = message.message.content.find(
(block) => block.type === 'tool_result',
@@ -1270,7 +1270,7 @@ describe('Permission Control (E2E)', () => {
let toolExecuted = false;
for await (const message of q) {
if (isCLIUserMessage(message)) {
if (isSDKUserMessage(message)) {
if (Array.isArray(message.message.content)) {
const toolResult = message.message.content.find(
(block) => block.type === 'tool_result',

View File

@@ -6,31 +6,22 @@
import { describe, it, expect } from 'vitest';
import { query } from '../../src/index.js';
import {
isCLIAssistantMessage,
isCLISystemMessage,
isCLIResultMessage,
isCLIPartialAssistantMessage,
type TextBlock,
type ContentBlock,
type CLIMessage,
type CLISystemMessage,
type CLIAssistantMessage,
isSDKAssistantMessage,
isSDKSystemMessage,
isSDKResultMessage,
isSDKPartialAssistantMessage,
type SDKMessage,
type SDKSystemMessage,
type SDKAssistantMessage,
} from '../../src/types/protocol.js';
const TEST_CLI_PATH = process.env['TEST_CLI_PATH']!;
import {
extractText,
createSharedTestOptions,
assertSuccessfulCompletion,
collectMessagesByType,
} from './test-helper.js';
const SHARED_TEST_OPTIONS = {
pathToQwenExecutable: TEST_CLI_PATH,
};
/**
* Helper to extract text from ContentBlock array
*/
function extractText(content: ContentBlock[]): string {
return content
.filter((block): block is TextBlock => block.type === 'text')
.map((block) => block.text)
.join('');
}
const SHARED_TEST_OPTIONS = createSharedTestOptions();
describe('Single-Turn Query (E2E)', () => {
describe('Simple Text Queries', () => {
@@ -44,14 +35,14 @@ describe('Single-Turn Query (E2E)', () => {
},
});
const messages: CLIMessage[] = [];
const messages: SDKMessage[] = [];
let assistantText = '';
try {
for await (const message of q) {
messages.push(message);
if (isCLIAssistantMessage(message)) {
if (isSDKAssistantMessage(message)) {
assistantText += extractText(message.message.content);
}
}
@@ -64,11 +55,7 @@ describe('Single-Turn Query (E2E)', () => {
expect(assistantText).toMatch(/4/);
// Validate message flow ends with success
const lastMessage = messages[messages.length - 1];
expect(isCLIResultMessage(lastMessage)).toBe(true);
if (isCLIResultMessage(lastMessage)) {
expect(lastMessage.subtype).toBe('success');
}
assertSuccessfulCompletion(messages);
} finally {
await q.close();
}
@@ -83,14 +70,14 @@ describe('Single-Turn Query (E2E)', () => {
},
});
const messages: CLIMessage[] = [];
const messages: SDKMessage[] = [];
let assistantText = '';
try {
for await (const message of q) {
messages.push(message);
if (isCLIAssistantMessage(message)) {
if (isSDKAssistantMessage(message)) {
assistantText += extractText(message.message.content);
}
}
@@ -100,8 +87,7 @@ describe('Single-Turn Query (E2E)', () => {
expect(assistantText.toLowerCase()).toContain('paris');
// Validate completion
const lastMessage = messages[messages.length - 1];
expect(isCLIResultMessage(lastMessage)).toBe(true);
assertSuccessfulCompletion(messages);
} finally {
await q.close();
}
@@ -116,14 +102,14 @@ describe('Single-Turn Query (E2E)', () => {
},
});
const messages: CLIMessage[] = [];
const messages: SDKMessage[] = [];
let assistantText = '';
try {
for await (const message of q) {
messages.push(message);
if (isCLIAssistantMessage(message)) {
if (isSDKAssistantMessage(message)) {
assistantText += extractText(message.message.content);
}
}
@@ -133,7 +119,10 @@ describe('Single-Turn Query (E2E)', () => {
expect(assistantText.toLowerCase()).toMatch(/hello|hi|greetings/);
// Validate message types
const assistantMessages = messages.filter(isCLIAssistantMessage);
const assistantMessages = collectMessagesByType(
messages,
isSDKAssistantMessage,
);
expect(assistantMessages.length).toBeGreaterThan(0);
} finally {
await q.close();
@@ -151,14 +140,14 @@ describe('Single-Turn Query (E2E)', () => {
},
});
const messages: CLIMessage[] = [];
let systemMessage: CLISystemMessage | null = null;
const messages: SDKMessage[] = [];
let systemMessage: SDKSystemMessage | null = null;
try {
for await (const message of q) {
messages.push(message);
if (isCLISystemMessage(message) && message.subtype === 'init') {
if (isSDKSystemMessage(message) && message.subtype === 'init') {
systemMessage = message;
}
}
@@ -180,7 +169,7 @@ describe('Single-Turn Query (E2E)', () => {
// Validate system message appears early in sequence
const systemMessageIndex = messages.findIndex(
(msg) => isCLISystemMessage(msg) && msg.subtype === 'init',
(msg) => isSDKSystemMessage(msg) && msg.subtype === 'init',
);
expect(systemMessageIndex).toBeGreaterThanOrEqual(0);
expect(systemMessageIndex).toBeLessThan(3);
@@ -198,12 +187,12 @@ describe('Single-Turn Query (E2E)', () => {
},
});
let systemMessage: CLISystemMessage | null = null;
let systemMessage: SDKSystemMessage | null = null;
const sessionId = q.getSessionId();
try {
for await (const message of q) {
if (isCLISystemMessage(message) && message.subtype === 'init') {
if (isSDKSystemMessage(message) && message.subtype === 'init') {
systemMessage = message;
}
}
@@ -262,7 +251,7 @@ describe('Single-Turn Query (E2E)', () => {
for await (const message of q) {
messageCount++;
if (isCLIResultMessage(message)) {
if (isSDKResultMessage(message)) {
completedNaturally = true;
expect(message.subtype).toBe('success');
}
@@ -319,7 +308,7 @@ describe('Single-Turn Query (E2E)', () => {
try {
for await (const message of q) {
if (isCLIAssistantMessage(message)) {
if (isSDKAssistantMessage(message)) {
hasResponse = true;
}
}
@@ -340,7 +329,7 @@ describe('Single-Turn Query (E2E)', () => {
},
});
const messages: CLIMessage[] = [];
const messages: SDKMessage[] = [];
let partialMessageCount = 0;
let assistantMessageCount = 0;
@@ -348,11 +337,11 @@ describe('Single-Turn Query (E2E)', () => {
for await (const message of q) {
messages.push(message);
if (isCLIPartialAssistantMessage(message)) {
if (isSDKPartialAssistantMessage(message)) {
partialMessageCount++;
}
if (isCLIAssistantMessage(message)) {
if (isSDKAssistantMessage(message)) {
assistantMessageCount++;
}
}
@@ -376,7 +365,7 @@ describe('Single-Turn Query (E2E)', () => {
},
});
const messages: CLIMessage[] = [];
const messages: SDKMessage[] = [];
try {
for await (const message of q) {
@@ -384,9 +373,18 @@ describe('Single-Turn Query (E2E)', () => {
}
// Validate type guards work correctly
const assistantMessages = messages.filter(isCLIAssistantMessage);
const resultMessages = messages.filter(isCLIResultMessage);
const systemMessages = messages.filter(isCLISystemMessage);
const assistantMessages = collectMessagesByType(
messages,
isSDKAssistantMessage,
);
const resultMessages = collectMessagesByType(
messages,
isSDKResultMessage,
);
const systemMessages = collectMessagesByType(
messages,
isSDKSystemMessage,
);
expect(assistantMessages.length).toBeGreaterThan(0);
expect(resultMessages.length).toBeGreaterThan(0);
@@ -414,11 +412,11 @@ describe('Single-Turn Query (E2E)', () => {
},
});
let assistantMessage: CLIAssistantMessage | null = null;
let assistantMessage: SDKAssistantMessage | null = null;
try {
for await (const message of q) {
if (isCLIAssistantMessage(message)) {
if (isSDKAssistantMessage(message)) {
assistantMessage = message;
}
}
@@ -426,17 +424,9 @@ describe('Single-Turn Query (E2E)', () => {
expect(assistantMessage).not.toBeNull();
expect(assistantMessage!.message.content).toBeDefined();
// Extract text blocks
const textBlocks = assistantMessage!.message.content.filter(
(block: ContentBlock): block is TextBlock => block.type === 'text',
);
expect(textBlocks.length).toBeGreaterThan(0);
expect(textBlocks[0].text).toBeDefined();
expect(textBlocks[0].text.length).toBeGreaterThan(0);
// Validate content contains expected numbers
const text = extractText(assistantMessage!.message.content);
expect(text.length).toBeGreaterThan(0);
expect(text).toMatch(/1/);
expect(text).toMatch(/2/);
expect(text).toMatch(/3/);

View File

@@ -9,50 +9,42 @@
* Tests subagent delegation and task completion
*/
import { describe, it, expect, beforeAll } from 'vitest';
import { describe, it, expect, beforeAll, afterAll } from 'vitest';
import { query } from '../../src/index.js';
import {
isCLIAssistantMessage,
isCLISystemMessage,
isCLIResultMessage,
type TextBlock,
type ContentBlock,
type CLIMessage,
type CLISystemMessage,
isSDKAssistantMessage,
type SDKMessage,
type SubagentConfig,
type ContentBlock,
type ToolUseBlock,
} from '../../src/types/protocol.js';
import { writeFile, mkdir } from 'node:fs/promises';
import { join } from 'node:path';
import {
SDKTestHelper,
extractText,
createSharedTestOptions,
findToolUseBlocks,
assertSuccessfulCompletion,
findSystemMessage,
} from './test-helper.js';
const TEST_CLI_PATH = process.env['TEST_CLI_PATH']!;
const E2E_TEST_FILE_DIR = process.env['E2E_TEST_FILE_DIR']!;
const SHARED_TEST_OPTIONS = {
pathToQwenExecutable: TEST_CLI_PATH,
};
/**
* Helper to extract text from ContentBlock array
*/
function extractText(content: ContentBlock[]): string {
return content
.filter((block): block is TextBlock => block.type === 'text')
.map((block) => block.text)
.join('');
}
const SHARED_TEST_OPTIONS = createSharedTestOptions();
describe('Subagents (E2E)', () => {
let helper: SDKTestHelper;
let testWorkDir: string;
beforeAll(async () => {
// Create a test working directory
testWorkDir = join(E2E_TEST_FILE_DIR, 'subagent-tests');
await mkdir(testWorkDir, { recursive: true });
// Create isolated test environment using SDKTestHelper
helper = new SDKTestHelper();
testWorkDir = await helper.setup('subagent-tests');
// Create a simple test file for subagent to work with
const testFilePath = join(testWorkDir, 'test.txt');
await writeFile(testFilePath, 'Hello from test file\n', 'utf-8');
await helper.createFile('test.txt', 'Hello from test file\n');
});
afterAll(async () => {
// Cleanup test directory
await helper.cleanup();
});
describe('Subagent Configuration', () => {
@@ -75,29 +67,21 @@ describe('Subagents (E2E)', () => {
},
});
let systemMessage: CLISystemMessage | null = null;
const messages: CLIMessage[] = [];
const messages: SDKMessage[] = [];
try {
for await (const message of q) {
messages.push(message);
if (isCLISystemMessage(message) && message.subtype === 'init') {
systemMessage = message;
}
}
// Validate system message includes the subagent
const systemMessage = findSystemMessage(messages, 'init');
expect(systemMessage).not.toBeNull();
expect(systemMessage!.agents).toBeDefined();
expect(systemMessage!.agents).toContain('simple-greeter');
// Validate successful completion
const lastMessage = messages[messages.length - 1];
expect(isCLIResultMessage(lastMessage)).toBe(true);
if (isCLIResultMessage(lastMessage)) {
expect(lastMessage.subtype).toBe('success');
}
assertSuccessfulCompletion(messages);
} finally {
await q.close();
}
@@ -128,16 +112,15 @@ describe('Subagents (E2E)', () => {
},
});
let systemMessage: CLISystemMessage | null = null;
const messages: SDKMessage[] = [];
try {
for await (const message of q) {
if (isCLISystemMessage(message) && message.subtype === 'init') {
systemMessage = message;
}
messages.push(message);
}
// Validate both subagents are registered
const systemMessage = findSystemMessage(messages, 'init');
expect(systemMessage).not.toBeNull();
expect(systemMessage!.agents).toBeDefined();
expect(systemMessage!.agents).toContain('greeter');
@@ -170,16 +153,15 @@ describe('Subagents (E2E)', () => {
},
});
let systemMessage: CLISystemMessage | null = null;
const messages: SDKMessage[] = [];
try {
for await (const message of q) {
if (isCLISystemMessage(message) && message.subtype === 'init') {
systemMessage = message;
}
messages.push(message);
}
// Validate subagent is registered
const systemMessage = findSystemMessage(messages, 'init');
expect(systemMessage).not.toBeNull();
expect(systemMessage!.agents).toBeDefined();
expect(systemMessage!.agents).toContain('custom-model-agent');
@@ -210,16 +192,15 @@ describe('Subagents (E2E)', () => {
},
});
let systemMessage: CLISystemMessage | null = null;
const messages: SDKMessage[] = [];
try {
for await (const message of q) {
if (isCLISystemMessage(message) && message.subtype === 'init') {
systemMessage = message;
}
messages.push(message);
}
// Validate subagent is registered
const systemMessage = findSystemMessage(messages, 'init');
expect(systemMessage).not.toBeNull();
expect(systemMessage!.agents).toBeDefined();
expect(systemMessage!.agents).toContain('limited-agent');
@@ -248,16 +229,15 @@ describe('Subagents (E2E)', () => {
},
});
let systemMessage: CLISystemMessage | null = null;
const messages: SDKMessage[] = [];
try {
for await (const message of q) {
if (isCLISystemMessage(message) && message.subtype === 'init') {
systemMessage = message;
}
messages.push(message);
}
// Validate subagent is registered
const systemMessage = findSystemMessage(messages, 'init');
expect(systemMessage).not.toBeNull();
expect(systemMessage!.agents).toBeDefined();
expect(systemMessage!.agents).toContain('read-only-agent');
@@ -277,7 +257,7 @@ describe('Subagents (E2E)', () => {
tools: ['read_file', 'list_directory'],
};
const testFile = join(testWorkDir, 'test.txt');
const testFile = helper.getPath('test.txt');
const q = query({
prompt: `Use the file-reader subagent to read the file at ${testFile} and tell me what it contains.`,
options: {
@@ -289,7 +269,7 @@ describe('Subagents (E2E)', () => {
},
});
const messages: CLIMessage[] = [];
const messages: SDKMessage[] = [];
let foundTaskTool = false;
let taskToolUseId: string | null = null;
let foundSubagentToolCall = false;
@@ -299,25 +279,19 @@ describe('Subagents (E2E)', () => {
for await (const message of q) {
messages.push(message);
if (isCLIAssistantMessage(message)) {
if (isSDKAssistantMessage(message)) {
// Check for task tool use in content blocks (main agent calling subagent)
const toolUseBlock = message.message.content.find(
(block: ContentBlock): block is ToolUseBlock =>
block.type === 'tool_use' && block.name === 'task',
);
if (toolUseBlock) {
const taskToolBlocks = findToolUseBlocks(message, 'task');
if (taskToolBlocks.length > 0) {
foundTaskTool = true;
taskToolUseId = toolUseBlock.id;
taskToolUseId = taskToolBlocks[0].id;
}
// Check if this message is from a subagent (has parent_tool_use_id)
if (message.parent_tool_use_id !== null) {
// This is a subagent message
const subagentToolUse = message.message.content.find(
(block: ContentBlock): block is ToolUseBlock =>
block.type === 'tool_use',
);
if (subagentToolUse) {
const subagentToolBlocks = findToolUseBlocks(message);
if (subagentToolBlocks.length > 0) {
foundSubagentToolCall = true;
// Verify parent_tool_use_id matches the task tool use id
expect(message.parent_tool_use_id).toBe(taskToolUseId);
@@ -339,11 +313,7 @@ describe('Subagents (E2E)', () => {
expect(assistantText.length).toBeGreaterThan(0);
// Validate successful completion
const lastMessage = messages[messages.length - 1];
expect(isCLIResultMessage(lastMessage)).toBe(true);
if (isCLIResultMessage(lastMessage)) {
expect(lastMessage.subtype).toBe('success');
}
assertSuccessfulCompletion(messages);
} finally {
await q.close();
}
@@ -369,7 +339,7 @@ describe('Subagents (E2E)', () => {
},
});
const messages: CLIMessage[] = [];
const messages: SDKMessage[] = [];
let foundTaskTool = false;
let assistantText = '';
@@ -377,7 +347,7 @@ describe('Subagents (E2E)', () => {
for await (const message of q) {
messages.push(message);
if (isCLIAssistantMessage(message)) {
if (isSDKAssistantMessage(message)) {
// Check for task tool use (main agent delegating to subagent)
const toolUseBlock = message.message.content.find(
(block: ContentBlock): block is ToolUseBlock =>
@@ -398,11 +368,7 @@ describe('Subagents (E2E)', () => {
expect(assistantText.length).toBeGreaterThan(0);
// Validate successful completion
const lastMessage = messages[messages.length - 1];
expect(isCLIResultMessage(lastMessage)).toBe(true);
if (isCLIResultMessage(lastMessage)) {
expect(lastMessage.subtype).toBe('success');
}
assertSuccessfulCompletion(messages);
} finally {
await q.close();
}
@@ -429,7 +395,7 @@ describe('Subagents (E2E)', () => {
},
});
const messages: CLIMessage[] = [];
const messages: SDKMessage[] = [];
let taskToolUseId: string | null = null;
const subagentToolCalls: ToolUseBlock[] = [];
const mainAgentToolCalls: ToolUseBlock[] = [];
@@ -438,7 +404,7 @@ describe('Subagents (E2E)', () => {
for await (const message of q) {
messages.push(message);
if (isCLIAssistantMessage(message)) {
if (isSDKAssistantMessage(message)) {
// Collect all tool use blocks
const toolUseBlocks = message.message.content.filter(
(block: ContentBlock): block is ToolUseBlock =>
@@ -471,8 +437,8 @@ describe('Subagents (E2E)', () => {
// Verify all subagent messages have the correct parent_tool_use_id
const subagentMessages = messages.filter(
(msg): msg is CLIMessage & { parent_tool_use_id: string } =>
isCLIAssistantMessage(msg) && msg.parent_tool_use_id !== null,
(msg): msg is SDKMessage & { parent_tool_use_id: string } =>
isSDKAssistantMessage(msg) && msg.parent_tool_use_id !== null,
);
expect(subagentMessages.length).toBeGreaterThan(0);
@@ -482,23 +448,19 @@ describe('Subagents (E2E)', () => {
// Verify no main agent tool calls (except task) have parent_tool_use_id
const mainAgentMessages = messages.filter(
(msg): msg is CLIMessage =>
isCLIAssistantMessage(msg) && msg.parent_tool_use_id === null,
(msg): msg is SDKMessage =>
isSDKAssistantMessage(msg) && msg.parent_tool_use_id === null,
);
for (const mainMsg of mainAgentMessages) {
if (isCLIAssistantMessage(mainMsg)) {
if (isSDKAssistantMessage(mainMsg)) {
// Main agent messages should not have parent_tool_use_id
expect(mainMsg.parent_tool_use_id).toBeNull();
}
}
// Validate successful completion
const lastMessage = messages[messages.length - 1];
expect(isCLIResultMessage(lastMessage)).toBe(true);
if (isCLIResultMessage(lastMessage)) {
expect(lastMessage.subtype).toBe('success');
}
assertSuccessfulCompletion(messages);
} finally {
await q.close();
}
@@ -517,16 +479,15 @@ describe('Subagents (E2E)', () => {
},
});
let systemMessage: CLISystemMessage | null = null;
const messages: SDKMessage[] = [];
try {
for await (const message of q) {
if (isCLISystemMessage(message) && message.subtype === 'init') {
systemMessage = message;
}
messages.push(message);
}
// Should still work with empty agents array
const systemMessage = findSystemMessage(messages, 'init');
expect(systemMessage).not.toBeNull();
expect(systemMessage!.agents).toBeDefined();
} finally {
@@ -552,16 +513,15 @@ describe('Subagents (E2E)', () => {
},
});
let systemMessage: CLISystemMessage | null = null;
const messages: SDKMessage[] = [];
try {
for await (const message of q) {
if (isCLISystemMessage(message) && message.subtype === 'init') {
systemMessage = message;
}
messages.push(message);
}
// Validate minimal agent is registered
const systemMessage = findSystemMessage(messages, 'init');
expect(systemMessage).not.toBeNull();
expect(systemMessage!.agents).toBeDefined();
expect(systemMessage!.agents).toContain('minimal-agent');
@@ -596,16 +556,15 @@ describe('Subagents (E2E)', () => {
},
});
let systemMessage: CLISystemMessage | null = null;
const messages: SDKMessage[] = [];
try {
for await (const message of q) {
if (isCLISystemMessage(message) && message.subtype === 'init') {
systemMessage = message;
}
messages.push(message);
}
// Validate subagent works with debug mode
const systemMessage = findSystemMessage(messages, 'init');
expect(systemMessage).not.toBeNull();
expect(systemMessage!.agents).toBeDefined();
expect(systemMessage!.agents).toContain('test-agent');
@@ -633,16 +592,15 @@ describe('Subagents (E2E)', () => {
},
});
let systemMessage: CLISystemMessage | null = null;
const messages: SDKMessage[] = [];
try {
for await (const message of q) {
if (isCLISystemMessage(message) && message.subtype === 'init') {
systemMessage = message;
}
messages.push(message);
}
// Validate session consistency
const systemMessage = findSystemMessage(messages, 'init');
expect(systemMessage).not.toBeNull();
expect(systemMessage!.session_id).toBeDefined();
expect(systemMessage!.uuid).toBeDefined();

View File

@@ -6,9 +6,9 @@
import { describe, it, expect } from 'vitest';
import { query } from '../../src/index.js';
import {
isCLIAssistantMessage,
isCLISystemMessage,
type CLIUserMessage,
isSDKAssistantMessage,
isSDKSystemMessage,
type SDKUserMessage,
} from '../../src/types/protocol.js';
const TEST_CLI_PATH = process.env['TEST_CLI_PATH']!;
@@ -30,7 +30,7 @@ function createStreamingInputWithControlPoint(
firstMessage: string,
secondMessage: string,
): {
generator: AsyncIterable<CLIUserMessage>;
generator: AsyncIterable<SDKUserMessage>;
resume: () => void;
} {
let resumeResolve: (() => void) | null = null;
@@ -49,7 +49,7 @@ function createStreamingInputWithControlPoint(
content: firstMessage,
},
parent_tool_use_id: null,
} as CLIUserMessage;
} as SDKUserMessage;
await new Promise((resolve) => setTimeout(resolve, 200));
@@ -65,7 +65,7 @@ function createStreamingInputWithControlPoint(
content: secondMessage,
},
parent_tool_use_id: null,
} as CLIUserMessage;
} as SDKUserMessage;
})();
const resume = () => {
@@ -113,10 +113,10 @@ describe('System Control (E2E)', () => {
// Consume messages in a single loop
(async () => {
for await (const message of q) {
if (isCLISystemMessage(message)) {
if (isSDKSystemMessage(message)) {
systemMessages.push({ model: message.model });
}
if (isCLIAssistantMessage(message)) {
if (isSDKAssistantMessage(message)) {
if (!firstResponseReceived) {
firstResponseReceived = true;
resolvers.first?.();
@@ -186,7 +186,7 @@ describe('System Control (E2E)', () => {
session_id: sessionId,
message: { role: 'user', content: 'First message' },
parent_tool_use_id: null,
} as CLIUserMessage;
} as SDKUserMessage;
await new Promise((resolve) => setTimeout(resolve, 200));
await resumePromise1;
@@ -197,7 +197,7 @@ describe('System Control (E2E)', () => {
session_id: sessionId,
message: { role: 'user', content: 'Second message' },
parent_tool_use_id: null,
} as CLIUserMessage;
} as SDKUserMessage;
await new Promise((resolve) => setTimeout(resolve, 200));
await resumePromise2;
@@ -208,7 +208,7 @@ describe('System Control (E2E)', () => {
session_id: sessionId,
message: { role: 'user', content: 'Third message' },
parent_tool_use_id: null,
} as CLIUserMessage;
} as SDKUserMessage;
})();
const q = query({
@@ -232,10 +232,10 @@ describe('System Control (E2E)', () => {
(async () => {
for await (const message of q) {
if (isCLISystemMessage(message)) {
if (isSDKSystemMessage(message)) {
systemMessages.push({ model: message.model });
}
if (isCLIAssistantMessage(message)) {
if (isSDKAssistantMessage(message)) {
if (responseCount < resolvers.length) {
resolvers[responseCount]?.();
responseCount++;

View File

@@ -0,0 +1,829 @@
/**
* @license
* Copyright 2025 Qwen Team
* SPDX-License-Identifier: Apache-2.0
*/
/**
* SDK E2E Test Helper
* Provides utilities for SDK e2e tests including test isolation,
* file management, MCP server setup, and common test utilities.
*/
import { mkdir, writeFile, readFile, rm, chmod } from 'node:fs/promises';
import { join } from 'node:path';
import { existsSync } from 'node:fs';
import type {
SDKMessage,
SDKAssistantMessage,
SDKSystemMessage,
SDKUserMessage,
ContentBlock,
TextBlock,
ToolUseBlock,
} from '../../src/types/protocol.js';
import {
isSDKAssistantMessage,
isSDKSystemMessage,
isSDKResultMessage,
} from '../../src/types/protocol.js';
// ============================================================================
// Core Test Helper Class
// ============================================================================
export interface SDKTestHelperOptions {
/**
* Optional settings for .qwen/settings.json
*/
settings?: Record<string, unknown>;
/**
* Whether to create .qwen/settings.json
*/
createQwenConfig?: boolean;
}
/**
* Helper class for SDK E2E tests
* Provides isolated test environments for each test case
*/
export class SDKTestHelper {
testDir: string | null = null;
testName?: string;
private baseDir: string;
constructor() {
this.baseDir = process.env['E2E_TEST_FILE_DIR']!;
if (!this.baseDir) {
throw new Error('E2E_TEST_FILE_DIR environment variable not set');
}
}
/**
* Setup an isolated test directory for a specific test
*/
async setup(
testName: string,
options: SDKTestHelperOptions = {},
): Promise<string> {
this.testName = testName;
const sanitizedName = this.sanitizeTestName(testName);
this.testDir = join(this.baseDir, sanitizedName);
await mkdir(this.testDir, { recursive: true });
// Optionally create .qwen/settings.json for CLI configuration
if (options.createQwenConfig) {
const qwenDir = join(this.testDir, '.qwen');
await mkdir(qwenDir, { recursive: true });
const settings = {
telemetry: {
enabled: false, // SDK tests don't need telemetry
},
...options.settings,
};
await writeFile(
join(qwenDir, 'settings.json'),
JSON.stringify(settings, null, 2),
'utf-8',
);
}
return this.testDir;
}
/**
* Create a file in the test directory
*/
async createFile(fileName: string, content: string): Promise<string> {
if (!this.testDir) {
throw new Error('Test directory not initialized. Call setup() first.');
}
const filePath = join(this.testDir, fileName);
await writeFile(filePath, content, 'utf-8');
return filePath;
}
/**
* Read a file from the test directory
*/
async readFile(fileName: string): Promise<string> {
if (!this.testDir) {
throw new Error('Test directory not initialized. Call setup() first.');
}
const filePath = join(this.testDir, fileName);
return await readFile(filePath, 'utf-8');
}
/**
* Create a subdirectory in the test directory
*/
async mkdir(dirName: string): Promise<string> {
if (!this.testDir) {
throw new Error('Test directory not initialized. Call setup() first.');
}
const dirPath = join(this.testDir, dirName);
await mkdir(dirPath, { recursive: true });
return dirPath;
}
/**
* Check if a file exists in the test directory
*/
fileExists(fileName: string): boolean {
if (!this.testDir) {
throw new Error('Test directory not initialized. Call setup() first.');
}
const filePath = join(this.testDir, fileName);
return existsSync(filePath);
}
/**
* Get the full path to a file in the test directory
*/
getPath(fileName: string): string {
if (!this.testDir) {
throw new Error('Test directory not initialized. Call setup() first.');
}
return join(this.testDir, fileName);
}
/**
* Cleanup test directory
*/
async cleanup(): Promise<void> {
if (this.testDir && process.env['KEEP_OUTPUT'] !== 'true') {
try {
await rm(this.testDir, { recursive: true, force: true });
} catch (error) {
if (process.env['VERBOSE'] === 'true') {
console.warn('Cleanup warning:', (error as Error).message);
}
}
}
}
/**
* Sanitize test name to create valid directory name
*/
private sanitizeTestName(name: string): string {
return name
.toLowerCase()
.replace(/[^a-z0-9]/g, '-')
.replace(/-+/g, '-')
.substring(0, 100); // Limit length
}
}
// ============================================================================
// MCP Server Utilities
// ============================================================================
export interface MCPServerConfig {
command: string;
args: string[];
}
export interface MCPServerResult {
scriptPath: string;
config: MCPServerConfig;
}
/**
* Built-in MCP server template: Math server with add and multiply tools
*/
const MCP_MATH_SERVER_SCRIPT = `#!/usr/bin/env node
/**
* @license
* Copyright 2025 Qwen Team
* SPDX-License-Identifier: Apache-2.0
*/
const readline = require('readline');
const fs = require('fs');
// Debug logging to stderr (only when MCP_DEBUG or VERBOSE is set)
const debugEnabled = process.env['MCP_DEBUG'] === 'true' || process.env['VERBOSE'] === 'true';
function debug(msg) {
if (debugEnabled) {
fs.writeSync(2, \`[MCP-DEBUG] \${msg}\\n\`);
}
}
debug('MCP server starting...');
// Simple JSON-RPC implementation for MCP
class SimpleJSONRPC {
constructor() {
this.handlers = new Map();
this.rl = readline.createInterface({
input: process.stdin,
output: process.stdout,
terminal: false
});
this.rl.on('line', (line) => {
debug(\`Received line: \${line}\`);
try {
const message = JSON.parse(line);
debug(\`Parsed message: \${JSON.stringify(message)}\`);
this.handleMessage(message);
} catch (e) {
debug(\`Parse error: \${e.message}\`);
}
});
}
send(message) {
const msgStr = JSON.stringify(message);
debug(\`Sending message: \${msgStr}\`);
process.stdout.write(msgStr + '\\n');
}
async handleMessage(message) {
if (message.method && this.handlers.has(message.method)) {
try {
const result = await this.handlers.get(message.method)(message.params || {});
if (message.id !== undefined) {
this.send({
jsonrpc: '2.0',
id: message.id,
result
});
}
} catch (error) {
if (message.id !== undefined) {
this.send({
jsonrpc: '2.0',
id: message.id,
error: {
code: -32603,
message: error.message
}
});
}
}
} else if (message.id !== undefined) {
this.send({
jsonrpc: '2.0',
id: message.id,
error: {
code: -32601,
message: 'Method not found'
}
});
}
}
on(method, handler) {
this.handlers.set(method, handler);
}
}
// Create MCP server
const rpc = new SimpleJSONRPC();
// Handle initialize
rpc.on('initialize', async (params) => {
debug('Handling initialize request');
return {
protocolVersion: '2024-11-05',
capabilities: {
tools: {}
},
serverInfo: {
name: 'test-math-server',
version: '1.0.0'
}
};
});
// Handle tools/list
rpc.on('tools/list', async () => {
debug('Handling tools/list request');
return {
tools: [
{
name: 'add',
description: 'Add two numbers together',
inputSchema: {
type: 'object',
properties: {
a: { type: 'number', description: 'First number' },
b: { type: 'number', description: 'Second number' }
},
required: ['a', 'b']
}
},
{
name: 'multiply',
description: 'Multiply two numbers together',
inputSchema: {
type: 'object',
properties: {
a: { type: 'number', description: 'First number' },
b: { type: 'number', description: 'Second number' }
},
required: ['a', 'b']
}
}
]
};
});
// Handle tools/call
rpc.on('tools/call', async (params) => {
debug(\`Handling tools/call request for tool: \${params.name}\`);
if (params.name === 'add') {
const { a, b } = params.arguments;
return {
content: [{
type: 'text',
text: String(a + b)
}]
};
}
if (params.name === 'multiply') {
const { a, b } = params.arguments;
return {
content: [{
type: 'text',
text: String(a * b)
}]
};
}
throw new Error('Unknown tool: ' + params.name);
});
// Send initialization notification
rpc.send({
jsonrpc: '2.0',
method: 'initialized'
});
`;
/**
* Create an MCP server script in the test directory
* @param helper - SDKTestHelper instance
* @param type - Type of MCP server ('math' or provide custom script)
* @param serverName - Name of the MCP server (default: 'test-math-server')
* @param customScript - Custom MCP server script (if type is not 'math')
* @returns Object with scriptPath and config
*/
export async function createMCPServer(
helper: SDKTestHelper,
type: 'math' | 'custom' = 'math',
serverName: string = 'test-math-server',
customScript?: string,
): Promise<MCPServerResult> {
if (!helper.testDir) {
throw new Error('Test directory not initialized. Call setup() first.');
}
const script = type === 'math' ? MCP_MATH_SERVER_SCRIPT : customScript;
if (!script) {
throw new Error('Custom script required when type is "custom"');
}
const scriptPath = join(helper.testDir, `${serverName}.cjs`);
await writeFile(scriptPath, script, 'utf-8');
// Make script executable on Unix-like systems
if (process.platform !== 'win32') {
await chmod(scriptPath, 0o755);
}
return {
scriptPath,
config: {
command: 'node',
args: [scriptPath],
},
};
}
// ============================================================================
// Message & Content Utilities
// ============================================================================
/**
* Extract text from ContentBlock array
*/
export function extractText(content: ContentBlock[]): string {
return content
.filter((block): block is TextBlock => block.type === 'text')
.map((block) => block.text)
.join('');
}
/**
* Collect messages by type
*/
export function collectMessagesByType<T extends SDKMessage>(
messages: SDKMessage[],
predicate: (msg: SDKMessage) => msg is T,
): T[] {
return messages.filter(predicate);
}
/**
* Find tool use blocks in a message
*/
export function findToolUseBlocks(
message: SDKAssistantMessage,
toolName?: string,
): ToolUseBlock[] {
const toolUseBlocks = message.message.content.filter(
(block): block is ToolUseBlock => block.type === 'tool_use',
);
if (toolName) {
return toolUseBlocks.filter((block) => block.name === toolName);
}
return toolUseBlocks;
}
/**
* Extract all assistant text from messages
*/
export function getAssistantText(messages: SDKMessage[]): string {
return messages
.filter(isSDKAssistantMessage)
.map((msg) => extractText(msg.message.content))
.join('');
}
/**
* Find system message with optional subtype filter
*/
export function findSystemMessage(
messages: SDKMessage[],
subtype?: string,
): SDKSystemMessage | null {
const systemMessages = messages.filter(isSDKSystemMessage);
if (subtype) {
return systemMessages.find((msg) => msg.subtype === subtype) || null;
}
return systemMessages[0] || null;
}
/**
* Find all tool calls in messages
*/
export function findToolCalls(
messages: SDKMessage[],
toolName?: string,
): Array<{ message: SDKAssistantMessage; toolUse: ToolUseBlock }> {
const results: Array<{
message: SDKAssistantMessage;
toolUse: ToolUseBlock;
}> = [];
for (const message of messages) {
if (isSDKAssistantMessage(message)) {
const toolUseBlocks = findToolUseBlocks(message, toolName);
for (const toolUse of toolUseBlocks) {
results.push({ message, toolUse });
}
}
}
return results;
}
// ============================================================================
// Streaming Input Utilities
// ============================================================================
/**
* Create a simple streaming input from an array of message contents
*/
export async function* createStreamingInput(
messageContents: string[],
sessionId?: string,
): AsyncIterable<SDKUserMessage> {
const sid = sessionId || crypto.randomUUID();
for (const content of messageContents) {
yield {
type: 'user',
session_id: sid,
message: {
role: 'user',
content: content,
},
parent_tool_use_id: null,
} as SDKUserMessage;
// Small delay between messages
await new Promise((resolve) => setTimeout(resolve, 100));
}
}
/**
* Create a controlled streaming input with pause/resume capability
*/
export function createControlledStreamingInput(
messageContents: string[],
sessionId?: string,
): {
generator: AsyncIterable<SDKUserMessage>;
resume: () => void;
resumeAll: () => void;
} {
const sid = sessionId || crypto.randomUUID();
const resumeResolvers: Array<() => void> = [];
const resumePromises: Array<Promise<void>> = [];
// Create a resume promise for each message after the first
for (let i = 1; i < messageContents.length; i++) {
const promise = new Promise<void>((resolve) => {
resumeResolvers.push(resolve);
});
resumePromises.push(promise);
}
const generator = (async function* () {
// Yield first message immediately
yield {
type: 'user',
session_id: sid,
message: {
role: 'user',
content: messageContents[0],
},
parent_tool_use_id: null,
} as SDKUserMessage;
// For subsequent messages, wait for resume
for (let i = 1; i < messageContents.length; i++) {
await new Promise((resolve) => setTimeout(resolve, 200));
await resumePromises[i - 1];
await new Promise((resolve) => setTimeout(resolve, 200));
yield {
type: 'user',
session_id: sid,
message: {
role: 'user',
content: messageContents[i],
},
parent_tool_use_id: null,
} as SDKUserMessage;
}
})();
let currentResumeIndex = 0;
return {
generator,
resume: () => {
if (currentResumeIndex < resumeResolvers.length) {
resumeResolvers[currentResumeIndex]();
currentResumeIndex++;
}
},
resumeAll: () => {
resumeResolvers.forEach((resolve) => resolve());
currentResumeIndex = resumeResolvers.length;
},
};
}
// ============================================================================
// Assertion Utilities
// ============================================================================
/**
* Assert that messages follow expected type sequence
*/
export function assertMessageSequence(
messages: SDKMessage[],
expectedTypes: string[],
): void {
const actualTypes = messages.map((msg) => msg.type);
if (actualTypes.length < expectedTypes.length) {
throw new Error(
`Expected at least ${expectedTypes.length} messages, got ${actualTypes.length}`,
);
}
for (let i = 0; i < expectedTypes.length; i++) {
if (actualTypes[i] !== expectedTypes[i]) {
throw new Error(
`Expected message ${i} to be type '${expectedTypes[i]}', got '${actualTypes[i]}'`,
);
}
}
}
/**
* Assert that a specific tool was called
*/
export function assertToolCalled(
messages: SDKMessage[],
toolName: string,
): void {
const toolCalls = findToolCalls(messages, toolName);
if (toolCalls.length === 0) {
const allToolCalls = findToolCalls(messages);
const allToolNames = allToolCalls.map((tc) => tc.toolUse.name);
throw new Error(
`Expected tool '${toolName}' to be called. Found tools: ${allToolNames.length > 0 ? allToolNames.join(', ') : 'none'}`,
);
}
}
/**
* Assert that the conversation completed successfully
*/
export function assertSuccessfulCompletion(messages: SDKMessage[]): void {
const lastMessage = messages[messages.length - 1];
if (!isSDKResultMessage(lastMessage)) {
throw new Error(
`Expected last message to be a result message, got '${lastMessage.type}'`,
);
}
if (lastMessage.subtype !== 'success') {
throw new Error(
`Expected successful completion, got result subtype '${lastMessage.subtype}'`,
);
}
}
/**
* Wait for a condition to be true with timeout
*/
export async function waitFor(
predicate: () => boolean | Promise<boolean>,
options: {
timeout?: number;
interval?: number;
errorMessage?: string;
} = {},
): Promise<void> {
const {
timeout = 5000,
interval = 100,
errorMessage = 'Condition not met within timeout',
} = options;
const startTime = Date.now();
while (Date.now() - startTime < timeout) {
const result = await predicate();
if (result) {
return;
}
await new Promise((resolve) => setTimeout(resolve, interval));
}
throw new Error(errorMessage);
}
// ============================================================================
// Debug and Validation Utilities
// ============================================================================
/**
* Validate model output and warn about unexpected content
* Inspired by integration-tests test-helper
*/
export function validateModelOutput(
result: string,
expectedContent: string | (string | RegExp)[] | null = null,
testName = '',
): boolean {
// First, check if there's any output at all
if (!result || result.trim().length === 0) {
throw new Error('Expected model to return some output');
}
// If expectedContent is provided, check for it and warn if missing
if (expectedContent) {
const contents = Array.isArray(expectedContent)
? expectedContent
: [expectedContent];
const missingContent = contents.filter((content) => {
if (typeof content === 'string') {
return !result.toLowerCase().includes(content.toLowerCase());
} else if (content instanceof RegExp) {
return !content.test(result);
}
return false;
});
if (missingContent.length > 0) {
console.warn(
`Warning: Model did not include expected content in response: ${missingContent.join(', ')}.`,
'This is not ideal but not a test failure.',
);
console.warn(
'The tool was called successfully, which is the main requirement.',
);
return false;
} else if (process.env['VERBOSE'] === 'true') {
console.log(`${testName}: Model output validated successfully.`);
}
return true;
}
return true;
}
/**
* Print debug information when tests fail
*/
export function printDebugInfo(
messages: SDKMessage[],
context: Record<string, unknown> = {},
): void {
console.error('Test failed - Debug info:');
console.error('Message count:', messages.length);
// Print message types
const messageTypes = messages.map((m) => m.type);
console.error('Message types:', messageTypes.join(', '));
// Print assistant text
const assistantText = getAssistantText(messages);
console.error(
'Assistant text (first 500 chars):',
assistantText.substring(0, 500),
);
if (assistantText.length > 500) {
console.error(
'Assistant text (last 500 chars):',
assistantText.substring(assistantText.length - 500),
);
}
// Print tool calls
const toolCalls = findToolCalls(messages);
console.error(
'Tool calls found:',
toolCalls.map((tc) => tc.toolUse.name),
);
// Print any additional context provided
Object.entries(context).forEach(([key, value]) => {
console.error(`${key}:`, value);
});
}
/**
* Create detailed error message for tool call expectations
*/
export function createToolCallErrorMessage(
expectedTools: string | string[],
foundTools: string[],
messages: SDKMessage[],
): string {
const expectedStr = Array.isArray(expectedTools)
? expectedTools.join(' or ')
: expectedTools;
const assistantText = getAssistantText(messages);
const preview = assistantText
? assistantText.substring(0, 200) + '...'
: 'no output';
return (
`Expected to find ${expectedStr} tool call(s). ` +
`Found: ${foundTools.length > 0 ? foundTools.join(', ') : 'none'}. ` +
`Output preview: ${preview}`
);
}
// ============================================================================
// Shared Test Options Helper
// ============================================================================
/**
* Create shared test options with CLI path
*/
export function createSharedTestOptions(
overrides: Record<string, unknown> = {},
) {
const TEST_CLI_PATH = process.env['TEST_CLI_PATH'];
if (!TEST_CLI_PATH) {
throw new Error('TEST_CLI_PATH environment variable not set');
}
return {
pathToQwenExecutable: TEST_CLI_PATH,
...overrides,
};
}

View File

@@ -7,12 +7,12 @@ import { describe, expect, it, vi, beforeEach, afterEach } from 'vitest';
import { Query } from '../../src/query/Query.js';
import type { Transport } from '../../src/transport/Transport.js';
import type {
CLIMessage,
CLIUserMessage,
CLIAssistantMessage,
CLISystemMessage,
CLIResultMessage,
CLIPartialAssistantMessage,
SDKMessage,
SDKUserMessage,
SDKAssistantMessage,
SDKSystemMessage,
SDKResultMessage,
SDKPartialAssistantMessage,
CLIControlRequest,
CLIControlResponse,
ControlCancelRequest,
@@ -118,7 +118,7 @@ function findControlRequest(
function createUserMessage(
content: string,
sessionId = 'test-session',
): CLIUserMessage {
): SDKUserMessage {
return {
type: 'user',
session_id: sessionId,
@@ -133,7 +133,7 @@ function createUserMessage(
function createAssistantMessage(
content: string,
sessionId = 'test-session',
): CLIAssistantMessage {
): SDKAssistantMessage {
return {
type: 'assistant',
uuid: 'msg-123',
@@ -153,7 +153,7 @@ function createAssistantMessage(
function createSystemMessage(
subtype: string,
sessionId = 'test-session',
): CLISystemMessage {
): SDKSystemMessage {
return {
type: 'system',
subtype,
@@ -168,7 +168,7 @@ function createSystemMessage(
function createResultMessage(
success: boolean,
sessionId = 'test-session',
): CLIResultMessage {
): SDKResultMessage {
if (success) {
return {
type: 'result',
@@ -202,7 +202,7 @@ function createResultMessage(
function createPartialMessage(
sessionId = 'test-session',
): CLIPartialAssistantMessage {
): SDKPartialAssistantMessage {
return {
type: 'stream_event',
uuid: 'stream-123',
@@ -816,7 +816,7 @@ describe('Query', () => {
msg !== null &&
'type' in msg &&
msg.type === 'user',
) as CLIUserMessage[];
) as SDKUserMessage[];
userMessages.forEach((msg) => {
expect(msg.session_id).toBe(sessionId);
@@ -889,7 +889,7 @@ describe('Query', () => {
const query = new Query(transport, { cwd: '/test' });
const iterationPromise = (async () => {
const messages: CLIMessage[] = [];
const messages: SDKMessage[] = [];
for await (const msg of query) {
messages.push(msg);
}
@@ -946,7 +946,7 @@ describe('Query', () => {
it('should support for await loop', async () => {
const query = new Query(transport, { cwd: '/test' });
const messages: CLIMessage[] = [];
const messages: SDKMessage[] = [];
const iterationPromise = (async () => {
for await (const msg of query) {
messages.push(msg);
@@ -960,7 +960,7 @@ describe('Query', () => {
await iterationPromise;
expect(messages).toHaveLength(2);
expect((messages[0] as CLIUserMessage).message.content).toBe('First');
expect((messages[0] as SDKUserMessage).message.content).toBe('First');
await query.close();
});
@@ -968,7 +968,7 @@ describe('Query', () => {
it('should complete iteration when query closes', async () => {
const query = new Query(transport, { cwd: '/test' });
const messages: CLIMessage[] = [];
const messages: SDKMessage[] = [];
const iterationPromise = (async () => {
for await (const msg of query) {
messages.push(msg);
@@ -1321,7 +1321,7 @@ describe('Query', () => {
const result = await query.next();
expect(result.done).toBe(false);
expect((result.value as CLIResultMessage).is_error).toBe(true);
expect((result.value as SDKResultMessage).is_error).toBe(true);
await query.close();
});
@@ -1430,7 +1430,7 @@ describe('Query', () => {
transport.simulateMessage(createUserMessage(`Message ${i}`));
}
const messages: CLIMessage[] = [];
const messages: SDKMessage[] = [];
for (let i = 0; i < 100; i++) {
const result = await query.next();
if (!result.done) {
@@ -1447,7 +1447,7 @@ describe('Query', () => {
const query = new Query(transport, { cwd: '/test' });
const iterationPromise = (async () => {
const messages: CLIMessage[] = [];
const messages: SDKMessage[] = [];
for await (const msg of query) {
messages.push(msg);
if (messages.length === 2) {