Merge pull request #1147 from QwenLM/mingholy/feat/cli-sdk-stage-2

Custom tools support via SDK controlled MCP servers
This commit is contained in:
Mingholy
2025-12-05 21:19:58 +08:00
committed by GitHub
33 changed files with 2597 additions and 862 deletions

View File

@@ -132,6 +132,24 @@ jobs:
OPENAI_BASE_URL: '${{ secrets.OPENAI_BASE_URL }}' OPENAI_BASE_URL: '${{ secrets.OPENAI_BASE_URL }}'
OPENAI_MODEL: '${{ secrets.OPENAI_MODEL }}' OPENAI_MODEL: '${{ secrets.OPENAI_MODEL }}'
- name: 'Build CLI for Integration Tests'
if: |-
${{ github.event.inputs.force_skip_tests != 'true' }}
run: |
npm run build
npm run bundle
- name: 'Run SDK Integration Tests'
if: |-
${{ github.event.inputs.force_skip_tests != 'true' }}
run: |
npm run test:integration:sdk:sandbox:none
npm run test:integration:sdk:sandbox:docker
env:
OPENAI_API_KEY: '${{ secrets.OPENAI_API_KEY }}'
OPENAI_BASE_URL: '${{ secrets.OPENAI_BASE_URL }}'
OPENAI_MODEL: '${{ secrets.OPENAI_MODEL }}'
- name: 'Configure Git User' - name: 'Configure Git User'
run: | run: |
git config user.name "github-actions[bot]" git config user.name "github-actions[bot]"

View File

@@ -532,7 +532,6 @@ describe('Configuration Options (E2E)', () => {
cwd: testDir, cwd: testDir,
authType: 'openai', authType: 'openai',
debug: true, debug: true,
logLevel: 'debug',
stderr: (msg: string) => { stderr: (msg: string) => {
stderrMessages.push(msg); stderrMessages.push(msg);
}, },

View File

@@ -555,6 +555,15 @@ describe('Permission Control (E2E)', () => {
...SHARED_TEST_OPTIONS, ...SHARED_TEST_OPTIONS,
cwd: testDir, cwd: testDir,
permissionMode: 'default', permissionMode: 'default',
timeout: {
/**
* We use a short control request timeout and
* wait till the time exceeded to test if
* an immediate close() will raise an query close
* error and no other uncaught timeout error
*/
controlRequest: 5000,
},
}, },
}); });
@@ -563,7 +572,9 @@ describe('Permission Control (E2E)', () => {
await expect(q.setPermissionMode('yolo')).rejects.toThrow( await expect(q.setPermissionMode('yolo')).rejects.toThrow(
'Query is closed', 'Query is closed',
); );
});
await new Promise((resolve) => setTimeout(resolve, 8000));
}, 10_000);
}); });
describe('canUseTool and setPermissionMode integration', () => { describe('canUseTool and setPermissionMode integration', () => {

View File

@@ -0,0 +1,465 @@
/**
* @license
* Copyright 2025 Qwen Team
* SPDX-License-Identifier: Apache-2.0
*/
/**
* E2E tests for SDK-embedded MCP servers
*
* Tests that the SDK can create and manage MCP servers running in the SDK process
* using the tool() and createSdkMcpServer() APIs.
*/
import { describe, it, expect, beforeEach, afterEach } from 'vitest';
import { z } from 'zod';
import {
query,
tool,
createSdkMcpServer,
isSDKAssistantMessage,
isSDKResultMessage,
isSDKSystemMessage,
type SDKMessage,
type SDKSystemMessage,
} from '@qwen-code/sdk-typescript';
import {
SDKTestHelper,
extractText,
findToolUseBlocks,
createSharedTestOptions,
} from './test-helper.js';
const SHARED_TEST_OPTIONS = {
...createSharedTestOptions(),
permissionMode: 'yolo' as const,
};
describe('SDK MCP Server Integration (E2E)', () => {
let helper: SDKTestHelper;
let testDir: string;
beforeEach(async () => {
helper = new SDKTestHelper();
testDir = await helper.setup('sdk-mcp-server-integration');
});
afterEach(async () => {
await helper.cleanup();
});
describe('Basic SDK MCP Tool Usage', () => {
it('should use SDK MCP tool to perform a simple calculation', async () => {
// Define a simple calculator tool using the tool() API with Zod schema
console.log(
z.object({
a: z.number().describe('First number'),
b: z.number().describe('Second number'),
}),
);
const calculatorTool = tool(
'calculate_sum',
'Calculate the sum of two numbers',
z.object({
a: z.number().describe('First number'),
b: z.number().describe('Second number'),
}).shape,
async (args) => ({
content: [{ type: 'text', text: String(args.a + args.b) }],
}),
);
// Create SDK MCP server with the tool
const serverConfig = createSdkMcpServer({
name: 'sdk-calculator',
version: '1.0.0',
tools: [calculatorTool],
});
const q = query({
prompt:
'Use the calculate_sum tool to add 25 and 17. Output the result of tool only.',
options: {
...SHARED_TEST_OPTIONS,
cwd: testDir,
stderr: (message) => console.error(message),
mcpServers: {
'sdk-calculator': serverConfig,
},
},
});
const messages: SDKMessage[] = [];
let assistantText = '';
let foundToolUse = false;
try {
for await (const message of q) {
messages.push(message);
console.log(JSON.stringify(message, null, 2));
if (isSDKAssistantMessage(message)) {
const toolUseBlocks = findToolUseBlocks(message, 'calculate_sum');
if (toolUseBlocks.length > 0) {
foundToolUse = true;
}
assistantText += extractText(message.message.content);
}
}
// Validate tool was called
expect(foundToolUse).toBe(true);
// Validate result contains expected answer: 25 + 17 = 42
expect(assistantText).toMatch(/42/);
// Validate successful completion
const lastMessage = messages[messages.length - 1];
expect(isSDKResultMessage(lastMessage)).toBe(true);
if (isSDKResultMessage(lastMessage)) {
expect(lastMessage.subtype).toBe('success');
}
} finally {
await q.close();
}
});
it('should use SDK MCP tool with string operations', async () => {
// Define a string manipulation tool with Zod schema
const stringTool = tool(
'reverse_string',
'Reverse a string',
{
text: z.string().describe('The text to reverse'),
},
async (args) => ({
content: [
{ type: 'text', text: args.text.split('').reverse().join('') },
],
}),
);
const serverConfig = createSdkMcpServer({
name: 'sdk-string-utils',
version: '1.0.0',
tools: [stringTool],
});
const q = query({
prompt: `Use the 'reverse_string' tool to process the word "hello world". Output the tool result only.`,
options: {
...SHARED_TEST_OPTIONS,
cwd: testDir,
mcpServers: {
'sdk-string-utils': serverConfig,
},
},
});
const messages: SDKMessage[] = [];
let assistantText = '';
let foundToolUse = false;
try {
for await (const message of q) {
messages.push(message);
if (isSDKAssistantMessage(message)) {
const toolUseBlocks = findToolUseBlocks(message, 'reverse_string');
if (toolUseBlocks.length > 0) {
foundToolUse = true;
}
assistantText += extractText(message.message.content);
}
}
console.log(JSON.stringify(messages, null, 2));
// Validate tool was called
expect(foundToolUse).toBe(true);
// Validate result contains reversed string: "olleh"
expect(assistantText.toLowerCase()).toMatch(/olleh/);
// Validate successful completion
const lastMessage = messages[messages.length - 1];
expect(isSDKResultMessage(lastMessage)).toBe(true);
} finally {
await q.close();
}
});
});
describe('Multiple SDK MCP Tools', () => {
it('should use multiple tools from the same SDK MCP server', async () => {
// Define the Zod schema shape for two numbers
const twoNumbersSchema = {
a: z.number().describe('First number'),
b: z.number().describe('Second number'),
};
// Define multiple tools
const addTool = tool(
'sdk_add',
'Add two numbers',
twoNumbersSchema,
async (args) => ({
content: [{ type: 'text', text: String(args.a + args.b) }],
}),
);
const multiplyTool = tool(
'sdk_multiply',
'Multiply two numbers',
twoNumbersSchema,
async (args) => ({
content: [{ type: 'text', text: String(args.a * args.b) }],
}),
);
const serverConfig = createSdkMcpServer({
name: 'sdk-math',
version: '1.0.0',
tools: [addTool, multiplyTool],
});
const q = query({
prompt:
'First use sdk_add to calculate 10 + 5, then use sdk_multiply to multiply the result by 3. Give me the final answer.',
options: {
...SHARED_TEST_OPTIONS,
cwd: testDir,
debug: false,
mcpServers: {
'sdk-math': serverConfig,
},
},
});
const messages: SDKMessage[] = [];
let assistantText = '';
const toolCalls: string[] = [];
try {
for await (const message of q) {
messages.push(message);
if (isSDKAssistantMessage(message)) {
const toolUseBlocks = findToolUseBlocks(message);
toolUseBlocks.forEach((block) => {
toolCalls.push(block.name);
});
assistantText += extractText(message.message.content);
}
}
// Validate both tools were called
expect(toolCalls).toContain('sdk_add');
expect(toolCalls).toContain('sdk_multiply');
// Validate result: (10 + 5) * 3 = 45
expect(assistantText).toMatch(/45/);
// Validate successful completion
const lastMessage = messages[messages.length - 1];
expect(isSDKResultMessage(lastMessage)).toBe(true);
} finally {
await q.close();
}
});
});
describe('SDK MCP Server Discovery', () => {
it('should list SDK MCP servers in system init message', async () => {
// Define echo tool with Zod schema
const echoTool = tool(
'echo',
'Echo a message',
{
message: z.string().describe('Message to echo'),
},
async (args) => ({
content: [{ type: 'text', text: args.message }],
}),
);
const serverConfig = createSdkMcpServer({
name: 'sdk-echo',
version: '1.0.0',
tools: [echoTool],
});
const q = query({
prompt: 'Hello',
options: {
...SHARED_TEST_OPTIONS,
cwd: testDir,
debug: false,
mcpServers: {
'sdk-echo': serverConfig,
},
},
});
let systemMessage: SDKSystemMessage | null = null;
try {
for await (const message of q) {
if (isSDKSystemMessage(message) && message.subtype === 'init') {
systemMessage = message;
break;
}
}
// Validate MCP server is listed
expect(systemMessage).not.toBeNull();
expect(systemMessage!.mcp_servers).toBeDefined();
expect(Array.isArray(systemMessage!.mcp_servers)).toBe(true);
// Find our SDK MCP server
const sdkServer = systemMessage!.mcp_servers?.find(
(server) => server.name === 'sdk-echo',
);
expect(sdkServer).toBeDefined();
} finally {
await q.close();
}
});
});
describe('SDK MCP Tool Error Handling', () => {
it('should handle tool errors gracefully', async () => {
// Define a tool that throws an error with Zod schema
const errorTool = tool(
'maybe_fail',
'A tool that may fail based on input',
{
shouldFail: z.boolean().describe('If true, the tool will fail'),
},
async (args) => {
if (args.shouldFail) {
throw new Error('Tool intentionally failed');
}
return { content: [{ type: 'text', text: 'Success!' }] };
},
);
const serverConfig = createSdkMcpServer({
name: 'sdk-error-test',
version: '1.0.0',
tools: [errorTool],
});
const q = query({
prompt:
'Use the maybe_fail tool with shouldFail set to true. Tell me what happens.',
options: {
...SHARED_TEST_OPTIONS,
cwd: testDir,
debug: false,
mcpServers: {
'sdk-error-test': serverConfig,
},
},
});
const messages: SDKMessage[] = [];
let foundToolUse = false;
try {
for await (const message of q) {
messages.push(message);
if (isSDKAssistantMessage(message)) {
const toolUseBlocks = findToolUseBlocks(message, 'maybe_fail');
if (toolUseBlocks.length > 0) {
foundToolUse = true;
}
}
}
// Tool should be called
expect(foundToolUse).toBe(true);
// Query should complete (even with tool error)
const lastMessage = messages[messages.length - 1];
expect(isSDKResultMessage(lastMessage)).toBe(true);
} finally {
await q.close();
}
});
});
describe('Async Tool Handlers', () => {
it('should handle async tool handlers with delays', async () => {
// Define a tool with async delay using Zod schema
const delayedTool = tool(
'delayed_response',
'Returns a value after a delay',
{
delay: z.number().describe('Delay in milliseconds (max 100)'),
value: z.string().describe('Value to return'),
},
async (args) => {
// Cap delay at 100ms for test performance
const actualDelay = Math.min(args.delay, 100);
await new Promise((resolve) => setTimeout(resolve, actualDelay));
return {
content: [{ type: 'text', text: `Delayed result: ${args.value}` }],
};
},
);
const serverConfig = createSdkMcpServer({
name: 'sdk-async',
version: '1.0.0',
tools: [delayedTool],
});
const q = query({
prompt:
'Use the delayed_response tool with delay=50 and value="test_async". Tell me the result.',
options: {
...SHARED_TEST_OPTIONS,
cwd: testDir,
debug: false,
mcpServers: {
'sdk-async': serverConfig,
},
},
});
const messages: SDKMessage[] = [];
let assistantText = '';
let foundToolUse = false;
try {
for await (const message of q) {
messages.push(message);
if (isSDKAssistantMessage(message)) {
const toolUseBlocks = findToolUseBlocks(
message,
'delayed_response',
);
if (toolUseBlocks.length > 0) {
foundToolUse = true;
}
assistantText += extractText(message.message.content);
}
}
// Validate tool was called
expect(foundToolUse).toBe(true);
// Validate result contains the delayed response
expect(assistantText.toLowerCase()).toMatch(/test_async/i);
// Validate successful completion
const lastMessage = messages[messages.length - 1];
expect(isSDKResultMessage(lastMessage)).toBe(true);
} finally {
await q.close();
}
});
});
});

View File

@@ -44,7 +44,6 @@ describe('Single-Turn Query (E2E)', () => {
...SHARED_TEST_OPTIONS, ...SHARED_TEST_OPTIONS,
cwd: testDir, cwd: testDir,
debug: true, debug: true,
logLevel: 'debug',
}, },
}); });

View File

@@ -37,6 +37,10 @@
"test:integration:sandbox:none": "cross-env GEMINI_SANDBOX=false vitest run --root ./integration-tests", "test:integration:sandbox:none": "cross-env GEMINI_SANDBOX=false vitest run --root ./integration-tests",
"test:integration:sandbox:docker": "cross-env GEMINI_SANDBOX=docker npm run build:sandbox && GEMINI_SANDBOX=docker vitest run --root ./integration-tests", "test:integration:sandbox:docker": "cross-env GEMINI_SANDBOX=docker npm run build:sandbox && GEMINI_SANDBOX=docker vitest run --root ./integration-tests",
"test:integration:sandbox:podman": "cross-env GEMINI_SANDBOX=podman vitest run --root ./integration-tests", "test:integration:sandbox:podman": "cross-env GEMINI_SANDBOX=podman vitest run --root ./integration-tests",
"test:integration:sdk:sandbox:none": "cross-env GEMINI_SANDBOX=false vitest run --root ./integration-tests --dir sdk-typescript",
"test:integration:sdk:sandbox:docker": "cross-env GEMINI_SANDBOX=docker npm run build:sandbox && GEMINI_SANDBOX=docker vitest run --root ./integration-tests --dir sdk-typescript",
"test:integration:cli:sandbox:none": "cross-env GEMINI_SANDBOX=false vitest run --root ./integration-tests --exclude '**/sdk-typescript/**'",
"test:integration:cli:sandbox:docker": "cross-env GEMINI_SANDBOX=docker npm run build:sandbox && GEMINI_SANDBOX=docker vitest run --root ./integration-tests --exclude '**/sdk-typescript/**'",
"test:terminal-bench": "cross-env VERBOSE=true KEEP_OUTPUT=true vitest run --config ./vitest.terminal-bench.config.ts --root ./integration-tests", "test:terminal-bench": "cross-env VERBOSE=true KEEP_OUTPUT=true vitest run --config ./vitest.terminal-bench.config.ts --root ./integration-tests",
"test:terminal-bench:oracle": "cross-env VERBOSE=true KEEP_OUTPUT=true vitest run --config ./vitest.terminal-bench.config.ts --root ./integration-tests -t 'oracle'", "test:terminal-bench:oracle": "cross-env VERBOSE=true KEEP_OUTPUT=true vitest run --config ./vitest.terminal-bench.config.ts --root ./integration-tests -t 'oracle'",
"test:terminal-bench:qwen": "cross-env VERBOSE=true KEEP_OUTPUT=true vitest run --config ./vitest.terminal-bench.config.ts --root ./integration-tests -t 'qwen'", "test:terminal-bench:qwen": "cross-env VERBOSE=true KEEP_OUTPUT=true vitest run --config ./vitest.terminal-bench.config.ts --root ./integration-tests -t 'qwen'",

View File

@@ -276,8 +276,11 @@ export async function main() {
process.exit(1); process.exit(1);
} }
} }
// For stream-json mode, don't read stdin here - it should be forwarded to the sandbox
// and consumed by StreamJsonInputReader inside the container
const inputFormat = argv.inputFormat as string | undefined;
let stdinData = ''; let stdinData = '';
if (!process.stdin.isTTY) { if (!process.stdin.isTTY && inputFormat !== 'stream-json') {
stdinData = await readStdin(); stdinData = await readStdin();
} }

View File

@@ -16,9 +16,12 @@
* Controllers: * Controllers:
* - SystemController: initialize, interrupt, set_model, supported_commands * - SystemController: initialize, interrupt, set_model, supported_commands
* - PermissionController: can_use_tool, set_permission_mode * - PermissionController: can_use_tool, set_permission_mode
* - MCPController: mcp_message, mcp_server_status * - SdkMcpController: mcp_server_status (mcp_message handled via callback)
* - HookController: hook_callback * - HookController: hook_callback
* *
* Note: mcp_message requests are NOT routed through the dispatcher. CLI MCP
* clients send messages via SdkMcpController.createSendSdkMcpMessage() callback.
*
* Note: Control request types are centrally defined in the ControlRequestType * Note: Control request types are centrally defined in the ControlRequestType
* enum in packages/sdk/typescript/src/types/controlRequests.ts * enum in packages/sdk/typescript/src/types/controlRequests.ts
*/ */
@@ -27,7 +30,7 @@ import type { IControlContext } from './ControlContext.js';
import type { IPendingRequestRegistry } from './controllers/baseController.js'; import type { IPendingRequestRegistry } from './controllers/baseController.js';
import { SystemController } from './controllers/systemController.js'; import { SystemController } from './controllers/systemController.js';
import { PermissionController } from './controllers/permissionController.js'; import { PermissionController } from './controllers/permissionController.js';
// import { MCPController } from './controllers/mcpController.js'; import { SdkMcpController } from './controllers/sdkMcpController.js';
// import { HookController } from './controllers/hookController.js'; // import { HookController } from './controllers/hookController.js';
import type { import type {
CLIControlRequest, CLIControlRequest,
@@ -65,7 +68,7 @@ export class ControlDispatcher implements IPendingRequestRegistry {
// Make controllers publicly accessible // Make controllers publicly accessible
readonly systemController: SystemController; readonly systemController: SystemController;
readonly permissionController: PermissionController; readonly permissionController: PermissionController;
// readonly mcpController: MCPController; readonly sdkMcpController: SdkMcpController;
// readonly hookController: HookController; // readonly hookController: HookController;
// Central pending request registries // Central pending request registries
@@ -88,7 +91,11 @@ export class ControlDispatcher implements IPendingRequestRegistry {
this, this,
'PermissionController', 'PermissionController',
); );
// this.mcpController = new MCPController(context, this, 'MCPController'); this.sdkMcpController = new SdkMcpController(
context,
this,
'SdkMcpController',
);
// this.hookController = new HookController(context, this, 'HookController'); // this.hookController = new HookController(context, this, 'HookController');
// Listen for main abort signal // Listen for main abort signal
@@ -228,10 +235,10 @@ export class ControlDispatcher implements IPendingRequestRegistry {
} }
this.pendingOutgoingRequests.clear(); this.pendingOutgoingRequests.clear();
// Cleanup controllers (MCP controller will close all clients) // Cleanup controllers
this.systemController.cleanup(); this.systemController.cleanup();
this.permissionController.cleanup(); this.permissionController.cleanup();
// this.mcpController.cleanup(); this.sdkMcpController.cleanup();
// this.hookController.cleanup(); // this.hookController.cleanup();
} }
@@ -291,6 +298,47 @@ export class ControlDispatcher implements IPendingRequestRegistry {
} }
} }
/**
* Get count of pending incoming requests (for debugging)
*/
getPendingIncomingRequestCount(): number {
return this.pendingIncomingRequests.size;
}
/**
* Wait for all incoming request handlers to complete.
*
* Uses polling since we don't have direct Promise references to handlers.
* The pendingIncomingRequests map is managed by BaseController:
* - Registered when handler starts (in handleRequest)
* - Deregistered when handler completes (success or error)
*
* @param pollIntervalMs - How often to check (default 50ms)
* @param timeoutMs - Maximum wait time (default 30s)
*/
async waitForPendingIncomingRequests(
pollIntervalMs: number = 50,
timeoutMs: number = 30000,
): Promise<void> {
const startTime = Date.now();
while (this.pendingIncomingRequests.size > 0) {
if (Date.now() - startTime > timeoutMs) {
if (this.context.debugMode) {
console.error(
`[ControlDispatcher] Timeout waiting for ${this.pendingIncomingRequests.size} pending incoming requests`,
);
}
break;
}
await new Promise((resolve) => setTimeout(resolve, pollIntervalMs));
}
if (this.context.debugMode && this.pendingIncomingRequests.size === 0) {
console.error('[ControlDispatcher] All incoming requests completed');
}
}
/** /**
* Returns the controller that handles the given request subtype * Returns the controller that handles the given request subtype
*/ */
@@ -306,9 +354,8 @@ export class ControlDispatcher implements IPendingRequestRegistry {
case 'set_permission_mode': case 'set_permission_mode':
return this.permissionController; return this.permissionController;
// case 'mcp_message': case 'mcp_server_status':
// case 'mcp_server_status': return this.sdkMcpController;
// return this.mcpController;
// case 'hook_callback': // case 'hook_callback':
// return this.hookController; // return this.hookController;

View File

@@ -117,16 +117,41 @@ export abstract class BaseController {
* Send an outgoing control request to SDK * Send an outgoing control request to SDK
* *
* Manages lifecycle: register -> send -> wait for response -> deregister * Manages lifecycle: register -> send -> wait for response -> deregister
* Respects the provided AbortSignal for cancellation.
*/ */
async sendControlRequest( async sendControlRequest(
payload: ControlRequestPayload, payload: ControlRequestPayload,
timeoutMs: number = DEFAULT_REQUEST_TIMEOUT_MS, timeoutMs: number = DEFAULT_REQUEST_TIMEOUT_MS,
signal?: AbortSignal,
): Promise<ControlResponse> { ): Promise<ControlResponse> {
// Check if already aborted
if (signal?.aborted) {
throw new Error('Request aborted');
}
const requestId = randomUUID(); const requestId = randomUUID();
return new Promise<ControlResponse>((resolve, reject) => { return new Promise<ControlResponse>((resolve, reject) => {
// Setup abort handler
const abortHandler = () => {
this.registry.deregisterOutgoingRequest(requestId);
reject(new Error('Request aborted'));
if (this.context.debugMode) {
console.error(
`[${this.controllerName}] Outgoing request aborted: ${requestId}`,
);
}
};
if (signal) {
signal.addEventListener('abort', abortHandler, { once: true });
}
// Setup timeout // Setup timeout
const timeoutId = setTimeout(() => { const timeoutId = setTimeout(() => {
if (signal) {
signal.removeEventListener('abort', abortHandler);
}
this.registry.deregisterOutgoingRequest(requestId); this.registry.deregisterOutgoingRequest(requestId);
reject(new Error('Control request timeout')); reject(new Error('Control request timeout'));
if (this.context.debugMode) { if (this.context.debugMode) {
@@ -136,12 +161,27 @@ export abstract class BaseController {
} }
}, timeoutMs); }, timeoutMs);
// Wrap resolve/reject to clean up abort listener
const wrappedResolve = (response: ControlResponse) => {
if (signal) {
signal.removeEventListener('abort', abortHandler);
}
resolve(response);
};
const wrappedReject = (error: Error) => {
if (signal) {
signal.removeEventListener('abort', abortHandler);
}
reject(error);
};
// Register with central registry // Register with central registry
this.registry.registerOutgoingRequest( this.registry.registerOutgoingRequest(
requestId, requestId,
this.controllerName, this.controllerName,
resolve, wrappedResolve,
reject, wrappedReject,
timeoutId, timeoutId,
); );
@@ -155,6 +195,9 @@ export abstract class BaseController {
try { try {
this.context.streamJson.send(request); this.context.streamJson.send(request);
} catch (error) { } catch (error) {
if (signal) {
signal.removeEventListener('abort', abortHandler);
}
this.registry.deregisterOutgoingRequest(requestId); this.registry.deregisterOutgoingRequest(requestId);
reject(error); reject(error);
} }

View File

@@ -1,287 +0,0 @@
/**
* @license
* Copyright 2025 Qwen Team
* SPDX-License-Identifier: Apache-2.0
*/
/**
* MCP Controller
*
* Handles MCP-related control requests:
* - mcp_message: Route MCP messages
* - mcp_server_status: Return MCP server status
*/
import { BaseController } from './baseController.js';
import type { Client } from '@modelcontextprotocol/sdk/client/index.js';
import { ResultSchema } from '@modelcontextprotocol/sdk/types.js';
import type {
ControlRequestPayload,
CLIControlMcpMessageRequest,
} from '../../types.js';
import type {
MCPServerConfig,
WorkspaceContext,
} from '@qwen-code/qwen-code-core';
import {
connectToMcpServer,
MCP_DEFAULT_TIMEOUT_MSEC,
} from '@qwen-code/qwen-code-core';
export class MCPController extends BaseController {
/**
* Handle MCP control requests
*/
protected async handleRequestPayload(
payload: ControlRequestPayload,
_signal: AbortSignal,
): Promise<Record<string, unknown>> {
switch (payload.subtype) {
case 'mcp_message':
return this.handleMcpMessage(payload as CLIControlMcpMessageRequest);
case 'mcp_server_status':
return this.handleMcpStatus();
default:
throw new Error(`Unsupported request subtype in MCPController`);
}
}
/**
* Handle mcp_message request
*
* Routes JSON-RPC messages to MCP servers
*/
private async handleMcpMessage(
payload: CLIControlMcpMessageRequest,
): Promise<Record<string, unknown>> {
const serverNameRaw = payload.server_name;
if (
typeof serverNameRaw !== 'string' ||
serverNameRaw.trim().length === 0
) {
throw new Error('Missing server_name in mcp_message request');
}
const message = payload.message;
if (!message || typeof message !== 'object') {
throw new Error(
'Missing or invalid message payload for mcp_message request',
);
}
// Get or create MCP client
let clientEntry: { client: Client; config: MCPServerConfig };
try {
clientEntry = await this.getOrCreateMcpClient(serverNameRaw.trim());
} catch (error) {
throw new Error(
error instanceof Error
? error.message
: 'Failed to connect to MCP server',
);
}
const method = message.method;
if (typeof method !== 'string' || method.trim().length === 0) {
throw new Error('Invalid MCP message: missing method');
}
const jsonrpcVersion =
typeof message.jsonrpc === 'string' ? message.jsonrpc : '2.0';
const messageId = message.id;
const params = message.params;
const timeout =
typeof clientEntry.config.timeout === 'number'
? clientEntry.config.timeout
: MCP_DEFAULT_TIMEOUT_MSEC;
try {
// Handle notification (no id)
if (messageId === undefined) {
await clientEntry.client.notification({
method,
params,
});
return {
subtype: 'mcp_message',
mcp_response: {
jsonrpc: jsonrpcVersion,
id: null,
result: { success: true, acknowledged: true },
},
};
}
// Handle request (with id)
const result = await clientEntry.client.request(
{
method,
params,
},
ResultSchema,
{ timeout },
);
return {
subtype: 'mcp_message',
mcp_response: {
jsonrpc: jsonrpcVersion,
id: messageId,
result,
},
};
} catch (error) {
// If connection closed, remove from cache
if (error instanceof Error && /closed/i.test(error.message)) {
this.context.mcpClients.delete(serverNameRaw.trim());
}
const errorCode =
typeof (error as { code?: unknown })?.code === 'number'
? ((error as { code: number }).code as number)
: -32603;
const errorMessage =
error instanceof Error
? error.message
: 'Failed to execute MCP request';
const errorData = (error as { data?: unknown })?.data;
const errorBody: Record<string, unknown> = {
code: errorCode,
message: errorMessage,
};
if (errorData !== undefined) {
errorBody['data'] = errorData;
}
return {
subtype: 'mcp_message',
mcp_response: {
jsonrpc: jsonrpcVersion,
id: messageId ?? null,
error: errorBody,
},
};
}
}
/**
* Handle mcp_server_status request
*
* Returns status of registered MCP servers
*/
private async handleMcpStatus(): Promise<Record<string, unknown>> {
const status: Record<string, string> = {};
// Include SDK MCP servers
for (const serverName of this.context.sdkMcpServers) {
status[serverName] = 'connected';
}
// Include CLI-managed MCP clients
for (const serverName of this.context.mcpClients.keys()) {
status[serverName] = 'connected';
}
if (this.context.debugMode) {
console.error(
`[MCPController] MCP status: ${Object.keys(status).length} servers`,
);
}
return status;
}
/**
* Get or create MCP client for a server
*
* Implements lazy connection and caching
*/
private async getOrCreateMcpClient(
serverName: string,
): Promise<{ client: Client; config: MCPServerConfig }> {
// Check cache first
const cached = this.context.mcpClients.get(serverName);
if (cached) {
return cached;
}
// Get server configuration
const provider = this.context.config as unknown as {
getMcpServers?: () => Record<string, MCPServerConfig> | undefined;
getDebugMode?: () => boolean;
getWorkspaceContext?: () => unknown;
};
if (typeof provider.getMcpServers !== 'function') {
throw new Error(`MCP server "${serverName}" is not configured`);
}
const servers = provider.getMcpServers() ?? {};
const serverConfig = servers[serverName];
if (!serverConfig) {
throw new Error(`MCP server "${serverName}" is not configured`);
}
const debugMode =
typeof provider.getDebugMode === 'function'
? provider.getDebugMode()
: false;
const workspaceContext =
typeof provider.getWorkspaceContext === 'function'
? provider.getWorkspaceContext()
: undefined;
if (!workspaceContext) {
throw new Error('Workspace context is not available for MCP connection');
}
// Connect to MCP server
const client = await connectToMcpServer(
serverName,
serverConfig,
debugMode,
workspaceContext as WorkspaceContext,
);
// Cache the client
const entry = { client, config: serverConfig };
this.context.mcpClients.set(serverName, entry);
if (this.context.debugMode) {
console.error(`[MCPController] Connected to MCP server: ${serverName}`);
}
return entry;
}
/**
* Cleanup MCP clients
*/
override cleanup(): void {
if (this.context.debugMode) {
console.error(
`[MCPController] Cleaning up ${this.context.mcpClients.size} MCP clients`,
);
}
// Close all MCP clients
for (const [serverName, { client }] of this.context.mcpClients.entries()) {
try {
client.close();
} catch (error) {
if (this.context.debugMode) {
console.error(
`[MCPController] Failed to close MCP client ${serverName}:`,
error,
);
}
}
}
this.context.mcpClients.clear();
}
}

View File

@@ -44,15 +44,23 @@ export class PermissionController extends BaseController {
*/ */
protected async handleRequestPayload( protected async handleRequestPayload(
payload: ControlRequestPayload, payload: ControlRequestPayload,
_signal: AbortSignal, signal: AbortSignal,
): Promise<Record<string, unknown>> { ): Promise<Record<string, unknown>> {
if (signal.aborted) {
throw new Error('Request aborted');
}
switch (payload.subtype) { switch (payload.subtype) {
case 'can_use_tool': case 'can_use_tool':
return this.handleCanUseTool(payload as CLIControlPermissionRequest); return this.handleCanUseTool(
payload as CLIControlPermissionRequest,
signal,
);
case 'set_permission_mode': case 'set_permission_mode':
return this.handleSetPermissionMode( return this.handleSetPermissionMode(
payload as CLIControlSetPermissionModeRequest, payload as CLIControlSetPermissionModeRequest,
signal,
); );
default: default:
@@ -70,7 +78,12 @@ export class PermissionController extends BaseController {
*/ */
private async handleCanUseTool( private async handleCanUseTool(
payload: CLIControlPermissionRequest, payload: CLIControlPermissionRequest,
signal: AbortSignal,
): Promise<Record<string, unknown>> { ): Promise<Record<string, unknown>> {
if (signal.aborted) {
throw new Error('Request aborted');
}
const toolName = payload.tool_name; const toolName = payload.tool_name;
if ( if (
!toolName || !toolName ||
@@ -192,7 +205,12 @@ export class PermissionController extends BaseController {
*/ */
private async handleSetPermissionMode( private async handleSetPermissionMode(
payload: CLIControlSetPermissionModeRequest, payload: CLIControlSetPermissionModeRequest,
signal: AbortSignal,
): Promise<Record<string, unknown>> { ): Promise<Record<string, unknown>> {
if (signal.aborted) {
throw new Error('Request aborted');
}
const mode = payload.mode; const mode = payload.mode;
const validModes: PermissionMode[] = [ const validModes: PermissionMode[] = [
'default', 'default',
@@ -373,6 +391,14 @@ export class PermissionController extends BaseController {
toolCall: WaitingToolCall, toolCall: WaitingToolCall,
): Promise<void> { ): Promise<void> {
try { try {
// Check if already aborted
if (this.context.abortSignal?.aborted) {
await toolCall.confirmationDetails.onConfirm(
ToolConfirmationOutcome.Cancel,
);
return;
}
const inputFormat = this.context.config.getInputFormat?.(); const inputFormat = this.context.config.getInputFormat?.();
const isStreamJsonMode = inputFormat === InputFormat.STREAM_JSON; const isStreamJsonMode = inputFormat === InputFormat.STREAM_JSON;
@@ -392,14 +418,18 @@ export class PermissionController extends BaseController {
toolCall.confirmationDetails, toolCall.confirmationDetails,
); );
const response = await this.sendControlRequest({ const response = await this.sendControlRequest(
subtype: 'can_use_tool', {
tool_name: toolCall.request.name, subtype: 'can_use_tool',
tool_use_id: toolCall.request.callId, tool_name: toolCall.request.name,
input: toolCall.request.args, tool_use_id: toolCall.request.callId,
permission_suggestions: permissionSuggestions, input: toolCall.request.args,
blocked_path: null, permission_suggestions: permissionSuggestions,
} as CLIControlPermissionRequest); blocked_path: null,
} as CLIControlPermissionRequest,
undefined, // use default timeout
this.context.abortSignal,
);
if (response.subtype !== 'success') { if (response.subtype !== 'success') {
await toolCall.confirmationDetails.onConfirm( await toolCall.confirmationDetails.onConfirm(

View File

@@ -0,0 +1,138 @@
/**
* @license
* Copyright 2025 Qwen Team
* SPDX-License-Identifier: Apache-2.0
*/
/**
* SDK MCP Controller
*
* Handles MCP communication between CLI MCP clients and SDK MCP servers:
* - Provides sendSdkMcpMessage callback for CLI → SDK MCP message routing
* - mcp_server_status: Returns status of SDK MCP servers
*
* Message Flow (CLI MCP Client → SDK MCP Server):
* CLI MCP Client → SdkControlClientTransport.send() →
* sendSdkMcpMessage callback → control_request (mcp_message) → SDK →
* SDK MCP Server processes → control_response → CLI MCP Client
*/
import type { JSONRPCMessage } from '@modelcontextprotocol/sdk/types.js';
import { BaseController } from './baseController.js';
import type {
ControlRequestPayload,
CLIControlMcpMessageRequest,
} from '../../types.js';
const MCP_REQUEST_TIMEOUT = 30_000; // 30 seconds
export class SdkMcpController extends BaseController {
/**
* Handle SDK MCP control requests from ControlDispatcher
*
* Note: mcp_message requests are NOT handled here. CLI MCP clients
* send messages via the sendSdkMcpMessage callback directly, not
* through the control dispatcher.
*/
protected async handleRequestPayload(
payload: ControlRequestPayload,
signal: AbortSignal,
): Promise<Record<string, unknown>> {
if (signal.aborted) {
throw new Error('Request aborted');
}
switch (payload.subtype) {
case 'mcp_server_status':
return this.handleMcpStatus();
default:
throw new Error(`Unsupported request subtype in SdkMcpController`);
}
}
/**
* Handle mcp_server_status request
*
* Returns status of all registered SDK MCP servers.
* SDK servers are considered "connected" if they are registered.
*/
private async handleMcpStatus(): Promise<Record<string, unknown>> {
const status: Record<string, string> = {};
for (const serverName of this.context.sdkMcpServers) {
// SDK MCP servers are "connected" once registered since they run in SDK process
status[serverName] = 'connected';
}
return {
subtype: 'mcp_server_status',
status,
};
}
/**
* Send MCP message to SDK server via control plane
*
* @param serverName - Name of the SDK MCP server
* @param message - MCP JSON-RPC message to send
* @returns MCP JSON-RPC response from SDK server
*/
private async sendMcpMessageToSdk(
serverName: string,
message: JSONRPCMessage,
): Promise<JSONRPCMessage> {
if (this.context.debugMode) {
console.error(
`[SdkMcpController] Sending MCP message to SDK server '${serverName}':`,
JSON.stringify(message),
);
}
// Send control request to SDK with the MCP message
const response = await this.sendControlRequest(
{
subtype: 'mcp_message',
server_name: serverName,
message: message as CLIControlMcpMessageRequest['message'],
},
MCP_REQUEST_TIMEOUT,
this.context.abortSignal,
);
// Extract MCP response from control response
const responsePayload = response.response as Record<string, unknown>;
const mcpResponse = responsePayload?.['mcp_response'] as JSONRPCMessage;
if (!mcpResponse) {
throw new Error(
`Invalid MCP response from SDK for server '${serverName}'`,
);
}
if (this.context.debugMode) {
console.error(
`[SdkMcpController] Received MCP response from SDK server '${serverName}':`,
JSON.stringify(mcpResponse),
);
}
return mcpResponse;
}
/**
* Create a callback function for sending MCP messages to SDK servers.
*
* This callback is used by McpClientManager/SdkControlClientTransport to send
* MCP messages from CLI MCP clients to SDK MCP servers via the control plane.
*
* @returns A function that sends MCP messages to SDK and returns the response
*/
createSendSdkMcpMessage(): (
serverName: string,
message: JSONRPCMessage,
) => Promise<JSONRPCMessage> {
return (serverName: string, message: JSONRPCMessage) =>
this.sendMcpMessageToSdk(serverName, message);
}
}

View File

@@ -18,9 +18,15 @@ import type {
ControlRequestPayload, ControlRequestPayload,
CLIControlInitializeRequest, CLIControlInitializeRequest,
CLIControlSetModelRequest, CLIControlSetModelRequest,
CLIMcpServerConfig,
} from '../../types.js'; } from '../../types.js';
import { CommandService } from '../../../services/CommandService.js'; import { CommandService } from '../../../services/CommandService.js';
import { BuiltinCommandLoader } from '../../../services/BuiltinCommandLoader.js'; import { BuiltinCommandLoader } from '../../../services/BuiltinCommandLoader.js';
import {
MCPServerConfig,
AuthProviderType,
type MCPOAuthConfig,
} from '@qwen-code/qwen-code-core';
export class SystemController extends BaseController { export class SystemController extends BaseController {
/** /**
@@ -28,20 +34,30 @@ export class SystemController extends BaseController {
*/ */
protected async handleRequestPayload( protected async handleRequestPayload(
payload: ControlRequestPayload, payload: ControlRequestPayload,
_signal: AbortSignal, signal: AbortSignal,
): Promise<Record<string, unknown>> { ): Promise<Record<string, unknown>> {
if (signal.aborted) {
throw new Error('Request aborted');
}
switch (payload.subtype) { switch (payload.subtype) {
case 'initialize': case 'initialize':
return this.handleInitialize(payload as CLIControlInitializeRequest); return this.handleInitialize(
payload as CLIControlInitializeRequest,
signal,
);
case 'interrupt': case 'interrupt':
return this.handleInterrupt(); return this.handleInterrupt();
case 'set_model': case 'set_model':
return this.handleSetModel(payload as CLIControlSetModelRequest); return this.handleSetModel(
payload as CLIControlSetModelRequest,
signal,
);
case 'supported_commands': case 'supported_commands':
return this.handleSupportedCommands(); return this.handleSupportedCommands(signal);
default: default:
throw new Error(`Unsupported request subtype in SystemController`); throw new Error(`Unsupported request subtype in SystemController`);
@@ -51,46 +67,110 @@ export class SystemController extends BaseController {
/** /**
* Handle initialize request * Handle initialize request
* *
* Registers SDK MCP servers and returns capabilities * Processes SDK MCP servers config.
* SDK servers are registered in context.sdkMcpServers
* and added to config.mcpServers with the sdk type flag.
* External MCP servers are configured separately in settings.
*/ */
private async handleInitialize( private async handleInitialize(
payload: CLIControlInitializeRequest, payload: CLIControlInitializeRequest,
signal: AbortSignal,
): Promise<Record<string, unknown>> { ): Promise<Record<string, unknown>> {
if (signal.aborted) {
throw new Error('Request aborted');
}
this.context.config.setSdkMode(true); this.context.config.setSdkMode(true);
if (payload.sdkMcpServers && typeof payload.sdkMcpServers === 'object') { // Process SDK MCP servers
for (const serverName of Object.keys(payload.sdkMcpServers)) { if (
this.context.sdkMcpServers.add(serverName); payload.sdkMcpServers &&
typeof payload.sdkMcpServers === 'object' &&
payload.sdkMcpServers !== null
) {
const sdkServers: Record<string, MCPServerConfig> = {};
for (const [key, wireConfig] of Object.entries(payload.sdkMcpServers)) {
const name =
typeof wireConfig?.name === 'string' && wireConfig.name.trim().length
? wireConfig.name
: key;
this.context.sdkMcpServers.add(name);
sdkServers[name] = new MCPServerConfig(
undefined, // command
undefined, // args
undefined, // env
undefined, // cwd
undefined, // url
undefined, // httpUrl
undefined, // headers
undefined, // tcp
undefined, // timeout
true, // trust - SDK servers are trusted
undefined, // description
undefined, // includeTools
undefined, // excludeTools
undefined, // extensionName
undefined, // oauth
undefined, // authProviderType
undefined, // targetAudience
undefined, // targetServiceAccount
'sdk', // type
);
} }
try { const sdkServerCount = Object.keys(sdkServers).length;
this.context.config.addMcpServers(payload.sdkMcpServers); if (sdkServerCount > 0) {
if (this.context.debugMode) { try {
console.error( this.context.config.addMcpServers(sdkServers);
`[SystemController] Added ${Object.keys(payload.sdkMcpServers).length} SDK MCP servers to config`, if (this.context.debugMode) {
); console.error(
} `[SystemController] Added ${sdkServerCount} SDK MCP servers to config`,
} catch (error) { );
if (this.context.debugMode) { }
console.error( } catch (error) {
'[SystemController] Failed to add SDK MCP servers:', if (this.context.debugMode) {
error, console.error(
); '[SystemController] Failed to add SDK MCP servers:',
error,
);
}
} }
} }
} }
if (payload.mcpServers && typeof payload.mcpServers === 'object') { if (
try { payload.mcpServers &&
this.context.config.addMcpServers(payload.mcpServers); typeof payload.mcpServers === 'object' &&
if (this.context.debugMode) { payload.mcpServers !== null
console.error( ) {
`[SystemController] Added ${Object.keys(payload.mcpServers).length} MCP servers to config`, const externalServers: Record<string, MCPServerConfig> = {};
); for (const [name, serverConfig] of Object.entries(payload.mcpServers)) {
const normalized = this.normalizeMcpServerConfig(
name,
serverConfig as CLIMcpServerConfig | undefined,
);
if (normalized) {
externalServers[name] = normalized;
} }
} catch (error) { }
if (this.context.debugMode) {
console.error('[SystemController] Failed to add MCP servers:', error); const externalCount = Object.keys(externalServers).length;
if (externalCount > 0) {
try {
this.context.config.addMcpServers(externalServers);
if (this.context.debugMode) {
console.error(
`[SystemController] Added ${externalCount} external MCP servers to config`,
);
}
} catch (error) {
if (this.context.debugMode) {
console.error(
'[SystemController] Failed to add external MCP servers:',
error,
);
}
} }
} }
} }
@@ -143,13 +223,96 @@ export class SystemController extends BaseController {
can_set_permission_mode: can_set_permission_mode:
typeof this.context.config.setApprovalMode === 'function', typeof this.context.config.setApprovalMode === 'function',
can_set_model: typeof this.context.config.setModel === 'function', can_set_model: typeof this.context.config.setModel === 'function',
/* TODO: sdkMcpServers support */ // SDK MCP servers are supported - messages routed through control plane
can_handle_mcp_message: false, can_handle_mcp_message: true,
}; };
return capabilities; return capabilities;
} }
private normalizeMcpServerConfig(
serverName: string,
config?: CLIMcpServerConfig,
): MCPServerConfig | null {
if (!config || typeof config !== 'object') {
if (this.context.debugMode) {
console.error(
`[SystemController] Ignoring invalid MCP server config for '${serverName}'`,
);
}
return null;
}
const authProvider = this.normalizeAuthProviderType(
config.authProviderType,
);
const oauthConfig = this.normalizeOAuthConfig(config.oauth);
return new MCPServerConfig(
config.command,
config.args,
config.env,
config.cwd,
config.url,
config.httpUrl,
config.headers,
config.tcp,
config.timeout,
config.trust,
config.description,
config.includeTools,
config.excludeTools,
config.extensionName,
oauthConfig,
authProvider,
config.targetAudience,
config.targetServiceAccount,
);
}
private normalizeAuthProviderType(
value?: string,
): AuthProviderType | undefined {
if (!value) {
return undefined;
}
switch (value) {
case AuthProviderType.DYNAMIC_DISCOVERY:
case AuthProviderType.GOOGLE_CREDENTIALS:
case AuthProviderType.SERVICE_ACCOUNT_IMPERSONATION:
return value;
default:
if (this.context.debugMode) {
console.error(
`[SystemController] Unsupported authProviderType '${value}', skipping`,
);
}
return undefined;
}
}
private normalizeOAuthConfig(
oauth?: CLIMcpServerConfig['oauth'],
): MCPOAuthConfig | undefined {
if (!oauth) {
return undefined;
}
return {
enabled: oauth.enabled,
clientId: oauth.clientId,
clientSecret: oauth.clientSecret,
authorizationUrl: oauth.authorizationUrl,
tokenUrl: oauth.tokenUrl,
scopes: oauth.scopes,
audiences: oauth.audiences,
redirectUri: oauth.redirectUri,
tokenParamName: oauth.tokenParamName,
registrationUrl: oauth.registrationUrl,
};
}
/** /**
* Handle interrupt request * Handle interrupt request
* *
@@ -183,7 +346,12 @@ export class SystemController extends BaseController {
*/ */
private async handleSetModel( private async handleSetModel(
payload: CLIControlSetModelRequest, payload: CLIControlSetModelRequest,
signal: AbortSignal,
): Promise<Record<string, unknown>> { ): Promise<Record<string, unknown>> {
if (signal.aborted) {
throw new Error('Request aborted');
}
const model = payload.model; const model = payload.model;
// Validate model parameter // Validate model parameter
@@ -223,8 +391,14 @@ export class SystemController extends BaseController {
* *
* Returns list of supported slash commands loaded dynamically * Returns list of supported slash commands loaded dynamically
*/ */
private async handleSupportedCommands(): Promise<Record<string, unknown>> { private async handleSupportedCommands(
const slashCommands = await this.loadSlashCommandNames(); signal: AbortSignal,
): Promise<Record<string, unknown>> {
if (signal.aborted) {
throw new Error('Request aborted');
}
const slashCommands = await this.loadSlashCommandNames(signal);
return { return {
subtype: 'supported_commands', subtype: 'supported_commands',
@@ -235,15 +409,24 @@ export class SystemController extends BaseController {
/** /**
* Load slash command names using CommandService * Load slash command names using CommandService
* *
* @param signal - AbortSignal to respect for cancellation
* @returns Promise resolving to array of slash command names * @returns Promise resolving to array of slash command names
*/ */
private async loadSlashCommandNames(): Promise<string[]> { private async loadSlashCommandNames(signal: AbortSignal): Promise<string[]> {
const controller = new AbortController(); if (signal.aborted) {
return [];
}
try { try {
const service = await CommandService.create( const service = await CommandService.create(
[new BuiltinCommandLoader(this.context.config)], [new BuiltinCommandLoader(this.context.config)],
controller.signal, signal,
); );
if (signal.aborted) {
return [];
}
const names = new Set<string>(); const names = new Set<string>();
const commands = service.getCommands(); const commands = service.getCommands();
for (const command of commands) { for (const command of commands) {
@@ -251,6 +434,11 @@ export class SystemController extends BaseController {
} }
return Array.from(names).sort(); return Array.from(names).sort();
} catch (error) { } catch (error) {
// Check if the error is due to abort
if (signal.aborted) {
return [];
}
if (this.context.debugMode) { if (this.context.debugMode) {
console.error( console.error(
'[SystemController] Failed to load slash commands:', '[SystemController] Failed to load slash commands:',
@@ -258,8 +446,6 @@ export class SystemController extends BaseController {
); );
} }
return []; return [];
} finally {
controller.abort();
} }
} }
} }

View File

@@ -153,6 +153,11 @@ describe('runNonInteractiveStreamJson', () => {
handleControlResponse: ReturnType<typeof vi.fn>; handleControlResponse: ReturnType<typeof vi.fn>;
handleCancel: ReturnType<typeof vi.fn>; handleCancel: ReturnType<typeof vi.fn>;
shutdown: ReturnType<typeof vi.fn>; shutdown: ReturnType<typeof vi.fn>;
getPendingIncomingRequestCount: ReturnType<typeof vi.fn>;
waitForPendingIncomingRequests: ReturnType<typeof vi.fn>;
sdkMcpController: {
createSendSdkMcpMessage: ReturnType<typeof vi.fn>;
};
}; };
let mockConsolePatcher: { let mockConsolePatcher: {
patch: ReturnType<typeof vi.fn>; patch: ReturnType<typeof vi.fn>;
@@ -187,6 +192,11 @@ describe('runNonInteractiveStreamJson', () => {
handleControlResponse: vi.fn(), handleControlResponse: vi.fn(),
handleCancel: vi.fn(), handleCancel: vi.fn(),
shutdown: vi.fn(), shutdown: vi.fn(),
getPendingIncomingRequestCount: vi.fn().mockReturnValue(0),
waitForPendingIncomingRequests: vi.fn().mockResolvedValue(undefined),
sdkMcpController: {
createSendSdkMcpMessage: vi.fn().mockReturnValue(vi.fn()),
},
}; };
( (
ControlDispatcher as unknown as ReturnType<typeof vi.fn> ControlDispatcher as unknown as ReturnType<typeof vi.fn>

View File

@@ -4,7 +4,10 @@
* SPDX-License-Identifier: Apache-2.0 * SPDX-License-Identifier: Apache-2.0
*/ */
import type { Config } from '@qwen-code/qwen-code-core'; import type {
Config,
ConfigInitializeOptions,
} from '@qwen-code/qwen-code-core';
import { StreamJsonInputReader } from './io/StreamJsonInputReader.js'; import { StreamJsonInputReader } from './io/StreamJsonInputReader.js';
import { StreamJsonOutputAdapter } from './io/StreamJsonOutputAdapter.js'; import { StreamJsonOutputAdapter } from './io/StreamJsonOutputAdapter.js';
import { ControlContext } from './control/ControlContext.js'; import { ControlContext } from './control/ControlContext.js';
@@ -50,6 +53,12 @@ class Session {
private isShuttingDown: boolean = false; private isShuttingDown: boolean = false;
private configInitialized: boolean = false; private configInitialized: boolean = false;
// Single initialization promise that resolves when session is ready for user messages.
// Created lazily once initialization actually starts.
private initializationPromise: Promise<void> | null = null;
private initializationResolve: (() => void) | null = null;
private initializationReject: ((error: Error) => void) | null = null;
constructor(config: Config, initialPrompt?: CLIUserMessage) { constructor(config: Config, initialPrompt?: CLIUserMessage) {
this.config = config; this.config = config;
this.sessionId = config.getSessionId(); this.sessionId = config.getSessionId();
@@ -66,12 +75,32 @@ class Session {
this.setupSignalHandlers(); this.setupSignalHandlers();
} }
private ensureInitializationPromise(): void {
if (this.initializationPromise) {
return;
}
this.initializationPromise = new Promise<void>((resolve, reject) => {
this.initializationResolve = () => {
resolve();
this.initializationResolve = null;
this.initializationReject = null;
};
this.initializationReject = (error: Error) => {
reject(error);
this.initializationResolve = null;
this.initializationReject = null;
};
});
}
private getNextPromptId(): string { private getNextPromptId(): string {
this.promptIdCounter++; this.promptIdCounter++;
return `${this.sessionId}########${this.promptIdCounter}`; return `${this.sessionId}########${this.promptIdCounter}`;
} }
private async ensureConfigInitialized(): Promise<void> { private async ensureConfigInitialized(
options?: ConfigInitializeOptions,
): Promise<void> {
if (this.configInitialized) { if (this.configInitialized) {
return; return;
} }
@@ -81,7 +110,7 @@ class Session {
} }
try { try {
await this.config.initialize(); await this.config.initialize(options);
this.configInitialized = true; this.configInitialized = true;
} catch (error) { } catch (error) {
if (this.debugMode) { if (this.debugMode) {
@@ -91,6 +120,44 @@ class Session {
} }
} }
/**
* Mark initialization as complete
*/
private completeInitialization(): void {
if (this.initializationResolve) {
if (this.debugMode) {
console.error('[Session] Initialization complete');
}
this.initializationResolve();
this.initializationResolve = null;
this.initializationReject = null;
}
}
/**
* Mark initialization as failed
*/
private failInitialization(error: Error): void {
if (this.initializationReject) {
if (this.debugMode) {
console.error('[Session] Initialization failed:', error);
}
this.initializationReject(error);
this.initializationResolve = null;
this.initializationReject = null;
}
}
/**
* Wait for session to be ready for user messages
*/
private async waitForInitialization(): Promise<void> {
if (!this.initializationPromise) {
return;
}
await this.initializationPromise;
}
private ensureControlSystem(): void { private ensureControlSystem(): void {
if (this.controlContext && this.dispatcher && this.controlService) { if (this.controlContext && this.dispatcher && this.controlService) {
return; return;
@@ -120,49 +187,114 @@ class Session {
return this.dispatcher; return this.dispatcher;
} }
private async handleFirstMessage( /**
* Handle the first message to determine session mode (SDK vs direct).
* This is synchronous from the message loop's perspective - it starts
* async work but does not return a promise that the loop awaits.
*
* The initialization completes asynchronously and resolves initializationPromise
* when ready for user messages.
*/
private handleFirstMessage(
message: message:
| CLIMessage | CLIMessage
| CLIControlRequest | CLIControlRequest
| CLIControlResponse | CLIControlResponse
| ControlCancelRequest, | ControlCancelRequest,
): Promise<boolean> { ): void {
if (isControlRequest(message)) { if (isControlRequest(message)) {
const request = message as CLIControlRequest; const request = message as CLIControlRequest;
this.controlSystemEnabled = true; this.controlSystemEnabled = true;
this.ensureControlSystem(); this.ensureControlSystem();
if (request.request.subtype === 'initialize') {
// Dispatch the initialize request first
await this.dispatcher?.dispatch(request);
// After handling initialize control request, initialize the config if (request.request.subtype === 'initialize') {
// This is the SDK mode where config initialization is deferred // Start SDK mode initialization (fire-and-forget from loop perspective)
await this.ensureConfigInitialized(); void this.initializeSdkMode(request);
return true; return;
} }
if (this.debugMode) { if (this.debugMode) {
console.error( console.error(
'[Session] Ignoring non-initialize control request during initialization', '[Session] Ignoring non-initialize control request during initialization',
); );
} }
return true; return;
} }
if (isCLIUserMessage(message)) { if (isCLIUserMessage(message)) {
this.controlSystemEnabled = false; this.controlSystemEnabled = false;
// For non-SDK mode (direct user message), initialize config if not already done // Start direct mode initialization (fire-and-forget from loop perspective)
await this.ensureConfigInitialized(); void this.initializeDirectMode(message as CLIUserMessage);
this.enqueueUserMessage(message as CLIUserMessage); return;
return true;
} }
this.controlSystemEnabled = false; this.controlSystemEnabled = false;
return false;
} }
private async handleControlRequest( /**
request: CLIControlRequest, * SDK mode initialization flow
* Dispatches initialize request and initializes config with MCP support
*/
private async initializeSdkMode(request: CLIControlRequest): Promise<void> {
this.ensureInitializationPromise();
try {
// Dispatch the initialize request first
// This registers SDK MCP servers in the control context
await this.dispatcher?.dispatch(request);
// Get sendSdkMcpMessage callback from SdkMcpController
// This callback is used by McpClientManager to send MCP messages
// from CLI MCP clients to SDK MCP servers via the control plane
const sendSdkMcpMessage =
this.dispatcher?.sdkMcpController.createSendSdkMcpMessage();
// Initialize config with SDK MCP message support
await this.ensureConfigInitialized({ sendSdkMcpMessage });
// Initialization complete!
this.completeInitialization();
} catch (error) {
if (this.debugMode) {
console.error('[Session] SDK mode initialization failed:', error);
}
this.failInitialization(
error instanceof Error ? error : new Error(String(error)),
);
}
}
/**
* Direct mode initialization flow
* Initializes config and enqueues the first user message
*/
private async initializeDirectMode(
userMessage: CLIUserMessage,
): Promise<void> { ): Promise<void> {
this.ensureInitializationPromise();
try {
// Initialize config
await this.ensureConfigInitialized();
// Initialization complete!
this.completeInitialization();
// Enqueue the first user message for processing
this.enqueueUserMessage(userMessage);
} catch (error) {
if (this.debugMode) {
console.error('[Session] Direct mode initialization failed:', error);
}
this.failInitialization(
error instanceof Error ? error : new Error(String(error)),
);
}
}
/**
* Handle control request asynchronously (fire-and-forget from main loop).
* Errors are handled internally and responses sent by dispatcher.
*/
private handleControlRequestAsync(request: CLIControlRequest): void {
const dispatcher = this.getDispatcher(); const dispatcher = this.getDispatcher();
if (!dispatcher) { if (!dispatcher) {
if (this.debugMode) { if (this.debugMode) {
@@ -171,9 +303,20 @@ class Session {
return; return;
} }
await dispatcher.dispatch(request); // Fire-and-forget: dispatch runs concurrently
// The dispatcher's pendingIncomingRequests tracks completion
void dispatcher.dispatch(request).catch((error) => {
if (this.debugMode) {
console.error('[Session] Control request dispatch error:', error);
}
// Error response is already sent by dispatcher.dispatch()
});
} }
/**
* Handle control response - MUST be synchronous
* This resolves pending outgoing requests, breaking the deadlock cycle.
*/
private handleControlResponse(response: CLIControlResponse): void { private handleControlResponse(response: CLIControlResponse): void {
const dispatcher = this.getDispatcher(); const dispatcher = this.getDispatcher();
if (!dispatcher) { if (!dispatcher) {
@@ -201,8 +344,8 @@ class Session {
return; return;
} }
// Ensure config is initialized before processing user messages // Wait for initialization to complete before processing user messages
await this.ensureConfigInitialized(); await this.waitForInitialization();
const promptId = this.getNextPromptId(); const promptId = this.getNextPromptId();
@@ -307,6 +450,45 @@ class Session {
process.on('SIGTERM', this.shutdownHandler); process.on('SIGTERM', this.shutdownHandler);
} }
/**
* Wait for all pending work to complete before shutdown
*/
private async waitForAllPendingWork(): Promise<void> {
// 1. Wait for initialization to complete (or fail)
try {
await this.waitForInitialization();
} catch (error) {
if (this.debugMode) {
console.error('[Session] Initialization error during shutdown:', error);
}
}
// 2. Wait for all control request handlers using dispatcher's tracking
if (this.dispatcher) {
const pendingCount = this.dispatcher.getPendingIncomingRequestCount();
if (pendingCount > 0 && this.debugMode) {
console.error(
`[Session] Waiting for ${pendingCount} pending control request handlers`,
);
}
await this.dispatcher.waitForPendingIncomingRequests();
}
// 3. Wait for user message processing queue
while (this.processingPromise) {
if (this.debugMode) {
console.error('[Session] Waiting for user message processing');
}
try {
await this.processingPromise;
} catch (error) {
if (this.debugMode) {
console.error('[Session] Error in user message processing:', error);
}
}
}
}
private async shutdown(): Promise<void> { private async shutdown(): Promise<void> {
if (this.debugMode) { if (this.debugMode) {
console.error('[Session] Shutting down'); console.error('[Session] Shutting down');
@@ -314,18 +496,8 @@ class Session {
this.isShuttingDown = true; this.isShuttingDown = true;
if (this.processingPromise) { // Wait for all pending work
try { await this.waitForAllPendingWork();
await this.processingPromise;
} catch (error) {
if (this.debugMode) {
console.error(
'[Session] Error waiting for processing to complete:',
error,
);
}
}
}
this.dispatcher?.shutdown(); this.dispatcher?.shutdown();
this.cleanupSignalHandlers(); this.cleanupSignalHandlers();
@@ -339,18 +511,30 @@ class Session {
} }
} }
/**
* Main message processing loop
*
* CRITICAL: This loop must NEVER await handlers that might need to
* send control requests and wait for responses. Such handlers must
* be started in fire-and-forget mode, allowing the loop to continue
* reading responses that resolve pending requests.
*
* Message handling order:
* 1. control_response - FIRST, synchronously resolves pending requests
* 2. First message - determines mode, starts async initialization
* 3. control_request - fire-and-forget, tracked by dispatcher
* 4. control_cancel - synchronous
* 5. user_message - enqueued for processing
*/
async run(): Promise<void> { async run(): Promise<void> {
try { try {
if (this.debugMode) { if (this.debugMode) {
console.error('[Session] Starting session', this.sessionId); console.error('[Session] Starting session', this.sessionId);
} }
// Handle initial prompt if provided (fire-and-forget)
if (this.initialPrompt !== null) { if (this.initialPrompt !== null) {
const handled = await this.handleFirstMessage(this.initialPrompt); this.handleFirstMessage(this.initialPrompt);
if (handled && this.isShuttingDown) {
await this.shutdown();
return;
}
} }
try { try {
@@ -359,23 +543,33 @@ class Session {
break; break;
} }
if (this.controlSystemEnabled === null) { // ============================================================
const handled = await this.handleFirstMessage(message); // CRITICAL: Handle control_response FIRST and SYNCHRONOUSLY
if (handled) { // This resolves pending outgoing requests, breaking deadlock.
if (this.isShuttingDown) { // ============================================================
break; if (isControlResponse(message)) {
} this.handleControlResponse(message as CLIControlResponse);
continue; continue;
}
} }
// Handle first message to determine session mode
if (this.controlSystemEnabled === null) {
this.handleFirstMessage(message);
continue;
}
// ============================================================
// CRITICAL: Handle control_request in FIRE-AND-FORGET mode
// DON'T await - let handler run concurrently while loop continues
// Dispatcher's pendingIncomingRequests tracks completion
// ============================================================
if (isControlRequest(message)) { if (isControlRequest(message)) {
await this.handleControlRequest(message as CLIControlRequest); this.handleControlRequestAsync(message as CLIControlRequest);
} else if (isControlResponse(message)) {
this.handleControlResponse(message as CLIControlResponse);
} else if (isControlCancel(message)) { } else if (isControlCancel(message)) {
// Cancel is synchronous - OK to handle inline
this.handleControlCancel(message as ControlCancelRequest); this.handleControlCancel(message as ControlCancelRequest);
} else if (isCLIUserMessage(message)) { } else if (isCLIUserMessage(message)) {
// User messages are enqueued, processing runs separately
this.enqueueUserMessage(message as CLIUserMessage); this.enqueueUserMessage(message as CLIUserMessage);
} else if (this.debugMode) { } else if (this.debugMode) {
if ( if (
@@ -402,19 +596,8 @@ class Session {
throw streamError; throw streamError;
} }
while (this.processingPromise) { // Stream ended - wait for all pending work before shutdown
if (this.debugMode) { await this.waitForAllPendingWork();
console.error('[Session] Waiting for final processing to complete');
}
try {
await this.processingPromise;
} catch (error) {
if (this.debugMode) {
console.error('[Session] Error in final processing:', error);
}
}
}
await this.shutdown(); await this.shutdown();
} catch (error) { } catch (error) {
if (this.debugMode) { if (this.debugMode) {

View File

@@ -1,8 +1,5 @@
/* eslint-disable @typescript-eslint/no-explicit-any */ /* eslint-disable @typescript-eslint/no-explicit-any */
import type { import type { SubagentConfig } from '@qwen-code/qwen-code-core';
MCPServerConfig,
SubagentConfig,
} from '@qwen-code/qwen-code-core';
/** /**
* Annotation for attaching metadata to content blocks * Annotation for attaching metadata to content blocks
@@ -298,11 +295,68 @@ export interface CLIControlPermissionRequest {
blocked_path: string | null; blocked_path: string | null;
} }
/**
* Wire format for SDK MCP server config in initialization request.
* The actual Server instance stays in the SDK process.
*/
export interface SDKMcpServerConfig {
type: 'sdk';
name: string;
}
/**
* Wire format for external MCP server config in initialization request.
* Represents stdio/SSE/HTTP/TCP transports that must run in the CLI process.
*/
export interface CLIMcpServerConfig {
command?: string;
args?: string[];
env?: Record<string, string>;
cwd?: string;
url?: string;
httpUrl?: string;
headers?: Record<string, string>;
tcp?: string;
timeout?: number;
trust?: boolean;
description?: string;
includeTools?: string[];
excludeTools?: string[];
extensionName?: string;
oauth?: {
enabled?: boolean;
clientId?: string;
clientSecret?: string;
authorizationUrl?: string;
tokenUrl?: string;
scopes?: string[];
audiences?: string[];
redirectUri?: string;
tokenParamName?: string;
registrationUrl?: string;
};
authProviderType?:
| 'dynamic_discovery'
| 'google_credentials'
| 'service_account_impersonation';
targetAudience?: string;
targetServiceAccount?: string;
}
export interface CLIControlInitializeRequest { export interface CLIControlInitializeRequest {
subtype: 'initialize'; subtype: 'initialize';
hooks?: HookRegistration[] | null; hooks?: HookRegistration[] | null;
sdkMcpServers?: Record<string, MCPServerConfig>; /**
mcpServers?: Record<string, MCPServerConfig>; * SDK MCP servers config
* These are MCP servers running in the SDK process, connected via control plane.
* External MCP servers are configured separately in settings, not via initialization.
*/
sdkMcpServers?: Record<string, Omit<SDKMcpServerConfig, 'instance'>>;
/**
* External MCP servers that the SDK wants the CLI to manage.
* These run outside the SDK process and require CLI-side transport setup.
*/
mcpServers?: Record<string, CLIMcpServerConfig>;
agents?: SubagentConfig[]; agents?: SubagentConfig[];
} }

View File

@@ -63,6 +63,7 @@ vi.mock('../tools/tool-registry', () => {
ToolRegistryMock.prototype.registerTool = vi.fn(); ToolRegistryMock.prototype.registerTool = vi.fn();
ToolRegistryMock.prototype.discoverAllTools = vi.fn(); ToolRegistryMock.prototype.discoverAllTools = vi.fn();
ToolRegistryMock.prototype.getAllTools = vi.fn(() => []); // Mock methods if needed ToolRegistryMock.prototype.getAllTools = vi.fn(() => []); // Mock methods if needed
ToolRegistryMock.prototype.getAllToolNames = vi.fn(() => []);
ToolRegistryMock.prototype.getTool = vi.fn(); ToolRegistryMock.prototype.getTool = vi.fn();
ToolRegistryMock.prototype.getFunctionDeclarations = vi.fn(() => []); ToolRegistryMock.prototype.getFunctionDeclarations = vi.fn(() => []);
return { ToolRegistry: ToolRegistryMock }; return { ToolRegistry: ToolRegistryMock };

View File

@@ -46,6 +46,7 @@ import { ExitPlanModeTool } from '../tools/exitPlanMode.js';
import { GlobTool } from '../tools/glob.js'; import { GlobTool } from '../tools/glob.js';
import { GrepTool } from '../tools/grep.js'; import { GrepTool } from '../tools/grep.js';
import { LSTool } from '../tools/ls.js'; import { LSTool } from '../tools/ls.js';
import type { SendSdkMcpMessage } from '../tools/mcp-client.js';
import { MemoryTool, setGeminiMdFilename } from '../tools/memoryTool.js'; import { MemoryTool, setGeminiMdFilename } from '../tools/memoryTool.js';
import { ReadFileTool } from '../tools/read-file.js'; import { ReadFileTool } from '../tools/read-file.js';
import { ReadManyFilesTool } from '../tools/read-many-files.js'; import { ReadManyFilesTool } from '../tools/read-many-files.js';
@@ -239,9 +240,18 @@ export class MCPServerConfig {
readonly targetAudience?: string, readonly targetAudience?: string,
/* targetServiceAccount format: <service-account-name>@<project-num>.iam.gserviceaccount.com */ /* targetServiceAccount format: <service-account-name>@<project-num>.iam.gserviceaccount.com */
readonly targetServiceAccount?: string, readonly targetServiceAccount?: string,
// SDK MCP server type - 'sdk' indicates server runs in SDK process
readonly type?: 'sdk',
) {} ) {}
} }
/**
* Check if an MCP server config represents an SDK server
*/
export function isSdkMcpServerConfig(config: MCPServerConfig): boolean {
return config.type === 'sdk';
}
export enum AuthProviderType { export enum AuthProviderType {
DYNAMIC_DISCOVERY = 'dynamic_discovery', DYNAMIC_DISCOVERY = 'dynamic_discovery',
GOOGLE_CREDENTIALS = 'google_credentials', GOOGLE_CREDENTIALS = 'google_credentials',
@@ -360,6 +370,17 @@ function normalizeConfigOutputFormat(
} }
} }
/**
* Options for Config.initialize()
*/
export interface ConfigInitializeOptions {
/**
* Callback for sending MCP messages to SDK servers via control plane.
* Required for SDK MCP server support in SDK mode.
*/
sendSdkMcpMessage?: SendSdkMcpMessage;
}
export class Config { export class Config {
private sessionId: string; private sessionId: string;
private sessionData?: ResumedSessionData; private sessionData?: ResumedSessionData;
@@ -599,8 +620,9 @@ export class Config {
/** /**
* Must only be called once, throws if called again. * Must only be called once, throws if called again.
* @param options Optional initialization options including sendSdkMcpMessage callback
*/ */
async initialize(): Promise<void> { async initialize(options?: ConfigInitializeOptions): Promise<void> {
if (this.initialized) { if (this.initialized) {
throw Error('Config was already initialized'); throw Error('Config was already initialized');
} }
@@ -619,7 +641,9 @@ export class Config {
this.subagentManager.loadSessionSubagents(this.sessionSubagents); this.subagentManager.loadSessionSubagents(this.sessionSubagents);
} }
this.toolRegistry = await this.createToolRegistry(); this.toolRegistry = await this.createToolRegistry(
options?.sendSdkMcpMessage,
);
await this.geminiClient.initialize(); await this.geminiClient.initialize();
@@ -1261,8 +1285,14 @@ export class Config {
return this.subagentManager; return this.subagentManager;
} }
async createToolRegistry(): Promise<ToolRegistry> { async createToolRegistry(
const registry = new ToolRegistry(this, this.eventEmitter); sendSdkMcpMessage?: SendSdkMcpMessage,
): Promise<ToolRegistry> {
const registry = new ToolRegistry(
this,
this.eventEmitter,
sendSdkMcpMessage,
);
const coreToolsConfig = this.getCoreTools(); const coreToolsConfig = this.getCoreTools();
const excludeToolsConfig = this.getExcludeTools(); const excludeToolsConfig = this.getExcludeTools();
@@ -1347,6 +1377,7 @@ export class Config {
} }
await registry.discoverAllTools(); await registry.discoverAllTools();
console.debug('ToolRegistry created', registry.getAllToolNames());
return registry; return registry;
} }
} }

View File

@@ -102,7 +102,9 @@ export * from './tools/shell.js';
export * from './tools/web-search/index.js'; export * from './tools/web-search/index.js';
export * from './tools/read-many-files.js'; export * from './tools/read-many-files.js';
export * from './tools/mcp-client.js'; export * from './tools/mcp-client.js';
export * from './tools/mcp-client-manager.js';
export * from './tools/mcp-tool.js'; export * from './tools/mcp-tool.js';
export * from './tools/sdk-control-client-transport.js';
export * from './tools/task.js'; export * from './tools/task.js';
export * from './tools/todoWrite.js'; export * from './tools/todoWrite.js';
export * from './tools/exitPlanMode.js'; export * from './tools/exitPlanMode.js';

View File

@@ -5,6 +5,7 @@
*/ */
import type { Config, MCPServerConfig } from '../config/config.js'; import type { Config, MCPServerConfig } from '../config/config.js';
import { isSdkMcpServerConfig } from '../config/config.js';
import type { ToolRegistry } from './tool-registry.js'; import type { ToolRegistry } from './tool-registry.js';
import type { PromptRegistry } from '../prompts/prompt-registry.js'; import type { PromptRegistry } from '../prompts/prompt-registry.js';
import { import {
@@ -12,6 +13,7 @@ import {
MCPDiscoveryState, MCPDiscoveryState,
populateMcpServerCommand, populateMcpServerCommand,
} from './mcp-client.js'; } from './mcp-client.js';
import type { SendSdkMcpMessage } from './mcp-client.js';
import { getErrorMessage } from '../utils/errors.js'; import { getErrorMessage } from '../utils/errors.js';
import type { EventEmitter } from 'node:events'; import type { EventEmitter } from 'node:events';
import type { WorkspaceContext } from '../utils/workspaceContext.js'; import type { WorkspaceContext } from '../utils/workspaceContext.js';
@@ -31,6 +33,7 @@ export class McpClientManager {
private readonly workspaceContext: WorkspaceContext; private readonly workspaceContext: WorkspaceContext;
private discoveryState: MCPDiscoveryState = MCPDiscoveryState.NOT_STARTED; private discoveryState: MCPDiscoveryState = MCPDiscoveryState.NOT_STARTED;
private readonly eventEmitter?: EventEmitter; private readonly eventEmitter?: EventEmitter;
private readonly sendSdkMcpMessage?: SendSdkMcpMessage;
constructor( constructor(
mcpServers: Record<string, MCPServerConfig>, mcpServers: Record<string, MCPServerConfig>,
@@ -40,6 +43,7 @@ export class McpClientManager {
debugMode: boolean, debugMode: boolean,
workspaceContext: WorkspaceContext, workspaceContext: WorkspaceContext,
eventEmitter?: EventEmitter, eventEmitter?: EventEmitter,
sendSdkMcpMessage?: SendSdkMcpMessage,
) { ) {
this.mcpServers = mcpServers; this.mcpServers = mcpServers;
this.mcpServerCommand = mcpServerCommand; this.mcpServerCommand = mcpServerCommand;
@@ -48,6 +52,7 @@ export class McpClientManager {
this.debugMode = debugMode; this.debugMode = debugMode;
this.workspaceContext = workspaceContext; this.workspaceContext = workspaceContext;
this.eventEmitter = eventEmitter; this.eventEmitter = eventEmitter;
this.sendSdkMcpMessage = sendSdkMcpMessage;
} }
/** /**
@@ -71,6 +76,11 @@ export class McpClientManager {
this.eventEmitter?.emit('mcp-client-update', this.clients); this.eventEmitter?.emit('mcp-client-update', this.clients);
const discoveryPromises = Object.entries(servers).map( const discoveryPromises = Object.entries(servers).map(
async ([name, config]) => { async ([name, config]) => {
// For SDK MCP servers, pass the sendSdkMcpMessage callback
const sdkCallback = isSdkMcpServerConfig(config)
? this.sendSdkMcpMessage
: undefined;
const client = new McpClient( const client = new McpClient(
name, name,
config, config,
@@ -78,6 +88,7 @@ export class McpClientManager {
this.promptRegistry, this.promptRegistry,
this.workspaceContext, this.workspaceContext,
this.debugMode, this.debugMode,
sdkCallback,
); );
this.clients.set(name, client); this.clients.set(name, client);

View File

@@ -13,6 +13,7 @@ import { StreamableHTTPClientTransport } from '@modelcontextprotocol/sdk/client/
import type { Transport } from '@modelcontextprotocol/sdk/shared/transport.js'; import type { Transport } from '@modelcontextprotocol/sdk/shared/transport.js';
import type { import type {
GetPromptResult, GetPromptResult,
JSONRPCMessage,
Prompt, Prompt,
} from '@modelcontextprotocol/sdk/types.js'; } from '@modelcontextprotocol/sdk/types.js';
import { import {
@@ -22,10 +23,11 @@ import {
} from '@modelcontextprotocol/sdk/types.js'; } from '@modelcontextprotocol/sdk/types.js';
import { parse } from 'shell-quote'; import { parse } from 'shell-quote';
import type { Config, MCPServerConfig } from '../config/config.js'; import type { Config, MCPServerConfig } from '../config/config.js';
import { AuthProviderType } from '../config/config.js'; import { AuthProviderType, isSdkMcpServerConfig } from '../config/config.js';
import { GoogleCredentialProvider } from '../mcp/google-auth-provider.js'; import { GoogleCredentialProvider } from '../mcp/google-auth-provider.js';
import { ServiceAccountImpersonationProvider } from '../mcp/sa-impersonation-provider.js'; import { ServiceAccountImpersonationProvider } from '../mcp/sa-impersonation-provider.js';
import { DiscoveredMCPTool } from './mcp-tool.js'; import { DiscoveredMCPTool } from './mcp-tool.js';
import { SdkControlClientTransport } from './sdk-control-client-transport.js';
import type { FunctionDeclaration } from '@google/genai'; import type { FunctionDeclaration } from '@google/genai';
import { mcpToTool } from '@google/genai'; import { mcpToTool } from '@google/genai';
@@ -42,6 +44,14 @@ import type {
} from '../utils/workspaceContext.js'; } from '../utils/workspaceContext.js';
import type { ToolRegistry } from './tool-registry.js'; import type { ToolRegistry } from './tool-registry.js';
/**
* Callback type for sending MCP messages to SDK servers via control plane
*/
export type SendSdkMcpMessage = (
serverName: string,
message: JSONRPCMessage,
) => Promise<JSONRPCMessage>;
export const MCP_DEFAULT_TIMEOUT_MSEC = 10 * 60 * 1000; // default to 10 minutes export const MCP_DEFAULT_TIMEOUT_MSEC = 10 * 60 * 1000; // default to 10 minutes
export type DiscoveredMCPPrompt = Prompt & { export type DiscoveredMCPPrompt = Prompt & {
@@ -92,6 +102,7 @@ export class McpClient {
private readonly promptRegistry: PromptRegistry, private readonly promptRegistry: PromptRegistry,
private readonly workspaceContext: WorkspaceContext, private readonly workspaceContext: WorkspaceContext,
private readonly debugMode: boolean, private readonly debugMode: boolean,
private readonly sendSdkMcpMessage?: SendSdkMcpMessage,
) { ) {
this.client = new Client({ this.client = new Client({
name: `qwen-cli-mcp-client-${this.serverName}`, name: `qwen-cli-mcp-client-${this.serverName}`,
@@ -189,7 +200,12 @@ export class McpClient {
} }
private async createTransport(): Promise<Transport> { private async createTransport(): Promise<Transport> {
return createTransport(this.serverName, this.serverConfig, this.debugMode); return createTransport(
this.serverName,
this.serverConfig,
this.debugMode,
this.sendSdkMcpMessage,
);
} }
private async discoverTools(cliConfig: Config): Promise<DiscoveredMCPTool[]> { private async discoverTools(cliConfig: Config): Promise<DiscoveredMCPTool[]> {
@@ -501,6 +517,7 @@ export function populateMcpServerCommand(
* @param mcpServerName The name identifier for this MCP server * @param mcpServerName The name identifier for this MCP server
* @param mcpServerConfig Configuration object containing connection details * @param mcpServerConfig Configuration object containing connection details
* @param toolRegistry The registry to register discovered tools with * @param toolRegistry The registry to register discovered tools with
* @param sendSdkMcpMessage Optional callback for SDK MCP servers to route messages via control plane.
* @returns Promise that resolves when discovery is complete * @returns Promise that resolves when discovery is complete
*/ */
export async function connectAndDiscover( export async function connectAndDiscover(
@@ -511,6 +528,7 @@ export async function connectAndDiscover(
debugMode: boolean, debugMode: boolean,
workspaceContext: WorkspaceContext, workspaceContext: WorkspaceContext,
cliConfig: Config, cliConfig: Config,
sendSdkMcpMessage?: SendSdkMcpMessage,
): Promise<void> { ): Promise<void> {
updateMCPServerStatus(mcpServerName, MCPServerStatus.CONNECTING); updateMCPServerStatus(mcpServerName, MCPServerStatus.CONNECTING);
@@ -521,6 +539,7 @@ export async function connectAndDiscover(
mcpServerConfig, mcpServerConfig,
debugMode, debugMode,
workspaceContext, workspaceContext,
sendSdkMcpMessage,
); );
mcpClient.onerror = (error) => { mcpClient.onerror = (error) => {
@@ -744,6 +763,7 @@ export function hasNetworkTransport(config: MCPServerConfig): boolean {
* *
* @param mcpServerName The name of the MCP server, used for logging and identification. * @param mcpServerName The name of the MCP server, used for logging and identification.
* @param mcpServerConfig The configuration specifying how to connect to the server. * @param mcpServerConfig The configuration specifying how to connect to the server.
* @param sendSdkMcpMessage Optional callback for SDK MCP servers to route messages via control plane.
* @returns A promise that resolves to a connected MCP `Client` instance. * @returns A promise that resolves to a connected MCP `Client` instance.
* @throws An error if the connection fails or the configuration is invalid. * @throws An error if the connection fails or the configuration is invalid.
*/ */
@@ -752,6 +772,7 @@ export async function connectToMcpServer(
mcpServerConfig: MCPServerConfig, mcpServerConfig: MCPServerConfig,
debugMode: boolean, debugMode: boolean,
workspaceContext: WorkspaceContext, workspaceContext: WorkspaceContext,
sendSdkMcpMessage?: SendSdkMcpMessage,
): Promise<Client> { ): Promise<Client> {
const mcpClient = new Client({ const mcpClient = new Client({
name: 'qwen-code-mcp-client', name: 'qwen-code-mcp-client',
@@ -808,6 +829,7 @@ export async function connectToMcpServer(
mcpServerName, mcpServerName,
mcpServerConfig, mcpServerConfig,
debugMode, debugMode,
sendSdkMcpMessage,
); );
try { try {
await mcpClient.connect(transport, { await mcpClient.connect(transport, {
@@ -1172,7 +1194,21 @@ export async function createTransport(
mcpServerName: string, mcpServerName: string,
mcpServerConfig: MCPServerConfig, mcpServerConfig: MCPServerConfig,
debugMode: boolean, debugMode: boolean,
sendSdkMcpMessage?: SendSdkMcpMessage,
): Promise<Transport> { ): Promise<Transport> {
if (isSdkMcpServerConfig(mcpServerConfig)) {
if (!sendSdkMcpMessage) {
throw new Error(
`SDK MCP server '${mcpServerName}' requires sendSdkMcpMessage callback`,
);
}
return new SdkControlClientTransport({
serverName: mcpServerName,
sendMcpMessage: sendSdkMcpMessage,
debugMode,
});
}
if ( if (
mcpServerConfig.authProviderType === mcpServerConfig.authProviderType ===
AuthProviderType.SERVICE_ACCOUNT_IMPERSONATION AuthProviderType.SERVICE_ACCOUNT_IMPERSONATION

View File

@@ -0,0 +1,163 @@
/**
* @license
* Copyright 2025 Qwen Team
* SPDX-License-Identifier: Apache-2.0
*/
/**
* SdkControlClientTransport - MCP Client transport for SDK MCP servers
*
* This transport enables CLI's MCP client to connect to SDK MCP servers
* through the control plane. Messages are routed:
*
* CLI MCP Client → SdkControlClientTransport → sendMcpMessage() →
* control_request (mcp_message) → SDK → control_response → onmessage → CLI
*
* Unlike StdioClientTransport which spawns a subprocess, this transport
* communicates with SDK MCP servers running in the SDK process.
*/
import type { JSONRPCMessage } from '@modelcontextprotocol/sdk/types.js';
/**
* Callback to send MCP messages to SDK via control plane
* Returns the MCP response from the SDK
*/
export type SendMcpMessageCallback = (
serverName: string,
message: JSONRPCMessage,
) => Promise<JSONRPCMessage>;
export interface SdkControlClientTransportOptions {
serverName: string;
sendMcpMessage: SendMcpMessageCallback;
debugMode?: boolean;
}
/**
* MCP Client Transport for SDK MCP servers
*
* Implements the @modelcontextprotocol/sdk Transport interface to enable
* CLI's MCP client to connect to SDK MCP servers via the control plane.
*/
export class SdkControlClientTransport {
private serverName: string;
private sendMcpMessage: SendMcpMessageCallback;
private debugMode: boolean;
private started = false;
// Transport interface callbacks
onmessage?: (message: JSONRPCMessage) => void;
onerror?: (error: Error) => void;
onclose?: () => void;
constructor(options: SdkControlClientTransportOptions) {
this.serverName = options.serverName;
this.sendMcpMessage = options.sendMcpMessage;
this.debugMode = options.debugMode ?? false;
}
/**
* Start the transport
* For SDK transport, this just marks it as ready - no subprocess to spawn
*/
async start(): Promise<void> {
if (this.started) {
return;
}
this.started = true;
if (this.debugMode) {
console.error(
`[SdkControlClientTransport] Started for server '${this.serverName}'`,
);
}
}
/**
* Send a message to the SDK MCP server via control plane
*
* Routes the message through the control plane and delivers
* the response via onmessage callback.
*/
async send(message: JSONRPCMessage): Promise<void> {
if (!this.started) {
throw new Error(
`SdkControlClientTransport (${this.serverName}) not started. Call start() first.`,
);
}
if (this.debugMode) {
console.error(
`[SdkControlClientTransport] Sending message to '${this.serverName}':`,
JSON.stringify(message),
);
}
try {
// Send message to SDK and wait for response
const response = await this.sendMcpMessage(this.serverName, message);
if (this.debugMode) {
console.error(
`[SdkControlClientTransport] Received response from '${this.serverName}':`,
JSON.stringify(response),
);
}
// Deliver response via onmessage callback
if (this.onmessage) {
this.onmessage(response);
}
} catch (error) {
if (this.debugMode) {
console.error(
`[SdkControlClientTransport] Error sending to '${this.serverName}':`,
error,
);
}
if (this.onerror) {
this.onerror(error instanceof Error ? error : new Error(String(error)));
}
throw error;
}
}
/**
* Close the transport
*/
async close(): Promise<void> {
if (!this.started) {
return;
}
this.started = false;
if (this.debugMode) {
console.error(
`[SdkControlClientTransport] Closed for server '${this.serverName}'`,
);
}
if (this.onclose) {
this.onclose();
}
}
/**
* Check if transport is started
*/
isStarted(): boolean {
return this.started;
}
/**
* Get server name
*/
getServerName(): string {
return this.serverName;
}
}

View File

@@ -16,6 +16,7 @@ import type { Config } from '../config/config.js';
import { spawn } from 'node:child_process'; import { spawn } from 'node:child_process';
import { StringDecoder } from 'node:string_decoder'; import { StringDecoder } from 'node:string_decoder';
import { connectAndDiscover } from './mcp-client.js'; import { connectAndDiscover } from './mcp-client.js';
import type { SendSdkMcpMessage } from './mcp-client.js';
import { McpClientManager } from './mcp-client-manager.js'; import { McpClientManager } from './mcp-client-manager.js';
import { DiscoveredMCPTool } from './mcp-tool.js'; import { DiscoveredMCPTool } from './mcp-tool.js';
import { parse } from 'shell-quote'; import { parse } from 'shell-quote';
@@ -173,7 +174,11 @@ export class ToolRegistry {
private config: Config; private config: Config;
private mcpClientManager: McpClientManager; private mcpClientManager: McpClientManager;
constructor(config: Config, eventEmitter?: EventEmitter) { constructor(
config: Config,
eventEmitter?: EventEmitter,
sendSdkMcpMessage?: SendSdkMcpMessage,
) {
this.config = config; this.config = config;
this.mcpClientManager = new McpClientManager( this.mcpClientManager = new McpClientManager(
this.config.getMcpServers() ?? {}, this.config.getMcpServers() ?? {},
@@ -183,6 +188,7 @@ export class ToolRegistry {
this.config.getDebugMode(), this.config.getDebugMode(),
this.config.getWorkspaceContext(), this.config.getWorkspaceContext(),
eventEmitter, eventEmitter,
sendSdkMcpMessage,
); );
} }

View File

@@ -13,7 +13,7 @@ npm install @qwen-code/sdk-typescript
## Requirements ## Requirements
- Node.js >= 20.0.0 - Node.js >= 20.0.0
- [Qwen Code](https://github.com/QwenLM/qwen-code) installed and accessible in PATH - [Qwen Code](https://github.com/QwenLM/qwen-code) >= 0.4.0 (stable) installed and accessible in PATH
> **Note for nvm users**: If you use nvm to manage Node.js versions, the SDK may not be able to auto-detect the Qwen Code executable. You should explicitly set the `pathToQwenExecutable` option to the full path of the `qwen` binary. > **Note for nvm users**: If you use nvm to manage Node.js versions, the SDK may not be able to auto-detect the Qwen Code executable. You should explicitly set the `pathToQwenExecutable` option to the full path of the `qwen` binary.
@@ -61,7 +61,7 @@ Creates a new query session with the Qwen Code.
| `permissionMode` | `'default' \| 'plan' \| 'auto-edit' \| 'yolo'` | `'default'` | Permission mode controlling tool execution approval. See [Permission Modes](#permission-modes) for details. | | `permissionMode` | `'default' \| 'plan' \| 'auto-edit' \| 'yolo'` | `'default'` | Permission mode controlling tool execution approval. See [Permission Modes](#permission-modes) for details. |
| `canUseTool` | `CanUseTool` | - | Custom permission handler for tool execution approval. Invoked when a tool requires confirmation. Must respond within 30 seconds or the request will be auto-denied. See [Custom Permission Handler](#custom-permission-handler). | | `canUseTool` | `CanUseTool` | - | Custom permission handler for tool execution approval. Invoked when a tool requires confirmation. Must respond within 30 seconds or the request will be auto-denied. See [Custom Permission Handler](#custom-permission-handler). |
| `env` | `Record<string, string>` | - | Environment variables to pass to the Qwen Code process. Merged with the current process environment. | | `env` | `Record<string, string>` | - | Environment variables to pass to the Qwen Code process. Merged with the current process environment. |
| `mcpServers` | `Record<string, ExternalMcpServerConfig>` | - | External MCP (Model Context Protocol) servers to connect. Each server is identified by a unique name and configured with `command`, `args`, and `env`. | | `mcpServers` | `Record<string, McpServerConfig>` | - | MCP (Model Context Protocol) servers to connect. Supports external servers (stdio/SSE/HTTP) and SDK-embedded servers. External servers are configured with transport options like `command`, `args`, `url`, `httpUrl`, etc. SDK servers use `{ type: 'sdk', name: string, instance: Server }`. |
| `abortController` | `AbortController` | - | Controller to cancel the query session. Call `abortController.abort()` to terminate the session and cleanup resources. | | `abortController` | `AbortController` | - | Controller to cancel the query session. Call `abortController.abort()` to terminate the session and cleanup resources. |
| `debug` | `boolean` | `false` | Enable debug mode for verbose logging from the CLI process. | | `debug` | `boolean` | `false` | Enable debug mode for verbose logging from the CLI process. |
| `maxSessionTurns` | `number` | `-1` (unlimited) | Maximum number of conversation turns before the session automatically terminates. A turn consists of a user message and an assistant response. | | `maxSessionTurns` | `number` | `-1` (unlimited) | Maximum number of conversation turns before the session automatically terminates. A turn consists of a user message and an assistant response. |
@@ -74,12 +74,27 @@ Creates a new query session with the Qwen Code.
### Timeouts ### Timeouts
The SDK enforces the following timeouts: The SDK enforces the following default timeouts:
| Timeout | Duration | Description | | Timeout | Default | Description |
| ------------------- | ---------- | ---------------------------------------------------------------------------------------------------------------------------- | | ---------------- | ---------- | ---------------------------------------------------------------------------------------------------------------------------- |
| Permission Callback | 30 seconds | Maximum time for `canUseTool` callback to respond. If exceeded, the tool request is auto-denied. | | `canUseTool` | 30 seconds | Maximum time for `canUseTool` callback to respond. If exceeded, the tool request is auto-denied. |
| Control Request | 30 seconds | Maximum time for control operations like `initialize()`, `setModel()`, `setPermissionMode()`, and `interrupt()` to complete. | | `mcpRequest` | 1 minute | Maximum time for SDK MCP tool calls to complete. |
| `controlRequest` | 30 seconds | Maximum time for control operations like `initialize()`, `setModel()`, `setPermissionMode()`, and `interrupt()` to complete. |
| `streamClose` | 1 minute | Maximum time to wait for initialization to complete before closing CLI stdin in multi-turn mode with SDK MCP servers. |
You can customize these timeouts via the `timeout` option:
```typescript
const query = qwen.query('Your prompt', {
timeout: {
canUseTool: 60000, // 60 seconds for permission callback
mcpRequest: 600000, // 10 minutes for MCP tool calls
controlRequest: 60000, // 60 seconds for control requests
streamClose: 15000, // 15 seconds for stream close wait
},
});
```
### Message Types ### Message Types
@@ -212,7 +227,7 @@ const result = query({
}); });
``` ```
### With MCP Servers ### With External MCP Servers
```typescript ```typescript
import { query } from '@qwen-code/sdk-typescript'; import { query } from '@qwen-code/sdk-typescript';
@@ -231,6 +246,84 @@ const result = query({
}); });
``` ```
### With SDK-Embedded MCP Servers
The SDK provides `tool` and `createSdkMcpServer` to create MCP servers that run in the same process as your SDK application. This is useful when you want to expose custom tools to the AI without running a separate server process.
#### `tool(name, description, inputSchema, handler)`
Creates a tool definition with Zod schema type inference.
| Parameter | Type | Description |
| ------------- | ---------------------------------- | ------------------------------------------------------------------------ |
| `name` | `string` | Tool name (1-64 chars, starts with letter, alphanumeric and underscores) |
| `description` | `string` | Human-readable description of what the tool does |
| `inputSchema` | `ZodRawShape` | Zod schema object defining the tool's input parameters |
| `handler` | `(args, extra) => Promise<Result>` | Async function that executes the tool and returns MCP content blocks |
The handler must return a `CallToolResult` object with the following structure:
```typescript
{
content: Array<
| { type: 'text'; text: string }
| { type: 'image'; data: string; mimeType: string }
| { type: 'resource'; uri: string; mimeType?: string; text?: string }
>;
isError?: boolean;
}
```
#### `createSdkMcpServer(options)`
Creates an SDK-embedded MCP server instance.
| Option | Type | Default | Description |
| --------- | ------------------------ | --------- | ------------------------------------ |
| `name` | `string` | Required | Unique name for the MCP server |
| `version` | `string` | `'1.0.0'` | Server version |
| `tools` | `SdkMcpToolDefinition[]` | - | Array of tools created with `tool()` |
Returns a `McpSdkServerConfigWithInstance` object that can be passed directly to the `mcpServers` option.
#### Example
```typescript
import { z } from 'zod';
import { query, tool, createSdkMcpServer } from '@qwen-code/sdk-typescript';
// Define a tool with Zod schema
const calculatorTool = tool(
'calculate_sum',
'Add two numbers',
{ a: z.number(), b: z.number() },
async (args) => ({
content: [{ type: 'text', text: String(args.a + args.b) }],
}),
);
// Create the MCP server
const server = createSdkMcpServer({
name: 'calculator',
tools: [calculatorTool],
});
// Use the server in a query
const result = query({
prompt: 'What is 42 + 17?',
options: {
permissionMode: 'yolo',
mcpServers: {
calculator: server,
},
},
});
for await (const message of result) {
console.log(message);
}
```
### Abort a Query ### Abort a Query
```typescript ```typescript

View File

@@ -3,6 +3,17 @@ export { AbortError, isAbortError } from './types/errors.js';
export { Query } from './query/Query.js'; export { Query } from './query/Query.js';
export { SdkLogger } from './utils/logger.js'; export { SdkLogger } from './utils/logger.js';
// SDK MCP Server exports
export { tool } from './mcp/tool.js';
export { createSdkMcpServer } from './mcp/createSdkMcpServer.js';
export type { SdkMcpToolDefinition } from './mcp/tool.js';
export type {
CreateSdkMcpServerOptions,
McpSdkServerConfigWithInstance,
} from './mcp/createSdkMcpServer.js';
export type { QueryOptions } from './query/createQuery.js'; export type { QueryOptions } from './query/createQuery.js';
export type { LogLevel, LoggerConfig, ScopedLogger } from './utils/logger.js'; export type { LogLevel, LoggerConfig, ScopedLogger } from './utils/logger.js';
@@ -18,6 +29,7 @@ export type {
SDKResultMessage, SDKResultMessage,
SDKPartialAssistantMessage, SDKPartialAssistantMessage,
SDKMessage, SDKMessage,
SDKMcpServerConfig,
ControlMessage, ControlMessage,
CLIControlRequest, CLIControlRequest,
CLIControlResponse, CLIControlResponse,
@@ -43,6 +55,10 @@ export type {
PermissionMode, PermissionMode,
CanUseTool, CanUseTool,
PermissionResult, PermissionResult,
ExternalMcpServerConfig, CLIMcpServerConfig,
SdkMcpServerConfig, McpServerConfig,
McpOAuthConfig,
McpAuthProviderType,
} from './types/types.js'; } from './types/types.js';
export { isSdkMcpServerConfig } from './types/types.js';

View File

@@ -103,9 +103,3 @@ export class SdkControlServerTransport {
return this.serverName; return this.serverName;
} }
} }
export function createSdkControlServerTransport(
options: SdkControlServerTransportOptions,
): SdkControlServerTransport {
return new SdkControlServerTransport(options);
}

View File

@@ -1,29 +1,63 @@
/** /**
* Factory function to create SDK-embedded MCP servers * @license
* * Copyright 2025 Qwen Team
* Creates MCP Server instances that run in the user's Node.js process * SPDX-License-Identifier: Apache-2.0
* and are proxied to the CLI via the control plane.
*/ */
import { Server } from '@modelcontextprotocol/sdk/server/index.js'; /**
import { * Factory function to create SDK-embedded MCP servers
ListToolsRequestSchema, */
CallToolRequestSchema,
type CallToolResultSchema, import { McpServer } from '@modelcontextprotocol/sdk/server/mcp.js';
} from '@modelcontextprotocol/sdk/types.js'; import type { SdkMcpToolDefinition } from './tool.js';
import type { ToolDefinition } from '../types/types.js';
import { formatToolResult, formatToolError } from './formatters.js';
import { validateToolName } from './tool.js'; import { validateToolName } from './tool.js';
import type { z } from 'zod';
type CallToolResult = z.infer<typeof CallToolResultSchema>; /**
* Options for creating an SDK MCP server
*/
export type CreateSdkMcpServerOptions = {
name: string;
version?: string;
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
tools?: Array<SdkMcpToolDefinition<any>>;
};
/**
* SDK MCP Server configuration with instance
*/
export type McpSdkServerConfigWithInstance = {
type: 'sdk';
name: string;
instance: McpServer;
};
/**
* Creates an MCP server instance that can be used with the SDK transport.
*
* @example
* ```typescript
* import { z } from 'zod';
* import { tool, createSdkMcpServer } from '@qwen-code/sdk-typescript';
*
* const calculatorTool = tool(
* 'calculate_sum',
* 'Add two numbers',
* { a: z.number(), b: z.number() },
* async (args) => ({ content: [{ type: 'text', text: String(args.a + args.b) }] })
* );
*
* const server = createSdkMcpServer({
* name: 'calculator',
* version: '1.0.0',
* tools: [calculatorTool],
* });
* ```
*/
export function createSdkMcpServer( export function createSdkMcpServer(
name: string, options: CreateSdkMcpServerOptions,
version: string, ): McpSdkServerConfigWithInstance {
tools: ToolDefinition[], const { name, version = '1.0.0', tools } = options;
): Server {
// Validate server name
if (!name || typeof name !== 'string') { if (!name || typeof name !== 'string') {
throw new Error('MCP server name must be a non-empty string'); throw new Error('MCP server name must be a non-empty string');
} }
@@ -32,78 +66,42 @@ export function createSdkMcpServer(
throw new Error('MCP server version must be a non-empty string'); throw new Error('MCP server version must be a non-empty string');
} }
if (!Array.isArray(tools)) { if (tools !== undefined && !Array.isArray(tools)) {
throw new Error('Tools must be an array'); throw new Error('Tools must be an array');
} }
// Validate tool names are unique
const toolNames = new Set<string>(); const toolNames = new Set<string>();
for (const tool of tools) { if (tools) {
validateToolName(tool.name); for (const t of tools) {
validateToolName(t.name);
if (toolNames.has(tool.name)) { if (toolNames.has(t.name)) {
throw new Error( throw new Error(
`Duplicate tool name '${tool.name}' in MCP server '${name}'`, `Duplicate tool name '${t.name}' in MCP server '${name}'`,
); );
}
toolNames.add(t.name);
} }
toolNames.add(tool.name);
} }
// Create MCP Server instance const server = new McpServer(
const server = new Server( { name, version },
{
name,
version,
},
{ {
capabilities: { capabilities: {
tools: {}, tools: tools ? {} : undefined,
}, },
}, },
); );
// Create tool map for fast lookup if (tools) {
const toolMap = new Map<string, ToolDefinition>(); tools.forEach((toolDef) => {
for (const tool of tools) { server.tool(
toolMap.set(tool.name, tool); toolDef.name,
toolDef.description,
toolDef.inputSchema,
toolDef.handler,
);
});
} }
// Register list_tools handler return { type: 'sdk', name, instance: server };
server.setRequestHandler(ListToolsRequestSchema, async () => ({
tools: tools.map((tool) => ({
name: tool.name,
description: tool.description,
inputSchema: tool.inputSchema,
})),
}));
// Register call_tool handler
server.setRequestHandler(CallToolRequestSchema, async (request) => {
const { name: toolName, arguments: toolArgs } = request.params;
// Find tool
const tool = toolMap.get(toolName);
if (!tool) {
return formatToolError(
new Error(`Tool '${toolName}' not found in server '${name}'`),
) as CallToolResult;
}
try {
// Invoke tool handler
const result = await tool.handler(toolArgs);
// Format result
return formatToolResult(result) as CallToolResult;
} catch (error) {
// Handle tool execution error
return formatToolError(
error instanceof Error
? error
: new Error(`Tool '${toolName}' failed: ${String(error)}`),
) as CallToolResult;
}
});
return server;
} }

View File

@@ -1,39 +1,76 @@
/** /**
* Tool definition helper for SDK-embedded MCP servers * @license
* * Copyright 2025 Qwen Team
* Provides type-safe tool definitions with generic input/output types. * SPDX-License-Identifier: Apache-2.0
*/ */
import type { ToolDefinition } from '../types/types.js'; /**
* Tool definition helper for SDK-embedded MCP servers
*/
export function tool<TInput = unknown, TOutput = unknown>( import type { CallToolResultSchema } from '@modelcontextprotocol/sdk/types.js';
def: ToolDefinition<TInput, TOutput>, import type { z, ZodRawShape, ZodObject, ZodTypeAny } from 'zod';
): ToolDefinition<TInput, TOutput> {
// Validate tool definition type CallToolResult = z.infer<typeof CallToolResultSchema>;
if (!def.name || typeof def.name !== 'string') {
throw new Error('Tool definition must have a name (string)'); /**
* SDK MCP Tool Definition with Zod schema type inference
*/
export type SdkMcpToolDefinition<Schema extends ZodRawShape = ZodRawShape> = {
name: string;
description: string;
inputSchema: Schema;
handler: (
args: z.infer<ZodObject<Schema, 'strip', ZodTypeAny>>,
extra: unknown,
) => Promise<CallToolResult>;
};
/**
* Create an SDK MCP tool definition with Zod schema inference
*
* @example
* ```typescript
* import { z } from 'zod';
* import { tool } from '@qwen-code/sdk-typescript';
*
* const calculatorTool = tool(
* 'calculate_sum',
* 'Calculate the sum of two numbers',
* { a: z.number(), b: z.number() },
* async (args) => {
* // args is inferred as { a: number, b: number }
* return { content: [{ type: 'text', text: String(args.a + args.b) }] };
* }
* );
* ```
*/
export function tool<Schema extends ZodRawShape>(
name: string,
description: string,
inputSchema: Schema,
handler: (
args: z.infer<ZodObject<Schema, 'strip', ZodTypeAny>>,
extra: unknown,
) => Promise<CallToolResult>,
): SdkMcpToolDefinition<Schema> {
if (!name || typeof name !== 'string') {
throw new Error('Tool name must be a non-empty string');
} }
if (!def.description || typeof def.description !== 'string') { if (!description || typeof description !== 'string') {
throw new Error( throw new Error(`Tool '${name}' must have a description (string)`);
`Tool definition for '${def.name}' must have a description (string)`,
);
} }
if (!def.inputSchema || typeof def.inputSchema !== 'object') { if (!inputSchema || typeof inputSchema !== 'object') {
throw new Error( throw new Error(`Tool '${name}' must have an inputSchema (object)`);
`Tool definition for '${def.name}' must have an inputSchema (object)`,
);
} }
if (!def.handler || typeof def.handler !== 'function') { if (!handler || typeof handler !== 'function') {
throw new Error( throw new Error(`Tool '${name}' must have a handler (function)`);
`Tool definition for '${def.name}' must have a handler (function)`,
);
} }
// Return definition (pass-through for type safety) return { name, description, inputSchema, handler };
return def;
} }
export function validateToolName(name: string): void { export function validateToolName(name: string): void {
@@ -53,39 +90,3 @@ export function validateToolName(name: string): void {
); );
} }
} }
export function validateInputSchema(schema: unknown): void {
if (!schema || typeof schema !== 'object') {
throw new Error('Input schema must be an object');
}
const schemaObj = schema as Record<string, unknown>;
if (!schemaObj.type) {
throw new Error('Input schema must have a type field');
}
// For object schemas, validate properties
if (schemaObj.type === 'object') {
if (schemaObj.properties && typeof schemaObj.properties !== 'object') {
throw new Error('Input schema properties must be an object');
}
if (schemaObj.required && !Array.isArray(schemaObj.required)) {
throw new Error('Input schema required must be an array');
}
}
}
export function createTool<TInput = unknown, TOutput = unknown>(
def: ToolDefinition<TInput, TOutput>,
): ToolDefinition<TInput, TOutput> {
// Validate via tool() function
const validated = tool(def);
// Additional validation
validateToolName(validated.name);
validateInputSchema(validated.inputSchema);
return validated;
}

View File

@@ -5,10 +5,10 @@
* Implements AsyncIterator protocol for message consumption. * Implements AsyncIterator protocol for message consumption.
*/ */
const PERMISSION_CALLBACK_TIMEOUT = 30000; const DEFAULT_CAN_USE_TOOL_TIMEOUT = 30_000;
const MCP_REQUEST_TIMEOUT = 30000; const DEFAULT_MCP_REQUEST_TIMEOUT = 60_000;
const CONTROL_REQUEST_TIMEOUT = 30000; const DEFAULT_CONTROL_REQUEST_TIMEOUT = 30_000;
const STREAM_CLOSE_TIMEOUT = 10000; const DEFAULT_STREAM_CLOSE_TIMEOUT = 60_000;
import { randomUUID } from 'node:crypto'; import { randomUUID } from 'node:crypto';
import { SdkLogger } from '../utils/logger.js'; import { SdkLogger } from '../utils/logger.js';
@@ -19,6 +19,7 @@ import type {
CLIControlResponse, CLIControlResponse,
ControlCancelRequest, ControlCancelRequest,
PermissionSuggestion, PermissionSuggestion,
WireSDKMcpServerConfig,
} from '../types/protocol.js'; } from '../types/protocol.js';
import { import {
isSDKUserMessage, isSDKUserMessage,
@@ -31,12 +32,17 @@ import {
isControlCancel, isControlCancel,
} from '../types/protocol.js'; } from '../types/protocol.js';
import type { Transport } from '../transport/Transport.js'; import type { Transport } from '../transport/Transport.js';
import type { QueryOptions } from '../types/types.js'; import type { McpServer } from '@modelcontextprotocol/sdk/server/mcp.js';
import type { QueryOptions, CLIMcpServerConfig } from '../types/types.js';
import { isSdkMcpServerConfig } from '../types/types.js';
import { Stream } from '../utils/Stream.js'; import { Stream } from '../utils/Stream.js';
import { serializeJsonLine } from '../utils/jsonLines.js'; import { serializeJsonLine } from '../utils/jsonLines.js';
import { AbortError } from '../types/errors.js'; import { AbortError } from '../types/errors.js';
import type { JSONRPCMessage } from '@modelcontextprotocol/sdk/types.js'; import type { JSONRPCMessage } from '@modelcontextprotocol/sdk/types.js';
import type { SdkControlServerTransport } from '../mcp/SdkControlServerTransport.js'; import {
SdkControlServerTransport,
type SdkControlServerTransportOptions,
} from '../mcp/SdkControlServerTransport.js';
import { ControlRequestType } from '../types/protocol.js'; import { ControlRequestType } from '../types/protocol.js';
interface PendingControlRequest { interface PendingControlRequest {
@@ -46,6 +52,11 @@ interface PendingControlRequest {
abortController: AbortController; abortController: AbortController;
} }
interface PendingMcpResponse {
resolve: (response: JSONRPCMessage) => void;
reject: (error: Error) => void;
}
interface TransportWithEndInput extends Transport { interface TransportWithEndInput extends Transport {
endInput(): void; endInput(): void;
} }
@@ -61,7 +72,9 @@ export class Query implements AsyncIterable<SDKMessage> {
private abortController: AbortController; private abortController: AbortController;
private pendingControlRequests: Map<string, PendingControlRequest> = private pendingControlRequests: Map<string, PendingControlRequest> =
new Map(); new Map();
private pendingMcpResponses: Map<string, PendingMcpResponse> = new Map();
private sdkMcpTransports: Map<string, SdkControlServerTransport> = new Map(); private sdkMcpTransports: Map<string, SdkControlServerTransport> = new Map();
private sdkMcpServers: Map<string, McpServer> = new Map();
readonly initialized: Promise<void>; readonly initialized: Promise<void>;
private closed = false; private closed = false;
private messageRouterStarted = false; private messageRouterStarted = false;
@@ -92,6 +105,11 @@ export class Query implements AsyncIterable<SDKMessage> {
*/ */
this.sdkMessages = this.readSdkMessages(); this.sdkMessages = this.readSdkMessages();
/**
* Promise that resolves when the first SDKResultMessage is received.
* Used to coordinate endInput() timing - ensures all initialization
* (SDK MCP servers, control responses) is complete before closing CLI stdin.
*/
this.firstResultReceivedPromise = new Promise((resolve) => { this.firstResultReceivedPromise = new Promise((resolve) => {
this.firstResultReceivedResolve = resolve; this.firstResultReceivedResolve = resolve;
}); });
@@ -121,17 +139,152 @@ export class Query implements AsyncIterable<SDKMessage> {
this.startMessageRouter(); this.startMessageRouter();
} }
private async initializeSdkMcpServers(): Promise<void> {
if (!this.options.mcpServers) {
return;
}
const connectionPromises: Array<Promise<void>> = [];
// Extract SDK MCP servers from the unified mcpServers config
for (const [key, config] of Object.entries(this.options.mcpServers)) {
if (!isSdkMcpServerConfig(config)) {
continue; // Skip external MCP servers
}
// Use the name from SDKMcpServerConfig, fallback to key for backwards compatibility
const serverName = config.name || key;
const server = config.instance;
// Create transport options with callback to route MCP server responses
const transportOptions: SdkControlServerTransportOptions = {
sendToQuery: async (message: JSONRPCMessage) => {
this.handleMcpServerResponse(serverName, message);
},
serverName,
};
const sdkTransport = new SdkControlServerTransport(transportOptions);
// Connect server to transport and only register on success
const connectionPromise = server
.connect(sdkTransport)
.then(() => {
// Only add to maps after successful connection
this.sdkMcpServers.set(serverName, server);
this.sdkMcpTransports.set(serverName, sdkTransport);
logger.debug(`SDK MCP server '${serverName}' connected to transport`);
})
.catch((error) => {
logger.error(
`Failed to connect SDK MCP server '${serverName}' to transport:`,
error,
);
// Don't throw - one failed server shouldn't prevent others
});
connectionPromises.push(connectionPromise);
}
// Wait for all connection attempts to complete
await Promise.all(connectionPromises);
if (this.sdkMcpServers.size > 0) {
logger.info(
`Initialized ${this.sdkMcpServers.size} SDK MCP server(s): ${Array.from(this.sdkMcpServers.keys()).join(', ')}`,
);
}
}
/**
* Handle response messages from SDK MCP servers
*
* When an MCP server sends a response via transport.send(), this callback
* routes it back to the pending request that's waiting for it.
*/
private handleMcpServerResponse(
serverName: string,
message: JSONRPCMessage,
): void {
// Check if this is a response with an id
if ('id' in message && message.id !== null && message.id !== undefined) {
const key = `${serverName}:${message.id}`;
const pending = this.pendingMcpResponses.get(key);
if (pending) {
logger.debug(
`Routing MCP response for server '${serverName}', id: ${message.id}`,
);
pending.resolve(message);
this.pendingMcpResponses.delete(key);
return;
}
}
// If no pending request found, log a warning (this shouldn't happen normally)
logger.warn(
`Received MCP server response with no pending request: server='${serverName}'`,
message,
);
}
/**
* Get SDK MCP servers config for CLI initialization
*
* Only SDK servers are sent in the initialize request.
*/
private getSdkMcpServersForCli(): Record<string, WireSDKMcpServerConfig> {
const sdkServers: Record<string, WireSDKMcpServerConfig> = {};
for (const [name] of this.sdkMcpServers.entries()) {
sdkServers[name] = { type: 'sdk', name };
}
return sdkServers;
}
/**
* Get external MCP servers (non-SDK) that should be managed by the CLI
*/
private getMcpServersForCli(): Record<string, CLIMcpServerConfig> {
if (!this.options.mcpServers) {
return {};
}
const externalServers: Record<string, CLIMcpServerConfig> = {};
for (const [name, config] of Object.entries(this.options.mcpServers)) {
if (isSdkMcpServerConfig(config)) {
continue;
}
externalServers[name] = config as CLIMcpServerConfig;
}
return externalServers;
}
private async initialize(): Promise<void> { private async initialize(): Promise<void> {
try { try {
logger.debug('Initializing Query'); logger.debug('Initializing Query');
const sdkMcpServerNames = Array.from(this.sdkMcpTransports.keys()); // Initialize SDK MCP servers and wait for connections
await this.initializeSdkMcpServers();
// Get only successfully connected SDK servers for CLI
const sdkMcpServersForCli = this.getSdkMcpServersForCli();
const mcpServersForCli = this.getMcpServersForCli();
logger.debug('SDK MCP servers for CLI:', sdkMcpServersForCli);
logger.debug('External MCP servers for CLI:', mcpServersForCli);
await this.sendControlRequest(ControlRequestType.INITIALIZE, { await this.sendControlRequest(ControlRequestType.INITIALIZE, {
hooks: null, hooks: null,
sdkMcpServers: sdkMcpServers:
sdkMcpServerNames.length > 0 ? sdkMcpServerNames : undefined, Object.keys(sdkMcpServersForCli).length > 0
mcpServers: this.options.mcpServers, ? sdkMcpServersForCli
: undefined,
mcpServers:
Object.keys(mcpServersForCli).length > 0
? mcpServersForCli
: undefined,
agents: this.options.agents, agents: this.options.agents,
}); });
logger.info('Query initialized successfully'); logger.info('Query initialized successfully');
@@ -279,10 +432,12 @@ export class Query implements AsyncIterable<SDKMessage> {
} }
try { try {
const canUseToolTimeout =
this.options.timeout?.canUseTool ?? DEFAULT_CAN_USE_TOOL_TIMEOUT;
const timeoutPromise = new Promise<never>((_, reject) => { const timeoutPromise = new Promise<never>((_, reject) => {
setTimeout( setTimeout(
() => reject(new Error('Permission callback timeout')), () => reject(new Error('Permission callback timeout')),
PERMISSION_CALLBACK_TIMEOUT, canUseToolTimeout,
); );
}); });
@@ -361,32 +516,45 @@ export class Query implements AsyncIterable<SDKMessage> {
} }
private handleMcpRequest( private handleMcpRequest(
_serverName: string, serverName: string,
message: JSONRPCMessage, message: JSONRPCMessage,
transport: SdkControlServerTransport, transport: SdkControlServerTransport,
): Promise<JSONRPCMessage> { ): Promise<JSONRPCMessage> {
const messageId = 'id' in message ? message.id : null;
const key = `${serverName}:${messageId}`;
return new Promise((resolve, reject) => { return new Promise((resolve, reject) => {
const mcpRequestTimeout =
this.options.timeout?.mcpRequest ?? DEFAULT_MCP_REQUEST_TIMEOUT;
const timeout = setTimeout(() => { const timeout = setTimeout(() => {
this.pendingMcpResponses.delete(key);
reject(new Error('MCP request timeout')); reject(new Error('MCP request timeout'));
}, MCP_REQUEST_TIMEOUT); }, mcpRequestTimeout);
const messageId = 'id' in message ? message.id : null; const cleanup = () => {
clearTimeout(timeout);
/** this.pendingMcpResponses.delete(key);
* Hook into transport to capture response.
* Temporarily replace sendToQuery to intercept the response message
* matching this request's ID, then restore the original handler.
*/
const originalSend = transport.sendToQuery;
transport.sendToQuery = async (responseMessage: JSONRPCMessage) => {
if ('id' in responseMessage && responseMessage.id === messageId) {
clearTimeout(timeout);
transport.sendToQuery = originalSend;
resolve(responseMessage);
}
return originalSend(responseMessage);
}; };
const resolveAndCleanup = (response: JSONRPCMessage) => {
cleanup();
resolve(response);
};
const rejectAndCleanup = (error: Error) => {
cleanup();
reject(error);
};
// Register pending response handler
this.pendingMcpResponses.set(key, {
resolve: resolveAndCleanup,
reject: rejectAndCleanup,
});
// Deliver message to MCP server via transport.onmessage
// The server will process it and call transport.send() with the response,
// which triggers handleMcpServerResponse to resolve our pending promise
transport.handleMessage(message); transport.handleMessage(message);
}); });
} }
@@ -452,6 +620,10 @@ export class Query implements AsyncIterable<SDKMessage> {
subtype: string, subtype: string,
data: Record<string, unknown> = {}, data: Record<string, unknown> = {},
): Promise<Record<string, unknown> | null> { ): Promise<Record<string, unknown> | null> {
if (this.closed) {
return Promise.reject(new Error('Query is closed'));
}
const requestId = randomUUID(); const requestId = randomUUID();
const request: CLIControlRequest = { const request: CLIControlRequest = {
@@ -466,10 +638,13 @@ export class Query implements AsyncIterable<SDKMessage> {
const responsePromise = new Promise<Record<string, unknown> | null>( const responsePromise = new Promise<Record<string, unknown> | null>(
(resolve, reject) => { (resolve, reject) => {
const abortController = new AbortController(); const abortController = new AbortController();
const controlRequestTimeout =
this.options.timeout?.controlRequest ??
DEFAULT_CONTROL_REQUEST_TIMEOUT;
const timeout = setTimeout(() => { const timeout = setTimeout(() => {
this.pendingControlRequests.delete(requestId); this.pendingControlRequests.delete(requestId);
reject(new Error(`Control request timeout: ${subtype}`)); reject(new Error(`Control request timeout: ${subtype}`));
}, CONTROL_REQUEST_TIMEOUT); }, controlRequestTimeout);
this.pendingControlRequests.set(requestId, { this.pendingControlRequests.set(requestId, {
resolve, resolve,
@@ -517,9 +692,16 @@ export class Query implements AsyncIterable<SDKMessage> {
for (const pending of this.pendingControlRequests.values()) { for (const pending of this.pendingControlRequests.values()) {
pending.abortController.abort(); pending.abortController.abort();
clearTimeout(pending.timeout); clearTimeout(pending.timeout);
pending.reject(new Error('Query is closed'));
} }
this.pendingControlRequests.clear(); this.pendingControlRequests.clear();
// Clean up pending MCP responses
for (const pending of this.pendingMcpResponses.values()) {
pending.reject(new Error('Query is closed'));
}
this.pendingMcpResponses.clear();
await this.transport.close(); await this.transport.close();
/** /**
@@ -542,7 +724,7 @@ export class Query implements AsyncIterable<SDKMessage> {
} }
} }
this.sdkMcpTransports.clear(); this.sdkMcpTransports.clear();
logger.info('Query closed'); logger.info('Query is closed');
} }
private async *readSdkMessages(): AsyncGenerator<SDKMessage> { private async *readSdkMessages(): AsyncGenerator<SDKMessage> {
@@ -588,22 +770,31 @@ export class Query implements AsyncIterable<SDKMessage> {
} }
/** /**
* In multi-turn mode with MCP servers, wait for first result * After all user messages are sent (for-await loop ended), determine when to
* to ensure MCP servers have time to process before next input. * close the CLI's stdin via endInput().
* This prevents race conditions where the next input arrives before *
* MCP servers have finished processing the current request. * - If a result message was already received: All initialization (SDK MCP servers,
* control responses, etc.) is complete, safe to close stdin immediately.
* - If no result yet: Wait for either the result to arrive, or the timeout to expire.
* This gives pending control_responses from SDK MCP servers or other modules
* time to complete their initialization before we close the input stream.
*
* The timeout ensures we don't hang indefinitely - either the turn proceeds
* normally, or it fails with a timeout, but Promise.race will always resolve.
*/ */
if ( if (
!this.isSingleTurn && !this.isSingleTurn &&
this.sdkMcpTransports.size > 0 && this.sdkMcpTransports.size > 0 &&
this.firstResultReceivedPromise this.firstResultReceivedPromise
) { ) {
const streamCloseTimeout =
this.options.timeout?.streamClose ?? DEFAULT_STREAM_CLOSE_TIMEOUT;
await Promise.race([ await Promise.race([
this.firstResultReceivedPromise, this.firstResultReceivedPromise,
new Promise<void>((resolve) => { new Promise<void>((resolve) => {
setTimeout(() => { setTimeout(() => {
resolve(); resolve();
}, STREAM_CLOSE_TIMEOUT); }, streamCloseTimeout);
}), }),
]); ]);
} }
@@ -635,28 +826,16 @@ export class Query implements AsyncIterable<SDKMessage> {
} }
async interrupt(): Promise<void> { async interrupt(): Promise<void> {
if (this.closed) {
throw new Error('Query is closed');
}
await this.sendControlRequest(ControlRequestType.INTERRUPT); await this.sendControlRequest(ControlRequestType.INTERRUPT);
} }
async setPermissionMode(mode: string): Promise<void> { async setPermissionMode(mode: string): Promise<void> {
if (this.closed) {
throw new Error('Query is closed');
}
await this.sendControlRequest(ControlRequestType.SET_PERMISSION_MODE, { await this.sendControlRequest(ControlRequestType.SET_PERMISSION_MODE, {
mode, mode,
}); });
} }
async setModel(model: string): Promise<void> { async setModel(model: string): Promise<void> {
if (this.closed) {
throw new Error('Query is closed');
}
await this.sendControlRequest(ControlRequestType.SET_MODEL, { model }); await this.sendControlRequest(ControlRequestType.SET_MODEL, { model });
} }
@@ -667,10 +846,6 @@ export class Query implements AsyncIterable<SDKMessage> {
* @throws Error if query is closed * @throws Error if query is closed
*/ */
async supportedCommands(): Promise<Record<string, unknown> | null> { async supportedCommands(): Promise<Record<string, unknown> | null> {
if (this.closed) {
throw new Error('Query is closed');
}
return this.sendControlRequest(ControlRequestType.SUPPORTED_COMMANDS); return this.sendControlRequest(ControlRequestType.SUPPORTED_COMMANDS);
} }
@@ -681,10 +856,6 @@ export class Query implements AsyncIterable<SDKMessage> {
* @throws Error if query is closed * @throws Error if query is closed
*/ */
async mcpServerStatus(): Promise<Record<string, unknown> | null> { async mcpServerStatus(): Promise<Record<string, unknown> | null> {
if (this.closed) {
throw new Error('Query is closed');
}
return this.sendControlRequest(ControlRequestType.MCP_SERVER_STATUS); return this.sendControlRequest(ControlRequestType.MCP_SERVER_STATUS);
} }

View File

@@ -1,5 +1,6 @@
/* eslint-disable @typescript-eslint/no-explicit-any */ /* eslint-disable @typescript-eslint/no-explicit-any */
import type { McpServer } from '@modelcontextprotocol/sdk/server/mcp.js';
export interface Annotation { export interface Annotation {
type: string; type: string;
value: string; value: string;
@@ -293,10 +294,44 @@ export interface MCPServerConfig {
targetServiceAccount?: string; targetServiceAccount?: string;
} }
/**
* SDK MCP Server configuration
*
* SDK MCP servers run in the SDK process and are connected via in-memory transport.
* Tool calls are routed through the control plane between SDK and CLI.
*/
export interface SDKMcpServerConfig {
/**
* Type identifier for SDK MCP servers
*/
type: 'sdk';
/**
* Server name for identification and routing
*/
name: string;
/**
* The MCP Server instance created by createSdkMcpServer()
*/
instance: McpServer;
}
/**
* Wire format for SDK MCP servers sent to the CLI
*/
export type WireSDKMcpServerConfig = Omit<SDKMcpServerConfig, 'instance'>;
export interface CLIControlInitializeRequest { export interface CLIControlInitializeRequest {
subtype: 'initialize'; subtype: 'initialize';
hooks?: HookRegistration[] | null; hooks?: HookRegistration[] | null;
sdkMcpServers?: Record<string, MCPServerConfig>; /**
* SDK MCP servers config
* These are MCP servers running in the SDK process, connected via control plane.
* External MCP servers are configured separately in settings, not via initialization.
*/
sdkMcpServers?: Record<string, WireSDKMcpServerConfig>;
/**
* External MCP servers that should be managed by the CLI.
*/
mcpServers?: Record<string, MCPServerConfig>; mcpServers?: Record<string, MCPServerConfig>;
agents?: SubagentConfig[]; agents?: SubagentConfig[];
} }

View File

@@ -2,19 +2,98 @@ import { z } from 'zod';
import type { CanUseTool } from './types.js'; import type { CanUseTool } from './types.js';
import type { SubagentConfig } from './protocol.js'; import type { SubagentConfig } from './protocol.js';
export const ExternalMcpServerConfigSchema = z.object({ /**
command: z.string().min(1, 'Command must be a non-empty string'), * OAuth configuration for MCP servers
*/
export const McpOAuthConfigSchema = z
.object({
enabled: z.boolean().optional(),
clientId: z
.string()
.min(1, 'clientId must be a non-empty string')
.optional(),
clientSecret: z.string().optional(),
scopes: z.array(z.string()).optional(),
redirectUri: z.string().optional(),
authorizationUrl: z.string().optional(),
tokenUrl: z.string().optional(),
audiences: z.array(z.string()).optional(),
tokenParamName: z.string().optional(),
registrationUrl: z.string().optional(),
})
.strict();
/**
* CLI MCP Server configuration schema
*
* Supports multiple transport types:
* - stdio: command, args, env, cwd
* - SSE: url
* - Streamable HTTP: httpUrl, headers
* - WebSocket: tcp
*/
export const CLIMcpServerConfigSchema = z.object({
// For stdio transport
command: z.string().optional(),
args: z.array(z.string()).optional(), args: z.array(z.string()).optional(),
env: z.record(z.string(), z.string()).optional(), env: z.record(z.string(), z.string()).optional(),
cwd: z.string().optional(),
// For SSE transport
url: z.string().optional(),
// For streamable HTTP transport
httpUrl: z.string().optional(),
headers: z.record(z.string(), z.string()).optional(),
// For WebSocket transport
tcp: z.string().optional(),
// Common
timeout: z.number().optional(),
trust: z.boolean().optional(),
// Metadata
description: z.string().optional(),
includeTools: z.array(z.string()).optional(),
excludeTools: z.array(z.string()).optional(),
extensionName: z.string().optional(),
// OAuth configuration
oauth: McpOAuthConfigSchema.optional(),
authProviderType: z
.enum([
'dynamic_discovery',
'google_credentials',
'service_account_impersonation',
])
.optional(),
// Service Account Configuration
targetAudience: z.string().optional(),
targetServiceAccount: z.string().optional(),
}); });
/**
* SDK MCP Server configuration schema
*/
export const SdkMcpServerConfigSchema = z.object({ export const SdkMcpServerConfigSchema = z.object({
connect: z.custom<(transport: unknown) => Promise<void>>( type: z.literal('sdk'),
(val) => typeof val === 'function', name: z.string().min(1, 'name must be a non-empty string'),
{ message: 'connect must be a function' }, instance: z.custom<{
connect(transport: unknown): Promise<void>;
close(): Promise<void>;
}>(
(val) =>
val &&
typeof val === 'object' &&
'connect' in val &&
typeof val.connect === 'function',
{ message: 'instance must be an MCP Server with connect method' },
), ),
}); });
/**
* Unified MCP Server configuration schema
*/
export const McpServerConfigSchema = z.union([
CLIMcpServerConfigSchema,
SdkMcpServerConfigSchema,
]);
export const ModelConfigSchema = z.object({ export const ModelConfigSchema = z.object({
model: z.string().optional(), model: z.string().optional(),
temp: z.number().optional(), temp: z.number().optional(),
@@ -37,6 +116,13 @@ export const SubagentConfigSchema = z.object({
isBuiltin: z.boolean().optional(), isBuiltin: z.boolean().optional(),
}); });
export const TimeoutConfigSchema = z.object({
canUseTool: z.number().positive().optional(),
mcpRequest: z.number().positive().optional(),
controlRequest: z.number().positive().optional(),
streamClose: z.number().positive().optional(),
});
export const QueryOptionsSchema = z export const QueryOptionsSchema = z
.object({ .object({
cwd: z.string().optional(), cwd: z.string().optional(),
@@ -49,7 +135,7 @@ export const QueryOptionsSchema = z
message: 'canUseTool must be a function', message: 'canUseTool must be a function',
}) })
.optional(), .optional(),
mcpServers: z.record(z.string(), ExternalMcpServerConfigSchema).optional(), mcpServers: z.record(z.string(), McpServerConfigSchema).optional(),
abortController: z.instanceof(AbortController).optional(), abortController: z.instanceof(AbortController).optional(),
debug: z.boolean().optional(), debug: z.boolean().optional(),
stderr: z stderr: z
@@ -78,5 +164,6 @@ export const QueryOptionsSchema = z
) )
.optional(), .optional(),
includePartialMessages: z.boolean().optional(), includePartialMessages: z.boolean().optional(),
timeout: TimeoutConfigSchema.optional(),
}) })
.strict(); .strict();

View File

@@ -2,25 +2,11 @@ import type {
PermissionMode, PermissionMode,
PermissionSuggestion, PermissionSuggestion,
SubagentConfig, SubagentConfig,
SDKMcpServerConfig,
} from './protocol.js'; } from './protocol.js';
export type { PermissionMode }; export type { PermissionMode };
type JSONSchema = {
type: string;
properties?: Record<string, unknown>;
required?: string[];
description?: string;
[key: string]: unknown;
};
export type ToolDefinition<TInput = unknown, TOutput = unknown> = {
name: string;
description: string;
inputSchema: JSONSchema;
handler: (input: TInput) => Promise<TOutput>;
};
export type TransportOptions = { export type TransportOptions = {
pathToQwenExecutable: string; pathToQwenExecutable: string;
cwd?: string; cwd?: string;
@@ -61,14 +47,115 @@ export type PermissionResult =
interrupt?: boolean; interrupt?: boolean;
}; };
export interface ExternalMcpServerConfig { /**
command: string; * OAuth configuration for MCP servers
args?: string[]; */
env?: Record<string, string>; export interface McpOAuthConfig {
enabled?: boolean;
clientId?: string;
clientSecret?: string;
scopes?: string[];
redirectUri?: string;
authorizationUrl?: string;
tokenUrl?: string;
audiences?: string[];
tokenParamName?: string;
registrationUrl?: string;
} }
export interface SdkMcpServerConfig { /**
connect: (transport: unknown) => Promise<void>; * Auth provider type for MCP servers
*/
export type McpAuthProviderType =
| 'dynamic_discovery'
| 'google_credentials'
| 'service_account_impersonation';
/**
* CLI MCP Server configuration
*
* Supports multiple transport types:
* - stdio: command, args, env, cwd
* - SSE: url
* - Streamable HTTP: httpUrl, headers
* - WebSocket: tcp
*
* This interface aligns with MCPServerConfig in @qwen-code/qwen-code-core.
*/
export interface CLIMcpServerConfig {
// For stdio transport
command?: string;
args?: string[];
env?: Record<string, string>;
cwd?: string;
// For SSE transport
url?: string;
// For streamable HTTP transport
httpUrl?: string;
headers?: Record<string, string>;
// For WebSocket transport
tcp?: string;
// Common
timeout?: number;
trust?: boolean;
// Metadata
description?: string;
includeTools?: string[];
excludeTools?: string[];
extensionName?: string;
// OAuth configuration
oauth?: McpOAuthConfig;
authProviderType?: McpAuthProviderType;
// Service Account Configuration
/** targetAudience format: CLIENT_ID.apps.googleusercontent.com */
targetAudience?: string;
/** targetServiceAccount format: <service-account-name>@<project-num>.iam.gserviceaccount.com */
targetServiceAccount?: string;
}
/**
* Unified MCP Server configuration
*
* Supports both external MCP servers (stdio/SSE/HTTP/WebSocket) and SDK-embedded MCP servers.
*
* @example External MCP server (stdio)
* ```typescript
* mcpServers: {
* 'my-server': { command: 'node', args: ['server.js'] }
* }
* ```
*
* @example External MCP server (SSE)
* ```typescript
* mcpServers: {
* 'remote-server': { url: 'http://localhost:3000/sse' }
* }
* ```
*
* @example External MCP server (Streamable HTTP)
* ```typescript
* mcpServers: {
* 'http-server': { httpUrl: 'http://localhost:3000/mcp', headers: { 'Authorization': 'Bearer token' } }
* }
* ```
*
* @example SDK MCP server
* ```typescript
* const server = createSdkMcpServer('weather', '1.0.0', [weatherTool]);
* mcpServers: {
* 'weather': { type: 'sdk', name: 'weather', instance: server }
* }
* ```
*/
export type McpServerConfig = CLIMcpServerConfig | SDKMcpServerConfig;
/**
* Type guard to check if a config is an SDK MCP server
*/
export function isSdkMcpServerConfig(
config: McpServerConfig,
): config is SDKMcpServerConfig {
return 'type' in config && config.type === 'sdk';
} }
/** /**
@@ -174,11 +261,36 @@ export interface QueryOptions {
canUseTool?: CanUseTool; canUseTool?: CanUseTool;
/** /**
* External MCP (Model Context Protocol) servers to connect to. * MCP (Model Context Protocol) servers to connect to.
* Each server is identified by a unique name and configured with command, args, and environment. *
* @example { 'my-server': { command: 'node', args: ['server.js'], env: { PORT: '3000' } } } * Supports both external MCP servers and SDK-embedded MCP servers:
*
* **External MCP servers** - Run in separate processes, connected via stdio/SSE/HTTP:
* ```typescript
* mcpServers: {
* 'stdio-server': { command: 'node', args: ['server.js'], env: { PORT: '3000' } },
* 'sse-server': { url: 'http://localhost:3000/sse' },
* 'http-server': { httpUrl: 'http://localhost:3000/mcp' }
* }
* ```
*
* **SDK MCP servers** - Run in the SDK process, connected via in-memory transport:
* ```typescript
* const myTool = tool({
* name: 'my_tool',
* description: 'My custom tool',
* inputSchema: { type: 'object', properties: { input: { type: 'string' } } },
* handler: async (input) => ({ result: input.input.toUpperCase() }),
* });
*
* const server = createSdkMcpServer('my-server', '1.0.0', [myTool]);
*
* mcpServers: {
* 'my-server': { type: 'sdk', name: 'my-server', instance: server }
* }
* ```
*/ */
mcpServers?: Record<string, ExternalMcpServerConfig>; mcpServers?: Record<string, McpServerConfig>;
/** /**
* AbortController to cancel the query session. * AbortController to cancel the query session.
@@ -294,4 +406,43 @@ export interface QueryOptions {
* @default false * @default false
*/ */
includePartialMessages?: boolean; includePartialMessages?: boolean;
/**
* Timeout configuration for various SDK operations.
* All values are in milliseconds.
*/
timeout?: {
/**
* Timeout for the `canUseTool` callback.
* If the callback doesn't resolve within this time, the permission request
* will be denied with a timeout error (fail-safe behavior).
* @default 60000 (1 minute)
*/
canUseTool?: number;
/**
* Timeout for SDK MCP tool calls.
* This applies to tool calls made to SDK-embedded MCP servers.
* @default 60000 (1 minute)
*/
mcpRequest?: number;
/**
* Timeout for SDK→CLI control requests.
* This applies to internal control operations like initialize, interrupt,
* setPermissionMode, setModel, etc.
* @default 60000 (1 minute)
*/
controlRequest?: number;
/**
* Timeout for waiting before closing CLI's stdin after user messages are sent.
* In multi-turn mode with SDK MCP servers, after all user messages are processed,
* the SDK waits for the first result message to ensure all initialization
* (control responses, MCP server setup, etc.) is complete before closing stdin.
* This timeout is a fallback to avoid hanging indefinitely.
* @default 60000 (1 minute)
*/
streamClose?: number;
};
} }

View File

@@ -1,3 +1,9 @@
/**
* @license
* Copyright 2025 Qwen Team
* SPDX-License-Identifier: Apache-2.0
*/
/** /**
* Unit tests for createSdkMcpServer * Unit tests for createSdkMcpServer
* *
@@ -5,93 +11,112 @@
*/ */
import { describe, expect, it, vi } from 'vitest'; import { describe, expect, it, vi } from 'vitest';
import { z } from 'zod';
import { createSdkMcpServer } from '../../src/mcp/createSdkMcpServer.js'; import { createSdkMcpServer } from '../../src/mcp/createSdkMcpServer.js';
import { tool } from '../../src/mcp/tool.js'; import { tool } from '../../src/mcp/tool.js';
import type { ToolDefinition } from '../../src/types/config.js'; import type { SdkMcpToolDefinition } from '../../src/mcp/tool.js';
describe('createSdkMcpServer', () => { describe('createSdkMcpServer', () => {
describe('Server Creation', () => { describe('Server Creation', () => {
it('should create server with name and version', () => { it('should create server with name and version', () => {
const server = createSdkMcpServer('test-server', '1.0.0', []); const server = createSdkMcpServer({
name: 'test-server',
version: '1.0.0',
tools: [],
});
expect(server).toBeDefined(); expect(server).toBeDefined();
expect(server.type).toBe('sdk');
expect(server.name).toBe('test-server');
expect(server.instance).toBeDefined();
});
it('should create server with default version', () => {
const server = createSdkMcpServer({
name: 'test-server',
});
expect(server).toBeDefined();
expect(server.name).toBe('test-server');
}); });
it('should throw error with invalid name', () => { it('should throw error with invalid name', () => {
expect(() => createSdkMcpServer('', '1.0.0', [])).toThrow( expect(() => createSdkMcpServer({ name: '', version: '1.0.0' })).toThrow(
'name must be a non-empty string', 'MCP server name must be a non-empty string',
); );
}); });
it('should throw error with invalid version', () => { it('should throw error with invalid version', () => {
expect(() => createSdkMcpServer('test', '', [])).toThrow( expect(() => createSdkMcpServer({ name: 'test', version: '' })).toThrow(
'version must be a non-empty string', 'MCP server version must be a non-empty string',
); );
}); });
it('should throw error with non-array tools', () => { it('should throw error with non-array tools', () => {
expect(() => expect(() =>
createSdkMcpServer('test', '1.0.0', {} as unknown as ToolDefinition[]), createSdkMcpServer({
name: 'test',
version: '1.0.0',
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
tools: {} as unknown as SdkMcpToolDefinition<any>[],
}),
).toThrow('Tools must be an array'); ).toThrow('Tools must be an array');
}); });
}); });
describe('Tool Registration', () => { describe('Tool Registration', () => {
it('should register single tool', () => { it('should register single tool', () => {
const testTool = tool({ const testTool = tool(
name: 'test_tool', 'test_tool',
description: 'A test tool', 'A test tool',
inputSchema: { { input: z.string() },
type: 'object', async () => ({
properties: { content: [{ type: 'text', text: 'result' }],
input: { type: 'string' }, }),
}, );
},
handler: async () => 'result',
});
const server = createSdkMcpServer('test-server', '1.0.0', [testTool]); const server = createSdkMcpServer({
name: 'test-server',
version: '1.0.0',
tools: [testTool],
});
expect(server).toBeDefined(); expect(server).toBeDefined();
}); });
it('should register multiple tools', () => { it('should register multiple tools', () => {
const tool1 = tool({ const tool1 = tool('tool1', 'Tool 1', {}, async () => ({
name: 'tool1', content: [{ type: 'text', text: 'result1' }],
description: 'Tool 1', }));
inputSchema: { type: 'object' },
handler: async () => 'result1',
});
const tool2 = tool({ const tool2 = tool('tool2', 'Tool 2', {}, async () => ({
name: 'tool2', content: [{ type: 'text', text: 'result2' }],
description: 'Tool 2', }));
inputSchema: { type: 'object' },
handler: async () => 'result2',
});
const server = createSdkMcpServer('test-server', '1.0.0', [tool1, tool2]); const server = createSdkMcpServer({
name: 'test-server',
version: '1.0.0',
tools: [tool1, tool2],
});
expect(server).toBeDefined(); expect(server).toBeDefined();
}); });
it('should throw error for duplicate tool names', () => { it('should throw error for duplicate tool names', () => {
const tool1 = tool({ const tool1 = tool('duplicate', 'Tool 1', {}, async () => ({
name: 'duplicate', content: [{ type: 'text', text: 'result1' }],
description: 'Tool 1', }));
inputSchema: { type: 'object' },
handler: async () => 'result1',
});
const tool2 = tool({ const tool2 = tool('duplicate', 'Tool 2', {}, async () => ({
name: 'duplicate', content: [{ type: 'text', text: 'result2' }],
description: 'Tool 2', }));
inputSchema: { type: 'object' },
handler: async () => 'result2',
});
expect(() => expect(() =>
createSdkMcpServer('test-server', '1.0.0', [tool1, tool2]), createSdkMcpServer({
name: 'test-server',
version: '1.0.0',
tools: [tool1, tool2],
}),
).toThrow("Duplicate tool name 'duplicate'"); ).toThrow("Duplicate tool name 'duplicate'");
}); });
@@ -99,36 +124,41 @@ describe('createSdkMcpServer', () => {
const invalidTool = { const invalidTool = {
name: '123invalid', // Starts with number name: '123invalid', // Starts with number
description: 'Invalid tool', description: 'Invalid tool',
inputSchema: { type: 'object' }, inputSchema: {},
handler: async () => 'result', handler: async () => ({
content: [{ type: 'text' as const, text: 'result' }],
}),
}; };
expect(() => expect(() =>
createSdkMcpServer('test-server', '1.0.0', [ createSdkMcpServer({
invalidTool as unknown as ToolDefinition, name: 'test-server',
]), version: '1.0.0',
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
tools: [invalidTool as unknown as SdkMcpToolDefinition<any>],
}),
).toThrow('Tool name'); ).toThrow('Tool name');
}); });
}); });
describe('Tool Handler Invocation', () => { describe('Tool Handler Invocation', () => {
it('should invoke tool handler with correct input', async () => { it('should invoke tool handler with correct input', async () => {
const handler = vi.fn().mockResolvedValue({ result: 'success' }); const handler = vi.fn().mockResolvedValue({
content: [{ type: 'text', text: 'success' }],
const testTool = tool({
name: 'test_tool',
description: 'A test tool',
inputSchema: {
type: 'object',
properties: {
value: { type: 'string' },
},
required: ['value'],
},
handler,
}); });
createSdkMcpServer('test-server', '1.0.0', [testTool]); const testTool = tool(
'test_tool',
'A test tool',
{ value: z.string() },
handler,
);
createSdkMcpServer({
name: 'test-server',
version: '1.0.0',
tools: [testTool],
});
// Note: Actual invocation testing requires MCP SDK integration // Note: Actual invocation testing requires MCP SDK integration
// This test verifies the handler was properly registered // This test verifies the handler was properly registered
@@ -140,17 +170,18 @@ describe('createSdkMcpServer', () => {
.fn() .fn()
.mockImplementation(async (input: { value: string }) => { .mockImplementation(async (input: { value: string }) => {
await new Promise((resolve) => setTimeout(resolve, 10)); await new Promise((resolve) => setTimeout(resolve, 10));
return { processed: input.value }; return {
content: [{ type: 'text', text: `processed: ${input.value}` }],
};
}); });
const testTool = tool({ const testTool = tool('async_tool', 'An async tool', {}, handler);
name: 'async_tool',
description: 'An async tool',
inputSchema: { type: 'object' },
handler,
});
const server = createSdkMcpServer('test-server', '1.0.0', [testTool]); const server = createSdkMcpServer({
name: 'test-server',
version: '1.0.0',
tools: [testTool],
});
expect(server).toBeDefined(); expect(server).toBeDefined();
}); });
@@ -158,40 +189,29 @@ describe('createSdkMcpServer', () => {
describe('Type Safety', () => { describe('Type Safety', () => {
it('should preserve input type in handler', async () => { it('should preserve input type in handler', async () => {
type ToolInput = { const handler = vi.fn().mockImplementation(async (input) => {
name: string; return {
age: number; content: [
}; { type: 'text', text: `Hello ${input.name}, age ${input.age}` },
],
type ToolOutput = { };
greeting: string;
};
const handler = vi
.fn()
.mockImplementation(async (input: ToolInput): Promise<ToolOutput> => {
return {
greeting: `Hello ${input.name}, age ${input.age}`,
};
});
const typedTool = tool<ToolInput, ToolOutput>({
name: 'typed_tool',
description: 'A typed tool',
inputSchema: {
type: 'object',
properties: {
name: { type: 'string' },
age: { type: 'number' },
},
required: ['name', 'age'],
},
handler,
}); });
const server = createSdkMcpServer('test-server', '1.0.0', [ const typedTool = tool(
typedTool as ToolDefinition, 'typed_tool',
]); 'A typed tool',
{
name: z.string(),
age: z.number(),
},
handler,
);
const server = createSdkMcpServer({
name: 'test-server',
version: '1.0.0',
tools: [typedTool],
});
expect(server).toBeDefined(); expect(server).toBeDefined();
}); });
@@ -201,14 +221,13 @@ describe('createSdkMcpServer', () => {
it('should handle tool handler errors gracefully', async () => { it('should handle tool handler errors gracefully', async () => {
const handler = vi.fn().mockRejectedValue(new Error('Tool failed')); const handler = vi.fn().mockRejectedValue(new Error('Tool failed'));
const errorTool = tool({ const errorTool = tool('error_tool', 'A tool that errors', {}, handler);
name: 'error_tool',
description: 'A tool that errors',
inputSchema: { type: 'object' },
handler,
});
const server = createSdkMcpServer('test-server', '1.0.0', [errorTool]); const server = createSdkMcpServer({
name: 'test-server',
version: '1.0.0',
tools: [errorTool],
});
expect(server).toBeDefined(); expect(server).toBeDefined();
// Error handling occurs during tool invocation // Error handling occurs during tool invocation
@@ -219,14 +238,18 @@ describe('createSdkMcpServer', () => {
throw new Error('Sync error'); throw new Error('Sync error');
}); });
const errorTool = tool({ const errorTool = tool(
name: 'sync_error_tool', 'sync_error_tool',
description: 'A tool that errors synchronously', 'A tool that errors synchronously',
inputSchema: { type: 'object' }, {},
handler, handler,
}); );
const server = createSdkMcpServer('test-server', '1.0.0', [errorTool]); const server = createSdkMcpServer({
name: 'test-server',
version: '1.0.0',
tools: [errorTool],
});
expect(server).toBeDefined(); expect(server).toBeDefined();
}); });
@@ -234,69 +257,76 @@ describe('createSdkMcpServer', () => {
describe('Complex Tool Scenarios', () => { describe('Complex Tool Scenarios', () => {
it('should support tool with complex input schema', () => { it('should support tool with complex input schema', () => {
const complexTool = tool({ const complexTool = tool(
name: 'complex_tool', 'complex_tool',
description: 'A tool with complex schema', 'A tool with complex schema',
inputSchema: { {
type: 'object', query: z.string(),
properties: { filters: z
query: { type: 'string' }, .object({
filters: { category: z.string().optional(),
type: 'object', minPrice: z.number().optional(),
properties: { })
category: { type: 'string' }, .optional(),
minPrice: { type: 'number' }, options: z.array(z.string()).optional(),
},
},
options: {
type: 'array',
items: { type: 'string' },
},
},
required: ['query'],
}, },
handler: async (input: { filters?: unknown[] }) => { async (input) => {
return { return {
results: [], content: [
filters: input.filters, {
type: 'text',
text: JSON.stringify({ results: [], filters: input.filters }),
},
],
}; };
}, },
}); );
const server = createSdkMcpServer('test-server', '1.0.0', [ const server = createSdkMcpServer({
complexTool as ToolDefinition, name: 'test-server',
]); version: '1.0.0',
tools: [complexTool],
});
expect(server).toBeDefined(); expect(server).toBeDefined();
}); });
it('should support tool returning complex output', () => { it('should support tool returning complex output', () => {
const complexOutputTool = tool({ const complexOutputTool = tool(
name: 'complex_output_tool', 'complex_output_tool',
description: 'Returns complex data', 'Returns complex data',
inputSchema: { type: 'object' }, {},
handler: async () => { async () => {
return { return {
data: [ content: [
{ id: 1, name: 'Item 1' }, {
{ id: 2, name: 'Item 2' }, type: 'text',
], text: JSON.stringify({
metadata: { data: [
total: 2, { id: 1, name: 'Item 1' },
page: 1, { id: 2, name: 'Item 2' },
}, ],
nested: { metadata: {
deep: { total: 2,
value: 'test', page: 1,
},
nested: {
deep: {
value: 'test',
},
},
}),
}, },
}, ],
}; };
}, },
}); );
const server = createSdkMcpServer('test-server', '1.0.0', [ const server = createSdkMcpServer({
complexOutputTool, name: 'test-server',
]); version: '1.0.0',
tools: [complexOutputTool],
});
expect(server).toBeDefined(); expect(server).toBeDefined();
}); });
@@ -304,44 +334,50 @@ describe('createSdkMcpServer', () => {
describe('Multiple Servers', () => { describe('Multiple Servers', () => {
it('should create multiple independent servers', () => { it('should create multiple independent servers', () => {
const tool1 = tool({ const tool1 = tool('tool1', 'Tool in server 1', {}, async () => ({
name: 'tool1', content: [{ type: 'text', text: 'result1' }],
description: 'Tool in server 1', }));
inputSchema: { type: 'object' },
handler: async () => 'result1',
});
const tool2 = tool({ const tool2 = tool('tool2', 'Tool in server 2', {}, async () => ({
name: 'tool2', content: [{ type: 'text', text: 'result2' }],
description: 'Tool in server 2', }));
inputSchema: { type: 'object' },
handler: async () => 'result2',
});
const server1 = createSdkMcpServer('server1', '1.0.0', [tool1]); const server1 = createSdkMcpServer({
const server2 = createSdkMcpServer('server2', '1.0.0', [tool2]); name: 'server1',
version: '1.0.0',
tools: [tool1],
});
const server2 = createSdkMcpServer({
name: 'server2',
version: '1.0.0',
tools: [tool2],
});
expect(server1).toBeDefined(); expect(server1).toBeDefined();
expect(server2).toBeDefined(); expect(server2).toBeDefined();
expect(server1.name).toBe('server1');
expect(server2.name).toBe('server2');
}); });
it('should allow same tool name in different servers', () => { it('should allow same tool name in different servers', () => {
const tool1 = tool({ const tool1 = tool('shared_name', 'Tool in server 1', {}, async () => ({
name: 'shared_name', content: [{ type: 'text', text: 'result1' }],
description: 'Tool in server 1', }));
inputSchema: { type: 'object' },
handler: async () => 'result1',
});
const tool2 = tool({ const tool2 = tool('shared_name', 'Tool in server 2', {}, async () => ({
name: 'shared_name', content: [{ type: 'text', text: 'result2' }],
description: 'Tool in server 2', }));
inputSchema: { type: 'object' },
handler: async () => 'result2',
});
const server1 = createSdkMcpServer('server1', '1.0.0', [tool1]); const server1 = createSdkMcpServer({
const server2 = createSdkMcpServer('server2', '1.0.0', [tool2]); name: 'server1',
version: '1.0.0',
tools: [tool1],
});
const server2 = createSdkMcpServer({
name: 'server2',
version: '1.0.0',
tools: [tool2],
});
expect(server1).toBeDefined(); expect(server1).toBeDefined();
expect(server2).toBeDefined(); expect(server2).toBeDefined();